From e6a8d72decc57c0f324f249707f25bb341b4be13 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:57:29 +0000 Subject: [PATCH 01/36] refactor: Add `expressions` subpackage --- narwhals/_plan/expressions/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 narwhals/_plan/expressions/__init__.py diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From 8ad4aeda3ddec8f260ba90b6f50b7b6e78d864e2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:57:49 +0000 Subject: [PATCH 02/36] refactor: Move `aggregation.py` --- narwhals/_plan/arrow/expr.py | 26 +++++++++---------- narwhals/_plan/dummy.py | 2 +- narwhals/_plan/exceptions.py | 2 +- narwhals/_plan/expr.py | 2 +- .../_plan/{ => expressions}/aggregation.py | 0 narwhals/_plan/protocols.py | 3 ++- 6 files changed, 18 insertions(+), 17 deletions(-) rename narwhals/_plan/{ => expressions}/aggregation.py (100%) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index d8a163d120..f47e01b56f 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -27,7 +27,19 @@ from narwhals._arrow.typing import ChunkedArrayAny, Incomplete from narwhals._plan import boolean, expr - from narwhals._plan.aggregation import ( + from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame + from narwhals._plan.arrow.namespace import ArrowNamespace + from narwhals._plan.boolean import All, IsBetween, IsFinite, IsNan, IsNull, Not + from narwhals._plan.expr import ( + AnonymousExpr, + BinaryExpr, + FunctionExpr, + OrderedWindowExpr, + RollingExpr, + TernaryExpr, + WindowExpr, + ) + from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, Count, @@ -43,18 +55,6 @@ Sum, Var, ) - from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame - from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.boolean import All, IsBetween, IsFinite, IsNan, IsNull, Not - from narwhals._plan.expr import ( - AnonymousExpr, - BinaryExpr, - FunctionExpr, - OrderedWindowExpr, - RollingExpr, - TernaryExpr, - WindowExpr, - ) from narwhals._plan.functions import FillNull, Pow from narwhals.typing import Into1DArray, IntoDType, PythonLiteral diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 0a1e469917..961b0fe7b5 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload from narwhals._plan import ( - aggregation as agg, boolean, expr, expr_expansion, @@ -18,6 +17,7 @@ from narwhals._plan._guards import is_column, is_expr, is_series from narwhals._plan.common import into_dtype from narwhals._plan.contexts import ExprContext +from narwhals._plan.expressions import aggregation as agg from narwhals._plan.options import ( EWMOptions, RankOptions, diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 127471720d..d85513fe8a 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -24,9 +24,9 @@ import pandas as pd import polars as pl - from narwhals._plan.aggregation import AggExpr from narwhals._plan.common import ExprIR, Function from narwhals._plan.expr import FunctionExpr, WindowExpr + from narwhals._plan.expressions.aggregation import AggExpr from narwhals._plan.operators import Operator from narwhals._plan.options import SortOptions from narwhals._plan.typing import IntoExpr, Seq diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index ecf4efe136..45a1958123 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -6,9 +6,9 @@ # - Literal import typing as t -from narwhals._plan.aggregation import AggExpr, OrderableAggExpr from narwhals._plan.common import ExprIR, SelectorIR, collect from narwhals._plan.exceptions import function_expr_invalid_operation_error +from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr from narwhals._plan.name import KeepName, RenameAlias from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/expressions/aggregation.py similarity index 100% rename from narwhals/_plan/aggregation.py rename to narwhals/_plan/expressions/aggregation.py diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 951dfa850e..42c38fb7b8 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -11,10 +11,11 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._plan import aggregation as agg, boolean, expr, functions as F + from narwhals._plan import boolean, expr, functions as F from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull, Not from narwhals._plan.dummy import BaseFrame, DataFrame, Series from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr + from narwhals._plan.expressions import aggregation as agg from narwhals._plan.options import SortMultipleOptions from narwhals._plan.ranges import IntRange from narwhals._plan.strings import ConcatStr From 1c62a3bdd29d59f34348384ff9cb78e8ac505970 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:59:08 +0000 Subject: [PATCH 03/36] refactor: Move `name.py` --- narwhals/_plan/dummy.py | 4 ++-- narwhals/_plan/expr.py | 2 +- narwhals/_plan/{ => expressions}/name.py | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename narwhals/_plan/{ => expressions}/name.py (100%) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 961b0fe7b5..18c88ad96a 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -39,9 +39,9 @@ from narwhals._plan.categorical import ExprCatNamespace from narwhals._plan.common import ExprIR, Function, NamedIR + from narwhals._plan.expressions.name import ExprNameNamespace from narwhals._plan.lists import ExprListNamespace from narwhals._plan.meta import IRMetaNamespace - from narwhals._plan.name import ExprNameNamespace from narwhals._plan.protocols import ( CompliantBaseFrame, CompliantDataFrame, @@ -574,7 +574,7 @@ def name(self) -> ExprNameNamespace: >>> str(renamed._ir) "RenameAlias(expr=Columns(names=[a, b]), function=Suffix(suffix='_changed'))" """ - from narwhals._plan.name import ExprNameNamespace + from narwhals._plan.expressions.name import ExprNameNamespace return ExprNameNamespace(_expr=self) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 45a1958123..5c513644fc 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -9,7 +9,7 @@ from narwhals._plan.common import ExprIR, SelectorIR, collect from narwhals._plan.exceptions import function_expr_invalid_operation_error from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr -from narwhals._plan.name import KeepName, RenameAlias +from narwhals._plan.expressions.name import KeepName, RenameAlias from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( FunctionT_co, diff --git a/narwhals/_plan/name.py b/narwhals/_plan/expressions/name.py similarity index 100% rename from narwhals/_plan/name.py rename to narwhals/_plan/expressions/name.py From 36ca08820e48bb8603133acebfdfc8f72d341248 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:03:25 +0000 Subject: [PATCH 04/36] refactor: Move `expr.py` --- narwhals/_plan/_guards.py | 4 ++-- narwhals/_plan/arrow/expr.py | 21 +++++++++++---------- narwhals/_plan/arrow/namespace.py | 5 +++-- narwhals/_plan/boolean.py | 2 +- narwhals/_plan/common.py | 12 ++++++------ narwhals/_plan/demo.py | 5 +++-- narwhals/_plan/dummy.py | 3 +-- narwhals/_plan/exceptions.py | 2 +- narwhals/_plan/expr_expansion.py | 2 +- narwhals/_plan/{ => expressions}/expr.py | 0 narwhals/_plan/functions.py | 6 +++--- narwhals/_plan/literal.py | 4 ++-- narwhals/_plan/meta.py | 18 +++++++++--------- narwhals/_plan/operators.py | 6 +++--- narwhals/_plan/protocols.py | 6 +++--- narwhals/_plan/ranges.py | 4 ++-- narwhals/_plan/selectors.py | 4 ++-- narwhals/_plan/when_then.py | 4 ++-- narwhals/_plan/window.py | 6 +++--- tests/plan/expr_expansion_test.py | 2 +- tests/plan/expr_parsing_test.py | 5 +++-- tests/plan/expr_rewrites_test.py | 2 +- 22 files changed, 63 insertions(+), 60 deletions(-) rename narwhals/_plan/{ => expressions}/expr.py (100%) diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 867d16d397..bfdb762ebd 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -11,8 +11,8 @@ if TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._plan import expr from narwhals._plan.dummy import Expr, Series + from narwhals._plan.expressions import expr from narwhals._plan.protocols import CompliantSeries from narwhals._plan.typing import NativeSeriesT, Seq from narwhals.typing import NonNestedLiteral @@ -38,7 +38,7 @@ def _dummy(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 def _expr(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 - from narwhals._plan import expr + from narwhals._plan.expressions import expr return expr diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index f47e01b56f..fec13b488e 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -26,19 +26,11 @@ from typing_extensions import Self, TypeAlias from narwhals._arrow.typing import ChunkedArrayAny, Incomplete - from narwhals._plan import boolean, expr + from narwhals._plan import boolean from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.boolean import All, IsBetween, IsFinite, IsNan, IsNull, Not - from narwhals._plan.expr import ( - AnonymousExpr, - BinaryExpr, - FunctionExpr, - OrderedWindowExpr, - RollingExpr, - TernaryExpr, - WindowExpr, - ) + from narwhals._plan.expressions import expr from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, @@ -55,6 +47,15 @@ Sum, Var, ) + from narwhals._plan.expressions.expr import ( + AnonymousExpr, + BinaryExpr, + FunctionExpr, + OrderedWindowExpr, + RollingExpr, + TernaryExpr, + WindowExpr, + ) from narwhals._plan.functions import FillNull, Pow from narwhals.typing import Into1DArray, IntoDType, PythonLiteral diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index f7bfaaa330..2b6e0ca778 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -19,13 +19,14 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from narwhals._arrow.typing import ChunkedArrayAny - from narwhals._plan import expr, functions as F + from narwhals._plan import functions as F from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.boolean import AllHorizontal, AnyHorizontal from narwhals._plan.dummy import Series as NwSeries - from narwhals._plan.expr import FunctionExpr, RangeExpr + from narwhals._plan.expressions import expr + from narwhals._plan.expressions.expr import FunctionExpr, RangeExpr from narwhals._plan.ranges import IntRange from narwhals._plan.strings import ConcatStr from narwhals.typing import ConcatMethod, NonNestedLiteral, PythonLiteral diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 23f7d27dd3..fc9d863fdc 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -13,7 +13,7 @@ from narwhals._plan.common import ExprIR from narwhals._plan.dummy import Series - from narwhals._plan.expr import FunctionExpr, Literal # noqa: F401 + from narwhals._plan.expressions.expr import FunctionExpr, Literal # noqa: F401 from narwhals._plan.typing import NativeSeriesT, Seq # noqa: F401 from narwhals.typing import ClosedInterval diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index f73f21b26b..3364de4e80 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -34,7 +34,7 @@ from typing_extensions import Self, TypeAlias from narwhals._plan.dummy import Expr, Selector - from narwhals._plan.expr import Alias, Cast, Column, FunctionExpr + from narwhals._plan.expressions.expr import Alias, Cast, Column, FunctionExpr from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -266,12 +266,12 @@ def meta(self) -> IRMetaNamespace: return IRMetaNamespace(_ir=self) def cast(self, dtype: DType) -> Cast: - from narwhals._plan.expr import Cast + from narwhals._plan.expressions.expr import Cast return Cast(expr=self, dtype=dtype) def alias(self, name: str) -> Alias: - from narwhals._plan.expr import Alias + from narwhals._plan.expressions.expr import Alias return Alias(expr=self, name=name) @@ -319,7 +319,7 @@ def from_name(name: str, /) -> NamedIR[Column]: Intended to be used in `with_columns` from a `FrozenSchema`'s keys. """ - from narwhals._plan.expr import col + from narwhals._plan.expressions.expr import col return NamedIR(expr=col(name), name=name) @@ -353,7 +353,7 @@ def is_elementwise_top_level(self) -> bool: [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/properties.rs#L16-L44 """ - from narwhals._plan import expr + from narwhals._plan.expressions import expr ir = self.expr if is_function_expr(ir): @@ -414,7 +414,7 @@ def is_scalar(self) -> bool: return self.function_options.returns_scalar() def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: - from narwhals._plan.expr import FunctionExpr + from narwhals._plan.expressions.expr import FunctionExpr return FunctionExpr(input=inputs, function=self, options=self.function_options) diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index ab89b97c96..d60d7a9651 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,9 +3,10 @@ import builtins import typing as t -from narwhals._plan import _guards, boolean, expr, expr_parsing as parse, functions as F +from narwhals._plan import _guards, boolean, expr_parsing as parse, functions as F from narwhals._plan.common import into_dtype, py_to_narwhals_dtype -from narwhals._plan.expr import All, Len +from narwhals._plan.expressions import expr +from narwhals._plan.expressions.expr import All, Len from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.ranges import IntRange from narwhals._plan.strings import ConcatStr diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 18c88ad96a..f3ff304c42 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -8,7 +8,6 @@ from narwhals._plan import ( boolean, - expr, expr_expansion, expr_parsing as parse, functions as F, @@ -17,7 +16,7 @@ from narwhals._plan._guards import is_column, is_expr, is_series from narwhals._plan.common import into_dtype from narwhals._plan.contexts import ExprContext -from narwhals._plan.expressions import aggregation as agg +from narwhals._plan.expressions import aggregation as agg, expr from narwhals._plan.options import ( EWMOptions, RankOptions, diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index d85513fe8a..0507cb51c5 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -25,8 +25,8 @@ import polars as pl from narwhals._plan.common import ExprIR, Function - from narwhals._plan.expr import FunctionExpr, WindowExpr from narwhals._plan.expressions.aggregation import AggExpr + from narwhals._plan.expressions.expr import FunctionExpr, WindowExpr from narwhals._plan.operators import Operator from narwhals._plan.options import SortOptions from narwhals._plan.typing import IntoExpr, Seq diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 2fdf92ef7d..0eaafca10b 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -51,7 +51,7 @@ column_not_found_error, duplicate_error, ) -from narwhals._plan.expr import ( +from narwhals._plan.expressions.expr import ( Alias, All, Columns, diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expressions/expr.py similarity index 100% rename from narwhals/_plan/expr.py rename to narwhals/_plan/expressions/expr.py diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 570d75b4d0..b936b602ce 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -14,7 +14,7 @@ from typing_extensions import Self from narwhals._plan.common import ExprIR - from narwhals._plan.expr import AnonymousExpr, FunctionExpr, RollingExpr + from narwhals._plan.expressions.expr import AnonymousExpr, FunctionExpr, RollingExpr from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow from narwhals._plan.typing import Seq, Udf from narwhals.dtypes import DType @@ -31,7 +31,7 @@ class RollingWindow(Function, options=FunctionOptions.length_preserving): options: RollingOptionsFixedWindow def to_function_expr(self, *inputs: ExprIR) -> RollingExpr[Self]: - from narwhals._plan.expr import RollingExpr + from narwhals._plan.expressions.expr import RollingExpr options = self.function_options return RollingExpr(input=inputs, function=self, options=options) @@ -176,7 +176,7 @@ def function_options(self) -> FunctionOptions: return options def to_function_expr(self, *inputs: ExprIR) -> AnonymousExpr: - from narwhals._plan.expr import AnonymousExpr + from narwhals._plan.expressions.expr import AnonymousExpr options = self.function_options return AnonymousExpr(input=inputs, function=self, options=options) diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 22c170508c..9032d49cd4 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -10,7 +10,7 @@ from typing_extensions import TypeIs from narwhals._plan.dummy import Series - from narwhals._plan.expr import Literal + from narwhals._plan.expressions.expr import Literal from narwhals.dtypes import DType @@ -30,7 +30,7 @@ def is_scalar(self) -> bool: return False def to_literal(self) -> Literal[LiteralT]: - from narwhals._plan.expr import Literal + from narwhals._plan.expressions.expr import Literal return Literal(value=self) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index ce78165c00..d20529cfc1 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -20,7 +20,7 @@ from typing_extensions import TypeIs from narwhals._plan.common import ExprIR - from narwhals._plan.expr import Column + from narwhals._plan.expressions.expr import Column class IRMetaNamespace(IRNamespace): @@ -86,7 +86,7 @@ def _expr_to_leaf_column_names_iter(ir: ExprIR) -> Iterator[str]: def _expr_to_leaf_column_exprs_iter(ir: ExprIR) -> Iterator[ExprIR]: - from narwhals._plan import expr + from narwhals._plan.expressions import expr for outer in ir.iter_root_names(): if isinstance(outer, (expr.Column, expr.All)): @@ -102,7 +102,7 @@ def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: msg = "no root column name found" return ComputeError(msg) leaf = leaves[0] - from narwhals._plan import expr + from narwhals._plan.expressions import expr if isinstance(leaf, expr.Column): return leaf.name @@ -119,7 +119,7 @@ def root_names_unique(irs: Iterable[ExprIR], /) -> set[str]: @lru_cache(maxsize=32) def _expr_output_name(ir: ExprIR) -> str | ComputeError: - from narwhals._plan import expr + from narwhals._plan.expressions import expr for e in ir.iter_output_name(): if isinstance(e, (expr.Column, expr.Alias, expr.Literal, expr.Len)): @@ -144,7 +144,7 @@ def get_single_leaf_name(ir: ExprIR) -> str | ComputeError: [`polars_plan::utils::get_single_leaf`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L151-L168 """ - from narwhals._plan import expr + from narwhals._plan.expressions import expr for e in ir.iter_right(): if isinstance(e, (expr.WindowExpr, expr.SortBy, expr.Filter)): @@ -159,7 +159,7 @@ def get_single_leaf_name(ir: ExprIR) -> str | ComputeError: def _has_multiple_outputs(ir: ExprIR) -> bool: - from narwhals._plan import expr + from narwhals._plan.expressions import expr return isinstance(ir, (expr.Columns, expr.IndexColumns, expr.SelectorIR, expr.All)) @@ -175,13 +175,13 @@ def has_expr_ir(ir: ExprIR, *matches: type[ExprIR]) -> bool: def is_column(ir: ExprIR) -> TypeIs[Column]: - from narwhals._plan.expr import Column + from narwhals._plan.expressions.expr import Column return isinstance(ir, Column) def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: - from narwhals._plan import expr + from narwhals._plan.expressions import expr from narwhals._plan.literal import is_literal_scalar return ( @@ -196,7 +196,7 @@ def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: - from narwhals._plan import expr + from narwhals._plan.expressions import expr return isinstance(ir, (expr.Column, expr._ColumnSelection, expr.SelectorIR)) or ( allow_aliasing and isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)) diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index 78b33b042f..b8e9fc2b65 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -17,7 +17,7 @@ from typing_extensions import Self from narwhals._plan.common import ExprIR - from narwhals._plan.expr import BinaryExpr, BinarySelector + from narwhals._plan.expressions.expr import BinaryExpr, BinarySelector from narwhals._plan.typing import ( LeftSelectorT, LeftT, @@ -45,7 +45,7 @@ def __init_subclass__( def to_binary_expr( self, left: LeftT, right: RightT, / ) -> BinaryExpr[LeftT, Self, RightT]: - from narwhals._plan.expr import BinaryExpr + from narwhals._plan.expressions.expr import BinaryExpr if right.meta.has_multiple_outputs(): raise binary_expr_multi_output_error(left, self, right) @@ -74,7 +74,7 @@ class SelectorOperator(Operator, func=None): def to_binary_selector( self, left: LeftSelectorT, right: RightSelectorT, / ) -> BinarySelector[LeftSelectorT, Self, RightSelectorT]: - from narwhals._plan.expr import BinarySelector + from narwhals._plan.expressions.expr import BinarySelector return BinarySelector(left=left, op=self, right=right) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 42c38fb7b8..8f1fdcd9a6 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -11,11 +11,11 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._plan import boolean, expr, functions as F + from narwhals._plan import boolean, functions as F from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull, Not from narwhals._plan.dummy import BaseFrame, DataFrame, Series - from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr - from narwhals._plan.expressions import aggregation as agg + from narwhals._plan.expressions import aggregation as agg, expr + from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.options import SortMultipleOptions from narwhals._plan.ranges import IntRange from narwhals._plan.strings import ConcatStr diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/ranges.py index 4f8e49b531..7dc0faa42e 100644 --- a/narwhals/_plan/ranges.py +++ b/narwhals/_plan/ranges.py @@ -8,13 +8,13 @@ if TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.expr import RangeExpr + from narwhals._plan.expressions.expr import RangeExpr from narwhals.dtypes import IntegerType class RangeFunction(Function, config=FEOptions.namespaced()): def to_function_expr(self, *inputs: ExprIR) -> RangeExpr[Self]: - from narwhals._plan.expr import RangeExpr + from narwhals._plan.expressions.expr import RangeExpr return RangeExpr(input=inputs, function=self, options=self.function_options) diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 4aa5f58a3d..8dc44d70f6 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -19,7 +19,7 @@ from typing import TypeVar from narwhals._plan import dummy - from narwhals._plan.expr import RootSelector + from narwhals._plan.expressions.expr import RootSelector from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType from narwhals.typing import TimeUnit @@ -31,7 +31,7 @@ class Selector(Immutable): def to_selector(self) -> RootSelector: - from narwhals._plan.expr import RootSelector + from narwhals._plan.expressions.expr import RootSelector return RootSelector(selector=self) diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 62e0da3d2a..b32f9c4f77 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from narwhals._plan.common import ExprIR - from narwhals._plan.expr import TernaryExpr + from narwhals._plan.expressions.expr import TernaryExpr from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq @@ -114,6 +114,6 @@ def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: - from narwhals._plan.expr import TernaryExpr + from narwhals._plan.expressions.expr import TernaryExpr return TernaryExpr(predicate=predicate, truthy=truthy, falsy=falsy) diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index fd27743948..329b36078c 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from narwhals._plan.common import ExprIR - from narwhals._plan.expr import OrderedWindowExpr, WindowExpr + from narwhals._plan.expressions.expr import OrderedWindowExpr, WindowExpr from narwhals._plan.options import SortOptions from narwhals._plan.typing import Seq from narwhals.exceptions import InvalidOperationError @@ -43,7 +43,7 @@ def _validate_over( return None def to_window_expr(self, expr: ExprIR, partition_by: Seq[ExprIR], /) -> WindowExpr: - from narwhals._plan.expr import WindowExpr + from narwhals._plan.expressions.expr import WindowExpr if err := self._validate_over(expr, partition_by): raise err @@ -57,7 +57,7 @@ def to_ordered_window_expr( sort_options: SortOptions, /, ) -> OrderedWindowExpr: - from narwhals._plan.expr import OrderedWindowExpr + from narwhals._plan.expressions.expr import OrderedWindowExpr if err := self._validate_over(expr, partition_by, order_by, sort_options): raise err diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 8396db6c74..ba45c5a7b1 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -7,13 +7,13 @@ import narwhals as nw from narwhals._plan import demo as nwd, selectors as ndcs -from narwhals._plan.expr import Alias, Columns from narwhals._plan.expr_expansion import ( prepare_projection, replace_selector, rewrite_special_aliases, ) from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir +from narwhals._plan.expressions.expr import Alias, Columns from narwhals._plan.schema import freeze_schema from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError from tests.plan.utils import assert_expr_ir_equal diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 52068da571..2e3a69df90 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -11,11 +11,12 @@ import narwhals as nw import narwhals._plan.demo as nwd -from narwhals._plan import boolean, expr, functions as F, operators as ops +from narwhals._plan import boolean, functions as F, operators as ops from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import Expr, Series -from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir +from narwhals._plan.expressions import expr +from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.literal import SeriesLiteral from narwhals.exceptions import ( InvalidIntoExprError, diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 740d966818..964d03aeac 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -8,12 +8,12 @@ from narwhals._plan import demo as nwd, expr_parsing as parse, selectors as ndcs from narwhals._plan._guards import is_expr from narwhals._plan.common import ExprIR, NamedIR -from narwhals._plan.expr import WindowExpr from narwhals._plan.expr_rewrites import ( rewrite_all, rewrite_binary_agg_over, rewrite_elementwise_over, ) +from narwhals._plan.expressions.expr import WindowExpr from narwhals._plan.window import Over from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_expr_ir_equal From 1241580489ef2d385311c98181555ccc45c6c8f2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:09:42 +0000 Subject: [PATCH 05/36] refactor: Move `boolean`, `functions` --- narwhals/_plan/arrow/expr.py | 14 ++++++++++---- narwhals/_plan/arrow/namespace.py | 5 ++--- narwhals/_plan/demo.py | 4 ++-- narwhals/_plan/dummy.py | 10 ++-------- narwhals/_plan/expr_parsing.py | 2 +- narwhals/_plan/{ => expressions}/boolean.py | 0 narwhals/_plan/expressions/expr.py | 2 +- narwhals/_plan/{ => expressions}/functions.py | 0 narwhals/_plan/protocols.py | 10 +++++++--- narwhals/_plan/typing.py | 2 +- tests/plan/expr_parsing_test.py | 4 ++-- 11 files changed, 28 insertions(+), 25 deletions(-) rename narwhals/_plan/{ => expressions}/boolean.py (100%) rename narwhals/_plan/{ => expressions}/functions.py (100%) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index fec13b488e..bb715ea7c5 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -26,11 +26,9 @@ from typing_extensions import Self, TypeAlias from narwhals._arrow.typing import ChunkedArrayAny, Incomplete - from narwhals._plan import boolean from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.boolean import All, IsBetween, IsFinite, IsNan, IsNull, Not - from narwhals._plan.expressions import expr + from narwhals._plan.expressions import boolean, expr from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, @@ -47,6 +45,14 @@ Sum, Var, ) + from narwhals._plan.expressions.boolean import ( + All, + IsBetween, + IsFinite, + IsNan, + IsNull, + Not, + ) from narwhals._plan.expressions.expr import ( AnonymousExpr, BinaryExpr, @@ -56,7 +62,7 @@ TernaryExpr, WindowExpr, ) - from narwhals._plan.functions import FillNull, Pow + from narwhals._plan.expressions.functions import FillNull, Pow from narwhals.typing import Into1DArray, IntoDType, PythonLiteral Expr: TypeAlias = "ArrowExpr" diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 2b6e0ca778..55f7f4017d 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -19,13 +19,12 @@ from collections.abc import Callable, Iterable, Iterator, Sequence from narwhals._arrow.typing import ChunkedArrayAny - from narwhals._plan import functions as F from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.series import ArrowSeries as Series - from narwhals._plan.boolean import AllHorizontal, AnyHorizontal from narwhals._plan.dummy import Series as NwSeries - from narwhals._plan.expressions import expr + from narwhals._plan.expressions import expr, functions as F + from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal from narwhals._plan.expressions.expr import FunctionExpr, RangeExpr from narwhals._plan.ranges import IntRange from narwhals._plan.strings import ConcatStr diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index d60d7a9651..d72d877759 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,9 +3,9 @@ import builtins import typing as t -from narwhals._plan import _guards, boolean, expr_parsing as parse, functions as F +from narwhals._plan import _guards, expr_parsing as parse from narwhals._plan.common import into_dtype, py_to_narwhals_dtype -from narwhals._plan.expressions import expr +from narwhals._plan.expressions import boolean, expr, functions as F from narwhals._plan.expressions.expr import All, Len from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.ranges import IntRange diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index f3ff304c42..1686c36ceb 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -6,17 +6,11 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload -from narwhals._plan import ( - boolean, - expr_expansion, - expr_parsing as parse, - functions as F, - operators as ops, -) +from narwhals._plan import expr_expansion, expr_parsing as parse, operators as ops from narwhals._plan._guards import is_column, is_expr, is_series from narwhals._plan.common import into_dtype from narwhals._plan.contexts import ExprContext -from narwhals._plan.expressions import aggregation as agg, expr +from narwhals._plan.expressions import aggregation as agg, boolean, expr, functions as F from narwhals._plan.options import ( EWMOptions, RankOptions, diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index 1e450f2307..303017d700 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -164,7 +164,7 @@ def _parse_constraints(constraints: dict[str, IntoExpr], /) -> Iterator[ExprIR]: def _combine_predicates(predicates: Iterator[ExprIR], /) -> ExprIR: - from narwhals._plan.boolean import AllHorizontal + from narwhals._plan.expressions.boolean import AllHorizontal first = next(predicates, None) if not first: diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/expressions/boolean.py similarity index 100% rename from narwhals/_plan/boolean.py rename to narwhals/_plan/expressions/boolean.py diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 5c513644fc..ced5a0be9a 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -31,7 +31,7 @@ if t.TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.functions import MapBatches # noqa: F401 + from narwhals._plan.expressions.functions import MapBatches # noqa: F401 from narwhals._plan.literal import LiteralValue from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions from narwhals._plan.protocols import Ctx, FrameT_contra, R_co diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/expressions/functions.py similarity index 100% rename from narwhals/_plan/functions.py rename to narwhals/_plan/expressions/functions.py diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 8f1fdcd9a6..15f6d0a31d 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -11,10 +11,14 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._plan import boolean, functions as F - from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull, Not from narwhals._plan.dummy import BaseFrame, DataFrame, Series - from narwhals._plan.expressions import aggregation as agg, expr + from narwhals._plan.expressions import ( + aggregation as agg, + boolean, + expr, + functions as F, + ) + from narwhals._plan.expressions.boolean import IsBetween, IsFinite, IsNan, IsNull, Not from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.options import SortMultipleOptions from narwhals._plan.ranges import IntRange diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index b7e0736e15..d82ca745c4 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -11,7 +11,7 @@ from narwhals._plan import operators as ops from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR from narwhals._plan.dummy import Expr, Series - from narwhals._plan.functions import RollingWindow + from narwhals._plan.expressions.functions import RollingWindow from narwhals._plan.ranges import RangeFunction from narwhals.typing import ( NativeDataFrame, diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 2e3a69df90..70567917f3 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -11,11 +11,11 @@ import narwhals as nw import narwhals._plan.demo as nwd -from narwhals._plan import boolean, functions as F, operators as ops +from narwhals._plan import operators as ops from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import Expr, Series from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir -from narwhals._plan.expressions import expr +from narwhals._plan.expressions import boolean, expr, functions as F from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.literal import SeriesLiteral from narwhals.exceptions import ( From 46de12bf0c34f53296e045de1d2e09efcd21ac29 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:17:00 +0000 Subject: [PATCH 06/36] refactor: Move remaining `Function`-based modules --- narwhals/_plan/arrow/namespace.py | 4 ++-- narwhals/_plan/demo.py | 4 ++-- narwhals/_plan/dummy.py | 22 +++++++++---------- .../_plan/{ => expressions}/categorical.py | 0 narwhals/_plan/expressions/expr.py | 2 +- narwhals/_plan/{ => expressions}/lists.py | 0 narwhals/_plan/{ => expressions}/ranges.py | 0 narwhals/_plan/{ => expressions}/strings.py | 0 narwhals/_plan/{ => expressions}/struct.py | 0 narwhals/_plan/{ => expressions}/temporal.py | 0 narwhals/_plan/{ => expressions}/window.py | 0 narwhals/_plan/protocols.py | 4 ++-- narwhals/_plan/typing.py | 2 +- tests/plan/expr_rewrites_test.py | 2 +- 14 files changed, 20 insertions(+), 20 deletions(-) rename narwhals/_plan/{ => expressions}/categorical.py (100%) rename narwhals/_plan/{ => expressions}/lists.py (100%) rename narwhals/_plan/{ => expressions}/ranges.py (100%) rename narwhals/_plan/{ => expressions}/strings.py (100%) rename narwhals/_plan/{ => expressions}/struct.py (100%) rename narwhals/_plan/{ => expressions}/temporal.py (100%) rename narwhals/_plan/{ => expressions}/window.py (100%) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 55f7f4017d..8981e549a9 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -26,8 +26,8 @@ from narwhals._plan.expressions import expr, functions as F from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal from narwhals._plan.expressions.expr import FunctionExpr, RangeExpr - from narwhals._plan.ranges import IntRange - from narwhals._plan.strings import ConcatStr + from narwhals._plan.expressions.ranges import IntRange + from narwhals._plan.expressions.strings import ConcatStr from narwhals.typing import ConcatMethod, NonNestedLiteral, PythonLiteral diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index d72d877759..92f40ba471 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -7,9 +7,9 @@ from narwhals._plan.common import into_dtype, py_to_narwhals_dtype from narwhals._plan.expressions import boolean, expr, functions as F from narwhals._plan.expressions.expr import All, Len +from narwhals._plan.expressions.ranges import IntRange +from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.literal import ScalarLiteral, SeriesLiteral -from narwhals._plan.ranges import IntRange -from narwhals._plan.strings import ConcatStr from narwhals._plan.when_then import When from narwhals._utils import Version, flatten diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 1686c36ceb..f95f1bf9e1 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -11,6 +11,7 @@ from narwhals._plan.common import into_dtype from narwhals._plan.contexts import ExprContext from narwhals._plan.expressions import aggregation as agg, boolean, expr, functions as F +from narwhals._plan.expressions.window import Over from narwhals._plan.options import ( EWMOptions, RankOptions, @@ -20,7 +21,6 @@ ) from narwhals._plan.selectors import by_name from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT -from narwhals._plan.window import Over from narwhals._utils import Version, generate_repr from narwhals.dependencies import is_pyarrow_chunked_array, is_pyarrow_table from narwhals.exceptions import ComputeError, InvalidOperationError @@ -30,10 +30,13 @@ import pyarrow as pa from typing_extensions import Never, Self - from narwhals._plan.categorical import ExprCatNamespace from narwhals._plan.common import ExprIR, Function, NamedIR + from narwhals._plan.expressions.categorical import ExprCatNamespace + from narwhals._plan.expressions.lists import ExprListNamespace from narwhals._plan.expressions.name import ExprNameNamespace - from narwhals._plan.lists import ExprListNamespace + from narwhals._plan.expressions.strings import ExprStringNamespace + from narwhals._plan.expressions.struct import ExprStructNamespace + from narwhals._plan.expressions.temporal import ExprDateTimeNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.protocols import ( CompliantBaseFrame, @@ -41,9 +44,6 @@ CompliantSeries, ) from narwhals._plan.schema import FrozenSchema - from narwhals._plan.strings import ExprStringNamespace - from narwhals._plan.struct import ExprStructNamespace - from narwhals._plan.temporal import ExprDateTimeNamespace from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq, Udf from narwhals.dtypes import DType from narwhals.typing import ( @@ -573,31 +573,31 @@ def name(self) -> ExprNameNamespace: @property def cat(self) -> ExprCatNamespace: - from narwhals._plan.categorical import ExprCatNamespace + from narwhals._plan.expressions.categorical import ExprCatNamespace return ExprCatNamespace(_expr=self) @property def struct(self) -> ExprStructNamespace: - from narwhals._plan.struct import ExprStructNamespace + from narwhals._plan.expressions.struct import ExprStructNamespace return ExprStructNamespace(_expr=self) @property def dt(self) -> ExprDateTimeNamespace: - from narwhals._plan.temporal import ExprDateTimeNamespace + from narwhals._plan.expressions.temporal import ExprDateTimeNamespace return ExprDateTimeNamespace(_expr=self) @property def list(self) -> ExprListNamespace: - from narwhals._plan.lists import ExprListNamespace + from narwhals._plan.expressions.lists import ExprListNamespace return ExprListNamespace(_expr=self) @property def str(self) -> ExprStringNamespace: - from narwhals._plan.strings import ExprStringNamespace + from narwhals._plan.expressions.strings import ExprStringNamespace return ExprStringNamespace(_expr=self) diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/expressions/categorical.py similarity index 100% rename from narwhals/_plan/categorical.py rename to narwhals/_plan/expressions/categorical.py diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index ced5a0be9a..b621300b0c 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -32,11 +32,11 @@ from typing_extensions import Self from narwhals._plan.expressions.functions import MapBatches # noqa: F401 + from narwhals._plan.expressions.window import Window from narwhals._plan.literal import LiteralValue from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals._plan.selectors import Selector - from narwhals._plan.window import Window from narwhals.dtypes import DType __all__ = [ diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/expressions/lists.py similarity index 100% rename from narwhals/_plan/lists.py rename to narwhals/_plan/expressions/lists.py diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/expressions/ranges.py similarity index 100% rename from narwhals/_plan/ranges.py rename to narwhals/_plan/expressions/ranges.py diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/expressions/strings.py similarity index 100% rename from narwhals/_plan/strings.py rename to narwhals/_plan/expressions/strings.py diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/expressions/struct.py similarity index 100% rename from narwhals/_plan/struct.py rename to narwhals/_plan/expressions/struct.py diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/expressions/temporal.py similarity index 100% rename from narwhals/_plan/temporal.py rename to narwhals/_plan/expressions/temporal.py diff --git a/narwhals/_plan/window.py b/narwhals/_plan/expressions/window.py similarity index 100% rename from narwhals/_plan/window.py rename to narwhals/_plan/expressions/window.py diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 15f6d0a31d..3de4ffc543 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -20,9 +20,9 @@ ) from narwhals._plan.expressions.boolean import IsBetween, IsFinite, IsNan, IsNull, Not from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr + from narwhals._plan.expressions.ranges import IntRange + from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.options import SortMultipleOptions - from narwhals._plan.ranges import IntRange - from narwhals._plan.strings import ConcatStr from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType from narwhals.typing import ( diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index d82ca745c4..897bd0df2f 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -12,7 +12,7 @@ from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR from narwhals._plan.dummy import Expr, Series from narwhals._plan.expressions.functions import RollingWindow - from narwhals._plan.ranges import RangeFunction + from narwhals._plan.expressions.ranges import RangeFunction from narwhals.typing import ( NativeDataFrame, NativeFrame, diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 964d03aeac..df0a390312 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -14,7 +14,7 @@ rewrite_elementwise_over, ) from narwhals._plan.expressions.expr import WindowExpr -from narwhals._plan.window import Over +from narwhals._plan.expressions.window import Over from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_expr_ir_equal From 7368936b5510d3d176827b53071984e702cd236d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:23:57 +0000 Subject: [PATCH 07/36] refactor: Mostly move everything else --- narwhals/_plan/arrow/functions.py | 2 +- narwhals/_plan/arrow/namespace.py | 2 +- narwhals/_plan/demo.py | 2 +- narwhals/_plan/dummy.py | 12 +++++++++--- narwhals/_plan/exceptions.py | 2 +- narwhals/_plan/expressions/boolean.py | 2 +- narwhals/_plan/expressions/expr.py | 4 ++-- narwhals/_plan/{ => expressions}/literal.py | 0 narwhals/_plan/{ => expressions}/operators.py | 0 narwhals/_plan/{ => expressions}/selectors.py | 0 narwhals/_plan/meta.py | 2 +- narwhals/_plan/typing.py | 2 +- tests/plan/compliant_test.py | 4 +++- tests/plan/expr_expansion_test.py | 3 ++- tests/plan/expr_parsing_test.py | 5 ++--- tests/plan/expr_rewrites_test.py | 3 ++- 16 files changed, 27 insertions(+), 18 deletions(-) rename narwhals/_plan/{ => expressions}/literal.py (100%) rename narwhals/_plan/{ => expressions}/operators.py (100%) rename narwhals/_plan/{ => expressions}/selectors.py (100%) diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 83ecafae5f..7a16404d3d 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -13,7 +13,7 @@ chunked_array as _chunked_array, floordiv_compat as floordiv, ) -from narwhals._plan import operators as ops +from narwhals._plan.expressions import operators as ops from narwhals._utils import Implementation if TYPE_CHECKING: diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 8981e549a9..678de4d0ff 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -10,7 +10,7 @@ from narwhals._plan._guards import is_tuple_of from narwhals._plan.arrow import functions as fn from narwhals._plan.common import collect -from narwhals._plan.literal import is_literal_scalar +from narwhals._plan.expressions.literal import is_literal_scalar from narwhals._plan.protocols import EagerNamespace from narwhals._utils import Version from narwhals.exceptions import InvalidOperationError diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 92f40ba471..ddf025cf98 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -7,9 +7,9 @@ from narwhals._plan.common import into_dtype, py_to_narwhals_dtype from narwhals._plan.expressions import boolean, expr, functions as F from narwhals._plan.expressions.expr import All, Len +from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.expressions.ranges import IntRange from narwhals._plan.expressions.strings import ConcatStr -from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.when_then import When from narwhals._utils import Version, flatten diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index f95f1bf9e1..bf638bb2ed 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -6,11 +6,18 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload -from narwhals._plan import expr_expansion, expr_parsing as parse, operators as ops +from narwhals._plan import expr_expansion, expr_parsing as parse from narwhals._plan._guards import is_column, is_expr, is_series from narwhals._plan.common import into_dtype from narwhals._plan.contexts import ExprContext -from narwhals._plan.expressions import aggregation as agg, boolean, expr, functions as F +from narwhals._plan.expressions import ( + aggregation as agg, + boolean, + expr, + functions as F, + operators as ops, +) +from narwhals._plan.expressions.selectors import by_name from narwhals._plan.expressions.window import Over from narwhals._plan.options import ( EWMOptions, @@ -19,7 +26,6 @@ SortOptions, rolling_options, ) -from narwhals._plan.selectors import by_name from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT from narwhals._utils import Version, generate_repr from narwhals.dependencies import is_pyarrow_chunked_array, is_pyarrow_table diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 0507cb51c5..f61bb8e148 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -27,7 +27,7 @@ from narwhals._plan.common import ExprIR, Function from narwhals._plan.expressions.aggregation import AggExpr from narwhals._plan.expressions.expr import FunctionExpr, WindowExpr - from narwhals._plan.operators import Operator + from narwhals._plan.expressions.operators import Operator from narwhals._plan.options import SortOptions from narwhals._plan.typing import IntoExpr, Seq diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index fc9d863fdc..c3cbda1a1e 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -68,7 +68,7 @@ def from_iterable(cls, other: t.Iterable[t.Any], /) -> IsInSeq: class IsInSeries(IsIn["Literal[Series[NativeSeriesT]]"]): @classmethod def from_series(cls, other: Series[NativeSeriesT], /) -> IsInSeries[NativeSeriesT]: - from narwhals._plan.literal import SeriesLiteral + from narwhals._plan.expressions.literal import SeriesLiteral return IsInSeries(other=SeriesLiteral(value=other).to_literal()) diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index b621300b0c..a6db604eba 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -32,11 +32,11 @@ from typing_extensions import Self from narwhals._plan.expressions.functions import MapBatches # noqa: F401 + from narwhals._plan.expressions.literal import LiteralValue + from narwhals._plan.expressions.selectors import Selector from narwhals._plan.expressions.window import Window - from narwhals._plan.literal import LiteralValue from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions from narwhals._plan.protocols import Ctx, FrameT_contra, R_co - from narwhals._plan.selectors import Selector from narwhals.dtypes import DType __all__ = [ diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/expressions/literal.py similarity index 100% rename from narwhals/_plan/literal.py rename to narwhals/_plan/expressions/literal.py diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/expressions/operators.py similarity index 100% rename from narwhals/_plan/operators.py rename to narwhals/_plan/expressions/operators.py diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/expressions/selectors.py similarity index 100% rename from narwhals/_plan/selectors.py rename to narwhals/_plan/expressions/selectors.py diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index d20529cfc1..e941a2b2da 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -182,7 +182,7 @@ def is_column(ir: ExprIR) -> TypeIs[Column]: def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: from narwhals._plan.expressions import expr - from narwhals._plan.literal import is_literal_scalar + from narwhals._plan.expressions.literal import is_literal_scalar return ( isinstance(ir, expr.Literal) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 897bd0df2f..f854d6cd44 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -8,9 +8,9 @@ from typing_extensions import TypeAlias from narwhals import dtypes - from narwhals._plan import operators as ops from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR from narwhals._plan.dummy import Expr, Series + from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow from narwhals._plan.expressions.ranges import RangeFunction from narwhals.typing import ( diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index dc548968a4..801ce794c8 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -4,13 +4,15 @@ import pytest +from narwhals._plan.expressions import selectors as ndcs + pytest.importorskip("pyarrow") pytest.importorskip("numpy") import numpy as np import pyarrow as pa import narwhals as nw -from narwhals._plan import demo as nwd, selectors as ndcs +from narwhals._plan import demo as nwd from narwhals._plan._guards import is_expr from narwhals._plan.dummy import DataFrame from narwhals._utils import Version diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index ba45c5a7b1..fcd81b486c 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -6,13 +6,14 @@ import pytest import narwhals as nw -from narwhals._plan import demo as nwd, selectors as ndcs +from narwhals._plan import demo as nwd from narwhals._plan.expr_expansion import ( prepare_projection, replace_selector, rewrite_special_aliases, ) from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir +from narwhals._plan.expressions import selectors as ndcs from narwhals._plan.expressions.expr import Alias, Columns from narwhals._plan.schema import freeze_schema from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 70567917f3..b820fce29b 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -11,13 +11,12 @@ import narwhals as nw import narwhals._plan.demo as nwd -from narwhals._plan import operators as ops from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import Expr, Series from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir -from narwhals._plan.expressions import boolean, expr, functions as F +from narwhals._plan.expressions import boolean, expr, functions as F, operators as ops from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr -from narwhals._plan.literal import SeriesLiteral +from narwhals._plan.expressions.literal import SeriesLiteral from narwhals.exceptions import ( InvalidIntoExprError, InvalidOperationError, diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index df0a390312..916826e1ae 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -5,7 +5,7 @@ import pytest import narwhals as nw -from narwhals._plan import demo as nwd, expr_parsing as parse, selectors as ndcs +from narwhals._plan import demo as nwd, expr_parsing as parse from narwhals._plan._guards import is_expr from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.expr_rewrites import ( @@ -13,6 +13,7 @@ rewrite_binary_agg_over, rewrite_elementwise_over, ) +from narwhals._plan.expressions import selectors as ndcs from narwhals._plan.expressions.expr import WindowExpr from narwhals._plan.expressions.window import Over from narwhals.exceptions import InvalidOperationError From 4ddf7e4c4d8e5216bf8baa0e294a0fe2f28504ba Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:33:35 +0000 Subject: [PATCH 08/36] refactor: Rename `_plan.demo` -> `_plan.functions` --- narwhals/_plan/common.py | 4 ++-- narwhals/_plan/dummy.py | 4 ++-- narwhals/_plan/expr_parsing.py | 4 ++-- narwhals/_plan/{demo.py => functions.py} | 2 +- narwhals/_plan/meta.py | 2 +- tests/plan/compliant_test.py | 2 +- tests/plan/expr_expansion_test.py | 2 +- tests/plan/expr_parsing_test.py | 2 +- tests/plan/expr_rewrites_test.py | 2 +- tests/plan/meta_test.py | 2 +- 10 files changed, 13 insertions(+), 13 deletions(-) rename narwhals/_plan/{demo.py => functions.py} (98%) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 3364de4e80..ed022aa842 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -180,7 +180,7 @@ def iter_left(self) -> Iterator[ExprIR]: """Yield nodes root->leaf. Examples: - >>> from narwhals._plan import demo as nwd + >>> from narwhals._plan import functions as nwd >>> >>> a = nwd.col("a") >>> b = a.alias("b") @@ -215,7 +215,7 @@ def iter_right(self) -> Iterator[ExprIR]: Identical to `iter_left` for root nodes. Examples: - >>> from narwhals._plan import demo as nwd + >>> from narwhals._plan import functions as nwd >>> >>> a = nwd.col("a") >>> b = a.alias("b") diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index bf638bb2ed..d0f7e095f0 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -567,9 +567,9 @@ def name(self) -> ExprNameNamespace: """Specialized expressions for modifying the name of existing expressions. Examples: - >>> from narwhals._plan import demo as nw + >>> from narwhals._plan import functions as nwd >>> - >>> renamed = nw.col("a", "b").name.suffix("_changed") + >>> renamed = nwd.col("a", "b").name.suffix("_changed") >>> str(renamed._ir) "RenameAlias(expr=Columns(names=[a, b]), function=Suffix(suffix='_changed'))" """ diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index 303017d700..05e503eb68 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -88,7 +88,7 @@ def parse_into_expr_ir( input: IntoExpr, *, str_as_lit: bool = False, dtype: IntoDType | None = None ) -> ExprIR: """Parse a single input into an `ExprIR` node.""" - from narwhals._plan import demo as nwd + from narwhals._plan import functions as nwd if is_expr(input): expr = input @@ -157,7 +157,7 @@ def _parse_named_inputs(named_inputs: dict[str, IntoExpr], /) -> Iterator[ExprIR def _parse_constraints(constraints: dict[str, IntoExpr], /) -> Iterator[ExprIR]: - from narwhals._plan import demo as nwd + from narwhals._plan import functions as nwd for name, value in constraints.items(): yield (nwd.col(name) == value)._ir diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/functions.py similarity index 98% rename from narwhals/_plan/demo.py rename to narwhals/_plan/functions.py index ddf025cf98..49bebf0d53 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/functions.py @@ -129,7 +129,7 @@ def when( """Start a `when-then-otherwise` expression. Examples: - >>> from narwhals._plan import demo as nwd + >>> from narwhals._plan import functions as nwd >>> nwd.when(nwd.col("y") == "b").then(1) nw._plan.Expr(main): diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index e941a2b2da..01d218fcd4 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -51,7 +51,7 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: """Get the output name of this expression. Examples: - >>> from narwhals._plan import demo as nwd + >>> from narwhals._plan import functions as nwd >>> >>> a = nwd.col("a") >>> b = a.alias("b") diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 801ce794c8..7b1caa426f 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -12,7 +12,7 @@ import pyarrow as pa import narwhals as nw -from narwhals._plan import demo as nwd +from narwhals._plan import functions as nwd from narwhals._plan._guards import is_expr from narwhals._plan.dummy import DataFrame from narwhals._utils import Version diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index fcd81b486c..040aa03197 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -6,7 +6,7 @@ import pytest import narwhals as nw -from narwhals._plan import demo as nwd +from narwhals._plan import functions as nwd from narwhals._plan.expr_expansion import ( prepare_projection, replace_selector, diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index b820fce29b..1b2e4fa30c 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -10,7 +10,7 @@ import pytest import narwhals as nw -import narwhals._plan.demo as nwd +import narwhals._plan.functions as nwd from narwhals._plan.common import ExprIR, Function from narwhals._plan.dummy import Expr, Series from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 916826e1ae..fcca17c0b1 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -5,7 +5,7 @@ import pytest import narwhals as nw -from narwhals._plan import demo as nwd, expr_parsing as parse +from narwhals._plan import expr_parsing as parse, functions as nwd from narwhals._plan._guards import is_expr from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.expr_rewrites import ( diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index e783e55c31..c09fcb3604 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -4,7 +4,7 @@ import pytest -import narwhals._plan.demo as nwd +import narwhals._plan.functions as nwd from tests.utils import POLARS_VERSION if TYPE_CHECKING: From ebef1d5e3bf41297e28835e4e408c1abfb7ab97d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:41:00 +0000 Subject: [PATCH 09/36] refactor: Split out `series.py` --- narwhals/_plan/_guards.py | 3 +- narwhals/_plan/arrow/namespace.py | 2 +- narwhals/_plan/dummy.py | 63 ++----------------------- narwhals/_plan/expressions/boolean.py | 2 +- narwhals/_plan/expressions/literal.py | 2 +- narwhals/_plan/functions.py | 3 +- narwhals/_plan/protocols.py | 5 +- narwhals/_plan/series.py | 67 +++++++++++++++++++++++++++ narwhals/_plan/typing.py | 3 +- tests/plan/expr_parsing_test.py | 3 +- 10 files changed, 85 insertions(+), 68 deletions(-) create mode 100644 narwhals/_plan/series.py diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index bfdb762ebd..f7aa223a41 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -11,9 +11,10 @@ if TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._plan.dummy import Expr, Series + from narwhals._plan.dummy import Expr from narwhals._plan.expressions import expr from narwhals._plan.protocols import CompliantSeries + from narwhals._plan.series import Series from narwhals._plan.typing import NativeSeriesT, Seq from narwhals.typing import NonNestedLiteral diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index 678de4d0ff..fc67bbc36e 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -22,12 +22,12 @@ from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.series import ArrowSeries as Series - from narwhals._plan.dummy import Series as NwSeries from narwhals._plan.expressions import expr, functions as F from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal from narwhals._plan.expressions.expr import FunctionExpr, RangeExpr from narwhals._plan.expressions.ranges import IntRange from narwhals._plan.expressions.strings import ConcatStr + from narwhals._plan.series import Series as NwSeries from narwhals.typing import ConcatMethod, NonNestedLiteral, PythonLiteral diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index d0f7e095f0..acff9ceea7 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -3,7 +3,7 @@ from __future__ import annotations import math -from collections.abc import Iterable, Iterator, Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload from narwhals._plan import expr_expansion, expr_parsing as parse @@ -26,9 +26,10 @@ SortOptions, rolling_options, ) +from narwhals._plan.series import Series from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT from narwhals._utils import Version, generate_repr -from narwhals.dependencies import is_pyarrow_chunked_array, is_pyarrow_table +from narwhals.dependencies import is_pyarrow_table from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.schema import Schema @@ -44,20 +45,14 @@ from narwhals._plan.expressions.struct import ExprStructNamespace from narwhals._plan.expressions.temporal import ExprDateTimeNamespace from narwhals._plan.meta import IRMetaNamespace - from narwhals._plan.protocols import ( - CompliantBaseFrame, - CompliantDataFrame, - CompliantSeries, - ) + from narwhals._plan.protocols import CompliantBaseFrame, CompliantDataFrame from narwhals._plan.schema import FrozenSchema from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq, Udf - from narwhals.dtypes import DType from narwhals.typing import ( ClosedInterval, FillNullStrategy, IntoDType, NativeFrame, - NativeSeries, NumericLiteral, RankMethod, RollingInterpolationMethod, @@ -834,53 +829,3 @@ def to_dict( def __len__(self) -> int: return len(self._compliant) - - -class Series(Generic[NativeSeriesT]): - _compliant: CompliantSeries[NativeSeriesT] - _version: ClassVar[Version] = Version.MAIN - - @property - def version(self) -> Version: - return self._version - - @property - def dtype(self) -> DType: - return self._compliant.dtype - - @property - def name(self) -> str: - return self._compliant.name - - # NOTE: Gave up on trying to get typing working for now - @classmethod - def from_native( - cls, native: NativeSeries, name: str = "", / - ) -> Series[pa.ChunkedArray[Any]]: - if is_pyarrow_chunked_array(native): - from narwhals._plan.arrow.series import ArrowSeries - - return ArrowSeries.from_native( - native, name, version=cls._version - ).to_narwhals() - - raise NotImplementedError(type(native)) - - @classmethod - def _from_compliant(cls, compliant: CompliantSeries[NativeSeriesT], /) -> Self: - obj = cls.__new__(cls) - obj._compliant = compliant - return obj - - def to_native(self) -> NativeSeriesT: - return self._compliant.native - - def to_list(self) -> list[Any]: - return self._compliant.to_list() - - def __iter__(self) -> Iterator[Any]: - yield from self.to_native() - - -class SeriesV1(Series[NativeSeriesT]): - _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index c3cbda1a1e..49ad2bd2ca 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -12,8 +12,8 @@ from typing_extensions import Self from narwhals._plan.common import ExprIR - from narwhals._plan.dummy import Series from narwhals._plan.expressions.expr import FunctionExpr, Literal # noqa: F401 + from narwhals._plan.series import Series from narwhals._plan.typing import NativeSeriesT, Seq # noqa: F401 from narwhals.typing import ClosedInterval diff --git a/narwhals/_plan/expressions/literal.py b/narwhals/_plan/expressions/literal.py index 9032d49cd4..7d46c8436c 100644 --- a/narwhals/_plan/expressions/literal.py +++ b/narwhals/_plan/expressions/literal.py @@ -9,8 +9,8 @@ if TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._plan.dummy import Series from narwhals._plan.expressions.expr import Literal + from narwhals._plan.series import Series from narwhals.dtypes import DType diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 49bebf0d53..c48e541567 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -14,7 +14,8 @@ from narwhals._utils import Version, flatten if t.TYPE_CHECKING: - from narwhals._plan.dummy import Expr, Series + from narwhals._plan.dummy import Expr + from narwhals._plan.series import Series from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT from narwhals.dtypes import IntegerType from narwhals.typing import IntoDType, NonNestedLiteral diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 3de4ffc543..bdadb708da 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._plan.dummy import BaseFrame, DataFrame, Series + from narwhals._plan.dummy import BaseFrame, DataFrame from narwhals._plan.expressions import ( aggregation as agg, boolean, @@ -23,6 +23,7 @@ from narwhals._plan.expressions.ranges import IntRange from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.options import SortMultipleOptions + from narwhals._plan.series import Series from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType from narwhals.typing import ( @@ -608,7 +609,7 @@ def name(self) -> str: return self._name def to_narwhals(self) -> Series[NativeSeriesT]: - from narwhals._plan.dummy import Series + from narwhals._plan.series import Series return Series[NativeSeriesT]._from_compliant(self) diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py new file mode 100644 index 0000000000..1ab9366ea3 --- /dev/null +++ b/narwhals/_plan/series.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, Generic + +from narwhals._plan.typing import NativeSeriesT +from narwhals._utils import Version +from narwhals.dependencies import is_pyarrow_chunked_array + +if TYPE_CHECKING: + from collections.abc import Iterator + + import pyarrow as pa + from typing_extensions import Self + + from narwhals._plan.protocols import CompliantSeries + from narwhals.dtypes import DType + from narwhals.typing import NativeSeries + + +class Series(Generic[NativeSeriesT]): + _compliant: CompliantSeries[NativeSeriesT] + _version: ClassVar[Version] = Version.MAIN + + @property + def version(self) -> Version: + return self._version + + @property + def dtype(self) -> DType: + return self._compliant.dtype + + @property + def name(self) -> str: + return self._compliant.name + + # NOTE: Gave up on trying to get typing working for now + @classmethod + def from_native( + cls, native: NativeSeries, name: str = "", / + ) -> Series[pa.ChunkedArray[Any]]: + if is_pyarrow_chunked_array(native): + from narwhals._plan.arrow.series import ArrowSeries + + return ArrowSeries.from_native( + native, name, version=cls._version + ).to_narwhals() + + raise NotImplementedError(type(native)) + + @classmethod + def _from_compliant(cls, compliant: CompliantSeries[NativeSeriesT], /) -> Self: + obj = cls.__new__(cls) + obj._compliant = compliant + return obj + + def to_native(self) -> NativeSeriesT: + return self._compliant.native + + def to_list(self) -> list[Any]: + return self._compliant.to_list() + + def __iter__(self) -> Iterator[Any]: + yield from self.to_native() + + +class SeriesV1(Series[NativeSeriesT]): + _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index f854d6cd44..b1007b839d 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -9,10 +9,11 @@ from narwhals import dtypes from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR - from narwhals._plan.dummy import Expr, Series + from narwhals._plan.dummy import Expr from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow from narwhals._plan.expressions.ranges import RangeFunction + from narwhals._plan.series import Series from narwhals.typing import ( NativeDataFrame, NativeFrame, diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 1b2e4fa30c..beb72acc35 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -12,11 +12,12 @@ import narwhals as nw import narwhals._plan.functions as nwd from narwhals._plan.common import ExprIR, Function -from narwhals._plan.dummy import Expr, Series +from narwhals._plan.dummy import Expr from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir from narwhals._plan.expressions import boolean, expr, functions as F, operators as ops from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.expressions.literal import SeriesLiteral +from narwhals._plan.series import Series from narwhals.exceptions import ( InvalidIntoExprError, InvalidOperationError, From 77e8cba24623ee65860b82572584e75c3beddffe Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:52:22 +0000 Subject: [PATCH 10/36] refactor: Split out `dataframe.py` --- narwhals/_plan/_guards.py | 10 +- narwhals/_plan/arrow/dataframe.py | 4 +- narwhals/_plan/dataframe.py | 146 ++++++++++++++++++++++++++++++ narwhals/_plan/dummy.py | 135 +-------------------------- narwhals/_plan/protocols.py | 2 +- tests/plan/compliant_test.py | 2 +- 6 files changed, 162 insertions(+), 137 deletions(-) create mode 100644 narwhals/_plan/dataframe.py diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index f7aa223a41..7d32a61da8 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -44,6 +44,12 @@ def _expr(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 return expr +def _series(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 + from narwhals._plan import series + + return series + + def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS) @@ -58,7 +64,7 @@ def is_column(obj: Any) -> TypeIs[Expr]: def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]: - return isinstance(obj, _dummy().Series) + return isinstance(obj, _series().Series) def is_compliant_series( @@ -68,7 +74,7 @@ def is_compliant_series( def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSeries]: - return isinstance(obj, (str, bytes, _dummy().Series)) or is_compliant_series(obj) + return isinstance(obj, (str, bytes, _series().Series)) or is_compliant_series(obj) def is_window_expr(obj: Any) -> TypeIs[expr.WindowExpr]: diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index fc61e69acc..fad9a119f9 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -21,7 +21,7 @@ from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.common import ExprIR, NamedIR - from narwhals._plan.dummy import DataFrame as NwDataFrame + from narwhals._plan.dataframe import DataFrame as NwDataFrame from narwhals._plan.options import SortMultipleOptions from narwhals._plan.typing import Seq from narwhals.dtypes import DType @@ -50,7 +50,7 @@ def __len__(self) -> int: return self.native.num_rows def to_narwhals(self) -> NwDataFrame[pa.Table, ChunkedArrayAny]: - from narwhals._plan.dummy import DataFrame + from narwhals._plan.dataframe import DataFrame return DataFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py new file mode 100644 index 0000000000..bd16b4ca24 --- /dev/null +++ b/narwhals/_plan/dataframe.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload + +from narwhals._plan import expr_expansion, expr_parsing as parse +from narwhals._plan.contexts import ExprContext +from narwhals._plan.dummy import _parse_sort_by +from narwhals._plan.series import Series +from narwhals._plan.typing import ( + IntoExpr, + NativeDataFrameT, + NativeFrameT, + NativeSeriesT, + OneOrIterable, +) +from narwhals._utils import Version, generate_repr +from narwhals.dependencies import is_pyarrow_table +from narwhals.schema import Schema + +if TYPE_CHECKING: + import pyarrow as pa + from typing_extensions import Self + + from narwhals._plan.common import ExprIR, NamedIR + from narwhals._plan.protocols import CompliantBaseFrame, CompliantDataFrame + from narwhals._plan.schema import FrozenSchema + from narwhals._plan.typing import Seq + from narwhals.typing import NativeFrame + + +class BaseFrame(Generic[NativeFrameT]): + _compliant: CompliantBaseFrame[Any, NativeFrameT] + _version: ClassVar[Version] = Version.MAIN + + @property + def version(self) -> Version: + return self._version + + @property + def schema(self) -> Schema: + return Schema(self._compliant.schema.items()) + + @property + def columns(self) -> list[str]: + return self._compliant.columns + + def __repr__(self) -> str: # pragma: no cover + return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) + + @classmethod + def from_native(cls, native: Any, /) -> Self: + raise NotImplementedError + + @classmethod + def _from_compliant(cls, compliant: CompliantBaseFrame[Any, NativeFrameT], /) -> Self: + obj = cls.__new__(cls) + obj._compliant = compliant + return obj + + def to_native(self) -> NativeFrameT: + return self._compliant.native + + def _project( + self, + exprs: tuple[OneOrIterable[IntoExpr], ...], + named_exprs: dict[str, Any], + context: ExprContext, + /, + ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: + """Temp, while these parts aren't connected, this is easier for testing.""" + irs, schema_frozen, output_names = expr_expansion.prepare_projection( + parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema + ) + named_irs = expr_expansion.into_named_irs(irs, output_names) + return schema_frozen.project(named_irs, context) + + def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: + named_irs, schema_projected = self._project( + exprs, named_exprs, ExprContext.SELECT + ) + return self._from_compliant(self._compliant.select(named_irs)) + + def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: + named_irs, schema_projected = self._project( + exprs, named_exprs, ExprContext.WITH_COLUMNS + ) + return self._from_compliant(self._compliant.with_columns(named_irs)) + + def sort( + self, + by: OneOrIterable[str], + *more_by: str, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, + ) -> Self: + sort, opts = _parse_sort_by( + by, *more_by, descending=descending, nulls_last=nulls_last + ) + irs, schema_frozen, output_names = expr_expansion.prepare_projection( + sort, self.schema + ) + named_irs = expr_expansion.into_named_irs(irs, output_names) + return self._from_compliant(self._compliant.sort(named_irs, opts)) + + +class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): + _compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT] + + @property + def _series(self) -> type[Series[NativeSeriesT]]: + return Series[NativeSeriesT] + + # NOTE: Gave up on trying to get typing working for now + @classmethod + def from_native( # type: ignore[override] + cls, native: NativeFrame, / + ) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: + if is_pyarrow_table(native): + from narwhals._plan.arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame.from_native(native, cls._version).to_narwhals() + + raise NotImplementedError(type(native)) + + @overload + def to_dict( + self, *, as_series: Literal[True] = ... + ) -> dict[str, Series[NativeSeriesT]]: ... + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload + def to_dict( + self, *, as_series: bool + ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: ... + def to_dict( + self, *, as_series: bool = True + ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: + if as_series: + return { + key: self._series._from_compliant(value) + for key, value in self._compliant.to_dict(as_series=as_series).items() + } + return self._compliant.to_dict(as_series=as_series) + + def __len__(self) -> int: + return len(self._compliant) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index acff9ceea7..939b26c0bc 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -4,12 +4,11 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload +from typing import TYPE_CHECKING, Any, ClassVar, overload -from narwhals._plan import expr_expansion, expr_parsing as parse +from narwhals._plan import expr_parsing as parse from narwhals._plan._guards import is_column, is_expr, is_series from narwhals._plan.common import into_dtype -from narwhals._plan.contexts import ExprContext from narwhals._plan.expressions import ( aggregation as agg, boolean, @@ -26,18 +25,13 @@ SortOptions, rolling_options, ) -from narwhals._plan.series import Series -from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT -from narwhals._utils import Version, generate_repr -from narwhals.dependencies import is_pyarrow_table +from narwhals._utils import Version from narwhals.exceptions import ComputeError, InvalidOperationError -from narwhals.schema import Schema if TYPE_CHECKING: - import pyarrow as pa from typing_extensions import Never, Self - from narwhals._plan.common import ExprIR, Function, NamedIR + from narwhals._plan.common import ExprIR, Function from narwhals._plan.expressions.categorical import ExprCatNamespace from narwhals._plan.expressions.lists import ExprListNamespace from narwhals._plan.expressions.name import ExprNameNamespace @@ -45,14 +39,11 @@ from narwhals._plan.expressions.struct import ExprStructNamespace from narwhals._plan.expressions.temporal import ExprDateTimeNamespace from narwhals._plan.meta import IRMetaNamespace - from narwhals._plan.protocols import CompliantBaseFrame, CompliantDataFrame - from narwhals._plan.schema import FrozenSchema from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq, Udf from narwhals.typing import ( ClosedInterval, FillNullStrategy, IntoDType, - NativeFrame, NumericLiteral, RankMethod, RollingInterpolationMethod, @@ -711,121 +702,3 @@ class ExprV1(Expr): class SelectorV1(Selector): _version: ClassVar[Version] = Version.V1 - - -class BaseFrame(Generic[NativeFrameT]): - _compliant: CompliantBaseFrame[Any, NativeFrameT] - _version: ClassVar[Version] = Version.MAIN - - @property - def version(self) -> Version: - return self._version - - @property - def schema(self) -> Schema: - return Schema(self._compliant.schema.items()) - - @property - def columns(self) -> list[str]: - return self._compliant.columns - - def __repr__(self) -> str: # pragma: no cover - return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) - - @classmethod - def from_native(cls, native: Any, /) -> Self: - raise NotImplementedError - - @classmethod - def _from_compliant(cls, compliant: CompliantBaseFrame[Any, NativeFrameT], /) -> Self: - obj = cls.__new__(cls) - obj._compliant = compliant - return obj - - def to_native(self) -> NativeFrameT: - return self._compliant.native - - def _project( - self, - exprs: tuple[OneOrIterable[IntoExpr], ...], - named_exprs: dict[str, Any], - context: ExprContext, - /, - ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: - """Temp, while these parts aren't connected, this is easier for testing.""" - irs, schema_frozen, output_names = expr_expansion.prepare_projection( - parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema - ) - named_irs = expr_expansion.into_named_irs(irs, output_names) - return schema_frozen.project(named_irs, context) - - def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, schema_projected = self._project( - exprs, named_exprs, ExprContext.SELECT - ) - return self._from_compliant(self._compliant.select(named_irs)) - - def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, schema_projected = self._project( - exprs, named_exprs, ExprContext.WITH_COLUMNS - ) - return self._from_compliant(self._compliant.with_columns(named_irs)) - - def sort( - self, - by: OneOrIterable[str], - *more_by: str, - descending: OneOrIterable[bool] = False, - nulls_last: OneOrIterable[bool] = False, - ) -> Self: - sort, opts = _parse_sort_by( - by, *more_by, descending=descending, nulls_last=nulls_last - ) - irs, schema_frozen, output_names = expr_expansion.prepare_projection( - sort, self.schema - ) - named_irs = expr_expansion.into_named_irs(irs, output_names) - return self._from_compliant(self._compliant.sort(named_irs, opts)) - - -class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): - _compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT] - - @property - def _series(self) -> type[Series[NativeSeriesT]]: - return Series[NativeSeriesT] - - # NOTE: Gave up on trying to get typing working for now - @classmethod - def from_native( # type: ignore[override] - cls, native: NativeFrame, / - ) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: - if is_pyarrow_table(native): - from narwhals._plan.arrow.dataframe import ArrowDataFrame - - return ArrowDataFrame.from_native(native, cls._version).to_narwhals() - - raise NotImplementedError(type(native)) - - @overload - def to_dict( - self, *, as_series: Literal[True] = ... - ) -> dict[str, Series[NativeSeriesT]]: ... - @overload - def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... - @overload - def to_dict( - self, *, as_series: bool - ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: ... - def to_dict( - self, *, as_series: bool = True - ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: - if as_series: - return { - key: self._series._from_compliant(value) - for key, value in self._compliant.to_dict(as_series=as_series).items() - } - return self._compliant.to_dict(as_series=as_series) - - def __len__(self) -> int: - return len(self._compliant) diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index bdadb708da..9d94879759 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs - from narwhals._plan.dummy import BaseFrame, DataFrame + from narwhals._plan.dataframe import BaseFrame, DataFrame from narwhals._plan.expressions import ( aggregation as agg, boolean, diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 7b1caa426f..6f00772988 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -14,7 +14,7 @@ import narwhals as nw from narwhals._plan import functions as nwd from narwhals._plan._guards import is_expr -from narwhals._plan.dummy import DataFrame +from narwhals._plan.dataframe import DataFrame from narwhals._utils import Version from narwhals.exceptions import ComputeError from tests.utils import assert_equal_data From 3c37d691ee820db71a329a174ee799e486a4303d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:57:46 +0000 Subject: [PATCH 11/36] refactor: Rename `_plan.dummy` -> `_plan.expr` --- narwhals/_plan/_guards.py | 6 +++--- narwhals/_plan/common.py | 12 ++++++------ narwhals/_plan/dataframe.py | 2 +- narwhals/_plan/{dummy.py => expr.py} | 2 -- narwhals/_plan/expressions/categorical.py | 2 +- narwhals/_plan/expressions/lists.py | 2 +- narwhals/_plan/expressions/name.py | 2 +- narwhals/_plan/expressions/selectors.py | 20 ++++++++++---------- narwhals/_plan/expressions/strings.py | 2 +- narwhals/_plan/expressions/struct.py | 2 +- narwhals/_plan/expressions/temporal.py | 2 +- narwhals/_plan/functions.py | 2 +- narwhals/_plan/typing.py | 2 +- narwhals/_plan/when_then.py | 2 +- tests/plan/compliant_test.py | 2 +- tests/plan/expr_expansion_test.py | 2 +- tests/plan/expr_parsing_test.py | 2 +- tests/plan/expr_rewrites_test.py | 2 +- tests/plan/meta_test.py | 2 +- tests/plan/utils.py | 2 +- 20 files changed, 35 insertions(+), 37 deletions(-) rename narwhals/_plan/{dummy.py => expr.py} (99%) diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 7d32a61da8..4437f93eac 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr from narwhals._plan.expressions import expr from narwhals._plan.protocols import CompliantSeries from narwhals._plan.series import Series @@ -33,9 +33,9 @@ def _dummy(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 - from narwhals._plan import dummy + from narwhals._plan import expr - return dummy + return expr def _expr(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index ed022aa842..f0d1a7112e 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -33,7 +33,7 @@ from typing_extensions import Self, TypeAlias - from narwhals._plan.dummy import Expr, Selector + from narwhals._plan.expr import Expr, Selector from narwhals._plan.expressions.expr import Alias, Cast, Column, FunctionExpr from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.protocols import Ctx, FrameT_contra, R_co @@ -153,9 +153,9 @@ def dispatch( return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] def to_narwhals(self, version: Version = Version.MAIN) -> Expr: - from narwhals._plan import dummy + from narwhals._plan import expr - tp = dummy.Expr if version is Version.MAIN else dummy.ExprV1 + tp = expr.Expr if version is Version.MAIN else expr.ExprV1 return tp._from_ir(self) @property @@ -281,11 +281,11 @@ def _repr_html_(self) -> str: class SelectorIR(ExprIR, config=ExprIROptions.no_dispatch()): def to_narwhals(self, version: Version = Version.MAIN) -> Selector: - from narwhals._plan import dummy + from narwhals._plan import expr if version is Version.MAIN: - return dummy.Selector._from_ir(self) - return dummy.SelectorV1._from_ir(self) + return expr.Selector._from_ir(self) + return expr.SelectorV1._from_ir(self) def matches_column(self, name: str, dtype: DType) -> bool: """Return True if we can select this column. diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index bd16b4ca24..cbf8326ec0 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -4,7 +4,7 @@ from narwhals._plan import expr_expansion, expr_parsing as parse from narwhals._plan.contexts import ExprContext -from narwhals._plan.dummy import _parse_sort_by +from narwhals._plan.expr import _parse_sort_by from narwhals._plan.series import Series from narwhals._plan.typing import ( IntoExpr, diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/expr.py similarity index 99% rename from narwhals/_plan/dummy.py rename to narwhals/_plan/expr.py index 939b26c0bc..37f2846b69 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/expr.py @@ -1,5 +1,3 @@ -"""Mock version of current narwhals API.""" - from __future__ import annotations import math diff --git a/narwhals/_plan/expressions/categorical.py b/narwhals/_plan/expressions/categorical.py index 13791bed16..d89e3da75d 100644 --- a/narwhals/_plan/expressions/categorical.py +++ b/narwhals/_plan/expressions/categorical.py @@ -5,7 +5,7 @@ from narwhals._plan.common import ExprNamespace, Function, IRNamespace if TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr # fmt: off diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index f4a45f217f..168f50dadf 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -6,7 +6,7 @@ from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr # fmt: off diff --git a/narwhals/_plan/expressions/name.py b/narwhals/_plan/expressions/name.py index 4147f20450..24bc648cde 100644 --- a/narwhals/_plan/expressions/name.py +++ b/narwhals/_plan/expressions/name.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from narwhals._compliant.typing import AliasName - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr class KeepName(common.ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): diff --git a/narwhals/_plan/expressions/selectors.py b/narwhals/_plan/expressions/selectors.py index 8dc44d70f6..5d6bfa8292 100644 --- a/narwhals/_plan/expressions/selectors.py +++ b/narwhals/_plan/expressions/selectors.py @@ -18,7 +18,7 @@ from datetime import timezone from typing import TypeVar - from narwhals._plan import dummy + from narwhals._plan import expr from narwhals._plan.expressions.expr import RootSelector from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType @@ -153,30 +153,30 @@ def matches_column(self, name: str, dtype: DType) -> bool: return isinstance(dtype, dtypes.String) -def all() -> dummy.Selector: +def all() -> expr.Selector: return All().to_selector().to_narwhals() -def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> dummy.Selector: +def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> expr.Selector: return ByDType.from_dtypes(*dtypes).to_selector().to_narwhals() -def by_name(*names: OneOrIterable[str]) -> dummy.Selector: +def by_name(*names: OneOrIterable[str]) -> expr.Selector: return Matches.from_names(*names).to_selector().to_narwhals() -def boolean() -> dummy.Selector: +def boolean() -> expr.Selector: return Boolean().to_selector().to_narwhals() -def categorical() -> dummy.Selector: +def categorical() -> expr.Selector: return Categorical().to_selector().to_narwhals() def datetime( time_unit: OneOrIterable[TimeUnit] | None = None, time_zone: OneOrIterable[str | timezone | None] = ("*", None), -) -> dummy.Selector: +) -> expr.Selector: return ( Datetime.from_time_unit_and_time_zone(time_unit, time_zone) .to_selector() @@ -184,13 +184,13 @@ def datetime( ) -def matches(pattern: str) -> dummy.Selector: +def matches(pattern: str) -> expr.Selector: return Matches.from_string(pattern).to_selector().to_narwhals() -def numeric() -> dummy.Selector: +def numeric() -> expr.Selector: return Numeric().to_selector().to_narwhals() -def string() -> dummy.Selector: +def string() -> expr.Selector: return String().to_selector().to_narwhals() diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 4c1f4af303..94812a8c8f 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -6,7 +6,7 @@ from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr # fmt: off diff --git a/narwhals/_plan/expressions/struct.py b/narwhals/_plan/expressions/struct.py index 2a3eca0b27..b978ed4295 100644 --- a/narwhals/_plan/expressions/struct.py +++ b/narwhals/_plan/expressions/struct.py @@ -6,7 +6,7 @@ from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr class StructFunction(Function, accessor="struct"): ... diff --git a/narwhals/_plan/expressions/temporal.py b/narwhals/_plan/expressions/temporal.py index bd21388728..e17bdbfb22 100644 --- a/narwhals/_plan/expressions/temporal.py +++ b/narwhals/_plan/expressions/temporal.py @@ -10,7 +10,7 @@ from typing_extensions import TypeAlias, TypeIs from narwhals._duration import IntervalUnit - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr from narwhals.typing import TimeUnit PolarsTimeUnit: TypeAlias = Literal["ns", "us", "ms"] diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index c48e541567..9ef91db398 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -14,7 +14,7 @@ from narwhals._utils import Version, flatten if t.TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr from narwhals._plan.series import Series from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT from narwhals.dtypes import IntegerType diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index b1007b839d..3f1a9ea313 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -9,7 +9,7 @@ from narwhals import dtypes from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow from narwhals._plan.expressions.ranges import RangeFunction diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index b32f9c4f77..0853d4cd20 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -4,7 +4,7 @@ from narwhals._plan._guards import is_expr from narwhals._plan._immutable import Immutable -from narwhals._plan.dummy import Expr +from narwhals._plan.expr import Expr from narwhals._plan.expr_parsing import ( parse_into_expr_ir, parse_predicates_constraints_into_expr_ir, diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 6f00772988..ef179346c3 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from collections.abc import Sequence - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr from narwhals.typing import PythonLiteral diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 040aa03197..fec5a59b16 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -23,7 +23,7 @@ from collections.abc import Iterable, Sequence from narwhals._plan.common import ExprIR - from narwhals._plan.dummy import Expr, Selector + from narwhals._plan.expr import Expr, Selector from narwhals._plan.typing import IntoExpr, MapIR from narwhals.dtypes import DType diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index beb72acc35..993024a3ab 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -12,7 +12,7 @@ import narwhals as nw import narwhals._plan.functions as nwd from narwhals._plan.common import ExprIR, Function -from narwhals._plan.dummy import Expr +from narwhals._plan.expr import Expr from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir from narwhals._plan.expressions import boolean, expr, functions as F, operators as ops from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index fcca17c0b1..ab1598a9bc 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -20,7 +20,7 @@ from tests.plan.utils import assert_expr_ir_equal if TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr from narwhals._plan.typing import IntoExpr from narwhals.dtypes import DType diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index c09fcb3604..c4b256a696 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -8,7 +8,7 @@ from tests.utils import POLARS_VERSION if TYPE_CHECKING: - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr pytest.importorskip("polars") import polars as pl diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 4eaf98db9f..25d0acd3f4 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from typing_extensions import LiteralString - from narwhals._plan.dummy import Expr + from narwhals._plan.expr import Expr def _unwrap_ir(obj: Expr | ExprIR | NamedIR) -> ExprIR: From 1c12bd1dd9e864f8cd803abe22932b4487aad486 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:13:04 +0000 Subject: [PATCH 12/36] refactor: Rename `expr_parsing` -> `_parse` --- narwhals/_plan/{expr_parsing.py => _parse.py} | 0 narwhals/_plan/dataframe.py | 4 +-- narwhals/_plan/expr.py | 32 +++++++++---------- narwhals/_plan/expr_rewrites.py | 6 ++-- narwhals/_plan/functions.py | 20 ++++++------ narwhals/_plan/when_then.py | 4 +-- tests/plan/expr_expansion_test.py | 2 +- tests/plan/expr_parsing_test.py | 2 +- tests/plan/expr_rewrites_test.py | 6 ++-- 9 files changed, 37 insertions(+), 39 deletions(-) rename narwhals/_plan/{expr_parsing.py => _parse.py} (100%) diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/_parse.py similarity index 100% rename from narwhals/_plan/expr_parsing.py rename to narwhals/_plan/_parse.py diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index cbf8326ec0..642ea083c7 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload -from narwhals._plan import expr_expansion, expr_parsing as parse +from narwhals._plan import _parse, expr_expansion from narwhals._plan.contexts import ExprContext from narwhals._plan.expr import _parse_sort_by from narwhals._plan.series import Series @@ -69,7 +69,7 @@ def _project( ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: """Temp, while these parts aren't connected, this is easier for testing.""" irs, schema_frozen, output_names = expr_expansion.prepare_projection( - parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema + _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema ) named_irs = expr_expansion.into_named_irs(irs, output_names) return schema_frozen.project(named_irs, context) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 37f2846b69..8e406bd22e 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -4,8 +4,12 @@ from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, ClassVar, overload -from narwhals._plan import expr_parsing as parse from narwhals._plan._guards import is_column, is_expr, is_series +from narwhals._plan._parse import ( + parse_into_expr_ir, + parse_into_seq_of_expr_ir, + parse_predicates_constraints_into_expr_ir, +) from narwhals._plan.common import into_dtype from narwhals._plan.expressions import ( aggregation as agg, @@ -56,7 +60,7 @@ def _parse_sort_by( descending: OneOrIterable[bool] = False, nulls_last: OneOrIterable[bool] = False, ) -> tuple[Seq[ExprIR], SortMultipleOptions]: - sort_by = parse.parse_into_seq_of_expr_ir(by, *more_by) + sort_by = parse_into_seq_of_expr_ir(by, *more_by) if length_changing := next((e for e in sort_by if e.is_scalar), None): msg = f"All expressions sort keys must preserve length, but got:\n{length_changing!r}" raise InvalidOperationError(msg) @@ -158,9 +162,9 @@ def over( msg = "At least one of `partition_by` or `order_by` must be specified." raise TypeError(msg) if partition_by: - partition = parse.parse_into_seq_of_expr_ir(*partition_by) + partition = parse_into_seq_of_expr_ir(*partition_by) if order_by is not None: - by = parse.parse_into_seq_of_expr_ir(order_by) + by = parse_into_seq_of_expr_ir(order_by) options = SortOptions(descending=descending, nulls_last=nulls_last) node = Over().to_ordered_window_expr(self._ir, partition, by, options) else: @@ -186,7 +190,7 @@ def sort_by( def filter( self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any ) -> Self: - by = parse.parse_predicates_constraints_into_expr_ir(*predicates, **constraints) + by = parse_predicates_constraints_into_expr_ir(*predicates, **constraints) return self._from_ir(expr.Filter(expr=self._ir, by=by)) def _with_unary(self, function: Function, /) -> Self: @@ -238,7 +242,7 @@ def fill_null( limit: int | None = None, ) -> Self: if strategy is None: - ir = parse.parse_into_expr_ir(value, str_as_lit=True) + ir = parse_into_expr_ir(value, str_as_lit=True) return self._from_ir(F.FillNull().to_function_expr(self._ir, ir)) return self._with_unary(F.FillNullWithStrategy(strategy=strategy, limit=limit)) @@ -263,11 +267,8 @@ def clip( lower_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, upper_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None, ) -> Self: - return self._from_ir( - F.Clip().to_function_expr( - self._ir, *parse.parse_into_seq_of_expr_ir(lower_bound, upper_bound) - ) - ) + it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) + return self._from_ir(F.Clip().to_function_expr(self._ir, *it)) def cum_count(self, *, reverse: bool = False) -> Self: return self._with_unary(F.CumCount(reverse=reverse)) @@ -430,7 +431,7 @@ def is_between( upper_bound: IntoExpr, closed: ClosedInterval = "both", ) -> Self: - it = parse.parse_into_seq_of_expr_ir(lower_bound, upper_bound) + it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) return self._from_ir( boolean.IsBetween(closed=closed).to_function_expr(self._ir, *it) ) @@ -453,7 +454,7 @@ def _with_binary( str_as_lit: bool = False, reflect: bool = False, ) -> Self: - other_ir = parse.parse_into_expr_ir(other, str_as_lit=str_as_lit) + other_ir = parse_into_expr_ir(other, str_as_lit=str_as_lit) args = (self._ir, other_ir) if not reflect else (other_ir, self._ir) return self._from_ir(op().to_binary_expr(*args)) @@ -530,12 +531,11 @@ def __rxor__(self, other: IntoExprColumn | int | bool) -> Self: return self._with_binary(ops.ExclusiveOr, other, reflect=True) def __pow__(self, exponent: IntoExprColumn | float) -> Self: - exp = parse.parse_into_expr_ir(exponent) + exp = parse_into_expr_ir(exponent) return self._from_ir(F.Pow().to_function_expr(self._ir, exp)) def __rpow__(self, base: IntoExprColumn | float) -> Self: - base_ = parse.parse_into_expr_ir(base) - return self._from_ir(F.Pow().to_function_expr(base_, self._ir)) + return self._from_ir(F.Pow().to_function_expr(parse_into_expr_ir(base), self._ir)) def __invert__(self) -> Self: return self._with_unary(boolean.Not()) diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index 597e8afc21..655632738e 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -4,13 +4,13 @@ from typing import TYPE_CHECKING -from narwhals._plan import expr_parsing as parse from narwhals._plan._guards import ( is_aggregation, is_binary_expr, is_function_expr, is_window_expr, ) +from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.common import NamedIR, map_ir, replace from narwhals._plan.expr_expansion import into_named_irs, prepare_projection @@ -31,9 +31,7 @@ def rewrite_all( - Currently we do a full traversal of each tree per-rewrite function - There's no caching *after* `prepare_projection` yet """ - out_irs, _, names = prepare_projection( - parse.parse_into_seq_of_expr_ir(*exprs), schema - ) + out_irs, _, names = prepare_projection(parse_into_seq_of_expr_ir(*exprs), schema) named_irs = into_named_irs(out_irs, names) return tuple(map_ir(ir, *rewrites) for ir in named_irs) diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 9ef91db398..5a38289a4f 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -3,7 +3,7 @@ import builtins import typing as t -from narwhals._plan import _guards, expr_parsing as parse +from narwhals._plan import _guards, _parse from narwhals._plan.common import into_dtype, py_to_narwhals_dtype from narwhals._plan.expressions import boolean, expr, functions as F from narwhals._plan.expressions.expr import All, Len @@ -81,32 +81,32 @@ def sum(*columns: str) -> Expr: def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) + it = _parse.parse_into_seq_of_expr_ir(*exprs) return boolean.AllHorizontal().to_function_expr(*it).to_narwhals() def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) + it = _parse.parse_into_seq_of_expr_ir(*exprs) return boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() def sum_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) + it = _parse.parse_into_seq_of_expr_ir(*exprs) return F.SumHorizontal().to_function_expr(*it).to_narwhals() def min_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) + it = _parse.parse_into_seq_of_expr_ir(*exprs) return F.MinHorizontal().to_function_expr(*it).to_narwhals() def max_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) + it = _parse.parse_into_seq_of_expr_ir(*exprs) return F.MaxHorizontal().to_function_expr(*it).to_narwhals() def mean_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: - it = parse.parse_into_seq_of_expr_ir(*exprs) + it = _parse.parse_into_seq_of_expr_ir(*exprs) return F.MeanHorizontal().to_function_expr(*it).to_narwhals() @@ -116,7 +116,7 @@ def concat_str( separator: str = "", ignore_nulls: bool = False, ) -> Expr: - it = parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) + it = _parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) return ( ConcatStr(separator=separator, ignore_nulls=ignore_nulls) .to_function_expr(*it) @@ -136,7 +136,7 @@ def when( nw._plan.Expr(main): .when([(col('y')) == (lit(str: b))]).then(lit(int: 1)).otherwise(lit(null)) """ - condition = parse.parse_predicates_constraints_into_expr_ir( + condition = _parse.parse_predicates_constraints_into_expr_ir( *predicates, **constraints ) return When._from_ir(condition) @@ -158,6 +158,6 @@ def int_range( raise NotImplementedError(msg) return ( IntRange(step=step, dtype=into_dtype(dtype)) - .to_function_expr(*parse.parse_into_seq_of_expr_ir(start, end)) + .to_function_expr(*_parse.parse_into_seq_of_expr_ir(start, end)) .to_narwhals() ) diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 0853d4cd20..4bf4447b16 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -4,11 +4,11 @@ from narwhals._plan._guards import is_expr from narwhals._plan._immutable import Immutable -from narwhals._plan.expr import Expr -from narwhals._plan.expr_parsing import ( +from narwhals._plan._parse import ( parse_into_expr_ir, parse_predicates_constraints_into_expr_ir, ) +from narwhals._plan.expr import Expr if TYPE_CHECKING: from narwhals._plan.common import ExprIR diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index fec5a59b16..dbf89b0b97 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -7,12 +7,12 @@ import narwhals as nw from narwhals._plan import functions as nwd +from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.expr_expansion import ( prepare_projection, replace_selector, rewrite_special_aliases, ) -from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir from narwhals._plan.expressions import selectors as ndcs from narwhals._plan.expressions.expr import Alias, Columns from narwhals._plan.schema import freeze_schema diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 993024a3ab..fd63c4a355 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -11,9 +11,9 @@ import narwhals as nw import narwhals._plan.functions as nwd +from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.common import ExprIR, Function from narwhals._plan.expr import Expr -from narwhals._plan.expr_parsing import parse_into_seq_of_expr_ir from narwhals._plan.expressions import boolean, expr, functions as F, operators as ops from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.expressions.literal import SeriesLiteral diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index ab1598a9bc..bf46387d84 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -5,7 +5,7 @@ import pytest import narwhals as nw -from narwhals._plan import expr_parsing as parse, functions as nwd +from narwhals._plan import _parse, functions as nwd from narwhals._plan._guards import is_expr from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.expr_rewrites import ( @@ -44,8 +44,8 @@ def schema_2() -> dict[str, DType]: def _to_window_expr(into_expr: IntoExpr, *partition_by: IntoExpr) -> WindowExpr: return WindowExpr( - expr=parse.parse_into_expr_ir(into_expr), - partition_by=parse.parse_into_seq_of_expr_ir(*partition_by), + expr=_parse.parse_into_expr_ir(into_expr), + partition_by=_parse.parse_into_seq_of_expr_ir(*partition_by), options=Over(), ) From 57dec59d8bc5ece7894a1861da62e77ce8537dee Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:23:45 +0000 Subject: [PATCH 13/36] ci: Update `pre-commit` exclude https://results.pre-commit.ci/run/github/760058710/1757607493.hhgBLFBsSIeLGiRbomtdHg --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9f9f105658..2ea270264b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -77,8 +77,8 @@ repos: narwhals/stable/v./_?dtypes.py| narwhals/.*__init__.py| narwhals/.*typing\.py| - narwhals/_plan/demo\.py| - narwhals/_plan/ranges\.py| + narwhals/_plan/functions\.py| + narwhals/_plan/expressions/ranges\.py| narwhals/_plan/schema\.py ) - id: pull-request-target From abaa2cc10677ac6331426f8892d01f4c4309c48c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 11:02:30 +0000 Subject: [PATCH 14/36] chore(ruff): partial bump fixes --- narwhals/_plan/dataframe.py | 12 +++--------- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 642ea083c7..a50132a61a 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -75,15 +75,11 @@ def _project( return schema_frozen.project(named_irs, context) def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, schema_projected = self._project( - exprs, named_exprs, ExprContext.SELECT - ) + named_irs, _ = self._project(exprs, named_exprs, ExprContext.SELECT) return self._from_compliant(self._compliant.select(named_irs)) def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: - named_irs, schema_projected = self._project( - exprs, named_exprs, ExprContext.WITH_COLUMNS - ) + named_irs, _ = self._project(exprs, named_exprs, ExprContext.WITH_COLUMNS) return self._from_compliant(self._compliant.with_columns(named_irs)) def sort( @@ -96,9 +92,7 @@ def sort( sort, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last ) - irs, schema_frozen, output_names = expr_expansion.prepare_projection( - sort, self.schema - ) + irs, _, output_names = expr_expansion.prepare_projection(sort, self.schema) named_irs = expr_expansion.into_named_irs(irs, output_names) return self._from_compliant(self._compliant.sort(named_irs, opts)) diff --git a/pyproject.toml b/pyproject.toml index 3b5db19672..fadf125896 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,6 @@ ignore = [ "FBT003", # boolean-positional-value-in-call (We enforce at definition site when it is a flag, not a value (e.g. `lit(False)`)) "FIX", # flake8-fixme "PD010", # pandas-use-of-dot-pivot-or-unstack - "PD901", # pandas-df-variable-name (This is a auxiliary library so dataframe variables have no concrete business meaning) "PLC0415", # `import` should be at the top-level of a file "PLR0913", # too-many-arguments "PLR2004", # magic-value-comparison @@ -209,6 +208,7 @@ extend-ignore-names = [ "C901", # complex-structure "PLR0912", # too-many-branches "PLR0916", # too-many-boolean-expressions + "RUF043", # temp ignore until sync ] "tpch/tests/*" = ["S101"] "utils/*" = ["S311"] From 978cd82b5e39579289fb34a7bd01492901f94a9d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 11:26:53 +0000 Subject: [PATCH 15/36] feat: export to `expressions` --- narwhals/_plan/expressions/__init__.py | 70 ++++++++++++++++++++++++++ narwhals/_plan/expressions/expr.py | 6 --- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index e69de29bb2..e3f89c4c83 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from narwhals._plan.common import ExprIR, SelectorIR # prob should move into package? +from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr +from narwhals._plan.expressions.expr import ( + Alias, + All, + AnonymousExpr, + BinaryExpr, + BinarySelector, + Cast, + Column, + Columns, + Exclude, + Filter, + FunctionExpr, + IndexColumns, + Len, + Literal, + Nth, + OrderedWindowExpr, + RollingExpr, + RootSelector, + Sort, + SortBy, + TernaryExpr, + WindowExpr, + _ColumnSelection, # if needs exposing, make it public! + col, + cols, + index_columns, + nth, +) +from narwhals._plan.expressions.name import KeepName, RenameAlias + +__all__ = [ + "AggExpr", + "Alias", + "All", + "AnonymousExpr", + "BinaryExpr", + "BinarySelector", + "Cast", + "Column", + "Columns", + "Exclude", + "ExprIR", + "Filter", + "FunctionExpr", + "IndexColumns", + "KeepName", + "Len", + "Literal", + "Nth", + "OrderableAggExpr", + "OrderedWindowExpr", + "RenameAlias", + "RollingExpr", + "RootSelector", + "SelectorIR", + "Sort", + "SortBy", + "TernaryExpr", + "WindowExpr", + "_ColumnSelection", + "col", + "cols", + "index_columns", + "nth", +] diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index a6db604eba..083173549e 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -8,8 +8,6 @@ from narwhals._plan.common import ExprIR, SelectorIR, collect from narwhals._plan.exceptions import function_expr_invalid_operation_error -from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr -from narwhals._plan.expressions.name import KeepName, RenameAlias from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( FunctionT_co, @@ -40,7 +38,6 @@ from narwhals.dtypes import DType __all__ = [ - "AggExpr", "Alias", "All", "AnonymousExpr", @@ -53,12 +50,9 @@ "Filter", "FunctionExpr", "IndexColumns", - "KeepName", "Len", "Literal", "Nth", - "OrderableAggExpr", - "RenameAlias", "RollingExpr", "RootSelector", "SelectorIR", From a8673f6699c98bc31d519b34aec9bcf263ddb5ea Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 11:27:46 +0000 Subject: [PATCH 16/36] refactor: Update imports --- narwhals/_plan/_guards.py | 36 ++++++++++++++++---------------- narwhals/_plan/dataframe.py | 3 ++- narwhals/_plan/expr_expansion.py | 6 ++++-- narwhals/_plan/meta.py | 17 +++++++-------- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 4437f93eac..0f62942ab8 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -11,8 +11,8 @@ if TYPE_CHECKING: from typing_extensions import TypeIs + from narwhals._plan import expressions as ir from narwhals._plan.expr import Expr - from narwhals._plan.expressions import expr from narwhals._plan.protocols import CompliantSeries from narwhals._plan.series import Series from narwhals._plan.typing import NativeSeriesT, Seq @@ -32,14 +32,14 @@ ) -def _dummy(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 - from narwhals._plan import expr +def _ir(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 + from narwhals._plan import expressions as ir - return expr + return ir def _expr(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 - from narwhals._plan.expressions import expr + from narwhals._plan import expr return expr @@ -55,7 +55,7 @@ def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: def is_expr(obj: Any) -> TypeIs[Expr]: - return isinstance(obj, _dummy().Expr) + return isinstance(obj, _expr().Expr) def is_column(obj: Any) -> TypeIs[Expr]: @@ -77,32 +77,32 @@ def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSerie return isinstance(obj, (str, bytes, _series().Series)) or is_compliant_series(obj) -def is_window_expr(obj: Any) -> TypeIs[expr.WindowExpr]: - return isinstance(obj, _expr().WindowExpr) +def is_window_expr(obj: Any) -> TypeIs[ir.WindowExpr]: + return isinstance(obj, _ir().WindowExpr) -def is_function_expr(obj: Any) -> TypeIs[expr.FunctionExpr[Any]]: - return isinstance(obj, _expr().FunctionExpr) +def is_function_expr(obj: Any) -> TypeIs[ir.FunctionExpr[Any]]: + return isinstance(obj, _ir().FunctionExpr) -def is_binary_expr(obj: Any) -> TypeIs[expr.BinaryExpr]: - return isinstance(obj, _expr().BinaryExpr) +def is_binary_expr(obj: Any) -> TypeIs[ir.BinaryExpr]: + return isinstance(obj, _ir().BinaryExpr) -def is_agg_expr(obj: Any) -> TypeIs[expr.AggExpr]: - return isinstance(obj, _expr().AggExpr) +def is_agg_expr(obj: Any) -> TypeIs[ir.AggExpr]: + return isinstance(obj, _ir().AggExpr) -def is_aggregation(obj: Any) -> TypeIs[expr.AggExpr | expr.FunctionExpr[Any]]: +def is_aggregation(obj: Any) -> TypeIs[ir.AggExpr | ir.FunctionExpr[Any]]: """Superset of `ExprIR.is_scalar`, excludes literals & len.""" return is_agg_expr(obj) or (is_function_expr(obj) and obj.is_scalar) -def is_literal(obj: Any) -> TypeIs[expr.Literal[Any]]: - return isinstance(obj, _expr().Literal) +def is_literal(obj: Any) -> TypeIs[ir.Literal[Any]]: + return isinstance(obj, _ir().Literal) -def is_horizontal_reduction(obj: Any) -> TypeIs[expr.FunctionExpr[Any]]: +def is_horizontal_reduction(obj: Any) -> TypeIs[ir.FunctionExpr[Any]]: return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index a50132a61a..ebfceb093e 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -21,7 +21,8 @@ import pyarrow as pa from typing_extensions import Self - from narwhals._plan.common import ExprIR, NamedIR + from narwhals._plan.common import NamedIR + from narwhals._plan.expressions import ExprIR from narwhals._plan.protocols import CompliantBaseFrame, CompliantDataFrame from narwhals._plan.schema import FrozenSchema from narwhals._plan.typing import Seq diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index 0eaafca10b..fb5f271308 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -45,21 +45,23 @@ from narwhals._plan import common, meta from narwhals._plan._guards import is_horizontal_reduction from narwhals._plan._immutable import Immutable -from narwhals._plan.common import ExprIR, NamedIR, SelectorIR +from narwhals._plan.common import NamedIR from narwhals._plan.exceptions import ( column_index_error, column_not_found_error, duplicate_error, ) -from narwhals._plan.expressions.expr import ( +from narwhals._plan.expressions import ( Alias, All, Columns, Exclude, + ExprIR, IndexColumns, KeepName, Nth, RenameAlias, + SelectorIR, _ColumnSelection, col, cols, diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 01d218fcd4..7dd82debb0 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -19,8 +19,7 @@ from typing_extensions import TypeIs - from narwhals._plan.common import ExprIR - from narwhals._plan.expressions.expr import Column + from narwhals._plan.expressions import Column, ExprIR class IRMetaNamespace(IRNamespace): @@ -86,7 +85,7 @@ def _expr_to_leaf_column_names_iter(ir: ExprIR) -> Iterator[str]: def _expr_to_leaf_column_exprs_iter(ir: ExprIR) -> Iterator[ExprIR]: - from narwhals._plan.expressions import expr + from narwhals._plan import expressions as expr for outer in ir.iter_root_names(): if isinstance(outer, (expr.Column, expr.All)): @@ -102,7 +101,7 @@ def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: msg = "no root column name found" return ComputeError(msg) leaf = leaves[0] - from narwhals._plan.expressions import expr + from narwhals._plan import expressions as expr if isinstance(leaf, expr.Column): return leaf.name @@ -119,7 +118,7 @@ def root_names_unique(irs: Iterable[ExprIR], /) -> set[str]: @lru_cache(maxsize=32) def _expr_output_name(ir: ExprIR) -> str | ComputeError: - from narwhals._plan.expressions import expr + from narwhals._plan import expressions as expr for e in ir.iter_output_name(): if isinstance(e, (expr.Column, expr.Alias, expr.Literal, expr.Len)): @@ -144,7 +143,7 @@ def get_single_leaf_name(ir: ExprIR) -> str | ComputeError: [`polars_plan::utils::get_single_leaf`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L151-L168 """ - from narwhals._plan.expressions import expr + from narwhals._plan import expressions as expr for e in ir.iter_right(): if isinstance(e, (expr.WindowExpr, expr.SortBy, expr.Filter)): @@ -159,7 +158,7 @@ def get_single_leaf_name(ir: ExprIR) -> str | ComputeError: def _has_multiple_outputs(ir: ExprIR) -> bool: - from narwhals._plan.expressions import expr + from narwhals._plan import expressions as expr return isinstance(ir, (expr.Columns, expr.IndexColumns, expr.SelectorIR, expr.All)) @@ -181,7 +180,7 @@ def is_column(ir: ExprIR) -> TypeIs[Column]: def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: - from narwhals._plan.expressions import expr + from narwhals._plan import expressions as expr from narwhals._plan.expressions.literal import is_literal_scalar return ( @@ -196,7 +195,7 @@ def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: - from narwhals._plan.expressions import expr + from narwhals._plan import expressions as expr return isinstance(ir, (expr.Column, expr._ColumnSelection, expr.SelectorIR)) or ( allow_aliasing and isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)) From da83726af1677e72b515f17d3e2d93665c694da5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 12:02:39 +0000 Subject: [PATCH 17/36] refactor: Use more `from narwhals._plan import expressions as ir` --- narwhals/_plan/_guards.py | 4 + narwhals/_plan/exceptions.py | 39 +++++----- narwhals/_plan/expr.py | 36 ++++----- narwhals/_plan/expressions/__init__.py | 2 + narwhals/_plan/functions.py | 13 ++-- narwhals/_plan/meta.py | 101 ++++++++++++------------- narwhals/_plan/when_then.py | 12 +-- 7 files changed, 104 insertions(+), 103 deletions(-) diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 0f62942ab8..e8b27ed376 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -63,6 +63,10 @@ def is_column(obj: Any) -> TypeIs[Expr]: return is_expr(obj) and obj.meta.is_column() +def is_column_ir(obj: Any) -> TypeIs[ir.Column]: + return isinstance(obj, _ir().Column) + + def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]: return isinstance(obj, _series().Series) diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index f61bb8e148..e378f8c80c 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -24,9 +24,8 @@ import pandas as pd import polars as pl - from narwhals._plan.common import ExprIR, Function - from narwhals._plan.expressions.aggregation import AggExpr - from narwhals._plan.expressions.expr import FunctionExpr, WindowExpr + from narwhals._plan import expressions as ir + from narwhals._plan.common import Function from narwhals._plan.expressions.operators import Operator from narwhals._plan.options import SortOptions from narwhals._plan.typing import IntoExpr, Seq @@ -37,13 +36,13 @@ # TODO @dangotbanned: Use arguments in error message -def agg_scalar_error(agg: AggExpr, scalar: ExprIR, /) -> InvalidOperationError: # noqa: ARG001 +def agg_scalar_error(agg: ir.AggExpr, scalar: ir.ExprIR, /) -> InvalidOperationError: # noqa: ARG001 msg = "Can't apply aggregations to scalar-like expressions." return InvalidOperationError(msg) def function_expr_invalid_operation_error( - function: Function, parent: ExprIR + function: Function, parent: ir.ExprIR ) -> InvalidOperationError: msg = f"Cannot use `{function!r}()` on aggregated expression `{parent!r}`." return InvalidOperationError(msg) @@ -57,7 +56,9 @@ def hist_bins_monotonic_error(bins: Seq[float]) -> ComputeError: # noqa: ARG001 # NOTE: Always underlining `right`, since the message refers to both types of exprs # Assuming the most recent as the issue -def binary_expr_shape_error(left: ExprIR, op: Operator, right: ExprIR) -> ShapeError: +def binary_expr_shape_error( + left: ir.ExprIR, op: Operator, right: ir.ExprIR +) -> ShapeError: lhs_op = f"{left!r} {op!r} " rhs = repr(right) indent = len(lhs_op) * " " @@ -71,7 +72,7 @@ def binary_expr_shape_error(left: ExprIR, op: Operator, right: ExprIR) -> ShapeE # TODO @dangotbanned: Share the right underline code w/ `binary_expr_shape_error` def binary_expr_multi_output_error( - left: ExprIR, op: Operator, right: ExprIR + left: ir.ExprIR, op: Operator, right: ir.ExprIR ) -> MultiOutputExpressionError: lhs_op = f"{left!r} {op!r} " rhs = repr(right) @@ -86,7 +87,7 @@ def binary_expr_multi_output_error( def binary_expr_length_changing_error( - left: ExprIR, op: Operator, right: ExprIR + left: ir.ExprIR, op: Operator, right: ir.ExprIR ) -> LengthChangingExprError: lhs, rhs = repr(left), repr(right) op_s = f" {op!r} " @@ -103,9 +104,9 @@ def binary_expr_length_changing_error( # TODO @dangotbanned: Use arguments in error message def over_nested_error( - expr: WindowExpr, # noqa: ARG001 - partition_by: Seq[ExprIR], # noqa: ARG001 - order_by: Seq[ExprIR] = (), # noqa: ARG001 + expr: ir.WindowExpr, # noqa: ARG001 + partition_by: Seq[ir.ExprIR], # noqa: ARG001 + order_by: Seq[ir.ExprIR] = (), # noqa: ARG001 sort_options: SortOptions | None = None, # noqa: ARG001 ) -> InvalidOperationError: msg = "Cannot nest `over` statements." @@ -114,9 +115,9 @@ def over_nested_error( # TODO @dangotbanned: Use arguments in error message def over_elementwise_error( - expr: FunctionExpr[Function], - partition_by: Seq[ExprIR], # noqa: ARG001 - order_by: Seq[ExprIR] = (), # noqa: ARG001 + expr: ir.FunctionExpr, + partition_by: Seq[ir.ExprIR], # noqa: ARG001 + order_by: Seq[ir.ExprIR] = (), # noqa: ARG001 sort_options: SortOptions | None = None, # noqa: ARG001 ) -> InvalidOperationError: msg = f"Cannot use `over` on expressions which are elementwise.\n{expr!r}" @@ -125,9 +126,9 @@ def over_elementwise_error( # TODO @dangotbanned: Use arguments in error message def over_row_separable_error( - expr: FunctionExpr[Function], - partition_by: Seq[ExprIR], # noqa: ARG001 - order_by: Seq[ExprIR] = (), # noqa: ARG001 + expr: ir.FunctionExpr, + partition_by: Seq[ir.ExprIR], # noqa: ARG001 + order_by: Seq[ir.ExprIR] = (), # noqa: ARG001 sort_options: SortOptions | None = None, # noqa: ARG001 ) -> InvalidOperationError: msg = f"Cannot use `over` on expressions which change length.\n{expr!r}" @@ -169,7 +170,7 @@ def is_iterable_polars_error( return TypeError(msg) -def duplicate_error(exprs: Seq[ExprIR]) -> DuplicateError: +def duplicate_error(exprs: Seq[ir.ExprIR]) -> DuplicateError: INDENT = "\n " # noqa: N806 names = [_output_name(expr) for expr in exprs] duplicates = {k for k, v in Counter(names).items() if v > 1} @@ -184,7 +185,7 @@ def duplicate_error(exprs: Seq[ExprIR]) -> DuplicateError: return DuplicateError(msg) -def _output_name(expr: ExprIR) -> str: +def _output_name(expr: ir.ExprIR) -> str: return expr.meta.output_name() diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 8e406bd22e..9a0210af2e 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -4,6 +4,7 @@ from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, ClassVar, overload +from narwhals._plan import expressions as ir from narwhals._plan._guards import is_column, is_expr, is_series from narwhals._plan._parse import ( parse_into_expr_ir, @@ -14,7 +15,6 @@ from narwhals._plan.expressions import ( aggregation as agg, boolean, - expr, functions as F, operators as ops, ) @@ -33,7 +33,7 @@ if TYPE_CHECKING: from typing_extensions import Never, Self - from narwhals._plan.common import ExprIR, Function + from narwhals._plan.common import Function from narwhals._plan.expressions.categorical import ExprCatNamespace from narwhals._plan.expressions.lists import ExprListNamespace from narwhals._plan.expressions.name import ExprNameNamespace @@ -59,7 +59,7 @@ def _parse_sort_by( *more_by: IntoExpr, descending: OneOrIterable[bool] = False, nulls_last: OneOrIterable[bool] = False, -) -> tuple[Seq[ExprIR], SortMultipleOptions]: +) -> tuple[Seq[ir.ExprIR], SortMultipleOptions]: sort_by = parse_into_seq_of_expr_ir(by, *more_by) if length_changing := next((e for e in sort_by if e.is_scalar), None): msg = f"All expressions sort keys must preserve length, but got:\n{length_changing!r}" @@ -71,7 +71,7 @@ def _parse_sort_by( # NOTE: Overly simplified placeholders for mocking typing # Entirely ignoring namespace + function binding class Expr: - _ir: ExprIR + _ir: ir.ExprIR _version: ClassVar[Version] = Version.MAIN def __repr__(self) -> str: @@ -85,9 +85,9 @@ def _repr_html_(self) -> str: return self._ir._repr_html_() @classmethod - def _from_ir(cls, ir: ExprIR, /) -> Self: + def _from_ir(cls, expr_ir: ir.ExprIR, /) -> Self: obj = cls.__new__(cls) - obj._ir = ir + obj._ir = expr_ir return obj @property @@ -101,7 +101,7 @@ def cast(self, dtype: IntoDType) -> Self: return self._from_ir(self._ir.cast(into_dtype(dtype))) def exclude(self, *names: OneOrIterable[str]) -> Self: - return self._from_ir(expr.Exclude.from_names(self._ir, *names)) + return self._from_ir(ir.Exclude.from_names(self._ir, *names)) def count(self) -> Self: return self._from_ir(agg.Count(expr=self._ir)) @@ -156,8 +156,8 @@ def over( descending: bool = False, nulls_last: bool = False, ) -> Self: - node: expr.WindowExpr | expr.OrderedWindowExpr - partition: Seq[ExprIR] = () + node: ir.WindowExpr | ir.OrderedWindowExpr + partition: Seq[ir.ExprIR] = () if not (partition_by) and order_by is None: msg = "At least one of `partition_by` or `order_by` must be specified." raise TypeError(msg) @@ -173,7 +173,7 @@ def over( def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: options = SortOptions(descending=descending, nulls_last=nulls_last) - return self._from_ir(expr.Sort(expr=self._ir, options=options)) + return self._from_ir(ir.Sort(expr=self._ir, options=options)) def sort_by( self, @@ -185,13 +185,13 @@ def sort_by( keys, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last ) - return self._from_ir(expr.SortBy(expr=self._ir, by=keys, options=opts)) + return self._from_ir(ir.SortBy(expr=self._ir, by=keys, options=opts)) def filter( self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any ) -> Self: by = parse_predicates_constraints_into_expr_ir(*predicates, **constraints) - return self._from_ir(expr.Filter(expr=self._ir, by=by)) + return self._from_ir(ir.Filter(expr=self._ir, by=by)) def _with_unary(self, function: Function, /) -> Self: return self._from_ir(function.to_function_expr(self._ir)) @@ -242,8 +242,8 @@ def fill_null( limit: int | None = None, ) -> Self: if strategy is None: - ir = parse_into_expr_ir(value, str_as_lit=True) - return self._from_ir(F.FillNull().to_function_expr(self._ir, ir)) + e = parse_into_expr_ir(value, str_as_lit=True) + return self._from_ir(F.FillNull().to_function_expr(self._ir, e)) return self._with_unary(F.FillNullWithStrategy(strategy=strategy, limit=limit)) def shift(self, n: int) -> Self: @@ -593,15 +593,15 @@ def str(self) -> ExprStringNamespace: class Selector(Expr): - _ir: expr.SelectorIR + _ir: ir.SelectorIR def __repr__(self) -> str: return f"nw._plan.Selector({self.version.name.lower()}):\n{self._ir!r}" @classmethod - def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override] + def _from_ir(cls, selector_ir: ir.SelectorIR, /) -> Self: # type: ignore[override] obj = cls.__new__(cls) - obj._ir = ir + obj._ir = selector_ir return obj def _to_expr(self) -> Expr: @@ -650,7 +650,7 @@ def __xor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: return self._to_expr() ^ other def __invert__(self) -> Self: - return self._from_ir(expr.InvertSelector(selector=self._ir)) + return self._from_ir(ir.InvertSelector(selector=self._ir)) def __add__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, type(self)): diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index e3f89c4c83..671bb336e1 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -15,6 +15,7 @@ Filter, FunctionExpr, IndexColumns, + InvertSelector, Len, Literal, Nth, @@ -48,6 +49,7 @@ "Filter", "FunctionExpr", "IndexColumns", + "InvertSelector", "KeepName", "Len", "Literal", diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 5a38289a4f..b8c64c2b84 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -3,10 +3,9 @@ import builtins import typing as t -from narwhals._plan import _guards, _parse +from narwhals._plan import _guards, _parse, expressions as ir from narwhals._plan.common import into_dtype, py_to_narwhals_dtype -from narwhals._plan.expressions import boolean, expr, functions as F -from narwhals._plan.expressions.expr import All, Len +from narwhals._plan.expressions import boolean, functions as F from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.expressions.ranges import IntRange from narwhals._plan.expressions.strings import ConcatStr @@ -23,13 +22,13 @@ def col(*names: str | t.Iterable[str]) -> Expr: flat = tuple(flatten(names)) - node = expr.col(flat[0]) if builtins.len(flat) == 1 else expr.cols(*flat) + node = ir.col(flat[0]) if builtins.len(flat) == 1 else ir.cols(*flat) return node.to_narwhals() def nth(*indices: int | t.Sequence[int]) -> Expr: flat = tuple(flatten(indices)) - node = expr.nth(flat[0]) if builtins.len(flat) == 1 else expr.index_columns(*flat) + node = ir.nth(flat[0]) if builtins.len(flat) == 1 else ir.index_columns(*flat) return node.to_narwhals() @@ -49,11 +48,11 @@ def lit( def len() -> Expr: - return Len().to_narwhals() + return ir.Len().to_narwhals() def all() -> Expr: - return All().to_narwhals() + return ir.All().to_narwhals() def exclude(*names: str | t.Iterable[str]) -> Expr: diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 7dd82debb0..b0d6e7f6c9 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -10,6 +10,7 @@ from itertools import chain from typing import TYPE_CHECKING, Literal, overload +from narwhals._plan._guards import is_column_ir, is_literal from narwhals._plan.common import IRNamespace from narwhals.exceptions import ComputeError from narwhals.utils import Version @@ -17,9 +18,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator - from typing_extensions import TypeIs - - from narwhals._plan.expressions import Column, ExprIR + from narwhals._plan.expressions import ExprIR class IRMetaNamespace(IRNamespace): @@ -29,7 +28,7 @@ def has_multiple_outputs(self) -> bool: return any(_has_multiple_outputs(e) for e in self._ir.iter_left()) def is_column(self) -> bool: - return is_column(self._ir) + return is_column_ir(self._ir) def is_column_selection(self, *, allow_aliasing: bool = False) -> bool: return all( @@ -77,23 +76,23 @@ def root_names(self) -> list[str]: return list(_expr_to_leaf_column_names_iter(self._ir)) -def _expr_to_leaf_column_names_iter(ir: ExprIR) -> Iterator[str]: - for e in _expr_to_leaf_column_exprs_iter(ir): +def _expr_to_leaf_column_names_iter(expr: ExprIR, /) -> Iterator[str]: + for e in _expr_to_leaf_column_exprs_iter(expr): result = _expr_to_leaf_column_name(e) if isinstance(result, str): yield result -def _expr_to_leaf_column_exprs_iter(ir: ExprIR) -> Iterator[ExprIR]: - from narwhals._plan import expressions as expr +def _expr_to_leaf_column_exprs_iter(expr: ExprIR, /) -> Iterator[ExprIR]: + from narwhals._plan import expressions as ir - for outer in ir.iter_root_names(): - if isinstance(outer, (expr.Column, expr.All)): + for outer in expr.iter_root_names(): + if isinstance(outer, (ir.Column, ir.All)): yield outer -def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: - leaves = list(_expr_to_leaf_column_exprs_iter(ir)) +def _expr_to_leaf_column_name(expr: ExprIR, /) -> str | ComputeError: + leaves = list(_expr_to_leaf_column_exprs_iter(expr)) if not len(leaves) <= 1: msg = "found more than one root column name" return ComputeError(msg) @@ -101,40 +100,42 @@ def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: msg = "no root column name found" return ComputeError(msg) leaf = leaves[0] - from narwhals._plan import expressions as expr + from narwhals._plan import expressions as ir - if isinstance(leaf, expr.Column): + if isinstance(leaf, ir.Column): return leaf.name - if isinstance(leaf, expr.All): + if isinstance(leaf, ir.All): msg = "wildcard has no root column name" return ComputeError(msg) msg = f"Expected unreachable, got {type(leaf).__name__!r}\n\n{leaf}" return ComputeError(msg) -def root_names_unique(irs: Iterable[ExprIR], /) -> set[str]: - return set(chain.from_iterable(_expr_to_leaf_column_names_iter(e) for e in irs)) +def root_names_unique(exprs: Iterable[ExprIR], /) -> set[str]: + return set(chain.from_iterable(_expr_to_leaf_column_names_iter(e) for e in exprs)) @lru_cache(maxsize=32) -def _expr_output_name(ir: ExprIR) -> str | ComputeError: - from narwhals._plan import expressions as expr +def _expr_output_name(expr: ExprIR, /) -> str | ComputeError: + from narwhals._plan import expressions as ir - for e in ir.iter_output_name(): - if isinstance(e, (expr.Column, expr.Alias, expr.Literal, expr.Len)): + for e in expr.iter_output_name(): + if isinstance(e, (ir.Column, ir.Alias, ir.Literal, ir.Len)): return e.name - if isinstance(e, (expr.All, expr.KeepName, expr.RenameAlias)): + if isinstance(e, (ir.All, ir.KeepName, ir.RenameAlias)): msg = "cannot determine output column without a context for this expression" return ComputeError(msg) - if isinstance(e, (expr.Columns, expr.IndexColumns, expr.Nth)): + if isinstance(e, (ir.Columns, ir.IndexColumns, ir.Nth)): msg = "this expression may produce multiple output names" return ComputeError(msg) continue - msg = f"unable to find root column name for expr '{ir!r}' when calling 'output_name'" + msg = ( + f"unable to find root column name for expr '{expr!r}' when calling 'output_name'" + ) return ComputeError(msg) -def get_single_leaf_name(ir: ExprIR) -> str | ComputeError: +def get_single_leaf_name(expr: ExprIR, /) -> str | ComputeError: """Find the name at the start of an expression. Normal iteration would just return the first root column it found. @@ -143,60 +144,54 @@ def get_single_leaf_name(ir: ExprIR) -> str | ComputeError: [`polars_plan::utils::get_single_leaf`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L151-L168 """ - from narwhals._plan import expressions as expr + from narwhals._plan import expressions as ir - for e in ir.iter_right(): - if isinstance(e, (expr.WindowExpr, expr.SortBy, expr.Filter)): + for e in expr.iter_right(): + if isinstance(e, (ir.WindowExpr, ir.SortBy, ir.Filter)): return get_single_leaf_name(e.expr) - if isinstance(e, expr.BinaryExpr): + if isinstance(e, ir.BinaryExpr): return get_single_leaf_name(e.left) # NOTE: `polars` doesn't include `Literal` here - if isinstance(e, (expr.Column, expr.Len)): + if isinstance(e, (ir.Column, ir.Len)): return e.name - msg = f"unable to find a single leaf column in expr '{ir!r}'" + msg = f"unable to find a single leaf column in expr '{expr!r}'" return ComputeError(msg) -def _has_multiple_outputs(ir: ExprIR) -> bool: - from narwhals._plan import expressions as expr +def _has_multiple_outputs(expr: ExprIR, /) -> bool: + from narwhals._plan import expressions as ir - return isinstance(ir, (expr.Columns, expr.IndexColumns, expr.SelectorIR, expr.All)) + return isinstance(expr, (ir.Columns, ir.IndexColumns, ir.SelectorIR, ir.All)) -def has_expr_ir(ir: ExprIR, *matches: type[ExprIR]) -> bool: +def has_expr_ir(expr: ExprIR, *matches: type[ExprIR]) -> bool: """Return True if any node in the tree is in type `matches`. Based on [`polars_plan::utils::has_expr`] [`polars_plan::utils::has_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L70-L77 """ - return any(isinstance(e, matches) for e in ir.iter_right()) - - -def is_column(ir: ExprIR) -> TypeIs[Column]: - from narwhals._plan.expressions.expr import Column - - return isinstance(ir, Column) + return any(isinstance(e, matches) for e in expr.iter_right()) -def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: - from narwhals._plan import expressions as expr +def _is_literal(expr: ExprIR, /, *, allow_aliasing: bool) -> bool: + from narwhals._plan import expressions as ir from narwhals._plan.expressions.literal import is_literal_scalar return ( - isinstance(ir, expr.Literal) - or (allow_aliasing and isinstance(ir, expr.Alias)) + is_literal(expr) + or (allow_aliasing and isinstance(expr, ir.Alias)) or ( - isinstance(ir, expr.Cast) - and is_literal_scalar(ir.expr) - and isinstance(ir.expr.dtype, Version.MAIN.dtypes.Datetime) + isinstance(expr, ir.Cast) + and is_literal_scalar(expr.expr) + and isinstance(expr.expr.dtype, Version.MAIN.dtypes.Datetime) ) ) -def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: - from narwhals._plan import expressions as expr +def _is_column_selection(expr: ExprIR, /, *, allow_aliasing: bool) -> bool: + from narwhals._plan import expressions as ir - return isinstance(ir, (expr.Column, expr._ColumnSelection, expr.SelectorIR)) or ( - allow_aliasing and isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)) + return isinstance(expr, (ir.Column, ir._ColumnSelection, ir.SelectorIR)) or ( + allow_aliasing and isinstance(expr, (ir.Alias, ir.KeepName, ir.RenameAlias)) ) diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 4bf4447b16..781087e9fc 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -28,8 +28,8 @@ def _from_expr(expr: Expr, /) -> When: return When(condition=expr._ir) @staticmethod - def _from_ir(ir: ExprIR, /) -> When: - return When(condition=ir) + def _from_ir(expr_ir: ExprIR, /) -> When: + return When(condition=expr_ir) class Then(Immutable, Expr): @@ -56,8 +56,8 @@ def _ir(self) -> ExprIR: # type: ignore[override] return self._otherwise() @classmethod - def _from_ir(cls, ir: ExprIR, /) -> Expr: # type: ignore[override] - return Expr._from_ir(ir) + def _from_ir(cls, expr_ir: ExprIR, /) -> Expr: # type: ignore[override] + return Expr._from_ir(expr_ir) def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] if is_expr(value): @@ -104,8 +104,8 @@ def _ir(self) -> ExprIR: # type: ignore[override] return self._otherwise() @classmethod - def _from_ir(cls, ir: ExprIR, /) -> Expr: # type: ignore[override] - return Expr._from_ir(ir) + def _from_ir(cls, expr_ir: ExprIR, /) -> Expr: # type: ignore[override] + return Expr._from_ir(expr_ir) def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] if is_expr(value): From 9b18b1363b003075d4c750db965f956f1d023f23 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 12:20:54 +0000 Subject: [PATCH 18/36] refactor: Even more import updates --- narwhals/_plan/arrow/expr.py | 37 ++++++++---------- narwhals/_plan/expressions/__init__.py | 2 + narwhals/_plan/protocols.py | 53 +++++++++++++------------- 3 files changed, 44 insertions(+), 48 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index bb715ea7c5..b5d3b85360 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -9,7 +9,7 @@ from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co -from narwhals._plan.common import ExprIR, NamedIR +from narwhals._plan.common import NamedIR from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch, namespace from narwhals._utils import ( Implementation, @@ -26,9 +26,10 @@ from typing_extensions import Self, TypeAlias from narwhals._arrow.typing import ChunkedArrayAny, Incomplete + from narwhals._plan import expressions as ir from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.expressions import boolean, expr + from narwhals._plan.expressions import boolean from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, @@ -53,15 +54,7 @@ IsNull, Not, ) - from narwhals._plan.expressions.expr import ( - AnonymousExpr, - BinaryExpr, - FunctionExpr, - OrderedWindowExpr, - RollingExpr, - TernaryExpr, - WindowExpr, - ) + from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr from narwhals._plan.expressions.functions import FillNull, Pow from narwhals.typing import Into1DArray, IntoDType, PythonLiteral @@ -81,7 +74,7 @@ def __narwhals_namespace__(self) -> ArrowNamespace: return ArrowNamespace(self.version) def _with_native(self, native: Any, name: str, /) -> StoresNativeT_co: ... - def cast(self, node: expr.Cast, frame: Frame, name: str) -> StoresNativeT_co: + def cast(self, node: ir.Cast, frame: Frame, name: str) -> StoresNativeT_co: data_type = narwhals_to_native_dtype(node.dtype, frame.version) native = node.expr.dispatch(self, frame, name).native return self._with_native(fn.cast(native, data_type), name) @@ -154,7 +147,7 @@ def binary_expr(self, node: BinaryExpr, frame: Frame, name: str) -> StoresNative return self._with_native(result, name) def ternary_expr( - self, node: TernaryExpr, frame: Frame, name: str + self, node: ir.TernaryExpr, frame: Frame, name: str ) -> StoresNativeT_co: when = node.predicate.dispatch(self, frame, name) then = node.truthy.dispatch(self, frame, name) @@ -199,7 +192,7 @@ def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Sel return ArrowScalar.from_native(result, name, version=self.version) return self.from_native(result, name or self.name, self.version) - def _dispatch_expr(self, node: ExprIR, frame: Frame, name: str) -> Series: + def _dispatch_expr(self, node: ir.ExprIR, frame: Frame, name: str) -> Series: """Use instead of `_dispatch` *iff* an operation isn't natively supported on `ChunkedArray`. There is no need to broadcast, as they may have a cheaper impl elsewhere (`CompliantScalar` or `ArrowScalar`). @@ -225,12 +218,12 @@ def broadcast(self, length: int, /) -> Series: def __len__(self) -> int: return len(self._evaluated) - def sort(self, node: expr.Sort, frame: Frame, name: str) -> Expr: + def sort(self, node: ir.Sort, frame: Frame, name: str) -> Expr: native = self._dispatch_expr(node.expr, frame, name).native sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) return self._with_native(native.take(sorted_indices), name) - def sort_by(self, node: expr.SortBy, frame: Frame, name: str) -> Expr: + def sort_by(self, node: ir.SortBy, frame: Frame, name: str) -> Expr: series = self._dispatch_expr(node.expr, frame, name) by = ( self._dispatch_expr(e, frame, f"_{idx}") @@ -242,7 +235,7 @@ def sort_by(self, node: expr.SortBy, frame: Frame, name: str) -> Expr: result: ChunkedArrayAny = df.native.column(0).take(indices) return self._with_native(result, name) - def filter(self, node: expr.Filter, frame: Frame, name: str) -> Expr: + def filter(self, node: ir.Filter, frame: Frame, name: str) -> Expr: return self._with_native( self._dispatch_expr(node.expr, frame, name).native.filter( self._dispatch_expr(node.by, frame, name).native @@ -326,11 +319,11 @@ def min(self, node: Min, frame: Frame, name: str) -> Scalar: # - [x] `map_batches` is defined in `EagerExpr`, might be simpler here than on main # - [ ] `rolling_expr` has 4 variants - def over(self, node: WindowExpr, frame: Frame, name: str) -> Self: + def over(self, node: ir.WindowExpr, frame: Frame, name: str) -> Self: raise NotImplementedError def over_ordered( - self, node: OrderedWindowExpr, frame: Frame, name: str + self, node: ir.OrderedWindowExpr, frame: Frame, name: str ) -> Self | Scalar: if node.partition_by: msg = f"Need to implement `group_by`, `join` for:\n{node!r}" @@ -354,7 +347,7 @@ def over_ordered( return self._with_native(result, name) # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` - def map_batches(self, node: AnonymousExpr, frame: Frame, name: str) -> Self: + def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: if node.is_scalar: # NOTE: Just trying to avoid redoing the whole API for `Series` msg = "Only elementwise is currently supported" @@ -368,7 +361,7 @@ def map_batches(self, node: AnonymousExpr, frame: Frame, name: str) -> Self: result = result.cast(dtype) return self.from_series(result) - def rolling_expr(self, node: RollingExpr, frame: Frame, name: str) -> Self: + def rolling_expr(self, node: ir.RollingExpr, frame: Frame, name: str) -> Self: raise NotImplementedError @@ -421,7 +414,7 @@ def from_series(cls, series: Series) -> Self: msg = f"Too long {len(series)!r}" raise InvalidOperationError(msg) - def _dispatch_expr(self, node: ExprIR, frame: Frame, name: str) -> Series: + def _dispatch_expr(self, node: ir.ExprIR, frame: Frame, name: str) -> Series: msg = f"Expected unreachable, but hit at: {node!r}" raise InvalidOperationError(msg) diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 671bb336e1..4ba67a03da 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -20,6 +20,7 @@ Literal, Nth, OrderedWindowExpr, + RangeExpr, RollingExpr, RootSelector, Sort, @@ -56,6 +57,7 @@ "Nth", "OrderableAggExpr", "OrderedWindowExpr", + "RangeExpr", "RenameAlias", "RollingExpr", "RootSelector", diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 9d94879759..fec4cad043 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe +from narwhals._plan.common import NamedIR, flatten_hash_safe from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar from narwhals._utils import Version @@ -11,15 +11,16 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs + from narwhals._plan import expressions as ir from narwhals._plan.dataframe import BaseFrame, DataFrame from narwhals._plan.expressions import ( + BinaryExpr, + FunctionExpr, aggregation as agg, boolean, - expr, functions as F, ) from narwhals._plan.expressions.boolean import IsBetween, IsFinite, IsNan, IsNull, Not - from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.expressions.ranges import IntRange from narwhals._plan.expressions.strings import ConcatStr from narwhals._plan.options import SortMultipleOptions @@ -168,13 +169,13 @@ def _length_required( class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): @classmethod - def from_ir(cls, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: + def from_ir(cls, node: ir.ExprIR, frame: FrameT_contra, name: str) -> R_co: obj = cls.__new__(cls) obj._version = frame.version return node.dispatch(obj, frame, name) @classmethod - def from_named_ir(cls, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> R_co: + def from_named_ir(cls, named_ir: NamedIR[ir.ExprIR], frame: FrameT_contra) -> R_co: return cls.from_ir(named_ir.expr, frame, named_ir.name) # NOTE: Needs to stay `covariant` and never be used as a parameter @@ -197,7 +198,7 @@ def _with_native(self, native: Any, name: str, /) -> Self: return self.from_native(native, name or self.name, self.version) # series & scalar - def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... + def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ... def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... def fill_null( @@ -217,24 +218,24 @@ def is_null( ) -> Self: ... def binary_expr(self, node: BinaryExpr, frame: FrameT_contra, name: str) -> Self: ... def ternary_expr( - self, node: expr.TernaryExpr, frame: FrameT_contra, name: str + self, node: ir.TernaryExpr, frame: FrameT_contra, name: str ) -> Self: ... - def over(self, node: expr.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... + def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` # e.g. `nw.col("a").first().over(order_by="b")` def over_ordered( - self, node: expr.OrderedWindowExpr, frame: FrameT_contra, name: str + self, node: ir.OrderedWindowExpr, frame: FrameT_contra, name: str ) -> Self | CompliantScalar[FrameT_contra, SeriesT_co]: ... def map_batches( - self, node: expr.AnonymousExpr, frame: FrameT_contra, name: str + self, node: ir.AnonymousExpr, frame: FrameT_contra, name: str ) -> Self: ... def rolling_expr( - self, node: expr.RollingExpr, frame: FrameT_contra, name: str + self, node: ir.RollingExpr, frame: FrameT_contra, name: str ) -> Self: ... # series only (section 3) - def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: ... - def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: ... - def filter(self, node: expr.Filter, frame: FrameT_contra, name: str) -> Self: ... + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: ... + def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: ... + def filter(self, node: ir.Filter, frame: FrameT_contra, name: str) -> Self: ... # series -> scalar def first( self, node: agg.First, frame: FrameT_contra, name: str @@ -334,7 +335,7 @@ def last(self, node: agg.Last, frame: FrameT_contra, name: str) -> Self: """Returns self.""" return self._with_evaluated(self._evaluated, name) - def _cast_float(self, node: ExprIR, frame: FrameT_contra, name: str) -> Self: + def _cast_float(self, node: ir.ExprIR, frame: FrameT_contra, name: str) -> Self: """`polars` interpolates a single scalar as a float.""" dtype = self.version.dtypes.Float64() return self.cast(node.cast(dtype), frame, name) @@ -372,10 +373,10 @@ def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: """Returns 0 if null, else 1.""" ... - def sort(self, node: expr.Sort, frame: FrameT_contra, name: str) -> Self: + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) - def sort_by(self, node: expr.SortBy, frame: FrameT_contra, name: str) -> Self: + def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) # NOTE: `Filter` behaves the same, (maybe) no need to override @@ -445,11 +446,11 @@ def _frame(self) -> type[FrameT]: ... def _expr(self) -> type[ExprT_co]: ... @property def _scalar(self) -> type[ScalarT_co]: ... - def col(self, node: expr.Column, frame: FrameT, name: str) -> ExprT_co: ... + def col(self, node: ir.Column, frame: FrameT, name: str) -> ExprT_co: ... def lit( - self, node: expr.Literal[Any], frame: FrameT, name: str + self, node: ir.Literal[Any], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... - def len(self, node: expr.Len, frame: FrameT, name: str) -> ScalarT_co: ... + def len(self, node: ir.Len, frame: FrameT, name: str) -> ScalarT_co: ... def any_horizontal( self, node: FunctionExpr[boolean.AnyHorizontal], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... @@ -472,7 +473,7 @@ def concat_str( self, node: FunctionExpr[ConcatStr], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... def int_range( - self, node: RangeExpr[IntRange], frame: FrameT, name: str + self, node: ir.RangeExpr[IntRange], frame: FrameT, name: str ) -> ExprT_co: ... @@ -497,16 +498,16 @@ def _is_dataframe(self, obj: Any) -> TypeIs[EagerDataFrameT]: @overload def lit( - self, node: expr.Literal[NonNestedLiteral], frame: EagerDataFrameT, name: str + self, node: ir.Literal[NonNestedLiteral], frame: EagerDataFrameT, name: str ) -> EagerScalarT_co: ... @overload def lit( - self, node: expr.Literal[Series[Any]], frame: EagerDataFrameT, name: str + self, node: ir.Literal[Series[Any]], frame: EagerDataFrameT, name: str ) -> EagerExprT_co: ... def lit( - self, node: expr.Literal[Any], frame: EagerDataFrameT, name: str + self, node: ir.Literal[Any], frame: EagerDataFrameT, name: str ) -> EagerExprT_co | EagerScalarT_co: ... - def len(self, node: expr.Len, frame: EagerDataFrameT, name: str) -> EagerScalarT_co: + def len(self, node: ir.Len, frame: EagerDataFrameT, name: str) -> EagerScalarT_co: return self._scalar.from_python( len(frame), name or node.name, dtype=None, version=frame.version ) @@ -548,7 +549,7 @@ def _with_native(self, native: NativeFrameT) -> Self: @property def schema(self) -> Mapping[str, DType]: ... def _evaluate_irs( - self, nodes: Iterable[NamedIR[ExprIR]], / + self, nodes: Iterable[NamedIR[ir.ExprIR]], / ) -> Iterator[ColumnT_co]: ... def select(self, irs: Seq[NamedIR]) -> Self: ... def with_columns(self, irs: Seq[NamedIR]) -> Self: ... From ba70130116fdbd612d476b0f3615a162e9e81bd5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 13:02:29 +0000 Subject: [PATCH 19/36] refactor: Rename `expr_rewrites` -> `_rewrites` --- narwhals/_plan/{expr_rewrites.py => _rewrites.py} | 0 tests/plan/expr_rewrites_test.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename narwhals/_plan/{expr_rewrites.py => _rewrites.py} (100%) diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/_rewrites.py similarity index 100% rename from narwhals/_plan/expr_rewrites.py rename to narwhals/_plan/_rewrites.py diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index bf46387d84..e4481c0ded 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -7,12 +7,12 @@ import narwhals as nw from narwhals._plan import _parse, functions as nwd from narwhals._plan._guards import is_expr -from narwhals._plan.common import ExprIR, NamedIR -from narwhals._plan.expr_rewrites import ( +from narwhals._plan._rewrites import ( rewrite_all, rewrite_binary_agg_over, rewrite_elementwise_over, ) +from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.expressions import selectors as ndcs from narwhals._plan.expressions.expr import WindowExpr from narwhals._plan.expressions.window import Over From 2a7079dbc19cb4d32a298937973e5afa03fff045 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 13:07:38 +0000 Subject: [PATCH 20/36] refactor: Rename `expr_expansion` -> `_expansion` --- narwhals/_plan/{expr_expansion.py => _expansion.py} | 0 narwhals/_plan/_rewrites.py | 4 ++-- narwhals/_plan/dataframe.py | 10 +++++----- tests/plan/expr_expansion_test.py | 7 +++---- 4 files changed, 10 insertions(+), 11 deletions(-) rename narwhals/_plan/{expr_expansion.py => _expansion.py} (100%) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/_expansion.py similarity index 100% rename from narwhals/_plan/expr_expansion.py rename to narwhals/_plan/_expansion.py diff --git a/narwhals/_plan/_rewrites.py b/narwhals/_plan/_rewrites.py index 655632738e..de32ef69ca 100644 --- a/narwhals/_plan/_rewrites.py +++ b/narwhals/_plan/_rewrites.py @@ -1,9 +1,10 @@ -"""Post-`expr_expansion` rewrites, in a similar style.""" +"""Post-`_expansion` rewrites, in a similar style.""" from __future__ import annotations from typing import TYPE_CHECKING +from narwhals._plan._expansion import into_named_irs, prepare_projection from narwhals._plan._guards import ( is_aggregation, is_binary_expr, @@ -12,7 +13,6 @@ ) from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.common import NamedIR, map_ir, replace -from narwhals._plan.expr_expansion import into_named_irs, prepare_projection if TYPE_CHECKING: from collections.abc import Sequence diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index ebfceb093e..040f598b35 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload -from narwhals._plan import _parse, expr_expansion +from narwhals._plan import _expansion, _parse from narwhals._plan.contexts import ExprContext from narwhals._plan.expr import _parse_sort_by from narwhals._plan.series import Series @@ -69,10 +69,10 @@ def _project( /, ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: """Temp, while these parts aren't connected, this is easier for testing.""" - irs, schema_frozen, output_names = expr_expansion.prepare_projection( + irs, schema_frozen, output_names = _expansion.prepare_projection( _parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs), self.schema ) - named_irs = expr_expansion.into_named_irs(irs, output_names) + named_irs = _expansion.into_named_irs(irs, output_names) return schema_frozen.project(named_irs, context) def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: @@ -93,8 +93,8 @@ def sort( sort, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last ) - irs, _, output_names = expr_expansion.prepare_projection(sort, self.schema) - named_irs = expr_expansion.into_named_irs(irs, output_names) + irs, _, output_names = _expansion.prepare_projection(sort, self.schema) + named_irs = _expansion.into_named_irs(irs, output_names) return self._from_compliant(self._compliant.sort(named_irs, opts)) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index dbf89b0b97..7067c52fd9 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -7,14 +7,13 @@ import narwhals as nw from narwhals._plan import functions as nwd -from narwhals._plan._parse import parse_into_seq_of_expr_ir -from narwhals._plan.expr_expansion import ( +from narwhals._plan._expansion import ( prepare_projection, replace_selector, rewrite_special_aliases, ) -from narwhals._plan.expressions import selectors as ndcs -from narwhals._plan.expressions.expr import Alias, Columns +from narwhals._plan._parse import parse_into_seq_of_expr_ir +from narwhals._plan.expressions import Alias, Columns, selectors as ndcs from narwhals._plan.schema import freeze_schema from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError from tests.plan.utils import assert_expr_ir_equal From cc3fdf6a74fa29e0f99a39d764a6174d162896d9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 13:10:26 +0000 Subject: [PATCH 21/36] chore: export some common modules --- narwhals/_plan/expressions/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 4ba67a03da..cf6323224a 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations from narwhals._plan.common import ExprIR, SelectorIR # prob should move into package? +from narwhals._plan.expressions import aggregation, functions, operators, selectors from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr from narwhals._plan.expressions.expr import ( Alias, @@ -67,8 +68,12 @@ "TernaryExpr", "WindowExpr", "_ColumnSelection", + "aggregation", "col", "cols", + "functions", "index_columns", "nth", + "operators", + "selectors", ] From 800e77228d5a8aba0b6064814d3d901d14b5d8b6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 13:21:18 +0000 Subject: [PATCH 22/36] refactor: Make `meta` depend on `expressions` --- narwhals/_plan/_guards.py | 4 ---- narwhals/_plan/meta.py | 43 +++++++++++++-------------------------- 2 files changed, 14 insertions(+), 33 deletions(-) diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index e8b27ed376..0f62942ab8 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -63,10 +63,6 @@ def is_column(obj: Any) -> TypeIs[Expr]: return is_expr(obj) and obj.meta.is_column() -def is_column_ir(obj: Any) -> TypeIs[ir.Column]: - return isinstance(obj, _ir().Column) - - def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]: return isinstance(obj, _series().Series) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index b0d6e7f6c9..715011175e 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -10,16 +10,16 @@ from itertools import chain from typing import TYPE_CHECKING, Literal, overload -from narwhals._plan._guards import is_column_ir, is_literal +from narwhals._plan import expressions as ir +from narwhals._plan._guards import is_literal from narwhals._plan.common import IRNamespace +from narwhals._plan.expressions.literal import is_literal_scalar from narwhals.exceptions import ComputeError from narwhals.utils import Version if TYPE_CHECKING: from collections.abc import Iterable, Iterator - from narwhals._plan.expressions import ExprIR - class IRMetaNamespace(IRNamespace): """Methods to modify and traverse existing expressions.""" @@ -28,7 +28,7 @@ def has_multiple_outputs(self) -> bool: return any(_has_multiple_outputs(e) for e in self._ir.iter_left()) def is_column(self) -> bool: - return is_column_ir(self._ir) + return isinstance(self._ir, ir.Column) def is_column_selection(self, *, allow_aliasing: bool = False) -> bool: return all( @@ -76,22 +76,20 @@ def root_names(self) -> list[str]: return list(_expr_to_leaf_column_names_iter(self._ir)) -def _expr_to_leaf_column_names_iter(expr: ExprIR, /) -> Iterator[str]: +def _expr_to_leaf_column_names_iter(expr: ir.ExprIR, /) -> Iterator[str]: for e in _expr_to_leaf_column_exprs_iter(expr): result = _expr_to_leaf_column_name(e) if isinstance(result, str): yield result -def _expr_to_leaf_column_exprs_iter(expr: ExprIR, /) -> Iterator[ExprIR]: - from narwhals._plan import expressions as ir - +def _expr_to_leaf_column_exprs_iter(expr: ir.ExprIR, /) -> Iterator[ir.ExprIR]: for outer in expr.iter_root_names(): if isinstance(outer, (ir.Column, ir.All)): yield outer -def _expr_to_leaf_column_name(expr: ExprIR, /) -> str | ComputeError: +def _expr_to_leaf_column_name(expr: ir.ExprIR, /) -> str | ComputeError: leaves = list(_expr_to_leaf_column_exprs_iter(expr)) if not len(leaves) <= 1: msg = "found more than one root column name" @@ -100,8 +98,6 @@ def _expr_to_leaf_column_name(expr: ExprIR, /) -> str | ComputeError: msg = "no root column name found" return ComputeError(msg) leaf = leaves[0] - from narwhals._plan import expressions as ir - if isinstance(leaf, ir.Column): return leaf.name if isinstance(leaf, ir.All): @@ -111,14 +107,12 @@ def _expr_to_leaf_column_name(expr: ExprIR, /) -> str | ComputeError: return ComputeError(msg) -def root_names_unique(exprs: Iterable[ExprIR], /) -> set[str]: +def root_names_unique(exprs: Iterable[ir.ExprIR], /) -> set[str]: return set(chain.from_iterable(_expr_to_leaf_column_names_iter(e) for e in exprs)) @lru_cache(maxsize=32) -def _expr_output_name(expr: ExprIR, /) -> str | ComputeError: - from narwhals._plan import expressions as ir - +def _expr_output_name(expr: ir.ExprIR, /) -> str | ComputeError: for e in expr.iter_output_name(): if isinstance(e, (ir.Column, ir.Alias, ir.Literal, ir.Len)): return e.name @@ -135,7 +129,7 @@ def _expr_output_name(expr: ExprIR, /) -> str | ComputeError: return ComputeError(msg) -def get_single_leaf_name(expr: ExprIR, /) -> str | ComputeError: +def get_single_leaf_name(expr: ir.ExprIR, /) -> str | ComputeError: """Find the name at the start of an expression. Normal iteration would just return the first root column it found. @@ -144,8 +138,6 @@ def get_single_leaf_name(expr: ExprIR, /) -> str | ComputeError: [`polars_plan::utils::get_single_leaf`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L151-L168 """ - from narwhals._plan import expressions as ir - for e in expr.iter_right(): if isinstance(e, (ir.WindowExpr, ir.SortBy, ir.Filter)): return get_single_leaf_name(e.expr) @@ -158,13 +150,11 @@ def get_single_leaf_name(expr: ExprIR, /) -> str | ComputeError: return ComputeError(msg) -def _has_multiple_outputs(expr: ExprIR, /) -> bool: - from narwhals._plan import expressions as ir - +def _has_multiple_outputs(expr: ir.ExprIR, /) -> bool: return isinstance(expr, (ir.Columns, ir.IndexColumns, ir.SelectorIR, ir.All)) -def has_expr_ir(expr: ExprIR, *matches: type[ExprIR]) -> bool: +def has_expr_ir(expr: ir.ExprIR, *matches: type[ir.ExprIR]) -> bool: """Return True if any node in the tree is in type `matches`. Based on [`polars_plan::utils::has_expr`] @@ -174,10 +164,7 @@ def has_expr_ir(expr: ExprIR, *matches: type[ExprIR]) -> bool: return any(isinstance(e, matches) for e in expr.iter_right()) -def _is_literal(expr: ExprIR, /, *, allow_aliasing: bool) -> bool: - from narwhals._plan import expressions as ir - from narwhals._plan.expressions.literal import is_literal_scalar - +def _is_literal(expr: ir.ExprIR, /, *, allow_aliasing: bool) -> bool: return ( is_literal(expr) or (allow_aliasing and isinstance(expr, ir.Alias)) @@ -189,9 +176,7 @@ def _is_literal(expr: ExprIR, /, *, allow_aliasing: bool) -> bool: ) -def _is_column_selection(expr: ExprIR, /, *, allow_aliasing: bool) -> bool: - from narwhals._plan import expressions as ir - +def _is_column_selection(expr: ir.ExprIR, /, *, allow_aliasing: bool) -> bool: return isinstance(expr, (ir.Column, ir._ColumnSelection, ir.SelectorIR)) or ( allow_aliasing and isinstance(expr, (ir.Alias, ir.KeepName, ir.RenameAlias)) ) From a6a130dc940c6f12a74304c65593479f3cc19115 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 16:24:11 +0000 Subject: [PATCH 23/36] refactor: Move `common.map_ir` -> `_rewrites` Not referenced anywhere else --- narwhals/_plan/_rewrites.py | 16 ++++++++++++++-- narwhals/_plan/common.py | 13 ------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/narwhals/_plan/_rewrites.py b/narwhals/_plan/_rewrites.py index de32ef69ca..83e86b74c0 100644 --- a/narwhals/_plan/_rewrites.py +++ b/narwhals/_plan/_rewrites.py @@ -12,14 +12,14 @@ is_window_expr, ) from narwhals._plan._parse import parse_into_seq_of_expr_ir -from narwhals._plan.common import NamedIR, map_ir, replace +from narwhals._plan.common import NamedIR, replace if TYPE_CHECKING: from collections.abc import Sequence from narwhals._plan.common import ExprIR from narwhals._plan.schema import IntoFrozenSchema - from narwhals._plan.typing import IntoExpr, MapIR, Seq + from narwhals._plan.typing import IntoExpr, MapIR, NamedOrExprIRT, Seq def rewrite_all( @@ -83,3 +83,15 @@ def rewrite_binary_agg_over(window: ExprIR, /) -> ExprIR: binary_expr = window.expr return replace(binary_expr, right=replace(window, expr=binary_expr.right)) return window + + +def map_ir( + origin: NamedOrExprIRT, function: MapIR, *more_functions: MapIR +) -> NamedOrExprIRT: + """Apply one or more functions, sequentially, to all of `origin`'s children.""" + if more_functions: + result = origin + for fn in (function, *more_functions): + result = result.map_ir(fn) + return result + return origin.map_ir(function) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index f0d1a7112e..8e13c91031 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -19,7 +19,6 @@ FunctionT, IRNamespaceT, MapIR, - NamedOrExprIRT, NonNestedDTypeT, OneOrIterable, Seq, @@ -478,18 +477,6 @@ def collect(iterable: Seq[T] | Iterable[T], /) -> Seq[T]: return iterable if isinstance(iterable, tuple) else tuple(iterable) -def map_ir( - origin: NamedOrExprIRT, function: MapIR, *more_functions: MapIR -) -> NamedOrExprIRT: - """Apply one or more functions, sequentially, to all of `origin`'s children.""" - if more_functions: - result = origin - for fn in (function, *more_functions): - result = result.map_ir(fn) - return result - return origin.map_ir(function) - - def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) From aab0f1d28eb36364a3696aaedeb4458e84fef413 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 16:40:06 +0000 Subject: [PATCH 24/36] refactor: remove `collect` Had more use prior to (#3066) --- narwhals/_plan/arrow/namespace.py | 3 +-- narwhals/_plan/common.py | 5 ----- narwhals/_plan/expressions/expr.py | 7 +++---- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index fc67bbc36e..e4f68f27db 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -9,7 +9,6 @@ from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan._guards import is_tuple_of from narwhals._plan.arrow import functions as fn -from narwhals._plan.common import collect from narwhals._plan.expressions.literal import is_literal_scalar from narwhals._plan.protocols import EagerNamespace from narwhals._utils import Version @@ -225,7 +224,7 @@ def gen(objs: Iterable[Frame | Series]) -> Iterator[tuple[ChunkedArrayAny, str]] return self._dataframe.from_native(native, self.version) def _concat_vertical(self, items: Iterable[Frame | Series]) -> Frame | Series: - collected = collect(items) + collected = items if isinstance(items, tuple) else tuple(items) if is_tuple_of(collected, self._series): sers = collected chunked = fn.concat_vertical_chunked(ser.native for ser in sers) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 8e13c91031..27346905b7 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -472,11 +472,6 @@ def into_dtype(dtype: DTypeT | type[NonNestedDTypeT], /) -> DTypeT | NonNestedDT return dtype -def collect(iterable: Seq[T] | Iterable[T], /) -> Seq[T]: - """Collect `iterable` into a `tuple`, *iff* it is not one already.""" - return iterable if isinstance(iterable, tuple) else tuple(iterable) - - def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 083173549e..d33b603ccf 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -6,7 +6,7 @@ # - Literal import typing as t -from narwhals._plan.common import ExprIR, SelectorIR, collect +from narwhals._plan.common import ExprIR, SelectorIR, flatten_hash_safe from narwhals._plan.exceptions import function_expr_invalid_operation_error from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( @@ -23,7 +23,6 @@ SelectorT, Seq, ) -from narwhals._utils import flatten from narwhals.exceptions import InvalidOperationError if t.TYPE_CHECKING: @@ -143,8 +142,8 @@ class Exclude(_ColumnSelection, child=("expr",)): @staticmethod def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: - flat = flatten(names) - return Exclude(expr=expr, names=collect(flat)) + flat: t.Iterator[str] = flatten_hash_safe(names) + return Exclude(expr=expr, names=tuple(flat)) def __repr__(self) -> str: return f"{self.expr!r}.exclude({list(self.names)!r})" From 391311127a2a4bfecd4b1459d86a067f01aac8ab Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 17:08:32 +0000 Subject: [PATCH 25/36] tweak `into_dtype` --- narwhals/_plan/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 27346905b7..23c51502ee 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -466,9 +466,9 @@ def into_dtype(dtype: type[NonNestedDTypeT], /) -> NonNestedDTypeT: ... @overload def into_dtype(dtype: DTypeT, /) -> DTypeT: ... def into_dtype(dtype: DTypeT | type[NonNestedDTypeT], /) -> DTypeT | NonNestedDTypeT: + # NOTE: `mypy` needs to learn intersections if isinstance(dtype, type) and issubclass(dtype, DType): - # NOTE: `mypy` needs to learn intersections - return dtype() # type: ignore[return-value] + return cast("NonNestedDTypeT", dtype()) return dtype From 4e39a6d3d92941453affe36c8f4ea2e4b07c25b5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 17:20:56 +0000 Subject: [PATCH 26/36] refactor: more `common` removal prep --- narwhals/_plan/common.py | 12 ++++++------ narwhals/_plan/expr.py | 9 ++++----- narwhals/_plan/expressions/aggregation.py | 4 ++-- narwhals/_plan/functions.py | 9 ++++----- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 23c51502ee..4549abc34e 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -56,7 +56,7 @@ def replace(obj: T, /, **changes: Any) -> T: Incomplete: TypeAlias = "Any" -def _pascal_to_snake_case(s: str) -> str: +def pascal_to_snake_case(s: str) -> str: """Convert a PascalCase, camelCase string to snake_case. Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 @@ -77,7 +77,7 @@ def _re_repl_snake(match: re.Match[str], /) -> str: def _dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: config = tp.__expr_ir_config__ - name = config.override_name or _pascal_to_snake_case(tp.__name__) + name = config.override_name or pascal_to_snake_case(tp.__name__) return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name @@ -120,6 +120,10 @@ def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: return _ +def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: + return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) + + class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" @@ -472,10 +476,6 @@ def into_dtype(dtype: DTypeT | type[NonNestedDTypeT], /) -> DTypeT | NonNestedDT return dtype -def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: - return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) - - # TODO @dangotbanned: Review again and try to work around (https://github.com/microsoft/pyright/issues/10673#issuecomment-3033789021) # The issue is `T` possibly being `Iterable` # Ignoring here still leaks the issue to the caller, where you need to annotate the base case diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 9a0210af2e..06e457c54b 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -4,14 +4,13 @@ from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, ClassVar, overload -from narwhals._plan import expressions as ir +from narwhals._plan import common, expressions as ir from narwhals._plan._guards import is_column, is_expr, is_series from narwhals._plan._parse import ( parse_into_expr_ir, parse_into_seq_of_expr_ir, parse_predicates_constraints_into_expr_ir, ) -from narwhals._plan.common import into_dtype from narwhals._plan.expressions import ( aggregation as agg, boolean, @@ -98,7 +97,7 @@ def alias(self, name: str) -> Self: return self._from_ir(self._ir.alias(name)) def cast(self, dtype: IntoDType) -> Self: - return self._from_ir(self._ir.cast(into_dtype(dtype))) + return self._from_ir(self._ir.cast(common.into_dtype(dtype))) def exclude(self, *names: OneOrIterable[str]) -> Self: return self._from_ir(ir.Exclude.from_names(self._ir, *names)) @@ -372,7 +371,7 @@ def replace_strict( before = tuple(old) after = tuple(new) if return_dtype is not None: - return_dtype = into_dtype(return_dtype) + return_dtype = common.into_dtype(return_dtype) function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype) return self._with_unary(function) @@ -388,7 +387,7 @@ def map_batches( returns_scalar: bool = False, ) -> Self: if return_dtype is not None: - return_dtype = into_dtype(return_dtype) + return_dtype = common.into_dtype(return_dtype) return self._with_unary( F.MapBatches( function=function, diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index b1f47ca1d7..a0ba57f7f1 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan.common import ExprIR, _pascal_to_snake_case +from narwhals._plan.common import ExprIR, pascal_to_snake_case from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: @@ -20,7 +20,7 @@ def is_scalar(self) -> bool: return True def __repr__(self) -> str: - return f"{self.expr!r}.{_pascal_to_snake_case(type(self).__name__)}()" + return f"{self.expr!r}.{pascal_to_snake_case(type(self).__name__)}()" def iter_output_name(self) -> Iterator[ExprIR]: yield from self.expr.iter_output_name() diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index b8c64c2b84..57ea8c4492 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -3,8 +3,7 @@ import builtins import typing as t -from narwhals._plan import _guards, _parse, expressions as ir -from narwhals._plan.common import into_dtype, py_to_narwhals_dtype +from narwhals._plan import _guards, _parse, common, expressions as ir from narwhals._plan.expressions import boolean, functions as F from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.expressions.ranges import IntRange @@ -41,9 +40,9 @@ def lit( msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." raise TypeError(msg) if dtype is None: - dtype = py_to_narwhals_dtype(value, Version.MAIN) + dtype = common.py_to_narwhals_dtype(value, Version.MAIN) else: - dtype = into_dtype(dtype) + dtype = common.into_dtype(dtype) return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals() @@ -156,7 +155,7 @@ def int_range( msg = f"{eager=}" raise NotImplementedError(msg) return ( - IntRange(step=step, dtype=into_dtype(dtype)) + IntRange(step=step, dtype=common.into_dtype(dtype)) .to_function_expr(*_parse.parse_into_seq_of_expr_ir(start, end)) .to_narwhals() ) From 9615a6cfa7e4a6d8a19169bddc9ae9a21311115b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 17:23:18 +0000 Subject: [PATCH 27/36] rename `IRMetaNamespace` -> `MetaNamespace` don't need an `Expr` version, and this name gets exposed --- narwhals/_plan/common.py | 8 ++++---- narwhals/_plan/expr.py | 8 ++++---- narwhals/_plan/meta.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 4549abc34e..ad9a51f336 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -34,7 +34,7 @@ from narwhals._plan.expr import Expr, Selector from narwhals._plan.expressions.expr import Alias, Cast, Column, FunctionExpr - from narwhals._plan.meta import IRMetaNamespace + from narwhals._plan.meta import MetaNamespace from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -263,10 +263,10 @@ def iter_output_name(self) -> Iterator[ExprIR]: yield from self.iter_right() @property - def meta(self) -> IRMetaNamespace: - from narwhals._plan.meta import IRMetaNamespace + def meta(self) -> MetaNamespace: + from narwhals._plan.meta import MetaNamespace - return IRMetaNamespace(_ir=self) + return MetaNamespace(_ir=self) def cast(self, dtype: DType) -> Cast: from narwhals._plan.expressions.expr import Cast diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 06e457c54b..b9a3eafc70 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -39,7 +39,7 @@ from narwhals._plan.expressions.strings import ExprStringNamespace from narwhals._plan.expressions.struct import ExprStructNamespace from narwhals._plan.expressions.temporal import ExprDateTimeNamespace - from narwhals._plan.meta import IRMetaNamespace + from narwhals._plan.meta import MetaNamespace from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq, Udf from narwhals.typing import ( ClosedInterval, @@ -540,10 +540,10 @@ def __invert__(self) -> Self: return self._with_unary(boolean.Not()) @property - def meta(self) -> IRMetaNamespace: - from narwhals._plan.meta import IRMetaNamespace + def meta(self) -> MetaNamespace: + from narwhals._plan.meta import MetaNamespace - return IRMetaNamespace.from_expr(self) + return MetaNamespace.from_expr(self) @property def name(self) -> ExprNameNamespace: diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 715011175e..4d67eb60fc 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -21,7 +21,7 @@ from collections.abc import Iterable, Iterator -class IRMetaNamespace(IRNamespace): +class MetaNamespace(IRNamespace): """Methods to modify and traverse existing expressions.""" def has_multiple_outputs(self) -> bool: From 9121d788f7285558e3ce3d7afe32ffff6c145efe Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 17:45:52 +0000 Subject: [PATCH 28/36] refactor: Split out `_function.py` --- narwhals/_plan/_function.py | 83 +++++++++++++++++++++++ narwhals/_plan/common.py | 70 +------------------ narwhals/_plan/exceptions.py | 2 +- narwhals/_plan/expr.py | 2 +- narwhals/_plan/expressions/boolean.py | 2 +- narwhals/_plan/expressions/categorical.py | 3 +- narwhals/_plan/expressions/functions.py | 2 +- narwhals/_plan/expressions/lists.py | 3 +- narwhals/_plan/expressions/ranges.py | 3 +- narwhals/_plan/expressions/strings.py | 3 +- narwhals/_plan/expressions/struct.py | 3 +- narwhals/_plan/expressions/temporal.py | 3 +- narwhals/_plan/typing.py | 3 +- tests/plan/expr_parsing_test.py | 3 +- 14 files changed, 106 insertions(+), 79 deletions(-) create mode 100644 narwhals/_plan/_function.py diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py new file mode 100644 index 0000000000..59605d63d9 --- /dev/null +++ b/narwhals/_plan/_function.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar + +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import _dispatch_getter, _dispatch_method_name, replace +from narwhals._plan.options import FEOptions, FunctionOptions + +if TYPE_CHECKING: + from typing import Any, Callable + + from typing_extensions import Self, TypeAlias + + from narwhals._plan.expressions import ExprIR, FunctionExpr + from narwhals._plan.typing import Accessor, FunctionT + +__all__ = ["Function", "HorizontalFunction"] + +Incomplete: TypeAlias = "Any" + + +def _dispatch_generate_function( + tp: type[FunctionT], / +) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: + getter = _dispatch_getter(tp) + + def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ + + +class Function(Immutable): + """Shared by expr functions and namespace functions. + + https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 + """ + + _function_options: ClassVar[staticmethod[[], FunctionOptions]] = staticmethod( + FunctionOptions.default + ) + __expr_ir_config__: ClassVar[FEOptions] = FEOptions.default() + __expr_ir_dispatch__: ClassVar[ + staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete] + ] + + @property + def function_options(self) -> FunctionOptions: + return self._function_options() + + @property + def is_scalar(self) -> bool: + return self.function_options.returns_scalar() + + def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: + from narwhals._plan.expressions.expr import FunctionExpr + + return FunctionExpr(input=inputs, function=self, options=self.function_options) + + def __init_subclass__( + cls: type[Self], + *args: Any, + accessor: Accessor | None = None, + options: Callable[[], FunctionOptions] | None = None, + config: FEOptions | None = None, + **kwds: Any, + ) -> None: + super().__init_subclass__(*args, **kwds) + if accessor: + config = replace(config or FEOptions.default(), accessor_name=accessor) + if options: + cls._function_options = staticmethod(options) + if config: + cls.__expr_ir_config__ = config + cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) + + def __repr__(self) -> str: + return _dispatch_method_name(type(self)) + + +class HorizontalFunction( + Function, options=FunctionOptions.horizontal, config=FEOptions.namespaced() +): ... diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index ad9a51f336..b6e76c5c42 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -10,9 +10,8 @@ from narwhals._plan._guards import is_function_expr, is_iterable_reject, is_literal from narwhals._plan._immutable import Immutable -from narwhals._plan.options import ExprIROptions, FEOptions, FunctionOptions +from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( - Accessor, DTypeT, ExprIRT, ExprIRT2, @@ -32,8 +31,9 @@ from typing_extensions import Self, TypeAlias + from narwhals._plan._function import Function from narwhals._plan.expr import Expr, Selector - from narwhals._plan.expressions.expr import Alias, Cast, Column, FunctionExpr + from narwhals._plan.expressions.expr import Alias, Cast, Column from narwhals._plan.meta import MetaNamespace from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -109,17 +109,6 @@ def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: return _ -def _dispatch_generate_function( - tp: type[FunctionT], / -) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: - getter = _dispatch_getter(tp) - - def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - - def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) @@ -394,59 +383,6 @@ def _with_unary(self, function: Function, /) -> Expr: return self._expr._with_unary(function) -class Function(Immutable): - """Shared by expr functions and namespace functions. - - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 - """ - - _function_options: ClassVar[staticmethod[[], FunctionOptions]] = staticmethod( - FunctionOptions.default - ) - __expr_ir_config__: ClassVar[FEOptions] = FEOptions.default() - __expr_ir_dispatch__: ClassVar[ - staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete] - ] - - @property - def function_options(self) -> FunctionOptions: - return self._function_options() - - @property - def is_scalar(self) -> bool: - return self.function_options.returns_scalar() - - def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: - from narwhals._plan.expressions.expr import FunctionExpr - - return FunctionExpr(input=inputs, function=self, options=self.function_options) - - def __init_subclass__( - cls: type[Self], - *args: Any, - accessor: Accessor | None = None, - options: Callable[[], FunctionOptions] | None = None, - config: FEOptions | None = None, - **kwds: Any, - ) -> None: - super().__init_subclass__(*args, **kwds) - if accessor: - config = replace(config or FEOptions.default(), accessor_name=accessor) - if options: - cls._function_options = staticmethod(options) - if config: - cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) - - def __repr__(self) -> str: - return _dispatch_method_name(type(self)) - - -class HorizontalFunction( - Function, options=FunctionOptions.horizontal, config=FEOptions.namespaced() -): ... - - def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes mapping: dict[type[NonNestedLiteral], type[NonNestedDType]] = { diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index e378f8c80c..8f4348aaa3 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -25,7 +25,7 @@ import polars as pl from narwhals._plan import expressions as ir - from narwhals._plan.common import Function + from narwhals._plan._function import Function from narwhals._plan.expressions.operators import Operator from narwhals._plan.options import SortOptions from narwhals._plan.typing import IntoExpr, Seq diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index b9a3eafc70..c6242d8503 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from typing_extensions import Never, Self - from narwhals._plan.common import Function + from narwhals._plan._function import Function from narwhals._plan.expressions.categorical import ExprCatNamespace from narwhals._plan.expressions.lists import ExprListNamespace from narwhals._plan.expressions.name import ExprNameNamespace diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index 49ad2bd2ca..a11ff4569e 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -4,7 +4,7 @@ # - Any import typing as t -from narwhals._plan.common import Function, HorizontalFunction +from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.options import FEOptions, FunctionOptions from narwhals._typing_compat import TypeVar diff --git a/narwhals/_plan/expressions/categorical.py b/narwhals/_plan/expressions/categorical.py index d89e3da75d..5b3849bcdb 100644 --- a/narwhals/_plan/expressions/categorical.py +++ b/narwhals/_plan/expressions/categorical.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, ClassVar -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan._function import Function +from narwhals._plan.common import ExprNamespace, IRNamespace if TYPE_CHECKING: from narwhals._plan.expr import Expr diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index b936b602ce..94c9cebd6c 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import Function, HorizontalFunction +from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.exceptions import hist_bins_monotonic_error from narwhals._plan.options import FunctionFlags, FunctionOptions diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 168f50dadf..7e50458c25 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, ClassVar -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan._function import Function +from narwhals._plan.common import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: diff --git a/narwhals/_plan/expressions/ranges.py b/narwhals/_plan/expressions/ranges.py index 7dc0faa42e..f73b85f7e3 100644 --- a/narwhals/_plan/expressions/ranges.py +++ b/narwhals/_plan/expressions/ranges.py @@ -2,12 +2,13 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import ExprIR, Function +from narwhals._plan._function import Function from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: from typing_extensions import Self + from narwhals._plan.common import ExprIR from narwhals._plan.expressions.expr import RangeExpr from narwhals.dtypes import IntegerType diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 94812a8c8f..bd1e6eda97 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, ClassVar -from narwhals._plan.common import ExprNamespace, Function, HorizontalFunction, IRNamespace +from narwhals._plan._function import Function, HorizontalFunction +from narwhals._plan.common import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: diff --git a/narwhals/_plan/expressions/struct.py b/narwhals/_plan/expressions/struct.py index b978ed4295..02009b12e8 100644 --- a/narwhals/_plan/expressions/struct.py +++ b/narwhals/_plan/expressions/struct.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, ClassVar -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan._function import Function +from narwhals._plan.common import ExprNamespace, IRNamespace from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: diff --git a/narwhals/_plan/expressions/temporal.py b/narwhals/_plan/expressions/temporal.py index e17bdbfb22..6aa5a61c8c 100644 --- a/narwhals/_plan/expressions/temporal.py +++ b/narwhals/_plan/expressions/temporal.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal from narwhals._duration import Interval -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan._function import Function +from narwhals._plan.common import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 3f1a9ea313..97afc235da 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -8,7 +8,8 @@ from typing_extensions import TypeAlias from narwhals import dtypes - from narwhals._plan.common import ExprIR, Function, IRNamespace, NamedIR, SelectorIR + from narwhals._plan._function import Function + from narwhals._plan.common import ExprIR, IRNamespace, NamedIR, SelectorIR from narwhals._plan.expr import Expr from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index fd63c4a355..00870e0084 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -12,7 +12,7 @@ import narwhals as nw import narwhals._plan.functions as nwd from narwhals._plan._parse import parse_into_seq_of_expr_ir -from narwhals._plan.common import ExprIR, Function +from narwhals._plan.common import ExprIR from narwhals._plan.expr import Expr from narwhals._plan.expressions import boolean, expr, functions as F, operators as ops from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr @@ -32,6 +32,7 @@ from typing_extensions import TypeAlias + from narwhals._plan._function import Function from narwhals._plan.typing import IntoExpr, IntoExprColumn, OperatorFn, Seq From 11f5612b05452a1bcd6986286f06925f7b4eb2dd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 18:51:40 +0000 Subject: [PATCH 29/36] refactor: Split out `expressions.namespace.py` Barely referenced outside this package --- narwhals/_plan/common.py | 30 ----------------- narwhals/_plan/expressions/categorical.py | 2 +- narwhals/_plan/expressions/lists.py | 2 +- narwhals/_plan/expressions/name.py | 5 +-- narwhals/_plan/expressions/namespace.py | 41 +++++++++++++++++++++++ narwhals/_plan/expressions/strings.py | 2 +- narwhals/_plan/expressions/struct.py | 2 +- narwhals/_plan/expressions/temporal.py | 2 +- narwhals/_plan/meta.py | 2 +- narwhals/_plan/typing.py | 3 +- 10 files changed, 52 insertions(+), 39 deletions(-) create mode 100644 narwhals/_plan/expressions/namespace.py diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index b6e76c5c42..a6f6ef2394 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -16,7 +16,6 @@ ExprIRT, ExprIRT2, FunctionT, - IRNamespaceT, MapIR, NonNestedDTypeT, OneOrIterable, @@ -31,7 +30,6 @@ from typing_extensions import Self, TypeAlias - from narwhals._plan._function import Function from narwhals._plan.expr import Expr, Selector from narwhals._plan.expressions.expr import Alias, Cast, Column from narwhals._plan.meta import MetaNamespace @@ -355,34 +353,6 @@ def is_elementwise_top_level(self) -> bool: return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) -class IRNamespace(Immutable): - __slots__ = ("_ir",) - _ir: ExprIR - - @classmethod - def from_expr(cls, expr: Expr, /) -> Self: - return cls(_ir=expr._ir) - - -class ExprNamespace(Immutable, Generic[IRNamespaceT]): - __slots__ = ("_expr",) - _expr: Expr - - @property - def _ir_namespace(self) -> type[IRNamespaceT]: - raise NotImplementedError - - @property - def _ir(self) -> IRNamespaceT: - return self._ir_namespace.from_expr(self._expr) - - def _to_narwhals(self, ir: ExprIR, /) -> Expr: - return self._expr._from_ir(ir) - - def _with_unary(self, function: Function, /) -> Expr: - return self._expr._with_unary(function) - - def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes mapping: dict[type[NonNestedLiteral], type[NonNestedDType]] = { diff --git a/narwhals/_plan/expressions/categorical.py b/narwhals/_plan/expressions/categorical.py index 5b3849bcdb..7c59fd4443 100644 --- a/narwhals/_plan/expressions/categorical.py +++ b/narwhals/_plan/expressions/categorical.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, ClassVar from narwhals._plan._function import Function -from narwhals._plan.common import ExprNamespace, IRNamespace +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace if TYPE_CHECKING: from narwhals._plan.expr import Expr diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 7e50458c25..604e054a5e 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, ClassVar from narwhals._plan._function import Function -from narwhals._plan.common import ExprNamespace, IRNamespace +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: diff --git a/narwhals/_plan/expressions/name.py b/narwhals/_plan/expressions/name.py index 24bc648cde..a8460cd6dd 100644 --- a/narwhals/_plan/expressions/name.py +++ b/narwhals/_plan/expressions/name.py @@ -4,6 +4,7 @@ from narwhals._plan import common from narwhals._plan._immutable import Immutable +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import ExprIROptions if TYPE_CHECKING: @@ -52,7 +53,7 @@ def __call__(self, name: str, /) -> str: return f"{name}{self.suffix}" -class IRNameNamespace(common.IRNamespace): +class IRNameNamespace(IRNamespace): def keep(self) -> KeepName: return KeepName(expr=self._ir) @@ -72,7 +73,7 @@ def to_uppercase(self) -> RenameAlias: return self.map(str.upper) -class ExprNameNamespace(common.ExprNamespace[IRNameNamespace]): +class ExprNameNamespace(ExprNamespace[IRNameNamespace]): @property def _ir_namespace(self) -> type[IRNameNamespace]: return IRNameNamespace diff --git a/narwhals/_plan/expressions/namespace.py b/narwhals/_plan/expressions/namespace.py new file mode 100644 index 0000000000..be548dba4c --- /dev/null +++ b/narwhals/_plan/expressions/namespace.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic + +from narwhals._plan._immutable import Immutable +from narwhals._plan.typing import IRNamespaceT + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._plan._function import Function + from narwhals._plan.expr import Expr + from narwhals._plan.expressions import ExprIR + + +class IRNamespace(Immutable): + __slots__ = ("_ir",) + _ir: ExprIR + + @classmethod + def from_expr(cls, expr: Expr, /) -> Self: + return cls(_ir=expr._ir) + + +class ExprNamespace(Immutable, Generic[IRNamespaceT]): + __slots__ = ("_expr",) + _expr: Expr + + @property + def _ir_namespace(self) -> type[IRNamespaceT]: + raise NotImplementedError + + @property + def _ir(self) -> IRNamespaceT: + return self._ir_namespace.from_expr(self._expr) + + def _to_narwhals(self, ir: ExprIR, /) -> Expr: + return self._expr._from_ir(ir) + + def _with_unary(self, function: Function, /) -> Expr: + return self._expr._with_unary(function) diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index bd1e6eda97..6e60a7b530 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, ClassVar from narwhals._plan._function import Function, HorizontalFunction -from narwhals._plan.common import ExprNamespace, IRNamespace +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: diff --git a/narwhals/_plan/expressions/struct.py b/narwhals/_plan/expressions/struct.py index 02009b12e8..e3625adb8a 100644 --- a/narwhals/_plan/expressions/struct.py +++ b/narwhals/_plan/expressions/struct.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, ClassVar from narwhals._plan._function import Function -from narwhals._plan.common import ExprNamespace, IRNamespace +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: diff --git a/narwhals/_plan/expressions/temporal.py b/narwhals/_plan/expressions/temporal.py index 6aa5a61c8c..11a87599ab 100644 --- a/narwhals/_plan/expressions/temporal.py +++ b/narwhals/_plan/expressions/temporal.py @@ -4,7 +4,7 @@ from narwhals._duration import Interval from narwhals._plan._function import Function -from narwhals._plan.common import ExprNamespace, IRNamespace +from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 4d67eb60fc..a5d905ca85 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -12,8 +12,8 @@ from narwhals._plan import expressions as ir from narwhals._plan._guards import is_literal -from narwhals._plan.common import IRNamespace from narwhals._plan.expressions.literal import is_literal_scalar +from narwhals._plan.expressions.namespace import IRNamespace from narwhals.exceptions import ComputeError from narwhals.utils import Version diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 97afc235da..3f60a1ea98 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -9,10 +9,11 @@ from narwhals import dtypes from narwhals._plan._function import Function - from narwhals._plan.common import ExprIR, IRNamespace, NamedIR, SelectorIR + from narwhals._plan.common import ExprIR, NamedIR, SelectorIR from narwhals._plan.expr import Expr from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow + from narwhals._plan.expressions.namespace import IRNamespace from narwhals._plan.expressions.ranges import RangeFunction from narwhals._plan.series import Series from narwhals.typing import ( From 3d90dc7209f62e21d283effec530f728f8f5592d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 20:09:02 +0000 Subject: [PATCH 30/36] export to `_plan`, update doctests --- narwhals/_plan/__init__.py | 55 +++++++++++++++++++++++++++++++++++++ narwhals/_plan/common.py | 12 ++++---- narwhals/_plan/expr.py | 4 +-- narwhals/_plan/functions.py | 4 +-- narwhals/_plan/meta.py | 4 +-- 5 files changed, 67 insertions(+), 12 deletions(-) diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index 9d48db4f9f..afeff442c0 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -1 +1,56 @@ from __future__ import annotations + +from narwhals._plan.dataframe import DataFrame +from narwhals._plan.expr import Expr, Selector +from narwhals._plan.expressions import selectors +from narwhals._plan.functions import ( + all, + all_horizontal, + any_horizontal, + col, + concat_str, + exclude, + int_range, + len, + lit, + max, + max_horizontal, + mean, + mean_horizontal, + median, + min, + min_horizontal, + nth, + sum, + sum_horizontal, + when, +) +from narwhals._plan.series import Series + +__all__ = [ + "DataFrame", + "Expr", + "Selector", + "Series", + "all", + "all_horizontal", + "any_horizontal", + "col", + "concat_str", + "exclude", + "int_range", + "len", + "lit", + "max", + "max_horizontal", + "mean", + "mean_horizontal", + "median", + "min", + "min_horizontal", + "nth", + "selectors", + "sum", + "sum_horizontal", + "when", +] diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index a6f6ef2394..65cbef4a72 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -170,12 +170,12 @@ def iter_left(self) -> Iterator[ExprIR]: """Yield nodes root->leaf. Examples: - >>> from narwhals._plan import functions as nwd + >>> from narwhals import _plan as nw >>> - >>> a = nwd.col("a") + >>> a = nw.col("a") >>> b = a.alias("b") >>> c = b.min().alias("c") - >>> d = c.over(nwd.col("e"), nwd.col("f")) + >>> d = c.over(nw.col("e"), nw.col("f")) >>> >>> list(a._ir.iter_left()) [col('a')] @@ -205,12 +205,12 @@ def iter_right(self) -> Iterator[ExprIR]: Identical to `iter_left` for root nodes. Examples: - >>> from narwhals._plan import functions as nwd + >>> from narwhals import _plan as nw >>> - >>> a = nwd.col("a") + >>> a = nw.col("a") >>> b = a.alias("b") >>> c = b.min().alias("c") - >>> d = c.over(nwd.col("e"), nwd.col("f")) + >>> d = c.over(nw.col("e"), nw.col("f")) >>> >>> list(a._ir.iter_right()) [col('a')] diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index c6242d8503..294e5a1670 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -550,9 +550,9 @@ def name(self) -> ExprNameNamespace: """Specialized expressions for modifying the name of existing expressions. Examples: - >>> from narwhals._plan import functions as nwd + >>> from narwhals import _plan as nw >>> - >>> renamed = nwd.col("a", "b").name.suffix("_changed") + >>> renamed = nw.col("a", "b").name.suffix("_changed") >>> str(renamed._ir) "RenameAlias(expr=Columns(names=[a, b]), function=Suffix(suffix='_changed'))" """ diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 57ea8c4492..1c265d9ed0 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -128,9 +128,9 @@ def when( """Start a `when-then-otherwise` expression. Examples: - >>> from narwhals._plan import functions as nwd + >>> from narwhals import _plan as nw - >>> nwd.when(nwd.col("y") == "b").then(1) + >>> nw.when(nw.col("y") == "b").then(1) nw._plan.Expr(main): .when([(col('y')) == (lit(str: b))]).then(lit(int: 1)).otherwise(lit(null)) """ diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index a5d905ca85..bb7a4315b3 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -49,9 +49,9 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: """Get the output name of this expression. Examples: - >>> from narwhals._plan import functions as nwd + >>> from narwhals import _plan as nw >>> - >>> a = nwd.col("a") + >>> a = nw.col("a") >>> b = a.alias("b") >>> c = b.min().alias("c") >>> From a1f20ca602ca05c45c689d8956e198d9bfd39a41 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 13 Sep 2025 20:51:33 +0000 Subject: [PATCH 31/36] test: Update imports --- narwhals/_plan/expressions/__init__.py | 9 ++- tests/plan/compliant_test.py | 31 ++++----- tests/plan/expr_expansion_test.py | 38 +++++------ tests/plan/expr_parsing_test.py | 87 +++++++++++++------------- tests/plan/expr_rewrites_test.py | 18 +++--- tests/plan/meta_test.py | 13 ++-- tests/plan/utils.py | 19 +++--- 7 files changed, 107 insertions(+), 108 deletions(-) diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index cf6323224a..c9a7fe8f6f 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -1,7 +1,13 @@ from __future__ import annotations from narwhals._plan.common import ExprIR, SelectorIR # prob should move into package? -from narwhals._plan.expressions import aggregation, functions, operators, selectors +from narwhals._plan.expressions import ( + aggregation, + boolean, + functions, + operators, + selectors, +) from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr from narwhals._plan.expressions.expr import ( Alias, @@ -69,6 +75,7 @@ "WindowExpr", "_ColumnSelection", "aggregation", + "boolean", "col", "cols", "functions", diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index ef179346c3..89f511e214 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -4,7 +4,7 @@ import pytest -from narwhals._plan.expressions import selectors as ndcs +from narwhals._plan import selectors as ndcs pytest.importorskip("pyarrow") pytest.importorskip("numpy") @@ -12,9 +12,7 @@ import pyarrow as pa import narwhals as nw -from narwhals._plan import functions as nwd -from narwhals._plan._guards import is_expr -from narwhals._plan.dataframe import DataFrame +from narwhals import _plan as nwd from narwhals._utils import Version from narwhals.exceptions import ComputeError from tests.utils import assert_equal_data @@ -22,7 +20,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from narwhals._plan.expr import Expr from narwhals.typing import PythonLiteral @@ -66,8 +63,8 @@ def data_indexed() -> dict[str, Any]: } -def _ids_ir(expr: Expr | Any) -> str: - if is_expr(expr): +def _ids_ir(expr: nwd.Expr | Any) -> str: + if isinstance(expr, nwd.Expr): return repr(expr._ir) return repr(expr) @@ -405,10 +402,12 @@ def _ids_ir(expr: Expr | Any) -> str: ids=_ids_ir, ) def test_select( - expr: Expr | Sequence[Expr], expected: dict[str, Any], data_small: dict[str, Any] + expr: nwd.Expr | Sequence[nwd.Expr], + expected: dict[str, Any], + data_small: dict[str, Any], ) -> None: frame = pa.table(data_small) - df = DataFrame.from_native(frame) + df = nwd.DataFrame.from_native(frame) result = df.select(expr).to_dict(as_series=False) assert_equal_data(result, expected) @@ -477,19 +476,21 @@ def test_select( ], ) def test_with_columns( - expr: Expr | Sequence[Expr], expected: dict[str, Any], data_smaller: dict[str, Any] + expr: nwd.Expr | Sequence[nwd.Expr], + expected: dict[str, Any], + data_smaller: dict[str, Any], ) -> None: frame = pa.table(data_smaller) - df = DataFrame.from_native(frame) + df = nwd.DataFrame.from_native(frame) result = df.with_columns(expr).to_dict(as_series=False) assert_equal_data(result, expected) -def first(*names: str) -> Expr: +def first(*names: str) -> nwd.Expr: return nwd.col(*names).first() -def last(*names: str) -> Expr: +def last(*names: str) -> nwd.Expr: return nwd.col(*names).last() @@ -505,12 +506,12 @@ def last(*names: str) -> Expr: ], ) def test_first_last_expr_with_columns( - data_indexed: dict[str, Any], agg: Expr, expected: PythonLiteral + data_indexed: dict[str, Any], agg: nwd.Expr, expected: PythonLiteral ) -> None: """Related https://github.com/narwhals-dev/narwhals/pull/2528#discussion_r2225930065.""" height = len(next(iter(data_indexed.values()))) expected_broadcast = height * [expected] - frame = DataFrame.from_native(pa.table(data_indexed)) + frame = nwd.DataFrame.from_native(pa.table(data_indexed)) expr = agg.over(order_by="idx").alias("result") result = frame.with_columns(expr).select("result").to_dict(as_series=False) assert_equal_data(result, {"result": expected_broadcast}) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 7067c52fd9..d881a3a132 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -6,14 +6,14 @@ import pytest import narwhals as nw -from narwhals._plan import functions as nwd +from narwhals import _plan as nwd +from narwhals._plan import expressions as ir, selectors as ndcs from narwhals._plan._expansion import ( prepare_projection, replace_selector, rewrite_special_aliases, ) from narwhals._plan._parse import parse_into_seq_of_expr_ir -from narwhals._plan.expressions import Alias, Columns, selectors as ndcs from narwhals._plan.schema import freeze_schema from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError from tests.plan.utils import assert_expr_ir_equal @@ -21,8 +21,6 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence - from narwhals._plan.common import ExprIR - from narwhals._plan.expr import Expr, Selector from narwhals._plan.typing import IntoExpr, MapIR from narwhals.dtypes import DType @@ -110,7 +108,7 @@ def udf_name_map(name: str) -> str: ), ], ) -def test_rewrite_special_aliases_single(expr: Expr, expected: str) -> None: +def test_rewrite_special_aliases_single(expr: nwd.Expr, expected: str) -> None: # NOTE: We can't use `output_name()` without resolving these rewrites # Once they're done, `output_name()` just peeks into `Alias(name=...)` ir_input = expr._ir @@ -126,10 +124,10 @@ def test_rewrite_special_aliases_single(expr: Expr, expected: str) -> None: def alias_replace_guarded(name: str) -> MapIR: # pragma: no cover """Guards against repeatedly creating the same alias.""" - def fn(ir: ExprIR) -> ExprIR: - if isinstance(ir, Alias) and ir.name != name: - return Alias(expr=ir.expr, name=name) - return ir + def fn(e_ir: ir.ExprIR) -> ir.ExprIR: + if isinstance(e_ir, ir.Alias) and e_ir.name != name: + return ir.Alias(expr=e_ir.expr, name=name) + return e_ir return fn @@ -143,10 +141,10 @@ def alias_replace_unguarded(name: str) -> MapIR: # pragma: no cover - *Pragmatically*, it might require an extra iteration to detect a cycle """ - def fn(ir: ExprIR) -> ExprIR: - if isinstance(ir, Alias): - return Alias(expr=ir.expr, name=name) - return ir + def fn(e_ir: ir.ExprIR) -> ir.ExprIR: + if isinstance(e_ir, ir.Alias): + return ir.Alias(expr=e_ir.expr, name=name) + return e_ir return fn @@ -180,7 +178,7 @@ def fn(ir: ExprIR) -> ExprIR: ), ], ) -def test_map_ir_recursive(expr: Expr, function: MapIR, expected: Expr) -> None: +def test_map_ir_recursive(expr: nwd.Expr, function: MapIR, expected: nwd.Expr) -> None: actual = expr._ir.map_ir(function) assert_expr_ir_equal(actual, expected) @@ -190,7 +188,7 @@ def test_map_ir_recursive(expr: Expr, function: MapIR, expected: Expr) -> None: [ (nwd.col("a"), nwd.col("a")), (nwd.col("a").max().alias("z"), nwd.col("a").max().alias("z")), - (ndcs.string(), Columns(names=("k",))), + (ndcs.string(), ir.Columns(names=("k",))), ( ndcs.by_dtype(nw.Datetime("ms"), nw.Date, nw.List(nw.String)), nwd.col("n", "s"), @@ -248,7 +246,9 @@ def test_map_ir_recursive(expr: Expr, function: MapIR, expected: Expr) -> None: ], ) def test_replace_selector( - expr: Selector | Expr, expected: Expr | ExprIR, schema_1: dict[str, DType] + expr: nwd.Selector | nwd.Expr, + expected: nwd.Expr | ir.ExprIR, + schema_1: dict[str, DType], ) -> None: actual = replace_selector(expr._ir, schema=freeze_schema(**schema_1)) assert_expr_ir_equal(actual, expected) @@ -422,7 +422,7 @@ def test_replace_selector( ) def test_prepare_projection( into_exprs: IntoExpr | Sequence[IntoExpr], - expected: Sequence[Expr], + expected: Sequence[nwd.Expr], schema_1: dict[str, DType], ) -> None: irs_in = parse_into_seq_of_expr_ir(into_exprs) @@ -447,7 +447,7 @@ def test_prepare_projection( *MULTI_OUTPUT_EXPRS, ], ) -def test_prepare_projection_duplicate(expr: Expr, schema_1: dict[str, DType]) -> None: +def test_prepare_projection_duplicate(expr: nwd.Expr, schema_1: dict[str, DType]) -> None: irs = parse_into_seq_of_expr_ir(expr.alias("dupe")) pattern = re.compile(r"\.alias\(.dupe.\)") with pytest.raises(DuplicateError, match=pattern): @@ -547,7 +547,7 @@ def test_prepare_projection_column_not_found( ) def test_prepare_projection_horizontal_alias( into_exprs: IntoExpr | Iterable[IntoExpr], - function: Callable[..., Expr], + function: Callable[..., nwd.Expr], schema_1: dict[str, DType], ) -> None: # NOTE: See https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2139965411 diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 00870e0084..ad10e3a7cb 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -3,21 +3,18 @@ import operator import re from collections import deque -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import pytest import narwhals as nw -import narwhals._plan.functions as nwd +from narwhals import _plan as nwd +from narwhals._plan import expressions as ir from narwhals._plan._parse import parse_into_seq_of_expr_ir -from narwhals._plan.common import ExprIR -from narwhals._plan.expr import Expr -from narwhals._plan.expressions import boolean, expr, functions as F, operators as ops -from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr, RangeExpr +from narwhals._plan.expressions import functions as F, operators as ops from narwhals._plan.expressions.literal import SeriesLiteral -from narwhals._plan.series import Series from narwhals.exceptions import ( InvalidIntoExprError, InvalidOperationError, @@ -58,7 +55,7 @@ def test_parsing( exprs: Seq[IntoExpr | Iterable[IntoExpr]], named_exprs: dict[str, IntoExpr] ) -> None: assert all( - isinstance(node, ExprIR) + isinstance(node, ir.ExprIR) for node in parse_into_seq_of_expr_ir(*exprs, **named_exprs) ) @@ -66,8 +63,8 @@ def test_parsing( @pytest.mark.parametrize( ("function", "ir_node"), [ - (nwd.all_horizontal, boolean.AllHorizontal), - (nwd.any_horizontal, boolean.AnyHorizontal), + (nwd.all_horizontal, ir.boolean.AllHorizontal), + (nwd.any_horizontal, ir.boolean.AnyHorizontal), (nwd.sum_horizontal, F.SumHorizontal), (nwd.min_horizontal, F.MinHorizontal), (nwd.max_horizontal, F.MaxHorizontal), @@ -85,18 +82,18 @@ def test_parsing( ], ) def test_function_expr_horizontal( - function: Callable[..., Expr], + function: Callable[..., nwd.Expr], ir_node: type[Function], args: Seq[IntoExpr | Iterable[IntoExpr]], ) -> None: variadic = function(*args) sequence = function(args) - assert isinstance(variadic, Expr) - assert isinstance(sequence, Expr) + assert isinstance(variadic, nwd.Expr) + assert isinstance(sequence, nwd.Expr) variadic_node = variadic._ir sequence_node = sequence._ir unrelated_node = nwd.lit(1)._ir - assert isinstance(variadic_node, FunctionExpr) + assert isinstance(variadic_node, ir.FunctionExpr) assert isinstance(variadic_node.function, ir_node) assert variadic_node == sequence_node assert sequence_node != unrelated_node @@ -161,13 +158,13 @@ def test_invalid_agg_non_elementwise() -> None: def test_agg_non_elementwise_range_special() -> None: e = nwd.int_range(0, 100) - assert isinstance(e._ir, RangeExpr) + assert isinstance(e._ir, ir.RangeExpr) e = nwd.int_range(nwd.len(), dtype=nw.UInt32).alias("index") - ir = e._ir - assert isinstance(ir, expr.Alias) - assert isinstance(ir.expr, RangeExpr) - assert isinstance(ir.expr.input[0], expr.Literal) - assert isinstance(ir.expr.input[1], expr.Len) + e_ir = e._ir + assert isinstance(e_ir, ir.Alias) + assert isinstance(e_ir.expr, ir.RangeExpr) + assert isinstance(e_ir.expr.input[0], ir.Literal) + assert isinstance(e_ir.expr.input[1], ir.Len) def test_invalid_int_range() -> None: @@ -249,8 +246,8 @@ def test_invalid_binary_expr_length_changing() -> None: a.map_batches(lambda x: x) / b.gather_every(1, 0) -def _is_expr_ir_binary_expr(expr: Expr) -> bool: - return isinstance(expr._ir, BinaryExpr) +def _is_expr_ir_binary_expr(expr: nwd.Expr) -> bool: + return isinstance(expr._ir, ir.BinaryExpr) def test_binary_expr_length_changing_agg() -> None: @@ -291,10 +288,10 @@ def test_is_in_seq(into_iter: IntoIterable) -> None: expected = 1, 2, 3 other = into_iter(list(expected)) expr = nwd.col("a").is_in(other) - ir = expr._ir - assert isinstance(ir, FunctionExpr) - assert isinstance(ir.function, boolean.IsInSeq) - assert ir.function.other == expected + e_ir = expr._ir + assert isinstance(e_ir, ir.FunctionExpr) + assert isinstance(e_ir.function, ir.boolean.IsInSeq) + assert e_ir.function.other == expected def test_is_in_series() -> None: @@ -302,12 +299,12 @@ def test_is_in_series() -> None: import pyarrow as pa native = pa.chunked_array([pa.array([1, 2, 3])]) - other = Series.from_native(native) + other = nwd.Series.from_native(native) expr = nwd.col("a").is_in(other) - ir = expr._ir - assert isinstance(ir, FunctionExpr) - assert isinstance(ir.function, boolean.IsInSeries) - assert ir.function.other.unwrap().to_native() is native + e_ir = expr._ir + assert isinstance(e_ir, ir.FunctionExpr) + assert isinstance(e_ir.function, ir.boolean.IsInSeries) + assert e_ir.function.other.unwrap().to_native() is native @pytest.mark.parametrize( @@ -395,15 +392,15 @@ def test_lit_series_roundtrip() -> None: data = ["a", "b", "c"] native = pa.chunked_array([pa.array(data)]) - series = Series.from_native(native) + series = nwd.Series.from_native(native) lit_series = nwd.lit(series) assert lit_series.meta.is_literal() - ir = lit_series._ir - assert isinstance(ir, expr.Literal) - assert isinstance(ir.dtype, nw.String) - assert isinstance(ir.value, SeriesLiteral) - unwrapped = ir.unwrap() - assert isinstance(unwrapped, Series) + e_ir = lit_series._ir + assert isinstance(e_ir, ir.Literal) + assert isinstance(e_ir.dtype, nw.String) + assert isinstance(e_ir.value, SeriesLiteral) + unwrapped = e_ir.unwrap() + assert isinstance(unwrapped, nwd.Series) assert isinstance(unwrapped.to_native(), pa.ChunkedArray) assert unwrapped.to_list() == data @@ -445,8 +442,8 @@ def test_operators_left_right( } result_1 = function(arg_1, arg_2) result_2 = function(arg_2, arg_1) - assert isinstance(result_1, Expr) - assert isinstance(result_2, Expr) + assert isinstance(result_1, nwd.Expr) + assert isinstance(result_2, nwd.Expr) ir_1 = result_1._ir ir_2 = result_2._ir if op in {ops.Eq, ops.NotEq}: @@ -454,9 +451,9 @@ def test_operators_left_right( else: assert ir_1 != ir_2 if issubclass(op, ops.Operator): - assert isinstance(ir_1, BinaryExpr) + assert isinstance(ir_1, ir.BinaryExpr) assert isinstance(ir_1.op, op) - assert isinstance(ir_2, BinaryExpr) + assert isinstance(ir_2, ir.BinaryExpr) op_inverse = inverse.get(op, op) assert isinstance(ir_2.op, op_inverse) if op in {ops.Eq, ops.NotEq, *inverse}: @@ -466,8 +463,8 @@ def test_operators_left_right( assert ir_1.left == ir_2.right assert ir_1.right == ir_2.left else: - assert isinstance(ir_1, FunctionExpr) + assert isinstance(ir_1, ir.FunctionExpr) assert isinstance(ir_1.function, op) - assert isinstance(ir_2, FunctionExpr) + assert isinstance(ir_2, ir.FunctionExpr) assert isinstance(ir_2.function, op) assert tuple(reversed(ir_2.input)) == ir_1.input diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index e4481c0ded..ce78c4f72f 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -5,22 +5,19 @@ import pytest import narwhals as nw -from narwhals._plan import _parse, functions as nwd -from narwhals._plan._guards import is_expr +from narwhals import _plan as nwd +from narwhals._plan import _parse, expressions as ir, selectors as ndcs from narwhals._plan._rewrites import ( rewrite_all, rewrite_binary_agg_over, rewrite_elementwise_over, ) -from narwhals._plan.common import ExprIR, NamedIR -from narwhals._plan.expressions import selectors as ndcs -from narwhals._plan.expressions.expr import WindowExpr +from narwhals._plan.common import NamedIR from narwhals._plan.expressions.window import Over from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_expr_ir_equal if TYPE_CHECKING: - from narwhals._plan.expr import Expr from narwhals._plan.typing import IntoExpr from narwhals.dtypes import DType @@ -42,8 +39,8 @@ def schema_2() -> dict[str, DType]: } -def _to_window_expr(into_expr: IntoExpr, *partition_by: IntoExpr) -> WindowExpr: - return WindowExpr( +def _to_window_expr(into_expr: IntoExpr, *partition_by: IntoExpr) -> ir.WindowExpr: + return ir.WindowExpr( expr=_parse.parse_into_expr_ir(into_expr), partition_by=_parse.parse_into_seq_of_expr_ir(*partition_by), options=Over(), @@ -83,10 +80,9 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: assert_expr_ir_equal(lhs, rhs) -def named_ir(name: str, expr: Expr | ExprIR, /) -> NamedIR[ExprIR]: +def named_ir(name: str, expr: nwd.Expr | ir.ExprIR, /) -> NamedIR[ir.ExprIR]: """Helper constructor for test compare.""" - ir = expr._ir if is_expr(expr) else expr - return NamedIR(expr=ir, name=name) + return NamedIR(expr=expr._ir if isinstance(expr, nwd.Expr) else expr, name=name) def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index c4b256a696..d360907b6d 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -1,15 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import pytest -import narwhals._plan.functions as nwd +from narwhals import _plan as nwd from tests.utils import POLARS_VERSION -if TYPE_CHECKING: - from narwhals._plan.expr import Expr - pytest.importorskip("polars") import polars as pl @@ -51,7 +46,9 @@ (nwd.all().mean().sort_by("d"), pl.all().mean().sort_by("d"), ["d"]), ], ) -def test_meta_root_names(nw_expr: Expr, pl_expr: pl.Expr, expected: list[str]) -> None: +def test_meta_root_names( + nw_expr: nwd.Expr, pl_expr: pl.Expr, expected: list[str] +) -> None: pl_result = pl_expr.meta.root_names() nw_result = nw_expr.meta.root_names() assert nw_result == expected @@ -179,7 +176,7 @@ def test_meta_root_names(nw_expr: Expr, pl_expr: pl.Expr, expected: list[str]) - ), ], ) -def test_meta_output_name(nw_expr: Expr, pl_expr: pl.Expr, expected: str) -> None: +def test_meta_output_name(nw_expr: nwd.Expr, pl_expr: pl.Expr, expected: str) -> None: pl_result = pl_expr.meta.output_name() nw_result = nw_expr.meta.output_name() assert nw_result == expected diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 25d0acd3f4..16bd3430c0 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -2,19 +2,18 @@ from typing import TYPE_CHECKING -from narwhals._plan._guards import is_expr -from narwhals._plan.common import ExprIR, NamedIR +from narwhals import _plan as nwd +from narwhals._plan import expressions as ir +from narwhals._plan.common import NamedIR if TYPE_CHECKING: from typing_extensions import LiteralString - from narwhals._plan.expr import Expr - -def _unwrap_ir(obj: Expr | ExprIR | NamedIR) -> ExprIR: - if is_expr(obj): +def _unwrap_ir(obj: nwd.Expr | ir.ExprIR | NamedIR) -> ir.ExprIR: + if isinstance(obj, nwd.Expr): return obj._ir - if isinstance(obj, ExprIR): + if isinstance(obj, ir.ExprIR): return obj if isinstance(obj, NamedIR): return obj.expr @@ -22,7 +21,9 @@ def _unwrap_ir(obj: Expr | ExprIR | NamedIR) -> ExprIR: def assert_expr_ir_equal( - actual: Expr | ExprIR | NamedIR, expected: Expr | ExprIR | NamedIR | LiteralString, / + actual: nwd.Expr | ir.ExprIR | NamedIR, + expected: nwd.Expr | ir.ExprIR | NamedIR | LiteralString, + /, ) -> None: """Assert that `actual` is equivalent to `expected`. @@ -40,5 +41,5 @@ def assert_expr_ir_equal( elif isinstance(actual, NamedIR) and isinstance(expected, NamedIR): assert actual == expected else: - rhs = expected._ir if is_expr(expected) else expected + rhs = expected._ir if isinstance(expected, nwd.Expr) else expected assert lhs == rhs From a6617d546bbfbee47c7fe372f3ff176857871b11 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Sep 2025 16:07:26 +0000 Subject: [PATCH 32/36] refactor: remove `nwd` alias --- narwhals/_plan/_parse.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/_parse.py b/narwhals/_plan/_parse.py index 05e503eb68..ef622707e5 100644 --- a/narwhals/_plan/_parse.py +++ b/narwhals/_plan/_parse.py @@ -88,14 +88,14 @@ def parse_into_expr_ir( input: IntoExpr, *, str_as_lit: bool = False, dtype: IntoDType | None = None ) -> ExprIR: """Parse a single input into an `ExprIR` node.""" - from narwhals._plan import functions as nwd + from narwhals._plan import col, lit if is_expr(input): expr = input elif isinstance(input, str) and not str_as_lit: - expr = nwd.col(input) + expr = col(input) else: - expr = nwd.lit(input, dtype=dtype) + expr = lit(input, dtype=dtype) return expr._ir @@ -157,10 +157,10 @@ def _parse_named_inputs(named_inputs: dict[str, IntoExpr], /) -> Iterator[ExprIR def _parse_constraints(constraints: dict[str, IntoExpr], /) -> Iterator[ExprIR]: - from narwhals._plan import functions as nwd + from narwhals._plan import col for name, value in constraints.items(): - yield (nwd.col(name) == value)._ir + yield (col(name) == value)._ir def _combine_predicates(predicates: Iterator[ExprIR], /) -> ExprIR: From 5ad77ee75afe31cfba7a25246e51caed3de25b04 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Sep 2025 16:11:08 +0000 Subject: [PATCH 33/36] test: Use `nwp` for `narwhals._plan` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `nwd` was for both `demo.py` and `dummy.py` modules - They no longer exist Starting to think the `d` stood for *dan* 😂 --- tests/plan/compliant_test.py | 244 ++++++++++++------------- tests/plan/expr_expansion_test.py | 288 +++++++++++++++--------------- tests/plan/expr_parsing_test.py | 200 ++++++++++----------- tests/plan/expr_rewrites_test.py | 66 +++---- tests/plan/meta_test.py | 66 +++---- tests/plan/utils.py | 12 +- 6 files changed, 438 insertions(+), 438 deletions(-) diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 89f511e214..ffada70747 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -12,7 +12,7 @@ import pyarrow as pa import narwhals as nw -from narwhals import _plan as nwd +from narwhals import _plan as nwp from narwhals._utils import Version from narwhals.exceptions import ComputeError from tests.utils import assert_equal_data @@ -63,8 +63,8 @@ def data_indexed() -> dict[str, Any]: } -def _ids_ir(expr: nwd.Expr | Any) -> str: - if isinstance(expr, nwd.Expr): +def _ids_ir(expr: nwp.Expr | Any) -> str: + if isinstance(expr, nwp.Expr): return repr(expr._ir) return repr(expr) @@ -86,60 +86,60 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: @pytest.mark.parametrize( ("expr", "expected"), [ - (nwd.col("a"), {"a": ["A", "B", "A"]}), - (nwd.col("a", "b"), {"a": ["A", "B", "A"], "b": [1, 2, 3]}), - (nwd.lit(1), {"literal": [1]}), - (nwd.lit(2.0), {"literal": [2.0]}), - (nwd.lit(None, nw.String), {"literal": [None]}), - (nwd.col("a", "b").first(), {"a": ["A"], "b": [1]}), - (nwd.col("d").max(), {"d": [8]}), - ([nwd.len(), nwd.nth(3).last()], {"len": [3], "d": [8]}), + (nwp.col("a"), {"a": ["A", "B", "A"]}), + (nwp.col("a", "b"), {"a": ["A", "B", "A"], "b": [1, 2, 3]}), + (nwp.lit(1), {"literal": [1]}), + (nwp.lit(2.0), {"literal": [2.0]}), + (nwp.lit(None, nw.String), {"literal": [None]}), + (nwp.col("a", "b").first(), {"a": ["A"], "b": [1]}), + (nwp.col("d").max(), {"d": [8]}), + ([nwp.len(), nwp.nth(3).last()], {"len": [3], "d": [8]}), ( - [nwd.len().alias("e"), nwd.nth(3).last(), nwd.nth(2)], + [nwp.len().alias("e"), nwp.nth(3).last(), nwp.nth(2)], {"e": [3, 3, 3], "d": [8, 8, 8], "c": [9, 2, 4]}, ), - (nwd.col("b").sort(descending=True).alias("b_desc"), {"b_desc": [3, 2, 1]}), - (nwd.col("c").filter(a="B"), {"c": [2]}), + (nwp.col("b").sort(descending=True).alias("b_desc"), {"b_desc": [3, 2, 1]}), + (nwp.col("c").filter(a="B"), {"c": [2]}), ( - [nwd.nth(0, 1).filter(nwd.col("c") >= 4), nwd.col("d").last() - 4], + [nwp.nth(0, 1).filter(nwp.col("c") >= 4), nwp.col("d").last() - 4], {"a": ["A", "A"], "b": [1, 3], "d": [4, 4]}, ), - (nwd.col("b").cast(nw.Float64()), {"b": [1.0, 2.0, 3.0]}), - (nwd.lit(1).cast(nw.Float64).alias("literal_cast"), {"literal_cast": [1.0]}), + (nwp.col("b").cast(nw.Float64()), {"b": [1.0, 2.0, 3.0]}), + (nwp.lit(1).cast(nw.Float64).alias("literal_cast"), {"literal_cast": [1.0]}), pytest.param( - nwd.lit(1).cast(nw.Float64()).name.suffix("_cast"), + nwp.lit(1).cast(nw.Float64()).name.suffix("_cast"), {"literal_cast": [1.0]}, marks=XFAIL_REWRITE_SPECIAL_ALIASES, ), - ([ndcs.string().first(), nwd.col("b")], {"a": ["A", "A", "A"], "b": [1, 2, 3]}), + ([ndcs.string().first(), nwp.col("b")], {"a": ["A", "A", "A"], "b": [1, 2, 3]}), ( - nwd.col("c", "d") + nwp.col("c", "d") .sort_by("a", "b", descending=[True, False]) .cast(nw.Float32()) .name.to_uppercase(), {"C": [2.0, 9.0, 4.0], "D": [7.0, 8.0, 8.0]}, ), - ([nwd.int_range(5)], {"literal": [0, 1, 2, 3, 4]}), - ([nwd.int_range(nwd.len())], {"literal": [0, 1, 2]}), - (nwd.int_range(nwd.len() * 5, 20).alias("lol"), {"lol": [15, 16, 17, 18, 19]}), - (nwd.int_range(nwd.col("b").min() + 4, nwd.col("d").last()), {"b": [5, 6, 7]}), - (nwd.col("b") ** 2, {"b": [1, 4, 9]}), + ([nwp.int_range(5)], {"literal": [0, 1, 2, 3, 4]}), + ([nwp.int_range(nwp.len())], {"literal": [0, 1, 2]}), + (nwp.int_range(nwp.len() * 5, 20).alias("lol"), {"lol": [15, 16, 17, 18, 19]}), + (nwp.int_range(nwp.col("b").min() + 4, nwp.col("d").last()), {"b": [5, 6, 7]}), + (nwp.col("b") ** 2, {"b": [1, 4, 9]}), ( - [2 ** nwd.col("b"), (nwd.lit(2.0) ** nwd.nth(1)).alias("lit")], + [2 ** nwp.col("b"), (nwp.lit(2.0) ** nwp.nth(1)).alias("lit")], {"literal": [2, 4, 8], "lit": [2, 4, 8]}, ), pytest.param( [ - nwd.col("b").is_between(2, 3, "left").alias("left"), - nwd.col("b").is_between(2, 3, "right").alias("right"), - nwd.col("b").is_between(2, 3, "none").alias("none"), - nwd.col("b").is_between(2, 3, "both").alias("both"), - nwd.col("c").is_between( - nwd.col("c").mean() - 1, 7 - nwd.col("b"), "both" + nwp.col("b").is_between(2, 3, "left").alias("left"), + nwp.col("b").is_between(2, 3, "right").alias("right"), + nwp.col("b").is_between(2, 3, "none").alias("none"), + nwp.col("b").is_between(2, 3, "both").alias("both"), + nwp.col("c").is_between( + nwp.col("c").mean() - 1, 7 - nwp.col("b"), "both" ), - nwd.col("c") + nwp.col("c") .alias("c_right") - .is_between(nwd.col("c").mean() - 1, 7 - nwd.col("b"), "right"), + .is_between(nwp.col("c").mean() - 1, 7 - nwp.col("b"), "right"), ], { "left": [False, True, False], @@ -153,12 +153,12 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: ), pytest.param( [ - nwd.col("e").fill_null(0).alias("e_0"), - nwd.col("e").fill_null(nwd.col("b")).alias("e_b"), - nwd.col("e").fill_null(nwd.col("b").last()).alias("e_b_last"), - nwd.col("e") + nwp.col("e").fill_null(0).alias("e_0"), + nwp.col("e").fill_null(nwp.col("b")).alias("e_b"), + nwp.col("e").fill_null(nwp.col("b").last()).alias("e_b_last"), + nwp.col("e") .sort(nulls_last=True) - .fill_null(nwd.col("d").last() - nwd.col("c")) + .fill_null(nwp.col("d").last() - nwp.col("c")) .alias("e_sort_wild"), ], { @@ -169,88 +169,88 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: }, id="sort", ), - (nwd.col("e", "d").is_null().any(), {"e": [True], "d": [False]}), + (nwp.col("e", "d").is_null().any(), {"e": [True], "d": [False]}), ( - [(~nwd.col("e", "d").is_null()).all(), "b"], + [(~nwp.col("e", "d").is_null()).all(), "b"], {"e": [False, False, False], "d": [True, True, True], "b": [1, 2, 3]}, ), pytest.param( - nwd.when(d=8).then("c"), {"c": [9, None, 4]}, id="When-otherwise-none" + nwp.when(d=8).then("c"), {"c": [9, None, 4]}, id="When-otherwise-none" ), pytest.param( - nwd.when(nwd.col("e").is_null()) - .then(nwd.col("b") + nwd.col("c")) + nwp.when(nwp.col("e").is_null()) + .then(nwp.col("b") + nwp.col("c")) .otherwise(50), {"b": [10, 50, 50]}, id="When-otherwise-native-broadcast", ), pytest.param( - nwd.when(nwd.col("a") == nwd.lit("C")) - .then(nwd.lit("c")) - .when(nwd.col("a") == nwd.lit("D")) - .then(nwd.lit("d")) - .when(nwd.col("a") == nwd.lit("B")) - .then(nwd.lit("b")) - .when(nwd.col("a") == nwd.lit("A")) - .then(nwd.lit("a")) + nwp.when(nwp.col("a") == nwp.lit("C")) + .then(nwp.lit("c")) + .when(nwp.col("a") == nwp.lit("D")) + .then(nwp.lit("d")) + .when(nwp.col("a") == nwp.lit("B")) + .then(nwp.lit("b")) + .when(nwp.col("a") == nwp.lit("A")) + .then(nwp.lit("a")) .alias("A"), {"A": ["a", "b", "a"]}, id="When-then-x4", ), pytest.param( - nwd.when(nwd.col("c") > 5, b=1).then(999), + nwp.when(nwp.col("c") > 5, b=1).then(999), {"literal": [999, None, None]}, id="When-multiple-predicates", ), pytest.param( - nwd.when(nwd.col("b") == nwd.col("c"), nwd.col("d").mean() > nwd.col("d")) + nwp.when(nwp.col("b") == nwp.col("c"), nwp.col("d").mean() > nwp.col("d")) .then(123) - .when(nwd.lit(True), ~nwd.nth(4).is_null()) + .when(nwp.lit(True), ~nwp.nth(4).is_null()) .then(456) - .otherwise(nwd.col("c")), + .otherwise(nwp.col("c")), {"literal": [9, 123, 456]}, id="When-multiple-predicates-mixed-broadcast", ), pytest.param( - nwd.when(nwd.lit(True)).then("c"), + nwp.when(nwp.lit(True)).then("c"), {"c": [9, 2, 4]}, id="When-literal-then-column", ), pytest.param( - nwd.when(nwd.lit(True)).then(nwd.col("c").mean()), + nwp.when(nwp.lit(True)).then(nwp.col("c").mean()), {"c": [5.0]}, id="When-literal-then-agg", ), pytest.param( [ - nwd.when(nwd.lit(True)).then(nwd.col("e").last()), - nwd.col("b").sort(descending=True), + nwp.when(nwp.lit(True)).then(nwp.col("e").last()), + nwp.col("b").sort(descending=True), ], {"e": [7, 7, 7], "b": [3, 2, 1]}, id="When-literal-then-agg-broadcast", ), pytest.param( [ - nwd.all_horizontal( - nwd.col("b") < nwd.col("c"), - nwd.col("a") != nwd.lit("B"), - nwd.col("e").cast(nw.Boolean), - nwd.lit(True), + nwp.all_horizontal( + nwp.col("b") < nwp.col("c"), + nwp.col("a") != nwp.lit("B"), + nwp.col("e").cast(nw.Boolean), + nwp.lit(True), ), - nwd.nth(1).last().name.suffix("_last"), + nwp.nth(1).last().name.suffix("_last"), ], {"b": [None, False, True], "b_last": [3, 3, 3]}, id="all-horizontal-mixed-broadcast", ), pytest.param( [ - nwd.all_horizontal(nwd.lit(True), nwd.lit(True)).alias("a"), - nwd.all_horizontal(nwd.lit(False), nwd.lit(True)).alias("b"), - nwd.all_horizontal(nwd.lit(False), nwd.lit(False)).alias("c"), - nwd.all_horizontal(nwd.lit(None, nw.Boolean), nwd.lit(True)).alias("d"), - nwd.all_horizontal(nwd.lit(None, nw.Boolean), nwd.lit(False)).alias("e"), - nwd.all_horizontal( - nwd.lit(None, nw.Boolean), nwd.lit(None, nw.Boolean) + nwp.all_horizontal(nwp.lit(True), nwp.lit(True)).alias("a"), + nwp.all_horizontal(nwp.lit(False), nwp.lit(True)).alias("b"), + nwp.all_horizontal(nwp.lit(False), nwp.lit(False)).alias("c"), + nwp.all_horizontal(nwp.lit(None, nw.Boolean), nwp.lit(True)).alias("d"), + nwp.all_horizontal(nwp.lit(None, nw.Boolean), nwp.lit(False)).alias("e"), + nwp.all_horizontal( + nwp.lit(None, nw.Boolean), nwp.lit(None, nw.Boolean) ).alias("f"), ], { @@ -265,9 +265,9 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: ), pytest.param( [ - nwd.any_horizontal("f", "g"), - nwd.any_horizontal("g", "h"), - nwd.any_horizontal(nwd.lit(False), nwd.col("g").last()).alias( + nwp.any_horizontal("f", "g"), + nwp.any_horizontal("g", "h"), + nwp.any_horizontal(nwp.lit(False), nwp.col("g").last()).alias( "False-False" ), ], @@ -280,9 +280,9 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: ), pytest.param( [ - nwd.any_horizontal(nwd.lit(None, nw.Boolean), "i").alias("None-None"), - nwd.any_horizontal(nwd.lit(True), "i").alias("True-None"), - nwd.any_horizontal(nwd.lit(False), "i").alias("False-None"), + nwp.any_horizontal(nwp.lit(None, nw.Boolean), "i").alias("None-None"), + nwp.any_horizontal(nwp.lit(True), "i").alias("True-None"), + nwp.any_horizontal(nwp.lit(False), "i").alias("False-None"), ], { "None-None": [None, None, None], @@ -294,15 +294,15 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: ), pytest.param( [ - nwd.col("b").alias("a"), - nwd.col("l").alias("b"), - nwd.col("m").alias("i"), - nwd.any_horizontal(nwd.sum("b", "l").cast(nw.Boolean)).alias("any"), - nwd.all_horizontal(nwd.sum("b", "l").cast(nw.Boolean)).alias("all"), - nwd.max_horizontal(nwd.sum("b"), nwd.sum("l")).alias("max"), - nwd.min_horizontal(nwd.sum("b"), nwd.sum("l")).alias("min"), - nwd.sum_horizontal(nwd.sum("b"), nwd.sum("l")).alias("sum"), - nwd.mean_horizontal(nwd.sum("b"), nwd.sum("l")).alias("mean"), + nwp.col("b").alias("a"), + nwp.col("l").alias("b"), + nwp.col("m").alias("i"), + nwp.any_horizontal(nwp.sum("b", "l").cast(nw.Boolean)).alias("any"), + nwp.all_horizontal(nwp.sum("b", "l").cast(nw.Boolean)).alias("all"), + nwp.max_horizontal(nwp.sum("b"), nwp.sum("l")).alias("max"), + nwp.min_horizontal(nwp.sum("b"), nwp.sum("l")).alias("min"), + nwp.sum_horizontal(nwp.sum("b"), nwp.sum("l")).alias("sum"), + nwp.mean_horizontal(nwp.sum("b"), nwp.sum("l")).alias("mean"), ], { "a": [1, 2, 3], @@ -318,39 +318,39 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: id="sumh_broadcasting", ), pytest.param( - nwd.mean_horizontal("j", nwd.col("k"), "e"), + nwp.mean_horizontal("j", nwp.col("k"), "e"), {"j": [27.05, 9.5, 5.5]}, id="mean_horizontal-null", ), pytest.param( - nwd.sum_horizontal("j", nwd.col("k"), "e"), + nwp.sum_horizontal("j", nwp.col("k"), "e"), {"j": [54.1, 19.0, 11.0]}, id="sum_horizontal-null", ), pytest.param( - nwd.concat_str(nwd.col("b") * 2, "n", nwd.col("o"), separator=" "), + nwp.concat_str(nwp.col("b") * 2, "n", nwp.col("o"), separator=" "), {"b": ["2 dogs play", "4 cats swim", None]}, id="concat_str-preserve_nulls", ), pytest.param( - nwd.concat_str( - nwd.col("b") * 2, "n", nwd.col("o"), separator=" ", ignore_nulls=True + nwp.concat_str( + nwp.col("b") * 2, "n", nwp.col("o"), separator=" ", ignore_nulls=True ), {"b": ["2 dogs play", "4 cats swim", "6 walk"]}, id="concat_str-ignore_nulls", ), pytest.param( - nwd.concat_str("a", nwd.lit("a")), + nwp.concat_str("a", nwp.lit("a")), {"a": ["Aa", "Ba", "Aa"]}, id="concat_str-lit", ), pytest.param( - nwd.concat_str( - nwd.lit("a"), - nwd.lit("b"), - nwd.lit("c"), - nwd.lit("d"), - nwd.col("e").last() + 13, + nwp.concat_str( + nwp.lit("a"), + nwp.lit("b"), + nwp.lit("c"), + nwp.lit("d"), + nwp.col("e").last() + 13, separator="|", ), {"literal": ["a|b|c|d|20"]}, @@ -358,7 +358,7 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: ), pytest.param( [ - nwd.col("a") + nwp.col("a") .alias("...") .map_batches( lambda s: s.from_iterable( @@ -368,13 +368,13 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: ), is_elementwise=True, ), - nwd.col("a"), + nwp.col("a"), ], {"funky": ["string", "string", "last"], "a": ["A", "B", "A"]}, id="map_batches-series", ), pytest.param( - nwd.col("b") + nwp.col("b") .map_batches(lambda s: s.to_numpy() + 1, nw.Float64(), is_elementwise=True) .sum(), {"b": [9.0]}, @@ -388,7 +388,7 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: id="map_batches-selector", ), pytest.param( - nwd.col("j", "k") + nwp.col("j", "k") .fill_null(15) .map_batches(lambda s: (s.to_numpy().max()), returns_scalar=True), {"j": [15], "k": [42]}, @@ -402,12 +402,12 @@ def _ids_ir(expr: nwd.Expr | Any) -> str: ids=_ids_ir, ) def test_select( - expr: nwd.Expr | Sequence[nwd.Expr], + expr: nwp.Expr | Sequence[nwp.Expr], expected: dict[str, Any], data_small: dict[str, Any], ) -> None: frame = pa.table(data_small) - df = nwd.DataFrame.from_native(frame) + df = nwp.DataFrame.from_native(frame) result = df.select(expr).to_dict(as_series=False) assert_equal_data(result, expected) @@ -416,7 +416,7 @@ def test_select( ("expr", "expected"), [ ( - ["d", nwd.col("a"), "b", nwd.col("e")], + ["d", nwp.col("a"), "b", nwp.col("e")], { "a": ["A", "B", "A"], "b": [1, 2, 3], @@ -439,9 +439,9 @@ def test_select( ), ( [ - nwd.col("e").fill_null(nwd.col("e").last()), - nwd.col("f").sort(), - nwd.nth(1).max(), + nwp.col("e").fill_null(nwp.col("e").last()), + nwp.col("f").sort(), + nwp.nth(1).max(), ], { "a": ["A", "B", "A"], @@ -454,11 +454,11 @@ def test_select( ), pytest.param( [ - nwd.col("a").alias("a?"), + nwp.col("a").alias("a?"), ndcs.by_name("a"), - nwd.col("b").cast(nw.Float64).name.suffix("_float"), - nwd.col("c").max() + 1, - nwd.sum_horizontal(1, "d", nwd.col("b"), nwd.lit(3)), + nwp.col("b").cast(nw.Float64).name.suffix("_float"), + nwp.col("c").max() + 1, + nwp.sum_horizontal(1, "d", nwp.col("b"), nwp.lit(3)), ], { "a": ["A", "B", "A"], @@ -476,22 +476,22 @@ def test_select( ], ) def test_with_columns( - expr: nwd.Expr | Sequence[nwd.Expr], + expr: nwp.Expr | Sequence[nwp.Expr], expected: dict[str, Any], data_smaller: dict[str, Any], ) -> None: frame = pa.table(data_smaller) - df = nwd.DataFrame.from_native(frame) + df = nwp.DataFrame.from_native(frame) result = df.with_columns(expr).to_dict(as_series=False) assert_equal_data(result, expected) -def first(*names: str) -> nwd.Expr: - return nwd.col(*names).first() +def first(*names: str) -> nwp.Expr: + return nwp.col(*names).first() -def last(*names: str) -> nwd.Expr: - return nwd.col(*names).last() +def last(*names: str) -> nwp.Expr: + return nwp.col(*names).last() @pytest.mark.parametrize( @@ -506,12 +506,12 @@ def last(*names: str) -> nwd.Expr: ], ) def test_first_last_expr_with_columns( - data_indexed: dict[str, Any], agg: nwd.Expr, expected: PythonLiteral + data_indexed: dict[str, Any], agg: nwp.Expr, expected: PythonLiteral ) -> None: """Related https://github.com/narwhals-dev/narwhals/pull/2528#discussion_r2225930065.""" height = len(next(iter(data_indexed.values()))) expected_broadcast = height * [expected] - frame = nwd.DataFrame.from_native(pa.table(data_indexed)) + frame = nwp.DataFrame.from_native(pa.table(data_indexed)) expr = agg.over(order_by="idx").alias("result") result = frame.with_columns(expr).select("result").to_dict(as_series=False) assert_equal_data(result, {"result": expected_broadcast}) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index d881a3a132..a80724ff86 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -6,7 +6,7 @@ import pytest import narwhals as nw -from narwhals import _plan as nwd +from narwhals import _plan as nwp from narwhals._plan import expressions as ir, selectors as ndcs from narwhals._plan._expansion import ( prepare_projection, @@ -52,9 +52,9 @@ def schema_1() -> dict[str, DType]: MULTI_OUTPUT_EXPRS = ( - pytest.param(nwd.col("a", "b", "c")), + pytest.param(nwp.col("a", "b", "c")), pytest.param(ndcs.numeric() - ndcs.matches("[d-j]")), - pytest.param(nwd.nth(0, 1, 2)), + pytest.param(nwp.nth(0, 1, 2)), pytest.param(ndcs.by_dtype(nw.Int64, nw.Int32, nw.Int16)), pytest.param(ndcs.by_name("a", "b", "c")), ) @@ -74,19 +74,19 @@ def udf_name_map(name: str) -> str: @pytest.mark.parametrize( ("expr", "expected"), [ - (nwd.col("a").name.to_uppercase(), "A"), - (nwd.col("B").name.to_lowercase(), "b"), - (nwd.col("c").name.suffix("_after"), "c_after"), - (nwd.col("d").name.prefix("before_"), "before_d"), + (nwp.col("a").name.to_uppercase(), "A"), + (nwp.col("B").name.to_lowercase(), "b"), + (nwp.col("c").name.suffix("_after"), "c_after"), + (nwp.col("d").name.prefix("before_"), "before_d"), ( - nwd.col("aBcD EFg hi").name.map(udf_name_map), + nwp.col("aBcD EFg hi").name.map(udf_name_map), "original='aBcD EFg hi' | upper='ABCD EFG HI' | lower='abcd efg hi' | title='Abcd Efg Hi'", ), - (nwd.col("a").min().alias("b").over("c").alias("d").max().name.keep(), "a"), + (nwp.col("a").min().alias("b").over("c").alias("d").max().name.keep(), "a"), ( ( - nwd.col("hello") - .sort_by(nwd.col("ignore me")) + nwp.col("hello") + .sort_by(nwp.col("ignore me")) .max() .over("ignore me as well") .first() @@ -96,7 +96,7 @@ def udf_name_map(name: str) -> str: ), ( ( - nwd.col("start") + nwp.col("start") .alias("next") .sort() .round() @@ -108,7 +108,7 @@ def udf_name_map(name: str) -> str: ), ], ) -def test_rewrite_special_aliases_single(expr: nwd.Expr, expected: str) -> None: +def test_rewrite_special_aliases_single(expr: nwp.Expr, expected: str) -> None: # NOTE: We can't use `output_name()` without resolving these rewrites # Once they're done, `output_name()` just peeks into `Alias(name=...)` ir_input = expr._ir @@ -152,33 +152,33 @@ def fn(e_ir: ir.ExprIR) -> ir.ExprIR: @pytest.mark.parametrize( ("expr", "function", "expected"), [ - (nwd.col("a"), alias_replace_guarded("never"), nwd.col("a")), - (nwd.col("a"), alias_replace_unguarded("never"), nwd.col("a")), - (nwd.col("a").alias("b"), alias_replace_guarded("c"), nwd.col("a").alias("c")), - (nwd.col("a").alias("b"), alias_replace_unguarded("c"), nwd.col("a").alias("c")), + (nwp.col("a"), alias_replace_guarded("never"), nwp.col("a")), + (nwp.col("a"), alias_replace_unguarded("never"), nwp.col("a")), + (nwp.col("a").alias("b"), alias_replace_guarded("c"), nwp.col("a").alias("c")), + (nwp.col("a").alias("b"), alias_replace_unguarded("c"), nwp.col("a").alias("c")), ( - nwd.col("a").alias("d").first().over("b", order_by="c").alias("e"), + nwp.col("a").alias("d").first().over("b", order_by="c").alias("e"), alias_replace_guarded("d"), - nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), + nwp.col("a").alias("d").first().over("b", order_by="c").alias("d"), ), ( - nwd.col("a").alias("d").first().over("b", order_by="c").alias("e"), + nwp.col("a").alias("d").first().over("b", order_by="c").alias("e"), alias_replace_unguarded("d"), - nwd.col("a").alias("d").first().over("b", order_by="c").alias("d"), + nwp.col("a").alias("d").first().over("b", order_by="c").alias("d"), ), ( - nwd.col("a").alias("e").abs().alias("f").sort().alias("g"), + nwp.col("a").alias("e").abs().alias("f").sort().alias("g"), alias_replace_guarded("e"), - nwd.col("a").alias("e").abs().alias("e").sort().alias("e"), + nwp.col("a").alias("e").abs().alias("e").sort().alias("e"), ), ( - nwd.col("a").alias("e").abs().alias("f").sort().alias("g"), + nwp.col("a").alias("e").abs().alias("f").sort().alias("g"), alias_replace_unguarded("e"), - nwd.col("a").alias("e").abs().alias("e").sort().alias("e"), + nwp.col("a").alias("e").abs().alias("e").sort().alias("e"), ), ], ) -def test_map_ir_recursive(expr: nwd.Expr, function: MapIR, expected: nwd.Expr) -> None: +def test_map_ir_recursive(expr: nwp.Expr, function: MapIR, expected: nwp.Expr) -> None: actual = expr._ir.map_ir(function) assert_expr_ir_equal(actual, expected) @@ -186,17 +186,17 @@ def test_map_ir_recursive(expr: nwd.Expr, function: MapIR, expected: nwd.Expr) - @pytest.mark.parametrize( ("expr", "expected"), [ - (nwd.col("a"), nwd.col("a")), - (nwd.col("a").max().alias("z"), nwd.col("a").max().alias("z")), + (nwp.col("a"), nwp.col("a")), + (nwp.col("a").max().alias("z"), nwp.col("a").max().alias("z")), (ndcs.string(), ir.Columns(names=("k",))), ( ndcs.by_dtype(nw.Datetime("ms"), nw.Date, nw.List(nw.String)), - nwd.col("n", "s"), + nwp.col("n", "s"), ), - (ndcs.string() | ndcs.boolean(), nwd.col("k", "m")), + (ndcs.string() | ndcs.boolean(), nwp.col("k", "m")), ( ~(ndcs.numeric() | ndcs.string()), - nwd.col("l", "m", "n", "o", "p", "q", "r", "s", "u"), + nwp.col("l", "m", "n", "o", "p", "q", "r", "s", "u"), ), ( ( @@ -204,14 +204,14 @@ def test_map_ir_recursive(expr: nwd.Expr, function: MapIR, expected: nwd.Expr) - - (ndcs.categorical() | ndcs.by_name("a", "b") | ndcs.matches("[fqohim]")) ^ ndcs.by_name("u", "a", "b", "d", "e", "f", "g") ).name.suffix("_after"), - nwd.col("a", "b", "c", "f", "j", "k", "l", "n", "r", "s").name.suffix( + nwp.col("a", "b", "c", "f", "j", "k", "l", "n", "r", "s").name.suffix( "_after" ), ), ( (ndcs.matches("[a-m]") & ~ndcs.numeric()).sort(nulls_last=True).first() - != nwd.lit(None), - nwd.col("k", "l", "m").sort(nulls_last=True).first() != nwd.lit(None), + != nwp.lit(None), + nwp.col("k", "l", "m").sort(nulls_last=True).first() != nwp.lit(None), ), ( ( @@ -220,9 +220,9 @@ def test_map_ir_recursive(expr: nwd.Expr, function: MapIR, expected: nwd.Expr) - .over("k", order_by=ndcs.by_dtype(nw.Date()) | ndcs.boolean()) ), ( - nwd.col("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + nwp.col("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") .mean() - .over(nwd.col("k"), order_by=nwd.col("m", "n")) + .over(nwp.col("k"), order_by=nwp.col("m", "n")) ), ), ( @@ -235,10 +235,10 @@ def test_map_ir_recursive(expr: nwd.Expr, function: MapIR, expected: nwd.Expr) - .name.to_uppercase() ), ( - nwd.col("l", "o") + nwp.col("l", "o") .dt.timestamp("us") .min() - .over(nwd.col("k", "m")) + .over(nwp.col("k", "m")) .last() .name.to_uppercase() ), @@ -246,8 +246,8 @@ def test_map_ir_recursive(expr: nwd.Expr, function: MapIR, expected: nwd.Expr) - ], ) def test_replace_selector( - expr: nwd.Selector | nwd.Expr, - expected: nwd.Expr | ir.ExprIR, + expr: nwp.Selector | nwp.Expr, + expected: nwp.Expr | ir.ExprIR, schema_1: dict[str, DType], ) -> None: actual = replace_selector(expr._ir, schema=freeze_schema(**schema_1)) @@ -257,41 +257,41 @@ def test_replace_selector( @pytest.mark.parametrize( ("into_exprs", "expected"), [ - ("a", [nwd.col("a")]), - (nwd.col("b", "c", "d"), [nwd.col("b"), nwd.col("c"), nwd.col("d")]), - (nwd.nth(6), [nwd.col("g")]), - (nwd.nth(9, 8, -5), [nwd.col("j"), nwd.col("i"), nwd.col("p")]), + ("a", [nwp.col("a")]), + (nwp.col("b", "c", "d"), [nwp.col("b"), nwp.col("c"), nwp.col("d")]), + (nwp.nth(6), [nwp.col("g")]), + (nwp.nth(9, 8, -5), [nwp.col("j"), nwp.col("i"), nwp.col("p")]), ( - [nwd.nth(2).alias("c again"), nwd.nth(-1, -2).name.to_uppercase()], + [nwp.nth(2).alias("c again"), nwp.nth(-1, -2).name.to_uppercase()], [ - nwd.col("c").alias("c again"), - nwd.col("u").alias("U"), - nwd.col("s").alias("S"), + nwp.col("c").alias("c again"), + nwp.col("u").alias("U"), + nwp.col("s").alias("S"), ], ), ( - nwd.all(), + nwp.all(), [ - nwd.col("a"), - nwd.col("b"), - nwd.col("c"), - nwd.col("d"), - nwd.col("e"), - nwd.col("f"), - nwd.col("g"), - nwd.col("h"), - nwd.col("i"), - nwd.col("j"), - nwd.col("k"), - nwd.col("l"), - nwd.col("m"), - nwd.col("n"), - nwd.col("o"), - nwd.col("p"), - nwd.col("q"), - nwd.col("r"), - nwd.col("s"), - nwd.col("u"), + nwp.col("a"), + nwp.col("b"), + nwp.col("c"), + nwp.col("d"), + nwp.col("e"), + nwp.col("f"), + nwp.col("g"), + nwp.col("h"), + nwp.col("i"), + nwp.col("j"), + nwp.col("k"), + nwp.col("l"), + nwp.col("m"), + nwp.col("n"), + nwp.col("o"), + nwp.col("p"), + nwp.col("q"), + nwp.col("r"), + nwp.col("s"), + nwp.col("u"), ], ), ( @@ -300,21 +300,21 @@ def test_replace_selector( .mean() .name.suffix("_mean"), [ - nwd.col("a").cast(nw.Int64()).mean().alias("a_mean"), - nwd.col("b").cast(nw.Int64()).mean().alias("b_mean"), - nwd.col("c").cast(nw.Int64()).mean().alias("c_mean"), - nwd.col("d").cast(nw.Int64()).mean().alias("d_mean"), - nwd.col("e").cast(nw.Int64()).mean().alias("e_mean"), - nwd.col("f").cast(nw.Int64()).mean().alias("f_mean"), - nwd.col("g").cast(nw.Int64()).mean().alias("g_mean"), - nwd.col("h").cast(nw.Int64()).mean().alias("h_mean"), + nwp.col("a").cast(nw.Int64()).mean().alias("a_mean"), + nwp.col("b").cast(nw.Int64()).mean().alias("b_mean"), + nwp.col("c").cast(nw.Int64()).mean().alias("c_mean"), + nwp.col("d").cast(nw.Int64()).mean().alias("d_mean"), + nwp.col("e").cast(nw.Int64()).mean().alias("e_mean"), + nwp.col("f").cast(nw.Int64()).mean().alias("f_mean"), + nwp.col("g").cast(nw.Int64()).mean().alias("g_mean"), + nwp.col("h").cast(nw.Int64()).mean().alias("h_mean"), ], ), ( - nwd.col("u").alias("1").alias("2").alias("3").alias("4").name.keep(), + nwp.col("u").alias("1").alias("2").alias("3").alias("4").name.keep(), # NOTE: Would be nice to rewrite with less intermediate steps # but retrieving the root name is enough for now - [nwd.col("u").alias("1").alias("2").alias("3").alias("4").alias("u")], + [nwp.col("u").alias("1").alias("2").alias("3").alias("4").alias("u")], ), ( ( @@ -322,30 +322,30 @@ def test_replace_selector( * 100 ).name.suffix("_mult_100"), [ - (nwd.col("e") * nwd.lit(100)).alias("e_mult_100"), - (nwd.col("h") * nwd.lit(100)).alias("h_mult_100"), - (nwd.col("j") * nwd.lit(100)).alias("j_mult_100"), + (nwp.col("e") * nwp.lit(100)).alias("e_mult_100"), + (nwp.col("h") * nwp.lit(100)).alias("h_mult_100"), + (nwp.col("j") * nwp.lit(100)).alias("j_mult_100"), ], ), ( ndcs.by_dtype(nw.Duration()) .dt.total_minutes() .name.map(lambda nm: f"total_mins: {nm!r} ?"), - [nwd.col("q").dt.total_minutes().alias("total_mins: 'q' ?")], + [nwp.col("q").dt.total_minutes().alias("total_mins: 'q' ?")], ), ( - nwd.col("f", "g") + nwp.col("f", "g") .cast(nw.String) .str.starts_with("1") .all() .name.suffix("_all_starts_with_1"), [ - nwd.col("f") + nwp.col("f") .cast(nw.String) .str.starts_with("1") .all() .alias("f_all_starts_with_1"), - nwd.col("g") + nwp.col("g") .cast(nw.String) .str.starts_with("1") .all() @@ -353,66 +353,66 @@ def test_replace_selector( ], ), ( - nwd.col("a", "b") + nwp.col("a", "b") .first() .over("c", "e", order_by="d") .name.suffix("_first_over_part_order_1"), [ - nwd.col("a") + nwp.col("a") .first() - .over(nwd.col("c"), nwd.col("e"), order_by=[nwd.col("d")]) + .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]) .alias("a_first_over_part_order_1"), - nwd.col("b") + nwp.col("b") .first() - .over(nwd.col("c"), nwd.col("e"), order_by=[nwd.col("d")]) + .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]) .alias("b_first_over_part_order_1"), ], ), ( - nwd.exclude(BIG_EXCLUDE), + nwp.exclude(BIG_EXCLUDE), [ - nwd.col("c"), - nwd.col("d"), - nwd.col("f"), - nwd.col("g"), - nwd.col("h"), - nwd.col("i"), - nwd.col("j"), + nwp.col("c"), + nwp.col("d"), + nwp.col("f"), + nwp.col("g"), + nwp.col("h"), + nwp.col("i"), + nwp.col("j"), ], ), ( - nwd.exclude(BIG_EXCLUDE).name.suffix("_2"), + nwp.exclude(BIG_EXCLUDE).name.suffix("_2"), [ - nwd.col("c").alias("c_2"), - nwd.col("d").alias("d_2"), - nwd.col("f").alias("f_2"), - nwd.col("g").alias("g_2"), - nwd.col("h").alias("h_2"), - nwd.col("i").alias("i_2"), - nwd.col("j").alias("j_2"), + nwp.col("c").alias("c_2"), + nwp.col("d").alias("d_2"), + nwp.col("f").alias("f_2"), + nwp.col("g").alias("g_2"), + nwp.col("h").alias("h_2"), + nwp.col("i").alias("i_2"), + nwp.col("j").alias("j_2"), ], ), ( - nwd.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), + nwp.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), [ - nwd.col("c") + nwp.col("c") .alias("c_min_over_order_by") .min() - .over(order_by=[nwd.col("k")]) + .over(order_by=[nwp.col("k")]) ], ), pytest.param( - (ndcs.by_name("a", "b", "c") / nwd.col("e").first()) + (ndcs.by_name("a", "b", "c") / nwp.col("e").first()) .over("g", "f", order_by="f") .name.prefix("hi_"), [ - (nwd.col("a") / nwd.col("e").first()) + (nwp.col("a") / nwp.col("e").first()) .over("g", "f", order_by="f") .alias("hi_a"), - (nwd.col("b") / nwd.col("e").first()) + (nwp.col("b") / nwp.col("e").first()) .over("g", "f", order_by="f") .alias("hi_b"), - (nwd.col("c") / nwd.col("e").first()) + (nwp.col("c") / nwp.col("e").first()) .over("g", "f", order_by="f") .alias("hi_c"), ], @@ -422,7 +422,7 @@ def test_replace_selector( ) def test_prepare_projection( into_exprs: IntoExpr | Sequence[IntoExpr], - expected: Sequence[nwd.Expr], + expected: Sequence[nwp.Expr], schema_1: dict[str, DType], ) -> None: irs_in = parse_into_seq_of_expr_ir(into_exprs) @@ -435,19 +435,19 @@ def test_prepare_projection( @pytest.mark.parametrize( "expr", [ - nwd.all(), - nwd.nth(1, 2, 3), - nwd.col("a", "b", "c"), + nwp.all(), + nwp.nth(1, 2, 3), + nwp.col("a", "b", "c"), ndcs.boolean() | ndcs.categorical(), (ndcs.by_name("a", "b") | ndcs.string()), - (nwd.col("b", "c") & nwd.col("a")), - nwd.col("a", "b").min().over("c", order_by="e"), + (nwp.col("b", "c") & nwp.col("a")), + nwp.col("a", "b").min().over("c", order_by="e"), (~ndcs.by_dtype(nw.Int64()) - ndcs.datetime()), - nwd.nth(6, 2).abs().cast(nw.Int32()) + 10, + nwp.nth(6, 2).abs().cast(nw.Int32()) + 10, *MULTI_OUTPUT_EXPRS, ], ) -def test_prepare_projection_duplicate(expr: nwd.Expr, schema_1: dict[str, DType]) -> None: +def test_prepare_projection_duplicate(expr: nwp.Expr, schema_1: dict[str, DType]) -> None: irs = parse_into_seq_of_expr_ir(expr.alias("dupe")) pattern = re.compile(r"\.alias\(.dupe.\)") with pytest.raises(DuplicateError, match=pattern): @@ -457,12 +457,12 @@ def test_prepare_projection_duplicate(expr: nwd.Expr, schema_1: dict[str, DType] @pytest.mark.parametrize( ("into_exprs", "missing"), [ - ([nwd.col("y", "z")], ["y", "z"]), - ([nwd.col("a", "b", "z")], ["z"]), - ([nwd.col("x", "b", "a")], ["x"]), + ([nwp.col("y", "z")], ["y", "z"]), + ([nwp.col("a", "b", "z")], ["z"]), + ([nwp.col("x", "b", "a")], ["x"]), ( [ - nwd.col( + nwp.col( [ "a", "b", @@ -491,18 +491,18 @@ def test_prepare_projection_duplicate(expr: nwd.Expr, schema_1: dict[str, DType] ["FIVE"], ), ( - [nwd.col("a").min().over("c").alias("y"), nwd.col("one").alias("b").last()], + [nwp.col("a").min().over("c").alias("y"), nwp.col("one").alias("b").last()], ["one"], ), - ([nwd.col("a").sort_by("b", "who").alias("f")], ["who"]), + ([nwp.col("a").sort_by("b", "who").alias("f")], ["who"]), ( [ - nwd.nth(0, 5) + nwp.nth(0, 5) .cast(nw.Int64()) .abs() .cum_sum() .over("X", "O", "h", "m", "r", "zee"), - nwd.col("d", "j"), + nwp.col("d", "j"), "n", ], ["O", "X", "zee"], @@ -525,29 +525,29 @@ def test_prepare_projection_column_not_found( [ ("a", "b", "c"), (["a", "b", "c"]), - ("a", "b", nwd.col("c")), - (nwd.col("a"), "b", "c"), - (nwd.col("a", "b"), "c"), - ("a", nwd.col("b", "c")), - ((nwd.nth(0), nwd.nth(1, 2))), + ("a", "b", nwp.col("c")), + (nwp.col("a"), "b", "c"), + (nwp.col("a", "b"), "c"), + ("a", nwp.col("b", "c")), + ((nwp.nth(0), nwp.nth(1, 2))), *MULTI_OUTPUT_EXPRS, ], ) @pytest.mark.parametrize( "function", [ - nwd.all_horizontal, - nwd.any_horizontal, - nwd.sum_horizontal, - nwd.min_horizontal, - nwd.max_horizontal, - nwd.mean_horizontal, - nwd.concat_str, + nwp.all_horizontal, + nwp.any_horizontal, + nwp.sum_horizontal, + nwp.min_horizontal, + nwp.max_horizontal, + nwp.mean_horizontal, + nwp.concat_str, ], ) def test_prepare_projection_horizontal_alias( into_exprs: IntoExpr | Iterable[IntoExpr], - function: Callable[..., nwd.Expr], + function: Callable[..., nwp.Expr], schema_1: dict[str, DType], ) -> None: # NOTE: See https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2139965411 @@ -566,7 +566,7 @@ def test_prepare_projection_horizontal_alias( @pytest.mark.parametrize( - "into_exprs", [nwd.nth(-21), nwd.nth(-1, 2, 54, 0), nwd.nth(20), nwd.nth([-10, -100])] + "into_exprs", [nwp.nth(-21), nwp.nth(-1, 2, 54, 0), nwp.nth(20), nwp.nth([-10, -100])] ) def test_prepare_projection_index_error( into_exprs: IntoExpr | Iterable[IntoExpr], schema_1: dict[str, DType] diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index ad10e3a7cb..5b525001a9 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -10,7 +10,7 @@ import pytest import narwhals as nw -from narwhals import _plan as nwd +from narwhals import _plan as nwp from narwhals._plan import expressions as ir from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.expressions import functions as F, operators as ops @@ -39,15 +39,15 @@ @pytest.mark.parametrize( ("exprs", "named_exprs"), [ - ([nwd.col("a")], {}), + ([nwp.col("a")], {}), (["a"], {}), ([], {"a": "b"}), - ([], {"a": nwd.col("b")}), - (["a", "b", nwd.col("c", "d", "e")], {"g": nwd.lit(1)}), - ([["a", "b", "c"]], {"q": nwd.lit(5, nw.Int8())}), + ([], {"a": nwp.col("b")}), + (["a", "b", nwp.col("c", "d", "e")], {"g": nwp.lit(1)}), + ([["a", "b", "c"]], {"q": nwp.lit(5, nw.Int8())}), ( - [[nwd.nth(1), nwd.nth(2, 3, 4)]], - {"n": nwd.col("p").count(), "other n": nwd.len()}, + [[nwp.nth(1), nwp.nth(2, 3, 4)]], + {"n": nwp.col("p").count(), "other n": nwp.len()}, ), ], ) @@ -63,12 +63,12 @@ def test_parsing( @pytest.mark.parametrize( ("function", "ir_node"), [ - (nwd.all_horizontal, ir.boolean.AllHorizontal), - (nwd.any_horizontal, ir.boolean.AnyHorizontal), - (nwd.sum_horizontal, F.SumHorizontal), - (nwd.min_horizontal, F.MinHorizontal), - (nwd.max_horizontal, F.MaxHorizontal), - (nwd.mean_horizontal, F.MeanHorizontal), + (nwp.all_horizontal, ir.boolean.AllHorizontal), + (nwp.any_horizontal, ir.boolean.AnyHorizontal), + (nwp.sum_horizontal, F.SumHorizontal), + (nwp.min_horizontal, F.MinHorizontal), + (nwp.max_horizontal, F.MaxHorizontal), + (nwp.mean_horizontal, F.MeanHorizontal), ], ) @pytest.mark.parametrize( @@ -76,23 +76,23 @@ def test_parsing( [ ("a", "b", "c"), (["a", "b", "c"]), - (nwd.col("d", "e", "f"), nwd.col("g"), "q", nwd.nth(9)), - ((nwd.lit(1),)), - ([nwd.lit(1), nwd.lit(2, nw.Int64), nwd.lit(3, nw.Int64())]), + (nwp.col("d", "e", "f"), nwp.col("g"), "q", nwp.nth(9)), + ((nwp.lit(1),)), + ([nwp.lit(1), nwp.lit(2, nw.Int64), nwp.lit(3, nw.Int64())]), ], ) def test_function_expr_horizontal( - function: Callable[..., nwd.Expr], + function: Callable[..., nwp.Expr], ir_node: type[Function], args: Seq[IntoExpr | Iterable[IntoExpr]], ) -> None: variadic = function(*args) sequence = function(args) - assert isinstance(variadic, nwd.Expr) - assert isinstance(sequence, nwd.Expr) + assert isinstance(variadic, nwp.Expr) + assert isinstance(sequence, nwp.Expr) variadic_node = variadic._ir sequence_node = sequence._ir - unrelated_node = nwd.lit(1)._ir + unrelated_node = nwp.lit(1)._ir assert isinstance(variadic_node, ir.FunctionExpr) assert isinstance(variadic_node.function, ir_node) assert variadic_node == sequence_node @@ -105,7 +105,7 @@ def test_valid_windows() -> None: https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L10-L45 """ ELEMENTWISE_ERR = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) # noqa: N806 - a = nwd.col("a") + a = nwp.col("a") assert a.cum_sum() assert a.cum_sum().over(order_by="id") with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): @@ -114,32 +114,32 @@ def test_valid_windows() -> None: assert (a.cum_sum() + 1).over(order_by="id") assert a.cum_sum().cum_sum().over(order_by="id") assert a.cum_sum().cum_sum() - assert nwd.sum_horizontal(a, a.cum_sum()) + assert nwp.sum_horizontal(a, a.cum_sum()) with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): - assert nwd.sum_horizontal(a, a.cum_sum()).over(order_by="a") + assert nwp.sum_horizontal(a, a.cum_sum()).over(order_by="a") - assert nwd.sum_horizontal(a, a.cum_sum().over(order_by="i")) - assert nwd.sum_horizontal(a.diff(), a.cum_sum().over(order_by="i")) + assert nwp.sum_horizontal(a, a.cum_sum().over(order_by="i")) + assert nwp.sum_horizontal(a.diff(), a.cum_sum().over(order_by="i")) with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): - assert nwd.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i") + assert nwp.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i") with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): - assert nwd.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i") + assert nwp.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i") def test_invalid_repeat_agg() -> None: with pytest.raises(InvalidOperationError): - nwd.col("a").mean().mean() + nwp.col("a").mean().mean() with pytest.raises(InvalidOperationError): - nwd.col("a").first().max() + nwp.col("a").first().max() with pytest.raises(InvalidOperationError): - nwd.col("a").any().std() + nwp.col("a").any().std() with pytest.raises(InvalidOperationError): - nwd.col("a").all().quantile(0.5, "linear") + nwp.col("a").all().quantile(0.5, "linear") with pytest.raises(InvalidOperationError): - nwd.col("a").arg_max().min() + nwp.col("a").arg_max().min() with pytest.raises(InvalidOperationError): - nwd.col("a").arg_min().arg_max() + nwp.col("a").arg_min().arg_max() # NOTE: Previously multiple different errors, but they can be reduced to the same thing @@ -147,19 +147,19 @@ def test_invalid_repeat_agg() -> None: def test_invalid_agg_non_elementwise() -> None: pattern = re.compile(r"cannot use.+rank.+aggregated.+mean", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").mean().rank() + nwp.col("a").mean().rank() pattern = re.compile(r"cannot use.+drop_nulls.+aggregated.+max", re.IGNORECASE) with pytest.raises(InvalidOperationError): - nwd.col("a").max().drop_nulls() + nwp.col("a").max().drop_nulls() pattern = re.compile(r"cannot use.+diff.+aggregated.+min", re.IGNORECASE) with pytest.raises(InvalidOperationError): - nwd.col("a").min().diff() + nwp.col("a").min().diff() def test_agg_non_elementwise_range_special() -> None: - e = nwd.int_range(0, 100) + e = nwp.int_range(0, 100) assert isinstance(e._ir, ir.RangeExpr) - e = nwd.int_range(nwd.len(), dtype=nw.UInt32).alias("index") + e = nwp.int_range(nwp.len(), dtype=nw.UInt32).alias("index") e_ir = e._ir assert isinstance(e_ir, ir.Alias) assert isinstance(e_ir.expr, ir.RangeExpr) @@ -170,28 +170,28 @@ def test_agg_non_elementwise_range_special() -> None: def test_invalid_int_range() -> None: pattern = re.compile(r"scalar.+agg", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): - nwd.int_range(nwd.col("a")) + nwp.int_range(nwp.col("a")) with pytest.raises(InvalidOperationError, match=pattern): - nwd.int_range(nwd.nth(1), 10) + nwp.int_range(nwp.nth(1), 10) with pytest.raises(InvalidOperationError, match=pattern): - nwd.int_range(0, nwd.col("a").abs()) + nwp.int_range(0, nwp.col("a").abs()) with pytest.raises(InvalidOperationError, match=pattern): - nwd.int_range(nwd.col("a") + 1) + nwp.int_range(nwp.col("a") + 1) # NOTE: Non-`polars`` rule def test_invalid_over() -> None: pattern = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").fill_null(3).over("b") + nwp.col("a").fill_null(3).over("b") def test_nested_over() -> None: pattern = re.compile(r"cannot nest.+over", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").mean().over("b").over("c") + nwp.col("a").mean().over("b").over("c") with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").mean().over("b").over("c", order_by="i") + nwp.col("a").mean().over("b").over("c", order_by="i") # NOTE: This *can* error in polars, but only if the length **actually changes** @@ -199,36 +199,36 @@ def test_nested_over() -> None: def test_filtration_over() -> None: pattern = re.compile(r"cannot use.+over.+change length", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").drop_nulls().over("b") + nwp.col("a").drop_nulls().over("b") with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").drop_nulls().over("b", order_by="i") + nwp.col("a").drop_nulls().over("b", order_by="i") with pytest.raises(InvalidOperationError, match=pattern): - nwd.col("a").diff().drop_nulls().over("b", order_by="i") + nwp.col("a").diff().drop_nulls().over("b", order_by="i") def test_invalid_binary_expr_multi() -> None: pattern = re.escape("all() + cols(['b', 'c'])\n ^^^^^^^^^^^^^^^^") with pytest.raises(MultiOutputExpressionError, match=pattern): - nwd.all() + nwd.col("b", "c") + nwp.all() + nwp.col("b", "c") pattern = re.escape( "index_columns((1, 2, 3)) * index_columns((4, 5, 6)).max()\n" " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" ) with pytest.raises(MultiOutputExpressionError, match=pattern): - nwd.nth(1, 2, 3) * nwd.nth(4, 5, 6).max() + nwp.nth(1, 2, 3) * nwp.nth(4, 5, 6).max() pattern = re.escape( "cols(['a', 'b', 'c']).abs().fill_null([lit(int: 0)]).round() * index_columns((9, 10)).cast(Int64).sort(asc)\n" " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" ) with pytest.raises(MultiOutputExpressionError, match=pattern): - nwd.col("a", "b", "c").abs().fill_null(0).round(2) * nwd.nth(9, 10).cast( + nwp.col("a", "b", "c").abs().fill_null(0).round(2) * nwp.nth(9, 10).cast( nw.Int64() ).sort() def test_invalid_binary_expr_length_changing() -> None: - a = nwd.col("a") - b = nwd.col("b") + a = nwp.col("a") + b = nwp.col("b") with pytest.raises(LengthChangingExprError): a.unique() + b.unique() @@ -246,13 +246,13 @@ def test_invalid_binary_expr_length_changing() -> None: a.map_batches(lambda x: x) / b.gather_every(1, 0) -def _is_expr_ir_binary_expr(expr: nwd.Expr) -> bool: +def _is_expr_ir_binary_expr(expr: nwp.Expr) -> bool: return isinstance(expr._ir, ir.BinaryExpr) def test_binary_expr_length_changing_agg() -> None: - a = nwd.col("a") - b = nwd.col("b") + a = nwp.col("a") + b = nwp.col("b") assert _is_expr_ir_binary_expr(a.unique().first() + b.unique()) assert _is_expr_ir_binary_expr(a.mode().last() * b.unique()) @@ -272,8 +272,8 @@ def test_invalid_binary_expr_shape() -> None: re.escape("Cannot combine length-changing expressions with length-preserving"), re.IGNORECASE, ) - a = nwd.col("a") - b = nwd.col("b") + a = nwp.col("a") + b = nwp.col("b") with pytest.raises(ShapeError, match=pattern): a.unique() + b @@ -287,7 +287,7 @@ def test_invalid_binary_expr_shape() -> None: def test_is_in_seq(into_iter: IntoIterable) -> None: expected = 1, 2, 3 other = into_iter(list(expected)) - expr = nwd.col("a").is_in(other) + expr = nwp.col("a").is_in(other) e_ir = expr._ir assert isinstance(e_ir, ir.FunctionExpr) assert isinstance(e_ir.function, ir.boolean.IsInSeq) @@ -299,8 +299,8 @@ def test_is_in_series() -> None: import pyarrow as pa native = pa.chunked_array([pa.array([1, 2, 3])]) - other = nwd.Series.from_native(native) - expr = nwd.col("a").is_in(other) + other = nwp.Series.from_native(native) + expr = nwp.col("a").is_in(other) e_ir = expr._ir assert isinstance(e_ir, ir.FunctionExpr) assert isinstance(e_ir.function, ir.boolean.IsInSeries) @@ -313,7 +313,7 @@ def test_is_in_series() -> None: ("words", pytest.raises(TypeError, match=r"str \| bytes.+str")), (b"words", pytest.raises(TypeError, match=r"str \| bytes.+bytes")), ( - nwd.col("b"), + nwp.col("b"), pytest.raises( NotImplementedError, match=re.compile(r"iterable instead", re.IGNORECASE) ), @@ -328,19 +328,19 @@ def test_is_in_series() -> None: ) def test_invalid_is_in(other: Any, context: AbstractContextManager[Any]) -> None: with context: - nwd.col("a").is_in(other) + nwp.col("a").is_in(other) def test_filter_full_spellings() -> None: - a = nwd.col("a") - b = nwd.col("b") - c = nwd.col("c") - d = nwd.col("d") - expected = a.filter(b != b.max(), c < nwd.lit(2), d == nwd.lit(5)) - expr_1 = a.filter([b != b.max(), c < nwd.lit(2), d == nwd.lit(5)]) - expr_2 = a.filter([b != b.max(), c < nwd.lit(2)], d=nwd.lit(5)) - expr_3 = a.filter([b != b.max(), c < nwd.lit(2)], d=5) - expr_4 = a.filter(b != b.max(), c < nwd.lit(2), d=5) + a = nwp.col("a") + b = nwp.col("b") + c = nwp.col("c") + d = nwp.col("d") + expected = a.filter(b != b.max(), c < nwp.lit(2), d == nwp.lit(5)) + expr_1 = a.filter([b != b.max(), c < nwp.lit(2), d == nwp.lit(5)]) + expr_2 = a.filter([b != b.max(), c < nwp.lit(2)], d=nwp.lit(5)) + expr_3 = a.filter([b != b.max(), c < nwp.lit(2)], d=5) + expr_4 = a.filter(b != b.max(), c < nwp.lit(2), d=5) expr_5 = a.filter(b != b.max(), c < 2, d=5) expr_6 = a.filter((b != b.max(), c < 2), d=5) assert_expr_ir_equal(expected, expr_1) @@ -354,9 +354,9 @@ def test_filter_full_spellings() -> None: @pytest.mark.parametrize( ("predicates", "constraints", "context"), [ - ([nwd.col("b").is_last_distinct()], {}, nullcontext()), + ([nwp.col("b").is_last_distinct()], {}, nullcontext()), ((), {"b": 10}, nullcontext()), - ((), {"b": nwd.lit(10)}, nullcontext()), + ((), {"b": nwp.lit(10)}, nullcontext()), ( (), {}, @@ -364,9 +364,9 @@ def test_filter_full_spellings() -> None: TypeError, match=re.compile(r"at least one predicate", re.IGNORECASE) ), ), - ((nwd.col("b") > 1, nwd.col("c").is_null()), {}, nullcontext()), + ((nwp.col("b") > 1, nwp.col("c").is_null()), {}, nullcontext()), ( - ([nwd.col("b") > 1], nwd.col("c").is_null()), + ([nwp.col("b") > 1], nwp.col("c").is_null()), {}, pytest.raises( InvalidIntoExprError, @@ -383,7 +383,7 @@ def test_filter_partial_spellings( context: AbstractContextManager[Any], ) -> None: with context: - assert nwd.col("a").filter(*predicates, **constraints) + assert nwp.col("a").filter(*predicates, **constraints) def test_lit_series_roundtrip() -> None: @@ -392,15 +392,15 @@ def test_lit_series_roundtrip() -> None: data = ["a", "b", "c"] native = pa.chunked_array([pa.array(data)]) - series = nwd.Series.from_native(native) - lit_series = nwd.lit(series) + series = nwp.Series.from_native(native) + lit_series = nwp.lit(series) assert lit_series.meta.is_literal() e_ir = lit_series._ir assert isinstance(e_ir, ir.Literal) assert isinstance(e_ir.dtype, nw.String) assert isinstance(e_ir.value, SeriesLiteral) unwrapped = e_ir.unwrap() - assert isinstance(unwrapped, nwd.Series) + assert isinstance(unwrapped, nwp.Series) assert isinstance(unwrapped.to_native(), pa.ChunkedArray) assert unwrapped.to_list() == data @@ -408,24 +408,24 @@ def test_lit_series_roundtrip() -> None: @pytest.mark.parametrize( ("arg_1", "arg_2", "function", "op"), [ - (nwd.col("a"), 1, operator.eq, ops.Eq), - (nwd.col("a"), "b", operator.eq, ops.Eq), - (nwd.col("a"), 1, operator.ne, ops.NotEq), - (nwd.col("a"), "b", operator.ne, ops.NotEq), - (nwd.col("a"), "b", operator.ge, ops.GtEq), - (nwd.col("a"), "b", operator.gt, ops.Gt), - (nwd.col("a"), "b", operator.le, ops.LtEq), - (nwd.col("a"), "b", operator.lt, ops.Lt), - ((nwd.col("a") != 1), False, operator.and_, ops.And), - ((nwd.col("a") != 1), False, operator.or_, ops.Or), - ((nwd.col("a")), True, operator.xor, ops.ExclusiveOr), - (nwd.col("a"), 6, operator.add, ops.Add), - (nwd.col("a"), 2.1, operator.mul, ops.Multiply), - (nwd.col("a"), nwd.col("b"), operator.sub, ops.Sub), - (nwd.col("a"), 2, operator.pow, F.Pow), - (nwd.col("a"), 2, operator.mod, ops.Modulus), - (nwd.col("a"), 2, operator.floordiv, ops.FloorDivide), - (nwd.col("a"), 4, operator.truediv, ops.TrueDivide), + (nwp.col("a"), 1, operator.eq, ops.Eq), + (nwp.col("a"), "b", operator.eq, ops.Eq), + (nwp.col("a"), 1, operator.ne, ops.NotEq), + (nwp.col("a"), "b", operator.ne, ops.NotEq), + (nwp.col("a"), "b", operator.ge, ops.GtEq), + (nwp.col("a"), "b", operator.gt, ops.Gt), + (nwp.col("a"), "b", operator.le, ops.LtEq), + (nwp.col("a"), "b", operator.lt, ops.Lt), + ((nwp.col("a") != 1), False, operator.and_, ops.And), + ((nwp.col("a") != 1), False, operator.or_, ops.Or), + ((nwp.col("a")), True, operator.xor, ops.ExclusiveOr), + (nwp.col("a"), 6, operator.add, ops.Add), + (nwp.col("a"), 2.1, operator.mul, ops.Multiply), + (nwp.col("a"), nwp.col("b"), operator.sub, ops.Sub), + (nwp.col("a"), 2, operator.pow, F.Pow), + (nwp.col("a"), 2, operator.mod, ops.Modulus), + (nwp.col("a"), 2, operator.floordiv, ops.FloorDivide), + (nwp.col("a"), 4, operator.truediv, ops.TrueDivide), ], ) def test_operators_left_right( @@ -442,8 +442,8 @@ def test_operators_left_right( } result_1 = function(arg_1, arg_2) result_2 = function(arg_2, arg_1) - assert isinstance(result_1, nwd.Expr) - assert isinstance(result_2, nwd.Expr) + assert isinstance(result_1, nwp.Expr) + assert isinstance(result_2, nwp.Expr) ir_1 = result_1._ir ir_2 = result_2._ir if op in {ops.Eq, ops.NotEq}: diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index ce78c4f72f..f4f7368c7e 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -5,7 +5,7 @@ import pytest import narwhals as nw -from narwhals import _plan as nwd +from narwhals import _plan as nwp from narwhals._plan import _parse, expressions as ir, selectors as ndcs from narwhals._plan._rewrites import ( rewrite_all, @@ -49,14 +49,14 @@ def _to_window_expr(into_expr: IntoExpr, *partition_by: IntoExpr) -> ir.WindowEx def test_rewrite_elementwise_over_simple(schema_2: dict[str, DType]) -> None: with pytest.raises(InvalidOperationError, match=r"over.+elementwise"): - nwd.col("a").sum().abs().over("b") + nwp.col("a").sum().abs().over("b") # NOTE: Since the requested "before" expression is currently an error (at definition time), # we need to manually build the IR - to sidestep the validation in `Over.to_window_expr`. # Later, that error might not be needed if we can do this rewrite. # If you're here because of a "Did not raise" - just replace everything with the (previously) erroring expr. - expected = nwd.col("a").sum().over("b").abs() - before = _to_window_expr(nwd.col("a").sum().abs(), "b").to_narwhals() + expected = nwp.col("a").sum().over("b").abs() + before = _to_window_expr(nwp.col("a").sum().abs(), "b").to_narwhals() assert_expr_ir_equal(before, "col('a').sum().abs().over([col('b')])") actual = rewrite_all(before, schema=schema_2, rewrites=[rewrite_elementwise_over]) assert len(actual) == 1 @@ -65,11 +65,11 @@ def test_rewrite_elementwise_over_simple(schema_2: dict[str, DType]) -> None: def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: expected = ( - nwd.col("b").last().over("d").replace_strict({1: 2}), - nwd.col("c").last().over("d").replace_strict({1: 2}), + nwp.col("b").last().over("d").replace_strict({1: 2}), + nwp.col("c").last().over("d").replace_strict({1: 2}), ) before = _to_window_expr( - nwd.col("b", "c").last().replace_strict({1: 2}), "d" + nwp.col("b", "c").last().replace_strict({1: 2}), "d" ).to_narwhals() assert_expr_ir_equal( before, "cols(['b', 'c']).last().replace_strict().over([col('d')])" @@ -80,36 +80,36 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: assert_expr_ir_equal(lhs, rhs) -def named_ir(name: str, expr: nwd.Expr | ir.ExprIR, /) -> NamedIR[ir.ExprIR]: +def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> NamedIR[ir.ExprIR]: """Helper constructor for test compare.""" - return NamedIR(expr=expr._ir if isinstance(expr, nwd.Expr) else expr, name=name) + return NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: expected = ( - named_ir("a", nwd.col("a")), - named_ir("b", nwd.col("b").cast(nw.String)), - named_ir("x2", nwd.col("c").max().over("a").fill_null(50)), - named_ir("d**", ~nwd.col("d").is_duplicated().over("b")), - named_ir("f_some", nwd.col("f").str.contains("some")), - named_ir("g_some", nwd.col("g").str.contains("some")), - named_ir("h_some", nwd.col("h").str.contains("some")), - named_ir("D", nwd.col("d").null_count().over("f", "g", "j").sqrt()), - named_ir("E", nwd.col("e").null_count().over("f", "g", "j").sqrt()), - named_ir("B", nwd.col("b").null_count().over("f", "g", "j").sqrt()), + named_ir("a", nwp.col("a")), + named_ir("b", nwp.col("b").cast(nw.String)), + named_ir("x2", nwp.col("c").max().over("a").fill_null(50)), + named_ir("d**", ~nwp.col("d").is_duplicated().over("b")), + named_ir("f_some", nwp.col("f").str.contains("some")), + named_ir("g_some", nwp.col("g").str.contains("some")), + named_ir("h_some", nwp.col("h").str.contains("some")), + named_ir("D", nwp.col("d").null_count().over("f", "g", "j").sqrt()), + named_ir("E", nwp.col("e").null_count().over("f", "g", "j").sqrt()), + named_ir("B", nwp.col("b").null_count().over("f", "g", "j").sqrt()), ) before = ( - nwd.col("a"), - nwd.col("b").cast(nw.String), + nwp.col("a"), + nwp.col("b").cast(nw.String), ( - _to_window_expr(nwd.col("c").max().alias("x").fill_null(50), "a") + _to_window_expr(nwp.col("c").max().alias("x").fill_null(50), "a") .to_narwhals() .alias("x2") ), - ~(nwd.col("d").is_duplicated().alias("d*")).alias("d**").over("b"), + ~(nwp.col("d").is_duplicated().alias("d*")).alias("d**").over("b"), ndcs.string().str.contains("some").name.suffix("_some"), ( - _to_window_expr(nwd.nth(3, 4, 1).null_count().sqrt(), "f", "g", "j") + _to_window_expr(nwp.nth(3, 4, 1).null_count().sqrt(), "f", "g", "j") .to_narwhals() .name.to_uppercase() ), @@ -122,12 +122,12 @@ def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: def test_rewrite_binary_agg_over_simple(schema_2: dict[str, DType]) -> None: expected = ( - nwd.col("a") - nwd.col("a").mean().over("b"), - nwd.col("c") * nwd.col("c").abs().null_count().over("d"), + nwp.col("a") - nwp.col("a").mean().over("b"), + nwp.col("c") * nwp.col("c").abs().null_count().over("d"), ) before = ( - (nwd.col("a") - nwd.col("a").mean()).over("b"), - (nwd.col("c") * nwd.col("c").abs().null_count()).over("d"), + (nwp.col("a") - nwp.col("a").mean()).over("b"), + (nwp.col("c") * nwp.col("c").abs().null_count()).over("d"), ) actual = rewrite_all(*before, schema=schema_2, rewrites=[rewrite_binary_agg_over]) assert len(actual) == 2 @@ -137,13 +137,13 @@ def test_rewrite_binary_agg_over_simple(schema_2: dict[str, DType]) -> None: def test_rewrite_binary_agg_over_multiple(schema_2: dict[str, DType]) -> None: expected = ( - named_ir("hi_a", nwd.col("a") / nwd.col("e").drop_nulls().first().over("g")), - named_ir("hi_b", nwd.col("b") / nwd.col("e").drop_nulls().first().over("g")), - named_ir("hi_c", nwd.col("c") / nwd.col("e").drop_nulls().first().over("g")), - named_ir("hi_d", nwd.col("d") / nwd.col("e").drop_nulls().first().over("g")), + named_ir("hi_a", nwp.col("a") / nwp.col("e").drop_nulls().first().over("g")), + named_ir("hi_b", nwp.col("b") / nwp.col("e").drop_nulls().first().over("g")), + named_ir("hi_c", nwp.col("c") / nwp.col("e").drop_nulls().first().over("g")), + named_ir("hi_d", nwp.col("d") / nwp.col("e").drop_nulls().first().over("g")), ) before = ( - (nwd.col("a", "b", "c", "d") / nwd.col("e").drop_nulls().first()).over("g") + (nwp.col("a", "b", "c", "d") / nwp.col("e").drop_nulls().first()).over("g") ).name.prefix("hi_") actual = rewrite_all(before, schema=schema_2, rewrites=[rewrite_binary_agg_over]) assert len(actual) == 4 diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index d360907b6d..2b5ca80c35 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -2,7 +2,7 @@ import pytest -from narwhals import _plan as nwd +from narwhals import _plan as nwp from tests.utils import POLARS_VERSION pytest.importorskip("polars") @@ -11,43 +11,43 @@ if POLARS_VERSION >= (1, 0): # https://github.com/pola-rs/polars/pull/16743 OVER_CASE = ( - nwd.col("a").last().over("b", order_by="c"), + nwp.col("a").last().over("b", order_by="c"), pl.col("a").last().over("b", order_by="c"), ["a", "b"], ) else: # pragma: no cover - OVER_CASE = (nwd.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"]) + OVER_CASE = (nwp.col("a").last().over("b"), pl.col("a").last().over("b"), ["a", "b"]) if POLARS_VERSION >= (0, 20, 5): - LEN_CASE = (nwd.len(), pl.len(), "len") + LEN_CASE = (nwp.len(), pl.len(), "len") else: # pragma: no cover - LEN_CASE = (nwd.len().alias("count"), pl.count(), "count") + LEN_CASE = (nwp.len().alias("count"), pl.count(), "count") @pytest.mark.parametrize( ("nw_expr", "pl_expr", "expected"), [ ( - nwd.col("a").alias("b").min().alias("c").alias("d"), + nwp.col("a").alias("b").min().alias("c").alias("d"), pl.col("a").alias("b").min().alias("c").alias("d"), ["a"], ), ( - (nwd.col("a") + (nwd.col("a") - nwd.col("b"))).alias("c"), + (nwp.col("a") + (nwp.col("a") - nwp.col("b"))).alias("c"), (pl.col("a") + (pl.col("a") - pl.col("b"))).alias("c"), ["a", "a", "b"], ), OVER_CASE, ( - (nwd.col("a", "b", "c").sort().abs() * 20).max(), + (nwp.col("a", "b", "c").sort().abs() * 20).max(), (pl.col("a", "b", "c").sort().abs() * 20).max(), [], ), - (nwd.all().mean(), pl.all().mean(), []), - (nwd.all().mean().sort_by("d"), pl.all().mean().sort_by("d"), ["d"]), + (nwp.all().mean(), pl.all().mean(), []), + (nwp.all().mean().sort_by("d"), pl.all().mean().sort_by("d"), ["d"]), ], ) def test_meta_root_names( - nw_expr: nwd.Expr, pl_expr: pl.Expr, expected: list[str] + nw_expr: nwp.Expr, pl_expr: pl.Expr, expected: list[str] ) -> None: pl_result = pl_expr.meta.root_names() nw_result = nw_expr.meta.root_names() @@ -58,17 +58,17 @@ def test_meta_root_names( @pytest.mark.parametrize( ("nw_expr", "pl_expr", "expected"), [ - (nwd.col("a"), pl.col("a"), "a"), - (nwd.lit(1), pl.lit(1), "literal"), + (nwp.col("a"), pl.col("a"), "a"), + (nwp.lit(1), pl.lit(1), "literal"), LEN_CASE, pytest.param( ( - nwd.col("a") + nwp.col("a") .alias("b") .min() .alias("c") .over("e", "f") - .sort_by(nwd.col("i"), nwd.col("g", "h")) + .sort_by(nwp.col("i"), nwp.col("g", "h")) ), ( pl.col("a") @@ -82,17 +82,17 @@ def test_meta_root_names( id="Kitchen-Sink", ), pytest.param( - nwd.col("c").alias("x").fill_null(50), + nwp.col("c").alias("x").fill_null(50), pl.col("c").alias("x").fill_null(50), "x", id="FunctionExpr-Literal", ), pytest.param( ( - nwd.col("ROOT") + nwp.col("ROOT") .alias("ROOT-ALIAS") - .filter(nwd.col("b") >= 30, nwd.col("c").alias("d") == 7) - + nwd.col("RHS").alias("RHS-ALIAS") + .filter(nwp.col("b") >= 30, nwp.col("c").alias("d") == 7) + + nwp.col("RHS").alias("RHS-ALIAS") ), ( pl.col("ROOT") @@ -104,35 +104,35 @@ def test_meta_root_names( id="BinaryExpr-Multiple", ), pytest.param( - nwd.col("ROOT").alias("ROOT-ALIAS").mean().over(nwd.col("a").alias("b")), + nwp.col("ROOT").alias("ROOT-ALIAS").mean().over(nwp.col("a").alias("b")), pl.col("ROOT").alias("ROOT-ALIAS").mean().over(pl.col("a").alias("b")), "ROOT-ALIAS", id="WindowExpr", ), pytest.param( - nwd.when(nwd.col("a").alias("a?")).then(10), + nwp.when(nwp.col("a").alias("a?")).then(10), pl.when(pl.col("a").alias("a?")).then(10), "literal", id="When-Literal", ), pytest.param( - nwd.when(nwd.col("a").alias("a?")).then(nwd.col("b")).otherwise(20), + nwp.when(nwp.col("a").alias("a?")).then(nwp.col("b")).otherwise(20), pl.when(pl.col("a").alias("a?")).then(pl.col("b")).otherwise(20), "b", id="When-Column-Literal", ), pytest.param( - nwd.when(a=1).then(10).otherwise(nwd.col("c").alias("c?")), + nwp.when(a=1).then(10).otherwise(nwp.col("c").alias("c?")), pl.when(a=1).then(10).otherwise(pl.col("c").alias("c?")), "literal", id="When-Literal-Alias", ), pytest.param( ( - nwd.when(nwd.col("a").alias("a?")) + nwp.when(nwp.col("a").alias("a?")) .then(1) - .when(nwd.col("b") == 1) - .then(nwd.col("c")) + .when(nwp.col("b") == 1) + .then(nwp.col("c")) ), ( pl.when(pl.col("a").alias("a?")) @@ -145,9 +145,9 @@ def test_meta_root_names( ), pytest.param( ( - nwd.when(nwd.col("foo") > 2, nwd.col("bar") < 3) - .then(nwd.lit("Yes")) - .otherwise(nwd.lit("No")) + nwp.when(nwp.col("foo") > 2, nwp.col("bar") < 3) + .then(nwp.lit("Yes")) + .otherwise(nwp.lit("No")) .alias("TARGET") ), ( @@ -160,23 +160,23 @@ def test_meta_root_names( id="When2-Literal-Literal-Alias", ), pytest.param( - (nwd.col("ROOT").alias("ROOT-ALIAS").filter(nwd.col("c") <= 1).mean()), + (nwp.col("ROOT").alias("ROOT-ALIAS").filter(nwp.col("c") <= 1).mean()), (pl.col("ROOT").alias("ROOT-ALIAS").filter(pl.col("c") <= 1).mean()), "ROOT-ALIAS", id="Filter", ), pytest.param( - nwd.int_range(0, 10), pl.int_range(0, 10), "literal", id="IntRange-Literal" + nwp.int_range(0, 10), pl.int_range(0, 10), "literal", id="IntRange-Literal" ), pytest.param( - nwd.int_range(nwd.col("b").first(), 10), + nwp.int_range(nwp.col("b").first(), 10), pl.int_range(pl.col("b").first(), 10), "b", id="IntRange-Column", ), ], ) -def test_meta_output_name(nw_expr: nwd.Expr, pl_expr: pl.Expr, expected: str) -> None: +def test_meta_output_name(nw_expr: nwp.Expr, pl_expr: pl.Expr, expected: str) -> None: pl_result = pl_expr.meta.output_name() nw_result = nw_expr.meta.output_name() assert nw_result == expected diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 16bd3430c0..78c8ffe75f 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from narwhals import _plan as nwd +from narwhals import _plan as nwp from narwhals._plan import expressions as ir from narwhals._plan.common import NamedIR @@ -10,8 +10,8 @@ from typing_extensions import LiteralString -def _unwrap_ir(obj: nwd.Expr | ir.ExprIR | NamedIR) -> ir.ExprIR: - if isinstance(obj, nwd.Expr): +def _unwrap_ir(obj: nwp.Expr | ir.ExprIR | NamedIR) -> ir.ExprIR: + if isinstance(obj, nwp.Expr): return obj._ir if isinstance(obj, ir.ExprIR): return obj @@ -21,8 +21,8 @@ def _unwrap_ir(obj: nwd.Expr | ir.ExprIR | NamedIR) -> ir.ExprIR: def assert_expr_ir_equal( - actual: nwd.Expr | ir.ExprIR | NamedIR, - expected: nwd.Expr | ir.ExprIR | NamedIR | LiteralString, + actual: nwp.Expr | ir.ExprIR | NamedIR, + expected: nwp.Expr | ir.ExprIR | NamedIR | LiteralString, /, ) -> None: """Assert that `actual` is equivalent to `expected`. @@ -41,5 +41,5 @@ def assert_expr_ir_equal( elif isinstance(actual, NamedIR) and isinstance(expected, NamedIR): assert actual == expected else: - rhs = expected._ir if isinstance(expected, nwd.Expr) else expected + rhs = expected._ir if isinstance(expected, nwp.Expr) else expected assert lhs == rhs From 23704fa1c7d06dd5d5097c4afd712a7d31839937 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Sep 2025 16:25:51 +0000 Subject: [PATCH 34/36] refactor: Reuse `ir.boolean` export more --- narwhals/_plan/arrow/expr.py | 3 +-- narwhals/_plan/expr.py | 29 ++++++++++++++--------------- narwhals/_plan/functions.py | 6 +++--- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index b5d3b85360..2f602587ce 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -29,7 +29,6 @@ from narwhals._plan import expressions as ir from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.expressions import boolean from narwhals._plan.expressions.aggregation import ( ArgMax, ArgMin, @@ -119,7 +118,7 @@ def all(self, node: FunctionExpr[All], frame: Frame, name: str) -> StoresNativeT return self._unary_function(fn.all_)(node, frame, name) def any( - self, node: FunctionExpr[boolean.Any], frame: Frame, name: str + self, node: FunctionExpr[ir.boolean.Any], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.any_)(node, frame, name) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 294e5a1670..7a95936523 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -13,7 +13,6 @@ ) from narwhals._plan.expressions import ( aggregation as agg, - boolean, functions as F, operators as ops, ) @@ -398,31 +397,31 @@ def map_batches( ) def any(self) -> Self: - return self._with_unary(boolean.Any()) + return self._with_unary(ir.boolean.Any()) def all(self) -> Self: - return self._with_unary(boolean.All()) + return self._with_unary(ir.boolean.All()) def is_duplicated(self) -> Self: - return self._with_unary(boolean.IsDuplicated()) + return self._with_unary(ir.boolean.IsDuplicated()) def is_finite(self) -> Self: - return self._with_unary(boolean.IsFinite()) + return self._with_unary(ir.boolean.IsFinite()) def is_nan(self) -> Self: - return self._with_unary(boolean.IsNan()) + return self._with_unary(ir.boolean.IsNan()) def is_null(self) -> Self: - return self._with_unary(boolean.IsNull()) + return self._with_unary(ir.boolean.IsNull()) def is_first_distinct(self) -> Self: - return self._with_unary(boolean.IsFirstDistinct()) + return self._with_unary(ir.boolean.IsFirstDistinct()) def is_last_distinct(self) -> Self: - return self._with_unary(boolean.IsLastDistinct()) + return self._with_unary(ir.boolean.IsLastDistinct()) def is_unique(self) -> Self: - return self._with_unary(boolean.IsUnique()) + return self._with_unary(ir.boolean.IsUnique()) def is_between( self, @@ -432,16 +431,16 @@ def is_between( ) -> Self: it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) return self._from_ir( - boolean.IsBetween(closed=closed).to_function_expr(self._ir, *it) + ir.boolean.IsBetween(closed=closed).to_function_expr(self._ir, *it) ) def is_in(self, other: Iterable[Any]) -> Self: if is_series(other): - return self._with_unary(boolean.IsInSeries.from_series(other)) + return self._with_unary(ir.boolean.IsInSeries.from_series(other)) if isinstance(other, Iterable): - return self._with_unary(boolean.IsInSeq.from_iterable(other)) + return self._with_unary(ir.boolean.IsInSeq.from_iterable(other)) if is_expr(other): - return self._with_unary(boolean.IsInExpr(other=other._ir)) + return self._with_unary(ir.boolean.IsInExpr(other=other._ir)) msg = f"`is_in` only supports iterables, got: {type(other).__name__}" raise TypeError(msg) @@ -537,7 +536,7 @@ def __rpow__(self, base: IntoExprColumn | float) -> Self: return self._from_ir(F.Pow().to_function_expr(parse_into_expr_ir(base), self._ir)) def __invert__(self) -> Self: - return self._with_unary(boolean.Not()) + return self._with_unary(ir.boolean.Not()) @property def meta(self) -> MetaNamespace: diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 1c265d9ed0..c07fe92c29 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -4,7 +4,7 @@ import typing as t from narwhals._plan import _guards, _parse, common, expressions as ir -from narwhals._plan.expressions import boolean, functions as F +from narwhals._plan.expressions import functions as F from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.expressions.ranges import IntRange from narwhals._plan.expressions.strings import ConcatStr @@ -80,12 +80,12 @@ def sum(*columns: str) -> Expr: def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: it = _parse.parse_into_seq_of_expr_ir(*exprs) - return boolean.AllHorizontal().to_function_expr(*it).to_narwhals() + return ir.boolean.AllHorizontal().to_function_expr(*it).to_narwhals() def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: it = _parse.parse_into_seq_of_expr_ir(*exprs) - return boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() + return ir.boolean.AnyHorizontal().to_function_expr(*it).to_narwhals() def sum_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> Expr: From 69268ce5847022af2c1a427b98f8e09e0b0f3971 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Sep 2025 18:48:43 +0000 Subject: [PATCH 35/36] refactor: Split out `_expr_ir.py` --- narwhals/_plan/_expansion.py | 2 +- narwhals/_plan/_expr_ir.py | 292 ++++++++++++++++++++ narwhals/_plan/_function.py | 10 +- narwhals/_plan/_parse.py | 8 +- narwhals/_plan/_rewrites.py | 4 +- narwhals/_plan/arrow/dataframe.py | 2 +- narwhals/_plan/arrow/expr.py | 2 +- narwhals/_plan/common.py | 312 ++-------------------- narwhals/_plan/dataframe.py | 3 +- narwhals/_plan/expressions/__init__.py | 7 +- narwhals/_plan/expressions/aggregation.py | 3 +- narwhals/_plan/expressions/boolean.py | 2 +- narwhals/_plan/expressions/expr.py | 3 +- narwhals/_plan/expressions/functions.py | 2 +- narwhals/_plan/expressions/name.py | 10 +- narwhals/_plan/expressions/operators.py | 3 +- narwhals/_plan/expressions/ranges.py | 3 +- narwhals/_plan/expressions/window.py | 3 +- narwhals/_plan/protocols.py | 3 +- narwhals/_plan/schema.py | 2 +- narwhals/_plan/typing.py | 2 +- narwhals/_plan/when_then.py | 3 +- tests/plan/expr_rewrites_test.py | 5 +- tests/plan/utils.py | 11 +- 24 files changed, 355 insertions(+), 342 deletions(-) create mode 100644 narwhals/_plan/_expr_ir.py diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index fb5f271308..fb2dd390a8 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -45,7 +45,6 @@ from narwhals._plan import common, meta from narwhals._plan._guards import is_horizontal_reduction from narwhals._plan._immutable import Immutable -from narwhals._plan.common import NamedIR from narwhals._plan.exceptions import ( column_index_error, column_not_found_error, @@ -59,6 +58,7 @@ ExprIR, IndexColumns, KeepName, + NamedIR, Nth, RenameAlias, SelectorIR, diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py new file mode 100644 index 0000000000..0646520102 --- /dev/null +++ b/narwhals/_plan/_expr_ir.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, cast + +from narwhals._plan._guards import is_function_expr, is_literal +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import dispatch_getter, replace +from narwhals._plan.options import ExprIROptions +from narwhals._plan.typing import ExprIRT +from narwhals.utils import Version + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + from typing import Any, ClassVar + + from typing_extensions import Self, TypeAlias + + from narwhals._plan.expr import Expr, Selector + from narwhals._plan.expressions.expr import Alias, Cast, Column + from narwhals._plan.meta import MetaNamespace + from narwhals._plan.protocols import Ctx, FrameT_contra, R_co + from narwhals._plan.typing import ExprIRT2, MapIR, Seq + from narwhals.dtypes import DType + + Incomplete: TypeAlias = "Any" + + +def _dispatch_generate( + tp: type[ExprIRT], / +) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: + if not tp.__expr_ir_config__.allow_dispatch: + + def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: + msg = ( + f"{tp.__name__!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" + ) + raise TypeError(msg) + + return _ + getter = dispatch_getter(tp) + + def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ + + +class ExprIR(Immutable): + """Anything that can be a node on a graph of expressions.""" + + _child: ClassVar[Seq[str]] = () + """Nested node names, in iteration order.""" + + __expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default() + __expr_ir_dispatch__: ClassVar[ + staticmethod[[Incomplete, Self, Incomplete, str], Incomplete] + ] + + def __init_subclass__( + cls: type[Self], + *args: Any, + child: Seq[str] = (), + config: ExprIROptions | None = None, + **kwds: Any, + ) -> None: + super().__init_subclass__(*args, **kwds) + if child: + cls._child = child + if config: + cls.__expr_ir_config__ = config + cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls)) + + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / + ) -> R_co: + """Evaluate expression in `frame`, using `ctx` for implementation(s).""" + return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] + + def to_narwhals(self, version: Version = Version.MAIN) -> Expr: + from narwhals._plan import expr + + tp = expr.Expr if version is Version.MAIN else expr.ExprV1 + return tp._from_ir(self) + + @property + def is_scalar(self) -> bool: + return False + + def map_ir(self, function: MapIR, /) -> ExprIR: + """Apply `function` to each child node, returning a new `ExprIR`. + + See [`polars_plan::plans::iterator::Expr.map_expr`] and [`polars_plan::plans::visitor::visitors`]. + + [`polars_plan::plans::iterator::Expr.map_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/iterator.rs#L152-L159 + [`polars_plan::plans::visitor::visitors`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/visitor/visitors.rs + """ + if not self._child: + return function(self) + children = ((name, getattr(self, name)) for name in self._child) + changed = {name: _map_ir_child(child, function) for name, child in children} + return function(replace(self, **changed)) + + def iter_left(self) -> Iterator[ExprIR]: + """Yield nodes root->leaf. + + Examples: + >>> from narwhals import _plan as nw + >>> + >>> a = nw.col("a") + >>> b = a.alias("b") + >>> c = b.min().alias("c") + >>> d = c.over(nw.col("e"), nw.col("f")) + >>> + >>> list(a._ir.iter_left()) + [col('a')] + >>> + >>> list(b._ir.iter_left()) + [col('a'), col('a').alias('b')] + >>> + >>> list(c._ir.iter_left()) + [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c')] + >>> + >>> list(d._ir.iter_left()) + [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c'), col('e'), col('f'), col('a').alias('b').min().alias('c').over([col('e'), col('f')])] + """ + for name in self._child: + child: ExprIR | Seq[ExprIR] = getattr(self, name) + if isinstance(child, ExprIR): + yield from child.iter_left() + else: + for node in child: + yield from node.iter_left() + yield self + + def iter_right(self) -> Iterator[ExprIR]: + """Yield nodes leaf->root. + + Note: + Identical to `iter_left` for root nodes. + + Examples: + >>> from narwhals import _plan as nw + >>> + >>> a = nw.col("a") + >>> b = a.alias("b") + >>> c = b.min().alias("c") + >>> d = c.over(nw.col("e"), nw.col("f")) + >>> + >>> list(a._ir.iter_right()) + [col('a')] + >>> + >>> list(b._ir.iter_right()) + [col('a').alias('b'), col('a')] + >>> + >>> list(c._ir.iter_right()) + [col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] + >>> + >>> list(d._ir.iter_right()) + [col('a').alias('b').min().alias('c').over([col('e'), col('f')]), col('f'), col('e'), col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] + """ + yield self + for name in reversed(self._child): + child: ExprIR | Seq[ExprIR] = getattr(self, name) + if isinstance(child, ExprIR): + yield from child.iter_right() + else: + for node in reversed(child): + yield from node.iter_right() + + def iter_root_names(self) -> Iterator[ExprIR]: + """Override for different iteration behavior in `ExprIR.meta.root_names`. + + Note: + Identical to `iter_left` by default. + """ + yield from self.iter_left() + + def iter_output_name(self) -> Iterator[ExprIR]: + """Override for different iteration behavior in `ExprIR.meta.output_name`. + + Note: + Identical to `iter_right` by default. + """ + yield from self.iter_right() + + @property + def meta(self) -> MetaNamespace: + from narwhals._plan.meta import MetaNamespace + + return MetaNamespace(_ir=self) + + def cast(self, dtype: DType) -> Cast: + from narwhals._plan.expressions.expr import Cast + + return Cast(expr=self, dtype=dtype) + + def alias(self, name: str) -> Alias: + from narwhals._plan.expressions.expr import Alias + + return Alias(expr=self, name=name) + + def _repr_html_(self) -> str: + return self.__repr__() + + +def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: + return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) + + +class SelectorIR(ExprIR, config=ExprIROptions.no_dispatch()): + def to_narwhals(self, version: Version = Version.MAIN) -> Selector: + from narwhals._plan import expr + + if version is Version.MAIN: + return expr.Selector._from_ir(self) + return expr.SelectorV1._from_ir(self) + + def matches_column(self, name: str, dtype: DType) -> bool: + """Return True if we can select this column. + + - Thinking that we could get more cache hits on an individual column basis. + - May also be more efficient to not iterate over the schema for every selector + - Instead do one pass, evaluating every selector against a single column at a time + """ + raise NotImplementedError(type(self)) + + +class NamedIR(Immutable, Generic[ExprIRT]): + """Post-projection expansion wrapper for `ExprIR`. + + - Somewhat similar to [`polars_plan::plans::expr_ir::ExprIR`]. + - The [`polars_plan::plans::aexpr::AExpr`] stage has been skipped (*for now*) + - Parts of that will probably be in here too + - `AExpr` seems like too much duplication when we won't get the memory allocation benefits in python + + [`polars_plan::plans::expr_ir::ExprIR`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/expr_ir.rs#L63-L74 + [`polars_plan::plans::aexpr::AExpr`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/mod.rs#L145-L231 + """ + + __slots__ = ("expr", "name") + expr: ExprIRT + name: str + + @staticmethod + def from_name(name: str, /) -> NamedIR[Column]: + """Construct as a simple, unaliased `col(name)` expression. + + Intended to be used in `with_columns` from a `FrozenSchema`'s keys. + """ + from narwhals._plan.expressions.expr import col + + return NamedIR(expr=col(name), name=name) + + @staticmethod + def from_ir(expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: + """Construct from an already expanded `ExprIR`. + + Should be cheap to get the output name from cache, but will raise if used + without care. + """ + return NamedIR(expr=expr, name=expr.meta.output_name(raise_if_undetermined=True)) + + def map_ir(self, function: MapIR, /) -> Self: + """**WARNING**: don't use renaming ops here, or `self.name` is invalid.""" + return replace(self, expr=function(self.expr.map_ir(function))) + + def __repr__(self) -> str: + return f"{self.name}={self.expr!r}" + + def _repr_html_(self) -> str: + return f"{self.name}={self.expr._repr_html_()}" + + def is_elementwise_top_level(self) -> bool: + """Return True if the outermost node is elementwise. + + Based on [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`] + + This check: + - Is not recursive + - Is not valid on `ExprIR` *prior* to being expanded + + [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/properties.rs#L16-L44 + """ + from narwhals._plan.expressions import expr + + ir = self.expr + if is_function_expr(ir): + return ir.options.is_elementwise() + if is_literal(ir): + return ir.is_scalar + return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py index 59605d63d9..332dbfc085 100644 --- a/narwhals/_plan/_function.py +++ b/narwhals/_plan/_function.py @@ -1,13 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING from narwhals._plan._immutable import Immutable -from narwhals._plan.common import _dispatch_getter, _dispatch_method_name, replace +from narwhals._plan.common import dispatch_getter, dispatch_method_name, replace from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: - from typing import Any, Callable + from typing import Any, Callable, ClassVar from typing_extensions import Self, TypeAlias @@ -22,7 +22,7 @@ def _dispatch_generate_function( tp: type[FunctionT], / ) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: - getter = _dispatch_getter(tp) + getter = dispatch_getter(tp) def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: return getter(ctx)(node, frame, name) @@ -75,7 +75,7 @@ def __init_subclass__( cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) def __repr__(self) -> str: - return _dispatch_method_name(type(self)) + return dispatch_method_name(type(self)) class HorizontalFunction( diff --git a/narwhals/_plan/_parse.py b/narwhals/_plan/_parse.py index ef622707e5..651166ebee 100644 --- a/narwhals/_plan/_parse.py +++ b/narwhals/_plan/_parse.py @@ -4,7 +4,7 @@ # ruff: noqa: A002 from itertools import chain -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING from narwhals._plan._guards import is_expr, is_iterable_reject from narwhals._plan.exceptions import ( @@ -16,16 +16,16 @@ if TYPE_CHECKING: from collections.abc import Iterator - from typing import Any + from typing import Any, TypeVar import polars as pl from typing_extensions import TypeAlias, TypeIs - from narwhals._plan.common import ExprIR + from narwhals._plan.expressions import ExprIR from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq from narwhals.typing import IntoDType -T = TypeVar("T") + T = TypeVar("T") _RaisesInvalidIntoExprError: TypeAlias = "Any" """ diff --git a/narwhals/_plan/_rewrites.py b/narwhals/_plan/_rewrites.py index 83e86b74c0..ae23fa4b9b 100644 --- a/narwhals/_plan/_rewrites.py +++ b/narwhals/_plan/_rewrites.py @@ -12,12 +12,12 @@ is_window_expr, ) from narwhals._plan._parse import parse_into_seq_of_expr_ir -from narwhals._plan.common import NamedIR, replace +from narwhals._plan.common import replace if TYPE_CHECKING: from collections.abc import Sequence - from narwhals._plan.common import ExprIR + from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.schema import IntoFrozenSchema from narwhals._plan.typing import IntoExpr, MapIR, NamedOrExprIRT, Seq diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index fad9a119f9..27a02bc2ed 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -20,8 +20,8 @@ from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.dataframe import DataFrame as NwDataFrame + from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.options import SortMultipleOptions from narwhals._plan.typing import Seq from narwhals.dtypes import DType diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 2f602587ce..57ec5196d6 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -9,7 +9,7 @@ from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co -from narwhals._plan.common import NamedIR +from narwhals._plan.expressions import NamedIR from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch, namespace from narwhals._utils import ( Implementation, diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 65cbef4a72..0b4267f214 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -6,36 +6,27 @@ from collections.abc import Iterable from decimal import Decimal from operator import attrgetter -from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast, overload +from typing import TYPE_CHECKING, cast, overload -from narwhals._plan._guards import is_function_expr, is_iterable_reject, is_literal -from narwhals._plan._immutable import Immutable -from narwhals._plan.options import ExprIROptions -from narwhals._plan.typing import ( - DTypeT, - ExprIRT, - ExprIRT2, - FunctionT, - MapIR, - NonNestedDTypeT, - OneOrIterable, - Seq, -) +from narwhals._plan._guards import is_iterable_reject from narwhals.dtypes import DType from narwhals.utils import Version if TYPE_CHECKING: from collections.abc import Iterator - from typing import Any, Callable - - from typing_extensions import Self, TypeAlias - - from narwhals._plan.expr import Expr, Selector - from narwhals._plan.expressions.expr import Alias, Cast, Column - from narwhals._plan.meta import MetaNamespace - from narwhals._plan.protocols import Ctx, FrameT_contra, R_co + from typing import Any, Callable, TypeVar + + from narwhals._plan.typing import ( + DTypeT, + ExprIRT, + FunctionT, + NonNestedDTypeT, + OneOrIterable, + ) from narwhals.typing import NonNestedDType, NonNestedLiteral + T = TypeVar("T") + if sys.version_info >= (3, 13): from copy import replace as replace # noqa: PLC0414 @@ -50,10 +41,6 @@ def replace(obj: T, /, **changes: Any) -> T: return func(obj, **changes) # type: ignore[no-any-return] -T = TypeVar("T") -Incomplete: TypeAlias = "Any" - - def pascal_to_snake_case(s: str) -> str: """Convert a PascalCase, camelCase string to snake_case. @@ -73,286 +60,19 @@ def _re_repl_snake(match: re.Match[str], /) -> str: return f"{match.group(1)}_{match.group(2)}" -def _dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: +def dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: config = tp.__expr_ir_config__ name = config.override_name or pascal_to_snake_case(tp.__name__) return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name -def _dispatch_getter(tp: type[ExprIRT | FunctionT]) -> Callable[[Any], Any]: - getter = attrgetter(_dispatch_method_name(tp)) +def dispatch_getter(tp: type[ExprIRT | FunctionT]) -> Callable[[Any], Any]: + getter = attrgetter(dispatch_method_name(tp)) if tp.__expr_ir_config__.origin == "expr": return getter return lambda ctx: getter(ctx.__narwhals_namespace__()) -def _dispatch_generate( - tp: type[ExprIRT], / -) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: - if not tp.__expr_ir_config__.allow_dispatch: - - def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: - msg = ( - f"{tp.__name__!r} should not appear at the compliant-level.\n\n" - f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" - ) - raise TypeError(msg) - - return _ - getter = _dispatch_getter(tp) - - def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - - -def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: - return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) - - -class ExprIR(Immutable): - """Anything that can be a node on a graph of expressions.""" - - _child: ClassVar[Seq[str]] = () - """Nested node names, in iteration order.""" - - __expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default() - __expr_ir_dispatch__: ClassVar[ - staticmethod[[Incomplete, Self, Incomplete, str], Incomplete] - ] - - def __init_subclass__( - cls: type[Self], - *args: Any, - child: Seq[str] = (), - config: ExprIROptions | None = None, - **kwds: Any, - ) -> None: - super().__init_subclass__(*args, **kwds) - if child: - cls._child = child - if config: - cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls)) - - def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / - ) -> R_co: - """Evaluate expression in `frame`, using `ctx` for implementation(s).""" - return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] - - def to_narwhals(self, version: Version = Version.MAIN) -> Expr: - from narwhals._plan import expr - - tp = expr.Expr if version is Version.MAIN else expr.ExprV1 - return tp._from_ir(self) - - @property - def is_scalar(self) -> bool: - return False - - def map_ir(self, function: MapIR, /) -> ExprIR: - """Apply `function` to each child node, returning a new `ExprIR`. - - See [`polars_plan::plans::iterator::Expr.map_expr`] and [`polars_plan::plans::visitor::visitors`]. - - [`polars_plan::plans::iterator::Expr.map_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/iterator.rs#L152-L159 - [`polars_plan::plans::visitor::visitors`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/visitor/visitors.rs - """ - if not self._child: - return function(self) - children = ((name, getattr(self, name)) for name in self._child) - changed = {name: _map_ir_child(child, function) for name, child in children} - return function(replace(self, **changed)) - - def iter_left(self) -> Iterator[ExprIR]: - """Yield nodes root->leaf. - - Examples: - >>> from narwhals import _plan as nw - >>> - >>> a = nw.col("a") - >>> b = a.alias("b") - >>> c = b.min().alias("c") - >>> d = c.over(nw.col("e"), nw.col("f")) - >>> - >>> list(a._ir.iter_left()) - [col('a')] - >>> - >>> list(b._ir.iter_left()) - [col('a'), col('a').alias('b')] - >>> - >>> list(c._ir.iter_left()) - [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c')] - >>> - >>> list(d._ir.iter_left()) - [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c'), col('e'), col('f'), col('a').alias('b').min().alias('c').over([col('e'), col('f')])] - """ - for name in self._child: - child: ExprIR | Seq[ExprIR] = getattr(self, name) - if isinstance(child, ExprIR): - yield from child.iter_left() - else: - for node in child: - yield from node.iter_left() - yield self - - def iter_right(self) -> Iterator[ExprIR]: - """Yield nodes leaf->root. - - Note: - Identical to `iter_left` for root nodes. - - Examples: - >>> from narwhals import _plan as nw - >>> - >>> a = nw.col("a") - >>> b = a.alias("b") - >>> c = b.min().alias("c") - >>> d = c.over(nw.col("e"), nw.col("f")) - >>> - >>> list(a._ir.iter_right()) - [col('a')] - >>> - >>> list(b._ir.iter_right()) - [col('a').alias('b'), col('a')] - >>> - >>> list(c._ir.iter_right()) - [col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] - >>> - >>> list(d._ir.iter_right()) - [col('a').alias('b').min().alias('c').over([col('e'), col('f')]), col('f'), col('e'), col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] - """ - yield self - for name in reversed(self._child): - child: ExprIR | Seq[ExprIR] = getattr(self, name) - if isinstance(child, ExprIR): - yield from child.iter_right() - else: - for node in reversed(child): - yield from node.iter_right() - - def iter_root_names(self) -> Iterator[ExprIR]: - """Override for different iteration behavior in `ExprIR.meta.root_names`. - - Note: - Identical to `iter_left` by default. - """ - yield from self.iter_left() - - def iter_output_name(self) -> Iterator[ExprIR]: - """Override for different iteration behavior in `ExprIR.meta.output_name`. - - Note: - Identical to `iter_right` by default. - """ - yield from self.iter_right() - - @property - def meta(self) -> MetaNamespace: - from narwhals._plan.meta import MetaNamespace - - return MetaNamespace(_ir=self) - - def cast(self, dtype: DType) -> Cast: - from narwhals._plan.expressions.expr import Cast - - return Cast(expr=self, dtype=dtype) - - def alias(self, name: str) -> Alias: - from narwhals._plan.expressions.expr import Alias - - return Alias(expr=self, name=name) - - def _repr_html_(self) -> str: - return self.__repr__() - - -class SelectorIR(ExprIR, config=ExprIROptions.no_dispatch()): - def to_narwhals(self, version: Version = Version.MAIN) -> Selector: - from narwhals._plan import expr - - if version is Version.MAIN: - return expr.Selector._from_ir(self) - return expr.SelectorV1._from_ir(self) - - def matches_column(self, name: str, dtype: DType) -> bool: - """Return True if we can select this column. - - - Thinking that we could get more cache hits on an individual column basis. - - May also be more efficient to not iterate over the schema for every selector - - Instead do one pass, evaluating every selector against a single column at a time - """ - raise NotImplementedError(type(self)) - - -class NamedIR(Immutable, Generic[ExprIRT]): - """Post-projection expansion wrapper for `ExprIR`. - - - Somewhat similar to [`polars_plan::plans::expr_ir::ExprIR`]. - - The [`polars_plan::plans::aexpr::AExpr`] stage has been skipped (*for now*) - - Parts of that will probably be in here too - - `AExpr` seems like too much duplication when we won't get the memory allocation benefits in python - - [`polars_plan::plans::expr_ir::ExprIR`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/expr_ir.rs#L63-L74 - [`polars_plan::plans::aexpr::AExpr`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/mod.rs#L145-L231 - """ - - __slots__ = ("expr", "name") - expr: ExprIRT - name: str - - @staticmethod - def from_name(name: str, /) -> NamedIR[Column]: - """Construct as a simple, unaliased `col(name)` expression. - - Intended to be used in `with_columns` from a `FrozenSchema`'s keys. - """ - from narwhals._plan.expressions.expr import col - - return NamedIR(expr=col(name), name=name) - - @staticmethod - def from_ir(expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: - """Construct from an already expanded `ExprIR`. - - Should be cheap to get the output name from cache, but will raise if used - without care. - """ - return NamedIR(expr=expr, name=expr.meta.output_name(raise_if_undetermined=True)) - - def map_ir(self, function: MapIR, /) -> Self: - """**WARNING**: don't use renaming ops here, or `self.name` is invalid.""" - return replace(self, expr=function(self.expr.map_ir(function))) - - def __repr__(self) -> str: - return f"{self.name}={self.expr!r}" - - def _repr_html_(self) -> str: - return f"{self.name}={self.expr._repr_html_()}" - - def is_elementwise_top_level(self) -> bool: - """Return True if the outermost node is elementwise. - - Based on [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`] - - This check: - - Is not recursive - - Is not valid on `ExprIR` *prior* to being expanded - - [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`]: https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-plan/src/plans/aexpr/properties.rs#L16-L44 - """ - from narwhals._plan.expressions import expr - - ir = self.expr - if is_function_expr(ir): - return ir.options.is_elementwise() - if is_literal(ir): - return ir.is_scalar - return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) - - def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes mapping: dict[type[NonNestedLiteral], type[NonNestedDType]] = { diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 040f598b35..8f06f1e5c9 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -21,8 +21,7 @@ import pyarrow as pa from typing_extensions import Self - from narwhals._plan.common import NamedIR - from narwhals._plan.expressions import ExprIR + from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.protocols import CompliantBaseFrame, CompliantDataFrame from narwhals._plan.schema import FrozenSchema from narwhals._plan.typing import Seq diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index c9a7fe8f6f..19ff195827 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -1,6 +1,10 @@ from __future__ import annotations -from narwhals._plan.common import ExprIR, SelectorIR # prob should move into package? +from narwhals._plan._expr_ir import ( # prob should move into package? + ExprIR, + NamedIR, + SelectorIR, +) from narwhals._plan.expressions import ( aggregation, boolean, @@ -61,6 +65,7 @@ "KeepName", "Len", "Literal", + "NamedIR", "Nth", "OrderableAggExpr", "OrderedWindowExpr", diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index a0ba57f7f1..263ca300e5 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan.common import ExprIR, pascal_to_snake_case +from narwhals._plan._expr_ir import ExprIR +from narwhals._plan.common import pascal_to_snake_case from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index a11ff4569e..ebc2a8643b 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -11,7 +11,7 @@ if t.TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.common import ExprIR + from narwhals._plan._expr_ir import ExprIR from narwhals._plan.expressions.expr import FunctionExpr, Literal # noqa: F401 from narwhals._plan.series import Series from narwhals._plan.typing import NativeSeriesT, Seq # noqa: F401 diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index d33b603ccf..a898bf879b 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -6,7 +6,8 @@ # - Literal import typing as t -from narwhals._plan.common import ExprIR, SelectorIR, flatten_hash_safe +from narwhals._plan._expr_ir import ExprIR, SelectorIR +from narwhals._plan.common import flatten_hash_safe from narwhals._plan.exceptions import function_expr_invalid_operation_error from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( diff --git a/narwhals/_plan/expressions/functions.py b/narwhals/_plan/expressions/functions.py index 94c9cebd6c..f7ff80abe5 100644 --- a/narwhals/_plan/expressions/functions.py +++ b/narwhals/_plan/expressions/functions.py @@ -13,7 +13,7 @@ from typing_extensions import Self - from narwhals._plan.common import ExprIR + from narwhals._plan._expr_ir import ExprIR from narwhals._plan.expressions.expr import AnonymousExpr, FunctionExpr, RollingExpr from narwhals._plan.options import EWMOptions, RankOptions, RollingOptionsFixedWindow from narwhals._plan.typing import Seq, Udf diff --git a/narwhals/_plan/expressions/name.py b/narwhals/_plan/expressions/name.py index a8460cd6dd..7f1e71fb09 100644 --- a/narwhals/_plan/expressions/name.py +++ b/narwhals/_plan/expressions/name.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from narwhals._plan import common +from narwhals._plan._expr_ir import ExprIR from narwhals._plan._immutable import Immutable from narwhals._plan.expressions.namespace import ExprNamespace, IRNamespace from narwhals._plan.options import ExprIROptions @@ -12,9 +12,9 @@ from narwhals._plan.expr import Expr -class KeepName(common.ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): +class KeepName(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr",) - expr: common.ExprIR + expr: ExprIR @property def is_scalar(self) -> bool: @@ -24,9 +24,9 @@ def __repr__(self) -> str: return f"{self.expr!r}.name.keep()" -class RenameAlias(common.ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): +class RenameAlias(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr", "function") - expr: common.ExprIR + expr: ExprIR function: AliasName @property diff --git a/narwhals/_plan/expressions/operators.py b/narwhals/_plan/expressions/operators.py index b8e9fc2b65..9ecc45737a 100644 --- a/narwhals/_plan/expressions/operators.py +++ b/narwhals/_plan/expressions/operators.py @@ -16,8 +16,7 @@ from typing_extensions import Self - from narwhals._plan.common import ExprIR - from narwhals._plan.expressions.expr import BinaryExpr, BinarySelector + from narwhals._plan.expressions import BinaryExpr, BinarySelector, ExprIR from narwhals._plan.typing import ( LeftSelectorT, LeftT, diff --git a/narwhals/_plan/expressions/ranges.py b/narwhals/_plan/expressions/ranges.py index f73b85f7e3..6befe3fa6d 100644 --- a/narwhals/_plan/expressions/ranges.py +++ b/narwhals/_plan/expressions/ranges.py @@ -8,8 +8,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from narwhals._plan.common import ExprIR - from narwhals._plan.expressions.expr import RangeExpr + from narwhals._plan.expressions import ExprIR, RangeExpr from narwhals.dtypes import IntegerType diff --git a/narwhals/_plan/expressions/window.py b/narwhals/_plan/expressions/window.py index 329b36078c..da24ce564b 100644 --- a/narwhals/_plan/expressions/window.py +++ b/narwhals/_plan/expressions/window.py @@ -11,8 +11,7 @@ ) if TYPE_CHECKING: - from narwhals._plan.common import ExprIR - from narwhals._plan.expressions.expr import OrderedWindowExpr, WindowExpr + from narwhals._plan.expressions import ExprIR, OrderedWindowExpr, WindowExpr from narwhals._plan.options import SortOptions from narwhals._plan.typing import Seq from narwhals.exceptions import InvalidOperationError diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index fec4cad043..11a17eb081 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._plan.common import NamedIR, flatten_hash_safe +from narwhals._plan.common import flatten_hash_safe from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar from narwhals._utils import Version @@ -16,6 +16,7 @@ from narwhals._plan.expressions import ( BinaryExpr, FunctionExpr, + NamedIR, aggregation as agg, boolean, functions as F, diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 69c1b5a2b3..4dbf5e6ef3 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -7,8 +7,8 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, TypeVar, overload +from narwhals._plan._expr_ir import NamedIR from narwhals._plan._immutable import _IMMUTABLE_HASH_NAME, Immutable -from narwhals._plan.common import NamedIR from narwhals.dtypes import Unknown if TYPE_CHECKING: diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 3f60a1ea98..0efb81ea81 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -8,8 +8,8 @@ from typing_extensions import TypeAlias from narwhals import dtypes + from narwhals._plan._expr_ir import ExprIR, NamedIR, SelectorIR from narwhals._plan._function import Function - from narwhals._plan.common import ExprIR, NamedIR, SelectorIR from narwhals._plan.expr import Expr from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 781087e9fc..18ae514f0d 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -11,8 +11,7 @@ from narwhals._plan.expr import Expr if TYPE_CHECKING: - from narwhals._plan.common import ExprIR - from narwhals._plan.expressions.expr import TernaryExpr + from narwhals._plan.expressions import ExprIR, TernaryExpr from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index f4f7368c7e..bf810aa176 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -12,7 +12,6 @@ rewrite_binary_agg_over, rewrite_elementwise_over, ) -from narwhals._plan.common import NamedIR from narwhals._plan.expressions.window import Over from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_expr_ir_equal @@ -80,9 +79,9 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: assert_expr_ir_equal(lhs, rhs) -def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> NamedIR[ir.ExprIR]: +def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: """Helper constructor for test compare.""" - return NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) + return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 78c8ffe75f..bf6135ee2f 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -4,25 +4,24 @@ from narwhals import _plan as nwp from narwhals._plan import expressions as ir -from narwhals._plan.common import NamedIR if TYPE_CHECKING: from typing_extensions import LiteralString -def _unwrap_ir(obj: nwp.Expr | ir.ExprIR | NamedIR) -> ir.ExprIR: +def _unwrap_ir(obj: nwp.Expr | ir.ExprIR | ir.NamedIR) -> ir.ExprIR: if isinstance(obj, nwp.Expr): return obj._ir if isinstance(obj, ir.ExprIR): return obj - if isinstance(obj, NamedIR): + if isinstance(obj, ir.NamedIR): return obj.expr raise NotImplementedError(type(obj)) def assert_expr_ir_equal( - actual: nwp.Expr | ir.ExprIR | NamedIR, - expected: nwp.Expr | ir.ExprIR | NamedIR | LiteralString, + actual: nwp.Expr | ir.ExprIR | ir.NamedIR, + expected: nwp.Expr | ir.ExprIR | ir.NamedIR | LiteralString, /, ) -> None: """Assert that `actual` is equivalent to `expected`. @@ -38,7 +37,7 @@ def assert_expr_ir_equal( lhs = _unwrap_ir(actual) if isinstance(expected, str): assert repr(lhs) == expected - elif isinstance(actual, NamedIR) and isinstance(expected, NamedIR): + elif isinstance(actual, ir.NamedIR) and isinstance(expected, ir.NamedIR): assert actual == expected else: rhs = expected._ir if isinstance(expected, nwp.Expr) else expected From d63e4b7b02bdf1b7a8370651acc6dee041a950f7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 14 Sep 2025 20:20:08 +0000 Subject: [PATCH 36/36] refactor: Redo `window` builders Exposes them under `ir.over*` --- narwhals/_plan/expr.py | 21 ++++---- narwhals/_plan/expressions/__init__.py | 3 ++ narwhals/_plan/expressions/window.py | 68 +++++++++++++------------- 3 files changed, 45 insertions(+), 47 deletions(-) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 7a95936523..b0f369bd77 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -17,7 +17,6 @@ operators as ops, ) from narwhals._plan.expressions.selectors import by_name -from narwhals._plan.expressions.window import Over from narwhals._plan.options import ( EWMOptions, RankOptions, @@ -154,20 +153,18 @@ def over( descending: bool = False, nulls_last: bool = False, ) -> Self: - node: ir.WindowExpr | ir.OrderedWindowExpr - partition: Seq[ir.ExprIR] = () if not (partition_by) and order_by is None: msg = "At least one of `partition_by` or `order_by` must be specified." raise TypeError(msg) - if partition_by: - partition = parse_into_seq_of_expr_ir(*partition_by) - if order_by is not None: - by = parse_into_seq_of_expr_ir(order_by) - options = SortOptions(descending=descending, nulls_last=nulls_last) - node = Over().to_ordered_window_expr(self._ir, partition, by, options) - else: - node = Over().to_window_expr(self._ir, partition) - return self._from_ir(node) + parse = parse_into_seq_of_expr_ir + fn = self._ir + group = parse(*partition_by) if partition_by else () + if order_by is None: + return self._from_ir(ir.over(fn, group)) + over = ir.over_ordered + order = parse(order_by) + desc, nulls = descending, nulls_last + return self._from_ir(over(fn, group, order, descending=desc, nulls_last=nulls)) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: options = SortOptions(descending=descending, nulls_last=nulls_last) diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 19ff195827..237ee36e81 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -45,6 +45,7 @@ nth, ) from narwhals._plan.expressions.name import KeepName, RenameAlias +from narwhals._plan.expressions.window import over, over_ordered __all__ = [ "AggExpr", @@ -87,5 +88,7 @@ "index_columns", "nth", "operators", + "over", + "over_ordered", "selectors", ] diff --git a/narwhals/_plan/expressions/window.py b/narwhals/_plan/expressions/window.py index da24ce564b..772af084f0 100644 --- a/narwhals/_plan/expressions/window.py +++ b/narwhals/_plan/expressions/window.py @@ -5,16 +5,16 @@ from narwhals._plan._guards import is_function_expr, is_window_expr from narwhals._plan._immutable import Immutable from narwhals._plan.exceptions import ( - over_elementwise_error, - over_nested_error, - over_row_separable_error, + over_elementwise_error as elementwise_error, + over_nested_error as nested_error, + over_row_separable_error as row_separable_error, ) +from narwhals._plan.expressions.expr import OrderedWindowExpr, WindowExpr +from narwhals._plan.options import SortOptions if TYPE_CHECKING: - from narwhals._plan.expressions import ExprIR, OrderedWindowExpr, WindowExpr - from narwhals._plan.options import SortOptions + from narwhals._plan.expressions import ExprIR from narwhals._plan.typing import Seq - from narwhals.exceptions import InvalidOperationError class Window(Immutable): @@ -29,41 +29,39 @@ def _validate_over( order_by: Seq[ExprIR] = (), sort_options: SortOptions | None = None, /, - ) -> InvalidOperationError | None: + ) -> ValueError | None: if is_window_expr(expr): - return over_nested_error(expr, partition_by, order_by, sort_options) + return nested_error(expr, partition_by, order_by, sort_options) if is_function_expr(expr): if expr.options.is_elementwise(): - return over_elementwise_error(expr, partition_by, order_by, sort_options) + return elementwise_error(expr, partition_by, order_by, sort_options) if expr.options.is_row_separable(): - return over_row_separable_error( - expr, partition_by, order_by, sort_options - ) + return row_separable_error(expr, partition_by, order_by, sort_options) return None - def to_window_expr(self, expr: ExprIR, partition_by: Seq[ExprIR], /) -> WindowExpr: - from narwhals._plan.expressions.expr import WindowExpr - if err := self._validate_over(expr, partition_by): - raise err - return WindowExpr(expr=expr, partition_by=partition_by, options=self) +def over(expr: ExprIR, partition_by: Seq[ExprIR], /) -> WindowExpr: + if err := Over._validate_over(expr, partition_by): + raise err + return WindowExpr(expr=expr, partition_by=partition_by, options=Over()) - def to_ordered_window_expr( - self, - expr: ExprIR, - partition_by: Seq[ExprIR], - order_by: Seq[ExprIR], - sort_options: SortOptions, - /, - ) -> OrderedWindowExpr: - from narwhals._plan.expressions.expr import OrderedWindowExpr - if err := self._validate_over(expr, partition_by, order_by, sort_options): - raise err - return OrderedWindowExpr( - expr=expr, - partition_by=partition_by, - order_by=order_by, - sort_options=sort_options, - options=self, - ) +def over_ordered( + expr: ExprIR, + partition_by: Seq[ExprIR], + order_by: Seq[ExprIR], + /, + *, + descending: bool = False, + nulls_last: bool = False, +) -> OrderedWindowExpr: + sort_options = SortOptions(descending=descending, nulls_last=nulls_last) + if err := Over._validate_over(expr, partition_by, order_by, sort_options): + raise err + return OrderedWindowExpr( + expr=expr, + partition_by=partition_by, + order_by=order_by, + sort_options=sort_options, + options=Over(), + )