Source code for modelrunner.run.compatibility.version1

"""
Contains code necessary for loading results from format version 1

.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

import json
import warnings
from abc import ABCMeta
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar, Union

import numpy as np

from ...model import ModelBase
from ..results import Result
from .triage import guess_format, normalize_zarr_store

if TYPE_CHECKING:
    import zarr

    zarrElement = Union[zarr.Group, zarr.Array]


[docs]class NoData: """helper class that marks data omission""" ...
TState = TypeVar("TState", bound="StateBase")
[docs]class StateBase(metaclass=ABCMeta): """Base class for specifying the state of a simulation A state contains values of all degrees of freedom of a physical system (called the `data`) and some additional information (called `attributes`). The `data` is mutable and often a numpy array or a collection of numpy arrays. Conversely, the `attributes` are stroed in a dictionary with immutable values. To allow flexible storage, we define the properties `_state_data` and `_state_attributes`, which by default return `attributes` and `data` directly, but may be overwritten to process the data before storage (e.g., by additional serialization). """ _state_format_version = 2 """int: number indicating the version of the file format""" _state_classes: dict[str, type[StateBase]] = {} """dict: class-level list of all subclasses of StateBase""" def _state_init(self, attributes: dict[str, Any], data=NoData) -> None: """initialize the state with attributes and (optionally) data Args: attributes (dict): Additional (unserialized) attributes data: The data of the degerees of freedom of the physical system """ if data is not NoData: self._state_data = data if attributes: self._state_attributes = attributes
[docs] @classmethod def from_data(cls: type[TState], attributes: dict[str, Any], data=NoData) -> TState: """create instance of any state class from attributes and data Args: attributes (dict): Additional (unserialized) attributes data: The data of the degerees of freedom of the physical system Returns: The object containing the given attributes and data """ # copy attributes since they are modified in this function attributes = attributes.copy() cls_name = attributes["__class__"] if cls_name != cls.__name__: raise RuntimeError(f"Expected `{cls.__name__}` but data had `{cls_name}`") format_version = attributes.pop("__version__", 0) if format_version != 1: warnings.warn(f"File format version mismatch ({format_version} != 1)") # create a new object without calling __init__, which might be overwriten by # the subclass and not follow our interface obj = cls.__new__(cls) obj._state_init(attributes, data) return obj
@classmethod def _from_simple_objects(cls, content: dict[str, Any]) -> StateBase: """create state from text data Args: content: The loaded data """ cls_name = content["attributes"]["__class__"] if cls_name == "ArrayState": return ArrayState.from_data(content["attributes"], content["data"]) elif cls_name == "ArrayCollectionState": col_data = tuple( np.array(content["data"][label]) for label in content["attributes"]["labels"] ) return ArrayCollectionState.from_data(content["attributes"], col_data) elif cls_name == "DictState": dict_data: dict[str, StateBase] = {} for label, substate in content["data"].items(): dict_data[label] = cls._from_simple_objects(substate) return DictState.from_data(content["attributes"], dict_data) elif cls_name == "ObjectState": return ObjectState.from_data(content["attributes"], content["data"]) else: raise TypeError(f"Do not know how to load `{cls_name}`") @classmethod def _from_zarr(cls, zarr_element: zarrElement, *, index=...) -> StateBase: """create instance of correct subclass from data stored in zarr""" # determine the class that knows how to read this data cls_name = zarr_element.attrs["__class__"] attributes = zarr_element.attrs.asdict() # create object form zarr data if cls_name == "ArrayState": data = zarr_element[index] return ArrayState.from_data(attributes, data) elif cls_name == "ArrayCollectionState": data = data = tuple( zarr_element[label][index] for label in zarr_element.attrs["labels"] ) return ArrayCollectionState.from_data(attributes, data) elif cls_name == "DictState": data = { label: cls._from_zarr(zarr_element[label], index=index) for label in zarr_element.attrs["__keys__"] } return DictState.from_data(attributes, data) elif cls_name == "ObjectState": if zarr_element.shape == () and index is ...: data = zarr_element[index].item() else: data = zarr_element[index] return ObjectState.from_data(attributes, data) else: raise TypeError(f"Do not know how to load `{cls_name}`")
[docs]class ArrayState(StateBase): ...
[docs]class ArrayCollectionState(StateBase): ...
[docs]class DictState(StateBase): ...
[docs]class ObjectState(StateBase): ...
def _Result_from_simple_objects( content: dict[str, Any], model: ModelBase | None = None ) -> Result: """read result from simple object (like loaded from a JSON file) using version 1 Args: content (dict): Data from which the result is restored model (:class:`ModelBase`): Model associated with the result """ format_version = content.pop("__version__", None) if format_version != 1: raise RuntimeError(f"Cannot read format version {format_version}") return Result.from_data( model_data=content.get("model", {}), result=StateBase._from_simple_objects(content["state"]), info=content.get("info"), ) def _Result_from_zarr( zarr_element: zarrElement, *, index=..., model: ModelBase | None = None ) -> Result: """create result from data stored in zarr""" attributes = {key: json.loads(value) for key, value in zarr_element.attrs.items()} # extract version information from attributes format_version = attributes.pop("__version__", None) if format_version != 1: raise RuntimeError(f"Cannot read format version {format_version}") info = attributes.pop("__info__", {}) # load additional info # the remaining attributes correspond to the model model_data = attributes # load state state = StateBase._from_zarr(zarr_element["state"]) return Result.from_data(model_data=model_data, result=state, model=model, info=info)
[docs]def result_from_file_v1(store: Path, *, label: str = "data", **kwargs) -> Result: """load object from a file using format version 1 Args: store (Path): Path of file to read fmt (str): Explicit file format. Determined from `store` if omitted. label (str): Name of the node in which the data was stored. This applies to some hierarchical storage formats. """ fmt = guess_format(store) if fmt == "json": with open(store) as fp: content = json.load(fp) return _Result_from_simple_objects(content, **kwargs) elif fmt == "yaml": import yaml with open(store) as fp: content = yaml.safe_load(fp) return _Result_from_simple_objects(content, **kwargs) elif fmt == "zarr": import zarr # @Reimport zarr_store = normalize_zarr_store(store, mode="r") root = zarr.open_group(zarr_store, mode="r") return _Result_from_zarr(root[label], **kwargs) else: raise NotImplementedError(f"Format `{fmt}` not implemented")