Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fadddc3
feat: add list aggregate methods
raisadz Nov 28, 2025
040527b
xfail old pandas and skip if no pyarrow
raisadz Nov 28, 2025
9b45555
add the new methods to the polars list namespace
raisadz Nov 28, 2025
7bdd2d2
xfail old polars for median and pyspark for non implemented methods
raisadz Nov 28, 2025
74934c0
unxfail modin
raisadz Nov 28, 2025
7efe2d2
add no cover for non-pyarrow backends
raisadz Nov 28, 2025
8e04fc1
link pyspark and ibis issues
raisadz Nov 28, 2025
7ab1ebf
add sum/mean/median for PySpark
raisadz Nov 28, 2025
6b57809
xfail pyspark[connect], add no cover for sqlframe
raisadz Nov 28, 2025
7dfc9fa
xfail pyspark connect
raisadz Nov 28, 2025
1243980
handle empty lists for pyarrow, tests for empty lists
raisadz Nov 29, 2025
7eca29a
undo typo
raisadz Nov 29, 2025
545abf8
add None case
raisadz Nov 29, 2025
74de5c6
fix list_agg and tests
raisadz Nov 29, 2025
65072ff
adjust duckdb
raisadz Nov 30, 2025
ca4c794
adjust pyarrow and tests
raisadz Nov 30, 2025
3251865
add `try_divide` for pyspark mean, set min duckdb version for lambda_…
raisadz Nov 30, 2025
eff085c
xfail old duckdb for sum
raisadz Nov 30, 2025
fdcf3f3
fux typing
raisadz Nov 30, 2025
146c458
adjust pyspark median
raisadz Nov 30, 2025
4d18654
fix typing
raisadz Nov 30, 2025
9975334
adjust ibis sum
raisadz Dec 1, 2025
be000f5
mix docstrings, add a test where there is a mismatch for median
raisadz Dec 1, 2025
a470f8f
Merge remote-tracking branch 'upstream/main' into feat/list-agg
raisadz Dec 1, 2025
9b53f03
xfail median for pyarrow and python below 3.10
raisadz Dec 1, 2025
a549576
add no cover
raisadz Dec 1, 2025
76c70ff
update the error msg
raisadz Dec 1, 2025
d716fd1
Merge remote-tracking branch 'upstream/main' into feat/list-agg
raisadz Dec 3, 2025
aa7bad3
Merge remote-tracking branch 'upstream/main' into feat/list-agg
raisadz Dec 8, 2025
54d7041
Update narwhals/_spark_like/expr_list.py
raisadz Dec 9, 2025
47f5a4b
remove a minimum 3.10 python version
raisadz Dec 9, 2025
9bcdebe
remove xfail from tests
raisadz Dec 9, 2025
3ab7639
skip old Python on windows tests for median
raisadz Dec 12, 2025
2c913d7
Merge remote-tracking branch 'upstream/main' into feat/list-agg
raisadz Dec 12, 2025
687c4ae
add no cover
raisadz Dec 12, 2025
c851f10
modify list_agg as suggested
raisadz Dec 13, 2025
310daa6
simplify tests
raisadz Dec 13, 2025
eca1c02
Merge remote-tracking branch 'upstream/main' into feat/list-agg
raisadz Dec 13, 2025
398c350
fix typing
raisadz Dec 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api-reference/expr_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
- contains
- get
- len
- max
- mean
- median
- min
- sum
- unique
show_source: false
show_bases: false
5 changes: 5 additions & 0 deletions docs/api-reference/series_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
- contains
- get
- len
- max
- mean
- median
- min
- sum
- unique
show_source: false
show_bases: false
17 changes: 16 additions & 1 deletion narwhals/_arrow/series_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.utils import ArrowSeriesNamespace
from narwhals._arrow.utils import ArrowSeriesNamespace, list_agg
from narwhals._compliant.any_namespace import ListNamespace
from narwhals._utils import not_implemented

Expand All @@ -20,5 +20,20 @@ def len(self) -> ArrowSeries:
def get(self, index: int) -> ArrowSeries:
return self.with_native(pc.list_element(self.native, index))

def min(self) -> ArrowSeries:
return self.with_native(list_agg(self.native, "min"))

def max(self) -> ArrowSeries:
return self.with_native(list_agg(self.native, "max"))

def mean(self) -> ArrowSeries:
return self.with_native(list_agg(self.native, "mean"))

def median(self) -> ArrowSeries:
return self.with_native(list_agg(self.native, "approximate_median"))

def sum(self) -> ArrowSeries:
return self.with_native(list_agg(self.native, "sum"))

unique = not_implemented()
contains = not_implemented()
38 changes: 38 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping
from typing import Literal

