Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
0daaac9
feat(DRAFT): Start adding `rank`
dangotbanned Nov 9, 2025
1f406d6
test: `rank(descending=True)` as well
dangotbanned Nov 9, 2025
c7bec7f
test: Add placeholder for `test_rank_expr_partition_by`
dangotbanned Nov 9, 2025
cc1caea
test: Add `test_rank_expr_order_by`
dangotbanned Nov 9, 2025
b587b4e
refactor: Add `RankOptions.to_arrow`
dangotbanned Nov 9, 2025
1efa41a
test: Add `test_rank_expr_order_by_3177`
dangotbanned Nov 9, 2025
b21a126
tidy up
dangotbanned Nov 9, 2025
20b6cbf
feat(DRAFT): Start adding `with_row_index_by`
dangotbanned Nov 9, 2025
8974740
feat: Impl `ArrowDataFrame.with_row_index_by`
dangotbanned Nov 9, 2025
6fe0af8
perf: Massively simplify `pyarrow<20` branch
dangotbanned Nov 9, 2025
7c74c3f
fix(typing): Avoid `[no-any-return]`
dangotbanned Nov 9, 2025
7e82c40
feat: Support `ArrowExpr.rank(method="average")`
dangotbanned Nov 9, 2025
51b336f
feat: Consistently raise (when needed) on empty expansions
dangotbanned Nov 10, 2025
60ad2cf
chore: Drive-by add `CompliantSeries.len`
dangotbanned Nov 10, 2025
43209a2
feat: Mostly ready `over(*partition_by, order_by=...)`
dangotbanned Nov 10, 2025
081bff1
test: add test for respecting `nulls_last`
dangotbanned Nov 10, 2025
228551c
fix: Always respect `null_last`, enforce `polars` default
dangotbanned Nov 10, 2025
9e82fd6
shrinky-dink
dangotbanned Nov 10, 2025
0633b1a
Merge branch 'oh-nodes' into expr-ir/over-partition-by-order-by
dangotbanned Nov 11, 2025
f21a862
chore(expr-ir): Fill out more of `CompliantExpr` (#3304)
dangotbanned Nov 11, 2025
411b4e9
refactor: Sorting cleanup part 1
dangotbanned Nov 11, 2025
2948a12
refactor: Sorting cleanup part 2
dangotbanned Nov 11, 2025
560ae93
refactor: Sorting cleanup part 3
dangotbanned Nov 12, 2025
cf1ae6d
test: Cover more order_by variants
dangotbanned Nov 12, 2025
a1f33f3
refactor: Sorting cleanup part 4
dangotbanned Nov 12, 2025
feaf002
feat(expr-ir): Add `Series.scatter` (#3305)
dangotbanned Nov 13, 2025
35cf26a
refactor: Clean up `rank`
dangotbanned Nov 13, 2025
0f73d22
perf: Optimize `sort_by` some more
dangotbanned Nov 13, 2025
5faa499
test: More `sort_by` coverage
dangotbanned Nov 14, 2025
58f4837
refactor(perf): Simplify `over(order_by=...)`
dangotbanned Nov 14, 2025
9d7f09a
docs: Explain `broadcast` a lil bit
dangotbanned Nov 14, 2025
fadcf30
feat(DRAFT): Support `is_{first,last}_distinct().over(*partition_by)`
dangotbanned Nov 14, 2025
df6229f
test: Remove unused fixtures
dangotbanned Nov 14, 2025
ed99711
feat(DRAFT): Support `is_first_distinct().over(*partition_by, order_b…
dangotbanned Nov 14, 2025
01ea3e1
test: Make first/last distinct tests easier to follow
dangotbanned Nov 15, 2025
cfe07af
fix: `Expr.is_{first,last}_distinct.over(*partition_by, order_by=...,…
dangotbanned Nov 15, 2025
c1e376a
refactor: Avoid unnecessary aliasing
dangotbanned Nov 15, 2025
aabe956
feat(DRAFT): Start adding `is_in`
dangotbanned Nov 15, 2025
77c8a3f
feat: Support `Series.is_in`
dangotbanned Nov 15, 2025
c0c9410
refactor: Use `ArrowSeries.is_in`
dangotbanned Nov 15, 2025
b734820
refactor: Add `CompliantDataFrame.group_by_agg_irs`
dangotbanned Nov 15, 2025
04c55be
fix: xfail non-selecting expressions in `over(*partition_by)`
dangotbanned Nov 15, 2025
bdba098
feat: Support `is_in_series`
dangotbanned Nov 15, 2025
a32c322
chore: Add repr for debugging
dangotbanned Nov 16, 2025
a6937c1
feat: Add `Series.has_nulls`
dangotbanned Nov 16, 2025
14e3b58
feat: Support unordered expressions for over windows
dangotbanned Nov 16, 2025
3761472
test: Update bad test
dangotbanned Nov 16, 2025
2bcd3ff
feat: Support more complex expressions in over windows
dangotbanned Nov 16, 2025
5eb6b78
remove finished todo
dangotbanned Nov 16, 2025
82a0867
typo
dangotbanned Nov 16, 2025
1c8fdb4
refactor: Rename to `Indices`
dangotbanned Nov 17, 2025
e9df821
refactor: Merge the 3x `is_{first,last}_distinct` impls
dangotbanned Nov 17, 2025
b5345e4
chore: Huge cleanup of sort/indices
dangotbanned Nov 17, 2025
c6cf386
oooh missed a spot
dangotbanned Nov 17, 2025
67d2003
refactor: Switch more over to `functions.sort_indices`
dangotbanned Nov 17, 2025
5ab7d9f
feat: Impl `is_in_seq`
dangotbanned Nov 17, 2025
b3129b3
test: xfail `is_in_expr`
dangotbanned Nov 17, 2025
3958645
feat(DRAFT): Prep for `is_in_expr`
dangotbanned Nov 17, 2025
ad68560
feat: Support `Expr.is_in(Expr)`
dangotbanned Nov 17, 2025
911d8d7
🧹🧹🧹
dangotbanned Nov 17, 2025
0a552cc
fix: Partial support for nulls in `over(*partition_by)`
dangotbanned Nov 17, 2025
e0fd01b
docs: Note difference from upstream
dangotbanned Nov 18, 2025
a8564a4
feat: Give more context on `MultiOutputExpressionError`
dangotbanned Nov 18, 2025
4dcd15a
chore: Explain `Incomplete`
dangotbanned Nov 18, 2025
017e918
refactor: Move `is_seq_column` guard
dangotbanned Nov 18, 2025
da21c44
somewhat less complex
dangotbanned Nov 18, 2025
1c24115
test: Add tests for (#3316)
dangotbanned Nov 19, 2025
212f1ce
feat: Fully support nulls in `over(*partition_by)`
dangotbanned Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ ba = "ba" # Used as column name in docstring examples (way too much?)
iy = "iy" # Used as column name (once in a test)
pn = "pn" # Used in docs: pn = PandasLikeNamespace(...)
TYP = "TYP" # Used in flake8 rule
arange = "arange" # Used in numpy, polars, pyarrow

[files]
extend-exclude = ["tests/data/*"]
37 changes: 23 additions & 14 deletions narwhals/_plan/_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
binary_expr_multi_output_error,
column_not_found_error,
duplicate_error,
expand_multi_output_error,
selectors_not_found_error,
)
from narwhals._plan.expressions import (
Alias,
Expand All @@ -57,7 +59,6 @@
from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema, freeze_schema
from narwhals._typing_compat import assert_never
from narwhals._utils import check_column_names_are_unique, zip_strict
from narwhals.exceptions import MultiOutputExpressionError

if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Iterator, Sequence
Expand Down Expand Up @@ -99,7 +100,12 @@ def prepare_projection(


def expand_selector_irs_names(
selectors: Sequence[SelectorIR], /, ignored: Ignored = (), *, schema: IntoFrozenSchema
selectors: Sequence[SelectorIR],
/,
ignored: Ignored = (),
*,
schema: IntoFrozenSchema,
require_any: bool = False,
) -> OutputNames:
"""Expand selector-only input into the column names that match.

Expand All @@ -110,11 +116,15 @@ def expand_selector_irs_names(
selectors: IRs that **only** contain subclasses of `SelectorIR`.
ignored: Names of `group_by` columns.
schema: Scope to expand selectors in.
require_any: Raise if the entire expansion selected zero columns.
"""
names = tuple(Expander(schema, ignored).iter_expand_selector_names(selectors))
if len(names) != len(set(names)):
# NOTE: Can't easily reuse `duplicate_error`, falling back to main for now
check_column_names_are_unique(names)
expander = Expander(schema, ignored)
if names := tuple(expander.iter_expand_selector_names(selectors)):
if len(names) != len(set(names)):
# NOTE: Can't easily reuse `duplicate_error`, falling back to main for now
check_column_names_are_unique(names)
elif require_any:
raise selectors_not_found_error(selectors, expander.schema)
return names


Expand Down Expand Up @@ -245,15 +255,14 @@ def _expand_inner(self, children: Seq[ExprIR], /) -> Iterator[ExprIR]:
for child in children:
yield from self._expand_recursive(child)

def _expand_only(self, child: ExprIR, /) -> ExprIR:
def _expand_only(self, origin: ExprIR, child: ExprIR, /) -> ExprIR:
# used by
# - `_expand_combination` (ExprIR fields)
# - `_expand_function_expr` (all others that have len(inputs)>=2, call on non-root)
iterable = self._expand_recursive(child)
first = next(iterable)
if second := next(iterable, None):
msg = f"Multi-output expressions are not supported in this context, got: `{second!r}`" # pragma: no cover
raise MultiOutputExpressionError(msg) # pragma: no cover
raise expand_multi_output_error(origin, child, first, second, *iterable)
return first

# TODO @dangotbanned: It works, but all this class-specific branching belongs in the classes themselves
Expand All @@ -268,16 +277,16 @@ def _expand_combination(self, origin: Combination, /) -> Iterator[Combination]:
elif isinstance(origin, ir.SortBy):
changes["by"] = tuple(self._expand_inner(origin.by))
else:
changes["by"] = self._expand_only(origin.by)
changes["by"] = self._expand_only(origin, origin.by)
replaced = common.replace(origin, **changes)
for root in self._expand_recursive(replaced.expr):
yield common.replace(replaced, expr=root)
elif isinstance(origin, ir.BinaryExpr):
yield from self._expand_binary_expr(origin)
elif isinstance(origin, ir.TernaryExpr):
changes["truthy"] = self._expand_only(origin.truthy)
changes["predicate"] = self._expand_only(origin.predicate)
changes["falsy"] = self._expand_only(origin.falsy)
changes["truthy"] = self._expand_only(origin, origin.truthy)
changes["predicate"] = self._expand_only(origin, origin.predicate)
changes["falsy"] = self._expand_only(origin, origin.falsy)
yield origin.__replace__(**changes)
else:
assert_never(origin)
Expand Down Expand Up @@ -316,7 +325,7 @@ def _expand_function_expr(
yield origin.__replace__(input=reduced)
else:
if non_root := origin.input[1:]:
children = tuple(self._expand_only(child) for child in non_root)
children = tuple(self._expand_only(origin, child) for child in non_root)
else:
children = ()
for root in self._expand_recursive(origin.input[0]):
Expand Down
13 changes: 12 additions & 1 deletion narwhals/_plan/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def is_selector(obj: Any) -> TypeIs[Selector]:
return isinstance(obj, _selectors().Selector)


def is_column(obj: Any) -> TypeIs[Expr]:
def is_expr_column(obj: Any) -> TypeIs[Expr]:
"""Indicate if the given object is a basic/unaliased column."""
return is_expr(obj) and obj.meta.is_column()

Expand Down Expand Up @@ -136,8 +136,19 @@ def is_literal(obj: Any) -> TypeIs[ir.Literal[Any]]:
# TODO @dangotbanned: Coverage
# Used in `ArrowNamespace._vertical`, but only horizontal is covered
def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: # pragma: no cover
"""Return True if the **first** element of the tuple `obj` is an instance of `tp`."""
return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp))


def is_re_pattern(obj: Any) -> TypeIs[re.Pattern[str]]:
return isinstance(obj, re.Pattern)


def is_seq_column(exprs: Seq[ir.ExprIR], /) -> TypeIs[Seq[ir.Column]]:
"""Return True if **every** element is a `Column`.

Use this for detecting fastpaths in sub-expressions, that can rely on
every element in `exprs` having a resolved `name` attribute.
"""
Column = _ir().Column # noqa: N806
return all(isinstance(e, Column) for e in exprs)
67 changes: 35 additions & 32 deletions narwhals/_plan/arrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
import pyarrow.compute as pc # ignore-banned-import
from pyarrow.acero import Declaration as Decl

from narwhals._plan.common import ensure_list_str, flatten_hash_safe, temp
from narwhals._plan.options import SortMultipleOptions
from narwhals._plan.common import ensure_list_str, temp
from narwhals._plan.typing import NonCrossJoinStrategy, OneOrSeq
from narwhals._utils import check_column_names_are_unique
from narwhals.typing import JoinStrategy, SingleColSelector
Expand All @@ -48,13 +47,8 @@
Aggregation as _Aggregation,
)
from narwhals._plan.arrow.group_by import AggSpec
from narwhals._plan.arrow.typing import (
ArrowAny,
JoinTypeSubset,
NullPlacement,
ScalarAny,
)
from narwhals._plan.typing import OneOrIterable, Order, Seq
from narwhals._plan.arrow.typing import ArrowAny, JoinTypeSubset, ScalarAny
from narwhals._plan.typing import OneOrIterable, Seq
from narwhals.typing import NonNestedLiteral

Incomplete: TypeAlias = Any
Expand Down Expand Up @@ -238,29 +232,6 @@ def prepend_column(native: pa.Table, name: str, values: IntoExpr) -> Decl:
return _add_column(native, 0, name, values)


def _order_by(
sort_keys: Iterable[tuple[str, Order]] = (),
*,
null_placement: NullPlacement = "at_end",
) -> Decl:
# NOTE: There's no runtime type checking of `sort_keys` wrt shape
# Just need to be `Iterable`and unpack like a 2-tuple
# https://github.com/apache/arrow/blob/9b96bdbc733d62f0375a2b1b9806132abc19cd3f/python/pyarrow/_compute.pyx#L77-L88
keys: Incomplete = sort_keys
return Decl("order_by", pac.OrderByNodeOptions(keys, null_placement=null_placement))


def sort_by(
by: OneOrIterable[str],
*more_by: str,
descending: OneOrIterable[bool] = False,
nulls_last: bool = False,
) -> Decl:
return SortMultipleOptions.parse(
descending=descending, nulls_last=nulls_last
).to_arrow_acero(tuple(flatten_hash_safe((by, more_by))))


def _join_options(
how: NonCrossJoinStrategy,
left_on: OneOrIterable[str],
Expand Down Expand Up @@ -406,6 +377,38 @@ def join_tables(
return collect(_hashjoin(left, right, opts), ensure_unique_column_names=True)


# TODO @dangotbanned: Adapt this into a `Declaration` that handles more of `ArrowGroupBy.agg_over`
def join_inner_tables(left: pa.Table, right: pa.Table, on: list[str]) -> pa.Table:
"""Fast path for use with `over`.

Has almost zero branching and the bodys of helper functions are inlined.

Eventually want to adapt this into:

goal = declare(
join_inner(
declare(table_source(compliant.native), select_names(key_names)),
declare(table_source(ordered), group_by(key_names, specs)),
),
select_names(agg_names),
)
"""
tp: Incomplete = pac.HashJoinNodeOptions
opts = tp(
"inner",
left_keys=on,
right_keys=on,
left_output=left.schema.names,
right_output=(name for name in right.schema.names if name not in on),
output_suffix_for_right="_right",
)
lhs, rhs = pac.TableSourceNodeOptions(left), pac.TableSourceNodeOptions(right)
decl = Decl("hashjoin", opts, [Decl("table_source", lhs), Decl("table_source", rhs)])
result = decl.to_table()
check_column_names_are_unique(result.column_names)
return result


def join_cross_tables(
left: pa.Table, right: pa.Table, suffix: str = "_right", *, coalesce_keys: bool = True
) -> pa.Table:
Expand Down
61 changes: 61 additions & 0 deletions narwhals/_plan/arrow/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Behavior shared by two or more classes."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar, Generic

from narwhals._plan.arrow.functions import BACKEND_VERSION
from narwhals._typing_compat import TypeVar
from narwhals._utils import Implementation, Version, _StoresNative

if TYPE_CHECKING:
import pyarrow as pa
from typing_extensions import Self, TypeIs

from narwhals._plan.arrow.namespace import ArrowNamespace
from narwhals._plan.arrow.typing import ChunkedArrayAny, Indices


def is_series(obj: Any) -> TypeIs[_StoresNative[ChunkedArrayAny]]:
from narwhals._plan.arrow.series import ArrowSeries

return isinstance(obj, ArrowSeries)


NativeT = TypeVar("NativeT", "pa.Table", "ChunkedArrayAny")


class ArrowFrameSeries(Generic[NativeT]):
implementation: ClassVar = Implementation.PYARROW
_native: NativeT
_version: Version

@property
def native(self) -> NativeT:
return self._native

def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._plan.arrow.namespace import ArrowNamespace

return ArrowNamespace(self._version)

def _with_native(self, native: NativeT) -> Self:
msg = f"{type(self).__name__}._with_native"
raise NotImplementedError(msg)

if BACKEND_VERSION >= (18,):

def _gather(self, indices: Indices) -> NativeT:
return self.native.take(indices)
else:

def _gather(self, indices: Indices) -> NativeT:
rows = list(indices) if isinstance(indices, tuple) else indices
return self.native.take(rows)

def gather(self, indices: Indices | _StoresNative[ChunkedArrayAny]) -> Self:
ca = self._gather(indices.native if is_series(indices) else indices)
return self._with_native(ca)

def slice(self, offset: int, length: int | None = None) -> Self:
return self._with_native(self.native.slice(offset=offset, length=length))
49 changes: 34 additions & 15 deletions narwhals/_plan/arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._plan.arrow import acero, functions as fn
from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy, partition_by
from narwhals._plan.arrow.series import ArrowSeries as Series
from narwhals._plan.compliant.dataframe import EagerDataFrame
from narwhals._plan.compliant.typing import namespace
from narwhals._plan.expressions import NamedIR
from narwhals._utils import Implementation, Version
from narwhals._utils import Version, generate_repr
from narwhals.schema import Schema

if TYPE_CHECKING:
Expand All @@ -25,29 +26,31 @@
import polars as pl
from typing_extensions import Self

from narwhals._arrow.typing import ChunkedArrayAny
from narwhals._plan.arrow.namespace import ArrowNamespace
from narwhals._plan.arrow.typing import ChunkedArrayAny
from narwhals._plan.compliant.group_by import GroupByResolver
from narwhals._plan.expressions import ExprIR, NamedIR
from narwhals._plan.options import SortMultipleOptions
from narwhals._plan.typing import NonCrossJoinStrategy
from narwhals.dtypes import DType
from narwhals.typing import IntoSchema


class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]):
implementation = Implementation.PYARROW
_native: pa.Table
_version: Version
class ArrowDataFrame(
FrameSeries["pa.Table"], EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]
):
def __repr__(self) -> str:
return generate_repr(f"nw.{type(self).__name__}", self.native.__repr__())

def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._plan.arrow.namespace import ArrowNamespace

return ArrowNamespace(self._version)
def _with_native(self, native: pa.Table) -> Self:
return self.from_native(native, self.version)

@property
def _group_by(self) -> type[GroupBy]:
return GroupBy

def group_by_resolver(self, resolver: GroupByResolver, /) -> GroupBy:
return self._group_by.from_resolver(self, resolver)

@property
def columns(self) -> list[str]:
return self.native.column_names
Expand Down Expand Up @@ -98,14 +101,26 @@ def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[Series]
from_named_ir = ns._expr.from_named_ir
yield from ns._expr.align(from_named_ir(e, self) for e in nodes)

def sort(self, by: Sequence[str], options: SortMultipleOptions) -> Self:
native = self.native
indices = pc.sort_indices(native.select(list(by)), options=options.to_arrow(by))
return self._with_native(native.take(indices))
def sort(self, by: Sequence[str], options: SortMultipleOptions | None = None) -> Self:
return self.gather(fn.sort_indices(self.native, *by, options=options))

def with_row_index(self, name: str) -> Self:
return self._with_native(self.native.add_column(0, name, fn.int_range(len(self))))

def with_row_index_by(
self,
name: str,
order_by: Sequence[str],
*,
descending: bool = False,
nulls_last: bool = False,
) -> Self:
indices = fn.sort_indices(
self.native, *order_by, nulls_last=nulls_last, descending=descending
)
column = fn.unsort_indices(indices)
return self._with_native(self.native.add_column(0, name, column))

def get_column(self, name: str) -> Series:
chunked = self.native.column(name)
return Series.from_native(chunked, name, version=self.version)
Expand Down Expand Up @@ -168,6 +183,10 @@ def join_cross(self, other: Self, *, suffix: str = "_right") -> Self:
result = acero.join_cross_tables(self.native, other.native, suffix=suffix)
return self._with_native(result)

def join_inner(self, other: Self, on: list[str], /) -> Self:
"""Less flexible, but more direct equivalent to join(how="inner", left_on=...)`."""
return self._with_native(acero.join_inner_tables(self.native, other.native, on))

def filter(self, predicate: NamedIR) -> Self:
mask: pc.Expression | ChunkedArrayAny
resolved = Expr.from_named_ir(predicate, self)
Expand Down
Loading
Loading