From 8550da724784770b2ca12346ebecad299c4fa2c4 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 14:36:14 +0000 Subject: [PATCH 01/20] wip --- narwhals/_compliant/expr.py | 8 --- narwhals/_compliant/group_by.py | 6 ++- narwhals/_duckdb/expr.py | 5 +- narwhals/_duckdb/expr_name.py | 1 - narwhals/_duckdb/namespace.py | 14 +----- narwhals/_duckdb/selectors.py | 2 - narwhals/_expression_parsing.py | 84 +++++++++++++++++++++++-------- narwhals/_spark_like/expr.py | 14 ++---- narwhals/_spark_like/expr_name.py | 1 - narwhals/_spark_like/namespace.py | 13 ----- narwhals/_spark_like/selectors.py | 2 - narwhals/expr.py | 2 +- narwhals/functions.py | 18 ++++--- narwhals/group_by.py | 8 +-- narwhals/selectors.py | 20 +++++--- 15 files changed, 103 insertions(+), 95 deletions(-) diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 7fe41baf6f..6960e24b12 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -272,14 +272,6 @@ def __invert__(self) -> Self: ... def broadcast( self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL] ) -> Self: ... - def _is_multi_output_agg(self) -> bool: - """Return `True` for multi-output aggregations. - - Here we skip the keys, else they would appear duplicated in the output: - - df.group_by("a").agg(nw.all().mean()) - """ - return self._function_name.split("->", maxsplit=1)[0] in {"all", "selector"} class EagerExpr( diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 866e98c6e8..56aafd3914 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -147,7 +147,11 @@ def _evaluate_expr(self, expr: LazyExprT_contra, /) -> Iterator[NativeExprT_co]: else output_names ) native_exprs = expr(self.compliant) - if expr._is_multi_output_agg(): + assert expr._metadata is not None # noqa: S101 + if expr._metadata.expansion_kind.is_multi_unnamed(): + # Exclude keys from expansion. For example, in + # `df.group_by('a').agg(nw.all().sum())`, column 'a' only appears in the + # output as a grouping key - is does not get included in `nw.all().sum()`. for native_expr, name, alias in zip(native_exprs, output_names, aliases): if name not in self._keys: yield native_expr.alias(alias) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index bd2a53d1f2..2bf5248f13 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -49,19 +49,20 @@ class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]): _implementation = Implementation.DUCKDB _depth = 0 # Unused, just for compatibility with CompliantExpr + _function_name = "" # Unused, just for compatibility with CompliantExpr def __init__( self: Self, call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]], *, - function_name: str, + # Unused, just for compatibility with CompliantExpr + function_name: str = "", evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], version: Version, ) -> None: self._call = call - self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._backend_version = backend_version diff --git a/narwhals/_duckdb/expr_name.py b/narwhals/_duckdb/expr_name.py index 2095a3b9a1..151602bb17 100644 --- a/narwhals/_duckdb/expr_name.py +++ b/narwhals/_duckdb/expr_name.py @@ -58,7 +58,6 @@ def _from_alias_output_names( ) -> DuckDBExpr: return self._compliant_expr.__class__( call=self._compliant_expr._call, - function_name=self._compliant_expr._function_name, evaluate_output_names=self._compliant_expr._evaluate_output_names, alias_output_names=alias_output_names, backend_version=self._compliant_expr._backend_version, diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 121b91916d..b6aec3dd8f 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -120,7 +120,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self._expr( call=func, - function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -134,7 +133,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self._expr( call=func, - function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -148,7 +146,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self._expr( call=func, - function_name="or_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -162,7 +159,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self._expr( call=func, - function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -176,7 +172,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self._expr( call=func, - function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -190,7 +185,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self._expr( call=func, - function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -209,7 +203,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=func, - function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -235,7 +228,6 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self._expr( func, - function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, backend_version=self._backend_version, @@ -248,7 +240,6 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self._expr( call=func, - function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, backend_version=self._backend_version, @@ -289,7 +280,6 @@ def then(self: Self, value: DuckDBExpr | Any) -> DuckDBThen: return DuckDBThen( self, - function_name="whenthen", evaluate_output_names=getattr( value, "_evaluate_output_names", lambda _df: ["literal"] ), @@ -304,7 +294,7 @@ def __init__( self: Self, call: DuckDBWhen, *, - function_name: str, + function_name: str = '', evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], @@ -313,7 +303,6 @@ def __init__( self._backend_version = backend_version self._version = version self._call = call - self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names @@ -322,5 +311,4 @@ def otherwise(self: Self, value: DuckDBExpr | Any) -> DuckDBExpr: # callable object of type `DuckDBWhen`, base class has the attribute as # only a `Callable` self._call._otherwise_value = value # type: ignore[attr-defined] - self._function_name = "whenotherwise" return self diff --git a/narwhals/_duckdb/selectors.py b/narwhals/_duckdb/selectors.py index 77b5bc2d15..dcfd3113b1 100644 --- a/narwhals/_duckdb/selectors.py +++ b/narwhals/_duckdb/selectors.py @@ -27,7 +27,6 @@ def _selector( ) -> DuckDBSelector: return DuckDBSelector( call, - function_name="selector", evaluate_output_names=evaluate_output_names, alias_output_names=None, backend_version=self._backend_version, @@ -46,7 +45,6 @@ class DuckDBSelector( # type: ignore[misc] def _to_expr(self: Self) -> DuckDBExpr: return DuckDBExpr( self._call, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index d8aa822185..52f64ead7c 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -121,7 +121,7 @@ def evaluate_output_names_and_aliases( if expr._alias_output_names is None else expr._alias_output_names(output_names) ) - if expr._is_multi_output_agg(): + if exclude and expr._metadata is not None and expr._metadata.expansion_kind.is_multi_unnamed(): output_names, aliases = zip( *[(x, alias) for x, alias in zip(output_names, aliases) if x not in exclude] ) @@ -183,22 +183,44 @@ def is_scalar_like( return kind in {ExprKind.AGGREGATION, ExprKind.LITERAL} +class ExpansionKind(Enum): + """Describe what kind of expansion the expression performs.""" + + SINGLE = auto() + """e.g. `nw.col('a'), nw.sum_horizontal(nw.all())`""" + + MULTINAMED = auto() + """e.g. `nw.col('a', 'b')`""" + + MULTIUNNAMED = auto() + """e.g. `nw.all()`, nw.nth(0, 1)""" + + def is_multi_unnamed(self) -> bool: + return self is ExpansionKind.MULTIUNNAMED + + +def is_multi_output( + expansion_kind: ExpansionKind, +) -> TypeIs[Literal[ExpansionKind.MULTINAMED, ExpansionKind.MULTIUNNAMED]]: + return expansion_kind in {ExpansionKind.MULTINAMED, ExpansionKind.MULTIUNNAMED} + + class ExprMetadata: - __slots__ = ("_is_multi_output", "_kind", "_n_open_windows") + __slots__ = ("_expansion_kind", "_kind", "_n_open_windows") def __init__( - self, kind: ExprKind, /, *, n_open_windows: int, is_multi_output: bool + self, kind: ExprKind, /, *, n_open_windows: int, expansion_kind: ExpansionKind ) -> None: self._kind: ExprKind = kind self._n_open_windows = n_open_windows - self._is_multi_output = is_multi_output + self._expansion_kind = expansion_kind def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover msg = f"Cannot subclass {cls.__name__!r}" raise TypeError(msg) def __repr__(self) -> str: - return f"ExprMetadata(kind: {self._kind}, n_open_windows: {self._n_open_windows}, is_multi_output: {self._is_multi_output})" + return f"ExprMetadata(kind: {self._kind}, n_open_windows: {self._n_open_windows}, expansion_kind: {self._expansion_kind})" @property def kind(self) -> ExprKind: @@ -209,15 +231,13 @@ def n_open_windows(self) -> int: return self._n_open_windows @property - def is_multi_output(self) -> bool: - return self._is_multi_output + def expansion_kind(self) -> ExpansionKind: + return self._expansion_kind def with_kind(self, kind: ExprKind, /) -> ExprMetadata: """Change metadata kind, leaving all other attributes the same.""" return ExprMetadata( - kind, - n_open_windows=self._n_open_windows, - is_multi_output=self._is_multi_output, + kind, n_open_windows=self._n_open_windows, expansion_kind=self._expansion_kind ) def with_extra_open_window(self) -> ExprMetadata: @@ -225,7 +245,7 @@ def with_extra_open_window(self) -> ExprMetadata: return ExprMetadata( self.kind, n_open_windows=self._n_open_windows + 1, - is_multi_output=self._is_multi_output, + expansion_kind=self._expansion_kind, ) def with_kind_and_extra_open_window(self, kind: ExprKind, /) -> ExprMetadata: @@ -233,18 +253,31 @@ def with_kind_and_extra_open_window(self, kind: ExprKind, /) -> ExprMetadata: return ExprMetadata( kind, n_open_windows=self._n_open_windows + 1, - is_multi_output=self._is_multi_output, + expansion_kind=self._expansion_kind, ) @staticmethod def simple_selector() -> ExprMetadata: - # e.g. nw.col('a'), nw.nth(0) # noqa: ERA001 - return ExprMetadata(ExprKind.TRANSFORM, n_open_windows=0, is_multi_output=False) + # e.g. `nw.col('a')`, `nw.nth(0)` + return ExprMetadata( + ExprKind.TRANSFORM, n_open_windows=0, expansion_kind=ExpansionKind.SINGLE + ) + + @staticmethod + def multi_output_selector_named() -> ExprMetadata: + # e.g. `nw.col('a', 'b')` + return ExprMetadata( + ExprKind.TRANSFORM, n_open_windows=0, expansion_kind=ExpansionKind.MULTINAMED + ) @staticmethod - def multi_output_selector() -> ExprMetadata: - # e.g. nw.col('a', 'b'), nw.nth(0, 1), nw.all(), nw.selectors.matches('foo') # noqa: ERA001 - return ExprMetadata(ExprKind.TRANSFORM, n_open_windows=0, is_multi_output=True) + def multi_output_selector_unnamed() -> ExprMetadata: + # e.g. `nw.all()` + return ExprMetadata( + ExprKind.TRANSFORM, + n_open_windows=0, + expansion_kind=ExpansionKind.MULTIUNNAMED, + ) def combine_metadata( @@ -267,13 +300,13 @@ def combine_metadata( has_aggregations = False has_literals = False result_n_open_windows = 0 - result_is_multi_output = False + result_expansion_kind = ExpansionKind.SINGLE for i, arg in enumerate(args): if isinstance(arg, str) and not str_as_lit: has_transforms_or_windows = True elif is_expr(arg): - if arg._metadata.is_multi_output: + if is_multi_output(arg._metadata.expansion_kind): if i > 0 and not allow_multi_output: # Left-most argument is always allowed to be multi-output. msg = ( @@ -281,8 +314,15 @@ def combine_metadata( "are not supported in this context." ) raise MultiOutputExpressionError(msg) - if not to_single_output: - result_is_multi_output = True + if not to_single_output: # pragma: no cover + if i != 0 and arg._metadata.expansion_kind != result_expansion_kind: + msg = "Safety assertion failed, please report a bug." + raise AssertionError(msg) + # Preserve expansion kind. e.g. + # - `nw.all() + nw.col('a')` + # - `nw.selectors.datetime() - nw.selectors.numeric() + # preserve the expansion kind of the left-hand-side. + result_expansion_kind = arg._metadata.expansion_kind if arg._metadata.n_open_windows: result_n_open_windows += 1 kind = arg._metadata.kind @@ -321,7 +361,7 @@ def combine_metadata( return ExprMetadata( result_kind, n_open_windows=result_n_open_windows, - is_multi_output=result_is_multi_output, + expansion_kind=result_expansion_kind, ) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 22b8461d2e..055e2a60cf 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -41,12 +41,13 @@ class SparkLikeExpr(LazyExpr["SparkLikeLazyFrame", "Column"]): _depth = 0 # Unused, just for compatibility with CompliantExpr + _function_name = "" # Unused, just for compatibility with CompliantExpr def __init__( self: Self, call: Callable[[SparkLikeLazyFrame], Sequence[Column]], *, - function_name: str, + function_name: str = "", evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], @@ -54,7 +55,6 @@ def __init__( implementation: Implementation, ) -> None: self._call = call - self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._backend_version = backend_version @@ -77,7 +77,6 @@ def func(df: SparkLikeLazyFrame) -> Sequence[Column]: return self.__class__( func, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, @@ -127,7 +126,6 @@ def __narwhals_namespace__(self: Self) -> SparkLikeNamespace: # pragma: no cove def _with_metadata(self, metadata: ExprMetadata) -> Self: expr = self.__class__( self._call, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, @@ -145,7 +143,6 @@ def _with_window_function( ) -> Self: result = self.__class__( self._call, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, @@ -161,7 +158,7 @@ def from_column_names( evaluate_column_names: Callable[[SparkLikeLazyFrame], Sequence[str]], /, *, - function_name: str, + function_name: str = "", # noqa: ARG003 context: _FullContext, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: @@ -169,7 +166,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return cls( func, - function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, backend_version=context._backend_version, @@ -187,7 +183,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return cls( func, - function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, backend_version=context._backend_version, @@ -214,7 +209,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self.__class__( func, - function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, @@ -340,7 +334,6 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: return self.__class__( self._call, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, backend_version=self._backend_version, @@ -535,7 +528,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self.__class__( func, - function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, diff --git a/narwhals/_spark_like/expr_name.py b/narwhals/_spark_like/expr_name.py index 43aa5a192f..4efd3897a8 100644 --- a/narwhals/_spark_like/expr_name.py +++ b/narwhals/_spark_like/expr_name.py @@ -58,7 +58,6 @@ def _from_alias_output_names( ) -> SparkLikeExpr: return self._compliant_expr.__class__( self._compliant_expr._call, - function_name=self._compliant_expr._function_name, evaluate_output_names=self._compliant_expr._evaluate_output_names, alias_output_names=alias_output_names, backend_version=self._compliant_expr._backend_version, diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 6988ba9cc3..243c55954d 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -60,7 +60,6 @@ def _lit(df: SparkLikeLazyFrame) -> list[Column]: return self._expr( call=_lit, - function_name="lit", evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, backend_version=self._backend_version, @@ -74,7 +73,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self._expr( func, - function_name="len", evaluate_output_names=lambda _df: ["len"], alias_output_names=None, backend_version=self._backend_version, @@ -89,7 +87,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self._expr( call=func, - function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -104,7 +101,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self._expr( call=func, - function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -121,7 +117,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self._expr( call=func, - function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -150,7 +145,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self._expr( call=func, - function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -165,7 +159,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self._expr( call=func, - function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -180,7 +173,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self._expr( call=func, - function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -275,7 +267,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return self._expr( call=func, - function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, @@ -325,7 +316,6 @@ def then(self: Self, value: SparkLikeExpr | Any) -> SparkLikeThen: return SparkLikeThen( self, - function_name="whenthen", evaluate_output_names=getattr( value, "_evaluate_output_names", lambda _df: ["literal"] ), @@ -341,7 +331,6 @@ def __init__( self: Self, call: SparkLikeWhen, *, - function_name: str, evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], @@ -351,7 +340,6 @@ def __init__( self._backend_version = backend_version self._version = version self._call = call - self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._implementation = implementation @@ -361,5 +349,4 @@ def otherwise(self: Self, value: SparkLikeExpr | Any) -> SparkLikeExpr: # callable object of type `SparkLikeWhen`, base class has the attribute as # only a `Callable` self._call._otherwise_value = value # type: ignore[attr-defined] - self._function_name = "whenotherwise" return self diff --git a/narwhals/_spark_like/selectors.py b/narwhals/_spark_like/selectors.py index 2b832ef358..9b84b53761 100644 --- a/narwhals/_spark_like/selectors.py +++ b/narwhals/_spark_like/selectors.py @@ -25,7 +25,6 @@ def _selector( ) -> SparkLikeSelector: return SparkLikeSelector( call, - function_name="selector", evaluate_output_names=evaluate_output_names, alias_output_names=None, backend_version=self._backend_version, @@ -43,7 +42,6 @@ class SparkLikeSelector(CompliantSelector["SparkLikeLazyFrame", "Column"], Spark def _to_expr(self: Self) -> SparkLikeExpr: return SparkLikeExpr( self._call, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, diff --git a/narwhals/expr.py b/narwhals/expr.py index 2a402f1957..e31aa914d4 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1580,7 +1580,7 @@ def over( next_meta = ExprMetadata( kind, n_open_windows=n_open_windows, - is_multi_output=current_meta.is_multi_output, + expansion_kind=current_meta.expansion_kind, ) return self.__class__( diff --git a/narwhals/functions.py b/narwhals/functions.py index 6b56e4ca01..0f4509ef45 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -12,6 +12,7 @@ from typing import cast from typing import overload +from narwhals._expression_parsing import ExpansionKind from narwhals._expression_parsing import ExprKind from narwhals._expression_parsing import ExprMetadata from narwhals._expression_parsing import apply_n_ary_operation @@ -1195,7 +1196,7 @@ def func(plx: Any) -> Any: func, ExprMetadata.simple_selector() if len(flat_names) == 1 - else ExprMetadata.multi_output_selector(), + else ExprMetadata.multi_output_selector_named(), ) @@ -1233,7 +1234,7 @@ def exclude(*names: str | Iterable[str]) -> Expr: def func(plx: Any) -> Any: return plx.exclude(exclude_names) - return Expr(func, ExprMetadata.multi_output_selector()) + return Expr(func, ExprMetadata.multi_output_selector_unnamed()) def nth(*indices: int | Sequence[int]) -> Expr: @@ -1275,7 +1276,7 @@ def func(plx: Any) -> Any: func, ExprMetadata.simple_selector() if len(flat_indices) == 1 - else ExprMetadata.multi_output_selector(), + else ExprMetadata.multi_output_selector_unnamed(), ) @@ -1300,7 +1301,7 @@ def all_() -> Expr: | 1 4 0.246 | └──────────────────┘ """ - return Expr(lambda plx: plx.all(), ExprMetadata.multi_output_selector()) + return Expr(lambda plx: plx.all(), ExprMetadata.multi_output_selector_unnamed()) # Add underscore so it doesn't conflict with builtin `len` @@ -1334,7 +1335,10 @@ def func(plx: Any) -> Any: return plx.len() return Expr( - func, ExprMetadata(ExprKind.AGGREGATION, n_open_windows=0, is_multi_output=False) + func, + ExprMetadata( + ExprKind.AGGREGATION, n_open_windows=0, expansion_kind=ExpansionKind.SINGLE + ), ) @@ -1804,7 +1808,9 @@ def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr: return Expr( lambda plx: plx.lit(value, dtype), - ExprMetadata(ExprKind.LITERAL, n_open_windows=0, is_multi_output=False), + ExprMetadata( + ExprKind.LITERAL, n_open_windows=0, expansion_kind=ExpansionKind.SINGLE + ), ) diff --git a/narwhals/group_by.py b/narwhals/group_by.py index bc8e626e4b..51e3b624b6 100644 --- a/narwhals/group_by.py +++ b/narwhals/group_by.py @@ -84,9 +84,9 @@ def agg(self: Self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFra raise InvalidOperationError(msg) plx = self._df.__narwhals_namespace__() compliant_aggs = ( - *(x._to_compliant_expr(plx) for x in flat_aggs), + *(x._to_compliant_expr(plx)._with_metadata(x._metadata) for x in flat_aggs), *( - value._to_compliant_expr(plx).alias(key) + value._to_compliant_expr(plx).alias(key)._with_metadata(value._metadata) for key, value in named_aggs.items() ), ) @@ -174,9 +174,9 @@ def agg(self: Self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFra raise InvalidOperationError(msg) plx = self._df.__narwhals_namespace__() compliant_aggs = ( - *(x._to_compliant_expr(plx) for x in flat_aggs), + *(x._to_compliant_expr(plx)._with_metadata(x._metadata) for x in flat_aggs), *( - value._to_compliant_expr(plx).alias(key) + value._to_compliant_expr(plx).alias(key)._with_metadata(value._metadata) for key, value in named_aggs.items() ), ) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index 848478af60..d8603fe3c7 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -96,7 +96,7 @@ def by_dtype(*dtypes: DType | type[DType] | Iterable[DType | type[DType]]) -> Se flattened = flatten(dtypes) return Selector( lambda plx: plx.selectors.by_dtype(flattened), - ExprMetadata.multi_output_selector(), + ExprMetadata.multi_output_selector_unnamed(), ) @@ -130,7 +130,8 @@ def matches(pattern: str) -> Selector: 1 456 5.5 """ return Selector( - lambda plx: plx.selectors.matches(pattern), ExprMetadata.multi_output_selector() + lambda plx: plx.selectors.matches(pattern), + ExprMetadata.multi_output_selector_unnamed(), ) @@ -161,7 +162,7 @@ def numeric() -> Selector: └─────┴─────┘ """ return Selector( - lambda plx: plx.selectors.numeric(), ExprMetadata.multi_output_selector() + lambda plx: plx.selectors.numeric(), ExprMetadata.multi_output_selector_unnamed() ) @@ -196,7 +197,7 @@ def boolean() -> Selector: └──────────────────┘ """ return Selector( - lambda plx: plx.selectors.boolean(), ExprMetadata.multi_output_selector() + lambda plx: plx.selectors.boolean(), ExprMetadata.multi_output_selector_unnamed() ) @@ -227,7 +228,7 @@ def string() -> Selector: └─────┘ """ return Selector( - lambda plx: plx.selectors.string(), ExprMetadata.multi_output_selector() + lambda plx: plx.selectors.string(), ExprMetadata.multi_output_selector_unnamed() ) @@ -260,7 +261,8 @@ def categorical() -> Selector: └─────┘ """ return Selector( - lambda plx: plx.selectors.categorical(), ExprMetadata.multi_output_selector() + lambda plx: plx.selectors.categorical(), + ExprMetadata.multi_output_selector_unnamed(), ) @@ -284,7 +286,9 @@ def all() -> Selector: 0 1 x False 1 2 y True """ - return Selector(lambda plx: plx.selectors.all(), ExprMetadata.multi_output_selector()) + return Selector( + lambda plx: plx.selectors.all(), ExprMetadata.multi_output_selector_unnamed() + ) def datetime( @@ -344,7 +348,7 @@ def datetime( """ return Selector( lambda plx: plx.selectors.datetime(time_unit=time_unit, time_zone=time_zone), - ExprMetadata.multi_output_selector(), + ExprMetadata.multi_output_selector_unnamed(), ) From f7f0a9026a8d6d5a53616e7e6e1ea8ac8272cac2 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 14:52:32 +0000 Subject: [PATCH 02/20] lint --- narwhals/_duckdb/namespace.py | 2 +- narwhals/_expression_parsing.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index b6aec3dd8f..b4affa444d 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -294,7 +294,7 @@ def __init__( self: Self, call: DuckDBWhen, *, - function_name: str = '', + function_name: str = "", evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 52f64ead7c..c1a48e7b33 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -121,7 +121,11 @@ def evaluate_output_names_and_aliases( if expr._alias_output_names is None else expr._alias_output_names(output_names) ) - if exclude and expr._metadata is not None and expr._metadata.expansion_kind.is_multi_unnamed(): + if ( + exclude + and expr._metadata is not None + and expr._metadata.expansion_kind.is_multi_unnamed() + ): output_names, aliases = zip( *[(x, alias) for x, alias in zip(output_names, aliases) if x not in exclude] ) From 2a799c878397e4db3791199a9bc31e405462bc0c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 15:02:40 +0000 Subject: [PATCH 03/20] remove unnecessary check --- narwhals/_expression_parsing.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index c1a48e7b33..cafcb92f42 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -121,14 +121,16 @@ def evaluate_output_names_and_aliases( if expr._alias_output_names is None else expr._alias_output_names(output_names) ) - if ( - exclude - and expr._metadata is not None - and expr._metadata.expansion_kind.is_multi_unnamed() - ): - output_names, aliases = zip( - *[(x, alias) for x, alias in zip(output_names, aliases) if x not in exclude] - ) + if exclude: + assert expr._metadata is not None # noqa: S101 + if expr._metadata.expansion_kind.is_multi_unnamed(): + output_names, aliases = zip( + *[ + (x, alias) + for x, alias in zip(output_names, aliases) + if x not in exclude + ] + ) return output_names, aliases From 40aea92d481ca534e11be8ed83c83276422792e2 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 15:08:48 +0000 Subject: [PATCH 04/20] clean up duckdb --- narwhals/_duckdb/expr.py | 165 +++++++++----------------------- narwhals/_duckdb/expr_dt.py | 26 ++--- narwhals/_duckdb/expr_list.py | 2 +- narwhals/_duckdb/expr_str.py | 33 +++---- narwhals/_duckdb/expr_struct.py | 3 +- 5 files changed, 67 insertions(+), 162 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 2bf5248f13..34b989e337 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -160,7 +160,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def _from_call( self: Self, call: Callable[..., duckdb.Expression], - expr_name: str, **expressifiable_args: Self | Any, ) -> Self: """Create expression from callable. @@ -185,7 +184,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self.__class__( func, - function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, @@ -208,138 +206,80 @@ def _with_window_function( return result def __and__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input & other, - "__and__", - other=other, - ) + return self._from_call(lambda _input, other: _input & other, other=other) def __or__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input | other, - "__or__", - other=other, - ) + return self._from_call(lambda _input, other: _input | other, other=other) def __add__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input + other, - "__add__", - other=other, - ) + return self._from_call(lambda _input, other: _input + other, other=other) def __truediv__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input / other, - "__truediv__", - other=other, - ) + return self._from_call(lambda _input, other: _input / other, other=other) def __rtruediv__(self: Self, other: DuckDBExpr) -> Self: return self._from_call( - lambda _input, other: other.__truediv__(_input), "__rtruediv__", other=other + lambda _input, other: other.__truediv__(_input), other=other ).alias("literal") def __floordiv__(self: Self, other: DuckDBExpr) -> Self: return self._from_call( - lambda _input, other: _input.__floordiv__(other), - "__floordiv__", - other=other, + lambda _input, other: _input.__floordiv__(other), other=other ) def __rfloordiv__(self: Self, other: DuckDBExpr) -> Self: return self._from_call( - lambda _input, other: other.__floordiv__(_input), "__rfloordiv__", other=other + lambda _input, other: other.__floordiv__(_input), other=other ).alias("literal") def __mod__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__mod__(other), - "__mod__", - other=other, - ) + return self._from_call(lambda _input, other: _input.__mod__(other), other=other) def __rmod__(self: Self, other: DuckDBExpr) -> Self: return self._from_call( - lambda _input, other: other.__mod__(_input), "__rmod__", other=other + lambda _input, other: other.__mod__(_input), other=other ).alias("literal") def __sub__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input - other, - "__sub__", - other=other, - ) + return self._from_call(lambda _input, other: _input - other, other=other) def __rsub__(self: Self, other: DuckDBExpr) -> Self: return self._from_call( - lambda _input, other: other.__sub__(_input), "__rsub__", other=other + lambda _input, other: other.__sub__(_input), other=other ).alias("literal") def __mul__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input * other, - "__mul__", - other=other, - ) + return self._from_call(lambda _input, other: _input * other, other=other) def __pow__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input**other, - "__pow__", - other=other, - ) + return self._from_call(lambda _input, other: _input**other, other=other) def __rpow__(self: Self, other: DuckDBExpr) -> Self: return self._from_call( - lambda _input, other: other.__pow__(_input), "__rpow__", other=other + lambda _input, other: other.__pow__(_input), other=other ).alias("literal") def __lt__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input < other, - "__lt__", - other=other, - ) + return self._from_call(lambda _input, other: _input < other, other=other) def __gt__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input > other, - "__gt__", - other=other, - ) + return self._from_call(lambda _input, other: _input > other, other=other) def __le__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input <= other, - "__le__", - other=other, - ) + return self._from_call(lambda _input, other: _input <= other, other=other) def __ge__(self: Self, other: DuckDBExpr) -> Self: - return self._from_call( - lambda _input, other: _input >= other, - "__ge__", - other=other, - ) + return self._from_call(lambda _input, other: _input >= other, other=other) def __eq__(self: Self, other: DuckDBExpr) -> Self: # type: ignore[override] - return self._from_call( - lambda _input, other: _input == other, - "__eq__", - other=other, - ) + return self._from_call(lambda _input, other: _input == other, other=other) def __ne__(self: Self, other: DuckDBExpr) -> Self: # type: ignore[override] - return self._from_call( - lambda _input, other: _input != other, - "__ne__", - other=other, - ) + return self._from_call(lambda _input, other: _input != other, other=other) def __invert__(self: Self) -> Self: invert = cast("Callable[..., duckdb.Expression]", operator.invert) - return self._from_call(invert, "__invert__") + return self._from_call(invert) def alias(self: Self, name: str) -> Self: def alias_output_names(names: Sequence[str]) -> Sequence[str]: @@ -350,7 +290,6 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: return self.__class__( self._call, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, backend_version=self._backend_version, @@ -358,10 +297,10 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: ) def abs(self: Self) -> Self: - return self._from_call(lambda _input: FunctionExpression("abs", _input), "abs") + return self._from_call(lambda _input: FunctionExpression("abs", _input)) def mean(self: Self) -> Self: - return self._from_call(lambda _input: FunctionExpression("mean", _input), "mean") + return self._from_call(lambda _input: FunctionExpression("mean", _input)) def skew(self: Self) -> Self: def func(_input: duckdb.Expression) -> duckdb.Expression: @@ -379,22 +318,16 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: ) ) - return self._from_call(func, "skew") + return self._from_call(func) def median(self: Self) -> Self: - return self._from_call( - lambda _input: FunctionExpression("median", _input), "median" - ) + return self._from_call(lambda _input: FunctionExpression("median", _input)) def all(self: Self) -> Self: - return self._from_call( - lambda _input: FunctionExpression("bool_and", _input), "all" - ) + return self._from_call(lambda _input: FunctionExpression("bool_and", _input)) def any(self: Self) -> Self: - return self._from_call( - lambda _input: FunctionExpression("bool_or", _input), "any" - ) + return self._from_call(lambda _input: FunctionExpression("bool_or", _input)) def quantile( self: Self, @@ -407,7 +340,7 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: msg = "Only linear interpolation methods are supported for DuckDB quantile." raise NotImplementedError(msg) - return self._from_call(func, "quantile") + return self._from_call(func) def clip(self: Self, lower_bound: Any, upper_bound: Any) -> Self: def func( @@ -417,12 +350,10 @@ def func( "greatest", FunctionExpression("least", _input, upper_bound), lower_bound ) - return self._from_call( - func, "clip", lower_bound=lower_bound, upper_bound=upper_bound - ) + return self._from_call(func, lower_bound=lower_bound, upper_bound=upper_bound) def sum(self: Self) -> Self: - return self._from_call(lambda _input: FunctionExpression("sum", _input), "sum") + return self._from_call(lambda _input: FunctionExpression("sum", _input)) def n_unique(self: Self) -> Self: def func(_input: duckdb.Expression) -> duckdb.Expression: @@ -436,15 +367,13 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: ), ) - return self._from_call(func, "n_unique") + return self._from_call(func) def count(self: Self) -> Self: - return self._from_call( - lambda _input: FunctionExpression("count", _input), "count" - ) + return self._from_call(lambda _input: FunctionExpression("count", _input)) def len(self: Self) -> Self: - return self._from_call(lambda _input: FunctionExpression("count"), "len") + return self._from_call(lambda _input: FunctionExpression("count")) def std(self: Self, ddof: int) -> Self: def _std(_input: duckdb.Expression) -> duckdb.Expression: @@ -456,7 +385,7 @@ def _std(_input: duckdb.Expression) -> duckdb.Expression: / (FunctionExpression("sqrt", (n_samples - ddof))) # type: ignore[operator] ) - return self._from_call(_std, "std") + return self._from_call(_std) def var(self: Self, ddof: int) -> Self: def _var(_input: duckdb.Expression) -> duckdb.Expression: @@ -464,18 +393,17 @@ def _var(_input: duckdb.Expression) -> duckdb.Expression: # NOTE: Not implemented Error: Unable to transform python value of type '' to DuckDB LogicalType return FunctionExpression("var_pop", _input) * n_samples / (n_samples - ddof) # type: ignore[operator, no-any-return] - return self._from_call(_var, "var") + return self._from_call(_var) def max(self: Self) -> Self: - return self._from_call(lambda _input: FunctionExpression("max", _input), "max") + return self._from_call(lambda _input: FunctionExpression("max", _input)) def min(self: Self) -> Self: - return self._from_call(lambda _input: FunctionExpression("min", _input), "min") + return self._from_call(lambda _input: FunctionExpression("min", _input)) def null_count(self: Self) -> Self: return self._from_call( lambda _input: FunctionExpression("sum", _input.isnull().cast("int")), - "null_count", ) def over( @@ -505,7 +433,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return self.__class__( func, - function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, @@ -513,26 +440,22 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: ) def is_null(self: Self) -> Self: - return self._from_call(lambda _input: _input.isnull(), "is_null") + return self._from_call(lambda _input: _input.isnull()) def is_nan(self: Self) -> Self: - return self._from_call( - lambda _input: FunctionExpression("isnan", _input), "is_nan" - ) + return self._from_call(lambda _input: FunctionExpression("isnan", _input)) def is_finite(self: Self) -> Self: - return self._from_call( - lambda _input: FunctionExpression("isfinite", _input), "is_finite" - ) + return self._from_call(lambda _input: FunctionExpression("isfinite", _input)) def is_in(self: Self, other: Sequence[Any]) -> Self: return self._from_call( - lambda _input: FunctionExpression("contains", lit(other), _input), "is_in" + lambda _input: FunctionExpression("contains", lit(other), _input) ) def round(self: Self, decimals: int) -> Self: return self._from_call( - lambda _input: FunctionExpression("round", _input, lit(decimals)), "round" + lambda _input: FunctionExpression("round", _input, lit(decimals)) ) def cum_sum(self, *, reverse: bool) -> Self: @@ -581,14 +504,14 @@ def fill_null( def func(_input: duckdb.Expression, value: Any) -> duckdb.Expression: return CoalesceOperator(_input, value) - return self._from_call(func, "fill_null", value=value) + return self._from_call(func, value=value) def cast(self: Self, dtype: DType | type[DType]) -> Self: def func(_input: duckdb.Expression) -> duckdb.Expression: native_dtype = narwhals_to_native_dtype(dtype, self._version) return _input.cast(DuckDBPyType(native_dtype)) - return self._from_call(func, "cast") + return self._from_call(func) @property def str(self: Self) -> DuckDBExprStringNamespace: diff --git a/narwhals/_duckdb/expr_dt.py b/narwhals/_duckdb/expr_dt.py index e5e75a8004..08465de819 100644 --- a/narwhals/_duckdb/expr_dt.py +++ b/narwhals/_duckdb/expr_dt.py @@ -18,85 +18,79 @@ def __init__(self: Self, expr: DuckDBExpr) -> None: def year(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("year", _input), "year" + lambda _input: FunctionExpression("year", _input) ) def month(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("month", _input), "month" + lambda _input: FunctionExpression("month", _input) ) def day(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("day", _input), "day" + lambda _input: FunctionExpression("day", _input) ) def hour(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("hour", _input), "hour" + lambda _input: FunctionExpression("hour", _input) ) def minute(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("minute", _input), "minute" + lambda _input: FunctionExpression("minute", _input) ) def second(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("second", _input), "second" + lambda _input: FunctionExpression("second", _input) ) def millisecond(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("millisecond", _input) - FunctionExpression("second", _input) * lit(1_000), - "millisecond", ) def microsecond(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("microsecond", _input) - FunctionExpression("second", _input) * lit(1_000_000), - "microsecond", ) def nanosecond(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("nanosecond", _input) - FunctionExpression("second", _input) * lit(1_000_000_000), - "nanosecond", ) def to_string(self: Self, format: str) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("strftime", _input, lit(format)), - "to_string", ) def weekday(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("isodow", _input), "weekday" + lambda _input: FunctionExpression("isodow", _input) ) def ordinal_day(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("dayofyear", _input), "ordinal_day" + lambda _input: FunctionExpression("dayofyear", _input) ) def date(self: Self) -> DuckDBExpr: - return self._compliant_expr._from_call(lambda _input: _input.cast("date"), "date") + return self._compliant_expr._from_call(lambda _input: _input.cast("date")) def total_minutes(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("datepart", lit("minute"), _input), - "total_minutes", ) def total_seconds(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: lit(60) * FunctionExpression("datepart", lit("minute"), _input) + FunctionExpression("datepart", lit("second"), _input), - "total_seconds", ) def total_milliseconds(self: Self) -> DuckDBExpr: @@ -104,7 +98,6 @@ def total_milliseconds(self: Self) -> DuckDBExpr: lambda _input: lit(60_000) * FunctionExpression("datepart", lit("minute"), _input) + FunctionExpression("datepart", lit("millisecond"), _input), - "total_milliseconds", ) def total_microseconds(self: Self) -> DuckDBExpr: @@ -112,7 +105,6 @@ def total_microseconds(self: Self) -> DuckDBExpr: lambda _input: lit(60_000_000) * FunctionExpression("datepart", lit("minute"), _input) + FunctionExpression("datepart", lit("microsecond"), _input), - "total_microseconds", ) def total_nanoseconds(self: Self) -> DuckDBExpr: diff --git a/narwhals/_duckdb/expr_list.py b/narwhals/_duckdb/expr_list.py index 45cd93df40..de7290c1de 100644 --- a/narwhals/_duckdb/expr_list.py +++ b/narwhals/_duckdb/expr_list.py @@ -16,5 +16,5 @@ def __init__(self: Self, expr: DuckDBExpr) -> None: def len(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("len", _input), "len" + lambda _input: FunctionExpression("len", _input) ) diff --git a/narwhals/_duckdb/expr_str.py b/narwhals/_duckdb/expr_str.py index 191ea40291..a3d77b05aa 100644 --- a/narwhals/_duckdb/expr_str.py +++ b/narwhals/_duckdb/expr_str.py @@ -20,14 +20,12 @@ def __init__(self: Self, expr: DuckDBExpr) -> None: def starts_with(self: Self, prefix: str) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("starts_with", _input, lit(prefix)), - "starts_with", + lambda _input: FunctionExpression("starts_with", _input, lit(prefix)) ) def ends_with(self: Self, suffix: str) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("ends_with", _input, lit(suffix)), - "ends_with", + lambda _input: FunctionExpression("ends_with", _input, lit(suffix)) ) def contains(self: Self, pattern: str, *, literal: bool) -> DuckDBExpr: @@ -36,7 +34,7 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: return FunctionExpression("contains", _input, lit(pattern)) return FunctionExpression("regexp_matches", _input, lit(pattern)) - return self._compliant_expr._from_call(func, "contains") + return self._compliant_expr._from_call(func) def slice(self: Self, offset: int, length: int) -> DuckDBExpr: def func(_input: duckdb.Expression) -> duckdb.Expression: @@ -52,27 +50,26 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: else lit(length) + offset_lit, ) - return self._compliant_expr._from_call(func, "slice") + return self._compliant_expr._from_call(func) def split(self: Self, by: str) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("str_split", _input, lit(by)), - "split", + lambda _input: FunctionExpression("str_split", _input, lit(by)) ) def len_chars(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("length", _input), "len_chars" + lambda _input: FunctionExpression("length", _input) ) def to_lowercase(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("lower", _input), "to_lowercase" + lambda _input: FunctionExpression("lower", _input) ) def to_uppercase(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("upper", _input), "to_uppercase" + lambda _input: FunctionExpression("upper", _input) ) def strip_chars(self: Self, characters: str | None) -> DuckDBExpr: @@ -83,8 +80,7 @@ def strip_chars(self: Self, characters: str | None) -> DuckDBExpr: "trim", _input, lit(string.whitespace if characters is None else characters), - ), - "strip_chars", + ) ) def replace_all(self: Self, pattern: str, value: str, *, literal: bool) -> DuckDBExpr: @@ -92,14 +88,10 @@ def replace_all(self: Self, pattern: str, value: str, *, literal: bool) -> DuckD return self._compliant_expr._from_call( lambda _input: FunctionExpression( "regexp_replace", _input, lit(pattern), lit(value), lit("g") - ), - "replace_all", + ) ) return self._compliant_expr._from_call( - lambda _input: FunctionExpression( - "replace", _input, lit(pattern), lit(value) - ), - "replace_all", + lambda _input: FunctionExpression("replace", _input, lit(pattern), lit(value)) ) def replace(self: Self, pattern: str, value: str, *, literal: bool, n: int) -> Never: @@ -112,6 +104,5 @@ def to_datetime(self: Self, format: str | None) -> DuckDBExpr: raise NotImplementedError(msg) return self._compliant_expr._from_call( - lambda _input: FunctionExpression("strptime", _input, lit(format)), - "to_datetime", + lambda _input: FunctionExpression("strptime", _input, lit(format)) ) diff --git a/narwhals/_duckdb/expr_struct.py b/narwhals/_duckdb/expr_struct.py index 1f750e1326..2c4dc114c4 100644 --- a/narwhals/_duckdb/expr_struct.py +++ b/narwhals/_duckdb/expr_struct.py @@ -18,6 +18,5 @@ def __init__(self: Self, expr: DuckDBExpr) -> None: def field(self: Self, name: str) -> DuckDBExpr: return self._compliant_expr._from_call( - lambda _input: FunctionExpression("struct_extract", _input, lit(name)), - "field", + lambda _input: FunctionExpression("struct_extract", _input, lit(name)) ).alias(name) From bac71ce174c66c17aba03414639313807739263f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 15:13:49 +0000 Subject: [PATCH 05/20] clean up spark-like --- narwhals/_spark_like/expr.py | 121 +++++++++++----------------- narwhals/_spark_like/expr_dt.py | 26 +++--- narwhals/_spark_like/expr_list.py | 2 +- narwhals/_spark_like/expr_str.py | 28 +++---- narwhals/_spark_like/expr_struct.py | 2 +- 5 files changed, 71 insertions(+), 108 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 055e2a60cf..ca6e18c5eb 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -193,7 +193,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: def _from_call( self: Self, call: Callable[..., Column], - expr_name: str, **expressifiable_args: Self | Any, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: @@ -217,113 +216,87 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ) def __eq__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] - return self._from_call( - lambda _input, other: _input.__eq__(other), "__eq__", other=other - ) + return self._from_call(lambda _input, other: _input.__eq__(other), other=other) def __ne__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] - return self._from_call( - lambda _input, other: _input.__ne__(other), "__ne__", other=other - ) + return self._from_call(lambda _input, other: _input.__ne__(other), other=other) def __add__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__add__(other), "__add__", other=other - ) + return self._from_call(lambda _input, other: _input.__add__(other), other=other) def __sub__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__sub__(other), "__sub__", other=other - ) + return self._from_call(lambda _input, other: _input.__sub__(other), other=other) def __rsub__(self: Self, other: SparkLikeExpr) -> Self: return self._from_call( - lambda _input, other: other.__sub__(_input), "__rsub__", other=other + lambda _input, other: other.__sub__(_input), other=other ).alias("literal") def __mul__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__mul__(other), "__mul__", other=other - ) + return self._from_call(lambda _input, other: _input.__mul__(other), other=other) def __truediv__(self: Self, other: SparkLikeExpr) -> Self: return self._from_call( - lambda _input, other: _input.__truediv__(other), "__truediv__", other=other + lambda _input, other: _input.__truediv__(other), other=other ) def __rtruediv__(self: Self, other: SparkLikeExpr) -> Self: return self._from_call( - lambda _input, other: other.__truediv__(_input), "__rtruediv__", other=other + lambda _input, other: other.__truediv__(_input), other=other ).alias("literal") def __floordiv__(self: Self, other: SparkLikeExpr) -> Self: def _floordiv(_input: Column, other: Column) -> Column: return self._F.floor(_input / other) - return self._from_call(_floordiv, "__floordiv__", other=other) + return self._from_call(_floordiv, other=other) def __rfloordiv__(self: Self, other: SparkLikeExpr) -> Self: def _rfloordiv(_input: Column, other: Column) -> Column: return self._F.floor(other / _input) - return self._from_call(_rfloordiv, "__rfloordiv__", other=other).alias("literal") + return self._from_call(_rfloordiv, other=other).alias("literal") def __pow__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__pow__(other), "__pow__", other=other - ) + return self._from_call(lambda _input, other: _input.__pow__(other), other=other) def __rpow__(self: Self, other: SparkLikeExpr) -> Self: return self._from_call( - lambda _input, other: other.__pow__(_input), "__rpow__", other=other + lambda _input, other: other.__pow__(_input), other=other ).alias("literal") def __mod__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__mod__(other), "__mod__", other=other - ) + return self._from_call(lambda _input, other: _input.__mod__(other), other=other) def __rmod__(self: Self, other: SparkLikeExpr) -> Self: return self._from_call( - lambda _input, other: other.__mod__(_input), "__rmod__", other=other + lambda _input, other: other.__mod__(_input), other=other ).alias("literal") def __ge__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__ge__(other), "__ge__", other=other - ) + return self._from_call(lambda _input, other: _input.__ge__(other), other=other) def __gt__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input > other, "__gt__", other=other - ) + return self._from_call(lambda _input, other: _input > other, other=other) def __le__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__le__(other), "__le__", other=other - ) + return self._from_call(lambda _input, other: _input.__le__(other), other=other) def __lt__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__lt__(other), "__lt__", other=other - ) + return self._from_call(lambda _input, other: _input.__lt__(other), other=other) def __and__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__and__(other), "__and__", other=other - ) + return self._from_call(lambda _input, other: _input.__and__(other), other=other) def __or__(self: Self, other: SparkLikeExpr) -> Self: - return self._from_call( - lambda _input, other: _input.__or__(other), "__or__", other=other - ) + return self._from_call(lambda _input, other: _input.__or__(other), other=other) def __invert__(self: Self) -> Self: invert = cast("Callable[..., Column]", operator.invert) - return self._from_call(invert, "__invert__") + return self._from_call(invert) def abs(self: Self) -> Self: - return self._from_call(self._F.abs, "abs") + return self._from_call(self._F.abs) def alias(self: Self, name: str) -> Self: def alias_output_names(names: Sequence[str]) -> Sequence[str]: @@ -342,10 +315,10 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: ) def all(self: Self) -> Self: - return self._from_call(self._F.bool_and, "all") + return self._from_call(self._F.bool_and) def any(self: Self) -> Self: - return self._from_call(self._F.bool_or, "any") + return self._from_call(self._F.bool_or) def cast(self: Self, dtype: DType | type[DType]) -> Self: def _cast(_input: Column) -> Column: @@ -354,16 +327,16 @@ def _cast(_input: Column) -> Column: ) return _input.cast(spark_dtype) - return self._from_call(_cast, "cast") + return self._from_call(_cast) def count(self: Self) -> Self: - return self._from_call(self._F.count, "count") + return self._from_call(self._F.count) def max(self: Self) -> Self: - return self._from_call(self._F.max, "max") + return self._from_call(self._F.max) def mean(self: Self) -> Self: - return self._from_call(self._F.mean, "mean") + return self._from_call(self._F.mean) def median(self: Self) -> Self: def _median(_input: Column) -> Column: @@ -377,19 +350,19 @@ def _median(_input: Column) -> Column: return self._F.median(_input) - return self._from_call(_median, "median") + return self._from_call(_median) def min(self: Self) -> Self: - return self._from_call(self._F.min, "min") + return self._from_call(self._F.min) def null_count(self: Self) -> Self: def _null_count(_input: Column) -> Column: return self._F.count_if(self._F.isnull(_input)) - return self._from_call(_null_count, "null_count") + return self._from_call(_null_count) def sum(self: Self) -> Self: - return self._from_call(self._F.sum, "sum") + return self._from_call(self._F.sum) def std(self: Self, ddof: int) -> Self: from functools import partial @@ -406,7 +379,7 @@ def std(self: Self, ddof: int) -> Self: implementation=self._implementation, ) - return self._from_call(func, "std") + return self._from_call(func) def var(self: Self, ddof: int) -> Self: from functools import partial @@ -423,7 +396,7 @@ def var(self: Self, ddof: int) -> Self: implementation=self._implementation, ) - return self._from_call(func, "var") + return self._from_call(func) def clip( self: Self, @@ -446,11 +419,11 @@ def _clip_both( return self._F.when(result > upper_bound, upper_bound).otherwise(result) if lower_bound is None: - return self._from_call(_clip_upper, "clip", upper_bound=upper_bound) + return self._from_call(_clip_upper, upper_bound=upper_bound) if upper_bound is None: - return self._from_call(_clip_lower, "clip", lower_bound=lower_bound) + return self._from_call(_clip_lower, lower_bound=lower_bound) return self._from_call( - _clip_both, "clip", lower_bound=lower_bound, upper_bound=upper_bound + _clip_both, lower_bound=lower_bound, upper_bound=upper_bound ) def is_finite(self: Self) -> Self: @@ -466,36 +439,36 @@ def _is_finite(_input: Column) -> Column: None ) - return self._from_call(_is_finite, "is_finite") + return self._from_call(_is_finite) def is_in(self: Self, values: Sequence[Any]) -> Self: def _is_in(_input: Column) -> Column: return _input.isin(values) if values else self._F.lit(False) # noqa: FBT003 - return self._from_call(_is_in, "is_in") + return self._from_call(_is_in) def is_unique(self: Self) -> Self: def _is_unique(_input: Column) -> Column: # Create a window spec that treats each value separately return self._F.count("*").over(self._Window.partitionBy(_input)) == 1 - return self._from_call(_is_unique, "is_unique") + return self._from_call(_is_unique) def len(self: Self) -> Self: def _len(_input: Column) -> Column: # Use count(*) to count all rows including nulls return self._F.count("*") - return self._from_call(_len, "len") + return self._from_call(_len) def round(self: Self, decimals: int) -> Self: def _round(_input: Column) -> Column: return self._F.round(_input, decimals) - return self._from_call(_round, "round") + return self._from_call(_round) def skew(self: Self) -> Self: - return self._from_call(self._F.skewness, "skew") + return self._from_call(self._F.skewness) def n_unique(self: Self) -> Self: def _n_unique(_input: Column) -> Column: @@ -503,7 +476,7 @@ def _n_unique(_input: Column) -> Column: self._F.isnull(_input).cast(self._native_dtypes.IntegerType()) ) - return self._from_call(_n_unique, "n_unique") + return self._from_call(_n_unique) def over( self: Self, @@ -536,7 +509,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ) def is_null(self: Self) -> Self: - return self._from_call(self._F.isnull, "is_null") + return self._from_call(self._F.isnull) def is_nan(self: Self) -> Self: def _is_nan(_input: Column) -> Column: @@ -544,7 +517,7 @@ def _is_nan(_input: Column) -> Column: self._F.isnan(_input) ) - return self._from_call(_is_nan, "is_nan") + return self._from_call(_is_nan) def cum_sum(self, *, reverse: bool) -> Self: def func( @@ -577,7 +550,7 @@ def fill_null( def _fill_null(_input: Column, value: Column) -> Column: return self._F.ifnull(_input, value) - return self._from_call(_fill_null, "fill_null", value=value) + return self._from_call(_fill_null, value=value) def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: if center: diff --git a/narwhals/_spark_like/expr_dt.py b/narwhals/_spark_like/expr_dt.py index 9b1cec04d9..2b60097f3b 100644 --- a/narwhals/_spark_like/expr_dt.py +++ b/narwhals/_spark_like/expr_dt.py @@ -14,25 +14,25 @@ def __init__(self: Self, expr: SparkLikeExpr) -> None: self._compliant_expr = expr def date(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call(self._compliant_expr._F.to_date, "date") + return self._compliant_expr._from_call(self._compliant_expr._F.to_date) def year(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call(self._compliant_expr._F.year, "year") + return self._compliant_expr._from_call(self._compliant_expr._F.year) def month(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call(self._compliant_expr._F.month, "month") + return self._compliant_expr._from_call(self._compliant_expr._F.month) def day(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call(self._compliant_expr._F.day, "day") + return self._compliant_expr._from_call(self._compliant_expr._F.day) def hour(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call(self._compliant_expr._F.hour, "hour") + return self._compliant_expr._from_call(self._compliant_expr._F.hour) def minute(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call(self._compliant_expr._F.minute, "minute") + return self._compliant_expr._from_call(self._compliant_expr._F.minute) def second(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call(self._compliant_expr._F.second, "second") + return self._compliant_expr._from_call(self._compliant_expr._F.second) def millisecond(self: Self) -> SparkLikeExpr: def _millisecond(_input: Column) -> Column: @@ -40,28 +40,26 @@ def _millisecond(_input: Column) -> Column: (self._compliant_expr._F.unix_micros(_input) % 1_000_000) / 1000 ) - return self._compliant_expr._from_call(_millisecond, "millisecond") + return self._compliant_expr._from_call(_millisecond) def microsecond(self: Self) -> SparkLikeExpr: def _microsecond(_input: Column) -> Column: return self._compliant_expr._F.unix_micros(_input) % 1_000_000 - return self._compliant_expr._from_call(_microsecond, "microsecond") + return self._compliant_expr._from_call(_microsecond) def nanosecond(self: Self) -> SparkLikeExpr: def _nanosecond(_input: Column) -> Column: return (self._compliant_expr._F.unix_micros(_input) % 1_000_000) * 1000 - return self._compliant_expr._from_call(_nanosecond, "nanosecond") + return self._compliant_expr._from_call(_nanosecond) def ordinal_day(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call( - self._compliant_expr._F.dayofyear, "ordinal_day" - ) + return self._compliant_expr._from_call(self._compliant_expr._F.dayofyear) def weekday(self: Self) -> SparkLikeExpr: def _weekday(_input: Column) -> Column: # PySpark's dayofweek returns 1-7 for Sunday-Saturday return (self._compliant_expr._F.dayofweek(_input) + 6) % 7 - return self._compliant_expr._from_call(_weekday, "weekday") + return self._compliant_expr._from_call(_weekday) diff --git a/narwhals/_spark_like/expr_list.py b/narwhals/_spark_like/expr_list.py index 006a4ff659..bc125ce083 100644 --- a/narwhals/_spark_like/expr_list.py +++ b/narwhals/_spark_like/expr_list.py @@ -13,4 +13,4 @@ def __init__(self: Self, expr: SparkLikeExpr) -> None: self._compliant_expr = expr def len(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call(self._compliant_expr._F.array_size, "len") + return self._compliant_expr._from_call(self._compliant_expr._F.array_size) diff --git a/narwhals/_spark_like/expr_str.py b/narwhals/_spark_like/expr_str.py index 36bf2990f1..2dc2121035 100644 --- a/narwhals/_spark_like/expr_str.py +++ b/narwhals/_spark_like/expr_str.py @@ -15,7 +15,7 @@ def __init__(self: Self, expr: SparkLikeExpr) -> None: self._compliant_expr = expr def len_chars(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call(self._compliant_expr._F.char_length, "len") + return self._compliant_expr._from_call(self._compliant_expr._F.char_length) def replace_all( self: Self, pattern: str, value: str, *, literal: bool @@ -32,7 +32,7 @@ def func(_input: Column) -> Column: self._compliant_expr._F.lit(value), # pyright: ignore[reportArgumentType] ) - return self._compliant_expr._from_call(func, "replace") + return self._compliant_expr._from_call(func) def strip_chars(self: Self, characters: str | None) -> SparkLikeExpr: import string @@ -43,22 +43,20 @@ def func(_input: Column) -> Column: _input, self._compliant_expr._F.lit(to_remove) ) - return self._compliant_expr._from_call(func, "strip") + return self._compliant_expr._from_call(func) def starts_with(self: Self, prefix: str) -> SparkLikeExpr: return self._compliant_expr._from_call( lambda _input: self._compliant_expr._F.startswith( _input, self._compliant_expr._F.lit(prefix) - ), - "starts_with", + ) ) def ends_with(self: Self, suffix: str) -> SparkLikeExpr: return self._compliant_expr._from_call( lambda _input: self._compliant_expr._F.endswith( _input, self._compliant_expr._F.lit(suffix) - ), - "ends_with", + ) ) def contains(self: Self, pattern: str, *, literal: bool) -> SparkLikeExpr: @@ -70,7 +68,7 @@ def func(_input: Column) -> Column: ) return contains_func(_input, self._compliant_expr._F.lit(pattern)) - return self._compliant_expr._from_call(func, "contains") + return self._compliant_expr._from_call(func) def slice(self: Self, offset: int, length: int | None) -> SparkLikeExpr: # From the docs: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.substring.html @@ -88,23 +86,18 @@ def func(_input: Column) -> Column: ) return _input.substr(_offset, _length) - return self._compliant_expr._from_call(func, "slice") + return self._compliant_expr._from_call(func) def split(self: Self, by: str) -> SparkLikeExpr: return self._compliant_expr._from_call( - lambda _input: self._compliant_expr._F.split(_input, by), - "split", + lambda _input: self._compliant_expr._F.split(_input, by) ) def to_uppercase(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call( - self._compliant_expr._F.upper, "to_uppercase" - ) + return self._compliant_expr._from_call(self._compliant_expr._F.upper) def to_lowercase(self: Self) -> SparkLikeExpr: - return self._compliant_expr._from_call( - self._compliant_expr._F.lower, "to_lowercase" - ) + return self._compliant_expr._from_call(self._compliant_expr._F.lower) def to_datetime(self: Self, format: str | None) -> SparkLikeExpr: F = self._compliant_expr._F # noqa: N806 @@ -119,7 +112,6 @@ def to_datetime(self: Self, format: str | None) -> SparkLikeExpr: function = partial(F.to_timestamp, format=format) return self._compliant_expr._from_call( lambda _input: function(F.replace(_input, F.lit("T"), F.lit(" "))), - "to_datetime", ) diff --git a/narwhals/_spark_like/expr_struct.py b/narwhals/_spark_like/expr_struct.py index dab13daec2..ff1dad874e 100644 --- a/narwhals/_spark_like/expr_struct.py +++ b/narwhals/_spark_like/expr_struct.py @@ -17,4 +17,4 @@ def field(self: Self, name: str) -> SparkLikeExpr: def func(_input: Column) -> Column: return _input.getField(name) - return self._compliant_expr._from_call(func, "field").alias(name) + return self._compliant_expr._from_call(func).alias(name) From c8f483d53dbf2ca32f3d0e3ee67c1bf5a93d481d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 19:15:35 +0000 Subject: [PATCH 06/20] snake case --- narwhals/_expression_parsing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index cafcb92f42..6e63d78930 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -195,20 +195,20 @@ class ExpansionKind(Enum): SINGLE = auto() """e.g. `nw.col('a'), nw.sum_horizontal(nw.all())`""" - MULTINAMED = auto() + MULTI_NAMED = auto() """e.g. `nw.col('a', 'b')`""" - MULTIUNNAMED = auto() + MULTI_UNNAMED = auto() """e.g. `nw.all()`, nw.nth(0, 1)""" def is_multi_unnamed(self) -> bool: - return self is ExpansionKind.MULTIUNNAMED + return self is ExpansionKind.MULTI_UNNAMED def is_multi_output( expansion_kind: ExpansionKind, -) -> TypeIs[Literal[ExpansionKind.MULTINAMED, ExpansionKind.MULTIUNNAMED]]: - return expansion_kind in {ExpansionKind.MULTINAMED, ExpansionKind.MULTIUNNAMED} +) -> TypeIs[Literal[ExpansionKind.MULTI_NAMED, ExpansionKind.MULTI_UNNAMED]]: + return expansion_kind in {ExpansionKind.MULTI_NAMED, ExpansionKind.MULTI_UNNAMED} class ExprMetadata: @@ -273,7 +273,7 @@ def simple_selector() -> ExprMetadata: def multi_output_selector_named() -> ExprMetadata: # e.g. `nw.col('a', 'b')` return ExprMetadata( - ExprKind.TRANSFORM, n_open_windows=0, expansion_kind=ExpansionKind.MULTINAMED + ExprKind.TRANSFORM, n_open_windows=0, expansion_kind=ExpansionKind.MULTI_NAMED ) @staticmethod @@ -282,7 +282,7 @@ def multi_output_selector_unnamed() -> ExprMetadata: return ExprMetadata( ExprKind.TRANSFORM, n_open_windows=0, - expansion_kind=ExpansionKind.MULTIUNNAMED, + expansion_kind=ExpansionKind.MULTI_UNNAMED, ) From 3fd7284e7ccbe84c5f726053ab56c389eec1b262 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 19:31:49 +0000 Subject: [PATCH 07/20] factor out `resolve_expansion_kind` --- narwhals/_expression_parsing.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 6e63d78930..0aae57d40c 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -320,15 +320,13 @@ def combine_metadata( "are not supported in this context." ) raise MultiOutputExpressionError(msg) - if not to_single_output: # pragma: no cover - if i != 0 and arg._metadata.expansion_kind != result_expansion_kind: - msg = "Safety assertion failed, please report a bug." - raise AssertionError(msg) - # Preserve expansion kind. e.g. - # - `nw.all() + nw.col('a')` - # - `nw.selectors.datetime() - nw.selectors.numeric() - # preserve the expansion kind of the left-hand-side. - result_expansion_kind = arg._metadata.expansion_kind + if not to_single_output: + if i == 0: + result_expansion_kind = arg._metadata.expansion_kind + else: + result_expansion_kind = resolve_expansion_kind( + result_expansion_kind, arg._metadata.expansion_kind + ) if arg._metadata.n_open_windows: result_n_open_windows += 1 kind = arg._metadata.kind @@ -371,6 +369,15 @@ def combine_metadata( ) +def resolve_expansion_kind(lhs: ExpansionKind, rhs: ExpansionKind) -> ExpansionKind: + if lhs is ExpansionKind.MULTI_UNNAMED and rhs is ExpansionKind.MULTI_UNNAMED: + # e.g. nw.selectors.all() - nw.selectors.numeric(). + return ExpansionKind.MULTI_UNNAMED + # Don't attempt anything more complex, keep it simple and raise in the face of ambiguity. + msg = f"Unsupported ExpansionKind combination, got {lhs} and {rhs}, please report a bug." # pragma: no cover + raise AssertionError(msg) # pragma: no cover + + def combine_metadata_binary_op(lhs: Expr, rhs: IntoExpr) -> ExprMetadata: # We may be able to allow multi-output rhs in the future: # https://github.com/narwhals-dev/narwhals/issues/2244. From d06b4e544023d01f582159b4ee02ae490cb2a8d0 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 20:28:51 +0000 Subject: [PATCH 08/20] completely remove function_name from duckdb and pyspark --- Makefile | 4 ++-- narwhals/_arrow/namespace.py | 24 ++++++++++++++++++++++++ narwhals/_compliant/expr.py | 11 ----------- narwhals/_compliant/namespace.py | 28 ++++------------------------ narwhals/_dask/namespace.py | 25 +++++++++++++++++++++++++ narwhals/_duckdb/expr.py | 8 -------- narwhals/_duckdb/namespace.py | 22 ++++++++++++++++++++++ narwhals/_pandas_like/namespace.py | 25 +++++++++++++++++++++++++ narwhals/_spark_like/expr.py | 2 -- narwhals/_spark_like/namespace.py | 22 ++++++++++++++++++++++ 10 files changed, 124 insertions(+), 47 deletions(-) diff --git a/Makefile b/Makefile index 5a0feb5f93..fc297828bf 100644 --- a/Makefile +++ b/Makefile @@ -21,6 +21,6 @@ help: ## Display this help screen .PHONY: typing typing: ## Run typing checks # install duckdb nightly so mypy recognises duckdb.SQLExpression - $(VENV_BIN)/uv pip install -U --pre duckdb - $(VENV_BIN)/uv pip install -e . --group typing + $(VENV_BIN)/uv pip install -U --pre duckdb --offline + $(VENV_BIN)/uv pip install -e . --group typing --offline $(VENV_BIN)/mypy diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index a44a045d9d..ad3b7e3f73 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -1,11 +1,13 @@ from __future__ import annotations import operator +from functools import partial from functools import reduce from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -27,7 +29,10 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation +from narwhals.utils import exclude_column_names +from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing import Callable @@ -60,6 +65,25 @@ def __init__( self._version = version # --- selection --- + def all(self) -> ArrowExpr: + return self._expr.from_column_names( + get_column_names, function_name="all", context=self + ) + + def col(self, *column_names: str) -> ArrowExpr: + return self._expr.from_column_names( + passthrough_column_names(column_names), function_name="col", context=self + ) + + def exclude(self, excluded_names: Container[str]) -> ArrowExpr: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), + function_name="exclude", + context=self, + ) + + def nth(self, *column_indices: int) -> ArrowExpr: + return self._expr.from_column_indices(*column_indices, context=self) def len(self: Self) -> ArrowExpr: # coverage bug? this is definitely hit diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 6960e24b12..eebdfa8b6b 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -95,17 +95,6 @@ def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__( self, ) -> CompliantNamespace[CompliantFrameT, Self]: ... - @classmethod - def from_column_names( - cls, - evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]], - /, - *, - function_name: str, - context: _FullContext, - ) -> Self: ... - @classmethod - def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: ... def _with_metadata(self, metadata: ExprMetadata) -> Self: ... diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index f5449ec404..43fcdf9090 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -1,6 +1,5 @@ from __future__ import annotations -from functools import partial from typing import TYPE_CHECKING from typing import Any from typing import Container @@ -13,9 +12,6 @@ from narwhals._compliant.typing import EagerDataFrameT from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT_co -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from narwhals._compliant.selectors import CompliantSelectorNamespace @@ -31,26 +27,10 @@ class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): _backend_version: tuple[int, ...] _version: Version - def all(self) -> CompliantExprT: - return self._expr.from_column_names( - get_column_names, function_name="all", context=self - ) - - def col(self, *column_names: str) -> CompliantExprT: - return self._expr.from_column_names( - passthrough_column_names(column_names), function_name="col", context=self - ) - - def exclude(self, excluded_names: Container[str]) -> CompliantExprT: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - context=self, - ) - - def nth(self, *column_indices: int) -> CompliantExprT: - return self._expr.from_column_indices(*column_indices, context=self) - + def all(self) -> CompliantExprT: ... + def col(self, *column_names: str) -> CompliantExprT: ... + def exclude(self, excluded_names: Container[str]) -> CompliantExprT: ... + def nth(self, *column_indices: int) -> CompliantExprT: ... def len(self) -> CompliantExprT: ... def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ... def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 9ea92b75c9..da11b74055 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -1,10 +1,12 @@ from __future__ import annotations import operator +from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -24,6 +26,9 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation +from narwhals.utils import exclude_column_names +from narwhals.utils import get_column_names +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -54,6 +59,26 @@ def __init__( self._backend_version = backend_version self._version = version + def all(self) -> DaskExpr: + return self._expr.from_column_names( + get_column_names, function_name="all", context=self + ) + + def col(self, *column_names: str) -> DaskExpr: + return self._expr.from_column_names( + passthrough_column_names(column_names), function_name="col", context=self + ) + + def exclude(self, excluded_names: Container[str]) -> DaskExpr: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), + function_name="exclude", + context=self, + ) + + def nth(self, *column_indices: int) -> DaskExpr: + return self._expr.from_column_indices(*column_indices, context=self) + def lit(self: Self, value: Any, dtype: DType | None) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: if dtype is not None: diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 34b989e337..fbd751fa38 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -55,8 +55,6 @@ def __init__( self: Self, call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]], *, - # Unused, just for compatibility with CompliantExpr - function_name: str = "", evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], @@ -86,7 +84,6 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover def _with_metadata(self, metadata: ExprMetadata) -> Self: expr = self.__class__( self._call, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, @@ -111,7 +108,6 @@ def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: return self.__class__( func, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, @@ -124,7 +120,6 @@ def from_column_names( evaluate_column_names: Callable[[DuckDBLazyFrame], Sequence[str]], /, *, - function_name: str, context: _FullContext, ) -> Self: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: @@ -132,7 +127,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return cls( func, - function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, backend_version=context._backend_version, @@ -150,7 +144,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return cls( func, - function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, backend_version=context._backend_version, @@ -196,7 +189,6 @@ def _with_window_function( ) -> Self: result = self.__class__( self._call, - function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index b4affa444d..412892950e 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -1,10 +1,12 @@ from __future__ import annotations import operator +from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -24,6 +26,9 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation +from narwhals.utils import exclude_column_names +from narwhals.utils import get_column_names +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: import duckdb @@ -51,6 +56,23 @@ def selectors(self: Self) -> DuckDBSelectorNamespace: def _expr(self) -> type[DuckDBExpr]: return DuckDBExpr + def col(self, *column_names: str) -> DuckDBExpr: + return self._expr.from_column_names( + passthrough_column_names(column_names), context=self + ) + + def all(self) -> DuckDBExpr: + return self._expr.from_column_names(get_column_names, context=self) + + def exclude(self, excluded_names: Container[str]) -> DuckDBExpr: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), + context=self, + ) + + def nth(self, *column_indices: int) -> DuckDBExpr: + return self._expr.from_column_indices(*column_indices, context=self) + def concat( self: Self, items: Iterable[DuckDBLazyFrame], diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index dc525cb634..0c6aacb5d8 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -1,10 +1,12 @@ from __future__ import annotations import operator +from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -21,7 +23,10 @@ from narwhals._pandas_like.utils import extract_dataframe_comparand from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import vertical_concat +from narwhals.utils import exclude_column_names +from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -61,6 +66,26 @@ def __init__( self._version = version # --- selection --- + def all(self) -> PandasLikeExpr: + return self._expr.from_column_names( + get_column_names, function_name="all", context=self + ) + + def col(self, *column_names: str) -> PandasLikeExpr: + return self._expr.from_column_names( + passthrough_column_names(column_names), function_name="col", context=self + ) + + def exclude(self, excluded_names: Container[str]) -> PandasLikeExpr: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), + function_name="exclude", + context=self, + ) + + def nth(self, *column_indices: int) -> PandasLikeExpr: + return self._expr.from_column_indices(*column_indices, context=self) + def lit(self: Self, value: Any, dtype: DType | None) -> PandasLikeExpr: def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: pandas_series = self._series.from_iterable( diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index ca6e18c5eb..e801be04a4 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -47,7 +47,6 @@ def __init__( self: Self, call: Callable[[SparkLikeLazyFrame], Sequence[Column]], *, - function_name: str = "", evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], @@ -158,7 +157,6 @@ def from_column_names( evaluate_column_names: Callable[[SparkLikeLazyFrame], Sequence[str]], /, *, - function_name: str = "", # noqa: ARG003 context: _FullContext, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 243c55954d..1c691c925e 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -1,10 +1,12 @@ from __future__ import annotations import operator +from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -17,6 +19,9 @@ from narwhals._spark_like.selectors import SparkLikeSelectorNamespace from narwhals._spark_like.utils import maybe_evaluate_expr from narwhals._spark_like.utils import narwhals_to_native_dtype +from narwhals.utils import exclude_column_names +from narwhals.utils import get_column_names +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from sqlframe.base.column import Column @@ -47,6 +52,23 @@ def selectors(self: Self) -> SparkLikeSelectorNamespace: def _expr(self) -> type[SparkLikeExpr]: return SparkLikeExpr + def col(self, *column_names: str) -> SparkLikeExpr: + return self._expr.from_column_names( + passthrough_column_names(column_names), context=self + ) + + def all(self) -> SparkLikeExpr: + return self._expr.from_column_names(get_column_names, context=self) + + def exclude(self, excluded_names: Container[str]) -> SparkLikeExpr: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), + context=self, + ) + + def nth(self, *column_indices: int) -> SparkLikeExpr: + return self._expr.from_column_indices(*column_indices, context=self) + def lit(self: Self, value: object, dtype: DType | None) -> SparkLikeExpr: def _lit(df: SparkLikeLazyFrame) -> list[Column]: column = df._F.lit(value) From 291bd1082889933d42e599326016c6003ea05310 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 21:01:29 +0000 Subject: [PATCH 09/20] remove --offline --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index fc297828bf..5a0feb5f93 100644 --- a/Makefile +++ b/Makefile @@ -21,6 +21,6 @@ help: ## Display this help screen .PHONY: typing typing: ## Run typing checks # install duckdb nightly so mypy recognises duckdb.SQLExpression - $(VENV_BIN)/uv pip install -U --pre duckdb --offline - $(VENV_BIN)/uv pip install -e . --group typing --offline + $(VENV_BIN)/uv pip install -U --pre duckdb + $(VENV_BIN)/uv pip install -e . --group typing $(VENV_BIN)/mypy From 385fd13a432ae0814a896c2752fce456001a5ce3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 21:04:04 +0000 Subject: [PATCH 10/20] remove outdated arg --- narwhals/_duckdb/namespace.py | 1 - 1 file changed, 1 deletion(-) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 412892950e..e6a54826cd 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -316,7 +316,6 @@ def __init__( self: Self, call: DuckDBWhen, *, - function_name: str = "", evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], From 2d4dd8a980ad9cf22ef7ab02047df65aa15e4243 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Mar 2025 12:58:26 +0000 Subject: [PATCH 11/20] chore(typing): Ignore issues from #2263 - https://github.com/narwhals-dev/narwhals/pull/2263#discussion_r2009100508 - https://github.com/narwhals-dev/narwhals/pull/2263#discussion_r2009101659 --- narwhals/_duckdb/expr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index fbd751fa38..b6940fa90b 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -43,7 +43,7 @@ from narwhals.utils import _FullContext with contextlib.suppress(ImportError): # requires duckdb>=1.3.0 - from duckdb import SQLExpression + from duckdb import SQLExpression # type: ignore # noqa: PGH003 class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]): @@ -459,7 +459,7 @@ def func( order_by_sql = generate_order_by_sql(*order_by, ascending=not reverse) partition_by_sql = generate_partition_by_sql(*partition_by) sql = f"sum ({_input}) over ({partition_by_sql} {order_by_sql} rows between unbounded preceding and current row)" - return SQLExpression(sql) + return SQLExpression(sql) # type: ignore # noqa: PGH003 return self._with_window_function(func) @@ -482,7 +482,7 @@ def func( partition_by_sql = generate_partition_by_sql(*partition_by) window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})" sql = f"case when count({_input}) over {window} >= {min_samples} then sum({_input}) over {window} else null end" - return SQLExpression(sql) + return SQLExpression(sql) # type: ignore # noqa: PGH003 return self._with_window_function(func) From a58c583177bd456af4b29588ab43278f2129149c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Mar 2025 13:02:22 +0000 Subject: [PATCH 12/20] refactor: Add `DepthTrackingExpr` First part of https://github.com/narwhals-dev/narwhals/pull/2266#discussion_r2008942878 --- narwhals/_compliant/expr.py | 39 ++++++++++++++++++++++++-------- narwhals/_compliant/group_by.py | 25 ++++++++++++-------- narwhals/_compliant/selectors.py | 5 ---- narwhals/_dask/expr.py | 9 +++++--- narwhals/_duckdb/expr.py | 2 -- narwhals/_expression_parsing.py | 18 --------------- narwhals/_pandas_like/expr.py | 3 +-- narwhals/_spark_like/expr.py | 3 --- narwhals/utils.py | 8 ------- 9 files changed, 51 insertions(+), 61 deletions(-) diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index eebdfa8b6b..44c0e8ebdd 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -84,8 +84,6 @@ class CompliantExpr(Protocol38[CompliantFrameT, CompliantSeriesOrNativeExprT_co] _version: Version _evaluate_output_names: Callable[[CompliantFrameT], Sequence[str]] _alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None - _depth: int - _function_name: str _metadata: ExprMetadata | None def __call__( @@ -263,15 +261,39 @@ def broadcast( ) -> Self: ... +class DepthTrackingExpr( + CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co], + Protocol38[CompliantFrameT, CompliantSeriesOrNativeExprT_co], +): + _depth: int + _function_name: str + + def _is_elementary(self) -> bool: + """Check if expr is elementary. + + Examples: + - nw.col('a').mean() # depth 1 + - nw.mean('a') # depth 1 + - nw.len() # depth 0 + + as opposed to, say + + - nw.col('a').filter(nw.col('b')>nw.col('c')).max() + + Elementary expressions are the only ones supported properly in + pandas, PyArrow, and Dask. + """ + return self._depth < 2 + + def __repr__(self) -> str: # pragma: no cover + return f"{type(self).__name__}(depth={self._depth}, function_name={self._function_name})" + + class EagerExpr( - CompliantExpr[EagerDataFrameT, EagerSeriesT], + DepthTrackingExpr[EagerDataFrameT, EagerSeriesT], Protocol38[EagerDataFrameT, EagerSeriesT], ): _call: Callable[[EagerDataFrameT], Sequence[EagerSeriesT]] - _depth: int - _function_name: str - _evaluate_output_names: Any - _alias_output_names: Any _call_kwargs: dict[str, Any] def __init__( @@ -291,9 +313,6 @@ def __init__( def __call__(self, df: EagerDataFrameT) -> Sequence[EagerSeriesT]: return self._call(df) - def __repr__(self) -> str: # pragma: no cover - return f"{type(self).__name__}(depth={self._depth}, function_name={self._function_name})" - def __narwhals_namespace__( self, ) -> EagerNamespace[EagerDataFrameT, EagerSeriesT, Self]: ... diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 56aafd3914..ebde9d5c14 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -14,17 +14,18 @@ from typing import TypeVar from narwhals._compliant.typing import CompliantDataFrameT_co -from narwhals._compliant.typing import CompliantExprAny from narwhals._compliant.typing import CompliantExprT_contra from narwhals._compliant.typing import CompliantFrameT_co from narwhals._compliant.typing import CompliantLazyFrameT_co +from narwhals._compliant.typing import EagerExprT_contra from narwhals._compliant.typing import LazyExprT_contra from narwhals._compliant.typing import NativeExprT_co -from narwhals._expression_parsing import is_elementary_expression if TYPE_CHECKING: from typing_extensions import TypeAlias + from narwhals._compliant.expr import DepthTrackingExpr + if not TYPE_CHECKING: # pragma: no cover if sys.version_info >= (3, 9): from typing import Protocol as Protocol38 @@ -49,6 +50,10 @@ NarwhalsAggregation: TypeAlias = Literal[ "sum", "mean", "median", "max", "min", "std", "var", "len", "n_unique", "count" ] +DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]" +DepthTrackingExprT_contra = TypeVar( + "DepthTrackingExprT_contra", bound=DepthTrackingExprAny, contravariant=True +) _RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)") @@ -75,8 +80,8 @@ def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ... class DepthTrackingGroupBy( - CompliantGroupBy[CompliantFrameT_co, CompliantExprT_contra], - Protocol38[CompliantFrameT_co, CompliantExprT_contra, NativeAggregationT_co], + CompliantGroupBy[CompliantFrameT_co, DepthTrackingExprT_contra], + Protocol38[CompliantFrameT_co, DepthTrackingExprT_contra, NativeAggregationT_co], ): """`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`.""" @@ -87,7 +92,7 @@ class DepthTrackingGroupBy( - `Dask` *may* return a `Callable` instead of a `str` referring to one. """ - def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: + def _ensure_all_simple(self, exprs: Sequence[DepthTrackingExprT_contra]) -> None: for expr in exprs: if not self._is_simple(expr): name = self.compliant._implementation.name.lower() @@ -104,9 +109,9 @@ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None: raise ValueError(msg) @classmethod - def _is_simple(cls, expr: CompliantExprAny, /) -> bool: + def _is_simple(cls, expr: DepthTrackingExprAny, /) -> bool: """Return `True` is we can efficiently use `expr` in a native `group_by` context.""" - return is_elementary_expression(expr) and cls._leaf_name(expr) in cls._REMAP_AGGS + return expr._is_elementary() and cls._leaf_name(expr) in cls._REMAP_AGGS @classmethod def _remap_expr_name( @@ -123,14 +128,14 @@ def _remap_expr_name( return cls._REMAP_AGGS.get(name, name) @classmethod - def _leaf_name(cls, expr: CompliantExprAny, /) -> NarwhalsAggregation | Any: + def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any: """Return the last function name in the chain defined by `expr`.""" return _RE_LEAF_NAME.sub("", expr._function_name) class EagerGroupBy( - DepthTrackingGroupBy[CompliantDataFrameT_co, CompliantExprT_contra, str], - Protocol38[CompliantDataFrameT_co, CompliantExprT_contra], + DepthTrackingGroupBy[CompliantDataFrameT_co, EagerExprT_contra, str], + Protocol38[CompliantDataFrameT_co, EagerExprT_contra], ): def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ... diff --git a/narwhals/_compliant/selectors.py b/narwhals/_compliant/selectors.py index 83faee137c..6d74340b4a 100644 --- a/narwhals/_compliant/selectors.py +++ b/narwhals/_compliant/selectors.py @@ -20,7 +20,6 @@ from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module from narwhals.utils import is_compliant_dataframe -from narwhals.utils import is_tracks_depth if not TYPE_CHECKING: # pragma: no cover # TODO @dangotbanned: Remove after dropping `3.8` (#2084) @@ -302,10 +301,6 @@ def names(df: FrameT) -> Sequence[str]: def __invert__(self: Self) -> CompliantSelector[FrameT, SeriesOrExprT]: return self.selectors.all() - self # type: ignore[no-any-return] - def __repr__(self: Self) -> str: # pragma: no cover - s = f"depth={self._depth}, " if is_tracks_depth(self._implementation) else "" - return f"{type(self).__name__}({s}function_name={self._function_name})" - def _eval_lhs_rhs( df: CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any], diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 466edd9044..63976df4da 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -8,6 +8,7 @@ from typing import Sequence from narwhals._compliant import LazyExpr +from narwhals._compliant.expr import DepthTrackingExpr from narwhals._dask.expr_dt import DaskExprDateTimeNamespace from narwhals._dask.expr_name import DaskExprNameNamespace from narwhals._dask.expr_str import DaskExprStringNamespace @@ -16,7 +17,6 @@ from narwhals._dask.utils import narwhals_to_native_dtype from narwhals._expression_parsing import ExprKind from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._expression_parsing import is_elementary_expression from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals.exceptions import ColumnNotFoundError from narwhals.exceptions import InvalidOperationError @@ -42,7 +42,10 @@ from narwhals.utils import _FullContext -class DaskExpr(LazyExpr["DaskLazyFrame", "dx.Series"]): +class DaskExpr( + LazyExpr["DaskLazyFrame", "dx.Series"], + DepthTrackingExpr["DaskLazyFrame", "dx.Series"], +): _implementation: Implementation = Implementation.DASK def __init__( @@ -573,7 +576,7 @@ def over( # which we can always easily support, as it doesn't require grouping. def func(df: DaskLazyFrame) -> Sequence[dx.Series]: return self(df.sort(*order_by, descending=False, nulls_last=False)) - elif not is_elementary_expression(self): # pragma: no cover + elif not self._is_elementary(): # pragma: no cover msg = ( "Only elementary expressions are supported for `.over` in dask.\n\n" "Please see: " diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index b6940fa90b..03a1c9c466 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -48,8 +48,6 @@ class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]): _implementation = Implementation.DUCKDB - _depth = 0 # Unused, just for compatibility with CompliantExpr - _function_name = "" # Unused, just for compatibility with CompliantExpr def __init__( self: Self, diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 0aae57d40c..abc8b7a771 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -46,24 +46,6 @@ def is_expr(obj: Any) -> TypeIs[Expr]: return isinstance(obj, Expr) -def is_elementary_expression(expr: CompliantExpr[Any, Any]) -> bool: - """Check if expr is elementary. - - Examples: - - nw.col('a').mean() # depth 1 - - nw.mean('a') # depth 1 - - nw.len() # depth 0 - - as opposed to, say - - - nw.col('a').filter(nw.col('b')>nw.col('c')).max() - - Elementary expressions are the only ones supported properly in - pandas, PyArrow, and Dask. - """ - return expr._depth < 2 - - def combine_evaluate_output_names( *exprs: CompliantExpr[CompliantFrameT, Any], ) -> Callable[[CompliantFrameT], Sequence[str]]: diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index b8766fec31..e878a2051a 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -8,7 +8,6 @@ from narwhals._compliant import EagerExpr from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._expression_parsing import is_elementary_expression from narwhals._pandas_like.group_by import PandasLikeGroupBy from narwhals._pandas_like.series import PandasLikeSeries from narwhals.exceptions import ColumnNotFoundError @@ -214,7 +213,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: for s in results: s._scatter_in_place(sorting_indices, s) return results - elif not is_elementary_expression(self): + elif not self._is_elementary(): msg = ( "Only elementary expressions are supported for `.over` in pandas-like backends.\n\n" "Please see: " diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index e801be04a4..9daa6fcf81 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -40,9 +40,6 @@ class SparkLikeExpr(LazyExpr["SparkLikeLazyFrame", "Column"]): - _depth = 0 # Unused, just for compatibility with CompliantExpr - _function_name = "" # Unused, just for compatibility with CompliantExpr - def __init__( self: Self, call: Callable[[SparkLikeLazyFrame], Sequence[Column]], diff --git a/narwhals/utils.py b/narwhals/utils.py index 10e9e29634..5e75519ba1 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -52,7 +52,6 @@ from typing_extensions import LiteralString from typing_extensions import ParamSpec from typing_extensions import Self - from typing_extensions import TypeAlias from typing_extensions import TypeIs from narwhals._compliant import CompliantExpr @@ -86,8 +85,6 @@ P = ParamSpec("P") R = TypeVar("R") - _TracksDepth: TypeAlias = "Literal[Implementation.DASK,Implementation.CUDF,Implementation.MODIN,Implementation.PANDAS,Implementation.PYSPARK]" - class _SupportsVersion(Protocol): __version__: str @@ -1514,11 +1511,6 @@ def supports_arrow_c_stream(obj: Any) -> TypeIs[ArrowStreamExportable]: return _hasattr_static(obj, "__arrow_c_stream__") -def is_tracks_depth(obj: Implementation, /) -> TypeIs[_TracksDepth]: # pragma: no cover - # Return `True` for implementations that utilize `CompliantExpr._depth`. - return obj.is_pandas_like() or obj in {Implementation.PYARROW, Implementation.DASK} - - # TODO @dangotbanned: Extend with runtime behavior for `v1.*` # See `narwhals.exceptions.NarwhalsUnstableWarning` def unstable(fn: _Fn, /) -> _Fn: From 274fd6027db46dd2c7b7ac68965d0b73bf4649ee Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Mar 2025 13:15:29 +0000 Subject: [PATCH 13/20] revert: Add both classmethods back `from_column_indices` was removed without reason? --- narwhals/_compliant/expr.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 44c0e8ebdd..c5541f172e 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -93,6 +93,19 @@ def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__( self, ) -> CompliantNamespace[CompliantFrameT, Self]: ... + @classmethod + def from_column_names( + cls, + evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]], + /, + *, + context: _FullContext, + **kwds: Any, + ) -> Self: ... + @classmethod + def from_column_indices( + cls: type[Self], *column_indices: int, context: _FullContext + ) -> Self: ... def _with_metadata(self, metadata: ExprMetadata) -> Self: ... From e333fa3c56c0e1e741c32468392829b4b793cd65 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Mar 2025 13:30:37 +0000 Subject: [PATCH 14/20] refactor: Add `DepthTrackingNamespace` Second part of https://github.com/narwhals-dev/narwhals/pull/2266#discussion_r2008942878 --- narwhals/_arrow/expr.py | 2 +- narwhals/_compliant/expr.py | 11 +++++- narwhals/_compliant/namespace.py | 57 +++++++++++++++++++++++++++++--- narwhals/_dask/expr.py | 2 +- narwhals/_dask/namespace.py | 4 +-- narwhals/_pandas_like/expr.py | 2 +- 6 files changed, 67 insertions(+), 11 deletions(-) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index d914b6445f..af6431b496 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -60,8 +60,8 @@ def from_column_names( evaluate_column_names: Callable[[ArrowDataFrame], Sequence[str]], /, *, - function_name: str, context: _FullContext, + function_name: str = "", ) -> Self: def func(df: ArrowDataFrame) -> list[ArrowSeries]: try: diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index c5541f172e..2e54f296b8 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -100,7 +100,6 @@ def from_column_names( /, *, context: _FullContext, - **kwds: Any, ) -> Self: ... @classmethod def from_column_indices( @@ -281,6 +280,16 @@ class DepthTrackingExpr( _depth: int _function_name: str + @classmethod + def from_column_names( + cls: type[Self], + evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]], + /, + *, + context: _FullContext, + function_name: str = "", + ) -> Self: ... + def _is_elementary(self) -> bool: """Check if expr is elementary. diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 43fcdf9090..bc50e25cd3 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -1,19 +1,27 @@ from __future__ import annotations +from functools import partial from typing import TYPE_CHECKING from typing import Any from typing import Container from typing import Iterable from typing import Literal from typing import Protocol +from typing import TypeVar from narwhals._compliant.typing import CompliantExprT from narwhals._compliant.typing import CompliantFrameT from narwhals._compliant.typing import EagerDataFrameT from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT_co +from narwhals.utils import exclude_column_names +from narwhals.utils import get_column_names +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from narwhals._compliant.expr import DepthTrackingExpr from narwhals._compliant.selectors import CompliantSelectorNamespace from narwhals.dtypes import DType from narwhals.utils import Implementation @@ -21,16 +29,31 @@ __all__ = ["CompliantNamespace", "EagerNamespace"] +DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]" +DepthTrackingExprT = TypeVar("DepthTrackingExprT", bound=DepthTrackingExprAny) + class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): _implementation: Implementation _backend_version: tuple[int, ...] _version: Version - def all(self) -> CompliantExprT: ... - def col(self, *column_names: str) -> CompliantExprT: ... - def exclude(self, excluded_names: Container[str]) -> CompliantExprT: ... - def nth(self, *column_indices: int) -> CompliantExprT: ... + def all(self) -> CompliantExprT: + return self._expr.from_column_names(get_column_names, context=self) + + def col(self, *column_names: str) -> CompliantExprT: + return self._expr.from_column_names( + passthrough_column_names(column_names), context=self + ) + + def exclude(self, excluded_names: Container[str]) -> CompliantExprT: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), context=self + ) + + def nth(self, *column_indices: int) -> CompliantExprT: + return self._expr.from_column_indices(*column_indices, context=self) + def len(self) -> CompliantExprT: ... def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ... def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... @@ -58,8 +81,32 @@ def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ... def _expr(self) -> type[CompliantExprT]: ... +# Implement common to `Arrow`, `Dask`, `Pandas` +# Use directly for `DaskNamespace` +class DepthTrackingNamespace( + CompliantNamespace[CompliantFrameT, DepthTrackingExprT], + Protocol[CompliantFrameT, DepthTrackingExprT], +): + def all(self) -> DepthTrackingExprT: + return self._expr.from_column_names( + get_column_names, function_name="all", context=self + ) + + def col(self, *column_names: str) -> DepthTrackingExprT: + return self._expr.from_column_names( + passthrough_column_names(column_names), function_name="col", context=self + ) + + def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), + function_name="exclude", + context=self, + ) + + class EagerNamespace( - CompliantNamespace[EagerDataFrameT, EagerExprT], + DepthTrackingNamespace[EagerDataFrameT, EagerExprT], Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT], ): @property diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 63976df4da..045dd07de1 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -118,8 +118,8 @@ def from_column_names( evaluate_column_names: Callable[[DaskLazyFrame], Sequence[str]], /, *, - function_name: str, context: _FullContext, + function_name: str = "", ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: try: diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index da11b74055..5423b302f7 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -14,7 +14,7 @@ import dask.dataframe as dd import pandas as pd -from narwhals._compliant import CompliantNamespace +from narwhals._compliant.namespace import DepthTrackingNamespace from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals._dask.selectors import DaskSelectorNamespace @@ -42,7 +42,7 @@ import dask_expr as dx -class DaskNamespace(CompliantNamespace[DaskLazyFrame, "DaskExpr"]): +class DaskNamespace(DepthTrackingNamespace[DaskLazyFrame, "DaskExpr"]): _implementation: Implementation = Implementation.DASK @property diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index db4aeae557..3b1e456128 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -105,8 +105,8 @@ def from_column_names( evaluate_column_names: Callable[[PandasLikeDataFrame], Sequence[str]], /, *, - function_name: str, context: _FullContext, + function_name: str = "", ) -> Self: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: try: From 5e66ae59c710f24d03f174bd406e119c8916b588 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Mar 2025 13:36:32 +0000 Subject: [PATCH 15/20] revert: Undo copy/paste from (https://github.com/narwhals-dev/narwhals/pull/2266/commits/d06b4e544023d01f582159b4ee02ae490cb2a8d0) Third part of https://github.com/narwhals-dev/narwhals/pull/2266#discussion_r2008942878 --- narwhals/_arrow/namespace.py | 26 -------------------------- narwhals/_dask/namespace.py | 25 ------------------------- narwhals/_duckdb/namespace.py | 22 ---------------------- narwhals/_pandas_like/namespace.py | 26 -------------------------- narwhals/_spark_like/namespace.py | 22 ---------------------- 5 files changed, 121 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index ad3b7e3f73..a8b65123f1 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -1,13 +1,11 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -29,10 +27,7 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing import Callable @@ -64,27 +59,6 @@ def __init__( self._implementation = Implementation.PYARROW self._version = version - # --- selection --- - def all(self) -> ArrowExpr: - return self._expr.from_column_names( - get_column_names, function_name="all", context=self - ) - - def col(self, *column_names: str) -> ArrowExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), function_name="col", context=self - ) - - def exclude(self, excluded_names: Container[str]) -> ArrowExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - context=self, - ) - - def nth(self, *column_indices: int) -> ArrowExpr: - return self._expr.from_column_indices(*column_indices, context=self) - def len(self: Self) -> ArrowExpr: # coverage bug? this is definitely hit return self._expr( # pragma: no cover diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 5423b302f7..c7da90d356 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -26,9 +24,6 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -59,26 +54,6 @@ def __init__( self._backend_version = backend_version self._version = version - def all(self) -> DaskExpr: - return self._expr.from_column_names( - get_column_names, function_name="all", context=self - ) - - def col(self, *column_names: str) -> DaskExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), function_name="col", context=self - ) - - def exclude(self, excluded_names: Container[str]) -> DaskExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - context=self, - ) - - def nth(self, *column_indices: int) -> DaskExpr: - return self._expr.from_column_indices(*column_indices, context=self) - def lit(self: Self, value: Any, dtype: DType | None) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: if dtype is not None: diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index e6a54826cd..e2f6e6790d 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -26,9 +24,6 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: import duckdb @@ -56,23 +51,6 @@ def selectors(self: Self) -> DuckDBSelectorNamespace: def _expr(self) -> type[DuckDBExpr]: return DuckDBExpr - def col(self, *column_names: str) -> DuckDBExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), context=self - ) - - def all(self) -> DuckDBExpr: - return self._expr.from_column_names(get_column_names, context=self) - - def exclude(self, excluded_names: Container[str]) -> DuckDBExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - context=self, - ) - - def nth(self, *column_indices: int) -> DuckDBExpr: - return self._expr.from_column_indices(*column_indices, context=self) - def concat( self: Self, items: Iterable[DuckDBLazyFrame], diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 0c6aacb5d8..33095f040a 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -23,10 +21,7 @@ from narwhals._pandas_like.utils import extract_dataframe_comparand from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import vertical_concat -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -65,27 +60,6 @@ def __init__( self._backend_version = backend_version self._version = version - # --- selection --- - def all(self) -> PandasLikeExpr: - return self._expr.from_column_names( - get_column_names, function_name="all", context=self - ) - - def col(self, *column_names: str) -> PandasLikeExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), function_name="col", context=self - ) - - def exclude(self, excluded_names: Container[str]) -> PandasLikeExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - context=self, - ) - - def nth(self, *column_indices: int) -> PandasLikeExpr: - return self._expr.from_column_indices(*column_indices, context=self) - def lit(self: Self, value: Any, dtype: DType | None) -> PandasLikeExpr: def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: pandas_series = self._series.from_iterable( diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 1c691c925e..243c55954d 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -19,9 +17,6 @@ from narwhals._spark_like.selectors import SparkLikeSelectorNamespace from narwhals._spark_like.utils import maybe_evaluate_expr from narwhals._spark_like.utils import narwhals_to_native_dtype -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from sqlframe.base.column import Column @@ -52,23 +47,6 @@ def selectors(self: Self) -> SparkLikeSelectorNamespace: def _expr(self) -> type[SparkLikeExpr]: return SparkLikeExpr - def col(self, *column_names: str) -> SparkLikeExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), context=self - ) - - def all(self) -> SparkLikeExpr: - return self._expr.from_column_names(get_column_names, context=self) - - def exclude(self, excluded_names: Container[str]) -> SparkLikeExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - context=self, - ) - - def nth(self, *column_indices: int) -> SparkLikeExpr: - return self._expr.from_column_indices(*column_indices, context=self) - def lit(self: Self, value: object, dtype: DType | None) -> SparkLikeExpr: def _lit(df: SparkLikeLazyFrame) -> list[Column]: column = df._F.lit(value) From 94cc379edf25f749693306ac0a126517f61bb9ad Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Mar 2025 13:40:56 +0000 Subject: [PATCH 16/20] ci(typing): ignore, unused ignore - https://github.com/narwhals-dev/narwhals/pull/2266/commits/2d4dd8a980ad9cf22ef7ab02047df65aa15e4243 - https://github.com/narwhals-dev/narwhals/actions/runs/14019249312/job/39248935271?pr=2266 --- narwhals/_duckdb/expr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 03a1c9c466..15279c5b79 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -43,7 +43,7 @@ from narwhals.utils import _FullContext with contextlib.suppress(ImportError): # requires duckdb>=1.3.0 - from duckdb import SQLExpression # type: ignore # noqa: PGH003 + from duckdb import SQLExpression # type: ignore[attr-defined, unused-ignore] class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]): @@ -457,7 +457,7 @@ def func( order_by_sql = generate_order_by_sql(*order_by, ascending=not reverse) partition_by_sql = generate_partition_by_sql(*partition_by) sql = f"sum ({_input}) over ({partition_by_sql} {order_by_sql} rows between unbounded preceding and current row)" - return SQLExpression(sql) # type: ignore # noqa: PGH003 + return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore] return self._with_window_function(func) @@ -480,7 +480,7 @@ def func( partition_by_sql = generate_partition_by_sql(*partition_by) window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})" sql = f"case when count({_input}) over {window} >= {min_samples} then sum({_input}) over {window} else null end" - return SQLExpression(sql) # type: ignore # noqa: PGH003 + return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore] return self._with_window_function(func) From b5b1c92dd8d35a4371300893537ceaed929eaa67 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Mar 2025 13:46:39 +0000 Subject: [PATCH 17/20] test: Kinda unbreak `sqlframe` test Was broken for me locally following https://github.com/narwhals-dev/narwhals/pull/2263#discussion_r2009101659 --- tests/expr_and_series/str/to_datetime_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index 412485d01f..e8a969df2d 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -14,6 +14,7 @@ from tests.utils import PYARROW_VERSION from tests.utils import assert_equal_data from tests.utils import is_pyarrow_windows_no_tzdata +from tests.utils import is_windows if TYPE_CHECKING: from tests.utils import Constructor @@ -214,7 +215,10 @@ def test_to_datetime_tz_aware( if "pandas" in str(constructor) and PANDAS_VERSION < (1,): # "Cannot pass a tz argument when parsing strings with timezone information." pytest.skip() - if is_pyarrow_windows_no_tzdata(constructor): + if is_pyarrow_windows_no_tzdata(constructor) or ( + "sqlframe" in str(constructor) and format is not None and is_windows() + ): + # NOTE: For `sqlframe` see https://github.com/narwhals-dev/narwhals/pull/2263#discussion_r2009101659 pytest.skip() if "cudf" in str(constructor): # cuDF does not yet support timezone-aware datetimes From 8fa9f1cedb58b7ff26271294ed3e941fc21cb0a4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Mar 2025 15:25:51 +0000 Subject: [PATCH 18/20] refactor: Hide `CompliantExpr` internals from `LazyGroupBy` https://github.com/narwhals-dev/narwhals/pull/2266#discussion_r2009144087 --- narwhals/_compliant/expr.py | 13 +++++++++++++ narwhals/_compliant/group_by.py | 6 +----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 2e54f296b8..918e9b744a 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -271,6 +271,19 @@ def __invert__(self) -> Self: ... def broadcast( self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL] ) -> Self: ... + def _is_multi_output_unnamed(self) -> bool: + """Return `True` for multi-output aggregations without names. + + For example, column `'a'` only appears in the output as a grouping key: + + df.group_by('a').agg(nw.all().sum()) + + It does not get included in: + + nw.all().sum(). + """ + assert self._metadata is not None # noqa: S101 + return self._metadata.expansion_kind.is_multi_unnamed() class DepthTrackingExpr( diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index ebde9d5c14..30d4656a1e 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -152,11 +152,7 @@ def _evaluate_expr(self, expr: LazyExprT_contra, /) -> Iterator[NativeExprT_co]: else output_names ) native_exprs = expr(self.compliant) - assert expr._metadata is not None # noqa: S101 - if expr._metadata.expansion_kind.is_multi_unnamed(): - # Exclude keys from expansion. For example, in - # `df.group_by('a').agg(nw.all().sum())`, column 'a' only appears in the - # output as a grouping key - is does not get included in `nw.all().sum()`. + if expr._is_multi_output_unnamed(): for native_expr, name, alias in zip(native_exprs, output_names, aliases): if name not in self._keys: yield native_expr.alias(alias) From 3129332fd260e62dcc99c086053947c0ce2d24db Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Mar 2025 15:38:40 +0000 Subject: [PATCH 19/20] Remove comment Forgot to remove this --- narwhals/_compliant/namespace.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index bc50e25cd3..a1d5065880 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -81,8 +81,6 @@ def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ... def _expr(self) -> type[CompliantExprT]: ... -# Implement common to `Arrow`, `Dask`, `Pandas` -# Use directly for `DaskNamespace` class DepthTrackingNamespace( CompliantNamespace[CompliantFrameT, DepthTrackingExprT], Protocol[CompliantFrameT, DepthTrackingExprT], From 3cdc967d46dcb45369cfb79f9eac7a49fd455e22 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 23 Mar 2025 17:08:44 +0000 Subject: [PATCH 20/20] redo (94cc379edf25f749693306ac0a126517f61bb9ad) --- narwhals/_duckdb/expr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 85256c5c04..a628090ddd 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -455,7 +455,7 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression: sql = ( f"lag({window_inputs.expr}, {n}) over ({partition_by_sql} {order_by_sql})" ) - return SQLExpression(sql) + return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore] return self._with_window_function(func) @@ -470,7 +470,7 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression: else: partition_by_sql = f"partition by {window_inputs.expr}" sql = f"row_number() over({partition_by_sql} {order_by_sql}) == 1" - return SQLExpression(sql) + return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore] return self._with_window_function(func) @@ -485,7 +485,7 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression: else: partition_by_sql = f"partition by {window_inputs.expr}" sql = f"row_number() over({partition_by_sql} {order_by_sql}) == 1" - return SQLExpression(sql) + return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore] return self._with_window_function(func) @@ -494,7 +494,7 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression: order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True) partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by) sql = f"lag({window_inputs.expr}) over ({partition_by_sql} {order_by_sql})" - return window_inputs.expr - SQLExpression(sql) + return window_inputs.expr - SQLExpression(sql) # type: ignore[no-any-return, unused-ignore] return self._with_window_function(func)