from __future__ import annotations
import abc
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload
import numpy as np
import polars as pl
from neuralib.util.verbose import printdf
from polars.dataframe.group_by import GroupBy
from polars.testing import assert_frame_equal
if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Sequence
from typing import Concatenate, ParamSpec, Self
from polars import _typing as pty
P = ParamSpec('P')
__all__ = ['DataFrameWrapper',
'helper_with_index_column',
'assert_polars_equal_verbose']
[docs]
class DataFrameWrapper(metaclass=abc.ABCMeta):
"""
Abstract wrapper class for a `polars.DataFrame`, enabling convenient and composable
dataframe operations in a subclassable, object-oriented interface.
This base class is intended to be inherited by custom data structures whose core data
is represented as a `polars.DataFrame`. It provides a suite of standard dataframe
operations (e.g., filtering, sorting, renaming, joining) that return the wrapper
instance (`Self`), preserving method chaining and encapsulation.
This allows users to write clean, expressive logic using their custom wrapper class
while still leveraging the full power of Polars.
Subclasses **must** implement the `dataframe` method to get or set the internal
`polars.DataFrame`.
Examples
--------
A minimal subclass that wraps a Polars DataFrame:
>>> class MyTable(DataFrameWrapper):
... def __init__(self, data: pl.DataFrame):
... self._data = data
...
... def dataframe(self, dataframe: pl.DataFrame = None, may_inplace=True):
... if dataframe is None:
... return self._data
... if may_inplace:
... self._data = dataframe
... return self
... else:
... return MyTable(dataframe)
>>> df = pl.DataFrame({'a': [1, 2, 3], 'b': [10, 20, 30]})
>>> t = MyTable(df)
>>> t = t.filter(pl.col("a") > 1).rename({"b": "B"})
>>> print(t.dataframe())
shape: (2, 2)
┌─────┬─────┐
│ a │ B │
├─────┼─────┤
│ 2 │ 20 │
│ 3 │ 30 │
└─────┴─────┘
Notes
-----
- All supported operations delegate to the underlying `polars.DataFrame` and return
the modified wrapper instance.
- The actual `dataframe` storage and logic is delegated to subclasses via the abstract
`dataframe()` getter/setter method.
- This class is designed for flexible and extensible use in applications such as
data modeling, pipelines, or typed schema handling.
Supported Operations
--------------------
- Accessors: `columns`, `schema`, `__len__`, `__array__`, `__dataframe__`
- Indexing: `__getitem__`
- Structure: `filter`, `drop`, `drop_nulls`, `fill_null`, `fill_nan`, `select`,
`with_columns`, `with_row_index`, `rename`, `slice`, `head`, `tail`, `limit`, `sort`
- Aggregation: `group_by`
- Partitioning: `partition_by`
- Joining: `join`
- Transformation: `pipe`, `clone`, `lazy`
See Also
--------
polars.DataFrame : The underlying DataFrame API
polars.Expr : Expression system used throughout the API
"""
@overload
def dataframe(self) -> pl.DataFrame:
pass
@overload
def dataframe(self, dataframe: pl.DataFrame, may_inplace: bool = True) -> Self:
pass
[docs]
@abc.abstractmethod
def dataframe(self, dataframe: pl.DataFrame | None = None, may_inplace: bool = True) -> pl.DataFrame | Self:
"""
Getter/setter for the internal Polars DataFrame.
:param dataframe: Optional new dataframe to set.
:param may_inplace: If True, update current instance. Otherwise, return new instance.
:return: The current dataframe or a modified wrapper instance.
"""
pass
def __len__(self) -> int:
"""See `polars.DataFrame.__len__`."""
return len(self.dataframe())
@property
def columns(self) -> list[str]:
"""See `polars.DataFrame.columns`."""
return self.dataframe().columns
@property
def schema(self) -> pl.Schema:
"""See `polars.DataFrame.schema`."""
return self.dataframe().schema
def __array__(self, *args, **kwargs) -> np.ndarray:
"""See `polars.DataFrame.__array__`."""
return self.dataframe().__array__(*args, **kwargs)
def __dataframe__(self, *args, **kwargs):
"""See `polars.DataFrame.__dataframe__`."""
return self.dataframe().__dataframe__(*args, **kwargs)
def __getitem__(self, item: Any) -> Any:
"""See `polars.DataFrame.__getitem__`."""
return self.dataframe().__getitem__(item)
[docs]
def lazy(self) -> LazyDataFrameWrapper[Self]:
"""Wrap dataframe in a lazy wrapper."""
return LazyDataFrameWrapper(self, self.dataframe().lazy())
[docs]
def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> Self:
"""See `polars.DataFrame.rename`."""
return self.dataframe(self.dataframe().rename(mapping))
[docs]
def filter(self, *predicates: pty.IntoExprColumn | Iterable[pty.IntoExprColumn] | bool | list[bool] | np.ndarray,
**constraints: Any) -> Self:
"""See `polars.DataFrame.filter`."""
return self.dataframe(self.dataframe().filter(*predicates, **constraints))
[docs]
def slice(self, offset: int, length: int | None = None) -> Self:
"""See `polars.DataFrame.slice`."""
return self.dataframe(self.dataframe().slice(offset, length))
[docs]
def head(self, n: int = 5) -> Self:
"""See `polars.DataFrame.head`."""
return self.dataframe(self.dataframe().head(n))
[docs]
def tail(self, n: int = 5) -> Self:
"""See `polars.DataFrame.tail`."""
return self.dataframe(self.dataframe().tail(n))
[docs]
def limit(self, n: int = 5) -> Self:
"""See `polars.DataFrame.limit`."""
return self.dataframe(self.dataframe().limit(n))
[docs]
def sort(self,
by: pty.IntoExpr | Iterable[pty.IntoExpr],
*more_by: pty.IntoExpr,
descending: bool | Sequence[bool] = False,
nulls_last: bool | Sequence[bool] = False,
multithreaded: bool = True,
maintain_order: bool = False) -> Self:
"""See `polars.DataFrame.sort`."""
return self.dataframe(self.dataframe().sort(
by,
*more_by,
descending=descending,
nulls_last=nulls_last,
multithreaded=multithreaded,
maintain_order=maintain_order
))
[docs]
def drop(self, *columns: pty.ColumnNameOrSelector | Iterable[pty.ColumnNameOrSelector],
strict: bool = True) -> Self:
"""See `polars.DataFrame.drop`."""
return self.dataframe(self.dataframe().drop(*columns, strict=strict))
[docs]
def drop_nulls(self, subset: pty.ColumnNameOrSelector | Collection[pty.ColumnNameOrSelector]) -> Self:
"""See `polars.DataFrame.drop_nulls`."""
return self.dataframe(self.dataframe().drop_nulls(subset))
[docs]
def fill_null(self, value: Any | pl.Expr | None = None,
strategy: pty.FillNullStrategy | None = None,
limit: int | None = None, **kwargs) -> Self:
"""See `polars.DataFrame.fill_null`."""
return self.dataframe(self.dataframe().fill_null(value, strategy, limit, **kwargs))
[docs]
def fill_nan(self, value: pl.Expr | int | float | None = None) -> Self:
"""See `polars.DataFrame.fill_nan`."""
return self.dataframe(self.dataframe().fill_nan(value))
[docs]
def clear(self, n: int = 5) -> Self:
"""See `polars.DataFrame.clear`."""
return self.dataframe(self.dataframe().clear(n))
[docs]
def clone(self) -> Self:
"""Clone the wrapper."""
return self.dataframe(self.dataframe(), may_inplace=False)
@overload
def partition_by(self, by: pty.ColumnNameOrSelector | Iterable[pty.ColumnNameOrSelector],
*more_by: pty.ColumnNameOrSelector,
maintain_order: bool = True,
include_key: bool = True,
as_dict: Literal[False] = ...) -> list[Self]:
...
@overload
def partition_by(self, by: pty.ColumnNameOrSelector | Iterable[pty.ColumnNameOrSelector],
*more_by: pty.ColumnNameOrSelector,
maintain_order: bool = ...,
include_key: bool = ...,
as_dict: Literal[True]) -> dict[tuple[object, ...], Self]:
...
[docs]
def partition_by(self, by, *more_by, as_dict=False, **kwargs):
"""See `polars.DataFrame.partition_by`."""
if as_dict:
dataframe = self.dataframe().partition_by(by, *more_by, as_dict=True, **kwargs)
return {k: self.dataframe(it, may_inplace=False) for k, it in dataframe.items()}
else:
dataframe = self.dataframe().partition_by(by, *more_by, as_dict=False, **kwargs)
return [self.dataframe(it, may_inplace=False) for it in dataframe]
[docs]
def select(self, *exprs: pty.IntoExpr | Iterable[pty.IntoExpr],
**named_exprs: pty.IntoExpr) -> Self:
"""See `polars.DataFrame.select`."""
return self.dataframe(self.dataframe().select(*exprs, **named_exprs))
[docs]
def with_columns(self, *exprs: pty.IntoExpr | Iterable[pty.IntoExpr],
**named_exprs: pty.IntoExpr) -> Self:
"""See `polars.DataFrame.with_columns`."""
return self.dataframe(self.dataframe().with_columns(*exprs, **named_exprs))
[docs]
def with_row_index(self, name: str = "index", offset: int = 0) -> Self:
"""See `polars.DataFrame.with_row_index`."""
return self.dataframe(self.dataframe().with_row_index(name, offset))
[docs]
def join(self, other: pl.DataFrame | DataFrameWrapper,
on: str | pl.Expr | Sequence[str | pl.Expr],
how: pty.JoinStrategy = "inner", *,
left_on=None,
right_on=None,
suffix: str = "_right",
validate: pty.JoinValidation = "m:m",
join_nulls: bool = False,
coalesce: bool | None = None) -> Self:
"""See `polars.DataFrame.join`."""
if isinstance(other, DataFrameWrapper):
other = other.dataframe()
return self.dataframe(self.dataframe().join(
other,
on,
how=how,
left_on=left_on,
right_on=right_on,
suffix=suffix,
validate=validate,
nulls_equal=join_nulls,
coalesce=coalesce
))
[docs]
def pipe(self, function: Callable[Concatenate[pl.DataFrame, P], pl.DataFrame],
*args: P.args,
**kwargs: P.kwargs) -> Self:
"""See `polars.DataFrame.pipe`."""
return self.dataframe(self.dataframe().pipe(function, *args, **kwargs))
[docs]
def group_by(self, *by: pty.IntoExpr | Iterable[pty.IntoExpr],
maintain_order: bool = False,
**named_by: pty.IntoExpr) -> GroupBy:
"""See `polars.DataFrame.group_by`."""
return self.dataframe().group_by(*by, maintain_order=maintain_order, **named_by)
T = TypeVar('T', bound=DataFrameWrapper)
class LazyDataFrameWrapper(Generic[T]):
__slots__ = '__wrapper', '__lazy'
def __init__(self, wrapper: T, lazy: pl.LazyFrame):
self.__wrapper = wrapper
self.__lazy = lazy
@property
def columns(self) -> list[str]:
return self.__lazy.columns
@property
def schema(self) -> pl.Schema:
return self.__lazy.schema
def lazy(self) -> LazyDataFrameWrapper[T]:
return self
def collect(self, **kwargs) -> T:
return self.__wrapper.dataframe(self.__lazy.collect(**kwargs))
def rename(self, mapping: dict[str, str] | Callable[[str], str]) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.rename(mapping))
def slice(self, offset: int, length: int | None = None) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.slice(offset, length))
def head(self, n: int = 5) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.head(n))
def tail(self, n: int = 5) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.tail(n))
def limit(self, n: int = 5) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.limit(n))
def clear(self, n: int = 0) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.clear(n))
def filter(self, *predicates: pty.IntoExprColumn | Iterable[pty.IntoExprColumn] | bool | list[bool] | np.ndarray,
**constraints: Any) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.filter(*predicates, **constraints))
def sort(self,
by: pty.IntoExpr | Iterable[pty.IntoExpr],
*more_by: pty.IntoExpr,
descending: bool | Sequence[bool] = False,
nulls_last: bool | Sequence[bool] = False,
multithreaded: bool = True,
maintain_order: bool = False) -> LazyDataFrameWrapper[T]:
df = self.__lazy.sort(
by,
*more_by,
descending=descending,
nulls_last=nulls_last,
multithreaded=multithreaded,
maintain_order=maintain_order
)
return LazyDataFrameWrapper(self.__wrapper, df)
def drop(self, *columns: pty.ColumnNameOrSelector | Iterable[pty.ColumnNameOrSelector],
strict: bool = True) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.drop(*columns, strict=strict))
def drop_nulls(self, subset: pty.ColumnNameOrSelector | Collection[pty.ColumnNameOrSelector]) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.drop_nulls(subset))
def fill_null(self, value: Any | pl.Expr | None = None,
strategy: pty.FillNullStrategy | None = None,
limit: int | None = None, **kwargs) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.fill_null(value, strategy, limit, **kwargs))
def fill_nan(self, value: pl.Expr | int | float | None = None) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.fill_nan(value))
def select(self, *exprs: pty.IntoExpr | Iterable[pty.IntoExpr],
**named_exprs: pty.IntoExpr) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.select(*exprs, **named_exprs))
def with_columns(self, *exprs: pty.IntoExpr | Iterable[pty.IntoExpr],
**named_exprs: pty.IntoExpr) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.with_columns(*exprs, **named_exprs))
def with_row_index(self, name: str = "index", offset: int = 0) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.with_row_index(name, offset))
def join(self, other: pl.DataFrame | pl.LazyFrame | DataFrameWrapper,
on: str | pl.Expr | Sequence[str | pl.Expr],
how: pty.JoinStrategy = "inner", *,
left_on=None,
right_on=None,
suffix: str = "_right",
validate: pty.JoinValidation = "m:m",
join_nulls: bool = False,
coalesce: bool | None = None) -> LazyDataFrameWrapper[T]:
if isinstance(other, DataFrameWrapper):
other = other.dataframe()
if not isinstance(other, pl.LazyFrame):
other = other.lazy()
df = self.__lazy.join(
other,
on,
how=how,
left_on=left_on,
right_on=right_on,
suffix=suffix,
validate=validate,
nulls_equal=join_nulls,
coalesce=coalesce
)
return LazyDataFrameWrapper(self.__wrapper, df)
def pipe(self, function: Callable[Concatenate[pl.LazyFrame, P], pl.LazyFrame],
*args: P.args,
**kwargs: P.kwargs) -> LazyDataFrameWrapper[T]:
return LazyDataFrameWrapper(self.__wrapper, self.__lazy.pipe(function, *args, **kwargs))
[docs]
def helper_with_index_column(df: T,
column: str,
index: int | list[int] | np.ndarray | T,
maintain_order: bool = False,
strict: bool = False) -> T:
"""
A help function to do the filter on an index column.
:param df:
:param column: index column
:param index: index array
:param maintain_order: keep the ordering of *index* in the returned dataframe.
:param strict: all index in *index* should present in the returned dataframe. Otherwise, an error will be raised.
:return:
:raise RuntimeError: strict mode fail.
"""
if isinstance(index, (int, np.integer)):
index_values = np.asarray([index])
elif isinstance(index, type(df)):
index_values = index[column].to_numpy()
else:
index_values = np.asarray(index)
if strict:
if len(miss := np.setdiff1d(index_values, df.dataframe()[column].unique().to_numpy())) > 0:
raise RuntimeError(f'missing {column}: {list(miss)}')
if maintain_order:
_column = '_' + column
index_frame = pl.DataFrame(
{column: index_values},
schema_overrides={column: df.schema[column]}
).with_row_index(_column)
ret = df.lazy().join(index_frame, on=column, how='left')
ret = ret.filter(pl.col(_column).is_not_null())
return ret.sort(_column).drop(_column).collect()
else:
return df.filter(pl.col(column).is_in(index_values))
[docs]
def assert_polars_equal_verbose(df1: pl.DataFrame, df2: pl.DataFrame, **kwargs):
"""
Assert that two Polars DataFrames are equal and provide detailed diagnostics if they differ
:param df1: The first Polars DataFrame to compare
:param df2: The second Polars DataFrame to compare
:param kwargs: keyword arguments passed to :func:`~neuralib.util.verbose.printdf()`
:return:
"""
try:
assert_frame_equal(df1, df2)
print('DataFrames are equal.')
except AssertionError as e:
print('DataFrames are NOT equal.')
# shape
print('\nShape mismatch:')
print(f'df1: {df1.shape}')
print(f'df2: {df2.shape}')
# column
if df1.columns != df2.columns:
print('\nColumn mismatch:')
print(f'df1 columns: {df1.columns}')
print(f'df2 columns: {df2.columns}')
raise e
df1_extra = df1.join(df2, on=df1.columns, how='anti')
df2_extra = df2.join(df1, on=df1.columns, how='anti')
if df1_extra.height > 0:
print('\nRows in df1 not in df2:')
printdf(df1_extra, **kwargs)
if df2_extra.height > 0:
print('\nRows in df2 not in df1:')
printdf(df2_extra, **kwargs)
# If shapes match, show cell-wise diff
if df1.shape == df2.shape:
print('\nCell-wise differences (non-equal values):')
diffs = _highlight_cell_differences(df1, df2)
printdf(diffs, **kwargs)
raise e
def _highlight_cell_differences(df1: pl.DataFrame, df2: pl.DataFrame) -> pl.DataFrame:
return df1.select([
pl.when(df1[col] != df2[col])
.then(pl.lit('df1=') + df1[col].cast(str) + ', df2=' + df2[col].cast(str))
.otherwise('')
.alias(col)
for col in df1.columns
])