"""
Defines a class storing data in memory.
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import copy
from typing import Any, Collection, Sequence
import numpy as np
from numpy.typing import ArrayLike, DTypeLike
from ..access_modes import AccessError, ModeType
from ..attributes import Attrs
from ..base import StorageBase
[docs]class MemoryStorage(StorageBase):
"""store items in memory"""
_data: dict[str, Any]
def __init__(self, *, mode: ModeType = "insert"):
"""
Args:
mode (str or :class:`~modelrunner.storage.access_modes.AccessMode`):
The file mode with which the storage is accessed. Determines allowed
operations.
"""
super().__init__(mode=mode)
self._data = {}
[docs] def clear(self) -> None:
"""truncate the storage by removing all stored data.
Args:
clear_data_shape (bool):
Flag determining whether the data shape is also deleted.
"""
self._data = {}
def _get_parent(
self, loc: Sequence[str], *, check_write: bool = False
) -> tuple[dict, str]:
"""get the parent group for a particular location
Args:
loc (list of str):
The location in the storage where the group will be created
check_write (bool):
Check whether the parent group is writable if `True`
Returns:
(group, str):
A tuple consisting of the parent group and the name of the current item
"""
if check_write and not self.is_group(loc[:-1]):
raise TypeError(f"Location `/{'/'.join(loc[:-1])}` is not a group")
parent = self._data
for part in loc[:-1]:
try:
parent = parent[part]
except KeyError:
if isinstance(parent, dict):
parent[part] = {}
parent = parent[part]
else:
raise TypeError(f"Cannot add item to `/{'/'.join(loc)}`")
if not isinstance(parent, dict):
raise TypeError(f"Cannot add item to `/{'/'.join(loc)}`")
try:
name = loc[-1]
except IndexError:
raise KeyError(f"Location `/{'/'.join(loc)}` has no parent")
if check_write and not self.mode.overwrite and name in parent:
raise AccessError(f"Overwriting `/{'/'.join(loc)}` disabled")
return parent, name
def __getitem__(self, loc: Sequence[str]) -> Any:
if loc:
parent, name = self._get_parent(loc)
try:
return parent[name]
except KeyError as e:
raise KeyError(f"Item `{name}` not in group {parent.keys()}") from e
else:
return self._data
[docs] def keys(self, loc: Sequence[str]) -> Collection[str]:
keys = self[loc].keys() if loc else self._data.keys()
return [k for k in keys if not k.startswith("__")]
[docs] def is_group(self, loc: Sequence[str]) -> bool:
item = self[loc]
if isinstance(item, dict):
# dictionaries are usually groups, unless they have the `__type__` attribute
return "__type__" not in item.get("__attrs__", {})
else:
return False # no group, since it's not a dictionary
def _create_group(self, loc: Sequence[str]) -> None:
parent, name = self._get_parent(loc, check_write=True)
parent[name] = {}
def _read_attrs(self, loc: Sequence[str]) -> Attrs:
res = self[loc].get("__attrs__", {})
if isinstance(res, dict):
return res
else:
raise RuntimeError(f"No attributes at `/{'/'.join(loc)}`")
def _write_attr(self, loc: Sequence[str], name: str, value: str) -> None:
item = self[loc]
if "__attrs__" not in item:
item["__attrs__"] = {name: value}
else:
item["__attrs__"][name] = value
def _read_array(
self, loc: Sequence[str], *, copy: bool, index: int | None = None
) -> np.ndarray:
# read the data from the location
if index is None:
arr = self[loc]["data"]
else:
arr = self[loc]["data"][index]
return np.array(arr, copy=copy)
def _write_array(self, loc: Sequence[str], arr: np.ndarray) -> None:
parent, name = self._get_parent(loc, check_write=True)
parent[name] = {"data": np.copy(arr)}
def _create_dynamic_array(
self,
loc: Sequence[str],
shape: tuple[int, ...],
dtype: DTypeLike,
*,
record_array: bool = False,
) -> None:
parent, name = self._get_parent(loc, check_write=True)
if name in parent:
raise RuntimeError(f"Array `/{'/'.join(loc)}` already exists")
parent[name] = {
"data": [],
"shape": tuple(shape),
"dtype": np.dtype(dtype),
}
if record_array:
parent[name]["record_array"] = True
def _extend_dynamic_array(self, loc: Sequence[str], arr: ArrayLike) -> None:
item = self[loc]
# check data shape that is stored at this position
data = np.asanyarray(arr)
stored_shape = item["shape"]
if stored_shape != data.shape:
raise TypeError(f"Shape mismatch ({stored_shape} != {data.shape})")
# convert the data to the correct format
stored_dtype = item["dtype"]
if not np.issubdtype(data.dtype, stored_dtype):
raise TypeError(f"Dtype mismatch ({data.dtype} != {stored_dtype}")
# append the data to the dynamic array
if data.ndim == 0:
item["data"].append(data.item())
else:
item["data"].append(np.array(data, copy=True))
def _read_object(self, loc: Sequence[str]) -> Any:
return self[loc]["data"]
def _write_object(self, loc: Sequence[str], obj: Any) -> None:
parent, name = self._get_parent(loc, check_write=True)
parent[name] = {"data": copy.deepcopy(obj)}