From 4d33b68c2f060f5d465301b63474173bdae2807c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 15 Sep 2025 18:19:46 +0000 Subject: [PATCH 01/93] feat(expr-ir): Getting started on `GroupBy` Mapping things out a bit, no compliant yet --- narwhals/_plan/dataframe.py | 7 +++++++ narwhals/_plan/group_by.py | 38 +++++++++++++++++++++++++++++++++++++ narwhals/_plan/typing.py | 3 +++ 3 files changed, 48 insertions(+) create mode 100644 narwhals/_plan/group_by.py diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 8f06f1e5c9..8feeb7f2d3 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -5,6 +5,7 @@ from narwhals._plan import _expansion, _parse from narwhals._plan.contexts import ExprContext from narwhals._plan.expr import _parse_sort_by +from narwhals._plan.group_by import GroupBy from narwhals._plan.series import Series from narwhals._plan.typing import ( IntoExpr, @@ -138,3 +139,9 @@ def to_dict( def __len__(self) -> int: return len(self._compliant) + + def group_by( + self, *by: OneOrIterable[IntoExpr], **named_by: IntoExpr + ) -> GroupBy[Self]: + exprs = _parse.parse_into_seq_of_expr_ir(*by, **named_by) + return GroupBy(self, exprs) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py new file mode 100644 index 0000000000..fb152f304a --- /dev/null +++ b/narwhals/_plan/group_by.py @@ -0,0 +1,38 @@ +"""Refresher on `rust` impl. + +- [`resolve_group_by`] has the dsl algo + - Depends on some `expr_expansion` functions I've implemented + - `group_by_dynamic` is there also (but not doing that) + - ooooh [auto-implode] +- [`dsl_to_ir::to_alp_impl`] was the caller of ^^^^^ + + + +[`resolve_group_by`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1125-L1227 +[auto-implode]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1197-L1203 +[`dsl_to_ir::to_alp_impl`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L459-L509 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic + +from narwhals._plan import _parse +from narwhals._plan.typing import DataFrameT + +if TYPE_CHECKING: + from narwhals._plan.expressions import ExprIR + from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq + + +class GroupBy(Generic[DataFrameT]): + _frame: DataFrameT + _keys: Seq[ExprIR] + + def __init__(self, frame: DataFrameT, keys: Seq[ExprIR], /) -> None: + self._frame = frame + self._keys = keys + + def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: + exprs = _parse.parse_into_seq_of_expr_ir(*aggs, **named_aggs) # noqa: F841 + raise NotImplementedError diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 0efb81ea81..ac16adffea 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -10,6 +10,7 @@ from narwhals import dtypes from narwhals._plan._expr_ir import ExprIR, NamedIR, SelectorIR from narwhals._plan._function import Function + from narwhals._plan.dataframe import DataFrame from narwhals._plan.expr import Expr from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow @@ -25,6 +26,7 @@ ) __all__ = [ + "DataFrameT", "FunctionT", "IntoExpr", "IntoExprColumn", @@ -107,3 +109,4 @@ IntoExprColumn: TypeAlias = "Expr | Series[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" OneOrIterable: TypeAlias = "T | t.Iterable[T]" +DataFrameT = TypeVar("DataFrameT", bound="DataFrame[t.Any, t.Any]") From 3718690ba86e6e7331e9226f13746e859bdcd400 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 15 Sep 2025 22:19:13 +0000 Subject: [PATCH 02/93] feat(DRAFT): mock up `resolve_group_by` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There's a few gaps, but overall surprised how much was reusable 🥳 --- narwhals/_plan/group_by.py | 57 ++++++++++++++++++++++++++++++++++---- narwhals/_plan/schema.py | 13 +++++++++ 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index fb152f304a..38e6f393fe 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -5,12 +5,17 @@ - `group_by_dynamic` is there also (but not doing that) - ooooh [auto-implode] - [`dsl_to_ir::to_alp_impl`] was the caller of ^^^^^ - - +- Misc recent important PRs + - `1.32.1` + - [Remove `Context` from logical layer] + - `1.32.0` + - [Make `Selector` a concrete part of the DSL] [`resolve_group_by`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1125-L1227 [auto-implode]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1197-L1203 [`dsl_to_ir::to_alp_impl`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L459-L509 +[Remove `Context` from logical layer]: https://github.com/pola-rs/polars/pull/23863 +[Make `Selector` a concrete part of the DSL]: https://github.com/pola-rs/polars/pull/23351 """ from __future__ import annotations @@ -18,11 +23,18 @@ from typing import TYPE_CHECKING, Generic from narwhals._plan import _parse +from narwhals._plan._expansion import ( + ensure_valid_exprs, + into_named_irs, + rewrite_projections, +) +from narwhals._plan.schema import FrozenSchema, freeze_schema from narwhals._plan.typing import DataFrameT if TYPE_CHECKING: - from narwhals._plan.expressions import ExprIR + from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq + from narwhals.schema import Schema class GroupBy(Generic[DataFrameT]): @@ -34,5 +46,40 @@ def __init__(self, frame: DataFrameT, keys: Seq[ExprIR], /) -> None: self._keys = keys def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: - exprs = _parse.parse_into_seq_of_expr_ir(*aggs, **named_aggs) # noqa: F841 - raise NotImplementedError + keys_named_irs, aggs_named_irs, result_schema = resolve_group_by( # noqa: RUF059 + self._keys, + _parse.parse_into_seq_of_expr_ir(*aggs, **named_aggs), + self._frame.schema, + ) + msg = "`GroupBy.agg` needs a compliant-level to dispatch to" + raise NotImplementedError(msg) + + +def resolve_group_by( + input_keys: Seq[ExprIR], input_aggs: Seq[ExprIR], schema: Schema +) -> tuple[Seq[NamedIR], Seq[NamedIR], FrozenSchema]: + input_schema = freeze_schema(schema) + + # "Initialize schema from keys" + keys = rewrite_projections(input_keys, keys=(), schema=input_schema) + key_names = ensure_valid_exprs(keys, input_schema) + keys_named_irs = into_named_irs(keys, key_names) + output_schema = input_schema._select(keys_named_irs) + + # "Add aggregation column(s)" # noqa: ERA001 + # TODO @dangotbanned: Figure out if/when `keys: GroupByKeys` got out of sync + aggs = rewrite_projections(input_aggs, keys=key_names, schema=input_schema) # type: ignore[arg-type] + aggs_names = ensure_valid_exprs(aggs, input_schema) + aggs_named_irs = into_named_irs(aggs, aggs_names) + aggs_schema = input_schema._select(aggs_named_irs) + + # "Coerce aggregation column(s) into List unless not needed (auto-implode)" # noqa: ERA001 + # TODO @dangotbanned: seems to just be a schema transform, maybe not important for now? + + # "Final output_schema" + result_schema = output_schema.merge(aggs_schema) + + # "Make sure aggregation columns do not contain keys or index columns" + # TODO @dangotbanned: Probably just the keys part? + # *index columns* seems to be rolling/dynamic only + return keys_named_irs, aggs_named_irs, result_schema diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 4dbf5e6ef3..7f66560de6 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -77,6 +77,19 @@ def _with_columns(self, exprs: Seq[NamedIR]) -> tuple[Seq[NamedIR], FrozenSchema items = self return tuple(exprs_out), freeze_schema(items) + def merge(self, other: FrozenSchema, /) -> FrozenSchema: + """Return a new schema, merging `other` with `self`. + + Merging logic (from [`Schema.merge`]): + - Fields that occur in `self` but not `other` are unmodified + - Fields that occur in `other` but not `self` are appended, in order, to the end of `self` + - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original index + + [`Schema.merge`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-schema/src/schema.rs#L265-L274 + """ + msg = "`FrozenSchema.merge` has some fancy logic I need to twiddle around with first!" + raise NotImplementedError(msg) + @property def __immutable_hash__(self) -> int: if hasattr(self, _IMMUTABLE_HASH_NAME): From f70c0215687d3ede9cc0520535fc9ddce9f318cc Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:45:49 +0000 Subject: [PATCH 03/93] fix: re-sync `GroupByKeys` --- narwhals/_plan/_expansion.py | 6 ++---- narwhals/_plan/group_by.py | 3 +-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index fb2dd390a8..b8e7e27ecd 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -87,7 +87,7 @@ Excluded: TypeAlias = "frozenset[str]" """Internally use a `set`, then freeze before returning.""" -GroupByKeys: TypeAlias = "Seq[ExprIR]" +GroupByKeys: TypeAlias = "Seq[str]" """Represents group_by keys. - Originates from `polars_plan::plans::conversion::dsl_to_ir::resolve_group_by` @@ -326,9 +326,7 @@ def prepare_excluded( exclude: set[str] = set() if flags.has_exclude: exclude.update(*(e.names for e in origin.iter_left() if isinstance(e, Exclude))) - for group_by_key in keys: - if name := group_by_key.meta.output_name(raise_if_undetermined=False): - exclude.add(name) + exclude.update(keys) return frozenset(exclude) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 38e6f393fe..3388db2ce4 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -67,8 +67,7 @@ def resolve_group_by( output_schema = input_schema._select(keys_named_irs) # "Add aggregation column(s)" # noqa: ERA001 - # TODO @dangotbanned: Figure out if/when `keys: GroupByKeys` got out of sync - aggs = rewrite_projections(input_aggs, keys=key_names, schema=input_schema) # type: ignore[arg-type] + aggs = rewrite_projections(input_aggs, keys=key_names, schema=input_schema) aggs_names = ensure_valid_exprs(aggs, input_schema) aggs_named_irs = into_named_irs(aggs, aggs_names) aggs_schema = input_schema._select(aggs_named_irs) From 3828ea48327613e8d7ed31157cb0d3c480d5a4d8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:52:52 +0000 Subject: [PATCH 04/93] feat: Make `rewrite_projections(keys)` optional --- narwhals/_plan/_expansion.py | 6 +++--- narwhals/_plan/group_by.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index b8e7e27ecd..06a9f39dd5 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -169,7 +169,7 @@ def prepare_projection( `exprs`, rewritten using `Column(name)` only. """ frozen_schema = freeze_schema(schema) - rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema) + rewritten = rewrite_projections(tuple(exprs), schema=frozen_schema) output_names = ensure_valid_exprs(rewritten, frozen_schema) return rewritten, frozen_schema, output_names @@ -202,7 +202,7 @@ def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames: def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: if is_horizontal_reduction(child): - rewrites = rewrite_projections(child.input, keys=(), schema=schema) + rewrites = rewrite_projections(child.input, schema=schema) return common.replace(child, input=rewrites) return child @@ -275,7 +275,7 @@ def expand_selector(selector: SelectorIR, schema: FrozenSchema) -> Columns: def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, - keys: GroupByKeys, + keys: GroupByKeys = (), *, schema: FrozenSchema, ) -> Seq[ExprIR]: diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 3388db2ce4..4949ea9d32 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -61,7 +61,7 @@ def resolve_group_by( input_schema = freeze_schema(schema) # "Initialize schema from keys" - keys = rewrite_projections(input_keys, keys=(), schema=input_schema) + keys = rewrite_projections(input_keys, schema=input_schema) key_names = ensure_valid_exprs(keys, input_schema) keys_named_irs = into_named_irs(keys, key_names) output_schema = input_schema._select(keys_named_irs) From feb766153e890d92f0784c0c3a1cb19a9b677909 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 16 Sep 2025 17:31:24 +0000 Subject: [PATCH 05/93] feat: Add `FrozenSchema.merge` lol didn't realise it was just describing python dict behavior --- narwhals/_plan/schema.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 7f66560de6..9e0214f9fa 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -87,8 +87,7 @@ def merge(self, other: FrozenSchema, /) -> FrozenSchema: [`Schema.merge`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-schema/src/schema.rs#L265-L274 """ - msg = "`FrozenSchema.merge` has some fancy logic I need to twiddle around with first!" - raise NotImplementedError(msg) + return freeze_schema(self._mapping | other._mapping) @property def __immutable_hash__(self) -> int: From 561909c3a6860fec815d35983970b5e1966f3586 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 16 Sep 2025 17:52:08 +0000 Subject: [PATCH 06/93] chore: more informative placeholder error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Everything here seems to be working already? 😱 May as well show it off ```py >>> df.group_by("a", nwp.nth(2, 8)).agg(nwp.mean("d", "e", "g").name.suffix("_mean")) NotImplementedError: TODO: `GroupBy.agg` needs a `CompliantGroupBy` to dispatch to: keys: (a=col('a'), c=col('c'), i=col('i')) aggs: (d_mean=col('d').mean(), e_mean=col('e').mean(), g_mean=col('g').mean()) result_schema: FrozenSchema([ ('a', String), ('c', Int64), ('i', Unknown), ('d_mean', Unknown), ('e_mean', Unknown), ('g_mean', Unknown), ]) ``` --- narwhals/_plan/group_by.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 4949ea9d32..f81601e85e 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -46,12 +46,17 @@ def __init__(self, frame: DataFrameT, keys: Seq[ExprIR], /) -> None: self._keys = keys def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: - keys_named_irs, aggs_named_irs, result_schema = resolve_group_by( # noqa: RUF059 + keys_named_irs, aggs_named_irs, result_schema = resolve_group_by( self._keys, _parse.parse_into_seq_of_expr_ir(*aggs, **named_aggs), self._frame.schema, ) - msg = "`GroupBy.agg` needs a compliant-level to dispatch to" + msg = ( + "TODO: `GroupBy.agg` needs a `CompliantGroupBy` to dispatch to:\n\n" + f"keys:\n{keys_named_irs!r}\n\n" + f"aggs:\n{aggs_named_irs!r}\n\n" + f"result_schema:\n{result_schema!r}" + ) raise NotImplementedError(msg) From 7a811b6bba846059be8016ab480286f491b1fdc9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 18 Sep 2025 14:01:39 +0000 Subject: [PATCH 07/93] feat(DRAFT): Start spec-ing `CompliantGroupBy` Quite different to current version(s) --- narwhals/_plan/protocols.py | 46 +++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 11a17eb081..b107faea75 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -7,6 +7,7 @@ from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar from narwhals._utils import Version +from narwhals.exceptions import ComputeError if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs @@ -69,7 +70,10 @@ SeriesT = TypeVar("SeriesT", bound=SeriesAny) SeriesT_co = TypeVar("SeriesT_co", bound=SeriesAny, covariant=True) FrameT = TypeVar("FrameT", bound=FrameAny) +FrameT_co = TypeVar("FrameT_co", bound=FrameAny, covariant=True) FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) +DataFrameT = TypeVar("DataFrameT", bound=DataFrameAny) +DataFrameT_co = TypeVar("DataFrameT_co", bound=DataFrameAny, covariant=True) NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) EagerExprT_co = TypeVar("EagerExprT_co", bound=EagerExprAny, covariant=True) @@ -531,6 +535,8 @@ class CompliantBaseFrame(StoresVersion, Protocol[ColumnT_co, NativeFrameT]): def __narwhals_namespace__(self) -> Any: ... @property + def _group_by(self) -> type[CompliantGroupBy[Self]]: ... + @property def native(self) -> NativeFrameT: return self._native @@ -561,6 +567,8 @@ class CompliantDataFrame( CompliantBaseFrame[SeriesT, NativeDataFrameT], Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], ): + @property + def _group_by(self) -> type[DataFrameGroupBy[Self]]: ... @classmethod def from_dict( cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None @@ -581,6 +589,44 @@ def __len__(self) -> int: ... def with_row_index(self, name: str) -> Self: ... +class CompliantGroupBy(Protocol[FrameT_co]): + @property + def compliant(self) -> FrameT_co: ... + def agg(self, *args: Any, **kwds: Any) -> FrameT_co: ... + + +class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): + _keys: Seq[NamedIR] + _keys_names: Seq[str] + + @classmethod + def by_names( + cls, df: DataFrameT, names: Seq[str], / + ) -> DataFrameGroupBy[DataFrameT]: ... + + # TODO @dangotbanned: Plan how projection should work + @classmethod + def by_named_irs( + cls, df: DataFrameT, irs: Seq[NamedIR], / + ) -> DataFrameGroupBy[DataFrameT]: ... + + def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... + @property + def keys(self) -> Seq[NamedIR]: + return self._keys + + @property + def keys_names(self) -> Seq[str]: + if names := self._keys_names: + return names + if keys := self.keys: + return tuple(e.name for e in keys) + msg = "at least one key is required in a group_by operation" + raise ComputeError(msg) + + def agg(self, irs: Seq[NamedIR]) -> DataFrameT: ... + + class EagerDataFrame( CompliantDataFrame[SeriesT, NativeDataFrameT, NativeSeriesT], Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], From 6d3c0a96302a9a07dec453ef5c0ed9311654b252 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 18 Sep 2025 14:02:41 +0000 Subject: [PATCH 08/93] feat(DRAFT): Implement some of `ArrowGroupBy` --- narwhals/_plan/arrow/dataframe.py | 48 ++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 27a02bc2ed..12ce44e528 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -4,11 +4,14 @@ import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import +from typing_extensions import Self from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries as Series -from narwhals._plan.protocols import EagerDataFrame, namespace +from narwhals._plan.expressions import NamedIR +from narwhals._plan.protocols import DataFrameGroupBy, EagerDataFrame, namespace +from narwhals._plan.typing import Seq from narwhals._utils import Version from narwhals.schema import Schema @@ -34,6 +37,10 @@ def __narwhals_namespace__(self) -> ArrowNamespace: return ArrowNamespace(self._version) + @property + def _group_by(self) -> type[ArrowGroupBy]: + return ArrowGroupBy + @property def columns(self) -> list[str]: return self.native.column_names @@ -113,3 +120,42 @@ def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: else: native = native.append_column(name, chunked) return self._with_native(native) + + +class ArrowGroupBy(DataFrameGroupBy[ArrowDataFrame]): + """What narwhals is doing. + + - Keys are handled only at compliant + - `ParseKeysGroupBy` does weird stuff + - But has a fast path for all `str` keys + - Aggs are handled in both levels + - Some compliant have more restrictions + """ + + _df: ArrowDataFrame + _grouped: pa.TableGroupBy + _keys: Seq[NamedIR] + _keys_names: Seq[str] + + @classmethod + def by_names(cls, df: ArrowDataFrame, names: Seq[str], /) -> Self: + obj = cls.__new__(cls) + obj._df = df + obj._keys = () + obj._keys_names = names + obj._grouped = pa.TableGroupBy(df.native, list(names)) + return obj + + @classmethod + def by_named_irs(cls, df: ArrowDataFrame, irs: Seq[NamedIR], /) -> Self: + raise NotImplementedError + + @property + def compliant(self) -> ArrowDataFrame: + return self._df + + def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: + raise NotImplementedError + + def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame: + raise NotImplementedError From e71d092223cffe40fcf752a7a78f1decfe8040d9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 18 Sep 2025 14:04:49 +0000 Subject: [PATCH 09/93] feat(DRAFT): Fill out more of `GroupBy.agg` --- narwhals/_plan/_expr_ir.py | 14 +++++++++ narwhals/_plan/group_by.py | 61 +++++++++++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 0646520102..d163134c80 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -290,3 +290,17 @@ def is_elementwise_top_level(self) -> bool: if is_literal(ir): return ir.is_scalar return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) + + def is_column(self, *, allow_aliasing: bool = False) -> bool: + """Return True if wrapping a single `Column` node. + + Note: + Multi-output (including selectors) expressions have been expanded at this stage. + + Arguments: + allow_aliasing: If False (default), any aliasing is not considered to be column selection. + """ + from narwhals._plan.expressions import Column + + ir = self.expr + return isinstance(ir, Column) and ((self.name == ir.name) or allow_aliasing) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index f81601e85e..fc4f3cd8bc 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -20,7 +20,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic +from typing import TYPE_CHECKING, Generic, NamedTuple from narwhals._plan import _parse from narwhals._plan._expansion import ( @@ -46,23 +46,53 @@ def __init__(self, frame: DataFrameT, keys: Seq[ExprIR], /) -> None: self._keys = keys def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: - keys_named_irs, aggs_named_irs, result_schema = resolve_group_by( + frame = self._frame + resolved = resolve_group_by( self._keys, _parse.parse_into_seq_of_expr_ir(*aggs, **named_aggs), - self._frame.schema, + frame.schema, ) - msg = ( - "TODO: `GroupBy.agg` needs a `CompliantGroupBy` to dispatch to:\n\n" - f"keys:\n{keys_named_irs!r}\n\n" - f"aggs:\n{aggs_named_irs!r}\n\n" - f"result_schema:\n{result_schema!r}" + compliant = frame._compliant + compliant_gb = compliant._group_by + # Do we need to project first? + if not all(key.is_column() for key in resolved.keys): + msg = fmt_group_by_error( + "Need to sketch out non-projecting keys group by first", + resolved.keys, + resolved.aggs, + resolved.result_schema, + ) + raise NotImplementedError(msg) + grouped = compliant_gb.by_named_irs(compliant, resolved.keys) + else: # noqa: RET506 + # If not, we can just use the resolved key names as a fast-path + grouped = compliant_gb.by_names(compliant, resolved.keys_names) + msg = fmt_group_by_error( + "`GroupBy.agg` needs a `CompliantGroupBy.agg` to dispatch to", + resolved.keys, + resolved.aggs, + resolved.result_schema, ) raise NotImplementedError(msg) + return grouped.agg(resolved.aggs) + + +class _TempGroupByStuff(NamedTuple): + """Trying to organize info that's useful to keep from `resolve_group_by`. + + Important: + Not a long-term thing! + """ + + keys: Seq[NamedIR] + aggs: Seq[NamedIR] + keys_names: Seq[str] + result_schema: FrozenSchema def resolve_group_by( input_keys: Seq[ExprIR], input_aggs: Seq[ExprIR], schema: Schema -) -> tuple[Seq[NamedIR], Seq[NamedIR], FrozenSchema]: +) -> _TempGroupByStuff: input_schema = freeze_schema(schema) # "Initialize schema from keys" @@ -86,4 +116,15 @@ def resolve_group_by( # "Make sure aggregation columns do not contain keys or index columns" # TODO @dangotbanned: Probably just the keys part? # *index columns* seems to be rolling/dynamic only - return keys_named_irs, aggs_named_irs, result_schema + return _TempGroupByStuff(keys_named_irs, aggs_named_irs, key_names, result_schema) + + +def fmt_group_by_error( + message: str, /, keys: Seq[NamedIR], aggs: Seq[NamedIR], schema: FrozenSchema +) -> str: + return ( + f"TODO: {message}:\n\n" + f"keys:\n{keys!r}\n\n" + f"aggs:\n{aggs!r}\n\n" + f"result_schema:\n{schema!r}" + ) From 8aaf9a94ecdebbc1b5422faa8542d7b515bdf0fe Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:55:01 +0000 Subject: [PATCH 10/93] fix: avoid `typing_extensions` import oops https://github.com/narwhals-dev/narwhals/actions/runs/17838467107/job/50721552166?pr=3143 --- narwhals/_plan/arrow/dataframe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 12ce44e528..a885ab7bd2 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -4,7 +4,6 @@ import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import -from typing_extensions import Self from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn From d9b918f36181f5f98f669f792c5c60c047122a47 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 18 Sep 2025 20:40:35 +0000 Subject: [PATCH 11/93] refactor: Move `ArrowGroupBy` Gonna need space for the mini translator --- narwhals/_plan/arrow/dataframe.py | 46 +++----------------------- narwhals/_plan/arrow/group_by.py | 55 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 42 deletions(-) create mode 100644 narwhals/_plan/arrow/group_by.py diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index a885ab7bd2..1114d3c351 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -7,9 +7,10 @@ from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn +from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.expressions import NamedIR -from narwhals._plan.protocols import DataFrameGroupBy, EagerDataFrame, namespace +from narwhals._plan.protocols import EagerDataFrame, namespace from narwhals._plan.typing import Seq from narwhals._utils import Version from narwhals.schema import Schema @@ -37,8 +38,8 @@ def __narwhals_namespace__(self) -> ArrowNamespace: return ArrowNamespace(self._version) @property - def _group_by(self) -> type[ArrowGroupBy]: - return ArrowGroupBy + def _group_by(self) -> type[GroupBy]: + return GroupBy @property def columns(self) -> list[str]: @@ -119,42 +120,3 @@ def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: else: native = native.append_column(name, chunked) return self._with_native(native) - - -class ArrowGroupBy(DataFrameGroupBy[ArrowDataFrame]): - """What narwhals is doing. - - - Keys are handled only at compliant - - `ParseKeysGroupBy` does weird stuff - - But has a fast path for all `str` keys - - Aggs are handled in both levels - - Some compliant have more restrictions - """ - - _df: ArrowDataFrame - _grouped: pa.TableGroupBy - _keys: Seq[NamedIR] - _keys_names: Seq[str] - - @classmethod - def by_names(cls, df: ArrowDataFrame, names: Seq[str], /) -> Self: - obj = cls.__new__(cls) - obj._df = df - obj._keys = () - obj._keys_names = names - obj._grouped = pa.TableGroupBy(df.native, list(names)) - return obj - - @classmethod - def by_named_irs(cls, df: ArrowDataFrame, irs: Seq[NamedIR], /) -> Self: - raise NotImplementedError - - @property - def compliant(self) -> ArrowDataFrame: - return self._df - - def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: - raise NotImplementedError - - def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame: - raise NotImplementedError diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py new file mode 100644 index 0000000000..48ca0c7ff8 --- /dev/null +++ b/narwhals/_plan/arrow/group_by.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pyarrow as pa # ignore-banned-import + +from narwhals._plan.protocols import DataFrameGroupBy + +if TYPE_CHECKING: + from collections.abc import Iterator + + from typing_extensions import Self + + from narwhals._plan.arrow.dataframe import ArrowDataFrame + from narwhals._plan.expressions import NamedIR + from narwhals._plan.typing import Seq + + +class ArrowGroupBy(DataFrameGroupBy["ArrowDataFrame"]): + """What narwhals is doing. + + - Keys are handled only at compliant + - `ParseKeysGroupBy` does weird stuff + - But has a fast path for all `str` keys + - Aggs are handled in both levels + - Some compliant have more restrictions + """ + + _df: ArrowDataFrame + _grouped: pa.TableGroupBy + _keys: Seq[NamedIR] + _keys_names: Seq[str] + + @classmethod + def by_names(cls, df: ArrowDataFrame, names: Seq[str], /) -> Self: + obj = cls.__new__(cls) + obj._df = df + obj._keys = () + obj._keys_names = names + obj._grouped = pa.TableGroupBy(df.native, list(names)) + return obj + + @classmethod + def by_named_irs(cls, df: ArrowDataFrame, irs: Seq[NamedIR], /) -> Self: + raise NotImplementedError + + @property + def compliant(self) -> ArrowDataFrame: + return self._df + + def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: + raise NotImplementedError + + def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame: + raise NotImplementedError From 767261c44e1724dc7c2fd7089d447feb58c0b1dd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 19 Sep 2025 17:17:44 +0000 Subject: [PATCH 12/93] feat(DRAFT): Simple cases working? Borrowing some ideas from #2528, #2680 --- narwhals/_plan/arrow/dataframe.py | 10 +- narwhals/_plan/arrow/group_by.py | 183 ++++++++++++++++++++++++++++-- narwhals/_plan/group_by.py | 9 +- 3 files changed, 181 insertions(+), 21 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 1114d3c351..667030a6fb 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, cast, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -106,6 +106,14 @@ def drop(self, columns: Sequence[str]) -> Self: to_drop = list(columns) return self._with_native(self.native.drop(to_drop)) + def rename(self, mapping: Mapping[str, str]) -> Self: + names: dict[str, str] | list[str] + if fn.BACKEND_VERSION >= (17,): + names = cast("dict[str, str]", mapping) + else: # pragma: no cover + names = [mapping.get(c, c) for c in self.columns] + return self._with_native(self.native.rename_columns(names)) + # NOTE: Use instead of `with_columns` for trivial cases def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: native = self.native diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 48ca0c7ff8..31faa47e1a 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -1,31 +1,183 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import +from narwhals._plan import expressions as ir +from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import DataFrameGroupBy +from narwhals._utils import Implementation, requires if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterator, Mapping - from typing_extensions import Self + from typing_extensions import Self, TypeAlias + from narwhals._arrow.typing import ( # type: ignore[attr-defined] + AggregateOptions, + Aggregation, + ) + from narwhals._compliant.typing import NarwhalsAggregation as _NarwhalsAggregation from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq + NarwhalsAggregation: TypeAlias = Literal[_NarwhalsAggregation, "first", "last"] + InputName: TypeAlias = str + NativeName: TypeAlias = str + OutputName: TypeAlias = str + NativeAggSpec: TypeAlias = tuple[InputName, Aggregation, AggregateOptions | None] + RenameSpec: TypeAlias = tuple[NativeName, OutputName] -class ArrowGroupBy(DataFrameGroupBy["ArrowDataFrame"]): - """What narwhals is doing. - - Keys are handled only at compliant - - `ParseKeysGroupBy` does weird stuff - - But has a fast path for all `str` keys - - Aggs are handled in both levels - - Some compliant have more restrictions - """ +BACKEND_VERSION = Implementation.PYARROW._backend_version() + + +# TODO @dangotbanned: Missing `nw.col("a").len()` +SUPPORTED_AGG: Mapping[type[agg.AggExpr], Aggregation] = { + agg.Sum: "sum", + agg.Mean: "mean", + agg.Median: "approximate_median", + agg.Max: "max", + agg.Min: "min", + agg.Std: "stddev", + agg.Var: "variance", + agg.Count: "count", + agg.NUnique: "count_distinct", + agg.First: "first", + agg.Last: "last", +} + + +SUPPORTED_IR: Mapping[type[ir.Len], Aggregation] = {ir.Len: "count"} +SUPPORTED_FUNCTION: Mapping[type[ir.boolean.BooleanFunction], Aggregation] = { + ir.boolean.All: "all", + ir.boolean.Any: "any", +} + +REMAINING: tuple[Aggregation, ...] = ( + "count_all", # Count the number of rows in each group + "distinct", # Keep the distinct values in each group + "first_last", # Compute the first and last of values in each group + "list", # List all values in each group + "min_max", # Compute the minimum and maximum of values in each group + "one", # Get one value from each group + "product", # Compute the product of values in each group + "tdigest", # Compute approximate quantiles of values in each group +) +"""Available [native aggs] we haven't used (excluding `first`, `last`) + +[native aggs]: https://arrow.apache.org/docs/python/compute.html#grouped-aggregations +""" + + +REQUIRES_PYARROW_20: tuple[ + Literal["kurtosis"], Literal["pivot_wider"], Literal["skew"] +] = ( + "kurtosis", # Compute the kurtosis of values in each group + "pivot_wider", # Pivot values according to a pivot key column + "skew", # Compute the skewness of values in each group +) +"""https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations""" + + +def _ensure_single_thread( + grouped: pa.TableGroupBy, expr: ir.OrderableAggExpr, / +) -> pa.TableGroupBy: + """First/last require disabling threading.""" + if BACKEND_VERSION >= (14, 0) and grouped._use_threads: + # NOTE: Stubs say `_table` is a method, but at runtime it is a property + grouped = pa.TableGroupBy(grouped._table, grouped.keys, use_threads=False) # type: ignore[arg-type] + elif BACKEND_VERSION < (14, 0): # pragma: no cover + msg = ( + f"Using `{expr!r}` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', " + f"found version {requires._unparse_version(BACKEND_VERSION)!r}.\n\n" + f"See https://github.com/apache/arrow/issues/36709" + ) + raise NotImplementedError(msg) + return grouped + +def group_by_error( + expr: ArrowAggExpr, + reason: Literal[ + "too complex", + "unsupported aggregation", + "unsupported function", + "unsupported expression", + ], +) -> NotImplementedError: + if reason == "too complex": + msg = "Non-trivial complex aggregation found" + else: + msg = reason.title() + msg = f"{msg} in 'pyarrow.Table':\n\n{expr.named_ir!r}" + return NotImplementedError(msg) + + +class ArrowAggExpr: + def __init__(self, named_ir: NamedIR, /) -> None: + self.named_ir: NamedIR = named_ir + + @property + def output_name(self) -> OutputName: + return self.named_ir.name + + def _parse_agg_expr( + self, expr: agg.AggExpr, grouped: pa.TableGroupBy + ) -> tuple[InputName, Aggregation, AggregateOptions | None, pa.TableGroupBy]: + if agg_name := SUPPORTED_AGG.get(type(expr)): + option: AggregateOptions | None = None + if isinstance(expr, (agg.Std, agg.Var)): + # NOTE: Only branch which needs an instance (for `ddof`) + option = pc.VarianceOptions(ddof=expr.ddof) + elif isinstance(expr, agg.NUnique): + option = pc.CountOptions(mode="all") + elif isinstance(expr, agg.Count): + option = pc.CountOptions(mode="only_valid") + elif isinstance(expr, (agg.First, agg.Last)): + option = pc.ScalarAggregateOptions(skip_nulls=False) + # NOTE: Only branch which needs access to `pa.TableGroupBy` + grouped = _ensure_single_thread(grouped, expr) + if isinstance(expr.expr, ir.Column): + return expr.expr.name, agg_name, option, grouped + raise group_by_error(self, "too complex") + raise group_by_error(self, "unsupported aggregation") + + def _parse_function_expr(self, expr: ir.FunctionExpr) -> NativeAggSpec: + if isinstance(expr.function, (ir.boolean.All, ir.boolean.Any)): + agg_name = SUPPORTED_FUNCTION[type(expr.function)] + option = pc.ScalarAggregateOptions(min_count=0) + if len(expr.input) == 1 and isinstance(expr.input[0], ir.Column): + return expr.input[0].name, agg_name, option + raise group_by_error(self, "too complex") + raise group_by_error(self, "unsupported function") + + def _rename_spec(self, input_name: InputName, agg_name: Aggregation, /) -> RenameSpec: + # `pyarrow` auto-generates the lhs + # we want to overwrite that later with rhs + return f"{input_name}_{agg_name}", self.output_name + + def to_native( + self, grouped: pa.TableGroupBy + ) -> tuple[pa.TableGroupBy, NativeAggSpec, RenameSpec]: + expr = self.named_ir.expr + if isinstance(expr, agg.AggExpr): + input_name, agg_name, option, grouped = self._parse_agg_expr(expr, grouped) + elif isinstance(expr, ir.Len): + msg = "Need to investigate https://github.com/narwhals-dev/narwhals/blob/0fb045536f5b56b978f354f8178b292301e9598c/narwhals/_arrow/group_by.py#L132-L141" + raise NotImplementedError(msg) + elif isinstance(expr, ir.FunctionExpr): + input_name, agg_name, option = self._parse_function_expr(expr) + else: + raise group_by_error(self, "unsupported expression") + agg_spec = input_name, agg_name, option + return grouped, agg_spec, self._rename_spec(input_name, agg_name) + + +class ArrowGroupBy(DataFrameGroupBy["ArrowDataFrame"]): _df: ArrowDataFrame _grouped: pa.TableGroupBy _keys: Seq[NamedIR] @@ -52,4 +204,11 @@ def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: raise NotImplementedError def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame: - raise NotImplementedError + gb = self._grouped + aggs: list[NativeAggSpec] = [] + renames: list[RenameSpec] = [] + for e in irs: + gb, agg_spec, rename = ArrowAggExpr(e).to_native(gb) + aggs.append(agg_spec) + renames.append(rename) + return self.compliant._with_native(gb.aggregate(aggs)).rename(dict(renames)) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index fc4f3cd8bc..d4c4e471d1 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -67,14 +67,7 @@ def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFra else: # noqa: RET506 # If not, we can just use the resolved key names as a fast-path grouped = compliant_gb.by_names(compliant, resolved.keys_names) - msg = fmt_group_by_error( - "`GroupBy.agg` needs a `CompliantGroupBy.agg` to dispatch to", - resolved.keys, - resolved.aggs, - resolved.result_schema, - ) - raise NotImplementedError(msg) - return grouped.agg(resolved.aggs) + return self._frame._from_compliant(grouped.agg(resolved.aggs)) class _TempGroupByStuff(NamedTuple): From 648d5d9016144464f92d30f658bd9b16644e4145 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 19 Sep 2025 17:52:01 +0000 Subject: [PATCH 13/93] feat(expr-ir): Add missing `Expr.len` woops Making it a separate node rather than having a flag https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/dsl/mod.rs#L872-L889 --- narwhals/_plan/arrow/expr.py | 8 ++++++++ narwhals/_plan/arrow/group_by.py | 3 ++- narwhals/_plan/expr.py | 3 +++ narwhals/_plan/expressions/aggregation.py | 5 ++++- narwhals/_plan/protocols.py | 7 +++++++ tests/plan/compliant_test.py | 5 +++++ 6 files changed, 29 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 57ec5196d6..2caf3e69b2 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -35,6 +35,7 @@ Count, First, Last, + Len, Max, Mean, Median, @@ -296,6 +297,10 @@ def count(self, node: Count, frame: Frame, name: str) -> Scalar: result = fn.count(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) + def len(self, node: Len, frame: Frame, name: str) -> Scalar: + result = fn.count(self._dispatch_expr(node.expr, frame, name).native, mode="all") + return self._with_native(result, name) + def max(self, node: Max, frame: Frame, name: str) -> Scalar: result: NativeScalar = fn.max_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) @@ -460,6 +465,9 @@ def count(self, node: Count, frame: Frame, name: str) -> Scalar: native = node.expr.dispatch(self, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) + def len(self, node: Len, frame: Frame, name: str) -> Scalar: + return self._with_native(pa.scalar(1), name) + filter = not_implemented() over = not_implemented() over_ordered = not_implemented() diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 31faa47e1a..19a58f9623 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -45,6 +45,7 @@ agg.Std: "stddev", agg.Var: "variance", agg.Count: "count", + agg.Len: "count", agg.NUnique: "count_distinct", agg.First: "first", agg.Last: "last", @@ -133,7 +134,7 @@ def _parse_agg_expr( if isinstance(expr, (agg.Std, agg.Var)): # NOTE: Only branch which needs an instance (for `ddof`) option = pc.VarianceOptions(ddof=expr.ddof) - elif isinstance(expr, agg.NUnique): + elif isinstance(expr, (agg.NUnique, agg.Len)): option = pc.CountOptions(mode="all") elif isinstance(expr, agg.Count): option = pc.CountOptions(mode="only_valid") diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index b0f369bd77..7695c1d92f 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -103,6 +103,9 @@ def exclude(self, *names: OneOrIterable[str]) -> Self: def count(self) -> Self: return self._from_ir(agg.Count(expr=self._ir)) + def len(self) -> Self: + return self._from_ir(agg.Len(expr=self._ir)) + def max(self) -> Self: return self._from_ir(agg.Max(expr=self._ir)) diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index 263ca300e5..0f26a82c10 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -33,7 +33,10 @@ def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: # fmt: off -class Count(AggExpr): ... +class Count(AggExpr): + """Non-null count.""" +class Len(AggExpr): + """Null-inclusive count.""" class Max(AggExpr): ... class Mean(AggExpr): ... class Median(AggExpr): ... diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index b107faea75..f1084e42fc 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -272,6 +272,9 @@ def quantile( def count( self, node: agg.Count, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def len( + self, node: agg.Len, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... def max( self, node: agg.Max, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... @@ -378,6 +381,10 @@ def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: """Returns 0 if null, else 1.""" ... + def len(self, node: agg.Len, frame: FrameT_contra, name: str) -> Self: + """Returns 1.""" + ... + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index ffada70747..b8e9a9ccd6 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -398,6 +398,11 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: raises=NotImplementedError, ), ), + pytest.param( + [nwp.col("g").len(), nwp.col("m").last(), nwp.col("h").count()], + {"g": [3], "m": [2], "h": [1]}, + id="len-count-with-nulls", + ), ], ids=_ids_ir, ) From 2682b10b82d720dfe95b58ac908925d9ec223659 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 19 Sep 2025 18:42:26 +0000 Subject: [PATCH 14/93] feat(expr-ir): Support `nw.len()` https://github.com/narwhals-dev/narwhals/blob/0fb045536f5b56b978f354f8178b292301e9598c/narwhals/_arrow/group_by.py#L132-L141 --- narwhals/_plan/arrow/group_by.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 19a58f9623..5325dfdc7b 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -5,6 +5,7 @@ import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import +from narwhals._compliant.typing import NarwhalsAggregation as _NarwhalsAggregation from narwhals._plan import expressions as ir from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import DataFrameGroupBy @@ -19,17 +20,19 @@ AggregateOptions, Aggregation, ) - from narwhals._compliant.typing import NarwhalsAggregation as _NarwhalsAggregation from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq - NarwhalsAggregation: TypeAlias = Literal[_NarwhalsAggregation, "first", "last"] - InputName: TypeAlias = str - NativeName: TypeAlias = str - OutputName: TypeAlias = str - NativeAggSpec: TypeAlias = tuple[InputName, Aggregation, AggregateOptions | None] - RenameSpec: TypeAlias = tuple[NativeName, OutputName] + +NarwhalsAggregation: TypeAlias = Literal[_NarwhalsAggregation, "first", "last"] +InputName: TypeAlias = "str | tuple[()]" +"""`()` can be used with `"count_all"`.""" + +NativeName: TypeAlias = str +OutputName: TypeAlias = str +NativeAggSpec: TypeAlias = "tuple[InputName, Aggregation, AggregateOptions | None]" +RenameSpec: TypeAlias = tuple[NativeName, OutputName] BACKEND_VERSION = Implementation.PYARROW._backend_version() @@ -52,14 +55,13 @@ } -SUPPORTED_IR: Mapping[type[ir.Len], Aggregation] = {ir.Len: "count"} +SUPPORTED_IR: Mapping[type[ir.Len], Aggregation] = {ir.Len: "count_all"} SUPPORTED_FUNCTION: Mapping[type[ir.boolean.BooleanFunction], Aggregation] = { ir.boolean.All: "all", ir.boolean.Any: "any", } REMAINING: tuple[Aggregation, ...] = ( - "count_all", # Count the number of rows in each group "distinct", # Keep the distinct values in each group "first_last", # Compute the first and last of values in each group "list", # List all values in each group @@ -159,7 +161,8 @@ def _parse_function_expr(self, expr: ir.FunctionExpr) -> NativeAggSpec: def _rename_spec(self, input_name: InputName, agg_name: Aggregation, /) -> RenameSpec: # `pyarrow` auto-generates the lhs # we want to overwrite that later with rhs - return f"{input_name}_{agg_name}", self.output_name + old = f"{input_name}_{agg_name}" if input_name else agg_name + return old, self.output_name def to_native( self, grouped: pa.TableGroupBy @@ -168,8 +171,7 @@ def to_native( if isinstance(expr, agg.AggExpr): input_name, agg_name, option, grouped = self._parse_agg_expr(expr, grouped) elif isinstance(expr, ir.Len): - msg = "Need to investigate https://github.com/narwhals-dev/narwhals/blob/0fb045536f5b56b978f354f8178b292301e9598c/narwhals/_arrow/group_by.py#L132-L141" - raise NotImplementedError(msg) + input_name, agg_name, option = ((), "count_all", None) elif isinstance(expr, ir.FunctionExpr): input_name, agg_name, option = self._parse_function_expr(expr) else: From e1c3145b1a4144dba563f940e72bfd0260544dee Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 19 Sep 2025 18:53:48 +0000 Subject: [PATCH 15/93] feat(expr-ir): support auto-implode https://github.com/narwhals-dev/narwhals/issues/2660#issuecomment-2958240770 --- narwhals/_plan/arrow/group_by.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 5325dfdc7b..37d0a12f01 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -172,6 +172,8 @@ def to_native( input_name, agg_name, option, grouped = self._parse_agg_expr(expr, grouped) elif isinstance(expr, ir.Len): input_name, agg_name, option = ((), "count_all", None) + elif isinstance(expr, ir.Column): + input_name, agg_name, option = (expr.name, "list", None) elif isinstance(expr, ir.FunctionExpr): input_name, agg_name, option = self._parse_function_expr(expr) else: From 5d366072b00980a20dfeafe68f5f4c7d40231bc0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:07:04 +0000 Subject: [PATCH 16/93] feat(DRAFT): Support `nw.col("a").unique()` in `group_by` `pyarrow` has the same behavior as `polars` --- narwhals/_plan/arrow/group_by.py | 22 +++++++++++++--------- narwhals/_plan/expressions/__init__.py | 2 ++ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 37d0a12f01..ac78f17fa5 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -56,13 +56,13 @@ SUPPORTED_IR: Mapping[type[ir.Len], Aggregation] = {ir.Len: "count_all"} -SUPPORTED_FUNCTION: Mapping[type[ir.boolean.BooleanFunction], Aggregation] = { +SUPPORTED_FUNCTION: Mapping[type[ir.Function], Aggregation] = { ir.boolean.All: "all", ir.boolean.Any: "any", + ir.functions.Unique: "distinct", } REMAINING: tuple[Aggregation, ...] = ( - "distinct", # Keep the distinct values in each group "first_last", # Compute the first and last of values in each group "list", # List all values in each group "min_max", # Compute the minimum and maximum of values in each group @@ -150,13 +150,17 @@ def _parse_agg_expr( raise group_by_error(self, "unsupported aggregation") def _parse_function_expr(self, expr: ir.FunctionExpr) -> NativeAggSpec: - if isinstance(expr.function, (ir.boolean.All, ir.boolean.Any)): - agg_name = SUPPORTED_FUNCTION[type(expr.function)] - option = pc.ScalarAggregateOptions(min_count=0) - if len(expr.input) == 1 and isinstance(expr.input[0], ir.Column): - return expr.input[0].name, agg_name, option - raise group_by_error(self, "too complex") - raise group_by_error(self, "unsupported function") + func = expr.function + if agg_name := SUPPORTED_FUNCTION.get(type(func)): + if isinstance(func, (ir.boolean.All, ir.boolean.Any)): + option = pc.ScalarAggregateOptions(min_count=0) + else: + option = None + else: + raise group_by_error(self, "unsupported function") + if len(expr.input) == 1 and isinstance(expr.input[0], ir.Column): + return expr.input[0].name, agg_name, option + raise group_by_error(self, "too complex") def _rename_spec(self, input_name: InputName, agg_name: Aggregation, /) -> RenameSpec: # `pyarrow` auto-generates the lhs diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 237ee36e81..4444bbd6be 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -5,6 +5,7 @@ NamedIR, SelectorIR, ) +from narwhals._plan._function import Function from narwhals._plan.expressions import ( aggregation, boolean, @@ -60,6 +61,7 @@ "Exclude", "ExprIR", "Filter", + "Function", "FunctionExpr", "IndexColumns", "InvertSelector", From 1aa2464d1178950d8792f9675d30dcc40e87cccc Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 19 Sep 2025 21:07:30 +0000 Subject: [PATCH 17/93] test: Port over `tests/frame/group_by_test` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wasn't expecting so much to be working already 🥳 🥳 🥳 --- tests/plan/group_by_test.py | 367 ++++++++++++++++++++++++++++++++++++ 1 file changed, 367 insertions(+) create mode 100644 tests/plan/group_by_test.py diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py new file mode 100644 index 0000000000..2a7c050668 --- /dev/null +++ b/tests/plan/group_by_test.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +from narwhals import _plan as nwp +from narwhals._plan import selectors as npcs +from narwhals.exceptions import InvalidOperationError +from tests.utils import PYARROW_VERSION, assert_equal_data as _assert_equal_data + +pytest.importorskip("pyarrow") + + +import pyarrow as pa + +if TYPE_CHECKING: + from collections.abc import Mapping + + +def dataframe(data: dict[str, Any], /) -> nwp.DataFrame[Any, Any]: + return nwp.DataFrame.from_native(pa.table(data)) + + +def assert_equal_data(result: Any, expected: Mapping[str, Any]) -> None: + if isinstance(result, nwp.DataFrame): + result = result.to_dict(as_series=False) + _assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("attr", "expected"), + [ + ("sum", {"a": [1, 2], "b": [3, 3]}), + ("mean", {"a": [1, 2], "b": [1.5, 3]}), + ("max", {"a": [1, 2], "b": [2, 3]}), + ("min", {"a": [1, 2], "b": [1, 3]}), + ("std", {"a": [1, 2], "b": [0.707107, None]}), + ("var", {"a": [1, 2], "b": [0.5, None]}), + ("len", {"a": [1, 2], "b": [3, 1]}), + ("n_unique", {"a": [1, 2], "b": [3, 1]}), + ("count", {"a": [1, 2], "b": [2, 1]}), + ], +) +def test_group_by_depth_1_agg(attr: str, expected: dict[str, list[Any]]) -> None: + data = {"a": [1, 1, 1, 2], "b": [1, None, 2, 3]} + expr = getattr(nwp.col("b"), attr)() + result = dataframe(data).group_by("a").agg(expr).sort("a") + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("values", "expected"), + [ + ( + {"x": [True, True, True, False, False, False]}, + {"all": [True, False, False], "any": [True, True, False]}, + ), + ( + {"x": [True, None, False, None, None, None]}, + {"all": [True, False, True], "any": [True, False, False]}, + ), + ], + ids=["not-nullable", "nullable"], +) +def test_group_by_depth_1_agg_bool_ops( + values: dict[str, list[bool]], expected: dict[str, list[bool]] +) -> None: + data = {"a": [1, 1, 2, 2, 3, 3], **values} + result = ( + dataframe(data) + .group_by("a") + .agg(nwp.col("x").all().alias("all"), nwp.col("x").any().alias("any")) + .sort("a") + ) + assert_equal_data(result, {"a": [1, 2, 3], **expected}) + + +@pytest.mark.parametrize( + ("attr", "ddof"), [("std", 0), ("var", 0), ("std", 2), ("var", 2)] +) +def test_group_by_depth_1_std_var(attr: str, ddof: int) -> None: + data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} + _pow = 0.5 if attr == "std" else 1 + expected = { + "a": [1, 2], + "b": [ + (sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow, + (sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow, + ], + } + expr = getattr(nwp.col("b"), attr)(ddof=ddof) + result = dataframe(data).group_by("a").agg(expr).sort("a") + assert_equal_data(result, expected) + + +def test_group_by_median() -> None: + data = {"a": [1, 1, 1, 2, 2, 2], "b": [5, 4, 6, 7, 3, 2]} + result = dataframe(data).group_by("a").agg(nwp.col("b").median()).sort("a") + expected = {"a": [1, 2], "b": [5, 3]} + assert_equal_data(result, expected) + + +def test_group_by_n_unique_w_missing() -> None: + data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} + result = ( + dataframe(data) + .group_by("a") + .agg( + nwp.col("b").n_unique(), + c_n_unique=nwp.col("c").n_unique(), + c_n_min=nwp.col("b").min(), + d_n_unique=nwp.col("d").n_unique(), + ) + .sort("a") + ) + expected = { + "a": [1, 2], + "b": [2, 1], + "c_n_unique": [1, 1], + "c_n_min": [4, 5], + "d_n_unique": [1, 1], + } + assert_equal_data(result, expected) + + +def test_group_by_simple_named() -> None: + data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} + df = dataframe(data) + result = ( + df.group_by("a").agg(b_min=nwp.col("b").min(), b_max=nwp.col("b").max()).sort("a") + ) + expected = {"a": [1, 2], "b_min": [4, 6], "b_max": [5, 6]} + assert_equal_data(result, expected) + + +def test_group_by_simple_unnamed() -> None: + data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} + df = dataframe(data) + result = df.group_by("a").agg(nwp.col("b").min(), nwp.col("c").max()).sort("a") + expected = {"a": [1, 2], "b": [4, 6], "c": [7, 1]} + assert_equal_data(result, expected) + + +def test_group_by_multiple_keys() -> None: + data = {"a": [1, 1, 2], "b": [4, 4, 6], "c": [7, 2, 1]} + df = dataframe(data) + result = ( + df.group_by("a", "b") + .agg(c_min=nwp.col("c").min(), c_max=nwp.col("c").max()) + .sort("a") + ) + expected = {"a": [1, 2], "b": [4, 6], "c_min": [2, 1], "c_max": [7, 1]} + assert_equal_data(result, expected) + + +def test_key_with_nulls() -> None: + data = {"b": [4, 5, None], "a": [1, 2, 3]} + result = ( + dataframe(data) + .group_by("b") + .agg(nwp.len(), nwp.col("a").min()) + .sort("a") + .with_columns(nwp.col("b").cast(nw.Float64)) + ) + expected = {"b": [4.0, 5, None], "len": [1, 1, 1], "a": [1, 2, 3]} + assert_equal_data(result, expected) + + +@pytest.mark.xfail(reason="Not implemented `drop_null_keys`") +def test_key_with_nulls_ignored() -> None: # pragma: no cover + data = {"b": [4, 5, None], "a": [1, 2, 3]} + result = ( + dataframe(data) + .group_by("b", drop_null_keys=True) + .agg(nwp.len(), nwp.col("a").min()) + .sort("a") + .with_columns(nwp.col("b").cast(nw.Float64)) + ) + expected = {"b": [4.0, 5], "len": [1, 1], "a": [1, 2]} + assert_equal_data(result, expected) + + +def test_no_agg() -> None: + data = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} + result = dataframe(data).group_by(["a", "b"]).agg().sort("a", "b") + expected = {"a": [1, 3], "b": [4, 6]} + assert_equal_data(result, expected) + + +@pytest.mark.xfail( + PYARROW_VERSION < (15,), + reason=( + "The defaults for grouping by categories in pandas are different.\n\n" + "https://github.com/narwhals-dev/narwhals/issues/1078" + ), +) +def test_group_by_categorical() -> None: # pragma: no cover + data = {"g1": ["a", "a", "b", "b"], "g2": ["x", "y", "x", "z"], "x": [1, 2, 3, 4]} + df = dataframe(data) + result = ( + df.with_columns( + g1=nwp.col("g1").cast(nw.Categorical()), + g2=nwp.col("g2").cast(nw.Categorical()), + ) + .group_by(["g1", "g2"]) + .agg(nwp.col("x").sum()) + .sort("x") + ) + assert_equal_data(result, data) + + +# TODO @dangotbanned: Align the error to `InvalidOperation` +def test_group_by_shift_raises() -> None: + data = {"a": [1, 2, 3], "b": [1, 1, 2]} + df = dataframe(data) + with pytest.raises((InvalidOperationError, NotImplementedError)): + df.group_by("b").agg(nwp.col("a").shift(1)) + + +@pytest.mark.xfail( + reason="First column rename is shadowed by the second.", raises=KeyError +) +def test_double_same_aggregation() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) + result = df.group_by("a").agg(c=nwp.col("b").mean(), d=nwp.col("b").mean()).sort("a") + expected = {"a": [1, 2], "c": [4.5, 6], "d": [4.5, 6]} + assert_equal_data(result, expected) + + +def test_fancy_functions() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) + result = df.group_by("a").agg(nwp.all().std(ddof=0)).sort("a") + expected = {"a": [1, 2], "b": [0.5, 0.0]} + assert_equal_data(result, expected) + result = df.group_by("a").agg(npcs.numeric().std(ddof=0)).sort("a") + assert_equal_data(result, expected) + result = df.group_by("a").agg(npcs.matches("b").std(ddof=0)).sort("a") + assert_equal_data(result, expected) + result = df.group_by("a").agg(npcs.matches("b").std(ddof=0).alias("c")).sort("a") + expected = {"a": [1, 2], "c": [0.5, 0.0]} + assert_equal_data(result, expected) + result = ( + df.group_by("a") + .agg(npcs.matches("b").std(ddof=0).name.map(lambda _x: "c")) + .sort("a") + ) + assert_equal_data(result, expected) + + +XFAIL_NOT_IMPL_EXPR_KEYS = pytest.mark.xfail( + reason="TODO: Expr group_by keys", raises=NotImplementedError +) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected", "sort_by"), + [ + pytest.param( + [nwp.col("a").abs(), nwp.col("a").abs().alias("a_with_alias")], + [nwp.col("x").sum()], + {"a": [1, 2], "a_with_alias": [1, 2], "x": [5, 5]}, + ["a"], + marks=XFAIL_NOT_IMPL_EXPR_KEYS, + ), + pytest.param( + [nwp.col("a").alias("x")], + [nwp.col("x").mean().alias("y")], + {"x": [-1, 1, 2], "y": [4.0, 0.5, 2.5]}, + ["x"], + marks=XFAIL_NOT_IMPL_EXPR_KEYS, + ), + ( # NOTE: This one is fine as it just selects + [nwp.col("a")], + [nwp.col("a").count().alias("foo-bar"), nwp.all().sum()], + {"a": [-1, 1, 2], "foo-bar": [1, 2, 2], "x": [4, 1, 5], "y": [1.5, 0, 0]}, + ["a"], + ), + pytest.param( + [nwp.col("a", "y").abs()], + [nwp.col("x").sum()], + {"a": [1, 1, 2], "y": [0.5, 1.5, 1], "x": [1, 4, 5]}, + ["a", "y"], + marks=XFAIL_NOT_IMPL_EXPR_KEYS, + ), + pytest.param( + [nwp.col("a").abs().alias("y")], + [nwp.all().sum().name.suffix("c")], + {"y": [1, 2], "ac": [1, 4], "xc": [5, 5]}, + ["y"], + marks=XFAIL_NOT_IMPL_EXPR_KEYS, + ), + pytest.param( + [npcs.by_dtype(nw.Float64()).abs()], + [npcs.numeric().sum()], + {"y": [0.5, 1.0, 1.5], "a": [2, 4, -1], "x": [1, 5, 4]}, + ["y"], + marks=XFAIL_NOT_IMPL_EXPR_KEYS, + ), + ], +) +def test_group_by_expr( + keys: list[nwp.Expr], + aggs: list[nwp.Expr], + expected: dict[str, list[Any]], + sort_by: list[str], +) -> None: + data = {"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 4], "y": [0.5, -0.5, 1.0, -1.0, 1.5]} + df = dataframe(data) + result = df.group_by(*keys).agg(*aggs).sort(*sort_by) + assert_equal_data(result, expected) + + +def test_group_by_selector() -> None: + data = { + "a": [1, 1, 1], + "b": [4, 4, 6], + "c": ["foo", "foo", "bar"], + "x": [7.5, 8.5, 9.0], + } + result = ( + dataframe(data) + .group_by(npcs.by_dtype(nw.Int64), "c") + .agg(nwp.col("x").mean()) + .sort("a", "b") + ) + expected = {"a": [1, 1], "b": [4, 6], "c": ["foo", "bar"], "x": [8.0, 9.0]} + assert_equal_data(result, expected) + + +def test_renaming_edge_case() -> None: + data = {"a": [0, 0, 0], "_a_tmp": [1, 2, 3], "b": [4, 5, 6]} + result = dataframe(data).group_by(nwp.col("a")).agg(nwp.all().min()) + expected = {"a": [0], "_a_tmp": [1], "b": [4]} + assert_equal_data(result, expected) + + +@pytest.mark.xfail( + reason="First column rename is shadowed by the second.", raises=KeyError +) +def test_group_by_len_1_column() -> None: + """Based on a failure from marimo. + + - https://github.com/marimo-team/marimo/blob/036fd3ff89ef3a0e598bebb166637028024f98bc/tests/_plugins/ui/_impl/tables/test_narwhals.py#L1098-L1108 + - https://github.com/marimo-team/marimo/blob/036fd3ff89ef3a0e598bebb166637028024f98bc/marimo/_plugins/ui/_impl/tables/narwhals_table.py#L163-L188 + """ + data = {"a": [1, 2, 1, 2, 3, 4]} + expected = {"a": [1, 2, 3, 4], "len": [2, 2, 1, 1], "len_a": [2, 2, 1, 1]} + result = ( + dataframe(data).group_by("a").agg(nwp.len(), nwp.len().alias("len_a")).sort("a") + ) + assert_equal_data(result, expected) + + +def test_top_level_len() -> None: + # https://github.com/holoviz/holoviews/pull/6567#issuecomment-3178743331 + df = dataframe({"gender": ["m", "f", "f"], "weight": [4, 5, 6], "age": [None, 8, 9]}) + result = df.group_by(["gender"]).agg(nwp.all().len()).sort("gender") + expected = {"gender": ["f", "m"], "weight": [2, 1], "age": [2, 1]} + assert_equal_data(result, expected) + result = ( + df.group_by("gender") + .agg(nwp.col("weight").len(), nwp.col("age").len()) + .sort("gender") + ) + assert_equal_data(result, expected) From 45a816f63fd2f6d8c68ad7b57a7dff20ff5c9f81 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 19 Sep 2025 21:14:10 +0000 Subject: [PATCH 18/93] cov https://github.com/narwhals-dev/narwhals/actions/runs/17869832358/job/50820920705?pr=3143 --- tests/plan/group_by_test.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 2a7c050668..7a51f72e7d 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -23,10 +23,8 @@ def dataframe(data: dict[str, Any], /) -> nwp.DataFrame[Any, Any]: return nwp.DataFrame.from_native(pa.table(data)) -def assert_equal_data(result: Any, expected: Mapping[str, Any]) -> None: - if isinstance(result, nwp.DataFrame): - result = result.to_dict(as_series=False) - _assert_equal_data(result, expected) +def assert_equal_data(result: nwp.DataFrame, expected: Mapping[str, Any]) -> None: + _assert_equal_data(result.to_dict(as_series=False), expected) @pytest.mark.parametrize( From 8f2ad5014a86f9cd61f75bab93eb2893908a37bf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 19 Sep 2025 21:20:40 +0000 Subject: [PATCH 19/93] chore: Update todo --- narwhals/_plan/arrow/group_by.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index ac78f17fa5..18b0f76b02 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -38,7 +38,6 @@ BACKEND_VERSION = Implementation.PYARROW._backend_version() -# TODO @dangotbanned: Missing `nw.col("a").len()` SUPPORTED_AGG: Mapping[type[agg.AggExpr], Aggregation] = { agg.Sum: "sum", agg.Mean: "mean", @@ -53,9 +52,10 @@ agg.First: "first", agg.Last: "last", } - - -SUPPORTED_IR: Mapping[type[ir.Len], Aggregation] = {ir.Len: "count_all"} +SUPPORTED_IR: Mapping[type[ir.ExprIR], Aggregation] = { + ir.Len: "count_all", + ir.Column: "list", +} SUPPORTED_FUNCTION: Mapping[type[ir.Function], Aggregation] = { ir.boolean.All: "all", ir.boolean.Any: "any", @@ -64,13 +64,12 @@ REMAINING: tuple[Aggregation, ...] = ( "first_last", # Compute the first and last of values in each group - "list", # List all values in each group "min_max", # Compute the minimum and maximum of values in each group "one", # Get one value from each group "product", # Compute the product of values in each group "tdigest", # Compute approximate quantiles of values in each group ) -"""Available [native aggs] we haven't used (excluding `first`, `last`) +"""Available [native aggs] we haven't used. [native aggs]: https://arrow.apache.org/docs/python/compute.html#grouped-aggregations """ From 46504560e5aa1944508ccac9ba6f1b5c6301f2f9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 19 Sep 2025 21:52:23 +0000 Subject: [PATCH 20/93] chore: Add todo for `drop_null_keys=True` --- narwhals/_plan/dataframe.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 8feeb7f2d3..7e265bc2d7 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -141,7 +141,13 @@ def __len__(self) -> int: return len(self._compliant) def group_by( - self, *by: OneOrIterable[IntoExpr], **named_by: IntoExpr + self, + *by: OneOrIterable[IntoExpr], + drop_null_keys: bool = False, + **named_by: IntoExpr, ) -> GroupBy[Self]: + if drop_null_keys: + msg = "TODO: `drop_null_keys=True`" + raise NotImplementedError(msg) exprs = _parse.parse_into_seq_of_expr_ir(*by, **named_by) return GroupBy(self, exprs) From 4a52cecfbc0fec898e2d7917dfee11702c468287 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 20 Sep 2025 11:20:02 +0000 Subject: [PATCH 21/93] feat(DRAFT): start custom `pa.TableGroupBy` impl Just pushing this as tests are working. Useful changes to follow: - Column renaming stuff will be avoidable - we just use `ArrowAggExpr.output_name` - Awkward stuff `first`, `last`, `_ensure_single_thread` can be avoided - `use_threads` was always available on `Declaration.to_table` - Whether we need to use can just be an `__ior__` --- narwhals/_plan/arrow/group_by.py | 64 +++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 18b0f76b02..0c3c9bd5e5 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Literal import pyarrow as pa # ignore-banned-import +import pyarrow.acero as pac import pyarrow.compute as pc # ignore-banned-import from narwhals._compliant.typing import NarwhalsAggregation as _NarwhalsAggregation @@ -10,9 +11,10 @@ from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import DataFrameGroupBy from narwhals._utils import Implementation, requires +from narwhals.exceptions import ComputeError if TYPE_CHECKING: - from collections.abc import Iterator, Mapping + from collections.abc import Iterable, Iterator, Mapping from typing_extensions import Self, TypeAlias @@ -85,6 +87,8 @@ """https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations""" +# NOTE: Was available internally in `pyarrow==13` +# https://github.com/apache/arrow/blob/b7d2f7ffca66c868bd2fce5b3749c6caa002a7f0/python/pyarrow/acero.py#L302-L308 def _ensure_single_thread( grouped: pa.TableGroupBy, expr: ir.OrderableAggExpr, / ) -> pa.TableGroupBy: @@ -219,4 +223,60 @@ def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame: gb, agg_spec, rename = ArrowAggExpr(e).to_native(gb) aggs.append(agg_spec) renames.append(rename) - return self.compliant._with_native(gb.aggregate(aggs)).rename(dict(renames)) + result = _aggregate( + self.compliant.native, + list(self.keys_names), + aggs, + use_threads=gb._use_threads, + ) + return self.compliant._with_native(result).rename(dict(renames)) + + +_HASH: Literal["hash_"] = "hash_" + + +# TODO @dangotbanned: need to pass in the second element of `RenameSpec` + use that for `aggr_name` +def _aggregate( + df: pa.Table, + keys: list[str | pc.Expression], + aggregations: Iterable[NativeAggSpec], + *, + use_threads: bool, +) -> pa.Table: + """Adapted from [`pa.TableGroupBy.aggregate`]. + + [`pa.TableGroupBy.aggregate`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/table.pxi#L6600-L6626 + """ + if not keys: + # NOTE: We guard against this earlier, but `pyarrow` allows empty keys at this stage + msg = "at least one key is required in a group_by operation" + raise ComputeError(msg) + group_by_aggrs = [] + for aggr in aggregations: + target, func, opt = aggr + # Ensure target is a list + if isinstance(target, str): + target = [target] + # Ensure aggregate function is hash_ + # NOTE: Currently always the case, but probably want to invert that + hash_func = f"{_HASH}{func}" + # Determine output field name + aggr_name = "_".join((*target, func)) # <<<<<<<<<<<<<<<< replace me!!! + group_by_aggrs.append((target, hash_func, opt, aggr_name)) + return _group_by(df, group_by_aggrs, keys, use_threads=use_threads) + + +def _group_by( + table: pa.Table, + aggregates: Any, + keys: list[str | pc.Expression], + *, + use_threads: bool = True, +) -> pa.Table: + decl = pac.Declaration.from_sequence( + [ + pac.Declaration("table_source", pac.TableSourceNodeOptions(table)), + pac.Declaration("aggregate", pac.AggregateNodeOptions(aggregates, keys=keys)), + ] + ) + return decl.to_table(use_threads=use_threads) From ce18f518ac21d5484d5014282047125690960fc9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 20 Sep 2025 12:41:39 +0000 Subject: [PATCH 22/93] fix: Avoid shadowed output aggregation names --- narwhals/_plan/arrow/group_by.py | 163 +++++++++++++------------------ tests/plan/group_by_test.py | 6 -- 2 files changed, 70 insertions(+), 99 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 0c3c9bd5e5..89bfbd4164 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -6,12 +6,10 @@ import pyarrow.acero as pac import pyarrow.compute as pc # ignore-banned-import -from narwhals._compliant.typing import NarwhalsAggregation as _NarwhalsAggregation from narwhals._plan import expressions as ir from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import DataFrameGroupBy from narwhals._utils import Implementation, requires -from narwhals.exceptions import ComputeError if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping @@ -26,50 +24,48 @@ from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq +Incomplete: TypeAlias = Any -NarwhalsAggregation: TypeAlias = Literal[_NarwhalsAggregation, "first", "last"] -InputName: TypeAlias = "str | tuple[()]" -"""`()` can be used with `"count_all"`.""" - -NativeName: TypeAlias = str +AceroTarget: TypeAlias = "tuple[()] | list[str]" +NativeAggSpec: TypeAlias = "tuple[AceroTarget, Aggregation, AggregateOptions | None]" OutputName: TypeAlias = str -NativeAggSpec: TypeAlias = "tuple[InputName, Aggregation, AggregateOptions | None]" -RenameSpec: TypeAlias = tuple[NativeName, OutputName] +AceroAggSpec: TypeAlias = ( + "tuple[AceroTarget, Aggregation, AggregateOptions | None, OutputName]" +) BACKEND_VERSION = Implementation.PYARROW._backend_version() - SUPPORTED_AGG: Mapping[type[agg.AggExpr], Aggregation] = { - agg.Sum: "sum", - agg.Mean: "mean", - agg.Median: "approximate_median", - agg.Max: "max", - agg.Min: "min", - agg.Std: "stddev", - agg.Var: "variance", - agg.Count: "count", - agg.Len: "count", - agg.NUnique: "count_distinct", - agg.First: "first", - agg.Last: "last", + agg.Sum: "hash_sum", + agg.Mean: "hash_mean", + agg.Median: "hash_approximate_median", + agg.Max: "hash_max", + agg.Min: "hash_min", + agg.Std: "hash_stddev", + agg.Var: "hash_variance", + agg.Count: "hash_count", + agg.Len: "hash_count", + agg.NUnique: "hash_count_distinct", + agg.First: "hash_first", + agg.Last: "hash_last", } SUPPORTED_IR: Mapping[type[ir.ExprIR], Aggregation] = { - ir.Len: "count_all", - ir.Column: "list", + ir.Len: "hash_count_all", + ir.Column: "hash_list", } SUPPORTED_FUNCTION: Mapping[type[ir.Function], Aggregation] = { - ir.boolean.All: "all", - ir.boolean.Any: "any", - ir.functions.Unique: "distinct", + ir.boolean.All: "hash_all", + ir.boolean.Any: "hash_any", + ir.functions.Unique: "hash_distinct", } REMAINING: tuple[Aggregation, ...] = ( - "first_last", # Compute the first and last of values in each group - "min_max", # Compute the minimum and maximum of values in each group - "one", # Get one value from each group - "product", # Compute the product of values in each group - "tdigest", # Compute approximate quantiles of values in each group + "hash_first_last", # Compute the first and last of values in each group + "hash_min_max", # Compute the minimum and maximum of values in each group + "hash_one", # Get one value from each group + "hash_product", # Compute the product of values in each group + "hash_tdigest", # Compute approximate quantiles of values in each group ) """Available [native aggs] we haven't used. @@ -87,8 +83,7 @@ """https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations""" -# NOTE: Was available internally in `pyarrow==13` -# https://github.com/apache/arrow/blob/b7d2f7ffca66c868bd2fce5b3749c6caa002a7f0/python/pyarrow/acero.py#L302-L308 +# TODO @dangotbanned: Factor out to just a `bool.__ior__` def _ensure_single_thread( grouped: pa.TableGroupBy, expr: ir.OrderableAggExpr, / ) -> pa.TableGroupBy: @@ -133,7 +128,7 @@ def output_name(self) -> OutputName: def _parse_agg_expr( self, expr: agg.AggExpr, grouped: pa.TableGroupBy - ) -> tuple[InputName, Aggregation, AggregateOptions | None, pa.TableGroupBy]: + ) -> tuple[AceroTarget, Aggregation, AggregateOptions | None, pa.TableGroupBy]: if agg_name := SUPPORTED_AGG.get(type(expr)): option: AggregateOptions | None = None if isinstance(expr, (agg.Std, agg.Var)): @@ -148,7 +143,7 @@ def _parse_agg_expr( # NOTE: Only branch which needs access to `pa.TableGroupBy` grouped = _ensure_single_thread(grouped, expr) if isinstance(expr.expr, ir.Column): - return expr.expr.name, agg_name, option, grouped + return [expr.expr.name], agg_name, option, grouped raise group_by_error(self, "too complex") raise group_by_error(self, "unsupported aggregation") @@ -162,31 +157,25 @@ def _parse_function_expr(self, expr: ir.FunctionExpr) -> NativeAggSpec: else: raise group_by_error(self, "unsupported function") if len(expr.input) == 1 and isinstance(expr.input[0], ir.Column): - return expr.input[0].name, agg_name, option + return [expr.input[0].name], agg_name, option raise group_by_error(self, "too complex") - def _rename_spec(self, input_name: InputName, agg_name: Aggregation, /) -> RenameSpec: - # `pyarrow` auto-generates the lhs - # we want to overwrite that later with rhs - old = f"{input_name}_{agg_name}" if input_name else agg_name - return old, self.output_name - - def to_native( - self, grouped: pa.TableGroupBy - ) -> tuple[pa.TableGroupBy, NativeAggSpec, RenameSpec]: + def to_native(self, grouped: pa.TableGroupBy) -> tuple[pa.TableGroupBy, AceroAggSpec]: expr = self.named_ir.expr + input_name: AceroTarget = () + option: AggregateOptions | None = None if isinstance(expr, agg.AggExpr): input_name, agg_name, option, grouped = self._parse_agg_expr(expr, grouped) - elif isinstance(expr, ir.Len): - input_name, agg_name, option = ((), "count_all", None) - elif isinstance(expr, ir.Column): - input_name, agg_name, option = (expr.name, "list", None) elif isinstance(expr, ir.FunctionExpr): input_name, agg_name, option = self._parse_function_expr(expr) + elif isinstance(expr, (ir.Len, ir.Column)): + agg_name = SUPPORTED_IR[type(expr)] + if isinstance(expr, ir.Column): + input_name = [expr.name] else: raise group_by_error(self, "unsupported expression") - agg_spec = input_name, agg_name, option - return grouped, agg_spec, self._rename_spec(input_name, agg_name) + agg_spec = input_name, agg_name, option, self.output_name + return grouped, agg_spec class ArrowGroupBy(DataFrameGroupBy["ArrowDataFrame"]): @@ -217,66 +206,54 @@ def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame: gb = self._grouped - aggs: list[NativeAggSpec] = [] - renames: list[RenameSpec] = [] + aggs: list[AceroAggSpec] = [] for e in irs: - gb, agg_spec, rename = ArrowAggExpr(e).to_native(gb) + gb, agg_spec = ArrowAggExpr(e).to_native(gb) aggs.append(agg_spec) - renames.append(rename) result = _aggregate( self.compliant.native, list(self.keys_names), aggs, use_threads=gb._use_threads, ) - return self.compliant._with_native(result).rename(dict(renames)) + return self.compliant._with_native(result) -_HASH: Literal["hash_"] = "hash_" - - -# TODO @dangotbanned: need to pass in the second element of `RenameSpec` + use that for `aggr_name` def _aggregate( df: pa.Table, - keys: list[str | pc.Expression], - aggregations: Iterable[NativeAggSpec], + keys: list[str], + aggregations: Iterable[ # TODO @dangotbanned: Revisit after replacing `_ensure_single_thread` + AceroAggSpec + ], *, use_threads: bool, ) -> pa.Table: - """Adapted from [`pa.TableGroupBy.aggregate`]. - - [`pa.TableGroupBy.aggregate`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/table.pxi#L6600-L6626 - """ - if not keys: - # NOTE: We guard against this earlier, but `pyarrow` allows empty keys at this stage - msg = "at least one key is required in a group_by operation" - raise ComputeError(msg) - group_by_aggrs = [] - for aggr in aggregations: - target, func, opt = aggr - # Ensure target is a list - if isinstance(target, str): - target = [target] - # Ensure aggregate function is hash_ - # NOTE: Currently always the case, but probably want to invert that - hash_func = f"{_HASH}{func}" - # Determine output field name - aggr_name = "_".join((*target, func)) # <<<<<<<<<<<<<<<< replace me!!! - group_by_aggrs.append((target, hash_func, opt, aggr_name)) - return _group_by(df, group_by_aggrs, keys, use_threads=use_threads) + """Adapted from [`pa.TableGroupBy.aggregate`](https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/table.pxi#L6600-L6626).""" + aggs = list(aggregations) if not isinstance(aggregations, list) else aggregations + return _group_by(df, keys, aggs, use_threads=use_threads) def _group_by( table: pa.Table, - aggregates: Any, - keys: list[str | pc.Expression], + keys: list[str], + aggregates: list[AceroAggSpec], *, use_threads: bool = True, ) -> pa.Table: - decl = pac.Declaration.from_sequence( - [ - pac.Declaration("table_source", pac.TableSourceNodeOptions(table)), - pac.Declaration("aggregate", pac.AggregateNodeOptions(aggregates, keys=keys)), - ] - ) - return decl.to_table(use_threads=use_threads) + """Backport of [apache/arrow#36768]. + + `first` and `last` were [broken in `pyarrow==13`]. + + Also allows us to specify our own aliases for aggregate output columns. + + [apache/arrow#36768]: https://github.com/apache/arrow/pull/36768 + [broken in `pyarrow==13`]: https://github.com/apache/arrow/issues/36709 + """ + # NOTE: Stubs are (incorrectly) invariant + aggs: Incomplete = aggregates + keys_: Incomplete = keys + decls = [ + pac.Declaration("table_source", pac.TableSourceNodeOptions(table)), + pac.Declaration("aggregate", pac.AggregateNodeOptions(aggs, keys=keys_)), + ] + return pac.Declaration.from_sequence(decls).to_table(use_threads=use_threads) diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 7a51f72e7d..457848f2d4 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -217,9 +217,6 @@ def test_group_by_shift_raises() -> None: df.group_by("b").agg(nwp.col("a").shift(1)) -@pytest.mark.xfail( - reason="First column rename is shadowed by the second.", raises=KeyError -) def test_double_same_aggregation() -> None: df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) result = df.group_by("a").agg(c=nwp.col("b").mean(), d=nwp.col("b").mean()).sort("a") @@ -334,9 +331,6 @@ def test_renaming_edge_case() -> None: assert_equal_data(result, expected) -@pytest.mark.xfail( - reason="First column rename is shadowed by the second.", raises=KeyError -) def test_group_by_len_1_column() -> None: """Based on a failure from marimo. From 16148b2cb8f1e7e98d3d7287cccbcc5c2debcff1 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 20 Sep 2025 13:47:01 +0000 Subject: [PATCH 23/93] feat(expr-ir): Rewrite, fix ordered aggregations --- narwhals/_plan/arrow/group_by.py | 123 +++++++++++-------------------- 1 file changed, 42 insertions(+), 81 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 89bfbd4164..bc6939d673 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -9,10 +9,10 @@ from narwhals._plan import expressions as ir from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import DataFrameGroupBy -from narwhals._utils import Implementation, requires +from narwhals._utils import Implementation if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping + from collections.abc import Iterator, Mapping from typing_extensions import Self, TypeAlias @@ -83,24 +83,6 @@ """https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations""" -# TODO @dangotbanned: Factor out to just a `bool.__ior__` -def _ensure_single_thread( - grouped: pa.TableGroupBy, expr: ir.OrderableAggExpr, / -) -> pa.TableGroupBy: - """First/last require disabling threading.""" - if BACKEND_VERSION >= (14, 0) and grouped._use_threads: - # NOTE: Stubs say `_table` is a method, but at runtime it is a property - grouped = pa.TableGroupBy(grouped._table, grouped.keys, use_threads=False) # type: ignore[arg-type] - elif BACKEND_VERSION < (14, 0): # pragma: no cover - msg = ( - f"Using `{expr!r}` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', " - f"found version {requires._unparse_version(BACKEND_VERSION)!r}.\n\n" - f"See https://github.com/apache/arrow/issues/36709" - ) - raise NotImplementedError(msg) - return grouped - - def group_by_error( expr: ArrowAggExpr, reason: Literal[ @@ -121,14 +103,17 @@ def group_by_error( class ArrowAggExpr: def __init__(self, named_ir: NamedIR, /) -> None: self.named_ir: NamedIR = named_ir + self.use_threads: bool = True + """See https://github.com/apache/arrow/issues/36709""" + self.spec: AceroAggSpec @property def output_name(self) -> OutputName: return self.named_ir.name def _parse_agg_expr( - self, expr: agg.AggExpr, grouped: pa.TableGroupBy - ) -> tuple[AceroTarget, Aggregation, AggregateOptions | None, pa.TableGroupBy]: + self, expr: agg.AggExpr + ) -> tuple[AceroTarget, Aggregation, AggregateOptions | None]: if agg_name := SUPPORTED_AGG.get(type(expr)): option: AggregateOptions | None = None if isinstance(expr, (agg.Std, agg.Var)): @@ -140,10 +125,9 @@ def _parse_agg_expr( option = pc.CountOptions(mode="only_valid") elif isinstance(expr, (agg.First, agg.Last)): option = pc.ScalarAggregateOptions(skip_nulls=False) - # NOTE: Only branch which needs access to `pa.TableGroupBy` - grouped = _ensure_single_thread(grouped, expr) + self.use_threads = False if isinstance(expr.expr, ir.Column): - return [expr.expr.name], agg_name, option, grouped + return [expr.expr.name], agg_name, option raise group_by_error(self, "too complex") raise group_by_error(self, "unsupported aggregation") @@ -160,12 +144,12 @@ def _parse_function_expr(self, expr: ir.FunctionExpr) -> NativeAggSpec: return [expr.input[0].name], agg_name, option raise group_by_error(self, "too complex") - def to_native(self, grouped: pa.TableGroupBy) -> tuple[pa.TableGroupBy, AceroAggSpec]: + def parse(self) -> Self: expr = self.named_ir.expr input_name: AceroTarget = () option: AggregateOptions | None = None if isinstance(expr, agg.AggExpr): - input_name, agg_name, option, grouped = self._parse_agg_expr(expr, grouped) + input_name, agg_name, option = self._parse_agg_expr(expr) elif isinstance(expr, ir.FunctionExpr): input_name, agg_name, option = self._parse_function_expr(expr) elif isinstance(expr, (ir.Len, ir.Column)): @@ -174,13 +158,12 @@ def to_native(self, grouped: pa.TableGroupBy) -> tuple[pa.TableGroupBy, AceroAgg input_name = [expr.name] else: raise group_by_error(self, "unsupported expression") - agg_spec = input_name, agg_name, option, self.output_name - return grouped, agg_spec + self.spec = input_name, agg_name, option, self.output_name + return self class ArrowGroupBy(DataFrameGroupBy["ArrowDataFrame"]): _df: ArrowDataFrame - _grouped: pa.TableGroupBy _keys: Seq[NamedIR] _keys_names: Seq[str] @@ -190,7 +173,6 @@ def by_names(cls, df: ArrowDataFrame, names: Seq[str], /) -> Self: obj._df = df obj._keys = () obj._keys_names = names - obj._grouped = pa.TableGroupBy(df.native, list(names)) return obj @classmethod @@ -205,55 +187,34 @@ def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: raise NotImplementedError def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame: - gb = self._grouped aggs: list[AceroAggSpec] = [] + use_threads: bool = True for e in irs: - gb, agg_spec = ArrowAggExpr(e).to_native(gb) - aggs.append(agg_spec) - result = _aggregate( - self.compliant.native, - list(self.keys_names), - aggs, - use_threads=gb._use_threads, - ) - return self.compliant._with_native(result) - - -def _aggregate( - df: pa.Table, - keys: list[str], - aggregations: Iterable[ # TODO @dangotbanned: Revisit after replacing `_ensure_single_thread` - AceroAggSpec - ], - *, - use_threads: bool, -) -> pa.Table: - """Adapted from [`pa.TableGroupBy.aggregate`](https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/table.pxi#L6600-L6626).""" - aggs = list(aggregations) if not isinstance(aggregations, list) else aggregations - return _group_by(df, keys, aggs, use_threads=use_threads) - - -def _group_by( - table: pa.Table, - keys: list[str], - aggregates: list[AceroAggSpec], - *, - use_threads: bool = True, -) -> pa.Table: - """Backport of [apache/arrow#36768]. - - `first` and `last` were [broken in `pyarrow==13`]. - - Also allows us to specify our own aliases for aggregate output columns. - - [apache/arrow#36768]: https://github.com/apache/arrow/pull/36768 - [broken in `pyarrow==13`]: https://github.com/apache/arrow/issues/36709 - """ - # NOTE: Stubs are (incorrectly) invariant - aggs: Incomplete = aggregates - keys_: Incomplete = keys - decls = [ - pac.Declaration("table_source", pac.TableSourceNodeOptions(table)), - pac.Declaration("aggregate", pac.AggregateNodeOptions(aggs, keys=keys_)), - ] - return pac.Declaration.from_sequence(decls).to_table(use_threads=use_threads) + expr = ArrowAggExpr(e).parse() + use_threads = use_threads and expr.use_threads + aggs.append(expr.spec) + return self.compliant._with_native(self._agg(aggs, use_threads=use_threads)) + + def _agg(self, agg_specs: list[AceroAggSpec], /, *, use_threads: bool) -> pa.Table: + """Adapted from [`pa.TableGroupBy.aggregate`] and [`pa.acero._group_by`]. + + - Backport of [apache/arrow#36768]. + - `first` and `last` were [broken in `pyarrow==13`]. + - Also allows us to specify our own aliases for aggregate output columns. + - Fixes [narwhals-dev/narwhals#1612] + + [`pa.TableGroupBy.aggregate`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/table.pxi#L6600-L6626 + [`pa.acero._group_by`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/acero.py#L412-L418 + [apache/arrow#36768]: https://github.com/apache/arrow/pull/36768 + [broken in `pyarrow==13`]: https://github.com/apache/arrow/issues/36709 + [narwhals-dev/narwhals#1612]: https://github.com/narwhals-dev/narwhals/issues/1612 + """ + df = self.compliant.native + # NOTE: Stubs are (incorrectly) invariant + keys: Incomplete = list(self.keys_names) + aggs: Incomplete = agg_specs + decls = [ + pac.Declaration("table_source", pac.TableSourceNodeOptions(df)), + pac.Declaration("aggregate", pac.AggregateNodeOptions(aggs, keys=keys)), + ] + return pac.Declaration.from_sequence(decls).to_table(use_threads=use_threads) From ce86f8f76f0628fbe643860d5d14947df6b3a308 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 20 Sep 2025 13:50:49 +0000 Subject: [PATCH 24/93] test: Port over `first`, `last` group_by tests From #2528 https://github.com/narwhals-dev/narwhals/blob/0fb045536f5b56b978f354f8178b292301e9598c/tests/frame/group_by_test.py#L686-L781 --- tests/plan/group_by_test.py | 68 ++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 457848f2d4..a91f5c1e17 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -16,7 +16,7 @@ import pyarrow as pa if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Mapping, Sequence def dataframe(data: dict[str, Any], /) -> nwp.DataFrame[Any, Any]: @@ -357,3 +357,69 @@ def test_top_level_len() -> None: .sort("gender") ) assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected", "pre_sort"), + [ + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 2, 4, 6]}, None), + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 3, 5, 6]}, {"descending": True}), + (["a"], ["c"], {"a": [1, 2, 3, 4], "c": [None, "A", None, "B"]}, None), + ( + ["a"], + ["c"], + {"a": [1, 2, 3, 4], "c": [None, "A", "B", "B"]}, + {"nulls_last": True}, + ), + ], + ids=["no-sort", "sort-descending", "NA-order-nulls-first", "NA-order-nulls-last"], +) +def test_group_by_agg_first( + keys: Sequence[str], + aggs: Sequence[str], + expected: Mapping[str, Any], + pre_sort: Mapping[str, Any] | None, +) -> None: + data = { + "a": [1, 2, 2, 3, 3, 4], + "b": [1, 2, 3, 4, 5, 6], + "c": [None, "A", "A", None, "B", "B"], + } + df = dataframe(data) + if pre_sort: + df = df.sort(aggs, **pre_sort) + result = df.group_by(keys).agg(nwp.col(aggs).first()).sort(keys) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected", "pre_sort"), + [ + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 3, 5, 6]}, None), + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 2, 4, 6]}, {"descending": True}), + (["a"], ["c"], {"a": [1, 2, 3, 4], "c": [None, "A", "B", "B"]}, None), + ( + ["a"], + ["c"], + {"a": [1, 2, 3, 4], "c": [None, "A", None, "B"]}, + {"nulls_last": True}, + ), + ], + ids=["no-sort", "sort-descending", "NA-order-nulls-first", "NA-order-nulls-last"], +) +def test_group_by_agg_last( + keys: Sequence[str], + aggs: Sequence[str], + expected: Mapping[str, Any], + pre_sort: Mapping[str, Any] | None, +) -> None: + data = { + "a": [1, 2, 2, 3, 3, 4], + "b": [1, 2, 3, 4, 5, 6], + "c": [None, "A", "A", None, "B", "B"], + } + df = dataframe(data) + if pre_sort: + df = df.sort(aggs, **pre_sort) + result = df.group_by(keys).agg(nwp.col(aggs).last()).sort(keys) + assert_equal_data(result, expected) From 581e51123622d6802cc583c5bd9aff88155e65d8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 20 Sep 2025 14:27:53 +0000 Subject: [PATCH 25/93] test: Add failing `drop_null_keys`, `__iter__` tests --- narwhals/_plan/group_by.py | 8 +++++- tests/plan/group_by_test.py | 56 +++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index d4c4e471d1..56f0be1142 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -20,7 +20,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, NamedTuple +from typing import TYPE_CHECKING, Any, Generic, NamedTuple from narwhals._plan import _parse from narwhals._plan._expansion import ( @@ -32,6 +32,8 @@ from narwhals._plan.typing import DataFrameT if TYPE_CHECKING: + from collections.abc import Iterator + from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq from narwhals.schema import Schema @@ -69,6 +71,10 @@ def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFra grouped = compliant_gb.by_names(compliant, resolved.keys_names) return self._frame._from_compliant(grouped.agg(resolved.aggs)) + def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: + msg = "Not Implemented `GroupBy.__iter__`" + raise NotImplementedError(msg) + class _TempGroupByStuff(NamedTuple): """Trying to organize info that's useful to keep from `resolve_group_by`. diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index a91f5c1e17..b44fad0454 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -27,6 +27,26 @@ def assert_equal_data(result: nwp.DataFrame, expected: Mapping[str, Any]) -> Non _assert_equal_data(result.to_dict(as_series=False), expected) +@pytest.mark.xfail(reason="Not implemented `__iter__`", raises=NotImplementedError) +def test_group_by_iter() -> None: # pragma: no cover + data = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} + df = dataframe(data) + expected_keys: list[tuple[int, ...]] = [(1,), (3,)] + keys = [] + for key, sub_df in df.group_by("a"): + if key == (1,): + expected = {"a": [1, 1], "b": [4, 4], "c": [7.0, 8.0]} + assert_equal_data(sub_df, expected) + assert isinstance(sub_df, nwp.DataFrame) + keys.append(key) + assert sorted(keys) == sorted(expected_keys) + expected_keys = [(1, 4), (3, 6)] + keys = [key for key, _ in df.group_by("a", "b")] + assert sorted(keys) == sorted(expected_keys) + keys = [key for key, _ in df.group_by("a", "b")] + assert sorted(keys) == sorted(expected_keys) + + @pytest.mark.parametrize( ("attr", "expected"), [ @@ -180,6 +200,42 @@ def test_key_with_nulls_ignored() -> None: # pragma: no cover assert_equal_data(result, expected) +@pytest.mark.xfail( + reason="Not implemented `drop_null_keys`, `__iter__`", raises=NotImplementedError +) +def test_key_with_nulls_iter() -> None: # pragma: no cover + data = { + "b": [None, "4", "5", None, "7"], + "a": [None, 1, 2, 3, 4], + "c": [None, "4", "3", None, None], + } + result = dict(dataframe(data).group_by("b", "c", drop_null_keys=True).__iter__()) + + assert len(result) == 2 + assert_equal_data(result[("4", "4")], {"b": ["4"], "a": [1], "c": ["4"]}) + assert_equal_data(result[("5", "3")], {"b": ["5"], "a": [2], "c": ["3"]}) + + result = dict(dataframe(data).group_by("b", "c", drop_null_keys=False).__iter__()) + assert_equal_data(result[("4", "4")], {"b": ["4"], "a": [1], "c": ["4"]}) + assert_equal_data(result[("5", "3")], {"b": ["5"], "a": [2], "c": ["3"]}) + assert len(result) == 4 + + +@pytest.mark.xfail(reason="Not implemented `drop_null_keys`, `Expr` as keys") +@pytest.mark.parametrize( + "keys", [[nwp.col("a").abs()], ["a", nwp.col("a").abs().alias("a_test")]] +) +def test_group_by_raise_drop_null_keys_with_exprs( + keys: list[nwp.Expr | str], +) -> None: # pragma: no cover + data = {"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 4], "y": [0.5, -0.5, 1.0, -1.0, 1.5]} + df = dataframe(data) + with pytest.raises( + NotImplementedError, match="drop_null_keys cannot be True when keys contains Expr" + ): + df.group_by(*keys, drop_null_keys=True) # type: ignore[call-overload,unused-ignore] + + def test_no_agg() -> None: data = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} result = dataframe(data).group_by(["a", "b"]).agg().sort("a", "b") From 4b77500764ca1a872d5852cf31521bfe9a302eb9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 20 Sep 2025 15:34:16 +0000 Subject: [PATCH 26/93] feat(expr-ir): Support `group_by(drop_null_keys=True)` `ArrowDataFrame.drop_nulls` is shorter and waaaaaaay more efficient than `main` --- narwhals/_plan/arrow/dataframe.py | 15 ++++++++++++--- narwhals/_plan/arrow/group_by.py | 6 +++++- narwhals/_plan/dataframe.py | 22 ++++++++++++++++++---- narwhals/_plan/group_by.py | 10 ++++++++-- narwhals/_plan/protocols.py | 4 +++- tests/plan/group_by_test.py | 9 +++------ 6 files changed, 49 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 667030a6fb..17da99163b 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import reduce from typing import TYPE_CHECKING, Any, Literal, cast, overload import pyarrow as pa # ignore-banned-import @@ -12,7 +13,7 @@ from narwhals._plan.expressions import NamedIR from narwhals._plan.protocols import EagerDataFrame, namespace from narwhals._plan.typing import Seq -from narwhals._utils import Version +from narwhals._utils import Version, parse_columns_to_drop from narwhals.schema import Schema if TYPE_CHECKING: @@ -102,10 +103,18 @@ def get_column(self, name: str) -> Series: chunked = self.native.column(name) return Series.from_native(chunked, name, version=self.version) - def drop(self, columns: Sequence[str]) -> Self: - to_drop = list(columns) + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: + to_drop = parse_columns_to_drop(self, columns, strict=strict) return self._with_native(self.native.drop(to_drop)) + def drop_nulls(self, subset: Sequence[str] | None) -> Self: + if subset is None: + native = self.native.drop_null() + else: + to_drop = reduce(pc.or_, (pc.field(name).is_null() for name in subset)) + native = self.native.filter(~to_drop) + return self._with_native(native) + def rename(self, mapping: Mapping[str, str]) -> Self: names: dict[str, str] | list[str] if fn.BACKEND_VERSION >= (17,): diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index bc6939d673..bb307152ec 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -168,8 +168,12 @@ class ArrowGroupBy(DataFrameGroupBy["ArrowDataFrame"]): _keys_names: Seq[str] @classmethod - def by_names(cls, df: ArrowDataFrame, names: Seq[str], /) -> Self: + def by_names( + cls, df: ArrowDataFrame, names: Seq[str], /, *, drop_null_keys: bool = False + ) -> Self: obj = cls.__new__(cls) + if drop_null_keys: + df = df.drop_nulls(names) obj._df = df obj._keys = () obj._keys_names = names diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 7e265bc2d7..55b75e3e00 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload +from typing_extensions import Self + from narwhals._plan import _expansion, _parse from narwhals._plan.contexts import ExprContext from narwhals._plan.expr import _parse_sort_by @@ -19,6 +21,8 @@ from narwhals.schema import Schema if TYPE_CHECKING: + from collections.abc import Sequence + import pyarrow as pa from typing_extensions import Self @@ -97,6 +101,12 @@ def sort( named_irs = _expansion.into_named_irs(irs, output_names) return self._from_compliant(self._compliant.sort(named_irs, opts)) + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: + raise NotImplementedError + + def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: + raise NotImplementedError + class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): _compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT] @@ -146,8 +156,12 @@ def group_by( drop_null_keys: bool = False, **named_by: IntoExpr, ) -> GroupBy[Self]: - if drop_null_keys: - msg = "TODO: `drop_null_keys=True`" - raise NotImplementedError(msg) exprs = _parse.parse_into_seq_of_expr_ir(*by, **named_by) - return GroupBy(self, exprs) + return GroupBy(self, exprs, drop_null_keys=drop_null_keys) + + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: + return self._from_compliant(self._compliant.drop(columns, strict=strict)) + + def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: + subset = [subset] if isinstance(subset, str) else subset + return self._from_compliant(self._compliant.drop_nulls(subset)) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 56f0be1142..cc417e879e 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -42,10 +42,14 @@ class GroupBy(Generic[DataFrameT]): _frame: DataFrameT _keys: Seq[ExprIR] + _drop_null_keys: bool - def __init__(self, frame: DataFrameT, keys: Seq[ExprIR], /) -> None: + def __init__( + self, frame: DataFrameT, keys: Seq[ExprIR], /, *, drop_null_keys: bool = False + ) -> None: self._frame = frame self._keys = keys + self._drop_null_keys = drop_null_keys def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: frame = self._frame @@ -68,7 +72,9 @@ def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFra grouped = compliant_gb.by_named_irs(compliant, resolved.keys) else: # noqa: RET506 # If not, we can just use the resolved key names as a fast-path - grouped = compliant_gb.by_names(compliant, resolved.keys_names) + grouped = compliant_gb.by_names( + compliant, resolved.keys_names, drop_null_keys=self._drop_null_keys + ) return self._frame._from_compliant(grouped.agg(resolved.aggs)) def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index f1084e42fc..567ec52a9e 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -568,6 +568,8 @@ def _evaluate_irs( def select(self, irs: Seq[NamedIR]) -> Self: ... def with_columns(self, irs: Seq[NamedIR]) -> Self: ... def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: ... + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: ... + def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... class CompliantDataFrame( @@ -608,7 +610,7 @@ class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): @classmethod def by_names( - cls, df: DataFrameT, names: Seq[str], / + cls, df: DataFrameT, names: Seq[str], /, *, drop_null_keys: bool = False ) -> DataFrameGroupBy[DataFrameT]: ... # TODO @dangotbanned: Plan how projection should work diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index b44fad0454..538c2718b0 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -186,8 +186,7 @@ def test_key_with_nulls() -> None: assert_equal_data(result, expected) -@pytest.mark.xfail(reason="Not implemented `drop_null_keys`") -def test_key_with_nulls_ignored() -> None: # pragma: no cover +def test_key_with_nulls_ignored() -> None: data = {"b": [4, 5, None], "a": [1, 2, 3]} result = ( dataframe(data) @@ -200,9 +199,7 @@ def test_key_with_nulls_ignored() -> None: # pragma: no cover assert_equal_data(result, expected) -@pytest.mark.xfail( - reason="Not implemented `drop_null_keys`, `__iter__`", raises=NotImplementedError -) +@pytest.mark.xfail(reason="Not implemented `__iter__`", raises=NotImplementedError) def test_key_with_nulls_iter() -> None: # pragma: no cover data = { "b": [None, "4", "5", None, "7"], @@ -221,7 +218,7 @@ def test_key_with_nulls_iter() -> None: # pragma: no cover assert len(result) == 4 -@pytest.mark.xfail(reason="Not implemented `drop_null_keys`, `Expr` as keys") +@pytest.mark.xfail(reason="Not implemented `Expr` as keys") @pytest.mark.parametrize( "keys", [[nwp.col("a").abs()], ["a", nwp.col("a").abs().alias("a_test")]] ) From 0a94af61564cbb300acf09a6553c752d10957d18 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 20 Sep 2025 15:39:40 +0000 Subject: [PATCH 27/93] fix: avoid `typing_extensions` import (again) - https://github.com/narwhals-dev/narwhals/actions/runs/17881703473/job/50850187525?pr=3143 - https://github.com/narwhals-dev/narwhals/pull/3143/commits/8aaf9a94ecdebbc1b5422faa8542d7b515bdf0fe --- narwhals/_plan/dataframe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 55b75e3e00..80784d9c5b 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -2,8 +2,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload -from typing_extensions import Self - from narwhals._plan import _expansion, _parse from narwhals._plan.contexts import ExprContext from narwhals._plan.expr import _parse_sort_by From 874e736ad2d5c7f8d2377da68314064199eb576a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 21 Sep 2025 20:40:01 +0000 Subject: [PATCH 28/93] feat(expr-ir): Reject `drop_null_keys` with `Expr` --- narwhals/_plan/dataframe.py | 13 +++++++++++++ narwhals/_plan/group_by.py | 3 +++ tests/plan/group_by_test.py | 7 ++----- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 80784d9c5b..ba8d7b0e9d 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -148,6 +148,19 @@ def to_dict( def __len__(self) -> int: return len(self._compliant) + @overload + def group_by( + self, + *by: OneOrIterable[IntoExpr], + drop_null_keys: Literal[False] = ..., + **named_by: IntoExpr, + ) -> GroupBy[Self]: ... + + @overload + def group_by( + self, *by: OneOrIterable[str], drop_null_keys: Literal[True] + ) -> GroupBy[Self]: ... + def group_by( self, *by: OneOrIterable[IntoExpr], diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index cc417e879e..503020e375 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -62,6 +62,9 @@ def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFra compliant_gb = compliant._group_by # Do we need to project first? if not all(key.is_column() for key in resolved.keys): + if self._drop_null_keys: + msg = "drop_null_keys cannot be True when keys contains Expr or Series" + raise NotImplementedError(msg) msg = fmt_group_by_error( "Need to sketch out non-projecting keys group by first", resolved.keys, diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 538c2718b0..ae175835e8 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -218,19 +218,16 @@ def test_key_with_nulls_iter() -> None: # pragma: no cover assert len(result) == 4 -@pytest.mark.xfail(reason="Not implemented `Expr` as keys") @pytest.mark.parametrize( "keys", [[nwp.col("a").abs()], ["a", nwp.col("a").abs().alias("a_test")]] ) -def test_group_by_raise_drop_null_keys_with_exprs( - keys: list[nwp.Expr | str], -) -> None: # pragma: no cover +def test_group_by_raise_drop_null_keys_with_exprs(keys: list[nwp.Expr | str]) -> None: data = {"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 4], "y": [0.5, -0.5, 1.0, -1.0, 1.5]} df = dataframe(data) with pytest.raises( NotImplementedError, match="drop_null_keys cannot be True when keys contains Expr" ): - df.group_by(*keys, drop_null_keys=True) # type: ignore[call-overload,unused-ignore] + df.group_by(*keys, drop_null_keys=True).agg(nwp.sum("y")) # type: ignore[call-overload] def test_no_agg() -> None: From 42ccd1428a0f46ff51f7b52d3a78d843f8635d18 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 21 Sep 2025 23:02:11 +0000 Subject: [PATCH 29/93] refactor: Return `NamedIR` from `prepare_projection` Make a lot of stuff simpler --- narwhals/_plan/_expansion.py | 13 +-- narwhals/_plan/_rewrites.py | 5 +- narwhals/_plan/dataframe.py | 8 +- narwhals/_plan/group_by.py | 41 +++---- tests/plan/expr_expansion_test.py | 182 +++++++++++++++++------------- tests/plan/expr_rewrites_test.py | 7 +- tests/plan/utils.py | 15 ++- 7 files changed, 140 insertions(+), 131 deletions(-) diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index 06a9f39dd5..7d505ae653 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -154,8 +154,8 @@ def with_multiple_columns(self) -> ExpansionFlags: def prepare_projection( - exprs: Sequence[ExprIR], schema: IntoFrozenSchema -) -> tuple[Seq[ExprIR], FrozenSchema, OutputNames]: + exprs: Sequence[ExprIR], /, keys: GroupByKeys = (), *, schema: IntoFrozenSchema +) -> tuple[Seq[NamedIR], FrozenSchema]: """Expand IRs into named column selections. **Primary entry-point**, will be used by `select`, `with_columns`, @@ -163,15 +163,14 @@ def prepare_projection( Arguments: exprs: IRs that *may* contain things like `Columns`, `SelectorIR`, `Exclude`, etc. + keys: Names of `group_by` columns. schema: Scope to expand multi-column selectors in. - - Returns: - `exprs`, rewritten using `Column(name)` only. """ frozen_schema = freeze_schema(schema) - rewritten = rewrite_projections(tuple(exprs), schema=frozen_schema) + rewritten = rewrite_projections(tuple(exprs), keys=keys, schema=frozen_schema) output_names = ensure_valid_exprs(rewritten, frozen_schema) - return rewritten, frozen_schema, output_names + named_irs = into_named_irs(rewritten, output_names) + return named_irs, frozen_schema def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]: diff --git a/narwhals/_plan/_rewrites.py b/narwhals/_plan/_rewrites.py index ae23fa4b9b..fd26364e66 100644 --- a/narwhals/_plan/_rewrites.py +++ b/narwhals/_plan/_rewrites.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from narwhals._plan._expansion import into_named_irs, prepare_projection +from narwhals._plan._expansion import prepare_projection from narwhals._plan._guards import ( is_aggregation, is_binary_expr, @@ -31,8 +31,7 @@ def rewrite_all( - Currently we do a full traversal of each tree per-rewrite function - There's no caching *after* `prepare_projection` yet """ - out_irs, _, names = prepare_projection(parse_into_seq_of_expr_ir(*exprs), schema) - named_irs = into_named_irs(out_irs, names) + named_irs, _ = prepare_projection(parse_into_seq_of_expr_ir(*exprs), schema=schema) return tuple(map_ir(ir, *rewrites) for ir in named_irs) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index ba8d7b0e9d..90e4add512 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -71,10 +71,9 @@ def _project( /, ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: """Temp, while these parts aren't connected, this is easier for testing.""" - irs, schema_frozen, output_names = _expansion.prepare_projection( - _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema + named_irs, schema_frozen = _expansion.prepare_projection( + _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self.schema ) - named_irs = _expansion.into_named_irs(irs, output_names) return schema_frozen.project(named_irs, context) def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: @@ -95,8 +94,7 @@ def sort( sort, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last ) - irs, _, output_names = _expansion.prepare_projection(sort, self.schema) - named_irs = _expansion.into_named_irs(irs, output_names) + named_irs, _ = _expansion.prepare_projection(sort, schema=self.schema) return self._from_compliant(self._compliant.sort(named_irs, opts)) def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 503020e375..8167ca02bd 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -23,18 +23,14 @@ from typing import TYPE_CHECKING, Any, Generic, NamedTuple from narwhals._plan import _parse -from narwhals._plan._expansion import ( - ensure_valid_exprs, - into_named_irs, - rewrite_projections, -) -from narwhals._plan.schema import FrozenSchema, freeze_schema +from narwhals._plan._expansion import prepare_projection from narwhals._plan.typing import DataFrameT if TYPE_CHECKING: from collections.abc import Iterator from narwhals._plan.expressions import ExprIR, NamedIR + from narwhals._plan.schema import FrozenSchema from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq from narwhals.schema import Schema @@ -101,30 +97,21 @@ class _TempGroupByStuff(NamedTuple): def resolve_group_by( input_keys: Seq[ExprIR], input_aggs: Seq[ExprIR], schema: Schema ) -> _TempGroupByStuff: - input_schema = freeze_schema(schema) - - # "Initialize schema from keys" - keys = rewrite_projections(input_keys, schema=input_schema) - key_names = ensure_valid_exprs(keys, input_schema) - keys_named_irs = into_named_irs(keys, key_names) - output_schema = input_schema._select(keys_named_irs) - - # "Add aggregation column(s)" # noqa: ERA001 - aggs = rewrite_projections(input_aggs, keys=key_names, schema=input_schema) - aggs_names = ensure_valid_exprs(aggs, input_schema) - aggs_named_irs = into_named_irs(aggs, aggs_names) - aggs_schema = input_schema._select(aggs_named_irs) - - # "Coerce aggregation column(s) into List unless not needed (auto-implode)" # noqa: ERA001 - # TODO @dangotbanned: seems to just be a schema transform, maybe not important for now? - - # "Final output_schema" + # > Initialize schema from keys + keys, input_schema = prepare_projection(input_keys, schema=schema) + keys_names = tuple(e.name for e in keys) + output_schema = input_schema._select(keys) + + # > Add aggregation column(s) + aggs, _ = prepare_projection(input_aggs, keys_names, schema=input_schema) + aggs_schema = input_schema._select(aggs) + # > Final output_schema result_schema = output_schema.merge(aggs_schema) - # "Make sure aggregation columns do not contain keys or index columns" + # > Make sure aggregation columns do not contain keys or index columns # TODO @dangotbanned: Probably just the keys part? - # *index columns* seems to be rolling/dynamic only - return _TempGroupByStuff(keys_named_irs, aggs_named_irs, key_names, result_schema) + # *index columns* seems to be rolling/dynamic only + return _TempGroupByStuff(keys, aggs, keys_names, result_schema) def fmt_group_by_error( diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index a80724ff86..203c39911b 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -16,7 +16,7 @@ from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.schema import freeze_schema from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError -from tests.plan.utils import assert_expr_ir_equal +from tests.plan.utils import assert_expr_ir_equal, named_ir if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -257,19 +257,28 @@ def test_replace_selector( @pytest.mark.parametrize( ("into_exprs", "expected"), [ - ("a", [nwp.col("a")]), - (nwp.col("b", "c", "d"), [nwp.col("b"), nwp.col("c"), nwp.col("d")]), - (nwp.nth(6), [nwp.col("g")]), - (nwp.nth(9, 8, -5), [nwp.col("j"), nwp.col("i"), nwp.col("p")]), - ( + pytest.param("a", [nwp.col("a")], id="Col"), + pytest.param( + nwp.col("b", "c", "d"), + [nwp.col("b"), nwp.col("c"), nwp.col("d")], + id="Columns", + ), + pytest.param(nwp.nth(6), [nwp.col("g")], id="Nth"), + pytest.param( + nwp.nth(9, 8, -5), + [nwp.col("j"), nwp.col("i"), nwp.col("p")], + id="IndexColumns", + ), + pytest.param( [nwp.nth(2).alias("c again"), nwp.nth(-1, -2).name.to_uppercase()], [ - nwp.col("c").alias("c again"), - nwp.col("u").alias("U"), - nwp.col("s").alias("S"), + named_ir("c again", nwp.col("c")), + named_ir("U", nwp.col("u")), + named_ir("S", nwp.col("s")), ], + id="Nth-Alias-IndexColumns-Uppercase", ), - ( + pytest.param( nwp.all(), [ nwp.col("a"), @@ -293,82 +302,89 @@ def test_replace_selector( nwp.col("s"), nwp.col("u"), ], + id="All", ), - ( + pytest.param( (ndcs.numeric() - ndcs.by_dtype(nw.Float32(), nw.Float64())) .cast(nw.Int64) .mean() .name.suffix("_mean"), [ - nwp.col("a").cast(nw.Int64()).mean().alias("a_mean"), - nwp.col("b").cast(nw.Int64()).mean().alias("b_mean"), - nwp.col("c").cast(nw.Int64()).mean().alias("c_mean"), - nwp.col("d").cast(nw.Int64()).mean().alias("d_mean"), - nwp.col("e").cast(nw.Int64()).mean().alias("e_mean"), - nwp.col("f").cast(nw.Int64()).mean().alias("f_mean"), - nwp.col("g").cast(nw.Int64()).mean().alias("g_mean"), - nwp.col("h").cast(nw.Int64()).mean().alias("h_mean"), + named_ir("a_mean", nwp.col("a").cast(nw.Int64()).mean()), + named_ir("b_mean", nwp.col("b").cast(nw.Int64()).mean()), + named_ir("c_mean", nwp.col("c").cast(nw.Int64()).mean()), + named_ir("d_mean", nwp.col("d").cast(nw.Int64()).mean()), + named_ir("e_mean", nwp.col("e").cast(nw.Int64()).mean()), + named_ir("f_mean", nwp.col("f").cast(nw.Int64()).mean()), + named_ir("g_mean", nwp.col("g").cast(nw.Int64()).mean()), + named_ir("h_mean", nwp.col("h").cast(nw.Int64()).mean()), ], + id="Selector-SUB-Cast-Mean-Suffix", ), - ( + pytest.param( nwp.col("u").alias("1").alias("2").alias("3").alias("4").name.keep(), - # NOTE: Would be nice to rewrite with less intermediate steps - # but retrieving the root name is enough for now - [nwp.col("u").alias("1").alias("2").alias("3").alias("4").alias("u")], + [named_ir("u", nwp.col("u"))], + id="Alias-Etc-Keep", ), - ( + pytest.param( ( (ndcs.numeric() ^ (ndcs.matches(r"[abcdg]") | ndcs.by_name("i", "f"))) * 100 ).name.suffix("_mult_100"), [ - (nwp.col("e") * nwp.lit(100)).alias("e_mult_100"), - (nwp.col("h") * nwp.lit(100)).alias("h_mult_100"), - (nwp.col("j") * nwp.lit(100)).alias("j_mult_100"), + named_ir("e_mult_100", (nwp.col("e") * nwp.lit(100))), + named_ir("h_mult_100", (nwp.col("h") * nwp.lit(100))), + named_ir("j_mult_100", (nwp.col("j") * nwp.lit(100))), ], + id="Selector-XOR-OR-BinaryExpr-Suffix", ), - ( + pytest.param( ndcs.by_dtype(nw.Duration()) .dt.total_minutes() .name.map(lambda nm: f"total_mins: {nm!r} ?"), - [nwp.col("q").dt.total_minutes().alias("total_mins: 'q' ?")], + [named_ir("total_mins: 'q' ?", nwp.col("q").dt.total_minutes())], + id="ByDType-TotalMins-Name-Map", ), - ( + pytest.param( nwp.col("f", "g") .cast(nw.String) .str.starts_with("1") .all() .name.suffix("_all_starts_with_1"), [ - nwp.col("f") - .cast(nw.String) - .str.starts_with("1") - .all() - .alias("f_all_starts_with_1"), - nwp.col("g") - .cast(nw.String) - .str.starts_with("1") - .all() - .alias("g_all_starts_with_1"), + named_ir( + "f_all_starts_with_1", + nwp.col("f").cast(nw.String).str.starts_with("1").all(), + ), + named_ir( + "g_all_starts_with_1", + nwp.col("g").cast(nw.String).str.starts_with("1").all(), + ), ], + id="Cast-StartsWith-All-Suffix", ), - ( + pytest.param( nwp.col("a", "b") .first() .over("c", "e", order_by="d") .name.suffix("_first_over_part_order_1"), [ - nwp.col("a") - .first() - .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]) - .alias("a_first_over_part_order_1"), - nwp.col("b") - .first() - .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]) - .alias("b_first_over_part_order_1"), + named_ir( + "a_first_over_part_order_1", + nwp.col("a") + .first() + .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]), + ), + named_ir( + "b_first_over_part_order_1", + nwp.col("b") + .first() + .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]), + ), ], + id="First-Over-Partitioned-Ordered-Suffix", ), - ( + pytest.param( nwp.exclude(BIG_EXCLUDE), [ nwp.col("c"), @@ -379,42 +395,48 @@ def test_replace_selector( nwp.col("i"), nwp.col("j"), ], + id="Exclude", ), - ( + pytest.param( nwp.exclude(BIG_EXCLUDE).name.suffix("_2"), [ - nwp.col("c").alias("c_2"), - nwp.col("d").alias("d_2"), - nwp.col("f").alias("f_2"), - nwp.col("g").alias("g_2"), - nwp.col("h").alias("h_2"), - nwp.col("i").alias("i_2"), - nwp.col("j").alias("j_2"), + named_ir("c_2", nwp.col("c")), + named_ir("d_2", nwp.col("d")), + named_ir("f_2", nwp.col("f")), + named_ir("g_2", nwp.col("g")), + named_ir("h_2", nwp.col("h")), + named_ir("i_2", nwp.col("i")), + named_ir("j_2", nwp.col("j")), ], + id="Exclude-Suffix", ), - ( + pytest.param( nwp.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), [ - nwp.col("c") - .alias("c_min_over_order_by") - .min() - .over(order_by=[nwp.col("k")]) + named_ir( + "c_min_over_order_by", + nwp.col("c").min().over(order_by=[nwp.col("k")]), + ) ], + id="Alias-Min-Over-Order-By-Selector", ), pytest.param( (ndcs.by_name("a", "b", "c") / nwp.col("e").first()) .over("g", "f", order_by="f") .name.prefix("hi_"), [ - (nwp.col("a") / nwp.col("e").first()) - .over("g", "f", order_by="f") - .alias("hi_a"), - (nwp.col("b") / nwp.col("e").first()) - .over("g", "f", order_by="f") - .alias("hi_b"), - (nwp.col("c") / nwp.col("e").first()) - .over("g", "f", order_by="f") - .alias("hi_c"), + named_ir( + "hi_a", + (nwp.col("a") / nwp.col("e").first()).over("g", "f", order_by="f"), + ), + named_ir( + "hi_b", + (nwp.col("b") / nwp.col("e").first()).over("g", "f", order_by="f"), + ), + named_ir( + "hi_c", + (nwp.col("c") / nwp.col("e").first()).over("g", "f", order_by="f"), + ), ], id="Selector-BinaryExpr-Over-Prefix", ), @@ -426,7 +448,7 @@ def test_prepare_projection( schema_1: dict[str, DType], ) -> None: irs_in = parse_into_seq_of_expr_ir(into_exprs) - actual, _, _ = prepare_projection(irs_in, schema_1) + actual, _ = prepare_projection(irs_in, schema=schema_1) assert len(actual) == len(expected) for lhs, rhs in zip(actual, expected): assert_expr_ir_equal(lhs, rhs) @@ -451,7 +473,7 @@ def test_prepare_projection_duplicate(expr: nwp.Expr, schema_1: dict[str, DType] irs = parse_into_seq_of_expr_ir(expr.alias("dupe")) pattern = re.compile(r"\.alias\(.dupe.\)") with pytest.raises(DuplicateError, match=pattern): - prepare_projection(irs, schema_1) + prepare_projection(irs, schema=schema_1) @pytest.mark.parametrize( @@ -517,7 +539,7 @@ def test_prepare_projection_column_not_found( pattern = re.compile(rf"not found: {re.escape(repr(missing))}") irs = parse_into_seq_of_expr_ir(into_exprs) with pytest.raises(ColumnNotFoundError, match=pattern): - prepare_projection(irs, schema_1) + prepare_projection(irs, schema=schema_1) @pytest.mark.parametrize( @@ -554,15 +576,15 @@ def test_prepare_projection_horizontal_alias( expr = function(into_exprs) alias_1 = expr.alias("alias(x1)") irs = parse_into_seq_of_expr_ir(alias_1) - out_irs, _, _ = prepare_projection(irs, schema_1) + out_irs, _ = prepare_projection(irs, schema=schema_1) assert len(out_irs) == 1 - assert out_irs[0] == function("a", "b", "c").alias("alias(x1)")._ir + assert out_irs[0] == named_ir("alias(x1)", function("a", "b", "c")) alias_2 = alias_1.alias("alias(x2)") irs = parse_into_seq_of_expr_ir(alias_2) - out_irs, _, _ = prepare_projection(irs, schema_1) + out_irs, _ = prepare_projection(irs, schema=schema_1) assert len(out_irs) == 1 - assert out_irs[0] == function("a", "b", "c").alias("alias(x1)").alias("alias(x2)")._ir + assert out_irs[0] == named_ir("alias(x2)", function("a", "b", "c")) @pytest.mark.parametrize( @@ -574,4 +596,4 @@ def test_prepare_projection_index_error( irs = parse_into_seq_of_expr_ir(into_exprs) pattern = re.compile(r"invalid.+index.+nth", re.DOTALL | re.IGNORECASE) with pytest.raises(ComputeError, match=pattern): - prepare_projection(irs, schema_1) + prepare_projection(irs, schema=schema_1) diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index bf810aa176..455fecd114 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -14,7 +14,7 @@ ) from narwhals._plan.expressions.window import Over from narwhals.exceptions import InvalidOperationError -from tests.plan.utils import assert_expr_ir_equal +from tests.plan.utils import assert_expr_ir_equal, named_ir if TYPE_CHECKING: from narwhals._plan.typing import IntoExpr @@ -79,11 +79,6 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: assert_expr_ir_equal(lhs, rhs) -def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: - """Helper constructor for test compare.""" - return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) - - def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: expected = ( named_ir("a", nwp.col("a")), diff --git a/tests/plan/utils.py b/tests/plan/utils.py index bf6135ee2f..d1ae2ce95e 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -36,9 +36,18 @@ def assert_expr_ir_equal( """ lhs = _unwrap_ir(actual) if isinstance(expected, str): - assert repr(lhs) == expected + assert repr(lhs) == expected, ( + f"\nlhs:\n {lhs!r}\n\nexpected:\n {expected!r}" + ) elif isinstance(actual, ir.NamedIR) and isinstance(expected, ir.NamedIR): - assert actual == expected + assert actual == expected, ( + f"\nactual:\n {actual!r}\n\nexpected:\n {expected!r}" + ) else: rhs = expected._ir if isinstance(expected, nwp.Expr) else expected - assert lhs == rhs + assert lhs == rhs, f"\nlhs:\n {lhs!r}\n\nrhs:\n {rhs!r}" + + +def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: + """Helper constructor for test compare.""" + return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) From 451498c4e98212481fd51209b013fd61c0159a09 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 21 Sep 2025 23:32:38 +0000 Subject: [PATCH 30/93] feat(expr-ir): *Almost* all `Expr` key tests passing! --- narwhals/_plan/arrow/expr.py | 5 ++- narwhals/_plan/arrow/group_by.py | 8 +++- narwhals/_plan/dataframe.py | 3 ++ narwhals/_plan/group_by.py | 9 +---- narwhals/_plan/protocols.py | 1 + tests/plan/group_by_test.py | 65 ++++++++++++++++++++++++-------- 6 files changed, 66 insertions(+), 25 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 2caf3e69b2..b547ed57fa 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -55,7 +55,7 @@ Not, ) from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr - from narwhals._plan.expressions.functions import FillNull, Pow + from narwhals._plan.expressions.functions import Abs, FillNull, Pow from narwhals.typing import Into1DArray, IntoDType, PythonLiteral Expr: TypeAlias = "ArrowExpr" @@ -112,6 +112,9 @@ def func(node: FunctionExpr[Any], frame: Frame, name: str) -> StoresNativeT_co: return func + def abs(self, node: FunctionExpr[Abs], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(pc.abs)(node, frame, name) + def not_(self, node: FunctionExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(pc.invert)(node, frame, name) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index bb307152ec..02ccc0f480 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -9,6 +9,7 @@ from narwhals._plan import expressions as ir from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import DataFrameGroupBy +from narwhals._plan.schema import freeze_schema from narwhals._utils import Implementation if TYPE_CHECKING: @@ -181,7 +182,12 @@ def by_names( @classmethod def by_named_irs(cls, df: ArrowDataFrame, irs: Seq[NamedIR], /) -> Self: - raise NotImplementedError + obj = cls.__new__(cls) + irs_all, _ = freeze_schema(df.schema)._with_columns(irs) + obj._df = df.with_columns(irs_all) + obj._keys = irs + obj._keys_names = () + return obj @property def compliant(self) -> ArrowDataFrame: diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 90e4add512..a9f1cd5777 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -80,6 +80,9 @@ def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, _ = self._project(exprs, named_exprs, ExprContext.SELECT) return self._from_compliant(self._compliant.select(named_irs)) + # NOTE: Want to be able to call `with_columns` at compliant level - and still get the right schema + # - Currently it acts like select in `group_by` + # - Doing some gymnastics to workaround for now def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, _ = self._project(exprs, named_exprs, ExprContext.WITH_COLUMNS) return self._from_compliant(self._compliant.with_columns(named_irs)) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 8167ca02bd..c7844bb0b4 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -61,15 +61,8 @@ def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFra if self._drop_null_keys: msg = "drop_null_keys cannot be True when keys contains Expr or Series" raise NotImplementedError(msg) - msg = fmt_group_by_error( - "Need to sketch out non-projecting keys group by first", - resolved.keys, - resolved.aggs, - resolved.result_schema, - ) - raise NotImplementedError(msg) grouped = compliant_gb.by_named_irs(compliant, resolved.keys) - else: # noqa: RET506 + else: # If not, we can just use the resolved key names as a fast-path grouped = compliant_gb.by_names( compliant, resolved.keys_names, drop_null_keys=self._drop_null_keys diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 567ec52a9e..e403b409b1 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -203,6 +203,7 @@ def _with_native(self, native: Any, name: str, /) -> Self: return self.from_native(native, name or self.name, self.version) # series & scalar + def abs(self, node: FunctionExpr[F.Abs], frame: FrameT_contra, name: str) -> Self: ... def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ... def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index ae175835e8..8918218dd2 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -47,6 +47,16 @@ def test_group_by_iter() -> None: # pragma: no cover assert sorted(keys) == sorted(expected_keys) +def test_group_by_nw_all() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]}) + result = df.group_by("a").agg(nwp.all().sum()).sort("a") + expected = {"a": [1, 2], "b": [9, 6], "c": [15, 9]} + assert_equal_data(result, expected) + result = df.group_by("a").agg(nwp.all().sum().name.suffix("_sum")).sort("a") + expected = {"a": [1, 2], "b_sum": [9, 6], "c_sum": [15, 9]} + assert_equal_data(result, expected) + + @pytest.mark.parametrize( ("attr", "expected"), [ @@ -274,6 +284,36 @@ def test_double_same_aggregation() -> None: assert_equal_data(result, expected) +def test_all_kind_of_aggs() -> None: + df = dataframe({"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]}) + result = ( + df.group_by("a") + .agg( + c=nwp.col("b").mean(), + d=nwp.col("b").mean(), + e=nwp.col("b").std(ddof=1), + f=nwp.col("b").std(ddof=2), + g=nwp.col("b").var(ddof=2), + h=nwp.col("b").var(ddof=2), + i=nwp.col("b").n_unique(), + ) + .sort("a") + ) + + variance_num = sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) + expected = { + "a": [1, 2], + "c": [5, 10 / 3], + "d": [5, 10 / 3], + "e": [1, (variance_num / (3 - 1)) ** 0.5], + "f": [2**0.5, (variance_num) ** 0.5], # denominator is 1 (=3-2) + "g": [2.0, variance_num], # denominator is 1 (=3-2) + "h": [2.0, variance_num], # denominator is 1 (=3-2) + "i": [3, 2], + } + assert_equal_data(result, expected) + + def test_fancy_functions() -> None: df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) result = df.group_by("a").agg(nwp.all().std(ddof=0)).sort("a") @@ -294,54 +334,49 @@ def test_fancy_functions() -> None: assert_equal_data(result, expected) -XFAIL_NOT_IMPL_EXPR_KEYS = pytest.mark.xfail( - reason="TODO: Expr group_by keys", raises=NotImplementedError -) - - +# TODO @dangotbanned: Investigate the single failing case @pytest.mark.parametrize( ("keys", "aggs", "expected", "sort_by"), [ - pytest.param( + ( [nwp.col("a").abs(), nwp.col("a").abs().alias("a_with_alias")], [nwp.col("x").sum()], {"a": [1, 2], "a_with_alias": [1, 2], "x": [5, 5]}, ["a"], - marks=XFAIL_NOT_IMPL_EXPR_KEYS, ), pytest.param( [nwp.col("a").alias("x")], [nwp.col("x").mean().alias("y")], {"x": [-1, 1, 2], "y": [4.0, 0.5, 2.5]}, ["x"], - marks=XFAIL_NOT_IMPL_EXPR_KEYS, + marks=pytest.mark.xfail( + reason="AssertionError: Mismatch at index 0: -1.0 != 4.0" + ), + id="FIXME", ), - ( # NOTE: This one is fine as it just selects + ( [nwp.col("a")], [nwp.col("a").count().alias("foo-bar"), nwp.all().sum()], {"a": [-1, 1, 2], "foo-bar": [1, 2, 2], "x": [4, 1, 5], "y": [1.5, 0, 0]}, ["a"], ), - pytest.param( + ( [nwp.col("a", "y").abs()], [nwp.col("x").sum()], {"a": [1, 1, 2], "y": [0.5, 1.5, 1], "x": [1, 4, 5]}, ["a", "y"], - marks=XFAIL_NOT_IMPL_EXPR_KEYS, ), - pytest.param( + ( [nwp.col("a").abs().alias("y")], [nwp.all().sum().name.suffix("c")], {"y": [1, 2], "ac": [1, 4], "xc": [5, 5]}, ["y"], - marks=XFAIL_NOT_IMPL_EXPR_KEYS, ), - pytest.param( + ( [npcs.by_dtype(nw.Float64()).abs()], [npcs.numeric().sum()], {"y": [0.5, 1.0, 1.5], "a": [2, 4, -1], "x": [1, 5, 4]}, ["y"], - marks=XFAIL_NOT_IMPL_EXPR_KEYS, ), ], ) From 780af661e7f1b560878772bca37ae0a86b86e8d9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 22 Sep 2025 13:45:51 +0000 Subject: [PATCH 31/93] fix(DRAFT): Roughly port over `ParseExprKeysGroupBy` - A lot of room for improvement - But its enough to pass the tests --- narwhals/_plan/arrow/group_by.py | 27 +++++++++++++++++++++++---- tests/plan/group_by_test.py | 26 ++++++++++++++++++++------ 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 02ccc0f480..214ad41061 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -24,6 +24,7 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq +from narwhals._plan.common import replace Incomplete: TypeAlias = Any @@ -167,6 +168,7 @@ class ArrowGroupBy(DataFrameGroupBy["ArrowDataFrame"]): _df: ArrowDataFrame _keys: Seq[NamedIR] _keys_names: Seq[str] + _keys_names_original: Seq[str] @classmethod def by_names( @@ -178,15 +180,29 @@ def by_names( obj._df = df obj._keys = () obj._keys_names = names + obj._keys_names_original = () return obj @classmethod def by_named_irs(cls, df: ArrowDataFrame, irs: Seq[NamedIR], /) -> Self: obj = cls.__new__(cls) - irs_all, _ = freeze_schema(df.schema)._with_columns(irs) - obj._df = df.with_columns(irs_all) - obj._keys = irs + _, schema = freeze_schema(df.schema)._with_columns(irs) + tmp_name_length = max(len(str(c)) for c in schema) + 1 + key_names_orig: list[str] = [] + + def _temporary_name(key: str) -> str: + # 5 is the length of `__tmp` + len__tmp = 5 + alias = f"_{key}_tmp{'_' * (tmp_name_length - len(key) - len__tmp)}" + key_names_orig.append(key) + return alias + + safe_keys = tuple(replace(key, name=_temporary_name(key.name)) for key in irs) + irs_final, _ = freeze_schema(df.schema)._with_columns(safe_keys) + obj._df = df.with_columns(irs_final) + obj._keys = safe_keys obj._keys_names = () + obj._keys_names_original = tuple(key_names_orig) return obj @property @@ -203,7 +219,10 @@ def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame: expr = ArrowAggExpr(e).parse() use_threads = use_threads and expr.use_threads aggs.append(expr.spec) - return self.compliant._with_native(self._agg(aggs, use_threads=use_threads)) + result = self.compliant._with_native(self._agg(aggs, use_threads=use_threads)) + if original := self._keys_names_original: + return result.rename(dict(zip(self.keys_names, original))) + return result def _agg(self, agg_specs: list[AceroAggSpec], /, *, use_threads: bool) -> pa.Table: """Adapted from [`pa.TableGroupBy.aggregate`] and [`pa.acero._group_by`]. diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 8918218dd2..dee20afa4e 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -334,7 +334,6 @@ def test_fancy_functions() -> None: assert_equal_data(result, expected) -# TODO @dangotbanned: Investigate the single failing case @pytest.mark.parametrize( ("keys", "aggs", "expected", "sort_by"), [ @@ -344,15 +343,11 @@ def test_fancy_functions() -> None: {"a": [1, 2], "a_with_alias": [1, 2], "x": [5, 5]}, ["a"], ), - pytest.param( + ( [nwp.col("a").alias("x")], [nwp.col("x").mean().alias("y")], {"x": [-1, 1, 2], "y": [4.0, 0.5, 2.5]}, ["x"], - marks=pytest.mark.xfail( - reason="AssertionError: Mismatch at index 0: -1.0 != 4.0" - ), - id="FIXME", ), ( [nwp.col("a")], @@ -392,6 +387,25 @@ def test_group_by_expr( assert_equal_data(result, expected) +def test_group_by_expr_2757684799() -> None: + """From [narwhals-dev/narwhals#2325-2757684799]. + + The **incorrect** result is: + + {'b': [2, 1], 'a': [2, 1], 'c': [2.0, 1.0]} + + [narwhals-dev/narwhals#2325-2757684799]: https://github.com/narwhals-dev/narwhals/pull/2325#pullrequestreview-2757684799 + """ + data: dict[str, Any] = {"a": [1, 1, 2], "b": [4, 5, 6], "unrelated": [10, -1, -9]} + df = dataframe(data) + keys = nwp.col("a").alias("b"), "a" + aggs = nwp.col("b").mean().alias("c") + expected = {"b": [2, 1], "a": [2, 1], "c": [6.0, 4.5]} + + result = df.group_by(keys).agg(aggs).sort("b", descending=True) + assert_equal_data(result, expected) + + def test_group_by_selector() -> None: data = { "a": [1, 1, 1], From e0debe56b46163cd92e5b4f3a62a5a7260a394b4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 22 Sep 2025 13:54:43 +0000 Subject: [PATCH 32/93] refactor: Slightly simplify temp name Gonna replace with something more generalized later --- narwhals/_plan/arrow/group_by.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 214ad41061..0cee61176c 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -187,13 +187,12 @@ def by_names( def by_named_irs(cls, df: ArrowDataFrame, irs: Seq[NamedIR], /) -> Self: obj = cls.__new__(cls) _, schema = freeze_schema(df.schema)._with_columns(irs) - tmp_name_length = max(len(str(c)) for c in schema) + 1 + len__tmp = 5 + tmp_name_length = (max(len(c) for c in schema) + 1) - len__tmp key_names_orig: list[str] = [] def _temporary_name(key: str) -> str: - # 5 is the length of `__tmp` - len__tmp = 5 - alias = f"_{key}_tmp{'_' * (tmp_name_length - len(key) - len__tmp)}" + alias = f"_{key}_tmp{'_' * (tmp_name_length - len(key))}" key_names_orig.append(key) return alias From 8b622d473a1ffd0285eb67f40f7cd5742f5cf4cb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:21:27 +0000 Subject: [PATCH 33/93] feat: Add temp column name utils --- narwhals/_plan/common.py | 119 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 0b4267f214..bbf6f4385c 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -6,15 +6,20 @@ from collections.abc import Iterable from decimal import Decimal from operator import attrgetter +from secrets import token_hex from typing import TYPE_CHECKING, cast, overload from narwhals._plan._guards import is_iterable_reject +from narwhals._utils import _hasattr_static from narwhals.dtypes import DType +from narwhals.exceptions import NarwhalsError from narwhals.utils import Version if TYPE_CHECKING: from collections.abc import Iterator - from typing import Any, Callable, TypeVar + from typing import Any, Callable, ClassVar, TypeVar + + from typing_extensions import TypeIs from narwhals._plan.typing import ( DTypeT, @@ -23,6 +28,7 @@ NonNestedDTypeT, OneOrIterable, ) + from narwhals._utils import _StoresColumns from narwhals.typing import NonNestedDType, NonNestedLiteral T = TypeVar("T") @@ -115,3 +121,114 @@ def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: yield from flatten_hash_safe(element) else: yield element # type: ignore[misc] + + +def _has_columns(obj: Any) -> TypeIs[_StoresColumns]: + return _hasattr_static(obj, "columns") + + +class temp: # noqa: N801 + """Temporary mini namespace for temporary utils.""" + + _MAX_ITERATIONS: ClassVar[int] = 100 + + @classmethod + def column_name( + cls, + source: _StoresColumns | Iterable[str], + /, + *, + prefix: str = "nw", + n_chars: int = 16, + ) -> str: + """Generate a single, unique column name that is not present in `source`.""" + columns = cls._into_columns(source) + prefix, n_bytes = cls._parse_prefix_n_bytes(prefix, n_chars) + for _ in range(cls._MAX_ITERATIONS): + token = f"{prefix}{token_hex(n_bytes)}" + if token not in columns: + return token + raise cls._failed_generation_error(columns, n_bytes) + + @classmethod + def column_names( + cls, + source: _StoresColumns | Iterable[str], + /, + *, + prefix: str = "nw", + n_chars: int = 16, + ) -> Iterator[str]: + """Yields unique column names that are not present in `source`. + + Any column name returned will be unique among those that preceded it. + """ + columns = cls._into_columns(source) + prefix, n_bytes = cls._parse_prefix_n_bytes(prefix, n_chars) + n_failed: int = 0 + while n_failed <= cls._MAX_ITERATIONS: + token = f"{prefix}{token_hex(n_bytes)}" + if token not in columns: + columns.add(token) + n_failed = 0 + yield token + else: + n_failed += 1 + raise cls._failed_generation_error(columns, n_bytes) + + @staticmethod + def _into_columns(source: _StoresColumns | Iterable[str], /) -> set[str]: + return set(source.columns if _has_columns(source) else source) + + @staticmethod + def _parse_prefix_n_bytes(prefix: str, n_chars: int, /) -> tuple[str, int]: + prefix = prefix or "nw" + n_bytes = (n_chars - len(prefix)) // 2 + if n_bytes < 2: + msg = ( + f"Temporary column name generation requires at least 4 characters to store random bytes, \n" + f"but not enough room with: {prefix=}, {n_chars=}.\n\n" + "Hint: Maybe try\n- a shorter `prefix`?\n- a higher `n_chars`?" + ) + raise NarwhalsError(msg) + return prefix, n_bytes + + @classmethod + def _failed_generation_error( + cls, columns: Iterable[str], n_bytes: int, / + ) -> NarwhalsError: # pragma: no cover + """Takes some work to trigger this, but it's possible 😅. + + Examples: + >>> import itertools + >>> from narwhals._plan.common import temp + >>> it = temp.column_names(["a", "b", "c"], prefix="long_prefix") + >>> list(itertools.islice(it, 100_000)) # doctest:+SKIP + Traceback (most recent call last): + ... + NarwhalsError: Was unable to generate a column name with `n_bytes=2` within 100 iterations, + that was not present in existing (60246) columns: + [ + 'a', + 'b', + 'c', + 'long_prefix0000', + 'long_prefix0003', + 'long_prefix0004', + 'long_prefix0005', + 'long_prefix0006', + 'long_prefix0007', + 'long_prefix0008', + ..., + ] + """ + import reprlib + + current = sorted(columns) + truncated = reprlib.Repr(indent=4, maxlist=10).repr(current) + msg = ( + "Was unable to generate a column name with " + f"`{n_bytes=}` within {cls._MAX_ITERATIONS} iterations, \n" + f"that was not present in existing ({len(current)}) columns:\n{truncated}" + ) + return NarwhalsError(msg) From cad1a471e6f04f37ef43bf3c21dc7f8443af8743 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:37:09 +0000 Subject: [PATCH 34/93] refactor: Replace temp naming stuff --- narwhals/_plan/arrow/group_by.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 0cee61176c..34498c847e 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -1,5 +1,6 @@ from __future__ import annotations +from itertools import chain from typing import TYPE_CHECKING, Any, Literal import pyarrow as pa # ignore-banned-import @@ -7,6 +8,7 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._plan import expressions as ir +from narwhals._plan.common import replace, temp from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import DataFrameGroupBy from narwhals._plan.schema import freeze_schema @@ -24,7 +26,6 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq -from narwhals._plan.common import replace Incomplete: TypeAlias = Any @@ -186,22 +187,14 @@ def by_names( @classmethod def by_named_irs(cls, df: ArrowDataFrame, irs: Seq[NamedIR], /) -> Self: obj = cls.__new__(cls) - _, schema = freeze_schema(df.schema)._with_columns(irs) - len__tmp = 5 - tmp_name_length = (max(len(c) for c in schema) + 1) - len__tmp - key_names_orig: list[str] = [] - - def _temporary_name(key: str) -> str: - alias = f"_{key}_tmp{'_' * (tmp_name_length - len(key))}" - key_names_orig.append(key) - return alias - - safe_keys = tuple(replace(key, name=_temporary_name(key.name)) for key in irs) + keys_names = tuple(key.name for key in irs) + unique_names = temp.column_names(chain(keys_names, df.columns)) + safe_keys = tuple(replace(key, name=name) for key, name in zip(irs, unique_names)) irs_final, _ = freeze_schema(df.schema)._with_columns(safe_keys) obj._df = df.with_columns(irs_final) obj._keys = safe_keys obj._keys_names = () - obj._keys_names_original = tuple(key_names_orig) + obj._keys_names_original = keys_names return obj @property From 57c3e6d2cb67f9be4fb9e2f50ff9b2af4704633e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 22 Sep 2025 17:34:42 +0000 Subject: [PATCH 35/93] test: Add `Expr.unique` group_by tests --- tests/plan/group_by_test.py | 42 +++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index dee20afa4e..4ef74b77d0 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -18,6 +18,8 @@ if TYPE_CHECKING: from collections.abc import Mapping, Sequence + from narwhals._plan.typing import IntoExpr + def dataframe(data: dict[str, Any], /) -> nwp.DataFrame[Any, Any]: return nwp.DataFrame.from_native(pa.table(data)) @@ -522,3 +524,43 @@ def test_group_by_agg_last( df = df.sort(aggs, **pre_sort) result = df.group_by(keys).agg(nwp.col(aggs).last()).sort(keys) assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected"), + [ + (["a"], [nwp.col("b").unique()], {"a": ["a", "b", "c"], "b": [[1], [2, 3], [3]]}), + ( + ["a"], + [nwp.col("b", "d").unique()], + { + "a": ["a", "b", "c"], + "b": [[1], [2, 3], [3]], + "d": [["three", "one"], ["three"], ["one"]], + }, + ), + ( + ["d", "c"], + [npcs.string().unique(), nwp.col("b").first().alias("b_first")], + { + "d": ["one", "one", "three", "three", "three"], + "c": [1, 3, 2, 4, 5], + "a": [["c"], ["a"], ["b"], ["b"], ["a"]], + "b_first": [3, 1, 3, 2, 1], + }, + ), + ], + ids=["Unique-Single", "Unique-Multi", "Unique-Selector-Fancy"], +) +def test_group_by_agg_unique( + keys: Sequence[str], aggs: Sequence[IntoExpr], expected: Mapping[str, Any] +) -> None: + data = { + "a": ["a", "b", "a", "b", "c"], + "b": [1, 2, 1, 3, 3], + "c": [5, 4, 3, 2, 1], + "d": ["three", "three", "one", "three", "one"], + } + df = dataframe(data) + result = df.group_by(keys).agg(aggs).sort(keys) + assert_equal_data(result, expected) From 87cd4a8170c8ac33d537dc7494aa176f2d883455 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:22:39 +0000 Subject: [PATCH 36/93] fix: Use `operator.or_` instead of `pyarrow.compute.or_` Didn't show up as an issue until trying to drop multiple keys Surprised pyarrow doesnt support this? --- narwhals/_plan/arrow/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 17da99163b..00f7ea86ed 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator from functools import reduce from typing import TYPE_CHECKING, Any, Literal, cast, overload @@ -111,7 +112,7 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: if subset is None: native = self.native.drop_null() else: - to_drop = reduce(pc.or_, (pc.field(name).is_null() for name in subset)) + to_drop = reduce(operator.or_, (pc.field(name).is_null() for name in subset)) native = self.native.filter(~to_drop) return self._with_native(native) From d49bcce3bf4137f45f5e84b0f93cfc398fe6df6e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:36:54 +0000 Subject: [PATCH 37/93] =?UTF-8?q?test:=20Steal=20some=20of=20the=20`polars?= =?UTF-8?q?`=20test=20suite=20=F0=9F=98=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/plan/group_by_test.py | 65 +++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 4ef74b77d0..4039361083 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING, Any import pytest @@ -564,3 +565,67 @@ def test_group_by_agg_unique( df = dataframe(data) result = df.group_by(keys).agg(aggs).sort(keys) assert_equal_data(result, expected) + + +def test_group_by_args() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L302-L325 + """ + data = { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + df = dataframe(data) + + # Single column name + assert df.group_by("a").agg("b").columns == ["a", "b"] + # Column names as list + expected = ["a", "b", "c"] + assert df.group_by(["a", "b"]).agg("c").columns == expected + # Column names as positional arguments + assert df.group_by("a", "b").agg("c").columns == expected + # With keyword argument + assert df.group_by("a", "b", drop_null_keys=True).agg("c").columns == expected + # Multiple aggregations as list + assert df.group_by("a").agg(["b", "c"]).columns == expected + # Multiple aggregations as positional arguments + assert df.group_by("a").agg("b", "c").columns == expected + # Multiple aggregations as keyword arguments + assert df.group_by("a").agg(q="b", r="c").columns == ["a", "q", "r"] + + +def test_group_by_all() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L568-L577 + """ + data = {"a": [1, 2], "b": [1, 2]} + df = dataframe(data) + expected = {"a": [1, 2], "b": [1, 2], "a_agg": [1, 2]} + result = df.group_by(nwp.all()).agg(nwp.col("a").max().name.suffix("_agg")).sort("a") + assert_equal_data(result, expected) + + +def test_group_by_input_independent_with_len_23868() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L1476-L1484 + """ + data = {"a": ["A", "B", "C"]} + expected = {"literal": ["G"], "len": [3]} + result = dataframe(data).group_by(nwp.lit("G")).agg(nwp.len()) + assert_equal_data(result, expected) + + +def test_group_by_series_lit_22103() -> None: + """Adapted from [upstream], but rejecting for now. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L1406-L1424 + """ + data = {"g": [0, 1]} + series = nwp.Series.from_native(pa.chunked_array([[42, 2, 3]])) + df = dataframe(data) + with pytest.raises(NotImplementedError, match=re.escape("foo=lit(Series)")): + df.group_by("g").agg(foo=series) From 9d72311dae1ea5d906cb18e46cf1fd6c1973ca56 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:45:38 +0000 Subject: [PATCH 38/93] test: `df.group_by(**named_by)` --- tests/plan/group_by_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 4039361083..c177502814 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -629,3 +629,17 @@ def test_group_by_series_lit_22103() -> None: df = dataframe(data) with pytest.raises(NotImplementedError, match=re.escape("foo=lit(Series)")): df.group_by("g").agg(foo=series) + + +def test_group_by_named() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L878-884 + """ + data = {"a": [1, 1, 2, 2, 3, 3], "b": range(6)} + df = dataframe(data) + result = df.group_by(z=nwp.col("a") * 2).agg(nwp.col("b").min()).sort("b") + expected = ( + df.group_by((nwp.col("a") * 2).alias("z")).agg(nwp.col("b").min()).sort("b") + ) + assert_equal_data(result, expected.to_dict(as_series=False)) From 479eee6864214de8e6ab6d05fb8295e7de131c1a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:23:07 +0000 Subject: [PATCH 39/93] revert: Don't introduce unused type var --- narwhals/_plan/protocols.py | 1 - 1 file changed, 1 deletion(-) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index e403b409b1..4332a83f74 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -73,7 +73,6 @@ FrameT_co = TypeVar("FrameT_co", bound=FrameAny, covariant=True) FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) DataFrameT = TypeVar("DataFrameT", bound=DataFrameAny) -DataFrameT_co = TypeVar("DataFrameT_co", bound=DataFrameAny, covariant=True) NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) EagerExprT_co = TypeVar("EagerExprT_co", bound=EagerExprAny, covariant=True) From 4391a6fa5b65c0f5d0a606864de503e803cd570b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:25:33 +0000 Subject: [PATCH 40/93] chore: Remove completed todo --- narwhals/_plan/protocols.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 4332a83f74..e66d8525cc 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -612,13 +612,10 @@ class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): def by_names( cls, df: DataFrameT, names: Seq[str], /, *, drop_null_keys: bool = False ) -> DataFrameGroupBy[DataFrameT]: ... - - # TODO @dangotbanned: Plan how projection should work @classmethod def by_named_irs( cls, df: DataFrameT, irs: Seq[NamedIR], / ) -> DataFrameGroupBy[DataFrameT]: ... - def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... @property def keys(self) -> Seq[NamedIR]: From 668db86daf3430904f84075e5b75355aa66b5523 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:54:00 +0000 Subject: [PATCH 41/93] docs: Trim `Schema.merge` Added these rules before I realised it described `dict` --- narwhals/_plan/schema.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 9e0214f9fa..288e876bbf 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -78,14 +78,9 @@ def _with_columns(self, exprs: Seq[NamedIR]) -> tuple[Seq[NamedIR], FrozenSchema return tuple(exprs_out), freeze_schema(items) def merge(self, other: FrozenSchema, /) -> FrozenSchema: - """Return a new schema, merging `other` with `self`. + """Return a new schema, merging `other` with `self` (see [upstream]). - Merging logic (from [`Schema.merge`]): - - Fields that occur in `self` but not `other` are unmodified - - Fields that occur in `other` but not `self` are appended, in order, to the end of `self` - - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original index - - [`Schema.merge`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-schema/src/schema.rs#L265-L274 + [upstream]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-schema/src/schema.rs#L265-L274. """ return freeze_schema(self._mapping | other._mapping) From ad4babd8d9887953a667ede5408464cf51bc33b4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 14:15:11 +0000 Subject: [PATCH 42/93] chore: Clean up unused in `arrow.group_by` --- narwhals/_plan/arrow/group_by.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 34498c847e..7ca0b4997a 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -63,27 +63,15 @@ ir.functions.Unique: "hash_distinct", } -REMAINING: tuple[Aggregation, ...] = ( - "hash_first_last", # Compute the first and last of values in each group - "hash_min_max", # Compute the minimum and maximum of values in each group - "hash_one", # Get one value from each group - "hash_product", # Compute the product of values in each group - "hash_tdigest", # Compute approximate quantiles of values in each group -) -"""Available [native aggs] we haven't used. - -[native aggs]: https://arrow.apache.org/docs/python/compute.html#grouped-aggregations -""" - - -REQUIRES_PYARROW_20: tuple[ - Literal["kurtosis"], Literal["pivot_wider"], Literal["skew"] -] = ( +REQUIRES_PYARROW_20: tuple[Literal["kurtosis"], Literal["skew"]] = ( "kurtosis", # Compute the kurtosis of values in each group - "pivot_wider", # Pivot values according to a pivot key column "skew", # Compute the skewness of values in each group ) -"""https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations""" +"""They don't show in [our version of the stubs], but are possible in [`pyarrow>=20`]. + +[our version of the stubs]: https://github.com/narwhals-dev/narwhals/issues/2124#issuecomment-3191374210 +[`pyarrow>=20`]: https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations +""" def group_by_error( From a940e0592e55b73459630367ae0551e1eeba3050 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 18:02:39 +0000 Subject: [PATCH 43/93] test: Add `test_group_by_exclude_keys` --- tests/plan/group_by_test.py | 38 +++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index c177502814..70999c548d 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -643,3 +643,41 @@ def test_group_by_named() -> None: df.group_by((nwp.col("a") * 2).alias("z")).agg(nwp.col("b").min()).sort("b") ) assert_equal_data(result, expected.to_dict(as_series=False)) + + +def test_group_by_exclude_keys() -> None: + # `group_by(keys)` and `exclude` share some logic + data = { + "a": ["A", "B", "A"], + "b": [1, 2, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": [None, 9, 7], + "f": [True, False, None], + "g": [False, None, False], + "h": [None, None, True], + "j": [12.1, None, 4.0], + "k": [42, 10, None], + "l": [4, 5, 6], + "m": [0, 1, 2], + } + df = dataframe(data).with_columns( + npcs.boolean().fill_null(False), npcs.numeric().fill_null(0) + ) + exclude = "b", "c", "d", "e", "f", "g", "j", "k", "l", "m" + result = df.group_by(nwp.exclude(exclude)).agg(npcs.all().sum()).sort("a", "h") + expected = { + "a": ["A", "A", "B"], + "h": [False, True, False], + "b": [1, 3, 2], + "c": [9, 4, 2], + "d": [8, 8, 7], + "e": [0, 7, 9], + "f": [1, 0, 0], + "g": [0, 0, 0], + "j": [12.1, 4.0, 0.0], + "k": [42, 0, 10], + "l": [4, 6, 5], + "m": [0, 2, 1], + } + assert_equal_data(result, expected) From 3a0617dac1302f979d52105bc6ab686dc4f5a77e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 18:28:11 +0000 Subject: [PATCH 44/93] refactor: Tweak `prepare_excluded` https://github.com/narwhals-dev/narwhals/pull/3143#discussion_r2372360062 --- narwhals/_plan/_expansion.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index 7d505ae653..cbb7321d7a 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -322,11 +322,10 @@ def prepare_excluded( origin: ExprIR, keys: GroupByKeys, flags: ExpansionFlags, / ) -> Excluded: """Huge simplification of https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555.""" - exclude: set[str] = set() - if flags.has_exclude: - exclude.update(*(e.names for e in origin.iter_left() if isinstance(e, Exclude))) - exclude.update(keys) - return frozenset(exclude) + gb_keys = frozenset(keys) + if not flags.has_exclude: + return gb_keys + return gb_keys.union(*(e.names for e in origin.iter_left() if isinstance(e, Exclude))) def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool: From 28601df8a72d10cbca3b214e946a3792a014eaf7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 21:35:57 +0000 Subject: [PATCH 45/93] refactor: Refining schema projections --- narwhals/_plan/_expansion.py | 2 +- narwhals/_plan/arrow/group_by.py | 3 +- narwhals/_plan/dataframe.py | 34 ++++++++-------------- narwhals/_plan/group_by.py | 15 ++-------- narwhals/_plan/schema.py | 49 +++++++++++++------------------- 5 files changed, 36 insertions(+), 67 deletions(-) diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index cbb7321d7a..0e9d83e6b9 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -158,7 +158,7 @@ def prepare_projection( ) -> tuple[Seq[NamedIR], FrozenSchema]: """Expand IRs into named column selections. - **Primary entry-point**, will be used by `select`, `with_columns`, + **Primary entry-point**, for `select`, `with_columns`, and any other context that requires resolving expression names. Arguments: diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 7ca0b4997a..265c5808b2 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -178,8 +178,7 @@ def by_named_irs(cls, df: ArrowDataFrame, irs: Seq[NamedIR], /) -> Self: keys_names = tuple(key.name for key in irs) unique_names = temp.column_names(chain(keys_names, df.columns)) safe_keys = tuple(replace(key, name=name) for key, name in zip(irs, unique_names)) - irs_final, _ = freeze_schema(df.schema)._with_columns(safe_keys) - obj._df = df.with_columns(irs_final) + obj._df = df.with_columns(freeze_schema(df.schema).with_columns_irs(safe_keys)) obj._keys = safe_keys obj._keys_names = () obj._keys_names_original = keys_names diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index a9f1cd5777..37c1c6e8f8 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload -from narwhals._plan import _expansion, _parse -from narwhals._plan.contexts import ExprContext +from narwhals._plan import _parse +from narwhals._plan._expansion import prepare_projection from narwhals._plan.expr import _parse_sort_by from narwhals._plan.group_by import GroupBy from narwhals._plan.series import Series @@ -24,10 +24,7 @@ import pyarrow as pa from typing_extensions import Self - from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.protocols import CompliantBaseFrame, CompliantDataFrame - from narwhals._plan.schema import FrozenSchema - from narwhals._plan.typing import Seq from narwhals.typing import NativeFrame @@ -63,29 +60,22 @@ def _from_compliant(cls, compliant: CompliantBaseFrame[Any, NativeFrameT], /) -> def to_native(self) -> NativeFrameT: return self._compliant.native - def _project( - self, - exprs: tuple[OneOrIterable[IntoExpr], ...], - named_exprs: dict[str, Any], - context: ExprContext, - /, - ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: - """Temp, while these parts aren't connected, this is easier for testing.""" - named_irs, schema_frozen = _expansion.prepare_projection( + def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: + named_irs, schema = prepare_projection( _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self.schema ) - return schema_frozen.project(named_irs, context) - - def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, _ = self._project(exprs, named_exprs, ExprContext.SELECT) - return self._from_compliant(self._compliant.select(named_irs)) + return self._from_compliant(self._compliant.select(schema.select_irs(named_irs))) # NOTE: Want to be able to call `with_columns` at compliant level - and still get the right schema # - Currently it acts like select in `group_by` # - Doing some gymnastics to workaround for now def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, _ = self._project(exprs, named_exprs, ExprContext.WITH_COLUMNS) - return self._from_compliant(self._compliant.with_columns(named_irs)) + named_irs, schema = prepare_projection( + _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self.schema + ) + return self._from_compliant( + self._compliant.with_columns(schema.with_columns_irs(named_irs)) + ) def sort( self, @@ -97,7 +87,7 @@ def sort( sort, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last ) - named_irs, _ = _expansion.prepare_projection(sort, schema=self.schema) + named_irs, _ = prepare_projection(sort, schema=self.schema) return self._from_compliant(self._compliant.sort(named_irs, opts)) def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index c7844bb0b4..fc0d6312f8 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -93,11 +93,11 @@ def resolve_group_by( # > Initialize schema from keys keys, input_schema = prepare_projection(input_keys, schema=schema) keys_names = tuple(e.name for e in keys) - output_schema = input_schema._select(keys) + output_schema = input_schema.select(keys) # > Add aggregation column(s) aggs, _ = prepare_projection(input_aggs, keys_names, schema=input_schema) - aggs_schema = input_schema._select(aggs) + aggs_schema = input_schema.select(aggs) # > Final output_schema result_schema = output_schema.merge(aggs_schema) @@ -105,14 +105,3 @@ def resolve_group_by( # TODO @dangotbanned: Probably just the keys part? # *index columns* seems to be rolling/dynamic only return _TempGroupByStuff(keys, aggs, keys_names, result_schema) - - -def fmt_group_by_error( - message: str, /, keys: Seq[NamedIR], aggs: Seq[NamedIR], schema: FrozenSchema -) -> str: - return ( - f"TODO: {message}:\n\n" - f"keys:\n{keys!r}\n\n" - f"aggs:\n{aggs!r}\n\n" - f"result_schema:\n{schema!r}" - ) diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 288e876bbf..6ea494aca5 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -1,9 +1,8 @@ from __future__ import annotations -from collections import deque from collections.abc import Mapping from functools import lru_cache -from itertools import chain, repeat +from itertools import chain from types import MappingProxyType from typing import TYPE_CHECKING, Any, TypeVar, overload @@ -16,7 +15,6 @@ from typing_extensions import TypeAlias - from narwhals._plan.contexts import ExprContext from narwhals._plan.typing import Seq from narwhals.dtypes import DType @@ -41,16 +39,14 @@ class FrozenSchema(Immutable): __slots__ = ("_mapping",) _mapping: MappingProxyType[str, DType] - def project( - self, exprs: Seq[NamedIR], context: ExprContext - ) -> tuple[Seq[NamedIR], FrozenSchema]: - if context.is_select(): - return exprs, self._select(exprs) - if context.is_with_columns(): - return self._with_columns(exprs) - raise TypeError(context) + def merge(self, other: FrozenSchema, /) -> FrozenSchema: + """Return a new schema, merging `other` with `self` (see [upstream]). + + [upstream]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-schema/src/schema.rs#L265-L274. + """ + return freeze_schema(self._mapping | other._mapping) - def _select(self, exprs: Seq[NamedIR]) -> FrozenSchema: + def select(self, exprs: Seq[NamedIR]) -> FrozenSchema: """Return a new schema, equivalent to performing `df.select(*exprs)`. Arguments: @@ -64,25 +60,20 @@ def _select(self, exprs: Seq[NamedIR]) -> FrozenSchema: default = Unknown() return freeze_schema((name, self.get(name, default)) for name in names) - def _with_columns(self, exprs: Seq[NamedIR]) -> tuple[Seq[NamedIR], FrozenSchema]: - exprs_out = deque[NamedIR]() - named: dict[str, NamedIR[Any]] = {e.name: e for e in exprs} - items: IntoFrozenSchema - for name in self: - exprs_out.append(named.pop(name, NamedIR.from_name(name))) - if named: - items = chain(self.items(), zip(named, repeat(Unknown(), len(named)))) - exprs_out.extend(named.values()) - else: - items = self - return tuple(exprs_out), freeze_schema(items) + def select_irs(self, exprs: Seq[NamedIR]) -> Seq[NamedIR]: + return exprs - def merge(self, other: FrozenSchema, /) -> FrozenSchema: - """Return a new schema, merging `other` with `self` (see [upstream]). + def with_columns(self, exprs: Seq[NamedIR]) -> FrozenSchema: + # similar to `merge`, but preserving known `DType`s + names = (e.name for e in exprs) + default = Unknown() + miss = {name: default for name in names if name not in self} + return freeze_schema(self._mapping | miss) - [upstream]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-schema/src/schema.rs#L265-L274. - """ - return freeze_schema(self._mapping | other._mapping) + def with_columns_irs(self, exprs: Seq[NamedIR]) -> Seq[NamedIR]: + named: dict[str, NamedIR[Any]] = {e.name: e for e in exprs} + it = (named.pop(name, NamedIR.from_name(name)) for name in self) + return tuple(chain(it, named.values())) @property def __immutable_hash__(self) -> int: From 6122245451aabe9dbcf8244e2c5d0d6cd7c4b6f4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 23 Sep 2025 21:45:02 +0000 Subject: [PATCH 46/93] refactor: Clean up `resolve_group_by` a bit Still need to decide what to do with named tuple --- narwhals/_plan/group_by.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index fc0d6312f8..66d3f9355b 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -30,9 +30,8 @@ from collections.abc import Iterator from narwhals._plan.expressions import ExprIR, NamedIR - from narwhals._plan.schema import FrozenSchema + from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq - from narwhals.schema import Schema class GroupBy(Generic[DataFrameT]): @@ -88,20 +87,10 @@ class _TempGroupByStuff(NamedTuple): def resolve_group_by( - input_keys: Seq[ExprIR], input_aggs: Seq[ExprIR], schema: Schema + input_keys: Seq[ExprIR], input_aggs: Seq[ExprIR], input_schema: IntoFrozenSchema ) -> _TempGroupByStuff: - # > Initialize schema from keys - keys, input_schema = prepare_projection(input_keys, schema=schema) + keys, schema = prepare_projection(input_keys, schema=input_schema) keys_names = tuple(e.name for e in keys) - output_schema = input_schema.select(keys) - - # > Add aggregation column(s) - aggs, _ = prepare_projection(input_aggs, keys_names, schema=input_schema) - aggs_schema = input_schema.select(aggs) - # > Final output_schema - result_schema = output_schema.merge(aggs_schema) - - # > Make sure aggregation columns do not contain keys or index columns - # TODO @dangotbanned: Probably just the keys part? - # *index columns* seems to be rolling/dynamic only + aggs, _ = prepare_projection(input_aggs, keys_names, schema=schema) + result_schema = schema.select(keys).merge(schema.select(aggs)) return _TempGroupByStuff(keys, aggs, keys_names, result_schema) From 8d1220d3b3b15839818d2c6e6b7568286b6130c5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 25 Sep 2025 15:02:50 +0000 Subject: [PATCH 47/93] perf: Skip synthesized `FrozenSchema.__init__` --- narwhals/_plan/schema.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 6ea494aca5..889ac73c91 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from collections.abc import ItemsView, Iterator, KeysView, ValuesView - from typing_extensions import TypeAlias + from typing_extensions import Never, TypeAlias from narwhals._plan.typing import Seq from narwhals.dtypes import DType @@ -39,6 +39,10 @@ class FrozenSchema(Immutable): __slots__ = ("_mapping",) _mapping: MappingProxyType[str, DType] + def __init_subclass__(cls, *_: Never, **__: Never) -> Never: + msg = f"Cannot subclass {cls.__name__!r}" + raise TypeError(msg) + def merge(self, other: FrozenSchema, /) -> FrozenSchema: """Return a new schema, merging `other` with `self` (see [upstream]). @@ -90,7 +94,9 @@ def names(self) -> FrozenColumns: @staticmethod def _from_mapping(mapping: MappingProxyType[str, DType], /) -> FrozenSchema: - return FrozenSchema(_mapping=mapping) + obj = FrozenSchema.__new__(FrozenSchema) + object.__setattr__(obj, "_mapping", mapping) + return obj @staticmethod def _from_hash_safe(items: _FrozenSchemaHash, /) -> FrozenSchema: From 6dcfa4feb787f44516922be0eeade7a921571ca9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 25 Sep 2025 15:11:24 +0000 Subject: [PATCH 48/93] feat: Accept more in `freeze_schema`, `IntoFrozenSchema`, `prepare_projection` Can accomodate - any narwhals/compliant-level frame - any reasonable transformation on a schema - the result of `freeze_schema` can be passed *back* into itself for free --- narwhals/_plan/dataframe.py | 6 +++--- narwhals/_plan/group_by.py | 4 +--- narwhals/_plan/schema.py | 19 +++++++++++++++---- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 37c1c6e8f8..7d9467436a 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -62,7 +62,7 @@ def to_native(self) -> NativeFrameT: def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, schema = prepare_projection( - _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self.schema + _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self ) return self._from_compliant(self._compliant.select(schema.select_irs(named_irs))) @@ -71,7 +71,7 @@ def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: # - Doing some gymnastics to workaround for now def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, schema = prepare_projection( - _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self.schema + _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self ) return self._from_compliant( self._compliant.with_columns(schema.with_columns_irs(named_irs)) @@ -87,7 +87,7 @@ def sort( sort, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last ) - named_irs, _ = prepare_projection(sort, schema=self.schema) + named_irs, _ = prepare_projection(sort, schema=self) return self._from_compliant(self._compliant.sort(named_irs, opts)) def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 66d3f9355b..1bba08231e 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -49,9 +49,7 @@ def __init__( def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: frame = self._frame resolved = resolve_group_by( - self._keys, - _parse.parse_into_seq_of_expr_ir(*aggs, **named_aggs), - frame.schema, + self._keys, _parse.parse_into_seq_of_expr_ir(*aggs, **named_aggs), frame ) compliant = frame._compliant compliant_gb = compliant._group_by diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 889ac73c91..6ab379c793 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -4,23 +4,25 @@ from functools import lru_cache from itertools import chain from types import MappingProxyType -from typing import TYPE_CHECKING, Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload from narwhals._plan._expr_ir import NamedIR from narwhals._plan._immutable import _IMMUTABLE_HASH_NAME, Immutable +from narwhals._utils import _hasattr_static from narwhals.dtypes import Unknown if TYPE_CHECKING: from collections.abc import ItemsView, Iterator, KeysView, ValuesView - from typing_extensions import Never, TypeAlias + from typing_extensions import Never, TypeAlias, TypeIs from narwhals._plan.typing import Seq from narwhals.dtypes import DType + from narwhals.typing import IntoSchema IntoFrozenSchema: TypeAlias = ( - "Mapping[str, DType] | Iterator[tuple[str, DType]] | FrozenSchema" + "IntoSchema | Iterator[tuple[str, DType]] | FrozenSchema | HasSchema" ) """A schema to freeze, or an already frozen one. @@ -138,6 +140,15 @@ def __repr__(self) -> str: return f"{type(self).__name__}([{nl}{indent}{items}{sep}{nl}])" +class HasSchema(Protocol): + @property + def schema(self) -> IntoSchema: ... + + +def has_schema(obj: Any) -> TypeIs[HasSchema]: + return _hasattr_static(obj, "schema") + + @overload def freeze_schema(mapping: IntoFrozenSchema, /) -> FrozenSchema: ... @overload @@ -147,7 +158,7 @@ def freeze_schema( ) -> FrozenSchema: if isinstance(iterable, FrozenSchema): return iterable - into = iterable or schema + into = iterable.schema if has_schema(iterable) else (iterable or schema) hashable = tuple(into.items() if isinstance(into, Mapping) else into) return _freeze_schema_cache(hashable) From d7cf2d6d774c65fbdd5bf542634c9b84b6d9a45a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 25 Sep 2025 18:40:46 +0000 Subject: [PATCH 49/93] refactor(DRAFT): Add `Grouper`/`Resolver` concepts - Currently just replaces the narwhals-level bits - Next part is making use of them for compliant-level --- narwhals/_plan/dataframe.py | 7 +- narwhals/_plan/group_by.py | 147 +++++++++++++++++++++++++++--------- 2 files changed, 117 insertions(+), 37 deletions(-) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 7d9467436a..55aa72219b 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -5,7 +5,7 @@ from narwhals._plan import _parse from narwhals._plan._expansion import prepare_projection from narwhals._plan.expr import _parse_sort_by -from narwhals._plan.group_by import GroupBy +from narwhals._plan.group_by import GroupBy, Grouped from narwhals._plan.series import Series from narwhals._plan.typing import ( IntoExpr, @@ -158,8 +158,9 @@ def group_by( drop_null_keys: bool = False, **named_by: IntoExpr, ) -> GroupBy[Self]: - exprs = _parse.parse_into_seq_of_expr_ir(*by, **named_by) - return GroupBy(self, exprs, drop_null_keys=drop_null_keys) + return Grouped.by(*by, drop_null_keys=drop_null_keys, **named_by).to_group_by( + self + ) def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: return self._from_compliant(self._compliant.drop(columns, strict=strict)) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 1bba08231e..3eea2edaa9 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -20,49 +20,43 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, NamedTuple +from typing import TYPE_CHECKING, Any, Generic, Protocol from narwhals._plan import _parse from narwhals._plan._expansion import prepare_projection from narwhals._plan.typing import DataFrameT +from narwhals._typing_compat import TypeVar if TYPE_CHECKING: from collections.abc import Iterator + from typing_extensions import Self + from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq +ResolverT_co = TypeVar("ResolverT_co", bound="Resolver", covariant=True) + class GroupBy(Generic[DataFrameT]): _frame: DataFrameT - _keys: Seq[ExprIR] - _drop_null_keys: bool + _grouper: Grouped - def __init__( - self, frame: DataFrameT, keys: Seq[ExprIR], /, *, drop_null_keys: bool = False - ) -> None: + def __init__(self, frame: DataFrameT, grouper: Grouped, /) -> None: self._frame = frame - self._keys = keys - self._drop_null_keys = drop_null_keys + self._grouper = grouper def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: frame = self._frame - resolved = resolve_group_by( - self._keys, _parse.parse_into_seq_of_expr_ir(*aggs, **named_aggs), frame - ) + resolved = self._grouper.agg(*aggs, **named_aggs).resolve(frame) compliant = frame._compliant compliant_gb = compliant._group_by - # Do we need to project first? - if not all(key.is_column() for key in resolved.keys): - if self._drop_null_keys: - msg = "drop_null_keys cannot be True when keys contains Expr or Series" - raise NotImplementedError(msg) + if resolved.requires_projection(): grouped = compliant_gb.by_named_irs(compliant, resolved.keys) else: - # If not, we can just use the resolved key names as a fast-path grouped = compliant_gb.by_names( - compliant, resolved.keys_names, drop_null_keys=self._drop_null_keys + compliant, resolved.key_names, drop_null_keys=resolved._drop_null_keys ) return self._frame._from_compliant(grouped.agg(resolved.aggs)) @@ -71,24 +65,109 @@ def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: raise NotImplementedError(msg) -class _TempGroupByStuff(NamedTuple): - """Trying to organize info that's useful to keep from `resolve_group_by`. +class Grouper(Protocol[ResolverT_co]): + """Revised interface focused on the state change + expression projections. - Important: - Not a long-term thing! + - Uses `Expr` everywhere (no need to duplicate layers) + - Resolver only needs schema (neither needs a frame, but can use one to get `schema`) """ - keys: Seq[NamedIR] - aggs: Seq[NamedIR] - keys_names: Seq[str] - result_schema: FrozenSchema + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool + + @classmethod + def by( + cls, + *by: OneOrIterable[IntoExpr], + drop_null_keys: bool = False, + **named_by: IntoExpr, + ) -> Self: + obj = cls.__new__(cls) + obj._keys = _parse.parse_into_seq_of_expr_ir(*by, **named_by) + obj._drop_null_keys = drop_null_keys + return obj + + def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> Self: + self._aggs = _parse.parse_into_seq_of_expr_ir(*aggs, **named_aggs) + return self + + @property + def _resolver(self) -> type[ResolverT_co]: ... + + def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co: + return self._resolver.from_grouper(self, context) + + +class Resolver(Protocol): + _schema_in: FrozenSchema + _keys: Seq[NamedIR] + _aggs: Seq[NamedIR] + _key_names: Seq[str] + _schema: FrozenSchema + _drop_null_keys: bool + + @property + def keys(self) -> Seq[NamedIR]: + return self._keys + + @property + def aggs(self) -> Seq[NamedIR]: + return self._aggs + + @property + def key_names(self) -> Seq[str]: + return self._key_names + + @classmethod + def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self: + obj = cls.__new__(cls) + keys, schema_in = prepare_projection(grouper._keys, schema=context) + obj._keys, obj._schema_in = keys, schema_in + obj._key_names = tuple(e.name for e in keys) + obj._aggs, _ = prepare_projection(grouper._aggs, obj._key_names, schema=schema_in) + obj._schema = schema_in.select(keys).merge(schema_in.select(obj._aggs)) + obj._drop_null_keys = grouper._drop_null_keys + return obj + + def requires_projection(self, *, allow_aliasing: bool = False) -> bool: + """Return True is group keys contain anything that is not a column selection. + + Notes: + If False is returned, we can just use the resolved key names as a fast-path to group. + + Arguments: + allow_aliasing: If False (default), any aliasing is not considered to be column selection. + """ + if not all(key.is_column(allow_aliasing=allow_aliasing) for key in self.keys): + if self._drop_null_keys: + msg = "drop_null_keys cannot be True when keys contains Expr or Series" + raise NotImplementedError(msg) + return True + return False + + +class Grouped(Grouper["Resolved"]): + """Narwhals-level `GroupBy` builder.""" + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool + + @property + def _resolver(self) -> type[Resolved]: + return Resolved + def to_group_by(self, frame: DataFrameT, /) -> GroupBy[DataFrameT]: + return GroupBy(frame, self) -def resolve_group_by( - input_keys: Seq[ExprIR], input_aggs: Seq[ExprIR], input_schema: IntoFrozenSchema -) -> _TempGroupByStuff: - keys, schema = prepare_projection(input_keys, schema=input_schema) - keys_names = tuple(e.name for e in keys) - aggs, _ = prepare_projection(input_aggs, keys_names, schema=schema) - result_schema = schema.select(keys).merge(schema.select(aggs)) - return _TempGroupByStuff(keys, aggs, keys_names, result_schema) + +class Resolved(Resolver): + """Narwhals-level `GroupBy` resolver.""" + + _schema_in: FrozenSchema + _keys: Seq[NamedIR] + _aggs: Seq[NamedIR] + _key_names: Seq[str] + _schema: FrozenSchema + _drop_null_keys: bool From 3d0670b003ebdafcc50bd417f71a1a0d85d51b77 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 25 Sep 2025 22:01:14 +0000 Subject: [PATCH 50/93] refactor: Move loads of stuff up from `arrow` --- narwhals/_plan/arrow/dataframe.py | 3 + narwhals/_plan/arrow/group_by.py | 52 ++------- narwhals/_plan/group_by.py | 108 ++--------------- narwhals/_plan/protocols.py | 188 ++++++++++++++++++++++++++++-- 4 files changed, 199 insertions(+), 152 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 00f7ea86ed..fa1c806c44 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -34,6 +34,9 @@ class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]): + _native: pa.Table + _version: Version + def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._plan.arrow.namespace import ArrowNamespace diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 265c5808b2..aec96c0a73 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -1,6 +1,5 @@ from __future__ import annotations -from itertools import chain from typing import TYPE_CHECKING, Any, Literal import pyarrow as pa # ignore-banned-import @@ -8,10 +7,8 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._plan import expressions as ir -from narwhals._plan.common import replace, temp from narwhals._plan.expressions import aggregation as agg -from narwhals._plan.protocols import DataFrameGroupBy -from narwhals._plan.schema import freeze_schema +from narwhals._plan.protocols import EagerDataFrameGroupBy from narwhals._utils import Implementation if TYPE_CHECKING: @@ -23,7 +20,7 @@ AggregateOptions, Aggregation, ) - from narwhals._plan.arrow.dataframe import ArrowDataFrame + from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq @@ -153,45 +150,20 @@ def parse(self) -> Self: return self -class ArrowGroupBy(DataFrameGroupBy["ArrowDataFrame"]): - _df: ArrowDataFrame +class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): + _df: Frame _keys: Seq[NamedIR] - _keys_names: Seq[str] - _keys_names_original: Seq[str] - - @classmethod - def by_names( - cls, df: ArrowDataFrame, names: Seq[str], /, *, drop_null_keys: bool = False - ) -> Self: - obj = cls.__new__(cls) - if drop_null_keys: - df = df.drop_nulls(names) - obj._df = df - obj._keys = () - obj._keys_names = names - obj._keys_names_original = () - return obj - - @classmethod - def by_named_irs(cls, df: ArrowDataFrame, irs: Seq[NamedIR], /) -> Self: - obj = cls.__new__(cls) - keys_names = tuple(key.name for key in irs) - unique_names = temp.column_names(chain(keys_names, df.columns)) - safe_keys = tuple(replace(key, name=name) for key, name in zip(irs, unique_names)) - obj._df = df.with_columns(freeze_schema(df.schema).with_columns_irs(safe_keys)) - obj._keys = safe_keys - obj._keys_names = () - obj._keys_names_original = keys_names - return obj + _key_names: Seq[str] + _key_names_original: Seq[str] @property - def compliant(self) -> ArrowDataFrame: + def compliant(self) -> Frame: return self._df - def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: + def __iter__(self) -> Iterator[tuple[Any, Frame]]: raise NotImplementedError - def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame: + def agg(self, irs: Seq[NamedIR]) -> Frame: aggs: list[AceroAggSpec] = [] use_threads: bool = True for e in irs: @@ -199,8 +171,8 @@ def agg(self, irs: Seq[NamedIR]) -> ArrowDataFrame: use_threads = use_threads and expr.use_threads aggs.append(expr.spec) result = self.compliant._with_native(self._agg(aggs, use_threads=use_threads)) - if original := self._keys_names_original: - return result.rename(dict(zip(self.keys_names, original))) + if original := self._key_names_original: + return result.rename(dict(zip(self.key_names, original))) return result def _agg(self, agg_specs: list[AceroAggSpec], /, *, use_threads: bool) -> pa.Table: @@ -219,7 +191,7 @@ def _agg(self, agg_specs: list[AceroAggSpec], /, *, use_threads: bool) -> pa.Tab """ df = self.compliant.native # NOTE: Stubs are (incorrectly) invariant - keys: Incomplete = list(self.keys_names) + keys: Incomplete = list(self.key_names) aggs: Incomplete = agg_specs decls = [ pac.Declaration("table_source", pac.TableSourceNodeOptions(df)), diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 3eea2edaa9..0b609dd4ec 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -20,24 +20,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, Protocol +from typing import TYPE_CHECKING, Any, Generic -from narwhals._plan import _parse -from narwhals._plan._expansion import prepare_projection +from narwhals._plan.protocols import GroupByResolver, Grouper from narwhals._plan.typing import DataFrameT -from narwhals._typing_compat import TypeVar if TYPE_CHECKING: from collections.abc import Iterator - from typing_extensions import Self - from narwhals._plan.expressions import ExprIR, NamedIR - from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema + from narwhals._plan.schema import FrozenSchema from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq -ResolverT_co = TypeVar("ResolverT_co", bound="Resolver", covariant=True) - class GroupBy(Generic[DataFrameT]): _frame: DataFrameT @@ -50,103 +44,15 @@ def __init__(self, frame: DataFrameT, grouper: Grouped, /) -> None: def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: frame = self._frame resolved = self._grouper.agg(*aggs, **named_aggs).resolve(frame) - compliant = frame._compliant - compliant_gb = compliant._group_by - if resolved.requires_projection(): - grouped = compliant_gb.by_named_irs(compliant, resolved.keys) - else: - grouped = compliant_gb.by_names( - compliant, resolved.key_names, drop_null_keys=resolved._drop_null_keys - ) - return self._frame._from_compliant(grouped.agg(resolved.aggs)) + return frame._from_compliant( + frame._compliant.group_by_resolver(resolved).agg(resolved.aggs) + ) def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: msg = "Not Implemented `GroupBy.__iter__`" raise NotImplementedError(msg) -class Grouper(Protocol[ResolverT_co]): - """Revised interface focused on the state change + expression projections. - - - Uses `Expr` everywhere (no need to duplicate layers) - - Resolver only needs schema (neither needs a frame, but can use one to get `schema`) - """ - - _keys: Seq[ExprIR] - _aggs: Seq[ExprIR] - _drop_null_keys: bool - - @classmethod - def by( - cls, - *by: OneOrIterable[IntoExpr], - drop_null_keys: bool = False, - **named_by: IntoExpr, - ) -> Self: - obj = cls.__new__(cls) - obj._keys = _parse.parse_into_seq_of_expr_ir(*by, **named_by) - obj._drop_null_keys = drop_null_keys - return obj - - def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> Self: - self._aggs = _parse.parse_into_seq_of_expr_ir(*aggs, **named_aggs) - return self - - @property - def _resolver(self) -> type[ResolverT_co]: ... - - def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co: - return self._resolver.from_grouper(self, context) - - -class Resolver(Protocol): - _schema_in: FrozenSchema - _keys: Seq[NamedIR] - _aggs: Seq[NamedIR] - _key_names: Seq[str] - _schema: FrozenSchema - _drop_null_keys: bool - - @property - def keys(self) -> Seq[NamedIR]: - return self._keys - - @property - def aggs(self) -> Seq[NamedIR]: - return self._aggs - - @property - def key_names(self) -> Seq[str]: - return self._key_names - - @classmethod - def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self: - obj = cls.__new__(cls) - keys, schema_in = prepare_projection(grouper._keys, schema=context) - obj._keys, obj._schema_in = keys, schema_in - obj._key_names = tuple(e.name for e in keys) - obj._aggs, _ = prepare_projection(grouper._aggs, obj._key_names, schema=schema_in) - obj._schema = schema_in.select(keys).merge(schema_in.select(obj._aggs)) - obj._drop_null_keys = grouper._drop_null_keys - return obj - - def requires_projection(self, *, allow_aliasing: bool = False) -> bool: - """Return True is group keys contain anything that is not a column selection. - - Notes: - If False is returned, we can just use the resolved key names as a fast-path to group. - - Arguments: - allow_aliasing: If False (default), any aliasing is not considered to be column selection. - """ - if not all(key.is_column(allow_aliasing=allow_aliasing) for key in self.keys): - if self._drop_null_keys: - msg = "drop_null_keys cannot be True when keys contains Expr or Series" - raise NotImplementedError(msg) - return True - return False - - class Grouped(Grouper["Resolved"]): """Narwhals-level `GroupBy` builder.""" @@ -162,7 +68,7 @@ def to_group_by(self, frame: DataFrameT, /) -> GroupBy[DataFrameT]: return GroupBy(frame, self) -class Resolved(Resolver): +class Resolved(GroupByResolver): """Narwhals-level `GroupBy` resolver.""" _schema_in: FrozenSchema diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index e66d8525cc..183f851aea 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,10 +1,21 @@ from __future__ import annotations from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from itertools import chain from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._plan.common import flatten_hash_safe -from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq +from typing_extensions import Self + +from narwhals._plan._expansion import prepare_projection +from narwhals._plan._parse import parse_into_seq_of_expr_ir +from narwhals._plan.common import flatten_hash_safe, replace, temp +from narwhals._plan.typing import ( + IntoExpr, + NativeDataFrameT, + NativeFrameT, + NativeSeriesT, + Seq, +) from narwhals._typing_compat import TypeVar from narwhals._utils import Version from narwhals.exceptions import ComputeError @@ -16,6 +27,7 @@ from narwhals._plan.dataframe import BaseFrame, DataFrame from narwhals._plan.expressions import ( BinaryExpr, + ExprIR, FunctionExpr, NamedIR, aggregation as agg, @@ -26,6 +38,7 @@ from narwhals._plan.expressions.ranges import IntRange from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema from narwhals._plan.series import Series from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType @@ -50,6 +63,8 @@ ColumnT = TypeVar("ColumnT") ColumnT_co = TypeVar("ColumnT_co", covariant=True) +ResolverT_co = TypeVar("ResolverT_co", bound="GroupByResolver", covariant=True) + ExprAny: TypeAlias = "CompliantExpr[Any, Any]" ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" SeriesAny: TypeAlias = "CompliantSeries[Any]" @@ -582,6 +597,13 @@ def _group_by(self) -> type[DataFrameGroupBy[Self]]: ... def from_dict( cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: ... + def group_by( + self, *by: OneOrIterable[IntoExpr], **named_by: IntoExpr + ) -> DataFrameGroupBy[Self]: ... + def group_by_resolver( + self, resolver: GroupByResolver, / + ) -> DataFrameGroupBy[Self]: ... + def group_by_names(self, names: Seq[str], /) -> DataFrameGroupBy[Self]: ... def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... @@ -606,27 +628,27 @@ def agg(self, *args: Any, **kwds: Any) -> FrameT_co: ... class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): _keys: Seq[NamedIR] - _keys_names: Seq[str] + _key_names: Seq[str] @classmethod - def by_names( - cls, df: DataFrameT, names: Seq[str], /, *, drop_null_keys: bool = False + def from_resolver( + cls, df: DataFrameT, resolver: GroupByResolver, / ) -> DataFrameGroupBy[DataFrameT]: ... @classmethod - def by_named_irs( - cls, df: DataFrameT, irs: Seq[NamedIR], / + def by_names( + cls, df: DataFrameT, names: Seq[str], / ) -> DataFrameGroupBy[DataFrameT]: ... def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... @property def keys(self) -> Seq[NamedIR]: return self._keys + # TODO @dangotbanned: Review if this can be dropped/reduced + # now it *also* defined in `GroupByResolver` @property - def keys_names(self) -> Seq[str]: - if names := self._keys_names: + def key_names(self) -> Seq[str]: + if names := self._key_names: return names - if keys := self.keys: - return tuple(e.name for e in keys) msg = "at least one key is required in a group_by operation" raise ComputeError(msg) @@ -637,7 +659,26 @@ class EagerDataFrame( CompliantDataFrame[SeriesT, NativeDataFrameT, NativeSeriesT], Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], ): + @property + def _group_by(self) -> type[EagerDataFrameGroupBy[Self]]: ... def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... + def group_by( + self, *by: OneOrIterable[IntoExpr], **named_by: IntoExpr + ) -> EagerDataFrameGroupBy[Self]: + msg = ( + "Not Implemented `EagerDataFrame.group_by`.\n\n" + "TODO: Just needs a lil bit of planning.\nShould be quite similar to narwhals-level version, (excluding `drop_null_keys`)" + ) + raise NotImplementedError(msg) + + def group_by_resolver( + self, resolver: GroupByResolver, / + ) -> EagerDataFrameGroupBy[Self]: + return self._group_by.from_resolver(self, resolver) + + def group_by_names(self, names: Seq[str], /) -> EagerDataFrameGroupBy[Self]: + return self._group_by.by_names(self, names) + def select(self, irs: Seq[NamedIR]) -> Self: return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) @@ -645,6 +686,40 @@ def with_columns(self, irs: Seq[NamedIR]) -> Self: return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) +class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]): + _df: EagerDataFrameT + _key_names: Seq[str] + _key_names_original: Seq[str] + + @classmethod + def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: + obj = cls.__new__(cls) + obj._df = df + obj._keys = () + obj._key_names = names + obj._key_names_original = () + return obj + + @classmethod + def from_resolver( + cls, df: EagerDataFrameT, resolver: GroupByResolver, / + ) -> EagerDataFrameGroupBy[EagerDataFrameT]: + key_names = resolver.key_names + if not resolver.requires_projection(): + df = df.drop_nulls(key_names) if resolver._drop_null_keys else df + return cls.by_names(df, key_names) + obj = cls.__new__(cls) + unique_names = temp.column_names(chain(key_names, df.columns)) + safe_keys = tuple( + replace(key, name=name) for key, name in zip(resolver.keys, unique_names) + ) + obj._df = df.with_columns(resolver._schema_in.with_columns_irs(safe_keys)) + obj._keys = safe_keys + obj._key_names = tuple(e.name for e in safe_keys) + obj._key_names_original = key_names + return obj + + class CompliantSeries(StoresVersion, Protocol[NativeSeriesT]): _native: NativeSeriesT _name: str @@ -702,3 +777,94 @@ def __len__(self) -> int: def to_list(self) -> list[Any]: ... def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ... + + +class Grouper(Protocol[ResolverT_co]): + """Revised interface focused on the state change + expression projections. + + - Uses `Expr` everywhere (no need to duplicate layers) + - Resolver only needs schema (neither needs a frame, but can use one to get `schema`) + """ + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool + + @classmethod + def by( + cls, + *by: OneOrIterable[IntoExpr], + drop_null_keys: bool = False, + **named_by: IntoExpr, + ) -> Self: + obj = cls.__new__(cls) + obj._keys = parse_into_seq_of_expr_ir(*by, **named_by) + obj._drop_null_keys = drop_null_keys + return obj + + def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> Self: + self._aggs = parse_into_seq_of_expr_ir(*aggs, **named_aggs) + return self + + @property + def _resolver(self) -> type[ResolverT_co]: ... + + def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co: + return self._resolver.from_grouper(self, context) + + +class GroupByResolver(Protocol): + _schema_in: FrozenSchema + _keys: Seq[NamedIR] + _aggs: Seq[NamedIR] + _key_names: Seq[str] + _schema: FrozenSchema + _drop_null_keys: bool + + @property + def keys(self) -> Seq[NamedIR]: + return self._keys + + @property + def aggs(self) -> Seq[NamedIR]: + return self._aggs + + @property + def key_names(self) -> Seq[str]: + if names := self._key_names: + return names + if keys := self.keys: + return tuple(e.name for e in keys) + msg = "at least one key is required in a group_by operation" + raise ComputeError(msg) + + @property + def schema(self) -> FrozenSchema: + return self._schema + + @classmethod + def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self: + obj = cls.__new__(cls) + keys, schema_in = prepare_projection(grouper._keys, schema=context) + obj._keys, obj._schema_in = keys, schema_in + obj._key_names = tuple(e.name for e in keys) + obj._aggs, _ = prepare_projection(grouper._aggs, obj.key_names, schema=schema_in) + obj._schema = schema_in.select(keys).merge(schema_in.select(obj._aggs)) + obj._drop_null_keys = grouper._drop_null_keys + return obj + + def requires_projection(self, *, allow_aliasing: bool = False) -> bool: + """Return True is group keys contain anything that is not a column selection. + + Notes: + If False is returned, we can just use the resolved key names as a fast-path to group. + + Arguments: + allow_aliasing: If False (default), any aliasing is not considered to be column selection. + """ + if not all(key.is_column(allow_aliasing=allow_aliasing) for key in self.keys): + if self._drop_null_keys: + msg = "drop_null_keys cannot be True when keys contains Expr or Series" + raise NotImplementedError(msg) + return True + return False From a234cc749cc2cd55c84e746e4aa5d504ee2f9af7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 25 Sep 2025 22:14:04 +0000 Subject: [PATCH 51/93] =?UTF-8?q?=F0=9F=98=A0=F0=9F=98=A0=F0=9F=98=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://github.com/narwhals-dev/narwhals/actions/runs/18021842115/job/51280865880 --- narwhals/_plan/protocols.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 183f851aea..b15f5dde39 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -4,8 +4,6 @@ from itertools import chain from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from typing_extensions import Self - from narwhals._plan._expansion import prepare_projection from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.common import flatten_hash_safe, replace, temp From 9d17006e299da3f2c48a17c23601b13a282ce9e7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Sep 2025 14:28:05 +0000 Subject: [PATCH 52/93] refactor: Define `group_by_agg` Tricky to do this in a way that doesn't end up duplicating lots of logic/layers --- narwhals/_plan/group_by.py | 40 +++++++++------- narwhals/_plan/protocols.py | 92 +++++++++++++++++++++---------------- 2 files changed, 76 insertions(+), 56 deletions(-) diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 0b609dd4ec..2318c247e7 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -22,14 +22,16 @@ from typing import TYPE_CHECKING, Any, Generic -from narwhals._plan.protocols import GroupByResolver, Grouper +from narwhals._plan._parse import parse_into_seq_of_expr_ir +from narwhals._plan.protocols import GroupByResolver as Resolved, Grouper from narwhals._plan.typing import DataFrameT if TYPE_CHECKING: from collections.abc import Iterator - from narwhals._plan.expressions import ExprIR, NamedIR - from narwhals._plan.schema import FrozenSchema + from typing_extensions import Self + + from narwhals._plan.expressions import ExprIR from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq @@ -43,9 +45,10 @@ def __init__(self, frame: DataFrameT, grouper: Grouped, /) -> None: def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: frame = self._frame - resolved = self._grouper.agg(*aggs, **named_aggs).resolve(frame) return frame._from_compliant( - frame._compliant.group_by_resolver(resolved).agg(resolved.aggs) + self._grouper.agg(*aggs, **named_aggs) + .resolve(frame) + .evaluate(frame._compliant) ) def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: @@ -60,20 +63,25 @@ class Grouped(Grouper["Resolved"]): _aggs: Seq[ExprIR] _drop_null_keys: bool + @classmethod + def by( + cls, + *by: OneOrIterable[IntoExpr], + drop_null_keys: bool = False, + **named_by: IntoExpr, + ) -> Self: + obj = cls.__new__(cls) + obj._keys = parse_into_seq_of_expr_ir(*by, **named_by) + obj._drop_null_keys = drop_null_keys + return obj + + def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> Self: + self._aggs = parse_into_seq_of_expr_ir(*aggs, **named_aggs) + return self + @property def _resolver(self) -> type[Resolved]: return Resolved def to_group_by(self, frame: DataFrameT, /) -> GroupBy[DataFrameT]: return GroupBy(frame, self) - - -class Resolved(GroupByResolver): - """Narwhals-level `GroupBy` resolver.""" - - _schema_in: FrozenSchema - _keys: Seq[NamedIR] - _aggs: Seq[NamedIR] - _key_names: Seq[str] - _schema: FrozenSchema - _drop_null_keys: bool diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index b15f5dde39..93223cc116 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,3 +1,5 @@ +"""TODO: Split this module up into `narwhals._plan.compliant.*`.""" + from __future__ import annotations from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized @@ -591,17 +593,31 @@ class CompliantDataFrame( ): @property def _group_by(self) -> type[DataFrameGroupBy[Self]]: ... + @property + def _grouper(self) -> type[Grouped]: + return Grouped + @classmethod def from_dict( cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: ... - def group_by( - self, *by: OneOrIterable[IntoExpr], **named_by: IntoExpr - ) -> DataFrameGroupBy[Self]: ... - def group_by_resolver( - self, resolver: GroupByResolver, / - ) -> DataFrameGroupBy[Self]: ... - def group_by_names(self, names: Seq[str], /) -> DataFrameGroupBy[Self]: ... + def group_by_agg( + self, by: OneOrIterable[IntoExpr], aggs: OneOrIterable[IntoExpr], / + ) -> Self: + """Compliant-level `group_by(by).agg(agg)`, allows `Expr`.""" + return self._grouper.by(by).agg(aggs).resolve(self).evaluate(self) + + def group_by_names(self, names: Seq[str], /) -> DataFrameGroupBy[Self]: + """Compliant-level `group_by`, allowing only `str` keys.""" + return self._group_by.by_names(self, names) + + def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Self]: + """Narwhals-level resolved `group_by`. + + `keys`, `aggs` are already parsed and projections planned. + """ + return self._group_by.from_resolver(self, resolver) + def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... @@ -621,7 +637,7 @@ def with_row_index(self, name: str) -> Self: ... class CompliantGroupBy(Protocol[FrameT_co]): @property def compliant(self) -> FrameT_co: ... - def agg(self, *args: Any, **kwds: Any) -> FrameT_co: ... + def agg(self, irs: Seq[NamedIR]) -> FrameT_co: ... class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): @@ -641,8 +657,6 @@ def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... def keys(self) -> Seq[NamedIR]: return self._keys - # TODO @dangotbanned: Review if this can be dropped/reduced - # now it *also* defined in `GroupByResolver` @property def key_names(self) -> Seq[str]: if names := self._key_names: @@ -650,8 +664,6 @@ def key_names(self) -> Seq[str]: msg = "at least one key is required in a group_by operation" raise ComputeError(msg) - def agg(self, irs: Seq[NamedIR]) -> DataFrameT: ... - class EagerDataFrame( CompliantDataFrame[SeriesT, NativeDataFrameT, NativeSeriesT], @@ -660,23 +672,6 @@ class EagerDataFrame( @property def _group_by(self) -> type[EagerDataFrameGroupBy[Self]]: ... def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... - def group_by( - self, *by: OneOrIterable[IntoExpr], **named_by: IntoExpr - ) -> EagerDataFrameGroupBy[Self]: - msg = ( - "Not Implemented `EagerDataFrame.group_by`.\n\n" - "TODO: Just needs a lil bit of planning.\nShould be quite similar to narwhals-level version, (excluding `drop_null_keys`)" - ) - raise NotImplementedError(msg) - - def group_by_resolver( - self, resolver: GroupByResolver, / - ) -> EagerDataFrameGroupBy[Self]: - return self._group_by.from_resolver(self, resolver) - - def group_by_names(self, names: Seq[str], /) -> EagerDataFrameGroupBy[Self]: - return self._group_by.by_names(self, names) - def select(self, irs: Seq[NamedIR]) -> Self: return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) @@ -789,19 +784,13 @@ class Grouper(Protocol[ResolverT_co]): _drop_null_keys: bool @classmethod - def by( - cls, - *by: OneOrIterable[IntoExpr], - drop_null_keys: bool = False, - **named_by: IntoExpr, - ) -> Self: + def by(cls, *by: OneOrIterable[IntoExpr]) -> Self: obj = cls.__new__(cls) - obj._keys = parse_into_seq_of_expr_ir(*by, **named_by) - obj._drop_null_keys = drop_null_keys + obj._keys = parse_into_seq_of_expr_ir(*by) return obj - def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> Self: - self._aggs = parse_into_seq_of_expr_ir(*aggs, **named_aggs) + def agg(self, *aggs: OneOrIterable[IntoExpr]) -> Self: + self._aggs = parse_into_seq_of_expr_ir(*aggs) return self @property @@ -811,7 +800,9 @@ def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co: return self._resolver.from_grouper(self, context) -class GroupByResolver(Protocol): +class GroupByResolver: + """Narwhals-level `GroupBy` resolver.""" + _schema_in: FrozenSchema _keys: Seq[NamedIR] _aggs: Seq[NamedIR] @@ -840,6 +831,9 @@ def key_names(self) -> Seq[str]: def schema(self) -> FrozenSchema: return self._schema + def evaluate(self, frame: DataFrameT) -> DataFrameT: + return frame.group_by_resolver(self).agg(self.aggs) + @classmethod def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self: obj = cls.__new__(cls) @@ -866,3 +860,21 @@ def requires_projection(self, *, allow_aliasing: bool = False) -> bool: raise NotImplementedError(msg) return True return False + + +class Grouped(Grouper["Resolved"]): + """Compliant-level `GroupBy` builder.""" + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool = False + + @property + def _resolver(self) -> type[Resolved]: + return Resolved + + +class Resolved(GroupByResolver): + """Compliant-level `GroupBy` resolver.""" + + _drop_null_keys: bool = False From 8eb5db091df20288d95b67b4158e248c8da4f82f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Sep 2025 15:13:44 +0000 Subject: [PATCH 53/93] feat(DRAFT): Almost direct port of `ArrowGroupBy.__iter__` - Enough to pass the remaining tests - Hoping to reduce the complexity in `__iter__` next --- narwhals/_plan/arrow/dataframe.py | 6 ++++++ narwhals/_plan/arrow/group_by.py | 22 +++++++++++++++++++++- narwhals/_plan/group_by.py | 6 ++++-- narwhals/_plan/protocols.py | 1 + tests/plan/group_by_test.py | 6 ++---- 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index fa1c806c44..cbfeabac61 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -141,3 +141,9 @@ def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: else: native = native.append_column(name, chunked) return self._with_native(native) + + def select_names(self, *column_names: str) -> Self: + return self._with_native(self.native.select(list(column_names))) + + def row(self, index: int) -> tuple[Any, ...]: + return tuple(col[index] for col in self.native.itercolumns()) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index aec96c0a73..7062ff8d66 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -6,7 +6,9 @@ import pyarrow.acero as pac import pyarrow.compute as pc # ignore-banned-import +from narwhals._arrow.utils import cast_to_comparable_string_types from narwhals._plan import expressions as ir +from narwhals._plan.common import temp from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import EagerDataFrameGroupBy from narwhals._utils import Implementation @@ -161,7 +163,25 @@ def compliant(self) -> Frame: return self._df def __iter__(self) -> Iterator[tuple[Any, Frame]]: - raise NotImplementedError + col_token = temp.column_name(self.compliant) + null_token = f"__null_{col_token}_value__" + table = self.compliant.native + it, separator_scalar = cast_to_comparable_string_types( + *(table[key] for key in self.key_names), separator="" + ) + concat_str: Incomplete = pc.binary_join_element_wise + key_values = concat_str( + *it, separator_scalar, null_handling="replace", null_replacement=null_token + ) + table = table.add_column(i=0, field_=col_token, column=key_values) + for v in pc.unique(key_values): + t = self.compliant._with_native( + table.filter(pc.equal(table[col_token], v)).drop([col_token]) + ) + row = t.select_names(*self.key_names).row(0) + group_key = tuple(el.as_py() for el in row) + partition = t.select_names(*self.compliant.columns) + yield group_key, partition def agg(self, irs: Seq[NamedIR]) -> Frame: aggs: list[AceroAggSpec] = [] diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index 2318c247e7..f47ef9b8b3 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -52,8 +52,10 @@ def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFra ) def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: - msg = "Not Implemented `GroupBy.__iter__`" - raise NotImplementedError(msg) + frame = self._frame + resolver = self._grouper.agg().resolve(frame) + for key, df in frame._compliant.group_by_resolver(resolver): + yield key, frame._from_compliant(df) class Grouped(Grouper["Resolved"]): diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 93223cc116..0f4e2045b8 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -581,6 +581,7 @@ def _evaluate_irs( self, nodes: Iterable[NamedIR[ir.ExprIR]], / ) -> Iterator[ColumnT_co]: ... def select(self, irs: Seq[NamedIR]) -> Self: ... + def select_names(self, *column_names: str) -> Self: ... def with_columns(self, irs: Seq[NamedIR]) -> Self: ... def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: ... def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: ... diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 70999c548d..ef59e4f050 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -30,8 +30,7 @@ def assert_equal_data(result: nwp.DataFrame, expected: Mapping[str, Any]) -> Non _assert_equal_data(result.to_dict(as_series=False), expected) -@pytest.mark.xfail(reason="Not implemented `__iter__`", raises=NotImplementedError) -def test_group_by_iter() -> None: # pragma: no cover +def test_group_by_iter() -> None: data = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} df = dataframe(data) expected_keys: list[tuple[int, ...]] = [(1,), (3,)] @@ -212,8 +211,7 @@ def test_key_with_nulls_ignored() -> None: assert_equal_data(result, expected) -@pytest.mark.xfail(reason="Not implemented `__iter__`", raises=NotImplementedError) -def test_key_with_nulls_iter() -> None: # pragma: no cover +def test_key_with_nulls_iter() -> None: data = { "b": [None, "4", "5", None, "7"], "a": [None, 1, 2, 3, 4], From 890732e5cbc13cb71b3b9ad64c5659085c2d09d7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Sep 2025 16:58:52 +0000 Subject: [PATCH 54/93] test: Add (failing) `test_group_by_expr_iter` Bug isnt present on `main`, probably need to redo the resolve stuff to fix --- tests/plan/group_by_test.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index ef59e4f050..960b5ecb93 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -229,6 +229,32 @@ def test_key_with_nulls_iter() -> None: assert len(result) == 4 +@pytest.mark.xfail( + reason="Temporary alias column present in result", raises=AssertionError +) +def test_group_by_expr_iter() -> None: + data = { + "b": [None, "4", "5", None, "7"], + "a": [None, 1, 2, 3, 4], + "c": ["1", "4", "3", "1", "1"], + } + + expected = { + ("1",): {"b": [None, None, "7"], "a": [None, 3, 4], "c": ["1", "1", "1"]}, + ("3",): {"b": ["5"], "a": [2], "c": ["3"]}, + ("4",): {"b": ["4"], "a": [1], "c": ["4"]}, + } + grouped = dataframe(data).group_by(nwp.col("c").alias("d")) + result = dict(sorted((k, df.sort("c").to_dict(as_series=False)) for k, df in grouped)) + assert len(result) == len(expected) + assert result.keys() == expected.keys() + # NOTE: The bug means that zipping will break, as one side has more columns + result_p1 = next(iter(result.values())) + expected_p1 = next(iter(expected.values())) + assert result_p1 == expected_p1 + _assert_equal_data(result, expected) # type: ignore[arg-type] + + @pytest.mark.parametrize( "keys", [[nwp.col("a").abs()], ["a", nwp.col("a").abs().alias("a_test")]] ) From 3776e3a416e2fd419f589b72cc600355cc4a91af Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Sep 2025 18:03:38 +0000 Subject: [PATCH 55/93] fix: Select the right columns in `__iter__` Didn't realize on main we store two `ArrowDataFrame`s https://github.com/narwhals-dev/narwhals/blob/63c5022e347cef3f821f725350cd9d39e6e476c6/narwhals/_arrow/group_by.py#L139 https://github.com/narwhals-dev/narwhals/blob/63c5022e347cef3f821f725350cd9d39e6e476c6/narwhals/_arrow/group_by.py#L158 --- narwhals/_plan/arrow/group_by.py | 2 +- narwhals/_plan/protocols.py | 3 +++ tests/plan/group_by_test.py | 3 --- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 7062ff8d66..9edf3bcc58 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -180,7 +180,7 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: ) row = t.select_names(*self.key_names).row(0) group_key = tuple(el.as_py() for el in row) - partition = t.select_names(*self.compliant.columns) + partition = t.select_names(*self._column_names_original) yield group_key, partition def agg(self, irs: Seq[NamedIR]) -> Frame: diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 0f4e2045b8..1e7dc0fd4d 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -684,6 +684,7 @@ class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDat _df: EagerDataFrameT _key_names: Seq[str] _key_names_original: Seq[str] + _column_names_original: Seq[str] @classmethod def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: @@ -692,6 +693,7 @@ def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: obj._keys = () obj._key_names = names obj._key_names_original = () + obj._column_names_original = tuple(df.columns) return obj @classmethod @@ -711,6 +713,7 @@ def from_resolver( obj._keys = safe_keys obj._key_names = tuple(e.name for e in safe_keys) obj._key_names_original = key_names + obj._column_names_original = resolver._schema_in.names return obj diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 960b5ecb93..3a5ffd162a 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -229,9 +229,6 @@ def test_key_with_nulls_iter() -> None: assert len(result) == 4 -@pytest.mark.xfail( - reason="Temporary alias column present in result", raises=AssertionError -) def test_group_by_expr_iter() -> None: data = { "b": [None, "4", "5", None, "7"], From 5b1fa00a077e129c271fce7366131c451c88274d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Sep 2025 18:22:46 +0000 Subject: [PATCH 56/93] chore: Tidy up some comments/notes/docs --- narwhals/_plan/_expansion.py | 5 ++--- narwhals/_plan/dataframe.py | 3 --- narwhals/_plan/group_by.py | 20 -------------------- narwhals/_plan/protocols.py | 4 ++++ narwhals/_plan/schema.py | 4 ++++ 5 files changed, 10 insertions(+), 26 deletions(-) diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index 0e9d83e6b9..6cbf061a98 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -88,10 +88,9 @@ """Internally use a `set`, then freeze before returning.""" GroupByKeys: TypeAlias = "Seq[str]" -"""Represents group_by keys. +"""Represents `group_by` keys. -- Originates from `polars_plan::plans::conversion::dsl_to_ir::resolve_group_by` -- Not fully utilized in `narwhals` version yet +They need to be excluded from expansion. """ OutputNames: TypeAlias = "Seq[str]" diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 55aa72219b..1b6975d589 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -66,9 +66,6 @@ def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: ) return self._from_compliant(self._compliant.select(schema.select_irs(named_irs))) - # NOTE: Want to be able to call `with_columns` at compliant level - and still get the right schema - # - Currently it acts like select in `group_by` - # - Doing some gymnastics to workaround for now def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, schema = prepare_projection( _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py index f47ef9b8b3..5e95bd484e 100644 --- a/narwhals/_plan/group_by.py +++ b/narwhals/_plan/group_by.py @@ -1,23 +1,3 @@ -"""Refresher on `rust` impl. - -- [`resolve_group_by`] has the dsl algo - - Depends on some `expr_expansion` functions I've implemented - - `group_by_dynamic` is there also (but not doing that) - - ooooh [auto-implode] -- [`dsl_to_ir::to_alp_impl`] was the caller of ^^^^^ -- Misc recent important PRs - - `1.32.1` - - [Remove `Context` from logical layer] - - `1.32.0` - - [Make `Selector` a concrete part of the DSL] - -[`resolve_group_by`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1125-L1227 -[auto-implode]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1197-L1203 -[`dsl_to_ir::to_alp_impl`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L459-L509 -[Remove `Context` from logical layer]: https://github.com/pola-rs/polars/pull/23863 -[Make `Selector` a concrete part of the DSL]: https://github.com/pola-rs/polars/pull/23351 -""" - from __future__ import annotations from typing import TYPE_CHECKING, Any, Generic diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 1e7dc0fd4d..e1d8317522 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -840,6 +840,10 @@ def evaluate(self, frame: DataFrameT) -> DataFrameT: @classmethod def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self: + """Loosely based on [`resolve_group_by`]. + + [`resolve_group_by`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1125-L1227 + """ obj = cls.__new__(cls) keys, schema_in = prepare_projection(grouper._keys, schema=context) obj._keys, obj._schema_in = keys, schema_in diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 6ab379c793..67433db06b 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -77,6 +77,10 @@ def with_columns(self, exprs: Seq[NamedIR]) -> FrozenSchema: return freeze_schema(self._mapping | miss) def with_columns_irs(self, exprs: Seq[NamedIR]) -> Seq[NamedIR]: + """Required for `_concat_horizontal`-based `with_columns`. + + Fills in any unreferenced columns present in `self`, but not in `exprs` as selections. + """ named: dict[str, NamedIR[Any]] = {e.name: e for e in exprs} it = (named.pop(name, NamedIR.from_name(name)) for name in self) return tuple(chain(it, named.values())) From 0fac62255b8020a0ed503774f619e2325a566545 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Sep 2025 18:44:41 +0000 Subject: [PATCH 57/93] excessive comments --- narwhals/_plan/arrow/group_by.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 9edf3bcc58..bab14e491a 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -163,23 +163,38 @@ def compliant(self) -> Frame: return self._df def __iter__(self) -> Iterator[tuple[Any, Frame]]: + # random column name col_token = temp.column_name(self.compliant) + # random null fill value null_token = f"__null_{col_token}_value__" + # native table = self.compliant.native + # get key columns, cast everything to str? + # make sure all either string or all large_string + # separator also has to be that string type it, separator_scalar = cast_to_comparable_string_types( *(table[key] for key in self.key_names), separator="" ) + # join those strings horizontally to generate a single key column concat_str: Incomplete = pc.binary_join_element_wise key_values = concat_str( *it, separator_scalar, null_handling="replace", null_replacement=null_token ) - table = table.add_column(i=0, field_=col_token, column=key_values) + # add that column (of `key_values`) back to the table + table_w_key = table.add_column(i=0, field_=col_token, column=key_values) + # iterate over the unique keys in the `key_values` array for v in pc.unique(key_values): + # filter the keyed table to rows that have the same key (`t`) + # then drop the temporary key on the result t = self.compliant._with_native( - table.filter(pc.equal(table[col_token], v)).drop([col_token]) + table_w_key.filter(pc.equal(table_w_key[col_token], v)).drop([col_token]) ) + # subset this new table to only the actual key name columns + # then convert the first row to `tuple[pa.Scalar, ...]` row = t.select_names(*self.key_names).row(0) + # convert those scalars to python literals group_key = tuple(el.as_py() for el in row) + # select (all) columns from (`t`) that we started with at `.group_by()``, ignoring new keys/aliases partition = t.select_names(*self._column_names_original) yield group_key, partition From c5497361fa69716d9a24266208e6e2ba1775dde6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Sep 2025 18:47:19 +0000 Subject: [PATCH 58/93] perf: Use `pc.Expression` instead of eager predicate --- narwhals/_plan/arrow/group_by.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index bab14e491a..48626e26e7 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -187,7 +187,7 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: # filter the keyed table to rows that have the same key (`t`) # then drop the temporary key on the result t = self.compliant._with_native( - table_w_key.filter(pc.equal(table_w_key[col_token], v)).drop([col_token]) + table_w_key.filter(pc.field(col_token) == v).drop([col_token]) ) # subset this new table to only the actual key name columns # then convert the first row to `tuple[pa.Scalar, ...]` From 5ef5c5335158054fd1dffa4692d5be5fdc6ccdcd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 26 Sep 2025 18:50:17 +0000 Subject: [PATCH 59/93] perf: `remove_column` instead of `drop` Avoids iterating over schema --- narwhals/_plan/arrow/group_by.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 48626e26e7..667aa913d3 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -187,7 +187,7 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: # filter the keyed table to rows that have the same key (`t`) # then drop the temporary key on the result t = self.compliant._with_native( - table_w_key.filter(pc.field(col_token) == v).drop([col_token]) + table_w_key.filter(pc.field(col_token) == v).remove_column(0) ) # subset this new table to only the actual key name columns # then convert the first row to `tuple[pa.Scalar, ...]` From 70523ff823e06121b5d681327a6f9db3faac1cde Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 14:58:53 +0000 Subject: [PATCH 60/93] feat: Add `arrow.acero` module --- narwhals/_plan/arrow/acero.py | 196 +++++++++++++++++++++++++++++++ narwhals/_plan/arrow/group_by.py | 73 ++++-------- 2 files changed, 218 insertions(+), 51 deletions(-) create mode 100644 narwhals/_plan/arrow/acero.py diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py new file mode 100644 index 0000000000..ecbee87237 --- /dev/null +++ b/narwhals/_plan/arrow/acero.py @@ -0,0 +1,196 @@ +"""Sugar for working with [Acero]. + +[`pyarrow.acero`] has some building blocks for constructing queries, but is +quite verbose when used directly. + +This module aligns some apis to look more like `polars`. + +[Acero]: https://arrow.apache.org/docs/cpp/acero/overview.html +[`pyarrow.acero`]: https://arrow.apache.org/docs/python/api/acero.html +""" + +from __future__ import annotations + +import operator +from functools import reduce +from itertools import chain +from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union + +import pyarrow as pa # ignore-banned-import +import pyarrow.acero as pac +import pyarrow.compute as pc # ignore-banned-import +from pyarrow.acero import Declaration as Decl + +from narwhals.typing import SingleColSelector + +if TYPE_CHECKING: + from collections.abc import Iterable + + from typing_extensions import TypeAlias + + from narwhals._arrow.typing import ( # type: ignore[attr-defined] + AggregateOptions as _AggregateOptions, + Aggregation as _Aggregation, + ) + from narwhals._plan.typing import Seq + from narwhals.typing import NonNestedLiteral + +T = TypeVar("T") +OneOrListOrTuple: TypeAlias = Union[T, list[T], tuple[T, ...]] +"""WARNING: Don't use this unless there is a runtime check for exactly `list | tuple`.""" + + +Incomplete: TypeAlias = Any +Expr: TypeAlias = pc.Expression +IntoExpr: TypeAlias = "Expr | NonNestedLiteral" +Field: TypeAlias = Union[Expr, SingleColSelector] +"""Anything that passes as a single item in [`_compute._ensure_field_ref`]. + +[`_compute._ensure_field_ref`]: https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_compute.pyx#L1507-L1531 +""" + +AggKeys: TypeAlias = "Iterable[Field] | None" + +Target: TypeAlias = OneOrListOrTuple[Field] +Aggregation: TypeAlias = "_Aggregation" +Opts: TypeAlias = "_AggregateOptions | None" +OutputName: TypeAlias = str +AggSpec: TypeAlias = tuple[Target, Aggregation, Opts, OutputName] + + +# TODO @dangotbanned: Rename +def pc_expr(into: IntoExpr, /, *, str_as_lit: bool = False) -> Expr: + if isinstance(into, pc.Expression): + return into + if isinstance(into, str) and not str_as_lit: + return pc.field(into) + arg: Incomplete = into + return pc.scalar(arg) + + +def _parse_all_horizontal(predicates: Seq[Expr], constraints: dict[str, Any], /) -> Expr: + if not constraints and len(predicates) == 1: + return predicates[0] + it = ( + pc.field(name) == pc_expr(v, str_as_lit=True) for name, v in constraints.items() + ) + return reduce(operator.and_, chain(predicates, it)) + + +# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`) +def table_source(native: pa.Table, /) -> Decl: + """A Source node which accepts a table.""" + return Decl("table_source", options=pac.TableSourceNodeOptions(native)) + + +def _aggregate(agg_specs: Iterable[AggSpec], /, keys: AggKeys = None) -> Decl: + # NOTE: See https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_acero.pyx#L167-L192 + aggs: Incomplete = agg_specs + keys_: Incomplete = keys + return Decl("aggregate", pac.AggregateNodeOptions(aggs, keys=keys_)) + + +# TODO @dangotbanned: Plan +# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`) +def aggregate(aggs: Iterable[AggSpec], /) -> Decl: + """Scalar aggregate. + + Reduce an array or scalar input to a single scalar output (e.g. computing the mean of a column) + """ + return _aggregate(aggs) + + +# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`) +def group_by(keys: AggKeys, aggs: Iterable[AggSpec], /) -> Decl: + """Hash aggregate. + + Like GROUP BY in SQL and first partition data based on one or more key columns, + then reduce the data in each partition. + """ + return _aggregate(aggs, keys=keys) + + +def filter(*predicates: Expr, **constraints: IntoExpr) -> Decl: + """Selects rows where all expressions evaulate to True. + + Arguments: + predicates: [`Expression`](s) which must all have a return type of boolean. + constraints: Column filters; use `name = value` to filter columns by the supplied value. + + Notes: + - Uses logic similar to [`polars`] for an AND-reduction + - Elements where the filter does not evaluate to True are discarded, **including nulls** + + [`Expression`]: https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html + [`polars`]: https://github.com/pola-rs/polars/blob/d0914d416ce4e1dfcb5f946875ffd1181e31c493/py-polars/polars/_utils/parse/expr.py#L199-L242 + """ + expr = _parse_all_horizontal(predicates, constraints) + return Decl("filter", options=pac.FilterNodeOptions(expr)) + + +# TODO @dangotbanned: Plan +def select(*exprs: IntoExpr, **named_exprs: IntoExpr) -> Decl: + raise NotImplementedError + + +# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`) +def project(**named_exprs: Expr) -> Decl: + """Make a node which executes expressions on input batches, producing batches of the same length with new columns. + + This is the option class for the "project" node factory. + + The "project" operation rearranges, deletes, transforms, and + creates columns. Each output column is computed by evaluating + an expression against the source record batch. These must be + scalar expressions (expressions consisting of scalar literals, + field references and scalar functions, i.e. elementwise functions + that return one value for each input row independent of the value + of all other rows). + """ + # NOTE: Both just need to be sized and iterable + names: Incomplete = named_exprs.keys() + exprs: Incomplete = named_exprs.values() + return Decl("project", options=pac.ProjectNodeOptions(exprs, names)) + + +# TODO @dangotbanned: Find which option class this uses +def order_by( + sort_keys: tuple[tuple[str, Literal["ascending", "descending"]], ...] = (), + *, + null_placement: Literal["at_start", "at_end"] = "at_end", +) -> Decl: + return Decl( + "order_by", pac.OrderByNodeOptions(sort_keys, null_placement=null_placement) + ) + + +# TODO @dangotbanned: Docs +def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table: + # NOTE: stubs + docs say `list`, but impl allows any iterable + decls: Incomplete = declarations + return Decl.from_sequence(decls).to_table(use_threads=use_threads) + + +# NOTE: Composite functions are suffixed with `_table` +def group_by_table( + native: pa.Table, keys: AggKeys, aggs: Iterable[AggSpec], *, use_threads: bool +) -> pa.Table: + """Adapted from [`pa.TableGroupBy.aggregate`] and [`pa.acero._group_by`]. + + - Backport of [apache/arrow#36768]. + - `first` and `last` were [broken in `pyarrow==13`]. + - Also allows us to specify our own aliases for aggregate output columns. + - Fixes [narwhals-dev/narwhals#1612] + + [`pa.TableGroupBy.aggregate`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/table.pxi#L6600-L6626 + [`pa.acero._group_by`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/acero.py#L412-L418 + [apache/arrow#36768]: https://github.com/apache/arrow/pull/36768 + [broken in `pyarrow==13`]: https://github.com/apache/arrow/issues/36709 + [narwhals-dev/narwhals#1612]: https://github.com/narwhals-dev/narwhals/issues/1612 + """ + return collect(table_source(native), group_by(keys, aggs), use_threads=use_threads) + + +# TODO @dangotbanned: Docs? +def filter_table(native: pa.Table, *predicates: Expr, **constraints: Any) -> pa.Table: + return collect(table_source(native), filter(*predicates, **constraints)) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 667aa913d3..f15e73be30 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -3,11 +3,11 @@ from typing import TYPE_CHECKING, Any, Literal import pyarrow as pa # ignore-banned-import -import pyarrow.acero as pac import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import cast_to_comparable_string_types from narwhals._plan import expressions as ir +from narwhals._plan.arrow import acero from narwhals._plan.common import temp from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import EagerDataFrameGroupBy @@ -18,27 +18,16 @@ from typing_extensions import Self, TypeAlias - from narwhals._arrow.typing import ( # type: ignore[attr-defined] - AggregateOptions, - Aggregation, - ) from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq Incomplete: TypeAlias = Any -AceroTarget: TypeAlias = "tuple[()] | list[str]" -NativeAggSpec: TypeAlias = "tuple[AceroTarget, Aggregation, AggregateOptions | None]" -OutputName: TypeAlias = str -AceroAggSpec: TypeAlias = ( - "tuple[AceroTarget, Aggregation, AggregateOptions | None, OutputName]" -) - BACKEND_VERSION = Implementation.PYARROW._backend_version() -SUPPORTED_AGG: Mapping[type[agg.AggExpr], Aggregation] = { +SUPPORTED_AGG: Mapping[type[agg.AggExpr], acero.Aggregation] = { agg.Sum: "hash_sum", agg.Mean: "hash_mean", agg.Median: "hash_approximate_median", @@ -52,11 +41,11 @@ agg.First: "hash_first", agg.Last: "hash_last", } -SUPPORTED_IR: Mapping[type[ir.ExprIR], Aggregation] = { +SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { ir.Len: "hash_count_all", ir.Column: "hash_list", } -SUPPORTED_FUNCTION: Mapping[type[ir.Function], Aggregation] = { +SUPPORTED_FUNCTION: Mapping[type[ir.Function], acero.Aggregation] = { ir.boolean.All: "hash_all", ir.boolean.Any: "hash_any", ir.functions.Unique: "hash_distinct", @@ -95,17 +84,17 @@ def __init__(self, named_ir: NamedIR, /) -> None: self.named_ir: NamedIR = named_ir self.use_threads: bool = True """See https://github.com/apache/arrow/issues/36709""" - self.spec: AceroAggSpec + self.spec: acero.AggSpec @property - def output_name(self) -> OutputName: + def output_name(self) -> acero.OutputName: return self.named_ir.name def _parse_agg_expr( self, expr: agg.AggExpr - ) -> tuple[AceroTarget, Aggregation, AggregateOptions | None]: + ) -> tuple[acero.Target, acero.Aggregation, acero.Opts]: if agg_name := SUPPORTED_AGG.get(type(expr)): - option: AggregateOptions | None = None + option: acero.Opts = None if isinstance(expr, (agg.Std, agg.Var)): # NOTE: Only branch which needs an instance (for `ddof`) option = pc.VarianceOptions(ddof=expr.ddof) @@ -121,7 +110,9 @@ def _parse_agg_expr( raise group_by_error(self, "too complex") raise group_by_error(self, "unsupported aggregation") - def _parse_function_expr(self, expr: ir.FunctionExpr) -> NativeAggSpec: + def _parse_function_expr( + self, expr: ir.FunctionExpr + ) -> tuple[acero.Target, acero.Aggregation, acero.Opts]: func = expr.function if agg_name := SUPPORTED_FUNCTION.get(type(func)): if isinstance(func, (ir.boolean.All, ir.boolean.Any)): @@ -136,8 +127,8 @@ def _parse_function_expr(self, expr: ir.FunctionExpr) -> NativeAggSpec: def parse(self) -> Self: expr = self.named_ir.expr - input_name: AceroTarget = () - option: AggregateOptions | None = None + input_name: acero.Target = () + option: acero.Opts = None if isinstance(expr, agg.AggExpr): input_name, agg_name, option = self._parse_agg_expr(expr) elif isinstance(expr, ir.FunctionExpr): @@ -168,7 +159,7 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: # random null fill value null_token = f"__null_{col_token}_value__" # native - table = self.compliant.native + table: pa.Table = self.compliant.native # get key columns, cast everything to str? # make sure all either string or all large_string # separator also has to be that string type @@ -187,7 +178,7 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: # filter the keyed table to rows that have the same key (`t`) # then drop the temporary key on the result t = self.compliant._with_native( - table_w_key.filter(pc.field(col_token) == v).remove_column(0) + acero.filter_table(table_w_key, pc.field(col_token) == v).remove_column(0) ) # subset this new table to only the actual key name columns # then convert the first row to `tuple[pa.Scalar, ...]` @@ -199,37 +190,17 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: yield group_key, partition def agg(self, irs: Seq[NamedIR]) -> Frame: - aggs: list[AceroAggSpec] = [] + aggs: list[acero.AggSpec] = [] use_threads: bool = True for e in irs: expr = ArrowAggExpr(e).parse() use_threads = use_threads and expr.use_threads aggs.append(expr.spec) - result = self.compliant._with_native(self._agg(aggs, use_threads=use_threads)) + native = self.compliant.native + key_names = self.key_names + result = self.compliant._with_native( + acero.group_by_table(native, key_names, aggs, use_threads=use_threads) + ) if original := self._key_names_original: - return result.rename(dict(zip(self.key_names, original))) + return result.rename(dict(zip(key_names, original))) return result - - def _agg(self, agg_specs: list[AceroAggSpec], /, *, use_threads: bool) -> pa.Table: - """Adapted from [`pa.TableGroupBy.aggregate`] and [`pa.acero._group_by`]. - - - Backport of [apache/arrow#36768]. - - `first` and `last` were [broken in `pyarrow==13`]. - - Also allows us to specify our own aliases for aggregate output columns. - - Fixes [narwhals-dev/narwhals#1612] - - [`pa.TableGroupBy.aggregate`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/table.pxi#L6600-L6626 - [`pa.acero._group_by`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/acero.py#L412-L418 - [apache/arrow#36768]: https://github.com/apache/arrow/pull/36768 - [broken in `pyarrow==13`]: https://github.com/apache/arrow/issues/36709 - [narwhals-dev/narwhals#1612]: https://github.com/narwhals-dev/narwhals/issues/1612 - """ - df = self.compliant.native - # NOTE: Stubs are (incorrectly) invariant - keys: Incomplete = list(self.key_names) - aggs: Incomplete = agg_specs - decls = [ - pac.Declaration("table_source", pac.TableSourceNodeOptions(df)), - pac.Declaration("aggregate", pac.AggregateNodeOptions(aggs, keys=keys)), - ] - return pac.Declaration.from_sequence(decls).to_table(use_threads=use_threads) From e1482669b4569f615ed36739c576393b675787cf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 15:09:42 +0000 Subject: [PATCH 61/93] refactor: Use a single options class --- narwhals/_plan/arrow/group_by.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index f15e73be30..49d1ecaed6 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Literal import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -148,6 +148,9 @@ class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): _keys: Seq[NamedIR] _key_names: Seq[str] _key_names_original: Seq[str] + _ITER_CONCAT_STR: ClassVar[pc.JoinOptions] = pc.JoinOptions( + null_handling="replace", null_replacement="__nw_null_value__" + ) @property def compliant(self) -> Frame: @@ -156,8 +159,6 @@ def compliant(self) -> Frame: def __iter__(self) -> Iterator[tuple[Any, Frame]]: # random column name col_token = temp.column_name(self.compliant) - # random null fill value - null_token = f"__null_{col_token}_value__" # native table: pa.Table = self.compliant.native # get key columns, cast everything to str? @@ -168,9 +169,7 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: ) # join those strings horizontally to generate a single key column concat_str: Incomplete = pc.binary_join_element_wise - key_values = concat_str( - *it, separator_scalar, null_handling="replace", null_replacement=null_token - ) + key_values = concat_str(*it, separator_scalar, options=self._ITER_CONCAT_STR) # add that column (of `key_values`) back to the table table_w_key = table.add_column(i=0, field_=col_token, column=key_values) # iterate over the unique keys in the `key_values` array From 1f389df692854eb35abb4ef03444924f6180d21c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 15:28:00 +0000 Subject: [PATCH 62/93] refactor: renaming/aliasing --- narwhals/_plan/arrow/group_by.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 49d1ecaed6..25191afa03 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -158,30 +158,32 @@ def compliant(self) -> Frame: def __iter__(self) -> Iterator[tuple[Any, Frame]]: # random column name - col_token = temp.column_name(self.compliant) + temp_name = temp.column_name(self.compliant) + temp_expr = pc.field(temp_name) # native table: pa.Table = self.compliant.native + key_names = self.key_names # get key columns, cast everything to str? # make sure all either string or all large_string # separator also has to be that string type - it, separator_scalar = cast_to_comparable_string_types( - *(table[key] for key in self.key_names), separator="" + it, separator = cast_to_comparable_string_types( + *(table[key] for key in key_names), separator="" ) # join those strings horizontally to generate a single key column concat_str: Incomplete = pc.binary_join_element_wise - key_values = concat_str(*it, separator_scalar, options=self._ITER_CONCAT_STR) - # add that column (of `key_values`) back to the table - table_w_key = table.add_column(i=0, field_=col_token, column=key_values) - # iterate over the unique keys in the `key_values` array - for v in pc.unique(key_values): + composite_values = concat_str(*it, separator, options=self._ITER_CONCAT_STR) + # add that column (of `composite_values`) back to the table + re_keyed = table.add_column(0, temp_name, composite_values) + # iterate over the unique keys in the `composite_values` array + from_native = self.compliant._with_native + for v in pc.unique(composite_values): # filter the keyed table to rows that have the same key (`t`) # then drop the temporary key on the result - t = self.compliant._with_native( - acero.filter_table(table_w_key, pc.field(col_token) == v).remove_column(0) - ) + predicate = temp_expr == v + t = from_native(acero.filter_table(re_keyed, predicate).remove_column(0)) # subset this new table to only the actual key name columns # then convert the first row to `tuple[pa.Scalar, ...]` - row = t.select_names(*self.key_names).row(0) + row = t.select_names(*key_names).row(0) # convert those scalars to python literals group_key = tuple(el.as_py() for el in row) # select (all) columns from (`t`) that we started with at `.group_by()``, ignoring new keys/aliases From 9ba003857265da67d71b3b2616db53ade98d4b5a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 16:20:31 +0000 Subject: [PATCH 63/93] refactor: Split out, rewrite composite key concat --- narwhals/_plan/arrow/group_by.py | 58 ++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 25191afa03..20b0b69cf3 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -1,13 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Literal +from typing import TYPE_CHECKING, Any, Final, Literal import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import -from narwhals._arrow.utils import cast_to_comparable_string_types from narwhals._plan import expressions as ir -from narwhals._plan.arrow import acero +from narwhals._plan.arrow import acero, functions as fn from narwhals._plan.common import temp from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import EagerDataFrameGroupBy @@ -19,6 +18,7 @@ from typing_extensions import Self, TypeAlias from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame + from narwhals._plan.arrow.typing import ChunkedArray from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq @@ -143,14 +143,38 @@ def parse(self) -> Self: return self +_NULL_FILL: Final = pc.JoinOptions( + null_handling="replace", null_replacement="__nw_null_value__" +) + + +def concat_str( + native: pa.Table, + subset: Seq[str], + *, + separator: str = "", + options: pc.JoinOptions = _NULL_FILL, +) -> ChunkedArray: + # get key columns, casting everything to str + # docs says "list-like", runtime supports iterable + df = native.select(subset) # pyright: ignore[reportArgumentType] + schema = df.schema + dtype = ( + pa.string() + if not any(pa.types.is_large_string(tp) for tp in schema.types) + else pa.large_string() + ) + schema = pa.schema((name, dtype) for name in schema.names) + sep = fn.lit(separator, dtype) + concat: Incomplete = pc.binary_join_element_wise + return concat(*df.cast(schema).itercolumns(), sep, options=options) # type: ignore[no-any-return] + + class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): _df: Frame _keys: Seq[NamedIR] _key_names: Seq[str] _key_names_original: Seq[str] - _ITER_CONCAT_STR: ClassVar[pc.JoinOptions] = pc.JoinOptions( - null_handling="replace", null_replacement="__nw_null_value__" - ) @property def compliant(self) -> Frame: @@ -160,22 +184,12 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: # random column name temp_name = temp.column_name(self.compliant) temp_expr = pc.field(temp_name) - # native - table: pa.Table = self.compliant.native - key_names = self.key_names - # get key columns, cast everything to str? - # make sure all either string or all large_string - # separator also has to be that string type - it, separator = cast_to_comparable_string_types( - *(table[key] for key in key_names), separator="" - ) - # join those strings horizontally to generate a single key column - concat_str: Incomplete = pc.binary_join_element_wise - composite_values = concat_str(*it, separator, options=self._ITER_CONCAT_STR) - # add that column (of `composite_values`) back to the table - re_keyed = table.add_column(0, temp_name, composite_values) - # iterate over the unique keys in the `composite_values` array + + native = self.compliant.native + composite_values = concat_str(native, self.key_names) + re_keyed = native.add_column(0, temp_name, composite_values) from_native = self.compliant._with_native + # iterate over the unique keys in the `composite_values` array for v in pc.unique(composite_values): # filter the keyed table to rows that have the same key (`t`) # then drop the temporary key on the result @@ -183,7 +197,7 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: t = from_native(acero.filter_table(re_keyed, predicate).remove_column(0)) # subset this new table to only the actual key name columns # then convert the first row to `tuple[pa.Scalar, ...]` - row = t.select_names(*key_names).row(0) + row = t.select_names(*self.key_names).row(0) # convert those scalars to python literals group_key = tuple(el.as_py() for el in row) # select (all) columns from (`t`) that we started with at `.group_by()``, ignoring new keys/aliases From 42872e830cac2258658b375aaec5e98947a2ec64 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 16:25:44 +0000 Subject: [PATCH 64/93] refactor: Use `unique` method --- narwhals/_plan/arrow/group_by.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 20b0b69cf3..be7cc9617d 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -189,12 +189,10 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: composite_values = concat_str(native, self.key_names) re_keyed = native.add_column(0, temp_name, composite_values) from_native = self.compliant._with_native - # iterate over the unique keys in the `composite_values` array - for v in pc.unique(composite_values): + for v in composite_values.unique(): # filter the keyed table to rows that have the same key (`t`) # then drop the temporary key on the result - predicate = temp_expr == v - t = from_native(acero.filter_table(re_keyed, predicate).remove_column(0)) + t = from_native(acero.filter_table(re_keyed, temp_expr == v).remove_column(0)) # subset this new table to only the actual key name columns # then convert the first row to `tuple[pa.Scalar, ...]` row = t.select_names(*self.key_names).row(0) From d098acc81fe00948670ac96b5f7ec51ab1c105ce Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 16:29:05 +0000 Subject: [PATCH 65/93] refactor: Remove aliasing that doesn't save lines --- narwhals/_plan/arrow/group_by.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index be7cc9617d..5a34ad9c0a 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -181,13 +181,10 @@ def compliant(self) -> Frame: return self._df def __iter__(self) -> Iterator[tuple[Any, Frame]]: - # random column name temp_name = temp.column_name(self.compliant) temp_expr = pc.field(temp_name) - - native = self.compliant.native - composite_values = concat_str(native, self.key_names) - re_keyed = native.add_column(0, temp_name, composite_values) + composite_values = concat_str(self.compliant.native, self.key_names) + re_keyed = self.compliant.native.add_column(0, temp_name, composite_values) from_native = self.compliant._with_native for v in composite_values.unique(): # filter the keyed table to rows that have the same key (`t`) From 113b6a5a43d00f97c6d436d89f99fef0c9762565 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 16:32:06 +0000 Subject: [PATCH 66/93] typo --- narwhals/_plan/arrow/acero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index ecbee87237..1a836fc370 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -111,7 +111,7 @@ def group_by(keys: AggKeys, aggs: Iterable[AggSpec], /) -> Decl: def filter(*predicates: Expr, **constraints: IntoExpr) -> Decl: - """Selects rows where all expressions evaulate to True. + """Selects rows where all expressions evaluate to True. Arguments: predicates: [`Expression`](s) which must all have a return type of boolean. From b35be602fd23c01635750b9080184517e05c42f5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 17:20:51 +0000 Subject: [PATCH 67/93] perf: Remove unnecessary `remove_column` both downstream tables selections will exclude it anyway --- narwhals/_plan/arrow/group_by.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 5a34ad9c0a..c61f1ca09d 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -189,7 +189,7 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: for v in composite_values.unique(): # filter the keyed table to rows that have the same key (`t`) # then drop the temporary key on the result - t = from_native(acero.filter_table(re_keyed, temp_expr == v).remove_column(0)) + t = from_native(acero.filter_table(re_keyed, temp_expr == v)) # subset this new table to only the actual key name columns # then convert the first row to `tuple[pa.Scalar, ...]` row = t.select_names(*self.key_names).row(0) From 23565a3a6a795fce654efde25e287dfd6ba49295 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 20:52:43 +0000 Subject: [PATCH 68/93] perf: Cached, lazy-loaded options --- narwhals/_plan/arrow/acero.py | 3 +- narwhals/_plan/arrow/group_by.py | 46 ++++++++------------ narwhals/_plan/arrow/options.py | 75 ++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 28 deletions(-) create mode 100644 narwhals/_plan/arrow/options.py diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 1a836fc370..006f10f89b 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -53,7 +53,8 @@ Target: TypeAlias = OneOrListOrTuple[Field] Aggregation: TypeAlias = "_Aggregation" -Opts: TypeAlias = "_AggregateOptions | None" +AggregateOptions: TypeAlias = "_AggregateOptions" +Opts: TypeAlias = "AggregateOptions | None" OutputName: TypeAlias = str AggSpec: TypeAlias = tuple[Target, Aggregation, Opts, OutputName] diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index c61f1ca09d..dee50b6d5e 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -6,7 +6,7 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._plan import expressions as ir -from narwhals._plan.arrow import acero, functions as fn +from narwhals._plan.arrow import acero, functions as fn, options from narwhals._plan.common import temp from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import EagerDataFrameGroupBy @@ -93,36 +93,28 @@ def output_name(self) -> acero.OutputName: def _parse_agg_expr( self, expr: agg.AggExpr ) -> tuple[acero.Target, acero.Aggregation, acero.Opts]: - if agg_name := SUPPORTED_AGG.get(type(expr)): - option: acero.Opts = None - if isinstance(expr, (agg.Std, agg.Var)): - # NOTE: Only branch which needs an instance (for `ddof`) - option = pc.VarianceOptions(ddof=expr.ddof) - elif isinstance(expr, (agg.NUnique, agg.Len)): - option = pc.CountOptions(mode="all") - elif isinstance(expr, agg.Count): - option = pc.CountOptions(mode="only_valid") - elif isinstance(expr, (agg.First, agg.Last)): - option = pc.ScalarAggregateOptions(skip_nulls=False) - self.use_threads = False - if isinstance(expr.expr, ir.Column): - return [expr.expr.name], agg_name, option + tp = type(expr) + if not (agg_name := SUPPORTED_AGG.get(tp)): + raise group_by_error(self, "unsupported aggregation") + if not isinstance(expr.expr, ir.Column): raise group_by_error(self, "too complex") - raise group_by_error(self, "unsupported aggregation") + if issubclass(tp, agg.OrderableAggExpr): + self.use_threads = False + option = ( + options.variance(expr.ddof) + if isinstance(expr, (agg.Std, agg.Var)) + else options.AGG.get(tp) + ) + return ([expr.expr.name], agg_name, option) def _parse_function_expr( self, expr: ir.FunctionExpr ) -> tuple[acero.Target, acero.Aggregation, acero.Opts]: - func = expr.function - if agg_name := SUPPORTED_FUNCTION.get(type(func)): - if isinstance(func, (ir.boolean.All, ir.boolean.Any)): - option = pc.ScalarAggregateOptions(min_count=0) - else: - option = None - else: + tp = type(expr.function) + if not (agg_name := SUPPORTED_FUNCTION.get(tp)): raise group_by_error(self, "unsupported function") if len(expr.input) == 1 and isinstance(expr.input[0], ir.Column): - return [expr.input[0].name], agg_name, option + return [expr.input[0].name], agg_name, options.FUNCTION.get(tp) raise group_by_error(self, "too complex") def parse(self) -> Self: @@ -153,7 +145,7 @@ def concat_str( subset: Seq[str], *, separator: str = "", - options: pc.JoinOptions = _NULL_FILL, + join_options: pc.JoinOptions = _NULL_FILL, ) -> ChunkedArray: # get key columns, casting everything to str # docs says "list-like", runtime supports iterable @@ -167,7 +159,7 @@ def concat_str( schema = pa.schema((name, dtype) for name in schema.names) sep = fn.lit(separator, dtype) concat: Incomplete = pc.binary_join_element_wise - return concat(*df.cast(schema).itercolumns(), sep, options=options) # type: ignore[no-any-return] + return concat(*df.cast(schema).itercolumns(), sep, options=join_options) # type: ignore[no-any-return] class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): @@ -186,7 +178,7 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: composite_values = concat_str(self.compliant.native, self.key_names) re_keyed = self.compliant.native.add_column(0, temp_name, composite_values) from_native = self.compliant._with_native - for v in composite_values.unique(): + for v in composite_values.unique(): # TODO @dangotbanned: Can more of the stuff inside the loop be done in `acero`? # filter the keyed table to rows that have the same key (`t`) # then drop the temporary key on the result t = from_native(acero.filter_table(re_keyed, temp_expr == v)) diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py new file mode 100644 index 0000000000..abff7077d1 --- /dev/null +++ b/narwhals/_plan/arrow/options.py @@ -0,0 +1,75 @@ +"""Cached `pyarrow.compute` options classes, using `polars` defaults.""" + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Literal + +import pyarrow.compute as pc # ignore-banned-import + +if TYPE_CHECKING: + from collections.abc import Mapping + + from narwhals._plan import expressions as ir + from narwhals._plan.arrow import acero + from narwhals._plan.expressions import aggregation as agg + + +__all__ = ["AGG", "FUNCTION", "count", "scalar_aggregate", "variance"] + + +AGG: Mapping[type[agg.AggExpr], acero.AggregateOptions] +FUNCTION: Mapping[type[ir.Function], acero.AggregateOptions] + + +@functools.cache +def count( + mode: Literal["only_valid", "only_null", "all"] = "only_valid", +) -> pc.CountOptions: + return pc.CountOptions(mode) + + +# pyarrow defaults to ignore_nulls +# polars doesn't mention +@functools.cache +def variance( + ddof: int = 1, *, ignore_nulls: bool = True, min_count: int = 0 +) -> pc.VarianceOptions: + return pc.VarianceOptions(ddof=ddof, skip_nulls=ignore_nulls, min_count=min_count) + + +@functools.cache +def scalar_aggregate( + *, ignore_nulls: bool = False, min_count: int = 0 +) -> pc.ScalarAggregateOptions: + return pc.ScalarAggregateOptions(skip_nulls=ignore_nulls, min_count=min_count) + + +# ruff: noqa: PLW0603 +# NOTE: Using globals for lazy-loading cache +if not TYPE_CHECKING: + + def __getattr__(name: str) -> Any: + if name == "AGG": + from narwhals._plan.expressions import aggregation as agg + + global AGG + AGG = { + agg.NUnique: count("all"), + agg.Len: count("all"), + agg.Count: count("only_valid"), + agg.First: scalar_aggregate(), + agg.Last: scalar_aggregate(), + } + return AGG + if name == "FUNCTION": + from narwhals._plan.expressions import boolean + + global FUNCTION + FUNCTION = { + boolean.All: scalar_aggregate(ignore_nulls=True), + boolean.Any: scalar_aggregate(ignore_nulls=True), + } + return FUNCTION + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) From 7099e0483f24d0ea96c5665054ba47e09252c232 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 21:36:44 +0000 Subject: [PATCH 69/93] fix: Simplify, fix, optimize `ArrowDataFrame.row` No longer need to manually convert `pa.Scalar` (at least locally) --- narwhals/_plan/arrow/dataframe.py | 4 +++- narwhals/_plan/arrow/group_by.py | 5 +---- narwhals/_plan/dataframe.py | 3 +++ narwhals/_plan/protocols.py | 1 + tests/plan/compliant_test.py | 17 +++++++++++++++++ 5 files changed, 25 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index cbfeabac61..b588b59180 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -2,6 +2,7 @@ import operator from functools import reduce +from itertools import chain from typing import TYPE_CHECKING, Any, Literal, cast, overload import pyarrow as pa # ignore-banned-import @@ -146,4 +147,5 @@ def select_names(self, *column_names: str) -> Self: return self._with_native(self.native.select(list(column_names))) def row(self, index: int) -> tuple[Any, ...]: - return tuple(col[index] for col in self.native.itercolumns()) + row = self.native.slice(index, 1) + return tuple(chain.from_iterable(row.to_pydict().values())) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index dee50b6d5e..dd9ad0731b 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -183,10 +183,7 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: # then drop the temporary key on the result t = from_native(acero.filter_table(re_keyed, temp_expr == v)) # subset this new table to only the actual key name columns - # then convert the first row to `tuple[pa.Scalar, ...]` - row = t.select_names(*self.key_names).row(0) - # convert those scalars to python literals - group_key = tuple(el.as_py() for el in row) + group_key = t.select_names(*self.key_names).row(0) # select (all) columns from (`t`) that we started with at `.group_by()``, ignoring new keys/aliases partition = t.select_names(*self._column_names_original) yield group_key, partition diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 1b6975d589..4b49c04ec0 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -165,3 +165,6 @@ def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: subset = [subset] if isinstance(subset, str) else subset return self._from_compliant(self._compliant.drop_nulls(subset)) + + def row(self, index: int) -> tuple[Any, ...]: + return self._compliant.row(index) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index e1d8317522..67b1bec221 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -633,6 +633,7 @@ def to_dict( ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... def __len__(self) -> int: ... def with_row_index(self, name: str) -> Self: ... + def row(self, index: int) -> tuple[Any, ...]: ... class CompliantGroupBy(Protocol[FrameT_co]): diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index b8e9a9ccd6..7b7113e450 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -522,6 +522,23 @@ def test_first_last_expr_with_columns( assert_equal_data(result, {"result": expected_broadcast}) +@pytest.mark.parametrize( + ("index", "expected"), [(3, (None, 12, 0.9, 3, 3)), (1, (2, 5, 1.0, 1, 1))] +) +def test_row_is_py_literal( + data_indexed: dict[str, Any], index: int, expected: tuple[PythonLiteral, ...] +) -> None: + frame = nwp.DataFrame.from_native(pa.table(data_indexed)) + result = frame.row(index) + assert all(v is None or isinstance(v, (int, float)) for v in result) + assert result == expected + pytest.importorskip("polars") + import polars as pl + + polars_result = pl.DataFrame(data_indexed).row(index) + assert result == polars_result + + if TYPE_CHECKING: def test_protocol_expr() -> None: From 55b1cafab67f9f95db2d9faea96d20bd64fb65d1 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 27 Sep 2025 21:40:07 +0000 Subject: [PATCH 70/93] refactor: Clean up more of `__iter__` --- narwhals/_plan/arrow/group_by.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index dd9ad0731b..0cc7bdd5be 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -174,19 +174,15 @@ def compliant(self) -> Frame: def __iter__(self) -> Iterator[tuple[Any, Frame]]: temp_name = temp.column_name(self.compliant) - temp_expr = pc.field(temp_name) composite_values = concat_str(self.compliant.native, self.key_names) re_keyed = self.compliant.native.add_column(0, temp_name, composite_values) from_native = self.compliant._with_native for v in composite_values.unique(): # TODO @dangotbanned: Can more of the stuff inside the loop be done in `acero`? - # filter the keyed table to rows that have the same key (`t`) - # then drop the temporary key on the result - t = from_native(acero.filter_table(re_keyed, temp_expr == v)) - # subset this new table to only the actual key name columns - group_key = t.select_names(*self.key_names).row(0) - # select (all) columns from (`t`) that we started with at `.group_by()``, ignoring new keys/aliases - partition = t.select_names(*self._column_names_original) - yield group_key, partition + t = from_native(acero.filter_table(re_keyed, pc.field(temp_name) == v)) + yield ( + t.select_names(*self.key_names).row(0), + t.select_names(*self._column_names_original), + ) def agg(self, irs: Seq[NamedIR]) -> Frame: aggs: list[acero.AggSpec] = [] From ceb3b4ec4932ec2c1c5ca3a0cc56fe4f8a51de3f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 28 Sep 2025 13:49:43 +0000 Subject: [PATCH 71/93] feat: Align the two `concat_str` impls In the process, generalizes some special case handling into: - 2x `cast_*` utilities -an issue linked function for why? - 2x cached `join*` options functions `cast_schema` (and `cast_table`) accept similar values to https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.cast.html --- narwhals/_plan/arrow/functions.py | 71 +++++++++++++++++++++++-------- narwhals/_plan/arrow/group_by.py | 29 +++---------- narwhals/_plan/arrow/options.py | 20 ++++++++- narwhals/_plan/arrow/typing.py | 7 +-- 4 files changed, 84 insertions(+), 43 deletions(-) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 7a16404d3d..1fd1942b2c 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -13,11 +13,12 @@ chunked_array as _chunked_array, floordiv_compat as floordiv, ) +from narwhals._plan.arrow import options from narwhals._plan.expressions import operators as ops from narwhals._utils import Implementation if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping, Sequence + from collections.abc import Iterable, Mapping from typing_extensions import TypeIs @@ -38,17 +39,20 @@ ChunkedOrScalar, ChunkedOrScalarAny, DataType, + DataTypeRemap, DataTypeT, IntegerScalar, IntegerType, + LargeStringType, NativeScalar, Scalar, ScalarAny, ScalarT, StringScalar, + StringType, UnaryFunction, ) - from narwhals.typing import ClosedInterval + from narwhals.typing import ClosedInterval, IntoArrowSchema BACKEND_VERSION = Implementation.PYARROW._backend_version() @@ -133,6 +137,41 @@ def cast( return pc.cast(native, target_type, safe=safe) +def cast_schema( + native: pa.Schema, target_types: DataType | Mapping[str, DataType] | DataTypeRemap +) -> pa.Schema: + if isinstance(target_types, pa.DataType): + return pa.schema((name, target_types) for name in native.names) + if _is_into_pyarrow_schema(target_types): + new_schema = native + for name, dtype in target_types.items(): + index = native.get_field_index(name) + new_schema.set(index, native.field(index).with_type(dtype)) + return new_schema + return pa.schema((fld.name, target_types.get(fld.type, fld.type)) for fld in native) + + +def cast_table( + native: pa.Table, target: DataType | IntoArrowSchema | DataTypeRemap +) -> pa.Table: + s = target if isinstance(target, pa.Schema) else cast_schema(native.schema, target) + return native.cast(s) + + +def has_large_string(data_types: Iterable[DataType], /) -> bool: + return any(pa.types.is_large_string(tp) for tp in data_types) + + +def string_type(data_types: Iterable[DataType] = (), /) -> StringType | LargeStringType: + """Return a native string type, compatible with `data_types`. + + Until [apache/arrow#45717] is resolved, we need to upcast `string` to `large_string` when joining. + + [apache/arrow#45717]: https://github.com/apache/arrow/issues/45717 + """ + return pa.large_string() if has_large_string(data_types) else pa.string() + + def any_(native: Any) -> pa.BooleanScalar: return pc.any(native, min_count=0) @@ -180,21 +219,11 @@ def binary( def concat_str( *arrays: ChunkedArrayAny, separator: str = "", ignore_nulls: bool = False ) -> ChunkedArray[StringScalar]: - fn: Incomplete = pc.binary_join_element_wise - it, sep = _cast_to_comparable_string_types(arrays, separator) - return fn(*it, sep, null_handling="skip" if ignore_nulls else "emit_null") # type: ignore[no-any-return] - - -def _cast_to_comparable_string_types( - arrays: Sequence[ChunkedArrayAny], /, separator: str -) -> tuple[Iterator[ChunkedArray[StringScalar]], StringScalar]: - # Ensure `chunked_arrays` are either all `string` or all `large_string`. - dtype = ( - pa.string() - if not any(pa.types.is_large_string(obj.type) for obj in arrays) - else pa.large_string() - ) - return (obj.cast(dtype) for obj in arrays), pa.scalar(separator, dtype) + dtype = string_type(obj.type for obj in arrays) + it = (obj.cast(dtype) for obj in arrays) + concat: Incomplete = pc.binary_join_element_wise + join = options.join(ignore_nulls=ignore_nulls) + return concat(*it, lit(separator, dtype), options=join) # type: ignore[no-any-return] def int_range( @@ -260,3 +289,11 @@ def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: from narwhals._plan.arrow.series import ArrowSeries return isinstance(obj, ArrowSeries) + + +def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataType]]: + return ( + (first := next(iter(obj.items())), None) + and isinstance(first[0], str) + and isinstance(first[1], pa.DataType) + ) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 0cc7bdd5be..fba273fda9 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Final, Literal +from typing import TYPE_CHECKING, Any, Literal import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -135,31 +135,16 @@ def parse(self) -> Self: return self -_NULL_FILL: Final = pc.JoinOptions( - null_handling="replace", null_replacement="__nw_null_value__" -) - - def concat_str( - native: pa.Table, - subset: Seq[str], - *, - separator: str = "", - join_options: pc.JoinOptions = _NULL_FILL, + native: pa.Table, subset: Seq[str], *, separator: str = "" ) -> ChunkedArray: - # get key columns, casting everything to str - # docs says "list-like", runtime supports iterable + # NOTE: docs says "list-like", runtime supports iterable df = native.select(subset) # pyright: ignore[reportArgumentType] - schema = df.schema - dtype = ( - pa.string() - if not any(pa.types.is_large_string(tp) for tp in schema.types) - else pa.large_string() - ) - schema = pa.schema((name, dtype) for name in schema.names) - sep = fn.lit(separator, dtype) + dtype = fn.string_type(df.schema.types) + it = fn.cast_table(df, dtype).itercolumns() concat: Incomplete = pc.binary_join_element_wise - return concat(*df.cast(schema).itercolumns(), sep, options=join_options) # type: ignore[no-any-return] + join = options.join_replace_nulls() + return concat(*it, fn.lit(separator, dtype), options=join) # type: ignore[no-any-return] class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index abff7077d1..e661691ef3 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -15,7 +15,15 @@ from narwhals._plan.expressions import aggregation as agg -__all__ = ["AGG", "FUNCTION", "count", "scalar_aggregate", "variance"] +__all__ = [ + "AGG", + "FUNCTION", + "count", + "join", + "join_replace_nulls", + "scalar_aggregate", + "variance", +] AGG: Mapping[type[agg.AggExpr], acero.AggregateOptions] @@ -45,6 +53,16 @@ def scalar_aggregate( return pc.ScalarAggregateOptions(skip_nulls=ignore_nulls, min_count=min_count) +@functools.cache +def join(*, ignore_nulls: bool = False) -> pc.JoinOptions: + return pc.JoinOptions(null_handling="skip" if ignore_nulls else "emit_null") + + +@functools.cache +def join_replace_nulls(*, replacement: str = "__nw_null_value__") -> pc.JoinOptions: + return pc.JoinOptions(null_handling="replace", null_replacement=replacement) + + # ruff: noqa: PLW0603 # NOTE: Using globals for lazy-loading cache if not TYPE_CHECKING: diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index e633e6560e..f46c862cc7 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Protocol, overload from narwhals._typing_compat import TypeVar @@ -14,8 +14,8 @@ Int16Type, Int32Type, Int64Type, - LargeStringType, - StringType, + LargeStringType as LargeStringType, # noqa: PLC0414 + StringType as StringType, # noqa: PLC0414 Uint8Type, Uint16Type, Uint32Type, @@ -117,3 +117,4 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot NativeScalar: TypeAlias = ScalarAny BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] StoresNativeT_co = TypeVar("StoresNativeT_co", bound=StoresNative[Any], covariant=True) +DataTypeRemap: TypeAlias = Mapping[DataType, DataType] From f8882c7c5225e8d0277d1fece966134657b4c79c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 28 Sep 2025 14:22:14 +0000 Subject: [PATCH 72/93] feat: More `acero`, select before `concat_str` --- narwhals/_plan/arrow/acero.py | 38 ++++++++++++++++++++++++++------ narwhals/_plan/arrow/group_by.py | 17 ++++++-------- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 006f10f89b..dc6cceb52b 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -24,7 +24,7 @@ from narwhals.typing import SingleColSelector if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Collection, Iterable from typing_extensions import TypeAlias @@ -32,7 +32,7 @@ AggregateOptions as _AggregateOptions, Aggregation as _Aggregation, ) - from narwhals._plan.typing import Seq + from narwhals._plan.typing import OneOrIterable, Seq from narwhals.typing import NonNestedLiteral T = TypeVar("T") @@ -134,7 +134,28 @@ def select(*exprs: IntoExpr, **named_exprs: IntoExpr) -> Decl: raise NotImplementedError -# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`) +def select_names(column_names: OneOrIterable[str], *more_names: str) -> Decl: + """`select` where all args are column names.""" + if not more_names: + if isinstance(column_names, str): + return _project((pc.field(column_names),), (column_names,)) + more_names = tuple(column_names) + elif isinstance(column_names, str): + more_names = column_names, *more_names + else: + msg = f"Passing both iterable and positional inputs is not supported.\n{column_names=}\n{more_names=}" + raise NotImplementedError(msg) + return _project([pc.field(name) for name in more_names], more_names) + + +def _project(exprs: Collection[Expr], names: Collection[str]) -> Decl: + # NOTE: Both just need to be `Sized` and `Iterable` + exprs_: Incomplete = exprs + names_: Incomplete = names + return Decl("project", options=pac.ProjectNodeOptions(exprs_, names_)) + + +# TODO @dangotbanned: Docs def project(**named_exprs: Expr) -> Decl: """Make a node which executes expressions on input batches, producing batches of the same length with new columns. @@ -148,10 +169,7 @@ def project(**named_exprs: Expr) -> Decl: that return one value for each input row independent of the value of all other rows). """ - # NOTE: Both just need to be sized and iterable - names: Incomplete = named_exprs.keys() - exprs: Incomplete = named_exprs.values() - return Decl("project", options=pac.ProjectNodeOptions(exprs, names)) + return _project(names=named_exprs.keys(), exprs=named_exprs.values()) # TODO @dangotbanned: Find which option class this uses @@ -195,3 +213,9 @@ def group_by_table( # TODO @dangotbanned: Docs? def filter_table(native: pa.Table, *predicates: Expr, **constraints: Any) -> pa.Table: return collect(table_source(native), filter(*predicates, **constraints)) + + +def select_names_table( + native: pa.Table, column_names: OneOrIterable[str], *more_names: str +) -> pa.Table: + return collect(table_source(native), select_names(column_names, *more_names)) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index fba273fda9..13168ccaac 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -135,13 +135,9 @@ def parse(self) -> Self: return self -def concat_str( - native: pa.Table, subset: Seq[str], *, separator: str = "" -) -> ChunkedArray: - # NOTE: docs says "list-like", runtime supports iterable - df = native.select(subset) # pyright: ignore[reportArgumentType] - dtype = fn.string_type(df.schema.types) - it = fn.cast_table(df, dtype).itercolumns() +def concat_str(native: pa.Table, *, separator: str = "") -> ChunkedArray: + dtype = fn.string_type(native.schema.types) + it = fn.cast_table(native, dtype).itercolumns() concat: Incomplete = pc.binary_join_element_wise join = options.join_replace_nulls() return concat(*it, fn.lit(separator, dtype), options=join) # type: ignore[no-any-return] @@ -159,10 +155,11 @@ def compliant(self) -> Frame: def __iter__(self) -> Iterator[tuple[Any, Frame]]: temp_name = temp.column_name(self.compliant) - composite_values = concat_str(self.compliant.native, self.key_names) - re_keyed = self.compliant.native.add_column(0, temp_name, composite_values) + native = self.compliant.native + composite_values = concat_str(acero.select_names_table(native, self.key_names)) + re_keyed = native.add_column(0, temp_name, composite_values) from_native = self.compliant._with_native - for v in composite_values.unique(): # TODO @dangotbanned: Can more of the stuff inside the loop be done in `acero`? + for v in composite_values.unique(): t = from_native(acero.filter_table(re_keyed, pc.field(temp_name) == v)) yield ( t.select_names(*self.key_names).row(0), From 92f03d5eae153b9ee49b79f8078966b791463bb4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 28 Sep 2025 21:20:41 +0000 Subject: [PATCH 73/93] refactor: Replace `ArrowAggExpr` -> `AggSpec` --- narwhals/_plan/arrow/acero.py | 26 ++++-- narwhals/_plan/arrow/group_by.py | 142 +++++++++++++++---------------- 2 files changed, 91 insertions(+), 77 deletions(-) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index dc6cceb52b..163297327d 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -11,10 +11,11 @@ from __future__ import annotations +import functools import operator from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Any, Final, Literal, TypeVar, Union import pyarrow as pa # ignore-banned-import import pyarrow.acero as pac @@ -32,6 +33,7 @@ AggregateOptions as _AggregateOptions, Aggregation as _Aggregation, ) + from narwhals._plan.arrow.group_by import AggSpec from narwhals._plan.typing import OneOrIterable, Seq from narwhals.typing import NonNestedLiteral @@ -51,12 +53,24 @@ AggKeys: TypeAlias = "Iterable[Field] | None" -Target: TypeAlias = OneOrListOrTuple[Field] +Target: TypeAlias = ( + "OneOrListOrTuple[Expr] | OneOrListOrTuple[str] | OneOrListOrTuple[int]" +) Aggregation: TypeAlias = "_Aggregation" AggregateOptions: TypeAlias = "_AggregateOptions" Opts: TypeAlias = "AggregateOptions | None" OutputName: TypeAlias = str -AggSpec: TypeAlias = tuple[Target, Aggregation, Opts, OutputName] + +_THREAD_UNSAFE: Final = frozenset[Aggregation]( + ("hash_first", "hash_last", "first", "last") +) + + +# NOTE: ATOW there are 304 valid function names, 46 can be used for some kind of agg +# Due to expr expansion, it is very likely that we have repeat runs +@functools.lru_cache(maxsize=128) +def can_thread(function_name: str, /) -> bool: + return function_name not in _THREAD_UNSAFE # TODO @dangotbanned: Rename @@ -191,9 +205,7 @@ def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table: # NOTE: Composite functions are suffixed with `_table` -def group_by_table( - native: pa.Table, keys: AggKeys, aggs: Iterable[AggSpec], *, use_threads: bool -) -> pa.Table: +def group_by_table(native: pa.Table, keys: AggKeys, aggs: Iterable[AggSpec]) -> pa.Table: """Adapted from [`pa.TableGroupBy.aggregate`] and [`pa.acero._group_by`]. - Backport of [apache/arrow#36768]. @@ -207,6 +219,8 @@ def group_by_table( [broken in `pyarrow==13`]: https://github.com/apache/arrow/issues/36709 [narwhals-dev/narwhals#1612]: https://github.com/narwhals-dev/narwhals/issues/1612 """ + aggs = tuple(aggs) + use_threads = all(spec.use_threads for spec in aggs) return collect(table_source(native), group_by(keys, aggs), use_threads=use_threads) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 13168ccaac..9e07d9751a 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -6,11 +6,11 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._plan import expressions as ir +from narwhals._plan._guards import is_agg_expr, is_function_expr from narwhals._plan.arrow import acero, functions as fn, options from narwhals._plan.common import temp from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import EagerDataFrameGroupBy -from narwhals._utils import Implementation if TYPE_CHECKING: from collections.abc import Iterator, Mapping @@ -24,9 +24,6 @@ Incomplete: TypeAlias = Any - -BACKEND_VERSION = Implementation.PYARROW._backend_version() - SUPPORTED_AGG: Mapping[type[agg.AggExpr], acero.Aggregation] = { agg.Sum: "hash_sum", agg.Mean: "hash_mean", @@ -62,77 +59,86 @@ """ -def group_by_error( - expr: ArrowAggExpr, - reason: Literal[ - "too complex", - "unsupported aggregation", - "unsupported function", - "unsupported expression", - ], -) -> NotImplementedError: - if reason == "too complex": - msg = "Non-trivial complex aggregation found" - else: - msg = reason.title() - msg = f"{msg} in 'pyarrow.Table':\n\n{expr.named_ir!r}" - return NotImplementedError(msg) - +class AggSpec: + __slots__ = ("agg", "name", "option", "target") -class ArrowAggExpr: - def __init__(self, named_ir: NamedIR, /) -> None: - self.named_ir: NamedIR = named_ir - self.use_threads: bool = True - """See https://github.com/apache/arrow/issues/36709""" - self.spec: acero.AggSpec + def __init__( + self, + target: acero.Target, + agg: acero.Aggregation, + option: acero.Opts = None, + name: acero.OutputName = "", + ) -> None: + self.target = target + self.agg = agg + self.option = option + self.name = name or str(target) @property - def output_name(self) -> acero.OutputName: - return self.named_ir.name + def use_threads(self) -> bool: + """See https://github.com/apache/arrow/issues/36709.""" + return acero.can_thread(self.agg) + + def __iter__(self) -> Iterator[acero.Target | acero.Aggregation | acero.Opts]: + """Let's us duck-type as a 4-tuple.""" + yield from (self.target, self.agg, self.option, self.name) + + @classmethod + def from_named_ir(cls, named_ir: NamedIR) -> Self: + return cls.from_expr_ir(named_ir.expr, named_ir.name) - def _parse_agg_expr( - self, expr: agg.AggExpr - ) -> tuple[acero.Target, acero.Aggregation, acero.Opts]: + @classmethod + def from_agg_expr(cls, expr: agg.AggExpr, name: acero.OutputName) -> Self: tp = type(expr) if not (agg_name := SUPPORTED_AGG.get(tp)): - raise group_by_error(self, "unsupported aggregation") + raise group_by_error(name, expr, "unsupported aggregation") if not isinstance(expr.expr, ir.Column): - raise group_by_error(self, "too complex") - if issubclass(tp, agg.OrderableAggExpr): - self.use_threads = False + raise group_by_error(name, expr, "too complex") option = ( options.variance(expr.ddof) if isinstance(expr, (agg.Std, agg.Var)) else options.AGG.get(tp) ) - return ([expr.expr.name], agg_name, option) + return cls([expr.expr.name], agg_name, option, name) - def _parse_function_expr( - self, expr: ir.FunctionExpr - ) -> tuple[acero.Target, acero.Aggregation, acero.Opts]: + @classmethod + def from_function_expr(cls, expr: ir.FunctionExpr, name: acero.OutputName) -> Self: tp = type(expr.function) - if not (agg_name := SUPPORTED_FUNCTION.get(tp)): - raise group_by_error(self, "unsupported function") - if len(expr.input) == 1 and isinstance(expr.input[0], ir.Column): - return [expr.input[0].name], agg_name, options.FUNCTION.get(tp) - raise group_by_error(self, "too complex") - - def parse(self) -> Self: - expr = self.named_ir.expr - input_name: acero.Target = () - option: acero.Opts = None - if isinstance(expr, agg.AggExpr): - input_name, agg_name, option = self._parse_agg_expr(expr) - elif isinstance(expr, ir.FunctionExpr): - input_name, agg_name, option = self._parse_function_expr(expr) - elif isinstance(expr, (ir.Len, ir.Column)): - agg_name = SUPPORTED_IR[type(expr)] - if isinstance(expr, ir.Column): - input_name = [expr.name] - else: - raise group_by_error(self, "unsupported expression") - self.spec = input_name, agg_name, option, self.output_name - return self + if not (fn_name := SUPPORTED_FUNCTION.get(tp)): + raise group_by_error(name, expr, "unsupported function") + args = expr.input + if not (len(args) == 1 and isinstance(args[0], ir.Column)): + raise group_by_error(name, expr, "too complex") + return cls(args[0].name, fn_name, options.FUNCTION.get(tp), name) + + @classmethod + def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self: + if is_agg_expr(expr): + return cls.from_agg_expr(expr, name) + if is_function_expr(expr): + return cls.from_function_expr(expr, name) + if not isinstance(expr, (ir.Len, ir.Column)): + raise group_by_error(name, expr, "unsupported expression") + fn_name = SUPPORTED_IR[type(expr)] + return cls([expr.name] if isinstance(expr, ir.Column) else (), fn_name, name=name) + + +def group_by_error( + name: str, + expr: ir.ExprIR, + reason: Literal[ + "too complex", + "unsupported aggregation", + "unsupported function", + "unsupported expression", + ], +) -> NotImplementedError: + if reason == "too complex": + msg = "Non-trivial complex aggregation found" + else: + msg = reason.title() + msg = f"{msg} in 'pyarrow.Table':\n\n{name}={expr!r}" + return NotImplementedError(msg) def concat_str(native: pa.Table, *, separator: str = "") -> ChunkedArray: @@ -167,17 +173,11 @@ def __iter__(self) -> Iterator[tuple[Any, Frame]]: ) def agg(self, irs: Seq[NamedIR]) -> Frame: - aggs: list[acero.AggSpec] = [] - use_threads: bool = True - for e in irs: - expr = ArrowAggExpr(e).parse() - use_threads = use_threads and expr.use_threads - aggs.append(expr.spec) - native = self.compliant.native + compliant = self.compliant + native = compliant.native key_names = self.key_names - result = self.compliant._with_native( - acero.group_by_table(native, key_names, aggs, use_threads=use_threads) - ) + specs = (AggSpec.from_named_ir(e) for e in irs) + result = compliant._with_native(acero.group_by_table(native, key_names, specs)) if original := self._key_names_original: return result.rename(dict(zip(key_names, original))) return result From 9d766286d235568d6d09982cee071d69f46c20b2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 12:33:00 +0000 Subject: [PATCH 74/93] fix: Oops infinite repeats --- narwhals/_plan/options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 6f77674dff..ea77ce6e1e 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -170,7 +170,7 @@ def to_multiple(self, n_repeat: int = 1, /) -> SortMultipleOptions: nulls: Seq[bool] = (self.nulls_last,) else: desc = tuple(repeat(self.descending, n_repeat)) - nulls = tuple(repeat(self.nulls_last)) + nulls = tuple(repeat(self.nulls_last, n_repeat)) return SortMultipleOptions(descending=desc, nulls_last=nulls) From aae39366aecd202bdab80e18b30db95b871dc972 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 13:45:35 +0000 Subject: [PATCH 75/93] lil progress on `order_by`, `sort_by` In theory, we should be able to compose `over()` using combinations of: - aggregate - both scalar and hash - order_by - project - hashjoin --- narwhals/_plan/arrow/acero.py | 26 +++++++++++++++++--------- narwhals/_plan/arrow/typing.py | 3 ++- narwhals/_plan/options.py | 29 +++++++++++++++++++++-------- narwhals/_plan/typing.py | 2 ++ 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 163297327d..7b0f9f31b4 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -15,7 +15,7 @@ import operator from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any, Final, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Any, Final, TypeVar, Union import pyarrow as pa # ignore-banned-import import pyarrow.acero as pac @@ -34,7 +34,8 @@ Aggregation as _Aggregation, ) from narwhals._plan.arrow.group_by import AggSpec - from narwhals._plan.typing import OneOrIterable, Seq + from narwhals._plan.arrow.typing import NullPlacement + from narwhals._plan.typing import OneOrIterable, Order, Seq from narwhals.typing import NonNestedLiteral T = TypeVar("T") @@ -186,15 +187,22 @@ def project(**named_exprs: Expr) -> Decl: return _project(names=named_exprs.keys(), exprs=named_exprs.values()) -# TODO @dangotbanned: Find which option class this uses -def order_by( - sort_keys: tuple[tuple[str, Literal["ascending", "descending"]], ...] = (), +def _order_by( + sort_keys: Iterable[tuple[str, Order]] = (), *, - null_placement: Literal["at_start", "at_end"] = "at_end", + null_placement: NullPlacement = "at_end", ) -> Decl: - return Decl( - "order_by", pac.OrderByNodeOptions(sort_keys, null_placement=null_placement) - ) + # NOTE: There's no runtime type checking of `sort_keys` wrt shape + # Just need to be `Iterable`and unpack like a 2-tuple + # https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_compute.pyx#L77-L88 + keys: Incomplete = sort_keys + return Decl("order_by", pac.OrderByNodeOptions(keys, null_placement=null_placement)) + + +# TODO @dangotbanned: Utilize `SortMultipleOptions.to_arrow_acero` +def sort_by(*args: Any, **kwds: Any) -> Decl: + msg = "Should convert from polars args -> use `_order_by" + raise NotImplementedError(msg) # TODO @dangotbanned: Docs diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index f46c862cc7..e11e9d45c1 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Mapping -from typing import TYPE_CHECKING, Any, Protocol, overload +from typing import TYPE_CHECKING, Any, Literal, Protocol, overload from narwhals._typing_compat import TypeVar from narwhals._utils import _StoresNative as StoresNative @@ -118,3 +118,4 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] StoresNativeT_co = TypeVar("StoresNativeT_co", bound=StoresNative[Any], covariant=True) DataTypeRemap: TypeAlias = Mapping[DataType, DataType] +NullPlacement: TypeAlias = Literal["at_start", "at_end"] diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index ea77ce6e1e..303d07a097 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -9,10 +9,12 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence + import pyarrow.acero import pyarrow.compute as pc from typing_extensions import Self, TypeAlias - from narwhals._plan.typing import Accessor, OneOrIterable, Seq + from narwhals._plan.arrow.typing import NullPlacement + from narwhals._plan.typing import Accessor, OneOrIterable, Order, Seq from narwhals.typing import RankMethod DispatchOrigin: TypeAlias = Literal["expr", "__narwhals_namespace__"] @@ -193,9 +195,9 @@ def parse( nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) return SortMultipleOptions(descending=desc, nulls_last=nulls) - def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: - import pyarrow.compute as pc - + def _to_arrow_args( + self, by: Sequence[str] + ) -> tuple[Sequence[tuple[str, Order]], NullPlacement]: first = self.nulls_last[0] if len(self.nulls_last) != 1 and any(x != first for x in self.nulls_last[1:]): msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {self.nulls_last!r}" @@ -204,12 +206,23 @@ def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: descending: Iterable[bool] = repeat(self.descending[0], len(by)) else: descending = self.descending - sorting: list[tuple[str, Literal["ascending", "descending"]]] = [ + sorting = tuple[tuple[str, "Order"]]( (key, "descending" if desc else "ascending") for key, desc in zip(by, descending) - ] - placement: Literal["at_start", "at_end"] = "at_end" if first else "at_start" - return pc.SortOptions(sort_keys=sorting, null_placement=placement) + ) + return sorting, "at_end" if first else "at_start" + + def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: + import pyarrow.compute as pc + + sort_keys, placement = self._to_arrow_args(by) + return pc.SortOptions(sort_keys=sort_keys, null_placement=placement) + + def to_arrow_acero(self, by: Sequence[str]) -> pyarrow.acero.Declaration: + from narwhals._plan.arrow import acero + + sort_keys, placement = self._to_arrow_args(by) + return acero._order_by(sort_keys, null_placement=placement) class RankOptions(Immutable): diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index ac16adffea..94ad494ca1 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -110,3 +110,5 @@ IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" OneOrIterable: TypeAlias = "T | t.Iterable[T]" DataFrameT = TypeVar("DataFrameT", bound="DataFrame[t.Any, t.Any]") + +Order: TypeAlias = t.Literal["ascending", "descending"] From cb51c6741062f5037c11fd37fbd5a7eb68915f46 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:03:20 +0000 Subject: [PATCH 76/93] docs: Leave more useful notes in `group_by` --- narwhals/_plan/arrow/group_by.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 9e07d9751a..92cac30dfb 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -24,6 +24,9 @@ Incomplete: TypeAlias = Any +# NOTE: Unless stated otherwise, all aggregations have 2 variants: +# - `` (pc.Function.kind == "scalar_aggregate") +# - `hash_` (pc.Function.kind == "hash_aggregate") SUPPORTED_AGG: Mapping[type[agg.AggExpr], acero.Aggregation] = { agg.Sum: "hash_sum", agg.Mean: "hash_mean", @@ -40,18 +43,15 @@ } SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { ir.Len: "hash_count_all", - ir.Column: "hash_list", + ir.Column: "hash_list", # `hash_aggregate` only } SUPPORTED_FUNCTION: Mapping[type[ir.Function], acero.Aggregation] = { ir.boolean.All: "hash_all", ir.boolean.Any: "hash_any", - ir.functions.Unique: "hash_distinct", + ir.functions.Unique: "hash_distinct", # `hash_aggregate` only } -REQUIRES_PYARROW_20: tuple[Literal["kurtosis"], Literal["skew"]] = ( - "kurtosis", # Compute the kurtosis of values in each group - "skew", # Compute the skewness of values in each group -) +REQUIRES_PYARROW_20: tuple[Literal["kurtosis"], Literal["skew"]] = ("kurtosis", "skew") """They don't show in [our version of the stubs], but are possible in [`pyarrow>=20`]. [our version of the stubs]: https://github.com/narwhals-dev/narwhals/issues/2124#issuecomment-3191374210 From d046981a23f02a48f1a2c188d11710486e627b1e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:31:12 +0000 Subject: [PATCH 77/93] refine typing --- narwhals/_plan/arrow/acero.py | 26 ++++++++++---------------- narwhals/_plan/arrow/group_by.py | 4 ++-- narwhals/_plan/typing.py | 3 ++- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 7b0f9f31b4..a0677a8cb6 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -15,13 +15,14 @@ import operator from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any, Final, TypeVar, Union +from typing import TYPE_CHECKING, Any, Final, Union import pyarrow as pa # ignore-banned-import import pyarrow.acero as pac import pyarrow.compute as pc # ignore-banned-import from pyarrow.acero import Declaration as Decl +from narwhals._plan.typing import OneOrSeq from narwhals.typing import SingleColSelector if TYPE_CHECKING: @@ -38,11 +39,6 @@ from narwhals._plan.typing import OneOrIterable, Order, Seq from narwhals.typing import NonNestedLiteral -T = TypeVar("T") -OneOrListOrTuple: TypeAlias = Union[T, list[T], tuple[T, ...]] -"""WARNING: Don't use this unless there is a runtime check for exactly `list | tuple`.""" - - Incomplete: TypeAlias = Any Expr: TypeAlias = pc.Expression IntoExpr: TypeAlias = "Expr | NonNestedLiteral" @@ -52,11 +48,7 @@ [`_compute._ensure_field_ref`]: https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_compute.pyx#L1507-L1531 """ -AggKeys: TypeAlias = "Iterable[Field] | None" - -Target: TypeAlias = ( - "OneOrListOrTuple[Expr] | OneOrListOrTuple[str] | OneOrListOrTuple[int]" -) +Target: TypeAlias = OneOrSeq[Field] Aggregation: TypeAlias = "_Aggregation" AggregateOptions: TypeAlias = "_AggregateOptions" Opts: TypeAlias = "AggregateOptions | None" @@ -99,11 +91,11 @@ def table_source(native: pa.Table, /) -> Decl: return Decl("table_source", options=pac.TableSourceNodeOptions(native)) -def _aggregate(agg_specs: Iterable[AggSpec], /, keys: AggKeys = None) -> Decl: +def _aggregate(aggs: Iterable[AggSpec], /, keys: Iterable[Field] | None = None) -> Decl: # NOTE: See https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_acero.pyx#L167-L192 - aggs: Incomplete = agg_specs + aggs_: Incomplete = aggs keys_: Incomplete = keys - return Decl("aggregate", pac.AggregateNodeOptions(aggs, keys=keys_)) + return Decl("aggregate", pac.AggregateNodeOptions(aggs_, keys=keys_)) # TODO @dangotbanned: Plan @@ -117,7 +109,7 @@ def aggregate(aggs: Iterable[AggSpec], /) -> Decl: # TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`) -def group_by(keys: AggKeys, aggs: Iterable[AggSpec], /) -> Decl: +def group_by(keys: Iterable[Field], aggs: Iterable[AggSpec], /) -> Decl: """Hash aggregate. Like GROUP BY in SQL and first partition data based on one or more key columns, @@ -213,7 +205,9 @@ def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table: # NOTE: Composite functions are suffixed with `_table` -def group_by_table(native: pa.Table, keys: AggKeys, aggs: Iterable[AggSpec]) -> pa.Table: +def group_by_table( + native: pa.Table, keys: Iterable[Field], aggs: Iterable[AggSpec] +) -> pa.Table: """Adapted from [`pa.TableGroupBy.aggregate`] and [`pa.acero._group_by`]. - Backport of [apache/arrow#36768]. diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 92cac30dfb..13c49b79b1 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -99,7 +99,7 @@ def from_agg_expr(cls, expr: agg.AggExpr, name: acero.OutputName) -> Self: if isinstance(expr, (agg.Std, agg.Var)) else options.AGG.get(tp) ) - return cls([expr.expr.name], agg_name, option, name) + return cls(expr.expr.name, agg_name, option, name) @classmethod def from_function_expr(cls, expr: ir.FunctionExpr, name: acero.OutputName) -> Self: @@ -120,7 +120,7 @@ def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self: if not isinstance(expr, (ir.Len, ir.Column)): raise group_by_error(name, expr, "unsupported expression") fn_name = SUPPORTED_IR[type(expr)] - return cls([expr.name] if isinstance(expr, ir.Column) else (), fn_name, name=name) + return cls(expr.name if isinstance(expr, ir.Column) else (), fn_name, name=name) def group_by_error( diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 94ad494ca1..8e591b8c0b 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -97,7 +97,7 @@ T = TypeVar("T") -Seq: TypeAlias = "tuple[T,...]" +Seq: TypeAlias = tuple[T, ...] """Immutable Sequence. Using instead of `Sequence`, as a `list` can be passed there (can't break immutability promise). @@ -109,6 +109,7 @@ IntoExprColumn: TypeAlias = "Expr | Series[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" OneOrIterable: TypeAlias = "T | t.Iterable[T]" +OneOrSeq: TypeAlias = t.Union[T, Seq[T]] DataFrameT = TypeVar("DataFrameT", bound="DataFrame[t.Any, t.Any]") Order: TypeAlias = t.Literal["ascending", "descending"] From 6bacae18e40f9431948775af611fa376e625fe0a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:08:59 +0000 Subject: [PATCH 78/93] address some `acero` todos --- narwhals/_plan/arrow/acero.py | 62 ++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index a0677a8cb6..90adc8a592 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -5,6 +5,9 @@ This module aligns some apis to look more like `polars`. +Notes: + - Functions suffixed with `_table` all handle composition and collection internally + [Acero]: https://arrow.apache.org/docs/cpp/acero/overview.html [`pyarrow.acero`]: https://arrow.apache.org/docs/python/api/acero.html """ @@ -15,7 +18,7 @@ import operator from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any, Final, Union +from typing import TYPE_CHECKING, Any, Final, Union, cast import pyarrow as pa # ignore-banned-import import pyarrow.acero as pac @@ -26,7 +29,7 @@ from narwhals.typing import SingleColSelector if TYPE_CHECKING: - from collections.abc import Collection, Iterable + from collections.abc import Callable, Collection, Iterable from typing_extensions import TypeAlias @@ -57,6 +60,9 @@ _THREAD_UNSAFE: Final = frozenset[Aggregation]( ("hash_first", "hash_last", "first", "last") ) +col = pc.field +lit = cast("Callable[[NonNestedLiteral], Expr]", pc.scalar) +"""Alias for `pyarrow.compute.scalar`.""" # NOTE: ATOW there are 304 valid function names, 46 can be used for some kind of agg @@ -66,21 +72,20 @@ def can_thread(function_name: str, /) -> bool: return function_name not in _THREAD_UNSAFE -# TODO @dangotbanned: Rename -def pc_expr(into: IntoExpr, /, *, str_as_lit: bool = False) -> Expr: +def _parse_into_expr(into: IntoExpr, /, *, str_as_lit: bool = False) -> Expr: if isinstance(into, pc.Expression): return into if isinstance(into, str) and not str_as_lit: - return pc.field(into) - arg: Incomplete = into - return pc.scalar(arg) + return col(into) + return lit(into) def _parse_all_horizontal(predicates: Seq[Expr], constraints: dict[str, Any], /) -> Expr: if not constraints and len(predicates) == 1: return predicates[0] it = ( - pc.field(name) == pc_expr(v, str_as_lit=True) for name, v in constraints.items() + col(name) == _parse_into_expr(v, str_as_lit=True) + for name, v in constraints.items() ) return reduce(operator.and_, chain(predicates, it)) @@ -119,19 +124,6 @@ def group_by(keys: Iterable[Field], aggs: Iterable[AggSpec], /) -> Decl: def filter(*predicates: Expr, **constraints: IntoExpr) -> Decl: - """Selects rows where all expressions evaluate to True. - - Arguments: - predicates: [`Expression`](s) which must all have a return type of boolean. - constraints: Column filters; use `name = value` to filter columns by the supplied value. - - Notes: - - Uses logic similar to [`polars`] for an AND-reduction - - Elements where the filter does not evaluate to True are discarded, **including nulls** - - [`Expression`]: https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html - [`polars`]: https://github.com/pola-rs/polars/blob/d0914d416ce4e1dfcb5f946875ffd1181e31c493/py-polars/polars/_utils/parse/expr.py#L199-L242 - """ expr = _parse_all_horizontal(predicates, constraints) return Decl("filter", options=pac.FilterNodeOptions(expr)) @@ -145,14 +137,14 @@ def select_names(column_names: OneOrIterable[str], *more_names: str) -> Decl: """`select` where all args are column names.""" if not more_names: if isinstance(column_names, str): - return _project((pc.field(column_names),), (column_names,)) + return _project((col(column_names),), (column_names,)) more_names = tuple(column_names) elif isinstance(column_names, str): more_names = column_names, *more_names else: msg = f"Passing both iterable and positional inputs is not supported.\n{column_names=}\n{more_names=}" raise NotImplementedError(msg) - return _project([pc.field(name) for name in more_names], more_names) + return _project([col(name) for name in more_names], more_names) def _project(exprs: Collection[Expr], names: Collection[str]) -> Decl: @@ -197,14 +189,19 @@ def sort_by(*args: Any, **kwds: Any) -> Decl: raise NotImplementedError(msg) -# TODO @dangotbanned: Docs def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table: + """Compose and evaluate a logical plan. + + Arguments: + *declarations: One or more `Declaration` nodes to execute as a pipeline. + **The first node must be a `table_source`**. + use_threads: Pass `False` if `declarations` contains any order-dependent aggregation(s). + """ # NOTE: stubs + docs say `list`, but impl allows any iterable decls: Incomplete = declarations return Decl.from_sequence(decls).to_table(use_threads=use_threads) -# NOTE: Composite functions are suffixed with `_table` def group_by_table( native: pa.Table, keys: Iterable[Field], aggs: Iterable[AggSpec] ) -> pa.Table: @@ -226,8 +223,21 @@ def group_by_table( return collect(table_source(native), group_by(keys, aggs), use_threads=use_threads) -# TODO @dangotbanned: Docs? def filter_table(native: pa.Table, *predicates: Expr, **constraints: Any) -> pa.Table: + """Selects rows where all expressions evaluate to True. + + Arguments: + native: source table + predicates: [`Expression`](s) which must all have a return type of boolean. + constraints: Column filters; use `name = value` to filter columns by the supplied value. + + Notes: + - Uses logic similar to [`polars`] for an AND-reduction + - Elements where the filter does not evaluate to True are discarded, **including nulls** + + [`Expression`]: https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html + [`polars`]: https://github.com/pola-rs/polars/blob/d0914d416ce4e1dfcb5f946875ffd1181e31c493/py-polars/polars/_utils/parse/expr.py#L199-L242 + """ return collect(table_source(native), filter(*predicates, **constraints)) From f568ce09e867e5ea46675ce6cb972024e0f3ea94 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 16:16:18 +0000 Subject: [PATCH 79/93] improve `project` parsing + docs --- narwhals/_plan/arrow/acero.py | 47 +++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 90adc8a592..b9eeb3f8c2 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -29,7 +29,7 @@ from narwhals.typing import SingleColSelector if TYPE_CHECKING: - from collections.abc import Callable, Collection, Iterable + from collections.abc import Callable, Collection, Iterable, Iterator from typing_extensions import TypeAlias @@ -80,6 +80,15 @@ def _parse_into_expr(into: IntoExpr, /, *, str_as_lit: bool = False) -> Expr: return lit(into) +def _parse_into_iter_expr(inputs: Iterable[IntoExpr], /) -> Iterator[Expr]: + for into_expr in inputs: + yield _parse_into_expr(into_expr) + + +def _parse_into_seq_of_expr(inputs: Iterable[IntoExpr], /) -> Seq[Expr]: + return tuple(_parse_into_iter_expr(inputs)) + + def _parse_all_horizontal(predicates: Seq[Expr], constraints: dict[str, Any], /) -> Expr: if not constraints and len(predicates) == 1: return predicates[0] @@ -90,9 +99,11 @@ def _parse_all_horizontal(predicates: Seq[Expr], constraints: dict[str, Any], /) return reduce(operator.and_, chain(predicates, it)) -# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`) def table_source(native: pa.Table, /) -> Decl: - """A Source node which accepts a table.""" + """Start building a logical plan, using `native` as the source table. + + All calls to `collect` must use this as the first `Declaration`. + """ return Decl("table_source", options=pac.TableSourceNodeOptions(native)) @@ -154,21 +165,25 @@ def _project(exprs: Collection[Expr], names: Collection[str]) -> Decl: return Decl("project", options=pac.ProjectNodeOptions(exprs_, names_)) -# TODO @dangotbanned: Docs -def project(**named_exprs: Expr) -> Decl: - """Make a node which executes expressions on input batches, producing batches of the same length with new columns. +def project(**named_exprs: IntoExpr) -> Decl: + """Similar to `select`, but more rigid. + + Arguments: + **named_exprs: Inputs composed of any combination of + + - Column names or `pc.field` + - Python literals or `pc.scalar` (for `str` literals) + - [Scalar functions] applied to the above - This is the option class for the "project" node factory. + Notes: + - [`Expression`]s have no concept of aliasing, therefore, all inputs must be `**named_exprs`. + - Always returns a table with the same length, scalar literals are broadcast unconditionally. - The "project" operation rearranges, deletes, transforms, and - creates columns. Each output column is computed by evaluating - an expression against the source record batch. These must be - scalar expressions (expressions consisting of scalar literals, - field references and scalar functions, i.e. elementwise functions - that return one value for each input row independent of the value - of all other rows). + [`Expression`]: https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html + [Scalar functions]: https://arrow.apache.org/docs/cpp/compute.html#element-wise-scalar-functions """ - return _project(names=named_exprs.keys(), exprs=named_exprs.values()) + exprs = _parse_into_seq_of_expr(named_exprs.values()) + return _project(names=named_exprs.keys(), exprs=exprs) def _order_by( @@ -228,7 +243,7 @@ def filter_table(native: pa.Table, *predicates: Expr, **constraints: Any) -> pa. Arguments: native: source table - predicates: [`Expression`](s) which must all have a return type of boolean. + predicates: [`Expression`]s which must all have a return type of boolean. constraints: Column filters; use `name = value` to filter columns by the supplied value. Notes: From f04146dfb89216849a0c75bf2fb802c5efd40876 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 16:24:36 +0000 Subject: [PATCH 80/93] finish most remaining `acero` todos --- narwhals/_plan/arrow/acero.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index b9eeb3f8c2..768248e312 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -114,22 +114,18 @@ def _aggregate(aggs: Iterable[AggSpec], /, keys: Iterable[Field] | None = None) return Decl("aggregate", pac.AggregateNodeOptions(aggs_, keys=keys_)) -# TODO @dangotbanned: Plan -# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`) def aggregate(aggs: Iterable[AggSpec], /) -> Decl: - """Scalar aggregate. + """May only use [Scalar aggregate] functions. - Reduce an array or scalar input to a single scalar output (e.g. computing the mean of a column) + [Scalar aggregate]: https://arrow.apache.org/docs/cpp/compute.html#aggregations """ return _aggregate(aggs) -# TODO @dangotbanned: Docs (currently copy/paste from `pyarrow`) def group_by(keys: Iterable[Field], aggs: Iterable[AggSpec], /) -> Decl: - """Hash aggregate. + """May only use [Hash aggregate] functions, requires grouping. - Like GROUP BY in SQL and first partition data based on one or more key columns, - then reduce the data in each partition. + [Hash aggregate]: https://arrow.apache.org/docs/cpp/compute.html#grouped-aggregations-group-by """ return _aggregate(aggs, keys=keys) @@ -139,11 +135,6 @@ def filter(*predicates: Expr, **constraints: IntoExpr) -> Decl: return Decl("filter", options=pac.FilterNodeOptions(expr)) -# TODO @dangotbanned: Plan -def select(*exprs: IntoExpr, **named_exprs: IntoExpr) -> Decl: - raise NotImplementedError - - def select_names(column_names: OneOrIterable[str], *more_names: str) -> Decl: """`select` where all args are column names.""" if not more_names: From 575f07f9253cbf09b15d527155a01ec02656a657 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 18:02:59 +0000 Subject: [PATCH 81/93] make `__getattr__` more visible --- narwhals/_plan/arrow/options.py | 44 +++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index e661691ef3..8998b288a2 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -1,4 +1,8 @@ -"""Cached `pyarrow.compute` options classes, using `polars` defaults.""" +"""Cached `pyarrow.compute` options classes, using `polars` defaults. + +Important: + `AGG` and `FUNCTION` mappings are constructed on first `__getattr__` access. +""" from __future__ import annotations @@ -63,31 +67,39 @@ def join_replace_nulls(*, replacement: str = "__nw_null_value__") -> pc.JoinOpti return pc.JoinOptions(null_handling="replace", null_replacement=replacement) +def _generate_agg() -> Mapping[type[agg.AggExpr], acero.AggregateOptions]: + from narwhals._plan.expressions import aggregation as agg + + return { + agg.NUnique: count("all"), + agg.Len: count("all"), + agg.Count: count("only_valid"), + agg.First: scalar_aggregate(), + agg.Last: scalar_aggregate(), + } + + +def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: + from narwhals._plan.expressions import boolean + + return { + boolean.All: scalar_aggregate(ignore_nulls=True), + boolean.Any: scalar_aggregate(ignore_nulls=True), + } + + # ruff: noqa: PLW0603 # NOTE: Using globals for lazy-loading cache if not TYPE_CHECKING: def __getattr__(name: str) -> Any: if name == "AGG": - from narwhals._plan.expressions import aggregation as agg - global AGG - AGG = { - agg.NUnique: count("all"), - agg.Len: count("all"), - agg.Count: count("only_valid"), - agg.First: scalar_aggregate(), - agg.Last: scalar_aggregate(), - } + AGG = _generate_agg() return AGG if name == "FUNCTION": - from narwhals._plan.expressions import boolean - global FUNCTION - FUNCTION = { - boolean.All: scalar_aggregate(ignore_nulls=True), - boolean.Any: scalar_aggregate(ignore_nulls=True), - } + FUNCTION = _generate_function() return FUNCTION msg = f"module {__name__!r} has no attribute {name!r}" raise AttributeError(msg) From 2e04ed5b13b1a69974f341687569626016c70a23 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 18:07:51 +0000 Subject: [PATCH 82/93] refactor: Move new `DataFrame` impls up --- narwhals/_plan/dataframe.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 4b49c04ec0..8956c33457 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -88,10 +88,11 @@ def sort( return self._from_compliant(self._compliant.sort(named_irs, opts)) def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: - raise NotImplementedError + return self._from_compliant(self._compliant.drop(columns, strict=strict)) def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: - raise NotImplementedError + subset = [subset] if isinstance(subset, str) else subset + return self._from_compliant(self._compliant.drop_nulls(subset)) class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): @@ -159,12 +160,5 @@ def group_by( self ) - def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: - return self._from_compliant(self._compliant.drop(columns, strict=strict)) - - def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: - subset = [subset] if isinstance(subset, str) else subset - return self._from_compliant(self._compliant.drop_nulls(subset)) - def row(self, index: int) -> tuple[Any, ...]: return self._compliant.row(index) From e111cef2999d4d12a0f67d5618353ee1769b3066 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 18:23:38 +0000 Subject: [PATCH 83/93] docs: Explain new group bits --- narwhals/_plan/protocols.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 67b1bec221..8818243c32 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -778,7 +778,7 @@ def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: class Grouper(Protocol[ResolverT_co]): - """Revised interface focused on the state change + expression projections. + """`GroupBy` helper for collecting and forwarding `Expr`s for projection. - Uses `Expr` everywhere (no need to duplicate layers) - Resolver only needs schema (neither needs a frame, but can use one to get `schema`) @@ -802,6 +802,7 @@ def agg(self, *aggs: OneOrIterable[IntoExpr]) -> Self: def _resolver(self) -> type[ResolverT_co]: ... def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co: + """Project keys and aggs in `context`, expanding all `Expr` -> `NamedIR`.""" return self._resolver.from_grouper(self, context) @@ -837,6 +838,7 @@ def schema(self) -> FrozenSchema: return self._schema def evaluate(self, frame: DataFrameT) -> DataFrameT: + """Perform the `group_by` on `frame`.""" return frame.group_by_resolver(self).agg(self.aggs) @classmethod From 0f63dbf53ac8711aeb490ec8fa4e02509fb8a351 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 18:28:13 +0000 Subject: [PATCH 84/93] chore: move all group_by stuff to the end Will make for an easier diff in the PR that splits up this mess --- narwhals/_plan/protocols.py | 150 ++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 8818243c32..cff5e790e8 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -636,37 +636,6 @@ def with_row_index(self, name: str) -> Self: ... def row(self, index: int) -> tuple[Any, ...]: ... -class CompliantGroupBy(Protocol[FrameT_co]): - @property - def compliant(self) -> FrameT_co: ... - def agg(self, irs: Seq[NamedIR]) -> FrameT_co: ... - - -class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): - _keys: Seq[NamedIR] - _key_names: Seq[str] - - @classmethod - def from_resolver( - cls, df: DataFrameT, resolver: GroupByResolver, / - ) -> DataFrameGroupBy[DataFrameT]: ... - @classmethod - def by_names( - cls, df: DataFrameT, names: Seq[str], / - ) -> DataFrameGroupBy[DataFrameT]: ... - def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... - @property - def keys(self) -> Seq[NamedIR]: - return self._keys - - @property - def key_names(self) -> Seq[str]: - if names := self._key_names: - return names - msg = "at least one key is required in a group_by operation" - raise ComputeError(msg) - - class EagerDataFrame( CompliantDataFrame[SeriesT, NativeDataFrameT, NativeSeriesT], Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], @@ -681,43 +650,6 @@ def with_columns(self, irs: Seq[NamedIR]) -> Self: return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) -class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]): - _df: EagerDataFrameT - _key_names: Seq[str] - _key_names_original: Seq[str] - _column_names_original: Seq[str] - - @classmethod - def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: - obj = cls.__new__(cls) - obj._df = df - obj._keys = () - obj._key_names = names - obj._key_names_original = () - obj._column_names_original = tuple(df.columns) - return obj - - @classmethod - def from_resolver( - cls, df: EagerDataFrameT, resolver: GroupByResolver, / - ) -> EagerDataFrameGroupBy[EagerDataFrameT]: - key_names = resolver.key_names - if not resolver.requires_projection(): - df = df.drop_nulls(key_names) if resolver._drop_null_keys else df - return cls.by_names(df, key_names) - obj = cls.__new__(cls) - unique_names = temp.column_names(chain(key_names, df.columns)) - safe_keys = tuple( - replace(key, name=name) for key, name in zip(resolver.keys, unique_names) - ) - obj._df = df.with_columns(resolver._schema_in.with_columns_irs(safe_keys)) - obj._keys = safe_keys - obj._key_names = tuple(e.name for e in safe_keys) - obj._key_names_original = key_names - obj._column_names_original = resolver._schema_in.names - return obj - - class CompliantSeries(StoresVersion, Protocol[NativeSeriesT]): _native: NativeSeriesT _name: str @@ -777,6 +709,74 @@ def to_list(self) -> list[Any]: ... def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ... +class CompliantGroupBy(Protocol[FrameT_co]): + @property + def compliant(self) -> FrameT_co: ... + def agg(self, irs: Seq[NamedIR]) -> FrameT_co: ... + + +class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): + _keys: Seq[NamedIR] + _key_names: Seq[str] + + @classmethod + def from_resolver( + cls, df: DataFrameT, resolver: GroupByResolver, / + ) -> DataFrameGroupBy[DataFrameT]: ... + @classmethod + def by_names( + cls, df: DataFrameT, names: Seq[str], / + ) -> DataFrameGroupBy[DataFrameT]: ... + def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... + @property + def keys(self) -> Seq[NamedIR]: + return self._keys + + @property + def key_names(self) -> Seq[str]: + if names := self._key_names: + return names + msg = "at least one key is required in a group_by operation" + raise ComputeError(msg) + + +class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]): + _df: EagerDataFrameT + _key_names: Seq[str] + _key_names_original: Seq[str] + _column_names_original: Seq[str] + + @classmethod + def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: + obj = cls.__new__(cls) + obj._df = df + obj._keys = () + obj._key_names = names + obj._key_names_original = () + obj._column_names_original = tuple(df.columns) + return obj + + @classmethod + def from_resolver( + cls, df: EagerDataFrameT, resolver: GroupByResolver, / + ) -> EagerDataFrameGroupBy[EagerDataFrameT]: + key_names = resolver.key_names + if not resolver.requires_projection(): + df = df.drop_nulls(key_names) if resolver._drop_null_keys else df + return cls.by_names(df, key_names) + obj = cls.__new__(cls) + unique_names = temp.column_names(chain(key_names, df.columns)) + safe_keys = tuple( + replace(key, name=name) for key, name in zip(resolver.keys, unique_names) + ) + obj._df = df.with_columns(resolver._schema_in.with_columns_irs(safe_keys)) + obj._keys = safe_keys + obj._key_names = tuple(e.name for e in safe_keys) + obj._key_names_original = key_names + obj._column_names_original = resolver._schema_in.names + return obj + + class Grouper(Protocol[ResolverT_co]): """`GroupBy` helper for collecting and forwarding `Expr`s for projection. @@ -873,7 +873,13 @@ def requires_projection(self, *, allow_aliasing: bool = False) -> bool: return False -class Grouped(Grouper["Resolved"]): +class Resolved(GroupByResolver): + """Compliant-level `GroupBy` resolver.""" + + _drop_null_keys: bool = False + + +class Grouped(Grouper[Resolved]): """Compliant-level `GroupBy` builder.""" _keys: Seq[ExprIR] @@ -883,9 +889,3 @@ class Grouped(Grouper["Resolved"]): @property def _resolver(self) -> type[Resolved]: return Resolved - - -class Resolved(GroupByResolver): - """Compliant-level `GroupBy` resolver.""" - - _drop_null_keys: bool = False From 6b3018e5854c4271ea0b6ec9f67f79e1c766164f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 18:36:54 +0000 Subject: [PATCH 85/93] minor nits --- narwhals/_plan/typing.py | 1 - tests/plan/group_by_test.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 8e591b8c0b..2a734488a6 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -111,5 +111,4 @@ OneOrIterable: TypeAlias = "T | t.Iterable[T]" OneOrSeq: TypeAlias = t.Union[T, Seq[T]] DataFrameT = TypeVar("DataFrameT", bound="DataFrame[t.Any, t.Any]") - Order: TypeAlias = t.Literal["ascending", "descending"] diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 3a5ffd162a..22fce71089 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -245,7 +245,7 @@ def test_group_by_expr_iter() -> None: result = dict(sorted((k, df.sort("c").to_dict(as_series=False)) for k, df in grouped)) assert len(result) == len(expected) assert result.keys() == expected.keys() - # NOTE: The bug means that zipping will break, as one side has more columns + # NOTE: The bug this is trying to avoid regressing on would break zipping, as one side has more columns result_p1 = next(iter(result.values())) expected_p1 = next(iter(expected.values())) assert result_p1 == expected_p1 @@ -278,7 +278,7 @@ def test_no_agg() -> None: "https://github.com/narwhals-dev/narwhals/issues/1078" ), ) -def test_group_by_categorical() -> None: # pragma: no cover +def test_group_by_categorical() -> None: data = {"g1": ["a", "a", "b", "b"], "g2": ["x", "y", "x", "z"], "x": [1, 2, 3, 4]} df = dataframe(data) result = ( From e0a3684ba4f77a59fe45d2914c74b4cff25cf344 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 29 Sep 2025 21:44:21 +0000 Subject: [PATCH 86/93] feat: Improve error reporting Resolves (https://github.com/narwhals-dev/narwhals/pull/3143#discussion_r2388886025) --- narwhals/_plan/arrow/group_by.py | 34 +++++++++++++++--------------- tests/plan/group_by_test.py | 36 +++++++++++++++++++++++++------- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 13c49b79b1..c878f344ed 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -8,9 +8,11 @@ from narwhals._plan import expressions as ir from narwhals._plan._guards import is_agg_expr, is_function_expr from narwhals._plan.arrow import acero, functions as fn, options -from narwhals._plan.common import temp +from narwhals._plan.common import dispatch_method_name, temp from narwhals._plan.expressions import aggregation as agg from narwhals._plan.protocols import EagerDataFrameGroupBy +from narwhals._utils import Implementation +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from collections.abc import Iterator, Mapping @@ -91,7 +93,7 @@ def from_named_ir(cls, named_ir: NamedIR) -> Self: def from_agg_expr(cls, expr: agg.AggExpr, name: acero.OutputName) -> Self: tp = type(expr) if not (agg_name := SUPPORTED_AGG.get(tp)): - raise group_by_error(name, expr, "unsupported aggregation") + raise group_by_error(name, expr) if not isinstance(expr.expr, ir.Column): raise group_by_error(name, expr, "too complex") option = ( @@ -105,7 +107,7 @@ def from_agg_expr(cls, expr: agg.AggExpr, name: acero.OutputName) -> Self: def from_function_expr(cls, expr: ir.FunctionExpr, name: acero.OutputName) -> Self: tp = type(expr.function) if not (fn_name := SUPPORTED_FUNCTION.get(tp)): - raise group_by_error(name, expr, "unsupported function") + raise group_by_error(name, expr) args = expr.input if not (len(args) == 1 and isinstance(args[0], ir.Column)): raise group_by_error(name, expr, "too complex") @@ -118,27 +120,25 @@ def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self: if is_function_expr(expr): return cls.from_function_expr(expr, name) if not isinstance(expr, (ir.Len, ir.Column)): - raise group_by_error(name, expr, "unsupported expression") + raise group_by_error(name, expr) fn_name = SUPPORTED_IR[type(expr)] return cls(expr.name if isinstance(expr, ir.Column) else (), fn_name, name=name) def group_by_error( - name: str, - expr: ir.ExprIR, - reason: Literal[ - "too complex", - "unsupported aggregation", - "unsupported function", - "unsupported expression", - ], -) -> NotImplementedError: + column_name: str, expr: ir.ExprIR, reason: Literal["too complex"] | None = None +) -> InvalidOperationError: + backend = Implementation.PYARROW if reason == "too complex": - msg = "Non-trivial complex aggregation found" + msg = "Non-trivial complex aggregation found, which" else: - msg = reason.title() - msg = f"{msg} in 'pyarrow.Table':\n\n{name}={expr!r}" - return NotImplementedError(msg) + if is_function_expr(expr): + func_name = repr(expr.function) + else: + func_name = dispatch_method_name(type(expr)) + msg = f"`{func_name}()`" + msg = f"{msg} is not supported in a `group_by` context for {backend!r}:\n{column_name}={expr!r}" + return InvalidOperationError(msg) def concat_str(native: pa.Table, *, separator: str = "") -> ChunkedArray: diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index 22fce71089..2b60c118db 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -293,12 +293,34 @@ def test_group_by_categorical() -> None: assert_equal_data(result, data) -# TODO @dangotbanned: Align the error to `InvalidOperation` -def test_group_by_shift_raises() -> None: - data = {"a": [1, 2, 3], "b": [1, 1, 2]} - df = dataframe(data) - with pytest.raises((InvalidOperationError, NotImplementedError)): - df.group_by("b").agg(nwp.col("a").shift(1)) +@pytest.mark.parametrize( + ("agg", "message_body", "expected_repr"), + [ + (nwp.col("a").shift(1), r"shift.+not.+group_by.+pyarrow.+", "col('a').shift("), + ( + nwp.col("a").arg_max(), + r"arg_max.+not.+group_by.+pyarrow.+", + "col('a').arg_max(", + ), + ( + nwp.col("a").max().over("b"), + r"over.+not.+group_by.+pyarrow.+", + "col('a').max().over([col('b')])", + ), + ( + nwp.col("a").drop_nulls().abs().mean(), + r"complex aggregation found.+not.+group_by.+pyarrow.+", + "col('a').drop_nulls().abs().mean()", + ), + ], +) +def test_group_by_unsupported_raises( + agg: nwp.Expr, message_body: str, expected_repr: str +) -> None: + df = dataframe({"a": [1, 2, 3], "b": [1, 1, 2]}) + pat = re.compile(rf"{message_body}{re.escape(expected_repr)}", re.DOTALL) + with pytest.raises(InvalidOperationError, match=pat): + df.group_by("b").agg(agg) def test_double_same_aggregation() -> None: @@ -648,7 +670,7 @@ def test_group_by_series_lit_22103() -> None: data = {"g": [0, 1]} series = nwp.Series.from_native(pa.chunked_array([[42, 2, 3]])) df = dataframe(data) - with pytest.raises(NotImplementedError, match=re.escape("foo=lit(Series)")): + with pytest.raises(InvalidOperationError, match=re.escape("foo=lit(Series)")): df.group_by("g").agg(foo=series) From bfd0fe87736b9b7f74714d5702acef6985b1579c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 30 Sep 2025 13:10:17 +0000 Subject: [PATCH 87/93] feat: Improve `temp.column_name` error message --- narwhals/_plan/common.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index bbf6f4385c..c5bd6cde84 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -131,6 +131,7 @@ class temp: # noqa: N801 """Temporary mini namespace for temporary utils.""" _MAX_ITERATIONS: ClassVar[int] = 100 + _MIN_RANDOM_CHARS: ClassVar[int] = 4 @classmethod def column_name( @@ -180,18 +181,31 @@ def column_names( def _into_columns(source: _StoresColumns | Iterable[str], /) -> set[str]: return set(source.columns if _has_columns(source) else source) - @staticmethod - def _parse_prefix_n_bytes(prefix: str, n_chars: int, /) -> tuple[str, int]: + @classmethod + def _parse_prefix_n_bytes(cls, prefix: str, n_chars: int, /) -> tuple[str, int]: prefix = prefix or "nw" - n_bytes = (n_chars - len(prefix)) // 2 - if n_bytes < 2: - msg = ( - f"Temporary column name generation requires at least 4 characters to store random bytes, \n" - f"but not enough room with: {prefix=}, {n_chars=}.\n\n" - "Hint: Maybe try\n- a shorter `prefix`?\n- a higher `n_chars`?" - ) - raise NarwhalsError(msg) - return prefix, n_bytes + if not (available := n_chars - len(prefix)) or available < cls._MIN_RANDOM_CHARS: + raise cls._not_enough_room_error(prefix, n_chars) + return prefix, available // 2 + + @classmethod + def _not_enough_room_error(cls, prefix: str, n_chars: int, /) -> NarwhalsError: + len_prefix = len(prefix) + available_chars = n_chars - len_prefix + if available_chars < 0: + visualize = "" + else: + okay = "✔" * available_chars + bad = "✖" * (cls._MIN_RANDOM_CHARS - available_chars) + visualize = f"\n Preview: '{prefix}{okay}{bad}'" + msg = ( + f"Temporary column name generation requires {len_prefix} characters for the prefix " + f"and at least {cls._MIN_RANDOM_CHARS} more to store random bytes:{visualize}\n\n" + f"Hint: Maybe try\n" + f"- a shorter `prefix` than {prefix!r}?\n" + f"- a higher `n_chars` than {n_chars!r}?" + ) + return NarwhalsError(msg) @classmethod def _failed_generation_error( From 1c4210bb338c34d115f20fadb4794f17e5cb4a1f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 30 Sep 2025 14:09:17 +0000 Subject: [PATCH 88/93] test: Add tests for `temp.column_name` Towards https://github.com/narwhals-dev/narwhals/pull/3143#discussion_r2372394167 --- tests/plan/temp_test.py | 99 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 tests/plan/temp_test.py diff --git a/tests/plan/temp_test.py b/tests/plan/temp_test.py new file mode 100644 index 0000000000..ec644761e7 --- /dev/null +++ b/tests/plan/temp_test.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import random +import re +import string + +# ruff: noqa: S311 +from collections import deque +from itertools import product, repeat +from typing import TYPE_CHECKING, NamedTuple + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +import narwhals as nw +from narwhals._plan.common import temp +from narwhals._utils import qualified_type_name +from narwhals.exceptions import NarwhalsError + +pytest.importorskip("pyarrow") +pytest.importorskip("polars") + + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from narwhals._utils import _StoresColumns + + +class MockStoresColumns(NamedTuple): + columns: Sequence[str] + + +_COLUMNS = ("abc", "XYZ", "nw2929023", "column", string.hexdigits) +_EMPTY_SCHEMA = nw.Schema((name, nw.Int64()) for name in _COLUMNS) + + +sources = pytest.mark.parametrize( + "source", + [ + _COLUMNS, + MockStoresColumns(columns=_COLUMNS), + deque(_COLUMNS), + nw.from_dict({}, _EMPTY_SCHEMA, backend="pyarrow"), + dict.fromkeys(_COLUMNS), + set(_COLUMNS), + nw.from_dict({}, _EMPTY_SCHEMA, backend="polars").to_native(), + ], + ids=qualified_type_name, +) + + +@sources +def test_temp_column_name_sources(source: _StoresColumns | Iterable[str]) -> None: + name = temp.column_name(source) + assert name not in _COLUMNS + + +@given(n_chars=st.integers(6, 106)) +@pytest.mark.slow +def test_column_name_n_chars(n_chars: int) -> None: + name = temp.column_name(_COLUMNS, n_chars=n_chars) + assert name not in _COLUMNS + + +@pytest.mark.parametrize( + ("prefix", "n_chars"), + [ + ("nw", random.randint(0, 5)), + ("col", random.randint(0, 4)), + ("NW_", random.randint(0, 3)), + ("join", random.randint(0, 2)), + ("__tmp", random.randint(0, 1)), + ("longer", random.randint(-5, 0)), + ("", random.randint(0, 5)), + ], +) +def test_column_name_requires_more_characters(prefix: str, n_chars: int) -> None: + pattern = re.compile( + rf"temp.+column.+name.+requires.+try.+shorter.+{prefix}.+higher.+{n_chars}", + re.IGNORECASE | re.DOTALL, + ) + with pytest.raises(NarwhalsError, match=pattern): + temp.column_name(_COLUMNS, prefix=prefix, n_chars=n_chars) + + +def test_column_name_failed_unique() -> None: + hex_lower = string.hexdigits.strip(string.ascii_uppercase) + every_possible_name_65k = [ + f"nw{e1}{e2}{e3}{e4}" for e1, e2, e3, e4 in product(*repeat(hex_lower, 4)) + ] + n_many_columns = len(every_possible_name_65k) + + pattern = re.compile( + rf"unable.+generate.+name.+within.+existing.+{n_many_columns}.+columns", re.DOTALL + ) + with pytest.raises(NarwhalsError, match=pattern): + temp.column_name(every_possible_name_65k, prefix="nw", n_chars=6) From 1fe1a892c465b805cfc4aff245e4c392e783042d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:33:42 +0000 Subject: [PATCH 89/93] fix: Omit `indent` before `3.12` Resolves https://github.com/narwhals-dev/narwhals/pull/3143#discussion_r2391682435 --- narwhals/_plan/common.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index c5bd6cde84..398bfa2e87 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -132,6 +132,13 @@ class temp: # noqa: N801 _MAX_ITERATIONS: ClassVar[int] = 100 _MIN_RANDOM_CHARS: ClassVar[int] = 4 + _REPRLIB_REPR_KWDS: ClassVar[dict[str, Any]] = ( + {"indent": 4, "maxlist": 10} if sys.version_info >= (3, 12) else {"maxlist": 10} + ) + """Version-dependent arguments for `reprlib.Repr`. + + See https://github.com/python/cpython/issues/92734 + """ @classmethod def column_name( @@ -210,7 +217,7 @@ def _not_enough_room_error(cls, prefix: str, n_chars: int, /) -> NarwhalsError: @classmethod def _failed_generation_error( cls, columns: Iterable[str], n_bytes: int, / - ) -> NarwhalsError: # pragma: no cover + ) -> NarwhalsError: """Takes some work to trigger this, but it's possible 😅. Examples: @@ -239,7 +246,7 @@ def _failed_generation_error( import reprlib current = sorted(columns) - truncated = reprlib.Repr(indent=4, maxlist=10).repr(current) + truncated = reprlib.Repr(**cls._REPRLIB_REPR_KWDS).repr(current) msg = ( "Was unable to generate a column name with " f"`{n_bytes=}` within {cls._MAX_ITERATIONS} iterations, \n" From 0c8921bd168c0c3c5caf91c1e34ffc8a7f8ec60a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:56:24 +0000 Subject: [PATCH 90/93] fix: welp no constructor --- narwhals/_plan/common.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 398bfa2e87..3ca9d43f97 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -16,6 +16,7 @@ from narwhals.utils import Version if TYPE_CHECKING: + import reprlib from collections.abc import Iterator from typing import Any, Callable, ClassVar, TypeVar @@ -127,18 +128,24 @@ def _has_columns(obj: Any) -> TypeIs[_StoresColumns]: return _hasattr_static(obj, "columns") +def _reprlib_repr_backport() -> reprlib.Repr: + # 3.12 added `indent` https://github.com/python/cpython/issues/92734 + # but also a useful constructor https://github.com/python/cpython/issues/94343 + import reprlib + + if sys.version_info >= (3, 12): + return reprlib.Repr(indent=4, maxlist=10) + else: # pragma: no cover # noqa: RET505 + obj = reprlib.Repr() + obj.maxlist = 10 + return obj + + class temp: # noqa: N801 """Temporary mini namespace for temporary utils.""" _MAX_ITERATIONS: ClassVar[int] = 100 _MIN_RANDOM_CHARS: ClassVar[int] = 4 - _REPRLIB_REPR_KWDS: ClassVar[dict[str, Any]] = ( - {"indent": 4, "maxlist": 10} if sys.version_info >= (3, 12) else {"maxlist": 10} - ) - """Version-dependent arguments for `reprlib.Repr`. - - See https://github.com/python/cpython/issues/92734 - """ @classmethod def column_name( @@ -243,10 +250,8 @@ def _failed_generation_error( ..., ] """ - import reprlib - current = sorted(columns) - truncated = reprlib.Repr(**cls._REPRLIB_REPR_KWDS).repr(current) + truncated = _reprlib_repr_backport().repr(current) msg = ( "Was unable to generate a column name with " f"`{n_bytes=}` within {cls._MAX_ITERATIONS} iterations, \n" From 8b605ede62942303523ca3df04872f4cfac8b337 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 30 Sep 2025 17:14:47 +0000 Subject: [PATCH 91/93] test: Add `test_temp_column_names_failed_unique` well, move it out of the docstring is more accurate i suppose --- narwhals/_plan/common.py | 33 ++++----------------------------- tests/plan/temp_test.py | 21 ++++++++++++++++----- 2 files changed, 20 insertions(+), 34 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 3ca9d43f97..ef2955b441 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -163,7 +163,7 @@ def column_name( token = f"{prefix}{token_hex(n_bytes)}" if token not in columns: return token - raise cls._failed_generation_error(columns, n_bytes) + raise cls._failed_generation_error(columns, n_chars) @classmethod def column_names( @@ -189,7 +189,7 @@ def column_names( yield token else: n_failed += 1 - raise cls._failed_generation_error(columns, n_bytes) + raise cls._failed_generation_error(columns, n_chars) @staticmethod def _into_columns(source: _StoresColumns | Iterable[str], /) -> set[str]: @@ -223,38 +223,13 @@ def _not_enough_room_error(cls, prefix: str, n_chars: int, /) -> NarwhalsError: @classmethod def _failed_generation_error( - cls, columns: Iterable[str], n_bytes: int, / + cls, columns: Iterable[str], n_chars: int, / ) -> NarwhalsError: - """Takes some work to trigger this, but it's possible 😅. - - Examples: - >>> import itertools - >>> from narwhals._plan.common import temp - >>> it = temp.column_names(["a", "b", "c"], prefix="long_prefix") - >>> list(itertools.islice(it, 100_000)) # doctest:+SKIP - Traceback (most recent call last): - ... - NarwhalsError: Was unable to generate a column name with `n_bytes=2` within 100 iterations, - that was not present in existing (60246) columns: - [ - 'a', - 'b', - 'c', - 'long_prefix0000', - 'long_prefix0003', - 'long_prefix0004', - 'long_prefix0005', - 'long_prefix0006', - 'long_prefix0007', - 'long_prefix0008', - ..., - ] - """ current = sorted(columns) truncated = _reprlib_repr_backport().repr(current) msg = ( "Was unable to generate a column name with " - f"`{n_bytes=}` within {cls._MAX_ITERATIONS} iterations, \n" + f"`{n_chars=}` within {cls._MAX_ITERATIONS} iterations, \n" f"that was not present in existing ({len(current)}) columns:\n{truncated}" ) return NarwhalsError(msg) diff --git a/tests/plan/temp_test.py b/tests/plan/temp_test.py index ec644761e7..a05fb70e8f 100644 --- a/tests/plan/temp_test.py +++ b/tests/plan/temp_test.py @@ -6,7 +6,7 @@ # ruff: noqa: S311 from collections import deque -from itertools import product, repeat +from itertools import islice, product, repeat from typing import TYPE_CHECKING, NamedTuple import hypothesis.strategies as st @@ -59,7 +59,7 @@ def test_temp_column_name_sources(source: _StoresColumns | Iterable[str]) -> Non @given(n_chars=st.integers(6, 106)) @pytest.mark.slow -def test_column_name_n_chars(n_chars: int) -> None: +def test_temp_column_name_n_chars(n_chars: int) -> None: name = temp.column_name(_COLUMNS, n_chars=n_chars) assert name not in _COLUMNS @@ -76,7 +76,7 @@ def test_column_name_n_chars(n_chars: int) -> None: ("", random.randint(0, 5)), ], ) -def test_column_name_requires_more_characters(prefix: str, n_chars: int) -> None: +def test_temp_column_name_requires_more_characters(prefix: str, n_chars: int) -> None: pattern = re.compile( rf"temp.+column.+name.+requires.+try.+shorter.+{prefix}.+higher.+{n_chars}", re.IGNORECASE | re.DOTALL, @@ -85,7 +85,7 @@ def test_column_name_requires_more_characters(prefix: str, n_chars: int) -> None temp.column_name(_COLUMNS, prefix=prefix, n_chars=n_chars) -def test_column_name_failed_unique() -> None: +def test_temp_column_name_failed_unique() -> None: hex_lower = string.hexdigits.strip(string.ascii_uppercase) every_possible_name_65k = [ f"nw{e1}{e2}{e3}{e4}" for e1, e2, e3, e4 in product(*repeat(hex_lower, 4)) @@ -93,7 +93,18 @@ def test_column_name_failed_unique() -> None: n_many_columns = len(every_possible_name_65k) pattern = re.compile( - rf"unable.+generate.+name.+within.+existing.+{n_many_columns}.+columns", re.DOTALL + rf"unable.+generate.+name.+n_chars=6.+within.+existing.+{n_many_columns}.+columns", + re.DOTALL, ) with pytest.raises(NarwhalsError, match=pattern): temp.column_name(every_possible_name_65k, prefix="nw", n_chars=6) + + +def test_temp_column_names_failed_unique() -> None: + it = temp.column_names(["a", "b", "c"], prefix="long_prefix", n_chars=16) + pattern = re.compile( + r"unable.+generate.+name.+n_chars=16.+within.+existing.+.+columns.+\.\.\.", + re.DOTALL, + ) + with pytest.raises(NarwhalsError, match=pattern): + list(islice(it, 100_000)) From cef3e0670ef2e208b3bfb071487c78de83b25e1f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 30 Sep 2025 17:39:03 +0000 Subject: [PATCH 92/93] test: More `temp.column_names` tests --- tests/plan/temp_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/plan/temp_test.py b/tests/plan/temp_test.py index a05fb70e8f..9dd7a0e42f 100644 --- a/tests/plan/temp_test.py +++ b/tests/plan/temp_test.py @@ -57,6 +57,13 @@ def test_temp_column_name_sources(source: _StoresColumns | Iterable[str]) -> Non assert name not in _COLUMNS +@sources +def test_temp_column_names_sources(source: _StoresColumns | Iterable[str]) -> None: + it = temp.column_names(source) + name = next(it) + assert name not in _COLUMNS + + @given(n_chars=st.integers(6, 106)) @pytest.mark.slow def test_temp_column_name_n_chars(n_chars: int) -> None: @@ -64,6 +71,15 @@ def test_temp_column_name_n_chars(n_chars: int) -> None: assert name not in _COLUMNS +@given(n_new_names=st.integers(10_000, 100_000)) +@pytest.mark.slow +def test_temp_column_names_always_new_names(n_new_names: int) -> None: + it = temp.column_names(_COLUMNS) + new_names = set(islice(it, n_new_names)) + assert len(new_names) == n_new_names + assert new_names.isdisjoint(_COLUMNS) + + @pytest.mark.parametrize( ("prefix", "n_chars"), [ From 79e49f475510483901f37e98151b17cd1cb4ad7d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 30 Sep 2025 22:04:10 +0000 Subject: [PATCH 93/93] docs: Add meaningful examples to `temp.column_name` Still need to do `temp.column_names` as well That one is quite different to #3147, but was needed in this PR --- narwhals/_plan/common.py | 49 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index ef2955b441..defe398f95 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -156,7 +156,48 @@ def column_name( prefix: str = "nw", n_chars: int = 16, ) -> str: - """Generate a single, unique column name that is not present in `source`.""" + """Generate a single, unique column name that is not present in `source`. + + Arguments: + source: Source of columns to check for uniqueness. + prefix: Prepend the name with this string. + n_chars: Total number of characters used by the name (including `prefix`). + + Examples: + >>> import narwhals as nw + >>> from narwhals._plan.common import temp + >>> columns = "abc", "xyz" + >>> temp.column_name(columns) # doctest: +SKIP + 'nwf65daf7ceb3c2f' + + Limit the number of characters that the name uses + + >>> temp.column_name(columns, n_chars=8) # doctest: +SKIP + 'nw388b5d' + + Make the name easier to trace back + + >>> temp.column_name(columns, prefix="_its_a_me_") # doctest: +SKIP + '_its_a_me_0ea2b0' + + Pass in a `DataFrame` directly, and let us get the columns for you + + >>> df = nw.from_dict({"foo": [1, 2], "bar": [6.0, 7.0]}, backend="polars") + >>> df.with_row_index(temp.column_name(df, prefix="idx_")) # doctest: +SKIP + ┌────────────────────────────────┐ + | Narwhals DataFrame | + |--------------------------------| + |shape: (2, 3) | + |┌──────────────────┬─────┬─────┐| + |│ idx_bae5e1b22963 ┆ foo ┆ bar │| + |│ --- ┆ --- ┆ --- │| + |│ u32 ┆ i64 ┆ f64 │| + |╞══════════════════╪═════╪═════╡| + |│ 0 ┆ 1 ┆ 6.0 │| + |│ 1 ┆ 2 ┆ 7.0 │| + |└──────────────────┴─────┴─────┘| + └────────────────────────────────┘ + """ columns = cls._into_columns(source) prefix, n_bytes = cls._parse_prefix_n_bytes(prefix, n_chars) for _ in range(cls._MAX_ITERATIONS): @@ -165,6 +206,7 @@ def column_name( return token raise cls._failed_generation_error(columns, n_chars) + # TODO @dangotbanned: Write examples @classmethod def column_names( cls, @@ -177,6 +219,11 @@ def column_names( """Yields unique column names that are not present in `source`. Any column name returned will be unique among those that preceded it. + + Arguments: + source: Source of columns to check for uniqueness. + prefix: Prepend the name with this string. + n_chars: Total number of characters used by the name (including `prefix`). """ columns = cls._into_columns(source) prefix, n_bytes = cls._parse_prefix_n_bytes(prefix, n_chars)