Source code for neuralib.persistence.persistence

"""
Persistence Class
=================

:author:
    Ta-Shun Su

Define a persistence class
---------------------------

Import

>>> import numpy as np
>>> from neuralib import persistence

Define class

>>> @persistence.persistence_class
... class Example:
...     # key. use filename keyword to display key in filename.
...     use_animal: str = persistence.field(validator=True, filename=True)
...     use_session: str = persistence.field(validator=True)
...     use_date: str = persistence.field(validator=True, filename=True)
...     # data
...     channels: list[int]
...     data: np.ndarray

Load/Save

>>> example = Example(use_animal='A00', use_session='', use_date='1234')
>>> save(example, 'example.pkl')
>>> example_2 = load(Example, 'example.pkl') # example and example_2 should be content identical


Cooperate with PersistenceOptions
----------------------------------

>>> from neuralib.persistence.cli_persistence import PersistenceOptions
>>> class ExampleHandle(PersistenceOptions[Example]):
...     def empty_cache(self) -> Example:
...         return Example(use_animal='A00', use_session='', use_date='1234') # with attribute initialization
...     def compute_cache(self, result: Example) -> Example:
...         result.channels = [0, 1, 2]
...         result.riglog = np.array(result.channels)
...         return result

Dynamic generated methods for persistence class
------------------------------------------------

1. `__init__({foreach persistence.field})`

    >>> class Example:
    ...     ... # as same as above
    ...     def __init__(self, use_animal:str,  use_session: str, use_date: str): # auto generated
    ...         ...

2. `__str__` return filename

    >>> class Example:
    ...     ... # as same as above
    ...     def __str__(self): # auto generated
    ...         return filename(self)

3. `__repr__`

    >>> class Example:
    ...     ... # as same as above
    ...     def __repr__(self): # auto generated
    ...         return 'Example{' + f'use_animal={self.use_animal}, use_session={self.use_session}, use_date={self.use_date}' + "}"

4. `_replace` when persistence define empty `_replace` methods. It is NamedTuple._replace like function.

    >>> class Example:
    ...     ... # as same as above
    ...     def _replace(self, **kwargs): pass # empty method
    ...     def _replace(self, *,  # replaced by generated
    ...                  use_animal=missing,
    ...                  use_session=missing,
    ...                  use_date=missing,
    ...                  channels=missing,
    ...                  data=missing) -> Example:
    ...         ...

Auto increment field
---------------------

For some reason you want to save a persistence result that came from same data source but different
contents, which are usualy generated by random or suffle process. To save them all separately, you may
need a field value that keep track that it is n-th persistence result. :func:`autoinc_field` is proposed
to help this case.


>>> @persistence_class
... class Result:
...     a: str = field(validator=True, filename=T)
...     b: int = autoinc_field()
...     c: str
...     def __init__(self, a: str, b: int = None): # auto generated signature
...     def _replace(self, *, a: str, c:str): # auto generated signature

There are some rule when a persistence class has an autoinc_field.

1. only one autoinc_field in a persistence class allowed.
2. field type only int is allowed.
3. raise error when load a result without autoinc field resolved.
4. autoinc field is auto resolved when saving. Its value is max(found) + 1

Pickle format
--------------

Persistence class is transformed into a dict by `as_dict`, which as a root instance to be saved into
a pickle file.

**IMPORTANT**

If a persistance class has a custom `__init__` function which signature is differed from auto generated,
you need to define a classmethod `from_dict` for creating that persistance class.

>>> @persistence.persistence_class
... class Example:
...     a: int = persistence.field(validator=True, filename=True)
...     b: int = persistence.field(validator=True, filename=True)
...     c: int
...     def __init__(self): ... # custom __init__
...     @classmethod
...     def from_dict(cls, data:dict[str, Any]) -> 'Example':
...         # data = {'a', 'b', 'c'}


"""

import abc
import inspect
import sys
from collections.abc import Callable, Iterator
from pathlib import Path
from typing import Any, Generic, Type, TypeVar, Union, get_type_hints

from neuralib.util.func import PARA_TYPE, create_fn

__all__ = [
    'field',
    'autoinc_field',
    'persistence_class',
    'ensure_persistence_class',
    'as_dict',
    'from_dict',
    'load',
    'save',
    'filename',
    'AutoIncFieldNotResolvedError',
    'auto_generated_content',
    'PersistenceHandler',
    'PickleHandler',
    'GzipHandler',
]

T = TypeVar('T')
P = Union[str, Path]

missing = inspect.Parameter.empty

VALIDATOR = Callable[[Any, Any], bool]


