Source code for neuralib.rastermap.run

import pickle
from pathlib import Path
from typing import Literal, cast

import numpy as np
from neuralib.typing import PathLike
from scipy.stats import zscore

from .core import RasterMapResult, RasterOptions

__all__ = ['DATA_TYPE', 'run_rastermap']

DATA_TYPE = Literal['cellular', 'wfield']


[docs] def run_rastermap(neural_activity: np.ndarray, bin_size: int, dtype: DATA_TYPE, *, options: RasterOptions | None = None, svd_components: int = 128, svd_cache: PathLike | None = None, invalid_svd_cache: bool = False, filename: str | None = None, save_path: str | None = None, **kwargs) -> RasterMapResult: """ Run rastermap :param neural_activity: neural activity. `Array[float, [N, T]]` | `Array[Any, [T, H, W]]` :param bin_size: number of neurons to bin over :param dtype: :attr:`~neuralib.model.rastermap.run.DATA_TYPE` :param options: :class:`~neuralib.model.rastermap.core.RasterOptions` :param svd_components: number of SVD components to use (for widefield dtype) :param svd_cache: path to cache SVD file (for widefield dtype) :param invalid_svd_cache: invalid (recompute) SVD cached file (for widefield dtype) :param filename: optional specify if GUI re-launch :param save_path: optional specify if GUI re-launch """ rastermap = RunRastermap(neural_activity, bin_size, dtype, options=options, svd_components=svd_components, svd_cache=svd_cache, invalid_svd_cache=invalid_svd_cache, filename=filename, save_path=save_path) return rastermap.run(**kwargs)
class RunRastermap: DEFAULT_OPTIONS: RasterOptions = { 'n_clusters': 50, 'n_PCs': 128, 'locality': 0.75, 'time_lag_window': 5, 'grid_upsample': 10, } def __init__(self, neural_activity: np.ndarray, bin_size: int, dtype: DATA_TYPE, *, options: RasterOptions | None = None, svd_components: int = 128, svd_cache: PathLike | None = None, invalid_svd_cache: bool = False, filename: str | None = None, save_path: str | None = None): """ :param neural_activity: neural activity. `Array[float, [N, T]]` | `Array[Any, [T, H, W]]` :param bin_size: number of neurons to bin over :param options: :class:`~neuralib.model.rastermap.core.RasterOptions` :param svd_components: number of SVD components to use (for widefield dtype) :param svd_cache: path to cache SVD file (for widefield dtype) :param invalid_svd_cache: invalid (recompute) SVD cached file (for widefield dtype) :param filename: optional specify if GUI re-launch :param save_path: optional specify if GUI re-launch """ self.neural_activity = neural_activity self.bin_size = bin_size self.dtype = dtype self.svd_components = svd_components self.svd_cache = svd_cache self.invalid_svd_cache = invalid_svd_cache self.filename = filename self.save_path = save_path self._options = options or self.DEFAULT_OPTIONS @property def options(self) -> RasterOptions: return self._options def run(self, **kwargs) -> RasterMapResult: """run rastermap with the given dtype""" match self.dtype: case 'cellular': ret = self.run_cellular(**kwargs) case 'wfield': ret = self.run_wfield(**kwargs) case _: raise NotImplementedError('') if self.save_path is not None: ret.save(self.save_path) return ret def run_cellular(self, **kwargs) -> RasterMapResult: """cellular input run. `Array[float, [N, T]]`""" try: import rastermap.utils from rastermap import Rastermap except ImportError: raise RuntimeError('pip install rastermap first') opt = self.options model = Rastermap(n_clusters=opt['n_clusters'], n_PCs=opt['n_PCs'], locality=opt['locality'], time_lag_window=opt['time_lag_window'], grid_upsample=opt['grid_upsample'], **kwargs) model.fit(self.neural_activity) embedding = model.embedding isort = model.isort sn = rastermap.utils.bin1d(self.neural_activity[isort], bin_size=self.bin_size, axis=0) ret = RasterMapResult( filename=self.filename, save_path=self.save_path, isort=isort, embedding=embedding, ops=opt, user_clusters=[], super_neurons=sn ) return ret def run_wfield(self, **kwargs): """wfield input run. `Array[Any, [T, H, W]]`""" from neuralib.widefield import compute_singular_vector # pyright: ignore[reportMissingImports] try: import rastermap.utils from rastermap import Rastermap except ImportError: raise RuntimeError('pip install rastermap first') # opt = self.options model = Rastermap( n_clusters=opt['n_clusters'], n_PCs=opt['n_PCs'], locality=opt['locality'], time_lag_window=opt['time_lag_window'], grid_upsample=opt['grid_upsample'], **kwargs ) # svd cache if self.svd_cache is not None: file = Path(self.svd_cache) if file.exists() and not self.invalid_svd_cache: with file.open("rb") as f: sv = pickle.load(f) else: sv = compute_singular_vector(self.neural_activity, self.svd_components) with file.open('wb') as f: pickle.dump(sv, f) else: sv = compute_singular_vector(self.neural_activity, self.svd_components) usv = sv.singular_value * sv.left_vector vsv = sv.right_vector # fit model.fit(Usv=usv, Vsv=vsv) embedding = model.embedding isort = model.isort Vsv_sub = model.Vsv U_sn = rastermap.utils.bin1d(sv.left_vector[isort], bin_size=self.bin_size, axis=0) sn = U_sn @ Vsv_sub.T sn = zscore(sn, axis=1) ret = RasterMapResult( filename=self.filename, save_path=self.save_path, isort=isort, embedding=embedding, ops=opt, user_clusters=[], super_neurons=cast(np.ndarray, sn) ) return ret