Source code for neuralib.persistence.cli_persistence

import abc
from pathlib import Path
from typing import Generic, TypeVar, get_args, get_origin

from argclz import argument, copy_argument
from neuralib.persistence import AutoIncFieldNotResolvedError, PersistenceHandler, PickleHandler, persistence
from neuralib.util.verbose import fprint, print_load, print_save

__all__ = ['PersistenceOptions',
           'get_options_and_cache']

T = TypeVar('T')


def persistence_filename(cache: object) -> str:
    return persistence.filename(cache) + '.pkl'


[docs] class PersistenceOptions(Generic[T], metaclass=abc.ABCMeta): """The Option class that handle one kind of cache class T, including load cache, create cache, and save cache. """ GROUP_CACHE = 'Persistence options' invalid_cache: bool = argument( '--invalid-cache', group=GROUP_CACHE, help='invalid persistence data' ) @property def persistence_class(self) -> type[T]: # https://stackoverflow.com/a/50101934 for t in getattr(type(self), '__orig_bases__', ()): if get_origin(t) == PersistenceOptions: return get_args(t)[0] raise TypeError('unable to retrieve cache class T')
[docs] def persistence_handler(self, dest: Path) -> PersistenceHandler[T]: """ :param dest: save root directory. :return: """ return PickleHandler(self.persistence_class, dest)
[docs] @abc.abstractmethod def empty_cache(self) -> T: """create an empty cache which only initialize required fields. :return: cache instance """ pass
[docs] def find_cache(self, result: T, dest: Path | None = None, validator=False) -> list[T]: """Find the persistence. for all fields >>> template = self.empty_cache() >>> template.a = 1 # want to find all cache whose `a` equals to 1 >>> template.b = field_missing # want to find all cache and don't matter what `b` is >>> found = self.find_cache(template) :param result: :param dest: save root directory. :param validator: :return: """ handler = self.persistence_handler(Path() if dest is None else dest) ret = [] for file, found in handler.load_all(result): if validator: if not self.validate_cache(file, found): continue ret.append(found) return ret
[docs] def save_cache(self, result: T, dest: Path, force=True): """ :param result: :param dest: save root directory. :param force: :return: """ save_path = dest / persistence_filename(result) if save_path.exists() and not force: raise FileExistsError(str(save_path)) save_path.parent.mkdir(parents=True, exist_ok=True) persistence.save(result, save_path)
[docs] def load_cache(self, result: T | None = None, error_when_missing=False, dest: Path | None = None, **kwargs) -> T: """load persistence from disk according to *result*'s required fields. :param result: persistence instance with necessary fields filled. :param error_when_missing: do not try to generate the cache when cache missing. :param dest: save root directory. :param kwargs: overwrite field value in *result*. :return: persistence instance. :raise FileNotFoundError: error_when_missing and file not found. """ do_load = not self.invalid_cache handler = self.persistence_handler(Path() if dest is None else dest) if result is None: result = self.empty_cache() # load/save path try: output_file = handler.filepath(result, **kwargs) except AutoIncFieldNotResolvedError as e: if error_when_missing: raise FileNotFoundError from e do_load = False output_file = None if do_load and output_file is not None: try: print_load(output_file) ref = result result = handler.load_persistence(output_file) if handler.validate(ref, result) and self.validate_cache(output_file, result): return result except FileNotFoundError: if error_when_missing: raise except (TypeError, ValueError, KeyError, AttributeError, RuntimeError) as e: fprint(repr(e), vtype='error') elif error_when_missing: raise FileNotFoundError(output_file) result = self.compute_cache(result) handler.save_persistence(result, output_file) output_file = handler.filepath(result) print_save(output_file) return result
[docs] def validate_cache(self, result_path: Path, result: T) -> bool: """Validating loaded cache instance. Once validating fail (return False), goto :meth:`_compute_cache`. :param result_path: :param result: :return: False if validating fail. :raise TypeError: if validating fail. :raise ValueError: if validating fail. :raise KeyError: if validating fail. :raise AttributeError: if validating fail. :raise RuntimeError: if validating fail. """ return True
[docs] @abc.abstractmethod def compute_cache(self, result: T) -> T: """Compute cache according to *cache*'s required fields. :param result: :return: computed result """ pass
[docs] def get_options_and_cache(opt_cls: type[PersistenceOptions[T]], ref, error_when_missing: bool = False, invalid_cache: bool = False, **kwargs) -> T: """ copy the arguments from ``PersistenceOptions`` (class that compute the cache) to ``ApplyOptions`` Can be used for analysis apply two different cached files :param opt_cls: ``PersistenceOptions`` :param ref: reference class for applying the cache :param error_when_missing: do not try to generate the cache when cache missing :param invalid_cache: invalid_cache pass to ``PersistenceOptions`` :return: """ return ( copy_argument(opt_cls(), ref, invalid_cache=invalid_cache) .load_cache(error_when_missing=error_when_missing, **kwargs) )