diff --git a/_typos.toml b/_typos.toml index 74e66bbe0a..8f80cbcf3d 100644 --- a/_typos.toml +++ b/_typos.toml @@ -5,6 +5,7 @@ ba = "ba" # Used as column name in docstring examples (way too much?) iy = "iy" # Used as column name (once in a test) pn = "pn" # Used in docs: pn = PandasLikeNamespace(...) TYP = "TYP" # Used in flake8 rule +arange = "arange" # Used in numpy, polars, pyarrow [files] extend-exclude = ["tests/data/*"] diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index abfeb6c1b0..d6e2f4cdaf 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -45,6 +45,8 @@ binary_expr_multi_output_error, column_not_found_error, duplicate_error, + expand_multi_output_error, + selectors_not_found_error, ) from narwhals._plan.expressions import ( Alias, @@ -57,7 +59,6 @@ from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema, freeze_schema from narwhals._typing_compat import assert_never from narwhals._utils import check_column_names_are_unique, zip_strict -from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: from collections.abc import Collection, Iterable, Iterator, Sequence @@ -99,7 +100,12 @@ def prepare_projection( def expand_selector_irs_names( - selectors: Sequence[SelectorIR], /, ignored: Ignored = (), *, schema: IntoFrozenSchema + selectors: Sequence[SelectorIR], + /, + ignored: Ignored = (), + *, + schema: IntoFrozenSchema, + require_any: bool = False, ) -> OutputNames: """Expand selector-only input into the column names that match. @@ -110,11 +116,15 @@ def expand_selector_irs_names( selectors: IRs that **only** contain subclasses of `SelectorIR`. ignored: Names of `group_by` columns. schema: Scope to expand selectors in. + require_any: Raise if the entire expansion selected zero columns. """ - names = tuple(Expander(schema, ignored).iter_expand_selector_names(selectors)) - if len(names) != len(set(names)): - # NOTE: Can't easily reuse `duplicate_error`, falling back to main for now - check_column_names_are_unique(names) + expander = Expander(schema, ignored) + if names := tuple(expander.iter_expand_selector_names(selectors)): + if len(names) != len(set(names)): + # NOTE: Can't easily reuse `duplicate_error`, falling back to main for now + check_column_names_are_unique(names) + elif require_any: + raise selectors_not_found_error(selectors, expander.schema) return names @@ -245,15 +255,14 @@ def _expand_inner(self, children: Seq[ExprIR], /) -> Iterator[ExprIR]: for child in children: yield from self._expand_recursive(child) - def _expand_only(self, child: ExprIR, /) -> ExprIR: + def _expand_only(self, origin: ExprIR, child: ExprIR, /) -> ExprIR: # used by # - `_expand_combination` (ExprIR fields) # - `_expand_function_expr` (all others that have len(inputs)>=2, call on non-root) iterable = self._expand_recursive(child) first = next(iterable) if second := next(iterable, None): - msg = f"Multi-output expressions are not supported in this context, got: `{second!r}`" # pragma: no cover - raise MultiOutputExpressionError(msg) # pragma: no cover + raise expand_multi_output_error(origin, child, first, second, *iterable) return first # TODO @dangotbanned: It works, but all this class-specific branching belongs in the classes themselves @@ -268,16 +277,16 @@ def _expand_combination(self, origin: Combination, /) -> Iterator[Combination]: elif isinstance(origin, ir.SortBy): changes["by"] = tuple(self._expand_inner(origin.by)) else: - changes["by"] = self._expand_only(origin.by) + changes["by"] = self._expand_only(origin, origin.by) replaced = common.replace(origin, **changes) for root in self._expand_recursive(replaced.expr): yield common.replace(replaced, expr=root) elif isinstance(origin, ir.BinaryExpr): yield from self._expand_binary_expr(origin) elif isinstance(origin, ir.TernaryExpr): - changes["truthy"] = self._expand_only(origin.truthy) - changes["predicate"] = self._expand_only(origin.predicate) - changes["falsy"] = self._expand_only(origin.falsy) + changes["truthy"] = self._expand_only(origin, origin.truthy) + changes["predicate"] = self._expand_only(origin, origin.predicate) + changes["falsy"] = self._expand_only(origin, origin.falsy) yield origin.__replace__(**changes) else: assert_never(origin) @@ -316,7 +325,7 @@ def _expand_function_expr( yield origin.__replace__(input=reduced) else: if non_root := origin.input[1:]: - children = tuple(self._expand_only(child) for child in non_root) + children = tuple(self._expand_only(origin, child) for child in non_root) else: children = () for root in self._expand_recursive(origin.input[0]): diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index da6b84880d..8231af57ac 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -76,7 +76,7 @@ def is_selector(obj: Any) -> TypeIs[Selector]: return isinstance(obj, _selectors().Selector) -def is_column(obj: Any) -> TypeIs[Expr]: +def is_expr_column(obj: Any) -> TypeIs[Expr]: """Indicate if the given object is a basic/unaliased column.""" return is_expr(obj) and obj.meta.is_column() @@ -136,8 +136,19 @@ def is_literal(obj: Any) -> TypeIs[ir.Literal[Any]]: # TODO @dangotbanned: Coverage # Used in `ArrowNamespace._vertical`, but only horizontal is covered def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: # pragma: no cover + """Return True if the **first** element of the tuple `obj` is an instance of `tp`.""" return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp)) def is_re_pattern(obj: Any) -> TypeIs[re.Pattern[str]]: return isinstance(obj, re.Pattern) + + +def is_seq_column(exprs: Seq[ir.ExprIR], /) -> TypeIs[Seq[ir.Column]]: + """Return True if **every** element is a `Column`. + + Use this for detecting fastpaths in sub-expressions, that can rely on + every element in `exprs` having a resolved `name` attribute. + """ + Column = _ir().Column # noqa: N806 + return all(isinstance(e, Column) for e in exprs) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 990a688d98..541d38671a 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -25,8 +25,7 @@ import pyarrow.compute as pc # ignore-banned-import from pyarrow.acero import Declaration as Decl -from narwhals._plan.common import ensure_list_str, flatten_hash_safe, temp -from narwhals._plan.options import SortMultipleOptions +from narwhals._plan.common import ensure_list_str, temp from narwhals._plan.typing import NonCrossJoinStrategy, OneOrSeq from narwhals._utils import check_column_names_are_unique from narwhals.typing import JoinStrategy, SingleColSelector @@ -48,13 +47,8 @@ Aggregation as _Aggregation, ) from narwhals._plan.arrow.group_by import AggSpec - from narwhals._plan.arrow.typing import ( - ArrowAny, - JoinTypeSubset, - NullPlacement, - ScalarAny, - ) - from narwhals._plan.typing import OneOrIterable, Order, Seq + from narwhals._plan.arrow.typing import ArrowAny, JoinTypeSubset, ScalarAny + from narwhals._plan.typing import OneOrIterable, Seq from narwhals.typing import NonNestedLiteral Incomplete: TypeAlias = Any @@ -238,29 +232,6 @@ def prepend_column(native: pa.Table, name: str, values: IntoExpr) -> Decl: return _add_column(native, 0, name, values) -def _order_by( - sort_keys: Iterable[tuple[str, Order]] = (), - *, - null_placement: NullPlacement = "at_end", -) -> Decl: - # NOTE: There's no runtime type checking of `sort_keys` wrt shape - # Just need to be `Iterable`and unpack like a 2-tuple - # https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_compute.pyx#L77-L88 - keys: Incomplete = sort_keys - return Decl("order_by", pac.OrderByNodeOptions(keys, null_placement=null_placement)) - - -def sort_by( - by: OneOrIterable[str], - *more_by: str, - descending: OneOrIterable[bool] = False, - nulls_last: bool = False, -) -> Decl: - return SortMultipleOptions.parse( - descending=descending, nulls_last=nulls_last - ).to_arrow_acero(tuple(flatten_hash_safe((by, more_by)))) - - def _join_options( how: NonCrossJoinStrategy, left_on: OneOrIterable[str], @@ -406,6 +377,38 @@ def join_tables( return collect(_hashjoin(left, right, opts), ensure_unique_column_names=True) +# TODO @dangotbanned: Adapt this into a `Declaration` that handles more of `ArrowGroupBy.agg_over` +def join_inner_tables(left: pa.Table, right: pa.Table, on: list[str]) -> pa.Table: + """Fast path for use with `over`. + + Has almost zero branching and the bodys of helper functions are inlined. + + Eventually want to adapt this into: + + goal = declare( + join_inner( + declare(table_source(compliant.native), select_names(key_names)), + declare(table_source(ordered), group_by(key_names, specs)), + ), + select_names(agg_names), + ) + """ + tp: Incomplete = pac.HashJoinNodeOptions + opts = tp( + "inner", + left_keys=on, + right_keys=on, + left_output=left.schema.names, + right_output=(name for name in right.schema.names if name not in on), + output_suffix_for_right="_right", + ) + lhs, rhs = pac.TableSourceNodeOptions(left), pac.TableSourceNodeOptions(right) + decl = Decl("hashjoin", opts, [Decl("table_source", lhs), Decl("table_source", rhs)]) + result = decl.to_table() + check_column_names_are_unique(result.column_names) + return result + + def join_cross_tables( left: pa.Table, right: pa.Table, suffix: str = "_right", *, coalesce_keys: bool = True ) -> pa.Table: diff --git a/narwhals/_plan/arrow/common.py b/narwhals/_plan/arrow/common.py new file mode 100644 index 0000000000..fdbe173f2c --- /dev/null +++ b/narwhals/_plan/arrow/common.py @@ -0,0 +1,61 @@ +"""Behavior shared by two or more classes.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, Generic + +from narwhals._plan.arrow.functions import BACKEND_VERSION +from narwhals._typing_compat import TypeVar +from narwhals._utils import Implementation, Version, _StoresNative + +if TYPE_CHECKING: + import pyarrow as pa + from typing_extensions import Self, TypeIs + + from narwhals._plan.arrow.namespace import ArrowNamespace + from narwhals._plan.arrow.typing import ChunkedArrayAny, Indices + + +def is_series(obj: Any) -> TypeIs[_StoresNative[ChunkedArrayAny]]: + from narwhals._plan.arrow.series import ArrowSeries + + return isinstance(obj, ArrowSeries) + + +NativeT = TypeVar("NativeT", "pa.Table", "ChunkedArrayAny") + + +class ArrowFrameSeries(Generic[NativeT]): + implementation: ClassVar = Implementation.PYARROW + _native: NativeT + _version: Version + + @property + def native(self) -> NativeT: + return self._native + + def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._plan.arrow.namespace import ArrowNamespace + + return ArrowNamespace(self._version) + + def _with_native(self, native: NativeT) -> Self: + msg = f"{type(self).__name__}._with_native" + raise NotImplementedError(msg) + + if BACKEND_VERSION >= (18,): + + def _gather(self, indices: Indices) -> NativeT: + return self.native.take(indices) + else: + + def _gather(self, indices: Indices) -> NativeT: + rows = list(indices) if isinstance(indices, tuple) else indices + return self.native.take(rows) + + def gather(self, indices: Indices | _StoresNative[ChunkedArrayAny]) -> Self: + ca = self._gather(indices.native if is_series(indices) else indices) + return self._with_native(ca) + + def slice(self, offset: int, length: int | None = None) -> Self: + return self._with_native(self.native.slice(offset=offset, length=length)) diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 03043b0e3e..2c0b2e37e8 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -10,13 +10,14 @@ from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import acero, functions as fn +from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy, partition_by from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.compliant.dataframe import EagerDataFrame from narwhals._plan.compliant.typing import namespace from narwhals._plan.expressions import NamedIR -from narwhals._utils import Implementation, Version +from narwhals._utils import Version, generate_repr from narwhals.schema import Schema if TYPE_CHECKING: @@ -25,8 +26,8 @@ import polars as pl from typing_extensions import Self - from narwhals._arrow.typing import ChunkedArrayAny - from narwhals._plan.arrow.namespace import ArrowNamespace + from narwhals._plan.arrow.typing import ChunkedArrayAny + from narwhals._plan.compliant.group_by import GroupByResolver from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.options import SortMultipleOptions from narwhals._plan.typing import NonCrossJoinStrategy @@ -34,20 +35,22 @@ from narwhals.typing import IntoSchema -class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]): - implementation = Implementation.PYARROW - _native: pa.Table - _version: Version +class ArrowDataFrame( + FrameSeries["pa.Table"], EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"] +): + def __repr__(self) -> str: + return generate_repr(f"nw.{type(self).__name__}", self.native.__repr__()) - def __narwhals_namespace__(self) -> ArrowNamespace: - from narwhals._plan.arrow.namespace import ArrowNamespace - - return ArrowNamespace(self._version) + def _with_native(self, native: pa.Table) -> Self: + return self.from_native(native, self.version) @property def _group_by(self) -> type[GroupBy]: return GroupBy + def group_by_resolver(self, resolver: GroupByResolver, /) -> GroupBy: + return self._group_by.from_resolver(self, resolver) + @property def columns(self) -> list[str]: return self.native.column_names @@ -98,14 +101,26 @@ def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[Series] from_named_ir = ns._expr.from_named_ir yield from ns._expr.align(from_named_ir(e, self) for e in nodes) - def sort(self, by: Sequence[str], options: SortMultipleOptions) -> Self: - native = self.native - indices = pc.sort_indices(native.select(list(by)), options=options.to_arrow(by)) - return self._with_native(native.take(indices)) + def sort(self, by: Sequence[str], options: SortMultipleOptions | None = None) -> Self: + return self.gather(fn.sort_indices(self.native, *by, options=options)) def with_row_index(self, name: str) -> Self: return self._with_native(self.native.add_column(0, name, fn.int_range(len(self)))) + def with_row_index_by( + self, + name: str, + order_by: Sequence[str], + *, + descending: bool = False, + nulls_last: bool = False, + ) -> Self: + indices = fn.sort_indices( + self.native, *order_by, nulls_last=nulls_last, descending=descending + ) + column = fn.unsort_indices(indices) + return self._with_native(self.native.add_column(0, name, column)) + def get_column(self, name: str) -> Series: chunked = self.native.column(name) return Series.from_native(chunked, name, version=self.version) @@ -168,6 +183,10 @@ def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: result = acero.join_cross_tables(self.native, other.native, suffix=suffix) return self._with_native(result) + def join_inner(self, other: Self, on: list[str], /) -> Self: + """Less flexible, but more direct equivalent to join(how="inner", left_on=...)`.""" + return self._with_native(acero.join_inner_tables(self.native, other.native, on)) + def filter(self, predicate: NamedIR) -> Self: mask: pc.Expression | ChunkedArrayAny resolved = Expr.from_named_ir(predicate, self) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index b34cf963ca..31181b7a95 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -7,6 +7,7 @@ from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan import expressions as ir +from narwhals._plan._guards import is_function_expr, is_seq_column from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co @@ -15,18 +16,26 @@ from narwhals._plan.compliant.expr import EagerExpr from narwhals._plan.compliant.scalar import EagerScalar from narwhals._plan.compliant.typing import namespace +from narwhals._plan.expressions.boolean import ( + IsFirstDistinct, + IsInExpr, + IsInSeq, + IsInSeries, + IsLastDistinct, +) from narwhals._plan.expressions.functions import NullCount from narwhals._utils import Implementation, Version, _StoresNative, not_implemented from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Sequence from typing_extensions import Self, TypeAlias - from narwhals._arrow.typing import ChunkedArrayAny, Incomplete + from narwhals._arrow.typing import Incomplete from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace + from narwhals._plan.arrow.typing import ChunkedArrayAny, P, VectorFunction from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, @@ -48,13 +57,11 @@ All, IsBetween, IsFinite, - IsFirstDistinct, - IsLastDistinct, IsNan, IsNull, Not, ) - from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr + from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr as FExpr from narwhals._plan.expressions.functions import ( Abs, CumAgg, @@ -62,8 +69,10 @@ FillNull, NullCount, Pow, + Rank, Shift, ) + from narwhals._plan.typing import Seq from narwhals.typing import Into1DArray, IntoDType, PythonLiteral Expr: TypeAlias = "ArrowExpr" @@ -87,14 +96,14 @@ def cast(self, node: ir.Cast, frame: Frame, name: str) -> StoresNativeT_co: native = node.expr.dispatch(self, frame, name).native return self._with_native(fn.cast(native, data_type), name) - def pow(self, node: FunctionExpr[Pow], frame: Frame, name: str) -> StoresNativeT_co: + def pow(self, node: FExpr[Pow], frame: Frame, name: str) -> StoresNativeT_co: base, exponent = node.function.unwrap_input(node) base_ = base.dispatch(self, frame, "base").native exponent_ = exponent.dispatch(self, frame, "exponent").native return self._with_native(pc.power(base_, exponent_), name) def fill_null( - self, node: FunctionExpr[FillNull], frame: Frame, name: str + self, node: FExpr[FillNull], frame: Frame, name: str ) -> StoresNativeT_co: expr, value = node.function.unwrap_input(node) native = expr.dispatch(self, frame, name).native @@ -102,7 +111,7 @@ def fill_null( return self._with_native(pc.fill_null(native, value_), name) def is_between( - self, node: FunctionExpr[IsBetween], frame: Frame, name: str + self, node: FExpr[IsBetween], frame: Frame, name: str ) -> StoresNativeT_co: expr, lower_bound, upper_bound = node.function.unwrap_input(node) native = expr.dispatch(self, frame, name).native @@ -113,40 +122,60 @@ def is_between( def _unary_function( self, fn_native: Callable[[Any], Any], / - ) -> Callable[[FunctionExpr[Any], Frame, str], StoresNativeT_co]: - def func(node: FunctionExpr[Any], frame: Frame, name: str) -> StoresNativeT_co: + ) -> Callable[[FExpr[Any], Frame, str], StoresNativeT_co]: + def func(node: FExpr[Any], frame: Frame, name: str) -> StoresNativeT_co: native = node.input[0].dispatch(self, frame, name).native return self._with_native(fn_native(native), name) return func - def abs(self, node: FunctionExpr[Abs], frame: Frame, name: str) -> StoresNativeT_co: + def abs(self, node: FExpr[Abs], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(pc.abs)(node, frame, name) - def not_(self, node: FunctionExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: + def not_(self, node: FExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(pc.invert)(node, frame, name) - def all(self, node: FunctionExpr[All], frame: Frame, name: str) -> StoresNativeT_co: + def all(self, node: FExpr[All], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.all_)(node, frame, name) def any( - self, node: FunctionExpr[ir.boolean.Any], frame: Frame, name: str + self, node: FExpr[ir.boolean.Any], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.any_)(node, frame, name) def is_finite( - self, node: FunctionExpr[IsFinite], frame: Frame, name: str + self, node: FExpr[IsFinite], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.is_finite)(node, frame, name) - def is_nan( - self, node: FunctionExpr[IsNan], frame: Frame, name: str + def is_in_expr( + self, node: FExpr[IsInExpr], frame: Frame, name: str ) -> StoresNativeT_co: - return self._unary_function(fn.is_nan)(node, frame, name) + expr, other = node.function.unwrap_input(node) + right = other.dispatch(self, frame, name).native + if isinstance(right, pa.Scalar): + right = fn.array(right) + result = fn.is_in(expr.dispatch(self, frame, name).native, right) + return self._with_native(result, name) + + def is_in_series( + self, node: FExpr[IsInSeries[ChunkedArrayAny]], frame: Frame, name: str + ) -> StoresNativeT_co: + native = node.input[0].dispatch(self, frame, name).native + other = node.function.other.unwrap().to_native() + return self._with_native(fn.is_in(native, other), name) - def is_null( - self, node: FunctionExpr[IsNull], frame: Frame, name: str + def is_in_seq( + self, node: FExpr[IsInSeq], frame: Frame, name: str ) -> StoresNativeT_co: + native = node.input[0].dispatch(self, frame, name).native + other = fn.array(node.function.other) + return self._with_native(fn.is_in(native, other), name) + + def is_nan(self, node: FExpr[IsNan], frame: Frame, name: str) -> StoresNativeT_co: + return self._unary_function(fn.is_nan)(node, frame, name) + + def is_null(self, node: FExpr[IsNull], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.is_null)(node, frame, name) def binary_expr(self, node: BinaryExpr, frame: Frame, name: str) -> StoresNativeT_co: @@ -166,6 +195,14 @@ def ternary_expr( result = pc.if_else(when.native, then.native, otherwise.native) return self._with_native(result, name) + exp = not_implemented() # type: ignore[misc] + log = not_implemented() # type: ignore[misc] + sqrt = not_implemented() # type: ignore[misc] + round = not_implemented() # type: ignore[misc] + clip = not_implemented() # type: ignore[misc] + drop_nulls = not_implemented() # type: ignore[misc] + replace_strict = not_implemented() # type: ignore[misc] + class ArrowExpr( # type: ignore[misc] _ArrowDispatch["ArrowExpr | ArrowScalar"], @@ -216,6 +253,15 @@ def _dispatch_expr(self, node: ir.ExprIR, frame: Frame, name: str) -> Series: """ return node.dispatch(self, frame, name).to_series() + def _vector_function( + self, fn_native: VectorFunction[P], *args: P.args, **kwds: P.kwargs + ) -> Callable[[FExpr[Any], Frame, str], Self]: + def func(node: FExpr[Any], frame: Frame, name: str, /) -> Self: # type: ignore[type-var, misc] + native = self._dispatch_expr(node.input[0], frame, name).native + return self._with_native(fn_native(native, *args, **kwds), name) + + return func + @property def native(self) -> ChunkedArrayAny: return self._evaluated.native @@ -233,19 +279,25 @@ def __len__(self) -> int: return len(self._evaluated) def sort(self, node: ir.Sort, frame: Frame, name: str) -> Expr: - native = self._dispatch_expr(node.expr, frame, name).native - sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) - return self._with_native(native.take(sorted_indices), name) + series = self._dispatch_expr(node.expr, frame, name) + opts = node.options + result = series.sort(descending=opts.descending, nulls_last=opts.nulls_last) + return self.from_series(result) def sort_by(self, node: ir.SortBy, frame: Frame, name: str) -> Expr: + if is_seq_column(node.by): + # fastpath, roughly the same as `DataFrame.sort`, but only taking indices + # of a single column + keys: Sequence[str] = tuple(e.name for e in node.by) + df = frame + else: + it_names = temp.column_names(frame) + by = (self._dispatch_expr(e, frame, nm) for e, nm in zip(node.by, it_names)) + df = namespace(self)._concat_horizontal(by) + keys = df.columns + indices = fn.sort_indices(df.native, *keys, options=node.options) series = self._dispatch_expr(node.expr, frame, name) - it_names = temp.column_names(frame) - by = (self._dispatch_expr(e, frame, nm) for e, nm in zip(node.by, it_names)) - df = namespace(self)._concat_horizontal((series, *by)) - names = df.columns[1:] - indices = pc.sort_indices(df.native, options=node.options.to_arrow(names)) - result: ChunkedArrayAny = df.native.column(0).take(indices) - return self._with_native(result, name) + return self.from_series(series.gather(indices)) def filter(self, node: ir.Filter, frame: Frame, name: str) -> Expr: return self._with_native( @@ -329,56 +381,77 @@ def min(self, node: Min, frame: Frame, name: str) -> Scalar: result: NativeScalar = fn.min_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) + def null_count(self, node: FExpr[NullCount], frame: Frame, name: str) -> Scalar: + native = self._dispatch_expr(node.input[0], frame, name).native + return self._with_native(fn.null_count(native), name) + # TODO @dangotbanned: top-level, complex-ish nodes # - [ ] Over # - [x] `over_ordered` # - [x] `group_by`, `join` # - [x] `over` (with partitions) - # - [ ] `over_ordered` (with partitions) + # - [x] `over_ordered` (with partitions) + # - [ ] fix: join on nulls after https://github.com/narwhals-dev/narwhals/issues/3300 # - [ ] `map_batches` # - [x] elementwise # - [ ] scalar # - [ ] `rolling_expr` has 4 variants - def over(self, node: ir.WindowExpr, frame: Frame, name: str) -> Self: - resolved = ( - frame._grouper.by_irs(*node.partition_by) - # TODO @dangotbanned: Clean this up so the re-alias isn't needed - .agg_irs(node.expr.alias(name)) - .resolve(frame) - ) - by_names = resolved.key_names - result = ( - frame.select_names(*by_names) - .join(resolved.evaluate(frame), how="left", left_on=by_names) - .get_column(name) - .native - ) - return self._with_native(result, name) + def over( + self, + node: ir.WindowExpr, + frame: Frame, + name: str, + *, + sort_indices: pa.UInt64Array | None = None, + ) -> Self: + expr = node.expr + by = node.partition_by + if is_function_expr(expr) and isinstance( + expr.function, (IsFirstDistinct, IsLastDistinct) + ): + return self._is_first_last_distinct( + expr, frame, name, by, sort_indices=sort_indices + ) + resolved = frame._grouper.by_irs(*by).agg_irs(expr.alias(name)).resolve(frame) + results = frame.group_by_resolver(resolved).agg_over(resolved.aggs, sort_indices) + return self.from_series(results.get_column(name)) def over_ordered( self, node: ir.OrderedWindowExpr, frame: Frame, name: str ) -> Self | Scalar: + by = node.order_by_names() + indices = fn.sort_indices(frame.native, *by, options=node.sort_options) if node.partition_by: - msg = f"Need to implement `group_by`, `join` for:\n{node!r}" - raise NotImplementedError(msg) + return self.over(node, frame, name, sort_indices=indices) + evaluated = node.expr.dispatch(self, frame.gather(indices), name) + if isinstance(evaluated, ArrowScalar): + return evaluated + return self.from_series(evaluated.broadcast(len(frame)).gather(indices)) - # NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort` - sort_by = tuple(node.order_by_names()) - options = node.sort_options.to_multiple(len(sort_by)) + def _is_first_last_distinct( + self, + node: FExpr[IsFirstDistinct | IsLastDistinct], + frame: Frame, + name: str, + partition_by: Seq[ir.ExprIR] = (), + *, + sort_indices: pa.UInt64Array | None = None, + ) -> Self: idx_name = temp.column_name(frame) - sorted_context = frame.with_row_index(idx_name).sort(sort_by, options) - evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name) - if isinstance(evaluated, ArrowScalar): - # NOTE: We're already sorted, defer broadcasting to the outer context - # Wouldn't be suitable for partitions, but will be fine here - # - https://github.com/narwhals-dev/narwhals/pull/2528/commits/2ae42458cae91f4473e01270919815fcd7cb9667 - # - https://github.com/narwhals-dev/narwhals/pull/2528/commits/b8066c4c57d4b0b6c38d58a0f5de05eefc2cae70 - return self._with_native(evaluated.native, name) - indices = pc.sort_indices(sorted_context.get_column(idx_name).native) - height = len(sorted_context) - result = evaluated.broadcast(height).native.take(indices) - return self._with_native(result, name) + df = frame._with_columns([node.input[0].dispatch(self, frame, name)]) + if sort_indices is not None: + column = fn.unsort_indices(sort_indices) + df = df._with_native(df.native.add_column(0, idx_name, column)) + else: + df = df.with_row_index(idx_name) + agg = fn.IS_FIRST_LAST_DISTINCT[type(node.function)](idx_name) + if not (partition_by or sort_indices is not None): + distinct = df.group_by_names((name,)).agg((ir.named_ir(idx_name, agg),)) + else: + distinct = df.group_by_agg_irs((ir.col(name), *partition_by), agg) + index = df.to_series().alias(name) + return self.from_series(index.is_in(distinct.get_column(idx_name))) # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: @@ -389,7 +462,7 @@ def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: series = self._dispatch_expr(node.input[0], frame, name) udf = node.function.function result: Series | Into1DArray = udf(series) - if not fn.is_series(result): + if not isinstance(result, Series): result = Series.from_numpy(result, name, version=self.version) if dtype := node.function.return_dtype: result = result.cast(dtype) @@ -398,50 +471,43 @@ def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: def rolling_expr(self, node: ir.RollingExpr, frame: Frame, name: str) -> Self: raise NotImplementedError - def shift(self, node: ir.FunctionExpr[Shift], frame: Frame, name: str) -> Self: - series = self._dispatch_expr(node.input[0], frame, name) - return self._with_native(fn.shift(series.native, node.function.n), name) + def shift(self, node: FExpr[Shift], frame: Frame, name: str) -> Self: + return self._vector_function(fn.shift, node.function.n)(node, frame, name) - def diff(self, node: ir.FunctionExpr[Diff], frame: Frame, name: str) -> Self: - series = self._dispatch_expr(node.input[0], frame, name) - return self._with_native(fn.diff(series.native), name) + def diff(self, node: FExpr[Diff], frame: Frame, name: str) -> Self: + return self._vector_function(fn.diff)(node, frame, name) - def _cumulative(self, node: ir.FunctionExpr[CumAgg], frame: Frame, name: str) -> Self: - series = self._dispatch_expr(node.input[0], frame, name) - return self._with_native(fn.cumulative(series.native, node.function), name) + def rank(self, node: FExpr[Rank], frame: Frame, name: str) -> Self: + return self._vector_function(fn.rank, node.function.options)(node, frame, name) + + def _cumulative(self, node: FExpr[CumAgg], frame: Frame, name: str) -> Self: + native = self._dispatch_expr(node.input[0], frame, name).native + func = fn.CUMULATIVE[type(node.function)] + if not node.function.reverse: + result = func(native) + else: + result = fn.reverse(func(fn.reverse(native))) + return self._with_native(result, name) cum_count = _cumulative cum_min = _cumulative cum_max = _cumulative cum_prod = _cumulative cum_sum = _cumulative - - def _is_first_last_distinct( - self, - node: FunctionExpr[IsFirstDistinct | IsLastDistinct], - frame: Frame, - name: str, - ) -> Self: - idx_name = temp.column_name([name]) - expr_ir = fn.IS_FIRST_LAST_DISTINCT[type(node.function)](idx_name) - series = self._dispatch_expr(node.input[0], frame, name) - df = series.to_frame().with_row_index(idx_name) - distinct_index = ( - df.group_by_names((name,)) - .agg((ir.named_ir(idx_name, expr_ir),)) - .get_column(idx_name) - .native - ) - return self._with_native(fn.is_in(df.to_series().native, distinct_index), name) - is_first_distinct = _is_first_last_distinct is_last_distinct = _is_first_last_distinct - def null_count( - self, node: ir.FunctionExpr[NullCount], frame: Frame, name: str - ) -> Scalar: - series = self._dispatch_expr(node.input[0], frame, name) - return self._with_native(fn.lit(series.native.null_count), name) + # ewm_mean = not_implemented() # noqa: ERA001 + hist_bins = not_implemented() + hist_bin_count = not_implemented() + mode = not_implemented() + unique = not_implemented() + fill_null_with_strategy = not_implemented() + kurtosis = not_implemented() + skew = not_implemented() + gather_every = not_implemented() + is_duplicated = not_implemented() + is_unique = not_implemented() class ArrowScalar( @@ -525,9 +591,7 @@ def count(self, node: Count, frame: Frame, name: str) -> Scalar: native = node.expr.dispatch(self, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) - def null_count( - self, node: ir.FunctionExpr[NullCount], frame: Frame, name: str - ) -> Self: + def null_count(self, node: FExpr[NullCount], frame: Frame, name: str) -> Self: native = node.input[0].dispatch(self, frame, name).native return self._with_native(pa.scalar(0 if native.is_valid else 1), name) @@ -535,6 +599,7 @@ def null_count( over = not_implemented() over_ordered = not_implemented() map_batches = not_implemented() + rank = not_implemented() # length_preserving rolling_expr = not_implemented() diff = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 455c53a517..a17cb0ebc7 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -3,8 +3,8 @@ from __future__ import annotations import typing as t -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, overload +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, Final, Literal, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -15,7 +15,7 @@ floordiv_compat as floordiv, ) from narwhals._plan import expressions as ir -from narwhals._plan.arrow import options +from narwhals._plan.arrow import options as pa_options from narwhals._plan.expressions import functions as F, operators as ops from narwhals._utils import Implementation @@ -26,7 +26,6 @@ from typing_extensions import TypeAlias, TypeIs from narwhals._arrow.typing import Incomplete, PromoteOptions - from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.arrow.typing import ( Array, ArrayAny, @@ -37,6 +36,7 @@ BinOp, ChunkedArray, ChunkedArrayAny, + ChunkedOrArray, ChunkedOrArrayAny, ChunkedOrArrayT, ChunkedOrScalar, @@ -56,9 +56,23 @@ StringType, UnaryFunction, ) + from narwhals._plan.options import RankOptions, SortMultipleOptions, SortOptions from narwhals.typing import ClosedInterval, IntoArrowSchema, PythonLiteral BACKEND_VERSION = Implementation.PYARROW._backend_version() +"""Static backend version for `pyarrow`.""" + +RANK_ACCEPTS_CHUNKED: Final = BACKEND_VERSION >= (14,) + +HAS_SCATTER: Final = BACKEND_VERSION >= (20,) +"""`pyarrow.compute.scatter` added in https://github.com/apache/arrow/pull/44394""" + +HAS_ARANGE: Final = BACKEND_VERSION >= (21,) +"""`pyarrow.arange` added in https://github.com/apache/arrow/pull/46778""" + + +I64: Final = pa.int64() +F64: Final = pa.float64() IntoColumnAgg: TypeAlias = Callable[[str], ir.AggExpr] """Helper constructor for single-column aggregations.""" @@ -216,7 +230,7 @@ def n_unique(native: Any) -> pa.Int64Scalar: return count(native, mode="all") -def _reverse(native: ChunkedOrArrayT) -> ChunkedOrArrayT: +def reverse(native: ChunkedOrArrayT) -> ChunkedOrArrayT: """Unlike other slicing ops, `[::-1]` creates a full-copy. https://github.com/apache/arrow/issues/19103#issuecomment-1377671886 @@ -224,13 +238,6 @@ def _reverse(native: ChunkedOrArrayT) -> ChunkedOrArrayT: return native[::-1] -def cumulative(native: ChunkedArrayAny, cum_agg: F.CumAgg, /) -> ChunkedArrayAny: - func = _CUMULATIVE[type(cum_agg)] - if not cum_agg.reverse: - return func(native) - return _reverse(func(_reverse(native))) - - def cum_sum(native: ChunkedOrArrayT) -> ChunkedOrArrayT: return pc.cumulative_sum(native, skip_nulls=True) @@ -251,7 +258,7 @@ def cum_count(native: ChunkedArrayAny) -> ChunkedArrayAny: return cum_sum(is_not_null(native).cast(pa.uint32())) -_CUMULATIVE: Mapping[type[F.CumAgg], Callable[[ChunkedArrayAny], ChunkedArrayAny]] = { +CUMULATIVE: Mapping[type[F.CumAgg], Callable[[ChunkedArrayAny], ChunkedArrayAny]] = { F.CumSum: cum_sum, F.CumCount: cum_count, F.CumMin: cum_min, @@ -280,6 +287,35 @@ def shift(native: ChunkedArrayAny, n: int) -> ChunkedArrayAny: return pa.chunked_array(arrays) +def rank(native: ChunkedArrayAny, rank_options: RankOptions) -> ChunkedArrayAny: + arr = native if RANK_ACCEPTS_CHUNKED else array(native) + if rank_options.method == "average": + # Adapted from https://github.com/pandas-dev/pandas/blob/f4851e500a43125d505db64e548af0355227714b/pandas/core/arrays/arrow/array.py#L2290-L2316 + order = pa_options.ORDER[rank_options.descending] + min = preserve_nulls(arr, pc.rank(arr, order, tiebreaker="min").cast(F64)) + max = pc.rank(arr, order, tiebreaker="max").cast(F64) + ranked = pc.divide(pc.add(min, max), lit(2, F64)) + else: + ranked = preserve_nulls(native, pc.rank(arr, options=rank_options.to_arrow())) + return chunked_array(ranked) + + +def null_count(native: ChunkedOrArrayAny) -> pa.Int64Scalar: + return pc.count(native, mode="only_null") + + +def has_nulls(native: ChunkedOrArrayAny) -> bool: + return bool(native.null_count) + + +def preserve_nulls( + before: ChunkedOrArrayAny, after: ChunkedOrArrayT, / +) -> ChunkedOrArrayT: + if has_nulls(before): + after = pc.if_else(before.is_null(), lit(None, after.type), after) + return after + + def is_between( native: ChunkedOrScalar[ScalarT], lower: ChunkedOrScalar[ScalarT], @@ -323,29 +359,109 @@ def concat_str( dtype = string_type(obj.type for obj in arrays) it = (obj.cast(dtype) for obj in arrays) concat: Incomplete = pc.binary_join_element_wise - join = options.join(ignore_nulls=ignore_nulls) + join = pa_options.join(ignore_nulls=ignore_nulls) return concat(*it, lit(separator, dtype), options=join) # type: ignore[no-any-return] +def sort_indices( + native: ChunkedOrArrayAny | pa.Table, + *order_by: str, + descending: bool | Sequence[bool] = False, + nulls_last: bool = False, + options: SortOptions | SortMultipleOptions | None = None, +) -> pa.UInt64Array: + """Return the indices that would sort an array or table.""" + opts = ( + options.to_arrow(order_by) + if options + else pa_options.sort(*order_by, descending=descending, nulls_last=nulls_last) + ) + return pc.sort_indices(native, options=opts) + + +def unsort_indices(indices: pa.UInt64Array, /) -> pa.Int64Array: + """Return the inverse permutation of the given indices. + + Arguments: + indices: The output of `sort_indices`. + + Examples: + We can use this pair of functions to recreate a windowed `pl.row_index` + + >>> import polars as pl + >>> data = {"by": [5, 2, 5, None]} + >>> df = pl.DataFrame(data) + >>> df.select( + ... pl.row_index().over(order_by="by", descending=True, nulls_last=False) + ... ).to_series().to_list() + [1, 3, 2, 0] + + Now in `pyarrow` + + >>> import pyarrow as pa + >>> from narwhals._plan.arrow.functions import sort_indices, unsort_indices + >>> df = pa.Table.from_pydict(data) + >>> unsort_indices( + ... sort_indices(df, "by", descending=True, nulls_last=False) + ... ).to_pylist() + [1, 3, 2, 0] + """ + return ( + pc.inverse_permutation(indices.cast(pa.int64())) # type: ignore[attr-defined] + if HAS_SCATTER + else int_range(len(indices), chunked=False).take(pc.sort_indices(indices)) + ) + + +@overload +def int_range( + start: int = ..., + end: int | None = ..., + step: int = ..., + /, + *, + dtype: IntegerType = ..., + chunked: Literal[True] = ..., +) -> ChunkedArray[IntegerScalar]: ... +@overload +def int_range( + start: int = ..., + end: int | None = ..., + step: int = ..., + /, + *, + chunked: Literal[False], +) -> pa.Int64Array: ... +@overload +def int_range( + start: int = ..., + end: int | None = ..., + step: int = ..., + /, + *, + dtype: IntegerType = ..., + chunked: Literal[False], +) -> Array[IntegerScalar]: ... def int_range( start: int = 0, end: int | None = None, step: int = 1, /, *, - dtype: IntegerType = pa.int64(), # noqa: B008 -) -> ChunkedArray[IntegerScalar]: + dtype: IntegerType = I64, + chunked: bool = True, +) -> ChunkedOrArray[IntegerScalar]: if end is None: end = start start = 0 - if BACKEND_VERSION < (21, 0, 0): # pragma: no cover + if not HAS_ARANGE: # pragma: no cover import numpy as np # ignore-banned-import arr = pa.array(np.arange(start=start, stop=end, step=step), type=dtype) else: - int_range_: Incomplete = t.cast("Incomplete", pa.arange) # type: ignore[attr-defined] + int_range_: Incomplete = pa.arange # type: ignore[attr-defined] arr = t.cast("ArrayAny", int_range_(start=start, stop=end, step=step)).cast(dtype) - return pa.chunked_array([arr]) + return arr if not chunked else pa.chunked_array([arr]) def date_range( @@ -433,12 +549,6 @@ def concat_diagonal(tables: Iterable[pa.Table]) -> pa.Table: return pa.concat_tables(tables, promote=True) -def is_series(obj: t.Any) -> TypeIs[ArrowSeries]: - from narwhals._plan.arrow.series import ArrowSeries - - return isinstance(obj, ArrowSeries) - - def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataType]]: return ( (first := next(iter(obj.items())), None) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 776f4d3d60..cce6463239 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from itertools import chain +from typing import TYPE_CHECKING, Any, Literal, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -16,12 +17,17 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterator, Mapping, Sequence + from collections.abc import Collection, Iterator, Mapping, Sequence from typing_extensions import Self, TypeAlias from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame - from narwhals._plan.arrow.typing import ChunkedArray + from narwhals._plan.arrow.typing import ( + ArrayAny, + ChunkedArray, + ChunkedArrayAny, + Indices, + ) from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq @@ -139,16 +145,21 @@ def group_by_error( return InvalidOperationError(msg) +def multiple_null_partitions_error(column_names: Collection[str]) -> NotImplementedError: + backend = Implementation.PYARROW + msg = ( + f"`over(*partition_by)` where multiple columns contain null values is not yet supported for {backend!r}\n" + f"Got: {list(column_names)!r}" + ) + return NotImplementedError(msg) + + class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): _df: Frame _keys: Seq[NamedIR] _key_names: Seq[str] _key_names_original: Seq[str] - @property - def compliant(self) -> Frame: - return self._df - def __iter__(self) -> Iterator[tuple[Any, Frame]]: by = self.key_names from_native = self.compliant._with_native @@ -169,6 +180,65 @@ def agg(self, irs: Seq[NamedIR]) -> Frame: return result.rename(dict(zip(key_names, original))) return result + def agg_over(self, irs: Seq[NamedIR], sort_indices: Indices | None = None) -> Frame: + key_names = list(self.key_names) + compliant = self.compliant + native = compliant.native + column_names = native.column_names + agg_names = (e.name for e in irs) + from_native = compliant._with_native + + # Handle null values in partitions, trying to avoid any work if possible + if len(key_names) == 1: + by = native.column(key_names[0]) + if by.null_count: + temp_name = temp.column_name({*column_names, *agg_names}) + key_names = [temp_name] + native = native.append_column(temp_name, dictionary_encode(by)) + compliant = from_native(native) + else: + partitions = native.select(key_names) + it_temp_names = temp.column_names(chain(column_names, agg_names)) + by_names: list[str] = [] + for orig_name, by in zip(key_names, partitions.columns): + if by.null_count: + by_name = next(it_temp_names) + native = native.append_column(by_name, dictionary_encode(by)) + else: + by_name = orig_name + by_names.append(by_name) + if by_names != key_names: + key_names = by_names + compliant = from_native(native) + + # If `order_by` was used, we can now apply the new order to the aggregation only + ordered = native if sort_indices is None else compliant._gather(sort_indices) + specs = (AggSpec.from_named_ir(e) for e in irs) + windowed = from_native(acero.group_by_table(ordered, key_names, specs)) + return ( + compliant.select_names(*key_names) + .join_inner(windowed, key_names) + .drop(key_names) + ) + + +@overload +def dictionary_encode(native: ChunkedArrayAny, /) -> pa.Int32Array: ... +@overload +def dictionary_encode( + native: ChunkedArrayAny, /, *, include_values: Literal[True] +) -> tuple[ArrayAny, pa.Int32Array]: ... +def dictionary_encode( + native: ChunkedArrayAny, /, *, include_values: bool = False +) -> tuple[ArrayAny, pa.Int32Array] | pa.Int32Array: + """Extra typing for `pc.dictionary_encode`.""" + da: Incomplete = native.dictionary_encode("encode").combine_chunks() + indices: pa.Int32Array = da.indices + if not include_values: + return indices + values: ArrayAny = da.dictionary + return values, indices + def _composite_key(native: pa.Table, *, separator: str = "") -> ChunkedArray: """Horizontally join columns to *seed* a unique key per row combination.""" @@ -192,11 +262,10 @@ def _partition_by_one( native: pa.Table, by: str, *, include_key: bool = True ) -> Iterator[pa.Table]: """Optimized path for single-column partition.""" - arr_dict: Incomplete = fn.array(native.column(by).dictionary_encode("encode")) - indices: pa.Int32Array = arr_dict.indices + values, indices = dictionary_encode(native.column(by), include_values=True) if not include_key: native = native.remove_column(native.schema.get_field_index(by)) - for idx in range(len(arr_dict.dictionary)): + for idx in range(len(values)): # NOTE: Acero filter doesn't support `null_selection_behavior="emit_null"` # Is there any reasonable way to do this in Acero? yield native.filter(pc.equal(pa.scalar(idx), indices)) diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index d0257c8c41..83e73eff87 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -9,23 +9,30 @@ import functools from typing import TYPE_CHECKING, Any, Literal -import pyarrow.compute as pc # ignore-banned-import +import pyarrow.compute as pc + +from narwhals._utils import zip_strict if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Mapping, Sequence from narwhals._plan import expressions as ir from narwhals._plan.arrow import acero + from narwhals._plan.arrow.typing import NullPlacement, RankMethodSingle from narwhals._plan.expressions import aggregation as agg + from narwhals._plan.typing import Order, Seq __all__ = [ "AGG", "FUNCTION", + "array_sort", "count", "join", "join_replace_nulls", + "rank", "scalar_aggregate", + "sort", "variance", ] @@ -33,6 +40,17 @@ AGG: Mapping[type[agg.AggExpr], acero.AggregateOptions] FUNCTION: Mapping[type[ir.Function], acero.AggregateOptions] +_NULLS_LAST = True +_NULLS_FIRST = False +_ASC = False +_DESC = True + +NULL_PLACEMENT: Mapping[bool, NullPlacement] = { + _NULLS_LAST: "at_end", + _NULLS_FIRST: "at_start", +} +ORDER: Mapping[bool, Order] = {_ASC: "ascending", _DESC: "descending"} + @functools.cache def count( @@ -67,6 +85,59 @@ def join_replace_nulls(*, replacement: str = "__nw_null_value__") -> pc.JoinOpti return pc.JoinOptions(null_handling="replace", null_replacement=replacement) +@functools.cache +def array_sort( + *, descending: bool = False, nulls_last: bool = False +) -> pc.ArraySortOptions: + return pc.ArraySortOptions( + order=ORDER[descending], null_placement=NULL_PLACEMENT[nulls_last] + ) + + +@functools.lru_cache(maxsize=16) +def _sort_key(by: str, *, descending: bool = False) -> tuple[str, Order]: + return by, ORDER[descending] + + +@functools.lru_cache(maxsize=8) +def _sort_keys_every( + by: tuple[str, ...], *, descending: bool = False +) -> Seq[tuple[str, Order]]: + if len(by) == 1: + return (_sort_key(by[0], descending=descending),) + order = ORDER[descending] + return tuple((key, order) for key in by) + + +def _sort_keys( + by: tuple[str, ...], *, descending: bool | Sequence[bool] +) -> Seq[tuple[str, Order]]: + if not isinstance(descending, bool) and len(descending) == 1: + descending = descending[0] + if isinstance(descending, bool): + return _sort_keys_every(by, descending=descending) + it = zip_strict(by, descending) + return tuple(_sort_key(key, descending=desc) for (key, desc) in it) + + +def sort( + *by: str, descending: bool | Sequence[bool] = False, nulls_last: bool = False +) -> pc.SortOptions: + keys = _sort_keys(by, descending=descending) + return pc.SortOptions(sort_keys=keys, null_placement=NULL_PLACEMENT[nulls_last]) + + +@functools.cache +def rank( + method: RankMethodSingle, *, descending: bool = False, nulls_last: bool = True +) -> pc.RankOptions: + return pc.RankOptions( + sort_keys=ORDER[descending], + null_placement=NULL_PLACEMENT[nulls_last], + tiebreaker=("first" if method == "ordinal" else method), + ) + + def _generate_agg() -> Mapping[type[agg.AggExpr], acero.AggregateOptions]: from narwhals._plan.expressions import aggregation as agg diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index 2c15f91bb5..ffd68b660a 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -2,11 +2,14 @@ from typing import TYPE_CHECKING, Any +import pyarrow.compute as pc + from narwhals._arrow.utils import narwhals_to_native_dtype, native_to_narwhals_dtype -from narwhals._plan.arrow import functions as fn +from narwhals._plan.arrow import functions as fn, options +from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.compliant.typing import namespace -from narwhals._utils import Implementation, Version +from narwhals._utils import Version, generate_repr from narwhals.dependencies import is_numpy_array_1d if TYPE_CHECKING: @@ -15,23 +18,20 @@ import polars as pl from typing_extensions import Self - from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan.arrow.dataframe import ArrowDataFrame as DataFrame - from narwhals._plan.arrow.namespace import ArrowNamespace + from narwhals._plan.arrow.typing import ChunkedArrayAny from narwhals.dtypes import DType from narwhals.typing import Into1DArray, IntoDType, _1DArray -class ArrowSeries(CompliantSeries["ChunkedArrayAny"]): - implementation = Implementation.PYARROW - _native: ChunkedArrayAny - _version: Version +class ArrowSeries(FrameSeries["ChunkedArrayAny"], CompliantSeries["ChunkedArrayAny"]): _name: str - def __narwhals_namespace__(self) -> ArrowNamespace: - from narwhals._plan.arrow.namespace import ArrowNamespace + def __repr__(self) -> str: + return generate_repr(f"nw.{type(self).__name__}", self.native.__repr__()) - return ArrowNamespace(self._version) + def _with_native(self, native: ChunkedArrayAny) -> Self: + return self.from_native(native, self.name, version=self.version) def to_frame(self) -> DataFrame: return namespace(self)._dataframe.from_dict({self.name: self.native}) @@ -78,3 +78,19 @@ def from_iterable( def cast(self, dtype: IntoDType) -> Self: dtype_pa = narwhals_to_native_dtype(dtype, self.version) return self._with_native(fn.cast(self.native, dtype_pa)) + + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: + opts = options.array_sort(descending=descending, nulls_last=nulls_last) + indices = pc.array_sort_indices(self.native, options=opts) + return self._with_native(self._gather(indices)) + + def scatter(self, indices: Self, values: Self) -> Self: + mask = fn.is_in(fn.int_range(len(self), chunked=False), indices.native) + replacements = fn.array(values._gather(pc.sort_indices(indices.native))) + return self._with_native(pc.replace_with_mask(self.native, mask, replacements)) + + def is_in(self, other: Self) -> Self: + return self._with_native(fn.is_in(self.native, other.native)) + + def has_nulls(self) -> bool: + return bool(self.native.null_count) diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 1d6baf7002..ad2d42cb16 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -1,5 +1,6 @@ from __future__ import annotations +# ruff: noqa: PLC0414 from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, Protocol, overload @@ -15,16 +16,17 @@ Int16Type, Int32Type, Int64Type, - LargeStringType as LargeStringType, # noqa: PLC0414 - StringType as StringType, # noqa: PLC0414 + LargeStringType as LargeStringType, + StringType as StringType, Uint8Type, Uint16Type, Uint32Type, Uint64Type, ) - from typing_extensions import TypeAlias + from typing_extensions import ParamSpec, TypeAlias from narwhals._native import NativeDataFrame, NativeSeries + from narwhals.typing import SizedMultiIndexSelector as _SizedMultiIndexSelector StringScalar: TypeAlias = "Scalar[StringType | LargeStringType]" IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type" @@ -40,6 +42,13 @@ def column(self, *args: Any, **kwds: Any) -> NativeArrowSeries: ... @property def columns(self) -> Sequence[NativeArrowSeries]: ... + P = ParamSpec("P") + + class VectorFunction(Protocol[P]): + def __call__( + self, native: ChunkedArrayAny, *args: P.args, **kwds: P.kwargs + ) -> ChunkedArrayAny: ... + ScalarT = TypeVar("ScalarT", bound="pa.Scalar[Any]", default="pa.Scalar[Any]") ScalarPT_contra = TypeVar( @@ -141,6 +150,8 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot ChunkedOrScalarAny: TypeAlias = "ChunkedOrScalar[ScalarAny]" ChunkedOrArrayAny: TypeAlias = "ChunkedOrArray[ScalarAny]" ChunkedOrArrayT = TypeVar("ChunkedOrArrayT", ChunkedArrayAny, ArrayAny) +Indices: TypeAlias = "_SizedMultiIndexSelector[ChunkedOrArray[pc.IntegerScalar]]" + Arrow: TypeAlias = "ChunkedOrScalar[ScalarT_co] | Array[ScalarT_co]" ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" NativeScalar: TypeAlias = ScalarAny @@ -153,3 +164,9 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot "inner", "left outer", "full outer", "left anti", "left semi" ] """Only the `pyarrow` `JoinType`'s we use in narwhals""" + +RankMethodSingle: TypeAlias = Literal["min", "max", "dense", "ordinal"] +"""Subset of `narwhals` `RankMethod` that is supported directly in `pyarrow`. + +`"average"` requires calculating both `"min"` and `"max"`. +""" diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index b2e3c415b0..3ffd943fe5 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -225,6 +225,10 @@ def column_names( source: Source of columns to check for uniqueness. prefix: Prepend the name with this string. n_chars: Total number of characters used by the name (including `prefix`). + + Notes: + When an `source` is an `Iterator`, it will only be consumed *iff* the result of + `temp.column_names` advances at least once. """ columns = cls._into_columns(source) prefix, n_bytes = cls._parse_prefix_n_bytes(prefix, n_chars) diff --git a/narwhals/_plan/compliant/column.py b/narwhals/_plan/compliant/column.py index f2bb799409..2669a598db 100644 --- a/narwhals/_plan/compliant/column.py +++ b/narwhals/_plan/compliant/column.py @@ -59,7 +59,15 @@ def align( for e in exprs: yield e.broadcast(length) - def broadcast(self, length: LengthT, /) -> SeriesT: ... + def broadcast(self, length: LengthT, /) -> SeriesT: + """Repeat a `Scalar`, or unwrap an `Expr` into a `Series`. + + For `Scalar`, this is always safe, but will be less efficient than if we can operate on (`Scalar`, `Series`). + + For `Expr`, mismatched `length` will raise, but the operation is otherwise free. + """ + ... + @classmethod def from_series(cls, series: SeriesT, /) -> Self: ... def to_series(self) -> SeriesT: ... diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index 80f0b7cbe0..c439cda262 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload from narwhals._plan.compliant.group_by import Grouped @@ -14,7 +15,7 @@ ) if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping, Sequence + from collections.abc import Iterator, Mapping, Sequence import polars as pl from typing_extensions import Self, TypeAlias @@ -67,6 +68,9 @@ def select(self, irs: Seq[NamedIR]) -> Self: ... def select_names(self, *column_names: str) -> Self: ... def sort(self, by: Sequence[str], options: SortMultipleOptions) -> Self: ... def with_columns(self, irs: Seq[NamedIR]) -> Self: ... + def with_row_index_by( + self, name: str, order_by: Sequence[str], *, nulls_last: bool = False + ) -> Self: ... class CompliantDataFrame( @@ -108,6 +112,17 @@ def group_by_agg( """Compliant-level `group_by(by).agg(agg)`, allows `Expr`.""" return self._grouper.by(by).agg(aggs).resolve(self).evaluate(self) + def group_by_agg_irs( + self, by: OneOrIterable[ir.ExprIR], aggs: OneOrIterable[ir.ExprIR], / + ) -> Self: + """Compliant-level `group_by(by).agg(agg)`, allows `ExprIR`. + + Useful for rewriting `over(*partition_by)`. + """ + by = (by,) if not isinstance(by, Iterable) else by + aggs = (aggs,) if not isinstance(aggs, Iterable) else aggs + return self._grouper.by_irs(*by).agg_irs(*aggs).resolve(self).evaluate(self) + def group_by_names(self, names: Seq[str], /) -> DataFrameGroupBy[Self]: """Compliant-level `group_by`, allowing only `str` keys.""" return self._group_by.by_names(self, names) @@ -153,6 +168,7 @@ def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: def to_series(self, index: int = 0) -> SeriesT: ... def to_polars(self) -> pl.DataFrame: ... def with_row_index(self, name: str) -> Self: ... + def slice(self, offset: int, length: int | None = None) -> Self: ... class EagerDataFrame( @@ -162,6 +178,12 @@ class EagerDataFrame( def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... @property def _group_by(self) -> type[EagerDataFrameGroupBy[Self]]: ... + + def group_by_resolver( + self, resolver: GroupByResolver, / + ) -> EagerDataFrameGroupBy[Self]: + return self._group_by.from_resolver(self, resolver) + def select(self, irs: Seq[NamedIR]) -> Self: return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 8e1b8e0e88..aead4a8e1a 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -13,7 +13,7 @@ from narwhals._utils import Version if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._plan import expressions as ir from narwhals._plan.compliant.scalar import CompliantScalar @@ -34,6 +34,8 @@ Not, ) +Incomplete: TypeAlias = Any + class CompliantExpr(HasVersion, Protocol[FrameT_contra, SeriesT_co]): """Everything common to `Expr`/`Series` and `Scalar` literal values.""" @@ -120,6 +122,9 @@ def cum_prod( def cum_sum( self, node: FunctionExpr[F.CumSum], frame: FrameT_contra, name: str ) -> Self: ... + def rank( + self, node: FunctionExpr[F.Rank], frame: FrameT_contra, name: str + ) -> Self: ... # series -> scalar def all( self, node: FunctionExpr[boolean.All], frame: FrameT_contra, name: str @@ -175,13 +180,74 @@ def std( def var( self, node: agg.Var, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def kurtosis( + self, node: FunctionExpr[F.Kurtosis], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def skew( + self, node: FunctionExpr[F.Skew], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + + # mixed/todo + def clip( + self, node: FunctionExpr[F.Clip], frame: FrameT_contra, name: str + ) -> Self: ... + def drop_nulls( + self, node: FunctionExpr[F.DropNulls], frame: FrameT_contra, name: str + ) -> Self: ... + def exp(self, node: FunctionExpr[F.Exp], frame: FrameT_contra, name: str) -> Self: ... + def fill_null_with_strategy( + self, node: FunctionExpr[F.FillNullWithStrategy], frame: FrameT_contra, name: str + ) -> Self: ... + def hist_bins( + self, node: FunctionExpr[F.HistBins], frame: FrameT_contra, name: str + ) -> Self: ... + def hist_bin_count( + self, node: FunctionExpr[F.HistBinCount], frame: FrameT_contra, name: str + ) -> Self: ... + def is_duplicated( + self, node: FunctionExpr[boolean.IsDuplicated], frame: FrameT_contra, name: str + ) -> Self: ... + def is_in_expr( + self, node: FunctionExpr[boolean.IsInExpr], frame: FrameT_contra, name: str + ) -> Self: ... + def is_in_seq( + self, node: FunctionExpr[boolean.IsInSeq], frame: FrameT_contra, name: str + ) -> Self: ... + def is_unique( + self, node: FunctionExpr[boolean.IsUnique], frame: FrameT_contra, name: str + ) -> Self: ... + def log(self, node: FunctionExpr[F.Log], frame: FrameT_contra, name: str) -> Self: ... + def mode( + self, node: FunctionExpr[F.Mode], frame: FrameT_contra, name: str + ) -> Self: ... + def replace_strict( + self, node: FunctionExpr[F.ReplaceStrict], frame: FrameT_contra, name: str + ) -> Self: ... + def round( + self, node: FunctionExpr[F.Round], frame: FrameT_contra, name: str + ) -> Self: ... + def sqrt( + self, node: FunctionExpr[F.Sqrt], frame: FrameT_contra, name: str + ) -> Self: ... + def unique( + self, node: FunctionExpr[F.Unique], frame: FrameT_contra, name: str + ) -> Self: ... class EagerExpr( EagerBroadcast[SeriesT], CompliantExpr[FrameT_contra, SeriesT], Protocol[FrameT_contra, SeriesT], -): ... +): + def gather_every( + self, node: FunctionExpr[F.GatherEvery], frame: FrameT_contra, name: str + ) -> Self: ... + def is_in_series( + self, + node: FunctionExpr[boolean.IsInSeries[Incomplete]], + frame: FrameT_contra, + name: str, + ) -> Self: ... class LazyExpr( diff --git a/narwhals/_plan/compliant/group_by.py b/narwhals/_plan/compliant/group_by.py index adac2bb402..be09f16eb5 100644 --- a/narwhals/_plan/compliant/group_by.py +++ b/narwhals/_plan/compliant/group_by.py @@ -60,6 +60,17 @@ class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDat _key_names_original: Seq[str] _column_names_original: Seq[str] + @property + def compliant(self) -> EagerDataFrameT: + return self._df + + def agg_over(self, irs: Seq[NamedIR[Any]]) -> EagerDataFrameT: + """Perform a windowed aggregation. + + Returns the re-joined aggregation results. + """ + ... + @classmethod def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: obj = cls.__new__(cls) @@ -71,9 +82,7 @@ def by_names(cls, df: EagerDataFrameT, names: Seq[str], /) -> Self: return obj @classmethod - def from_resolver( - cls, df: EagerDataFrameT, resolver: GroupByResolver, / - ) -> EagerDataFrameGroupBy[EagerDataFrameT]: + def from_resolver(cls, df: EagerDataFrameT, resolver: GroupByResolver, /) -> Self: key_names = resolver.key_names if not resolver.requires_projection(): df = df.drop_nulls(key_names) if resolver._drop_null_keys else df diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index bdab2e03e6..3e86327240 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -4,13 +4,13 @@ from narwhals._plan.compliant.expr import CompliantExpr, EagerExpr, LazyExpr from narwhals._plan.compliant.typing import FrameT_contra, LengthT, SeriesT, SeriesT_co +from narwhals._utils import not_implemented if TYPE_CHECKING: from typing_extensions import Self from narwhals._plan import expressions as ir from narwhals._plan.expressions import FunctionExpr, aggregation as agg - from narwhals._plan.expressions.boolean import IsFirstDistinct, IsLastDistinct from narwhals._plan.expressions.functions import EwmMean, NullCount, Shift from narwhals._utils import Version from narwhals.typing import IntoDType, PythonLiteral @@ -35,6 +35,29 @@ def _with_evaluated(self, evaluated: Any, name: str) -> Self: obj._version = self.version return obj + # NOTE: Constant behaviors with scalars observed in `polars` + + def _always_nan(self, node: ir.ExprIR, frame: Any, name: str) -> Self: + return self.from_python(float("nan"), name, dtype=None, version=self.version) + + def _always_noop(self, node: ir.ExprIR, frame: Any, name: str) -> Self: + return self._with_evaluated(self._evaluated, name) + + def _always_true(self, node: ir.ExprIR, frame: Any, name: str) -> Self: + return self.from_python(True, name, dtype=None, version=self.version) + + def _always_false(self, node: ir.ExprIR, frame: Any, name: str) -> Self: + return self.from_python(False, name, dtype=None, version=self.version) + + def _always_null(self, node: ir.ExprIR, frame: Any, name: str) -> Self: + return self.from_python(None, name, dtype=None, version=self.version) + + def _always_zero(self, node: ir.ExprIR, frame: Any, name: str) -> Self: + return self.from_python(0, name, dtype=None, version=self.version) + + def _always_one(self, node: ir.ExprIR, frame: Any, name: str) -> Self: + return self.from_python(1, name, dtype=None, version=self.version) + @property def name(self) -> str: return self._name @@ -49,11 +72,6 @@ def from_python( dtype: IntoDType | None, version: Version, ) -> Self: ... - def arg_max(self, node: agg.ArgMax, frame: FrameT_contra, name: str) -> Self: - return self.from_python(0, name, dtype=None, version=self.version) - - def arg_min(self, node: agg.ArgMin, frame: FrameT_contra, name: str) -> Self: - return self.from_python(0, name, dtype=None, version=self.version) def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: """Returns 0 if null, else 1.""" @@ -64,40 +82,12 @@ def ewm_mean( ) -> Self: return self._cast_float(node.input[0], frame, name) - def first(self, node: agg.First, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated, name) - - def is_first_distinct( - self, node: FunctionExpr[IsFirstDistinct], frame: FrameT_contra, name: str - ) -> Self: - return self.from_python(True, name, dtype=None, version=self.version) - - def is_last_distinct( - self, node: FunctionExpr[IsLastDistinct], frame: FrameT_contra, name: str - ) -> Self: - return self.from_python(True, name, dtype=None, version=self.version) - - def last(self, node: agg.Last, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated, name) - - def len(self, node: agg.Len, frame: FrameT_contra, name: str) -> Self: - return self.from_python(1, name, dtype=None, version=self.version) - - def max(self, node: agg.Max, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated, name) - def mean(self, node: agg.Mean, frame: FrameT_contra, name: str) -> Self: return self._cast_float(node.expr, frame, name) def median(self, node: agg.Median, frame: FrameT_contra, name: str) -> Self: return self._cast_float(node.expr, frame, name) - def min(self, node: agg.Min, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated, name) - - def n_unique(self, node: agg.NUnique, frame: FrameT_contra, name: str) -> Self: - return self.from_python(1, name, dtype=None, version=self.version) - def null_count( self, node: FunctionExpr[NullCount], frame: FrameT_contra, name: str ) -> Self: @@ -112,22 +102,30 @@ def shift(self, node: FunctionExpr[Shift], frame: FrameT_contra, name: str) -> S return self._with_evaluated(self._evaluated, name) return self.from_python(None, name, dtype=None, version=self.version) - def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated, name) - - def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated, name) - - def std(self, node: agg.Std, frame: FrameT_contra, name: str) -> Self: - return self.from_python(None, name, dtype=None, version=self.version) - - def sum(self, node: agg.Sum, frame: FrameT_contra, name: str) -> Self: - return self._with_evaluated(self._evaluated, name) - - def var(self, node: agg.Var, frame: FrameT_contra, name: str) -> Self: - return self.from_python(None, name, dtype=None, version=self.version) - - # NOTE: `Filter` behaves the same, (maybe) no need to override + arg_max = _always_zero # type: ignore[misc] + arg_min = _always_zero # type: ignore[misc] + is_first_distinct = _always_true # type: ignore[misc] + is_last_distinct = _always_true # type: ignore[misc] + is_unique = _always_true # type: ignore[misc] + is_duplicated = _always_false # type: ignore[misc] + n_unique = _always_one # type: ignore[misc] + std = _always_null # type: ignore[misc] + var = _always_null # type: ignore[misc] + first = _always_noop # type: ignore[misc] + max = _always_noop # type: ignore[misc] + min = _always_noop # type: ignore[misc] + last = _always_noop # type: ignore[misc] + sort = _always_noop # type: ignore[misc] + sort_by = _always_noop # type: ignore[misc] + sum = _always_noop # type: ignore[misc] + mode = _always_noop # type: ignore[misc] + unique = _always_noop # type: ignore[misc] + kurtosis = _always_nan # type: ignore[misc] + skew = _always_nan # type: ignore[misc] + fill_null_with_strategy = not_implemented() # type: ignore[misc] + hist_bins = not_implemented() # type: ignore[misc] + hist_bin_count = not_implemented() # type: ignore[misc] + len = _always_one # type: ignore[misc] class EagerScalar( @@ -140,6 +138,8 @@ def __len__(self) -> int: def to_python(self) -> PythonLiteral: ... + gather_every = not_implemented() # type: ignore[misc] + class LazyScalar( CompliantScalar[FrameT_contra, SeriesT], diff --git a/narwhals/_plan/compliant/series.py b/narwhals/_plan/compliant/series.py index 7333479d8a..f9c33523ff 100644 --- a/narwhals/_plan/compliant/series.py +++ b/narwhals/_plan/compliant/series.py @@ -4,7 +4,7 @@ from narwhals._plan.compliant.typing import HasVersion from narwhals._plan.typing import NativeSeriesT -from narwhals._utils import Version +from narwhals._utils import Version, _StoresNative if TYPE_CHECKING: from collections.abc import Iterable @@ -15,7 +15,7 @@ from narwhals._plan.series import Series from narwhals._typing import _EagerAllowedImpl from narwhals.dtypes import DType - from narwhals.typing import Into1DArray, IntoDType, _1DArray + from narwhals.typing import Into1DArray, IntoDType, SizedMultiIndexSelector, _1DArray Incomplete: TypeAlias = Any @@ -28,6 +28,9 @@ class CompliantSeries(HasVersion, Protocol[NativeSeriesT]): def __len__(self) -> int: return len(self.native) + def len(self) -> int: + return len(self.native) + def __narwhals_namespace__(self) -> Incomplete: ... def __narwhals_series__(self) -> Self: return self @@ -72,6 +75,18 @@ def alias(self, name: str) -> Self: return self.from_native(self.native, name, version=self.version) def cast(self, dtype: IntoDType) -> Self: ... + def gather( + self, + indices: SizedMultiIndexSelector[NativeSeriesT] | _StoresNative[NativeSeriesT], + ) -> Self: ... + def has_nulls(self) -> bool: ... + def is_empty(self) -> bool: + return len(self) == 0 + + def is_in(self, other: Self) -> Self: ... + def scatter(self, indices: Self, values: Self) -> Self: ... + def slice(self, offset: int, length: int | None = None) -> Self: ... + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: ... def to_frame(self) -> Incomplete: ... def to_list(self) -> list[Any]: ... def to_narwhals(self) -> Series[NativeSeriesT]: diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 2bb9a0860a..59a22c5438 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -5,7 +5,6 @@ from narwhals._plan import _parse from narwhals._plan._expansion import expand_selector_irs_names, prepare_projection from narwhals._plan.common import ensure_seq_str, temp -from narwhals._plan.exceptions import group_by_no_keys_error from narwhals._plan.group_by import GroupBy, Grouped from narwhals._plan.options import SortMultipleOptions from narwhals._plan.series import Series @@ -107,7 +106,7 @@ def sort( nulls_last: OneOrIterable[bool] = False, ) -> Self: s_irs = _parse.parse_into_seq_of_selector_ir(by, *more_by) - names = expand_selector_irs_names(s_irs, schema=self) + names = expand_selector_irs_names(s_irs, schema=self, require_any=True) opts = SortMultipleOptions.parse(descending=descending, nulls_last=nulls_last) return self._with_compliant(self._compliant.sort(names, opts)) @@ -115,15 +114,18 @@ def drop( self, *columns: OneOrIterable[ColumnNameOrSelector], strict: bool = True ) -> Self: s_ir = _parse.parse_into_combined_selector_ir(*columns, require_all=strict) - names = expand_selector_irs_names((s_ir,), schema=self) - return self._with_compliant(self._compliant.drop(names)) + if names := expand_selector_irs_names((s_ir,), schema=self): + compliant = self._compliant.drop(names) + else: + compliant = self._compliant._with_native(self.to_native()) + return self._with_compliant(compliant) def drop_nulls( self, subset: OneOrIterable[ColumnNameOrSelector] | None = None ) -> Self: if subset is not None: s_irs = _parse.parse_into_seq_of_selector_ir(subset) - subset = expand_selector_irs_names(s_irs, schema=self) + subset = expand_selector_irs_names(s_irs, schema=self) or None return self._with_compliant(self._compliant.drop_nulls(subset)) def rename(self, mapping: Mapping[str, str]) -> Self: @@ -132,6 +134,13 @@ def rename(self, mapping: Mapping[str, str]) -> Self: def collect_schema(self) -> Schema: return self.schema + def with_row_index( + self, name: str = "index", *, order_by: OneOrIterable[ColumnNameOrSelector] + ) -> Self: + by_selectors = _parse.parse_into_seq_of_selector_ir(order_by) + by_names = expand_selector_irs_names(by_selectors, schema=self, require_any=True) + return self._with_compliant(self._compliant.with_row_index_by(name, by_names)) + class DataFrame( BaseFrame[NativeDataFrameT_co], Generic[NativeDataFrameT_co, NativeSeriesT] @@ -281,12 +290,23 @@ def partition_by( include_key: bool = True, ) -> list[Self]: by_selectors = _parse.parse_into_seq_of_selector_ir(by, *more_by) - names = expand_selector_irs_names(by_selectors, schema=self) - if not names: - raise group_by_no_keys_error() + names = expand_selector_irs_names(by_selectors, schema=self, require_any=True) partitions = self._compliant.partition_by(names, include_key=include_key) return [self._with_compliant(p) for p in partitions] + def with_row_index( + self, + name: str = "index", + *, + order_by: OneOrIterable[ColumnNameOrSelector] | None = None, + ) -> Self: + if order_by is None: + return self._with_compliant(self._compliant.with_row_index(name)) + return super().with_row_index(name, order_by=order_by) + + def slice(self, offset: int, length: int | None = None) -> Self: # pragma: no cover + return type(self)(self._compliant.slice(offset=offset, length=length)) + def _is_join_strategy(obj: Any) -> TypeIs[JoinStrategy]: return obj in {"inner", "left", "full", "cross", "anti", "semi"} diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 75ecd02a87..05348372c0 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -26,7 +26,9 @@ from narwhals._plan._function import Function from narwhals._plan.expressions.operators import Operator from narwhals._plan.options import SortOptions + from narwhals._plan.schema import FrozenSchema from narwhals._plan.typing import IntoExpr, Seq + from narwhals.typing import IntoSchema # NOTE: Using verbose names to start with @@ -225,6 +227,39 @@ def column_index_error( return ColumnNotFoundError(msg) -def group_by_no_keys_error() -> ComputeError: +# TODO @dangotbanned: Remove or get coverage for failing: +# - `GroupByResolver.key_names` +# - `DataFrameGroupBy.key_names` +def group_by_no_keys_error() -> ComputeError: # pragma: no cover msg = "at least one key is required in a group_by operation" return ComputeError(msg) + + +def format_expressions(*exprs: ir.ExprIR, indent: int = 2) -> str: + """Return an indented list of `exprs` reprs.""" + indent_str = " " * indent + return "\n".join(f"{indent_str}{e!r}" for e in exprs) + + +def selectors_not_found_error( + selectors: Collection[ir.SelectorIR], schema: IntoSchema | FrozenSchema +) -> ColumnNotFoundError: + msg = "Found no columns when expanding:" + if len(selectors) == 1: + msg = f"{msg} {next(iter(selectors))!r}" + else: + msg = f"{msg}\n{format_expressions(*selectors)}" + items = dict(schema) + msg = f"{msg}\n\nHint: Did you mean one of these columns: {items!r}?" + return ColumnNotFoundError(msg) + + +def expand_multi_output_error( + origin: ir.ExprIR, child: ir.ExprIR, *expanded: ir.ExprIR +) -> MultiOutputExpressionError: + msg = ( + "Multi-output expressions are not supported in this context.\n" + f"Got `{origin!r}`, but `{child!r}` expanded into {len(expanded)} outputs:\n" + f"{format_expressions(*expanded)}" + ) + return MultiOutputExpressionError(msg) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 5dfa9e3770..6b5798f8b1 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -2,7 +2,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar from narwhals._plan import common, expressions as ir from narwhals._plan._guards import is_expr, is_series @@ -28,7 +28,10 @@ from narwhals.exceptions import ComputeError if TYPE_CHECKING: - from typing_extensions import Self + from collections.abc import Callable + from typing import TypeVar + + from typing_extensions import Concatenate, ParamSpec, Self from narwhals._plan._function import Function from narwhals._plan.expressions.categorical import ExprCatNamespace @@ -49,6 +52,9 @@ TemporalLiteral, ) + P = ParamSpec("P") + R = TypeVar("R") + class Expr: _ir: ir.ExprIR @@ -416,16 +422,23 @@ def is_between( ir.boolean.IsBetween(closed=closed).to_function_expr(self._ir, *it) ) - def is_in(self, other: Iterable[Any]) -> Self: + def is_in(self, other: Iterable[Any] | Expr) -> Self: if is_series(other): return self._with_unary(ir.boolean.IsInSeries.from_series(other)) if isinstance(other, Iterable): return self._with_unary(ir.boolean.IsInSeq.from_iterable(other)) if is_expr(other): - return self._with_unary(ir.boolean.IsInExpr(other=other._ir)) - msg = f"`is_in` only supports iterables, got: {type(other).__name__}" + return self._from_ir( + ir.boolean.IsInExpr().to_function_expr(self._ir, other._ir) + ) + msg = f"`is_in` only supports iterables or Expr, got: {type(other).__name__}" raise TypeError(msg) + def pipe( + self, function: Callable[Concatenate[Self, P], R], *args: P.args, **kwds: P.kwargs + ) -> R: + return function(self, *args, **kwds) + def _with_binary( self, op: type[ops.Operator], diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index f6042a391d..b6bc93dd88 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -7,15 +7,16 @@ from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.options import FEOptions, FunctionOptions +from narwhals._plan.typing import NativeSeriesT from narwhals._typing_compat import TypeVar if TYPE_CHECKING: from typing_extensions import Self from narwhals._plan._expr_ir import ExprIR - from narwhals._plan.expressions.expr import FunctionExpr, Literal # noqa: F401 + from narwhals._plan.expressions.expr import FunctionExpr, Literal from narwhals._plan.series import Series - from narwhals._plan.typing import NativeSeriesT, Seq # noqa: F401 + from narwhals._plan.typing import Seq from narwhals.typing import ClosedInterval OtherT = TypeVar("OtherT") @@ -48,15 +49,13 @@ def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR, Exp return expr, lower_bound, upper_bound -class IsIn(BooleanFunction, t.Generic[OtherT]): +class IsInSeq(BooleanFunction): __slots__ = ("other",) - other: OtherT + other: Seq[t.Any] def __repr__(self) -> str: return "is_in" - -class IsInSeq(IsIn["Seq[t.Any]"]): @classmethod def from_iterable(cls, other: t.Iterable[t.Any], /) -> IsInSeq: if not isinstance(other, (str, bytes)): @@ -65,8 +64,13 @@ def from_iterable(cls, other: t.Iterable[t.Any], /) -> IsInSeq: raise TypeError(msg) -# NOTE: Shouldn't be allowed for lazy backends (maybe besides `polars`) -class IsInSeries(IsIn["Literal[Series[NativeSeriesT]]"]): +class IsInSeries(BooleanFunction, t.Generic[NativeSeriesT]): + __slots__ = ("other",) + other: Literal[Series[NativeSeriesT]] + + def __repr__(self) -> str: + return "is_in" + @classmethod def from_series(cls, other: Series[NativeSeriesT], /) -> IsInSeries[NativeSeriesT]: from narwhals._plan.expressions.literal import SeriesLiteral @@ -74,11 +78,19 @@ def from_series(cls, other: Series[NativeSeriesT], /) -> IsInSeries[NativeSeries return IsInSeries(other=SeriesLiteral(value=other).to_literal()) -# NOTE: Placeholder for allowing `Expr` iff it passes `.meta.is_column()` -class IsInExpr(IsIn[ExprT], t.Generic[ExprT]): - def __init__(self, *, other: ExprT) -> None: - msg = ( - "`is_in` doesn't accept expressions as an argument, as opposed to Polars. " - "You should provide an iterable instead." - ) - raise NotImplementedError(msg) +class IsInExpr(BooleanFunction): + """N-ary (expr, other). + + Note: + If we get to a stage where `narwhals` has wide support for `list`, and + accepts them in `lit(...)` - *consider* [restricting to non-equal types]. + + [restricting to non-equal types]: https://github.com/pola-rs/polars/pull/22178 + """ + + def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: + expr, other = node.input + return expr, other + + def __repr__(self) -> str: + return "is_in" diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 03550afb7c..4a31389170 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -1,20 +1,20 @@ from __future__ import annotations import enum -from itertools import repeat from typing import TYPE_CHECKING from narwhals._plan._immutable import Immutable +from narwhals._utils import Implementation +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Sequence - import pyarrow.acero import pyarrow.compute as pc from typing_extensions import Self - from narwhals._plan.arrow.typing import NullPlacement - from narwhals._plan.typing import Accessor, OneOrIterable, Order, Seq + from narwhals._plan.typing import Accessor, OneOrIterable, Seq + from narwhals._typing import Backend from narwhals.typing import RankMethod @@ -155,22 +155,10 @@ def __repr__(self) -> str: args = f"descending={self.descending!r}, nulls_last={self.nulls_last!r}" return f"{type(self).__name__}({args})" - def to_arrow(self) -> pc.ArraySortOptions: - import pyarrow.compute as pc + def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: + from narwhals._plan.arrow.options import sort - return pc.ArraySortOptions( - order=("descending" if self.descending else "ascending"), - null_placement=("at_end" if self.nulls_last else "at_start"), - ) - - def to_multiple(self, n_repeat: int = 1, /) -> SortMultipleOptions: - if n_repeat == 1: - desc: Seq[bool] = (self.descending,) - nulls: Seq[bool] = (self.nulls_last,) - else: - desc = tuple(repeat(self.descending, n_repeat)) - nulls = tuple(repeat(self.nulls_last, n_repeat)) - return SortMultipleOptions(descending=desc, nulls_last=nulls) + return sort(*by, descending=self.descending, nulls_last=self.nulls_last) class SortMultipleOptions(Immutable): @@ -192,36 +180,18 @@ def parse( nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) return SortMultipleOptions(descending=desc, nulls_last=nulls) - def _to_arrow_args( - self, by: Sequence[str] - ) -> tuple[Sequence[tuple[str, Order]], NullPlacement]: + def _ensure_single_nulls_last(self, backend: Backend, /) -> bool: first = self.nulls_last[0] if len(self.nulls_last) != 1 and any(x != first for x in self.nulls_last[1:]): - msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {self.nulls_last!r}" # pragma: no cover + msg = f"{Implementation.from_backend(backend)!r} does not support multiple values for `nulls_last`, got: {self.nulls_last!r}" # pragma: no cover raise NotImplementedError(msg) - if len(self.descending) == 1: - descending: Iterable[bool] = repeat(self.descending[0], len(by)) - else: - descending = self.descending - sorting = tuple[tuple[str, "Order"]]( - (key, "descending" if desc else "ascending") - for key, desc in zip(by, descending) - ) - return sorting, "at_end" if first else "at_start" + return first def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: - import pyarrow.compute as pc + from narwhals._plan.arrow.options import sort - sort_keys, placement = self._to_arrow_args(by) - return pc.SortOptions(sort_keys=sort_keys, null_placement=placement) - - def to_arrow_acero( - self, by: Sequence[str] - ) -> pyarrow.acero.Declaration: # pragma: no cover - from narwhals._plan.arrow import acero - - sort_keys, placement = self._to_arrow_args(by) - return acero._order_by(sort_keys, null_placement=placement) + nulls_last = self._ensure_single_nulls_last("pyarrow") + return sort(*by, descending=self.descending, nulls_last=nulls_last) class RankOptions(Immutable): @@ -229,6 +199,14 @@ class RankOptions(Immutable): method: RankMethod descending: bool + def to_arrow(self) -> pc.RankOptions: + if self.method == "average": # pragma: no cover + msg = f"`RankOptions.to_arrow` is not compatible with {self.method=}." + raise InvalidOperationError(msg) + from narwhals._plan.arrow.options import rank + + return rank(self.method, descending=self.descending) + class EWMOptions(Immutable): """Deviates from polars, since we aren't pre-computing alpha. diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 65c7197895..6dca244806 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, overload from narwhals._plan import expressions as ir -from narwhals._plan._guards import is_column, is_re_pattern +from narwhals._plan._guards import is_expr_column, is_re_pattern from narwhals._plan.common import flatten_hash_safe from narwhals._plan.expr import Expr, ExprV1 from narwhals._plan.expressions import operators as ops, selectors as s_ir @@ -87,7 +87,7 @@ def __and__(self, other: Self) -> Self: ... @overload def __and__(self, other: Any) -> Expr: ... def __and__(self, other: Any) -> Self | Expr: - if is_column(other): # @polars>=2.0: remove + if is_expr_column(other): # @polars>=2.0: remove other = by_name(other.meta.output_name()) if isinstance(other, type(self)): op = ops.And() @@ -102,7 +102,7 @@ def __or__(self, other: Self) -> Self: ... @overload def __or__(self, other: Any) -> Expr: ... def __or__(self, other: Any) -> Self | Expr: - if is_column(other): # @polars>=2.0: remove + if is_expr_column(other): # @polars>=2.0: remove other = by_name(other.meta.output_name()) if isinstance(other, type(self)): op = ops.Or() @@ -110,7 +110,7 @@ def __or__(self, other: Any) -> Self | Expr: return self.as_expr().__or__(other) def __ror__(self, other: Any) -> Expr: # type: ignore[override] - if is_column(other): + if is_expr_column(other): other = by_name(other.meta.output_name()) return self.as_expr().__ror__(other) @@ -133,7 +133,7 @@ def __xor__(self, other: Self) -> Self: ... @overload def __xor__(self, other: Any) -> Expr: ... def __xor__(self, other: Any) -> Self | Expr: - if is_column(other): # @polars>=2.0: remove + if is_expr_column(other): # @polars>=2.0: remove other = by_name(other.meta.output_name()) if isinstance(other, type(self)): op = ops.ExclusiveOr() @@ -141,7 +141,7 @@ def __xor__(self, other: Any) -> Self | Expr: return self.as_expr().__xor__(other) def __rxor__(self, other: Any) -> Expr: # type: ignore[override] - if is_column(other): # @polars>=2.0: remove + if is_expr_column(other): # @polars>=2.0: remove other = by_name(other.meta.output_name()) return self.as_expr().__rxor__(other) diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index ae7337dc80..0220253087 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -1,21 +1,32 @@ from __future__ import annotations +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, ClassVar, Generic -from narwhals._plan.typing import NativeSeriesT, NativeSeriesT_co -from narwhals._utils import Implementation, Version, is_eager_allowed +from narwhals._plan._guards import is_series +from narwhals._plan.typing import NativeSeriesT, NativeSeriesT_co, OneOrIterable +from narwhals._utils import ( + Implementation, + Version, + generate_repr, + is_eager_allowed, + qualified_type_name, +) from narwhals.dependencies import is_pyarrow_chunked_array if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Iterator import polars as pl - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._plan.compliant.series import CompliantSeries - from narwhals._typing import EagerAllowed, IntoBackend + from narwhals._plan.dataframe import DataFrame + from narwhals._typing import EagerAllowed, IntoBackend, _EagerAllowedImpl from narwhals.dtypes import DType - from narwhals.typing import IntoDType + from narwhals.typing import IntoDType, NonNestedLiteral, SizedMultiIndexSelector + +Incomplete: TypeAlias = Any class Series(Generic[NativeSeriesT_co]): @@ -34,9 +45,16 @@ def dtype(self) -> DType: def name(self) -> str: return self._compliant.name + @property + def implementation(self) -> _EagerAllowedImpl: + return self._compliant.implementation + def __init__(self, compliant: CompliantSeries[NativeSeriesT_co], /) -> None: self._compliant = compliant + def __repr__(self) -> str: + return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) + @classmethod def from_iterable( cls: type[Series[Any]], @@ -72,6 +90,14 @@ def from_native( raise NotImplementedError(type(native)) + # NOTE: `Incomplete` until `CompliantSeries` can avoid a cyclic dependency back to `CompliantDataFrame` + # Currently an issue on `main` and leads to a lot of intermittent warnings + def to_frame(self) -> DataFrame[Incomplete, NativeSeriesT_co]: + import narwhals._plan.dataframe as _df + + # NOTE: Missing placeholder for `DataFrameV1` + return _df.DataFrame(self._compliant.to_frame()) + def to_native(self) -> NativeSeriesT_co: return self._compliant.native @@ -90,6 +116,63 @@ def alias(self, name: str) -> Self: def __len__(self) -> int: return len(self._compliant) + def gather(self, indices: SizedMultiIndexSelector[Self]) -> Self: # pragma: no cover + if len(indices) == 0: + return self.slice(0, 0) + rows = indices._compliant if isinstance(indices, Series) else indices + return type(self)(self._compliant.gather(rows)) + + def has_nulls(self) -> bool: # pragma: no cover + return self._compliant.has_nulls() + + def slice(self, offset: int, length: int | None = None) -> Self: # pragma: no cover + return type(self)(self._compliant.slice(offset=offset, length=length)) + + def sort( + self, *, descending: bool = False, nulls_last: bool = False + ) -> Self: # pragma: no cover + result = self._compliant.sort(descending=descending, nulls_last=nulls_last) + return type(self)(result) + + def is_empty(self) -> bool: # pragma: no cover + return self._compliant.is_empty() + + def _unwrap_compliant( + self, other: Series[Any], / + ) -> CompliantSeries[NativeSeriesT_co]: + compliant = other._compliant + if isinstance(compliant, type(self._compliant)): + return compliant + msg = f"Expected {qualified_type_name(self._compliant)!r}, got {qualified_type_name(compliant)!r}" + raise NotImplementedError(msg) + + def _parse_into_compliant( + self, other: Series[Any] | Iterable[Any], / + ) -> CompliantSeries[NativeSeriesT_co]: + if is_series(other): + return self._unwrap_compliant(other) + return self._compliant.from_iterable(other, version=self.version) + + def scatter( + self, + indices: Self | OneOrIterable[int], + values: Self | OneOrIterable[NonNestedLiteral], + ) -> Self: + if not isinstance(indices, Iterable): + indices = [indices] + indices_ = self._parse_into_compliant(indices) + if indices_.is_empty(): + return self + if not is_series(values) and ( + not isinstance(values, Iterable) or isinstance(values, str) + ): + values = [values] + result = self._compliant.scatter(indices_, self._parse_into_compliant(values)) + return type(self)(result) + + def is_in(self, other: Iterable[Any]) -> Self: + return type(self)(self._compliant.is_in(self._parse_into_compliant(other))) + class SeriesV1(Series[NativeSeriesT_co]): _version: ClassVar[Version] = Version.V1 diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 72324aeaac..b6e640adf8 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -728,6 +728,19 @@ def test_expand_binary_expr_combination_invalid(df_1: Frame) -> None: df_1.project(ten_to_nine) +def test_expand_function_expr_multi_invalid(df_1: Frame) -> None: + first_column = re.escape("col('a')") + last_selected_column = re.escape("col('h')") + found = rf".+{first_column}.+{last_selected_column}\Z" + pattern = re_compile( + rf"not supported.+context.+{first_column}\.is_in.+ncs\.integer.+ncs\.integer.+expanded into 8 outputs{found}" + ) + with pytest.raises(MultiOutputExpressionError, match=pattern): + df_1.project(nwp.col("a").is_in(ncs.integer())) + with pytest.raises(MultiOutputExpressionError, match=r"expanded into 20 outputs"): + df_1.project(nwp.col("d").is_in(nwp.all())) + + def test_over_order_by_names() -> None: expr = nwp.col("a").first().over(order_by=ncs.string()) e_ir = expr._ir diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 7df7e5b957..b588b966b8 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -406,12 +406,6 @@ def test_is_in_series() -> None: [ ("words", pytest.raises(TypeError, match=r"str \| bytes.+str")), (b"words", pytest.raises(TypeError, match=r"str \| bytes.+bytes")), - ( - nwp.col("b"), - pytest.raises( - NotImplementedError, match=re.compile(r"iterable instead", re.IGNORECASE) - ), - ), ( 999, pytest.raises( diff --git a/tests/plan/frame_partition_by_test.py b/tests/plan/frame_partition_by_test.py index 429a78fb20..3c6010578b 100644 --- a/tests/plan/frame_partition_by_test.py +++ b/tests/plan/frame_partition_by_test.py @@ -8,7 +8,7 @@ import narwhals as nw from narwhals._plan import Selector, selectors as ncs from narwhals._utils import zip_strict -from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError +from narwhals.exceptions import ColumnNotFoundError, DuplicateError from tests.plan.utils import assert_equal_data, dataframe, re_compile if TYPE_CHECKING: @@ -135,7 +135,7 @@ def test_partition_by_duplicate_names(data: Data) -> None: def test_partition_by_fully_empty_selector(data: Data) -> None: df = dataframe(data) with pytest.raises( - ComputeError, match=r"at least one key is required in a group_by operation" + ColumnNotFoundError, match=re_compile(r"ncs.array.+ncs.struct.+ncs.duration") ): df.partition_by(ncs.array(ncs.numeric()), ncs.struct(), ncs.duration()) diff --git a/tests/plan/frame_with_row_index_test.py b/tests/plan/frame_with_row_index_test.py new file mode 100644 index 0000000000..f56369454c --- /dev/null +++ b/tests/plan/frame_with_row_index_test.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan.selectors as ncs +from narwhals.exceptions import ColumnNotFoundError +from tests.plan.utils import assert_equal_data, dataframe, re_compile + +if TYPE_CHECKING: + from narwhals._plan.typing import ColumnNameOrSelector, OneOrIterable + + +def test_with_row_index_eager() -> None: + data = {"abc": ["foo", "bars"], "xyz": [100, 200], "const": [42, 42]} + result = dataframe(data).with_row_index() + expected = {"index": [0, 1], **data} + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("order_by", "expected_index"), + [ + (["a"], [0, 2, 1]), + (ncs.first(), [0, 2, 1]), + (ncs.string(), [0, 2, 1]), + (["c"], [2, 0, 1]), + (ncs.last(), [2, 0, 1]), + (ncs.integer() - ncs.by_index(1), [2, 0, 1]), + (["a", "c"], [1, 2, 0]), + ([ncs.first(), "c"], [1, 2, 0]), + (["a", ncs.by_name("c")], [1, 2, 0]), + (["c", "a"], [2, 0, 1]), + ([ncs.by_index(-1, 0)], [2, 0, 1]), + ([ncs.last(), ncs.first()], [2, 0, 1]), + ], +) +def test_with_row_index_by( + order_by: OneOrIterable[ColumnNameOrSelector], expected_index: list[int] +) -> None: + # https://github.com/narwhals-dev/narwhals/issues/3289 + data = {"a": ["A", "B", "A"], "b": [1, 2, 3], "c": [9, 2, 4]} + result = dataframe(data).with_row_index(name="index", order_by=order_by).sort("b") + expected = {"index": expected_index, **data} + assert_equal_data(result, expected) + + +def test_with_row_index_by_invalid() -> None: + data = {"a": ["A", "B", "A"], "b": [1, 2, 3], "c": [9, 2, 4]} + df = dataframe(data) + + with pytest.raises(ColumnNotFoundError, match=re.escape("not found: ['d']")): + df.with_row_index(order_by="d") + + with pytest.raises(ColumnNotFoundError, match=re.escape("not found: ['e']")): + df.with_row_index(order_by=["e", "b"]) + + with pytest.raises(ColumnNotFoundError, match=r"Invalid column index 5"): + df.with_row_index(order_by=ncs.by_index(5)) + + +def test_with_row_index_by_empty_selection() -> None: + data = {"a": ["A", "B", "A"], "b": [1, 2, 3], "c": [9, 2, 4]} + df = dataframe(data) + with pytest.raises(ColumnNotFoundError, match=re.escape("ncs.datetime(")): + df.with_row_index(order_by=ncs.datetime()) + + schema = re.escape("{'a': String, 'b': Int64, 'c': Int64}") + pattern = re_compile(rf"ncs.float\(\).*ncs.temporal\(\).*Hint:.+{schema}") + with pytest.raises(ColumnNotFoundError, match=pattern): + df.with_row_index(order_by=[ncs.float(), ncs.temporal()]) diff --git a/tests/plan/is_first_last_distinct_test.py b/tests/plan/is_first_last_distinct_test.py index 0b23eac761..74019caa73 100644 --- a/tests/plan/is_first_last_distinct_test.py +++ b/tests/plan/is_first_last_distinct_test.py @@ -21,26 +21,6 @@ def data_indexed(data: Data) -> Data: return data | {"i": [None, 1, 2, 3, 4, 5]} -@pytest.fixture -def data_alt_1() -> Data: - return {"a": [1, 1, 2, 2, 2], "b": [1, 3, 3, 2, 3]} - - -@pytest.fixture -def data_alt_1_indexed(data_alt_1: Data) -> Data: - return data_alt_1 | {"i": [0, 1, 2, 3, 4]} - - -@pytest.fixture -def data_alt_2() -> Data: - return {"a": [1, 1, 2, 2, 2], "b": [1, 2, 2, 2, 1]} - - -@pytest.fixture -def data_alt_2_indexed(data_alt_2: Data) -> Data: - return data_alt_2 | {"i": [None, 1, 2, 3, 4]} - - @pytest.fixture def expected() -> Data: return { @@ -54,14 +34,6 @@ def expected_invert(expected: Data) -> Data: return {k: [not el for el in v] for k, v in expected.items()} -# NOTE: Isn't supported on `main` for `pyarrow` + lots of other cases (non-elementary group-by agg) -# Could be interesting to attempt here? -XFAIL_PARTITIONED_ORDER_BY = pytest.mark.xfail( - reason="Not supporting `over(*partition_by, order_by=...)` yet", - raises=NotImplementedError, -) - - def test_is_first_distinct(data: Data, expected: Data) -> None: result = dataframe(data).select(nwp.all().is_first_distinct()) assert_equal_data(result, expected) @@ -92,43 +64,106 @@ def test_is_last_distinct_order_by(data_indexed: Data, expected_invert: Data) -> assert_equal_data(result, expected_invert) -@XFAIL_PARTITIONED_ORDER_BY -def test_is_first_distinct_partitioned_order_by( - data_alt_1_indexed: Data, -) -> None: # pragma: no cover - expected = {"b": [True, True, True, True, False]} +# NOTE: Everything from here onwards is not supported on `main` + + +@pytest.fixture +def grouped() -> Data: + return { + "group": ["A", "A", "B", "B", "B"], + "value_1": [1, 3, 3, 2, 3], + "value_2": [1, 3, 3, 3, 3], + "o_asc": [0, 1, 2, 3, 4], + "o_null": [0, 1, 2, None, 4], + } + + +GROUP = "group" +VALUE_1 = "value_1" +VALUE_2 = "value_2" +ORDER_ASC = "o_asc" +ORDER_NULL = "o_null" + + +# NOTE: For `pyarrow`, the result is identical to `order_by`, because the index is already in order +def test_is_first_last_distinct_partitioned(grouped: Data) -> None: + expected = { + GROUP: ["A", "A", "B", "B", "B"], + VALUE_1: [1, 3, 3, 2, 3], + "is_first_distinct": [True, True, True, True, False], + "is_last_distinct": [True, True, False, True, True], + } + df = dataframe(grouped).drop(VALUE_2, ORDER_NULL) + value = nwp.col(VALUE_1) result = ( - dataframe(data_alt_1_indexed) - .select(nwp.col("b").is_first_distinct().over("a", order_by="i"), "i") - .sort("i") - .drop("i") + df.with_columns( + is_first_distinct=value.is_first_distinct().over(GROUP), + is_last_distinct=value.is_last_distinct().over(GROUP), + ) + .sort(ORDER_ASC) + .drop(ORDER_ASC) ) assert_equal_data(result, expected) -@XFAIL_PARTITIONED_ORDER_BY -def test_is_last_distinct_partitioned_order_by( - data_alt_1_indexed: Data, -) -> None: # pragma: no cover - expected = {"b": [True, True, False, True, True]} +# NOTE: This works the same as `polars` +def test_is_first_last_distinct_partitioned_order_by_desc(grouped: Data) -> None: + expected = { + GROUP: ["A", "A", "B", "B", "B"], + VALUE_2: [1, 3, 3, 3, 3], + # (1) Same result + "first_distinct": [True, True, True, False, False], + "last_distinct_desc": [True, True, True, False, False], + # (2) Same result + "last_distinct": [True, True, False, False, True], + "first_distinct_desc": [True, True, False, False, True], + } + df = dataframe(grouped).drop(VALUE_1, ORDER_NULL) + value = nwp.col(VALUE_2) + first = value.is_first_distinct() + last = value.is_last_distinct() + result = ( - dataframe(data_alt_1_indexed) - .select(nwp.col("b").is_last_distinct().over("a", order_by="i"), "i") - .sort("i") - .drop("i") + df.with_columns( + first_distinct=first.over(GROUP, order_by=ORDER_ASC), + last_distinct_desc=last.over(GROUP, order_by=ORDER_ASC, descending=True), + last_distinct=last.over(GROUP, order_by=ORDER_ASC), + first_distinct_desc=first.over(GROUP, order_by=ORDER_ASC, descending=True), + ) + .sort(ORDER_ASC) + .drop(ORDER_ASC) ) assert_equal_data(result, expected) -@XFAIL_PARTITIONED_ORDER_BY -def test_is_last_distinct_partitioned_order_by_nulls( - data_alt_2_indexed: Data, -) -> None: # pragma: no cover - expected = {"b": [True, True, False, True, True]} +# NOTE: `polars` *currently* ignores the `nulls_last` argument +# https://github.com/pola-rs/polars/issues/24989 +def test_is_first_last_distinct_partitioned_order_by_nulls(grouped: Data) -> None: + expected = { + GROUP: ["A", "A", "B", "B", "B"], + VALUE_2: [1, 3, 3, 3, 3], + "first_distinct_nulls_first": [True, True, False, True, False], + "last_distinct_nulls_first": [True, True, False, False, True], + "first_distinct_nulls_last": [True, True, True, False, False], + "last_distinct_nulls_last": [True, True, False, True, False], + } + df = dataframe(grouped).drop(VALUE_1) + value = nwp.col(VALUE_2) + first = value.is_first_distinct() + last = value.is_last_distinct() result = ( - dataframe(data_alt_2_indexed) - .select(nwp.col("b").is_last_distinct().over("a", order_by="i"), "i") - .sort("i") - .drop("i") + df.with_columns( + first_distinct_nulls_first=first.over(GROUP, order_by=ORDER_NULL), + last_distinct_nulls_first=last.over(GROUP, order_by=ORDER_NULL), + first_distinct_nulls_last=first.over( + GROUP, order_by=ORDER_NULL, nulls_last=True + ), + last_distinct_nulls_last=last.over( + GROUP, order_by=ORDER_NULL, nulls_last=True + ), + ) + .sort(ORDER_ASC) + .drop(ORDER_ASC, ORDER_NULL) ) + assert_equal_data(result, expected) diff --git a/tests/plan/is_in_test.py b/tests/plan/is_in_test.py new file mode 100644 index 0000000000..4d5fb48364 --- /dev/null +++ b/tests/plan/is_in_test.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from collections import deque +from typing import TYPE_CHECKING, Any, Literal + +import pytest + +import narwhals as nw +import narwhals._plan as nwp +from narwhals._plan import selectors as ncs +from tests.plan.utils import assert_equal_data, assert_equal_series, dataframe, series + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from tests.conftest import Data + +pytest.importorskip("pyarrow") + +import pyarrow as pa + + +@pytest.fixture +def data() -> Data: + return {"a": [1, 4, 2, 5], "b": [1, 0, 2, 0], "c": [None, "hi", "hello", "howdy"]} + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.col("a").is_in([4, 5]), {"a": [False, True, False, True]}), + (nwp.col("a").is_in([]), {"a": [False, False, False, False]}), + (nwp.col("b").is_in(deque([0, 1])), {"b": [True, True, False, True]}), + (nwp.col("c").is_in(("howdy", None)), {"c": [True, False, False, True]}), + ( + ncs.integer().is_in([5, 6, 0]), + {"a": [False, False, False, True], "b": [False, True, False, True]}, + ), + (ncs.string().last().is_in(iter(["howdy"])), {"c": [True]}), + ( + (nwp.col("b").max() + nwp.col("a")).is_in(range(5, 10)), + {"b": [False, True, False, True]}, + ), + ], +) +def test_expr_is_in(data: Data, expr: nwp.Expr, expected: Data) -> None: + result = dataframe(data).select(expr) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.col("a").is_in(nwp.col("b")), {"a": [True, False, True, False]}), + (nwp.col("a").is_in(nwp.nth(1)), {"a": [True, False, True, False]}), + (nwp.col("b").is_in(nwp.col("a") - 5), {"b": [False, True, False, True]}), + ( + (nwp.col("b").max() + nwp.col("a")).is_in(nwp.int_range(5, 10)), + {"b": [False, True, False, True]}, + ), + (nwp.col("a").last().is_in(ncs.first()), {"a": [True]}), + ( + (nwp.col("a").last() - nwp.col("b").first()).is_in( + ncs.integer() - ncs.first() + ), + {"a": [False]}, + ), + ( + (ncs.integer() + 4).is_in(nwp.nth(0).filter(nwp.col("b") < 2)), + {"a": [True, False, False, False], "b": [True, True, False, True]}, + ), + ( + ncs.string().is_in(nwp.lit(None, nw.String)), + {"c": [True, False, False, False]}, + ), + ], +) +def test_expr_is_in_expr(data: Data, expr: nwp.Expr, expected: Data) -> None: + df = dataframe(data) + assert_equal_data(df.select(expr), expected) + + +@pytest.mark.parametrize( + ("column", "other", "expected"), + [ + ("a", [4, 5], [False, True, False, True]), + ("a", [], [False, False, False, False]), + ("b", deque([0, 1]), [True, True, False, True]), + ("c", ("howdy", None), [True, False, False, True]), + ("b", series([2]), [False, False, True, False]), + ("c", pa.array(["hi", "hello"]), [False, True, True, False]), + ], +) +def test_ser_is_in( + data: Data, + column: Literal["a", "b", "c"], + other: Iterable[Any], + expected: Sequence[Any], +) -> None: + result = series(data[column]).alias(column).is_in(other) + assert_equal_series(result, expected, column) + + +def test_is_in_other(data: Data) -> None: + df = dataframe(data) + with pytest.raises(TypeError, match=r"is_in.+doesn't accept.+str"): + df.with_columns(contains=nwp.col("a").is_in("sets")) + + +def test_expr_is_in_series(data: Data) -> None: + df = dataframe(data) + + a = nwp.col("a") + a_first = a.first() + a_last = a.last() + a_ser = df.get_column("a") + b_ser = df.get_column("b") + + assert_equal_data(df.filter(a.is_in(b_ser)), {"a": [1, 2], "b": [1, 2]}) + assert_equal_data(df.select(a_last.is_in(b_ser)), {"a": [False]}) + assert_equal_data(df.select(a_first.is_in(b_ser)), {"a": [True]}) + assert_equal_data(df.select((a_last - a_first).is_in(a_ser)), {"a": [True]}) + assert_equal_data(df.select((a_last - a_first).is_in(b_ser)), {"a": [False]}) diff --git a/tests/plan/over_test.py b/tests/plan/over_test.py index c561b8d47e..ccea0eec77 100644 --- a/tests/plan/over_test.py +++ b/tests/plan/over_test.py @@ -1,19 +1,28 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from operator import methodcaller +from typing import TYPE_CHECKING, Any, Literal, TypeVar import pytest pytest.importorskip("pyarrow") + import narwhals as nw import narwhals._plan as nwp from narwhals._plan import selectors as ncs +from narwhals._utils import zip_strict from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_equal_data, dataframe, re_compile if TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence + + from _pytest.mark import ParameterSet + from typing_extensions import TypeAlias + from narwhals._plan.typing import IntoExprColumn, OneOrIterable + from narwhals.typing import NonNestedLiteral from tests.conftest import Data @@ -263,3 +272,271 @@ def test_null_count_over() -> None: .over(ncs.integer() - ncs.by_name("c")) ) assert_equal_data(result, expected) + + +@pytest.fixture +def data_groups() -> Data: + return { + "a": ["a", "b", "d", "d", "b", "c"], + "b": [1, 2, 1, 5, 3, 3], + "c": [5, 4, 3, 6, 2, 1], + "i": [0, 1, 2, 3, 4, 5], + } + + +v = nwp.col("v") +p1 = nwp.col("p1") +p2 = nwp.col("p2") +p3 = nwp.col("p3") + + +@pytest.mark.parametrize( + ("expr", "result_values"), + [ + (v.first().over(p1, order_by="i", descending=True), [5, 2, 6, 6, 2, 1]), + (v.last().over(p1, p2, order_by="i", descending=True), [5, 4, 3, 6, 2, 1]), + (v.first().over(p3, p2, order_by="i"), [5, 4, 3, 6, 6, 1]), + ( + ( + v.first().over( + nwp.when(p2.is_null()).then(2).when(p2 == 1).then(p2), + p3, + order_by="i", + descending=True, + ) + ), + [5, 4, 3, 2, 2, 1], + ), + ], +) +def test_over_partition_by_nulls_order_by( + expr: nwp.Expr, result_values: list[Any] +) -> None: + data = { + "p1": ["a", "b", None, None, "b", "c"], + "p2": [1, 2, 1, None, None, None], + "p3": [None, 1, 1, 2, 2, None], + "v": [5, 4, 3, 6, 2, 1], + "i": [0, 1, 2, 3, 4, 5], + } + expected = data | {"result": result_values} + result = dataframe(data).with_columns(result=expr).sort("i") + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("expr", "result_values"), + [ + ( + nwp.col("c").first().over("a", order_by="i", descending=True), + [5, 2, 6, 6, 2, 1], + ), + (nwp.col("c").first().over("a", order_by="i"), [5, 4, 3, 3, 4, 1]), + ( + nwp.col("c").mean().over(ncs.integer(), order_by="i"), + [5.0, 4.0, 3.0, 6.0, 2.0, 1.0], + ), + ( + nwp.col("c").min().over(ncs.first(), order_by=[ncs.first(), ncs.last()]), + [5, 2, 3, 3, 2, 1], + ), + ], +) +def test_over_partition_by_order_by( + data_groups: Data, expr: nwp.Expr, result_values: list[Any] +) -> None: + expected = data_groups | {"result": result_values} + df = dataframe(data_groups) + result = df.with_columns(result=expr).sort("i") + assert_equal_data(result, expected) + + +def _ensure_list(arg: T | list[T]) -> list[T]: + return [arg] if not isinstance(arg, list) else arg + + +ValueColumn: TypeAlias = Literal["v1", "v2", "v3"] +OrderColumn: TypeAlias = Literal["o1", "o2", "o3", "o4", "o5"] +Agg: TypeAlias = Literal["first", "last"] +T = TypeVar("T") + +_AGG_EXPR_METHOD: Mapping[Agg, Callable[[nwp.Expr], nwp.Expr]] = { + "first": methodcaller("first"), + "last": methodcaller("last"), +} + + +@pytest.fixture(scope="module") +def data_order() -> Mapping[str, list[NonNestedLiteral]]: + return { + "o1": [0, 1, 2, 3], + "o2": ["y", "y", "x", "a"], + "o3": [None, 5, 2, 5], + "o4": ["L", "M", "A", None], + "o5": [1, None, None, -1], + "v1": [12, 1, 5, 2], + "v2": ["under", "water", "unicorn", "magic"], + "v3": [5.9, 1.2, 22.9, 999.1], + } + + +def order_case( + columns: ValueColumn | list[ValueColumn], + aggregation: Agg, + /, + order_by: OrderColumn | Sequence[OrderColumn], + *, + descending: bool = False, + nulls_last: bool = False, + expected: NonNestedLiteral | list[NonNestedLiteral], +) -> ParameterSet: + """Generate `Expr`s and an expected dataset for ordered aggregations. + + Covers both `over(order_by=...)` and `sort_by(...)` to ensure their results are identical in a + select context. + """ + # Encoding argument combinations into column names and the shared test id + ordering = f"{order_by}-{'desc' if descending else 'asc'}-{'nulls_last' if nulls_last else 'nulls_first'}" + suffix_over = f"_{aggregation}-over-{ordering}" + suffix_sort_by = f"_sort_by-{ordering}-{aggregation}" + test_id = f"{columns}_{aggregation}-{ordering}" + + # Generating what our expected dataset should be + names_values = list(zip_strict(_ensure_list(columns), _ensure_list(expected))) + result_data = { + f"{name}{suffix}": [expect] + for suffix in (suffix_over, suffix_sort_by) + for name, expect in names_values + } + + # Finally, all the expressions + cols = nwp.col(columns) + agg = _AGG_EXPR_METHOD[aggregation] + over = ( + cols.pipe(agg) + .over(order_by=order_by, descending=descending, nulls_last=nulls_last) + .name.suffix(suffix_over) + ) + sort_by = ( + cols.sort_by(order_by, descending=descending, nulls_last=nulls_last) + .pipe(agg) + .name.suffix(suffix_sort_by) + ) + return pytest.param([over, sort_by], result_data, id=test_id) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + order_case("v1", "first", order_by="o4", expected=2), + order_case("v1", "first", order_by="o4", descending=True, expected=2), + order_case("v1", "first", order_by="o4", nulls_last=True, expected=5), + order_case( + "v1", "first", order_by="o4", descending=True, nulls_last=True, expected=1 + ), + order_case("v2", "last", order_by=["o3", "o5"], expected="magic"), + order_case( + "v2", "last", order_by=["o3", "o5"], descending=True, expected="unicorn" + ), + order_case( + "v2", "last", order_by=["o3", "o5"], nulls_last=True, expected="under" + ), + order_case( + "v2", + "last", + order_by=["o3", "o5"], + descending=True, + nulls_last=True, + expected="under", + ), + order_case(["v3", "v2"], "last", order_by=["o2", "o5"], expected=[5.9, "under"]), + order_case( + ["v3", "v2"], + "first", + order_by=["o2", "o5"], + descending=True, + expected=[1.2, "water"], + ), + order_case( + ["v3", "v2"], + "first", + order_by=["o2", "o5"], + nulls_last=True, + expected=[999.1, "magic"], + ), + order_case( + ["v3", "v2"], + "last", + order_by=["o5", "o2"], + nulls_last=True, + descending=True, + expected=[22.9, "unicorn"], + ), + ], +) +def test_over_order_by_sort_by_asc_desc_nulls_first_last( + expr: OneOrIterable[nwp.Expr], + expected: Data, + data_order: Mapping[str, list[NonNestedLiteral]], +) -> None: + result = dataframe(data_order).select(expr) + assert_equal_data(result, expected) + + +def test_over_partition_by_order_by_asc_desc_nulls_first_24989() -> None: + # Adapted from https://github.com/pola-rs/polars/issues/24989 + data = {"a": [1, 1, 2, 2], "b": [4, 5, 6, 7], "i": [1, None, 2, 3]} + b_first = nwp.col("b").first() + df = dataframe(data) + result = df.with_columns( + asc_nulls_first=b_first.over("a", order_by="i"), + asc_nulls_last=b_first.over("a", order_by="i", nulls_last=True), + desc_nulls_first=b_first.over("a", order_by="i", descending=True), + desc_nulls_last=b_first.over("a", order_by="i", descending=True, nulls_last=True), + ).sort("b") + expected = { + "a": [1, 1, 2, 2], + "b": [4, 5, 6, 7], + "i": [1, None, 2, 3], + "asc_nulls_first": [5, 5, 6, 6], + "asc_nulls_last": [4, 4, 6, 6], + "desc_nulls_first": [5, 5, 7, 7], + "desc_nulls_last": [4, 4, 7, 7], + } + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + pytest.param( + nwp.col("b").mean().alias("c").over(nwp.col("b") - nwp.col("a")), + {"a": [1, 1, 2], "b": [4, 5, 6], "d": [10, -1, -9], "c": [4.0, 5.5, 5.5]}, + id="unordered", + ), + pytest.param( + nwp.col("b") + .last() + .over(nwp.col("b") - nwp.col("a"), order_by="d") + .alias("f"), + {"a": [1, 1, 2], "b": [4, 5, 6], "d": [10, -1, -9], "f": [4, 5, 5]}, + id="ordered", + ), + pytest.param( + nwp.col("a") + .first() + .over( + nwp.when(d=10).then(nwp.col("d").last()).otherwise("d"), + order_by="b", + descending=True, + ) + .alias("e"), + {"a": [1, 1, 2], "b": [4, 5, 6], "d": [10, -1, -9], "e": [2, 1, 2]}, + id="ordered-agg-ordered-partition", + ), + ], +) +def test_over_partition_by_projection(expr: nwp.Expr, expected: Data) -> None: + data = {"a": [1, 1, 2], "b": [4, 5, 6], "d": [10, -1, -9]} + result = dataframe(data).with_columns(expr).sort("b") + assert_equal_data(result, expected) diff --git a/tests/plan/rank_test.py b/tests/plan/rank_test.py new file mode 100644 index 0000000000..319b5d2415 --- /dev/null +++ b/tests/plan/rank_test.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals._plan as nwp +from narwhals.exceptions import InvalidOperationError +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from collections.abc import Iterator, Mapping, Sequence + + from typing_extensions import TypeAlias + + from narwhals.typing import RankMethod + +Data: TypeAlias = "dict[str, Sequence[float | None]]" + +ASC = False +DESC = True + + +@pytest.fixture(params=["average", "min", "max", "dense", "ordinal"]) +def rank_method(request: pytest.FixtureRequest) -> RankMethod: + method: RankMethod = request.param + return method + + +def _generate_data() -> Iterator[Data]: + a_int = [3, 6, 1, 1, None, 6] + a_float = [3.1, 6.1, 1.5, 1.5, None, 6.1] + for column in (a_int, a_float): + yield {"a": column, "b": [1, 1, 2, 1, 2, 2], "i": [1, 2, 3, 4, 5, 6]} + + +@pytest.fixture(params=_generate_data(), scope="module", ids=["int", "float"]) +def data(request: pytest.FixtureRequest) -> Data: + data_: Data = request.param + return data_ + + +EXPECTED: Mapping[tuple[RankMethod, bool], Sequence[float | None]] = { + ("average", ASC): [3.0, 4.5, 1.5, 1.5, None, 4.5], + ("average", DESC): [3.0, 1.5, 4.5, 4.5, None, 1.5], + ("min", ASC): [3, 4, 1, 1, None, 4], + ("min", DESC): [3, 1, 4, 4, None, 1], + ("max", ASC): [3, 5, 2, 2, None, 5], + ("max", DESC): [3, 2, 5, 5, None, 2], + ("dense", ASC): [2, 3, 1, 1, None, 3], + ("dense", DESC): [2, 1, 3, 3, None, 1], + ("ordinal", ASC): [3, 4, 1, 2, None, 5], + ("ordinal", DESC): [3, 1, 4, 5, None, 2], +} +EXPECTED_PARTITION_BY: Mapping[tuple[RankMethod, bool], Sequence[float | None]] = { + ("average", ASC): [2.0, 3.0, 1.0, 1.0, None, 2.0], + ("average", DESC): [2.0, 1.0, 2.0, 3.0, None, 1.0], + ("min", ASC): [2, 3, 1, 1, None, 2], + ("min", DESC): [2, 1, 2, 3, None, 1], + ("max", ASC): [2, 3, 1, 1, None, 2], + ("max", DESC): [2, 1, 2, 3, None, 1], + ("dense", ASC): [2, 3, 1, 1, None, 2], + ("dense", DESC): [2, 1, 2, 3, None, 1], + ("ordinal", ASC): [2, 3, 1, 1, None, 2], + ("ordinal", DESC): [2, 1, 2, 3, None, 1], +} +EXPECTED_ORDER_BY: Mapping[tuple[RankMethod, bool], Sequence[float | None]] = { + ("average", ASC): [3.0, 4.5, 1.5, 1.5, None, 4.5], + ("average", DESC): [3.0, 1.5, 4.5, 4.5, None, 1.5], + ("min", ASC): [3, 4, 1, 1, None, 4], + ("min", DESC): [3, 1, 4, 4, None, 1], + ("max", ASC): [3, 5, 2, 2, None, 5], + ("max", DESC): [3, 2, 5, 5, None, 2], + ("dense", ASC): [2, 3, 1, 1, None, 3], + ("dense", DESC): [2, 1, 3, 3, None, 1], + ("ordinal", ASC): [3, 4, 1, 2, None, 5], + ("ordinal", DESC): [3, 1, 4, 5, None, 2], +} + + +@pytest.mark.parametrize("descending", [ASC, DESC], ids=["asc", "desc"]) +def test_rank_expr(rank_method: RankMethod, data: Data, *, descending: bool) -> None: + result = dataframe(data).select(nwp.col("a").rank(rank_method, descending=descending)) + assert_equal_data(result, {"a": EXPECTED[rank_method, descending]}) + + +@pytest.mark.xfail( + reason="`ArrowExpr.rank().over(*partition_by)` is not implemented on main", + raises=InvalidOperationError, +) +@pytest.mark.parametrize("descending", [ASC, DESC], ids=["asc", "desc"]) +def test_rank_expr_partition_by( + rank_method: RankMethod, data: Data, *, descending: bool +) -> None: # pragma: no cover + # `test_rank_expr_in_over_context` + result = dataframe(data).select( + nwp.col("a").rank(rank_method, descending=descending).over("b") + ) + assert_equal_data(result, {"a": EXPECTED_PARTITION_BY[rank_method, descending]}) + + +@pytest.mark.parametrize("descending", [ASC, DESC], ids=["asc", "desc"]) +def test_rank_expr_order_by( + rank_method: RankMethod, data: Data, *, descending: bool +) -> None: + result = dataframe(data).select( + nwp.col("a").rank(rank_method, descending=descending).over(order_by="i") + ) + assert_equal_data(result, {"a": EXPECTED_ORDER_BY[rank_method, descending]}) + + +def test_rank_expr_order_by_3177() -> None: + # NOTE: #3177 + data = {"a": [1, 1, 2, 2, 3, 3], "b": [3, None, 4, 3, 5, 6], "i": list(range(6))} + df = dataframe(data) + result = df.with_columns(c=nwp.col("a").rank("ordinal").over(order_by="b")).sort("i") + expected = { + "a": [1, 1, 2, 2, 3, 3], + "b": [3, None, 4, 3, 5, 6], + "i": [0, 1, 2, 3, 4, 5], + "c": [2, 1, 4, 3, 5, 6], + } + assert_equal_data(result, expected) + + data = {"i": [0, 1, 2], "j": [1, 2, 1]} + df = dataframe(data) + result = ( + df.with_columns(z=nwp.col("j").rank("min").over(order_by="i")) + .sort("i") + .select("z") + ) + expected = {"z": [1.0, 3.0, 1.0]} + assert_equal_data(result, expected) diff --git a/tests/plan/series_scatter_test.py b/tests/plan/series_scatter_test.py new file mode 100644 index 0000000000..f00adefc3c --- /dev/null +++ b/tests/plan/series_scatter_test.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals._plan as nwp +from tests.plan.utils import assert_equal_data, assert_equal_series, dataframe, series + +if TYPE_CHECKING: + from collections.abc import Sequence + + from narwhals._plan.typing import OneOrIterable + + +@pytest.mark.parametrize( + ("data", "indices", "values", "expected"), + [ + ([1, 2, 3], [0, 1], [999, 888], [999, 888, 3]), + ([142, 124, 13], [0, 2, 1], series([142, 124, 13]), [142, 13, 124]), + ([1, 2, 3], 0, 999, [999, 2, 3]), + ( + [16, 12, 10, 9, 6, 5, 2], + [6, 1, 0, 5, 3, 2, 4], + series([16, 12, 10, 9, 6, 5, 2]), + [10, 12, 5, 6, 2, 9, 16], + ), + ([5.5, 9.2, 1.0], (), (), [5.5, 9.2, 1.0]), + ], + ids=["lists", "single-series", "integer", "unordered-indices", "empty-indices"], +) +def test_scatter( + data: list[Any], + indices: int | Sequence[int], + values: OneOrIterable[int], + expected: list[Any], +) -> None: + ser = series(data).alias("ser") + if isinstance(values, nwp.Series): + assert ser.implementation is values.implementation + assert_equal_series(ser.scatter(indices, values), expected, "ser") + + +def test_scatter_unchanged() -> None: + df = dataframe({"a": [1, 2, 3], "b": [142, 124, 132]}) + a = df.get_column("a") + b = df.get_column("b") + df.with_columns(a.scatter([0, 1], [999, 888]), b.scatter([0, 2, 1], [142, 124, 132])) + assert_equal_data(df, {"a": [1, 2, 3], "b": [142, 124, 132]}) + + +def test_scatter_2862() -> None: + ser = series([1, 2, 3]).alias("a") + assert_equal_series(ser.scatter(1, 999), [1, 999, 3], "a") + assert_equal_series(ser.scatter([0, 2], [999, 888]), [999, 2, 888], "a") + assert_equal_series(ser.scatter([2, 0], [999, 888]), [888, 2, 999], "a") diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 23bfcff219..09acbf750f 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -207,8 +207,10 @@ def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) -def dataframe(data: dict[str, Any], /) -> nwp.DataFrame[pa.Table, pa.ChunkedArray[Any]]: - return nwp.DataFrame.from_native(pa.table(data)) +def dataframe( + data: Mapping[str, Any], / +) -> nwp.DataFrame[pa.Table, pa.ChunkedArray[Any]]: + return nwp.DataFrame.from_native(pa.Table.from_pydict(data)) def series(values: Iterable[Any], /) -> nwp.Series[pa.ChunkedArray[Any]]: @@ -221,6 +223,12 @@ def assert_equal_data( _assert_equal_data(result.to_dict(as_series=False), expected) +def assert_equal_series( + result: nwp.Series[Any], expected: Sequence[Any], name: str +) -> None: + assert_equal_data(result.to_frame(), {name: expected}) + + def re_compile( pattern: str, flags: _Flags = re.DOTALL | re.IGNORECASE ) -> re.Pattern[str]: