Source code for modelrunner.storage.backend.text_base

"""
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

import io
from abc import ABCMeta, abstractmethod
from collections.abc import Sequence
from io import StringIO
from pathlib import Path
from typing import Any

import numpy as np
from numpy.lib.recfunctions import (
    structured_to_unstructured,
    unstructured_to_structured,
)
from numpy.typing import ArrayLike, DTypeLike

from ..access_modes import AccessError, ModeType
from ..utils import decode_binary, encode_binary
from .memory import MemoryStorage
from .utils import simplify_data


[docs] class TextStorageBase(MemoryStorage, metaclass=ABCMeta): """Base class for storage that stores data in a text file. Note that the data is only written once the storage is closed. """ def __init__( self, path: str | Path, *, mode: ModeType = "read", simplify: bool = True, **kwargs, ): """ Args: path (str or :class:`~pathlib.Path`): File path to the file mode (str or :class:`~modelrunner.storage.access_modes.AccessMode`): The file mode with which the storage is accessed. Determines allowed operations. simplify (bool): Flag indicating whether the data is stored in a simplified form """ super().__init__(mode=mode) self.simplify = simplify self._path = Path(path) self._write_flags = kwargs self._modified = False if self.mode.file_mode in {"r", "x", "a"} and self._path.exists(): if self.mode.file_mode == "x": raise FileExistsError(f"File `{path}` already exists") # read content from file with self._path.open() as fp: data = self._read_data_from_fp(fp) # interprete empty files correctly self._data = {} if data is None else data def __repr__(self): return f'{self.__class__.__name__}("{self._path}", mode="{self.mode.name}")'
[docs] def flush(self) -> None: """Write (cached) data to storage.""" if self.mode.file_mode in {"x", "a", "w"}: # Write the data to the writeable file. Note that we do not check the # self._modified flag since it might not capture all changes, e.g., when an # item (attribute, array, or object) was modified in place data = simplify_data(self._data) if self.simplify else self._data with self._path.open("w") as fp: self._write_data_to_fp(fp, data) self._modified = False # reset modifications elif self._modified: # The storage was modified, but it cannot be written to the file. This # should not happen, but it's better to throw an explicit error raise AccessError("Cannot write modifications to file opened read-only")
[docs] def close(self) -> None: """Close the file and write the data to the file.""" self.flush() super().close()
[docs] def to_text(self, simplify: bool | None = None) -> str: """Serialize the data and return it as a string. Args: simplify (bool): Flag indicating whether the data is stored in a simplified form. If `None`, the object-level value is used. """ if simplify is None: simplify = self.simplify data = simplify_data(self._data) if self.simplify else self._data with StringIO() as fp: self._write_data_to_fp(fp, data) return fp.getvalue()
@abstractmethod def _read_data_from_fp(self, fp: io.TextIOBase): """Read data from an open file. Args: fp (:class:`io.TextIOBase`): The opened text file """ @abstractmethod def _write_data_to_fp(self, fp: io.TextIOBase, data) -> None: """Write data to an open file. Args: fp (:class:`io.TextIOBase`): The opened text file data: The data to write """ def _write_attr(self, loc: Sequence[str], name: str, value: str) -> None: super()._write_attr(loc, name, value) self._modified = True def _read_array( self, loc: Sequence[str], *, copy: bool, index: int | None = None ) -> np.ndarray: # read the data from the location if index is None: arr = self[loc]["data"] else: arr = self[loc]["data"][index] if hasattr(arr, "__iter__"): # minimal sanity check dtype = decode_binary(self[loc]["dtype"]) if dtype.names is not None: arr = unstructured_to_structured( np.asarray(arr), dtype=dtype, copy=copy ) elif copy: arr = np.array(arr, dtype=dtype, copy=True) else: arr = np.asarray(arr, dtype=dtype) if self[loc].get("record_array", False): arr = arr.view(np.recarray) return arr # type: ignore else: raise RuntimeError(f"No array at `/{'/'.join(loc)}`") def _write_array(self, loc: Sequence[str], arr: np.ndarray) -> None: parent, name = self._get_parent(loc, check_write=True) dtype = arr.dtype # extract dtype here since `arr` is changed later if dtype.names is not None: # structured array arr = structured_to_unstructured(arr) parent[name] = { "data": np.array(arr, copy=True), "dtype": encode_binary(dtype, binary=False), } if isinstance(arr, np.recarray): parent[name]["record_array"] = True self._modified = True def _create_dynamic_array( self, loc: Sequence[str], shape: tuple[int, ...], dtype: DTypeLike, *, record_array: bool = False, ) -> None: parent, name = self._get_parent(loc, check_write=True) if name in parent: raise RuntimeError(f"Array `/{'/'.join(loc)}` already exists") parent[name] = { "data": [], "shape": shape, "dtype": encode_binary(np.dtype(dtype), binary=False), } if record_array: parent[name]["record_array"] = True self._modified = True def _extend_dynamic_array(self, loc: Sequence[str], arr: ArrayLike) -> None: item = self[loc] # check data shape that is stored at this position data = np.asanyarray(arr) stored_shape = tuple(item["shape"]) if stored_shape != data.shape: raise TypeError(f"Shape mismatch ({stored_shape} != {data.shape})") # convert the data to the correct format stored_dtype = decode_binary(item["dtype"]) if not np.issubdtype(data.dtype, stored_dtype): raise TypeError(f"Dtype mismatch ({data.dtype} != {stored_dtype}") if data.dtype.names is not None: # structured array data = structured_to_unstructured(data) # append the data to the dynamic array if data.ndim == 0: item["data"].append(data.item()) else: item["data"].append(np.array(data, copy=True)) self._modified = True def _read_object(self, loc: Sequence[str]) -> Any: return self.codec.decode(self[loc]["data"]) def _write_object(self, loc: Sequence[str], obj: Any) -> None: parent, name = self._get_parent(loc, check_write=True) parent[name] = {"data": self.codec.encode(obj)} self._modified = True