[docs] def field(validator: bool | VALIDATOR = False, filename_prefix: str = '', filename: bool | Callable[..., str] = False) -> Any: """Cache class's exported field. Used as keys to find correspond persistence. :param validator: validate this field. use __eq__ by default. Can be a callable as a customized validator. :param filename_prefix: prefix word of *filename* :param filename: display this value on filename. Can be a callable that return the string. :return: """ return PersistentField(validator, filename_prefix, filename, init=True)
[docs] def autoinc_field(filename_prefix: str = '') -> Any: """make a field as auto increment, which can be used to save and distinguish between same source data but different random/shuffle persistence result. :param filename_prefix: :return: """ # noinspection PyTypeChecker return PersistentField(validator=True, filename_prefix=filename_prefix, filename=True, init=True, autoinc=True)
class PersistentField(Generic[T]): """exported field of persistence class.""" __slots__ = ('field_name', 'field_type', 'validator', 'filename_prefix', 'filename', 'init', 'autoinc', 'optional') def __init__(self, validator: bool | VALIDATOR = False, filename_prefix: str = '', filename: bool | Callable[..., str] = False, init=True, autoinc=False, optional=False): """ :param validator: validate this field. use __eq__ by default. Can be a callable as a customized validator. :param filename_prefix: prefix word of *filename* :param filename: display this value on filename. Can be a callable that return the string. :param init: put this field into class __init__ :param autoinc: :param optional: does this field has default value """ self.validator = validator """Is this field a identify?""" self.filename_prefix = filename_prefix """The filename prefix word of this field value""" self.filename = filename """Does this field show in filename?""" self.init = init """Does this field is a __init__ argument?""" self.autoinc = autoinc """Does this field is auto incremental field?""" self.optional = optional """Does this field has default value?""" def __set_name__(self, owner: type, name: str): self.field_name = name if self.autoinc: if (field_type := get_type_hints(owner).get(name, Any)) is not int: raise RuntimeError(f'type of autoinc field {name} should be int, but {field_type}') else: self.field_type = get_type_hints(owner).get(name, Any) def validate(self, v: T, u: T) -> bool: """Validate this two value. :param v: :param u: :return: """ if self.validator is False: return True elif self.validator is True: return v == u elif callable(self.validator): return self.validator(v, u) else: raise RuntimeError()
[docs] class AutoIncFieldNotResolvedError(RuntimeError): """Raised when a persistence's autoinc field is not resolved, and it is unable to do following operations. """
[docs] def __init__(self, instance, field: str | PersistentField, message: str | None = None): if isinstance(field, PersistentField): field = field.field_name if message is None: message = f'{type(instance).__name__} autoinc field {field} is not resolved' super().__init__(message) self.instance = instance self.field = field
[docs] def persistence_class(cls: type | None = None, /, *, name: str | None = None, filename_field_splitter='-'): """A class decorator. Decorated class ... :param cls: Persistence class. :param name: class name as filename. :param filename_field_splitter: the field splitter on filename. :return: """ def decorator(cls: type): cls._ast_persistence_cls_info_ = pc = PersistentClass(cls, name, filename_field_splitter) prev_autoinc_field = None for attr_name, attr_type in cls.__annotations__.items(): if isinstance((attr_value := getattr(cls, attr_name, None)), PersistentField): delattr(cls, attr_name) f = attr_value else: f = PersistentField(init=False) f.field_name = attr_name f.field_type = attr_type f.optional = attr_value is not None if f.autoinc: if prev_autoinc_field is None: prev_autoinc_field = f else: raise RuntimeError('duplicated auto_inc_field') pc.fields.append(f) if cls.__init__ == object.__init__: cls.__init__ = _persistence_class_init(pc) if cls.__str__ == object.__str__: cls.__str__ = _persistence_class_str(pc) if cls.__repr__ == object.__repr__: cls.__repr__ = _persistence_class_repr(pc) if hasattr(cls, '_replace'): cls._replace = _persistence_class_replace(pc) if sys.version_info >= (3, 10): if not hasattr(cls, '__match_args__'): cls.__match_args__ = tuple([it.field_name for it in pc.fields]) return cls if cls is None: return decorator else: return decorator(cls)
class PersistentClass(Generic[T]): """persistence information class""" __slots__ = ('persistence_cls', 'cls_name', 'filename_field_splitter', 'fields') def __init__(self, persistence_cls: type, persistence_name: str | None = None, filename_field_splitter='-'): self.persistence_cls = persistence_cls self.cls_name = persistence_name if persistence_name is not None else persistence_cls.__name__ self.filename_field_splitter = filename_field_splitter self.fields: list[PersistentField] = [] def fields_name(self) -> list[str]: return [it.field_name for it in self.fields] def get_field(self, name: str) -> PersistentField | None: for f in self.fields: if f.field_name == name: return f return None def autoinc_field(self) -> PersistentField | None: for f in self.fields: if f.autoinc: return f return None def is_autoinc_field_resolved(self, result: T | None, **kwargs) -> bool: if (af := self.autoinc_field()) is None: return True return getattr(result, af.field_name, None) is not None or af.field_name in kwargs def validate(self, v: T, u: T | dict[str, Any]) -> bool: """validate that does data u is as same as data v. :param v: reference data :param u: tested data :return: False if validation fail. """ for f in self.fields: fv = getattr(v, f.field_name, None) if isinstance(u, dict): fu = u.get(f.field_name, None) else: fu = getattr(u, f.field_name, None) if not f.validate(fv, fu): raise FileNotFoundError(f'{self.cls_name}.{f.field_name} validate fail: {fv} != {fu}') return True def filename(self, data: T | None, **kwargs) -> str: """build filename for persistence instance. :param data: persistence instance. :param kwargs: overwrote keywords. :return: filename. :raise RuntimeError: filename field missing. :raise TypeError: wrong field.filename type. """ # legacy name ret = ['cache', self.cls_name] for f in self.fields: if f.filename: if f.field_name in kwargs: field_value = kwargs[f.field_name] else: if data is None or not hasattr(data, f.field_name): raise RuntimeError(f'field {f.field_name} is required for filename') field_value = getattr(data, f.field_name) if field_value is missing: s = '*' elif f.filename is True: s = str(field_value) elif callable(f.filename): if (s := f.filename(field_value)) is not None: s = str(s) else: continue else: raise TypeError() ret.append(f.filename_prefix + s) return self.filename_field_splitter.join(ret) def _persistence_class_init(pc: PersistentClass): """generate an __init__ function for persistent class.""" init_fields: list[PARA_TYPE] = ['self'] for f in pc.fields: if f.init: init_fields.append((f.field_name, None, 'None') if f.autoinc else f.field_name) code = [] for name in init_fields[1:]: if isinstance(name, tuple): name = name[0] code.append(f'self.{name} = {name}') return create_fn('__init__', init_fields, '\n'.join(code), locals=dict(missing=missing)) def _persistence_class_str(pc: PersistentClass): """generate a __str__ function for persistent class.""" return create_fn('__str__', (['self'], str), 'return pc.filename(self)', locals={'pc': pc}) def _persistence_class_repr(pc: PersistentClass): """generate a __repr__ function for persistent class.""" init_fields = [f.field_name for f in pc.fields if f.init] code = [f'return ("{pc.cls_name}' + '{"'] for i, name in enumerate(init_fields): comma = '", " ' if i > 0 else '' code.append(f'{comma}f"{name}={{self.{name}}}"') code.append('"}")') return create_fn('__repr__', (['self'], str), ' '.join(code)) def _persistence_class_replace(pc: PersistentClass): """generate a _replace function for persistent class.""" init_fields = [f.field_name for f in pc.fields if f.init] data_fields = [f.field_name for f in pc.fields if not f.init and not f.autoinc] code = [f'ret = {pc.cls_name}('] for name in init_fields: code.append(f'{name} if {name} is not missing else self.{name},') code.append(')') for name in data_fields: code.append(f'ret.{name} = {name} if {name} is not missing else self.{name}') code.append('return ret') replace_fields: list[PARA_TYPE] = ['self', '*'] for name in init_fields + data_fields: replace_fields.append((name, None, 'missing')) return create_fn('_replace', (replace_fields, pc.cls_name), '\n'.join(code), locals={pc.cls_name: pc.persistence_cls, 'missing': missing})
[docs] def auto_generated_content(**kwargs): """It is used to mark the function which its function body is auto generated. :param kwargs: blackhole :return: nothing """ raise RuntimeError('It is auto generated content')
[docs] def ensure_persistence_class(data: T | type[T]) -> PersistentClass[T]: """ensure **data** is a persistence class. :param data: instance or type :return: persistence info :raise RuntimeError: not a persistence class """ if not isinstance(data, type): data = type(data) try: cls_info: PersistentClass = data._ast_persistence_cls_info_ except AttributeError as e: raise RuntimeError(f'not a persistence_class : {data.__name__}') from e return cls_info
[docs] def as_dict(data: object) -> dict[str, Any]: """transform persistence *data* into dictionary, which field as key. :param data: persistence instance :return: dict """ if data is None: raise TypeError('data is None') info = ensure_persistence_class(data) ret = {} for field in info.fields: try: ret[field.field_name] = getattr(data, field.field_name) except AttributeError: pass return ret
[docs] def from_dict(data_cls: type[T], d: dict[str, Any]) -> T: """transform dictionary :param data_cls: :param d: :return: """ info = ensure_persistence_class(data_cls) def get_or_raise(key): try: return d[key] except KeyError: pass raise KeyError(f'missing required field : {key}') init = { f.field_name: get_or_raise(f.field_name) for f in info.fields if f.init } try: ret = data_cls(**init) except TypeError: if hasattr(data_cls, 'from_dict'): return _from_dict_factory(data_cls, info, d) else: raise else: return _from_dict_builtin(ret, info, d)
def _from_dict_builtin(ret: T, info: PersistentClass, d: dict[str, Any]) -> T: """After initialized *ret*, set attributes for remind fields.""" for f in info.fields: if not f.init: try: v = d[f.field_name] except KeyError: pass else: setattr(ret, f.field_name, v) return ret def _from_dict_factory(data_cls: type[T], info: PersistentClass, d: dict[str, Any]) -> T: """When *data_cls* initialize failed, try invoking `from_dict` function.""" kwargs = {} for f in info.fields: try: v = d[f.field_name] except KeyError: pass else: kwargs[f.field_name] = v ret = data_cls.from_dict(kwargs) if not isinstance(ret, data_cls): raise TypeError() return ret
[docs] def save(data: object, path: P) -> None: """save persistence **data** under **path**. :param data: persistence instance :param path: filepath """ path = Path(path) PickleHandler(type(data), path.parent).save_persistence(data, path)
[docs] def load(data_cls: type[T], path: P) -> T: """Load data as **data_cls** from **path**. :param data_cls: data class type :param path: filepath :return: persistence instance """ path = Path(path) return PickleHandler(data_cls, path.parent).load_persistence(path)
def load_by(data_cls: type[T], path: P, **kwargs) -> T: """Load **data_cls** from directory **path** with fields **kwargs**. :param data_cls: data class type :param path: directory :param kwargs: data fields :return: persistence instance """ path = Path(path) handler = PickleHandler(data_cls, path.parent) return handler.load_persistence(handler.filepath(None, **kwargs))
[docs] def filename(result: T | type[T], **kwargs) -> str: """Get data persistence filename. :param result: :param kwargs: overwrite fields. :return: filename :raise RuntimeError: *result*'s autoinc field not resolved, or filename field missing. :raise TypeError: wrong field.filename type. :raise AutoIncFieldNotResolvedError: *result*'s autoinc field not resolved. """ cls_info = ensure_persistence_class(result) if isinstance(result, type): result_instance: T | None = None else: result_instance = result if not cls_info.is_autoinc_field_resolved(result_instance, **kwargs): af = cls_info.autoinc_field() if af is None: raise RuntimeError('missing autoinc field') raise AutoIncFieldNotResolvedError(result, af, f'cannot generate filepath without autoinc field {af.field_name} keywords') name = cls_info.filename(result_instance, **kwargs) return name
[docs] class PersistenceHandler(Generic[T], metaclass=abc.ABCMeta): """The handler for loading and saving persistence instance.""" @property @abc.abstractmethod def persistence_class(self) -> type[T]: """:return: type T""" pass @property def persistence_info(self) -> PersistentClass[T]: """information for persistence class""" return ensure_persistence_class(self.persistence_class) @property @abc.abstractmethod def save_root(self) -> Path: """saving directory""" pass
[docs] def filename(self, result: T | None, **kwargs) -> str: """build filename for persistence instance. :param result: persistence instance :param kwargs: overwrite field value in *result*. :return: file name of *result*, may contains '{}' if *result*'s autoinc field not resolved """ cls_info = ensure_persistence_class(self.persistence_class) return cls_info.filename(result, **kwargs)
[docs] def filepath(self, result: T | None, **kwargs) -> Path: """build filepath for persistence instance. :param result: persistence instance :param kwargs: overwrite field value in *result*. :return: file path of *result* :raise RuntimeError: *result*'s autoinc field not resolved, or errors from :meth:`filename` """ info = ensure_persistence_class(self.persistence_class) if not info.is_autoinc_field_resolved(result, **kwargs): af = info.autoinc_field() if af is None: raise RuntimeError('missing autoinc field') raise AutoIncFieldNotResolvedError(result, af, f'cannot generate filepath without autoinc field {af.field_name} keywords') name = self.filename(result, **kwargs) return self.save_root / name
[docs] def validate(self, ref: T, res: T) -> bool: return ensure_persistence_class(ref).validate(ref, res)
[docs] def save_persistence(self, result: T, path: str | Path | None = None) -> T: """save persistence *result* under **path**. :param result: :param path: save path. :return: *result*. autoinc field will be resolved after saving. :raise AutoIncFieldNotResolvedError: """ info = ensure_persistence_class(result) if path is None: if not info.is_autoinc_field_resolved(result): f = info.autoinc_field() if f is None: raise RuntimeError('missing autoinc field') found = [it for _, it in self.load_all(result, **{f.field_name: '*'})] u = max([ value for it in found if isinstance(value := getattr(it, f.field_name, 0), int) ], default=-1) + 1 setattr(result, f.field_name, u) path = self.filepath(result) else: if not info.is_autoinc_field_resolved(result): f = info.autoinc_field() if f is None: raise RuntimeError('missing autoinc field') raise AutoIncFieldNotResolvedError(result, f) if isinstance(path, str): path = Path(path) if path.is_dir(): raise IsADirectoryError(str(path)) path.parent.mkdir(parents=True, exist_ok=True) self._save_persistence(result, path) return result
@abc.abstractmethod def _save_persistence(self, result: T, path: Path): pass
[docs] def load_persistence(self, path: Path | T | dict[str, Any]) -> T: """Load data as **data_cls** from **path** without validation. :param path: load from path. :return: persistence instance :raise IsADirectoryError: """ data_cls = self.persistence_class ensure_persistence_class(data_cls) if isinstance(path, dict): load_path = self.filepath(None, **path) elif isinstance(path, Path): load_path = path elif isinstance(path, str): load_path = Path(path) else: load_path = self.filepath(path) if load_path.is_dir(): raise IsADirectoryError(str(load_path)) return self._load_persistence(load_path)
@abc.abstractmethod def _load_persistence(self, path: Path) -> T: pass
[docs] def load_all(self, result: T | None, **kwargs) -> Iterator[tuple[Path, T]]: """load all persistent result under *save_root*. *missing* is used to make a field becomes a wildcard field. >>> template = Example(use_animal='A00', use_session='test', use_date='20200101') >>> # find all animal A00's persistent result. >>> found = PickleHandler(Example, Path('')).load_all(template, use_date=missing) """ for file in self.save_root.glob(self.filename(result, **kwargs)): yield file, self.load_persistence(file)
[docs] class PickleHandler(PersistenceHandler[T]): """ Support field type: all python objects. """
[docs] def __init__(self, data_cls: type[T], save_root: Path, ext: str = '.pkl'): ensure_persistence_class(data_cls) self._save_path = save_root self._data_cls = data_cls self._ext = ext
@property def persistence_class(self) -> type[T]: return self._data_cls @property def save_root(self) -> Path: return self._save_path
[docs] def filename(self, result: T | None, **kwargs) -> str: return super().filename(result, **kwargs) + self._ext
def _save_persistence(self, result: T, path: Path): import pickle with path.open('wb') as file: pickle.dump(as_dict(result), file) def _load_persistence(self, path: Path) -> T: import pickle with path.open('rb') as file: ret = pickle.load(file) data_cls = self.persistence_class if isinstance(ret, dict): return from_dict(data_cls, ret) elif isinstance(ret, data_cls): # for old persistent pickle file return ret else: raise TypeError(f'not a {data_cls.__name__} for cache {path} : {ret}')
[docs] class GzipHandler(PersistenceHandler[T]): """ Support field type: all python objects. """
[docs] def __init__(self, data_cls: type[T], save_root: Path, ext: str = '.pkl.gz', compression: int = 9): ensure_persistence_class(data_cls) self._save_path = save_root self._data_cls = data_cls self._ext = ext self._cmp = compression
@property def persistence_class(self) -> type[T]: return self._data_cls @property def save_root(self) -> Path: return self._save_path
[docs] def filename(self, result: T | None, **kwargs) -> str: return super().filename(result, **kwargs) + self._ext
def _save_persistence(self, result: T, path: Path): import gzip import pickle with gzip.open(path, 'wb', compresslevel=self._cmp) as file: pickle.dump(as_dict(result), file) def _load_persistence(self, path: Path) -> T: import gzip import pickle with gzip.open(path, 'rb') as file: ret = pickle.load(file) return from_dict(self.persistence_class, ret)