Source code for modelrunner.storage.trajectory

"""
Classes that describe time-dependent data, i.e., trajectories.

.. autosummary::
   :nosignatures:

   TrajectoryWriter
   Trajectory

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

from pathlib import Path
from typing import Any, Iterator, Literal

import numpy as np

from modelrunner.storage.access_modes import ModeType

from .access_modes import AccessError
from .base import StorageBase
from .group import StorageGroup
from .tools import open_storage
from .utils import Location, storage_actions


[docs]class TrajectoryWriter: """writes trajectories into a storage Stored data can then be read using :class:`Trajectory`. Example: .. code-block:: python # write data using context manager with TrajectoryWriter("test.zarr") as writer: for t, data in simulation: writer.append(data, t) # append to same file using explicit class interface writer = TrajectoryWriter("test.zarr", mode="append") writer.append(data0) writer.append(data1) writer.close() # read data trajectory = Trajectory("test.zarr") assert trajectory[-1] == data1 """ _item_type: Literal["array", "object"] def __init__( self, storage: str | Path | StorageGroup | StorageBase, loc: Location = "trajectory", *, attrs: dict[str, Any] | None = None, mode: ModeType | None = None, ): """ Args: store: Store or path to directory in file system or name of zip file. loc (str or list of str): The location in the storage where the trajectory data is written. attrs (dict): Additional attributes stored in the trajectory. mode (str or :class:`~modelrunner.storage.access_modes.AccessMode`): The file mode with which the storage is accessed. Determines allowed operations. The meaning of the special (default) value `None` depends on whether the file given by `store` already exists. If yes, a RuntimeError is raised, otherwise the choice corresponds to `mode="full"` and thus creates a new trajectory. If the file exists, use `mode="truncate"` to overwrite file or `mode="append"` to insert new data into the file. """ # create the root group where we store all the data if mode is None: if isinstance(storage, (str, Path)) and Path(storage).exists(): raise RuntimeError( 'Storage already exists. Use `mode="truncate"` to overwrite file ' 'or `mode="append"` to insert new data into the file.' ) mode = "full" self._storage = open_storage(storage, mode=mode) if self._storage.mode.insert: self._trajectory = self._storage.create_group(loc, cls=Trajectory) elif self._storage.mode.dynamic_append: self._trajectory = StorageGroup(self._storage, loc) else: raise AccessError(f"Cannot insert data. Open storage with write access") # make sure we don't overwrite data if "times" in self._trajectory or "data" in self._trajectory: if not self._storage.mode.dynamic_append: raise OSError("Storage already contains data and we cannot append") self._item_type = self._trajectory.attrs["item_type"] if attrs is not None: self._trajectory.write_attrs(attrs=attrs) @property def times(self) -> np.ndarray: """:class:`~numpy.ndarray`: Time points written so far""" return self._trajectory.read_array("time")
[docs] def append(self, data: Any, time: float | None = None) -> None: """append data to the trajectory Args: data: The data to append to the trajectory time (float, optional): The associated time point. If omitted, the last time point is incremented by one. """ if "data" not in self._trajectory: # initialize new trajectory if isinstance(data, np.ndarray): dtype = data.dtype shape: tuple[int, ...] = data.shape self._item_type = "array" else: dtype = object shape = () self._item_type = "object" self._trajectory.create_dynamic_array("data", shape=shape, dtype=dtype) self._trajectory.create_dynamic_array("time", shape=(), dtype=float) self._trajectory.write_attrs(None, {"item_type": self._item_type}) if time is None: time = 0.0 else: if time is None: time = float(self._trajectory.read_array("time", index=-1)) + 1.0 if self._item_type == "array": self._trajectory.extend_dynamic_array("data", data) elif self._item_type == "object": arr: np.ndarray = np.empty((), dtype=object) arr[...] = data self._trajectory.extend_dynamic_array("data", arr) else: raise NotImplementedError self._trajectory.extend_dynamic_array("time", time)
[docs] def close(self): self._storage.close()
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close()
[docs]class Trajectory: """Reads trajectories written with :class:`TrajectoryWriter` The class permits direct access to indivdual items in the trajcetory using the square bracket notation. It is also possible to iterate over all items. Attributes: times (:class:`~numpy.ndarray`): Time points at which data is available """ _item_type: Literal["array", "object"] def __init__(self, storage: StorageGroup, loc: Location = "trajectory"): """ Args: storage (MutableMapping or string): Store or path to directory in file system or name of zip file. loc (str or list of str): The location in the storage where the trajectory data is read. """ # open the storage self._storage = open_storage(storage, mode="read") self._loc = self._storage._get_loc(loc) # self._storage = storage#StorageGroup(storage, loc) # read some intial data from storage self.attrs = self._storage.read_attrs(self._loc) self._item_type = self.attrs["item_type"] self.times = self._storage.read_array(self._loc + ["time"]) # check temporal ordering if np.any(np.diff(self.times) < 0): raise ValueError(f"Times are not monotonously increasing: {self.times}")
[docs] def close(self) -> None: """close the openend storage""" self._storage.close()
def __len__(self) -> int: return len(self.times) def _get_item(self, t_index: int) -> Any: """return the data object corresponding to the given time index Load the data given an index, i.e., the data at time `self.times[t_index]`. Args: t_index (int): The index of the data to load Returns: The requested item """ if t_index < 0: t_index += len(self) if not 0 <= t_index < len(self): raise IndexError("Time index out of range") res = self._storage.read_array(self._loc + ["data"], index=t_index) if self._item_type == "array": return res elif self._item_type == "object": return res.item() else: raise NotImplementedError def __getitem__(self, key: int) -> Any: """return field at given index or a list of fields for a slice""" if isinstance(key, int): return self._get_item(key) else: raise TypeError("Unknown key type") def __iter__(self) -> Iterator[Any]: """iterate over all stored fields""" for i in range(len(self)): yield self[i]
storage_actions.register("read_item", Trajectory, Trajectory)