diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 82102e189d..97f5fb9842 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -62,7 +62,15 @@ Target: TypeAlias = OneOrSeq[Field] Aggregation: TypeAlias = Union[ - "_Aggregation", Literal["hash_kurtosis", "hash_skew", "kurtosis", "skew"] + "_Aggregation", + Literal[ + "hash_kurtosis", + "hash_skew", + "hash_pivot_wider", + "kurtosis", + "skew", + "pivot_wider", + ], ] AggregateOptions: TypeAlias = "_AggregateOptions" Opts: TypeAlias = "AggregateOptions | None" diff --git a/narwhals/_plan/arrow/compat.py b/narwhals/_plan/arrow/compat.py index 5aef07b8e8..7bd2607f3d 100644 --- a/narwhals/_plan/arrow/compat.py +++ b/narwhals/_plan/arrow/compat.py @@ -20,7 +20,7 @@ TAKE_ACCEPTS_TUPLE: Final = BACKEND_VERSION >= (18,) HAS_STRUCT_TYPE_FIELDS: Final = BACKEND_VERSION >= (18,) -"""`pyarrow.StructType.fields` added in https://github.com/apache/arrow/pull/43481""" +"""`pyarrow.StructType.{fields,names}` added in https://github.com/apache/arrow/pull/43481""" HAS_SCATTER: Final = BACKEND_VERSION >= (20,) """`pyarrow.compute.scatter` added in https://github.com/apache/arrow/pull/44394""" diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index c1f1bc2bb9..31e220a2a5 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -13,6 +13,7 @@ from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy, partition_by +from narwhals._plan.arrow.pivot import pivot_table from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.common import temp from narwhals._plan.compliant.dataframe import EagerDataFrame @@ -36,7 +37,7 @@ from narwhals._plan.typing import NonCrossJoinStrategy from narwhals._typing import _LazyAllowedImpl from narwhals.dtypes import DType - from narwhals.typing import IntoSchema, UniqueKeepStrategy + from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy Incomplete: TypeAlias = Any @@ -325,6 +326,27 @@ def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[S partitions = partition_by(self.native, by, include_key=include_key) return [from_native(df) for df in partitions] + def pivot( + self, + on: Sequence[str], + on_columns: Self, + *, + index: Sequence[str], + values: Sequence[str], + aggregate_function: PivotAgg | None = None, + separator: str = "_", + ) -> Self: + result = pivot_table( + self.native, + list(on), + on_columns.native, + index, + values, + aggregate_function, + separator, + ) + return self._with_native(result) + def with_array(table: pa.Table, name: str, column: ChunkedOrArrayAny) -> pa.Table: column_names = table.column_names diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 0dd438b98b..627ecd1e8b 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -334,6 +334,12 @@ def struct_schema(native: Arrow[pa.StructScalar] | pa.StructType) -> pa.Schema: return pa.schema(fields) +def struct_field_names(native: Arrow[pa.StructScalar] | pa.StructType) -> list[str]: + """Get the names of all struct fields.""" + tp = native.type if _is_arrow(native) else native + return tp.names if compat.HAS_STRUCT_TYPE_FIELDS else [f.name for f in tp] + + @t.overload def struct_field(native: ChunkedStruct, field: Field, /) -> ChunkedArrayAny: ... @t.overload @@ -1574,6 +1580,7 @@ def concat_str( def concat_str( *arrays: ArrowAny, separator: str = "", ignore_nulls: bool = False ) -> Arrow[StringScalar]: + """Horizontally arrow data into a single string column.""" dtype = string_type(obj.type for obj in arrays) it = (obj.cast(dtype) for obj in arrays) concat: Incomplete = pc.binary_join_element_wise diff --git a/narwhals/_plan/arrow/pivot.py b/narwhals/_plan/arrow/pivot.py new file mode 100644 index 0000000000..2f169f762d --- /dev/null +++ b/narwhals/_plan/arrow/pivot.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import re +from itertools import chain +from typing import TYPE_CHECKING, Any, cast + +import pyarrow.compute as pc + +from narwhals._plan.arrow import ( + acero, + compat, + functions as fn, + group_by, + options as pa_options, +) +from narwhals._plan.arrow.group_by import AggSpec +from narwhals._plan.common import temp +from narwhals._plan.expressions import aggregation as agg + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence + + import pyarrow as pa + + from narwhals._plan.arrow.typing import ChunkedArray, StringScalar + from narwhals.typing import PivotAgg + + +SUPPORTED_PIVOT_AGG: Mapping[PivotAgg, type[agg.AggExpr]] = { + "min": agg.Min, + "max": agg.Max, + "first": agg.First, + "last": agg.Last, + "sum": agg.Sum, + "mean": agg.Mean, + "median": agg.Median, + "len": agg.Len, +} + + +def pivot_table( + native: pa.Table, + on: list[str], + on_columns: pa.Table, + /, + index: Sequence[str], + values: Sequence[str], + aggregate_function: PivotAgg | None, + separator: str, +) -> pa.Table: + """Create a spreadsheet-style `pivot` table. + + Supports multiple-`on` and aggregations. + """ + if len(on) == 1: + on_column = on_columns.column(0) + on_one = on[0] + target = native + else: + on_column = _format_on_columns_titles(on_columns) + on_one = temp.column_name(native.column_names) + target = acero.join_inner_tables( + native, on_columns.append_column(on_one, on_column), on + ).drop(on) + if aggregate_function: + target = _aggregate(target, on_one, index, values, aggregate_function) + return _pivot(target, on_one, on_column.to_pylist(), index, values, separator) + + +def _format_on_columns_titles(on_columns: pa.Table, /) -> ChunkedArray[StringScalar]: + dtype = fn.string_type(on_columns.schema.types) + on_columns = fn.cast_table(on_columns, dtype) + parts = '{"', '"}', "", '","' + LB, RB, EMPTY, SEP = (fn.lit(s, dtype) for s in parts) # noqa: N806 + + # NOTE: Variation of https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.intersperse + seps = (SEP,) * on_columns.num_columns + interspersed = chain.from_iterable(zip(seps, on_columns.itercolumns())) + # skip the first separator, we just need the zip-terminating iterable to be the columns + next(interspersed) + func = "binary_join_element_wise" + args = [LB, *interspersed, RB, EMPTY] + opts = pa_options.join(ignore_nulls=False) + result: ChunkedArray[StringScalar] = pc.call_function(func, args, opts) + return result + + +def _replace_flatten_names( + column_names: list[str], + /, + on_columns_names: Sequence[str], + values: Sequence[str], + separator: str, +) -> list[str]: + """Replace the separator used in unnested struct columns. + + [`pa.Table.flatten`] *unconditionally* uses the separator `"."`, so we *likely* need to fix that here. + + [`pa.Table.flatten`]: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.flatten + """ + if separator == ".": + return column_names + p_on_columns = "|".join(re.escape(name) for name in on_columns_names) + p_values = "|".join(re.escape(name) for name in values) + pattern = re.compile(rf"^(?P{p_on_columns})\.(?P{p_values})\Z") + repl = rf"\g{separator}\g" + return [pattern.sub(repl, s) for s in column_names] + + +def _pivot( + native: pa.Table, + on: str, + on_columns: Sequence[Any], + /, + index: Sequence[str], + values: Sequence[str], + separator: str, +) -> pa.Table: + """Perform a single-`on`, non-aggregating `pivot`.""" + options = _pivot_wider_options(on_columns) + specs = (AggSpec((on, name), "hash_pivot_wider", options, name) for name in values) + pivot = acero.group_by_table(native, index, specs) + flat = pivot.flatten() + if len(values) == 1: + names = [*index, *fn.struct_field_names(pivot.column(values[0]))] + else: + names = _replace_flatten_names(flat.column_names, values, on_columns, separator) + return flat.rename_columns(names) + + +def _aggregate( + native: pa.Table, + on: str, + /, + index: Sequence[str], + values: Sequence[str], + aggregate_function: PivotAgg, +) -> pa.Table: + tp_agg = SUPPORTED_PIVOT_AGG[aggregate_function] + agg_func = group_by.SUPPORTED_AGG[tp_agg] + option = pa_options.AGG.get(tp_agg) + specs = (AggSpec(value, agg_func, option) for value in values) + return acero.group_by_table(native, [*index, on], specs) + + +def _pivot_wider_options(on_columns: Sequence[Any]) -> pc.FunctionOptions: + """Tries to wrap [`pc.PivotWiderOptions`], and raises if we're on an old `pyarrow`. + + [`pc.PivotWiderOptions`]: https://arrow.apache.org/docs/python/generated/pyarrow.compute.PivotWiderOptions.html + """ + if compat.HAS_PIVOT_WIDER and (tp := getattr(pc, "PivotWiderOptions")): # noqa: B009 + tp_options = cast("Callable[..., pc.FunctionOptions]", tp) + return tp_options(on_columns, unexpected_key_behavior="raise") + msg = f"`pivot` requires `pyarrow>=20`, got {compat.BACKEND_VERSION!r}" + raise NotImplementedError(msg) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index f709b21b52..0dba40f48e 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -109,14 +109,14 @@ def flatten_hash_safe(iterable: Iterable[OneOrIterable[Any]], /) -> Iterator[Any yield element -def _not_one_or_iterable_str_error(obj: Any, /) -> TypeError: # pragma: no cover +def _not_one_or_iterable_str_error(obj: Any, /) -> TypeError: msg = f"Expected one or an iterable of strings, but got: {qualified_type_name(obj)!r}\n{obj!r}" return TypeError(msg) def ensure_seq_str(obj: OneOrIterable[str], /) -> Seq[str]: if not isinstance(obj, Iterable): - raise _not_one_or_iterable_str_error(obj) # pragma: no cover + raise _not_one_or_iterable_str_error(obj) return (obj,) if isinstance(obj, str) else tuple(obj) diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index 6400e6bd5c..7a935f0c8e 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -45,7 +45,7 @@ from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl from narwhals._utils import Implementation, Version from narwhals.dtypes import DType - from narwhals.typing import IntoSchema, UniqueKeepStrategy + from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy Incomplete: TypeAlias = Any @@ -152,6 +152,10 @@ def __narwhals_dataframe__(self) -> Self: def lazy(self, backend: _LazyAllowedImpl | None, **kwds: Any) -> LazyFrameAny: ... @property def shape(self) -> tuple[int, int]: ... + @property + def width(self) -> int: + return self.shape[-1] + def __len__(self) -> int: ... @property def _group_by(self) -> type[DataFrameGroupBy[Self]]: ... @@ -222,6 +226,16 @@ def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ... def partition_by( self, by: Sequence[str], *, include_key: bool = True ) -> list[Self]: ... + def pivot( + self, + on: Sequence[str], + on_columns: Self, + *, + index: Sequence[str], + values: Sequence[str], + aggregate_function: PivotAgg | None = None, + separator: str = "_", + ) -> Self: ... def row(self, index: int) -> tuple[Any, ...]: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 78675b4e3a..134e46ec1e 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Sequence from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, get_args, overload from narwhals._plan import _parse @@ -26,7 +27,7 @@ PartialSeries, Seq, ) -from narwhals._utils import Implementation, Version, generate_repr +from narwhals._utils import Implementation, Version, generate_repr, qualified_type_name from narwhals.dependencies import is_pyarrow_table from narwhals.exceptions import InvalidOperationError, ShapeError from narwhals.schema import Schema @@ -37,11 +38,12 @@ IntoDType, IntoSchema, JoinStrategy, + PivotAgg, UniqueKeepStrategy, ) if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping, Sequence + from collections.abc import Iterable, Iterator, Mapping from io import BytesIO import polars as pl @@ -261,6 +263,22 @@ def fn(values: Iterable[Any], /) -> Series[NativeSeriesT]: return fn + def _parse_into_compliant_series( + self, other: Series[Any] | Iterable[Any], /, name: str = "" + ) -> CompliantSeries[NativeSeriesT]: + if columns := self.columns: + compliant = self.get_column(columns[0])._parse_into_compliant(other) + return compliant if not name or compliant.name else compliant.alias(name) + else: # pragma: no cover # noqa: RET505 + tp_series = self.__narwhals_namespace__()._series + if not is_series(other): + return tp_series.from_iterable(other, version=self.version, name=name) + s = other._compliant + if isinstance(s, tp_series): + return s + msg = f"Expected {qualified_type_name(tp_series)!r}, got {qualified_type_name(s)!r}" + raise NotImplementedError(msg) + @overload @classmethod def from_native( @@ -452,6 +470,80 @@ def partition_by( partitions = self._compliant.partition_by(names, include_key=include_key) return [self._with_compliant(p) for p in partitions] + # TODO @dangotbanned: (Follow-up) Accept selectors in `on`, `index`, `values` + def pivot( + self, + on: OneOrIterable[str], + on_columns: Sequence[str] | Series | Self | None = None, + *, + index: OneOrIterable[str] | None = None, + values: OneOrIterable[str] | None = None, + aggregate_function: PivotAgg | None = None, + sort_columns: bool = False, + separator: str = "_", + ) -> Self: + from narwhals._plan import functions as F + + on_, index_, values_ = normalize_pivot_args( + on, index=index, values=values, frame_columns=self.columns + ) + dtype_str = self.version.dtypes.String() + on_cols: EagerDataFrame[IncompleteCyclic, NativeDataFrameT_co, NativeSeriesT] + + if on_columns is None: + nw_on_cols = self.select(F.col(name).cast(dtype_str) for name in on_).unique( + on_, maintain_order=True + ) + if sort_columns: + nw_on_cols = nw_on_cols.sort(on_) + on_cols = nw_on_cols._compliant + elif isinstance(on_columns, DataFrame): + on_cols = on_columns._compliant + else: + on_cols = ( + self._parse_into_compliant_series(on_columns, on_[0]) + .cast(dtype_str) + .to_frame() + ) + + if len(on_) != on_cols.width: + msg = "`pivot` expected `on` and `on_columns` to have the same amount of columns." + raise InvalidOperationError(msg) + if on_ != tuple(on_cols.columns): + msg = "`pivot` has mismatching column names between `on` and `on_columns`." + raise InvalidOperationError(msg) + + return self._with_compliant( + self._compliant.pivot( + on_, + on_cols, + index=index_, + values=values_, + aggregate_function=aggregate_function, + separator=separator, + ) + ) + + def sort( + self, + by: OneOrIterable[ColumnNameOrSelector], + *more_by: ColumnNameOrSelector, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, + ) -> Self: + if ( + not more_by + and _is_sort_by_one(by, self.columns) + and isinstance(descending, bool) + and isinstance(nulls_last, bool) + ): + return self._with_compliant( + self.to_series() + ._compliant.sort(descending=descending, nulls_last=nulls_last) + .to_frame() + ) + return super().sort(by, *more_by, descending=descending, nulls_last=nulls_last) + def unique( self, subset: OneOrIterable[ColumnNameOrSelector] | None = None, @@ -469,11 +561,18 @@ def unique( s_irs, schema=schema, require_any=True ) if order_by is None: - return self._with_compliant( - self._compliant.unique( + if len(schema) == 1 and keep in {"any", "first"}: + # NOTE: Fastpath for single-column frame + result = ( + self.to_series() + ._compliant.unique(maintain_order=maintain_order) + .to_frame() + ) + else: + result = self._compliant.unique( subset_names, keep=keep, maintain_order=maintain_order ) - ) + return self._with_compliant(result) s_irs = _parse.parse_into_seq_of_selector_ir(order_by) by_names = expand_selector_irs_names(s_irs, schema=schema, require_any=True) return self._with_compliant( @@ -531,6 +630,18 @@ def sample( return type(self)(result) +def _is_sort_by_one( + by: OneOrIterable[ColumnNameOrSelector], frame_columns: list[str] +) -> bool: + """Return True if requested to sort a single-column DataFrame - without consuming iterators.""" + columns = frame_columns + if len(columns) != 1: + return False + return (isinstance(by, str) and by in columns) or ( + isinstance(by, Sequence) and len(by) == 1 and by[0] in columns + ) + + def _is_join_strategy(obj: Any) -> TypeIs[JoinStrategy]: return obj in {"inner", "left", "full", "cross", "anti", "semi"} @@ -579,3 +690,40 @@ def normalize_join_on( raise ValueError(msg) on = ensure_seq_str(on) return on, on + + +def normalize_pivot_args( + on: OneOrIterable[str], + *, + index: OneOrIterable[str] | None, + values: OneOrIterable[str] | None, + frame_columns: list[str], +) -> tuple[Seq[str], Seq[str], Seq[str]]: + """Derive a pivot specification from optional arguments. + + Returns in the order: + + (on, index, values) + """ + columns = frame_columns + on_ = ensure_seq_str(on) + if not on_: + msg = "`pivot` called without `on` columns." + raise InvalidOperationError(msg) + if index is None: + if values is None: + msg = "At least one of `values` and `index` must be passed" + raise ValueError(msg) + values_ = ensure_seq_str(values) + index_ = tuple( + nm for nm in columns if nm in set(columns).difference(on_, values_) + ) + elif values is None: + index_ = ensure_seq_str(index) + values_ = tuple( + nm for nm in columns if nm in set(columns).difference(on_, index_) + ) + else: + index_ = ensure_seq_str(index) + values_ = ensure_seq_str(values) + return on_, index_, values_ diff --git a/tests/plan/pivot_test.py b/tests/plan/pivot_test.py new file mode 100644 index 0000000000..8a52dccaaa --- /dev/null +++ b/tests/plan/pivot_test.py @@ -0,0 +1,515 @@ +from __future__ import annotations + +from collections import deque + +# ruff: noqa: FBT001 +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals._plan as nwp +import narwhals._plan.selectors as ncs +from narwhals._utils import Implementation +from narwhals.exceptions import InvalidOperationError, NarwhalsError +from tests.plan.utils import assert_equal_data, dataframe, re_compile +from tests.utils import PYARROW_VERSION + +if TYPE_CHECKING: + from narwhals._plan.typing import OneOrIterable + from narwhals.typing import PivotAgg + from tests.conftest import Data + +# TODO @dangotbanned: Add more values for `median`, like in +# https://github.com/narwhals-dev/narwhals/blob/b3d8c7349bbf7ecb7f11ea590c334e12d5c1d43e/tests/plan/list_agg_test.py#L24-L34 +XFAIL_PYARROW_MEDIAN = pytest.mark.xfail( + reason="Tried to use `'approximate_median'` but groups are too small", + raises=(AssertionError, NotImplementedError), +) + +# TODO @dangotbanned: Consider fixing this? +# The `pandas` impl on `main` has the same issue +XFAIL_ALWAYS_ZERO_AGG = pytest.mark.xfail( + reason="`sum` & `len` are special-cased in `polars` to always return 0 instead on `None`", + raises=(AssertionError, NotImplementedError), +) + + +@pytest.fixture(scope="module") +def scores() -> Data: + """Dataset 1 `pl.DataFrame.pivot` docstring.""" + return { + "name": ["Cady", "Cady", "Karen", "Karen"], + "subject": ["maths", "physics", "maths", "physics"], + "test_1": [98, 99, 61, 58], + "test_2": [100, 100, 60, 60], + } + + +@pytest.fixture(scope="module") +def data() -> Data: + return { + "idx_1": [1, 2, 1, 1, 2, 2], + "idx_2": [1, 2, 2, 1, 2, 1], + "on_lower": ["b", "b", "a", "a", "a", "a"], + "on_upper": ["X", "Y", "Y", "Y", "X", "Y"], + "foo": [7, 1, 0, 1, 2, 2], + "bar": [9, 4, 0, 2, 0, 0], + } + + +@pytest.fixture(scope="module") +def data_no_dups() -> Data: + return { + "idx_1": [1, 1, 2, 2], + "on_lower": ["a", "b", "a", "b"], + "foo": [1, 2, 3, 4], + "bar": ["x", "y", "z", "w"], + } + + +@pytest.fixture(scope="module") +def data_no_dups_unordered(data_no_dups: Data) -> Data: + """Variant of `data_no_dups` to support tests without needing `aggregate_function`. + + - `"on_lower"` has an order to test `sort_columns=True` + - `"on_upper"` is added for `on: list[str]` name generation + """ + return data_no_dups | { + "on_lower": ["b", "a", "b", "a"], + "on_upper": ["X", "X", "Y", "Y"], + } + + +def assert_names_match_polars( + input_data: Data, + on: list[str], + index: str | list[str], + values: list[str], + aggregate_function: PivotAgg | None = None, + *, + result_columns: list[str], +) -> None: + """Ensure the complex renaming cases match upstream.""" + pytest.importorskip("polars") + import polars as pl + + df = pl.DataFrame(input_data) + + pl_result = df.pivot( + on=on, values=values, index=index, aggregate_function=aggregate_function + ) + assert result_columns == pl_result.columns + + +def require_pyarrow_20( + df: nwp.DataFrame[Any, Any], request: pytest.FixtureRequest +) -> None: + request.applymarker( + pytest.mark.xfail( + ( + df.implementation is Implementation.PYARROW + and PYARROW_VERSION < (20, 0, 0) + ), + reason="pyarrow too old for `pivot` support", + raises=NotImplementedError, + ) + ) + + +@pytest.mark.parametrize( + ("agg_func", "expected"), + [ + ( + "min", + { + "idx_1": [1, 2], + "foo_a": [0, 2], + "foo_b": [7, 1], + "bar_a": [0, 0], + "bar_b": [9, 4], + }, + ), + ( + "max", + { + "idx_1": [1, 2], + "foo_a": [1, 2], + "foo_b": [7, 1], + "bar_a": [2, 0], + "bar_b": [9, 4], + }, + ), + ( + "first", + { + "idx_1": [1, 2], + "foo_a": [0, 2], + "foo_b": [7, 1], + "bar_a": [0, 0], + "bar_b": [9, 4], + }, + ), + ( + "last", + { + "idx_1": [1, 2], + "foo_a": [1, 2], + "foo_b": [7, 1], + "bar_a": [2, 0], + "bar_b": [9, 4], + }, + ), + ( + "sum", + { + "idx_1": [1, 2], + "foo_a": [1, 4], + "foo_b": [7, 1], + "bar_a": [2, 0], + "bar_b": [9, 4], + }, + ), + ( + "mean", + { + "idx_1": [1, 2], + "foo_a": [0.5, 2.0], + "foo_b": [7.0, 1.0], + "bar_a": [1.0, 0.0], + "bar_b": [9.0, 4.0], + }, + ), + pytest.param( + "median", + { + "idx_1": [1, 2], + "foo_a": [0.5, 2.0], + "foo_b": [7.0, 1.0], + "bar_a": [1.0, 0.0], + "bar_b": [9.0, 4.0], + }, + marks=XFAIL_PYARROW_MEDIAN, + ), + ( + "len", + { + "idx_1": [1, 2], + "foo_a": [2, 2], + "foo_b": [1, 1], + "bar_a": [2, 2], + "bar_b": [1, 1], + }, + ), + ], +) +@pytest.mark.parametrize( + ("on", "index"), + [("on_lower", "idx_1"), (deque(["on_lower"]), dict.fromkeys(["idx_1"]).keys())], +) +def test_pivot_agg( + data: Data, + on: OneOrIterable[str], + index: OneOrIterable[str], + agg_func: PivotAgg, + expected: Data, + request: pytest.FixtureRequest, +) -> None: + df = dataframe(data) + require_pyarrow_20(df, request) + result = df.pivot( + on, + index=index, + values=["foo", "bar"], + aggregate_function=agg_func, + sort_columns=True, + ) + + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("sort_columns", "expected"), + [ + (True, ["idx_1", "foo_a", "foo_b", "bar_a", "bar_b"]), + (False, ["idx_1", "foo_b", "foo_a", "bar_b", "bar_a"]), + ], +) +def test_pivot_sort_columns( + data_no_dups_unordered: Data, + sort_columns: bool, + expected: list[str], + request: pytest.FixtureRequest, +) -> None: + df = dataframe(data_no_dups_unordered) + require_pyarrow_20(df, request) + values = ["foo", "bar"] + result = df.pivot("on_lower", index="idx_1", values=values, sort_columns=sort_columns) + assert result.columns == expected + + +@pytest.mark.parametrize( + ("on", "values", "expected"), + [ + ( + ["on_lower", "on_upper"], + ["foo"], + ["idx_1", '{"b","X"}', '{"a","X"}', '{"b","Y"}', '{"a","Y"}'], + ), + ( + ["on_lower", "on_upper"], + ["foo", "bar"], + [ + "idx_1", + 'foo_{"b","X"}', + 'foo_{"a","X"}', + 'foo_{"b","Y"}', + 'foo_{"a","Y"}', + 'bar_{"b","X"}', + 'bar_{"a","X"}', + 'bar_{"b","Y"}', + 'bar_{"a","Y"}', + ], + ), + ], + ids=["single-values", "multiple-values"], +) +def test_pivot_on_multiple_names( + data_no_dups_unordered: Data, + on: list[str], + values: list[str], + expected: list[str], + request: pytest.FixtureRequest, +) -> None: + index = "idx_1" + data_ = data_no_dups_unordered + df = dataframe(data_) + require_pyarrow_20(df, request) + result = df.pivot(on, values=values, index=index) + assert result.columns == expected + assert_names_match_polars(data_, on, index, values, result_columns=result.columns) + + +@pytest.mark.parametrize( + ("on", "values", "expected"), + [ + ( + ["on_lower", "on_upper"], + ["foo"], + ["idx_1", '{"b","X"}', '{"b","Y"}', '{"a","Y"}', '{"a","X"}'], + ), + ( + ["on_lower", "on_upper"], + ["foo", "bar"], + [ + "idx_1", + 'foo_{"b","X"}', + 'foo_{"b","Y"}', + 'foo_{"a","Y"}', + 'foo_{"a","X"}', + 'bar_{"b","X"}', + 'bar_{"b","Y"}', + 'bar_{"a","Y"}', + 'bar_{"a","X"}', + ], + ), + ], + ids=["single-values", "multiple-values"], +) +def test_pivot_on_multiple_names_agg( + data: Data, + on: list[str], + values: list[str], + expected: list[str], + request: pytest.FixtureRequest, +) -> None: + index = "idx_1" + df = dataframe(data) + require_pyarrow_20(df, request) + result = df.pivot(on, values=values, aggregate_function="min", index=index) + assert result.columns == expected + assert_names_match_polars( + data, on, index, values, "min", result_columns=result.columns + ) + + +def test_pivot_no_agg_duplicated(data: Data, request: pytest.FixtureRequest) -> None: + df = dataframe(data) + require_pyarrow_20(df, request) + with pytest.raises((ValueError, NarwhalsError)): + df.pivot("on_lower", index="idx_1") + + +def test_pivot_no_agg_no_duplicates( + data_no_dups: Data, request: pytest.FixtureRequest +) -> None: + df = dataframe(data_no_dups) + require_pyarrow_20(df, request) + result = df.pivot("on_lower", index="idx_1") + expected = { + "idx_1": [1, 2], + "foo_a": [1, 3], + "foo_b": [2, 4], + "bar_a": ["x", "z"], + "bar_b": ["y", "w"], + } + assert_equal_data(result, expected) + + +@pytest.mark.xfail( + reason=( + "BUG: Incorrect results, `pyarrow` not consistent with `polars` and `pandas`.\n" + "https://github.com/apache/arrow/issues/48679" + ), + raises=(AssertionError, NotImplementedError), +) +def test_pivot_no_values() -> None: + # https://github.com/pola-rs/polars/blob/473951bcf8c49fc23bee5ee7b8853b5dd063cb9d/py-polars/tests/unit/operations/test_pivot.py#L39-L65 + data = { + "foo": ["A", "A", "B", "B", "C"], + "bar": ["k", "l", "m", "n", "o"], + "N1": [1, 2, 2, 4, 2], + "N2": [1, 2, 2, 4, 2], + } + df = dataframe(data) + result = df.pivot(on="bar", index="foo") + expected = { + "foo": ["A", "B", "C"], + "N1_k": [1, None, None], + "N1_l": [2, None, None], + "N1_m": [None, 2, None], + "N1_n": [None, 4, None], # < these 2 are flipped for pyarrow? + "N1_o": [None, None, 2], # < + "N2_k": [1, None, None], + "N2_l": [2, None, None], + "N2_m": [None, 2, None], + "N2_n": [None, 4, None], # < and down here as well + "N2_o": [None, None, 2], # < + } + assert_equal_data(result, expected) + + +def test_pivot_no_index_no_values(data_no_dups: Data) -> None: + df = dataframe(data_no_dups) + with pytest.raises( + ValueError, match=re_compile(r"at least one of.+values.+index.+must") + ): + df.pivot("on_lower") + + +def test_pivot_on_invalid(data: Data) -> None: + df = dataframe(data) + with pytest.raises(InvalidOperationError, match=r"`pivot` called without `on`"): + df.pivot([], index="idx_1", values="foo") + + +def test_pivot_on_columns_invalid(data: Data) -> None: + df = dataframe(data) + on_1 = "on_lower" + on_2 = ["on_lower", "on_upper"] + index = "idx_1" + + df_1 = df.select(on_1) + df_2 = df.select(on_2) + df_2_mismatch = df_2.rename(dict(zip(on_2, reversed(on_2)))) + ser = df.get_column(on_1) + + same_n_cols = r"expected `on`.+`on_columns`.+same amount of columns" + mismatch_names = r"mismatching column names between `on` and `on_columns`" + + with pytest.raises(InvalidOperationError, match=same_n_cols): + df.pivot(on_1, df_2, index=index) + with pytest.raises(InvalidOperationError, match=same_n_cols): + df.pivot(on_2, df_1, index=index) + with pytest.raises(InvalidOperationError, match=same_n_cols): + df.pivot(on_2, ser, index=index) + with pytest.raises(InvalidOperationError, match=mismatch_names): + df.pivot(on_2, df_2_mismatch, index=index) + with pytest.raises(InvalidOperationError, match=mismatch_names): + df.pivot([on_1], ser.alias("bad"), index=index) + + +def test_pivot_non_iterable_invalid() -> None: + small = {"a": [1], "b": [2], "c": [3]} + df = dataframe(small) + match = re_compile(r"expected one or.+iterable.+string.+got.+int") + with pytest.raises(TypeError, match=match): + df.pivot(1, index="b", values="c") # type: ignore[arg-type] + with pytest.raises(TypeError, match=match): + df.pivot("a", index=2, values="c") # type: ignore[arg-type] + with pytest.raises(TypeError, match=match): + df.pivot("a", index="b", values=3) # type: ignore[arg-type] + with pytest.raises(TypeError, match=match): + df.pivot(1, index=2, values="c") # type: ignore[arg-type] + with pytest.raises(TypeError, match=match): + df.pivot("a", index=2, values=3) # type: ignore[arg-type] + with pytest.raises(TypeError, match=match): + df.pivot(1, index="b", values=3) # type: ignore[arg-type] + + +def test_pivot_implicit_index(data_no_dups: Data, request: pytest.FixtureRequest) -> None: + df = dataframe(data_no_dups) + require_pyarrow_20(df, request) + expected = { + "idx_1": [1, 1, 2, 2], + "bar": ["x", "y", "w", "z"], + "a": [1.0, None, None, 3.0], + "b": [None, 2.0, 4.0, None], + } + result = df.pivot("on_lower", values="foo").sort(ncs.by_index(0, 1)) + assert_equal_data(result, expected) + + +def test_pivot_test_scores_1(scores: Data, request: pytest.FixtureRequest) -> None: + df = dataframe(scores) + require_pyarrow_20(df, request) + expected = {"name": ["Cady", "Karen"], "maths": [98, 61], "physics": [99, 58]} + result = df.pivot("subject", index="name", values="test_1") + assert_equal_data(result, expected) + result = df.pivot( + "subject", on_columns=["maths", "physics"], index="name", values="test_1" + ) + assert_equal_data(result, expected) + + +def test_pivot_test_scores_2(scores: Data, request: pytest.FixtureRequest) -> None: + df = dataframe(scores) + require_pyarrow_20(df, request) + expected = { + "name": ["Cady", "Karen"], + "test_1_maths": [98, 61], + "test_1_physics": [99, 58], + "test_2_maths": [100, 60], + "test_2_physics": [100, 60], + } + result = df.pivot("subject", values=["test_1", "test_2"]) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("agg_fn", "expected_rows"), + [ + pytest.param( + "sum", [("a", 6, 0, 0), ("b", 0, 8, 10)], marks=XFAIL_ALWAYS_ZERO_AGG + ), + pytest.param( + "len", [("a", 2, 0, 0), ("b", 0, 2, 1)], marks=XFAIL_ALWAYS_ZERO_AGG + ), + ("first", [("a", 2, None, None), ("b", None, None, 10)]), + ("min", [("a", 2, None, None), ("b", None, 8, 10)]), + ("max", [("a", 4, None, None), ("b", None, 8, 10)]), + ("mean", [("a", 3.0, None, None), ("b", None, 8.0, 10.0)]), + ], +) +def test_pivot_aggregate( + agg_fn: PivotAgg, expected_rows: list[tuple[Any, ...]], request: pytest.FixtureRequest +) -> None: + # https://github.com/pola-rs/polars/blob/473951bcf8c49fc23bee5ee7b8853b5dd063cb9d/py-polars/tests/unit/operations/test_pivot.py#L89-L112 + df = dataframe( + {"a": [1, 1, 2, 2, 3], "b": ["a", "a", "b", "b", "b"], "c": [2, 4, None, 8, 10]} + ) + require_pyarrow_20(df, request) + result = df.pivot( + "a", index="b", values="c", aggregate_function=agg_fn, sort_columns=True + ) + result_rows = [*zip(*result.to_dict(as_series=False).values())] + assert result_rows == expected_rows