"""
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import io
from abc import ABCMeta, abstractmethod
from io import StringIO
from pathlib import Path
from typing import Any, Sequence
import numpy as np
from numpy.lib.recfunctions import (
structured_to_unstructured,
unstructured_to_structured,
)
from numpy.typing import ArrayLike, DTypeLike
from ..access_modes import AccessError, ModeType
from ..utils import decode_binary, encode_binary
from .memory import MemoryStorage
from .utils import simplify_data
[docs]class TextStorageBase(MemoryStorage, metaclass=ABCMeta):
"""base class for storage that stores data in a text file
Note that the data is only written once the storage is closed.
"""
def __init__(
self,
path: str | Path,
*,
mode: ModeType = "read",
simplify: bool = True,
**kwargs,
):
"""
Args:
path (str or :class:`~pathlib.Path`):
File path to the file
mode (str or :class:`~modelrunner.storage.access_modes.AccessMode`):
The file mode with which the storage is accessed. Determines allowed
operations.
simplify (bool):
Flag indicating whether the data is stored in a simplified form
"""
super().__init__(mode=mode)
self.simplify = simplify
self._path = Path(path)
self._write_flags = kwargs
self._modified = False
if self.mode.file_mode in {"r", "x", "a"}:
if self._path.exists():
if self.mode.file_mode == "x":
raise FileExistsError(f"File `{path}` already exists")
# read content from file
with open(self._path) as fp:
data = self._read_data_from_fp(fp)
# interprete empty files correctly
self._data = {} if data is None else data
def __repr__(self):
return f'{self.__class__.__name__}("{self._path}", ' f'mode="{self.mode.name}")'
[docs] def flush(self) -> None:
"""write (cached) data to storage"""
if self.mode.file_mode in {"x", "a", "w"}:
# Write the data to the writeable file. Note that we do not check the
# self._modified flag since it might not capture all changes, e.g., when an
# item (attribute, array, or object) was modified in place
data = simplify_data(self._data) if self.simplify else self._data
with open(self._path, mode="w") as fp:
self._write_data_to_fp(fp, data)
self._modified = False # reset modifications
elif self._modified:
# The storage was modified, but it cannot be written to the file. This
# should not happen, but it's better to throw an explicit error
raise AccessError("Cannot write to file")
[docs] def close(self) -> None:
"""close the file and write the data to the file"""
self.flush()
super().close()
[docs] def to_text(self, simplify: bool | None = None) -> str:
"""serialize the data and return it as a string
Args:
simplify (bool):
Flag indicating whether the data is stored in a simplified form. If
`None`, the object-level value is used.
"""
if simplify is None:
simplify = self.simplify
data = simplify_data(self._data) if self.simplify else self._data
with StringIO() as fp:
self._write_data_to_fp(fp, data)
return fp.getvalue()
@abstractmethod
def _read_data_from_fp(self, fp: io.TextIOBase):
"""read data from an open file
Args:
fp (:class:`io.TextIOBase`): The opened text file
"""
@abstractmethod
def _write_data_to_fp(self, fp: io.TextIOBase, data) -> None:
"""write data to an open file
Args:
fp (:class:`io.TextIOBase`): The opened text file
data: The data to write
"""
def _write_attr(self, loc: Sequence[str], name: str, value: str) -> None:
super()._write_attr(loc, name, value)
self._modified = True
def _read_array(
self, loc: Sequence[str], *, copy: bool, index: int | None = None
) -> np.ndarray:
# read the data from the location
if index is None:
arr = self[loc]["data"]
else:
arr = self[loc]["data"][index]
if hasattr(arr, "__iter__"): # minimal sanity check
dtype = decode_binary(self[loc]["dtype"])
if dtype.names is not None:
arr = unstructured_to_structured(
np.asarray(arr), dtype=dtype, copy=copy
)
else:
arr = np.array(arr, dtype=dtype, copy=copy)
if self[loc].get("record_array", False):
arr = arr.view(np.recarray)
return arr # type: ignore
else:
raise RuntimeError(f"No array at `/{'/'.join(loc)}`")
def _write_array(self, loc: Sequence[str], arr: np.ndarray) -> None:
parent, name = self._get_parent(loc, check_write=True)
dtype = arr.dtype # extract dtype here since `arr` is changed later
if dtype.names is not None:
# structured array
arr = structured_to_unstructured(arr)
parent[name] = {
"data": np.array(arr, copy=True),
"dtype": encode_binary(dtype, binary=False),
}
if isinstance(arr, np.recarray):
parent[name]["record_array"] = True
self._modified = True
def _create_dynamic_array(
self,
loc: Sequence[str],
shape: tuple[int, ...],
dtype: DTypeLike,
*,
record_array: bool = False,
) -> None:
parent, name = self._get_parent(loc, check_write=True)
if name in parent:
raise RuntimeError(f"Array `/{'/'.join(loc)}` already exists")
parent[name] = {
"data": [],
"shape": shape,
"dtype": encode_binary(np.dtype(dtype), binary=False),
}
if record_array:
parent[name]["record_array"] = True
self._modified = True
def _extend_dynamic_array(self, loc: Sequence[str], arr: ArrayLike) -> None:
item = self[loc]
# check data shape that is stored at this position
data = np.asanyarray(arr)
stored_shape = tuple(item["shape"])
if stored_shape != data.shape:
raise TypeError(f"Shape mismatch ({stored_shape} != {data.shape})")
# convert the data to the correct format
stored_dtype = decode_binary(item["dtype"])
if not np.issubdtype(data.dtype, stored_dtype):
raise TypeError(f"Dtype mismatch ({data.dtype} != {stored_dtype}")
if data.dtype.names is not None:
# structured array
data = structured_to_unstructured(data)
# append the data to the dynamic array
if data.ndim == 0:
item["data"].append(data.item())
else:
item["data"].append(np.array(data, copy=True))
self._modified = True
def _read_object(self, loc: Sequence[str]) -> Any:
return self.codec.decode(self[loc]["data"])
def _write_object(self, loc: Sequence[str], obj: Any) -> None:
parent, name = self._get_parent(loc, check_write=True)
parent[name] = {"data": self.codec.encode(obj)}
self._modified = True