diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index fb2dd390a8..6cbf061a98 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -87,11 +87,10 @@ Excluded: TypeAlias = "frozenset[str]" """Internally use a `set`, then freeze before returning.""" -GroupByKeys: TypeAlias = "Seq[ExprIR]" -"""Represents group_by keys. +GroupByKeys: TypeAlias = "Seq[str]" +"""Represents `group_by` keys. -- Originates from `polars_plan::plans::conversion::dsl_to_ir::resolve_group_by` -- Not fully utilized in `narwhals` version yet +They need to be excluded from expansion. """ OutputNames: TypeAlias = "Seq[str]" @@ -154,24 +153,23 @@ def with_multiple_columns(self) -> ExpansionFlags: def prepare_projection( - exprs: Sequence[ExprIR], schema: IntoFrozenSchema -) -> tuple[Seq[ExprIR], FrozenSchema, OutputNames]: + exprs: Sequence[ExprIR], /, keys: GroupByKeys = (), *, schema: IntoFrozenSchema +) -> tuple[Seq[NamedIR], FrozenSchema]: """Expand IRs into named column selections. - **Primary entry-point**, will be used by `select`, `with_columns`, + **Primary entry-point**, for `select`, `with_columns`, and any other context that requires resolving expression names. Arguments: exprs: IRs that *may* contain things like `Columns`, `SelectorIR`, `Exclude`, etc. + keys: Names of `group_by` columns. schema: Scope to expand multi-column selectors in. - - Returns: - `exprs`, rewritten using `Column(name)` only. """ frozen_schema = freeze_schema(schema) - rewritten = rewrite_projections(tuple(exprs), keys=(), schema=frozen_schema) + rewritten = rewrite_projections(tuple(exprs), keys=keys, schema=frozen_schema) output_names = ensure_valid_exprs(rewritten, frozen_schema) - return rewritten, frozen_schema, output_names + named_irs = into_named_irs(rewritten, output_names) + return named_irs, frozen_schema def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]: @@ -202,7 +200,7 @@ def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames: def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: if is_horizontal_reduction(child): - rewrites = rewrite_projections(child.input, keys=(), schema=schema) + rewrites = rewrite_projections(child.input, schema=schema) return common.replace(child, input=rewrites) return child @@ -275,7 +273,7 @@ def expand_selector(selector: SelectorIR, schema: FrozenSchema) -> Columns: def rewrite_projections( input: Seq[ExprIR], # `FunctionExpr.input` /, - keys: GroupByKeys, + keys: GroupByKeys = (), *, schema: FrozenSchema, ) -> Seq[ExprIR]: @@ -323,13 +321,10 @@ def prepare_excluded( origin: ExprIR, keys: GroupByKeys, flags: ExpansionFlags, / ) -> Excluded: """Huge simplification of https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555.""" - exclude: set[str] = set() - if flags.has_exclude: - exclude.update(*(e.names for e in origin.iter_left() if isinstance(e, Exclude))) - for group_by_key in keys: - if name := group_by_key.meta.output_name(raise_if_undetermined=False): - exclude.add(name) - return frozenset(exclude) + gb_keys = frozenset(keys) + if not flags.has_exclude: + return gb_keys + return gb_keys.union(*(e.names for e in origin.iter_left() if isinstance(e, Exclude))) def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool: diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 0646520102..d163134c80 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -290,3 +290,17 @@ def is_elementwise_top_level(self) -> bool: if is_literal(ir): return ir.is_scalar return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) + + def is_column(self, *, allow_aliasing: bool = False) -> bool: + """Return True if wrapping a single `Column` node. + + Note: + Multi-output (including selectors) expressions have been expanded at this stage. + + Arguments: + allow_aliasing: If False (default), any aliasing is not considered to be column selection. + """ + from narwhals._plan.expressions import Column + + ir = self.expr + return isinstance(ir, Column) and ((self.name == ir.name) or allow_aliasing) diff --git a/narwhals/_plan/_rewrites.py b/narwhals/_plan/_rewrites.py index ae23fa4b9b..fd26364e66 100644 --- a/narwhals/_plan/_rewrites.py +++ b/narwhals/_plan/_rewrites.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from narwhals._plan._expansion import into_named_irs, prepare_projection +from narwhals._plan._expansion import prepare_projection from narwhals._plan._guards import ( is_aggregation, is_binary_expr, @@ -31,8 +31,7 @@ def rewrite_all( - Currently we do a full traversal of each tree per-rewrite function - There's no caching *after* `prepare_projection` yet """ - out_irs, _, names = prepare_projection(parse_into_seq_of_expr_ir(*exprs), schema) - named_irs = into_named_irs(out_irs, names) + named_irs, _ = prepare_projection(parse_into_seq_of_expr_ir(*exprs), schema=schema) return tuple(map_ir(ir, *rewrites) for ir in named_irs) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py new file mode 100644 index 0000000000..768248e312 --- /dev/null +++ b/narwhals/_plan/arrow/acero.py @@ -0,0 +1,253 @@ +"""Sugar for working with [Acero]. + +[`pyarrow.acero`] has some building blocks for constructing queries, but is +quite verbose when used directly. + +This module aligns some apis to look more like `polars`. + +Notes: + - Functions suffixed with `_table` all handle composition and collection internally + +[Acero]: https://arrow.apache.org/docs/cpp/acero/overview.html +[`pyarrow.acero`]: https://arrow.apache.org/docs/python/api/acero.html +""" + +from __future__ import annotations + +import functools +import operator +from functools import reduce +from itertools import chain +from typing import TYPE_CHECKING, Any, Final, Union, cast + +import pyarrow as pa # ignore-banned-import +import pyarrow.acero as pac +import pyarrow.compute as pc # ignore-banned-import +from pyarrow.acero import Declaration as Decl + +from narwhals._plan.typing import OneOrSeq +from narwhals.typing import SingleColSelector + +if TYPE_CHECKING: + from collections.abc import Callable, Collection, Iterable, Iterator + + from typing_extensions import TypeAlias + + from narwhals._arrow.typing import ( # type: ignore[attr-defined] + AggregateOptions as _AggregateOptions, + Aggregation as _Aggregation, + ) + from narwhals._plan.arrow.group_by import AggSpec + from narwhals._plan.arrow.typing import NullPlacement + from narwhals._plan.typing import OneOrIterable, Order, Seq + from narwhals.typing import NonNestedLiteral + +Incomplete: TypeAlias = Any +Expr: TypeAlias = pc.Expression +IntoExpr: TypeAlias = "Expr | NonNestedLiteral" +Field: TypeAlias = Union[Expr, SingleColSelector] +"""Anything that passes as a single item in [`_compute._ensure_field_ref`]. + +[`_compute._ensure_field_ref`]: https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_compute.pyx#L1507-L1531 +""" + +Target: TypeAlias = OneOrSeq[Field] +Aggregation: TypeAlias = "_Aggregation" +AggregateOptions: TypeAlias = "_AggregateOptions" +Opts: TypeAlias = "AggregateOptions | None" +OutputName: TypeAlias = str + +_THREAD_UNSAFE: Final = frozenset[Aggregation]( + ("hash_first", "hash_last", "first", "last") +) +col = pc.field +lit = cast("Callable[[NonNestedLiteral], Expr]", pc.scalar) +"""Alias for `pyarrow.compute.scalar`.""" + + +# NOTE: ATOW there are 304 valid function names, 46 can be used for some kind of agg +# Due to expr expansion, it is very likely that we have repeat runs +@functools.lru_cache(maxsize=128) +def can_thread(function_name: str, /) -> bool: + return function_name not in _THREAD_UNSAFE + + +def _parse_into_expr(into: IntoExpr, /, *, str_as_lit: bool = False) -> Expr: + if isinstance(into, pc.Expression): + return into + if isinstance(into, str) and not str_as_lit: + return col(into) + return lit(into) + + +def _parse_into_iter_expr(inputs: Iterable[IntoExpr], /) -> Iterator[Expr]: + for into_expr in inputs: + yield _parse_into_expr(into_expr) + + +def _parse_into_seq_of_expr(inputs: Iterable[IntoExpr], /) -> Seq[Expr]: + return tuple(_parse_into_iter_expr(inputs)) + + +def _parse_all_horizontal(predicates: Seq[Expr], constraints: dict[str, Any], /) -> Expr: + if not constraints and len(predicates) == 1: + return predicates[0] + it = ( + col(name) == _parse_into_expr(v, str_as_lit=True) + for name, v in constraints.items() + ) + return reduce(operator.and_, chain(predicates, it)) + + +def table_source(native: pa.Table, /) -> Decl: + """Start building a logical plan, using `native` as the source table. + + All calls to `collect` must use this as the first `Declaration`. + """ + return Decl("table_source", options=pac.TableSourceNodeOptions(native)) + + +def _aggregate(aggs: Iterable[AggSpec], /, keys: Iterable[Field] | None = None) -> Decl: + # NOTE: See https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_acero.pyx#L167-L192 + aggs_: Incomplete = aggs + keys_: Incomplete = keys + return Decl("aggregate", pac.AggregateNodeOptions(aggs_, keys=keys_)) + + +def aggregate(aggs: Iterable[AggSpec], /) -> Decl: + """May only use [Scalar aggregate] functions. + + [Scalar aggregate]: https://arrow.apache.org/docs/cpp/compute.html#aggregations + """ + return _aggregate(aggs) + + +def group_by(keys: Iterable[Field], aggs: Iterable[AggSpec], /) -> Decl: + """May only use [Hash aggregate] functions, requires grouping. + + [Hash aggregate]: https://arrow.apache.org/docs/cpp/compute.html#grouped-aggregations-group-by + """ + return _aggregate(aggs, keys=keys) + + +def filter(*predicates: Expr, **constraints: IntoExpr) -> Decl: + expr = _parse_all_horizontal(predicates, constraints) + return Decl("filter", options=pac.FilterNodeOptions(expr)) + + +def select_names(column_names: OneOrIterable[str], *more_names: str) -> Decl: + """`select` where all args are column names.""" + if not more_names: + if isinstance(column_names, str): + return _project((col(column_names),), (column_names,)) + more_names = tuple(column_names) + elif isinstance(column_names, str): + more_names = column_names, *more_names + else: + msg = f"Passing both iterable and positional inputs is not supported.\n{column_names=}\n{more_names=}" + raise NotImplementedError(msg) + return _project([col(name) for name in more_names], more_names) + + +def _project(exprs: Collection[Expr], names: Collection[str]) -> Decl: + # NOTE: Both just need to be `Sized` and `Iterable` + exprs_: Incomplete = exprs + names_: Incomplete = names + return Decl("project", options=pac.ProjectNodeOptions(exprs_, names_)) + + +def project(**named_exprs: IntoExpr) -> Decl: + """Similar to `select`, but more rigid. + + Arguments: + **named_exprs: Inputs composed of any combination of + + - Column names or `pc.field` + - Python literals or `pc.scalar` (for `str` literals) + - [Scalar functions] applied to the above + + Notes: + - [`Expression`]s have no concept of aliasing, therefore, all inputs must be `**named_exprs`. + - Always returns a table with the same length, scalar literals are broadcast unconditionally. + + [`Expression`]: https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html + [Scalar functions]: https://arrow.apache.org/docs/cpp/compute.html#element-wise-scalar-functions + """ + exprs = _parse_into_seq_of_expr(named_exprs.values()) + return _project(names=named_exprs.keys(), exprs=exprs) + + +def _order_by( + sort_keys: Iterable[tuple[str, Order]] = (), + *, + null_placement: NullPlacement = "at_end", +) -> Decl: + # NOTE: There's no runtime type checking of `sort_keys` wrt shape + # Just need to be `Iterable`and unpack like a 2-tuple + # https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_compute.pyx#L77-L88 + keys: Incomplete = sort_keys + return Decl("order_by", pac.OrderByNodeOptions(keys, null_placement=null_placement)) + + +# TODO @dangotbanned: Utilize `SortMultipleOptions.to_arrow_acero` +def sort_by(*args: Any, **kwds: Any) -> Decl: + msg = "Should convert from polars args -> use `_order_by" + raise NotImplementedError(msg) + + +def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table: + """Compose and evaluate a logical plan. + + Arguments: + *declarations: One or more `Declaration` nodes to execute as a pipeline. + **The first node must be a `table_source`**. + use_threads: Pass `False` if `declarations` contains any order-dependent aggregation(s). + """ + # NOTE: stubs + docs say `list`, but impl allows any iterable + decls: Incomplete = declarations + return Decl.from_sequence(decls).to_table(use_threads=use_threads) + + +def group_by_table( + native: pa.Table, keys: Iterable[Field], aggs: Iterable[AggSpec] +) -> pa.Table: + """Adapted from [`pa.TableGroupBy.aggregate`] and [`pa.acero._group_by`]. + + - Backport of [apache/arrow#36768]. + - `first` and `last` were [broken in `pyarrow==13`]. + - Also allows us to specify our own aliases for aggregate output columns. + - Fixes [narwhals-dev/narwhals#1612] + + [`pa.TableGroupBy.aggregate`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/table.pxi#L6600-L6626 + [`pa.acero._group_by`]: https://github.com/apache/arrow/blob/0e7e70cfdef4efa287495272649c071a700c34fa/python/pyarrow/acero.py#L412-L418 + [apache/arrow#36768]: https://github.com/apache/arrow/pull/36768 + [broken in `pyarrow==13`]: https://github.com/apache/arrow/issues/36709 + [narwhals-dev/narwhals#1612]: https://github.com/narwhals-dev/narwhals/issues/1612 + """ + aggs = tuple(aggs) + use_threads = all(spec.use_threads for spec in aggs) + return collect(table_source(native), group_by(keys, aggs), use_threads=use_threads) + + +def filter_table(native: pa.Table, *predicates: Expr, **constraints: Any) -> pa.Table: + """Selects rows where all expressions evaluate to True. + + Arguments: + native: source table + predicates: [`Expression`]s which must all have a return type of boolean. + constraints: Column filters; use `name = value` to filter columns by the supplied value. + + Notes: + - Uses logic similar to [`polars`] for an AND-reduction + - Elements where the filter does not evaluate to True are discarded, **including nulls** + + [`Expression`]: https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html + [`polars`]: https://github.com/pola-rs/polars/blob/d0914d416ce4e1dfcb5f946875ffd1181e31c493/py-polars/polars/_utils/parse/expr.py#L199-L242 + """ + return collect(table_source(native), filter(*predicates, **constraints)) + + +def select_names_table( + native: pa.Table, column_names: OneOrIterable[str], *more_names: str +) -> pa.Table: + return collect(table_source(native), select_names(column_names, *more_names)) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 27a02bc2ed..b588b59180 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -1,15 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, overload +import operator +from functools import reduce +from itertools import chain +from typing import TYPE_CHECKING, Any, Literal, cast, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn +from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy from narwhals._plan.arrow.series import ArrowSeries as Series +from narwhals._plan.expressions import NamedIR from narwhals._plan.protocols import EagerDataFrame, namespace -from narwhals._utils import Version +from narwhals._plan.typing import Seq +from narwhals._utils import Version, parse_columns_to_drop from narwhals.schema import Schema if TYPE_CHECKING: @@ -29,11 +35,18 @@ class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]): + _native: pa.Table + _version: Version + def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._plan.arrow.namespace import ArrowNamespace return ArrowNamespace(self._version) + @property + def _group_by(self) -> type[GroupBy]: + return GroupBy + @property def columns(self) -> list[str]: return self.native.column_names @@ -95,10 +108,26 @@ def get_column(self, name: str) -> Series: chunked = self.native.column(name) return Series.from_native(chunked, name, version=self.version) - def drop(self, columns: Sequence[str]) -> Self: - to_drop = list(columns) + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: + to_drop = parse_columns_to_drop(self, columns, strict=strict) return self._with_native(self.native.drop(to_drop)) + def drop_nulls(self, subset: Sequence[str] | None) -> Self: + if subset is None: + native = self.native.drop_null() + else: + to_drop = reduce(operator.or_, (pc.field(name).is_null() for name in subset)) + native = self.native.filter(~to_drop) + return self._with_native(native) + + def rename(self, mapping: Mapping[str, str]) -> Self: + names: dict[str, str] | list[str] + if fn.BACKEND_VERSION >= (17,): + names = cast("dict[str, str]", mapping) + else: # pragma: no cover + names = [mapping.get(c, c) for c in self.columns] + return self._with_native(self.native.rename_columns(names)) + # NOTE: Use instead of `with_columns` for trivial cases def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: native = self.native @@ -113,3 +142,10 @@ def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: else: native = native.append_column(name, chunked) return self._with_native(native) + + def select_names(self, *column_names: str) -> Self: + return self._with_native(self.native.select(list(column_names))) + + def row(self, index: int) -> tuple[Any, ...]: + row = self.native.slice(index, 1) + return tuple(chain.from_iterable(row.to_pydict().values())) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 57ec5196d6..b547ed57fa 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -35,6 +35,7 @@ Count, First, Last, + Len, Max, Mean, Median, @@ -54,7 +55,7 @@ Not, ) from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr - from narwhals._plan.expressions.functions import FillNull, Pow + from narwhals._plan.expressions.functions import Abs, FillNull, Pow from narwhals.typing import Into1DArray, IntoDType, PythonLiteral Expr: TypeAlias = "ArrowExpr" @@ -111,6 +112,9 @@ def func(node: FunctionExpr[Any], frame: Frame, name: str) -> StoresNativeT_co: return func + def abs(self, node: FunctionExpr[Abs], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(pc.abs)(node, frame, name) + def not_(self, node: FunctionExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(pc.invert)(node, frame, name) @@ -296,6 +300,10 @@ def count(self, node: Count, frame: Frame, name: str) -> Scalar: result = fn.count(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) + def len(self, node: Len, frame: Frame, name: str) -> Scalar: + result = fn.count(self._dispatch_expr(node.expr, frame, name).native, mode="all") + return self._with_native(result, name) + def max(self, node: Max, frame: Frame, name: str) -> Scalar: result: NativeScalar = fn.max_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) @@ -460,6 +468,9 @@ def count(self, node: Count, frame: Frame, name: str) -> Scalar: native = node.expr.dispatch(self, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) + def len(self, node: Len, frame: Frame, name: str) -> Scalar: + return self._with_native(pa.scalar(1), name) + filter = not_implemented() over = not_implemented() over_ordered = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 7a16404d3d..1fd1942b2c 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -13,11 +13,12 @@ chunked_array as _chunked_array, floordiv_compat as floordiv, ) +from narwhals._plan.arrow import options from narwhals._plan.expressions import operators as ops from narwhals._utils import Implementation if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping, Sequence + from collections.abc import Iterable, Mapping from typing_extensions import TypeIs @@ -38,17 +39,20 @@ ChunkedOrScalar, ChunkedOrScalarAny, DataType, + DataTypeRemap, DataTypeT, IntegerScalar, IntegerType, + LargeStringType, NativeScalar, Scalar, ScalarAny, ScalarT, StringScalar, + StringType, UnaryFunction, ) - from narwhals.typing import ClosedInterval + from narwhals.typing import ClosedInterval, IntoArrowSchema BACKEND_VERSION = Implementation.PYARROW._backend_version() @@ -133,6 +137,41 @@ def cast( return pc.cast(native, target_type, safe=safe) +def cast_schema( + native: pa.Schema, target_types: DataType | Mapping[str, DataType] | DataTypeRemap +) -> pa.Schema: + if isinstance(target_types, pa.DataType): + return pa.schema((name, target_types) for name in native.names) + if _is_into_pyarrow_schema(target_types): + new_schema = native + for name, dtype in target_types.items(): + index = native.get_field_index(name) + new_schema.set(index, native.field(index).with_type(dtype)) + return new_schema + return pa.schema((fld.name, target_types.get(fld.type, fld.type)) for fld in native) + + +def cast_table( + native: pa.Table, target: DataType | IntoArrowSchema | DataTypeRemap +) -> pa.Table: + s = target if isinstance(target, pa.Schema) else cast_schema(native.schema, target) + return native.cast(s) + + +def has_large_string(data_types: Iterable[DataType], /) -> bool: + return any(pa.types.is_large_string(tp) for tp in data_types) + + +def string_type(data_types: Iterable[DataType] = (), /) -> StringType | LargeStringType: + """Return a native string type, compatible with `data_types`. + + Until [apache/arrow#45717] is resolved, we need to upcast `string` to `large_string` when joining. + + [apache/arrow#45717]: https://github.com/apache/arrow/issues/45717 + """ + return pa.large_string() if has_large_string(data_types) else pa.string() + + def any_(native: Any) -> pa.BooleanScalar: return pc.any(native, min_count=0) @@ -180,21 +219,11 @@ def binary( def concat_str( *arrays: ChunkedArrayAny, separator: str = "", ignore_nulls: bool = False ) -> ChunkedArray[StringScalar]: - fn: Incomplete = pc.binary_join_element_wise - it, sep = _cast_to_comparable_string_types(arrays, separator) - return fn(*it, sep, null_handling="skip" if ignore_nulls else "emit_null") # type: ignore[no-any-return] - - -def _cast_to_comparable_string_types( - arrays: Sequence[ChunkedArrayAny], /, separator: str -) -> tuple[Iterator[ChunkedArray[StringScalar]], StringScalar]: - # Ensure `chunked_arrays` are either all `string` or all `large_string`. - dtype = ( - pa.string() - if not any(pa.types.is_large_string(obj.type) for obj in arrays) - else pa.large_string() - ) - return (obj.cast(dtype) for obj in arrays), pa.scalar(separator, dtype) + dtype = string_type(obj.type for obj in arrays) + it = (obj.cast(dtype) for obj in arrays) + concat: Incomplete = pc.binary_join_element_wise + join = options.join(ignore_nulls=ignore_nulls) + return concat(*it, lit(separator, dtype), options=join) # type: ignore[no-any-return] def int_range( @@ -260,3 +289,11 @@ def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: from narwhals._plan.arrow.series import ArrowSeries return isinstance(obj, ArrowSeries) + + +def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataType]]: + return ( + (first := next(iter(obj.items())), None) + and isinstance(first[0], str) + and isinstance(first[1], pa.DataType) + ) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py new file mode 100644 index 0000000000..c878f344ed --- /dev/null +++ b/narwhals/_plan/arrow/group_by.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan import expressions as ir +from narwhals._plan._guards import is_agg_expr, is_function_expr +from narwhals._plan.arrow import acero, functions as fn, options +from narwhals._plan.common import dispatch_method_name, temp +from narwhals._plan.expressions import aggregation as agg +from narwhals._plan.protocols import EagerDataFrameGroupBy +from narwhals._utils import Implementation +from narwhals.exceptions import InvalidOperationError + +if TYPE_CHECKING: + from collections.abc import Iterator, Mapping + + from typing_extensions import Self, TypeAlias + + from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame + from narwhals._plan.arrow.typing import ChunkedArray + from narwhals._plan.expressions import NamedIR + from narwhals._plan.typing import Seq + +Incomplete: TypeAlias = Any + +# NOTE: Unless stated otherwise, all aggregations have 2 variants: +# - `` (pc.Function.kind == "scalar_aggregate") +# - `hash_` (pc.Function.kind == "hash_aggregate") +SUPPORTED_AGG: Mapping[type[agg.AggExpr], acero.Aggregation] = { + agg.Sum: "hash_sum", + agg.Mean: "hash_mean", + agg.Median: "hash_approximate_median", + agg.Max: "hash_max", + agg.Min: "hash_min", + agg.Std: "hash_stddev", + agg.Var: "hash_variance", + agg.Count: "hash_count", + agg.Len: "hash_count", + agg.NUnique: "hash_count_distinct", + agg.First: "hash_first", + agg.Last: "hash_last", +} +SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = { + ir.Len: "hash_count_all", + ir.Column: "hash_list", # `hash_aggregate` only +} +SUPPORTED_FUNCTION: Mapping[type[ir.Function], acero.Aggregation] = { + ir.boolean.All: "hash_all", + ir.boolean.Any: "hash_any", + ir.functions.Unique: "hash_distinct", # `hash_aggregate` only +} + +REQUIRES_PYARROW_20: tuple[Literal["kurtosis"], Literal["skew"]] = ("kurtosis", "skew") +"""They don't show in [our version of the stubs], but are possible in [`pyarrow>=20`]. + +[our version of the stubs]: https://github.com/narwhals-dev/narwhals/issues/2124#issuecomment-3191374210 +[`pyarrow>=20`]: https://arrow.apache.org/docs/20.0/python/compute.html#grouped-aggregations +""" + + +class AggSpec: + __slots__ = ("agg", "name", "option", "target") + + def __init__( + self, + target: acero.Target, + agg: acero.Aggregation, + option: acero.Opts = None, + name: acero.OutputName = "", + ) -> None: + self.target = target + self.agg = agg + self.option = option + self.name = name or str(target) + + @property + def use_threads(self) -> bool: + """See https://github.com/apache/arrow/issues/36709.""" + return acero.can_thread(self.agg) + + def __iter__(self) -> Iterator[acero.Target | acero.Aggregation | acero.Opts]: + """Let's us duck-type as a 4-tuple.""" + yield from (self.target, self.agg, self.option, self.name) + + @classmethod + def from_named_ir(cls, named_ir: NamedIR) -> Self: + return cls.from_expr_ir(named_ir.expr, named_ir.name) + + @classmethod + def from_agg_expr(cls, expr: agg.AggExpr, name: acero.OutputName) -> Self: + tp = type(expr) + if not (agg_name := SUPPORTED_AGG.get(tp)): + raise group_by_error(name, expr) + if not isinstance(expr.expr, ir.Column): + raise group_by_error(name, expr, "too complex") + option = ( + options.variance(expr.ddof) + if isinstance(expr, (agg.Std, agg.Var)) + else options.AGG.get(tp) + ) + return cls(expr.expr.name, agg_name, option, name) + + @classmethod + def from_function_expr(cls, expr: ir.FunctionExpr, name: acero.OutputName) -> Self: + tp = type(expr.function) + if not (fn_name := SUPPORTED_FUNCTION.get(tp)): + raise group_by_error(name, expr) + args = expr.input + if not (len(args) == 1 and isinstance(args[0], ir.Column)): + raise group_by_error(name, expr, "too complex") + return cls(args[0].name, fn_name, options.FUNCTION.get(tp), name) + + @classmethod + def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self: + if is_agg_expr(expr): + return cls.from_agg_expr(expr, name) + if is_function_expr(expr): + return cls.from_function_expr(expr, name) + if not isinstance(expr, (ir.Len, ir.Column)): + raise group_by_error(name, expr) + fn_name = SUPPORTED_IR[type(expr)] + return cls(expr.name if isinstance(expr, ir.Column) else (), fn_name, name=name) + + +def group_by_error( + column_name: str, expr: ir.ExprIR, reason: Literal["too complex"] | None = None +) -> InvalidOperationError: + backend = Implementation.PYARROW + if reason == "too complex": + msg = "Non-trivial complex aggregation found, which" + else: + if is_function_expr(expr): + func_name = repr(expr.function) + else: + func_name = dispatch_method_name(type(expr)) + msg = f"`{func_name}()`" + msg = f"{msg} is not supported in a `group_by` context for {backend!r}:\n{column_name}={expr!r}" + return InvalidOperationError(msg) + + +def concat_str(native: pa.Table, *, separator: str = "") -> ChunkedArray: + dtype = fn.string_type(native.schema.types) + it = fn.cast_table(native, dtype).itercolumns() + concat: Incomplete = pc.binary_join_element_wise + join = options.join_replace_nulls() + return concat(*it, fn.lit(separator, dtype), options=join) # type: ignore[no-any-return] + + +class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): + _df: Frame + _keys: Seq[NamedIR] + _key_names: Seq[str] + _key_names_original: Seq[str] + + @property + def compliant(self) -> Frame: + return self._df + + def __iter__(self) -> Iterator[tuple[Any, Frame]]: + temp_name = temp.column_name(self.compliant) + native = self.compliant.native + composite_values = concat_str(acero.select_names_table(native, self.key_names)) + re_keyed = native.add_column(0, temp_name, composite_values) + from_native = self.compliant._with_native + for v in composite_values.unique(): + t = from_native(acero.filter_table(re_keyed, pc.field(temp_name) == v)) + yield ( + t.select_names(*self.key_names).row(0), + t.select_names(*self._column_names_original), + ) + + def agg(self, irs: Seq[NamedIR]) -> Frame: + compliant = self.compliant + native = compliant.native + key_names = self.key_names + specs = (AggSpec.from_named_ir(e) for e in irs) + result = compliant._with_native(acero.group_by_table(native, key_names, specs)) + if original := self._key_names_original: + return result.rename(dict(zip(key_names, original))) + return result diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py new file mode 100644 index 0000000000..8998b288a2 --- /dev/null +++ b/narwhals/_plan/arrow/options.py @@ -0,0 +1,105 @@ +"""Cached `pyarrow.compute` options classes, using `polars` defaults. + +Important: + `AGG` and `FUNCTION` mappings are constructed on first `__getattr__` access. +""" + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Literal + +import pyarrow.compute as pc # ignore-banned-import + +if TYPE_CHECKING: + from collections.abc import Mapping + + from narwhals._plan import expressions as ir + from narwhals._plan.arrow import acero + from narwhals._plan.expressions import aggregation as agg + + +__all__ = [ + "AGG", + "FUNCTION", + "count", + "join", + "join_replace_nulls", + "scalar_aggregate", + "variance", +] + + +AGG: Mapping[type[agg.AggExpr], acero.AggregateOptions] +FUNCTION: Mapping[type[ir.Function], acero.AggregateOptions] + + +@functools.cache +def count( + mode: Literal["only_valid", "only_null", "all"] = "only_valid", +) -> pc.CountOptions: + return pc.CountOptions(mode) + + +# pyarrow defaults to ignore_nulls +# polars doesn't mention +@functools.cache +def variance( + ddof: int = 1, *, ignore_nulls: bool = True, min_count: int = 0 +) -> pc.VarianceOptions: + return pc.VarianceOptions(ddof=ddof, skip_nulls=ignore_nulls, min_count=min_count) + + +@functools.cache +def scalar_aggregate( + *, ignore_nulls: bool = False, min_count: int = 0 +) -> pc.ScalarAggregateOptions: + return pc.ScalarAggregateOptions(skip_nulls=ignore_nulls, min_count=min_count) + + +@functools.cache +def join(*, ignore_nulls: bool = False) -> pc.JoinOptions: + return pc.JoinOptions(null_handling="skip" if ignore_nulls else "emit_null") + + +@functools.cache +def join_replace_nulls(*, replacement: str = "__nw_null_value__") -> pc.JoinOptions: + return pc.JoinOptions(null_handling="replace", null_replacement=replacement) + + +def _generate_agg() -> Mapping[type[agg.AggExpr], acero.AggregateOptions]: + from narwhals._plan.expressions import aggregation as agg + + return { + agg.NUnique: count("all"), + agg.Len: count("all"), + agg.Count: count("only_valid"), + agg.First: scalar_aggregate(), + agg.Last: scalar_aggregate(), + } + + +def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: + from narwhals._plan.expressions import boolean + + return { + boolean.All: scalar_aggregate(ignore_nulls=True), + boolean.Any: scalar_aggregate(ignore_nulls=True), + } + + +# ruff: noqa: PLW0603 +# NOTE: Using globals for lazy-loading cache +if not TYPE_CHECKING: + + def __getattr__(name: str) -> Any: + if name == "AGG": + global AGG + AGG = _generate_agg() + return AGG + if name == "FUNCTION": + global FUNCTION + FUNCTION = _generate_function() + return FUNCTION + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index e633e6560e..e11e9d45c1 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Protocol, overload +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, Literal, Protocol, overload from narwhals._typing_compat import TypeVar from narwhals._utils import _StoresNative as StoresNative @@ -14,8 +14,8 @@ Int16Type, Int32Type, Int64Type, - LargeStringType, - StringType, + LargeStringType as LargeStringType, # noqa: PLC0414 + StringType as StringType, # noqa: PLC0414 Uint8Type, Uint16Type, Uint32Type, @@ -117,3 +117,5 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot NativeScalar: TypeAlias = ScalarAny BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] StoresNativeT_co = TypeVar("StoresNativeT_co", bound=StoresNative[Any], covariant=True) +DataTypeRemap: TypeAlias = Mapping[DataType, DataType] +NullPlacement: TypeAlias = Literal["at_start", "at_end"] diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 0b4267f214..defe398f95 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -6,15 +6,21 @@ from collections.abc import Iterable from decimal import Decimal from operator import attrgetter +from secrets import token_hex from typing import TYPE_CHECKING, cast, overload from narwhals._plan._guards import is_iterable_reject +from narwhals._utils import _hasattr_static from narwhals.dtypes import DType +from narwhals.exceptions import NarwhalsError from narwhals.utils import Version if TYPE_CHECKING: + import reprlib from collections.abc import Iterator - from typing import Any, Callable, TypeVar + from typing import Any, Callable, ClassVar, TypeVar + + from typing_extensions import TypeIs from narwhals._plan.typing import ( DTypeT, @@ -23,6 +29,7 @@ NonNestedDTypeT, OneOrIterable, ) + from narwhals._utils import _StoresColumns from narwhals.typing import NonNestedDType, NonNestedLiteral T = TypeVar("T") @@ -115,3 +122,161 @@ def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: yield from flatten_hash_safe(element) else: yield element # type: ignore[misc] + + +def _has_columns(obj: Any) -> TypeIs[_StoresColumns]: + return _hasattr_static(obj, "columns") + + +def _reprlib_repr_backport() -> reprlib.Repr: + # 3.12 added `indent` https://github.com/python/cpython/issues/92734 + # but also a useful constructor https://github.com/python/cpython/issues/94343 + import reprlib + + if sys.version_info >= (3, 12): + return reprlib.Repr(indent=4, maxlist=10) + else: # pragma: no cover # noqa: RET505 + obj = reprlib.Repr() + obj.maxlist = 10 + return obj + + +class temp: # noqa: N801 + """Temporary mini namespace for temporary utils.""" + + _MAX_ITERATIONS: ClassVar[int] = 100 + _MIN_RANDOM_CHARS: ClassVar[int] = 4 + + @classmethod + def column_name( + cls, + source: _StoresColumns | Iterable[str], + /, + *, + prefix: str = "nw", + n_chars: int = 16, + ) -> str: + """Generate a single, unique column name that is not present in `source`. + + Arguments: + source: Source of columns to check for uniqueness. + prefix: Prepend the name with this string. + n_chars: Total number of characters used by the name (including `prefix`). + + Examples: + >>> import narwhals as nw + >>> from narwhals._plan.common import temp + >>> columns = "abc", "xyz" + >>> temp.column_name(columns) # doctest: +SKIP + 'nwf65daf7ceb3c2f' + + Limit the number of characters that the name uses + + >>> temp.column_name(columns, n_chars=8) # doctest: +SKIP + 'nw388b5d' + + Make the name easier to trace back + + >>> temp.column_name(columns, prefix="_its_a_me_") # doctest: +SKIP + '_its_a_me_0ea2b0' + + Pass in a `DataFrame` directly, and let us get the columns for you + + >>> df = nw.from_dict({"foo": [1, 2], "bar": [6.0, 7.0]}, backend="polars") + >>> df.with_row_index(temp.column_name(df, prefix="idx_")) # doctest: +SKIP + ┌────────────────────────────────┐ + | Narwhals DataFrame | + |--------------------------------| + |shape: (2, 3) | + |┌──────────────────┬─────┬─────┐| + |│ idx_bae5e1b22963 ┆ foo ┆ bar │| + |│ --- ┆ --- ┆ --- │| + |│ u32 ┆ i64 ┆ f64 │| + |╞══════════════════╪═════╪═════╡| + |│ 0 ┆ 1 ┆ 6.0 │| + |│ 1 ┆ 2 ┆ 7.0 │| + |└──────────────────┴─────┴─────┘| + └────────────────────────────────┘ + """ + columns = cls._into_columns(source) + prefix, n_bytes = cls._parse_prefix_n_bytes(prefix, n_chars) + for _ in range(cls._MAX_ITERATIONS): + token = f"{prefix}{token_hex(n_bytes)}" + if token not in columns: + return token + raise cls._failed_generation_error(columns, n_chars) + + # TODO @dangotbanned: Write examples + @classmethod + def column_names( + cls, + source: _StoresColumns | Iterable[str], + /, + *, + prefix: str = "nw", + n_chars: int = 16, + ) -> Iterator[str]: + """Yields unique column names that are not present in `source`. + + Any column name returned will be unique among those that preceded it. + + Arguments: + source: Source of columns to check for uniqueness. + prefix: Prepend the name with this string. + n_chars: Total number of characters used by the name (including `prefix`). + """ + columns = cls._into_columns(source) + prefix, n_bytes = cls._parse_prefix_n_bytes(prefix, n_chars) + n_failed: int = 0 + while n_failed <= cls._MAX_ITERATIONS: + token = f"{prefix}{token_hex(n_bytes)}" + if token not in columns: + columns.add(token) + n_failed = 0 + yield token + else: + n_failed += 1 + raise cls._failed_generation_error(columns, n_chars) + + @staticmethod + def _into_columns(source: _StoresColumns | Iterable[str], /) -> set[str]: + return set(source.columns if _has_columns(source) else source) + + @classmethod + def _parse_prefix_n_bytes(cls, prefix: str, n_chars: int, /) -> tuple[str, int]: + prefix = prefix or "nw" + if not (available := n_chars - len(prefix)) or available < cls._MIN_RANDOM_CHARS: + raise cls._not_enough_room_error(prefix, n_chars) + return prefix, available // 2 + + @classmethod + def _not_enough_room_error(cls, prefix: str, n_chars: int, /) -> NarwhalsError: + len_prefix = len(prefix) + available_chars = n_chars - len_prefix + if available_chars < 0: + visualize = "" + else: + okay = "✔" * available_chars + bad = "✖" * (cls._MIN_RANDOM_CHARS - available_chars) + visualize = f"\n Preview: '{prefix}{okay}{bad}'" + msg = ( + f"Temporary column name generation requires {len_prefix} characters for the prefix " + f"and at least {cls._MIN_RANDOM_CHARS} more to store random bytes:{visualize}\n\n" + f"Hint: Maybe try\n" + f"- a shorter `prefix` than {prefix!r}?\n" + f"- a higher `n_chars` than {n_chars!r}?" + ) + return NarwhalsError(msg) + + @classmethod + def _failed_generation_error( + cls, columns: Iterable[str], n_chars: int, / + ) -> NarwhalsError: + current = sorted(columns) + truncated = _reprlib_repr_backport().repr(current) + msg = ( + "Was unable to generate a column name with " + f"`{n_chars=}` within {cls._MAX_ITERATIONS} iterations, \n" + f"that was not present in existing ({len(current)}) columns:\n{truncated}" + ) + return NarwhalsError(msg) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 8f06f1e5c9..8956c33457 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload -from narwhals._plan import _expansion, _parse -from narwhals._plan.contexts import ExprContext +from narwhals._plan import _parse +from narwhals._plan._expansion import prepare_projection from narwhals._plan.expr import _parse_sort_by +from narwhals._plan.group_by import GroupBy, Grouped from narwhals._plan.series import Series from narwhals._plan.typing import ( IntoExpr, @@ -18,13 +19,12 @@ from narwhals.schema import Schema if TYPE_CHECKING: + from collections.abc import Sequence + import pyarrow as pa from typing_extensions import Self - from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.protocols import CompliantBaseFrame, CompliantDataFrame - from narwhals._plan.schema import FrozenSchema - from narwhals._plan.typing import Seq from narwhals.typing import NativeFrame @@ -60,27 +60,19 @@ def _from_compliant(cls, compliant: CompliantBaseFrame[Any, NativeFrameT], /) -> def to_native(self) -> NativeFrameT: return self._compliant.native - def _project( - self, - exprs: tuple[OneOrIterable[IntoExpr], ...], - named_exprs: dict[str, Any], - context: ExprContext, - /, - ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: - """Temp, while these parts aren't connected, this is easier for testing.""" - irs, schema_frozen, output_names = _expansion.prepare_projection( - _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema - ) - named_irs = _expansion.into_named_irs(irs, output_names) - return schema_frozen.project(named_irs, context) - def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, _ = self._project(exprs, named_exprs, ExprContext.SELECT) - return self._from_compliant(self._compliant.select(named_irs)) + named_irs, schema = prepare_projection( + _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self + ) + return self._from_compliant(self._compliant.select(schema.select_irs(named_irs))) def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, _ = self._project(exprs, named_exprs, ExprContext.WITH_COLUMNS) - return self._from_compliant(self._compliant.with_columns(named_irs)) + named_irs, schema = prepare_projection( + _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), schema=self + ) + return self._from_compliant( + self._compliant.with_columns(schema.with_columns_irs(named_irs)) + ) def sort( self, @@ -92,10 +84,16 @@ def sort( sort, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last ) - irs, _, output_names = _expansion.prepare_projection(sort, self.schema) - named_irs = _expansion.into_named_irs(irs, output_names) + named_irs, _ = prepare_projection(sort, schema=self) return self._from_compliant(self._compliant.sort(named_irs, opts)) + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: + return self._from_compliant(self._compliant.drop(columns, strict=strict)) + + def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: + subset = [subset] if isinstance(subset, str) else subset + return self._from_compliant(self._compliant.drop_nulls(subset)) + class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): _compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT] @@ -138,3 +136,29 @@ def to_dict( def __len__(self) -> int: return len(self._compliant) + + @overload + def group_by( + self, + *by: OneOrIterable[IntoExpr], + drop_null_keys: Literal[False] = ..., + **named_by: IntoExpr, + ) -> GroupBy[Self]: ... + + @overload + def group_by( + self, *by: OneOrIterable[str], drop_null_keys: Literal[True] + ) -> GroupBy[Self]: ... + + def group_by( + self, + *by: OneOrIterable[IntoExpr], + drop_null_keys: bool = False, + **named_by: IntoExpr, + ) -> GroupBy[Self]: + return Grouped.by(*by, drop_null_keys=drop_null_keys, **named_by).to_group_by( + self + ) + + def row(self, index: int) -> tuple[Any, ...]: + return self._compliant.row(index) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index b0f369bd77..7695c1d92f 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -103,6 +103,9 @@ def exclude(self, *names: OneOrIterable[str]) -> Self: def count(self) -> Self: return self._from_ir(agg.Count(expr=self._ir)) + def len(self) -> Self: + return self._from_ir(agg.Len(expr=self._ir)) + def max(self) -> Self: return self._from_ir(agg.Max(expr=self._ir)) diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 237ee36e81..4444bbd6be 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -5,6 +5,7 @@ NamedIR, SelectorIR, ) +from narwhals._plan._function import Function from narwhals._plan.expressions import ( aggregation, boolean, @@ -60,6 +61,7 @@ "Exclude", "ExprIR", "Filter", + "Function", "FunctionExpr", "IndexColumns", "InvertSelector", diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index 263ca300e5..0f26a82c10 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -33,7 +33,10 @@ def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: # fmt: off -class Count(AggExpr): ... +class Count(AggExpr): + """Non-null count.""" +class Len(AggExpr): + """Null-inclusive count.""" class Max(AggExpr): ... class Mean(AggExpr): ... class Median(AggExpr): ... diff --git a/narwhals/_plan/group_by.py b/narwhals/_plan/group_by.py new file mode 100644 index 0000000000..5e95bd484e --- /dev/null +++ b/narwhals/_plan/group_by.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic + +from narwhals._plan._parse import parse_into_seq_of_expr_ir +from narwhals._plan.protocols import GroupByResolver as Resolved, Grouper +from narwhals._plan.typing import DataFrameT + +if TYPE_CHECKING: + from collections.abc import Iterator + + from typing_extensions import Self + + from narwhals._plan.expressions import ExprIR + from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq + + +class GroupBy(Generic[DataFrameT]): + _frame: DataFrameT + _grouper: Grouped + + def __init__(self, frame: DataFrameT, grouper: Grouped, /) -> None: + self._frame = frame + self._grouper = grouper + + def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> DataFrameT: + frame = self._frame + return frame._from_compliant( + self._grouper.agg(*aggs, **named_aggs) + .resolve(frame) + .evaluate(frame._compliant) + ) + + def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: + frame = self._frame + resolver = self._grouper.agg().resolve(frame) + for key, df in frame._compliant.group_by_resolver(resolver): + yield key, frame._from_compliant(df) + + +class Grouped(Grouper["Resolved"]): + """Narwhals-level `GroupBy` builder.""" + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool + + @classmethod + def by( + cls, + *by: OneOrIterable[IntoExpr], + drop_null_keys: bool = False, + **named_by: IntoExpr, + ) -> Self: + obj = cls.__new__(cls) + obj._keys = parse_into_seq_of_expr_ir(*by, **named_by) + obj._drop_null_keys = drop_null_keys + return obj + + def agg(self, *aggs: OneOrIterable[IntoExpr], **named_aggs: IntoExpr) -> Self: + self._aggs = parse_into_seq_of_expr_ir(*aggs, **named_aggs) + return self + + @property + def _resolver(self) -> type[Resolved]: + return Resolved + + def to_group_by(self, frame: DataFrameT, /) -> GroupBy[DataFrameT]: + return GroupBy(frame, self) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 6f77674dff..303d07a097 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -9,10 +9,12 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence + import pyarrow.acero import pyarrow.compute as pc from typing_extensions import Self, TypeAlias - from narwhals._plan.typing import Accessor, OneOrIterable, Seq + from narwhals._plan.arrow.typing import NullPlacement + from narwhals._plan.typing import Accessor, OneOrIterable, Order, Seq from narwhals.typing import RankMethod DispatchOrigin: TypeAlias = Literal["expr", "__narwhals_namespace__"] @@ -170,7 +172,7 @@ def to_multiple(self, n_repeat: int = 1, /) -> SortMultipleOptions: nulls: Seq[bool] = (self.nulls_last,) else: desc = tuple(repeat(self.descending, n_repeat)) - nulls = tuple(repeat(self.nulls_last)) + nulls = tuple(repeat(self.nulls_last, n_repeat)) return SortMultipleOptions(descending=desc, nulls_last=nulls) @@ -193,9 +195,9 @@ def parse( nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) return SortMultipleOptions(descending=desc, nulls_last=nulls) - def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: - import pyarrow.compute as pc - + def _to_arrow_args( + self, by: Sequence[str] + ) -> tuple[Sequence[tuple[str, Order]], NullPlacement]: first = self.nulls_last[0] if len(self.nulls_last) != 1 and any(x != first for x in self.nulls_last[1:]): msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {self.nulls_last!r}" @@ -204,12 +206,23 @@ def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: descending: Iterable[bool] = repeat(self.descending[0], len(by)) else: descending = self.descending - sorting: list[tuple[str, Literal["ascending", "descending"]]] = [ + sorting = tuple[tuple[str, "Order"]]( (key, "descending" if desc else "ascending") for key, desc in zip(by, descending) - ] - placement: Literal["at_start", "at_end"] = "at_end" if first else "at_start" - return pc.SortOptions(sort_keys=sorting, null_placement=placement) + ) + return sorting, "at_end" if first else "at_start" + + def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: + import pyarrow.compute as pc + + sort_keys, placement = self._to_arrow_args(by) + return pc.SortOptions(sort_keys=sort_keys, null_placement=placement) + + def to_arrow_acero(self, by: Sequence[str]) -> pyarrow.acero.Declaration: + from narwhals._plan.arrow import acero + + sort_keys, placement = self._to_arrow_args(by) + return acero._order_by(sort_keys, null_placement=placement) class RankOptions(Immutable): diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 11a17eb081..cff5e790e8 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,12 +1,24 @@ +"""TODO: Split this module up into `narwhals._plan.compliant.*`.""" + from __future__ import annotations from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from itertools import chain from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._plan.common import flatten_hash_safe -from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq +from narwhals._plan._expansion import prepare_projection +from narwhals._plan._parse import parse_into_seq_of_expr_ir +from narwhals._plan.common import flatten_hash_safe, replace, temp +from narwhals._plan.typing import ( + IntoExpr, + NativeDataFrameT, + NativeFrameT, + NativeSeriesT, + Seq, +) from narwhals._typing_compat import TypeVar from narwhals._utils import Version +from narwhals.exceptions import ComputeError if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs @@ -15,6 +27,7 @@ from narwhals._plan.dataframe import BaseFrame, DataFrame from narwhals._plan.expressions import ( BinaryExpr, + ExprIR, FunctionExpr, NamedIR, aggregation as agg, @@ -25,6 +38,7 @@ from narwhals._plan.expressions.ranges import IntRange from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema from narwhals._plan.series import Series from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType @@ -49,6 +63,8 @@ ColumnT = TypeVar("ColumnT") ColumnT_co = TypeVar("ColumnT_co", covariant=True) +ResolverT_co = TypeVar("ResolverT_co", bound="GroupByResolver", covariant=True) + ExprAny: TypeAlias = "CompliantExpr[Any, Any]" ScalarAny: TypeAlias = "CompliantScalar[Any, Any]" SeriesAny: TypeAlias = "CompliantSeries[Any]" @@ -69,7 +85,9 @@ SeriesT = TypeVar("SeriesT", bound=SeriesAny) SeriesT_co = TypeVar("SeriesT_co", bound=SeriesAny, covariant=True) FrameT = TypeVar("FrameT", bound=FrameAny) +FrameT_co = TypeVar("FrameT_co", bound=FrameAny, covariant=True) FrameT_contra = TypeVar("FrameT_contra", bound=FrameAny, contravariant=True) +DataFrameT = TypeVar("DataFrameT", bound=DataFrameAny) NamespaceT_co = TypeVar("NamespaceT_co", bound="NamespaceAny", covariant=True) EagerExprT_co = TypeVar("EagerExprT_co", bound=EagerExprAny, covariant=True) @@ -199,6 +217,7 @@ def _with_native(self, native: Any, name: str, /) -> Self: return self.from_native(native, name or self.name, self.version) # series & scalar + def abs(self, node: FunctionExpr[F.Abs], frame: FrameT_contra, name: str) -> Self: ... def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ... def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... @@ -268,6 +287,9 @@ def quantile( def count( self, node: agg.Count, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def len( + self, node: agg.Len, frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... def max( self, node: agg.Max, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... @@ -374,6 +396,10 @@ def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: """Returns 0 if null, else 1.""" ... + def len(self, node: agg.Len, frame: FrameT_contra, name: str) -> Self: + """Returns 1.""" + ... + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) @@ -531,6 +557,8 @@ class CompliantBaseFrame(StoresVersion, Protocol[ColumnT_co, NativeFrameT]): def __narwhals_namespace__(self) -> Any: ... @property + def _group_by(self) -> type[CompliantGroupBy[Self]]: ... + @property def native(self) -> NativeFrameT: return self._native @@ -553,18 +581,44 @@ def _evaluate_irs( self, nodes: Iterable[NamedIR[ir.ExprIR]], / ) -> Iterator[ColumnT_co]: ... def select(self, irs: Seq[NamedIR]) -> Self: ... + def select_names(self, *column_names: str) -> Self: ... def with_columns(self, irs: Seq[NamedIR]) -> Self: ... def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: ... + def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: ... + def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... class CompliantDataFrame( CompliantBaseFrame[SeriesT, NativeDataFrameT], Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], ): + @property + def _group_by(self) -> type[DataFrameGroupBy[Self]]: ... + @property + def _grouper(self) -> type[Grouped]: + return Grouped + @classmethod def from_dict( cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: ... + def group_by_agg( + self, by: OneOrIterable[IntoExpr], aggs: OneOrIterable[IntoExpr], / + ) -> Self: + """Compliant-level `group_by(by).agg(agg)`, allows `Expr`.""" + return self._grouper.by(by).agg(aggs).resolve(self).evaluate(self) + + def group_by_names(self, names: Seq[str], /) -> DataFrameGroupBy[Self]: + """Compliant-level `group_by`, allowing only `str` keys.""" + return self._group_by.by_names(self, names) + + def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Self]: + """Narwhals-level resolved `group_by`. + + `keys`, `aggs` are already parsed and projections planned. + """ + return self._group_by.from_resolver(self, resolver) + def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... @@ -579,12 +633,15 @@ def to_dict( ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... def __len__(self) -> int: ... def with_row_index(self, name: str) -> Self: ... + def row(self, index: int) -> tuple[Any, ...]: ... class EagerDataFrame( CompliantDataFrame[SeriesT, NativeDataFrameT, NativeSeriesT], Protocol[SeriesT, NativeDataFrameT, NativeSeriesT], ): + @property + def _group_by(self) -> type[EagerDataFrameGroupBy[Self]]: ... def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... def select(self, irs: Seq[NamedIR]) -> Self: return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) @@ -650,3 +707,185 @@ def __len__(self) -> int: def to_list(self) -> list[Any]: ... def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: ... + + +class CompliantGroupBy(Protocol[FrameT_co]): + @property + def compliant(self) -> FrameT_co: ... + def agg(self, irs: Seq[NamedIR]) -> FrameT_co: ... + + +class DataFrameGroupBy(CompliantGroupBy[DataFrameT], Protocol[DataFrameT]): + _keys: Seq[NamedIR] + _key_names: Seq[str] + + @classmethod + def from_resolver( + cls, df: DataFrameT, resolver: GroupByResolver, / + ) -> DataFrameGroupBy[DataFrameT]: ... + @classmethod + def by_names( + cls, df: DataFrameT, names: Seq[str], / + ) -> DataFrameGroupBy[DataFrameT]: ... + def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: ... + @property + def keys(self) -> Seq[NamedIR]: + return self._keys + + @property + def key_names(self) -> Seq[str]: + if names := self._key_names: + return names + msg = "at least one key is required in a group_by operation" + raise ComputeError(msg) + + +class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]): + _df: EagerDataFrameT + _key_names: Seq[str] + _key_names_original: Seq[str] + _column_names_original: Seq[str] + + @classmethod + def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: + obj = cls.__new__(cls) + obj._df = df + obj._keys = () + obj._key_names = names + obj._key_names_original = () + obj._column_names_original = tuple(df.columns) + return obj + + @classmethod + def from_resolver( + cls, df: EagerDataFrameT, resolver: GroupByResolver, / + ) -> EagerDataFrameGroupBy[EagerDataFrameT]: + key_names = resolver.key_names + if not resolver.requires_projection(): + df = df.drop_nulls(key_names) if resolver._drop_null_keys else df + return cls.by_names(df, key_names) + obj = cls.__new__(cls) + unique_names = temp.column_names(chain(key_names, df.columns)) + safe_keys = tuple( + replace(key, name=name) for key, name in zip(resolver.keys, unique_names) + ) + obj._df = df.with_columns(resolver._schema_in.with_columns_irs(safe_keys)) + obj._keys = safe_keys + obj._key_names = tuple(e.name for e in safe_keys) + obj._key_names_original = key_names + obj._column_names_original = resolver._schema_in.names + return obj + + +class Grouper(Protocol[ResolverT_co]): + """`GroupBy` helper for collecting and forwarding `Expr`s for projection. + + - Uses `Expr` everywhere (no need to duplicate layers) + - Resolver only needs schema (neither needs a frame, but can use one to get `schema`) + """ + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool + + @classmethod + def by(cls, *by: OneOrIterable[IntoExpr]) -> Self: + obj = cls.__new__(cls) + obj._keys = parse_into_seq_of_expr_ir(*by) + return obj + + def agg(self, *aggs: OneOrIterable[IntoExpr]) -> Self: + self._aggs = parse_into_seq_of_expr_ir(*aggs) + return self + + @property + def _resolver(self) -> type[ResolverT_co]: ... + + def resolve(self, context: IntoFrozenSchema, /) -> ResolverT_co: + """Project keys and aggs in `context`, expanding all `Expr` -> `NamedIR`.""" + return self._resolver.from_grouper(self, context) + + +class GroupByResolver: + """Narwhals-level `GroupBy` resolver.""" + + _schema_in: FrozenSchema + _keys: Seq[NamedIR] + _aggs: Seq[NamedIR] + _key_names: Seq[str] + _schema: FrozenSchema + _drop_null_keys: bool + + @property + def keys(self) -> Seq[NamedIR]: + return self._keys + + @property + def aggs(self) -> Seq[NamedIR]: + return self._aggs + + @property + def key_names(self) -> Seq[str]: + if names := self._key_names: + return names + if keys := self.keys: + return tuple(e.name for e in keys) + msg = "at least one key is required in a group_by operation" + raise ComputeError(msg) + + @property + def schema(self) -> FrozenSchema: + return self._schema + + def evaluate(self, frame: DataFrameT) -> DataFrameT: + """Perform the `group_by` on `frame`.""" + return frame.group_by_resolver(self).agg(self.aggs) + + @classmethod + def from_grouper(cls, grouper: Grouper[Self], context: IntoFrozenSchema, /) -> Self: + """Loosely based on [`resolve_group_by`]. + + [`resolve_group_by`]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs#L1125-L1227 + """ + obj = cls.__new__(cls) + keys, schema_in = prepare_projection(grouper._keys, schema=context) + obj._keys, obj._schema_in = keys, schema_in + obj._key_names = tuple(e.name for e in keys) + obj._aggs, _ = prepare_projection(grouper._aggs, obj.key_names, schema=schema_in) + obj._schema = schema_in.select(keys).merge(schema_in.select(obj._aggs)) + obj._drop_null_keys = grouper._drop_null_keys + return obj + + def requires_projection(self, *, allow_aliasing: bool = False) -> bool: + """Return True is group keys contain anything that is not a column selection. + + Notes: + If False is returned, we can just use the resolved key names as a fast-path to group. + + Arguments: + allow_aliasing: If False (default), any aliasing is not considered to be column selection. + """ + if not all(key.is_column(allow_aliasing=allow_aliasing) for key in self.keys): + if self._drop_null_keys: + msg = "drop_null_keys cannot be True when keys contains Expr or Series" + raise NotImplementedError(msg) + return True + return False + + +class Resolved(GroupByResolver): + """Compliant-level `GroupBy` resolver.""" + + _drop_null_keys: bool = False + + +class Grouped(Grouper[Resolved]): + """Compliant-level `GroupBy` builder.""" + + _keys: Seq[ExprIR] + _aggs: Seq[ExprIR] + _drop_null_keys: bool = False + + @property + def _resolver(self) -> type[Resolved]: + return Resolved diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 4dbf5e6ef3..67433db06b 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -1,28 +1,28 @@ from __future__ import annotations -from collections import deque from collections.abc import Mapping from functools import lru_cache -from itertools import chain, repeat +from itertools import chain from types import MappingProxyType -from typing import TYPE_CHECKING, Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload from narwhals._plan._expr_ir import NamedIR from narwhals._plan._immutable import _IMMUTABLE_HASH_NAME, Immutable +from narwhals._utils import _hasattr_static from narwhals.dtypes import Unknown if TYPE_CHECKING: from collections.abc import ItemsView, Iterator, KeysView, ValuesView - from typing_extensions import TypeAlias + from typing_extensions import Never, TypeAlias, TypeIs - from narwhals._plan.contexts import ExprContext from narwhals._plan.typing import Seq from narwhals.dtypes import DType + from narwhals.typing import IntoSchema IntoFrozenSchema: TypeAlias = ( - "Mapping[str, DType] | Iterator[tuple[str, DType]] | FrozenSchema" + "IntoSchema | Iterator[tuple[str, DType]] | FrozenSchema | HasSchema" ) """A schema to freeze, or an already frozen one. @@ -41,16 +41,18 @@ class FrozenSchema(Immutable): __slots__ = ("_mapping",) _mapping: MappingProxyType[str, DType] - def project( - self, exprs: Seq[NamedIR], context: ExprContext - ) -> tuple[Seq[NamedIR], FrozenSchema]: - if context.is_select(): - return exprs, self._select(exprs) - if context.is_with_columns(): - return self._with_columns(exprs) - raise TypeError(context) + def __init_subclass__(cls, *_: Never, **__: Never) -> Never: + msg = f"Cannot subclass {cls.__name__!r}" + raise TypeError(msg) - def _select(self, exprs: Seq[NamedIR]) -> FrozenSchema: + def merge(self, other: FrozenSchema, /) -> FrozenSchema: + """Return a new schema, merging `other` with `self` (see [upstream]). + + [upstream]: https://github.com/pola-rs/polars/blob/cdd247aaba8db3332be0bd031e0f31bc3fc33f77/crates/polars-schema/src/schema.rs#L265-L274. + """ + return freeze_schema(self._mapping | other._mapping) + + def select(self, exprs: Seq[NamedIR]) -> FrozenSchema: """Return a new schema, equivalent to performing `df.select(*exprs)`. Arguments: @@ -64,18 +66,24 @@ def _select(self, exprs: Seq[NamedIR]) -> FrozenSchema: default = Unknown() return freeze_schema((name, self.get(name, default)) for name in names) - def _with_columns(self, exprs: Seq[NamedIR]) -> tuple[Seq[NamedIR], FrozenSchema]: - exprs_out = deque[NamedIR]() + def select_irs(self, exprs: Seq[NamedIR]) -> Seq[NamedIR]: + return exprs + + def with_columns(self, exprs: Seq[NamedIR]) -> FrozenSchema: + # similar to `merge`, but preserving known `DType`s + names = (e.name for e in exprs) + default = Unknown() + miss = {name: default for name in names if name not in self} + return freeze_schema(self._mapping | miss) + + def with_columns_irs(self, exprs: Seq[NamedIR]) -> Seq[NamedIR]: + """Required for `_concat_horizontal`-based `with_columns`. + + Fills in any unreferenced columns present in `self`, but not in `exprs` as selections. + """ named: dict[str, NamedIR[Any]] = {e.name: e for e in exprs} - items: IntoFrozenSchema - for name in self: - exprs_out.append(named.pop(name, NamedIR.from_name(name))) - if named: - items = chain(self.items(), zip(named, repeat(Unknown(), len(named)))) - exprs_out.extend(named.values()) - else: - items = self - return tuple(exprs_out), freeze_schema(items) + it = (named.pop(name, NamedIR.from_name(name)) for name in self) + return tuple(chain(it, named.values())) @property def __immutable_hash__(self) -> int: @@ -92,7 +100,9 @@ def names(self) -> FrozenColumns: @staticmethod def _from_mapping(mapping: MappingProxyType[str, DType], /) -> FrozenSchema: - return FrozenSchema(_mapping=mapping) + obj = FrozenSchema.__new__(FrozenSchema) + object.__setattr__(obj, "_mapping", mapping) + return obj @staticmethod def _from_hash_safe(items: _FrozenSchemaHash, /) -> FrozenSchema: @@ -134,6 +144,15 @@ def __repr__(self) -> str: return f"{type(self).__name__}([{nl}{indent}{items}{sep}{nl}])" +class HasSchema(Protocol): + @property + def schema(self) -> IntoSchema: ... + + +def has_schema(obj: Any) -> TypeIs[HasSchema]: + return _hasattr_static(obj, "schema") + + @overload def freeze_schema(mapping: IntoFrozenSchema, /) -> FrozenSchema: ... @overload @@ -143,7 +162,7 @@ def freeze_schema( ) -> FrozenSchema: if isinstance(iterable, FrozenSchema): return iterable - into = iterable or schema + into = iterable.schema if has_schema(iterable) else (iterable or schema) hashable = tuple(into.items() if isinstance(into, Mapping) else into) return _freeze_schema_cache(hashable) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 0efb81ea81..2a734488a6 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -10,6 +10,7 @@ from narwhals import dtypes from narwhals._plan._expr_ir import ExprIR, NamedIR, SelectorIR from narwhals._plan._function import Function + from narwhals._plan.dataframe import DataFrame from narwhals._plan.expr import Expr from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow @@ -25,6 +26,7 @@ ) __all__ = [ + "DataFrameT", "FunctionT", "IntoExpr", "IntoExprColumn", @@ -95,7 +97,7 @@ T = TypeVar("T") -Seq: TypeAlias = "tuple[T,...]" +Seq: TypeAlias = tuple[T, ...] """Immutable Sequence. Using instead of `Sequence`, as a `list` can be passed there (can't break immutability promise). @@ -107,3 +109,6 @@ IntoExprColumn: TypeAlias = "Expr | Series[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" OneOrIterable: TypeAlias = "T | t.Iterable[T]" +OneOrSeq: TypeAlias = t.Union[T, Seq[T]] +DataFrameT = TypeVar("DataFrameT", bound="DataFrame[t.Any, t.Any]") +Order: TypeAlias = t.Literal["ascending", "descending"] diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index ffada70747..7b7113e450 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -398,6 +398,11 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: raises=NotImplementedError, ), ), + pytest.param( + [nwp.col("g").len(), nwp.col("m").last(), nwp.col("h").count()], + {"g": [3], "m": [2], "h": [1]}, + id="len-count-with-nulls", + ), ], ids=_ids_ir, ) @@ -517,6 +522,23 @@ def test_first_last_expr_with_columns( assert_equal_data(result, {"result": expected_broadcast}) +@pytest.mark.parametrize( + ("index", "expected"), [(3, (None, 12, 0.9, 3, 3)), (1, (2, 5, 1.0, 1, 1))] +) +def test_row_is_py_literal( + data_indexed: dict[str, Any], index: int, expected: tuple[PythonLiteral, ...] +) -> None: + frame = nwp.DataFrame.from_native(pa.table(data_indexed)) + result = frame.row(index) + assert all(v is None or isinstance(v, (int, float)) for v in result) + assert result == expected + pytest.importorskip("polars") + import polars as pl + + polars_result = pl.DataFrame(data_indexed).row(index) + assert result == polars_result + + if TYPE_CHECKING: def test_protocol_expr() -> None: diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index a80724ff86..203c39911b 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -16,7 +16,7 @@ from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.schema import freeze_schema from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError -from tests.plan.utils import assert_expr_ir_equal +from tests.plan.utils import assert_expr_ir_equal, named_ir if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -257,19 +257,28 @@ def test_replace_selector( @pytest.mark.parametrize( ("into_exprs", "expected"), [ - ("a", [nwp.col("a")]), - (nwp.col("b", "c", "d"), [nwp.col("b"), nwp.col("c"), nwp.col("d")]), - (nwp.nth(6), [nwp.col("g")]), - (nwp.nth(9, 8, -5), [nwp.col("j"), nwp.col("i"), nwp.col("p")]), - ( + pytest.param("a", [nwp.col("a")], id="Col"), + pytest.param( + nwp.col("b", "c", "d"), + [nwp.col("b"), nwp.col("c"), nwp.col("d")], + id="Columns", + ), + pytest.param(nwp.nth(6), [nwp.col("g")], id="Nth"), + pytest.param( + nwp.nth(9, 8, -5), + [nwp.col("j"), nwp.col("i"), nwp.col("p")], + id="IndexColumns", + ), + pytest.param( [nwp.nth(2).alias("c again"), nwp.nth(-1, -2).name.to_uppercase()], [ - nwp.col("c").alias("c again"), - nwp.col("u").alias("U"), - nwp.col("s").alias("S"), + named_ir("c again", nwp.col("c")), + named_ir("U", nwp.col("u")), + named_ir("S", nwp.col("s")), ], + id="Nth-Alias-IndexColumns-Uppercase", ), - ( + pytest.param( nwp.all(), [ nwp.col("a"), @@ -293,82 +302,89 @@ def test_replace_selector( nwp.col("s"), nwp.col("u"), ], + id="All", ), - ( + pytest.param( (ndcs.numeric() - ndcs.by_dtype(nw.Float32(), nw.Float64())) .cast(nw.Int64) .mean() .name.suffix("_mean"), [ - nwp.col("a").cast(nw.Int64()).mean().alias("a_mean"), - nwp.col("b").cast(nw.Int64()).mean().alias("b_mean"), - nwp.col("c").cast(nw.Int64()).mean().alias("c_mean"), - nwp.col("d").cast(nw.Int64()).mean().alias("d_mean"), - nwp.col("e").cast(nw.Int64()).mean().alias("e_mean"), - nwp.col("f").cast(nw.Int64()).mean().alias("f_mean"), - nwp.col("g").cast(nw.Int64()).mean().alias("g_mean"), - nwp.col("h").cast(nw.Int64()).mean().alias("h_mean"), + named_ir("a_mean", nwp.col("a").cast(nw.Int64()).mean()), + named_ir("b_mean", nwp.col("b").cast(nw.Int64()).mean()), + named_ir("c_mean", nwp.col("c").cast(nw.Int64()).mean()), + named_ir("d_mean", nwp.col("d").cast(nw.Int64()).mean()), + named_ir("e_mean", nwp.col("e").cast(nw.Int64()).mean()), + named_ir("f_mean", nwp.col("f").cast(nw.Int64()).mean()), + named_ir("g_mean", nwp.col("g").cast(nw.Int64()).mean()), + named_ir("h_mean", nwp.col("h").cast(nw.Int64()).mean()), ], + id="Selector-SUB-Cast-Mean-Suffix", ), - ( + pytest.param( nwp.col("u").alias("1").alias("2").alias("3").alias("4").name.keep(), - # NOTE: Would be nice to rewrite with less intermediate steps - # but retrieving the root name is enough for now - [nwp.col("u").alias("1").alias("2").alias("3").alias("4").alias("u")], + [named_ir("u", nwp.col("u"))], + id="Alias-Etc-Keep", ), - ( + pytest.param( ( (ndcs.numeric() ^ (ndcs.matches(r"[abcdg]") | ndcs.by_name("i", "f"))) * 100 ).name.suffix("_mult_100"), [ - (nwp.col("e") * nwp.lit(100)).alias("e_mult_100"), - (nwp.col("h") * nwp.lit(100)).alias("h_mult_100"), - (nwp.col("j") * nwp.lit(100)).alias("j_mult_100"), + named_ir("e_mult_100", (nwp.col("e") * nwp.lit(100))), + named_ir("h_mult_100", (nwp.col("h") * nwp.lit(100))), + named_ir("j_mult_100", (nwp.col("j") * nwp.lit(100))), ], + id="Selector-XOR-OR-BinaryExpr-Suffix", ), - ( + pytest.param( ndcs.by_dtype(nw.Duration()) .dt.total_minutes() .name.map(lambda nm: f"total_mins: {nm!r} ?"), - [nwp.col("q").dt.total_minutes().alias("total_mins: 'q' ?")], + [named_ir("total_mins: 'q' ?", nwp.col("q").dt.total_minutes())], + id="ByDType-TotalMins-Name-Map", ), - ( + pytest.param( nwp.col("f", "g") .cast(nw.String) .str.starts_with("1") .all() .name.suffix("_all_starts_with_1"), [ - nwp.col("f") - .cast(nw.String) - .str.starts_with("1") - .all() - .alias("f_all_starts_with_1"), - nwp.col("g") - .cast(nw.String) - .str.starts_with("1") - .all() - .alias("g_all_starts_with_1"), + named_ir( + "f_all_starts_with_1", + nwp.col("f").cast(nw.String).str.starts_with("1").all(), + ), + named_ir( + "g_all_starts_with_1", + nwp.col("g").cast(nw.String).str.starts_with("1").all(), + ), ], + id="Cast-StartsWith-All-Suffix", ), - ( + pytest.param( nwp.col("a", "b") .first() .over("c", "e", order_by="d") .name.suffix("_first_over_part_order_1"), [ - nwp.col("a") - .first() - .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]) - .alias("a_first_over_part_order_1"), - nwp.col("b") - .first() - .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]) - .alias("b_first_over_part_order_1"), + named_ir( + "a_first_over_part_order_1", + nwp.col("a") + .first() + .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]), + ), + named_ir( + "b_first_over_part_order_1", + nwp.col("b") + .first() + .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]), + ), ], + id="First-Over-Partitioned-Ordered-Suffix", ), - ( + pytest.param( nwp.exclude(BIG_EXCLUDE), [ nwp.col("c"), @@ -379,42 +395,48 @@ def test_replace_selector( nwp.col("i"), nwp.col("j"), ], + id="Exclude", ), - ( + pytest.param( nwp.exclude(BIG_EXCLUDE).name.suffix("_2"), [ - nwp.col("c").alias("c_2"), - nwp.col("d").alias("d_2"), - nwp.col("f").alias("f_2"), - nwp.col("g").alias("g_2"), - nwp.col("h").alias("h_2"), - nwp.col("i").alias("i_2"), - nwp.col("j").alias("j_2"), + named_ir("c_2", nwp.col("c")), + named_ir("d_2", nwp.col("d")), + named_ir("f_2", nwp.col("f")), + named_ir("g_2", nwp.col("g")), + named_ir("h_2", nwp.col("h")), + named_ir("i_2", nwp.col("i")), + named_ir("j_2", nwp.col("j")), ], + id="Exclude-Suffix", ), - ( + pytest.param( nwp.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), [ - nwp.col("c") - .alias("c_min_over_order_by") - .min() - .over(order_by=[nwp.col("k")]) + named_ir( + "c_min_over_order_by", + nwp.col("c").min().over(order_by=[nwp.col("k")]), + ) ], + id="Alias-Min-Over-Order-By-Selector", ), pytest.param( (ndcs.by_name("a", "b", "c") / nwp.col("e").first()) .over("g", "f", order_by="f") .name.prefix("hi_"), [ - (nwp.col("a") / nwp.col("e").first()) - .over("g", "f", order_by="f") - .alias("hi_a"), - (nwp.col("b") / nwp.col("e").first()) - .over("g", "f", order_by="f") - .alias("hi_b"), - (nwp.col("c") / nwp.col("e").first()) - .over("g", "f", order_by="f") - .alias("hi_c"), + named_ir( + "hi_a", + (nwp.col("a") / nwp.col("e").first()).over("g", "f", order_by="f"), + ), + named_ir( + "hi_b", + (nwp.col("b") / nwp.col("e").first()).over("g", "f", order_by="f"), + ), + named_ir( + "hi_c", + (nwp.col("c") / nwp.col("e").first()).over("g", "f", order_by="f"), + ), ], id="Selector-BinaryExpr-Over-Prefix", ), @@ -426,7 +448,7 @@ def test_prepare_projection( schema_1: dict[str, DType], ) -> None: irs_in = parse_into_seq_of_expr_ir(into_exprs) - actual, _, _ = prepare_projection(irs_in, schema_1) + actual, _ = prepare_projection(irs_in, schema=schema_1) assert len(actual) == len(expected) for lhs, rhs in zip(actual, expected): assert_expr_ir_equal(lhs, rhs) @@ -451,7 +473,7 @@ def test_prepare_projection_duplicate(expr: nwp.Expr, schema_1: dict[str, DType] irs = parse_into_seq_of_expr_ir(expr.alias("dupe")) pattern = re.compile(r"\.alias\(.dupe.\)") with pytest.raises(DuplicateError, match=pattern): - prepare_projection(irs, schema_1) + prepare_projection(irs, schema=schema_1) @pytest.mark.parametrize( @@ -517,7 +539,7 @@ def test_prepare_projection_column_not_found( pattern = re.compile(rf"not found: {re.escape(repr(missing))}") irs = parse_into_seq_of_expr_ir(into_exprs) with pytest.raises(ColumnNotFoundError, match=pattern): - prepare_projection(irs, schema_1) + prepare_projection(irs, schema=schema_1) @pytest.mark.parametrize( @@ -554,15 +576,15 @@ def test_prepare_projection_horizontal_alias( expr = function(into_exprs) alias_1 = expr.alias("alias(x1)") irs = parse_into_seq_of_expr_ir(alias_1) - out_irs, _, _ = prepare_projection(irs, schema_1) + out_irs, _ = prepare_projection(irs, schema=schema_1) assert len(out_irs) == 1 - assert out_irs[0] == function("a", "b", "c").alias("alias(x1)")._ir + assert out_irs[0] == named_ir("alias(x1)", function("a", "b", "c")) alias_2 = alias_1.alias("alias(x2)") irs = parse_into_seq_of_expr_ir(alias_2) - out_irs, _, _ = prepare_projection(irs, schema_1) + out_irs, _ = prepare_projection(irs, schema=schema_1) assert len(out_irs) == 1 - assert out_irs[0] == function("a", "b", "c").alias("alias(x1)").alias("alias(x2)")._ir + assert out_irs[0] == named_ir("alias(x2)", function("a", "b", "c")) @pytest.mark.parametrize( @@ -574,4 +596,4 @@ def test_prepare_projection_index_error( irs = parse_into_seq_of_expr_ir(into_exprs) pattern = re.compile(r"invalid.+index.+nth", re.DOTALL | re.IGNORECASE) with pytest.raises(ComputeError, match=pattern): - prepare_projection(irs, schema_1) + prepare_projection(irs, schema=schema_1) diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index bf810aa176..455fecd114 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -14,7 +14,7 @@ ) from narwhals._plan.expressions.window import Over from narwhals.exceptions import InvalidOperationError -from tests.plan.utils import assert_expr_ir_equal +from tests.plan.utils import assert_expr_ir_equal, named_ir if TYPE_CHECKING: from narwhals._plan.typing import IntoExpr @@ -79,11 +79,6 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: assert_expr_ir_equal(lhs, rhs) -def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: - """Helper constructor for test compare.""" - return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) - - def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: expected = ( named_ir("a", nwp.col("a")), diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py new file mode 100644 index 0000000000..2b60c118db --- /dev/null +++ b/tests/plan/group_by_test.py @@ -0,0 +1,726 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +from narwhals import _plan as nwp +from narwhals._plan import selectors as npcs +from narwhals.exceptions import InvalidOperationError +from tests.utils import PYARROW_VERSION, assert_equal_data as _assert_equal_data + +pytest.importorskip("pyarrow") + + +import pyarrow as pa + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from narwhals._plan.typing import IntoExpr + + +def dataframe(data: dict[str, Any], /) -> nwp.DataFrame[Any, Any]: + return nwp.DataFrame.from_native(pa.table(data)) + + +def assert_equal_data(result: nwp.DataFrame, expected: Mapping[str, Any]) -> None: + _assert_equal_data(result.to_dict(as_series=False), expected) + + +def test_group_by_iter() -> None: + data = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} + df = dataframe(data) + expected_keys: list[tuple[int, ...]] = [(1,), (3,)] + keys = [] + for key, sub_df in df.group_by("a"): + if key == (1,): + expected = {"a": [1, 1], "b": [4, 4], "c": [7.0, 8.0]} + assert_equal_data(sub_df, expected) + assert isinstance(sub_df, nwp.DataFrame) + keys.append(key) + assert sorted(keys) == sorted(expected_keys) + expected_keys = [(1, 4), (3, 6)] + keys = [key for key, _ in df.group_by("a", "b")] + assert sorted(keys) == sorted(expected_keys) + keys = [key for key, _ in df.group_by("a", "b")] + assert sorted(keys) == sorted(expected_keys) + + +def test_group_by_nw_all() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]}) + result = df.group_by("a").agg(nwp.all().sum()).sort("a") + expected = {"a": [1, 2], "b": [9, 6], "c": [15, 9]} + assert_equal_data(result, expected) + result = df.group_by("a").agg(nwp.all().sum().name.suffix("_sum")).sort("a") + expected = {"a": [1, 2], "b_sum": [9, 6], "c_sum": [15, 9]} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("attr", "expected"), + [ + ("sum", {"a": [1, 2], "b": [3, 3]}), + ("mean", {"a": [1, 2], "b": [1.5, 3]}), + ("max", {"a": [1, 2], "b": [2, 3]}), + ("min", {"a": [1, 2], "b": [1, 3]}), + ("std", {"a": [1, 2], "b": [0.707107, None]}), + ("var", {"a": [1, 2], "b": [0.5, None]}), + ("len", {"a": [1, 2], "b": [3, 1]}), + ("n_unique", {"a": [1, 2], "b": [3, 1]}), + ("count", {"a": [1, 2], "b": [2, 1]}), + ], +) +def test_group_by_depth_1_agg(attr: str, expected: dict[str, list[Any]]) -> None: + data = {"a": [1, 1, 1, 2], "b": [1, None, 2, 3]} + expr = getattr(nwp.col("b"), attr)() + result = dataframe(data).group_by("a").agg(expr).sort("a") + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("values", "expected"), + [ + ( + {"x": [True, True, True, False, False, False]}, + {"all": [True, False, False], "any": [True, True, False]}, + ), + ( + {"x": [True, None, False, None, None, None]}, + {"all": [True, False, True], "any": [True, False, False]}, + ), + ], + ids=["not-nullable", "nullable"], +) +def test_group_by_depth_1_agg_bool_ops( + values: dict[str, list[bool]], expected: dict[str, list[bool]] +) -> None: + data = {"a": [1, 1, 2, 2, 3, 3], **values} + result = ( + dataframe(data) + .group_by("a") + .agg(nwp.col("x").all().alias("all"), nwp.col("x").any().alias("any")) + .sort("a") + ) + assert_equal_data(result, {"a": [1, 2, 3], **expected}) + + +@pytest.mark.parametrize( + ("attr", "ddof"), [("std", 0), ("var", 0), ("std", 2), ("var", 2)] +) +def test_group_by_depth_1_std_var(attr: str, ddof: int) -> None: + data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} + _pow = 0.5 if attr == "std" else 1 + expected = { + "a": [1, 2], + "b": [ + (sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow, + (sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow, + ], + } + expr = getattr(nwp.col("b"), attr)(ddof=ddof) + result = dataframe(data).group_by("a").agg(expr).sort("a") + assert_equal_data(result, expected) + + +def test_group_by_median() -> None: + data = {"a": [1, 1, 1, 2, 2, 2], "b": [5, 4, 6, 7, 3, 2]} + result = dataframe(data).group_by("a").agg(nwp.col("b").median()).sort("a") + expected = {"a": [1, 2], "b": [5, 3]} + assert_equal_data(result, expected) + + +def test_group_by_n_unique_w_missing() -> None: + data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} + result = ( + dataframe(data) + .group_by("a") + .agg( + nwp.col("b").n_unique(), + c_n_unique=nwp.col("c").n_unique(), + c_n_min=nwp.col("b").min(), + d_n_unique=nwp.col("d").n_unique(), + ) + .sort("a") + ) + expected = { + "a": [1, 2], + "b": [2, 1], + "c_n_unique": [1, 1], + "c_n_min": [4, 5], + "d_n_unique": [1, 1], + } + assert_equal_data(result, expected) + + +def test_group_by_simple_named() -> None: + data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} + df = dataframe(data) + result = ( + df.group_by("a").agg(b_min=nwp.col("b").min(), b_max=nwp.col("b").max()).sort("a") + ) + expected = {"a": [1, 2], "b_min": [4, 6], "b_max": [5, 6]} + assert_equal_data(result, expected) + + +def test_group_by_simple_unnamed() -> None: + data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} + df = dataframe(data) + result = df.group_by("a").agg(nwp.col("b").min(), nwp.col("c").max()).sort("a") + expected = {"a": [1, 2], "b": [4, 6], "c": [7, 1]} + assert_equal_data(result, expected) + + +def test_group_by_multiple_keys() -> None: + data = {"a": [1, 1, 2], "b": [4, 4, 6], "c": [7, 2, 1]} + df = dataframe(data) + result = ( + df.group_by("a", "b") + .agg(c_min=nwp.col("c").min(), c_max=nwp.col("c").max()) + .sort("a") + ) + expected = {"a": [1, 2], "b": [4, 6], "c_min": [2, 1], "c_max": [7, 1]} + assert_equal_data(result, expected) + + +def test_key_with_nulls() -> None: + data = {"b": [4, 5, None], "a": [1, 2, 3]} + result = ( + dataframe(data) + .group_by("b") + .agg(nwp.len(), nwp.col("a").min()) + .sort("a") + .with_columns(nwp.col("b").cast(nw.Float64)) + ) + expected = {"b": [4.0, 5, None], "len": [1, 1, 1], "a": [1, 2, 3]} + assert_equal_data(result, expected) + + +def test_key_with_nulls_ignored() -> None: + data = {"b": [4, 5, None], "a": [1, 2, 3]} + result = ( + dataframe(data) + .group_by("b", drop_null_keys=True) + .agg(nwp.len(), nwp.col("a").min()) + .sort("a") + .with_columns(nwp.col("b").cast(nw.Float64)) + ) + expected = {"b": [4.0, 5], "len": [1, 1], "a": [1, 2]} + assert_equal_data(result, expected) + + +def test_key_with_nulls_iter() -> None: + data = { + "b": [None, "4", "5", None, "7"], + "a": [None, 1, 2, 3, 4], + "c": [None, "4", "3", None, None], + } + result = dict(dataframe(data).group_by("b", "c", drop_null_keys=True).__iter__()) + + assert len(result) == 2 + assert_equal_data(result[("4", "4")], {"b": ["4"], "a": [1], "c": ["4"]}) + assert_equal_data(result[("5", "3")], {"b": ["5"], "a": [2], "c": ["3"]}) + + result = dict(dataframe(data).group_by("b", "c", drop_null_keys=False).__iter__()) + assert_equal_data(result[("4", "4")], {"b": ["4"], "a": [1], "c": ["4"]}) + assert_equal_data(result[("5", "3")], {"b": ["5"], "a": [2], "c": ["3"]}) + assert len(result) == 4 + + +def test_group_by_expr_iter() -> None: + data = { + "b": [None, "4", "5", None, "7"], + "a": [None, 1, 2, 3, 4], + "c": ["1", "4", "3", "1", "1"], + } + + expected = { + ("1",): {"b": [None, None, "7"], "a": [None, 3, 4], "c": ["1", "1", "1"]}, + ("3",): {"b": ["5"], "a": [2], "c": ["3"]}, + ("4",): {"b": ["4"], "a": [1], "c": ["4"]}, + } + grouped = dataframe(data).group_by(nwp.col("c").alias("d")) + result = dict(sorted((k, df.sort("c").to_dict(as_series=False)) for k, df in grouped)) + assert len(result) == len(expected) + assert result.keys() == expected.keys() + # NOTE: The bug this is trying to avoid regressing on would break zipping, as one side has more columns + result_p1 = next(iter(result.values())) + expected_p1 = next(iter(expected.values())) + assert result_p1 == expected_p1 + _assert_equal_data(result, expected) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "keys", [[nwp.col("a").abs()], ["a", nwp.col("a").abs().alias("a_test")]] +) +def test_group_by_raise_drop_null_keys_with_exprs(keys: list[nwp.Expr | str]) -> None: + data = {"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 4], "y": [0.5, -0.5, 1.0, -1.0, 1.5]} + df = dataframe(data) + with pytest.raises( + NotImplementedError, match="drop_null_keys cannot be True when keys contains Expr" + ): + df.group_by(*keys, drop_null_keys=True).agg(nwp.sum("y")) # type: ignore[call-overload] + + +def test_no_agg() -> None: + data = {"a": [1, 1, 3], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} + result = dataframe(data).group_by(["a", "b"]).agg().sort("a", "b") + expected = {"a": [1, 3], "b": [4, 6]} + assert_equal_data(result, expected) + + +@pytest.mark.xfail( + PYARROW_VERSION < (15,), + reason=( + "The defaults for grouping by categories in pandas are different.\n\n" + "https://github.com/narwhals-dev/narwhals/issues/1078" + ), +) +def test_group_by_categorical() -> None: + data = {"g1": ["a", "a", "b", "b"], "g2": ["x", "y", "x", "z"], "x": [1, 2, 3, 4]} + df = dataframe(data) + result = ( + df.with_columns( + g1=nwp.col("g1").cast(nw.Categorical()), + g2=nwp.col("g2").cast(nw.Categorical()), + ) + .group_by(["g1", "g2"]) + .agg(nwp.col("x").sum()) + .sort("x") + ) + assert_equal_data(result, data) + + +@pytest.mark.parametrize( + ("agg", "message_body", "expected_repr"), + [ + (nwp.col("a").shift(1), r"shift.+not.+group_by.+pyarrow.+", "col('a').shift("), + ( + nwp.col("a").arg_max(), + r"arg_max.+not.+group_by.+pyarrow.+", + "col('a').arg_max(", + ), + ( + nwp.col("a").max().over("b"), + r"over.+not.+group_by.+pyarrow.+", + "col('a').max().over([col('b')])", + ), + ( + nwp.col("a").drop_nulls().abs().mean(), + r"complex aggregation found.+not.+group_by.+pyarrow.+", + "col('a').drop_nulls().abs().mean()", + ), + ], +) +def test_group_by_unsupported_raises( + agg: nwp.Expr, message_body: str, expected_repr: str +) -> None: + df = dataframe({"a": [1, 2, 3], "b": [1, 1, 2]}) + pat = re.compile(rf"{message_body}{re.escape(expected_repr)}", re.DOTALL) + with pytest.raises(InvalidOperationError, match=pat): + df.group_by("b").agg(agg) + + +def test_double_same_aggregation() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) + result = df.group_by("a").agg(c=nwp.col("b").mean(), d=nwp.col("b").mean()).sort("a") + expected = {"a": [1, 2], "c": [4.5, 6], "d": [4.5, 6]} + assert_equal_data(result, expected) + + +def test_all_kind_of_aggs() -> None: + df = dataframe({"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]}) + result = ( + df.group_by("a") + .agg( + c=nwp.col("b").mean(), + d=nwp.col("b").mean(), + e=nwp.col("b").std(ddof=1), + f=nwp.col("b").std(ddof=2), + g=nwp.col("b").var(ddof=2), + h=nwp.col("b").var(ddof=2), + i=nwp.col("b").n_unique(), + ) + .sort("a") + ) + + variance_num = sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) + expected = { + "a": [1, 2], + "c": [5, 10 / 3], + "d": [5, 10 / 3], + "e": [1, (variance_num / (3 - 1)) ** 0.5], + "f": [2**0.5, (variance_num) ** 0.5], # denominator is 1 (=3-2) + "g": [2.0, variance_num], # denominator is 1 (=3-2) + "h": [2.0, variance_num], # denominator is 1 (=3-2) + "i": [3, 2], + } + assert_equal_data(result, expected) + + +def test_fancy_functions() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) + result = df.group_by("a").agg(nwp.all().std(ddof=0)).sort("a") + expected = {"a": [1, 2], "b": [0.5, 0.0]} + assert_equal_data(result, expected) + result = df.group_by("a").agg(npcs.numeric().std(ddof=0)).sort("a") + assert_equal_data(result, expected) + result = df.group_by("a").agg(npcs.matches("b").std(ddof=0)).sort("a") + assert_equal_data(result, expected) + result = df.group_by("a").agg(npcs.matches("b").std(ddof=0).alias("c")).sort("a") + expected = {"a": [1, 2], "c": [0.5, 0.0]} + assert_equal_data(result, expected) + result = ( + df.group_by("a") + .agg(npcs.matches("b").std(ddof=0).name.map(lambda _x: "c")) + .sort("a") + ) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected", "sort_by"), + [ + ( + [nwp.col("a").abs(), nwp.col("a").abs().alias("a_with_alias")], + [nwp.col("x").sum()], + {"a": [1, 2], "a_with_alias": [1, 2], "x": [5, 5]}, + ["a"], + ), + ( + [nwp.col("a").alias("x")], + [nwp.col("x").mean().alias("y")], + {"x": [-1, 1, 2], "y": [4.0, 0.5, 2.5]}, + ["x"], + ), + ( + [nwp.col("a")], + [nwp.col("a").count().alias("foo-bar"), nwp.all().sum()], + {"a": [-1, 1, 2], "foo-bar": [1, 2, 2], "x": [4, 1, 5], "y": [1.5, 0, 0]}, + ["a"], + ), + ( + [nwp.col("a", "y").abs()], + [nwp.col("x").sum()], + {"a": [1, 1, 2], "y": [0.5, 1.5, 1], "x": [1, 4, 5]}, + ["a", "y"], + ), + ( + [nwp.col("a").abs().alias("y")], + [nwp.all().sum().name.suffix("c")], + {"y": [1, 2], "ac": [1, 4], "xc": [5, 5]}, + ["y"], + ), + ( + [npcs.by_dtype(nw.Float64()).abs()], + [npcs.numeric().sum()], + {"y": [0.5, 1.0, 1.5], "a": [2, 4, -1], "x": [1, 5, 4]}, + ["y"], + ), + ], +) +def test_group_by_expr( + keys: list[nwp.Expr], + aggs: list[nwp.Expr], + expected: dict[str, list[Any]], + sort_by: list[str], +) -> None: + data = {"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 4], "y": [0.5, -0.5, 1.0, -1.0, 1.5]} + df = dataframe(data) + result = df.group_by(*keys).agg(*aggs).sort(*sort_by) + assert_equal_data(result, expected) + + +def test_group_by_expr_2757684799() -> None: + """From [narwhals-dev/narwhals#2325-2757684799]. + + The **incorrect** result is: + + {'b': [2, 1], 'a': [2, 1], 'c': [2.0, 1.0]} + + [narwhals-dev/narwhals#2325-2757684799]: https://github.com/narwhals-dev/narwhals/pull/2325#pullrequestreview-2757684799 + """ + data: dict[str, Any] = {"a": [1, 1, 2], "b": [4, 5, 6], "unrelated": [10, -1, -9]} + df = dataframe(data) + keys = nwp.col("a").alias("b"), "a" + aggs = nwp.col("b").mean().alias("c") + expected = {"b": [2, 1], "a": [2, 1], "c": [6.0, 4.5]} + + result = df.group_by(keys).agg(aggs).sort("b", descending=True) + assert_equal_data(result, expected) + + +def test_group_by_selector() -> None: + data = { + "a": [1, 1, 1], + "b": [4, 4, 6], + "c": ["foo", "foo", "bar"], + "x": [7.5, 8.5, 9.0], + } + result = ( + dataframe(data) + .group_by(npcs.by_dtype(nw.Int64), "c") + .agg(nwp.col("x").mean()) + .sort("a", "b") + ) + expected = {"a": [1, 1], "b": [4, 6], "c": ["foo", "bar"], "x": [8.0, 9.0]} + assert_equal_data(result, expected) + + +def test_renaming_edge_case() -> None: + data = {"a": [0, 0, 0], "_a_tmp": [1, 2, 3], "b": [4, 5, 6]} + result = dataframe(data).group_by(nwp.col("a")).agg(nwp.all().min()) + expected = {"a": [0], "_a_tmp": [1], "b": [4]} + assert_equal_data(result, expected) + + +def test_group_by_len_1_column() -> None: + """Based on a failure from marimo. + + - https://github.com/marimo-team/marimo/blob/036fd3ff89ef3a0e598bebb166637028024f98bc/tests/_plugins/ui/_impl/tables/test_narwhals.py#L1098-L1108 + - https://github.com/marimo-team/marimo/blob/036fd3ff89ef3a0e598bebb166637028024f98bc/marimo/_plugins/ui/_impl/tables/narwhals_table.py#L163-L188 + """ + data = {"a": [1, 2, 1, 2, 3, 4]} + expected = {"a": [1, 2, 3, 4], "len": [2, 2, 1, 1], "len_a": [2, 2, 1, 1]} + result = ( + dataframe(data).group_by("a").agg(nwp.len(), nwp.len().alias("len_a")).sort("a") + ) + assert_equal_data(result, expected) + + +def test_top_level_len() -> None: + # https://github.com/holoviz/holoviews/pull/6567#issuecomment-3178743331 + df = dataframe({"gender": ["m", "f", "f"], "weight": [4, 5, 6], "age": [None, 8, 9]}) + result = df.group_by(["gender"]).agg(nwp.all().len()).sort("gender") + expected = {"gender": ["f", "m"], "weight": [2, 1], "age": [2, 1]} + assert_equal_data(result, expected) + result = ( + df.group_by("gender") + .agg(nwp.col("weight").len(), nwp.col("age").len()) + .sort("gender") + ) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected", "pre_sort"), + [ + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 2, 4, 6]}, None), + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 3, 5, 6]}, {"descending": True}), + (["a"], ["c"], {"a": [1, 2, 3, 4], "c": [None, "A", None, "B"]}, None), + ( + ["a"], + ["c"], + {"a": [1, 2, 3, 4], "c": [None, "A", "B", "B"]}, + {"nulls_last": True}, + ), + ], + ids=["no-sort", "sort-descending", "NA-order-nulls-first", "NA-order-nulls-last"], +) +def test_group_by_agg_first( + keys: Sequence[str], + aggs: Sequence[str], + expected: Mapping[str, Any], + pre_sort: Mapping[str, Any] | None, +) -> None: + data = { + "a": [1, 2, 2, 3, 3, 4], + "b": [1, 2, 3, 4, 5, 6], + "c": [None, "A", "A", None, "B", "B"], + } + df = dataframe(data) + if pre_sort: + df = df.sort(aggs, **pre_sort) + result = df.group_by(keys).agg(nwp.col(aggs).first()).sort(keys) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected", "pre_sort"), + [ + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 3, 5, 6]}, None), + (["a"], ["b"], {"a": [1, 2, 3, 4], "b": [1, 2, 4, 6]}, {"descending": True}), + (["a"], ["c"], {"a": [1, 2, 3, 4], "c": [None, "A", "B", "B"]}, None), + ( + ["a"], + ["c"], + {"a": [1, 2, 3, 4], "c": [None, "A", None, "B"]}, + {"nulls_last": True}, + ), + ], + ids=["no-sort", "sort-descending", "NA-order-nulls-first", "NA-order-nulls-last"], +) +def test_group_by_agg_last( + keys: Sequence[str], + aggs: Sequence[str], + expected: Mapping[str, Any], + pre_sort: Mapping[str, Any] | None, +) -> None: + data = { + "a": [1, 2, 2, 3, 3, 4], + "b": [1, 2, 3, 4, 5, 6], + "c": [None, "A", "A", None, "B", "B"], + } + df = dataframe(data) + if pre_sort: + df = df.sort(aggs, **pre_sort) + result = df.group_by(keys).agg(nwp.col(aggs).last()).sort(keys) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected"), + [ + (["a"], [nwp.col("b").unique()], {"a": ["a", "b", "c"], "b": [[1], [2, 3], [3]]}), + ( + ["a"], + [nwp.col("b", "d").unique()], + { + "a": ["a", "b", "c"], + "b": [[1], [2, 3], [3]], + "d": [["three", "one"], ["three"], ["one"]], + }, + ), + ( + ["d", "c"], + [npcs.string().unique(), nwp.col("b").first().alias("b_first")], + { + "d": ["one", "one", "three", "three", "three"], + "c": [1, 3, 2, 4, 5], + "a": [["c"], ["a"], ["b"], ["b"], ["a"]], + "b_first": [3, 1, 3, 2, 1], + }, + ), + ], + ids=["Unique-Single", "Unique-Multi", "Unique-Selector-Fancy"], +) +def test_group_by_agg_unique( + keys: Sequence[str], aggs: Sequence[IntoExpr], expected: Mapping[str, Any] +) -> None: + data = { + "a": ["a", "b", "a", "b", "c"], + "b": [1, 2, 1, 3, 3], + "c": [5, 4, 3, 2, 1], + "d": ["three", "three", "one", "three", "one"], + } + df = dataframe(data) + result = df.group_by(keys).agg(aggs).sort(keys) + assert_equal_data(result, expected) + + +def test_group_by_args() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L302-L325 + """ + data = { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + df = dataframe(data) + + # Single column name + assert df.group_by("a").agg("b").columns == ["a", "b"] + # Column names as list + expected = ["a", "b", "c"] + assert df.group_by(["a", "b"]).agg("c").columns == expected + # Column names as positional arguments + assert df.group_by("a", "b").agg("c").columns == expected + # With keyword argument + assert df.group_by("a", "b", drop_null_keys=True).agg("c").columns == expected + # Multiple aggregations as list + assert df.group_by("a").agg(["b", "c"]).columns == expected + # Multiple aggregations as positional arguments + assert df.group_by("a").agg("b", "c").columns == expected + # Multiple aggregations as keyword arguments + assert df.group_by("a").agg(q="b", r="c").columns == ["a", "q", "r"] + + +def test_group_by_all() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L568-L577 + """ + data = {"a": [1, 2], "b": [1, 2]} + df = dataframe(data) + expected = {"a": [1, 2], "b": [1, 2], "a_agg": [1, 2]} + result = df.group_by(nwp.all()).agg(nwp.col("a").max().name.suffix("_agg")).sort("a") + assert_equal_data(result, expected) + + +def test_group_by_input_independent_with_len_23868() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L1476-L1484 + """ + data = {"a": ["A", "B", "C"]} + expected = {"literal": ["G"], "len": [3]} + result = dataframe(data).group_by(nwp.lit("G")).agg(nwp.len()) + assert_equal_data(result, expected) + + +def test_group_by_series_lit_22103() -> None: + """Adapted from [upstream], but rejecting for now. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L1406-L1424 + """ + data = {"g": [0, 1]} + series = nwp.Series.from_native(pa.chunked_array([[42, 2, 3]])) + df = dataframe(data) + with pytest.raises(InvalidOperationError, match=re.escape("foo=lit(Series)")): + df.group_by("g").agg(foo=series) + + +def test_group_by_named() -> None: + """Adapted from [upstream]. + + [upstream]: https://github.com/pola-rs/polars/blob/04dbc94c36f75ed05bb19587f2226e240ec1775f/py-polars/tests/unit/operations/test_group_by.py#L878-884 + """ + data = {"a": [1, 1, 2, 2, 3, 3], "b": range(6)} + df = dataframe(data) + result = df.group_by(z=nwp.col("a") * 2).agg(nwp.col("b").min()).sort("b") + expected = ( + df.group_by((nwp.col("a") * 2).alias("z")).agg(nwp.col("b").min()).sort("b") + ) + assert_equal_data(result, expected.to_dict(as_series=False)) + + +def test_group_by_exclude_keys() -> None: + # `group_by(keys)` and `exclude` share some logic + data = { + "a": ["A", "B", "A"], + "b": [1, 2, 3], + "c": [9, 2, 4], + "d": [8, 7, 8], + "e": [None, 9, 7], + "f": [True, False, None], + "g": [False, None, False], + "h": [None, None, True], + "j": [12.1, None, 4.0], + "k": [42, 10, None], + "l": [4, 5, 6], + "m": [0, 1, 2], + } + df = dataframe(data).with_columns( + npcs.boolean().fill_null(False), npcs.numeric().fill_null(0) + ) + exclude = "b", "c", "d", "e", "f", "g", "j", "k", "l", "m" + result = df.group_by(nwp.exclude(exclude)).agg(npcs.all().sum()).sort("a", "h") + expected = { + "a": ["A", "A", "B"], + "h": [False, True, False], + "b": [1, 3, 2], + "c": [9, 4, 2], + "d": [8, 8, 7], + "e": [0, 7, 9], + "f": [1, 0, 0], + "g": [0, 0, 0], + "j": [12.1, 4.0, 0.0], + "k": [42, 0, 10], + "l": [4, 6, 5], + "m": [0, 2, 1], + } + assert_equal_data(result, expected) diff --git a/tests/plan/temp_test.py b/tests/plan/temp_test.py new file mode 100644 index 0000000000..9dd7a0e42f --- /dev/null +++ b/tests/plan/temp_test.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import random +import re +import string + +# ruff: noqa: S311 +from collections import deque +from itertools import islice, product, repeat +from typing import TYPE_CHECKING, NamedTuple + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +import narwhals as nw +from narwhals._plan.common import temp +from narwhals._utils import qualified_type_name +from narwhals.exceptions import NarwhalsError + +pytest.importorskip("pyarrow") +pytest.importorskip("polars") + + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from narwhals._utils import _StoresColumns + + +class MockStoresColumns(NamedTuple): + columns: Sequence[str] + + +_COLUMNS = ("abc", "XYZ", "nw2929023", "column", string.hexdigits) +_EMPTY_SCHEMA = nw.Schema((name, nw.Int64()) for name in _COLUMNS) + + +sources = pytest.mark.parametrize( + "source", + [ + _COLUMNS, + MockStoresColumns(columns=_COLUMNS), + deque(_COLUMNS), + nw.from_dict({}, _EMPTY_SCHEMA, backend="pyarrow"), + dict.fromkeys(_COLUMNS), + set(_COLUMNS), + nw.from_dict({}, _EMPTY_SCHEMA, backend="polars").to_native(), + ], + ids=qualified_type_name, +) + + +@sources +def test_temp_column_name_sources(source: _StoresColumns | Iterable[str]) -> None: + name = temp.column_name(source) + assert name not in _COLUMNS + + +@sources +def test_temp_column_names_sources(source: _StoresColumns | Iterable[str]) -> None: + it = temp.column_names(source) + name = next(it) + assert name not in _COLUMNS + + +@given(n_chars=st.integers(6, 106)) +@pytest.mark.slow +def test_temp_column_name_n_chars(n_chars: int) -> None: + name = temp.column_name(_COLUMNS, n_chars=n_chars) + assert name not in _COLUMNS + + +@given(n_new_names=st.integers(10_000, 100_000)) +@pytest.mark.slow +def test_temp_column_names_always_new_names(n_new_names: int) -> None: + it = temp.column_names(_COLUMNS) + new_names = set(islice(it, n_new_names)) + assert len(new_names) == n_new_names + assert new_names.isdisjoint(_COLUMNS) + + +@pytest.mark.parametrize( + ("prefix", "n_chars"), + [ + ("nw", random.randint(0, 5)), + ("col", random.randint(0, 4)), + ("NW_", random.randint(0, 3)), + ("join", random.randint(0, 2)), + ("__tmp", random.randint(0, 1)), + ("longer", random.randint(-5, 0)), + ("", random.randint(0, 5)), + ], +) +def test_temp_column_name_requires_more_characters(prefix: str, n_chars: int) -> None: + pattern = re.compile( + rf"temp.+column.+name.+requires.+try.+shorter.+{prefix}.+higher.+{n_chars}", + re.IGNORECASE | re.DOTALL, + ) + with pytest.raises(NarwhalsError, match=pattern): + temp.column_name(_COLUMNS, prefix=prefix, n_chars=n_chars) + + +def test_temp_column_name_failed_unique() -> None: + hex_lower = string.hexdigits.strip(string.ascii_uppercase) + every_possible_name_65k = [ + f"nw{e1}{e2}{e3}{e4}" for e1, e2, e3, e4 in product(*repeat(hex_lower, 4)) + ] + n_many_columns = len(every_possible_name_65k) + + pattern = re.compile( + rf"unable.+generate.+name.+n_chars=6.+within.+existing.+{n_many_columns}.+columns", + re.DOTALL, + ) + with pytest.raises(NarwhalsError, match=pattern): + temp.column_name(every_possible_name_65k, prefix="nw", n_chars=6) + + +def test_temp_column_names_failed_unique() -> None: + it = temp.column_names(["a", "b", "c"], prefix="long_prefix", n_chars=16) + pattern = re.compile( + r"unable.+generate.+name.+n_chars=16.+within.+existing.+.+columns.+\.\.\.", + re.DOTALL, + ) + with pytest.raises(NarwhalsError, match=pattern): + list(islice(it, 100_000)) diff --git a/tests/plan/utils.py b/tests/plan/utils.py index bf6135ee2f..d1ae2ce95e 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -36,9 +36,18 @@ def assert_expr_ir_equal( """ lhs = _unwrap_ir(actual) if isinstance(expected, str): - assert repr(lhs) == expected + assert repr(lhs) == expected, ( + f"\nlhs:\n {lhs!r}\n\nexpected:\n {expected!r}" + ) elif isinstance(actual, ir.NamedIR) and isinstance(expected, ir.NamedIR): - assert actual == expected + assert actual == expected, ( + f"\nactual:\n {actual!r}\n\nexpected:\n {expected!r}" + ) else: rhs = expected._ir if isinstance(expected, nwp.Expr) else expected - assert lhs == rhs + assert lhs == rhs, f"\nlhs:\n {lhs!r}\n\nrhs:\n {rhs!r}" + + +def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: + """Helper constructor for test compare.""" + return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name)