"""
Persistence Class
=================
:author:
Ta-Shun Su
Define a persistence class
---------------------------
Import
>>> import numpy as np
>>> from neuralib import persistence
Define class
>>> @persistence.persistence_class
... class Example:
... # key. use filename keyword to display key in filename.
... use_animal: str = persistence.field(validator=True, filename=True)
... use_session: str = persistence.field(validator=True)
... use_date: str = persistence.field(validator=True, filename=True)
... # data
... channels: list[int]
... data: np.ndarray
Load/Save
>>> example = Example(use_animal='A00', use_session='', use_date='1234')
>>> save(example, 'example.pkl')
>>> example_2 = load(Example, 'example.pkl') # example and example_2 should be content identical
Cooperate with PersistenceOptions
----------------------------------
>>> from neuralib.persistence.cli_persistence import PersistenceOptions
>>> class ExampleHandle(PersistenceOptions[Example]):
... def empty_cache(self) -> Example:
... return Example(use_animal='A00', use_session='', use_date='1234') # with attribute initialization
... def compute_cache(self, result: Example) -> Example:
... result.channels = [0, 1, 2]
... result.riglog = np.array(result.channels)
... return result
Dynamic generated methods for persistence class
------------------------------------------------
1. `__init__({foreach persistence.field})`
>>> class Example:
... ... # as same as above
... def __init__(self, use_animal:str, use_session: str, use_date: str): # auto generated
... ...
2. `__str__` return filename
>>> class Example:
... ... # as same as above
... def __str__(self): # auto generated
... return filename(self)
3. `__repr__`
>>> class Example:
... ... # as same as above
... def __repr__(self): # auto generated
... return 'Example{' + f'use_animal={self.use_animal}, use_session={self.use_session}, use_date={self.use_date}' + "}"
4. `_replace` when persistence define empty `_replace` methods. It is NamedTuple._replace like function.
>>> class Example:
... ... # as same as above
... def _replace(self, **kwargs): pass # empty method
... def _replace(self, *, # replaced by generated
... use_animal=missing,
... use_session=missing,
... use_date=missing,
... channels=missing,
... data=missing) -> Example:
... ...
Auto increment field
---------------------
For some reason you want to save a persistence result that came from same data source but different
contents, which are usualy generated by random or suffle process. To save them all separately, you may
need a field value that keep track that it is n-th persistence result. :func:`autoinc_field` is proposed
to help this case.
>>> @persistence_class
... class Result:
... a: str = field(validator=True, filename=T)
... b: int = autoinc_field()
... c: str
... def __init__(self, a: str, b: int = None): # auto generated signature
... def _replace(self, *, a: str, c:str): # auto generated signature
There are some rule when a persistence class has an autoinc_field.
1. only one autoinc_field in a persistence class allowed.
2. field type only int is allowed.
3. raise error when load a result without autoinc field resolved.
4. autoinc field is auto resolved when saving. Its value is max(found) + 1
Pickle format
--------------
Persistence class is transformed into a dict by `as_dict`, which as a root instance to be saved into
a pickle file.
**IMPORTANT**
If a persistance class has a custom `__init__` function which signature is differed from auto generated,
you need to define a classmethod `from_dict` for creating that persistance class.
>>> @persistence.persistence_class
... class Example:
... a: int = persistence.field(validator=True, filename=True)
... b: int = persistence.field(validator=True, filename=True)
... c: int
... def __init__(self): ... # custom __init__
... @classmethod
... def from_dict(cls, data:dict[str, Any]) -> 'Example':
... # data = {'a', 'b', 'c'}
"""
import abc
import inspect
import sys
from collections.abc import Callable, Iterator
from pathlib import Path
from typing import Any, Generic, Type, TypeVar, Union, get_type_hints
from neuralib.util.func import PARA_TYPE, create_fn
__all__ = [
'field',
'autoinc_field',
'persistence_class',
'ensure_persistence_class',
'as_dict',
'from_dict',
'load',
'save',
'filename',
'AutoIncFieldNotResolvedError',
'auto_generated_content',
'PersistenceHandler',
'PickleHandler',
'GzipHandler',
]
T = TypeVar('T')
P = Union[str, Path]
missing = inspect.Parameter.empty
VALIDATOR = Callable[[Any, Any], bool]
[docs]
def field(validator: bool | VALIDATOR = False,
filename_prefix: str = '',
filename: bool | Callable[..., str] = False) -> Any:
"""Cache class's exported field. Used as keys to find correspond persistence.
:param validator: validate this field. use __eq__ by default. Can be a callable as a customized validator.
:param filename_prefix: prefix word of *filename*
:param filename: display this value on filename. Can be a callable that return the string.
:return:
"""
return PersistentField(validator, filename_prefix, filename, init=True)
[docs]
def autoinc_field(filename_prefix: str = '') -> Any:
"""make a field as auto increment, which can be used to save and distinguish between same source data but
different random/shuffle persistence result.
:param filename_prefix:
:return:
"""
# noinspection PyTypeChecker
return PersistentField(validator=True, filename_prefix=filename_prefix, filename=True, init=True, autoinc=True)
class PersistentField(Generic[T]):
"""exported field of persistence class."""
__slots__ = ('field_name', 'field_type', 'validator', 'filename_prefix', 'filename', 'init', 'autoinc', 'optional')
def __init__(self,
validator: bool | VALIDATOR = False,
filename_prefix: str = '',
filename: bool | Callable[..., str] = False,
init=True,
autoinc=False,
optional=False):
"""
:param validator: validate this field. use __eq__ by default. Can be a callable as a customized validator.
:param filename_prefix: prefix word of *filename*
:param filename: display this value on filename. Can be a callable that return the string.
:param init: put this field into class __init__
:param autoinc:
:param optional: does this field has default value
"""
self.validator = validator
"""Is this field a identify?"""
self.filename_prefix = filename_prefix
"""The filename prefix word of this field value"""
self.filename = filename
"""Does this field show in filename?"""
self.init = init
"""Does this field is a __init__ argument?"""
self.autoinc = autoinc
"""Does this field is auto incremental field?"""
self.optional = optional
"""Does this field has default value?"""
def __set_name__(self, owner: type, name: str):
self.field_name = name
if self.autoinc:
if (field_type := get_type_hints(owner).get(name, Any)) is not int:
raise RuntimeError(f'type of autoinc field {name} should be int, but {field_type}')
else:
self.field_type = get_type_hints(owner).get(name, Any)
def validate(self, v: T, u: T) -> bool:
"""Validate this two value.
:param v:
:param u:
:return:
"""
if self.validator is False:
return True
elif self.validator is True:
return v == u
elif callable(self.validator):
return self.validator(v, u)
else:
raise RuntimeError()
[docs]
class AutoIncFieldNotResolvedError(RuntimeError):
"""Raised when a persistence's autoinc field is not resolved,
and it is unable to do following operations.
"""
[docs]
def __init__(self, instance, field: str | PersistentField, message: str | None = None):
if isinstance(field, PersistentField):
field = field.field_name
if message is None:
message = f'{type(instance).__name__} autoinc field {field} is not resolved'
super().__init__(message)
self.instance = instance
self.field = field
[docs]
def persistence_class(cls: type | None = None, /, *,
name: str | None = None,
filename_field_splitter='-'):
"""A class decorator.
Decorated class ...
:param cls: Persistence class.
:param name: class name as filename.
:param filename_field_splitter: the field splitter on filename.
:return:
"""
def decorator(cls: type):
cls._ast_persistence_cls_info_ = pc = PersistentClass(cls, name, filename_field_splitter)
prev_autoinc_field = None
for attr_name, attr_type in cls.__annotations__.items():
if isinstance((attr_value := getattr(cls, attr_name, None)), PersistentField):
delattr(cls, attr_name)
f = attr_value
else:
f = PersistentField(init=False)
f.field_name = attr_name
f.field_type = attr_type
f.optional = attr_value is not None
if f.autoinc:
if prev_autoinc_field is None:
prev_autoinc_field = f
else:
raise RuntimeError('duplicated auto_inc_field')
pc.fields.append(f)
if cls.__init__ == object.__init__:
cls.__init__ = _persistence_class_init(pc)
if cls.__str__ == object.__str__:
cls.__str__ = _persistence_class_str(pc)
if cls.__repr__ == object.__repr__:
cls.__repr__ = _persistence_class_repr(pc)
if hasattr(cls, '_replace'):
cls._replace = _persistence_class_replace(pc)
if sys.version_info >= (3, 10):
if not hasattr(cls, '__match_args__'):
cls.__match_args__ = tuple([it.field_name for it in pc.fields])
return cls
if cls is None:
return decorator
else:
return decorator(cls)
class PersistentClass(Generic[T]):
"""persistence information class"""
__slots__ = ('persistence_cls', 'cls_name', 'filename_field_splitter', 'fields')
def __init__(self,
persistence_cls: type,
persistence_name: str | None = None,
filename_field_splitter='-'):
self.persistence_cls = persistence_cls
self.cls_name = persistence_name if persistence_name is not None else persistence_cls.__name__
self.filename_field_splitter = filename_field_splitter
self.fields: list[PersistentField] = []
def fields_name(self) -> list[str]:
return [it.field_name for it in self.fields]
def get_field(self, name: str) -> PersistentField | None:
for f in self.fields:
if f.field_name == name:
return f
return None
def autoinc_field(self) -> PersistentField | None:
for f in self.fields:
if f.autoinc:
return f
return None
def is_autoinc_field_resolved(self, result: T | None, **kwargs) -> bool:
if (af := self.autoinc_field()) is None:
return True
return getattr(result, af.field_name, None) is not None or af.field_name in kwargs
def validate(self, v: T, u: T | dict[str, Any]) -> bool:
"""validate that does data u is as same as data v.
:param v: reference data
:param u: tested data
:return: False if validation fail.
"""
for f in self.fields:
fv = getattr(v, f.field_name, None)
if isinstance(u, dict):
fu = u.get(f.field_name, None)
else:
fu = getattr(u, f.field_name, None)
if not f.validate(fv, fu):
raise FileNotFoundError(f'{self.cls_name}.{f.field_name} validate fail: {fv} != {fu}')
return True
def filename(self, data: T | None, **kwargs) -> str:
"""build filename for persistence instance.
:param data: persistence instance.
:param kwargs: overwrote keywords.
:return: filename.
:raise RuntimeError: filename field missing.
:raise TypeError: wrong field.filename type.
"""
# legacy name
ret = ['cache', self.cls_name]
for f in self.fields:
if f.filename:
if f.field_name in kwargs:
field_value = kwargs[f.field_name]
else:
if data is None or not hasattr(data, f.field_name):
raise RuntimeError(f'field {f.field_name} is required for filename')
field_value = getattr(data, f.field_name)
if field_value is missing:
s = '*'
elif f.filename is True:
s = str(field_value)
elif callable(f.filename):
if (s := f.filename(field_value)) is not None:
s = str(s)
else:
continue
else:
raise TypeError()
ret.append(f.filename_prefix + s)
return self.filename_field_splitter.join(ret)
def _persistence_class_init(pc: PersistentClass):
"""generate an __init__ function for persistent class."""
init_fields: list[PARA_TYPE] = ['self']
for f in pc.fields:
if f.init:
init_fields.append((f.field_name, None, 'None') if f.autoinc else f.field_name)
code = []
for name in init_fields[1:]:
if isinstance(name, tuple):
name = name[0]
code.append(f'self.{name} = {name}')
return create_fn('__init__', init_fields, '\n'.join(code),
locals=dict(missing=missing))
def _persistence_class_str(pc: PersistentClass):
"""generate a __str__ function for persistent class."""
return create_fn('__str__', (['self'], str), 'return pc.filename(self)', locals={'pc': pc})
def _persistence_class_repr(pc: PersistentClass):
"""generate a __repr__ function for persistent class."""
init_fields = [f.field_name for f in pc.fields if f.init]
code = [f'return ("{pc.cls_name}' + '{"']
for i, name in enumerate(init_fields):
comma = '", " ' if i > 0 else ''
code.append(f'{comma}f"{name}={{self.{name}}}"')
code.append('"}")')
return create_fn('__repr__', (['self'], str), ' '.join(code))
def _persistence_class_replace(pc: PersistentClass):
"""generate a _replace function for persistent class."""
init_fields = [f.field_name for f in pc.fields if f.init]
data_fields = [f.field_name for f in pc.fields if not f.init and not f.autoinc]
code = [f'ret = {pc.cls_name}(']
for name in init_fields:
code.append(f'{name} if {name} is not missing else self.{name},')
code.append(')')
for name in data_fields:
code.append(f'ret.{name} = {name} if {name} is not missing else self.{name}')
code.append('return ret')
replace_fields: list[PARA_TYPE] = ['self', '*']
for name in init_fields + data_fields:
replace_fields.append((name, None, 'missing'))
return create_fn('_replace',
(replace_fields, pc.cls_name),
'\n'.join(code),
locals={pc.cls_name: pc.persistence_cls, 'missing': missing})
[docs]
def auto_generated_content(**kwargs):
"""It is used to mark the function which its function body is auto generated.
:param kwargs: blackhole
:return: nothing
"""
raise RuntimeError('It is auto generated content')
[docs]
def ensure_persistence_class(data: T | type[T]) -> PersistentClass[T]:
"""ensure **data** is a persistence class.
:param data: instance or type
:return: persistence info
:raise RuntimeError: not a persistence class
"""
if not isinstance(data, type):
data = type(data)
try:
cls_info: PersistentClass = data._ast_persistence_cls_info_
except AttributeError as e:
raise RuntimeError(f'not a persistence_class : {data.__name__}') from e
return cls_info
[docs]
def as_dict(data: object) -> dict[str, Any]:
"""transform persistence *data* into dictionary, which field as key.
:param data: persistence instance
:return: dict
"""
if data is None:
raise TypeError('data is None')
info = ensure_persistence_class(data)
ret = {}
for field in info.fields:
try:
ret[field.field_name] = getattr(data, field.field_name)
except AttributeError:
pass
return ret
[docs]
def from_dict(data_cls: type[T], d: dict[str, Any]) -> T:
"""transform dictionary
:param data_cls:
:param d:
:return:
"""
info = ensure_persistence_class(data_cls)
def get_or_raise(key):
try:
return d[key]
except KeyError:
pass
raise KeyError(f'missing required field : {key}')
init = {
f.field_name: get_or_raise(f.field_name)
for f in info.fields
if f.init
}
try:
ret = data_cls(**init)
except TypeError:
if hasattr(data_cls, 'from_dict'):
return _from_dict_factory(data_cls, info, d)
else:
raise
else:
return _from_dict_builtin(ret, info, d)
def _from_dict_builtin(ret: T, info: PersistentClass, d: dict[str, Any]) -> T:
"""After initialized *ret*, set attributes for remind fields."""
for f in info.fields:
if not f.init:
try:
v = d[f.field_name]
except KeyError:
pass
else:
setattr(ret, f.field_name, v)
return ret
def _from_dict_factory(data_cls: type[T], info: PersistentClass, d: dict[str, Any]) -> T:
"""When *data_cls* initialize failed, try invoking `from_dict` function."""
kwargs = {}
for f in info.fields:
try:
v = d[f.field_name]
except KeyError:
pass
else:
kwargs[f.field_name] = v
ret = data_cls.from_dict(kwargs)
if not isinstance(ret, data_cls):
raise TypeError()
return ret
[docs]
def save(data: object, path: P) -> None:
"""save persistence **data** under **path**.
:param data: persistence instance
:param path: filepath
"""
path = Path(path)
PickleHandler(type(data), path.parent).save_persistence(data, path)
[docs]
def load(data_cls: type[T], path: P) -> T:
"""Load data as **data_cls** from **path**.
:param data_cls: data class type
:param path: filepath
:return: persistence instance
"""
path = Path(path)
return PickleHandler(data_cls, path.parent).load_persistence(path)
def load_by(data_cls: type[T], path: P, **kwargs) -> T:
"""Load **data_cls** from directory **path** with fields **kwargs**.
:param data_cls: data class type
:param path: directory
:param kwargs: data fields
:return: persistence instance
"""
path = Path(path)
handler = PickleHandler(data_cls, path.parent)
return handler.load_persistence(handler.filepath(None, **kwargs))
[docs]
def filename(result: T | type[T], **kwargs) -> str:
"""Get data persistence filename.
:param result:
:param kwargs: overwrite fields.
:return: filename
:raise RuntimeError: *result*'s autoinc field not resolved, or filename field missing.
:raise TypeError: wrong field.filename type.
:raise AutoIncFieldNotResolvedError: *result*'s autoinc field not resolved.
"""
cls_info = ensure_persistence_class(result)
if isinstance(result, type):
result_instance: T | None = None
else:
result_instance = result
if not cls_info.is_autoinc_field_resolved(result_instance, **kwargs):
af = cls_info.autoinc_field()
if af is None:
raise RuntimeError('missing autoinc field')
raise AutoIncFieldNotResolvedError(result, af,
f'cannot generate filepath without autoinc field {af.field_name} keywords')
name = cls_info.filename(result_instance, **kwargs)
return name
[docs]
class PersistenceHandler(Generic[T], metaclass=abc.ABCMeta):
"""The handler for loading and saving persistence instance."""
@property
@abc.abstractmethod
def persistence_class(self) -> type[T]:
""":return: type T"""
pass
@property
def persistence_info(self) -> PersistentClass[T]:
"""information for persistence class"""
return ensure_persistence_class(self.persistence_class)
@property
@abc.abstractmethod
def save_root(self) -> Path:
"""saving directory"""
pass
[docs]
def filename(self, result: T | None, **kwargs) -> str:
"""build filename for persistence instance.
:param result: persistence instance
:param kwargs: overwrite field value in *result*.
:return: file name of *result*, may contains '{}' if *result*'s autoinc field not resolved
"""
cls_info = ensure_persistence_class(self.persistence_class)
return cls_info.filename(result, **kwargs)
[docs]
def filepath(self, result: T | None, **kwargs) -> Path:
"""build filepath for persistence instance.
:param result: persistence instance
:param kwargs: overwrite field value in *result*.
:return: file path of *result*
:raise RuntimeError: *result*'s autoinc field not resolved, or errors from :meth:`filename`
"""
info = ensure_persistence_class(self.persistence_class)
if not info.is_autoinc_field_resolved(result, **kwargs):
af = info.autoinc_field()
if af is None:
raise RuntimeError('missing autoinc field')
raise AutoIncFieldNotResolvedError(result, af,
f'cannot generate filepath without autoinc field {af.field_name} keywords')
name = self.filename(result, **kwargs)
return self.save_root / name
[docs]
def validate(self, ref: T, res: T) -> bool:
return ensure_persistence_class(ref).validate(ref, res)
[docs]
def save_persistence(self, result: T, path: str | Path | None = None) -> T:
"""save persistence *result* under **path**.
:param result:
:param path: save path.
:return: *result*. autoinc field will be resolved after saving.
:raise AutoIncFieldNotResolvedError:
"""
info = ensure_persistence_class(result)
if path is None:
if not info.is_autoinc_field_resolved(result):
f = info.autoinc_field()
if f is None:
raise RuntimeError('missing autoinc field')
found = [it for _, it in self.load_all(result, **{f.field_name: '*'})]
u = max([
value
for it in found
if isinstance(value := getattr(it, f.field_name, 0), int)
], default=-1) + 1
setattr(result, f.field_name, u)
path = self.filepath(result)
else:
if not info.is_autoinc_field_resolved(result):
f = info.autoinc_field()
if f is None:
raise RuntimeError('missing autoinc field')
raise AutoIncFieldNotResolvedError(result, f)
if isinstance(path, str):
path = Path(path)
if path.is_dir():
raise IsADirectoryError(str(path))
path.parent.mkdir(parents=True, exist_ok=True)
self._save_persistence(result, path)
return result
@abc.abstractmethod
def _save_persistence(self, result: T, path: Path):
pass
[docs]
def load_persistence(self, path: Path | T | dict[str, Any]) -> T:
"""Load data as **data_cls** from **path** without validation.
:param path: load from path.
:return: persistence instance
:raise IsADirectoryError:
"""
data_cls = self.persistence_class
ensure_persistence_class(data_cls)
if isinstance(path, dict):
load_path = self.filepath(None, **path)
elif isinstance(path, Path):
load_path = path
elif isinstance(path, str):
load_path = Path(path)
else:
load_path = self.filepath(path)
if load_path.is_dir():
raise IsADirectoryError(str(load_path))
return self._load_persistence(load_path)
@abc.abstractmethod
def _load_persistence(self, path: Path) -> T:
pass
[docs]
def load_all(self, result: T | None, **kwargs) -> Iterator[tuple[Path, T]]:
"""load all persistent result under *save_root*.
*missing* is used to make a field becomes a wildcard field.
>>> template = Example(use_animal='A00', use_session='test', use_date='20200101')
>>> # find all animal A00's persistent result.
>>> found = PickleHandler(Example, Path('')).load_all(template, use_date=missing)
"""
for file in self.save_root.glob(self.filename(result, **kwargs)):
yield file, self.load_persistence(file)
[docs]
class PickleHandler(PersistenceHandler[T]):
"""
Support field type: all python objects.
"""
[docs]
def __init__(self, data_cls: type[T], save_root: Path, ext: str = '.pkl'):
ensure_persistence_class(data_cls)
self._save_path = save_root
self._data_cls = data_cls
self._ext = ext
@property
def persistence_class(self) -> type[T]:
return self._data_cls
@property
def save_root(self) -> Path:
return self._save_path
[docs]
def filename(self, result: T | None, **kwargs) -> str:
return super().filename(result, **kwargs) + self._ext
def _save_persistence(self, result: T, path: Path):
import pickle
with path.open('wb') as file:
pickle.dump(as_dict(result), file)
def _load_persistence(self, path: Path) -> T:
import pickle
with path.open('rb') as file:
ret = pickle.load(file)
data_cls = self.persistence_class
if isinstance(ret, dict):
return from_dict(data_cls, ret)
elif isinstance(ret, data_cls): # for old persistent pickle file
return ret
else:
raise TypeError(f'not a {data_cls.__name__} for cache {path} : {ret}')
[docs]
class GzipHandler(PersistenceHandler[T]):
"""
Support field type: all python objects.
"""
[docs]
def __init__(self, data_cls: type[T], save_root: Path, ext: str = '.pkl.gz',
compression: int = 9):
ensure_persistence_class(data_cls)
self._save_path = save_root
self._data_cls = data_cls
self._ext = ext
self._cmp = compression
@property
def persistence_class(self) -> type[T]:
return self._data_cls
@property
def save_root(self) -> Path:
return self._save_path
[docs]
def filename(self, result: T | None, **kwargs) -> str:
return super().filename(result, **kwargs) + self._ext
def _save_persistence(self, result: T, path: Path):
import gzip
import pickle
with gzip.open(path, 'wb', compresslevel=self._cmp) as file:
pickle.dump(as_dict(result), file)
def _load_persistence(self, path: Path) -> T:
import gzip
import pickle
with gzip.open(path, 'rb') as file:
ret = pickle.load(file)
return from_dict(self.persistence_class, ret)