"""
Defines a class storing data in various storages
Requires the optional :mod:`zarr` module.
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import shutil
from pathlib import Path
from typing import Any, Collection, Sequence, Union
import numpy as np
import zarr
from numpy.typing import ArrayLike, DTypeLike
from zarr._storage.store import Store
from ..access_modes import ModeType
from ..attributes import AttrsLike
from ..base import StorageBase
zarrElement = Union[zarr.Group, zarr.Array]
[docs]class ZarrStorage(StorageBase):
"""storage that stores data in an zarr file or database"""
extensions = ["zarr", "zip", "sqldb", "lmdb"]
def __init__(self, store_or_path: str | Path | Store, *, mode: ModeType = "read"):
"""
Args:
store_or_path (str or :class:`~pathlib.Path` or :class:`~zarr._storage.store.Store`):
File path to the file/folder or a :mod:`zarr` Store
mode (str or :class:`~modelrunner.storage.access_modes.AccessMode`):
The file mode with which the storage is accessed. Determines allowed
operations.
"""
super().__init__(mode=mode)
if isinstance(store_or_path, (str, Path)):
# open zarr storage on file system
self._close = True
path = Path(store_or_path)
if self.mode.file_mode == "x" and path.exists():
raise FileExistsError(f"File `{path}` already exists")
if path.suffix in {"", ".zarr"}:
# path seems to be a directory or a zarr direction => DirectoryStore
if path.is_dir() and self.mode.file_mode == "w":
self._logger.info(f"Delete directory `{path}`")
shutil.rmtree(path) # remove the directory to reinstate it
if self.mode.file_mode == "r":
self._logger.info(f"DirectoryStore is always opened writable")
self._store = zarr.DirectoryStore(path)
elif path.suffix == ".zip":
# create a ZipStore
file_mode = self.mode.file_mode
if path.exists():
if file_mode == "w":
self._logger.info(f"Delete file `{path}`")
path.unlink()
self._store = zarr.storage.ZipStore(path, mode=file_mode)
elif path.suffix == ".sqldb":
# create a SQLiteStore
if self.mode.file_mode == "w" and path.exists():
self._logger.info(f"Delete file `{path}`")
path.unlink()
self._store = zarr.SQLiteStore(path)
elif isinstance(store_or_path, Store):
# use already opened zarr storage
self._close = False
self._store = store_or_path
else:
raise TypeError(f"Unknown store `{store_or_path}`")
self._root = zarr.group(store=self._store, overwrite=False)
@property
def can_update(self) -> bool:
"""bool: indicates whether the storage supports updating items"""
return not isinstance(self._store, zarr.storage.ZipStore)
def __repr__(self):
return f'{self.__class__.__name__}({self._root.store}, mode="{self.mode.name}")'
[docs] def close(self) -> None:
if self._close:
self._store.close()
self._root = None
super().close()
def _get_parent(self, loc: Sequence[str]) -> tuple[zarr.Group, str]:
"""get the parent group for a particular location
Args:
loc (list of str):
The location in the storage where the group will be created
Returns:
(group, str):
A tuple consisting of the parent group and the name of the current item
"""
try:
path, name = loc[:-1], loc[-1]
except IndexError:
raise KeyError(f"Location `/{'/'.join(loc)}` has no parent")
parent = self._root
for part in path:
try:
parent = parent[part]
except KeyError:
parent = parent.create_group(part, overwrite=False)
return parent, name
def __getitem__(self, loc: Sequence[str]) -> Any:
if len(loc) == 0:
return self._root
else:
parent, name = self._get_parent(loc)
return parent[name]
[docs] def keys(self, loc: Sequence[str] | None = None) -> Collection[str]:
if loc:
return self[loc].keys() # type: ignore
else:
return self._root.keys() # type: ignore
[docs] def is_group(self, loc: Sequence[str], *, ignore_cls: bool = False) -> bool:
return isinstance(self[loc], zarr.hierarchy.Group)
def _create_group(self, loc: Sequence[str]) -> None:
parent, name = self._get_parent(loc)
parent.create_group(name)
def _read_attrs(self, loc: Sequence[str]) -> AttrsLike:
return self[loc].attrs # type: ignore
def _write_attr(self, loc: Sequence[str], name: str, value: str) -> None:
self[loc].attrs[name] = value
def _read_array(
self, loc: Sequence[str], *, copy: bool, index: int | None = None
) -> np.ndarray:
arr_like = self[loc]
if not isinstance(arr_like, zarr.Array):
raise RuntimeError(
f"Found {arr_like.__class__} at location `/{'/'.join(loc)}`"
)
is_recarray = arr_like.attrs.get("__recarray__", False)
if index is not None:
arr_like = arr_like[index]
# convert it into the right type
arr = np.array(arr_like, copy=copy)
if is_recarray:
arr = arr.view(np.recarray)
return arr
def _write_array(self, loc: Sequence[str], arr: np.ndarray) -> None:
parent, name = self._get_parent(loc)
if name in parent:
# update an existing array assuming it has the same shape. The credentials
# for this operation need to be checked by the caller!
parent[name][...] = arr
else:
# create a new array element
if arr.dtype == object:
el = parent.array(name, arr, object_codec=self.codec, overwrite=True)
else:
el = parent.array(name, arr, overwrite=True)
if isinstance(arr, np.recarray):
el.attrs["__recarray__"] = True
def _create_dynamic_array(
self,
loc: Sequence[str],
shape: tuple[int, ...],
dtype: DTypeLike,
record_array: bool = False,
) -> None:
parent, name = self._get_parent(loc)
try:
if dtype == object:
element = parent.zeros(
name,
shape=(0,) + shape,
chunks=(1,) + shape,
dtype=dtype,
object_codec=self.codec,
overwrite=True,
)
else:
element = parent.zeros(
name, shape=(0,) + shape, chunks=(1,) + shape, dtype=dtype
)
except zarr.errors.ContainsArrayError:
raise RuntimeError(f"Array `/{'/'.join(loc)}` already exists")
else:
if record_array:
element.attrs["__recarray__"] = True
def _extend_dynamic_array(self, loc: Sequence[str], data: ArrayLike) -> None:
arr_obj = self[loc]
arr_obj.append([data])
def _read_object(self, loc: Sequence[str]) -> Any:
return self[loc][0]
def _write_object(self, loc: Sequence[str], obj: Any) -> None:
arr: np.ndarray = np.empty(1, dtype=object) # encode object in an array
arr[0] = obj
parent, name = self._get_parent(loc)
parent.array(name, arr, object_codec=self.codec, overwrite=True)