diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index d1c942191c..9e5b6bc2c6 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -25,6 +25,7 @@ import pyarrow.compute as pc # ignore-banned-import from pyarrow.acero import Declaration as Decl +from narwhals._plan.arrow.guards import is_expression from narwhals._plan.common import ensure_list_str, temp from narwhals._plan.typing import NonCrossJoinStrategy, OneOrSeq from narwhals._utils import check_column_names_are_unique @@ -40,7 +41,7 @@ Sequence, ) - from typing_extensions import TypeAlias, TypeIs + from typing_extensions import TypeAlias from narwhals._arrow.typing import ( # type: ignore[attr-defined] AggregateOptions as _AggregateOptions, @@ -117,12 +118,8 @@ def cols_iter(names: Iterable[str], /) -> Iterator[Expr]: yield col(name) -def _is_expr(obj: Any) -> TypeIs[pc.Expression]: - return isinstance(obj, pc.Expression) - - def _parse_into_expr(into: IntoExpr, /, *, str_as_lit: bool = False) -> Expr: - if _is_expr(into): + if is_expression(into): return into if isinstance(into, str) and not str_as_lit: return col(into) @@ -223,7 +220,7 @@ def project(**named_exprs: IntoExpr) -> Decl: def _add_column(native: pa.Table, index: int, name: str, values: IntoExpr) -> Decl: - column = values if _is_expr(values) else lit(values) + column = values if is_expression(values) else lit(values) schema = native.schema schema_names = schema.names if index == 0: @@ -323,8 +320,8 @@ def _join_asof_strategy_to_tolerance( if strategy == "nearest": msg = "Only 'backward' and 'forward' strategies are currently supported for `pyarrow`" raise NotImplementedError(msg) - lower = fn.min_horizontal(fn.min_(left_on), fn.min_(right_on)) - upper = fn.max_horizontal(fn.max_(left_on), fn.max_(right_on)) + lower = fn.min_horizontal(fn.min(left_on), fn.min(right_on)) + upper = fn.max_horizontal(fn.max(left_on), fn.max(right_on)) scalar = fn.sub(lower, upper) if strategy == "backward" else fn.sub(upper, lower) tolerance: int = fn.cast(scalar, fn.I64).as_py() return tolerance diff --git a/narwhals/_plan/arrow/common.py b/narwhals/_plan/arrow/common.py index 5a6a47bd37..3588f3648e 100644 --- a/narwhals/_plan/arrow/common.py +++ b/narwhals/_plan/arrow/common.py @@ -2,27 +2,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, Generic +from typing import TYPE_CHECKING, ClassVar, Generic from narwhals._plan.arrow import compat from narwhals._plan.arrow.functions import random_indices +from narwhals._plan.arrow.guards import is_series from narwhals._typing_compat import TypeVar from narwhals._utils import Implementation, Version, _StoresNative if TYPE_CHECKING: import pyarrow as pa - from typing_extensions import Self, TypeIs + from typing_extensions import Self from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.arrow.typing import ChunkedArrayAny, Indices -def is_series(obj: Any) -> TypeIs[_StoresNative[ChunkedArrayAny]]: - from narwhals._plan.arrow.series import ArrowSeries - - return isinstance(obj, ArrowSeries) - - NativeT = TypeVar("NativeT", "pa.Table", "ChunkedArrayAny") diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index b80b59e039..7906de1c47 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -12,7 +12,11 @@ from narwhals._plan.arrow import acero, compat, functions as fn from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar -from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy, partition_by +from narwhals._plan.arrow.group_by import ( + ArrowGroupBy as GroupBy, + partition_by, + unique_keep_boolean_length_preserving, +) from narwhals._plan.arrow.pivot import pivot_table from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.common import temp @@ -142,7 +146,7 @@ def _unique( [`unsort_indices`]: https://github.com/narwhals-dev/narwhals/blob/9b9122b4ab38a6aebe2f09c29ad0f6191952a7a7/narwhals/_plan/arrow/functions.py#L1666-L1697 """ subset = tuple(subset or self.columns) - into_column_agg, mask = fn.unique_keep_boolean_length_preserving(keep) + into_column_agg, mask = unique_keep_boolean_length_preserving(keep) idx_name = temp.column_name(self.columns) df = self.select_names(*set(subset).union(order_by)) if order_by: @@ -227,7 +231,7 @@ def to_struct(self, name: str = "") -> Series: else: struct = fn.chunked_array([], pa.struct(native.schema)) else: - struct = fn.struct(native.column_names, native.columns) + struct = fn.struct.into_struct(native.columns, native.column_names) return Series.from_native(struct, name, version=self.version) def get_column(self, name: str) -> Series: diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 0d65d8c4b3..66b0b63ba7 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -6,7 +6,6 @@ import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import -from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan import common, expressions as ir from narwhals._plan._guards import ( is_function_expr, @@ -15,7 +14,7 @@ is_seq_column, ) from narwhals._plan.arrow import functions as fn -from narwhals._plan.arrow.group_by import AggSpec +from narwhals._plan.arrow.group_by import BOOLEAN_LENGTH_PRESERVING, AggSpec from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co from narwhals._plan.common import temp @@ -129,7 +128,7 @@ def __narwhals_namespace__(self) -> ArrowNamespace: def _with_native(self, native: Any, name: str, /) -> StoresNativeT_co: ... def cast(self, node: ir.Cast, frame: Frame, name: str) -> StoresNativeT_co: - data_type = narwhals_to_native_dtype(node.dtype, frame.version) + data_type = fn.dtype_native(node.dtype, frame.version) native = node.expr.dispatch(self, frame, name).native return self._with_native(fn.cast(native, data_type), name) @@ -160,7 +159,7 @@ def is_between( native = expr.dispatch(self, frame, name).native lower = lower_bound.dispatch(self, frame, "lower").native upper = upper_bound.dispatch(self, frame, "upper").native - result = fn.is_between(native, lower, upper, node.function.closed) + result = fn.is_between(native, lower, upper, closed=node.function.closed) return self._with_native(result, name) @overload @@ -202,18 +201,18 @@ def func(node: FExpr[Any], frame: Frame, name: str, /) -> StoresNativeT_co: return func def abs(self, node: FExpr[Abs], frame: Frame, name: str) -> StoresNativeT_co: - return self._unary_function(fn.abs_)(node, frame, name) + return self._unary_function(fn.abs)(node, frame, name) def not_(self, node: FExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.not_)(node, frame, name) def all(self, node: FExpr[All], frame: Frame, name: str) -> StoresNativeT_co: - return self._unary_function(fn.all_)(node, frame, name) + return self._unary_function(fn.all)(node, frame, name) def any( self, node: FExpr[ir.boolean.Any], frame: Frame, name: str ) -> StoresNativeT_co: - return self._unary_function(fn.any_)(node, frame, name) + return self._unary_function(fn.any)(node, frame, name) def is_finite( self, node: FExpr[IsFinite], frame: Frame, name: str @@ -466,16 +465,16 @@ def last(self, node: Last, frame: Frame, name: str) -> Scalar: def arg_min(self, node: ArgMin, frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.expr, frame, name).native - result = pc.index(native, fn.min_(native)) + result = pc.index(native, fn.min(native)) return self._with_native(result, name) def arg_max(self, node: ArgMax, frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.expr, frame, name).native - result: NativeScalar = pc.index(native, fn.max_(native)) + result: NativeScalar = pc.index(native, fn.max(native)) return self._with_native(result, name) def sum(self, node: Sum, frame: Frame, name: str) -> Scalar: - result = fn.sum_(self._dispatch_expr(node.expr, frame, name).native) + result = fn.sum(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def n_unique(self, node: NUnique, frame: Frame, name: str) -> Scalar: @@ -511,7 +510,7 @@ def len(self, node: Len, frame: Frame, name: str) -> Scalar: return self._with_native(result, name) def max(self, node: Max, frame: Frame, name: str) -> Scalar: - result: NativeScalar = fn.max_(self._dispatch_expr(node.expr, frame, name).native) + result: NativeScalar = fn.max(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def mean(self, node: Mean, frame: Frame, name: str) -> Scalar: @@ -523,7 +522,7 @@ def median(self, node: Median, frame: Frame, name: str) -> Scalar: return self._with_native(result, name) def min(self, node: Min, frame: Frame, name: str) -> Scalar: - result: NativeScalar = fn.min_(self._dispatch_expr(node.expr, frame, name).native) + result: NativeScalar = fn.min(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) def null_count(self, node: FExpr[F.NullCount], frame: Frame, name: str) -> Scalar: @@ -580,7 +579,7 @@ def _boolean_length_preserving( sort_indices: pa.UInt64Array | None = None, ) -> Self: # NOTE: This subset of functions can be expressed as a mask applied to indices - into_column_agg, mask = fn.BOOLEAN_LENGTH_PRESERVING[type(node.function)] + into_column_agg, mask = BOOLEAN_LENGTH_PRESERVING[type(node.function)] idx_name = temp.column_name(frame) df = frame._with_columns([node.input[0].dispatch(self, frame, name)]) if sort_indices is not None: @@ -662,7 +661,8 @@ def sample_frac(self, node: FExpr[F.SampleFrac], frame: Frame, name: str) -> Sel return self.from_series(result) def drop_nulls(self, node: FExpr[F.DropNulls], frame: Frame, name: str) -> Self: - return self._vector_function(fn.drop_nulls)(node, frame, name) + series = self._dispatch_expr(node.input[0], frame, name) + return self.from_series(series.drop_nulls()) def mode_all(self, node: FExpr[F.ModeAll], frame: Frame, name: str) -> Self: return self._vector_function(fn.mode_all)(node, frame, name) @@ -750,8 +750,8 @@ def hist_bin_count( else: # NOTE: `Decimal` is not supported, but excluding it from the typing is surprisingly complicated # https://docs.rs/polars-core/0.52.0/polars_core/datatypes/enum.DataType.html#method.is_primitive_numeric - lower: NativeScalar = fn.min_(native) - upper: NativeScalar = fn.max_(native) + lower: NativeScalar = fn.min(native) + upper: NativeScalar = fn.max(native) if lower.equals(upper): # All data points are identical - use unit interval rhs = fn.lit(0.5) @@ -811,9 +811,8 @@ def from_python( dtype: IntoDType | None = None, version: Version = Version.MAIN, ) -> Self: - dtype_pa: pa.DataType | None = None - if dtype and dtype != version.dtypes.Unknown: - dtype_pa = narwhals_to_native_dtype(dtype, version) + unknown = version.dtypes.Unknown + dtype_pa = None if dtype == unknown else fn.dtype_native(dtype, version) return cls.from_native(fn.lit(value, dtype_pa), name, version) @classmethod @@ -884,7 +883,7 @@ def drop_nulls( # type: ignore[override] previous = node.input[0].dispatch(self, frame, name) if previous.native.is_valid: return previous - chunked = fn.chunked_array([[]], previous.native.type) + chunked = fn.chunked_array([], previous.native.type) return ArrowExpr.from_native(chunked, name, version=self.version) @property @@ -956,20 +955,20 @@ def unary( class ArrowCatNamespace(ExprCatNamespace["Frame", "Expr"], ArrowAccessor[ExprOrScalarT]): def get_categories(self, node: FExpr[GetCategories], frame: Frame, name: str) -> Expr: native = node.input[0].dispatch(self.compliant, frame, name).native - return ArrowExpr.from_native(fn.get_categories(native), name, self.version) + return ArrowExpr.from_native(fn.cat.get_categories(native), name, self.version) class ArrowListNamespace( ExprListNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): def len(self, node: FExpr[lists.Len], frame: Frame, name: str) -> Expr | Scalar: - return self.unary(fn.list_len)(node, frame, name) + return self.unary(fn.list.len)(node, frame, name) def get(self, node: FExpr[lists.Get], frame: Frame, name: str) -> Expr | Scalar: - return self.unary(fn.list_get, node.function.index)(node, frame, name) + return self.unary(fn.list.get, node.function.index)(node, frame, name) def unique(self, node: FExpr[lists.Unique], frame: Frame, name: str) -> Expr | Scalar: - return self.unary(fn.list_unique)(node, frame, name) + return self.unary(fn.list.unique)(node, frame, name) def contains( self, node: FExpr[lists.Contains], frame: Frame, name: str @@ -981,16 +980,16 @@ def contains( if isinstance(item, ArrowExpr): # Maybe one day, not now raise NotImplementedError - return self.with_native(fn.list_contains(prev.native, item.native), name) + return self.with_native(fn.list.contains(prev.native, item.native), name) def join(self, node: FExpr[lists.Join], frame: Frame, name: str) -> Expr | Scalar: separator, ignore_nulls = node.function.separator, node.function.ignore_nulls previous = node.input[0].dispatch(self.compliant, frame, name) result: ChunkedOrScalarAny if isinstance(previous, ArrowExpr): - result = fn.list_join(previous.native, separator, ignore_nulls=ignore_nulls) + result = fn.list.join(previous.native, separator, ignore_nulls=ignore_nulls) else: - result = fn.list_join_scalar( + result = fn.list.join_scalar( previous.native, separator, ignore_nulls=ignore_nulls ) return self.with_native(result, name) @@ -1006,11 +1005,11 @@ def sort(self, node: FExpr[lists.Sort], frame: Frame, name: str) -> Expr | Scala previous = node.input[0].dispatch(self.compliant, frame, name) result: ChunkedOrScalarAny if isinstance(previous, ArrowScalar): - result = fn.list_sort_scalar(previous.native, node.function.options) + result = fn.list.sort_scalar(previous.native, node.function.options) else: descending = node.function.options.descending nulls_last = node.function.options.nulls_last - result = fn.list_sort( + result = fn.list.sort( previous.native, descending=descending, nulls_last=nulls_last ) return self.with_native(result, name) @@ -1033,25 +1032,25 @@ class ArrowStringNamespace( def len_chars( self, node: FExpr[strings.LenChars], frame: Frame, name: str ) -> Expr | Scalar: - return self.unary(fn.str_len_chars)(node, frame, name) + return self.unary(fn.str.len_chars)(node, frame, name) def slice(self, node: FExpr[strings.Slice], frame: Frame, name: str) -> Expr | Scalar: offset, length = node.function.offset, node.function.length - return self.unary(fn.str_slice, offset, length)(node, frame, name) + return self.unary(fn.str.slice, offset, length)(node, frame, name) def zfill(self, node: FExpr[strings.ZFill], frame: Frame, name: str) -> Expr | Scalar: - return self.unary(fn.str_zfill, node.function.length)(node, frame, name) + return self.unary(fn.str.zfill, node.function.length)(node, frame, name) def contains( self, node: FExpr[strings.Contains], frame: Frame, name: str ) -> Expr | Scalar: pattern, literal = node.function.pattern, node.function.literal - return self.unary(fn.str_contains, pattern, literal=literal)(node, frame, name) + return self.unary(fn.str.contains, pattern, literal=literal)(node, frame, name) def ends_with( self, node: FExpr[strings.EndsWith], frame: Frame, name: str ) -> Expr | Scalar: - return self.unary(fn.str_ends_with, node.function.suffix)(node, frame, name) + return self.unary(fn.str.ends_with, node.function.suffix)(node, frame, name) def replace( self, node: FExpr[strings.Replace], frame: Frame, name: str @@ -1062,11 +1061,11 @@ def replace( prev = expr.dispatch(self.compliant, frame, name) value = other.dispatch(self.compliant, frame, name) if isinstance(value, ArrowScalar): - result = fn.str_replace( + result = fn.str.replace( prev.native, pattern, value.native.as_py(), literal=literal, n=n ) elif isinstance(prev, ArrowExpr): - result = fn.str_replace_vector( + result = fn.str.replace_vector( prev.native, pattern, value.native, literal=literal, n=n ) else: @@ -1084,32 +1083,32 @@ def replace_all( return self.replace(rewrite, frame, name) def split(self, node: FExpr[strings.Split], frame: Frame, name: str) -> Expr | Scalar: - return self.unary(fn.str_split, node.function.by)(node, frame, name) + return self.unary(fn.str.split, node.function.by)(node, frame, name) def starts_with( self, node: FExpr[strings.StartsWith], frame: Frame, name: str ) -> Expr | Scalar: - return self.unary(fn.str_starts_with, node.function.prefix)(node, frame, name) + return self.unary(fn.str.starts_with, node.function.prefix)(node, frame, name) def strip_chars( self, node: FExpr[strings.StripChars], frame: Frame, name: str ) -> Expr | Scalar: - return self.unary(fn.str_strip_chars, node.function.characters)(node, frame, name) + return self.unary(fn.str.strip_chars, node.function.characters)(node, frame, name) def to_uppercase( self, node: FExpr[strings.ToUppercase], frame: Frame, name: str ) -> Expr | Scalar: - return self.unary(fn.str_to_uppercase)(node, frame, name) + return self.unary(fn.str.to_uppercase)(node, frame, name) def to_lowercase( self, node: FExpr[strings.ToLowercase], frame: Frame, name: str ) -> Expr | Scalar: - return self.unary(fn.str_to_lowercase)(node, frame, name) + return self.unary(fn.str.to_lowercase)(node, frame, name) def to_titlecase( self, node: FExpr[strings.ToTitlecase], frame: Frame, name: str ) -> Expr | Scalar: - return self.unary(fn.str_to_titlecase)(node, frame, name) + return self.unary(fn.str.to_titlecase)(node, frame, name) to_date = not_implemented() to_datetime = not_implemented() @@ -1119,4 +1118,4 @@ class ArrowStructNamespace( ExprStructNamespace["Frame", "Expr | Scalar"], ArrowAccessor[ExprOrScalarT] ): def field(self, node: FExpr[FieldByName], frame: Frame, name: str) -> Expr | Scalar: - return self.unary(fn.struct_field, node.function.name)(node, frame, name) + return self.unary(fn.struct.field, node.function.name)(node, frame, name) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py deleted file mode 100644 index 627ecd1e8b..0000000000 --- a/narwhals/_plan/arrow/functions.py +++ /dev/null @@ -1,2015 +0,0 @@ -"""Native functions, aliased and/or with behavior aligned to `polars`.""" - -from __future__ import annotations - -import math -import typing as t -from collections.abc import Callable, Collection, Iterator, Sequence -from itertools import chain -from typing import TYPE_CHECKING, Any, Final, Literal, overload - -import pyarrow as pa # ignore-banned-import -import pyarrow.compute as pc # ignore-banned-import - -from narwhals._arrow.utils import ( - cast_for_truediv, - chunked_array as _chunked_array, - concat_tables as concat_tables, # noqa: PLC0414 - floordiv_compat as _floordiv, - narwhals_to_native_dtype as _dtype_native, -) -from narwhals._plan import common, expressions as ir -from narwhals._plan._guards import is_non_nested_literal -from narwhals._plan.arrow import compat, options as pa_options -from narwhals._plan.expressions import functions as F, operators as ops -from narwhals._plan.options import ExplodeOptions, SortOptions -from narwhals._utils import Version, no_default -from narwhals.exceptions import ShapeError - -if TYPE_CHECKING: - import datetime as dt - from collections.abc import Iterable, Mapping - - from typing_extensions import Self, TypeAlias, TypeIs, TypeVarTuple, Unpack - - from narwhals._arrow.typing import Incomplete - from narwhals._plan.arrow.acero import Field - from narwhals._plan.arrow.typing import ( - Array, - ArrayAny, - Arrow, - ArrowAny, - ArrowListT, - ArrowT, - BinaryComp, - BinaryFunction, - BinaryLogical, - BinaryNumericTemporal, - BinOp, - BooleanLengthPreserving, - BooleanScalar, - BoolType, - ChunkedArray, - ChunkedArrayAny, - ChunkedList, - ChunkedOrArray, - ChunkedOrArrayAny, - ChunkedOrArrayT, - ChunkedOrScalar, - ChunkedOrScalarAny, - ChunkedOrScalarT, - ChunkedStruct, - DataType, - DataTypeRemap, - DataTypeT, - DateScalar, - IntegerScalar, - IntegerType, - ListArray, - ListScalar, - ListTypeT, - NativeScalar, - NonListTypeT, - NumericScalar, - Predicate, - SameArrowT, - Scalar, - ScalarAny, - ScalarT, - StringScalar, - StringType, - StructArray, - UInt32Type, - UnaryFunction, - UnaryNumeric, - VectorFunction, - ) - from narwhals._plan.compliant.typing import SeriesT - from narwhals._plan.options import RankOptions, SortMultipleOptions - from narwhals._plan.typing import Seq - from narwhals._typing import NoDefault - from narwhals.typing import ( - ClosedInterval, - FillNullStrategy, - IntoArrowSchema, - IntoDType, - NonNestedLiteral, - NumericLiteral, - PythonLiteral, - UniqueKeepStrategy, - ) - - Ts = TypeVarTuple("Ts") - -# NOTE: Common data type instances to share -UI32: Final = pa.uint32() -I64: Final = pa.int64() -F64: Final = pa.float64() -BOOL: Final = pa.bool_() - -EMPTY: Final = "" -"""The empty string.""" - - -class MinMax(ir.AggExpr): - """Returns a `Struct({'min': ..., 'max': ...})`. - - https://arrow.apache.org/docs/python/generated/pyarrow.compute.min_max.html#pyarrow.compute.min_max - """ - - -IntoColumnAgg: TypeAlias = Callable[[str], ir.AggExpr] -"""Helper constructor for single-column aggregations.""" - -is_null = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_null) -is_not_null = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_valid) -is_nan = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_nan) -is_finite = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_finite) -not_ = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.invert) - - -@overload -def is_not_nan(native: ChunkedArrayAny) -> ChunkedArray[pa.BooleanScalar]: ... -@overload -def is_not_nan(native: ScalarAny) -> pa.BooleanScalar: ... -@overload -def is_not_nan(native: ChunkedOrScalarAny) -> ChunkedOrScalar[pa.BooleanScalar]: ... -@overload -def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: ... -def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: - return not_(is_nan(native)) - - -and_ = t.cast("BinaryLogical", pc.and_kleene) -or_ = t.cast("BinaryLogical", pc.or_kleene) -xor = t.cast("BinaryLogical", pc.xor) - -eq = t.cast("BinaryComp", pc.equal) -not_eq = t.cast("BinaryComp", pc.not_equal) -gt_eq = t.cast("BinaryComp", pc.greater_equal) -gt = t.cast("BinaryComp", pc.greater) -lt_eq = t.cast("BinaryComp", pc.less_equal) -lt = t.cast("BinaryComp", pc.less) - - -add = t.cast("BinaryNumericTemporal", pc.add) -sub = t.cast("BinaryNumericTemporal", pc.subtract) -multiply = pc.multiply -power = t.cast("BinaryFunction[NumericScalar, NumericScalar]", pc.power) -floordiv = _floordiv -abs_ = t.cast("UnaryNumeric", pc.abs) -exp = t.cast("UnaryNumeric", pc.exp) -sqrt = t.cast("UnaryNumeric", pc.sqrt) -ceil = t.cast("UnaryNumeric", pc.ceil) -floor = t.cast("UnaryNumeric", pc.floor) - - -def truediv(lhs: Incomplete, rhs: Incomplete) -> Incomplete: - return pc.divide(*cast_for_truediv(lhs, rhs)) - - -def modulus(lhs: Incomplete, rhs: Incomplete) -> Incomplete: - floor_div = floordiv(lhs, rhs) - return sub(lhs, multiply(floor_div, rhs)) - - -# TODO @dangotbanned: Somehow fix the typing on this -# - `_ArrowDispatch` is relying on the gradual typing -_DISPATCH_BINARY: Mapping[type[ops.Operator], BinOp] = { - # BinaryComp - ops.Eq: eq, - ops.NotEq: not_eq, - ops.Lt: lt, - ops.LtEq: lt_eq, - ops.Gt: gt, - ops.GtEq: gt_eq, - # BinaryFunction (well it should be) - ops.Add: add, # BinaryNumericTemporal - ops.Sub: sub, # pyarrow-stubs - ops.Multiply: multiply, # pyarrow-stubs - ops.TrueDivide: truediv, # [[Any, Any], Any] - ops.FloorDivide: floordiv, # [[ArrayOrScalar, ArrayOrScalar], Any] - ops.Modulus: modulus, # [[Any, Any], Any] - # BinaryLogical - ops.And: and_, - ops.Or: or_, - ops.ExclusiveOr: xor, -} - - -def bin_op( - function: Callable[[Any, Any], Any], /, *, reflect: bool = False -) -> Callable[[SeriesT, Any], SeriesT]: - """Attach a binary operator to `ArrowSeries`.""" - - def f(self: SeriesT, other: SeriesT | Any, /) -> SeriesT: - right = other.native if isinstance(other, type(self)) else lit(other) - return self._with_native(function(self.native, right)) - - def f_reflect(self: SeriesT, other: SeriesT | Any, /) -> SeriesT: - if isinstance(other, type(self)): - name = other.name - right: ArrowAny = other.native - else: - name = "literal" - right = lit(other) - return self.from_native(function(right, self.native), name, version=self.version) - - return f_reflect if reflect else f - - -_IS_BETWEEN: Mapping[ClosedInterval, tuple[BinaryComp, BinaryComp]] = { - "left": (gt_eq, lt), - "right": (gt, lt_eq), - "none": (gt, lt), - "both": (gt_eq, lt_eq), -} - - -@t.overload -def dtype_native(dtype: IntoDType, version: Version) -> pa.DataType: ... -@t.overload -def dtype_native(dtype: None, version: Version) -> None: ... -@t.overload -def dtype_native(dtype: IntoDType | None, version: Version) -> pa.DataType | None: ... -def dtype_native(dtype: IntoDType | None, version: Version) -> pa.DataType | None: - return dtype if dtype is None else _dtype_native(dtype, version) - - -@t.overload -def cast( - native: Scalar[Any], target_type: DataTypeT, *, safe: bool | None = ... -) -> Scalar[DataTypeT]: ... -@t.overload -def cast( - native: ChunkedArray[Any], target_type: DataTypeT, *, safe: bool | None = ... -) -> ChunkedArray[Scalar[DataTypeT]]: ... -@t.overload -def cast( - native: ChunkedOrScalar[Scalar[Any]], - target_type: DataTypeT, - *, - safe: bool | None = ..., -) -> ChunkedArray[Scalar[DataTypeT]] | Scalar[DataTypeT]: ... -def cast( - native: ChunkedOrScalar[Scalar[Any]], - target_type: DataTypeT, - *, - safe: bool | None = None, -) -> ChunkedArray[Scalar[DataTypeT]] | Scalar[DataTypeT]: - return pc.cast(native, target_type, safe=safe) - - -def cast_schema( - native: pa.Schema, target_types: DataType | Mapping[str, DataType] | DataTypeRemap -) -> pa.Schema: - if isinstance(target_types, pa.DataType): - return pa.schema((name, target_types) for name in native.names) - if _is_into_pyarrow_schema(target_types): - new_schema = native - for name, dtype in target_types.items(): - index = native.get_field_index(name) - new_schema.set(index, native.field(index).with_type(dtype)) - return new_schema - return pa.schema((fld.name, target_types.get(fld.type, fld.type)) for fld in native) - - -def cast_table( - native: pa.Table, target: DataType | IntoArrowSchema | DataTypeRemap -) -> pa.Table: - s = target if isinstance(target, pa.Schema) else cast_schema(native.schema, target) - return native.cast(s) - - -def has_large_string(data_types: Iterable[DataType], /) -> bool: - return any(pa.types.is_large_string(tp) for tp in data_types) - - -def string_type(data_types: Iterable[DataType] = (), /) -> StringType: - """Return a native string type, compatible with `data_types`. - - Until [apache/arrow#45717] is resolved, we need to upcast `string` to `large_string` when joining. - - [apache/arrow#45717]: https://github.com/apache/arrow/issues/45717 - """ - return pa.large_string() if has_large_string(data_types) else pa.string() - - -# NOTE: `mypy` isn't happy, but this broadcasting behavior is worth documenting -@t.overload -def struct(names: Iterable[str], columns: Iterable[ChunkedArrayAny]) -> ChunkedStruct: ... -@t.overload -def struct(names: Iterable[str], columns: Iterable[ArrayAny]) -> pa.StructArray: ... -@t.overload -def struct( # type: ignore[overload-overlap] - names: Iterable[str], columns: Iterable[ScalarAny] | Iterable[NonNestedLiteral] -) -> pa.StructScalar: ... -@t.overload -def struct( # type: ignore[overload-overlap] - names: Iterable[str], columns: Iterable[ChunkedArrayAny | NonNestedLiteral] -) -> ChunkedStruct: ... -@t.overload -def struct( - names: Iterable[str], columns: Iterable[ArrayAny | NonNestedLiteral] -) -> pa.StructArray: ... -@t.overload -def struct(names: Iterable[str], columns: Iterable[ArrowAny]) -> Incomplete: ... -def struct(names: Iterable[str], columns: Iterable[Incomplete]) -> Incomplete: - """Collect columns into a struct. - - Arguments: - names: Names of the struct fields to create. - columns: Value(s) to collect into a struct. Scalars will will be broadcast unless all - inputs are scalar. - """ - return pc.make_struct( - *columns, options=pc.MakeStructOptions(common.ensure_seq_str(names)) - ) - - -def struct_schema(native: Arrow[pa.StructScalar] | pa.StructType) -> pa.Schema: - """Get the struct definition as a schema.""" - tp = native.type if _is_arrow(native) else native - fields = tp.fields if compat.HAS_STRUCT_TYPE_FIELDS else list(tp) - return pa.schema(fields) - - -def struct_field_names(native: Arrow[pa.StructScalar] | pa.StructType) -> list[str]: - """Get the names of all struct fields.""" - tp = native.type if _is_arrow(native) else native - return tp.names if compat.HAS_STRUCT_TYPE_FIELDS else [f.name for f in tp] - - -@t.overload -def struct_field(native: ChunkedStruct, field: Field, /) -> ChunkedArrayAny: ... -@t.overload -def struct_field(native: StructArray, field: Field, /) -> ArrayAny: ... -@t.overload -def struct_field(native: pa.StructScalar, field: Field, /) -> ScalarAny: ... -@t.overload -def struct_field(native: SameArrowT, field: Field, /) -> SameArrowT: ... -@t.overload -def struct_field(native: ChunkedOrScalarAny, field: Field, /) -> ChunkedOrScalarAny: ... -def struct_field(native: ArrowAny, field: Field, /) -> ArrowAny: - """Retrieve one `Struct` field.""" - func = t.cast("Callable[[Any,Any], ArrowAny]", pc.struct_field) - return func(native, field) - - -@t.overload -def struct_fields(native: ChunkedStruct, *fields: Field) -> Seq[ChunkedArrayAny]: ... -@t.overload -def struct_fields(native: StructArray, *fields: Field) -> Seq[ArrayAny]: ... -@t.overload -def struct_fields(native: pa.StructScalar, *fields: Field) -> Seq[ScalarAny]: ... -@t.overload -def struct_fields(native: SameArrowT, *fields: Field) -> Seq[SameArrowT]: ... -def struct_fields(native: ArrowAny, *fields: Field) -> Seq[ArrowAny]: - """Retrieve multiple `Struct` fields.""" - func = t.cast("Callable[[Any,Any], ArrowAny]", pc.struct_field) - return tuple(func(native, name) for name in fields) - - -def get_categories(native: ArrowAny) -> ChunkedArrayAny: - da: Incomplete - if isinstance(native, pa.ChunkedArray): - da = native.unify_dictionaries().chunk(0) - else: - da = native - return chunked_array(da.dictionary) - - -class ExplodeBuilder: - """Tools for exploding lists. - - The complexity of these operations increases with: - - Needing to preserve null/empty elements - - All variants are cheaper if this can be skipped - - Exploding in the context of a table - - Where a single column is much simpler than multiple - """ - - options: ExplodeOptions - - def __init__(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> None: - self.options = ExplodeOptions(empty_as_null=empty_as_null, keep_nulls=keep_nulls) - - @classmethod - def from_options(cls, options: ExplodeOptions, /) -> Self: - obj = cls.__new__(cls) - obj.options = options - return obj - - @t.overload - def explode( - self, native: ChunkedList[DataTypeT] | ListScalar[DataTypeT] - ) -> ChunkedArray[Scalar[DataTypeT]]: ... - @t.overload - def explode(self, native: ListArray[DataTypeT]) -> Array[Scalar[DataTypeT]]: ... - @t.overload - def explode( - self, native: Arrow[ListScalar[DataTypeT]] - ) -> ChunkedOrArray[Scalar[DataTypeT]]: ... - def explode( - self, native: Arrow[ListScalar[DataTypeT]] - ) -> ChunkedOrArray[Scalar[DataTypeT]]: - """Explode list elements, expanding one-level into a new array. - - Equivalent to `polars.{Expr,Series}.explode`. - """ - safe = self._fill_with_null(native) if self.options.any() else native - if not isinstance(safe, pa.Scalar): - return _list_explode(safe) - return chunked_array(_list_explode(safe)) - - def explode_with_indices(self, native: ChunkedList | ListArray) -> pa.Table: - """Explode list elements, expanding one-level into a table indexing the origin. - - Returns a 2-column table, with names `"idx"` and `"values"`: - - >>> from narwhals._plan.arrow import functions as fn - >>> - >>> arr = fn.array([[1, 2, 3], None, [4, 5, 6], []]) - >>> fn.ExplodeBuilder().explode_with_indices(arr).to_pydict() - {'idx': [0, 0, 0, 1, 2, 2, 2, 3], 'values': [1, 2, 3, None, 4, 5, 6, None]} - # ^ Which sublist values come from ^ The exploded values themselves - """ - safe = self._fill_with_null(native) if self.options.any() else native - arrays = [_list_parent_indices(safe), _list_explode(safe)] - return concat_horizontal(arrays, ["idx", "values"]) - - def explode_column(self, native: pa.Table, column_name: str, /) -> pa.Table: - """Explode a list-typed column in the context of `native`.""" - ca = native.column(column_name) - if native.num_columns == 1: - return native.from_arrays([self.explode(ca)], [column_name]) - safe = self._fill_with_null(ca) if self.options.any() else ca - exploded = _list_explode(safe) - col_idx = native.schema.get_field_index(column_name) - if len(exploded) == len(native): - return native.set_column(col_idx, column_name, exploded) - return ( - native.remove_column(col_idx) - .take(_list_parent_indices(safe)) - .add_column(col_idx, column_name, exploded) - ) - - def explode_columns(self, native: pa.Table, subset: Collection[str], /) -> pa.Table: - """Explode multiple list-typed columns in the context of `native`.""" - subset = list(subset) - arrays = native.select(subset).columns - first = arrays[0] - first_len = list_len(first) - if self.options.any(): - mask = self._predicate(first_len) - first_safe = self._fill_with_null(first, mask) - it = ( - _list_explode(self._fill_with_null(arr, mask)) - for arr in self._iter_ensure_shape(first_len, arrays[1:]) - ) - else: - first_safe = first - it = ( - _list_explode(arr) - for arr in self._iter_ensure_shape(first_len, arrays[1:]) - ) - column_names = native.column_names - result = native - first_result = _list_explode(first_safe) - if len(first_result) == len(native): - # fastpath for all length-1 lists - # if only the first is length-1, then the others raise during iteration on either branch - for name, arr in zip(subset, chain([first_result], it)): - result = result.set_column(column_names.index(name), name, arr) - else: - result = result.drop_columns(subset).take(_list_parent_indices(first_safe)) - for name, arr in zip(subset, chain([first_result], it)): - result = result.append_column(name, arr) - result = result.select(column_names) - return result - - @classmethod - def explode_column_fast(cls, native: pa.Table, column_name: str, /) -> pa.Table: - """Explode a list-typed column in the context of `native`, ignoring empty and nulls.""" - return cls(empty_as_null=False, keep_nulls=False).explode_column( - native, column_name - ) - - def _iter_ensure_shape( - self, - first_len: ChunkedArray[pa.UInt32Scalar], - arrays: Iterable[ChunkedArrayAny], - /, - ) -> Iterator[ChunkedArrayAny]: - for arr in arrays: - if not first_len.equals(list_len(arr)): - msg = "exploded columns must have matching element counts" - raise ShapeError(msg) - yield arr - - def _predicate(self, lengths: ArrowAny, /) -> Arrow[pa.BooleanScalar]: - """Return True for each sublist length that indicates the original sublist should be replaced with `[None]`.""" - empty_as_null, keep_nulls = self.options.empty_as_null, self.options.keep_nulls - if empty_as_null and keep_nulls: - return or_(is_null(lengths), eq(lengths, lit(0))) - if empty_as_null: - return eq(lengths, lit(0)) - return is_null(lengths) - - def _fill_with_null( - self, native: ArrowListT, mask: Arrow[BooleanScalar] | NoDefault = no_default - ) -> ArrowListT: - """Replace each sublist in `native` with `[None]`, according to `self.options`. - - Arguments: - native: List-typed arrow data. - mask: An optional, pre-computed replacement mask. By default, this is generated from `native`. - """ - predicate = self._predicate(list_len(native)) if mask is no_default else mask - result: ArrowListT = when_then(predicate, lit([None], native.type), native) - return result - - -def implode(native: Arrow[Scalar[DataTypeT]]) -> pa.ListScalar[DataTypeT]: - """Aggregate values into a list. - - The returned list itself is a scalar value of `list` dtype. - """ - arr = array(native) - return pa.ListArray.from_arrays([0, len(arr)], arr)[0] - - -@t.overload -def _list_explode(native: ChunkedList[DataTypeT]) -> ChunkedArray[Scalar[DataTypeT]]: ... -@t.overload -def _list_explode( - native: ListArray[NonListTypeT] | ListScalar[NonListTypeT], -) -> Array[Scalar[NonListTypeT]]: ... -@t.overload -def _list_explode(native: ListArray[DataTypeT]) -> Array[Scalar[DataTypeT]]: ... -@t.overload -def _list_explode(native: ListScalar[ListTypeT]) -> ListArray[ListTypeT]: ... -def _list_explode(native: Arrow[ListScalar]) -> ChunkedOrArrayAny: - result: ChunkedOrArrayAny = pc.call_function("list_flatten", [native]) - return result - - -@t.overload -def _list_parent_indices(native: ChunkedList) -> ChunkedArray[pa.Int64Scalar]: ... -@t.overload -def _list_parent_indices(native: ListArray) -> pa.Int64Array: ... -def _list_parent_indices( - native: ChunkedOrArray[ListScalar], -) -> ChunkedOrArray[pa.Int64Scalar]: - """Don't use this withut handling nulls!""" - result: ChunkedOrArray[pa.Int64Scalar] = pc.call_function( - "list_parent_indices", [native] - ) - return result - - -@t.overload -def list_len(native: ChunkedList) -> ChunkedArray[pa.UInt32Scalar]: ... -@t.overload -def list_len(native: ListArray) -> pa.UInt32Array: ... -@t.overload -def list_len(native: ListScalar) -> pa.UInt32Scalar: ... -@t.overload -def list_len(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[pa.UInt32Scalar]: ... -@t.overload -def list_len(native: Arrow[ListScalar[Any]]) -> Arrow[pa.UInt32Scalar]: ... -def list_len(native: ArrowAny) -> ArrowAny: - length: Incomplete = pc.list_value_length - result: ArrowAny = length(native).cast(pa.uint32()) - return result - - -@t.overload -def list_get( - native: ChunkedList[DataTypeT], index: int -) -> ChunkedArray[Scalar[DataTypeT]]: ... -@t.overload -def list_get(native: ListArray[DataTypeT], index: int) -> Array[Scalar[DataTypeT]]: ... -@t.overload -def list_get(native: ListScalar[DataTypeT], index: int) -> Scalar[DataTypeT]: ... -@t.overload -def list_get(native: SameArrowT, index: int) -> SameArrowT: ... -@t.overload -def list_get(native: ChunkedOrScalarAny, index: int) -> ChunkedOrScalarAny: ... -def list_get(native: ArrowAny, index: int) -> ArrowAny: - list_get_: Incomplete = pc.list_element - result: ArrowAny = list_get_(native, index) - return result - - -_list_join = t.cast( - "Callable[[ChunkedOrArrayAny, Arrow[StringScalar] | str], ChunkedArray[StringScalar] | pa.StringArray]", - pc.binary_join, -) - - -# NOTE: Raised for native null-handling (https://github.com/apache/arrow/issues/48477) -@t.overload -def list_join( - native: ChunkedList[StringType], - separator: Arrow[StringScalar] | str, - *, - ignore_nulls: bool = ..., -) -> ChunkedArray[StringScalar]: ... -@t.overload -def list_join( - native: ListArray[StringType], - separator: Arrow[StringScalar] | str, - *, - ignore_nulls: bool = ..., -) -> pa.StringArray: ... -@t.overload -def list_join( - native: ChunkedOrArray[ListScalar[StringType]], - separator: str, - *, - ignore_nulls: bool = ..., -) -> ChunkedOrArray[StringScalar]: ... -def list_join( - native: ChunkedOrArrayAny, - separator: Arrow[StringScalar] | str, - *, - ignore_nulls: bool = True, -) -> ChunkedOrArrayAny: - """Join all string items in a sublist and place a separator between them. - - Each list of values in the first input is joined using each second input as separator. - If any input list is null or contains a null, the corresponding output will be null. - """ - from narwhals._plan.arrow.group_by import AggSpec - - # (1): Try to return *as-is* from `pc.binary_join` - result = _list_join(native, separator) - if not ignore_nulls or not result.null_count: - return result - is_null_sensitive = pc.and_not(result.is_null(), native.is_null()) - if array(is_null_sensitive, BOOL).true_count == 0: - return result - - # (2): Deal with only the bad kids - lists = native.filter(is_null_sensitive) - - # (2.1): We know that `[None]` should join as `""`, and that is the only length-1 list we could have after the filter - list_len_eq_1 = eq(list_len(lists), lit(1, UI32)) - has_a_len_1_null = any_(list_len_eq_1).as_py() - if has_a_len_1_null: - lists = when_then(list_len_eq_1, lit([EMPTY], lists.type), lists) - - # (2.2): Everything left falls into one of these boxes: - # - (2.1): `[""]` - # - (2.2): `["something", (str | None)*, None]` <--- We fix this here and hope for the best - # - (2.3): `[None, (None)*, None]` - idx, v = "idx", "values" - builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) - explode_w_idx = builder.explode_with_indices(lists) - implode_by_idx = AggSpec.implode(v).over(explode_w_idx.drop_null(), [idx]) - replacements = _list_join(implode_by_idx.column(v), separator) - - # (2.3): The cursed box 😨 - if len(replacements) != len(lists): - # This is a very unlucky case to hit, because we *can* detect the issue earlier - # but we *can't* join a table with a list in it. So we deal with the fallout now ... - # The end result is identical to (2.1) - indices_all = to_table(explode_w_idx.column(idx).unique(), idx) - indices_repaired = implode_by_idx.set_column(1, v, replacements) - replacements = ( - indices_all.join(indices_repaired, idx) - .sort_by(idx) - .column(v) - .fill_null(lit(EMPTY, lists.type.value_type)) - ) - return replace_with_mask(result, is_null_sensitive, replacements) - - -def list_join_scalar( - native: ListScalar[StringType], - separator: StringScalar | str, - *, - ignore_nulls: bool = True, -) -> StringScalar: - """Join all string items in a `ListScalar` and place a separator between them. - - Note: - Consider using `list_join` or `str_join` if you don't already have `native` in this shape. - """ - if ignore_nulls and native.is_valid: - native = implode(_list_explode(native).drop_null()) - result: StringScalar = pc.call_function("binary_join", [native, separator]) - return result - - -@overload -def list_unique(native: ChunkedList) -> ChunkedList: ... -@overload -def list_unique(native: ListScalar) -> ListScalar: ... -@overload -def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScalar]: ... -def list_unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScalar]: - """Get the unique/distinct values in the list. - - There's lots of tricky stuff going on in here, but for good reasons! - - Whenever possible, we want to avoid having to deal with these pesky guys: - - [["okay", None, "still fine"], None, []] - # ^^^^ ^^ - - - Those kinds of list elements are ignored natively - - `unique` is length-changing operation - - We can't use [`pc.replace_with_mask`] on a list - - We can't join when a table contains list columns [apache/arrow#43716] - - **But** - if we're lucky, and we got a non-awful list (or only one element) - then - most issues vanish. - - [`pc.replace_with_mask`]: https://arrow.apache.org/docs/python/generated/pyarrow.compute.replace_with_mask.html - [apache/arrow#43716]: https://github.com/apache/arrow/issues/43716 - """ - from narwhals._plan.arrow.group_by import AggSpec - - if isinstance(native, pa.Scalar): - scalar = t.cast("pa.ListScalar[Any]", native) - if scalar.is_valid and (len(scalar) > 1): - return implode(_list_explode(native).unique()) - return scalar - idx, v = "index", "values" - names = idx, v - len_not_eq_0 = not_eq(list_len(native), lit(0)) - can_fastpath = all_(len_not_eq_0, ignore_nulls=False).as_py() - if can_fastpath: - arrays = [_list_parent_indices(native), _list_explode(native)] - return AggSpec.unique(v).over_index(concat_horizontal(arrays, names), idx) - # Oh no - we caught a bad one! - # We need to split things into good/bad - and only work on the good stuff. - # `int_range` is acting like `parent_indices`, but doesn't give up when it see's `None` or `[]` - indexed = concat_horizontal([int_range(len(native)), native], names) - valid = indexed.filter(len_not_eq_0) - invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) - # To keep track of where we started, our index needs to be exploded with the list elements - explode_with_index = ExplodeBuilder.explode_column_fast(valid, v) - valid_unique = AggSpec.unique(v).over(explode_with_index, [idx]) - # And now, because we can't join - we do a poor man's version of one 😉 - return concat_tables([valid_unique, invalid]).sort_by(idx).column(v) - - -def list_contains( - native: ChunkedOrScalar[ListScalar], item: NonNestedLiteral | ScalarAny -) -> ChunkedOrScalar[pa.BooleanScalar]: - from narwhals._plan.arrow.group_by import AggSpec - - if isinstance(native, pa.Scalar): - scalar = t.cast("pa.ListScalar[Any]", native) - if scalar.is_valid: - if len(scalar): - value_type = scalar.type.value_type - return any_(eq_missing(_list_explode(scalar), lit(item).cast(value_type))) - return lit(False, BOOL) - return lit(None, BOOL) - builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) - tbl = builder.explode_with_indices(native) - idx, name = tbl.column_names - contains = eq_missing(tbl.column(name), item) - l_contains = AggSpec.any(name).over_index(tbl.set_column(1, name, contains), idx) - # Here's the really key part: this mask has the same result we want to return - # So by filling the `True`, we can flip those to `False` if needed - # But if we were already `None` or `False` - then that's sticky - propagate_invalid: ChunkedArray[pa.BooleanScalar] = not_eq(list_len(native), lit(0)) - return replace_with_mask(propagate_invalid, propagate_invalid, l_contains) - - -def list_sort( - native: ChunkedList, *, descending: bool = False, nulls_last: bool = False -) -> ChunkedList: - """Sort the sublists in this column. - - Works in a similar way to `list_unique` and `list_join`. - - 1. Select only sublists that require sorting (`None`, 0-length, and 1-length lists are noops) - 2. Explode -> Sort -> Implode -> Concat - """ - from narwhals._plan.arrow.group_by import AggSpec - - idx, v = "idx", "values" - is_not_sorted = gt(list_len(native), lit(1)) - indexed = concat_horizontal([int_range(len(native)), native], [idx, v]) - exploded = ExplodeBuilder.explode_column_fast(indexed.filter(is_not_sorted), v) - indices = sort_indices( - exploded, idx, v, descending=[False, descending], nulls_last=nulls_last - ) - exploded_sorted = exploded.take(indices) - implode_by_idx = AggSpec.implode(v).over(exploded_sorted, [idx]) - passthrough = indexed.filter(fill_null(not_(is_not_sorted), True)) - return concat_tables([implode_by_idx, passthrough]).sort_by(idx).column(v) - - -def list_sort_scalar( - native: ListScalar[NonListTypeT], options: SortOptions | None = None -) -> pa.ListScalar[NonListTypeT]: - native = t.cast("pa.ListScalar[NonListTypeT]", native) - if native.is_valid and len(native) > 1: - arr = _list_explode(native) - return implode(arr.take(sort_indices(arr, options=options))) - return native - - -def str_join( - native: Arrow[StringScalar], separator: str, *, ignore_nulls: bool = True -) -> StringScalar: - """Vertically concatenate the string values in the column to a single string value.""" - if isinstance(native, pa.Scalar): - # already joined - return native - if ignore_nulls and native.null_count: - native = native.drop_null() - return list_join_scalar(implode(native), separator, ignore_nulls=False) - - -def str_len_chars(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: - len_chars: Incomplete = pc.utf8_length - result: ChunkedOrScalarAny = len_chars(native) - return result - - -def str_slice( - native: ChunkedOrScalarAny, offset: int, length: int | None = None -) -> ChunkedOrScalarAny: - stop = length if length is None else offset + length - return pc.utf8_slice_codeunits(native, offset, stop=stop) - - -def str_pad_start( - native: ChunkedOrScalarAny, length: int, fill_char: str = " " -) -> ChunkedOrScalarAny: # pragma: no cover - return pc.utf8_lpad(native, length, fill_char) - - -@t.overload -def str_find( - native: ChunkedArrayAny, - pattern: str, - *, - literal: bool = ..., - not_found: int | None = ..., -) -> ChunkedArray[IntegerScalar]: ... -@t.overload -def str_find( - native: Array, pattern: str, *, literal: bool = ..., not_found: int | None = ... -) -> Array[IntegerScalar]: ... -@t.overload -def str_find( - native: ScalarAny, pattern: str, *, literal: bool = ..., not_found: int | None = ... -) -> IntegerScalar: ... -def str_find( - native: Arrow[StringScalar], - pattern: str, - *, - literal: bool = False, - not_found: int | None = -1, -) -> Arrow[IntegerScalar]: - """Return the bytes offset of the first substring matching a pattern. - - To match `pl.Expr.str.find` behavior, pass `not_found=None`. - - Note: - `pyarrow` distinguishes null *inputs* with `None` and failed matches with `-1`. - """ - # NOTE: `pyarrow-stubs` uses concrete types here - fn_name = "find_substring" if literal else "find_substring_regex" - result: Arrow[IntegerScalar] = pc.call_function( - fn_name, [native], pa_options.match_substring(pattern) - ) - if not_found == -1: - return result - return when_then(eq(result, lit(-1)), lit(not_found, result.type), result) - - -_StringFunction0: TypeAlias = "Callable[[ChunkedOrScalarAny], ChunkedOrScalarAny]" -_StringFunction1: TypeAlias = "Callable[[ChunkedOrScalarAny, str], ChunkedOrScalarAny]" -str_starts_with = t.cast("_StringFunction1", pc.starts_with) -str_ends_with = t.cast("_StringFunction1", pc.ends_with) -str_to_uppercase = t.cast("_StringFunction0", pc.utf8_upper) -str_to_lowercase = t.cast("_StringFunction0", pc.utf8_lower) -str_to_titlecase = t.cast("_StringFunction0", pc.utf8_title) - - -def _str_split( - native: ArrowAny, by: str, n: int | None = None, *, literal: bool = True -) -> Arrow[ListScalar]: - name = "split_pattern" if literal else "split_pattern_regex" - result: Arrow[ListScalar] = pc.call_function( - name, [native], pa_options.split_pattern(by, n) - ) - return result - - -@t.overload -def str_split( - native: ChunkedArrayAny, by: str, *, literal: bool = ... -) -> ChunkedArray[ListScalar]: ... -@t.overload -def str_split( - native: ChunkedOrScalarAny, by: str, *, literal: bool = ... -) -> ChunkedOrScalar[ListScalar]: ... -@t.overload -def str_split(native: ArrayAny, by: str, *, literal: bool = ...) -> pa.ListArray[Any]: ... -@t.overload -def str_split(native: ArrowAny, by: str, *, literal: bool = ...) -> Arrow[ListScalar]: ... -def str_split(native: ArrowAny, by: str, *, literal: bool = True) -> Arrow[ListScalar]: - return _str_split(native, by, literal=literal) - - -@t.overload -def str_splitn( - native: ChunkedArrayAny, - by: str, - n: int, - *, - literal: bool = ..., - as_struct: bool = ..., -) -> ChunkedArray[ListScalar]: ... -@t.overload -def str_splitn( - native: ChunkedOrScalarAny, - by: str, - n: int, - *, - literal: bool = ..., - as_struct: bool = ..., -) -> ChunkedOrScalar[ListScalar]: ... -@t.overload -def str_splitn( - native: ArrayAny, by: str, n: int, *, literal: bool = ..., as_struct: bool = ... -) -> pa.ListArray[Any]: ... -@t.overload -def str_splitn( - native: ArrowAny, by: str, n: int, *, literal: bool = ..., as_struct: bool = ... -) -> Arrow[ListScalar]: ... -def str_splitn( - native: ArrowAny, by: str, n: int, *, literal: bool = True, as_struct: bool = False -) -> Arrow[ListScalar]: - """Split the string by a substring, restricted to returning at most `n` items.""" - result = _str_split(native, by, n, literal=literal) - if as_struct: - # NOTE: `polars` would return a struct w/ field names (`'field_0`, ..., 'field_n-1`) - msg = "TODO: `ArrowExpr.str.splitn`" - raise NotImplementedError(msg) - return result - - -@t.overload -def str_contains( - native: ChunkedArrayAny, pattern: str, *, literal: bool = ... -) -> ChunkedArray[pa.BooleanScalar]: ... -@t.overload -def str_contains( - native: ChunkedOrScalarAny, pattern: str, *, literal: bool = ... -) -> ChunkedOrScalar[pa.BooleanScalar]: ... -@t.overload -def str_contains( - native: ArrowAny, pattern: str, *, literal: bool = ... -) -> Arrow[pa.BooleanScalar]: ... -def str_contains( - native: ArrowAny, pattern: str, *, literal: bool = False -) -> Arrow[pa.BooleanScalar]: - """Check if the string contains a substring that matches a pattern.""" - name = "match_substring" if literal else "match_substring_regex" - result: Arrow[pa.BooleanScalar] = pc.call_function( - name, [native], pa_options.match_substring(pattern) - ) - return result - - -def str_strip_chars(native: Incomplete, characters: str | None) -> Incomplete: - if characters: - return pc.utf8_trim(native, characters) - return pc.utf8_trim_whitespace(native) - - -def str_replace( - native: Incomplete, pattern: str, value: str, *, literal: bool = False, n: int = 1 -) -> Incomplete: - fn = pc.replace_substring if literal else pc.replace_substring_regex - return fn(native, pattern, replacement=value, max_replacements=n) - - -def str_replace_all( - native: Incomplete, pattern: str, value: str, *, literal: bool = False -) -> Incomplete: - return str_replace(native, pattern, value, literal=literal, n=-1) - - -def str_replace_vector( - native: ChunkedArrayAny, - pattern: str, - replacements: ChunkedArrayAny, - *, - literal: bool = False, - n: int | None = 1, -) -> ChunkedArrayAny: - has_match = str_contains(native, pattern, literal=literal) - if not any_(has_match).as_py(): - # fastpath, no work to do - return native - match, match_replacements = filter_arrays(has_match, native, replacements) - if n is None or n == -1: - list_split_by = str_split(match, pattern, literal=literal) - else: - list_split_by = str_splitn(match, pattern, n + 1, literal=literal) - replaced = list_join(list_split_by, match_replacements, ignore_nulls=False) - if all_(has_match, ignore_nulls=False).as_py(): - return chunked_array(replaced) - return replace_with_mask(native, has_match, array(replaced)) - - -def str_zfill(native: ChunkedOrScalarAny, length: int) -> ChunkedOrScalarAny: - if compat.HAS_ZFILL: - zfill: Incomplete = pc.utf8_zero_fill # type: ignore[attr-defined] - result: ChunkedOrScalarAny = zfill(native, length) - else: - result = _str_zfill_compat(native, length) - return result - - -# TODO @dangotbanned: Finish tidying this up -def _str_zfill_compat( - native: ChunkedOrScalarAny, length: int -) -> Incomplete: # pragma: no cover - dtype = string_type([native.type]) - hyphen, plus = lit("-", dtype), lit("+", dtype) - - padded_remaining = str_pad_start(str_slice(native, 1), length - 1, "0") - padded_lt_length = str_pad_start(native, length, "0") - - binary_join: Incomplete = pc.binary_join_element_wise - if isinstance(native, pa.Scalar): - case_1: ArrowAny = hyphen # starts with hyphen and less than length - case_2: ArrowAny = plus # starts with plus and less than length - else: - arr_len = len(native) - case_1 = repeat_unchecked(hyphen, arr_len) - case_2 = repeat_unchecked(plus, arr_len) - - first_char = str_slice(native, 0, 1) - lt_length = lt(str_len_chars(native), lit(length)) - first_hyphen_lt_length = and_(eq(first_char, hyphen), lt_length) - first_plus_lt_length = and_(eq(first_char, plus), lt_length) - return when_then( - first_hyphen_lt_length, - binary_join(case_1, padded_remaining, ""), - when_then( - first_plus_lt_length, - binary_join(case_2, padded_remaining, ""), - when_then(lt_length, padded_lt_length, native), - ), - ) - - -@t.overload -def when_then( - predicate: ChunkedArray[BooleanScalar], then: ScalarAny -) -> ChunkedArrayAny: ... -@t.overload -def when_then(predicate: Array[BooleanScalar], then: ScalarAny) -> ArrayAny: ... -@t.overload -def when_then( - predicate: Predicate, then: SameArrowT, otherwise: SameArrowT | None -) -> SameArrowT: ... -@t.overload -def when_then(predicate: Predicate, then: ScalarAny, otherwise: ArrowT) -> ArrowT: ... -@t.overload -def when_then( - predicate: Predicate, then: ArrowT, otherwise: ScalarAny | NonNestedLiteral = ... -) -> ArrowT: ... -@t.overload -def when_then( - predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None -) -> Incomplete: ... -def when_then( - predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None -) -> Incomplete: - """Thin wrapper around `pyarrow.compute.if_else`. - - - Supports a 2-arg form, like `pl.when(...).then(...)` - - Accepts python literals, but only in the `otherwise` position - """ - if is_non_nested_literal(otherwise): - otherwise = lit(otherwise, then.type) - return pc.if_else(predicate, then, otherwise) - - -def any_(native: Incomplete, *, ignore_nulls: bool = True) -> pa.BooleanScalar: - return pc.any(native, min_count=0, skip_nulls=ignore_nulls) - - -def all_(native: Incomplete, *, ignore_nulls: bool = True) -> pa.BooleanScalar: - return pc.all(native, min_count=0, skip_nulls=ignore_nulls) - - -def sum_(native: Incomplete) -> NativeScalar: - return pc.sum(native, min_count=0) - - -def first(native: ChunkedOrArrayAny) -> NativeScalar: - return pc.first(native, options=pa_options.scalar_aggregate()) - - -def last(native: ChunkedOrArrayAny) -> NativeScalar: - return pc.last(native, options=pa_options.scalar_aggregate()) - - -min_ = pc.min -# TODO @dangotbanned: Wrap horizontal functions with correct typing -# Should only return scalar if all elements are as well -min_horizontal = pc.min_element_wise -max_ = pc.max -max_horizontal = pc.max_element_wise -mean = t.cast("Callable[[ChunkedOrArray[pc.NumericScalar]], pa.DoubleScalar]", pc.mean) -count = pc.count -median = pc.approximate_median -std = pc.stddev -var = pc.variance -quantile = pc.quantile - - -def mode_all(native: ChunkedArrayAny) -> ChunkedArrayAny: - struct = pc.mode(native, n=len(native)) - indices: pa.Int32Array = struct.field("count").dictionary_encode().indices # type: ignore[attr-defined] - index_true_modes = lit(0) - return chunked_array(struct.field("mode").filter(pc.equal(indices, index_true_modes))) - - -def mode_any(native: ChunkedArrayAny) -> NativeScalar: - return first(pc.mode(native, n=1).field("mode")) - - -def kurtosis_skew( - native: ChunkedArray[pc.NumericScalar], function: Literal["kurtosis", "skew"], / -) -> NativeScalar: - result: NativeScalar - if compat.HAS_KURTOSIS_SKEW: - if pa.types.is_null(native.type): - native = native.cast(F64) - result = getattr(pc, function)(native) - else: - non_null = native.drop_null() - if len(non_null) == 0: - result = lit(None, F64) - elif len(non_null) == 1: - result = lit(float("nan")) - elif function == "skew" and len(non_null) == 2: - result = lit(0.0, F64) - else: - m = sub(non_null, mean(non_null)) - m2 = mean(power(m, lit(2))) - if function == "kurtosis": - m4 = mean(power(m, lit(4))) - result = sub(pc.divide(m4, power(m2, lit(2))), lit(3)) - else: - m3 = mean(power(m, lit(3))) - result = pc.divide(m3, power(m2, lit(1.5))) - return result - - -def clip_lower( - native: ChunkedOrScalarAny, lower: ChunkedOrScalarAny -) -> ChunkedOrScalarAny: - return max_horizontal(native, lower) - - -def clip_upper( - native: ChunkedOrScalarAny, upper: ChunkedOrScalarAny -) -> ChunkedOrScalarAny: - return min_horizontal(native, upper) - - -def clip( - native: ChunkedOrScalarAny, lower: ChunkedOrScalarAny, upper: ChunkedOrScalarAny -) -> ChunkedOrScalarAny: - return clip_lower(clip_upper(native, upper), lower) - - -def n_unique(native: Any) -> pa.Int64Scalar: - return count(native, mode="all") - - -@t.overload -def round(native: ChunkedOrScalarAny, decimals: int = ...) -> ChunkedOrScalarAny: ... -@t.overload -def round(native: ChunkedOrArrayT, decimals: int = ...) -> ChunkedOrArrayT: ... -def round(native: ArrowAny, decimals: int = 0) -> ArrowAny: - return pc.round(native, decimals, round_mode="half_towards_infinity") - - -def log(native: ChunkedOrScalarAny, base: float = math.e) -> ChunkedOrScalarAny: - return t.cast("ChunkedOrScalarAny", pc.logb(native, lit(base))) - - -def reverse(native: ChunkedOrArrayT) -> ChunkedOrArrayT: - """Unlike other slicing ops, `[::-1]` creates a full-copy. - - https://github.com/apache/arrow/issues/19103#issuecomment-1377671886 - """ - return native[::-1] - - -def cum_sum(native: ChunkedOrArrayT) -> ChunkedOrArrayT: - return pc.cumulative_sum(native, skip_nulls=True) - - -def cum_min(native: ChunkedOrArrayT) -> ChunkedOrArrayT: - return pc.cumulative_min(native, skip_nulls=True) - - -def cum_max(native: ChunkedOrArrayT) -> ChunkedOrArrayT: - return pc.cumulative_max(native, skip_nulls=True) - - -def cum_prod(native: ChunkedOrArrayT) -> ChunkedOrArrayT: - return pc.cumulative_prod(native, skip_nulls=True) - - -def cum_count(native: ChunkedArrayAny) -> ChunkedArrayAny: - return cum_sum(is_not_null(native).cast(pa.uint32())) - - -_CUMULATIVE: Mapping[type[F.CumAgg], Callable[[ChunkedArrayAny], ChunkedArrayAny]] = { - F.CumSum: cum_sum, - F.CumCount: cum_count, - F.CumMin: cum_min, - F.CumMax: cum_max, - F.CumProd: cum_prod, -} - - -def cumulative(native: ChunkedArrayAny, f: F.CumAgg, /) -> ChunkedArrayAny: - func = _CUMULATIVE[type(f)] - return func(native) if not f.reverse else reverse(func(reverse(native))) - - -def diff(native: ChunkedOrArrayT, n: int = 1) -> ChunkedOrArrayT: - # pyarrow.lib.ArrowInvalid: Vector kernel cannot execute chunkwise and no chunked exec function was defined - return ( - pc.pairwise_diff(native, n) - if isinstance(native, pa.Array) - else chunked_array(pc.pairwise_diff(native.combine_chunks(), n)) - ) - - -def shift( - native: ChunkedArrayAny, n: int, *, fill_value: NonNestedLiteral = None -) -> ChunkedArrayAny: - if n == 0: - return native - arr = native - if n > 0: - filled = repeat_like(fill_value, n, arr) - arrays = [filled, *arr.slice(length=arr.length() - n).chunks] - else: - filled = repeat_like(fill_value, -n, arr) - arrays = [*arr.slice(offset=-n).chunks, filled] - return pa.chunked_array(arrays) - - -def rank(native: ChunkedArrayAny, rank_options: RankOptions) -> ChunkedArrayAny: - arr = native if compat.RANK_ACCEPTS_CHUNKED else array(native) - if rank_options.method == "average": - # Adapted from https://github.com/pandas-dev/pandas/blob/f4851e500a43125d505db64e548af0355227714b/pandas/core/arrays/arrow/array.py#L2290-L2316 - order = pa_options.ORDER[rank_options.descending] - min = preserve_nulls(arr, pc.rank(arr, order, tiebreaker="min").cast(F64)) - max = pc.rank(arr, order, tiebreaker="max").cast(F64) - ranked = pc.divide(pc.add(min, max), lit(2, F64)) - else: - ranked = preserve_nulls(native, pc.rank(arr, options=rank_options.to_arrow())) - return chunked_array(ranked) - - -def null_count(native: ChunkedOrArrayAny) -> pa.Int64Scalar: - return pc.count(native, mode="only_null") - - -def preserve_nulls( - before: ChunkedOrArrayAny, after: ChunkedOrArrayT, / -) -> ChunkedOrArrayT: - return when_then(is_not_null(before), after) if before.null_count else after - - -drop_nulls = t.cast("VectorFunction[...]", pc.drop_null) - - -def is_only_nulls(native: ChunkedOrArrayAny, *, nan_is_null: bool = False) -> bool: - """Return True if `native` has no non-null values (and optionally include NaN).""" - return array(native.is_null(nan_is_null=nan_is_null), BOOL).false_count == 0 - - -_FILL_NULL_STRATEGY: Mapping[FillNullStrategy, UnaryFunction] = { - "forward": pc.fill_null_forward, - "backward": pc.fill_null_backward, -} - - -def _fill_null_forward_limit(native: ChunkedArrayAny, limit: int) -> ChunkedArrayAny: - SENTINEL = lit(-1) # noqa: N806 - is_not_null = native.is_valid() - index = int_range(len(native), chunked=False) - index_not_null = cum_max(when_then(is_not_null, index, SENTINEL)) - # NOTE: The correction here is for nulls at either end of the array - # They should be preserved when the `strategy` would need an out-of-bounds index - not_oob = not_eq(index_not_null, SENTINEL) - index_not_null = when_then(not_oob, index_not_null) - beyond_limit = gt(sub(index, index_not_null), lit(limit)) - return when_then(or_(is_not_null, beyond_limit), native, native.take(index_not_null)) - - -@t.overload -def fill_null( - native: ChunkedOrScalarT, value: NonNestedLiteral | ArrowAny -) -> ChunkedOrScalarT: ... -@t.overload -def fill_null( - native: ChunkedOrArrayT, value: ScalarAny | NonNestedLiteral | ChunkedOrArrayT -) -> ChunkedOrArrayT: ... -@t.overload -def fill_null( - native: ChunkedOrScalarAny, value: ChunkedOrScalarAny | NonNestedLiteral -) -> ChunkedOrScalarAny: ... -def fill_null(native: ArrowAny, value: ArrowAny | NonNestedLiteral) -> ArrowAny: - fill_value: Incomplete = value - result: ArrowAny = pc.fill_null(native, fill_value) - return result - - -@t.overload -def fill_nan( - native: ChunkedOrScalarT, value: NonNestedLiteral | ArrowAny -) -> ChunkedOrScalarT: ... -@t.overload -def fill_nan(native: SameArrowT, value: NonNestedLiteral | ArrowAny) -> SameArrowT: ... -def fill_nan(native: ArrowAny, value: NonNestedLiteral | ArrowAny) -> Incomplete: - return when_then(is_not_nan(native), native, value) - - -def fill_null_forward(native: ChunkedArrayAny) -> ChunkedArrayAny: - return fill_null_with_strategy(native, "forward") - - -def fill_null_with_strategy( - native: ChunkedArrayAny, strategy: FillNullStrategy, limit: int | None = None -) -> ChunkedArrayAny: - null_count = native.null_count - if null_count == 0 or (null_count == len(native)): - return native - if limit is None: - return _FILL_NULL_STRATEGY[strategy](native) - if strategy == "forward": - return _fill_null_forward_limit(native, limit) - return reverse(_fill_null_forward_limit(reverse(native), limit)) - - -def _ensure_all_replaced( - native: ChunkedOrScalarAny, unmatched: ArrowAny -) -> ValueError | None: - if not any_(unmatched).as_py(): - return None - msg = ( - "replace_strict did not replace all non-null values.\n\n" - f"The following did not get replaced: {chunked_array(native).filter(array(unmatched)).unique().to_pylist()}" - ) - return ValueError(msg) - - -def replace_strict( - native: ChunkedOrScalarAny, - old: Seq[Any], - new: Seq[Any], - dtype: pa.DataType | None = None, -) -> ChunkedOrScalarAny: - if isinstance(native, pa.Scalar): - idxs: ArrayAny = array(pc.index_in(native, pa.array(old))) - result: ChunkedOrScalarAny = pa.array(new).take(idxs)[0] - else: - idxs = pc.index_in(native, pa.array(old)) - result = chunked_array(pa.array(new).take(idxs)) - if err := _ensure_all_replaced(native, and_(is_not_null(native), is_null(idxs))): - raise err - return result.cast(dtype) if dtype else result - - -def replace_strict_default( - native: ChunkedOrScalarAny, - old: Seq[Any], - new: Seq[Any], - default: ChunkedOrScalarAny, - dtype: pa.DataType | None = None, -) -> ChunkedOrScalarAny: - idxs = pc.index_in(native, pa.array(old)) - result = pa.array(new).take(array(idxs)) - result = when_then(is_null(idxs), default, result.cast(dtype) if dtype else result) - return chunked_array(result) if isinstance(native, pa.ChunkedArray) else result[0] - - -@overload -def replace_with_mask( - native: ChunkedOrArrayT, mask: Predicate, replacements: ChunkedOrArrayAny -) -> ChunkedOrArrayT: ... -@overload -def replace_with_mask( - native: ChunkedOrArrayAny, mask: Predicate, replacements: ChunkedOrArrayAny -) -> ChunkedOrArrayAny: ... -def replace_with_mask( - native: ChunkedOrArrayAny, mask: Predicate, replacements: ChunkedOrArrayAny -) -> ChunkedOrArrayAny: - """Replace elements of `native`, at positions defined by `mask`. - - The length of `replacements` must equal the number of `True` values in `mask`. - """ - if isinstance(native, pa.ChunkedArray): - args = [array(p) for p in (native, mask, replacements)] - return chunked_array(pc.call_function("replace_with_mask", args)) - args = [native, array(mask), array(replacements)] - result: ChunkedOrArrayAny = pc.call_function("replace_with_mask", args) - return result - - -@t.overload -def is_between( - native: ChunkedArray[ScalarT], - lower: ChunkedOrScalar[ScalarT] | NumericLiteral, - upper: ChunkedOrScalar[ScalarT] | NumericLiteral, - closed: ClosedInterval, -) -> ChunkedArray[pa.BooleanScalar]: ... -@t.overload -def is_between( - native: ChunkedOrScalar[ScalarT], - lower: ChunkedOrScalar[ScalarT] | NumericLiteral, - upper: ChunkedOrScalar[ScalarT] | NumericLiteral, - closed: ClosedInterval, -) -> ChunkedOrScalar[pa.BooleanScalar]: ... -def is_between( - native: ChunkedOrScalar[ScalarT], - lower: ChunkedOrScalar[ScalarT] | NumericLiteral, - upper: ChunkedOrScalar[ScalarT] | NumericLiteral, - closed: ClosedInterval, -) -> ChunkedOrScalar[pa.BooleanScalar]: - fn_lhs, fn_rhs = _IS_BETWEEN[closed] - low, high = (el if _is_arrow(el) else lit(el) for el in (lower, upper)) - out: ChunkedOrScalar[pa.BooleanScalar] = and_( - fn_lhs(native, low), fn_rhs(native, high) - ) - return out - - -@t.overload -def is_in( - values: ChunkedArrayAny, /, other: ChunkedOrArrayAny -) -> ChunkedArray[pa.BooleanScalar]: ... -@t.overload -def is_in(values: ArrayAny, /, other: ChunkedOrArrayAny) -> Array[pa.BooleanScalar]: ... -@t.overload -def is_in(values: ScalarAny, /, other: ChunkedOrArrayAny) -> pa.BooleanScalar: ... -@t.overload -def is_in( - values: ChunkedOrScalarAny, /, other: ChunkedOrArrayAny -) -> ChunkedOrScalarAny: ... -def is_in(values: ArrowAny, /, other: ChunkedOrArrayAny) -> ArrowAny: - """Check if elements of `values` are present in `other`. - - Roughly equivalent to [`polars.Expr.is_in`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.is_in.html) - - Returns a mask with `len(values)` elements. - """ - # NOTE: Stubs don't include a `ChunkedArray` return - # NOTE: Replaced ambiguous parameter name (`value_set`) - is_in_: Incomplete = pc.is_in - return is_in_(values, other) # type: ignore[no-any-return] - - -@t.overload -def eq_missing( - native: ChunkedArrayAny, other: NonNestedLiteral | ArrowAny -) -> ChunkedArray[pa.BooleanScalar]: ... -@t.overload -def eq_missing( - native: ArrayAny, other: NonNestedLiteral | ArrowAny -) -> Array[pa.BooleanScalar]: ... -@t.overload -def eq_missing( - native: ScalarAny, other: NonNestedLiteral | ArrowAny -) -> pa.BooleanScalar: ... -@t.overload -def eq_missing( - native: ChunkedOrScalarAny, other: NonNestedLiteral | ArrowAny -) -> ChunkedOrScalarAny: ... -def eq_missing(native: ArrowAny, other: NonNestedLiteral | ArrowAny) -> ArrowAny: - """Equivalent to `native == other` where `None == None`. - - This differs from default `eq` where null values are propagated. - - Note: - Unique to `pyarrow`, this wrapper will ensure `None` uses `native.type`. - """ - if isinstance(other, (pa.Array, pa.ChunkedArray)): - return is_in(native, other) - item = array(other if isinstance(other, pa.Scalar) else lit(other, native.type)) - return is_in(native, item) - - -def ir_min_max(name: str, /) -> MinMax: - return MinMax(expr=ir.col(name)) - - -def _boolean_is_unique( - indices: ChunkedArrayAny, aggregated: ChunkedStruct, / -) -> ChunkedArrayAny: - min, max = struct_fields(aggregated, "min", "max") - return and_(is_in(indices, min), is_in(indices, max)) - - -def _boolean_is_duplicated( - indices: ChunkedArrayAny, aggregated: ChunkedStruct, / -) -> ChunkedArrayAny: - return not_(_boolean_is_unique(indices, aggregated)) - - -BOOLEAN_LENGTH_PRESERVING: Mapping[ - type[ir.boolean.BooleanFunction], tuple[IntoColumnAgg, BooleanLengthPreserving] -] = { - ir.boolean.IsFirstDistinct: (ir.min, is_in), - ir.boolean.IsLastDistinct: (ir.max, is_in), - ir.boolean.IsUnique: (ir_min_max, _boolean_is_unique), - ir.boolean.IsDuplicated: (ir_min_max, _boolean_is_duplicated), -} -_UNIQUE_KEEP_BOOLEAN_LENGTH_PRESERVING: Mapping[ - UniqueKeepStrategy, type[ir.boolean.BooleanFunction] -] = { - "any": ir.boolean.IsFirstDistinct, - "first": ir.boolean.IsFirstDistinct, - "last": ir.boolean.IsLastDistinct, - "none": ir.boolean.IsUnique, -} - - -def unique_keep_boolean_length_preserving( - keep: UniqueKeepStrategy, -) -> tuple[IntoColumnAgg, BooleanLengthPreserving]: - return BOOLEAN_LENGTH_PRESERVING[_UNIQUE_KEEP_BOOLEAN_LENGTH_PRESERVING[keep]] - - -def binary( - lhs: ChunkedOrScalarAny, op: type[ops.Operator], rhs: ChunkedOrScalarAny -) -> ChunkedOrScalarAny: - return _DISPATCH_BINARY[op](lhs, rhs) - - -@t.overload -def concat_str( - *arrays: ChunkedArrayAny, separator: str = ..., ignore_nulls: bool = ... -) -> ChunkedArray[StringScalar]: ... -@t.overload -def concat_str( - *arrays: ArrayAny, separator: str = ..., ignore_nulls: bool = ... -) -> Array[StringScalar]: ... -@t.overload -def concat_str( - *arrays: ScalarAny, separator: str = ..., ignore_nulls: bool = ... -) -> StringScalar: ... -def concat_str( - *arrays: ArrowAny, separator: str = "", ignore_nulls: bool = False -) -> Arrow[StringScalar]: - """Horizontally arrow data into a single string column.""" - dtype = string_type(obj.type for obj in arrays) - it = (obj.cast(dtype) for obj in arrays) - concat: Incomplete = pc.binary_join_element_wise - join = pa_options.join(ignore_nulls=ignore_nulls) - return concat(*it, lit(separator, dtype), options=join) # type: ignore[no-any-return] - - -def random_indices( - end: int, /, n: int, *, with_replacement: bool = False, seed: int | None = None -) -> ArrayAny: - """Generate `n` random indices within the range `[0, end)`.""" - # NOTE: Review this path if anything changes upstream - # https://github.com/apache/arrow/issues/47288#issuecomment-3597653670 - if with_replacement: - rand_values = pc.random(n, initializer="system" if seed is None else seed) - return round(multiply(rand_values, lit(end - 1))).cast(I64) - - import numpy as np # ignore-banned-import - - return array(np.random.default_rng(seed).choice(np.arange(end), n, replace=False)) - - -@overload -def sort_indices( - native: ChunkedOrArrayAny, *, options: SortOptions | None -) -> pa.UInt64Array: ... -@overload -def sort_indices( - native: ChunkedOrArrayAny, *, descending: bool = ..., nulls_last: bool = ... -) -> pa.UInt64Array: ... -@overload -def sort_indices( - native: pa.Table, - *by: Unpack[tuple[str, Unpack[tuple[str, ...]]]], - options: SortOptions | SortMultipleOptions | None, -) -> pa.UInt64Array: ... -@overload -def sort_indices( - native: pa.Table, - *by: Unpack[tuple[str, Unpack[tuple[str, ...]]]], - descending: bool | Sequence[bool] = ..., - nulls_last: bool = ..., -) -> pa.UInt64Array: ... -def sort_indices( - native: ChunkedOrArrayAny | pa.Table, - *by: str, - options: SortOptions | SortMultipleOptions | None = None, - descending: bool | Sequence[bool] = False, - nulls_last: bool = False, -) -> pa.UInt64Array: - """Return the indices that would sort an array or table. - - Arguments: - native: Any non-scalar arrow data. - *by: Column(s) to sort by. Only applicable to `Table` and must use at least one name. - options: An *already-parsed* options instance. - **Has higher precedence** than `descending` and `nulls_last`. - descending: Sort in descending order. When sorting by multiple columns, - can be specified per column by passing a sequence of booleans. - nulls_last: Place null values last. - - Notes: - Most commonly used as input for `take`, which forms a `sort_by` operation. - """ - if not isinstance(native, pa.Table): - if options: - descending = options.descending - nulls_last = options._ensure_single_nulls_last("pyarrow") - a_opts = pa_options.array_sort(descending=descending, nulls_last=nulls_last) - return pc.array_sort_indices(native, options=a_opts) - opts = ( - options.to_arrow(by) - if options - else pa_options.sort(*by, descending=descending, nulls_last=nulls_last) - ) - return pc.sort_indices(native, options=opts) - - -def unsort_indices(indices: pa.UInt64Array, /) -> pa.Int64Array: - """Return the inverse permutation of the given indices. - - Arguments: - indices: The output of `sort_indices`. - - Examples: - We can use this pair of functions to recreate a windowed `pl.row_index` - - >>> import polars as pl - >>> data = {"by": [5, 2, 5, None]} - >>> df = pl.DataFrame(data) - >>> df.select( - ... pl.row_index().over(order_by="by", descending=True, nulls_last=False) - ... ).to_series().to_list() - [1, 3, 2, 0] - - Now in `pyarrow` - - >>> import pyarrow as pa - >>> from narwhals._plan.arrow.functions import sort_indices, unsort_indices - >>> df = pa.Table.from_pydict(data) - >>> unsort_indices( - ... sort_indices(df, "by", descending=True, nulls_last=False) - ... ).to_pylist() - [1, 3, 2, 0] - """ - return ( - pc.inverse_permutation(indices.cast(pa.int64())) # type: ignore[attr-defined] - if compat.HAS_SCATTER - else int_range(len(indices), chunked=False).take(pc.sort_indices(indices)) - ) - - -@overload -def int_range( - start: int = ..., - end: int | None = ..., - step: int = ..., - /, - *, - dtype: IntegerType = ..., - chunked: Literal[True] = ..., -) -> ChunkedArray[IntegerScalar]: ... -@overload -def int_range( - start: int = ..., - end: int | None = ..., - step: int = ..., - /, - *, - chunked: Literal[False], -) -> pa.Int64Array: ... -@overload -def int_range( - start: int = ..., - end: int | None = ..., - step: int = ..., - /, - *, - dtype: IntegerType = ..., - chunked: Literal[False], -) -> Array[IntegerScalar]: ... -def int_range( - start: int = 0, - end: int | None = None, - step: int = 1, - /, - *, - dtype: IntegerType = I64, - chunked: bool = True, -) -> ChunkedOrArray[IntegerScalar]: - if end is None: - end = start - start = 0 - if not compat.HAS_ARANGE: # pragma: no cover - import numpy as np # ignore-banned-import - - arr = pa.array(np.arange(start, end, step), type=dtype) - else: - int_range_: Incomplete = pa.arange # type: ignore[attr-defined] - arr = t.cast("ArrayAny", int_range_(start, end, step)).cast(dtype) - return arr if not chunked else pa.chunked_array([arr]) - - -def date_range( - start: dt.date, - end: dt.date, - interval: int, # (* assuming the `Interval` part is solved) - *, - closed: ClosedInterval = "both", -) -> ChunkedArray[DateScalar]: - start_i = pa.scalar(start).cast(pa.int32()).as_py() - end_i = pa.scalar(end).cast(pa.int32()).as_py() - ca = int_range(start_i, end_i + 1, interval, dtype=pa.int32()) - if closed == "both": - return ca.cast(pa.date32()) - if closed == "left": - ca = ca.slice(length=ca.length() - 1) - elif closed == "none": - ca = ca.slice(1, length=ca.length() - 1) - else: - ca = ca.slice(1) - return ca.cast(pa.date32()) - - -def linear_space( - start: float, end: float, num_samples: int, *, closed: ClosedInterval = "both" -) -> ChunkedArray[pc.NumericScalar]: - """Based on [`new_linear_space_f64`]. - - [`new_linear_space_f64`]: https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-ops/src/series/ops/linear_space.rs#L62-L94 - """ - if num_samples < 0: - msg = f"Number of samples, {num_samples}, must be non-negative." - raise ValueError(msg) - if num_samples == 0: - return chunked_array([[]], F64) - if num_samples == 1: - if closed == "none": - value = (end + start) * 0.5 - elif closed in {"left", "both"}: - value = float(start) - else: - value = float(end) - return chunked_array([[value]], F64) - n = num_samples - span = float(end - start) - if closed == "none": - d = span / (n + 1) - start = start + d - elif closed == "left": - d = span / n - elif closed == "right": - start = start + span / n - d = span / n - else: - d = span / (n - 1) - ca: ChunkedArray[pc.NumericScalar] = multiply(int_range(0, n).cast(F64), lit(d)) - ca = add(ca, lit(start, F64)) - return ca # noqa: RET504 - - -def repeat(value: ScalarAny | NonNestedLiteral, n: int) -> ArrayAny: - value = value if isinstance(value, pa.Scalar) else lit(value) - return repeat_unchecked(value, n) - - -def repeat_unchecked(value: ScalarAny, /, n: int) -> ArrayAny: - repeat_: Incomplete = pa.repeat - result: ArrayAny = repeat_(value, n) - return result - - -def repeat_like(value: NonNestedLiteral, n: int, native: ArrowAny) -> ArrayAny: - return repeat_unchecked(lit(value, native.type), n) - - -def nulls_like(n: int, native: ArrowAny) -> ArrayAny: - """Create a strongly-typed Array instance with all elements null. - - Uses the type of `native`. - """ - result: ArrayAny = pa.nulls(n, native.type) - return result - - -def zeros(n: int, /) -> pa.Int64Array: - return pa.repeat(0, n) - - -SearchSortedSide: TypeAlias = Literal["left", "right"] - - -# NOTE @dangotbanned: (wish) replacing `np.searchsorted`? -@t.overload -def search_sorted( - native: ChunkedOrArrayT, - element: ChunkedOrArray[NumericScalar] | Sequence[float], - *, - side: SearchSortedSide = ..., -) -> ChunkedOrArrayT: ... -# NOTE: scalar case may work with only `partition_nth_indices`? -@t.overload -def search_sorted( - native: ChunkedOrArrayT, element: float, *, side: SearchSortedSide = ... -) -> ScalarAny: ... -def search_sorted( - native: ChunkedOrArrayT, - element: ChunkedOrArray[NumericScalar] | Sequence[float] | float, - *, - side: SearchSortedSide = "left", -) -> ChunkedOrArrayT | ScalarAny: - """Find indices where elements should be inserted to maintain order.""" - import numpy as np # ignore-banned-import - - indices = np.searchsorted(element, native, side=side) - if isinstance(indices, np.generic): - return lit(indices) - if isinstance(native, pa.ChunkedArray): - return chunked_array([indices]) - return array(indices) - - -def hist_bins( - native: ChunkedArrayAny, - bins: Sequence[float] | ChunkedArray[NumericScalar], - *, - include_breakpoint: bool, -) -> Mapping[str, Iterable[Any]]: - """Bin values into buckets and count their occurrences. - - Notes: - Assumes that the following edge cases have been handled: - - `len(bins) >= 2` - - `bins` increase monotonically - - `bin[0] != bin[-1]` - - `native` contains values that are non-null (including NaN) - """ - if len(bins) == 2: - upper = bins[1] - count = array(is_between(native, bins[0], upper, closed="both"), BOOL).true_count - if include_breakpoint: - return {"breakpoint": [upper], "count": [count]} - return {"count": [count]} - - # lowest bin is inclusive - # NOTE: `np.unique` behavior sorts first - value_counts = ( - when_then(not_eq(native, lit(bins[0])), search_sorted(native, bins), 1) - .sort() - .value_counts() - ) - values, counts = struct_fields(value_counts, "values", "counts") - bin_count = len(bins) - int_range_ = int_range(1, bin_count, chunked=False) - mask = is_in(int_range_, values) - replacements = counts.filter(is_in(values, int_range_)) - counts = replace_with_mask(zeros(bin_count - 1), mask, replacements) - - if include_breakpoint: - return {"breakpoint": bins[1:], "count": counts} - return {"count": counts} - - -def hist_zeroed_data( - arg: int | Sequence[float], *, include_breakpoint: bool -) -> Mapping[str, Iterable[Any]]: - # NOTE: If adding `linear_space` and `zeros` to `CompliantNamespace`, consider moving this. - n = arg if isinstance(arg, int) else len(arg) - 1 - if not include_breakpoint: - return {"count": zeros(n)} - bp = linear_space(0, 1, arg, closed="right") if isinstance(arg, int) else arg[1:] - return {"breakpoint": bp, "count": zeros(n)} - - -@overload -def lit(value: Any) -> NativeScalar: ... -@overload -def lit(value: Any, dtype: BoolType) -> pa.BooleanScalar: ... -@overload -def lit(value: Any, dtype: UInt32Type) -> pa.UInt32Scalar: ... -@overload -def lit(value: Any, dtype: DataType | None = ...) -> NativeScalar: ... -def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: - return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) - - -# TODO @dangotbanned: Report `ListScalar.values` bug upstream -# See `tests/plan/list_unique_test.py::test_list_unique_scalar[None-None]` -@overload -def array(data: ArrowAny, /) -> ArrayAny: ... -@overload -def array(data: Arrow[BooleanScalar], dtype: BoolType, /) -> pa.BooleanArray: ... -@overload -def array( - data: Iterable[PythonLiteral], dtype: DataType | None = None, / -) -> ArrayAny: ... -def array( - data: ArrowAny | Iterable[PythonLiteral], dtype: DataType | None = None, / -) -> ArrayAny: - """Convert `data` into an Array instance. - - Note: - `dtype` is **not used** for existing `pyarrow` data, but it can be used to signal - the concrete `Array` subclass that is returned. - To actually changed the type, use `cast` instead. - """ - if isinstance(data, pa.ChunkedArray): - return data.combine_chunks() - if isinstance(data, pa.Array): - return data - if isinstance(data, pa.Scalar): - if isinstance(data, pa.ListScalar) and data.is_valid is False: - return pa.array([None], data.type) - return pa.array([data], data.type) - return pa.array(data, dtype) - - -def chunked_array( - data: ArrowAny | list[Iterable[Any]], dtype: DataType | None = None, / -) -> ChunkedArrayAny: - return _chunked_array(array(data) if isinstance(data, pa.Scalar) else data, dtype) - - -def concat_horizontal( - arrays: Collection[ChunkedOrArrayAny], names: Collection[str] -) -> pa.Table: - """Concatenate `arrays` as columns in a new table.""" - table: Incomplete = pa.Table.from_arrays - result: pa.Table = table(arrays, names) - return result - - -def concat_vertical( - arrays: Iterable[ChunkedOrArrayAny], dtype: DataType | None = None, / -) -> ChunkedArrayAny: - """Concatenate `arrays` into a new array.""" - v_concat: Incomplete = pa.chunked_array - result: ChunkedArrayAny = v_concat(arrays, dtype) - return result - - -def to_table(array: ChunkedOrArrayAny, name: str = "") -> pa.Table: - """Equivalent to `Series.to_frame`, but with an option to insert a name for the column.""" - return concat_horizontal((array,), (name,)) - - -def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataType]]: - return ( - (first := next(iter(obj.items())), None) - and isinstance(first[0], str) - and isinstance(first[1], pa.DataType) - ) - - -def _is_arrow(obj: Arrow[ScalarT] | Any) -> TypeIs[Arrow[ScalarT]]: - return isinstance(obj, (pa.Scalar, pa.Array, pa.ChunkedArray)) - - -def filter_arrays( - predicate: ChunkedOrArray[BooleanScalar] | pc.Expression, - *arrays: Unpack[Ts], - ignore_nulls: bool = True, -) -> tuple[Unpack[Ts]]: - """Apply the same filter to multiple arrays, returning them independently. - - Note: - The typing here is a minefield. You'll get an `*arrays`-length `tuple[ChunkedArray, ...]`. - """ - table: Incomplete = pa.Table.from_arrays - tmp = [str(i) for i in range(len(arrays))] - result = table(arrays, tmp).filter(predicate, "drop" if ignore_nulls else "emit_null") - return t.cast("tuple[Unpack[Ts]]", tuple(result.columns)) diff --git a/narwhals/_plan/arrow/functions/__init__.py b/narwhals/_plan/arrow/functions/__init__.py new file mode 100644 index 0000000000..72e34f48ca --- /dev/null +++ b/narwhals/_plan/arrow/functions/__init__.py @@ -0,0 +1,254 @@ +"""Native functions, aliased and/or with behavior aligned to `polars`.""" + +from __future__ import annotations + +from narwhals._plan.arrow.functions import ( + _categorical as cat, + _lists as list, + _strings as str, + _struct as struct, + meta, +) +from narwhals._plan.arrow.functions._aggregation import ( + count, + first, + kurtosis_skew, + last, + max, + mean, + median, + min, + mode_all, + mode_any, + n_unique, + null_count, + quantile, + std, + sum, + var, +) +from narwhals._plan.arrow.functions._arithmetic import ( + abs, + add, + exp, + floordiv, + log, + modulus, + multiply, + power, + sqrt, + sub, + truediv, +) +from narwhals._plan.arrow.functions._bin_op import ( + and_, + binary, + eq, + gt, + gt_eq, + lt, + lt_eq, + not_eq, + or_, + xor, +) +from narwhals._plan.arrow.functions._boolean import ( + all, + any, + eq_missing, + is_between, + is_finite, + is_in, + is_nan, + is_not_nan, + is_not_null, + is_null, + is_only_nulls, + not_, +) +from narwhals._plan.arrow.functions._construction import ( + array, + chunked_array, + concat_horizontal, + concat_tables, + concat_vertical, + lit, + to_table, +) +from narwhals._plan.arrow.functions._cumulative import ( + cum_count, + cum_max, + cum_min, + cum_prod, + cum_sum, + cumulative, +) +from narwhals._plan.arrow.functions._dtypes import ( + BOOL, + DATE, + F64, + I32, + I64, + U32, + cast, + cast_table, + dtype_native, + string_type, +) +from narwhals._plan.arrow.functions._horizontal import max_horizontal, min_horizontal +from narwhals._plan.arrow.functions._lists import ExplodeBuilder +from narwhals._plan.arrow.functions._multiplex import ( + fill_nan, + fill_null, + fill_null_with_strategy, + preserve_nulls, + replace_strict, + replace_strict_default, + replace_with_mask, + when_then, +) +from narwhals._plan.arrow.functions._ranges import date_range, int_range, linear_space +from narwhals._plan.arrow.functions._repeat import ( + nulls_like, + repeat, + repeat_like, + repeat_unchecked, + zeros, +) +from narwhals._plan.arrow.functions._round import ( + ceil, + clip, + clip_lower, + clip_upper, + floor, + round, +) +from narwhals._plan.arrow.functions._sort import ( + random_indices, + reverse, + sort_indices, + unsort_indices, +) +from narwhals._plan.arrow.functions._vector import ( + diff, + hist_bins, + hist_zeroed_data, + rank, + search_sorted, + shift, +) + +__all__ = [ + "BOOL", + "DATE", + "F64", + "I32", + "I64", + "U32", + "ExplodeBuilder", + "abs", + "add", + "all", + "and_", + "any", + "array", + "binary", + "cast", + "cast_table", + "cat", + "ceil", + "chunked_array", + "clip", + "clip_lower", + "clip_upper", + "concat_horizontal", + "concat_tables", + "concat_vertical", + "count", + "cum_count", + "cum_max", + "cum_min", + "cum_prod", + "cum_sum", + "cumulative", + "date_range", + "diff", + "dtype_native", + "eq", + "eq_missing", + "exp", + "fill_nan", + "fill_null", + "fill_null_with_strategy", + "first", + "floor", + "floordiv", + "gt", + "gt_eq", + "hist_bins", + "hist_zeroed_data", + "int_range", + "is_between", + "is_finite", + "is_in", + "is_nan", + "is_not_nan", + "is_not_null", + "is_null", + "is_only_nulls", + "kurtosis_skew", + "last", + "linear_space", + "list", + "lit", + "log", + "lt", + "lt_eq", + "max", + "max_horizontal", + "mean", + "median", + "meta", + "min", + "min_horizontal", + "mode_all", + "mode_any", + "modulus", + "multiply", + "n_unique", + "not_", + "not_eq", + "null_count", + "nulls_like", + "or_", + "power", + "preserve_nulls", + "quantile", + "random_indices", + "rank", + "repeat", + "repeat_like", + "repeat_unchecked", + "replace_strict", + "replace_strict_default", + "replace_with_mask", + "reverse", + "round", + "search_sorted", + "shift", + "sort_indices", + "sqrt", + "std", + "str", + "string_type", + "struct", + "sub", + "sum", + "to_table", + "truediv", + "unsort_indices", + "var", + "when_then", + "xor", + "zeros", +] diff --git a/narwhals/_plan/arrow/functions/_aggregation.py b/narwhals/_plan/arrow/functions/_aggregation.py new file mode 100644 index 0000000000..28bacfc2bc --- /dev/null +++ b/narwhals/_plan/arrow/functions/_aggregation.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING, Literal + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan.arrow import compat, options as pa_options +from narwhals._plan.arrow.functions import _categorical as cat +from narwhals._plan.arrow.functions._arithmetic import power, sub +from narwhals._plan.arrow.functions._construction import array, chunked_array, lit +from narwhals._plan.arrow.functions._dtypes import F64 +from narwhals._plan.arrow.functions.meta import call + +if TYPE_CHECKING: + from collections.abc import Callable + + from narwhals._plan.arrow.typing import ( + Arrow, + ChunkedArray, + ChunkedArrayAny, + ChunkedOrArray, + ChunkedOrArrayAny, + DataTypeT, + Scalar, + ScalarAny, + ) + +__all__ = [ + "count", + "first", + "implode", + "kurtosis_skew", + "last", + "max", + "mean", + "median", + "min", + "mode_all", + "mode_any", + "n_unique", + "null_count", + "quantile", + "std", + "sum", + "var", +] + + +min = pc.min +"""Get the minimal value in this array.""" +max = pc.max +"""Get the maximum value in this array.""" +mean = t.cast("Callable[[ChunkedOrArray[pc.NumericScalar]], pa.DoubleScalar]", pc.mean) +"""Reduce this array to the mean value.""" +count = pc.count +"""Return the number of non-null elements in this array.""" +median = pc.approximate_median +"""Get the median of this array.""" +std = pc.stddev +"""Get the standard deviation of this array.""" +var = pc.variance +"""Get the variance of this array.""" +quantile = pc.quantile +"""Get the quantile value of this array.""" + + +def sum(native: ChunkedOrArrayAny) -> ScalarAny: + """Reduce this array to the sum value.""" + opts = pa_options.scalar_aggregate(ignore_nulls=True) + result: ScalarAny = call("sum", native, options=opts) + return result + + +def first(native: ChunkedOrArrayAny) -> ScalarAny: + """Get the first element of this array.""" + return pc.first(native, options=pa_options.scalar_aggregate()) + + +def last(native: ChunkedOrArrayAny) -> ScalarAny: + """Get the last element of this array.""" + return pc.last(native, options=pa_options.scalar_aggregate()) + + +def implode(native: Arrow[Scalar[DataTypeT]]) -> pa.ListScalar[DataTypeT]: + """Aggregate values into a list. + + Arguments: + native: Any arrow data. + + The returned list *itself* is a scalar value of `list` dtype. + """ + arr = array(native) + return pa.ListArray.from_arrays([0, len(arr)], arr)[0] + + +def kurtosis_skew( + native: ChunkedArray[pc.NumericScalar], function: Literal["kurtosis", "skew"], / +) -> ScalarAny: + """Compute the kurtosis or sample skewness of this array.""" + result: ScalarAny + if compat.HAS_KURTOSIS_SKEW: + if pa.types.is_null(native.type): + native = native.cast(F64) + result = call(function, native) + else: + non_null = native.drop_null() + if len(non_null) == 0: + result = lit(None, F64) + elif len(non_null) == 1: + result = lit(float("nan")) + elif function == "skew" and len(non_null) == 2: + result = lit(0.0, F64) + else: + m = sub(non_null, mean(non_null)) + m2 = mean(power(m, lit(2))) + if function == "kurtosis": + m4 = mean(power(m, lit(4))) + result = sub(pc.divide(m4, power(m2, lit(2))), lit(3)) + else: + m3 = mean(power(m, lit(3))) + result = pc.divide(m3, power(m2, lit(1.5))) + return result + + +def n_unique(native: ChunkedOrArrayAny) -> pa.Int64Scalar: + """Return the number of unique values in this array.""" + return pc.count_distinct(native, mode="all") + + +def null_count(native: ChunkedOrArrayAny) -> pa.Int64Scalar: + """Count the null values in this array.""" + return pc.count(native, mode="only_null") + + +def mode_any(native: ChunkedOrArrayAny) -> ScalarAny: + """Compute the most occurring value(s) and return *any* one of them.""" + return first(pc.mode(native, n=1).field("mode")) + + +def mode_all(native: ChunkedOrArrayAny) -> ChunkedArrayAny: + """Compute the most occurring value(s) and return *all* of them.""" + struct_arr = pc.mode(native, n=len(native)) + indices = cat.encode(struct_arr.field("count")) + index_true_modes = lit(0) + return chunked_array( + struct_arr.field("mode").filter(pc.equal(indices, index_true_modes)) + ) diff --git a/narwhals/_plan/arrow/functions/_arithmetic.py b/narwhals/_plan/arrow/functions/_arithmetic.py new file mode 100644 index 0000000000..1c0ae7988f --- /dev/null +++ b/narwhals/_plan/arrow/functions/_arithmetic.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import math +import typing as t +from typing import TYPE_CHECKING + +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._arrow.utils import floordiv_compat as _floordiv +from narwhals._plan.arrow.functions._dtypes import F64, is_integer +from narwhals._plan.arrow.functions.meta import call + +if TYPE_CHECKING: + from narwhals._plan.arrow.typing import ( + BinaryFunction, + BinaryNumericTemporal, + ChunkedOrScalarAny, + NumericScalar, + UnaryNumeric, + ) + +__all__ = [ + "abs", + "add", + "exp", + "floordiv", + "log", + "modulus", + "multiply", + "power", + "sqrt", + "sub", + "truediv", +] + +add = t.cast("BinaryNumericTemporal", pc.add) +"""Equivalent to `lhs + rhs`.""" +sub = t.cast("BinaryNumericTemporal", pc.subtract) +"""Equivalent to `lhs - rhs`.""" +multiply = t.cast("BinaryNumericTemporal", pc.multiply) +"""Equivalent to `lhs * rhs`.""" +floordiv = t.cast("BinaryNumericTemporal", _floordiv) +"""Equivalent to `lhs // rhs`.""" +power = t.cast("BinaryFunction[NumericScalar, NumericScalar]", pc.power) +"""Equivalent to `lhs ** rhs`.""" +sqrt = t.cast("UnaryNumeric", pc.sqrt) +"""Compute the square root of the elements.""" +abs = t.cast("UnaryNumeric", pc.abs) +"""Compute absolute values.""" +exp = t.cast("UnaryNumeric", pc.exp) +"""Compute the exponential, element-wise.""" + + +def truediv(lhs: ChunkedOrScalarAny, rhs: ChunkedOrScalarAny, /) -> ChunkedOrScalarAny: + """Equivalent to `lhs / rhs`.""" + if is_integer(lhs.type) and is_integer(rhs.type): + lhs, rhs = lhs.cast(F64, safe=False), rhs.cast(F64, safe=False) + result: ChunkedOrScalarAny = call("divide", lhs, rhs) + return result + + +def modulus(lhs: ChunkedOrScalarAny, rhs: ChunkedOrScalarAny, /) -> ChunkedOrScalarAny: + """Equivalent to `lhs % rhs`.""" + result: ChunkedOrScalarAny = sub(lhs, multiply(floordiv(lhs, rhs), rhs)) + return result + + +def log(native: ChunkedOrScalarAny, base: float = math.e) -> ChunkedOrScalarAny: + """Compute the logarithm to a given base.""" + result: ChunkedOrScalarAny = call("logb", native, base) + return result diff --git a/narwhals/_plan/arrow/functions/_bin_op.py b/narwhals/_plan/arrow/functions/_bin_op.py new file mode 100644 index 0000000000..86a014e25c --- /dev/null +++ b/narwhals/_plan/arrow/functions/_bin_op.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING + +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan.arrow.functions import _arithmetic as arith +from narwhals._plan.expressions import operators as ops + +if TYPE_CHECKING: + from collections.abc import Mapping + + from narwhals._plan.arrow.typing import ( + BinaryComp, + BinaryLogical, + BinOp, + ChunkedOrScalarAny, + ) + +__all__ = ["and_", "binary", "eq", "gt", "gt_eq", "lt", "lt_eq", "not_eq", "or_", "xor"] + +and_ = t.cast("BinaryLogical", pc.and_kleene) +"""Equivalent to `lhs & rhs`.""" +or_ = t.cast("BinaryLogical", pc.or_kleene) +"""Equivalent to `lhs | rhs`.""" +xor = t.cast("BinaryLogical", pc.xor) +"""Equivalent to `lhs ^ rhs`.""" + +eq = t.cast("BinaryComp", pc.equal) +"""Equivalent to `lhs == rhs`.""" +not_eq = t.cast("BinaryComp", pc.not_equal) +"""Equivalent to `lhs != rhs`.""" +gt_eq = t.cast("BinaryComp", pc.greater_equal) +"""Equivalent to `lhs >= rhs`.""" +gt = t.cast("BinaryComp", pc.greater) +"""Equivalent to `lhs > rhs`.""" +lt_eq = t.cast("BinaryComp", pc.less_equal) +"""Equivalent to `lhs <= rhs`.""" +lt = t.cast("BinaryComp", pc.less) +"""Equivalent to `lhs < rhs`.""" + + +def binary( + lhs: ChunkedOrScalarAny, op: type[ops.Operator], rhs: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + """Dispatch a binary operator type to a native function, providing `lhs` and `rhs` as operands.""" + return _DISPATCH_BINARY[op](lhs, rhs) + + +# TODO @dangotbanned: Somehow fix the typing on this +# - `_ArrowDispatch` is relying on the gradual typing +_DISPATCH_BINARY: Mapping[type[ops.Operator], BinOp] = { + # BinaryComp + ops.Eq: eq, + ops.NotEq: not_eq, + ops.Lt: lt, + ops.LtEq: lt_eq, + ops.Gt: gt, + ops.GtEq: gt_eq, + # BinaryFunction (well it should be) + ops.Add: arith.add, # BinaryNumericTemporal + ops.Sub: arith.sub, # BinaryNumericTemporal + ops.Multiply: arith.multiply, # BinaryNumericTemporal + ops.TrueDivide: arith.truediv, # [[ChunkedOrScalarAny, ChunkedOrScalarAny], ChunkedOrScalarAny] + ops.FloorDivide: arith.floordiv, # BinaryNumericTemporal + ops.Modulus: arith.modulus, # [[ChunkedOrScalarAny, ChunkedOrScalarAny], ChunkedOrScalarAny] + # BinaryLogical + ops.And: and_, + ops.Or: or_, + ops.ExclusiveOr: xor, +} diff --git a/narwhals/_plan/arrow/functions/_boolean.py b/narwhals/_plan/arrow/functions/_boolean.py new file mode 100644 index 0000000000..4ec8a122e0 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_boolean.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING, overload + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan.arrow.functions._bin_op import and_, gt, gt_eq, lt, lt_eq +from narwhals._plan.arrow.functions._construction import array, lit +from narwhals._plan.arrow.functions._dtypes import BOOL +from narwhals._plan.arrow.guards import is_arrow + +if TYPE_CHECKING: + from collections.abc import Mapping + + from narwhals._arrow.typing import Incomplete + from narwhals._plan.arrow.typing import ( + Array, + ArrayAny, + Arrow, + ArrowAny, + BinaryComp, + BooleanScalar, + ChunkedArray, + ChunkedArrayAny, + ChunkedOrArrayAny, + ChunkedOrScalar, + ChunkedOrScalarAny, + ScalarAny, + ScalarT, + UnaryFunction, + ) + from narwhals.typing import ClosedInterval, NonNestedLiteral, NumericLiteral + + +__all__ = [ + "all", + "any", + "eq_missing", + "is_between", + "is_finite", + "is_in", + "is_nan", + "is_not_nan", + "is_not_null", + "is_null", + "is_only_nulls", + "not_", +] + + +def any(native: Arrow[BooleanScalar], *, ignore_nulls: bool = True) -> pa.BooleanScalar: + """Return whether any values in `native` are True. + + Arguments: + native: Boolean-typed arrow data. + ignore_nulls: If set to `True` (default), null values are ignored. + If there are no non-null values, the output is `False`. + + If set to `False`, [Kleene logic] is used to deal with nulls; + if the column contains any null values and no `True` values, + the output is null. + + [Kleene logic]: https://en.wikipedia.org/wiki/Three-valued_logic + """ + ca = t.cast("ChunkedArray[pa.BooleanScalar]", native) + return pc.any(ca, min_count=0, skip_nulls=ignore_nulls) + + +def all(native: Arrow[BooleanScalar], *, ignore_nulls: bool = True) -> pa.BooleanScalar: + """Return whether all values in `native` are True. + + Arguments: + native: Boolean-typed arrow data. + ignore_nulls: If set to `True` (default), null values are ignored. + If there are no non-null values, the output is `True`. + + If set to `False`, [Kleene logic] is used to deal with nulls; + if the column contains any null values and no `False` values, + the output is null. + + [Kleene logic]: https://en.wikipedia.org/wiki/Three-valued_logic + """ + ca = t.cast("ChunkedArray[pa.BooleanScalar]", native) + return pc.all(ca, min_count=0, skip_nulls=ignore_nulls) + + +is_null = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_null) +is_not_null = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_valid) +is_nan = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_nan) +is_finite = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.is_finite) +not_ = t.cast("UnaryFunction[ScalarAny, pa.BooleanScalar]", pc.invert) +"""Invert Boolean-typed arrow data.""" + + +@overload +def is_not_nan(native: ChunkedArrayAny) -> ChunkedArray[pa.BooleanScalar]: ... +@overload +def is_not_nan(native: ScalarAny) -> pa.BooleanScalar: ... +@overload +def is_not_nan(native: ChunkedOrScalarAny) -> ChunkedOrScalar[pa.BooleanScalar]: ... +@overload +def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: ... +def is_not_nan(native: Arrow[ScalarAny]) -> Arrow[pa.BooleanScalar]: + return not_(is_nan(native)) + + +def is_only_nulls(native: ChunkedOrArrayAny, *, nan_is_null: bool = False) -> bool: + """Return True if `native` has 0 non-null values (and optionally include NaN).""" + return array(native.is_null(nan_is_null=nan_is_null), BOOL).false_count == 0 + + +@overload +def is_between( + native: ChunkedArray[ScalarT], + lower: ChunkedOrScalar[ScalarT] | NumericLiteral, + upper: ChunkedOrScalar[ScalarT] | NumericLiteral, + *, + closed: ClosedInterval = "both", +) -> ChunkedArray[pa.BooleanScalar]: ... +@overload +def is_between( + native: ChunkedOrScalar[ScalarT], + lower: ChunkedOrScalar[ScalarT] | NumericLiteral, + upper: ChunkedOrScalar[ScalarT] | NumericLiteral, + *, + closed: ClosedInterval = "both", +) -> ChunkedOrScalar[pa.BooleanScalar]: ... +def is_between( + native: ChunkedOrScalar[ScalarT], + lower: ChunkedOrScalar[ScalarT] | NumericLiteral, + upper: ChunkedOrScalar[ScalarT] | NumericLiteral, + *, + closed: ClosedInterval = "both", +) -> ChunkedOrScalar[pa.BooleanScalar]: + """Check if `native` is between the given `lower` and `upper` bounds.""" + fn_lhs, fn_rhs = _IS_BETWEEN[closed] + low, high = (el if is_arrow(el) else lit(el) for el in (lower, upper)) + out: ChunkedOrScalar[pa.BooleanScalar] = and_( + fn_lhs(native, low), fn_rhs(native, high) + ) + return out + + +@overload +def is_in( + values: ChunkedArrayAny, /, other: ChunkedOrArrayAny +) -> ChunkedArray[pa.BooleanScalar]: ... +@overload +def is_in(values: ArrayAny, /, other: ChunkedOrArrayAny) -> Array[pa.BooleanScalar]: ... +@overload +def is_in(values: ScalarAny, /, other: ChunkedOrArrayAny) -> pa.BooleanScalar: ... +@overload +def is_in( + values: ChunkedOrScalarAny, /, other: ChunkedOrArrayAny +) -> ChunkedOrScalarAny: ... +def is_in(values: ArrowAny, /, other: ChunkedOrArrayAny) -> ArrowAny: + """Check if elements of `values` are present in `other`. + + Roughly equivalent to [`polars.Expr.is_in`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.is_in.html) + + Returns a mask with `len(values)` elements. + """ + # NOTE: Stubs don't include a `ChunkedArray` return + # NOTE: Replaced ambiguous parameter name (`value_set`) + is_in_: Incomplete = pc.is_in + return is_in_(values, other) # type: ignore[no-any-return] + + +@overload +def eq_missing( + native: ChunkedArrayAny, other: NonNestedLiteral | ArrowAny +) -> ChunkedArray[pa.BooleanScalar]: ... +@overload +def eq_missing( + native: ArrayAny, other: NonNestedLiteral | ArrowAny +) -> Array[pa.BooleanScalar]: ... +@overload +def eq_missing( + native: ScalarAny, other: NonNestedLiteral | ArrowAny +) -> pa.BooleanScalar: ... +@overload +def eq_missing( + native: ChunkedOrScalarAny, other: NonNestedLiteral | ArrowAny +) -> ChunkedOrScalarAny: ... +def eq_missing(native: ArrowAny, other: NonNestedLiteral | ArrowAny) -> ArrowAny: + """Equivalent to `native == other` where `None == None`. + + This differs from default `eq` where null values are propagated. + + Note: + Unique to `pyarrow`, this wrapper will ensure `None` uses `native.type`. + """ + if isinstance(other, (pa.Array, pa.ChunkedArray)): + return is_in(native, other) + item = array(other if isinstance(other, pa.Scalar) else lit(other, native.type)) + return is_in(native, item) + + +_IS_BETWEEN: Mapping[ClosedInterval, tuple[BinaryComp, BinaryComp]] = { + "left": (gt_eq, lt), + "right": (gt, lt_eq), + "none": (gt, lt), + "both": (gt_eq, lt_eq), +} diff --git a/narwhals/_plan/arrow/functions/_categorical.py b/narwhals/_plan/arrow/functions/_categorical.py new file mode 100644 index 0000000000..597bd7a8a2 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_categorical.py @@ -0,0 +1,81 @@ +"""Categorical function namespace.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, overload + +import pyarrow as pa # ignore-banned-import + +from narwhals._plan.arrow.functions._construction import array + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from narwhals._plan.arrow.typing import ( + ArrayAny, + ArrowAny, + ChunkedArrayAny, + ChunkedOrArrayHashable, + ) + + Incomplete: TypeAlias = Any + + +__all__ = ["encode", "get_categories"] + + +def get_categories(native: ArrowAny, /) -> ChunkedArrayAny: + """Get the categories stored in the data type. + + Arguments: + native: Dictionary-typed arrow data. + """ + da: Incomplete + if isinstance(native, pa.ChunkedArray): + da = native.unify_dictionaries().chunk(0) + else: + da = native + return pa.chunked_array([da.dictionary]) + + +@overload +def encode(native: ChunkedOrArrayHashable, /) -> pa.Int32Array: ... +@overload +def encode( + native: ChunkedOrArrayHashable, /, *, include_categories: Literal[True] +) -> tuple[ArrayAny, pa.Int32Array]: ... +def encode( + native: ChunkedOrArrayHashable, /, *, include_categories: bool = False +) -> tuple[ArrayAny, pa.Int32Array] | pa.Int32Array: + """Return a [dictionary-encoded] version of the input array. + + Arguments: + native: An arrow array. + include_categories: Include the [`dictionary`] (categories) array in the output. + + Examples: + >>> from narwhals._plan.arrow import functions as fn + >>> values = [None, "foo", "bar", "foo", None, "foo", "ham"] + >>> arr = fn.array(values) + + Return the underlying [`indices`] *into* the ordered categories + + >>> fn.cat.encode(arr).to_pylist() + [0, 1, 2, 1, 0, 1, 3] + + Return the category each index refers to by specifying `include_categories=True` + + >>> categories, indices = fn.cat.encode(arr, include_categories=True) + >>> categories.to_pylist(), indices.to_pylist() + ([None, 'foo', 'bar', 'ham'], [0, 1, 2, 1, 0, 1, 3]) + + [dictionary-encoded]: https://arrow.apache.org/docs/python/generated/pyarrow.compute.dictionary_encode.html#pyarrow.compute.dictionary_encode + [`indices`]: https://arrow.apache.org/docs/python/generated/pyarrow.DictionaryArray.html#pyarrow.DictionaryArray.indices + [`dictionary`]: https://arrow.apache.org/docs/python/generated/pyarrow.DictionaryArray.html#pyarrow.DictionaryArray.dictionary + """ + da: Incomplete = array(native.dictionary_encode("encode")) + indices: pa.Int32Array = da.indices + if not include_categories: + return indices + categories: ArrayAny = da.dictionary + return categories, indices diff --git a/narwhals/_plan/arrow/functions/_construction.py b/narwhals/_plan/arrow/functions/_construction.py new file mode 100644 index 0000000000..64248e4380 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_construction.py @@ -0,0 +1,192 @@ +"""Creating Arrow data and converting between representations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, overload + +import pyarrow as pa # ignore-banned-import + +from narwhals._arrow.utils import concat_tables + +if TYPE_CHECKING: + from collections.abc import Collection, Iterable + + from typing_extensions import TypeAlias + + from narwhals._plan.arrow.typing import ( + ArrayAny, + Arrow, + ArrowAny, + BooleanScalar, + BoolType, + ChunkedArrayAny, + ChunkedOrArrayAny, + DataType, + IntoChunkedArray, + ScalarAny, + UInt32Type, + ) + from narwhals.typing import PythonLiteral + + +__all__ = [ + "array", + "chunked_array", + "concat_horizontal", + "concat_tables", + "concat_vertical", + "lit", + "to_table", +] + +Incomplete: TypeAlias = Any + + +@overload +def lit(value: Any, /) -> ScalarAny: ... +@overload +def lit(value: Any, /, dtype: BoolType) -> pa.BooleanScalar: ... +@overload +def lit(value: Any, /, dtype: UInt32Type) -> pa.UInt32Scalar: ... +@overload +def lit(value: Any, /, dtype: DataType | None = ...) -> ScalarAny: ... +def lit(value: Any, /, dtype: DataType | None = None) -> ScalarAny: + """Convert `value` into a [`Scalar`]. + + Note: + Feel free to add more `@overload`s, but avoid matching on `value`'s type. + If you need this, use [`pa.scalar`] directly but [pyarrow-stubs#208] may cause issues. + + [`Scalar`]: https://arrow.apache.org/docs/python/generated/pyarrow.Scalar.html + [`pa.scalar`]: https://arrow.apache.org/docs/python/generated/pyarrow.scalar.html + [pyarrow-stubs#208]: https://github.com/zen-xu/pyarrow-stubs/pull/208 + """ + return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) + + +# TODO @dangotbanned: Report `ListScalar.values` bug upstream +# See `tests/plan/list_unique_test.py::test_list_unique_scalar[None-None]` +@overload +def array(data: ArrowAny, /) -> ArrayAny: ... +@overload +def array(data: Arrow[BooleanScalar], dtype: BoolType, /) -> pa.BooleanArray: ... +@overload +def array( + data: Iterable[PythonLiteral], dtype: DataType | None = None, / +) -> ArrayAny: ... +def array( + data: ArrowAny | Iterable[PythonLiteral], dtype: DataType | None = None, / +) -> ArrayAny: + """Convert `data` into an [`Array`]. + + Note: + `dtype` is **not used** for existing `pyarrow` data, but it can be used to signal + the concrete `Array` subclass that is returned. + To actually changed the type, use `cast` instead. + + [`Array`]: https://arrow.apache.org/docs/python/generated/pyarrow.Array.html + """ + if isinstance(data, pa.ChunkedArray): + return data.combine_chunks() + if isinstance(data, pa.Array): + return data + if isinstance(data, pa.Scalar): + if isinstance(data, pa.ListScalar) and data.is_valid is False: + return pa.array([None], data.type) + return pa.array([data], data.type) + return pa.array(data, dtype) + + +def chunked_array( + data: IntoChunkedArray, dtype: DataType | None = None, / +) -> ChunkedArrayAny: + """Convert `data` into a [`ChunkedArray`]. + + Arguments: + data: Anything than can be coerced into an array. + A *little* more forgiving than [`pa.chunked_array`]. + dtype: A native `DataType`. + + Examples: + The result of `lit` and `array` can be passed in directly for the same result + + >>> import pyarrow as pa + >>> from narwhals._plan.arrow import functions as fn + >>> one = fn.lit(1) + >>> ones = fn.array(one) + >>> fn.chunked_array(one).equals(fn.chunked_array(ones)) + True + >>> fn.chunked_array(ones) # doctest: +ELLIPSIS + + [ + [ + 1 + ] + ] + + An empty list and a `DataType` produce an empty array + + >>> fn.chunked_array([], pa.string()) # doctest: +ELLIPSIS + + [ + + ] + + Chunks can be specified using nested lists + + >>> short = fn.chunked_array([[1], [2, 2], [3]]) + >>> [c.to_pylist() for c in short.chunks] + [[1], [2, 2], [3]] + + Which is equivalent to using `array` for each chunk + + >>> longer = fn.chunked_array([fn.array([1]), fn.array([2, 2]), fn.array([3])]) + >>> short.equals(longer) + True + + If you're feeling funky, all of these guys work as well + + >>> import numpy as np + >>> import pandas as pd + >>> import polars as pl + >>> im_surprised_too = [ + ... pl.Series([1]), + ... np.array([2]), + ... pd.Series([3]), + ... fn.array([4]), + ... ] + >>> [c.to_pylist() for c in fn.chunked_array(im_surprised_too).chunks] + [[1], [2], [3], [4]] + + [`ChunkedArray`]: https://arrow.apache.org/docs/python/generated/pyarrow.ChunkedArray.html + [`pa.chunked_array`]: https://arrow.apache.org/docs/python/generated/pyarrow.chunked_array.html#pyarrow.chunked_array + """ + arr = array(data) if isinstance(data, pa.Scalar) else data + if isinstance(arr, pa.ChunkedArray): + return arr + func: Incomplete = pa.chunked_array + result: ChunkedArrayAny = func([arr] if not isinstance(arr, list) else arr, dtype) + return result + + +def concat_horizontal( + arrays: Collection[ChunkedOrArrayAny], names: Collection[str] +) -> pa.Table: + """Concatenate `arrays` as columns in a new table.""" + table: Incomplete = pa.Table.from_arrays + result: pa.Table = table(arrays, names) + return result + + +def concat_vertical( + arrays: Iterable[ChunkedOrArrayAny], dtype: DataType | None = None, / +) -> ChunkedArrayAny: + """Concatenate `arrays` into a new array.""" + v_concat: Incomplete = pa.chunked_array + result: ChunkedArrayAny = v_concat(arrays, dtype) + return result + + +def to_table(array: ChunkedOrArrayAny, name: str = "") -> pa.Table: + """Equivalent to `Series.to_frame`, but with an option to insert a name for the column.""" + return concat_horizontal((array,), (name,)) diff --git a/narwhals/_plan/arrow/functions/_cumulative.py b/narwhals/_plan/arrow/functions/_cumulative.py new file mode 100644 index 0000000000..684fd70b63 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_cumulative.py @@ -0,0 +1,60 @@ +"""https://arrow.apache.org/docs/python/api/compute.html#cumulative-functions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan.arrow.functions._boolean import is_not_null +from narwhals._plan.arrow.functions._dtypes import U32 +from narwhals._plan.arrow.functions._sort import reverse +from narwhals._plan.expressions import functions as F + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + from narwhals._plan.arrow.typing import ChunkedArrayAny, ChunkedOrArrayT + + +__all__ = ["cum_count", "cum_max", "cum_min", "cum_prod", "cum_sum", "cumulative"] + + +def cum_sum(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + """Get an array with the cumulative sum computed at every element.""" + return pc.cumulative_sum(native, skip_nulls=True) + + +def cum_min(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + """Get an array with the cumulative min computed at every element.""" + return pc.cumulative_min(native, skip_nulls=True) + + +def cum_max(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + """Get an array with the cumulative max computed at every element.""" + return pc.cumulative_max(native, skip_nulls=True) + + +def cum_prod(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + """Get an array with the cumulative product computed at every element.""" + return pc.cumulative_prod(native, skip_nulls=True) + + +def cum_count(native: ChunkedArrayAny) -> ChunkedArrayAny: + """Return the cumulative count of the non-null values in the array.""" + return cum_sum(is_not_null(native).cast(U32)) + + +def cumulative(native: ChunkedArrayAny, f: F.CumAgg, /) -> ChunkedArrayAny: + """Dispatch on the cumulative function `f`.""" + func = _CUMULATIVE[type(f)] + return func(native) if not f.reverse else reverse(func(reverse(native))) + + +_CUMULATIVE: Mapping[type[F.CumAgg], Callable[[ChunkedArrayAny], ChunkedArrayAny]] = { + F.CumSum: cum_sum, + F.CumCount: cum_count, + F.CumMin: cum_min, + F.CumMax: cum_max, + F.CumProd: cum_prod, +} diff --git a/narwhals/_plan/arrow/functions/_dtypes.py b/narwhals/_plan/arrow/functions/_dtypes.py new file mode 100644 index 0000000000..4066d6ee18 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_dtypes.py @@ -0,0 +1,134 @@ +"""Native data types, conversion and casting.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Final, overload + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._arrow.utils import narwhals_to_native_dtype as _dtype_native + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + + from typing_extensions import TypeIs + + from narwhals._plan.arrow.typing import ( + ChunkedArray, + ChunkedOrScalar, + DataType, + DataTypeRemap, + DataTypeT, + Scalar, + StringType, + ) + from narwhals._utils import Version + from narwhals.typing import IntoArrowSchema, IntoDType + +__all__ = [ # noqa: RUF022 + "BOOL", + "DATE", + "F64", + "I32", + "I64", + "U32", + "cast", + "cast_table", + "dtype_native", + "string_type", + # Not to be exported to `functions.__all__` + "is_integer", + "is_large_string", +] + +# NOTE: Common data type instances to share. +# Names use an uppercase equivalent to [short repr codes] +# (https://github.com/pola-rs/polars/blob/5deaf7e9074fdc8f7f0082974cc956acf645af62/crates/polars-core/src/datatypes/dtype.rs#L1127-L1187) +U32: Final = pa.uint32() +I32: Final = pa.int32() +I64: Final = pa.int64() +F64: Final = pa.float64() +BOOL: Final = pa.bool_() +DATE: Final = pa.date32() + +is_integer: Final = pa.types.is_integer +is_large_string: Final = pa.types.is_large_string + + +@overload +def dtype_native(dtype: IntoDType, /, version: Version) -> DataType: ... +@overload +def dtype_native(dtype: None, /, version: Version) -> None: ... +@overload +def dtype_native(dtype: IntoDType | None, /, version: Version) -> DataType | None: ... +def dtype_native(dtype: IntoDType | None, /, version: Version) -> DataType | None: + """Convert a Narwhals `DType` to a `pyarrow.DataType`, or passthrough `None`.""" + return dtype if dtype is None else _dtype_native(dtype, version) + + +@overload +def cast(native: Scalar[Any], dtype: DataTypeT, /) -> Scalar[DataTypeT]: ... +@overload +def cast( + native: ChunkedArray[Any], dtype: DataTypeT, / +) -> ChunkedArray[Scalar[DataTypeT]]: ... +@overload +def cast( + native: ChunkedOrScalar[Scalar[Any]], dtype: DataTypeT, / +) -> ChunkedArray[Scalar[DataTypeT]] | Scalar[DataTypeT]: ... +def cast( + native: ChunkedOrScalar[Scalar[Any]], dtype: DataTypeT, / +) -> ChunkedArray[Scalar[DataTypeT]] | Scalar[DataTypeT]: + """Cast arrow data to the specified dtype.""" + return pc.cast(native, dtype) + + +def cast_table( + native: pa.Table, dtypes: DataType | IntoArrowSchema | DataTypeRemap, / +) -> pa.Table: + """Cast Table column(s) to the specified dtype(s). + + Similar to [`pl.DataFrame.cast`]. + + Arguments: + native: An arrow table. + dtypes: Mapping of column names (or dtypes) to dtypes, or a single dtype + to which all columns will be cast. + + [`pl.DataFrame.cast`]: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.cast.html#polars.DataFrame.cast + """ + s = dtypes if isinstance(dtypes, pa.Schema) else _cast_schema(native.schema, dtypes) + return native.cast(s) + + +def string_type(dtypes: Iterable[DataType] = (), /) -> StringType: + """Return a native string type, compatible with `dtypes`. + + Until [apache/arrow#45717] is resolved, we need to upcast `string` to `large_string` when joining. + + [apache/arrow#45717]: https://github.com/apache/arrow/issues/45717 + """ + return pa.large_string() if any(is_large_string(tp) for tp in dtypes) else pa.string() + + +def _cast_schema( + native: pa.Schema, dtypes: DataType | Mapping[str, DataType] | DataTypeRemap +) -> pa.Schema: + if isinstance(dtypes, pa.DataType): + return pa.schema((name, dtypes) for name in native.names) + if _is_into_pyarrow_schema(dtypes): + new_schema = native + for name, dtype in dtypes.items(): + index = native.get_field_index(name) + new_schema.set(index, native.field(index).with_type(dtype)) + return new_schema + return pa.schema((fld.name, dtypes.get(fld.type, fld.type)) for fld in native) + + +def _is_into_pyarrow_schema(obj: Mapping[Any, Any]) -> TypeIs[Mapping[str, DataType]]: + return ( + (first := next(iter(obj.items())), None) + and isinstance(first[0], str) + and isinstance(first[1], pa.DataType) + ) diff --git a/narwhals/_plan/arrow/functions/_horizontal.py b/narwhals/_plan/arrow/functions/_horizontal.py new file mode 100644 index 0000000000..5b6f4aecf3 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_horizontal.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import pyarrow.compute as pc # ignore-banned-import + +__all__ = ["max_horizontal", "min_horizontal"] + +# TODO @dangotbanned: Wrap horizontal functions with correct typing +# Should only return scalar if all elements are as well +# NOTE: Changing typing will propagate to a lot of places (so be careful!): +# - `_round.{clip,clip_lower,clip_upper}` +# - `acero.join_asof_tables` +# - `ArrowNamespace.{min,max}_horizontal` +# - `ArrowSeries.rolling_var` +min_horizontal = pc.min_element_wise +max_horizontal = pc.max_element_wise diff --git a/narwhals/_plan/arrow/functions/_lists.py b/narwhals/_plan/arrow/functions/_lists.py new file mode 100644 index 0000000000..b38b910134 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_lists.py @@ -0,0 +1,547 @@ +"""List namespace functions.""" + +from __future__ import annotations + +import builtins +import typing as t +from itertools import chain +from typing import TYPE_CHECKING, Any, Final, overload + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan.arrow.functions._aggregation import implode +from narwhals._plan.arrow.functions._bin_op import eq, gt, not_eq, or_ +from narwhals._plan.arrow.functions._boolean import all, any, eq_missing, is_null, not_ +from narwhals._plan.arrow.functions._construction import ( + array, + chunked_array, + concat_horizontal, + concat_tables, + lit, + to_table, +) +from narwhals._plan.arrow.functions._dtypes import BOOL, U32 +from narwhals._plan.arrow.functions._multiplex import ( + fill_null, + replace_with_mask, + when_then, +) +from narwhals._plan.arrow.functions._ranges import int_range +from narwhals._plan.arrow.functions._sort import sort_indices +from narwhals._plan.arrow.functions.meta import call +from narwhals._plan.options import ExplodeOptions, SortOptions +from narwhals._utils import no_default +from narwhals.exceptions import ShapeError + +if TYPE_CHECKING: + from collections.abc import Callable, Collection, Iterable, Iterator + + from typing_extensions import Self + + from narwhals._plan.arrow.typing import ( + Array, + Arrow, + ArrowAny, + ArrowListT, + BooleanScalar, + ChunkedArray, + ChunkedArrayAny, + ChunkedList, + ChunkedOrArray, + ChunkedOrArrayAny, + ChunkedOrScalar, + ChunkedOrScalarAny, + DataTypeT, + ListArray, + ListScalar, + ListTypeT, + NonListTypeT, + SameArrowT, + Scalar, + ScalarAny, + StringScalar, + StringType, + ) + from narwhals._typing import NoDefault + from narwhals.typing import NonNestedLiteral + + +__all__ = [ + "ExplodeBuilder", + "contains", + "get", + "join", + "join_scalar", + "len", + "sort", + "sort_scalar", + "unique", +] + + +class ExplodeBuilder: + """Tools for exploding lists. + + The complexity of these operations increases with: + - Needing to preserve null/empty elements + - All variants are cheaper if this can be skipped + - Exploding in the context of a table + - Where a single column is much simpler than multiple + """ + + options: ExplodeOptions + + def __init__(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> None: + self.options = ExplodeOptions(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + + @classmethod + def from_options(cls, options: ExplodeOptions, /) -> Self: + obj = cls.__new__(cls) + obj.options = options + return obj + + @overload + def explode( + self, native: ChunkedList[DataTypeT] | ListScalar[DataTypeT] + ) -> ChunkedArray[Scalar[DataTypeT]]: ... + @overload + def explode(self, native: ListArray[DataTypeT]) -> Array[Scalar[DataTypeT]]: ... + @overload + def explode( + self, native: Arrow[ListScalar[DataTypeT]] + ) -> ChunkedOrArray[Scalar[DataTypeT]]: ... + def explode( + self, native: Arrow[ListScalar[DataTypeT]] + ) -> ChunkedOrArray[Scalar[DataTypeT]]: + """Explode list elements, expanding one-level into a new array. + + Equivalent to `polars.{Expr,Series}.explode`. + """ + safe = self._fill_with_null(native) if self.options.any() else native + if not isinstance(safe, pa.Scalar): + return _explode(safe) + return chunked_array(_explode(safe)) + + def explode_with_indices(self, native: ChunkedList | ListArray) -> pa.Table: + """Explode list elements, expanding one-level into a table indexing the origin. + + Returns a 2-column table, with names `"idx"` and `"values"`: + + >>> from narwhals._plan.arrow import functions as fn + >>> + >>> arr = fn.array([[1, 2, 3], None, [4, 5, 6], []]) + >>> fn.ExplodeBuilder().explode_with_indices(arr).to_pydict() + {'idx': [0, 0, 0, 1, 2, 2, 2, 3], 'values': [1, 2, 3, None, 4, 5, 6, None]} + # ^ Which sublist values come from ^ The exploded values themselves + """ + safe = self._fill_with_null(native) if self.options.any() else native + arrays = [_list_parent_indices(safe), _explode(safe)] + return concat_horizontal(arrays, ["idx", "values"]) + + def explode_column(self, native: pa.Table, column_name: str, /) -> pa.Table: + """Explode a list-typed column in the context of `native`.""" + ca = native.column(column_name) + if native.num_columns == 1: + return native.from_arrays([self.explode(ca)], [column_name]) + safe = self._fill_with_null(ca) if self.options.any() else ca + exploded = _explode(safe) + col_idx = native.schema.get_field_index(column_name) + if exploded.length() == native.num_rows: + return native.set_column(col_idx, column_name, exploded) + return ( + native.remove_column(col_idx) + .take(_list_parent_indices(safe)) + .add_column(col_idx, column_name, exploded) + ) + + def explode_columns(self, native: pa.Table, subset: Collection[str], /) -> pa.Table: + """Explode multiple list-typed columns in the context of `native`.""" + subset = list(subset) + arrays = native.select(subset).columns + first = arrays[0] + first_len = len(first) + if self.options.any(): + mask = self._predicate(first_len) + first_safe = self._fill_with_null(first, mask) + it = ( + _explode(self._fill_with_null(arr, mask)) + for arr in self._iter_ensure_shape(first_len, arrays[1:]) + ) + else: + first_safe = first + it = (_explode(arr) for arr in self._iter_ensure_shape(first_len, arrays[1:])) + column_names = native.column_names + result = native + first_result = _explode(first_safe) + if first_result.length() == native.num_rows: + # fastpath for all length-1 lists + # if only the first is length-1, then the others raise during iteration on either branch + for name, arr in zip(subset, chain([first_result], it)): + result = result.set_column(column_names.index(name), name, arr) + else: + result = result.drop_columns(subset).take(_list_parent_indices(first_safe)) + for name, arr in zip(subset, chain([first_result], it)): + result = result.append_column(name, arr) + result = result.select(column_names) + return result + + @classmethod + def explode_column_fast(cls, native: pa.Table, column_name: str, /) -> pa.Table: + """Explode a list-typed column in the context of `native`, ignoring empty and nulls.""" + return cls(empty_as_null=False, keep_nulls=False).explode_column( + native, column_name + ) + + def _iter_ensure_shape( + self, + first_len: ChunkedArray[pa.UInt32Scalar], + arrays: Iterable[ChunkedArrayAny], + /, + ) -> Iterator[ChunkedArrayAny]: + for arr in arrays: + if not first_len.equals(len(arr)): + msg = "exploded columns must have matching element counts" + raise ShapeError(msg) + yield arr + + def _predicate(self, lengths: ArrowAny, /) -> Arrow[pa.BooleanScalar]: + """Return True for each sublist length that indicates the original sublist should be replaced with `[None]`.""" + empty_as_null, keep_nulls = self.options.empty_as_null, self.options.keep_nulls + if empty_as_null and keep_nulls: + return or_(is_null(lengths), eq(lengths, lit(0))) + if empty_as_null: + return eq(lengths, lit(0)) + return is_null(lengths) + + def _fill_with_null( + self, native: ArrowListT, mask: Arrow[BooleanScalar] | NoDefault = no_default + ) -> ArrowListT: + """Replace each sublist in `native` with `[None]`, according to `self.options`. + + Arguments: + native: List-typed arrow data. + mask: An optional, pre-computed replacement mask. By default, this is generated from `native`. + """ + predicate = self._predicate(len(native)) if mask is no_default else mask + result: ArrowListT = when_then(predicate, lit([None], native.type), native) + return result + + +@overload +def len(native: ChunkedList) -> ChunkedArray[pa.UInt32Scalar]: ... +@overload +def len(native: ListArray) -> pa.UInt32Array: ... +@overload +def len(native: ListScalar) -> pa.UInt32Scalar: ... +@overload +def len(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[pa.UInt32Scalar]: ... +@overload +def len(native: Arrow[ListScalar[Any]]) -> Arrow[pa.UInt32Scalar]: ... +def len(native: ArrowAny) -> ArrowAny: + """Return the number of elements in each sublist. + + Null values count towards the total. + + Arguments: + native: List-typed arrow data. + + Important: + This is **not** [`builtins.len`]! + + [`builtins.len`]: https://docs.python.org/3/library/functions.html#len + """ + result: ArrowAny = call("list_value_length", native).cast(U32) + return result + + +@overload +def get( + native: ChunkedList[DataTypeT], index: int +) -> ChunkedArray[Scalar[DataTypeT]]: ... +@overload +def get(native: ListArray[DataTypeT], index: int) -> Array[Scalar[DataTypeT]]: ... +@overload +def get(native: ListScalar[DataTypeT], index: int) -> Scalar[DataTypeT]: ... +@overload +def get(native: SameArrowT, index: int) -> SameArrowT: ... +@overload +def get(native: ChunkedOrScalarAny, index: int) -> ChunkedOrScalarAny: ... +def get(native: ArrowAny, index: int) -> ArrowAny: + """Get the value by index in the sublists. + + Arguments: + native: List-typed arrow data. + index: Index to return per sublist. + """ + result: ArrowAny = call("list_element", native, index) + return result + + +EMPTY: Final = "" +"""The empty string.""" + + +# NOTE: Raised for native null-handling (https://github.com/apache/arrow/issues/48477) +@overload +def join( + native: ChunkedList[StringType], + separator: Arrow[StringScalar] | str, + *, + ignore_nulls: bool = ..., +) -> ChunkedArray[StringScalar]: ... +@overload +def join( + native: ListArray[StringType], + separator: Arrow[StringScalar] | str, + *, + ignore_nulls: bool = ..., +) -> pa.StringArray: ... +@overload +def join( + native: ChunkedOrArray[ListScalar[StringType]], + separator: str, + *, + ignore_nulls: bool = ..., +) -> ChunkedOrArray[StringScalar]: ... +def join( + native: ChunkedOrArrayAny, + separator: Arrow[StringScalar] | str, + *, + ignore_nulls: bool = True, +) -> ChunkedOrArrayAny: + """Join all string items in a sublist and place a separator between them. + + Arguments: + native: List-typed arrow data, where the inner type is String. + separator: String to separate the items with + ignore_nulls: If set to False, null values will be propagated. + If the sub-list contains any null values, the output is None. + """ + from narwhals._plan.arrow.group_by import AggSpec + + # (1): Try to return *as-is* from `pc.binary_join` + result = _list_join(native, separator) + if not ignore_nulls or not result.null_count: + return result + is_null_sensitive = pc.and_not(result.is_null(), native.is_null()) + if array(is_null_sensitive, BOOL).true_count == 0: + return result + + # (2): Deal with only the bad kids + lists = native.filter(is_null_sensitive) + + # (2.1): We know that `[None]` should join as `""`, and that is the only length-1 list we could have after the filter + list_len_eq_1 = eq(len(lists), lit(1, U32)) + has_a_len_1_null = any(list_len_eq_1).as_py() + if has_a_len_1_null: + lists = when_then(list_len_eq_1, lit([EMPTY], lists.type), lists) + + # (2.2): Everything left falls into one of these boxes: + # - (2.1): `[""]` + # - (2.2): `["something", (str | None)*, None]` <--- We fix this here and hope for the best + # - (2.3): `[None, (None)*, None]` + idx, v = "idx", "values" + builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) + explode_w_idx = builder.explode_with_indices(lists) + implode_by_idx = AggSpec.implode(v).over(explode_w_idx.drop_null(), [idx]) + replacements = _list_join(implode_by_idx.column(v), separator) + + # (2.3): The cursed box 😨 + if builtins.len(replacements) != builtins.len(lists): + # This is a very unlucky case to hit, because we *can* detect the issue earlier + # but we *can't* join a table with a list in it. So we deal with the fallout now ... + # The end result is identical to (2.1) + indices_all = to_table(explode_w_idx.column(idx).unique(), idx) + indices_repaired = implode_by_idx.set_column(1, v, replacements) + replacements = ( + indices_all.join(indices_repaired, idx) + .sort_by(idx) + .column(v) + .fill_null(lit(EMPTY, lists.type.value_type)) + ) + return replace_with_mask(result, is_null_sensitive, replacements) + + +def join_scalar( + native: ListScalar[StringType], + separator: StringScalar | str, + *, + ignore_nulls: bool = True, +) -> StringScalar: + """Join all string items in a `ListScalar` and place a separator between them. + + Note: + Consider using `list_join` or `str_join` if you don't already have `native` in this shape. + """ + if ignore_nulls and native.is_valid: + native = implode(_explode(native).drop_null()) + result: StringScalar = call("binary_join", native, separator) + return result + + +@overload +def unique(native: ChunkedList) -> ChunkedList: ... +@overload +def unique(native: ListScalar) -> ListScalar: ... +@overload +def unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScalar]: ... +def unique(native: ChunkedOrScalar[ListScalar]) -> ChunkedOrScalar[ListScalar]: + """Get the distinct values in each sublist. + + Arguments: + native: List-typed arrow data. + + There's lots of tricky stuff going on in here, but for good reasons! + + Whenever possible, we want to avoid having to deal with these pesky guys: + + [["okay", None, "still fine"], None, []] + # ^^^^ ^^ + + - Those kinds of list elements are ignored natively + - `unique` is length-changing operation + - We can't use [`pc.replace_with_mask`] on a list + - We can't join when a table contains list columns [apache/arrow#43716] + + **But** - if we're lucky, and we got a non-awful list (or only one element) - then + most issues vanish. + + [`pc.replace_with_mask`]: https://arrow.apache.org/docs/python/generated/pyarrow.compute.replace_with_mask.html + [apache/arrow#43716]: https://github.com/apache/arrow/issues/43716 + """ + from narwhals._plan.arrow.group_by import AggSpec + + if isinstance(native, pa.Scalar): + scalar = _typing_list_scalar(native) + if scalar.is_valid and (builtins.len(scalar) > 1): + return implode(_explode(native).unique()) + return scalar + idx, v = "index", "values" + names = idx, v + len_not_eq_0 = not_eq(len(native), lit(0)) + can_fastpath = all(len_not_eq_0, ignore_nulls=False).as_py() + if can_fastpath: + arrays = [_list_parent_indices(native), _explode(native)] + return AggSpec.unique(v).over_index(concat_horizontal(arrays, names), idx) + # Oh no - we caught a bad one! + # We need to split things into good/bad - and only work on the good stuff. + # `int_range` is acting like `parent_indices`, but doesn't give up when it see's `None` or `[]` + indexed = concat_horizontal([int_range(native.length()), native], names) + valid = indexed.filter(len_not_eq_0) + invalid = indexed.filter(or_(native.is_null(), not_(len_not_eq_0))) + # To keep track of where we started, our index needs to be exploded with the list elements + explode_with_index = ExplodeBuilder.explode_column_fast(valid, v) + valid_unique = AggSpec.unique(v).over(explode_with_index, [idx]) + # And now, because we can't join - we do a poor man's version of one 😉 + return concat_tables([valid_unique, invalid]).sort_by(idx).column(v) + + +def contains( + native: ChunkedOrScalar[ListScalar], item: NonNestedLiteral | ScalarAny +) -> ChunkedOrScalar[pa.BooleanScalar]: + """Check if sublists contain the given item. + + Arguments: + native: List-typed arrow data. + item: Item that will be checked for membership + """ + from narwhals._plan.arrow.group_by import AggSpec + + if isinstance(native, pa.Scalar): + scalar = _typing_list_scalar(native) + if scalar.is_valid: + if builtins.len(scalar): + value_type = scalar.type.value_type + return any(eq_missing(_explode(scalar), lit(item).cast(value_type))) + return lit(False, BOOL) + return lit(None, BOOL) + builder = ExplodeBuilder(empty_as_null=False, keep_nulls=False) + tbl = builder.explode_with_indices(native) + idx, name = tbl.column_names + contains = eq_missing(tbl.column(name), item) + l_contains = AggSpec.any(name).over_index(tbl.set_column(1, name, contains), idx) + # Here's the really key part: this mask has the same result we want to return + # So by filling the `True`, we can flip those to `False` if needed + # But if we were already `None` or `False` - then that's sticky + propagate_invalid: ChunkedArray[pa.BooleanScalar] = not_eq(len(native), lit(0)) + return replace_with_mask(propagate_invalid, propagate_invalid, l_contains) + + +def sort( + native: ChunkedList, *, descending: bool = False, nulls_last: bool = False +) -> ChunkedList: + """Sort the sublists in this column. + + Works in a similar way to `list_unique` and `list_join`. + + 1. Select only sublists that require sorting (`None`, 0-length, and 1-length lists are noops) + 2. Explode -> Sort -> Implode -> Concat + """ + from narwhals._plan.arrow.group_by import AggSpec + + idx, v = "idx", "values" + is_not_sorted = gt(len(native), lit(1)) + indexed = concat_horizontal([int_range(native.length()), native], [idx, v]) + exploded = ExplodeBuilder.explode_column_fast(indexed.filter(is_not_sorted), v) + indices = sort_indices( + exploded, idx, v, descending=[False, descending], nulls_last=nulls_last + ) + exploded_sorted = exploded.take(indices) + implode_by_idx = AggSpec.implode(v).over(exploded_sorted, [idx]) + passthrough = indexed.filter(fill_null(not_(is_not_sorted), True)) + return concat_tables([implode_by_idx, passthrough]).sort_by(idx).column(v) + + +# TODO @dangotbanned: Docstring? +def sort_scalar( + native: ListScalar[NonListTypeT], options: SortOptions | None = None +) -> pa.ListScalar[NonListTypeT]: + native = _typing_list_scalar(native) + if native.is_valid and builtins.len(native) > 1: + arr = _explode(native) + return implode(arr.take(sort_indices(arr, options=options))) + return native + + +def _typing_list_scalar(native: ListScalar[DataTypeT], /) -> pa.ListScalar[DataTypeT]: + """**Runtime noop**. + + Just performs a useful `typing.cast`: + + pa.Scalar[pa.ListType[DataTypeT]] # This isn't a real thing at runtime + pa.ListScalar[DataTypeT] # Defines: `values`, `__len__` + """ + return t.cast("pa.ListScalar[DataTypeT]", native) + + +_list_join = t.cast( + "Callable[[ChunkedOrArrayAny, Arrow[StringScalar] | str], ChunkedArray[StringScalar] | pa.StringArray]", + pc.binary_join, +) + + +@overload +def _explode(native: ChunkedList[DataTypeT]) -> ChunkedArray[Scalar[DataTypeT]]: ... +@overload +def _explode( + native: ListArray[NonListTypeT] | ListScalar[NonListTypeT], +) -> Array[Scalar[NonListTypeT]]: ... +@overload +def _explode(native: ListArray[DataTypeT]) -> Array[Scalar[DataTypeT]]: ... +@overload +def _explode(native: ListScalar[ListTypeT]) -> ListArray[ListTypeT]: ... +def _explode(native: Arrow[ListScalar]) -> ChunkedOrArrayAny: + result: ChunkedOrArrayAny = call("list_flatten", native) + return result + + +@overload +def _list_parent_indices(native: ChunkedList) -> ChunkedArray[pa.Int64Scalar]: ... +@overload +def _list_parent_indices(native: ListArray) -> pa.Int64Array: ... +def _list_parent_indices( + native: ChunkedOrArray[ListScalar], +) -> ChunkedOrArray[pa.Int64Scalar]: + result: ChunkedOrArray[pa.Int64Scalar] = call("list_parent_indices", native) + return result diff --git a/narwhals/_plan/arrow/functions/_multiplex.py b/narwhals/_plan/arrow/functions/_multiplex.py new file mode 100644 index 0000000000..c77e1daaf4 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_multiplex.py @@ -0,0 +1,239 @@ +"""Conditional [selection] and [fill/replacement] functions. + +[selection]: https://arrow.apache.org/docs/python/api/compute.html#selecting-multiplexing +[fill/replacement]: https://arrow.apache.org/docs/python/api/compute.html#structural-transforms +""" + +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING, Any, overload + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan._guards import is_non_nested_literal +from narwhals._plan.arrow.functions._arithmetic import sub +from narwhals._plan.arrow.functions._bin_op import and_, gt, not_eq, or_ +from narwhals._plan.arrow.functions._boolean import any, is_not_nan, is_not_null, is_null +from narwhals._plan.arrow.functions._construction import array, chunked_array, lit +from narwhals._plan.arrow.functions._cumulative import cum_max +from narwhals._plan.arrow.functions._ranges import int_range +from narwhals._plan.arrow.functions._sort import reverse +from narwhals._plan.arrow.functions.meta import call + +if TYPE_CHECKING: + from collections.abc import Mapping + + from narwhals._arrow.typing import Incomplete + from narwhals._plan.arrow.typing import ( + Array, + ArrayAny, + ArrowAny, + ArrowT, + BooleanScalar, + ChunkedArray, + ChunkedArrayAny, + ChunkedOrArrayAny, + ChunkedOrArrayT, + ChunkedOrScalarAny, + ChunkedOrScalarT, + Predicate, + SameArrowT, + ScalarAny, + UnaryFunction, + ) + from narwhals._plan.typing import Seq + from narwhals.typing import FillNullStrategy, NonNestedLiteral + + +__all__ = [ + "fill_nan", + "fill_null", + "fill_null_with_strategy", + "preserve_nulls", + "replace_strict", + "replace_strict_default", + "replace_with_mask", + "when_then", +] + + +@overload +def when_then( + predicate: ChunkedArray[BooleanScalar], then: ScalarAny +) -> ChunkedArrayAny: ... +@overload +def when_then(predicate: Array[BooleanScalar], then: ScalarAny) -> ArrayAny: ... +@overload +def when_then( + predicate: Predicate, then: SameArrowT, otherwise: SameArrowT | None +) -> SameArrowT: ... +@overload +def when_then(predicate: Predicate, then: ScalarAny, otherwise: ArrowT) -> ArrowT: ... +@overload +def when_then( + predicate: Predicate, then: ArrowT, otherwise: ScalarAny | NonNestedLiteral = ... +) -> ArrowT: ... +@overload +def when_then( + predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None +) -> Incomplete: ... +def when_then( + predicate: Predicate, then: ArrowAny, otherwise: ArrowAny | NonNestedLiteral = None +) -> Incomplete: + """Return elements from `then` or `otherwise` depending on `predicate`. + + Thin wrapper around [`pc.if_else`], with two tweaks + *some* typing: + - Supports a 2-argument form, like `pl.when(...).then(...)` + - Accepts python literals, but only in the `otherwise` position + + [`pc.if_else`]: https://arrow.apache.org/docs/python/generated/pyarrow.compute.if_else.html + """ + if is_non_nested_literal(otherwise): + otherwise = lit(otherwise, then.type) + return pc.if_else(predicate, then, otherwise) + + +@overload +def replace_with_mask( + native: ChunkedOrArrayT, mask: Predicate, replacements: ChunkedOrArrayAny +) -> ChunkedOrArrayT: ... +@overload +def replace_with_mask( + native: ChunkedOrArrayAny, mask: Predicate, replacements: ChunkedOrArrayAny +) -> ChunkedOrArrayAny: ... +def replace_with_mask( + native: ChunkedOrArrayAny, mask: Predicate, replacements: ChunkedOrArrayAny +) -> ChunkedOrArrayAny: + """Replace elements of `native`, at positions defined by `mask`. + + The length of `replacements` must equal the number of `True` values in `mask`. + """ + args = (array(p) for p in (native, mask, replacements)) + result: ChunkedOrArrayAny = call("replace_with_mask", *args) + if isinstance(native, pa.ChunkedArray): + return chunked_array(result) + return result + + +def replace_strict( + native: ChunkedOrScalarAny, + old: Seq[Any], + new: Seq[Any], + dtype: pa.DataType | None = None, +) -> ChunkedOrScalarAny: + """Replace all values (`old`) by different values (`new`). + + Raises if any values in `native` were not replaced. + """ + if isinstance(native, pa.Scalar): + idxs: ArrayAny = array(pc.index_in(native, pa.array(old))) + result: ChunkedOrScalarAny = pa.array(new).take(idxs)[0] + else: + idxs = pc.index_in(native, pa.array(old)) + result = chunked_array(pa.array(new).take(idxs)) + if err := _ensure_all_replaced(native, and_(is_not_null(native), is_null(idxs))): + raise err + return result.cast(dtype) if dtype else result + + +def replace_strict_default( + native: ChunkedOrScalarAny, + old: Seq[Any], + new: Seq[Any], + default: ChunkedOrScalarAny, + dtype: pa.DataType | None = None, +) -> ChunkedOrScalarAny: + """Replace all values (`old`) by different values (`new`). + + Sets any values that were not replaced in `native` to `default`. + """ + idxs = pc.index_in(native, pa.array(old)) + result = pa.array(new).take(array(idxs)) + result = when_then(is_null(idxs), default, result.cast(dtype) if dtype else result) + return chunked_array(result) if isinstance(native, pa.ChunkedArray) else result[0] + + +# TODO @dangotbanned: Avoid using `TypeVar` constraints on `after` +# Only used in `_vector.rank` and `ChunkedOrArrayT` erases `Array[...]` +def preserve_nulls( + before: ChunkedOrArrayAny, after: ChunkedOrArrayT, / +) -> ChunkedOrArrayT: + """Propagate nulls positionally from `before` to `after`.""" + return when_then(is_not_null(before), after) if before.null_count else after + + +@t.overload +def fill_null( + native: ChunkedOrScalarT, value: NonNestedLiteral | ArrowAny +) -> ChunkedOrScalarT: ... +@t.overload +def fill_null( + native: ChunkedOrArrayT, value: ScalarAny | NonNestedLiteral | ChunkedOrArrayT +) -> ChunkedOrArrayT: ... +@t.overload +def fill_null( + native: ChunkedOrScalarAny, value: ChunkedOrScalarAny | NonNestedLiteral +) -> ChunkedOrScalarAny: ... +def fill_null(native: ArrowAny, value: ArrowAny | NonNestedLiteral) -> ArrowAny: + """Fill null values with `value`.""" + fill_value: Incomplete = value + result: ArrowAny = pc.fill_null(native, fill_value) + return result + + +@t.overload +def fill_nan( + native: ChunkedOrScalarT, value: NonNestedLiteral | ArrowAny +) -> ChunkedOrScalarT: ... +@t.overload +def fill_nan(native: SameArrowT, value: NonNestedLiteral | ArrowAny) -> SameArrowT: ... +def fill_nan(native: ArrowAny, value: NonNestedLiteral | ArrowAny) -> Incomplete: + """Fill floating point NaN values with `value`.""" + return when_then(is_not_nan(native), native, value) + + +def fill_null_with_strategy( + native: ChunkedArrayAny, strategy: FillNullStrategy, limit: int | None = None +) -> ChunkedArrayAny: + """Fill null values with `strategy`, optionally to at most `limit` consecutive null values.""" + null_count = native.null_count + if null_count == 0 or (null_count == len(native)): + return native + if limit is None: + return _FILL_NULL_STRATEGY[strategy](native) + if strategy == "forward": + return _fill_null_forward_limit(native, limit) + return reverse(_fill_null_forward_limit(reverse(native), limit)) + + +def _ensure_all_replaced( + native: ChunkedOrScalarAny, unmatched: ArrowAny +) -> ValueError | None: + if not any(unmatched).as_py(): + return None + msg = ( + "replace_strict did not replace all non-null values.\n\n" + f"The following did not get replaced: {chunked_array(native).filter(array(unmatched)).unique().to_pylist()}" + ) + return ValueError(msg) + + +def _fill_null_forward_limit(native: ChunkedArrayAny, limit: int) -> ChunkedArrayAny: + SENTINEL = lit(-1) # noqa: N806 + is_not_null = native.is_valid() + index = int_range(len(native), chunked=False) + index_not_null = cum_max(when_then(is_not_null, index, SENTINEL)) + # NOTE: The correction here is for nulls at either end of the array + # They should be preserved when the `strategy` would need an out-of-bounds index + not_oob = not_eq(index_not_null, SENTINEL) + index_not_null = when_then(not_oob, index_not_null) + beyond_limit = gt(sub(index, index_not_null), lit(limit)) + return when_then(or_(is_not_null, beyond_limit), native, native.take(index_not_null)) + + +_FILL_NULL_STRATEGY: Mapping[FillNullStrategy, UnaryFunction] = { + "forward": pc.fill_null_forward, + "backward": pc.fill_null_backward, +} diff --git a/narwhals/_plan/arrow/functions/_ranges.py b/narwhals/_plan/arrow/functions/_ranges.py new file mode 100644 index 0000000000..7df84a9463 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_ranges.py @@ -0,0 +1,148 @@ +"""Range generation functions.""" + +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING, Any, Literal, overload + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan.arrow import compat +from narwhals._plan.arrow.functions._arithmetic import add, multiply +from narwhals._plan.arrow.functions._construction import chunked_array, lit +from narwhals._plan.arrow.functions._dtypes import DATE, F64, I32, I64 + +if TYPE_CHECKING: + import datetime as dt + + from typing_extensions import TypeAlias + + from narwhals._plan.arrow.typing import ( + Array, + ArrayAny, + ChunkedArray, + ChunkedOrArray, + DateScalar, + IntegerScalar, + IntegerType, + ) + from narwhals.typing import ClosedInterval + + +__all__ = ["date_range", "int_range", "linear_space"] + +Incomplete: TypeAlias = Any + + +@overload +def int_range( + start: int = ..., + end: int | None = ..., + step: int = ..., + /, + *, + dtype: IntegerType = ..., + chunked: Literal[True] = ..., +) -> ChunkedArray[IntegerScalar]: ... +@overload +def int_range( + start: int = ..., + end: int | None = ..., + step: int = ..., + /, + *, + chunked: Literal[False], +) -> pa.Int64Array: ... +@overload +def int_range( + start: int = ..., + end: int | None = ..., + step: int = ..., + /, + *, + dtype: IntegerType = ..., + chunked: Literal[False], +) -> Array[IntegerScalar]: ... +def int_range( + start: int = 0, + end: int | None = None, + step: int = 1, + /, + *, + dtype: IntegerType = I64, + chunked: bool = True, +) -> ChunkedOrArray[IntegerScalar]: + """Generate a range of integers.""" + if end is None: + end = start + start = 0 + if not compat.HAS_ARANGE: # pragma: no cover + import numpy as np # ignore-banned-import + + arr = pa.array(np.arange(start, end, step), type=dtype) + else: + int_range_: Incomplete = pa.arange # type: ignore[attr-defined] + arr = t.cast("ArrayAny", int_range_(start, end, step)).cast(dtype) + return arr if not chunked else pa.chunked_array([arr]) + + +def date_range( + start: dt.date, end: dt.date, interval: int, *, closed: ClosedInterval = "both" +) -> ChunkedArray[DateScalar]: + """Generate a range of dates. + + Note: + `interval` is the number of full days. + """ + start_i = pa.scalar(start).cast(I32).as_py() + end_i = pa.scalar(end).cast(I32).as_py() + ca = int_range(start_i, end_i + 1, interval, dtype=I32) + if closed == "both": + return ca.cast(DATE) + if closed == "left": + ca = ca.slice(length=ca.length() - 1) + elif closed == "none": + ca = ca.slice(1, length=ca.length() - 1) + else: + ca = ca.slice(1) + return ca.cast(DATE) + + +def linear_space( + start: float, end: float, num_samples: int, *, closed: ClosedInterval = "both" +) -> ChunkedArray[pc.NumericScalar]: + """Generate a range of evenly-spaced floats. + + Based on [`new_linear_space_f64`]. + + [`new_linear_space_f64`]: https://github.com/pola-rs/polars/blob/1684cc09dfaa46656dfecc45ab866d01aa69bc78/crates/polars-ops/src/series/ops/linear_space.rs#L62-L94 + """ + if num_samples < 0: + msg = f"Number of samples, {num_samples}, must be non-negative." + raise ValueError(msg) + if num_samples == 0: + return chunked_array([], F64) + if num_samples == 1: + if closed == "none": + value = (end + start) * 0.5 + elif closed in {"left", "both"}: + value = float(start) + else: + value = float(end) + return chunked_array(lit(value, F64)) + n = num_samples + span = float(end - start) + if closed == "none": + d = span / (n + 1) + start = start + d + elif closed == "left": + d = span / n + elif closed == "right": + start = start + span / n + d = span / n + else: + d = span / (n - 1) + ca = multiply(int_range(0, n).cast(F64), lit(d)) + ca = add(ca, lit(start, F64)) + return ca # noqa: RET504 diff --git a/narwhals/_plan/arrow/functions/_repeat.py b/narwhals/_plan/arrow/functions/_repeat.py new file mode 100644 index 0000000000..17726697e2 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_repeat.py @@ -0,0 +1,66 @@ +"""Create known-length `pa.Array`s, filled with a single value. + +Note: + These wrappers should be preferred when the lack of precision in input types causes + false negatives and/or LSP hangs in the `pyarrow-stubs` overloads. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pyarrow as pa # ignore-banned-import + +from narwhals._plan.arrow.functions._construction import lit + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from narwhals._plan.arrow.typing import ArrayAny, ArrowAny, ScalarAny + from narwhals.typing import NonNestedLiteral + +Incomplete: TypeAlias = Any + +__all__ = ["nulls_like", "repeat", "repeat_like", "repeat_unchecked", "zeros"] + + +def nulls_like(n: int, /, native: ArrowAny) -> ArrayAny: + """Create an Array of length `n` filled with nulls. + + Uses the type of `native`, where `pa.nulls` defaults to `pa.NullType`. + """ + result: ArrayAny = pa.nulls(n, native.type) + return result + + +def repeat(value: ScalarAny | NonNestedLiteral, /, n: int) -> ArrayAny: + """Create an Array of length `n` filled with the given `value`. + + Adds an additional check and coerces `NonNestedLiteral` through `pa.Scalar`. + + Tip: + If you *already* know `pa.Scalar` is the only possible input, + use `repeat_unchecked` instead. + """ + value = value if isinstance(value, pa.Scalar) else lit(value) + return repeat_unchecked(value, n) + + +def repeat_like(value: NonNestedLiteral, /, n: int, native: ArrowAny) -> ArrayAny: + """Create an Array of length `n` filled with the given `value`. + + Uses the type of `native`. + """ + return repeat_unchecked(lit(value, native.type), n) + + +def repeat_unchecked(value: ScalarAny, /, n: int) -> ArrayAny: + """Create an Array of length `n` filled with the given `value`.""" + repeat_: Incomplete = pa.repeat + result: ArrayAny = repeat_(value, n) + return result + + +def zeros(n: int, /) -> pa.Int64Array: + """Create an Array of length `n` filled with zeros.""" + return pa.repeat(0, n) diff --git a/narwhals/_plan/arrow/functions/_round.py b/narwhals/_plan/arrow/functions/_round.py new file mode 100644 index 0000000000..ad0d98592f --- /dev/null +++ b/narwhals/_plan/arrow/functions/_round.py @@ -0,0 +1,61 @@ +"""Round underlying floating point data. + +This group is derived from the (rust) polars [feature] [`round_series`]. + +[feature]: https://docs.rs/polars/latest/polars/#compile-times-and-opt-in-features +[`round_series`]: https://github.com/search?q=repo%3Apola-rs%2Fpolars+path%3A%2F%5Ecrates%5C%2Fpolars-plan%5C%2Fsrc%5C%2Fdsl%5C%2F%2F+%23%5Bcfg%28feature+%3D+%22round_series%22%29%5D&type=code +""" + +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING, overload + +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan.arrow.functions._horizontal import max_horizontal, min_horizontal + +if TYPE_CHECKING: + from narwhals._plan.arrow.typing import ( + ArrowAny, + ChunkedOrArrayT, + ChunkedOrScalarAny, + UnaryNumeric, + ) + +__all__ = ["ceil", "clip", "clip_lower", "clip_upper", "floor", "round"] + +ceil = t.cast("UnaryNumeric", pc.ceil) +"""Rounds up to the nearest integer value.""" +floor = t.cast("UnaryNumeric", pc.floor) +"""Rounds down to the nearest integer value.""" + + +@overload +def round(native: ChunkedOrScalarAny, decimals: int = 0) -> ChunkedOrScalarAny: ... +@overload +def round(native: ChunkedOrArrayT, decimals: int = 0) -> ChunkedOrArrayT: ... +def round(native: ArrowAny, decimals: int = 0) -> ArrowAny: + """Round underlying floating point data by `decimals` digits.""" + return pc.round(native, decimals, round_mode="half_towards_infinity") + + +def clip_lower( + native: ChunkedOrScalarAny, lower: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + """Limit values to at-least `lower`.""" + return max_horizontal(native, lower) + + +def clip_upper( + native: ChunkedOrScalarAny, upper: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + """Limit values to at-most `upper`.""" + return min_horizontal(native, upper) + + +def clip( + native: ChunkedOrScalarAny, lower: ChunkedOrScalarAny, upper: ChunkedOrScalarAny +) -> ChunkedOrScalarAny: + """Set values outside the given boundaries to the boundary value.""" + return clip_lower(clip_upper(native, upper), lower) diff --git a/narwhals/_plan/arrow/functions/_sort.py b/narwhals/_plan/arrow/functions/_sort.py new file mode 100644 index 0000000000..f900e574b8 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_sort.py @@ -0,0 +1,149 @@ +"""Functions for manipulating the order of arrays.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, overload + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan.arrow import compat, options as pa_options +from narwhals._plan.arrow.functions._arithmetic import multiply +from narwhals._plan.arrow.functions._construction import array, lit +from narwhals._plan.arrow.functions._dtypes import I64 +from narwhals._plan.arrow.functions._ranges import int_range +from narwhals._plan.arrow.functions._round import round +from narwhals._plan.arrow.functions.meta import call + +if TYPE_CHECKING: + from collections.abc import Sequence + + from typing_extensions import Unpack + + from narwhals._plan.arrow.typing import ArrayAny, ChunkedOrArrayAny, ChunkedOrArrayT + from narwhals._plan.options import SortMultipleOptions, SortOptions + + +__all__ = ["random_indices", "reverse", "sort_indices", "unsort_indices"] + + +@overload +def sort_indices( + native: ChunkedOrArrayAny, *, options: SortOptions | None +) -> pa.UInt64Array: ... +@overload +def sort_indices( + native: ChunkedOrArrayAny, *, descending: bool = ..., nulls_last: bool = ... +) -> pa.UInt64Array: ... +@overload +def sort_indices( + native: pa.Table, + *by: Unpack[tuple[str, Unpack[tuple[str, ...]]]], + options: SortOptions | SortMultipleOptions | None, +) -> pa.UInt64Array: ... +@overload +def sort_indices( + native: pa.Table, + *by: Unpack[tuple[str, Unpack[tuple[str, ...]]]], + descending: bool | Sequence[bool] = ..., + nulls_last: bool = ..., +) -> pa.UInt64Array: ... +def sort_indices( + native: ChunkedOrArrayAny | pa.Table, + *by: str, + options: SortOptions | SortMultipleOptions | None = None, + descending: bool | Sequence[bool] = False, + nulls_last: bool = False, +) -> pa.UInt64Array: + """Return the indices that would sort an array or table. + + Arguments: + native: Any non-scalar arrow data. + *by: Column(s) to sort by. Only applicable to `Table` and must use at least one name. + options: An *already-parsed* options instance. + **Has higher precedence** than `descending` and `nulls_last`. + descending: Sort in descending order. When sorting by multiple columns, + can be specified per column by passing a sequence of booleans. + nulls_last: Place null values last. + + Notes: + Most commonly used as input for `take`, which forms a `sort_by` operation. + """ + if not isinstance(native, pa.Table): + if options: + descending = options.descending + nulls_last = options._ensure_single_nulls_last("pyarrow") + a_opts = pa_options.array_sort(descending=descending, nulls_last=nulls_last) + return pc.array_sort_indices(native, options=a_opts) + opts = ( + options.to_arrow(by) + if options + else pa_options.sort(*by, descending=descending, nulls_last=nulls_last) + ) + return pc.sort_indices(native, options=opts) + + +def unsort_indices(indices: pa.UInt64Array, /) -> pa.Int64Array: + """Return the inverse permutation of the given indices. + + Arguments: + indices: The output of `sort_indices`. + + Examples: + We can use this pair of functions to recreate a windowed [`pl.row_index`] + + >>> import polars as pl + >>> data = {"by": [5, 2, 5, None]} + >>> df = pl.DataFrame(data) + >>> df.select( + ... pl.row_index().over(order_by="by", descending=True, nulls_last=False) + ... ).to_series().to_list() + [1, 3, 2, 0] + + Now in `pyarrow` + + >>> import pyarrow as pa + >>> from narwhals._plan.arrow.functions import sort_indices, unsort_indices + >>> df = pa.Table.from_pydict(data) + >>> unsort_indices( + ... sort_indices(df, "by", descending=True, nulls_last=False) + ... ).to_pylist() + [1, 3, 2, 0] + + [`pl.row_index`]: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.row_index.html + """ + return ( + call("inverse_permutation", indices.cast(I64)) + if compat.HAS_SCATTER + else int_range(len(indices), chunked=False).take(pc.sort_indices(indices)) + ) + + +def random_indices( + end: int, /, n: int, *, with_replacement: bool = False, seed: int | None = None +) -> ArrayAny: + """Generate `n` random indices within the range `[0, end)`. + + Note: + Review this path if anything changes [upstream]. + + [upstream]: https://github.com/apache/arrow/issues/47288#issuecomment-3597653670 + """ + if with_replacement: + rand_values = pc.random(n, initializer="system" if seed is None else seed) + return round(multiply(rand_values, lit(end - 1))).cast(I64) + + import numpy as np # ignore-banned-import + + return array(np.random.default_rng(seed).choice(np.arange(end), n, replace=False)) + + +def reverse(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + """Return the array in reverse order. + + Important: + Implemented via slicing, but unlike other slicing operations, [triggers a full-copy]. + + [triggers a full-copy]: https://github.com/apache/arrow/issues/19103#issuecomment-1377671886 + """ + return native[::-1] diff --git a/narwhals/_plan/arrow/functions/_strings.py b/narwhals/_plan/arrow/functions/_strings.py new file mode 100644 index 0000000000..a6e0a5518f --- /dev/null +++ b/narwhals/_plan/arrow/functions/_strings.py @@ -0,0 +1,367 @@ +"""String namespace functions.""" + +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING, Any, Final, overload + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan.arrow import compat, options as pa_options +from narwhals._plan.arrow.functions import _lists as list_ +from narwhals._plan.arrow.functions._aggregation import implode +from narwhals._plan.arrow.functions._bin_op import and_, eq, lt +from narwhals._plan.arrow.functions._boolean import all, any +from narwhals._plan.arrow.functions._construction import ( + array, + chunked_array, + concat_horizontal, + lit, +) +from narwhals._plan.arrow.functions._dtypes import string_type +from narwhals._plan.arrow.functions._multiplex import replace_with_mask, when_then +from narwhals._plan.arrow.functions._repeat import repeat_unchecked +from narwhals._plan.arrow.functions.meta import call + +if TYPE_CHECKING: + from collections.abc import Callable + + from typing_extensions import TypeAlias + + from narwhals._arrow.typing import Incomplete + from narwhals._plan.arrow.typing import ( + Array, + ArrayAny, + Arrow, + ArrowAny, + ChunkedArray, + ChunkedArrayAny, + ChunkedOrScalar, + ChunkedOrScalarAny, + IntegerScalar, + ListScalar, + ScalarAny, + StringScalar, + ) + + _StringFunction0: TypeAlias = "Callable[[ChunkedOrScalarAny], ChunkedOrScalarAny]" + _StringFunction1: TypeAlias = ( + "Callable[[ChunkedOrScalarAny, str], ChunkedOrScalarAny]" + ) + + +__all__ = [ + "concat_str", + "contains", + "ends_with", + "find", + "join", + "len_chars", + "pad_start", + "replace", + "replace_all", + "replace_vector", + "slice", + "split", + "splitn", + "starts_with", + "strip_chars", + "to_lowercase", + "to_titlecase", + "to_uppercase", + "zfill", +] + +starts_with = t.cast("_StringFunction1", pc.starts_with) +"""Check if string values start with a substring.""" + +ends_with = t.cast("_StringFunction1", pc.ends_with) +"""Check if string values end with a substring.""" + +to_uppercase = t.cast("_StringFunction0", pc.utf8_upper) +"""Modify strings to their uppercase equivalent.""" + +to_lowercase = t.cast("_StringFunction0", pc.utf8_lower) +"""Modify strings to their lowercase equivalent.""" + +to_titlecase = t.cast("_StringFunction0", pc.utf8_title) +"""Modify strings to their titlecase equivalent.""" + + +@overload +def concat_str( + *arrays: ChunkedArrayAny, separator: str = ..., ignore_nulls: bool = ... +) -> ChunkedArray[StringScalar]: ... +@overload +def concat_str( + *arrays: ArrayAny, separator: str = ..., ignore_nulls: bool = ... +) -> Array[StringScalar]: ... +@overload +def concat_str( + *arrays: ScalarAny, separator: str = ..., ignore_nulls: bool = ... +) -> StringScalar: ... +def concat_str( + *arrays: ArrowAny, separator: str = "", ignore_nulls: bool = False +) -> Arrow[StringScalar]: + """Horizontally concatenate arrow data into a single string column.""" + dtype = string_type(obj.type for obj in arrays) + it = (obj.cast(dtype) for obj in arrays) + sep = lit(separator, dtype) + join = pa_options.join(ignore_nulls=ignore_nulls) + result: Arrow[StringScalar] = call("binary_join_element_wise", *it, sep, options=join) + return result + + +def join( + native: Arrow[StringScalar], separator: str, *, ignore_nulls: bool = True +) -> StringScalar: + """Vertically concatenate the string values in the column to a single string value.""" + if isinstance(native, pa.Scalar): + # already joined + return native + if ignore_nulls and native.null_count: + native = native.drop_null() + return list_.join_scalar(implode(native), separator, ignore_nulls=False) + + +def len_chars(native: ChunkedOrScalarAny) -> ChunkedOrScalarAny: + """Return the length of each string as the number of characters.""" + result: ChunkedOrScalarAny = call("utf8_length", native) + return result + + +def slice( + native: ChunkedOrScalarAny, offset: int, length: int | None = None +) -> ChunkedOrScalarAny: + """Extract a substring from each string value.""" + stop = length if length is None else offset + length + return pc.utf8_slice_codeunits(native, offset, stop=stop) + + +def pad_start( + native: ChunkedOrScalarAny, length: int, fill_char: str = " " +) -> ChunkedOrScalarAny: # pragma: no cover + """Pad the start of the string until it reaches the given length.""" + return pc.utf8_lpad(native, length, fill_char) + + +@overload +def find( + native: ChunkedArrayAny, + pattern: str, + *, + literal: bool = ..., + not_found: int | None = ..., +) -> ChunkedArray[IntegerScalar]: ... +@overload +def find( + native: Array, pattern: str, *, literal: bool = ..., not_found: int | None = ... +) -> Array[IntegerScalar]: ... +@overload +def find( + native: ScalarAny, pattern: str, *, literal: bool = ..., not_found: int | None = ... +) -> IntegerScalar: ... +def find( + native: Arrow[StringScalar], + pattern: str, + *, + literal: bool = False, + not_found: int | None = -1, +) -> Arrow[IntegerScalar]: + """Return the bytes offset of the first substring matching a pattern. + + To match `pl.Expr.str.find` behavior, pass `not_found=None`. + + Note: + `pyarrow` distinguishes null *inputs* with `None` and failed matches with `-1`. + """ + # NOTE: `pyarrow-stubs` uses concrete types here + name = "find_substring" if literal else "find_substring_regex" + result: Arrow[IntegerScalar] = call( + name, native, options=pa_options.match_substring(pattern) + ) + if not_found == -1: + return result + return when_then(eq(result, lit(-1)), lit(not_found, result.type), result) + + +def _split( + native: ArrowAny, by: str, n: int | None = None, *, literal: bool = True +) -> Arrow[ListScalar]: + name = "split_pattern" if literal else "split_pattern_regex" + result: Arrow[ListScalar] = call( + name, native, options=pa_options.split_pattern(by, n) + ) + return result + + +@overload +def split( + native: ChunkedArrayAny, by: str, *, literal: bool = ... +) -> ChunkedArray[ListScalar]: ... +@overload +def split( + native: ChunkedOrScalarAny, by: str, *, literal: bool = ... +) -> ChunkedOrScalar[ListScalar]: ... +@overload +def split(native: ArrayAny, by: str, *, literal: bool = ...) -> pa.ListArray[Any]: ... +@overload +def split(native: ArrowAny, by: str, *, literal: bool = ...) -> Arrow[ListScalar]: ... +def split(native: ArrowAny, by: str, *, literal: bool = True) -> Arrow[ListScalar]: + """Split the string by a substring.""" + return _split(native, by, literal=literal) + + +# TODO @dangotbanned: Support and default to `as_struct=True` +# `polars` would return a struct w/ field names (`'field_0', ..., 'field_n-1'`) +@overload +def splitn( + native: ChunkedArrayAny, + by: str, + n: int, + *, + literal: bool = ..., + as_struct: bool = ..., +) -> ChunkedArray[ListScalar]: ... +@overload +def splitn( + native: ChunkedOrScalarAny, + by: str, + n: int, + *, + literal: bool = ..., + as_struct: bool = ..., +) -> ChunkedOrScalar[ListScalar]: ... +@overload +def splitn( + native: ArrayAny, by: str, n: int, *, literal: bool = ..., as_struct: bool = ... +) -> pa.ListArray[Any]: ... +@overload +def splitn( + native: ArrowAny, by: str, n: int, *, literal: bool = ..., as_struct: bool = ... +) -> Arrow[ListScalar]: ... +def splitn( + native: ArrowAny, by: str, n: int, *, literal: bool = True, as_struct: bool = False +) -> Arrow[ListScalar]: + """Split the string by a substring, restricted to returning at most `n` items.""" + result = _split(native, by, n, literal=literal) + if as_struct: + msg = "TODO: `ArrowExpr.str.splitn`" + raise NotImplementedError(msg) + return result + + +@overload +def contains( + native: ChunkedArrayAny, pattern: str, *, literal: bool = ... +) -> ChunkedArray[pa.BooleanScalar]: ... +@overload +def contains( + native: ChunkedOrScalarAny, pattern: str, *, literal: bool = ... +) -> ChunkedOrScalar[pa.BooleanScalar]: ... +@overload +def contains( + native: ArrowAny, pattern: str, *, literal: bool = ... +) -> Arrow[pa.BooleanScalar]: ... +def contains( + native: ArrowAny, pattern: str, *, literal: bool = False +) -> Arrow[pa.BooleanScalar]: + """Check if the string contains a substring that matches a pattern.""" + name = "match_substring" if literal else "match_substring_regex" + result: Arrow[pa.BooleanScalar] = call( + name, native, options=pa_options.match_substring(pattern) + ) + return result + + +def strip_chars(native: Incomplete, characters: str | None) -> Incomplete: + """Remove leading and trailing characters.""" + if characters: + return pc.utf8_trim(native, characters) + return pc.utf8_trim_whitespace(native) + + +def replace( + native: Incomplete, pattern: str, value: str, *, literal: bool = False, n: int = 1 +) -> Incomplete: + """Replace the first matching regex/literal substring with a new string value.""" + fn = pc.replace_substring if literal else pc.replace_substring_regex + return fn(native, pattern, replacement=value, max_replacements=n) + + +def replace_all( + native: Incomplete, pattern: str, value: str, *, literal: bool = False +) -> Incomplete: + """Replace all matching regex/literal substrings with a new string value.""" + return replace(native, pattern, value, literal=literal, n=-1) + + +def replace_vector( + native: ChunkedArrayAny, + pattern: str, + replacements: ChunkedArrayAny, + *, + literal: bool = False, + n: int | None = 1, +) -> ChunkedArrayAny: + """Replace the first matching regex/literal substring with the adjacent string in `replacements`.""" + has_match = contains(native, pattern, literal=literal) + if not any(has_match).as_py(): + # fastpath, no work to do + return native + match, match_replacements = ( + concat_horizontal([native, replacements], ["0", "1"]).filter(has_match).columns + ) + if n is None or n == -1: + list_split_by = split(match, pattern, literal=literal) + else: + list_split_by = splitn(match, pattern, n + 1, literal=literal) + replaced = list_.join(list_split_by, match_replacements, ignore_nulls=False) + if all(has_match, ignore_nulls=False).as_py(): + return chunked_array(replaced) + return replace_with_mask(native, has_match, array(replaced)) + + +def zfill(native: ChunkedOrScalarAny, length: int) -> ChunkedOrScalarAny: + """Pad the start of the string with zeros until it reaches the given length.""" + if compat.HAS_ZFILL: + zfill: Incomplete = pc.utf8_zero_fill # type: ignore[attr-defined] + result: ChunkedOrScalarAny = zfill(native, length) + else: + result = _zfill_compat(native, length) + return result + + +# TODO @dangotbanned: Finish tidying this up +def _zfill_compat( + native: ChunkedOrScalarAny, length: int +) -> Incomplete: # pragma: no cover + dtype = string_type([native.type]) + hyphen, plus = lit("-", dtype), lit("+", dtype) + + padded_remaining = pad_start(slice(native, 1), length - 1, "0") + padded_lt_length = pad_start(native, length, "0") + + if isinstance(native, pa.Scalar): + case_1: ArrowAny = hyphen # starts with hyphen and less than length + case_2: ArrowAny = plus # starts with plus and less than length + else: + arr_len = len(native) + case_1 = repeat_unchecked(hyphen, arr_len) + case_2 = repeat_unchecked(plus, arr_len) + + first_char = slice(native, 0, 1) + lt_length = lt(len_chars(native), lit(length)) + first_hyphen_lt_length = and_(eq(first_char, hyphen), lt_length) + first_plus_lt_length = and_(eq(first_char, plus), lt_length) + join_: Final = "binary_join_element_wise" + return when_then( + first_hyphen_lt_length, + call(join_, case_1, padded_remaining, ""), + when_then( + first_plus_lt_length, + call(join_, case_2, padded_remaining, ""), + when_then(lt_length, padded_lt_length, native), + ), + ) diff --git a/narwhals/_plan/arrow/functions/_struct.py b/narwhals/_plan/arrow/functions/_struct.py new file mode 100644 index 0000000000..d9861ab904 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_struct.py @@ -0,0 +1,133 @@ +"""Struct function namespace, and some helpers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, overload + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan import common +from narwhals._plan.arrow import compat +from narwhals._plan.arrow.functions.meta import call +from narwhals._plan.arrow.guards import is_arrow + +if TYPE_CHECKING: + from collections.abc import Iterable + + from narwhals._plan.arrow.acero import Field + from narwhals._plan.arrow.typing import ( + ArrayAny, + Arrow, + ArrowAny, + ChunkedArrayAny, + ChunkedOrScalarAny, + ChunkedStruct, + SameArrowT, + ScalarAny, + Struct, + StructArray, + ) + from narwhals._plan.typing import Seq + from narwhals.typing import NonNestedLiteral + +__all__ = ["field", "field_names", "fields", "into_struct", "schema"] + + +@overload +def into_struct( + columns: Iterable[ChunkedArrayAny], names: Iterable[str] +) -> ChunkedStruct: ... +@overload +def into_struct(columns: Iterable[ArrayAny], names: Iterable[str]) -> pa.StructArray: ... +@overload +def into_struct( + columns: Iterable[ScalarAny], names: Iterable[str] +) -> pa.StructScalar: ... +@overload +def into_struct( + columns: Iterable[ChunkedArrayAny | NonNestedLiteral], names: Iterable[str] +) -> ChunkedStruct: ... +def into_struct( + columns: Iterable[ArrowAny | NonNestedLiteral], names: Iterable[str] +) -> Struct: + """Collect columns into a struct. + + Arguments: + columns: Value(s) to collect into a struct. Scalars will will be broadcast unless all + inputs are scalar. + names: Name(s) to assign to each struct field. + + Note: + Roughly [`polars.struct`] but `names` must be resolved ahead of time. + + [`polars.struct`]: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.struct.html + """ + options = _make_names(common.ensure_seq_str(names)) + result: Struct = call("make_struct", *columns, options=options) + return result + + +def schema(native: Arrow[pa.StructScalar] | pa.StructType, /) -> pa.Schema: + """Get the struct definition as a schema. + + Arguments: + native: Struct-typed arrow data, or a `StructType` *itself*. + """ + tp = native.type if is_arrow(native) else native + fields = tp.fields if compat.HAS_STRUCT_TYPE_FIELDS else list(tp) + return pa.schema(fields) + + +def field_names(native: Arrow[pa.StructScalar] | pa.StructType, /) -> list[str]: + """Get the names of each field in a struct. + + Arguments: + native: Struct-typed arrow data, or a `StructType` *itself*. + """ + tp = native.type if is_arrow(native) else native + return tp.names if compat.HAS_STRUCT_TYPE_FIELDS else [f.name for f in tp] + + +@overload +def field(native: ChunkedStruct, name: Field, /) -> ChunkedArrayAny: ... +@overload +def field(native: StructArray, name: Field, /) -> ArrayAny: ... +@overload +def field(native: pa.StructScalar, name: Field, /) -> ScalarAny: ... +@overload +def field(native: SameArrowT, name: Field, /) -> SameArrowT: ... +@overload +def field(native: ChunkedOrScalarAny, name: Field, /) -> ChunkedOrScalarAny: ... +def field(native: ArrowAny, name: Field, /) -> ArrowAny: + """Retrieve a single field from a struct as a new array/scalar. + + Arguments: + native: Struct-typed arrow data. + name: Name of the struct field to retrieve. + """ + result: ArrowAny = call("struct_field", native, options=_get_name(name)) + return result + + +@overload +def fields(native: ChunkedStruct, *names: Field) -> Seq[ChunkedArrayAny]: ... +@overload +def fields(native: StructArray, *names: Field) -> Seq[ArrayAny]: ... +@overload +def fields(native: pa.StructScalar, *names: Field) -> Seq[ScalarAny]: ... +@overload +def fields(native: SameArrowT, *names: Field) -> Seq[SameArrowT]: ... +def fields(native: ArrowAny, *names: Field) -> Seq[ArrowAny]: + """Retrieve multiple fields from a struct as new array/scalar(s). + + Arguments: + native: Struct-typed arrow data. + names: Names of the struct fields to retrieve. + """ + f = pc.get_function("struct_field") + return tuple["ArrowAny", ...](f.call([native], _get_name(nm)) for nm in names) + + +_make_names = pc.MakeStructOptions +_get_name = pc.StructFieldOptions diff --git a/narwhals/_plan/arrow/functions/_vector.py b/narwhals/_plan/arrow/functions/_vector.py new file mode 100644 index 0000000000..98dd566409 --- /dev/null +++ b/narwhals/_plan/arrow/functions/_vector.py @@ -0,0 +1,180 @@ +"""Non-scalar functions, which need to observe the context surrounding each element. + +Currently a subset of Arrow's [Array-wise ("vector") functions]. + +[Array-wise ("vector") functions]: https://arrow.apache.org/docs/cpp/compute.html#array-wise-vector-functions +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, overload + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +from narwhals._plan.arrow import compat, options as pa_options +from narwhals._plan.arrow.functions import _struct as struct +from narwhals._plan.arrow.functions._bin_op import not_eq +from narwhals._plan.arrow.functions._boolean import is_between, is_in +from narwhals._plan.arrow.functions._construction import array, chunked_array, lit +from narwhals._plan.arrow.functions._dtypes import BOOL, F64 +from narwhals._plan.arrow.functions._multiplex import ( + preserve_nulls, + replace_with_mask, + when_then, +) +from narwhals._plan.arrow.functions._ranges import int_range, linear_space +from narwhals._plan.arrow.functions._repeat import repeat_like, zeros + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping, Sequence + + from narwhals._plan.arrow.typing import ( + ArrayAny, + ChunkedArray, + ChunkedArrayAny, + ChunkedOrArray, + ChunkedOrArrayT, + NumericScalar, + ScalarAny, + SearchSortedSide, + ) + from narwhals._plan.options import RankOptions + from narwhals.typing import NonNestedLiteral + +__all__ = ["diff", "hist_bins", "hist_zeroed_data", "rank", "search_sorted", "shift"] + + +def diff(native: ChunkedOrArrayT, n: int = 1) -> ChunkedOrArrayT: + """Calculate the first discrete difference between shifted items. + + Arguments: + native: An arrow array. + n: Number of slots to shift. + """ + return ( + pc.pairwise_diff(native, n) + if isinstance(native, pa.Array) + else chunked_array(pc.pairwise_diff(native.combine_chunks(), n)) + ) + + +def shift( + native: ChunkedOrArrayT, n: int, *, fill_value: NonNestedLiteral = None +) -> ChunkedOrArrayT: + """Shift values by the given number of indices. + + Arguments: + native: An arrow array. + n: Number of indices to shift forward. If a negative value is passed, values + are shifted in the opposite direction instead. + fill_value: Fill the resulting null values with this value. + """ + if n == 0: + return native + n_abs = abs(n) + filled = repeat_like(fill_value, n_abs, native) + forward = n > 0 + sliced = native.slice(length=len(native) - n) if forward else native.slice(n_abs) + if isinstance(sliced, pa.ChunkedArray): + chunks: list[ArrayAny] = sliced.chunks + return pa.chunked_array((filled, *chunks) if forward else (*chunks, filled)) + return pa.concat_arrays((filled, sliced) if forward else (sliced, filled)) + + +def rank(native: ChunkedOrArrayT, options: RankOptions) -> ChunkedOrArrayT: + """Assign ranks to `native`, dealing with ties according to `options`.""" + arr = native if compat.RANK_ACCEPTS_CHUNKED else array(native) + if options.method == "average": + # Adapted from https://github.com/pandas-dev/pandas/blob/f4851e500a43125d505db64e548af0355227714b/pandas/core/arrays/arrow/array.py#L2290-L2316 + order = pa_options.ORDER[options.descending] + min = preserve_nulls(arr, pc.rank(arr, order, tiebreaker="min").cast(F64)) + max = pc.rank(arr, order, tiebreaker="max").cast(F64) + ranked = pc.divide(pc.add(min, max), lit(2, F64)) + else: + ranked = preserve_nulls(native, pc.rank(arr, options=options.to_arrow())) + if isinstance(native, pa.ChunkedArray): + return chunked_array(ranked) + return ranked + + +# NOTE @dangotbanned: (wish) replacing `np.searchsorted`? +@overload +def search_sorted( + native: ChunkedOrArrayT, + element: ChunkedOrArray[NumericScalar] | Sequence[float], + *, + side: SearchSortedSide = ..., +) -> ChunkedOrArrayT: ... +# NOTE: scalar case may work with only `partition_nth_indices`? +@overload +def search_sorted( + native: ChunkedOrArrayT, element: float, *, side: SearchSortedSide = ... +) -> ScalarAny: ... +def search_sorted( + native: ChunkedOrArrayT, + element: ChunkedOrArray[NumericScalar] | Sequence[float] | float, + *, + side: SearchSortedSide = "left", +) -> ChunkedOrArrayT | ScalarAny: + """Find indices where elements should be inserted to maintain order.""" + import numpy as np # ignore-banned-import + + indices = np.searchsorted(element, native, side=side) + if isinstance(indices, np.generic): + return lit(indices) + if isinstance(native, pa.ChunkedArray): + return chunked_array(indices) + return array(indices) + + +def hist_bins( + native: ChunkedArrayAny, + bins: Sequence[float] | ChunkedArray[NumericScalar], + *, + include_breakpoint: bool, +) -> Mapping[str, Iterable[Any]]: + """Bin values into buckets and count their occurrences. + + Notes: + Assumes that the following edge cases have been handled: + - `len(bins) >= 2` + - `bins` increase monotonically + - `bin[0] != bin[-1]` + - `native` contains values that are non-null (including NaN) + """ + if len(bins) == 2: + upper = bins[1] + count = array(is_between(native, bins[0], upper, closed="both"), BOOL).true_count + if include_breakpoint: + return {"breakpoint": [upper], "count": [count]} + return {"count": [count]} + + # lowest bin is inclusive + # NOTE: `np.unique` behavior sorts first + value_counts = ( + when_then(not_eq(native, lit(bins[0])), search_sorted(native, bins), 1) + .sort() + .value_counts() + ) + values, counts = struct.fields(value_counts, "values", "counts") + bin_count = len(bins) + int_range_ = int_range(1, bin_count, chunked=False) + mask = is_in(int_range_, values) + replacements = counts.filter(is_in(values, int_range_)) + counts = replace_with_mask(zeros(bin_count - 1), mask, replacements) + + if include_breakpoint: + return {"breakpoint": bins[1:], "count": counts} + return {"count": counts} + + +def hist_zeroed_data( + arg: int | Sequence[float], *, include_breakpoint: bool +) -> Mapping[str, Iterable[Any]]: + # NOTE: If adding `linear_space` and `zeros` to `CompliantNamespace`, consider moving this. + n = arg if isinstance(arg, int) else len(arg) - 1 + if not include_breakpoint: + return {"count": zeros(n)} + bp = linear_space(0, 1, arg, closed="right") if isinstance(arg, int) else arg[1:] + return {"breakpoint": bp, "count": zeros(n)} diff --git a/narwhals/_plan/arrow/functions/meta.py b/narwhals/_plan/arrow/functions/meta.py new file mode 100644 index 0000000000..631d2d584f --- /dev/null +++ b/narwhals/_plan/arrow/functions/meta.py @@ -0,0 +1,51 @@ +"""Functions about functions.""" + +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +if TYPE_CHECKING: + from typing_extensions import LiteralString, TypeAlias + + from narwhals._plan.arrow.typing import ArrowAny + from narwhals.typing import PythonLiteral + + +__all__ = ["call"] + +Incomplete: TypeAlias = t.Any + +_PackComputeArgsElement: TypeAlias = ( + "PythonLiteral | ArrowAny | pa.RecordBatch | pa.Table" +) +"""[`_pack_compute_args`] covers every possible input types to a `pyarrow.compute` function. + +This version just excludes `np.ndarray` (*for now*). + +[`_pack_compute_args`]: https://github.com/apache/arrow/blob/29586f4d28c50a4344f14a78dc7e091ab635fa72/python/pyarrow/_compute.pyx#L488-L520 +""" + + +def call( + name: LiteralString, + *args: _PackComputeArgsElement, + options: pc.FunctionOptions | None = None, +) -> Incomplete: + """Call a [`pyarrow.compute`] function by name. + + Escape hatch to use when typing falls apart. + + Arguments: + name: Name of the function to call. + *args: Arguments to the function. + options: A [`pc.FunctionOptions`] instance to pass to the function. + + [`pyarrow.compute`]: https://arrow.apache.org/docs/dev/python/generated/pyarrow.compute.call_function.html + [`pc.FunctionOptions`]: https://arrow.apache.org/docs/dev/python/api/compute.html#compute-options + """ + call_function: Incomplete = pc.call_function + return call_function(name, args, options=options) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index 93bc2c2a50..a2ab4a9d88 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -17,25 +17,37 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping, Sequence + from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from typing_extensions import Self, TypeAlias from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.typing import ( - ArrayAny, + BooleanLengthPreserving, ChunkedArray, ChunkedArrayAny, ChunkedList, ChunkedOrScalarAny, + ChunkedStruct, Indices, ListScalar, ScalarAny, ) from narwhals._plan.expressions import NamedIR from narwhals._plan.typing import Seq + from narwhals.typing import UniqueKeepStrategy Incomplete: TypeAlias = Any +IntoColumnAgg: TypeAlias = "Callable[[str], ir.AggExpr]" +"""Helper constructor for single-column aggregations.""" + + +class MinMax(ir.AggExpr): + """Returns a `Struct({'min': ..., 'max': ...})`. + + https://arrow.apache.org/docs/python/generated/pyarrow.compute.min_max.html#pyarrow.compute.min_max + """ + SUPPORTED_AGG: Mapping[type[agg.AggExpr], acero.Aggregation] = { agg.Sum: "hash_sum", @@ -50,7 +62,7 @@ agg.NUnique: "hash_count_distinct", agg.First: "hash_first", agg.Last: "hash_last", - fn.MinMax: "hash_min_max", + MinMax: "hash_min_max", } SUPPORTED_LIST_AGG: Mapping[type[ir.lists.Aggregation], type[agg.AggExpr]] = { ir.lists.Mean: agg.Mean, @@ -225,14 +237,14 @@ def agg_list(self, native: ChunkedList | ListScalar) -> ChunkedOrScalarAny: func = HASH_TO_SCALAR_NAME[self._function] if not scalar.is_valid: return fn.lit(None, SCALAR_OUTPUT_TYPE.get(func, scalar.type.value_type)) - result = pc.call_function(func, [scalar.values], self._option) + result = fn.meta.call(func, scalar.values, options=self._option) return result result = self.over_index(fn.ExplodeBuilder().explode_with_indices(native), "idx") result = fn.when_then(native.is_valid(), result) if self._is_n_unique(): # NOTE: Exploding `[]` becomes `[None]` - so we need to adjust the unique count *iff* we were unlucky - is_sublist_empty = fn.eq(fn.list_len(native), fn.lit(0)) - if fn.any_(is_sublist_empty).as_py(): + is_sublist_empty = fn.eq(fn.list.len(native), fn.lit(0)) + if fn.any(is_sublist_empty).as_py(): result = fn.when_then(is_sublist_empty, fn.lit(0), result) return result @@ -306,7 +318,7 @@ def agg_over(self, irs: Seq[NamedIR], sort_indices: Indices | None = None) -> Fr if by.null_count: temp_name = temp.column_name({*column_names, *agg_names}) key_names = [temp_name] - native = native.append_column(temp_name, dictionary_encode(by)) + native = native.append_column(temp_name, fn.cat.encode(by)) compliant = from_native(native) else: partitions = native.select(key_names) @@ -315,7 +327,7 @@ def agg_over(self, irs: Seq[NamedIR], sort_indices: Indices | None = None) -> Fr for orig_name, by in zip(key_names, partitions.columns): if by.null_count: by_name = next(it_temp_names) - native = native.append_column(by_name, dictionary_encode(by)) + native = native.append_column(by_name, fn.cat.encode(by)) else: by_name = orig_name by_names.append(by_name) @@ -334,24 +346,6 @@ def agg_over(self, irs: Seq[NamedIR], sort_indices: Indices | None = None) -> Fr ) -@overload -def dictionary_encode(native: ChunkedArrayAny, /) -> pa.Int32Array: ... -@overload -def dictionary_encode( - native: ChunkedArrayAny, /, *, include_values: Literal[True] -) -> tuple[ArrayAny, pa.Int32Array]: ... -def dictionary_encode( - native: ChunkedArrayAny, /, *, include_values: bool = False -) -> tuple[ArrayAny, pa.Int32Array] | pa.Int32Array: - """Extra typing for `pc.dictionary_encode`.""" - da: Incomplete = native.dictionary_encode("encode").combine_chunks() - indices: pa.Int32Array = da.indices - if not include_values: - return indices - values: ArrayAny = da.dictionary - return values, indices - - def _composite_key(native: pa.Table, *, separator: str = "") -> ChunkedArray: """Horizontally join columns to *seed* a unique key per row combination.""" dtype = fn.string_type(native.schema.types) @@ -374,7 +368,7 @@ def _partition_by_one( native: pa.Table, by: str, *, include_key: bool = True ) -> Iterator[pa.Table]: """Optimized path for single-column partition.""" - values, indices = dictionary_encode(native.column(by), include_values=True) + values, indices = fn.cat.encode(native.column(by), include_categories=True) if not include_key: native = native.remove_column(native.schema.get_field_index(by)) for idx in range(len(values)): @@ -434,3 +428,47 @@ def _generate_hash_to_scalar_name() -> Mapping[acero.Aggregation, acero.Aggregat [Hash aggregate]: https://arrow.apache.org/docs/dev/cpp/compute.html#grouped-aggregations-group-by [Scalar aggregate]: https://arrow.apache.org/docs/dev/cpp/compute.html#aggregations """ + + +def _ir_min_max(name: str, /) -> MinMax: + return MinMax(expr=ir.col(name)) + + +def _boolean_is_unique( + indices: ChunkedArrayAny, aggregated: ChunkedStruct, / +) -> ChunkedArrayAny: + min, max = aggregated.flatten() + return fn.and_(fn.is_in(indices, min), fn.is_in(indices, max)) + + +def _boolean_is_duplicated( + indices: ChunkedArrayAny, aggregated: ChunkedStruct, / +) -> ChunkedArrayAny: + return fn.not_(_boolean_is_unique(indices, aggregated)) + + +# TODO @dangotbanned: Replace with a function for export? +BOOLEAN_LENGTH_PRESERVING: Mapping[ + type[ir.boolean.BooleanFunction], tuple[IntoColumnAgg, BooleanLengthPreserving] +] = { + ir.boolean.IsFirstDistinct: (ir.min, fn.is_in), + ir.boolean.IsLastDistinct: (ir.max, fn.is_in), + ir.boolean.IsUnique: (_ir_min_max, _boolean_is_unique), + ir.boolean.IsDuplicated: (_ir_min_max, _boolean_is_duplicated), +} + + +def unique_keep_boolean_length_preserving( + keep: UniqueKeepStrategy, +) -> tuple[IntoColumnAgg, BooleanLengthPreserving]: + return BOOLEAN_LENGTH_PRESERVING[_UNIQUE_KEEP_BOOLEAN_LENGTH_PRESERVING[keep]] + + +_UNIQUE_KEEP_BOOLEAN_LENGTH_PRESERVING: Mapping[ + UniqueKeepStrategy, type[ir.boolean.BooleanFunction] +] = { + "any": ir.boolean.IsFirstDistinct, + "first": ir.boolean.IsFirstDistinct, + "last": ir.boolean.IsLastDistinct, + "none": ir.boolean.IsUnique, +} diff --git a/narwhals/_plan/arrow/guards.py b/narwhals/_plan/arrow/guards.py new file mode 100644 index 0000000000..553b2b82b6 --- /dev/null +++ b/narwhals/_plan/arrow/guards.py @@ -0,0 +1,37 @@ +"""Backend-specific type guards.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pyarrow as pa # ignore-banned-import +import pyarrow.compute as pc # ignore-banned-import + +if TYPE_CHECKING: + from typing_extensions import TypeIs + + from narwhals._plan.arrow import acero + from narwhals._plan.arrow.typing import Arrow, ChunkedArrayAny, ScalarT + from narwhals._utils import _StoresNative + +__all__ = ["is_arrow", "is_expression", "is_series"] + + +def is_series(obj: Any) -> TypeIs[_StoresNative[ChunkedArrayAny]]: + """Return True if `obj` is a (Compliant) ArrowSeries.""" + from narwhals._plan.arrow.series import ArrowSeries + + return isinstance(obj, ArrowSeries) + + +def is_arrow(obj: Arrow[ScalarT] | Any) -> TypeIs[Arrow[ScalarT]]: + """Return True if `obj` is a (Native) Arrow data container.""" + return isinstance(obj, (pa.Scalar, pa.Array, pa.ChunkedArray)) + + +def is_expression(obj: Any) -> TypeIs[acero.Expr]: + """Return True if `obj` is a (Native) [`Expression`]. + + [`Expression`]: https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html + """ + return isinstance(obj, pc.Expression) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 2622c48c00..28e2bf2ce7 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -173,7 +173,9 @@ def concat_str( aligned = (ser.native for ser in self._expr.align(exprs)) separator = node.function.separator ignore_nulls = node.function.ignore_nulls - result = fn.concat_str(*aligned, separator=separator, ignore_nulls=ignore_nulls) + result = fn.str.concat_str( + *aligned, separator=separator, ignore_nulls=ignore_nulls + ) if isinstance(result, pa.Scalar): return self._scalar.from_native(result, name, self.version) return self._expr.from_native(result, name, self.version) diff --git a/narwhals/_plan/arrow/pivot.py b/narwhals/_plan/arrow/pivot.py index 2f169f762d..a4311e0021 100644 --- a/narwhals/_plan/arrow/pivot.py +++ b/narwhals/_plan/arrow/pivot.py @@ -22,7 +22,7 @@ import pyarrow as pa - from narwhals._plan.arrow.typing import ChunkedArray, StringScalar + from narwhals._plan.arrow.typing import ChunkedArray, ChunkedOrScalarAny, StringScalar from narwhals.typing import PivotAgg @@ -75,13 +75,15 @@ def _format_on_columns_titles(on_columns: pa.Table, /) -> ChunkedArray[StringSca # NOTE: Variation of https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.intersperse seps = (SEP,) * on_columns.num_columns - interspersed = chain.from_iterable(zip(seps, on_columns.itercolumns())) + interspersed: chain[ChunkedOrScalarAny] = chain.from_iterable( + zip(seps, on_columns.itercolumns()) + ) # skip the first separator, we just need the zip-terminating iterable to be the columns next(interspersed) - func = "binary_join_element_wise" - args = [LB, *interspersed, RB, EMPTY] opts = pa_options.join(ignore_nulls=False) - result: ChunkedArray[StringScalar] = pc.call_function(func, args, opts) + result: ChunkedArray[StringScalar] = fn.meta.call( + "binary_join_element_wise", LB, *interspersed, RB, EMPTY, options=opts + ) return result @@ -122,7 +124,7 @@ def _pivot( pivot = acero.group_by_table(native, index, specs) flat = pivot.flatten() if len(values) == 1: - names = [*index, *fn.struct_field_names(pivot.column(values[0]))] + names = [*index, *fn.struct.field_names(pivot.column(values[0]))] else: names = _replace_flatten_names(flat.column_names, values, on_columns, separator) return flat.rename_columns(names) diff --git a/narwhals/_plan/arrow/series.py b/narwhals/_plan/arrow/series.py index fdce41e16f..8518ac733c 100644 --- a/narwhals/_plan/arrow/series.py +++ b/narwhals/_plan/arrow/series.py @@ -5,7 +5,7 @@ import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc -from narwhals._arrow.utils import narwhals_to_native_dtype, native_to_narwhals_dtype +from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import compat, functions as fn, options from narwhals._plan.arrow.common import ArrowFrameSeries as FrameSeries from narwhals._plan.compliant.accessors import SeriesStructNamespace as StructNamespace @@ -17,14 +17,15 @@ from narwhals.schema import Schema if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Callable, Iterable import polars as pl from typing_extensions import Self, TypeAlias from narwhals._plan.arrow.dataframe import ArrowDataFrame as DataFrame from narwhals._plan.arrow.namespace import ArrowNamespace as Namespace - from narwhals._plan.arrow.typing import ChunkedArrayAny + from narwhals._plan.arrow.typing import ArrowAny, ChunkedArrayAny + from narwhals._plan.compliant.typing import SeriesT from narwhals.dtypes import DType from narwhals.typing import ( FillNullStrategy, @@ -38,6 +39,27 @@ Incomplete: TypeAlias = Any +def bin_op( + function: Callable[[Any, Any], Any], /, *, reflect: bool = False +) -> Callable[[SeriesT, Any], SeriesT]: + """Attach a binary operator to `ArrowSeries`.""" + + def f(self: SeriesT, other: SeriesT | Any, /) -> SeriesT: + right = other.native if isinstance(other, type(self)) else fn.lit(other) + return self._with_native(function(self.native, right)) + + def f_reflect(self: SeriesT, other: SeriesT | Any, /) -> SeriesT: + if isinstance(other, type(self)): + name = other.name + right: ArrowAny = other.native + else: + name = "literal" + right = fn.lit(other) + return self.from_native(function(right, self.native), name, version=self.version) + + return f_reflect if reflect else f + + class ArrowSeries(FrameSeries["ChunkedArrayAny"], CompliantSeries["ChunkedArrayAny"]): _name: str @@ -90,7 +112,7 @@ def from_iterable( return cls.from_native(fn.chunked_array([data], dtype_pa), name, version=version) def cast(self, dtype: IntoDType) -> Self: - dtype_pa = narwhals_to_native_dtype(dtype, self.version) + dtype_pa = fn.dtype_native(dtype, self.version) return self._with_native(fn.cast(self.native, dtype_pa)) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -124,32 +146,32 @@ def has_nulls(self) -> bool: def null_count(self) -> int: return self.native.null_count - __add__ = fn.bin_op(fn.add) - __and__ = fn.bin_op(fn.and_) - __eq__ = fn.bin_op(fn.eq) - __floordiv__ = fn.bin_op(fn.floordiv) - __ge__ = fn.bin_op(fn.gt_eq) - __gt__ = fn.bin_op(fn.gt) - __le__ = fn.bin_op(fn.lt_eq) - __lt__ = fn.bin_op(fn.lt) - __mod__ = fn.bin_op(fn.modulus) - __mul__ = fn.bin_op(fn.multiply) - __ne__ = fn.bin_op(fn.not_eq) - __or__ = fn.bin_op(fn.or_) - __pow__ = fn.bin_op(fn.power) - __rfloordiv__ = fn.bin_op(fn.floordiv, reflect=True) - __radd__ = fn.bin_op(fn.add, reflect=True) - __rand__ = fn.bin_op(fn.and_, reflect=True) - __rmod__ = fn.bin_op(fn.modulus, reflect=True) - __rmul__ = fn.bin_op(fn.multiply, reflect=True) - __ror__ = fn.bin_op(fn.or_, reflect=True) - __rpow__ = fn.bin_op(fn.power, reflect=True) - __rsub__ = fn.bin_op(fn.sub, reflect=True) - __rtruediv__ = fn.bin_op(fn.truediv, reflect=True) - __rxor__ = fn.bin_op(fn.xor, reflect=True) - __sub__ = fn.bin_op(fn.sub) - __truediv__ = fn.bin_op(fn.truediv) - __xor__ = fn.bin_op(fn.xor) + __add__ = bin_op(fn.add) + __and__ = bin_op(fn.and_) + __eq__ = bin_op(fn.eq) + __floordiv__ = bin_op(fn.floordiv) + __ge__ = bin_op(fn.gt_eq) + __gt__ = bin_op(fn.gt) + __le__ = bin_op(fn.lt_eq) + __lt__ = bin_op(fn.lt) + __mod__ = bin_op(fn.modulus) + __mul__ = bin_op(fn.multiply) + __ne__ = bin_op(fn.not_eq) + __or__ = bin_op(fn.or_) + __pow__ = bin_op(fn.power) + __rfloordiv__ = bin_op(fn.floordiv, reflect=True) + __radd__ = bin_op(fn.add, reflect=True) + __rand__ = bin_op(fn.and_, reflect=True) + __rmod__ = bin_op(fn.modulus, reflect=True) + __rmul__ = bin_op(fn.multiply, reflect=True) + __ror__ = bin_op(fn.or_, reflect=True) + __rpow__ = bin_op(fn.power, reflect=True) + __rsub__ = bin_op(fn.sub, reflect=True) + __rtruediv__ = bin_op(fn.truediv, reflect=True) + __rxor__ = bin_op(fn.xor, reflect=True) + __sub__ = bin_op(fn.sub) + __truediv__ = bin_op(fn.truediv) + __xor__ = bin_op(fn.xor) def __invert__(self) -> Self: return self._with_native(pc.invert(self.native)) @@ -279,13 +301,13 @@ def zip_with(self, mask: Self, other: Self | None) -> Self: return self._with_native(fn.when_then(predicate, self.native, right)) def all(self) -> bool: - return fn.all_(self.native).as_py() + return fn.all(self.native).as_py() def any(self) -> bool: - return fn.any_(self.native).as_py() + return fn.any(self.native).as_py() def sum(self) -> float: - result: float = fn.sum_(self.native).as_py() + result: float = fn.sum(self.native).as_py() return result def count(self) -> int: @@ -347,18 +369,18 @@ def unnest(self) -> DataFrame: if len(native): table = pa.Table.from_struct_array(native) else: - table = fn.struct_schema(native).empty_table() + table = fn.struct.schema(native).empty_table() else: # pragma: no cover # NOTE: Too strict, doesn't allow `Array[StructScalar]` rec_batch: Incomplete = pa.RecordBatch.from_struct_array batches = (rec_batch(chunk) for chunk in native.chunks) - table = pa.Table.from_batches(batches, fn.struct_schema(native)) + table = pa.Table.from_batches(batches, fn.struct.schema(native)) return namespace(self)._dataframe.from_native(table, self.version) # name overriding *may* be wrong def field(self, name: str) -> ArrowSeries: - return self.with_native(fn.struct_field(self.native, name), name) + return self.with_native(fn.struct.field(self.native, name), name) @property def schema(self) -> Schema: - return Schema.from_arrow(fn.struct_schema(self.native)) + return Schema.from_arrow(fn.struct.schema(self.native)) diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index c2befc214b..4de7699d1c 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -1,7 +1,7 @@ from __future__ import annotations # ruff: noqa: PLC0414 -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, Protocol, overload from narwhals._typing_compat import TypeVar @@ -28,7 +28,12 @@ from typing_extensions import ParamSpec, TypeAlias from narwhals._native import NativeDataFrame, NativeSeries - from narwhals.typing import SizedMultiIndexSelector as _SizedMultiIndexSelector + from narwhals._plan.typing import OneOrIterable + from narwhals._translate import ArrowStreamExportable + from narwhals.typing import ( + SizedMultiIndexSelector as _SizedMultiIndexSelector, + _AnyDArray, + ) UInt32Type: TypeAlias = "Uint32Type" StringType: TypeAlias = "_StringType | _LargeStringType" @@ -37,8 +42,6 @@ IntegerScalar: TypeAlias = "Scalar[IntegerType]" DateScalar: TypeAlias = "Scalar[Date32Type]" ListScalar: TypeAlias = "Scalar[pa.ListType[DataTypeT_co]]" - BooleanScalar: TypeAlias = "Scalar[BoolType]" - """Only use this for a parameter type, not as a return type!""" NumericScalar: TypeAlias = "pc.NumericScalar" PrimitiveNumericType: TypeAlias = "types._Integer | types._Floating" @@ -48,8 +51,8 @@ BasicType: TypeAlias = ( "NumericOrTemporalType | StringOrBinaryType | BoolType | lib.NullType" ) - NonListNestedType: TypeAlias = "pa.StructType | pa.DictionaryType[Any, Any] | pa.MapType[Any, Any] | pa.UnionType" - NonListType: TypeAlias = "BasicType | NonListNestedType" + NonListNestedType: TypeAlias = "pa.StructType | pa.DictionaryType[Any, Any, Any] | pa.MapType[Any, Any, Any] | pa.UnionType" + NonListType: TypeAlias = "IntoHashableType | NonListNestedType" NestedType: TypeAlias = "NonListNestedType | pa.ListType[Any]" NonListTypeT = TypeVar("NonListTypeT", bound="NonListType") ListTypeT = TypeVar("ListTypeT", bound="pa.ListType[Any]") @@ -63,6 +66,11 @@ def column(self, *args: Any, **kwds: Any) -> NativeArrowSeries: ... @property def columns(self) -> Sequence[NativeArrowSeries]: ... + class _NumpyArray(Protocol): + def __array__(self) -> _AnyDArray: ... + + # TODO @dangotbanned: Move out of `TYPE_CHECKING` for docs after (3, 10) minimum + # https://github.com/narwhals-dev/narwhals/issues/3204 P = ParamSpec("P") class UnaryFunctionP(Protocol[P]): @@ -83,6 +91,15 @@ def __call__( ) -> ChunkedArrayAny: ... +BooleanScalar: TypeAlias = "Scalar[BoolType]" +"""Only use this for a parameter type, not as a return type!""" + +IntoHashableType: TypeAlias = "BasicType | pa.DictionaryType[Any, Any, Any]" +"""Types that can be encoded into a dictionary.""" + +IntoHashableScalar: TypeAlias = "Scalar[IntoHashableType]" +"""Values that can be encoded into a dictionary.""" + ScalarT = TypeVar("ScalarT", bound="pa.Scalar[Any]", default="pa.Scalar[Any]") ScalarPT_contra = TypeVar( "ScalarPT_contra", @@ -131,50 +148,53 @@ def __call__( class BinaryFunction(Protocol[ScalarPT_contra, ScalarRT_co]): @overload def __call__( - self, x: ChunkedArray[ScalarPT_contra], y: ChunkedArray[ScalarPT_contra], / + self, lhs: ChunkedArray[ScalarPT_contra], rhs: ChunkedArray[ScalarPT_contra], / ) -> ChunkedArray[ScalarRT_co]: ... @overload def __call__( - self, x: Array[ScalarPT_contra], y: Array[ScalarPT_contra], / + self, lhs: Array[ScalarPT_contra], rhs: Array[ScalarPT_contra], / ) -> Array[ScalarRT_co]: ... @overload - def __call__(self, x: ScalarPT_contra, y: ScalarPT_contra, /) -> ScalarRT_co: ... + def __call__(self, lhs: ScalarPT_contra, rhs: ScalarPT_contra, /) -> ScalarRT_co: ... @overload def __call__( - self, x: ChunkedArray[ScalarPT_contra], y: ScalarPT_contra, / + self, lhs: ChunkedArray[ScalarPT_contra], rhs: ScalarPT_contra, / ) -> ChunkedArray[ScalarRT_co]: ... @overload def __call__( - self, x: Array[ScalarPT_contra], y: ScalarPT_contra, / + self, lhs: Array[ScalarPT_contra], rhs: ScalarPT_contra, / ) -> Array[ScalarRT_co]: ... @overload def __call__( - self, x: ScalarPT_contra, y: ChunkedArray[ScalarPT_contra], / + self, lhs: ScalarPT_contra, rhs: ChunkedArray[ScalarPT_contra], / ) -> ChunkedArray[ScalarRT_co]: ... @overload def __call__( - self, x: ScalarPT_contra, y: Array[ScalarPT_contra], / + self, lhs: ScalarPT_contra, rhs: Array[ScalarPT_contra], / ) -> Array[ScalarRT_co]: ... @overload def __call__( - self, x: ChunkedArray[ScalarPT_contra], y: Array[ScalarPT_contra], / + self, lhs: ChunkedArray[ScalarPT_contra], rhs: Array[ScalarPT_contra], / ) -> ChunkedArray[ScalarRT_co]: ... @overload def __call__( - self, x: Array[ScalarPT_contra], y: ChunkedArray[ScalarPT_contra], / + self, lhs: Array[ScalarPT_contra], rhs: ChunkedArray[ScalarPT_contra], / ) -> ChunkedArray[ScalarRT_co]: ... @overload def __call__( - self, x: ChunkedOrScalar[ScalarPT_contra], y: ChunkedOrScalar[ScalarPT_contra], / + self, + lhs: ChunkedOrScalar[ScalarPT_contra], + rhs: ChunkedOrScalar[ScalarPT_contra], + /, ) -> ChunkedOrScalar[ScalarRT_co]: ... @overload def __call__( - self, x: Arrow[ScalarPT_contra], y: Arrow[ScalarPT_contra], / + self, lhs: Arrow[ScalarPT_contra], rhs: Arrow[ScalarPT_contra], / ) -> Arrow[ScalarRT_co]: ... def __call__( - self, x: Arrow[ScalarPT_contra], y: Arrow[ScalarPT_contra], / + self, lhs: Arrow[ScalarPT_contra], rhs: Arrow[ScalarPT_contra], / ) -> Arrow[ScalarRT_co]: ... @@ -208,10 +228,22 @@ class BinaryLogical(BinaryFunction["BooleanScalar", "pa.BooleanScalar"], Protoco ChunkedOrScalarT = TypeVar("ChunkedOrScalarT", ChunkedArrayAny, ScalarAny) Indices: TypeAlias = "_SizedMultiIndexSelector[ChunkedOrArray[pc.IntegerScalar]]" +# Common spellings for complicated types ChunkedStruct: TypeAlias = "ChunkedArray[pa.StructScalar]" StructArray: TypeAlias = "pa.StructArray | Array[pa.StructScalar]" ChunkedList: TypeAlias = "ChunkedArray[ListScalar[DataTypeT_co]]" +Struct: TypeAlias = "ChunkedStruct | pa.StructArray | pa.StructScalar" +"""(Concrete) Struct-typed arrow data.""" + ListArray: TypeAlias = "Array[ListScalar[DataTypeT_co]]" +ChunkedOrArrayHashable: TypeAlias = "ChunkedOrArray[IntoHashableScalar]" +"""Arrow arrays that can be [dictionary-encoded]. + +Boolean, Null, Numeric, Temporal, Binary or String-typed, + Dictionary ([no-op]). + +[dictionary-encoded]: https://arrow.apache.org/cookbook/py/create.html#store-categorical-data +[no-op]: https://arrow.apache.org/docs/cpp/compute.html#associative-transforms +""" Arrow: TypeAlias = "ChunkedOrScalar[ScalarT_co] | Array[ScalarT_co]" ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" @@ -221,6 +253,10 @@ class BinaryLogical(BinaryFunction["BooleanScalar", "pa.BooleanScalar"], Protoco Predicate: TypeAlias = "Arrow[BooleanScalar]" """Any `pyarrow` container that wraps boolean.""" +IntoChunkedArray: TypeAlias = ( + "ArrowAny | list[Iterable[Any]] | OneOrIterable[ArrowStreamExportable | _NumpyArray]" +) + NativeScalar: TypeAlias = ScalarAny BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] StoresNativeT_co = TypeVar( @@ -239,3 +275,5 @@ class BinaryLogical(BinaryFunction["BooleanScalar", "pa.BooleanScalar"], Protoco `"average"` requires calculating both `"min"` and `"max"`. """ + +SearchSortedSide: TypeAlias = Literal["left", "right"]