import re
import zipfile
from pathlib import Path
from typing import Literal, Required, TypedDict, cast, get_args
from urllib.request import urlopen
import numpy as np
import requests
import tensorflow as tf
import tensorflow.keras # pyright: ignore[reportMissingModuleSource]
from neuralib.io import CASCADE_MODEL_CACHE_DIRECTORY
from neuralib.typing import PathLike
from neuralib.util.utils import ensure_dir
from neuralib.util.verbose import fprint
from ruamel.yaml import YAML
from scipy.ndimage import binary_dilation, gaussian_filter
__all__ = [
'CASCADE_MODEL_TYPE',
'CascadeModelConfig',
'cascade_predict',
'CascadeSpikePrediction'
]
CASCADE_MODEL_TYPE = Literal[
'Global_EXC_1Hz_smoothing500ms',
'Global_EXC_1Hz_smoothing1000ms',
'Zebrafish_1Hz_smoothing1000ms',
'Global_EXC_2Hz_smoothing300ms',
'Global_EXC_2Hz_smoothing500ms',
'Global_EXC_2Hz_smoothing1000ms',
'Global_EXC_2.5Hz_smoothing400ms_high_noise',
'Global_EXC_3Hz_smoothing400ms',
'Global_EXC_3Hz_smoothing400ms_high_noise',
'Global_EXC_3Hz_smoothing400ms_causalkernel',
'Global_EXC_4.25Hz_smoothing300ms',
'Global_EXC_4.25Hz_smoothing300ms_high_noise',
'Global_EXC_4.25Hz_smoothing300ms_causalkernel',
'Global_EXC_5Hz_smoothing200ms',
'Global_EXC_5Hz_smoothing200ms_causalkernel',
'Global_EXC_6Hz_smoothing200ms',
'Global_EXC_6Hz_smoothing200ms_causalkernel',
'Global_EXC_7Hz_smoothing200ms',
'Global_EXC_7Hz_smoothing200ms_causalkernel',
'Global_EXC_7.5Hz_smoothing200ms_high_noise',
'Global_EXC_7.5Hz_smoothing200ms',
'Global_EXC_7.5Hz_smoothing200ms_causalkernel',
'OGB_zf_pDp_7.5Hz_smoothing200ms',
'OGB_zf_pDp_7.5Hz_smoothing200ms_causalkernel',
'Global_EXC_10Hz_smoothing50ms',
'Global_EXC_10Hz_smoothing50ms_causalkernel',
'Global_EXC_10Hz_smoothing100ms',
'Global_EXC_10Hz_smoothing100ms_causalkernel',
'Global_EXC_10Hz_smoothing200ms',
'Global_EXC_10Hz_smoothing200ms_causalkernel',
'Global_EXC_12.5Hz_smoothing100ms',
'Global_EXC_12.5Hz_smoothing100ms_causalkernel',
'Global_EXC_12.5Hz_smoothing200ms',
'Global_EXC_12.5Hz_smoothing200ms_causalkernel',
'Global_EXC_15Hz_smoothing50ms',
'Global_EXC_15Hz_smoothing50ms_causalkernel',
'Global_EXC_15Hz_smoothing100ms_high_noise',
'Global_EXC_15Hz_smoothing100ms',
'Global_EXC_15Hz_smoothing100ms_causalkernel',
'Global_EXC_15Hz_smoothing200ms',
'Global_EXC_15Hz_smoothing200ms_causalkernel',
'Global_INH_15Hz_smoothing100ms',
'Global_EXC_17.5Hz_smoothing100ms',
'Global_EXC_17.5Hz_smoothing100ms_causalkernel',
'Global_EXC_17.5Hz_smoothing200ms',
'Global_EXC_17.5Hz_smoothing200ms_causalkernel',
'Global_EXC_20Hz_smoothing100ms',
'Global_EXC_20Hz_smoothing100ms_causalkernel',
'Global_EXC_20Hz_smoothing200ms',
'Global_EXC_20Hz_smoothing200ms_causalkernel',
'Global_EXC_25Hz_smoothing100ms',
'Global_EXC_25Hz_smoothing100ms_causalkernel',
'Global_EXC_25Hz_smoothing50ms',
'Global_EXC_25Hz_smoothing50ms_causalkernel',
'Global_EXC_30Hz_smoothing25ms',
'Global_EXC_30Hz_smoothing25ms_causalkernel',
'Global_EXC_30Hz_smoothing50ms',
'Global_EXC_30Hz_smoothing50ms_high_noise',
'Global_EXC_30Hz_smoothing50ms_causalkernel',
'Global_EXC_30Hz_smoothing100ms',
'Global_EXC_30Hz_smoothing100ms_causalkernel',
'Global_EXC_30Hz_smoothing200ms',
'Global_EXC_30Hz_smoothing100ms_causalkernel_high_noise',
'Global_EXC_30Hz_smoothing100ms_high_noise',
'Global_EXC_30Hz_smoothing200ms_causalkernel_high_noise',
'Global_EXC_40Hz_smoothing25ms_causalkernel',
'Global_EXC_40Hz_smoothing25ms',
'Global_EXC_40Hz_smoothing25ms_high_noise',
'Global_EXC_40Hz_smoothing50ms',
'Global_EXC_40Hz_smoothing50ms_high_noise',
'Global_EXC_40Hz_smoothing50ms_causalkernel',
'Global_INH_30Hz_smoothing50ms',
'Global_INH_30Hz_smoothing100ms',
'Global_EXC_30Hz_smoothing50ms_asymmetric_window_1_frame',
'Global_EXC_30Hz_smoothing50ms_asymmetric_window_2_frames',
'Global_EXC_30Hz_smoothing50ms_asymmetric_window_4_frames',
'Global_EXC_30Hz_smoothing50ms_asymmetric_window_6_frames',
'Global_EXC_30Hz_smoothing50ms_asymmetric_window_8_frames',
'GCaMP6f_mouse_30Hz_smoothing200ms',
'Spinal_cord_excitatory_30Hz_smoothing50ms',
'Spinal_cord_inhibitory_30Hz_smoothing50ms',
'Spinal_cord_excitatory_3Hz_smoothing400ms_high_noise',
'Spinal_cord_inhibitory_3Hz_smoothing400ms_high_noise',
'Spinal_cord_excitatory_2.5Hz_smoothing400ms',
'Spinal_cord_inhibitory_2.5Hz_smoothing400ms',
'GC8_EXC_5Hz_smoothing400ms_high_noise',
'GC8_EXC_5Hz_smoothing800ms_high_noise',
'GC8_EXC_7.5Hz_smoothing100ms_high_noise',
'GC8_EXC_7.5Hz_smoothing200ms_high_noise',
'GC8_EXC_10Hz_smoothing150ms_high_noise',
'GC8_EXC_10Hz_smoothing75ms_high_noise',
'GC8_EXC_15Hz_smoothing100ms_high_noise',
'GC8_EXC_15Hz_smoothing50ms_high_noise',
'GC8_EXC_30Hz_smoothing25ms_high_noise',
'GC8_EXC_30Hz_smoothing50ms_high_noise',
'GC8_EXC_40Hz_smoothing15ms_high_noise',
'GC8_EXC_40Hz_smoothing30ms_high_noise'
]
[docs]
def cascade_predict(dff: np.ndarray,
model_type: CASCADE_MODEL_TYPE,
*,
threshold: int | bool = 0,
padding: float = 0,
verbose: bool = True,
chunks_mode_limit: float = 10,
cache_dir: PathLike | None = None) -> np.ndarray:
"""
Spike prediction using Cascade pretrained model
:param dff: dF/F activity to be predicted. `Array[float, [N, F]|F]`
:param model_type: ``MODEL_TYPE``
:param threshold: Allowed values: 0, 1 or False.
0: All negative values are set to 0.
1 or True: Threshold signal to set every signal which is smaller than the expected signal size of an action potential to zero (with dilated mask)
False: No thresholding. The result can contain negative values as well
:param padding: Value which is inserted for datapoints, where no prediction can be made
(because of window around timepoint of prediction).
Default value: np.nan, another recommended value would be 0 which circumvents some problems with following analysis.
:param verbose: Verbose of model information
:param chunks_mode_limit: Decrease the number if memory issue dealing with large input arrays.
:param cache_dir: Cache directory for saving the model. If None, then used default under `~/.cache/neuralib`
:return: Spiking probability as predicted by the model. `Array[float, [N, F]|F]`
"""
cascade = CascadeSpikePrediction(
dff,
model_type,
threshold=threshold,
padding=padding,
verbose=verbose,
chunks_mode_limit=chunks_mode_limit,
cache_dir=cache_dir
)
spike = cascade.run_spike_prediction()
return spike
[docs]
class CascadeModelConfig(TypedDict, total=False):
model_name: str
"""Name of the model"""
sampling_rate: Required[int]
"""Sampling rate in Hz"""
training_datasets: Required[list[str]]
"""Dataset of ground truth data (in folder 'Ground_truth')"""
placeholder_1: int
"""protect formatting"""
noise_levels: Required[list[int]]
"""Noise levels for training (integers, normally 1-9)"""
placeholder_2: int
"""protect formatting"""
smoothing: Required[float]
"""Standard deviation of Gaussian smoothing in time (sec)"""
causal_kernel: Required[int]
"""Smoothing kernel is symmetric in time (0) or is causal (1)"""
windowsize: Required[int]
"""Windowsize in timepoints"""
before_frac: Required[float]
"""Fraction of timepoints before prediction point (0-1)"""
filter_sizes: list[int]
"""Filter sizes for each convolutional layer"""
filter_numbers: list[int]
"""Filter numbers for each convolutional layer"""
dense_expansion: int
"""For dense layer"""
loss_function: str
"""gradient-descent loss function"""
optimizer: str
"""Adagrad"""
nr_of_epochs: int
"""Number of training epochs per model"""
ensemble_size: Required[int]
"""Number of models trained for one noise level"""
batch_size: Required[int]
"""Batch size"""
training_finished: Literal['Yes', 'No', 'Running']
"""Yes / No / Running"""
verbose: Required[int]
"""level of status messages (0: minimal, 1: standard, 2: most, 3: all)"""
[docs]
class CascadeSpikePrediction:
[docs]
def __init__(
self,
dff: np.ndarray,
model_type: CASCADE_MODEL_TYPE,
*,
threshold: int | bool = 0,
padding: float = 0,
verbose: bool = True,
chunks_mode_limit: float = 10,
cache_dir: PathLike | None = None
):
"""
:param dff: dF/F activity to be predicted. `Array[float, [N, F]|F]`
:param model_type: ``MODEL_TYPE``
:param threshold: Allowed values: 0, 1 or False.
0: All negative values are set to 0.
1 or True: Threshold signal to set every signal which is smaller than the expected signal size of an action potential to zero (with dilated mask)
False: No thresholding. The result can contain negative values as well
:param padding: Value which is inserted for datapoints, where no prediction can be made (because of window around timepoint of prediction).
Default value: np.nan, another recommended value would be 0 which circumvents some problems with following analysis.
:param verbose: Verbose of model information
:param chunks_mode_limit: Decrease the number if memory issue dealing with large input arrays.
:param cache_dir: Cache directory for saving the model. If None, then used default under `~/.cache/neuralib`
"""
self.dff = dff
self.model_type = model_type
if cache_dir is not None:
cache_dir = Path(cache_dir)
self.cache_dir = cache_dir or CASCADE_MODEL_CACHE_DIRECTORY
# model io
if not self.available_model_yaml.exists():
self._download_model_yaml()
self._check_instance()
if not self.model_dir.exists():
self._download_model()
# predict instance
self.threshold = threshold
self.padding = padding
self.verbose = verbose
self.chunks_mode_limit = chunks_mode_limit
def _download_model_yaml(self):
url = 'https://raw.githubusercontent.com/HelmchenLabSoftware/Cascade/master/Pretrained_models/available_models.yaml'
response = requests.get(url)
if response.status_code == 200:
ensure_dir(self.cache_dir)
out = self.available_model_yaml
with open(out, 'wb') as file:
file.write(response.content)
fprint(f'YAML file saved successfully as {out}', vtype='io')
else:
raise RuntimeError(f'Failed to download the file. Status code: {response.status_code}')
def _check_instance(self):
available_models = self.get_available_models()
if self.model_type not in available_models:
raise ValueError(f'{self.model_type} should be one of the {available_models}')
# check update
if set(available_models) != set(get_args(CASCADE_MODEL_TYPE)):
fprint('update MODEL_TYPE, new model released', vtype='warning')
def _download_model(self):
link = self.model_link
with urlopen(link) as response:
data = response.read()
tmp_file = self.cache_dir / 'tmp_zipped_model.zip'
with open(tmp_file, 'wb') as f:
f.write(data)
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
zip_ref.extractall(path=self.model_dir)
tmp_file.unlink()
fprint(f'Pretrained model was saved in folder {self.model_dir}', vtype='io')
# ============== #
# All Model Info #
# ============== #
@property
def available_model_yaml(self) -> Path:
"""All models link/info yaml"""
return self.cache_dir / 'available_models.yaml'
[docs]
def get_available_models(self) -> list[CASCADE_MODEL_TYPE]:
"""Get all the available in :attr:`available_model_yaml`"""
content = YAML().load(self.available_model_yaml)
models = list(content.keys())
return models
# ============== #
# A Model Config #
# ============== #
@property
def model_link(self) -> str:
"""Link of specified model"""
with open(self.available_model_yaml) as file:
config = YAML().load(file)
return config[self.model_type]['Link']
@property
def model_dir(self) -> Path:
"""Directory of specified model"""
return self.cache_dir / self.model_type
@property
def config_file(self):
"""Config filepath of specified model"""
return self.model_dir / 'config.yaml'
[docs]
def get_config(self) -> CascadeModelConfig:
"""``ModelConfig`` of specified model"""
with open(self.config_file) as file:
return YAML().load(file)
# ================ #
# Spike Prediction #
# ================ #
[docs]
def run_spike_prediction(self) -> np.ndarray:
"""
Spike prediction
:return: Spiking probability as predicted by the model. `Array[float, [N, F]|F]`
"""
total_array_size = self.dff.itemsize * self.dff.size * 64 / 1e9
if total_array_size < self.chunks_mode_limit:
spike = self._predict(self.dff, self.threshold, self.padding, self.verbose)
else:
n_neurons = self.dff.shape[0]
n_frames = self.dff.shape[1]
spike = np.zeros((n_neurons, n_frames))
nb_chunks = int(np.ceil(total_array_size / 10))
chunks = np.array_split(range(n_neurons), nb_chunks)
for part in range(nb_chunks):
part_dff = self.dff[chunks[part], :]
spike[chunks[part], :] = self._predict(part_dff, self.threshold, self.padding, self.verbose)
return spike
def _predict(self, dff, threshold=0, padding: float = np.nan, verbose: bool = True):
# reshape if only a single neuron's activity is provided
if len(dff.shape) == 1:
dff = np.expand_dims(dff, 0)
cfg = self.get_config()
cfg_verbose = cfg["verbose"]
training_data = cfg["training_datasets"]
ensemble_size = cfg["ensemble_size"]
batch_size = cfg["batch_size"]
sampling_rate = cfg["sampling_rate"]
before_frac = cfg["before_frac"]
window_size = cfg["windowsize"]
noise_levels_model = cfg["noise_levels"]
smoothing = cfg["smoothing"]
causal_kernel = cfg["causal_kernel"]
# calculate noise levels for each trace
trace_noise_levels = calculate_noise_levels(dff, sampling_rate)
# Get model paths as dictionary (key: noise_level) with lists of model path
model_dict = get_model_paths(self.model_dir)
if cfg_verbose:
msg = (f'The selected model was trained on {len(training_data)} datasets, '
f'with {ensemble_size} ensembles for each noise level, at a sampling rate of {sampling_rate} Hz')
if causal_kernel:
msg += ", with a resampled ground truth that was smoothed with a causal kernel"
else:
msg += ", with a resampled ground truth that was smoothed with a Gaussian kernel"
msg += f'of a standard deviation of {str(int(1000 * smoothing))} ms.'
fprint(msg)
fprint(f'Loaded model was trained at frame rate {sampling_rate} Hz')
fprint(f'Given argument traces contains {dff.shape[0]} neurons and {dff.shape[1]} frames.')
noise_mean = str(int(np.nanmean(trace_noise_levels * 100)) / 100)
noise_std = str(int(np.nanstd(trace_noise_levels * 100)) / 100)
fprint(f'Noise levels (mean, std; in standard units): {noise_mean}, {noise_std}')
# XX has shape: (neurons, timepoints, windows)
XX = preprocess_traces(dff, before_frac=before_frac, window_size=window_size)
Y_predict = np.zeros((XX.shape[0], XX.shape[1]))
# Compute difference of noise levels between each neuron and each model; find the best fit
differences = np.array(trace_noise_levels)[:, None] - np.array(noise_levels_model)[None, :]
relative_differences = np.min(differences, axis=1)
if np.mean(relative_differences) > 2:
fprint(
f"WARNING: The available models cannot match the experimentally obtained noise levels (difference: {str(np.mean(relative_differences))}),"
f"Please check that the computation of dF/F is performed correctly. Otherwise, please reach out and ask for pretrained models with higher noise level models "
f"(see: https://github.com/HelmchenLabSoftware/Cascade/issues/61).", vtype='warning'
)
best_model_for_each_neuron = np.argmin(np.abs(differences), axis=1)
# Use for each noise level the matching model
for i, noise_level in enumerate(noise_levels_model):
# select neurons which have this noise level:
neuron_idx = np.where(best_model_for_each_neuron == i)[0]
if cfg_verbose:
print(f'\nPredictions for noise level {noise_level}')
if len(neuron_idx) == 0: # no neurons were selected
if cfg_verbose:
print(f"\tNo neurons for this noise level: {noise_level}")
continue # jump to next noise level
# load keras models for the given noise level
models = []
for model_path in model_dict[noise_level]:
models.append(tf.keras.models.load_model(model_path))
# select neurons and merge neurons and timepoints into one dimension
XX_sel = XX[neuron_idx, :, :]
XX_sel = np.reshape(XX_sel, (XX_sel.shape[0] * XX_sel.shape[1], XX_sel.shape[2]))
XX_sel = np.expand_dims(XX_sel, axis=2) # add empty third dimension to match training shape
for j, model in enumerate(models):
if cfg_verbose:
print("\t... ensemble", j)
prediction_flat = model.predict(XX_sel, batch_size, verbose=cfg_verbose)
prediction = np.reshape(prediction_flat, (len(neuron_idx), XX.shape[1]))
Y_predict[neuron_idx, :] += prediction / len(models) # average predictions
# remove models from memory
tensorflow.keras.backend.clear_session()
# handle threshold
if threshold is False:
fprint('Skipping the thresholding. There can be negative values in the result.', vtype='warning')
elif threshold == 1: # or True
# Cut off noise floor (lower than 1/e of a single action potential)
# find out empirically how large a single AP is (depends on frame rate and smoothing)
single_spike = np.zeros(1001)
single_spike[501] = 1
single_spike_smoothed = gaussian_filter(single_spike.astype(float), sigma=smoothing * sampling_rate)
threshold_value = np.max(single_spike_smoothed) / np.exp(1)
# Set everything below threshold to zero.
# Use binary dilation to avoid clipping of true events.
for neuron in range(Y_predict.shape[0]):
# ignore warning because of nan's in Y_predict in comparison with value
with np.errstate(invalid="ignore"):
activity_mask = Y_predict[neuron, :] > threshold_value
activity_mask = cast(np.ndarray, binary_dilation(activity_mask, iterations=int(smoothing * sampling_rate))).astype(bool)
Y_predict[neuron, ~activity_mask] = 0
Y_predict[Y_predict < 0] = 0 # set possible negative values in dilated mask to 0
elif threshold == 0:
# ignore warning because of nan's in Y_predict in comparison with value
with np.errstate(invalid="ignore"):
Y_predict[Y_predict < 0] = 0
else:
raise ValueError(f'Invalid value of threshold "{threshold}". Only 0, 1 (or True) or False allowed')
# NaN or 0 for first and last datapoints, for which no predictions can be made
Y_predict[:, 0: int(before_frac * window_size)] = padding
Y_predict[:, -int((1 - before_frac) * window_size):] = padding
return Y_predict
def preprocess_traces(dff: np.ndarray,
before_frac: float,
window_size: int) -> np.ndarray:
"""
Transform dF/F data into a format that can be used by the deep network.
For each time point, a window of the size 'window_size' of the dF/F is extracted.
:param dff: `Array[float, [N, F]]`
:param before_frac: positioning of the window around the current time point; 0.5 means center position
:param window_size: size of the receptive window of the deep network
:return: a matrix with `Array[float, [N, F, W]]`
"""
start = int(before_frac * window_size - 1)
end = dff.shape[1] - window_size + start + 1
# extract a moving window from the calcium trace
window_indexes = (
np.expand_dims(np.arange(window_size), 0) +
np.expand_dims(np.arange(dff.shape[1] - window_size + 1), 0).T
)
X = np.full((*dff.shape, window_size), np.nan)
X[:, start:end, :] = dff[:, window_indexes]
return X
def get_model_paths(model_path: Path) -> dict[int, list[Path]]:
"""Find all models in the model folder and return as dictionary"""
all_models = sorted(list(model_path.glob('*.h5')))
if len(all_models) == 0:
raise FileNotFoundError(f'No models (*.h5 files) were found in the specified folder "{model_path}".')
# dictionary with key for noise level, entries are lists of models
model_dict = {}
for model_path in all_models:
noise_level = int(re.findall(r"_NoiseLevel_(\d+)", str(model_path))[0])
if noise_level not in model_dict:
model_dict[noise_level] = list()
model_dict[noise_level].append(model_path)
return model_dict
def calculate_noise_levels(dff: np.ndarray, frame_rate: float) -> np.ndarray:
"""
Computes the noise levels for each neuron of the input matrix 'dF_traces'.
The noise level is computed as the median absolute dF/F difference
between two subsequent time points. This is a outlier-robust measurement
that converges to the simple standard deviation of the dF/F trace for
uncorrelated and outlier-free dF/F traces.
Afterwards, the value is divided by the square root of the frame rate
in order to make it comparable across recordings with different frame rates.
:param dff: Fluorescence changed dF/F. `Array[float, [N, F]]`
:param frame_rate: frame rate
:return: vector of noise levels for all neurons
"""
noise_levels = np.nanmedian(np.abs(np.diff(dff, axis=-1)), axis=-1) / np.sqrt(frame_rate)
return noise_levels * 100 # scale noise levels to percent