"""
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
from typing import Any, Collection, Iterator
import numpy as np
from numpy.typing import ArrayLike, DTypeLike
from .attributes import Attrs
from .base import StorageBase
from .utils import Array, Location, decode_class, encode_class, storage_actions
# TODO: Provide .attrs attribute with a descriptor protocol (implemented by the backend)
# TODO: Provide a simple viewer of the tree structure (e.g. a `tree` method)
[docs]class StorageGroup:
"""refers to a group within a storage"""
def __init__(self, storage: StorageBase | StorageGroup, loc: Location = None):
"""
Args:
storage (:class:`StorageBase` or :class:`StorageGroup`):
The storage where the group is defined. If this is a
:class:`StorageGroup` itself, `loc` is interpreted relative to that
group
loc (str or list of str):
Denotes the location (path) of the group within the storage
"""
self.loc = [] # initialize empty location, since `loc` is relative to root
self.loc = self._get_loc(loc)
if isinstance(storage, StorageBase):
self._storage = storage
elif isinstance(storage, StorageGroup):
self.loc = storage.loc + self.loc
self._storage = storage._storage
else:
raise TypeError(
f"Cannot interprete `storage` of type `{storage.__class__}`"
)
assert isinstance(self._storage, StorageBase)
if self._storage.closed:
raise RuntimeError("Cannot access group in closed storage")
if self.loc not in self._storage:
raise RuntimeError(
f'"/{"/".join(self.loc)}" is not in storage. Available root items are: '
f"{list(self._storage.keys(loc=[]))}"
)
if not self.is_group():
raise RuntimeError(f'"/{"/".join(self.loc)}" is not a group')
def __repr__(self):
return f'StorageGroup(storage={self._storage}, loc="/{"/".join(self.loc)}")'
@property
def parent(self) -> StorageGroup:
""":class:`StorageGroup`: Parent group
Raises:
RuntimeError: If current group is root group
"""
if self.loc:
return StorageGroup(self._storage, loc=self.loc[:-1])
else:
raise RuntimeError("Root group has no parent")
[docs] def tree(self) -> None:
"""print the hierarchical storage as a tree structure"""
vertic = "│ "
cross = "├──"
corner = "└──"
space = " "
def print_tree(loc: list[str], header: str = ""):
"""recursive function printing information about one group"""
group = StorageGroup(self._storage, loc)
for i, key in enumerate(sorted(group.keys())):
last = i == len(group) - 1
if self._storage.is_group(loc + [key]):
cls = self._storage._read_attrs(loc).get("__class__")
if cls is None:
# item is a sub group
print(header + (corner if last else cross) + key)
print_tree(
loc + [key], header=header + (space if last else vertic)
)
else:
# item contains information to restore a certain class
print(header + (corner if last else cross) + f"{key} ({cls})")
else:
# item is a simple, scalar item
print(header + (corner if last else cross) + key)
if self.loc:
print("/" + "/".join(self.loc))
print_tree(self.loc)
def _get_loc(self, loc: Location) -> list[str]:
"""return a normalized location from various input
Args:
loc (str or list of str):
location in a general formation. For instance, "/" is interpreted as a
group separator.
Returns:
list of str: A list of the individual location items
"""
# TODO: use regex to check whether loc is only alphanumerical and has no "/"
def parse_loc(loc_data) -> list[str]:
if loc_data is None or loc_data == "":
return []
elif isinstance(loc_data, str):
return loc_data.strip("/").split("/")
else:
return sum((parse_loc(k) for k in loc_data), start=list())
return self.loc + parse_loc(loc)
def __getitem__(self, loc: Location) -> Any:
"""read state or trajectory from storage"""
loc_list = self._get_loc(loc)
if self._storage.is_group(loc_list): # storage points to a group
if "__class__" not in self._storage._read_attrs(loc_list):
# group does not contain class information => just return a subgroup
return StorageGroup(self._storage, loc_list)
# reconstruct objected stored at this place
return self.read_item(loc, use_class=True)
[docs] def get(self, loc: Location, default: Any = None) -> Any:
try:
return self[loc]
except KeyError:
return default
def __setitem__(self, loc: Location, obj: Any) -> None:
self.write_item(loc, obj)
[docs] def keys(self) -> Collection[str]:
"""return name of all stored items in this group"""
return self._storage.keys(self.loc)
def __len__(self) -> int:
return len(self.keys())
def __iter__(self) -> Iterator[Any]:
"""iterate over all stored items in this group"""
for loc in self.keys():
yield self[loc]
def __contains__(self, loc: Location):
"""check wether a particular item is contained in this group"""
return self._get_loc(loc) in self._storage
[docs] def items(self) -> Iterator[tuple[str, Any]]:
"""iterate over stored items, yielding the location and item of each"""
for loc in self.keys():
yield loc, self[loc]
[docs] def read_attrs(self, loc: Location = None) -> Attrs:
"""read attributes associated with a particular location
Args:
loc (str or list of str):
The location in the storage where the attributes are read
Returns:
dict: A copy of the attributes at this location
"""
return self._storage.read_attrs(self._get_loc(loc))
[docs] def write_attrs(self, loc: Location = None, attrs: Attrs | None = None) -> None:
"""write attributes to a particular location
Args:
loc (str or list of str):
The location in the storage where the attributes are written
attrs (dict):
The attributes to be added to this location
"""
self._storage.write_attrs(self._get_loc(loc), attrs=attrs)
@property
def attrs(self) -> Attrs:
"""dict: the attributes associated with this group"""
return self.read_attrs()
[docs] def get_class(self, loc: Location = None) -> type | None:
"""get the class associated with a particular location
Class information can be written using the `cls` attribute of `write_array`,
`write_object`, and similar functions.
Args:
loc (str or list of str):
The location where the class information is read from
Retruns: the class associated with the lcoation
"""
loc_list = self._get_loc(loc)
attrs = self._storage._read_attrs(loc_list)
return decode_class(attrs.get("__class__"))
[docs] def read_item(self, loc: Location, *, use_class: bool = True) -> Any:
"""read an item from a particular location
Args:
loc (str or list of str):
The location where the item is read from
use_class (bool):
If `True`, looks for class information in the attributes and evokes a
potentially registered hook to instantiate the associated object. If
`False`, only the current data or object is returned.
Returns:
The reconstructed python object
"""
loc_list = self._get_loc(loc)
if use_class:
cls = self.get_class(loc)
if cls is not None:
# create object using a registered action
read_item = storage_actions.get(cls, "read_item")
return read_item(self._storage, loc_list)
# read the item using the generic classes
obj_type = self._storage._read_attrs(loc_list).get("__type__")
if obj_type in {"array", "dynamic_array"}:
arr = self._storage.read_array(loc_list)
return Array(arr, attrs=self._storage.read_attrs(loc_list))
elif obj_type == "object":
return self._storage.read_object(loc_list)
else:
raise RuntimeError(f"Cannot read objects of type `{obj_type}`")
[docs] def write_item(
self,
loc: Location,
item: Any,
*,
attrs: Attrs | None = None,
use_class: bool = True,
) -> None:
"""write an item to a particular location
Args:
loc (sequence of str):
The location where the item is written to
item:
The item that will be written
attrs (dict, optional):
Attributes stored with the object
use_class (bool):
If `True`, looks for class information in the attributes and evokes a
potentially registered hook to instantiate the associated object. If
`False`, only the current data or object is returned.
"""
# try writing the object using the class definition
if use_class:
try:
write_item = storage_actions.get(item.__class__, "write_item")
except RuntimeError:
pass # fall back to the generic writing
else:
loc_list = self._get_loc(loc)
write_item(self._storage, loc_list, item)
self._storage._write_attr(
loc_list, "__class__", encode_class(item.__class__)
)
return
# write the object using generic writers
if isinstance(item, np.ndarray):
self.write_array(loc, item, attrs=attrs)
else:
self.write_object(loc, item, attrs=attrs)
[docs] def is_group(self, loc: Location = None) -> bool:
"""determine whether the location is a group
Args:
loc (sequence of str):
A list of strings determining the location in the storage
Returns:
bool: `True` if the loation is a group
"""
return self._storage.is_group(self._get_loc(loc))
[docs] def open_group(self, loc: Location) -> StorageGroup:
"""open an existing group at a particular location
Args:
loc (str or list of str):
The location where the group will be opened
Returns:
:class:`StorageGroup`: The reference to the group
"""
loc_list = self._get_loc(loc)
if not self._storage.is_group(loc_list):
raise TypeError(f"`/{'/'.join(loc_list)}` is not a group")
return StorageGroup(self._storage, loc_list)
[docs] def create_group(
self,
loc: Location,
*,
attrs: Attrs | None = None,
cls: type | None = None,
) -> StorageGroup:
"""create a new group at a particular location
Args:
loc (str or list of str):
The location where the group will be created
attrs (dict, optional):
Attributes stored with the group
cls (type):
A class associated with this group
Returns:
:class:`StorageGroup`: The reference of the new group
"""
loc_list = self._get_loc(loc)
return self._storage.create_group(loc_list, attrs=attrs, cls=cls)
[docs] def read_array(
self,
loc: Location,
*,
out: np.ndarray | None = None,
index: int | None = None,
) -> np.ndarray:
"""read an array from a particular location
Args:
loc (str or list of str):
The location where the array is created
out (array, optional):
An array to which the results are written
index (int, optional):
An index denoting the subarray that will be read
Returns:
:class:`~numpy.ndarray`:
An array containing the data. Identical to `out` if specified.
"""
loc_list = self._get_loc(loc)
return self._storage.read_array(loc_list, out=out, index=index)
[docs] def write_array(
self,
loc: Location,
arr: np.ndarray,
*,
attrs: Attrs | None = None,
cls: type | None = None,
):
"""write an array to a particular location
Args:
loc (str or list of str):
The location where the array is read
arr (:class:`~numpy.ndarray`):
The array that will be written
attrs (dict, optional):
Attributes stored with the array
cls (type):
A class associated with this array
"""
loc_list = self._get_loc(loc)
self._storage.write_array(loc_list, arr, attrs=attrs, cls=cls)
[docs] def create_dynamic_array(
self,
loc: Location,
*,
arr: np.ndarray | None = None,
shape: tuple[int, ...] | None = None,
dtype: DTypeLike = float,
record_array: bool = False,
attrs: Attrs | None = None,
cls: type | None = None,
):
"""creates a dynamic array of flexible size
Args:
loc (str or list of str):
The location where the dynamic array is created
shape (tuple of int):
The shape of the individual arrays. A singular axis is prepended to the
shape, which can then be extended subsequently.
dtype:
The data type of the array to be written
record_array (bool):
Flag indicating whether the array is of type :class:`~numpy.recarray`
attrs (dict, optional):
Attributes stored with the array
cls (type):
A class associated with this array
"""
if arr is not None:
if shape is not None:
raise TypeError("Cannot set `arr` and `shape` simultanously")
shape = arr.shape
dtype = arr.dtype
record_array = isinstance(arr, np.recarray)
if shape is None:
raise TypeError("Either `arr` or `shape` need to be specified")
self._storage.create_dynamic_array(
self._get_loc(loc),
shape,
dtype=dtype,
record_array=record_array,
attrs=attrs,
cls=cls,
)
[docs] def extend_dynamic_array(self, loc: Location, data: ArrayLike):
"""extend a dynamic array previously created
Args:
loc (str or list of str):
The location where the dynamic array is located
arr (array):
The array that will be appended to the dynamic array
"""
self._storage.extend_dynamic_array(self._get_loc(loc), data)
[docs] def read_object(self, loc: Location) -> Any:
"""read an object from a particular location
Args:
loc (str or list of str):
The location where the object is created
Returns:
The object that has been read from the storage
"""
return self._storage.read_object(self._get_loc(loc))
[docs] def write_object(
self,
loc: Location,
obj: Any,
*,
attrs: Attrs | None = None,
cls: type | None = None,
):
"""write an object to a particular location
Args:
loc (str or list of str):
The location where the object is read
obj:
The object that will be written
attrs (dict, optional):
Attributes stored with the object
cls (type):
A class associated with this object
"""
loc_list = self._get_loc(loc)
self._storage.write_object(loc_list, obj, attrs=attrs, cls=cls)