"""
Infrastructure for managing classes with parameters.
One aim is to allow easy management of inheritance of parameters.
.. autosummary::
:nosignatures:
Parameter
DeprecatedParameter
HideParameter
Parameterized
get_all_parameters
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations
import copy
import logging
import warnings
from dataclasses import dataclass, field
from typing import Any, Callable, Container, Dict, Iterator, List, Optional, Union
import numpy as np
from ..utils import hybridmethod, import_class
[docs]class NoValueType:
"""special value to indicate no value for a parameter"""
def __repr__(self):
return "NoValue"
NoValue = NoValueType()
[docs]def auto_type(value):
"""convert value to float or int if reasonable"""
try:
float_val = float(value)
except (TypeError, ValueError):
return value
try:
int_val = int(value)
except (ValueError, OverflowError):
return float_val
if int_val == float_val:
return int_val
else:
return float_val
[docs]@dataclass
class Parameter:
"""class representing a single parameter
Args:
name (str):
The name of the parameter
default_value:
The default value of the parameter
cls:
The type of the parameter, which is used for conversion. The conversion and
parsing of values can be disabled by using the default class `object`.
description (str):
A string describing the impact of this parameter. This description appears
in the parameter help.
choices (container):
A list or set of values that the parameter can take. Values (including the
default value) that are not in this list will be rejected. Note that values
are check after they have been converted by `cls`, so specifying `cls` is
particularly important to convert command line parameters from strings.
required (bool):
Whether the parameter is required
hidden (bool):
Whether the parameter is hidden in the description summary
extra (dict):
Extra arguments that are stored with the parameter
"""
name: str
default_value: Any = None
cls: type | Callable = object
description: str = ""
choices: Container | None = None
required: bool = False
hidden: bool = False
extra: dict[str, Any] = field(default_factory=dict)
def _check_value(self, value) -> None:
"""checks whether the value is acceptable"""
if value is not None and self.choices is not None and value not in self.choices:
raise ValueError(f"Default value `{value}` not in `{self.choices}`")
def __post_init__(self):
"""check default values and cls"""
if self.cls is not object and not any(
self.default_value is v for v in {None, NoValue}
):
# check whether the default value is of the correct type
try:
converted_value = self.cls(self.default_value)
except TypeError as err:
raise TypeError(
f"Parameter {self.name} of type {self.cls} has invalid default "
f"value: {self.default_value}"
) from err
self._check_value(converted_value)
if isinstance(converted_value, np.ndarray):
# numpy arrays are checked for each individual value
valid_default = np.allclose(
converted_value, self.default_value, equal_nan=True
)
else:
# other values are compared directly. Note that we also check identity
# to capture the case where the value is `math.nan`, where the direct
# comparison (nan == nan) would evaluate to False
valid_default = (
converted_value is self.default_value
or converted_value == self.default_value
)
if not valid_default:
logging.warning(
f"Default value `{self.name}` is not of type `{self.cls.__name__}`"
)
def __getstate__(self):
# replace the object class by its class path
return {
"name": str(self.name),
"default_value": self.convert(),
"cls": self.cls.__module__ + "." + self.cls.__name__,
"description": self.description,
"choices": self.choices,
"required": self.required,
"hidden": self.hidden,
"extra": self.extra,
}
def __setstate__(self, state):
# restore the object from the class path
state["cls"] = import_class(state["cls"])
# restore the state
self.__dict__.update(state)
@property
def short_description(self) -> str:
"""return only the first sentence of the description"""
return self.description.split(". ", 1)[0]
[docs] def convert(self, value=NoValue, *, strict: bool = True):
"""converts a `value` into the correct type for this parameter. If `value` is
not given, the default value is converted.
Note that this does not make a copy of the values, which could lead to
unexpected effects where the default value is changed by an instance.
Args:
value:
The value to convert
strict (bool):
Flag indicating whether conversion to the type indicated by `cls` is
enforced. If `False`, the original value is returned when conversion
fails.
Returns:
The converted value, which is of type `self.cls`
"""
if value is NoValue:
value = self.default_value
if value is NoValue or value is None:
pass # treat these values special
elif self.cls is object:
value = auto_type(value)
else:
try:
value = self.cls(value)
except (TypeError, ValueError) as err:
if strict:
raise ValueError(
f"Could not convert {value!r} to {self.cls.__name__} for "
f"parameter '{self.name}'"
) from err
# else: just return the value unchanged
self._check_value(value)
return value
def _argparser_add(self, parser):
"""add a command line option for this parameter to a parser"""
if not self.hidden:
if self.description:
description = self.description
else:
description = f"Parameter `{self.name}`"
arg_name = "--" + self.name
kwargs = {
"required": self.required,
"choices": self.choices,
"default": self.default_value,
"help": description,
}
if self.cls is bool:
# parameter is a boolean that we want to adjust
if self.default_value is False:
# allow enabling the parameter
parser.add_argument(
arg_name, action="store_true", default=False, help=description
)
elif self.default_value is True:
# allow disabling the parameter
parser.add_argument(
f"--no-{self.name}",
dest=self.name,
action="store_false",
default=True,
help=description,
)
else:
# no default value => allow setting it
flag_parser = parser.add_mutually_exclusive_group(required=True)
flag_parser.add_argument(
arg_name, dest=self.name, action="store_true", help=description
)
flag_parser.add_argument(
f"--no-{self.name}", dest=self.name, action="store_false"
)
# in python 3.9, we could use `argparse.BooleanOptionalAction`
elif issubclass(self.cls, (list, tuple, set)):
parser.add_argument(arg_name, metavar="VALUE", nargs="*", **kwargs)
elif self.cls is object or self.cls is auto_type:
parser.add_argument(arg_name, metavar="VALUE", **kwargs)
else:
parser.add_argument(arg_name, type=self.cls, metavar="VALUE", **kwargs)
[docs]class DeprecatedParameter(Parameter):
"""a parameter that can still be used normally but is deprecated"""
pass
[docs]class HideParameter:
"""a helper class that allows hiding parameters of the parent classes
This parameter will still appear in the :attr:`parameters` dictionary, but it will
typically not be visible to the user, e.g., when calling :meth:`show_parameters`.
"""
def __init__(self, name: str):
"""
Args:
name (str):
The name of the parameter
"""
self.name = name
def _argparser_add(self, parser):
pass
ParameterListType = List[Union[Parameter, HideParameter]]
ParameterInputType = Optional[Dict[str, Any]]
[docs]class Parameterized:
"""a mixin that manages the parameters of a class"""
parameters_default: ParameterListType = []
"""list: parameters (with default values) of this subclass"""
_parameters_default_full: ParameterListType = []
"""list: all parameters (including those of parent classes)"""
_subclasses: dict[str, type[Parameterized]] = {}
"""dict: a dictionary of all classes inheriting from `Parameterized`"""
def __init__(self, parameters: ParameterInputType = None, *, strict: bool = True):
"""initialize the parameters of the object
Args:
parameters (dict):
A dictionary of parameters to change the defaults. The allowed
parameters can be obtained from
:meth:`~Parameterized.get_parameters` or displayed by calling
:meth:`~Parameterized.show_parameters`.
strict (bool):
Flag indicating whether parameters are strictly interpreted. If `True`,
only parameters listed in `parameters_default` can be set and their type
will be enforced.
"""
# set logger if this has not happened, yet
if not hasattr(self, "_logger"):
self._logger = logging.getLogger(self.__class__.__name__)
# set parameters if they have not been initialized, yet
if not hasattr(self, "parameters"):
self.parameters = self._parse_parameters(
parameters, include_deprecated=True, check_validity=strict
)
def __init_subclass__(cls, **kwargs) -> None: # @NoSelf
"""register all subclasses to reconstruct them later"""
# normalize the parameters_default attribute to be a list of `Parameter`
if hasattr(cls, "parameters_default") and isinstance(
cls.parameters_default, dict
):
# default parameters are given as a dictionary
cls.parameters_default = [
Parameter(*args) for args in cls.parameters_default.items()
]
# combine parameters with those of the parent class
parameters_default: dict[str, Parameter] = {}
for p in cls._parameters_default_full + cls.parameters_default:
if isinstance(p, HideParameter):
if p.name in parameters_default:
parameters_default[p.name].hidden = True
else:
parameters_default[p.name] = copy.copy(p)
cls._parameters_default_full = list(parameters_default.values())
# Note that `_parameters_default_full` also includes hidden parameters
# append the list of parameters to the end of the docstring
parameter_doc = list(
cls._get_parameters_str(
description=True,
sort=True,
short_description=True,
template=" * `{name}`: {description} (default={value!r})",
template_object=" * `{name}`: {description} (default={value!r})",
)
)
if parameter_doc:
extra_doc = "Parameters Dictionary:\n" + "\n".join(parameter_doc)
if cls.__doc__:
cls.__doc__ += "\n\n" + extra_doc
else:
cls.__doc__ = extra_doc
# register this subclass
super().__init_subclass__(**kwargs)
if cls is not Parameterized:
if cls.__name__ in cls._subclasses:
warnings.warn(f"Redefining class `{cls.__name__}`")
cls._subclasses[cls.__name__] = cls
[docs] @classmethod
def get_parameters(
cls,
include_hidden: bool = False,
include_deprecated: bool = False,
sort: bool = True,
) -> dict[str, Parameter]:
"""return a dictionary of parameters that the class supports
Args:
include_hidden (bool):
Include hidden parameters
include_deprecated (bool):
Include deprecated parameters
sort (bool):
Return ordered dictionary with sorted keys
Returns:
dict: a dictionary mapping names to instances of :class:`Parameter`
"""
# collect the parameters from the class hierarchy
parameters: dict[str, Parameter] = {}
for p in cls._parameters_default_full:
if isinstance(p, HideParameter):
if include_hidden:
parameters[p.name].hidden = True
else:
del parameters[p.name]
else:
parameters[p.name] = p
# filter parameters based on hidden and deprecated flags
def show(p):
"""helper function to decide whether a parameter will be shown"""
# show based on hidden flag?
show1 = include_hidden or not p.hidden
# show based on deprecated flag?
show2 = include_deprecated or not isinstance(p, DeprecatedParameter)
return show1 and show2
# filter parameters based on `show`
result = {
name: parameter for name, parameter in parameters.items() if show(parameter)
}
if sort:
result = dict(sorted(result.items()))
return result
@classmethod
def _parse_parameters(
cls,
parameters: ParameterInputType = None,
*,
check_validity: bool = True,
allow_hidden: bool = True,
include_deprecated: bool = False,
) -> dict[str, Any]:
"""parse parameters from a given dictionary
Args:
parameters (dict):
A dictionary of parameters that will be parsed.
check_validity (bool):
Determines whether a `ValueError` is raised if there are keys in
parameters that are not in the defaults. If `False`, additional items
are simply stored in `self.parameters`
allow_hidden (bool):
Allow setting hidden parameters
include_deprecated (bool):
Include deprecated parameters
Returns:
dict: The parsed parameters
"""
if parameters is None:
parameters = {}
else:
parameters = parameters.copy() # do not modify the original
# obtain all possible parameters
param_objs = cls.get_parameters(
include_hidden=allow_hidden, include_deprecated=include_deprecated
)
# initialize parameters with default ones from all parent classes
result: dict[str, Any] = {}
for name, param_obj in param_objs.items():
if not allow_hidden and param_obj.hidden:
continue # skip hidden parameters
if param_obj.required and name not in parameters:
raise ValueError(f"Require parameter `{name}`")
# take value from parameters or set default value
value = parameters.pop(name, NoValue)
# convert parameter to correct type
result[name] = param_obj.convert(value, strict=check_validity)
# update parameters with the supplied ones
if check_validity and parameters:
raise ValueError(
f"Parameters `{sorted(parameters.keys())}` were provided for an "
f"instance but are not defined for the class `{cls.__name__}`"
)
else:
result.update(parameters) # add remaining parameters
return result
[docs] @classmethod
def get_parameter_default(cls, name): # @NoSelf
"""return the default value for the parameter with `name`
Args:
name (str): The parameter name
"""
for p in cls._parameters_default_full:
if isinstance(p, Parameter) and p.name == name:
return p.default_value
raise KeyError(f"Parameter `{name}` is not defined")
@classmethod
def _get_parameters_str(
cls,
*,
description: bool = False,
sort: bool = False,
show_hidden: bool = False,
show_deprecated: bool = False,
short_description: bool = False,
parameter_values: ParameterInputType = None,
template: str | None = None,
template_object: str | None = None,
) -> Iterator[str]:
"""private method showing all parameters in human readable format
Args:
description (bool):
Flag determining whether the parameter description is shown.
sort (bool):
Flag determining whether the parameters are sorted
show_hidden (bool):
Flag determining whether hidden parameters are shown
show_deprecated (bool):
Flag determining whether deprecated parameters are shown
short_description (bool):
Whether to show a shortended version of the description
parameter_values (dict):
A dictionary with values to show. Parameters not in this dictionary are
shown with their default value.
All flags default to `False`.
"""
# set the templates for displaying the data
if template is None:
template = "{name}: {type} = {value!r}"
if description:
template += " ({description})"
if template_object is None:
template_object = "{name} = {value!r}"
if description:
template_object += " ({description})"
# iterate over all parameters
params = cls.get_parameters(
include_hidden=show_hidden, include_deprecated=show_deprecated, sort=sort
)
for param in params.values():
# initialize the data to show
data = {
"name": param.name,
"type": param.cls.__name__,
"description": (
param.short_description if short_description else param.description
),
}
# determine the value to show
if parameter_values is None:
data["value"] = param.default_value
else:
data["value"] = parameter_values[param.name]
# print the data to stdout
if param.cls is object:
yield template_object.format(**data)
else:
yield template.format(**data)
@hybridmethod
def show_parameters( # @NoSelf
cls,
description: bool = False,
sort: bool = False,
show_hidden: bool = False,
show_deprecated: bool = False,
) -> None:
"""show all parameters in human readable format
Args:
description (bool):
Flag determining whether the parameter description is shown.
sort (bool):
Flag determining whether the parameters are sorted
show_hidden (bool):
Flag determining whether hidden parameters are shown
show_deprecated (bool):
Flag determining whether deprecated parameters are shown
All flags default to `False`.
"""
for line in cls._get_parameters_str(
description=description,
sort=sort,
show_hidden=show_hidden,
show_deprecated=show_deprecated,
):
print(line)
[docs] @show_parameters.instancemethod # type: ignore
def show_parameters(
self,
description: bool = False,
sort: bool = False,
show_hidden: bool = False,
show_deprecated: bool = False,
default_value: bool = False,
) -> None:
"""show all parameters in human readable format
Args:
description (bool):
Flag determining whether the parameter description is shown.
sort (bool):
Flag determining whether the parameters are sorted
show_hidden (bool):
Flag determining whether hidden parameters are shown
show_deprecated (bool):
Flag determining whether deprecated parameters are shown
default_value (bool):
Flag determining whether the default values or the current values are
shown
All flags default to `False`.
"""
for line in self._get_parameters_str(
description=description,
sort=sort,
show_hidden=show_hidden,
show_deprecated=show_deprecated,
parameter_values=None if default_value else self.parameters,
):
print(line)
[docs]def get_all_parameters(data: str = "name") -> dict[str, Any]:
"""get a dictionary with all parameters of all registered classes
Args:
data (str):
Determines what data is returned. Possible values are 'name', 'value', or
'description', to return the respective information about the parameters.
"""
result = {}
for cls_name, cls in Parameterized._subclasses.items():
if data == "name":
parameters = set(cls.get_parameters().keys())
elif data == "value":
parameters = { # type: ignore
k: v.default_value for k, v in cls.get_parameters().items()
}
elif data == "description":
parameters = { # type: ignore
k: v.description for k, v in cls.get_parameters().items()
}
else:
raise ValueError(f"Cannot interpret data `{data}`")
result[cls_name] = parameters
return result