Source code for modelrunner.storage.base

"""
Base classes for managing hierarchical storage in which data is stored

The storage classes provide low-level abstraction to store data in a hierarchical format
and should thus not be used directly. Instead, the user typically interacts with
:class:`~modelrunner.storage.group.StorageGroup` objects, i.e., returned by
:func:`~modelrunner.storage.tools.open_storage`.

The role of `StorageBase` is to ensure access rights and provide an interface that can
be specified easily by subclasses to provide new storage formats. In contrast, the
interface of `StorageGroup` is more user-friendly and provides additional convenience
methods.

The main structure of the storage is a hierarchical tree of *groups*, which can contain
other groups or specific data items. Currently, items can be either arrays or arbitrary
objects, which are serialized transparently. Moreover, each group and each item can have
attributes, which are a mapping with string keys and arbitrary values, which are also
serialized transparently. Note that keys with double underscores are reserved for
internal use and should thus not be used.

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de> 
"""

from __future__ import annotations

import logging
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Any, Collection, Literal, Sequence

import numcodecs
import numpy as np
from numpy.typing import ArrayLike, DTypeLike

from .access_modes import AccessError, AccessMode, ModeType, _access_closed
from .attributes import Attrs, AttrsLike, decode_attrs, encode_attr
from .utils import encode_class

if TYPE_CHECKING:
    from .group import StorageGroup