from typing_extensions import TypeAlias, TypeIs

Expand Down Expand Up @@ -494,3 +495,40 @@ def arange(start: int, end: int, step: int) -> ArrayAny:
return pa.array(np.arange(start, end, step))
# NOTE: Added in https://github.com/apache/arrow/pull/46778
return pa.arange(start, end, step) # type: ignore[attr-defined]


def list_agg(
array: ChunkedArrayAny,
func: Literal["min", "max", "mean", "approximate_median", "sum"],
) -> ChunkedArrayAny:
lit_: Incomplete = lit
aggregation = (
("values", func, pc.ScalarAggregateOptions(min_count=0))
if func == "sum"
else ("values", func)
)
agg = pa.array(
pa.Table.from_arrays(
[pc.list_flatten(array), pc.list_parent_indices(array)],
names=["values", "offsets"],
)
.group_by("offsets")
.aggregate([aggregation])
.sort_by("offsets")
.column(f"values_{func}")
)
Comment on lines +500 to +519
Copy link
Copy Markdown
Member

@dangotbanned dangotbanned Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raisadz I'm pretty excited by this! 😄

+1 from me on (#3332 (review))


I've just tried this out with the test case for list.unique:

data = {"a": [[2, 2, 3, None, None], None, [], [None]]}

The result for that should be:

[[None, 2, 3], None, [], [None]]

But using list_agg seems to have dropped 2/4 lists and all nulls 🤔

import pyarrow as pa

data = {"a": [[2, 2, 3, None, None], None, [], [None]]}
ca = pa.chunked_array([pa.array(data["a"])])
result = list_agg(ca, "distinct").to_pylist()
print(result)
[[2, 3], []]

I managed to get slightly closer to what we want, by passing in options for the group_by:

Show list_agg_opts

from typing import Any

import pyarrow as pa
import pyarrow.compute as pc

def list_agg_opts(
    array: pa.ChunkedArray[Any], func: Any, options: Any = None
) -> pa.ChunkedArray[Any]:
    return (
        pa.Table.from_arrays(
            [pc.list_flatten(array), pc.list_parent_indices(array)],
            names=["values", "offsets"],
        )
        .group_by("offsets")
        .aggregate([("values", func, options)])  # <-------
        .column(f"values_{func}")
    )

These are the correct results for 2/4 of the lists 🎉

But where did the other 2 go? 😳

result = list_agg_opts(ca, "distinct", pc.CountOptions("all")).to_pylist()
print(result)
[[2, 3, None], [None]]

Edit: I missed it myself lol, fixed in (d8363e1)

Comment on lines +510 to 519
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a comment rather than anything else, but I want mention it - maybe someone has ideas.

In the case of multiple expressions on the same list column, we would end up running the same (expensive?) operations (flatten + group by + sort). I wonder if for the case of expressions there is a way to execute them together.

In code, I would the following:

frame.select(
    a_sum = nw.col("a").list.sum(),
    a_mean = nw.col("a").list.mean()
)

to execute as:

agg = pa.array(
    pa.Table.from_arrays(
        [pc.list_flatten(array), pc.list_parent_indices(array)],
        names=["values", "offsets"],
    )
    .group_by("offsets")
    .aggregate([("values", "sum"), ("values", "mean")])
    .sort_by("offsets")
)
... # Here get the 2 columns

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe someone has ideas.

100% agree on the idea, but probably too complex in current narwhals IMO.

This issue isn't unique to list.* expressions, we'd have a similar problem for these:

Example

Efficient:

frame.select(
    nw.col("a", "b").sum().over("c"),
)

Same intention, but worse performance for pandas, pyarrow vs the first spelling:

frame.select(
    nw.col("a").sum().over("c"),
    nw.col("b").sum().over("c")
)

Why are these too hard?

I think what you're describing corresponds to common subplan/expression elimination (see datafusion optimizer rules).

The main challenges are:

  • knowing that list.sum becomes op_1, ..., op_n -> group_by -> sort is an implementation detail (+ specific to pyarrow)
    • not something that an Expr can "learn"
  • Expressions don't "know" anything about the context they're evaluated in
    • Are we in select?
    • What other expressions come before/after me?
    • Am I doing work that can be reused?

@FBruzzesi if this sounds similar to what I mentioned RE (#2572) and a LogicalPlan - you'd have the right hunch 😄

Basically, the problem would be much easier to generalize if we had:

frame.select(
    a_sum = nw.col("a").list.sum(),
    a_mean = nw.col("a").list.mean()
)

Outputs:

nw_plan = Select(
    inputs=[
        Alias(name="a_sum", expr=ListSum(Column(name="a"))),
        Alias(name="a_mean", expr=ListMean(Column(name="a"))),
    ]
)

But then we transform the pyarrow plan (expressions) into:

Warning - verbosity overload and invented data model

pyarrow_list_sum = Sort(
    inputs=[
        GroupBy(
            inputs=[
                HConcat(
                    inputs=[ListExplode(Column(name="a")), ListOffsets(Column(name="a"))],
                    names=["values", "offsets"],
                )
            ],
            keys=["offsets"],
            aggs=[Alias(name="a_sum", expr=Sum(Column(name="value")))],
        )
    ],
    by=["offsets"],
)


pyarrow_list_mean = Sort(
    inputs=[
        GroupBy(
            inputs=[
                HConcat(
                    inputs=[ListExplode(Column(name="a")), ListOffsets(Column(name="a"))],
                    names=["values", "offsets"],
                )
            ],
            keys=["offsets"],
            aggs=[Alias(name="a_mean", expr=Mean(Column(name="value")))],
        )
    ],
    by=["offsets"],
)

Then detecting what is shared between list.sum and list.mean is actually less complex than you might think.
And as a bonus, we'd have something that is generalized to expressions that rely on group_by 🙂

non_empty_mask = pa.array(pc.not_equal(pc.list_value_length(array), lit(0)))
if func == "sum":
# Make sure sum of empty list is 0.
base_array = pc.if_else(non_empty_mask.is_null(), None, 0)
else:
base_array = pa.repeat(lit_(None, type=agg.type), len(array))
return pa.chunked_array(
[
pc.replace_with_mask(
base_array,
non_empty_mask.fill_null(False), # type: ignore[arg-type]
agg,
)
]
)
5 changes: 5 additions & 0 deletions narwhals/_compliant/any_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def get(self, index: int) -> CompliantT_co: ...
def len(self) -> CompliantT_co: ...
def unique(self) -> CompliantT_co: ...
def contains(self, item: NonNestedLiteral) -> CompliantT_co: ...
def min(self) -> CompliantT_co: ...
def max(self) -> CompliantT_co: ...
def mean(self) -> CompliantT_co: ...
def median(self) -> CompliantT_co: ...
def sum(self) -> CompliantT_co: ...


class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,21 @@ def contains(self, item: NonNestedLiteral) -> EagerExprT:
def get(self, index: int) -> EagerExprT:
return self.compliant._reuse_series_namespace("list", "get", index=index)

def min(self) -> EagerExprT:
return self.compliant._reuse_series_namespace("list", "min")

def max(self) -> EagerExprT:
return self.compliant._reuse_series_namespace("list", "max")

def mean(self) -> EagerExprT:
return self.compliant._reuse_series_namespace("list", "mean")

def median(self) -> EagerExprT:
return self.compliant._reuse_series_namespace("list", "median")

def sum(self) -> EagerExprT:
return self.compliant._reuse_series_namespace("list", "sum")


class CompliantExprNameNamespace( # type: ignore[misc]
_ExprNamespace[CompliantExprT_co],
Expand Down
26 changes: 25 additions & 1 deletion narwhals/_duckdb/expr_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import ListNamespace
from narwhals._duckdb.utils import F, lit, when
from narwhals._duckdb.utils import F, col, lambda_expr, lit, when
from narwhals._utils import requires

if TYPE_CHECKING:
Expand Down Expand Up @@ -40,3 +40,27 @@ def get(self, index: int) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: F("list_extract", expr, lit(index + 1))
)

def min(self) -> DuckDBExpr:
return self.compliant._with_elementwise(lambda expr: F("list_min", expr))

def max(self) -> DuckDBExpr:
return self.compliant._with_elementwise(lambda expr: F("list_max", expr))

def mean(self) -> DuckDBExpr:
return self.compliant._with_elementwise(lambda expr: F("list_avg", expr))

def median(self) -> DuckDBExpr:
return self.compliant._with_elementwise(lambda expr: F("list_median", expr))

@requires.backend_version((1, 2))
def sum(self) -> DuckDBExpr:
def func(expr: Expression) -> Expression:
elem = col("_")
expr_no_nulls = F("list_filter", expr, lambda_expr(elem, elem.isnotnull()))
expr_sum = F("list_sum", expr_no_nulls)
return when(F("array_length", expr_no_nulls) == lit(0), lit(0)).otherwise(
expr_sum
)

return self.compliant._with_callable(func)
26 changes: 26 additions & 0 deletions narwhals/_ibis/expr_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

from typing import TYPE_CHECKING

from ibis import cases, literal

from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import ListNamespace
from narwhals._utils import not_implemented

if TYPE_CHECKING:
import ibis.expr.types as ir
Expand All @@ -27,3 +30,26 @@ def _get(expr: ir.ArrayColumn) -> ir.Column:
return expr[index]

return self.compliant._with_callable(_get)

def min(self) -> IbisExpr:
return self.compliant._with_callable(lambda expr: expr.mins())

def max(self) -> IbisExpr:
return self.compliant._with_callable(lambda expr: expr.maxs())

def mean(self) -> IbisExpr:
return self.compliant._with_callable(lambda expr: expr.means())

def sum(self) -> IbisExpr:
def func(expr: ir.ArrayColumn) -> ir.Value:
expr_no_nulls = expr.filter(lambda x: x.notnull())
len = expr_no_nulls.length()
return cases(
(len.isnull(), literal(None)),
(len == literal(0), literal(0)),
else_=expr.sums(),
)

return self.compliant._with_callable(func)

median = not_implemented()
40 changes: 40 additions & 0 deletions narwhals/_pandas_like/series_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from narwhals._utils import not_implemented

if TYPE_CHECKING:
from typing import Literal

from narwhals._pandas_like.series import PandasLikeSeries


Expand Down Expand Up @@ -40,3 +42,41 @@ def get(self, index: int) -> PandasLikeSeries:
result = self.native.list[index]
result.name = self.native.name
return self.with_native(result)

def _agg(
self, func: Literal["min", "max", "mean", "approximate_median", "sum"]
) -> PandasLikeSeries:
dtype_backend = get_dtype_backend(
self.native.dtype, self.compliant._implementation
)
if dtype_backend != "pyarrow": # pragma: no cover
msg = "Only pyarrow backend is currently supported."
raise NotImplementedError(msg)

from narwhals._arrow.utils import list_agg, native_to_narwhals_dtype

ca = self.native.array._pa_array
result_arr = list_agg(ca, func)
nw_dtype = native_to_narwhals_dtype(result_arr.type, self.version)
out_dtype = narwhals_to_native_dtype(
nw_dtype, "pyarrow", self.implementation, self.version
)
result_native = type(self.native)(
result_arr, dtype=out_dtype, index=self.native.index, name=self.native.name
)
return self.with_native(result_native)

def min(self) -> PandasLikeSeries:
return self._agg("min")

def max(self) -> PandasLikeSeries:
return self._agg("max")

def mean(self) -> PandasLikeSeries:
return self._agg("mean")

def median(self) -> PandasLikeSeries:
return self._agg("approximate_median")

def sum(self) -> PandasLikeSeries:
return self._agg("sum")
10 changes: 10 additions & 0 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,16 @@ def len(self) -> CompliantT: ...

unique: Method[CompliantT]

max: Method[CompliantT]

mean: Method[CompliantT]

median: Method[CompliantT]

min: Method[CompliantT]

sum: Method[CompliantT]


class PolarsStructNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[Accessor] = "struct"
Expand Down
49 changes: 49 additions & 0 deletions narwhals/_spark_like/expr_list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import operator
from typing import TYPE_CHECKING

from narwhals._compliant import LazyExprNamespace
Expand Down Expand Up @@ -33,3 +34,51 @@ def _get(expr: Column) -> Column:
return expr.getItem(index)

return self.compliant._with_elementwise(_get)

def min(self) -> SparkLikeExpr:
def func(expr: Column) -> Column:
F = self.compliant._F
return F.array_min(expr)

return self.compliant._with_elementwise(func)

def max(self) -> SparkLikeExpr:
def func(expr: Column) -> Column:
F = self.compliant._F
return F.array_max(F.array_compact(expr))

return self.compliant._with_elementwise(func)

def sum(self) -> SparkLikeExpr:
def func(expr: Column) -> Column:
F = self.compliant._F
return F.aggregate(F.array_compact(expr), F.lit(0.0), operator.add)

return self.compliant._with_elementwise(func)

def mean(self) -> SparkLikeExpr:
def func(expr: Column) -> Column:
F = self.compliant._F
return F.try_divide(
F.aggregate(F.array_compact(expr), F.lit(0.0), operator.add),
F.array_size(F.array_compact(expr)),
)

return self.compliant._with_elementwise(func)

def median(self) -> SparkLikeExpr:
def func(expr: Column) -> Column: # pragma: no cover
# sqlframe issue: https://github.com/eakmanrq/sqlframe/issues/548
F = self.compliant._F
sorted_expr = F.array_compact(F.sort_array(expr))
size = F.array_size(sorted_expr)
mid_index = (size / 2).cast("int")
odd_case = sorted_expr[mid_index]
even_case = (sorted_expr[mid_index - 1] + sorted_expr[mid_index]) / 2
return (
F.when((size.isNull()) | (size == 0), F.lit(None))
.when(size % 2 == 1, odd_case)
.otherwise(even_case)
)

return self.compliant._with_elementwise(func)
Loading
Loading