Skip to content

Commit c8d09ed

Browse files
committed
feat(DRAFT): Porting (#3332)
Tried to keep everything as close to original as possible Next step is simplifying everything and fixing `list.sum`
1 parent 2c2fa08 commit c8d09ed

5 files changed

Lines changed: 100 additions & 20 deletions

File tree

narwhals/_plan/arrow/expr.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@
1515
is_seq_column,
1616
)
1717
from narwhals._plan.arrow import functions as fn
18+
from narwhals._plan.arrow.group_by import AggSpec
1819
from narwhals._plan.arrow.series import ArrowSeries as Series
19-
from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co
20+
from narwhals._plan.arrow.typing import (
21+
ChunkedOrArrayAny,
22+
ChunkedOrScalarAny,
23+
NativeScalar,
24+
StoresNativeT_co,
25+
)
2026
from narwhals._plan.common import temp
2127
from narwhals._plan.compliant.accessors import (
2228
ExprCatNamespace,
@@ -994,11 +1000,47 @@ def join(self, node: FExpr[lists.Join], frame: Frame, name: str) -> Expr | Scala
9941000
)
9951001
return self.with_native(result, name)
9961002

997-
min = not_implemented()
998-
max = not_implemented()
999-
mean = not_implemented()
1000-
median = not_implemented()
1001-
sum = not_implemented()
1003+
def aggregate(
1004+
self, node: FExpr[lists.Aggregation], frame: Frame, name: str
1005+
) -> Expr | Scalar:
1006+
previous = node.input[0].dispatch(self.compliant, frame, name)
1007+
func = node.function
1008+
if isinstance(previous, ArrowScalar):
1009+
msg = f"TODO: ArrowScalar.{func!r}"
1010+
raise NotImplementedError(msg)
1011+
1012+
native = previous.native
1013+
lists = native
1014+
# TODO @dangotbanned: Experiment with explode step
1015+
# These options are to mirror `main`, but setting them to `True` may simplify everything after?
1016+
builder = fn.ExplodeBuilder(empty_as_null=False, keep_nulls=False)
1017+
explode_w_idx = builder.explode_with_indices(lists)
1018+
idx, v = "idx", "values"
1019+
agg_result = (
1020+
AggSpec._from_agg(type(func), v)
1021+
.over(explode_w_idx, [idx])
1022+
.sort_by(idx)
1023+
.column(v)
1024+
)
1025+
dtype: pa.DataType = agg_result.type
1026+
non_empty_mask = fn.not_eq(fn.list_len(lists), fn.lit(0))
1027+
base_array: ChunkedOrArrayAny
1028+
if isinstance(func, ir.lists.Sum):
1029+
# Make sure sum of empty list is 0.
1030+
base_array = fn.when_then(fn.is_not_null(non_empty_mask), fn.lit(0, dtype))
1031+
else:
1032+
base_array = fn.repeat_unchecked(fn.lit(None, dtype), len(lists))
1033+
replaced = fn.replace_with_mask(
1034+
base_array, fn.fill_null(non_empty_mask, False), agg_result
1035+
)
1036+
result = fn.chunked_array(replaced)
1037+
return self.with_native(result, name)
1038+
1039+
min = aggregate
1040+
max = aggregate
1041+
mean = aggregate
1042+
median = aggregate
1043+
sum = aggregate
10021044

10031045

