From d6fd44a30975c946314ba862bc7aeb74c5568548 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 21 Mar 2025 15:41:09 +0000 Subject: [PATCH 01/20] wip --- narwhals/_duckdb/expr.py | 54 +++++++++++++++++--- narwhals/_spark_like/expr.py | 13 +++-- tests/expr_and_series/over_test.py | 12 ++--- tests/expr_and_series/reduction_test.py | 11 +--- tests/expr_and_series/sum_horizontal_test.py | 3 -- tests/frame/select_test.py | 2 - tests/stable_api_test.py | 4 +- 7 files changed, 59 insertions(+), 40 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index f450143afb..538c4aa055 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import duckdb import operator from typing import TYPE_CHECKING from typing import Any @@ -58,6 +59,7 @@ def __init__( self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version + self._window_function = None def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: return self._call(df) @@ -73,11 +75,20 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover ) def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: - if kind is ExprKind.AGGREGATION: - msg = "Broadcasting aggregations is not yet supported for DuckDB." - raise NotImplementedError(msg) - # For literals, DuckDB does its own broadcasting. - return self + if kind is ExprKind.LITERAL: + return self + def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: + return [ + duckdb.SQLExpression(f'{result} over ()') for result in self(df) + ] + return self.__class__( + func, + function_name=self._function_name, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + ) @classmethod def from_column_names( @@ -423,6 +434,38 @@ def null_count(self: Self) -> Self: lambda _input: FunctionExpression("sum", _input.isnull().cast("int")), "null_count", ) + + def over( + self: Self, + partition_by: Sequence[str], + kind: ExprKind, + order_by: Sequence[str] | None, + ) -> Self: + if (window_function := self._window_function) is not None: + assert order_by is not None # noqa: S101 + + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + return [ + window_function(expr, partition_by, order_by) + for expr in self._call(df) + ] + else: + + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + return [ + duckdb.SQLExpression(f'{expr} over (partition by {",".join(partition_by)})') + for expr in self._call(df) + ] + + return self.__class__( + func, + function_name=self._function_name + "->over", + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + ) + def is_null(self: Self) -> Self: return self._from_call(lambda _input: _input.isnull(), "is_null") @@ -498,5 +541,4 @@ def struct(self: Self) -> DuckDBExprStructNamespace: cum_min = not_implemented() cum_max = not_implemented() cum_prod = not_implemented() - over = not_implemented() rolling_sum = not_implemented() diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 06d40c86b6..dd11b81dc2 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -65,14 +65,13 @@ def __call__(self: Self, df: SparkLikeLazyFrame) -> Sequence[Column]: return self._call(df) def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + if kind is ExprKind.LITERAL: + return self def func(df: SparkLikeLazyFrame) -> Sequence[Column]: - if kind is ExprKind.AGGREGATION: - return [ - result.over(df._Window().partitionBy(df._F.lit(1))) - for result in self(df) - ] - # Let PySpark do its own broadcasting for literals. - return self(df) + return [ + result.over(df._Window().partitionBy(df._F.lit(1))) + for result in self(df) + ] return self.__class__( func, diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index 38bfdd2cc7..6fd573661c 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -31,8 +31,8 @@ def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) + # if "duckdb" in str(constructor): + # request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) expected = { @@ -85,10 +85,7 @@ def test_over_std_var(request: pytest.FixtureRequest, constructor: Constructor) assert_equal_data(result, expected) -def test_over_multiple(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_over_multiple(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) expected = { "a": ["a", "a", "b", "b", "b"], @@ -275,9 +272,6 @@ def test_over_anonymous_cumulative( def test_over_anonymous_reduction( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): - # TODO(unassigned): we should be able to support these - request.applymarker(pytest.mark.xfail) if "modin" in str(constructor): # probably bugged request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/reduction_test.py b/tests/expr_and_series/reduction_test.py index b1d84e85e4..bd760d1b48 100644 --- a/tests/expr_and_series/reduction_test.py +++ b/tests/expr_and_series/reduction_test.py @@ -31,11 +31,7 @@ def test_scalar_reduction_select( constructor: Constructor, expr: list[Any], expected: dict[str, list[Any]], - request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor) and request.node.callspec.id not in {"duckdb-0"}: - request.applymarker(pytest.mark.xfail) - data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) result = df.select(*expr) @@ -63,10 +59,7 @@ def test_scalar_reduction_with_columns( constructor: Constructor, expr: list[Any], expected: dict[str, list[Any]], - request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) result = df.with_columns(*expr).select(*expected.keys()) @@ -108,10 +101,8 @@ def test_empty_scalar_reduction_select( def test_empty_scalar_reduction_with_columns( - constructor: Constructor, request: pytest.FixtureRequest + constructor: Constructor ) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) from itertools import chain data = { diff --git a/tests/expr_and_series/sum_horizontal_test.py b/tests/expr_and_series/sum_horizontal_test.py index aaef64356f..9e042fac04 100644 --- a/tests/expr_and_series/sum_horizontal_test.py +++ b/tests/expr_and_series/sum_horizontal_test.py @@ -60,9 +60,6 @@ def test_sumh_aggregations(constructor: Constructor) -> None: def test_sumh_transformations( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): - # We don't yet support broadcasting for DuckDB. - request.applymarker(pytest.mark.xfail) data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} df = nw.from_native(constructor(data)) result = df.select(d=nw.sum_horizontal("a", nw.col("b").sum(), "c")) diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index 32ef82aef3..0110a911f9 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -121,8 +121,6 @@ def test_left_to_right_broadcasting( ) -> None: if "dask" in str(constructor) and DASK_VERSION < (2024, 10): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6]})) result = df.select(nw.col("a") + nw.col("b").sum()) expected = {"a": [16, 16, 17]} diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index 24f8006c49..66a3ba44ff 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -19,10 +19,8 @@ def remove_docstring_examples(doc: str) -> str: def test_renamed_taxicab_norm( - constructor: Constructor, request: pytest.FixtureRequest + constructor: Constructor ) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) # Suppose we need to rename `_l1_norm` to `_taxicab_norm`. # We need `narwhals.stable.v1` to stay stable. So, we # make the change in `narwhals`, and then add the new method From 1655fbcbb1d2422d8cc47115ba8f2b46523e55b3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 21 Mar 2025 16:11:42 +0000 Subject: [PATCH 02/20] support cum_sum too! --- narwhals/_duckdb/expr.py | 62 ++++++++++++++++++++++----- tests/expr_and_series/cum_sum_test.py | 12 ------ 2 files changed, 51 insertions(+), 23 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 538c4aa055..a1f94f471e 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -1,6 +1,5 @@ from __future__ import annotations -import duckdb import operator from typing import TYPE_CHECKING from typing import Any @@ -9,6 +8,7 @@ from typing import Sequence from typing import cast +import duckdb from duckdb import CaseExpression from duckdb import CoalesceOperator from duckdb import ColumnExpression @@ -29,11 +29,11 @@ from narwhals.utils import not_implemented if TYPE_CHECKING: - import duckdb from typing_extensions import Self from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.namespace import DuckDBNamespace + from narwhals._duckdb.typing import WindowFunction from narwhals.dtypes import DType from narwhals.utils import Version from narwhals.utils import _FullContext @@ -59,7 +59,7 @@ def __init__( self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version - self._window_function = None + self._window_function: WindowFunction | None = None def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: return self._call(df) @@ -77,14 +77,14 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: if kind is ExprKind.LITERAL: return self + def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: - return [ - duckdb.SQLExpression(f'{result} over ()') for result in self(df) - ] + return [duckdb.SQLExpression(f"{result} over ()") for result in self(df)] + return self.__class__( func, function_name=self._function_name, - evaluate_output_names=self._evaluate_output_names, + evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, @@ -164,6 +164,21 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: version=self._version, ) + def _with_window_function( + self: Self, + window_function: WindowFunction, + ) -> Self: + result = self.__class__( + self._call, + function_name=self._function_name, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + ) + result._window_function = window_function + return result + def __and__(self: Self, other: DuckDBExpr) -> Self: return self._from_call( lambda _input, other: _input & other, @@ -434,7 +449,7 @@ def null_count(self: Self) -> Self: lambda _input: FunctionExpression("sum", _input.isnull().cast("int")), "null_count", ) - + def over( self: Self, partition_by: Sequence[str], @@ -453,7 +468,9 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ - duckdb.SQLExpression(f'{expr} over (partition by {",".join(partition_by)})') + duckdb.SQLExpression( + f"{expr} over (partition by {','.join(partition_by)})" + ) for expr in self._call(df) ] @@ -466,7 +483,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: version=self._version, ) - def is_null(self: Self) -> Self: return self._from_call(lambda _input: _input.isnull(), "is_null") @@ -490,6 +506,31 @@ def round(self: Self, decimals: int) -> Self: lambda _input: FunctionExpression("round", _input, lit(decimals)), "round" ) + def cum_sum(self, *, reverse: bool) -> Self: + def func( + _input: duckdb.Expression, + partition_by: Sequence[str], + order_by: Sequence[str], + ) -> duckdb.Expression: + if reverse: + order_by_sql = "order by " + ", ".join( + f'"{x}" desc nulls last' for x in order_by + ) + else: + order_by_sql = "order by " + ", ".join( + f'"{x}" asc nulls first' for x in order_by + ) + if partition_by: + partition_by_sql = "partition by " + ",".join( + f'"{x}"' for x in partition_by + ) + else: + partition_by_sql = "" + sql = f"sum ({_input}) over ({partition_by_sql} {order_by_sql} rows between unbounded preceding and current row)" + return duckdb.SQLExpression(sql) + + return self._with_window_function(func) + def fill_null( self: Self, value: Self | Any, strategy: Any, limit: int | None ) -> Self: @@ -536,7 +577,6 @@ def struct(self: Self) -> DuckDBExprStructNamespace: is_unique = not_implemented() is_first_distinct = not_implemented() is_last_distinct = not_implemented() - cum_sum = not_implemented() cum_count = not_implemented() cum_min = not_implemented() cum_max = not_implemented() diff --git a/tests/expr_and_series/cum_sum_test.py b/tests/expr_and_series/cum_sum_test.py index 884921e7b4..ecf29157e8 100644 --- a/tests/expr_and_series/cum_sum_test.py +++ b/tests/expr_and_series/cum_sum_test.py @@ -40,9 +40,6 @@ def test_lazy_cum_sum_grouped( reverse: bool, expected_a: list[int], ) -> None: - if "duckdb" in str(constructor): - # no window function support yet in duckdb - request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor): # grouped window functions not yet supported request.applymarker(pytest.mark.xfail) @@ -89,9 +86,6 @@ def test_lazy_cum_sum_ordered_by_nulls( reverse: bool, expected_a: list[int], ) -> None: - if "duckdb" in str(constructor): - # no window function support yet in duckdb - request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor): # grouped window functions not yet supported request.applymarker(pytest.mark.xfail) @@ -142,9 +136,6 @@ def test_lazy_cum_sum_ungrouped( reverse: bool, expected_a: list[int], ) -> None: - if "duckdb" in str(constructor): - # no window function support yet in duckdb - request.applymarker(pytest.mark.xfail) if "dask" in str(constructor) and reverse: # https://github.com/dask/dask/issues/11802 request.applymarker(pytest.mark.xfail) @@ -184,9 +175,6 @@ def test_lazy_cum_sum_ungrouped_ordered_by_nulls( reverse: bool, expected_a: list[int], ) -> None: - if "duckdb" in str(constructor): - # no window function support yet in duckdb - request.applymarker(pytest.mark.xfail) if "dask" in str(constructor): # https://github.com/dask/dask/issues/11806 request.applymarker(pytest.mark.xfail) From 8c36db7f73629bec35feb1c8b1693154baa53fef Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 21 Mar 2025 18:14:02 +0000 Subject: [PATCH 03/20] raise on lower versions --- narwhals/_duckdb/expr.py | 37 ++++++++++++++++++++++- tests/expr_and_series/rolling_sum_test.py | 5 +-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index a1f94f471e..fdaf6b4810 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -77,6 +77,9 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: if kind is ExprKind.LITERAL: return self + if self._backend_version < (1, 3): + msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns." + raise NotImplementedError(msg) def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: return [duckdb.SQLExpression(f"{result} over ()") for result in self(df)] @@ -456,6 +459,9 @@ def over( kind: ExprKind, order_by: Sequence[str] | None, ) -> Self: + if self._backend_version < (1, 3): + msg = "At least version 1.3 of DuckDB is required for `over` operation." + raise NotImplementedError(msg) if (window_function := self._window_function) is not None: assert order_by is not None # noqa: S101 @@ -531,6 +537,36 @@ def func( return self._with_window_function(func) + def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: + if center: + half = (window_size - 1) // 2 + remainder = (window_size - 1) % 2 + start = f"{half + remainder} preceding" + end = f"{half} following" + else: + start = f"{window_size - 1} preceding" + end = "current row" + + def func( + _input: duckdb.Expression, + partition_by: Sequence[str], + order_by: Sequence[str], + ) -> duckdb.Expression: + order_by_sql = "order by " + ", ".join( + f'"{x}" asc nulls first' for x in order_by + ) + if partition_by: + partition_by_sql = "partition by " + ",".join( + f'"{x}"' for x in partition_by + ) + else: + partition_by_sql = "" + window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})" + sql = f"case when count({_input}) over {window} >= {min_samples} then sum({_input}) over {window} else null end" + return duckdb.SQLExpression(sql) + + return self._with_window_function(func) + def fill_null( self: Self, value: Self | Any, strategy: Any, limit: int | None ) -> Self: @@ -581,4 +617,3 @@ def struct(self: Self) -> DuckDBExprStructNamespace: cum_min = not_implemented() cum_max = not_implemented() cum_prod = not_implemented() - rolling_sum = not_implemented() diff --git a/tests/expr_and_series/rolling_sum_test.py b/tests/expr_and_series/rolling_sum_test.py index c4e30d571c..2aa786b35c 100644 --- a/tests/expr_and_series/rolling_sum_test.py +++ b/tests/expr_and_series/rolling_sum_test.py @@ -76,14 +76,11 @@ def test_rolling_sum_expr_lazy_ungrouped( expected_a: list[float], window_size: int, min_samples: int, - request: pytest.FixtureRequest, *, center: bool, ) -> None: if "polars" in str(constructor) and POLARS_VERSION < (1, 10): pytest.skip() - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) if "modin" in str(constructor): # unreliable pytest.skip() @@ -134,7 +131,7 @@ def test_rolling_sum_expr_lazy_grouped( pytest.skip() if "pandas" in str(constructor) and PANDAS_VERSION < (1, 2): pytest.skip() - if any(x in str(constructor) for x in ("dask", "pyarrow_table", "duckdb")): + if any(x in str(constructor) for x in ("dask", "pyarrow_table")): request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor) and center: # center is not implemented for offset-based windows From ba5b646c170d19acab7e37f615753489dfc56654 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 21 Mar 2025 18:44:34 +0000 Subject: [PATCH 04/20] feat: support window operations for DuckDB --- .github/workflows/extremes.yml | 5 +++-- tests/expr_and_series/cum_sum_test.py | 17 +++++++++++++---- tests/expr_and_series/over_test.py | 13 +++++++++---- tests/expr_and_series/reduction_test.py | 11 ++++++++--- tests/expr_and_series/rolling_sum_test.py | 9 +++++++-- tests/expr_and_series/sum_horizontal_test.py | 7 ++++--- tests/frame/select_test.py | 3 +++ tests/stable_api_test.py | 7 ++++--- 8 files changed, 51 insertions(+), 21 deletions(-) diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index 31061e587f..26411e79f2 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -179,9 +179,10 @@ jobs: run: | uv pip uninstall dask dask-expr --system python -m pip install git+https://github.com/dask/distributed git+https://github.com/dask/dask - - name: install duckdb + - name: install duckdb nightly run: | - python -m pip install -U --pre duckdb + uv pip uninstall duckdb --system + uv pip install -U --pre duckdb --system - name: show-deps run: uv pip freeze - name: Assert nightlies dependencies diff --git a/tests/expr_and_series/cum_sum_test.py b/tests/expr_and_series/cum_sum_test.py index ecf29157e8..e4d9de2210 100644 --- a/tests/expr_and_series/cum_sum_test.py +++ b/tests/expr_and_series/cum_sum_test.py @@ -3,6 +3,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import DUCKDB_VERSION from tests.utils import POLARS_VERSION from tests.utils import Constructor from tests.utils import ConstructorEager @@ -49,7 +50,9 @@ def test_lazy_cum_sum_grouped( if "dask" in str(constructor): # https://github.com/dask/dask/issues/11806 request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip(reason="too old version") if "cudf" in str(constructor): # https://github.com/rapidsai/cudf/issues/18159 @@ -95,7 +98,9 @@ def test_lazy_cum_sum_ordered_by_nulls( if "dask" in str(constructor): # https://github.com/dask/dask/issues/11806 request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip(reason="too old version") if "cudf" in str(constructor): # https://github.com/rapidsai/cudf/issues/18159 @@ -142,7 +147,9 @@ def test_lazy_cum_sum_ungrouped( if "modin" in str(constructor): # probably bugged request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip(reason="too old version") df = nw.from_native( @@ -181,7 +188,9 @@ def test_lazy_cum_sum_ungrouped_ordered_by_nulls( if "modin" in str(constructor): # probably bugged request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip(reason="too old version") df = nw.from_native( diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index 6fd573661c..61c53d28a1 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -9,6 +9,7 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import LengthChangingExprError +from tests.utils import DUCKDB_VERSION from tests.utils import PANDAS_VERSION from tests.utils import POLARS_VERSION from tests.utils import Constructor @@ -30,9 +31,9 @@ } -def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) -> None: - # if "duckdb" in str(constructor): - # request.applymarker(pytest.mark.xfail) +def test_over_single(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() df = nw.from_native(constructor(data)) expected = { @@ -86,6 +87,8 @@ def test_over_std_var(request: pytest.FixtureRequest, constructor: Constructor) def test_over_multiple(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() df = nw.from_native(constructor(data)) expected = { "a": ["a", "a", "b", "b", "b"], @@ -272,6 +275,8 @@ def test_over_anonymous_cumulative( def test_over_anonymous_reduction( constructor: Constructor, request: pytest.FixtureRequest ) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() if "modin" in str(constructor): # probably bugged request.applymarker(pytest.mark.xfail) @@ -415,7 +420,7 @@ def test_over_without_partition_by( ) -> None: if "polars" in str(constructor) and POLARS_VERSION < (1, 10): pytest.skip() - if "duckdb" in str(constructor): + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): # windows not yet supported request.applymarker(pytest.mark.xfail) if "modin" in str(constructor): diff --git a/tests/expr_and_series/reduction_test.py b/tests/expr_and_series/reduction_test.py index bd760d1b48..47d55e11f8 100644 --- a/tests/expr_and_series/reduction_test.py +++ b/tests/expr_and_series/reduction_test.py @@ -5,6 +5,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import DUCKDB_VERSION from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -32,6 +33,8 @@ def test_scalar_reduction_select( expr: list[Any], expected: dict[str, list[Any]], ) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) result = df.select(*expr) @@ -60,6 +63,8 @@ def test_scalar_reduction_with_columns( expr: list[Any], expected: dict[str, list[Any]], ) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) result = df.with_columns(*expr).select(*expected.keys()) @@ -100,9 +105,9 @@ def test_empty_scalar_reduction_select( assert_equal_data(result, expected) -def test_empty_scalar_reduction_with_columns( - constructor: Constructor -) -> None: +def test_empty_scalar_reduction_with_columns(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() from itertools import chain data = { diff --git a/tests/expr_and_series/rolling_sum_test.py b/tests/expr_and_series/rolling_sum_test.py index 2aa786b35c..da7ea09cf8 100644 --- a/tests/expr_and_series/rolling_sum_test.py +++ b/tests/expr_and_series/rolling_sum_test.py @@ -11,6 +11,7 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import InvalidOperationError +from tests.utils import DUCKDB_VERSION from tests.utils import PANDAS_VERSION from tests.utils import POLARS_VERSION from tests.utils import Constructor @@ -79,7 +80,9 @@ def test_rolling_sum_expr_lazy_ungrouped( *, center: bool, ) -> None: - if "polars" in str(constructor) and POLARS_VERSION < (1, 10): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 10)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip() if "modin" in str(constructor): # unreliable @@ -127,7 +130,9 @@ def test_rolling_sum_expr_lazy_grouped( *, center: bool, ) -> None: - if "polars" in str(constructor) and POLARS_VERSION < (1, 10): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 10)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip() if "pandas" in str(constructor) and PANDAS_VERSION < (1, 2): pytest.skip() diff --git a/tests/expr_and_series/sum_horizontal_test.py b/tests/expr_and_series/sum_horizontal_test.py index 9e042fac04..3da18660b6 100644 --- a/tests/expr_and_series/sum_horizontal_test.py +++ b/tests/expr_and_series/sum_horizontal_test.py @@ -5,6 +5,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import DUCKDB_VERSION from tests.utils import Constructor from tests.utils import assert_equal_data @@ -57,9 +58,9 @@ def test_sumh_aggregations(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_sumh_transformations( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: +def test_sumh_transformations(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} df = nw.from_native(constructor(data)) result = df.select(d=nw.sum_horizontal("a", nw.col("b").sum(), "c")) diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index 0110a911f9..6b29d97d3b 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -11,6 +11,7 @@ from narwhals.exceptions import InvalidIntoExprError from narwhals.exceptions import NarwhalsError from tests.utils import DASK_VERSION +from tests.utils import DUCKDB_VERSION from tests.utils import PANDAS_VERSION from tests.utils import POLARS_VERSION from tests.utils import Constructor @@ -119,6 +120,8 @@ def test_missing_columns( def test_left_to_right_broadcasting( constructor: Constructor, request: pytest.FixtureRequest ) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() if "dask" in str(constructor) and DASK_VERSION < (2024, 10): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6]})) diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index 66a3ba44ff..4d8440a260 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -8,6 +8,7 @@ import narwhals as nw import narwhals.stable.v1 as nw_v1 +from tests.utils import DUCKDB_VERSION from tests.utils import Constructor from tests.utils import assert_equal_data @@ -18,9 +19,9 @@ def remove_docstring_examples(doc: str) -> str: return doc.rstrip() -def test_renamed_taxicab_norm( - constructor: Constructor -) -> None: +def test_renamed_taxicab_norm(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() # Suppose we need to rename `_l1_norm` to `_taxicab_norm`. # We need `narwhals.stable.v1` to stay stable. So, we # make the change in `narwhals`, and then add the new method From 8235348fd6302f4dd8276de26737fa89ee355086 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 21 Mar 2025 20:10:17 +0000 Subject: [PATCH 05/20] use duckdb nightly in full coverage job --- .github/workflows/pytest.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index b369afcbeb..ee339c7cbc 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -84,6 +84,8 @@ jobs: cache-dependency-glob: "pyproject.toml" - name: install-reqs run: uv pip install -e ".[modin, dask]" --group core-tests --group extra --system + - name: install duckdb nightly + run: uv pip install -U --pre duckdb - name: install pyspark run: uv pip install -e ".[pyspark]" --system # PySpark is not yet available on Python3.12+ From f2cc84d294e712d662c53ac40398d7096d3f9532 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 21 Mar 2025 20:21:44 +0000 Subject: [PATCH 06/20] ci --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ee339c7cbc..338cc645e1 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -85,7 +85,7 @@ jobs: - name: install-reqs run: uv pip install -e ".[modin, dask]" --group core-tests --group extra --system - name: install duckdb nightly - run: uv pip install -U --pre duckdb + run: uv pip install -U --pre duckdb --system - name: install pyspark run: uv pip install -e ".[pyspark]" --system # PySpark is not yet available on Python3.12+ From 06de928849f4e12f64f0b87160191e61cdfb0ddd Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 21 Mar 2025 20:26:39 +0000 Subject: [PATCH 07/20] ci fixup --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 39f23d2cf4..5a0feb5f93 100644 --- a/Makefile +++ b/Makefile @@ -20,5 +20,7 @@ help: ## Display this help screen .PHONY: typing typing: ## Run typing checks + # install duckdb nightly so mypy recognises duckdb.SQLExpression + $(VENV_BIN)/uv pip install -U --pre duckdb $(VENV_BIN)/uv pip install -e . --group typing $(VENV_BIN)/mypy From 35e603cdd150e54215e8b59ad44f995f47bf925c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 21 Mar 2025 20:28:15 +0000 Subject: [PATCH 08/20] lint --- narwhals/_spark_like/expr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index dd11b81dc2..7df083d9cc 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -67,10 +67,10 @@ def __call__(self: Self, df: SparkLikeLazyFrame) -> Sequence[Column]: def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: if kind is ExprKind.LITERAL: return self + def func(df: SparkLikeLazyFrame) -> Sequence[Column]: return [ - result.over(df._Window().partitionBy(df._F.lit(1))) - for result in self(df) + result.over(df._Window().partitionBy(df._F.lit(1))) for result in self(df) ] return self.__class__( From 38a471077bca194dd2174d22bbad4a91b5b5a37e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 21 Mar 2025 20:32:00 +0000 Subject: [PATCH 09/20] missing file --- narwhals/_duckdb/typing.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 narwhals/_duckdb/typing.py diff --git a/narwhals/_duckdb/typing.py b/narwhals/_duckdb/typing.py new file mode 100644 index 0000000000..9661351fb8 --- /dev/null +++ b/narwhals/_duckdb/typing.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Protocol +from typing import Sequence + +if TYPE_CHECKING: + import duckdb + + class WindowFunction(Protocol): + def __call__( + self, + _input: duckdb.Expression, + partition_by: Sequence[str], + order_by: Sequence[str], + ) -> duckdb.Expression: ... From 1c08990fe418b5f70104cf75e44cb5c95fff0b06 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 09:31:01 +0000 Subject: [PATCH 10/20] import sqlexpression --- narwhals/_duckdb/expr.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 086d064f88..2be772dd28 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -8,11 +8,11 @@ from typing import Sequence from typing import cast -import duckdb from duckdb import CaseExpression from duckdb import CoalesceOperator from duckdb import ColumnExpression from duckdb import FunctionExpression +from duckdb import SQLExpression from duckdb.typing import DuckDBPyType from narwhals._compliant import LazyExpr @@ -29,6 +29,7 @@ from narwhals.utils import not_implemented if TYPE_CHECKING: + import duckdb from typing_extensions import Self from narwhals._duckdb.dataframe import DuckDBLazyFrame @@ -96,7 +97,7 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se raise NotImplementedError(msg) def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: - return [duckdb.SQLExpression(f"{result} over ()") for result in self(df)] + return [SQLExpression(f"{result} over ()") for result in self(df)] return self.__class__( func, @@ -470,7 +471,6 @@ def null_count(self: Self) -> Self: def over( self: Self, partition_by: Sequence[str], - kind: ExprKind, order_by: Sequence[str] | None, ) -> Self: if self._backend_version < (1, 3): @@ -488,9 +488,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ - duckdb.SQLExpression( - f"{expr} over (partition by {','.join(partition_by)})" - ) + SQLExpression(f"{expr} over (partition by {','.join(partition_by)})") for expr in self._call(df) ] @@ -547,7 +545,7 @@ def func( else: partition_by_sql = "" sql = f"sum ({_input}) over ({partition_by_sql} {order_by_sql} rows between unbounded preceding and current row)" - return duckdb.SQLExpression(sql) + return SQLExpression(sql) return self._with_window_function(func) @@ -577,7 +575,7 @@ def func( partition_by_sql = "" window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})" sql = f"case when count({_input}) over {window} >= {min_samples} then sum({_input}) over {window} else null end" - return duckdb.SQLExpression(sql) + return SQLExpression(sql) return self._with_window_function(func) From eaf0172848778a9a7519fa14f81cc42b14378383 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 09:36:07 +0000 Subject: [PATCH 11/20] post-merge fixup --- narwhals/_duckdb/expr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 2be772dd28..845818c6f9 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -86,6 +86,8 @@ def _with_metadata(self, metadata: ExprMetadata) -> Self: backend_version=self._backend_version, version=self._version, ) + if func := self._window_function: + expr = expr._with_window_function(func) expr._metadata = metadata return expr @@ -488,7 +490,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ - SQLExpression(f"{expr} over (partition by {','.join(partition_by)})") + SQLExpression(f"{expr} over partition by {','.join(partition_by)}") for expr in self._call(df) ] From 64853036a33bdf5c49bed6ed3af57786d3c76ac3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 09:39:55 +0000 Subject: [PATCH 12/20] extra quotes for safety --- narwhals/_duckdb/expr.py | 4 +++- narwhals/dependencies.py | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 845818c6f9..9958dc4e8a 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -490,7 +490,9 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ - SQLExpression(f"{expr} over partition by {','.join(partition_by)}") + SQLExpression( + f"{expr} over (partition by {','.join(f'"{x}"' for x in partition_by)})" + ) for expr in self._call(df) ] diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 071df0a155..1a00e49d9a 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -233,10 +233,11 @@ def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]: def is_sqlframe_dataframe(df: Any) -> TypeIs[SQLFrameDataFrame]: """Check whether `df` is a SQLFrame DataFrame without importing SQLFrame.""" - return bool( - (sqlframe := get_sqlframe()) is not None - and isinstance(df, sqlframe.base.dataframe.BaseDataFrame) - ) + if get_sqlframe() is not None: + from sqlframe.base.dataframe import BaseDataFrame + + return isinstance(df, BaseDataFrame) + return False def is_numpy_array(arr: Any | _NDArray[_ShapeT]) -> TypeIs[_NDArray[_ShapeT]]: From 5e50682ca28c7f2b2c9463e34daeaf52ec7cd580 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 09:48:16 +0000 Subject: [PATCH 13/20] extra quotes for safety --- narwhals/_duckdb/dataframe.py | 4 ++-- narwhals/_duckdb/expr.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 45504da2e3..0eb623b25e 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -366,8 +366,8 @@ def unique( query = f""" with cte as ( select *, - row_number() over (partition by {",".join(subset)}) as {idx_name}, - count(*) over (partition by {",".join(subset)}) as {count_name} + row_number() over (partition by {",".join([f'"{x}"' for x in subset])}) as {idx_name}, + count(*) over (partition by {",".join([f'"{x}"' for x in subset])}) as {count_name} from rel ) select * exclude ({idx_name}, {count_name}) from cte {keep_condition} diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 9958dc4e8a..e9bbc39c0e 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -491,7 +491,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ SQLExpression( - f"{expr} over (partition by {','.join(f'"{x}"' for x in partition_by)})" + f"{expr} over (partition by {', '.join([f'"{x}"' for x in partition_by])})" ) for expr in self._call(df) ] From 65ae7cbaf7c2f31b90873501f51801f68896a33a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 09:57:54 +0000 Subject: [PATCH 14/20] old python compat --- narwhals/_duckdb/dataframe.py | 6 ++++-- narwhals/_duckdb/expr.py | 13 ++++--------- narwhals/_duckdb/utils.py | 7 +++++++ 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 0eb623b25e..fcddeb857e 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -12,6 +12,7 @@ from duckdb import FunctionExpression from narwhals._duckdb.utils import evaluate_exprs +from narwhals._duckdb.utils import generate_partition_by_sql from narwhals._duckdb.utils import lit from narwhals._duckdb.utils import native_to_narwhals_dtype from narwhals.dependencies import get_duckdb @@ -363,11 +364,12 @@ def unique( keep_condition = f"where {count_name}=1" else: keep_condition = f"where {idx_name}=1" + partition_by_sql = generate_partition_by_sql(*subset) query = f""" with cte as ( select *, - row_number() over (partition by {",".join([f'"{x}"' for x in subset])}) as {idx_name}, - count(*) over (partition by {",".join([f'"{x}"' for x in subset])}) as {count_name} + row_number() over {partition_by_sql} as {idx_name}, + count(*) over {partition_by_sql} as {count_name} from rel ) select * exclude ({idx_name}, {count_name}) from cte {keep_condition} diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index e9bbc39c0e..e18ac567d0 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -21,6 +21,7 @@ from narwhals._duckdb.expr_name import DuckDBExprNameNamespace from narwhals._duckdb.expr_str import DuckDBExprStringNamespace from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace +from narwhals._duckdb.utils import generate_partition_by_sql from narwhals._duckdb.utils import lit from narwhals._duckdb.utils import maybe_evaluate_expr from narwhals._duckdb.utils import narwhals_to_native_dtype @@ -487,12 +488,11 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: for expr in self._call(df) ] else: + partition_by_sql = generate_partition_by_sql(*partition_by) def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ - SQLExpression( - f"{expr} over (partition by {', '.join([f'"{x}"' for x in partition_by])})" - ) + SQLExpression(f"{expr} over {partition_by_sql}") for expr in self._call(df) ] @@ -542,12 +542,7 @@ def func( order_by_sql = "order by " + ", ".join( f'"{x}" asc nulls first' for x in order_by ) - if partition_by: - partition_by_sql = "partition by " + ",".join( - f'"{x}"' for x in partition_by - ) - else: - partition_by_sql = "" + partition_by_sql = generate_partition_by_sql(*partition_by) sql = f"sum ({_input}) over ({partition_by_sql} {order_by_sql} rows between unbounded preceding and current row)" return SQLExpression(sql) diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 3bc4d950ae..74fc0af3b8 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -189,3 +189,10 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st return f"{duckdb_inner}{duckdb_shape_fmt}" msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) + + +def generate_partition_by_sql(*partition_by: str) -> str: + if not partition_by: + return "" + by_sql = ", ".join([f'"{x}"' for x in partition_by]) + return f"(partition by {by_sql})" From bc515afe0a0ad3f0c5bae397c2ad65c477dfcd3d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 09:59:51 +0000 Subject: [PATCH 15/20] old python compat --- narwhals/_duckdb/dataframe.py | 4 ++-- narwhals/_duckdb/expr.py | 2 +- narwhals/_duckdb/utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index fcddeb857e..d7894c345c 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -368,8 +368,8 @@ def unique( query = f""" with cte as ( select *, - row_number() over {partition_by_sql} as {idx_name}, - count(*) over {partition_by_sql} as {count_name} + row_number() over ({partition_by_sql}) as {idx_name}, + count(*) over ({partition_by_sql}) as {count_name} from rel ) select * exclude ({idx_name}, {count_name}) from cte {keep_condition} diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index e18ac567d0..af8084e8f7 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -492,7 +492,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ - SQLExpression(f"{expr} over {partition_by_sql}") + SQLExpression(f"{expr} over ({partition_by_sql})") for expr in self._call(df) ] diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 74fc0af3b8..14dabf0323 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -195,4 +195,4 @@ def generate_partition_by_sql(*partition_by: str) -> str: if not partition_by: return "" by_sql = ", ".join([f'"{x}"' for x in partition_by]) - return f"(partition by {by_sql})" + return f"partition by {by_sql}" From eecff037ec6abb031a0eb28caae1dcd245995eaf Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 12:36:10 +0000 Subject: [PATCH 16/20] make test have empty space --- narwhals/_duckdb/expr.py | 5 +- tests/expr_and_series/cum_sum_test.py | 68 +++++++++++++-------------- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index af8084e8f7..bf08860266 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import operator from typing import TYPE_CHECKING from typing import Any @@ -12,7 +13,6 @@ from duckdb import CoalesceOperator from duckdb import ColumnExpression from duckdb import FunctionExpression -from duckdb import SQLExpression from duckdb.typing import DuckDBPyType from narwhals._compliant import LazyExpr @@ -41,6 +41,9 @@ from narwhals.utils import Version from narwhals.utils import _FullContext +with contextlib.suppress(ImportError): # requires duckdb>=1.3.0 + from duckdb import SQLExpression + class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]): _implementation = Implementation.DUCKDB diff --git a/tests/expr_and_series/cum_sum_test.py b/tests/expr_and_series/cum_sum_test.py index e4d9de2210..ceb674af42 100644 --- a/tests/expr_and_series/cum_sum_test.py +++ b/tests/expr_and_series/cum_sum_test.py @@ -9,7 +9,7 @@ from tests.utils import ConstructorEager from tests.utils import assert_equal_data -data = {"a": [1, 2, None, 4]} +data = {"arg entina": [1, 2, None, 4]} expected = { "cum_sum": [1, 3, None, 7], "reverse_cum_sum": [7, 6, None, 4], @@ -21,7 +21,7 @@ def test_cum_sum_expr(constructor_eager: ConstructorEager, *, reverse: bool) -> name = "reverse_cum_sum" if reverse else "cum_sum" df = nw.from_native(constructor_eager(data)) result = df.select( - nw.col("a").cum_sum(reverse=reverse).alias(name), + nw.col("arg entina").cum_sum(reverse=reverse).alias(name), ) assert_equal_data(result, {name: expected[name]}) @@ -61,17 +61,17 @@ def test_lazy_cum_sum_grouped( df = nw.from_native( constructor( { - "a": [1, 2, 3], - "b": [1, 0, 2], - "i": [0, 1, 2], + "arg entina": [1, 2, 3], + "ban gkock": [1, 0, 2], + "i ran": [0, 1, 2], "g": [1, 1, 1], } ) ) result = df.with_columns( - nw.col("a").cum_sum(reverse=reverse).over("g", _order_by="b") - ).sort("i") - expected = {"a": expected_a, "b": [1, 0, 2], "i": [0, 1, 2]} + nw.col("arg entina").cum_sum(reverse=reverse).over("g", _order_by="ban gkock") + ).sort("i ran") + expected = {"arg entina": expected_a, "ban gkock": [1, 0, 2], "i ran": [0, 1, 2]} assert_equal_data(result, expected) @@ -109,20 +109,20 @@ def test_lazy_cum_sum_ordered_by_nulls( df = nw.from_native( constructor( { - "a": [1, 2, 3, 1, 2, 3, 4], - "b": [1, -1, 3, 2, 5, 0, None], - "i": [0, 1, 2, 3, 4, 5, 6], + "arg entina": [1, 2, 3, 1, 2, 3, 4], + "ban gkock": [1, -1, 3, 2, 5, 0, None], + "i ran": [0, 1, 2, 3, 4, 5, 6], "g": [1, 1, 1, 1, 1, 1, 1], } ) ) result = df.with_columns( - nw.col("a").cum_sum(reverse=reverse).over("g", _order_by="b") - ).sort("i") + nw.col("arg entina").cum_sum(reverse=reverse).over("g", _order_by="ban gkock") + ).sort("i ran") expected = { - "a": expected_a, - "b": [1, -1, 3, 2, 5, 0, None], - "i": [0, 1, 2, 3, 4, 5, 6], + "arg entina": expected_a, + "ban gkock": [1, -1, 3, 2, 5, 0, None], + "i ran": [0, 1, 2, 3, 4, 5, 6], } assert_equal_data(result, expected) @@ -155,16 +155,16 @@ def test_lazy_cum_sum_ungrouped( df = nw.from_native( constructor( { - "a": [2, 3, 1], - "b": [0, 2, 1], - "i": [1, 2, 0], + "arg entina": [2, 3, 1], + "ban gkock": [0, 2, 1], + "i ran": [1, 2, 0], } ) - ).sort("i") + ).sort("i ran") result = df.with_columns( - nw.col("a").cum_sum(reverse=reverse).over(_order_by="b") - ).sort("i") - expected = {"a": expected_a, "b": [1, 0, 2], "i": [0, 1, 2]} + nw.col("arg entina").cum_sum(reverse=reverse).over(_order_by="ban gkock") + ).sort("i ran") + expected = {"arg entina": expected_a, "ban gkock": [1, 0, 2], "i ran": [0, 1, 2]} assert_equal_data(result, expected) @@ -196,19 +196,19 @@ def test_lazy_cum_sum_ungrouped_ordered_by_nulls( df = nw.from_native( constructor( { - "a": [1, 2, 3, 1, 2, 3, 4], - "b": [1, -1, 3, 2, 5, 0, None], - "i": [0, 1, 2, 3, 4, 5, 6], + "arg entina": [1, 2, 3, 1, 2, 3, 4], + "ban gkock": [1, -1, 3, 2, 5, 0, None], + "i ran": [0, 1, 2, 3, 4, 5, 6], } ) - ).sort("i") + ).sort("i ran") result = df.with_columns( - nw.col("a").cum_sum(reverse=reverse).over(_order_by="b") - ).sort("i") + nw.col("arg entina").cum_sum(reverse=reverse).over(_order_by="ban gkock") + ).sort("i ran") expected = { - "a": expected_a, - "b": [1, -1, 3, 2, 5, 0, None], - "i": [0, 1, 2, 3, 4, 5, 6], + "arg entina": expected_a, + "ban gkock": [1, -1, 3, 2, 5, 0, None], + "i ran": [0, 1, 2, 3, 4, 5, 6], } assert_equal_data(result, expected) @@ -216,7 +216,7 @@ def test_lazy_cum_sum_ungrouped_ordered_by_nulls( def test_cum_sum_series(constructor_eager: ConstructorEager) -> None: df = nw.from_native(constructor_eager(data), eager_only=True) result = df.select( - cum_sum=df["a"].cum_sum(), - reverse_cum_sum=df["a"].cum_sum(reverse=True), + cum_sum=df["arg entina"].cum_sum(), + reverse_cum_sum=df["arg entina"].cum_sum(reverse=True), ) assert_equal_data(result, expected) From f1a50e852a0a0457378fbcf24aee2cb376a48226 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 12:52:01 +0000 Subject: [PATCH 17/20] factor out generate_order_by_sql --- narwhals/_duckdb/expr.py | 29 +++++++++-------------------- narwhals/_duckdb/utils.py | 8 ++++++++ 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index bf08860266..bd2a53d1f2 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -21,6 +21,7 @@ from narwhals._duckdb.expr_name import DuckDBExprNameNamespace from narwhals._duckdb.expr_str import DuckDBExprStringNamespace from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace +from narwhals._duckdb.utils import generate_order_by_sql from narwhals._duckdb.utils import generate_partition_by_sql from narwhals._duckdb.utils import lit from narwhals._duckdb.utils import maybe_evaluate_expr @@ -102,8 +103,10 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns." raise NotImplementedError(msg) + template = "{expr} over ()" + def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: - return [SQLExpression(f"{result} over ()") for result in self(df)] + return [SQLExpression(template.format(expr=expr)) for expr in self(df)] return self.__class__( func, @@ -492,11 +495,11 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: ] else: partition_by_sql = generate_partition_by_sql(*partition_by) + template = f"{{expr}} over ({partition_by_sql})" def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ - SQLExpression(f"{expr} over ({partition_by_sql})") - for expr in self._call(df) + SQLExpression(template.format(expr=expr)) for expr in self._call(df) ] return self.__class__( @@ -537,14 +540,7 @@ def func( partition_by: Sequence[str], order_by: Sequence[str], ) -> duckdb.Expression: - if reverse: - order_by_sql = "order by " + ", ".join( - f'"{x}" desc nulls last' for x in order_by - ) - else: - order_by_sql = "order by " + ", ".join( - f'"{x}" asc nulls first' for x in order_by - ) + order_by_sql = generate_order_by_sql(*order_by, ascending=not reverse) partition_by_sql = generate_partition_by_sql(*partition_by) sql = f"sum ({_input}) over ({partition_by_sql} {order_by_sql} rows between unbounded preceding and current row)" return SQLExpression(sql) @@ -566,15 +562,8 @@ def func( partition_by: Sequence[str], order_by: Sequence[str], ) -> duckdb.Expression: - order_by_sql = "order by " + ", ".join( - f'"{x}" asc nulls first' for x in order_by - ) - if partition_by: - partition_by_sql = "partition by " + ",".join( - f'"{x}"' for x in partition_by - ) - else: - partition_by_sql = "" + order_by_sql = generate_order_by_sql(*order_by, ascending=True) + partition_by_sql = generate_partition_by_sql(*partition_by) window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})" sql = f"case when count({_input}) over {window} >= {min_samples} then sum({_input}) over {window} else null end" return SQLExpression(sql) diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 14dabf0323..72b080f29f 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -196,3 +196,11 @@ def generate_partition_by_sql(*partition_by: str) -> str: return "" by_sql = ", ".join([f'"{x}"' for x in partition_by]) return f"partition by {by_sql}" + + +def generate_order_by_sql(*order_by: str, ascending: bool) -> str: + if ascending: + by_sql = ", ".join([f'"{x}" asc nulls first' for x in order_by]) + else: + by_sql = ", ".join([f'"{x}" desc nulls last' for x in order_by]) + return f"order by {by_sql}" From 231f1f3f7607e4c753bdc7271d665414fd87dd0d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 13:00:30 +0000 Subject: [PATCH 18/20] remove outdated sqlframe xfail --- tests/expr_and_series/str/to_datetime_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index fda87a0b97..ef076613a0 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -216,9 +216,6 @@ def test_to_datetime_tz_aware( pytest.skip() if is_pyarrow_windows_no_tzdata(constructor): pytest.skip() - if "sqlframe" in str(constructor): - # https://github.com/eakmanrq/sqlframe/issues/325 - request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor): # cuDF does not yet support timezone-aware datetimes request.applymarker(pytest.mark.xfail) From 6fb1b7643810b137d6aa6c0cf34d4c9d119a0a98 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 13:02:49 +0000 Subject: [PATCH 19/20] again --- tests/expr_and_series/str/to_datetime_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index ef076613a0..412485d01f 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -221,7 +221,7 @@ def test_to_datetime_tz_aware( request.applymarker(pytest.mark.xfail) context = ( pytest.raises(NotImplementedError) - if any(x in str(constructor) for x in ("duckdb", "sqlframe")) and format is None + if any(x in str(constructor) for x in ("duckdb",)) and format is None else does_not_raise() ) df = nw.from_native(constructor({"a": ["2020-01-01T01:02:03+0100"]})) From de9f3756ccc921280341b5447ed9586e988ce88e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 22 Mar 2025 13:11:46 +0000 Subject: [PATCH 20/20] coverage --- narwhals/dependencies.py | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 1a00e49d9a..e32e544abc 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -237,7 +237,7 @@ def is_sqlframe_dataframe(df: Any) -> TypeIs[SQLFrameDataFrame]: from sqlframe.base.dataframe import BaseDataFrame return isinstance(df, BaseDataFrame) - return False + return False # pragma: no cover def is_numpy_array(arr: Any | _NDArray[_ShapeT]) -> TypeIs[_NDArray[_ShapeT]]: diff --git a/pyproject.toml b/pyproject.toml index 037b5d3b87..bc1721479f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,6 +242,7 @@ omit = [ 'narwhals/stable/v1/typing.py', 'narwhals/this.py', 'narwhals/_arrow/typing.py', + 'narwhals/_duckdb/typing.py', 'narwhals/_spark_like/typing.py', # we can't run this in every environment that we measure coverage on due to upper-bound constraits 'narwhals/_ibis/*',