|
15 | 15 | is_seq_column, |
16 | 16 | ) |
17 | 17 | from narwhals._plan.arrow import functions as fn |
| 18 | +from narwhals._plan.arrow.group_by import AggSpec |
18 | 19 | from narwhals._plan.arrow.series import ArrowSeries as Series |
19 | | -from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co |
| 20 | +from narwhals._plan.arrow.typing import ( |
| 21 | + ChunkedOrArrayAny, |
| 22 | + ChunkedOrScalarAny, |
| 23 | + NativeScalar, |
| 24 | + StoresNativeT_co, |
| 25 | +) |
20 | 26 | from narwhals._plan.common import temp |
21 | 27 | from narwhals._plan.compliant.accessors import ( |
22 | 28 | ExprCatNamespace, |
@@ -994,11 +1000,47 @@ def join(self, node: FExpr[lists.Join], frame: Frame, name: str) -> Expr | Scala |
994 | 1000 | ) |
995 | 1001 | return self.with_native(result, name) |
996 | 1002 |
|
997 | | - min = not_implemented() |
998 | | - max = not_implemented() |
999 | | - mean = not_implemented() |
1000 | | - median = not_implemented() |
1001 | | - sum = not_implemented() |
| 1003 | + def aggregate( |
| 1004 | + self, node: FExpr[lists.Aggregation], frame: Frame, name: str |
| 1005 | + ) -> Expr | Scalar: |
| 1006 | + previous = node.input[0].dispatch(self.compliant, frame, name) |
| 1007 | + func = node.function |
| 1008 | + if isinstance(previous, ArrowScalar): |
| 1009 | + msg = f"TODO: ArrowScalar.{func!r}" |
| 1010 | + raise NotImplementedError(msg) |
| 1011 | + |
| 1012 | + native = previous.native |
| 1013 | + lists = native |
| 1014 | + # TODO @dangotbanned: Experiment with explode step |
| 1015 | + # These options are to mirror `main`, but setting them to `True` may simplify everything after? |
| 1016 | + builder = fn.ExplodeBuilder(empty_as_null=False, keep_nulls=False) |
| 1017 | + explode_w_idx = builder.explode_with_indices(lists) |
| 1018 | + idx, v = "idx", "values" |
| 1019 | + agg_result = ( |
| 1020 | + AggSpec._from_agg(type(func), v) |
| 1021 | + .over(explode_w_idx, [idx]) |
| 1022 | + .sort_by(idx) |
| 1023 | + .column(v) |
| 1024 | + ) |
| 1025 | + dtype: pa.DataType = agg_result.type |
| 1026 | + non_empty_mask = fn.not_eq(fn.list_len(lists), fn.lit(0)) |
| 1027 | + base_array: ChunkedOrArrayAny |
| 1028 | + if isinstance(func, ir.lists.Sum): |
| 1029 | + # Make sure sum of empty list is 0. |
| 1030 | + base_array = fn.when_then(fn.is_not_null(non_empty_mask), fn.lit(0, dtype)) |
| 1031 | + else: |
| 1032 | + base_array = fn.repeat_unchecked(fn.lit(None, dtype), len(lists)) |
| 1033 | + replaced = fn.replace_with_mask( |
| 1034 | + base_array, fn.fill_null(non_empty_mask, False), agg_result |
| 1035 | + ) |
| 1036 | + result = fn.chunked_array(replaced) |
| 1037 | + return self.with_native(result, name) |
| 1038 | + |
| 1039 | + min = aggregate |
| 1040 | + max = aggregate |
| 1041 | + mean = aggregate |
| 1042 | + median = aggregate |
| 1043 | + sum = aggregate |
1002 | 1044 |
|
1003 | 1045 |
|
1004 | 1046 | class ArrowStringNamespace( |
|
0 commit comments