from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Literal, Self, cast
import numpy as np
import polars as pl
from neuralib.atlas.cellatlas import load_cellatlas
from neuralib.atlas.data import load_bg_volumes
from neuralib.atlas.map import DEFAULT_FAMILY_DICT, NUM_MERGE_LAYER, merge_until_level
from neuralib.atlas.typing import HEMISPHERE_TYPE, Channel
from neuralib.atlas.util import get_margin_merge_level
from neuralib.typing import PathLike
from neuralib.util.dataframe import DataFrameWrapper
from neuralib.util.utils import ensure_dir
from neuralib.util.verbose import print_load, print_save
from polars.exceptions import ColumnNotFoundError
__all__ = [
'ROIS_NORM_TYPE',
'RoiClassifierDataFrame',
'RoiNormalizedDataFrame',
'RoiSubregionDataFrame',
]
ROIS_NORM_TYPE = Literal['channel', 'volume', 'cell', 'none']
"""Roi normalized type"""
[docs]
class RoiClassifierDataFrame(DataFrameWrapper):
"""
RoiClassifierDataFrame with each roi (rows)
- Required fields:
- ``acronym`` - area acronym
- ``AP_location``: anterior-posterior coordinates (mm)
- ``DV_location``: dorsal-ventral coordinates (mm)
- ``ML_location``: medial-lateral coordinates (mm)
- ``channel``: fluorescence channel (i.e., gfp, rfp, mcherry, ...)
- ``source``: source name (i.e., if circuit tracing, you can give source tracing area...)
"""
_required_fields = ('acronym', 'AP_location', 'DV_location', 'ML_location', 'channel', 'source')
_valid_classified_fields = ('acronym', 'tree_0', 'tree_1', 'tree_2', 'tree_3', 'tree_4', 'family')
[docs]
def __init__(self, df: pl.DataFrame, *,
cached_dir: PathLike | None = None,
invalid_post_processing_cache: bool = False):
"""
:param df: DataFrame with required fields
:param cached_dir: create cached directory
:param invalid_post_processing_cache: invalid post processing cache if cached_dir is not None
"""
self._df = df
self._cached_dir = cached_dir
self._invalid_post_processing_cache = invalid_post_processing_cache
for field in self._required_fields:
if field not in df.columns:
raise RuntimeError(f'field not found: {field} -> {df.columns}')
self.__allow_inplace = True
def __repr__(self):
return repr(self.dataframe())
[docs]
def dataframe(self, dataframe: pl.DataFrame | None = None, may_inplace: bool = True) -> pl.DataFrame | RoiClassifierDataFrame: # pyright: ignore[reportIncompatibleMethodOverride]
if dataframe is None:
return self._df
else:
return RoiClassifierDataFrame(dataframe, cached_dir=self._cached_dir)
@property
def channels(self) -> list[Channel]:
"""list of channel names"""
return self.dataframe()['channel'].unique().to_list()
@property
def n_channels(self) -> int:
"""number of channel"""
return len(self.channels)
@property
def channel_counts(self) -> pl.DataFrame:
"""channel counts dataframe"""
return self['channel'].value_counts()
@property
def source_counts(self) -> pl.DataFrame:
"""source counts dataframe"""
return self['source'].value_counts()
@property
def sources(self) -> list[Channel]:
"""list of source names"""
return self.dataframe()['source'].unique().to_list()
@property
def n_sources(self) -> int:
"""number of source"""
return len(self.sources)
@property
def is_overlapped_channel(self) -> bool:
"""whether there is overlapped channel"""
return 'overlap' in self.channels
[docs]
def get_channel_source_dict(self) -> dict[str, str]:
"""get channel (key): source (value) dict"""
return {
it[0]: it[1]
for it in self['channel', 'source'].unique().iter_rows()
}
[docs]
def get_classified_column(self, field: int | str | None, strict: bool = True) -> str:
"""classified column name, with the given
:param field: int: ``tree_[level]``; str: ``get column``; None: ``acronym``
:param strict: strict check for the pre-defined classified column
"""
match field:
case int():
col = f'tree_{field}'
case str():
col = field
case None:
col = 'acronym'
case _:
raise TypeError(f'{type(field)}')
if strict:
if col not in self._valid_classified_fields:
raise ValueError(f'invalid field: {field}')
if col not in self.dataframe().columns:
raise ColumnNotFoundError(f'{col} not found: {self.dataframe().columns}')
return col
[docs]
def post_processing(self, *,
filter_capital: bool = True,
tree: bool = True,
family: bool = True,
hemisphere: bool = True,
copy_overlap: bool = True,
filter_injection: tuple[str, str] | None = None) -> Self:
"""
Load the post-processing dataframe
:param filter_capital: filter only ``acronym`` contain capital letters
:param tree: with customized hierarchical tree structure based on allen brain
:param family: with column ``family`` ('HB', 'HY', 'TH', 'MB', 'CB', 'CTXpl', 'HPF', 'ISOCORTEX', 'OLF', 'CTXsp'
:param hemisphere: with column ``hemisphere`` with which hemisphere
:param copy_overlap: Copy overlap channels counts to individual channels, only set true if overlapped roi is not counted for individual channels
:param filter_injection: filter out the injection site labelled
:return:
"""
if self._cached_dir is not None:
file = ensure_dir(self._cached_dir) / 'parsed_roi.csv'
if file.exists() and self._invalid_post_processing_cache:
file.unlink()
# load
if file.exists():
df = pl.read_csv(file)
print_load(file)
return cast(Self, RoiClassifierDataFrame(
df, cached_dir=self._cached_dir,
invalid_post_processing_cache=self._invalid_post_processing_cache
))
# write
ret = self._post_processing(filter_capital, tree, family, hemisphere, copy_overlap, filter_injection)
ret._df.write_csv(file)
print_save(file)
return ret
# return
return self._post_processing(filter_capital, tree, family, hemisphere, copy_overlap, filter_injection)
def _post_processing(self, filter_capital, tree, family, hemisphere, copy_overlap, filter_injection) -> Self:
ret = self
if filter_capital:
ret = ret.filter_capital_name()
if tree:
ret = ret.with_tree_columns()
if family:
ret = ret.with_family_columns()
if hemisphere:
ret = ret.with_hemisphere_column()
if copy_overlap:
ret = ret.with_overlap_copy()
self.__allow_inplace = False
if filter_injection is not None:
ret = ret.filter_injection_site(area=filter_injection[0], hemisphere=filter_injection[1])
return ret
[docs]
def filter_injection_site(self, area: str, hemisphere: str) -> Self:
"""
filter out the injection site labelled
:param area: brain area
:param hemisphere: which hemisphere
:return:
"""
expr1 = pl.col('acronym').str.starts_with(area)
expr2 = pl.col('hemisphere') == hemisphere
return self.filter(~(expr1 & expr2))
[docs]
def filter_capital_name(self) -> Self:
"""filter only ``acronym`` contain capital letters"""
return self.filter(pl.col('acronym').str.contains(r'[A-Z]+'))
[docs]
def with_tree_columns(self) -> Self:
"""with customized hierarchical tree structure based on allen brain
.. seealso::
Wang et al 2020, https://doi.org/10.1016/j.cell.2020.04.007
"""
acronym = self['acronym']
return self.with_columns(
pl.Series(name=f'tree_{level}', values=merge_until_level(acronym, level))
for level in range(NUM_MERGE_LAYER)
)
[docs]
def with_family_columns(self) -> Self:
"""with column ``family`` ('HB', 'HY', 'TH', 'MB', 'CB', 'CTXpl', 'HPF', 'ISOCORTEX', 'OLF', 'CTXsp')"""
def get_family(row) -> str:
for name, family in DEFAULT_FAMILY_DICT.items():
if row in family:
return name
return 'unknown'
return self.with_columns(pl.col('tree_0').map_elements(get_family, return_dtype=pl.Utf8).alias('family'))
[docs]
def with_hemisphere_column(self, invert: bool = False) -> Self:
"""
with column ``hemisphere`` with which hemisphere
:param invert: invert hemisphere. Default ML >= 0 (ipsi), ML < 0 (contra)
:return:
"""
if invert:
expr = pl.when(pl.col('ML_location') < 0)
else:
expr = pl.when(pl.col('ML_location') >= 0)
return self.with_columns(expr.then(pl.lit('ipsi')).otherwise(pl.lit('contra')).alias('hemisphere'))
[docs]
def with_overlap_copy(self) -> Self:
"""Copy overlap channels counts to individual channels, only used if overlapped roi is not counted for individual channels"""
if not self.__allow_inplace:
raise RuntimeError('recurrent copy overlap')
ret: list[pl.DataFrame] = [self._df]
for channel, source in self.get_channel_source_dict().items():
if channel != 'overlap':
df = (
self._df.filter(pl.col('channel') == 'overlap')
.with_columns(pl.lit(channel).alias('channel'))
.with_columns(pl.lit(source).alias('source'))
)
ret.append(df)
return cast(Self, RoiClassifierDataFrame(pl.concat(ret), cached_dir=self._cached_dir))
# ==================== #
# Normalized DataFrame #
# ==================== #
[docs]
def to_normalized(self, norm: ROIS_NORM_TYPE,
level: int | str | None, *,
source: str | None = None,
top_area: int | None = None,
rest_as_others: bool = False,
hemisphere: HEMISPHERE_TYPE = 'both',
animal: str | None = None,
volume_norm_backend: Literal['cellatlas', 'brainglobe'] = 'cellatlas') -> RoiNormalizedDataFrame:
"""
To the normalized dataframe (example as volume normalized) ::
┌─────────┬────────┬────────┬───────────┬────────────┬────────────────┬────────────┐
│ source ┆ tree_2 ┆ counts ┆ fraction ┆ hemisphere ┆ Volumes [mm^3] ┆ normalized │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ u32 ┆ f64 ┆ str ┆ f64 ┆ f64 │
╞═════════╪════════╪════════╪═══════════╪════════════╪════════════════╪════════════╡
│ overlap ┆ ACA ┆ 1208 ┆ 29.997517 ┆ both ┆ 5.222484 ┆ 231.307537 │
│ pRSC ┆ ACA ┆ 3296 ┆ 22.822324 ┆ both ┆ 5.222484 ┆ 631.117254 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ pRSC ┆ VIS ┆ 4035 ┆ 27.939344 ┆ both ┆ 12.957203 ┆ 311.409797 │
│ overlap ┆ VIS ┆ 628 ┆ 15.594736 ┆ both ┆ 12.957203 ┆ 48.46725 │
│ aRSC ┆ VIS ┆ 3865 ┆ 12.627005 ┆ both ┆ 12.957203 ┆ 298.289682 │
└─────────┴────────┴────────┴───────────┴────────────┴────────────────┴────────────┘
:param norm: :attr:`~neuralib.atlas.ccf.classifier.ROIS_NORM_TYPE`
:param level: tree level for determine which level of classified column
:param source: filter only the given ``source``
:param top_area: filter only the given top areas (sorted based on ``fraction``)
:param rest_as_others: vertical concat a region called **other** with the roi filtered by ``top area``
:param hemisphere: filter only the given ``hemisphere``
:param animal: with animal id column ``animal``
:param volume_norm_backend: volume normalization backend. {'cellatlas', 'brainglobe'}
:return: ``RoiNormalizedDataFrame``
"""
cls_col = self.get_classified_column(level)
if hemisphere != 'both':
cols = ['source', cls_col, 'hemisphere']
else:
cols = ['source', cls_col]
df = (
self.dataframe().select(cols).group_by(cols).agg(pl.col(cls_col).count().alias('counts'))
.with_columns((pl.col('counts') / pl.col('counts').sum().over('source') * 100).alias('fraction'))
.sort('fraction', descending=True)
)
#
if hemisphere != 'both':
df = df.filter(pl.col('hemisphere') == hemisphere)
else:
df = df.with_columns(pl.lit(hemisphere).alias('hemisphere')) # add back
#
if source is not None:
df = df.filter(pl.col('source') == source)
#
if top_area is not None:
if source is not None:
df = self._filter_top_region_single_source(df, top_area, source, rest_as_others)
else:
df = self._filter_top_region_all_source(df, top_area, level)
ret = RoiNormalizedDataFrame(df, cls_col, norm)
#
match norm:
case 'volume':
ret = ret.with_density_column(backend=volume_norm_backend)
case 'cell':
ret = ret.with_cell_density_column()
case 'channel':
ret = ret.with_columns(pl.col('fraction').alias('normalized')) # copy from fraction
case 'none':
pass
case _:
raise ValueError(f'invalid norm method: {norm}')
#
if animal is not None:
ret = ret.with_animal_column(animal)
return ret
def _filter_top_region_single_source(self, df, top_area, source, rest_as_others) -> pl.DataFrame:
if top_area > df.shape[0]:
print(f'{top_area} areas exceed, thus use all areas instead')
else:
df = df[:top_area]
others = df['fraction'].sum()
if rest_as_others and (100 - others) > 0:
other_perc = max(0, 100 - others)
total_counts = self.source_counts.filter(pl.col('source') == source)['count'].item()
other_counts = total_counts - df['counts'].sum()
schema = {df.columns[i]: df.dtypes[i] for i in range(df.shape[1])}
row = pl.DataFrame([[source, 'other', other_counts, other_perc, df['hemisphere'][0]]], schema=schema,
orient='row') # `other` row
df = pl.concat([df, row])
return df
def _filter_top_region_all_source(self, df, top_area, level) -> pl.DataFrame:
if top_area > df.shape[0]:
print(f'{top_area} areas exceed, thus use all areas instead')
ref_df = df
else:
ref_df = df[:top_area]
cls_col = self.get_classified_column(level)
region = ref_df[cls_col].unique()
return df.filter(pl.col(cls_col).is_in(region))
[docs]
def to_subregion(self, region: str, *,
unit: Literal['counts', 'fraction'] = 'fraction',
source_order: tuple[str, ...] | None = None,
show_col: str | None = None,
animal: str | None = None,
normalize: bool = True) -> RoiSubregionDataFrame:
"""
To the subregion dataframe (example as Visual region: VIS) ::
┌─────────┬───────────┬───────────┬───────────┬───┬──────────┬──────────┬──────────┬──────────┐
│ source ┆ VISam ┆ VISp ┆ VISpm ┆ … ┆ VISal ┆ VISpor ┆ VISli ┆ VISpl │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ f64 ┆ ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞═════════╪═══════════╪═══════════╪═══════════╪═══╪══════════╪══════════╪══════════╪══════════╡
│ overlap ┆ 39.649682 ┆ 15.127389 ┆ 28.025478 ┆ … ┆ 3.025478 ┆ 2.707006 ┆ 1.592357 ┆ 0.159236 │
│ aRSC ┆ 32.160414 ┆ 28.952135 ┆ 23.05304 ┆ … ┆ 6.080207 ┆ 1.293661 ┆ 2.069858 ┆ 0.07762 │
│ pRSC ┆ 25.947955 ┆ 27.95539 ┆ 27.459727 ┆ … ┆ 3.122677 ┆ 2.973978 ┆ 1.982652 ┆ 1.016109 │
└─────────┴───────────┴───────────┴───────────┴───┴──────────┴──────────┴──────────┴──────────┘
:param region: region name
:param unit: value unit. {'counts', 'fraction'}. default is 'fraction'
:param source_order: source order in dataframe (rows)
:param show_col: force set show col to which level. use case: if a low level area name is classified and show in high level (i.e., TH).
:param animal: with animal column as subregion dataframe
:param normalize: do subregion normalization. True: normalize per source within region (each source sums to 100% within this region),
False: normalize per source to total input counts (fraction of total input for that source). Default is True
:return:
"""
orig_df = self._df
source_order = source_order or tuple(orig_df['source'].unique().to_list())
# query based on lowest tree level
query_col = get_margin_merge_level(orig_df, region, 'lowest')
# show based on highest tree level
show_col = show_col or get_margin_merge_level(orig_df, region, 'highest')
s, lv = show_col.rsplit('_', 1)
next_lv = f'{s}_{int(lv) + 1}'
show_col = f'{s}_{int(lv)}' if next_lv not in orig_df.columns else next_lv
df = (
orig_df
.filter(pl.col(query_col) == region)
.select(['source', show_col])
.group_by(['source', show_col])
.agg(pl.col(show_col).count().alias('counts'))
)
if normalize:
# normalize per source within this region (each source sums to 100% within the region)
df = df.with_columns((pl.col('counts') / pl.col('counts').sum().over('source') * 100).alias('fraction'))
else:
# normalize per source to total input counts for that source
source_totals = orig_df.group_by('source').len(name='total_input_counts')
df = df.join(source_totals, on='source')
df = df.with_columns((pl.col('counts') / pl.col('total_input_counts') * 100).alias('fraction'))
df = df.drop('total_input_counts')
df = df.sort('fraction', descending=True)
# sort
idx = {val: idx for idx, val in enumerate(source_order)}
sort_expr = pl.col('source').replace(idx)
# profile
roi_profile = (
df.group_by('source').agg(pl.col('counts').sum().alias('counts'))
.join(orig_df.group_by('source').len(name='total'), on='source')
.with_columns((pl.col('counts') / pl.col('total')).alias('total_fraction'))
.sort(sort_expr)
)
# main result
ret = (
df.pivot(show_col, index='source', values=unit, aggregate_function='first')
.fill_null(0)
.sort(sort_expr)
)
subregion = RoiSubregionDataFrame(region, ret, roi_profile)
if animal is not None:
subregion = subregion.with_animal_column(animal)
return subregion
[docs]
class RoiNormalizedDataFrame(DataFrameWrapper):
"""
RoiNormalizedDataFrame with each area per row (unique ``source``, ``hemisphere``)
- Required fields:
- ``counts``: roi counts
- ``fraction``: roi fraction for individual sources (aka. per channel(source) normalized)
- ``hemisphere``: which hemisphere
- area column field {'acronym', 'tree_0', 'tree_1', 'tree_2', 'tree_3', 'tree_4'}
- Optional field:
- normalization-specific fields (if not 'none'): ``normalized``, ``Volumes [mm^3]``, ``volume_mm3``, ``n_neurons``
"""
_required_fields = ('counts', 'fraction', 'hemisphere')
[docs]
def __init__(self, df: pl.DataFrame,
classified_column: str,
normalized: ROIS_NORM_TYPE):
"""
:param df: DataFrame with required fields
:param classified_column: classified column for the brain area
:param normalized: :attr:`~neuralib.atlas.ccf.classifier.ROIS_NORM_TYPE`
"""
for field in self._required_fields:
if field not in df.columns:
raise RuntimeError(f'field not found: {field}')
self._df = df
self._classified_column = classified_column
self._normalized: ROIS_NORM_TYPE = normalized
def __repr__(self):
return repr(self.dataframe())
@property
def classified_column(self) -> str:
"""region classified column name"""
return self._classified_column
@property
def normalized(self) -> ROIS_NORM_TYPE:
"""normalization type"""
return self._normalized
@property
def value_column(self) -> str:
"""value column based on the ``normalized``"""
match self._normalized:
case 'volume' | 'cell' | 'channel':
return 'normalized'
case 'none':
return 'counts'
case _:
raise ValueError(f'invalid normalized method: {self._normalized}')
@property
def normalized_unit(self) -> str:
"""unit based on the ``normalized``"""
match self._normalized:
case 'volume':
return 'density (cells-mm3)'
case 'cell':
return 'cell density (%)'
case 'channel':
return 'fraction (%)'
case 'none':
return 'counts'
case _:
raise ValueError(f'invalid normalized unit: {self._normalized}')
[docs]
def dataframe(self, dataframe: pl.DataFrame | None = None, may_inplace: bool = True) -> pl.DataFrame | RoiNormalizedDataFrame: # pyright: ignore[reportIncompatibleMethodOverride]
"""
RoiNormalizedDataFrame (Volume normalized as example) ::
┌─────────┬────────┬────────┬───────────┬────────────┬────────────────┬────────────┐
│ source ┆ tree_2 ┆ counts ┆ fraction ┆ hemisphere ┆ Volumes [mm^3] ┆ normalized │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ u32 ┆ f64 ┆ str ┆ f64 ┆ f64 │
╞═════════╪════════╪════════╪═══════════╪════════════╪════════════════╪════════════╡
│ overlap ┆ ACA ┆ 1208 ┆ 29.997517 ┆ both ┆ 5.222484 ┆ 231.307537 │
│ pRSC ┆ ACA ┆ 3296 ┆ 22.822324 ┆ both ┆ 5.222484 ┆ 631.117254 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ pRSC ┆ VIS ┆ 4035 ┆ 27.939344 ┆ both ┆ 12.957203 ┆ 311.409797 │
│ overlap ┆ VIS ┆ 628 ┆ 15.594736 ┆ both ┆ 12.957203 ┆ 48.46725 │
│ aRSC ┆ VIS ┆ 3865 ┆ 12.627005 ┆ both ┆ 12.957203 ┆ 298.289682 │
└─────────┴────────┴────────┴───────────┴────────────┴────────────────┴────────────┘
"""
if dataframe is None:
return self._df
else:
ret = RoiNormalizedDataFrame(dataframe, self._classified_column, self._normalized)
return ret
[docs]
def with_density_column(self, backend: Literal['cellatlas', 'brainglobe'] = 'cellatlas') -> Self:
"""
:param backend: Volume information calculated from which backend. {'cellatlas', 'brainglobe'}
:return:
"""
match backend:
case 'cellatlas':
df_cellatlas = load_cellatlas().select('Volumes [mm^3]', 'acronym').rename(
{'acronym': self.classified_column})
return (self.join(df_cellatlas, on=self.classified_column)
.with_columns((pl.col('counts') / pl.col('Volumes [mm^3]')).alias('normalized')))
case 'brainglobe':
df_bg_volume = load_bg_volumes().select('volume_mm3', 'acronym').rename(
{'acronym': self.classified_column})
return (self.join(df_bg_volume, on=self.classified_column)
.with_columns((pl.col('counts') / pl.col('volume_mm3')).alias('normalized')))
case _:
raise ValueError(f'invalid backend: {backend}')
[docs]
def with_cell_density_column(self) -> Self:
"""Normalized to number of neurons foreach brain region (based on ``CellAtlas`` data source)"""
df_cellatlas = load_cellatlas().select('n_neurons', 'acronym').rename({'acronym': self.classified_column})
return (self.join(df_cellatlas, on=self.classified_column)
.with_columns((pl.col('counts') / pl.col('n_neurons')).alias('normalized')))
[docs]
def with_animal_column(self, animal) -> Self:
"""with animal id column"""
return self.with_columns(pl.lit(animal).alias('animal'))
[docs]
def filter_areas(self, areas: str | list[str]) -> Self:
"""filter the dataframe with specified areas"""
if isinstance(areas, str):
areas = [areas]
ret = self.filter(pl.col(self.classified_column).is_in(areas))
if ret._df.is_empty():
raise ValueError(f'{areas} not found')
return ret
[docs]
def filter_sources(self, source: str | list[str]) -> Self:
"""filter the dataframe with specified sources"""
if isinstance(source, str):
source = [source]
ret = self.filter(pl.col('source').is_in(source))
if ret._df.is_empty():
raise ValueError(f'{source} not found')
return ret
[docs]
def to_bias_index(self, source_a: str, source_b: str) -> pl.DataFrame:
"""
Bias Index dataframe used to determine bias within two sources
(positive value toward ``source a`` and negative value toward ``source b``) ::
┌────────┬────────────┐
│ tree_2 ┆ bias_index │
│ --- ┆ --- │
│ str ┆ f64 │
╞════════╪════════════╡
│ ATN ┆ -1.192889 │
│ VIS ┆ -1.145786 │
│ CLA ┆ -0.86059 │
│ SUB ┆ -0.478069 │
│ STRd ┆ -0.463589 │
│ … ┆ … │
│ ENT ┆ 0.580593 │
│ AUD ┆ 0.610688 │
│ PTLp ┆ 1.292926 │
│ MO ┆ 1.945567 │
│ SS ┆ 2.163074 │
└────────┴────────────┘
:param source_a: source a string
:param source_b: source b string
:return:
"""
expr_calc = (pl.col(source_a) / pl.col(source_b)).map_elements(np.log2, return_dtype=pl.Float64)
df = (
self._df
.select(self.classified_column, 'source', 'fraction')
.sort(self.classified_column, 'source')
.pivot(values='fraction', index=self.classified_column, on='source', aggregate_function='first')
.fill_null(0)
.with_columns(expr_calc.alias('bias_index'))
.filter(~pl.col('bias_index').is_infinite()) # log2 index calc fail
.select(self.classified_column, 'bias_index')
.sort(by='bias_index')
)
return df
[docs]
def to_winner(self, sources: Sequence[str]) -> pl.DataFrame:
"""
Winner dataframe used for plotting (i.e., ternary plot) ::
┌────────┬─────────┬──────┬──────┬───────┬────────┐
│ tree_2 ┆ overlap ┆ pRSC ┆ aRSC ┆ total ┆ winner │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ u32 ┆ u32 ┆ u32 ┆ u32 ┆ str │
╞════════╪═════════╪══════╪══════╪═══════╪════════╡
│ ACA ┆ 1208 ┆ 3296 ┆ 5761 ┆ 9057 ┆ aRSC │
│ VIS ┆ 628 ┆ 4035 ┆ 3865 ┆ 7900 ┆ pRSC │
│ MO ┆ 460 ┆ 714 ┆ 5829 ┆ 6543 ┆ aRSC │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ AUD ┆ 44 ┆ 165 ┆ 534 ┆ 699 ┆ aRSC │
│ TEa ┆ 34 ┆ 206 ┆ 358 ┆ 564 ┆ aRSC │
└────────┴─────────┴──────┴──────┴───────┴────────┘
:param sources: source sequences for calculating the total. The above case should be specified as ['aRSC', 'pRSC']
:return: Winner dataframe
"""
df = (
self._df
.pivot(values=self.value_column, on='source', index=self.classified_column, aggregate_function='first')
.fill_nan(0)
.fill_null(0)
.with_columns(pl.sum_horizontal(sources).alias('total'))
)
region_counts = df.select(sources).to_numpy()
winner_idx = np.argmax(region_counts, axis=1).astype(int)
df = df.with_columns(pl.Series([sources[idx] for idx in winner_idx]).alias('winner'))
return df
[docs]
class RoiSubregionDataFrame(DataFrameWrapper):
"""RoiSubregionDataFrame with each source per row, column shows the subregions"""
_profile_required_fields = ('source', 'counts', 'total', 'total_fraction')
[docs]
def __init__(self, region: str, df: pl.DataFrame, profile: pl.DataFrame):
"""
:param region: region name
:param df: subregion dataframe
:param profile: profile dataframe
"""
self._region = region
self._df = df
self._profile = profile
for field in self._profile_required_fields:
if field not in profile.columns:
raise RuntimeError(f'field not found: {field} -> {df.columns}')
def __repr__(self):
return repr(self.dataframe())
[docs]
def dataframe(self, dataframe: pl.DataFrame | None = None, may_inplace: bool = True) -> pl.DataFrame | RoiSubregionDataFrame: # pyright: ignore[reportIncompatibleMethodOverride]
"""
RoiSubregionDataFrame (VIS as example)::
┌─────────┬───────────┬───────────┬───────────┬───┬──────────┬──────────┬──────────┬──────────┐
│ source ┆ VISam ┆ VISp ┆ VISpm ┆ … ┆ VISal ┆ VISpor ┆ VISli ┆ VISpl │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ f64 ┆ ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞═════════╪═══════════╪═══════════╪═══════════╪═══╪══════════╪══════════╪══════════╪══════════╡
│ overlap ┆ 39.649682 ┆ 15.127389 ┆ 28.025478 ┆ … ┆ 3.025478 ┆ 2.707006 ┆ 1.592357 ┆ 0.159236 │
│ aRSC ┆ 32.160414 ┆ 28.952135 ┆ 23.05304 ┆ … ┆ 6.080207 ┆ 1.293661 ┆ 2.069858 ┆ 0.07762 │
│ pRSC ┆ 25.947955 ┆ 27.95539 ┆ 27.459727 ┆ … ┆ 3.122677 ┆ 2.973978 ┆ 1.982652 ┆ 1.016109 │
└─────────┴───────────┴───────────┴───────────┴───┴──────────┴──────────┴──────────┴──────────┘
"""
if dataframe is None:
return self._df
else:
ret = RoiSubregionDataFrame(self._region, dataframe, self._profile)
return ret
@property
def region(self) -> str:
"""region name"""
return self._region
@property
def subregion(self) -> list[str]:
"""list of subregion names"""
return self.drop('source').columns
@property
def n_subregion(self) -> int:
"""number of subregion"""
return len(self.subregion)
@property
def profile(self) -> pl.DataFrame:
"""
with channel-wise profile::
┌─────────┬────────┬───────┬────────────────┐
│ source ┆ counts ┆ total ┆ total_fraction │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ u32 ┆ u32 ┆ f64 │
╞═════════╪════════╪═══════╪════════════════╡
│ overlap ┆ 628 ┆ 4027 ┆ 0.155947 │
│ aRSC ┆ 3865 ┆ 30609 ┆ 0.12627 │
│ pRSC ┆ 4035 ┆ 14442 ┆ 0.279393 │
└─────────┴────────┴───────┴────────────────┘
"""
return self._profile
@property
def sources(self) -> list[str]:
"""list of source names"""
return self['source'].to_list()
[docs]
def filter_overlap(self) -> Self:
"""filter out overlap source"""
expr = pl.col('source') != 'overlap'
self._profile = self._profile.filter(expr)
return self.filter(expr)
[docs]
def with_animal_column(self, animal) -> Self:
"""with animal id column"""
return self.with_columns(pl.lit(animal).alias('animal'))
[docs]
def to_dict(self, as_series: bool = True) -> dict[str, Any]:
"""to subregion:value dict"""
return self.dataframe().select(pl.exclude('source')).to_dict(as_series=as_series)
[docs]
def to_numpy(self) -> np.ndarray:
"""to value array. `Array[float, [n_source, n_subregion]]`"""
return self.dataframe().select(pl.exclude('source')).to_numpy()