10041046
class ArrowStringNamespace(

narwhals/_plan/arrow/functions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,17 @@ def explode(
442442
return chunked_array(_list_explode(safe))
443443

444444
def explode_with_indices(self, native: ChunkedList | ListArray) -> pa.Table:
445+
"""Explode list elements, expanding one-level into a table indexing the origin.
446+
447+
Returns a 2-column table, with names `"idx"` and `"values"`:
448+
449+
>>> from narwhals._plan.arrow import functions as fn
450+
>>>
451+
>>> arr = fn.array([[1, 2, 3], None, [4, 5, 6], []])
452+
>>> fn.ExplodeBuilder().explode_with_indices(arr).to_pydict()
453+
{'idx': [0, 0, 0, 1, 2, 2, 2, 3], 'values': [1, 2, 3, None, 4, 5, 6, None]}
454+
# ^ Which sublist we came from ^ The exploded values themselves
455+
"""
445456
safe = self._fill_with_null(native) if self.options.any() else native
446457
arrays = [_list_parent_indices(safe), _list_explode(safe)]
447458
return concat_horizontal(arrays, ["idx", "values"])
@@ -1042,6 +1053,12 @@ def _str_zfill_compat(
10421053
)
10431054

10441055

1056+
@t.overload
1057+
def when_then(
1058+
predicate: ChunkedArray[BooleanScalar], then: ScalarAny
1059+
) -> ChunkedArrayAny: ...
1060+
@t.overload
1061+
def when_then(predicate: Array[BooleanScalar], then: ScalarAny) -> ArrayAny: ...
10451062
@t.overload
10461063
def when_then(
10471064
predicate: Predicate, then: SameArrowT, otherwise: SameArrowT | None
@@ -1059,6 +1076,11 @@ def when_then(
10591076
def when_then(
10601077
predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None
10611078
) -> Incomplete:
1079+
"""Thin wrapper around `pyarrow.compute.if_else`.
1080+
1081+
- Supports a 2-arg form, like `pl.when(...).then(...)`
1082+
- Accepts python literals, but only in the `otherwise` position
1083+
"""
10621084
if is_non_nested_literal(otherwise):
10631085
otherwise = lit(otherwise, then.type)
10641086
return pc.if_else(predicate, then, otherwise)

narwhals/_plan/arrow/group_by.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from narwhals._plan.common import temp
1414
from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy
1515
from narwhals._plan.expressions import aggregation as agg
16-
from narwhals._utils import Implementation
16+
from narwhals._utils import Implementation, qualified_type_name
1717
from narwhals.exceptions import InvalidOperationError
1818

1919
if TYPE_CHECKING:
@@ -51,6 +51,13 @@
5151
agg.Last: "hash_last",
5252
fn.MinMax: "hash_min_max",
5353
}
54+
SUPPORTED_LIST_AGG: Mapping[type[ir.lists.Aggregation], type[agg.AggExpr]] = {
55+
ir.lists.Mean: agg.Mean,
56+
ir.lists.Median: agg.Median,
57+
ir.lists.Max: agg.Max,
58+
ir.lists.Min: agg.Min,
59+
ir.lists.Sum: agg.Sum,
60+
}
5461
SUPPORTED_IR: Mapping[type[ir.ExprIR], acero.Aggregation] = {
5562
ir.Len: "hash_count_all",
5663
ir.Column: "hash_list", # `hash_aggregate` only
@@ -141,6 +148,14 @@ def from_expr_ir(cls, expr: ir.ExprIR, name: acero.OutputName) -> Self:
141148
def _from_function(cls, tp: type[ir.Function], name: str) -> Self:
142149
return cls(name, SUPPORTED_FUNCTION[tp], options.FUNCTION.get(tp), name)
143150

151+
@classmethod
152+
def _from_agg(cls, tp: type[ir.lists.Aggregation | agg.AggExpr], name: str) -> Self:
153+
tp_agg = SUPPORTED_LIST_AGG[tp] if issubclass(tp, ir.lists.ListFunction) else tp
154+
if tp_agg in {agg.Std, agg.Var}:
155+
msg = f"TODO: {qualified_type_name(agg)!r} needs access to `ddof`, so can't be passed in without an instance"
156+
raise NotImplementedError(msg)
157+
return cls(name, SUPPORTED_AGG[tp_agg], options.AGG.get(tp_agg), name)
158+
144159
@classmethod
145160
def any(cls, name: str) -> Self:
146161
return cls._from_function(ir.boolean.Any, name)

narwhals/_plan/expressions/lists.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from narwhals.exceptions import InvalidOperationError
1212

1313
if TYPE_CHECKING:
14-
from typing_extensions import Self
14+
from typing_extensions import Self, TypeAlias
1515

1616
from narwhals._plan.expr import Expr
1717
from narwhals._plan.expressions import ExprIR, FunctionExpr as FExpr
@@ -45,6 +45,9 @@ def unwrap_input(self, node: FExpr[Self], /) -> tuple[ExprIR, ExprIR]:
4545
return expr, item
4646

4747

48+
Aggregation: TypeAlias = "Min | Max | Mean | Median | Sum"
49+
50+
4851
class IRListNamespace(IRNamespace):
4952
len: ClassVar = Len
5053
unique: ClassVar = Unique

tests/plan/list_agg_test.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,32 @@ def data_median(data: Data) -> Data:
2929
cast_a = a.cast(nw.List(nw.Int32))
3030

3131

32-
XFAIL_NOT_IMPL = pytest.mark.xfail(
33-
reason="TODO: ArrowExpr.list.<agg>", raises=NotImplementedError
34-
)
35-
36-
37-
@XFAIL_NOT_IMPL
3832
@pytest.mark.parametrize(
3933
("exprs", "expected"),
4034
[
4135
(a.list.max(), {"a": [4, -1, None, None, None]}),
4236
(a.list.mean(), {"a": [2.75, -1, None, None, None]}),
4337
(a.list.min(), {"a": [2, -1, None, None, None]}),
44-
(a.list.sum(), {"a": [11, -1, None, 0, 0]}),
38+
pytest.param(
39+
a.list.sum(),
40+
{"a": [11, -1, None, 0, 0]},
41+
marks=pytest.mark.xfail(
42+
reason="Mismatch at index 3, key a: None != 0", raises=AssertionError
43+
),
44+
),
4545
],
46+
ids=["max", "mean", "min", "sum"],
4647
)
47-
def test_list_agg(
48-
data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data
49-
) -> None: # pragma: no cover
48+
def test_list_agg(data: Data, exprs: OneOrIterable[nwp.Expr], expected: Data) -> None:
5049
df = dataframe(data).with_columns(cast_a)
5150
result = df.select(exprs)
5251
assert_equal_data(result, expected)
5352

5453

55-
@XFAIL_NOT_IMPL
5654
@pytest.mark.xfail(
5755
is_windows() and sys.version_info < (3, 10), reason="Old pyarrow windows bad?"
5856
)
59-
def test_list_median(data_median: Data) -> None: # pragma: no cover
57+
def test_list_median(data_median: Data) -> None:
6058
df = dataframe(data_median).with_columns(cast_a)
6159
result = df.select(a.list.median())
6260

0 commit comments

Comments
 (0)