Source code for neuralib.util.dataframe

from __future__ import annotations

import abc
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload

import numpy as np
import polars as pl
from neuralib.util.verbose import printdf
from polars.dataframe.group_by import GroupBy
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
    from collections.abc import Collection, Iterable, Sequence
    from typing import Concatenate, ParamSpec, Self

    from polars import _typing as pty

    P = ParamSpec('P')

__all__ = ['DataFrameWrapper',
           'helper_with_index_column',
           'assert_polars_equal_verbose']


[docs] class DataFrameWrapper(metaclass=abc.ABCMeta): """ Abstract wrapper class for a `polars.DataFrame`, enabling convenient and composable dataframe operations in a subclassable, object-oriented interface. This base class is intended to be inherited by custom data structures whose core data is represented as a `polars.DataFrame`. It provides a suite of standard dataframe operations (e.g., filtering, sorting, renaming, joining) that return the wrapper instance (`Self`), preserving method chaining and encapsulation. This allows users to write clean, expressive logic using their custom wrapper class while still leveraging the full power of Polars. Subclasses **must** implement the `dataframe` method to get or set the internal `polars.DataFrame`. Examples -------- A minimal subclass that wraps a Polars DataFrame: >>> class MyTable(DataFrameWrapper): ... def __init__(self, data: pl.DataFrame): ... self._data = data ... ... def dataframe(self, dataframe: pl.DataFrame = None, may_inplace=True): ... if dataframe is None: ... return self._data ... if may_inplace: ... self._data = dataframe ... return self ... else: ... return MyTable(dataframe) >>> df = pl.DataFrame({'a': [1, 2, 3], 'b': [10, 20, 30]}) >>> t = MyTable(df) >>> t = t.filter(pl.col("a") > 1).rename({"b": "B"}) >>> print(t.dataframe()) shape: (2, 2) ┌─────┬─────┐ │ a │ B │ ├─────┼─────┤ │ 2 │ 20 │ │ 3 │ 30 │ └─────┴─────┘ Notes ----- - All supported operations delegate to the underlying `polars.DataFrame` and return the modified wrapper instance. - The actual `dataframe` storage and logic is delegated to subclasses via the abstract `dataframe()` getter/setter method. - This class is designed for flexible and extensible use in applications such as data modeling, pipelines, or typed schema handling. Supported Operations -------------------- - Accessors: `columns`, `schema`, `__len__`, `__array__`, `__dataframe__` - Indexing: `__getitem__` - Structure: `filter`, `drop`, `drop_nulls`, `fill_null`, `fill_nan`, `select`, `with_columns`, `with_row_index`, `rename`, `slice`, `head`, `tail`, `limit`, `sort` - Aggregation: `group_by` - Partitioning: `partition_by` - Joining: `join` - Transformation: `pipe`, `clone`, `lazy` See Also -------- polars.DataFrame : The underlying DataFrame API polars.Expr : Expression system used throughout the API """ @overload def dataframe(self) -> pl.DataFrame: pass @overload def dataframe(self, dataframe: pl.DataFrame, may_inplace: bool = True) -> Self: pass
[docs] @abc.abstractmethod def dataframe(self, dataframe: pl.DataFrame | None = None, may_inplace: bool = True) -> pl.DataFrame | Self: """ Getter/setter for the internal Polars DataFrame. :param dataframe: Optional new dataframe to set. :param may_inplace: If True, update current instance. Otherwise, return new instance. :return: The current dataframe or a modified wrapper instance. """ pass
def __len__(self) -> int: """See `polars.DataFrame.__len__`.""" return len(self.dataframe()) @property def columns(self) -> list[str]: """See `polars.DataFrame.columns`.""" return self.dataframe().columns @property def schema(self) -> pl.Schema: """See `polars.DataFrame.schema`.""" return self.dataframe().schema def __array__(self, *args, **kwargs) -> np.ndarray: """See `polars.DataFrame.__array__`.""" return self.dataframe().__array__(*args, **kwargs) def __dataframe__(self, *args, **kwargs): """See `polars.DataFrame.__dataframe__`.""" return self.dataframe().__dataframe__(*args, **kwargs) def __getitem__(self, item: Any) -> Any: """See `polars.DataFrame.__getitem__`.""" return self.dataframe().__getitem__(item)
[docs] def lazy(self) -> LazyDataFrameWrapper[Self]: """Wrap dataframe in a lazy wrapper.""" return LazyDataFrameWrapper(self, self.dataframe().lazy())
[docs] def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> Self: """See `polars.DataFrame.rename`.""" return self.dataframe(self.dataframe().rename(mapping))
[docs] def filter(self, *predicates: pty.IntoExprColumn | Iterable[pty.IntoExprColumn] | bool | list[bool] | np.ndarray, **constraints: Any) -> Self: """See `polars.DataFrame.filter`.""" return self.dataframe(self.dataframe().filter(*predicates, **constraints))
[docs] def slice(self, offset: int, length: int | None = None) -> Self: """See `polars.DataFrame.slice`.""" return self.dataframe(self.dataframe().slice(offset, length))
[docs] def head(self, n: int = 5) -> Self: """See `polars.DataFrame.head`.""" return self.dataframe(self.dataframe().head(n))
[docs] def tail(self, n: int = 5) -> Self: """See `polars.DataFrame.tail`.""" return self.dataframe(self.dataframe().tail(n))
[docs] def limit(self, n: int = 5) -> Self: """See `polars.DataFrame.limit`.""" return self.dataframe(self.dataframe().limit(n))
[docs] def sort(self, by: pty.IntoExpr | Iterable[pty.IntoExpr], *more_by: pty.IntoExpr, descending: bool | Sequence[bool] = False, nulls_last: bool | Sequence[bool] = False, multithreaded: bool = True, maintain_order: bool = False) -> Self: """See `polars.DataFrame.sort`.""" return self.dataframe(self.dataframe().sort( by, *more_by, descending=descending, nulls_last=nulls_last, multithreaded=multithreaded, maintain_order=maintain_order ))
[docs] def drop(self, *columns: pty.ColumnNameOrSelector | Iterable[pty.ColumnNameOrSelector], strict: bool = True) -> Self: """See `polars.DataFrame.drop`.""" return self.dataframe(self.dataframe().drop(*columns, strict=strict))
[docs] def drop_nulls(self, subset: pty.ColumnNameOrSelector | Collection[pty.ColumnNameOrSelector]) -> Self: """See `polars.DataFrame.drop_nulls`.""" return self.dataframe(self.dataframe().drop_nulls(subset))
[docs] def fill_null(self, value: Any | pl.Expr | None = None, strategy: pty.FillNullStrategy | None = None, limit: int | None = None, **kwargs) -> Self: """See `polars.DataFrame.fill_null`.""" return self.dataframe(self.dataframe().fill_null(value, strategy, limit, **kwargs))
[docs] def fill_nan(self, value: pl.Expr | int | float | None = None) -> Self: """See `polars.DataFrame.fill_nan`.""" return self.dataframe(self.dataframe().fill_nan(value))
[docs] def clear(self, n: int = 5) -> Self: """See `polars.DataFrame.clear`.""" return self.dataframe(self.dataframe().clear(n))
[docs] def clone(self) -> Self: """Clone the wrapper.""" return self.dataframe(self.dataframe(), may_inplace=False)
@overload def partition_by(self, by: pty.ColumnNameOrSelector | Iterable[pty.ColumnNameOrSelector], *more_by: pty.ColumnNameOrSelector, maintain_order: bool = True, include_key: bool = True, as_dict: Literal[False] = ...) -> list[Self]: ... @overload def partition_by(self, by: pty.ColumnNameOrSelector | Iterable[pty.ColumnNameOrSelector], *more_by: pty.ColumnNameOrSelector, maintain_order: bool = ..., include_key: bool = ..., as_dict: Literal[True]) -> dict[tuple[object, ...], Self]: ...
[docs] def partition_by(self, by, *more_by, as_dict=False, **kwargs): """See `polars.DataFrame.partition_by`.""" if as_dict: dataframe = self.dataframe().partition_by(by, *more_by, as_dict=True, **kwargs) return {k: self.dataframe(it, may_inplace=False) for k, it in dataframe.items()} else: dataframe = self.dataframe().partition_by(by, *more_by, as_dict=False, **kwargs) return [self.dataframe(it, may_inplace=False) for it in dataframe]
[docs] def select(self, *exprs: pty.IntoExpr | Iterable[pty.IntoExpr], **named_exprs: pty.IntoExpr) -> Self: """See `polars.DataFrame.select`.""" return self.dataframe(self.dataframe().select(*exprs, **named_exprs))
[docs] def with_columns(self, *exprs: pty.IntoExpr | Iterable[pty.IntoExpr], **named_exprs: pty.IntoExpr) -> Self: """See `polars.DataFrame.with_columns`.""" return self.dataframe(self.dataframe().with_columns(*exprs, **named_exprs))
[docs] def with_row_index(self, name: str = "index", offset: int = 0) -> Self: """See `polars.DataFrame.with_row_index`.""" return self.dataframe(self.dataframe().with_row_index(name, offset))
[docs] def join(self, other: pl.DataFrame | DataFrameWrapper, on: str | pl.Expr | Sequence[str | pl.Expr], how: pty.JoinStrategy = "inner", *, left_on=None, right_on=None, suffix: str = "_right", validate: pty.JoinValidation = "m:m", join_nulls: bool = False, coalesce: bool | None = None) -> Self: """See `polars.DataFrame.join`.""" if isinstance(other, DataFrameWrapper): other = other.dataframe() return self.dataframe(self.dataframe().join( other, on, how=how, left_on=left_on, right_on=right_on, suffix=suffix, validate=validate, nulls_equal=join_nulls, coalesce=coalesce ))
[docs] def pipe(self, function: Callable[Concatenate[pl.DataFrame, P], pl.DataFrame], *args: P.args, **kwargs: P.kwargs) -> Self: """See `polars.DataFrame.pipe`.""" return self.dataframe(self.dataframe().pipe(function, *args, **kwargs))
[docs] def group_by(self, *by: pty.IntoExpr | Iterable[pty.IntoExpr], maintain_order: bool = False, **named_by: pty.IntoExpr) -> GroupBy: """See `polars.DataFrame.group_by`.""" return self.dataframe().group_by(*by, maintain_order=maintain_order, **named_by)
T = TypeVar('T', bound=DataFrameWrapper) class LazyDataFrameWrapper(Generic[T]): __slots__ = '__wrapper', '__lazy' def __init__(self, wrapper: T, lazy: pl.LazyFrame): self.__wrapper = wrapper self.__lazy = lazy @property def columns(self) -> list[str]: return self.__lazy.columns @property def schema(self) -> pl.Schema: return self.__lazy.schema def lazy(self) -> LazyDataFrameWrapper[T]: return self def collect(self, **kwargs) -> T: return self.__wrapper.dataframe(self.__lazy.collect(**kwargs)) def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.rename(mapping)) def slice(self, offset: int, length: int | None = None) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.slice(offset, length)) def head(self, n: int = 5) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.head(n)) def tail(self, n: int = 5) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.tail(n)) def limit(self, n: int = 5) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.limit(n)) def clear(self, n: int = 0) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.clear(n)) def filter(self, *predicates: pty.IntoExprColumn | Iterable[pty.IntoExprColumn] | bool | list[bool] | np.ndarray, **constraints: Any) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.filter(*predicates, **constraints)) def sort(self, by: pty.IntoExpr | Iterable[pty.IntoExpr], *more_by: pty.IntoExpr, descending: bool | Sequence[bool] = False, nulls_last: bool | Sequence[bool] = False, multithreaded: bool = True, maintain_order: bool = False) -> LazyDataFrameWrapper[T]: df = self.__lazy.sort( by, *more_by, descending=descending, nulls_last=nulls_last, multithreaded=multithreaded, maintain_order=maintain_order ) return LazyDataFrameWrapper(self.__wrapper, df) def drop(self, *columns: pty.ColumnNameOrSelector | Iterable[pty.ColumnNameOrSelector], strict: bool = True) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.drop(*columns, strict=strict)) def drop_nulls(self, subset: pty.ColumnNameOrSelector | Collection[pty.ColumnNameOrSelector]) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.drop_nulls(subset)) def fill_null(self, value: Any | pl.Expr | None = None, strategy: pty.FillNullStrategy | None = None, limit: int | None = None, **kwargs) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.fill_null(value, strategy, limit, **kwargs)) def fill_nan(self, value: pl.Expr | int | float | None = None) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.fill_nan(value)) def select(self, *exprs: pty.IntoExpr | Iterable[pty.IntoExpr], **named_exprs: pty.IntoExpr) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.select(*exprs, **named_exprs)) def with_columns(self, *exprs: pty.IntoExpr | Iterable[pty.IntoExpr], **named_exprs: pty.IntoExpr) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.with_columns(*exprs, **named_exprs)) def with_row_index(self, name: str = "index", offset: int = 0) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.with_row_index(name, offset)) def join(self, other: pl.DataFrame | pl.LazyFrame | DataFrameWrapper, on: str | pl.Expr | Sequence[str | pl.Expr], how: pty.JoinStrategy = "inner", *, left_on=None, right_on=None, suffix: str = "_right", validate: pty.JoinValidation = "m:m", join_nulls: bool = False, coalesce: bool | None = None) -> LazyDataFrameWrapper[T]: if isinstance(other, DataFrameWrapper): other = other.dataframe() if not isinstance(other, pl.LazyFrame): other = other.lazy() df = self.__lazy.join( other, on, how=how, left_on=left_on, right_on=right_on, suffix=suffix, validate=validate, nulls_equal=join_nulls, coalesce=coalesce ) return LazyDataFrameWrapper(self.__wrapper, df) def pipe(self, function: Callable[Concatenate[pl.LazyFrame, P], pl.LazyFrame], *args: P.args, **kwargs: P.kwargs) -> LazyDataFrameWrapper[T]: return LazyDataFrameWrapper(self.__wrapper, self.__lazy.pipe(function, *args, **kwargs))
[docs] def helper_with_index_column(df: T, column: str, index: int | list[int] | np.ndarray | T, maintain_order: bool = False, strict: bool = False) -> T: """ A help function to do the filter on an index column. :param df: :param column: index column :param index: index array :param maintain_order: keep the ordering of *index* in the returned dataframe. :param strict: all index in *index* should present in the returned dataframe. Otherwise, an error will be raised. :return: :raise RuntimeError: strict mode fail. """ if isinstance(index, (int, np.integer)): index_values = np.asarray([index]) elif isinstance(index, type(df)): index_values = index[column].to_numpy() else: index_values = np.asarray(index) if strict: if len(miss := np.setdiff1d(index_values, df.dataframe()[column].unique().to_numpy())) > 0: raise RuntimeError(f'missing {column}: {list(miss)}') if maintain_order: _column = '_' + column index_frame = pl.DataFrame( {column: index_values}, schema_overrides={column: df.schema[column]} ).with_row_index(_column) ret = df.lazy().join(index_frame, on=column, how='left') ret = ret.filter(pl.col(_column).is_not_null()) return ret.sort(_column).drop(_column).collect() else: return df.filter(pl.col(column).is_in(index_values))
[docs] def assert_polars_equal_verbose(df1: pl.DataFrame, df2: pl.DataFrame, **kwargs): """ Assert that two Polars DataFrames are equal and provide detailed diagnostics if they differ :param df1: The first Polars DataFrame to compare :param df2: The second Polars DataFrame to compare :param kwargs: keyword arguments passed to :func:`~neuralib.util.verbose.printdf()` :return: """ try: assert_frame_equal(df1, df2) print('DataFrames are equal.') except AssertionError as e: print('DataFrames are NOT equal.') # shape print('\nShape mismatch:') print(f'df1: {df1.shape}') print(f'df2: {df2.shape}') # column if df1.columns != df2.columns: print('\nColumn mismatch:') print(f'df1 columns: {df1.columns}') print(f'df2 columns: {df2.columns}') raise e df1_extra = df1.join(df2, on=df1.columns, how='anti') df2_extra = df2.join(df1, on=df1.columns, how='anti') if df1_extra.height > 0: print('\nRows in df1 not in df2:') printdf(df1_extra, **kwargs) if df2_extra.height > 0: print('\nRows in df2 not in df1:') printdf(df2_extra, **kwargs) # If shapes match, show cell-wise diff if df1.shape == df2.shape: print('\nCell-wise differences (non-equal values):') diffs = _highlight_cell_differences(df1, df2) printdf(diffs, **kwargs) raise e
def _highlight_cell_differences(df1: pl.DataFrame, df2: pl.DataFrame) -> pl.DataFrame: return df1.select([ pl.when(df1[col] != df2[col]) .then(pl.lit('df1=') + df1[col].cast(str) + ', df2=' + df2[col].cast(str)) .otherwise('') .alias(col) for col in df1.columns ])