Source code for neuralib.rastermap.plot

from __future__ import annotations

from typing import Literal, NamedTuple, Self, cast

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.patches import Patch
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from neuralib.plot import plot_figure
from neuralib.rastermap import RasterMapResult
from neuralib.typing import PathLike
from scipy.interpolate import interp1d

__all__ = [
    'plot_rastermap',
    'plot_cellular_spatial',
    'plot_wfield_spatial',
    'RasterMapPlot',
    'Covariant'
]


[docs] def plot_rastermap(result: RasterMapResult, act_time: np.ndarray, *, time_range: tuple[float, float] | None = None, covars: list[Covariant] | None = None, figsize: tuple[float, float] = (8, 6), event_colors: dict[str, str] | None = None, output: PathLike | None = None): """ plot the rastermap result with behavioral measurements :param result: :class:`~.core.RasterMapResult` :param act_time: neural activity time array. should be the same T as neural_activity when run the rastermap :param time_range: time range for plotting (START,END) :param covars: list of :class:`~Covariant` :param figsize: figure size :param event_colors: event color dict. {event_name: color} :param output: output path for figure save. If None then show """ plotter = RasterMapPlot(result, act_time, time_range, covars) plotter.plot_rastermap(figsize, event_colors, output)
[docs] def plot_cellular_spatial(result: RasterMapResult, xpos: np.ndarray, ypos: np.ndarray, ax: Axes | None = None, output: PathLike | None = None, **kwargs): """ Plot spatial location of each cell cluster by rastermap :param result: :class:`~.core.RasterMapResult` :param xpos: soma central X position.`Array[float, N]` :param ypos: soma central Y position.`Array[float, N]` :param ax: ``Axes`` :param output: output path for figure save. If None then show :param kwargs: additional arguments pass to ``ax.set()`` :return: """ if ax is None: _, ax = plt.subplots() ax.scatter(xpos, ypos, s=8, c=result.embedding, cmap="gist_ncar", alpha=0.25) ax.invert_yaxis() ax.set(**kwargs) ax.set_aspect('equal') if output is not None: plt.savefig(output) else: plt.show()
[docs] def plot_wfield_spatial(result: RasterMapResult, width: int, height: int, ax: Axes | None = None, output: PathLike | None = None, **kwargs): """ Plot spatial location of each pixel cluster by rastermap :param result: :class:`~.core.RasterMapResult` :param width: sequence image width :param height: sequence image height :param ax: ``Axes`` :param output: output path for figure save. If None then show :param kwargs: additional arguments pass to ``ax.set()`` """ if ax is None: _, ax = plt.subplots() x = np.arange(width) y = np.arange(height) xpos, ypos = np.meshgrid(x, y) # Array[float, [W, H]] ax.scatter(xpos, ypos, s=1, c=result.embedding, cmap="gist_ncar", alpha=0.25) ax.invert_yaxis() ax.set(**kwargs) ax.set_aspect('equal') if output is not None: plt.savefig(output) else: plt.show()
[docs] class RasterMapPlot: """Plot the rastermap result with behavioral measurements"""
[docs] def __init__(self, result: RasterMapResult, act_time: np.ndarray, time_range: tuple[float, float] | None = None, covars: list[Covariant] | None = None): """ :param result: class:`~.core.RasterMapResult` :param act_time: neural activity time array. should be the same T as neural_activity when run the rastermap :param time_range: time range for plotting (START,END) :param covars: list of :class:`~Covariant` """ self.raster = result self.covars = covars self._covars_check() self.time_range = time_range or (act_time[0], act_time[-1]) if time_range is not None: self.act_mask = np.logical_and(time_range[0] < act_time, act_time < time_range[1]) else: self.act_mask = np.ones_like(act_time, dtype=np.bool_) self.act_time = act_time[self.act_mask]
def _covars_check(self): if self.covars is not None: for cov in self.covars: if cov.dtype == 'continuous': if cov.value is None: raise ValueError('value is required for continuous covariant') assert cov.time.shape == cov.value.shape, 'time/value shape mismatch' elif cov.dtype == 'event': assert cov.time.shape[1] == 2, f'event time shape should be [E, 2]: {cov.time.shape}' else: raise ValueError(f'unknown dtype: {cov.dtype}') @property def super_neurons(self) -> np.ndarray: """rastermap sorted 2D array. `Array[float, [N, T]]`""" if self.raster.super_neurons is None: raise RuntimeError('rastermap result has no super neuron data') return self.raster.super_neurons[:, self.act_mask]
[docs] def process_continuous(self) -> list[Covariant]: """process behavioral measurements, select time range and do the interpolation same shape as neural activity""" if self.covars is None: return [] return [ cov.masking_time(self.time_range).interp_activity(self.act_time) for cov in self.covars if cov.dtype == 'continuous' ]
[docs] def plot_rastermap(self, figsize: tuple[float, float] = (8, 6), event_colors: dict[str, str] | None = None, output: PathLike | None = None): if self.covars is not None: covars = self.process_continuous() n_covars = len(covars) else: covars = [] n_covars = 1 height_ratios = [1] * n_covars + [7] with plot_figure(output, n_covars + 1, 1, figsize=figsize, gridspec_kw={'height_ratios': height_ratios}, tight_layout=False, sharex=True) as _ax: axes = _ax.ravel() if isinstance(_ax, np.ndarray) else [_ax] # continuous dtype if self.covars is not None: for i, cov in enumerate(covars): if cov.dtype == 'continuous': if cov.value is None: raise ValueError('value is required for continuous covariant') ax = cast(Axes, axes[i]) ax.plot(cov.time, cov.value, color='k') ax.set_xlim(self.time_range) ax.axis('off') ax.set_title(cov.name) else: cast(Axes, axes[0]).axis('off') # rastermap ax = cast(Axes, axes[n_covars]) ax.imshow( self.super_neurons, cmap='gray_r', vmin=0, vmax=0.8, aspect='auto', interpolation='none', extent=(self.time_range[0], self.time_range[1], self.raster.n_clusters, 0), ) ax.set(xlabel='time(s)', ylabel='rastermap clusters') # event segments if self.covars is not None: self.plot_segments(ax=ax, event_colors=event_colors) # colormap n_clusters = self.raster.n_clusters cluster_colors = plt.get_cmap('gist_ncar', n_clusters) cb_ax = inset_axes(ax, width='2%', height='100%', loc='right', bbox_to_anchor=(0.05, 0., 1, 1), bbox_transform=ax.transAxes, borderpad=0) cb_ax.imshow( np.arange(n_clusters)[:, np.newaxis], cmap=cluster_colors, aspect='auto' ) cb_ax.axis('off')
[docs] def plot_segments(self, ax: Axes, event_colors: dict[str, str] | None = None): """ Plot event segments as vertical spans on the axis :param ax: matplotlib Axes object :param event_colors: optional mapping from event names to colors """ legend_patches = [] event_colors = event_colors or {} if self.covars is None: return for i, cov in enumerate(self.covars): if cov.dtype == 'event': color = event_colors.get(cov.name, None) for start, end in cov.time: if self.time_range is not None: t0 = self.time_range[0] t1 = self.time_range[1] if end < t0 or start > t1: continue # outside view start = max(start, t0) end = min(end, t1) ax.axvspan(start, end, color=color, alpha=0.4) legend_patches.append(Patch(facecolor=color, alpha=0.4, label=cov.name)) if legend_patches: ax.legend(handles=legend_patches, loc='upper right', fontsize=10, frameon=False)
[docs] class Covariant(NamedTuple): """ Covariant variable that can be plotted alongside rastermap results. Supports two types: - ``'continuous'``: time-series data (e.g., velocity, position) - ``'event'``: discrete time segments (e.g., trial periods, behavioral events) :param name: name of the covariant variable :param dtype: type of the covariant variable (``'event'`` or ``'continuous'``) :param time: time array. ``Array[float, T]`` for continuous dtype or ``Array[float, [E, 2]]`` for on/off event dtype :param value: value array (only for continuous dtype). ``Array[float, T]`` """ name: str dtype: Literal['event', 'continuous'] time: np.ndarray value: np.ndarray | None = None
[docs] def masking_time(self, t: tuple[float, float]) -> Self: """ Mask data to a specific time range (continuous dtype only). :param t: (START,END) time range :return: new Covariant with data filtered to the time range :raises ValueError: if called on event dtype """ if self.dtype == 'event': raise ValueError('method only available for continuous dtype') if self.value is None: raise ValueError('value is required for continuous dtype') mx = np.logical_and(t[0] < self.time, self.time < t[1]) return self._replace(time=self.time[mx], value=self.value[mx])
[docs] def interp_activity(self, act_time: np.ndarray) -> Self: """ Interpolate data to match another activity time array (continuous dtype only). :param act_time: activity time array to interpolate to. `Array[float, T']` :return: new Covariant with interpolated values matching act_time :raises ValueError: if called on event dtype """ if self.dtype == 'event': raise ValueError('method only available for continuous dtype') if self.value is None: raise ValueError('value is required for continuous dtype') v = interp1d(self.time, self.value, bounds_error=False, fill_value=0)(act_time) return self._replace(time=act_time, value=v)