[docs]class StorageBase(metaclass=ABCMeta): """base class for storing data""" extensions: list[str] = [] """list of str: all file extensions supported by this storage""" default_codec = numcodecs.Pickle() """:class:`numcodecs.Codec`: the default codec used for encoding binary data""" mode: AccessMode """:class:`~modelrunner.storage.access_modes.AccessMode`: access mode""" _codec: numcodecs.abc.Codec """:class:`numcodecs.Codec`: the specific codec used for encoding binary data""" def __init__(self, *, mode: ModeType = "read"): """ Args: mode (str or :class:`~modelrunner.storage.access_modes.AccessMode`): The file mode with which the storage is accessed. Determines allowed operations. """ self.mode = AccessMode.parse(mode) self._logger = logging.getLogger(self.__class__.__name__)
[docs] def close(self) -> None: """closes the storage, potentially writing data to a persistent place""" self.mode = _access_closed
@property def closed(self) -> bool: """bool: determines whether the storage has been closed""" return self.mode is _access_closed @property def can_update(self) -> bool: """bool: indicates whether the storage supports updating items""" # we are using a property instead of an attribute to make this read-only return True
[docs] def flush(self) -> None: """write (cached) data to storage"""
@property def codec(self) -> numcodecs.abc.Codec: """:class:`~numcodecs.abc.Codec`: A codec used to encode binary data""" try: return self._codec except AttributeError: attrs = self._read_attrs([]) if "__codec__" in attrs: self._codec = numcodecs.get_codec(attrs["__codec__"]) else: self._codec = self.default_codec if self.mode.set_attrs: self._write_attr([], "__codec__", self._codec.get_config()) return self._codec
[docs] @abstractmethod def keys(self, loc: Sequence[str]) -> Collection[str]: """return all sub-items defined at a given location Args: loc (sequence of str): A list of strings determining the location in the storage Returns: list: a list of all items defined at this location """
def __contains__(self, loc: Sequence[str]): if not loc: return True # the root is always contained in the storage try: return loc[-1] in self.keys(loc[:-1]) except KeyError: return False
[docs] @abstractmethod def is_group(self, loc: Sequence[str]) -> bool: """determine whether the location is a group Args: loc (sequence of str): A list of strings determining the location in the storage Returns: bool: `True` if the loation is a group """
@abstractmethod def _create_group(self, loc: Sequence[str]) -> None: """create a group at a particular location Args: loc (sequence of str): A list of strings determining the location in the storage """
[docs] def create_group( self, loc: Sequence[str], *, attrs: Attrs | None = None, cls: type | None = None, ) -> StorageGroup: """create a new group at a particular location Args: loc (list of str): The location in the storage where the group will be created attrs (dict, optional): Attributes stored with the group cls (type): A class associated with this group. The class will be used to re-create the object when this group is later accessed directly. Returns: :class:`StorageGroup`: The reference of the new group """ from .group import StorageGroup # @Reimport to avoid circular import if loc in self: # group already exists if self.mode.overwrite: pass # group already exists, but we can overwrite things else: # we cannot overwrite anything raise AccessError(f"Group `/{'/'.join(loc)}` already exists") else: # group needs to be created if not self.mode.insert: raise AccessError(f"No right to insert group `/{'/'.join(loc)}`") # create all parent groups for i in range(len(loc)): if loc[: i + 1] not in self: self._create_group(loc[: i + 1]) self._write_item_attrs(loc, attrs, cls=cls) return StorageGroup(self, loc)
[docs] def ensure_group(self, loc: Sequence[str]) -> None: """ensures the a group exists in the storage If the group is not already in the storage, it is created (recursively). Args: loc (list of str): The group location in the storage """ if loc not in self: # check whether we can insert a group if not self.mode.insert: raise AccessError(f"No right to insert group `/{'/'.join(loc)}`") # create group self.create_group(loc)
@abstractmethod def _read_attrs(self, loc: Sequence[str]) -> AttrsLike: """read attributes at a particular location Args: loc (sequence of str): A list of strings determining the location in the storage """
[docs] def read_attrs(self, loc: Sequence[str]) -> Attrs: """read attributes associated with a particular location Args: loc (list of str): The location in the storage where the attributes are read Returns: dict: A copy of the attributes at this location """ if not self.mode.read: raise AccessError("No right to read attributes") attrs = { k: v for k, v in self._read_attrs(loc).items() if not k.startswith("__") } return decode_attrs(attrs)
@abstractmethod def _write_attr(self, loc: Sequence[str], name: str, value: str) -> None: """write a single attribute to a particular location Args: loc (list of str): The location in the storage where the attributes are written name (str): Name of the attribute value (str): Value of the attribute """
[docs] def write_attrs(self, loc: Sequence[str], attrs: Attrs | None) -> None: """write attributes to a particular location Args: loc (list of str): The location in the storage where the attributes are written attrs (dict): The attributes to be added to this location """ # check whether we can insert anything if not self.mode.set_attrs: raise AccessError(f"No right to set attributes of `/{'/'.join(loc)}`") # check whether there are actually any attributes to be written if attrs is None or len(attrs) == 0: return for name, value in attrs.items(): if name.startswith("__"): # do not encode internal attributes self._write_attr(loc, name, value) else: # serialize and encode all foreign attributes self._write_attr(loc, name, encode_attr(value))
def _write_item_attrs( self, loc: Sequence[str], attrs: Attrs | None, *, item_type: Literal["array", "dynamic_array", "object"] | None = None, cls: type | None = None, ) -> None: """write attributes to a particular location Args: loc (list of str): The location in the storage where the attributes are written attrs (dict): The attributes to be added to this location item_type (str): Information about the type of the item cls (type): Class information that needs to be stored alongside """ if attrs is None: attrs = {} if item_type is not None: attrs.setdefault("__type__", str(item_type)) if cls is not None: attrs.setdefault("__class__", encode_class(cls)) self.write_attrs(loc, attrs) def _check_write_access(self, loc: Sequence[str], *, name: str = "item") -> None: """check whether we can safely write to a location Args: loc (list of str): The location in the storage where the array is read name (str): A name of the item appearing in error messages """ if not loc: raise RuntimeError(f"Cannot write {name} to the storage root") elif loc in self: # check whether we can overwrite the existing array if not self.can_update: raise RuntimeError("Storage does not support updating items") if not self.mode.overwrite: raise AccessError(f"{name} `/{'/'.join(loc)}` already exists in {self}") else: # check whether we can insert a new array if not self.mode.insert: raise AccessError(f"No right to insert {name} at `/{'/'.join(loc)}`") # make sure the parent group exists self.ensure_group(loc[:-1]) def _read_array( self, loc: Sequence[str], *, copy: bool, index: int | None = None, ) -> np.ndarray: """read an array from a particular location Args: loc (list of str): The location in the storage where the array is read copy (bool): Determines whether a copy of the data is returned. Set this flag to `False` for better performance in cases where the array is not modified. index (int, optional): An index denoting the subarray that will be read Returns: :class:`~numpy.ndarray`: An array containing the data. Identical to `out` if specified. """ raise NotImplementedError(f"Cannot read arrays from {self.__class__.__name__}")
[docs] def read_array( self, loc: Sequence[str], *, out: np.ndarray | None = None, index: int | None = None, ) -> np.ndarray: """read an array from a particular location Args: loc (list of str): The location in the storage where the array is read out (array): An array to which the results are written index (int, optional): An index denoting the subarray that will be read Returns: :class:`~numpy.ndarray`: An array containing the data. Identical to `out` if specified. """ if not self.mode.read: raise AccessError("No right to read array") if out is not None: out[:] = self._read_array(loc, index=index, copy=False) else: out = self._read_array(loc, index=index, copy=True) return out
def _write_array(self, loc: Sequence[str], arr: np.ndarray) -> None: raise NotImplementedError(f"Cannot write arrays in {self.__class__.__name__}")
[docs] def write_array( self, loc: Sequence[str], arr: np.ndarray, *, attrs: Attrs | None = None, cls: type | None = None, ) -> None: """write an array to a particular location Args: loc (list of str): The location in the storage where the array is read arr (:class:`~numpy.ndarray`): The array that will be written attrs (dict, optional): Attributes stored with the array cls (type): A class associated with this array. The class will be used to re-create the object when this array is later accessed. If no class is supplied, a generic `~modelrunner.storage.utils.Array` will be returned. """ self._check_write_access(loc, name="array") self._write_array(loc, arr) self._write_item_attrs(loc, attrs, cls=cls, item_type="array")
def _create_dynamic_array( self, loc: Sequence[str], shape: tuple[int, ...], *, dtype: DTypeLike, record_array: bool = False, ) -> None: raise NotImplementedError(f"No dynamic arrays for {self.__class__.__name__}")
[docs] def create_dynamic_array( self, loc: Sequence[str], shape: tuple[int, ...], *, dtype: DTypeLike = float, record_array: bool = False, attrs: Attrs | None = None, cls: type | None = None, ) -> None: """creates a dynamic array of flexible size Args: loc (list of str): The location in the storage where the dynamic array is created shape (tuple of int): The shape of the individual arrays. A singular axis is prepended to the shape, which can then be extended subsequently. dtype: The data type of the array to be written record_array (bool): Flag indicating whether the array is of type :class:`~numpy.recarray` attrs (dict, optional): Attributes stored with the array cls (type): A class associated with this array """ self._check_write_access(loc, name="array") self._create_dynamic_array( loc, tuple(shape), dtype=dtype, record_array=record_array ) self._write_item_attrs(loc, attrs, cls=cls, item_type="dynamic_array")
def _extend_dynamic_array(self, loc: Sequence[str], arr: ArrayLike) -> None: raise NotImplementedError(f"No dynamic arrays for {self.__class__.__name__}")
[docs] def extend_dynamic_array(self, loc: Sequence[str], arr: ArrayLike) -> None: """extend a dynamic array previously created Args: loc (list of str): The location in the storage where the dynamic array is located arr (array): The array that will be appended to the dynamic array """ if not self.mode.dynamic_append: raise AccessError(f"Cannot append data to dynamic array `/{'/'.join(loc)}`") if self._read_attrs(loc).get("__type__") != "dynamic_array": raise RuntimeError(f"Cannot extend array at `/{'/'.join(loc)}`") self._extend_dynamic_array(loc, arr)
def _read_object(self, loc: Sequence[str]) -> Any: raise NotImplementedError(f"Cannot read objects from {self.__class__.__name__}")
[docs] def read_object(self, loc: Sequence[str]) -> Any: """read an object from a particular location Args: loc (list of str): The location in the storage where the object is created Returns: The object that has been read from the storage """ if not self.mode.read: raise AccessError("No right to read object") if self._read_attrs(loc).get("__type__") != "object": raise RuntimeError(f"No object stored at `/{'/'.join(loc)}`") return self._read_object(loc)
def _write_object(self, loc: Sequence[str], obj: Any) -> None: raise NotImplementedError(f"Cannot write objects in {self.__class__.__name__}")
[docs] def write_object( self, loc: Sequence[str], obj: Any, *, attrs: Attrs | None = None, cls: type | None = None, ) -> None: """write an object to a particular location Args: loc (list of str): The location in the storage where the object is read. obj: The object that will be written attrs (dict, optional): Attributes stored with the object cls (type): A class associated with this object. The class will be used to re-create the object when this object is later accessed. If no class is supplied, a generic python object will be returned. """ self._check_write_access(loc, name="object") self._write_object(loc, obj) self._write_item_attrs(loc, attrs, cls=cls, item_type="object")