"""Defines a class storing data in various storages.
Requires the optional :mod:`zarr` module.
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import base64
import json
import pickle
import shutil
from collections.abc import Collection, Sequence
from pathlib import Path
from typing import Any, Union
import numpy as np
import zarr
from numpy.typing import ArrayLike, DTypeLike
zarr_version = int(zarr.__version__.split(".", 1)[0])
assert zarr_version == 3
from zarr.abc.store import Store
from ..access_modes import ModeType
from ..attributes import AttrsLike
from ..base import StorageBase
zarrElement = zarr.Group | zarr.Array
[docs]
class ZarrStorage(StorageBase):
"""Storage that stores data in an zarr file or database."""
extensions = ["zarr", "zip", "sqldb", "lmdb"]
def __init__(self, store_or_path: str | Path | Store, *, mode: ModeType = "read"):
"""
Args:
store_or_path (str or :class:`~pathlib.Path` or :class:`~zarr._storage.store.Store`):
File path to the file/folder or a :mod:`zarr` Store
mode (str or :class:`~modelrunner.storage.access_modes.AccessMode`):
The file mode with which the storage is accessed. Determines allowed
operations.
"""
super().__init__(mode=mode)
if isinstance(store_or_path, (str, Path)):
# open zarr storage on file system
self._close = True
path = Path(store_or_path)
if self.mode.file_mode == "x" and path.exists():
raise FileExistsError(f"File `{path}` already exists")
if path.suffix in {"", ".zarr"}:
# path seems to be a directory or a zarr direction => DirectoryStore
if path.is_dir() and self.mode.file_mode == "w":
self._logger.info("Delete directory `{%s}`", path)
shutil.rmtree(path) # remove the directory to reinstate it
if self.mode.file_mode == "r":
self._logger.info("DirectoryStore is always opened writable")
self._store = zarr.storage.LocalStore(path)
elif path.suffix == ".zip":
# create a ZipStore
file_mode = self.mode.file_mode
if file_mode == "x":
if path.exists():
raise OSError("File `{path}` already exists")
else:
file_mode = "w"
elif file_mode == "w" and path.exists():
self._logger.info("Delete file `%s`", path)
path.unlink()
self._store = zarr.storage.ZipStore(path, mode=file_mode)
elif isinstance(store_or_path, Store):
# use already opened zarr storage
self._close = False
self._store = store_or_path
else:
raise TypeError(f"Unknown store `{store_or_path}`")
zarr_mode = "r" if self.mode.file_mode == "r" else "a"
self._root = zarr.open_group(store=self._store, mode=zarr_mode)
@property
def can_update(self) -> bool:
"""Bool: indicates whether the storage supports updating items."""
return not isinstance(self._store, zarr.storage.ZipStore)
def __repr__(self):
return f'{self.__class__.__name__}({self._root.store}, mode="{self.mode.name}")'
[docs]
def close(self) -> None:
if self._close:
self._store.close()
self._root = None
super().close()
def _get_parent(self, loc: Sequence[str]) -> tuple[zarr.Group, str]:
"""Get the parent group for a particular location.
Args:
loc (list of str):
The location in the storage where the group will be created
Returns:
(group, str):
A tuple consisting of the parent group and the name of the current item
"""
try:
path, name = loc[:-1], loc[-1]
except IndexError as err:
raise KeyError(f"Location `/{'/'.join(loc)}` has no parent") from err
parent = self._root
for part in path:
try:
parent = parent[part]
except KeyError:
parent = parent.create_group(part, overwrite=False)
return parent, name
def __getitem__(self, loc: Sequence[str]) -> Any:
if len(loc) == 0:
return self._root
else:
parent, name = self._get_parent(loc)
return parent[name]
[docs]
def keys(self, loc: Sequence[str] | None = None) -> Collection[str]:
if loc:
return self[loc].keys() # type: ignore
else:
return self._root.keys() # type: ignore
[docs]
def is_group(self, loc: Sequence[str], *, ignore_cls: bool = False) -> bool:
return isinstance(self[loc], zarr.Group)
def _create_group(self, loc: Sequence[str]) -> None:
parent, name = self._get_parent(loc)
parent.create_group(name)
def _read_attrs(self, loc: Sequence[str]) -> AttrsLike:
return self[loc].attrs # type: ignore
def _write_attr(self, loc: Sequence[str], name: str, value: str) -> None:
self[loc].attrs[name] = value
def _read_array(
self, loc: Sequence[str], *, copy: bool, index: int | None = None
) -> np.ndarray:
arr_like = self[loc]
if not isinstance(arr_like, zarr.Array):
raise RuntimeError(
f"Found {arr_like.__class__} at location `/{'/'.join(loc)}`"
)
is_recarray = arr_like.attrs.get("__recarray__", False)
if index is not None:
arr_like = arr_like[index]
# convert it into the right type
arr = np.array(arr_like, copy=copy)
if is_recarray:
arr = arr.view(np.recarray)
return arr
def _write_array(self, loc: Sequence[str], arr: np.ndarray) -> None:
parent, name = self._get_parent(loc)
if name in parent:
# update an existing array assuming it has the same shape. The credentials
# for this operation need to be checked by the caller!
parent[name][...] = arr
else:
# create a new array element
if arr.dtype == object:
el = parent.create_array(name, data=arr, overwrite=True)
else:
el = parent.create_array(name, data=arr, overwrite=True)
if isinstance(arr, np.recarray):
el.attrs["__recarray__"] = True
def _create_dynamic_array(
self,
loc: Sequence[str],
shape: tuple[int, ...],
dtype: DTypeLike,
record_array: bool = False,
) -> None:
parent, name = self._get_parent(loc)
try:
element = parent.zeros(
name=name, shape=(0,) + shape, chunks=(1,) + shape, dtype=dtype
)
except zarr.errors.ContainsArrayError as err:
raise RuntimeError(f"Array `/{'/'.join(loc)}` already exists") from err
else:
if record_array:
element.attrs["__recarray__"] = True
def _extend_dynamic_array(self, loc: Sequence[str], data: ArrayLike) -> None:
arr_obj = self[loc]
arr_obj.append([data])
def _read_object(self, loc: Sequence[str]) -> Any:
stored = self[loc][0].item()
# Check if the object was pickled (starts with pickle marker)
if stored.startswith("__pickle__:"):
obj_data = stored[11:] # Remove the marker
return pickle.loads(base64.b64decode(obj_data))
else:
# Otherwise it's JSON
return json.loads(stored)
def _write_object(self, loc: Sequence[str], obj: Any) -> None:
# Try JSON serialization first
try:
obj_enc = json.dumps(obj)
except (TypeError, ValueError):
# Fall back to pickle for non-JSON-serializable objects
pickled = pickle.dumps(obj)
obj_enc = "__pickle__:" + base64.b64encode(pickled).decode("ascii")
data = np.array([obj_enc])
parent, name = self._get_parent(loc)
parent.create_array(name, data=data, overwrite=True)