Source code for modelrunner.storage.backend.text_base

"""
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

import io
from abc import ABCMeta, abstractmethod
from io import StringIO
from pathlib import Path
from typing import Any, Sequence

import numpy as np
from numpy.lib.recfunctions import (
    structured_to_unstructured,
    unstructured_to_structured,
)
from numpy.typing import ArrayLike, DTypeLike

from ..access_modes import AccessError, ModeType
from ..utils import decode_binary, encode_binary
from .memory import MemoryStorage
from .utils import simplify_data


[docs]class TextStorageBase(MemoryStorage, metaclass=ABCMeta): """base class for storage that stores data in a text file Note that the data is only written once the storage is closed. """ def __init__( self, path: str | Path, *, mode: ModeType = "read", simplify: bool = True, **kwargs, ): """ Args: path (str or :class:`~pathlib.Path`): File path to the file mode (str or :class:`~modelrunner.storage.access_modes.AccessMode`): The file mode with which the storage is accessed. Determines allowed operations. simplify (bool): Flag indicating whether the data is stored in a simplified form """ super().__init__(mode=mode) self.simplify = simplify self._path = Path(path) self._write_flags = kwargs self._modified = False if self.mode.file_mode in {"r", "x", "a"}: if self._path.exists(): if self.mode.file_mode == "x": raise FileExistsError(f"File `{path}` already exists") # read content from file with open(self._path) as fp: data = self._read_data_from_fp(fp) # interprete empty files correctly self._data = {} if data is None else data def __repr__(self): return f'{self.__class__.__name__}("{self._path}", ' f'mode="{self.mode.name}")'
[docs] def flush(self) -> None: """write (cached) data to storage""" if self.mode.file_mode in {"x", "a", "w"}: # Write the data to the writeable file. Note that we do not check the # self._modified flag since it might not capture all changes, e.g., when an # item (attribute, array, or object) was modified in place data = simplify_data(self._data) if self.simplify else self._data with open(self._path, mode="w") as fp: self._write_data_to_fp(fp, data) self._modified = False # reset modifications elif self._modified: # The storage was modified, but it cannot be written to the file. This # should not happen, but it's better to throw an explicit error raise AccessError("Cannot write to file")
[docs] def close(self) -> None: """close the file and write the data to the file""" self.flush() super().close()
[docs] def to_text(self, simplify: bool | None = None) -> str: """serialize the data and return it as a string Args: simplify (bool): Flag indicating whether the data is stored in a simplified form. If `None`, the object-level value is used. """ if simplify is None: simplify = self.simplify data = simplify_data(self._data) if self.simplify else self._data with StringIO() as fp: self._write_data_to_fp(fp, data) return fp.getvalue()
@abstractmethod def _read_data_from_fp(self, fp: io.TextIOBase): """read data from an open file Args: fp (:class:`io.TextIOBase`): The opened text file """ @abstractmethod def _write_data_to_fp(self, fp: io.TextIOBase, data) -> None: """write data to an open file Args: fp (:class:`io.TextIOBase`): The opened text file data: The data to write """ def _write_attr(self, loc: Sequence[str], name: str, value: str) -> None: super()._write_attr(loc, name, value) self._modified = True def _read_array( self, loc: Sequence[str], *, copy: bool, index: int | None = None ) -> np.ndarray: # read the data from the location if index is None: arr = self[loc]["data"] else: arr = self[loc]["data"][index] if hasattr(arr, "__iter__"): # minimal sanity check dtype = decode_binary(self[loc]["dtype"]) if dtype.names is not None: arr = unstructured_to_structured( np.asarray(arr), dtype=dtype, copy=copy ) else: arr = np.array(arr, dtype=dtype, copy=copy) if self[loc].get("record_array", False): arr = arr.view(np.recarray) return arr # type: ignore else: raise RuntimeError(f"No array at `/{'/'.join(loc)}`") def _write_array(self, loc: Sequence[str], arr: np.ndarray) -> None: parent, name = self._get_parent(loc, check_write=True) dtype = arr.dtype # extract dtype here since `arr` is changed later if dtype.names is not None: # structured array arr = structured_to_unstructured(arr) parent[name] = { "data": np.array(arr, copy=True), "dtype": encode_binary(dtype, binary=False), } if isinstance(arr, np.recarray): parent[name]["record_array"] = True self._modified = True def _create_dynamic_array( self, loc: Sequence[str], shape: tuple[int, ...], dtype: DTypeLike, *, record_array: bool = False, ) -> None: parent, name = self._get_parent(loc, check_write=True) if name in parent: raise RuntimeError(f"Array `/{'/'.join(loc)}` already exists") parent[name] = { "data": [], "shape": shape, "dtype": encode_binary(np.dtype(dtype), binary=False), } if record_array: parent[name]["record_array"] = True self._modified = True def _extend_dynamic_array(self, loc: Sequence[str], arr: ArrayLike) -> None: item = self[loc] # check data shape that is stored at this position data = np.asanyarray(arr) stored_shape = tuple(item["shape"]) if stored_shape != data.shape: raise TypeError(f"Shape mismatch ({stored_shape} != {data.shape})") # convert the data to the correct format stored_dtype = decode_binary(item["dtype"]) if not np.issubdtype(data.dtype, stored_dtype): raise TypeError(f"Dtype mismatch ({data.dtype} != {stored_dtype}") if data.dtype.names is not None: # structured array data = structured_to_unstructured(data) # append the data to the dynamic array if data.ndim == 0: item["data"].append(data.item()) else: item["data"].append(np.array(data, copy=True)) self._modified = True def _read_object(self, loc: Sequence[str]) -> Any: return self.codec.decode(self[loc]["data"]) def _write_object(self, loc: Sequence[str], obj: Any) -> None: parent, name = self._get_parent(loc, check_write=True) parent[name] = {"data": self.codec.encode(obj)} self._modified = True