diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index 31061e587f..26411e79f2 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -179,9 +179,10 @@ jobs: run: | uv pip uninstall dask dask-expr --system python -m pip install git+https://github.com/dask/distributed git+https://github.com/dask/dask - - name: install duckdb + - name: install duckdb nightly run: | - python -m pip install -U --pre duckdb + uv pip uninstall duckdb --system + uv pip install -U --pre duckdb --system - name: show-deps run: uv pip freeze - name: Assert nightlies dependencies diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index b369afcbeb..338cc645e1 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -84,6 +84,8 @@ jobs: cache-dependency-glob: "pyproject.toml" - name: install-reqs run: uv pip install -e ".[modin, dask]" --group core-tests --group extra --system + - name: install duckdb nightly + run: uv pip install -U --pre duckdb --system - name: install pyspark run: uv pip install -e ".[pyspark]" --system # PySpark is not yet available on Python3.12+ diff --git a/Makefile b/Makefile index 39f23d2cf4..5a0feb5f93 100644 --- a/Makefile +++ b/Makefile @@ -20,5 +20,7 @@ help: ## Display this help screen .PHONY: typing typing: ## Run typing checks + # install duckdb nightly so mypy recognises duckdb.SQLExpression + $(VENV_BIN)/uv pip install -U --pre duckdb $(VENV_BIN)/uv pip install -e . --group typing $(VENV_BIN)/mypy diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 45504da2e3..d7894c345c 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -12,6 +12,7 @@ from duckdb import FunctionExpression from narwhals._duckdb.utils import evaluate_exprs +from narwhals._duckdb.utils import generate_partition_by_sql from narwhals._duckdb.utils import lit from narwhals._duckdb.utils import native_to_narwhals_dtype from narwhals.dependencies import get_duckdb @@ -363,11 +364,12 @@ def unique( keep_condition = f"where {count_name}=1" else: keep_condition = f"where {idx_name}=1" + partition_by_sql = generate_partition_by_sql(*subset) query = f""" with cte as ( select *, - row_number() over (partition by {",".join(subset)}) as {idx_name}, - count(*) over (partition by {",".join(subset)}) as {count_name} + row_number() over ({partition_by_sql}) as {idx_name}, + count(*) over ({partition_by_sql}) as {count_name} from rel ) select * exclude ({idx_name}, {count_name}) from cte {keep_condition} diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 62bc942652..bd2a53d1f2 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import operator from typing import TYPE_CHECKING from typing import Any @@ -20,6 +21,8 @@ from narwhals._duckdb.expr_name import DuckDBExprNameNamespace from narwhals._duckdb.expr_str import DuckDBExprStringNamespace from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace +from narwhals._duckdb.utils import generate_order_by_sql +from narwhals._duckdb.utils import generate_partition_by_sql from narwhals._duckdb.utils import lit from narwhals._duckdb.utils import maybe_evaluate_expr from narwhals._duckdb.utils import narwhals_to_native_dtype @@ -33,11 +36,15 @@ from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.namespace import DuckDBNamespace + from narwhals._duckdb.typing import WindowFunction from narwhals._expression_parsing import ExprMetadata from narwhals.dtypes import DType from narwhals.utils import Version from narwhals.utils import _FullContext +with contextlib.suppress(ImportError): # requires duckdb>=1.3.0 + from duckdb import SQLExpression + class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]): _implementation = Implementation.DUCKDB @@ -59,6 +66,7 @@ def __init__( self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version + self._window_function: WindowFunction | None = None self._metadata: ExprMetadata | None = None def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: @@ -83,15 +91,31 @@ def _with_metadata(self, metadata: ExprMetadata) -> Self: backend_version=self._backend_version, version=self._version, ) + if func := self._window_function: + expr = expr._with_window_function(func) expr._metadata = metadata return expr def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: - if kind is ExprKind.AGGREGATION: - msg = "Broadcasting aggregations is not yet supported for DuckDB." + if kind is ExprKind.LITERAL: + return self + if self._backend_version < (1, 3): + msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns." raise NotImplementedError(msg) - # For literals, DuckDB does its own broadcasting. - return self + + template = "{expr} over ()" + + def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: + return [SQLExpression(template.format(expr=expr)) for expr in self(df)] + + return self.__class__( + func, + function_name=self._function_name, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + ) @classmethod def from_column_names( @@ -167,6 +191,21 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: version=self._version, ) + def _with_window_function( + self: Self, + window_function: WindowFunction, + ) -> Self: + result = self.__class__( + self._call, + function_name=self._function_name, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + ) + result._window_function = window_function + return result + def __and__(self: Self, other: DuckDBExpr) -> Self: return self._from_call( lambda _input, other: _input & other, @@ -438,6 +477,40 @@ def null_count(self: Self) -> Self: "null_count", ) + def over( + self: Self, + partition_by: Sequence[str], + order_by: Sequence[str] | None, + ) -> Self: + if self._backend_version < (1, 3): + msg = "At least version 1.3 of DuckDB is required for `over` operation." + raise NotImplementedError(msg) + if (window_function := self._window_function) is not None: + assert order_by is not None # noqa: S101 + + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + return [ + window_function(expr, partition_by, order_by) + for expr in self._call(df) + ] + else: + partition_by_sql = generate_partition_by_sql(*partition_by) + template = f"{{expr}} over ({partition_by_sql})" + + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + return [ + SQLExpression(template.format(expr=expr)) for expr in self._call(df) + ] + + return self.__class__( + func, + function_name=self._function_name + "->over", + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + ) + def is_null(self: Self) -> Self: return self._from_call(lambda _input: _input.isnull(), "is_null") @@ -461,6 +534,42 @@ def round(self: Self, decimals: int) -> Self: lambda _input: FunctionExpression("round", _input, lit(decimals)), "round" ) + def cum_sum(self, *, reverse: bool) -> Self: + def func( + _input: duckdb.Expression, + partition_by: Sequence[str], + order_by: Sequence[str], + ) -> duckdb.Expression: + order_by_sql = generate_order_by_sql(*order_by, ascending=not reverse) + partition_by_sql = generate_partition_by_sql(*partition_by) + sql = f"sum ({_input}) over ({partition_by_sql} {order_by_sql} rows between unbounded preceding and current row)" + return SQLExpression(sql) + + return self._with_window_function(func) + + def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: + if center: + half = (window_size - 1) // 2 + remainder = (window_size - 1) % 2 + start = f"{half + remainder} preceding" + end = f"{half} following" + else: + start = f"{window_size - 1} preceding" + end = "current row" + + def func( + _input: duckdb.Expression, + partition_by: Sequence[str], + order_by: Sequence[str], + ) -> duckdb.Expression: + order_by_sql = generate_order_by_sql(*order_by, ascending=True) + partition_by_sql = generate_partition_by_sql(*partition_by) + window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})" + sql = f"case when count({_input}) over {window} >= {min_samples} then sum({_input}) over {window} else null end" + return SQLExpression(sql) + + return self._with_window_function(func) + def fill_null( self: Self, value: Self | Any, strategy: Any, limit: int | None ) -> Self: @@ -507,10 +616,7 @@ def struct(self: Self) -> DuckDBExprStructNamespace: is_unique = not_implemented() is_first_distinct = not_implemented() is_last_distinct = not_implemented() - cum_sum = not_implemented() cum_count = not_implemented() cum_min = not_implemented() cum_max = not_implemented() cum_prod = not_implemented() - over = not_implemented() - rolling_sum = not_implemented() diff --git a/narwhals/_duckdb/typing.py b/narwhals/_duckdb/typing.py new file mode 100644 index 0000000000..9661351fb8 --- /dev/null +++ b/narwhals/_duckdb/typing.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Protocol +from typing import Sequence + +if TYPE_CHECKING: + import duckdb + + class WindowFunction(Protocol): + def __call__( + self, + _input: duckdb.Expression, + partition_by: Sequence[str], + order_by: Sequence[str], + ) -> duckdb.Expression: ... diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 3bc4d950ae..72b080f29f 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -189,3 +189,18 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st return f"{duckdb_inner}{duckdb_shape_fmt}" msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) + + +def generate_partition_by_sql(*partition_by: str) -> str: + if not partition_by: + return "" + by_sql = ", ".join([f'"{x}"' for x in partition_by]) + return f"partition by {by_sql}" + + +def generate_order_by_sql(*order_by: str, ascending: bool) -> str: + if ascending: + by_sql = ", ".join([f'"{x}" asc nulls first' for x in order_by]) + else: + by_sql = ", ".join([f'"{x}" desc nulls last' for x in order_by]) + return f"order by {by_sql}" diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index e9314cae20..22b8461d2e 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -67,14 +67,13 @@ def __call__(self: Self, df: SparkLikeLazyFrame) -> Sequence[Column]: return self._call(df) def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + if kind is ExprKind.LITERAL: + return self + def func(df: SparkLikeLazyFrame) -> Sequence[Column]: - if kind is ExprKind.AGGREGATION: - return [ - result.over(df._Window().partitionBy(df._F.lit(1))) - for result in self(df) - ] - # Let PySpark do its own broadcasting for literals. - return self(df) + return [ + result.over(df._Window().partitionBy(df._F.lit(1))) for result in self(df) + ] return self.__class__( func, diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 071df0a155..e32e544abc 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -233,10 +233,11 @@ def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]: def is_sqlframe_dataframe(df: Any) -> TypeIs[SQLFrameDataFrame]: """Check whether `df` is a SQLFrame DataFrame without importing SQLFrame.""" - return bool( - (sqlframe := get_sqlframe()) is not None - and isinstance(df, sqlframe.base.dataframe.BaseDataFrame) - ) + if get_sqlframe() is not None: + from sqlframe.base.dataframe import BaseDataFrame + + return isinstance(df, BaseDataFrame) + return False # pragma: no cover def is_numpy_array(arr: Any | _NDArray[_ShapeT]) -> TypeIs[_NDArray[_ShapeT]]: diff --git a/pyproject.toml b/pyproject.toml index 037b5d3b87..bc1721479f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,6 +242,7 @@ omit = [ 'narwhals/stable/v1/typing.py', 'narwhals/this.py', 'narwhals/_arrow/typing.py', + 'narwhals/_duckdb/typing.py', 'narwhals/_spark_like/typing.py', # we can't run this in every environment that we measure coverage on due to upper-bound constraits 'narwhals/_ibis/*', diff --git a/tests/expr_and_series/cum_sum_test.py b/tests/expr_and_series/cum_sum_test.py index 884921e7b4..ceb674af42 100644 --- a/tests/expr_and_series/cum_sum_test.py +++ b/tests/expr_and_series/cum_sum_test.py @@ -3,12 +3,13 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import DUCKDB_VERSION from tests.utils import POLARS_VERSION from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data -data = {"a": [1, 2, None, 4]} +data = {"arg entina": [1, 2, None, 4]} expected = { "cum_sum": [1, 3, None, 7], "reverse_cum_sum": [7, 6, None, 4], @@ -20,7 +21,7 @@ def test_cum_sum_expr(constructor_eager: ConstructorEager, *, reverse: bool) -> name = "reverse_cum_sum" if reverse else "cum_sum" df = nw.from_native(constructor_eager(data)) result = df.select( - nw.col("a").cum_sum(reverse=reverse).alias(name), + nw.col("arg entina").cum_sum(reverse=reverse).alias(name), ) assert_equal_data(result, {name: expected[name]}) @@ -40,9 +41,6 @@ def test_lazy_cum_sum_grouped( reverse: bool, expected_a: list[int], ) -> None: - if "duckdb" in str(constructor): - # no window function support yet in duckdb - request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor): # grouped window functions not yet supported request.applymarker(pytest.mark.xfail) @@ -52,7 +50,9 @@ def test_lazy_cum_sum_grouped( if "dask" in str(constructor): # https://github.com/dask/dask/issues/11806 request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip(reason="too old version") if "cudf" in str(constructor): # https://github.com/rapidsai/cudf/issues/18159 @@ -61,17 +61,17 @@ def test_lazy_cum_sum_grouped( df = nw.from_native( constructor( { - "a": [1, 2, 3], - "b": [1, 0, 2], - "i": [0, 1, 2], + "arg entina": [1, 2, 3], + "ban gkock": [1, 0, 2], + "i ran": [0, 1, 2], "g": [1, 1, 1], } ) ) result = df.with_columns( - nw.col("a").cum_sum(reverse=reverse).over("g", _order_by="b") - ).sort("i") - expected = {"a": expected_a, "b": [1, 0, 2], "i": [0, 1, 2]} + nw.col("arg entina").cum_sum(reverse=reverse).over("g", _order_by="ban gkock") + ).sort("i ran") + expected = {"arg entina": expected_a, "ban gkock": [1, 0, 2], "i ran": [0, 1, 2]} assert_equal_data(result, expected) @@ -89,9 +89,6 @@ def test_lazy_cum_sum_ordered_by_nulls( reverse: bool, expected_a: list[int], ) -> None: - if "duckdb" in str(constructor): - # no window function support yet in duckdb - request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor): # grouped window functions not yet supported request.applymarker(pytest.mark.xfail) @@ -101,7 +98,9 @@ def test_lazy_cum_sum_ordered_by_nulls( if "dask" in str(constructor): # https://github.com/dask/dask/issues/11806 request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip(reason="too old version") if "cudf" in str(constructor): # https://github.com/rapidsai/cudf/issues/18159 @@ -110,20 +109,20 @@ def test_lazy_cum_sum_ordered_by_nulls( df = nw.from_native( constructor( { - "a": [1, 2, 3, 1, 2, 3, 4], - "b": [1, -1, 3, 2, 5, 0, None], - "i": [0, 1, 2, 3, 4, 5, 6], + "arg entina": [1, 2, 3, 1, 2, 3, 4], + "ban gkock": [1, -1, 3, 2, 5, 0, None], + "i ran": [0, 1, 2, 3, 4, 5, 6], "g": [1, 1, 1, 1, 1, 1, 1], } ) ) result = df.with_columns( - nw.col("a").cum_sum(reverse=reverse).over("g", _order_by="b") - ).sort("i") + nw.col("arg entina").cum_sum(reverse=reverse).over("g", _order_by="ban gkock") + ).sort("i ran") expected = { - "a": expected_a, - "b": [1, -1, 3, 2, 5, 0, None], - "i": [0, 1, 2, 3, 4, 5, 6], + "arg entina": expected_a, + "ban gkock": [1, -1, 3, 2, 5, 0, None], + "i ran": [0, 1, 2, 3, 4, 5, 6], } assert_equal_data(result, expected) @@ -142,31 +141,30 @@ def test_lazy_cum_sum_ungrouped( reverse: bool, expected_a: list[int], ) -> None: - if "duckdb" in str(constructor): - # no window function support yet in duckdb - request.applymarker(pytest.mark.xfail) if "dask" in str(constructor) and reverse: # https://github.com/dask/dask/issues/11802 request.applymarker(pytest.mark.xfail) if "modin" in str(constructor): # probably bugged request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip(reason="too old version") df = nw.from_native( constructor( { - "a": [2, 3, 1], - "b": [0, 2, 1], - "i": [1, 2, 0], + "arg entina": [2, 3, 1], + "ban gkock": [0, 2, 1], + "i ran": [1, 2, 0], } ) - ).sort("i") + ).sort("i ran") result = df.with_columns( - nw.col("a").cum_sum(reverse=reverse).over(_order_by="b") - ).sort("i") - expected = {"a": expected_a, "b": [1, 0, 2], "i": [0, 1, 2]} + nw.col("arg entina").cum_sum(reverse=reverse).over(_order_by="ban gkock") + ).sort("i ran") + expected = {"arg entina": expected_a, "ban gkock": [1, 0, 2], "i ran": [0, 1, 2]} assert_equal_data(result, expected) @@ -184,34 +182,33 @@ def test_lazy_cum_sum_ungrouped_ordered_by_nulls( reverse: bool, expected_a: list[int], ) -> None: - if "duckdb" in str(constructor): - # no window function support yet in duckdb - request.applymarker(pytest.mark.xfail) if "dask" in str(constructor): # https://github.com/dask/dask/issues/11806 request.applymarker(pytest.mark.xfail) if "modin" in str(constructor): # probably bugged request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor) and POLARS_VERSION < (1, 9): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip(reason="too old version") df = nw.from_native( constructor( { - "a": [1, 2, 3, 1, 2, 3, 4], - "b": [1, -1, 3, 2, 5, 0, None], - "i": [0, 1, 2, 3, 4, 5, 6], + "arg entina": [1, 2, 3, 1, 2, 3, 4], + "ban gkock": [1, -1, 3, 2, 5, 0, None], + "i ran": [0, 1, 2, 3, 4, 5, 6], } ) - ).sort("i") + ).sort("i ran") result = df.with_columns( - nw.col("a").cum_sum(reverse=reverse).over(_order_by="b") - ).sort("i") + nw.col("arg entina").cum_sum(reverse=reverse).over(_order_by="ban gkock") + ).sort("i ran") expected = { - "a": expected_a, - "b": [1, -1, 3, 2, 5, 0, None], - "i": [0, 1, 2, 3, 4, 5, 6], + "arg entina": expected_a, + "ban gkock": [1, -1, 3, 2, 5, 0, None], + "i ran": [0, 1, 2, 3, 4, 5, 6], } assert_equal_data(result, expected) @@ -219,7 +216,7 @@ def test_lazy_cum_sum_ungrouped_ordered_by_nulls( def test_cum_sum_series(constructor_eager: ConstructorEager) -> None: df = nw.from_native(constructor_eager(data), eager_only=True) result = df.select( - cum_sum=df["a"].cum_sum(), - reverse_cum_sum=df["a"].cum_sum(reverse=True), + cum_sum=df["arg entina"].cum_sum(), + reverse_cum_sum=df["arg entina"].cum_sum(reverse=True), ) assert_equal_data(result, expected) diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index 38bfdd2cc7..61c53d28a1 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -9,6 +9,7 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import LengthChangingExprError +from tests.utils import DUCKDB_VERSION from tests.utils import PANDAS_VERSION from tests.utils import POLARS_VERSION from tests.utils import Constructor @@ -30,9 +31,9 @@ } -def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_over_single(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() df = nw.from_native(constructor(data)) expected = { @@ -85,10 +86,9 @@ def test_over_std_var(request: pytest.FixtureRequest, constructor: Constructor) assert_equal_data(result, expected) -def test_over_multiple(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_over_multiple(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() df = nw.from_native(constructor(data)) expected = { "a": ["a", "a", "b", "b", "b"], @@ -275,9 +275,8 @@ def test_over_anonymous_cumulative( def test_over_anonymous_reduction( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): - # TODO(unassigned): we should be able to support these - request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() if "modin" in str(constructor): # probably bugged request.applymarker(pytest.mark.xfail) @@ -421,7 +420,7 @@ def test_over_without_partition_by( ) -> None: if "polars" in str(constructor) and POLARS_VERSION < (1, 10): pytest.skip() - if "duckdb" in str(constructor): + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): # windows not yet supported request.applymarker(pytest.mark.xfail) if "modin" in str(constructor): diff --git a/tests/expr_and_series/reduction_test.py b/tests/expr_and_series/reduction_test.py index b1d84e85e4..47d55e11f8 100644 --- a/tests/expr_and_series/reduction_test.py +++ b/tests/expr_and_series/reduction_test.py @@ -5,6 +5,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import DUCKDB_VERSION from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -31,11 +32,9 @@ def test_scalar_reduction_select( constructor: Constructor, expr: list[Any], expected: dict[str, list[Any]], - request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor) and request.node.callspec.id not in {"duckdb-0"}: - request.applymarker(pytest.mark.xfail) - + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) result = df.select(*expr) @@ -63,10 +62,9 @@ def test_scalar_reduction_with_columns( constructor: Constructor, expr: list[Any], expected: dict[str, list[Any]], - request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) result = df.with_columns(*expr).select(*expected.keys()) @@ -107,11 +105,9 @@ def test_empty_scalar_reduction_select( assert_equal_data(result, expected) -def test_empty_scalar_reduction_with_columns( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_empty_scalar_reduction_with_columns(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() from itertools import chain data = { diff --git a/tests/expr_and_series/rolling_sum_test.py b/tests/expr_and_series/rolling_sum_test.py index c4e30d571c..da7ea09cf8 100644 --- a/tests/expr_and_series/rolling_sum_test.py +++ b/tests/expr_and_series/rolling_sum_test.py @@ -11,6 +11,7 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import InvalidOperationError +from tests.utils import DUCKDB_VERSION from tests.utils import PANDAS_VERSION from tests.utils import POLARS_VERSION from tests.utils import Constructor @@ -76,14 +77,13 @@ def test_rolling_sum_expr_lazy_ungrouped( expected_a: list[float], window_size: int, min_samples: int, - request: pytest.FixtureRequest, *, center: bool, ) -> None: - if "polars" in str(constructor) and POLARS_VERSION < (1, 10): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 10)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip() - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) if "modin" in str(constructor): # unreliable pytest.skip() @@ -130,11 +130,13 @@ def test_rolling_sum_expr_lazy_grouped( *, center: bool, ) -> None: - if "polars" in str(constructor) and POLARS_VERSION < (1, 10): + if ("polars" in str(constructor) and POLARS_VERSION < (1, 10)) or ( + "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3) + ): pytest.skip() if "pandas" in str(constructor) and PANDAS_VERSION < (1, 2): pytest.skip() - if any(x in str(constructor) for x in ("dask", "pyarrow_table", "duckdb")): + if any(x in str(constructor) for x in ("dask", "pyarrow_table")): request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor) and center: # center is not implemented for offset-based windows diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index fda87a0b97..412485d01f 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -216,15 +216,12 @@ def test_to_datetime_tz_aware( pytest.skip() if is_pyarrow_windows_no_tzdata(constructor): pytest.skip() - if "sqlframe" in str(constructor): - # https://github.com/eakmanrq/sqlframe/issues/325 - request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor): # cuDF does not yet support timezone-aware datetimes request.applymarker(pytest.mark.xfail) context = ( pytest.raises(NotImplementedError) - if any(x in str(constructor) for x in ("duckdb", "sqlframe")) and format is None + if any(x in str(constructor) for x in ("duckdb",)) and format is None else does_not_raise() ) df = nw.from_native(constructor({"a": ["2020-01-01T01:02:03+0100"]})) diff --git a/tests/expr_and_series/sum_horizontal_test.py b/tests/expr_and_series/sum_horizontal_test.py index aaef64356f..3da18660b6 100644 --- a/tests/expr_and_series/sum_horizontal_test.py +++ b/tests/expr_and_series/sum_horizontal_test.py @@ -5,6 +5,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import DUCKDB_VERSION from tests.utils import Constructor from tests.utils import assert_equal_data @@ -57,12 +58,9 @@ def test_sumh_aggregations(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_sumh_transformations( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - # We don't yet support broadcasting for DuckDB. - request.applymarker(pytest.mark.xfail) +def test_sumh_transformations(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} df = nw.from_native(constructor(data)) result = df.select(d=nw.sum_horizontal("a", nw.col("b").sum(), "c")) diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index 32ef82aef3..6b29d97d3b 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -11,6 +11,7 @@ from narwhals.exceptions import InvalidIntoExprError from narwhals.exceptions import NarwhalsError from tests.utils import DASK_VERSION +from tests.utils import DUCKDB_VERSION from tests.utils import PANDAS_VERSION from tests.utils import POLARS_VERSION from tests.utils import Constructor @@ -119,10 +120,10 @@ def test_missing_columns( def test_left_to_right_broadcasting( constructor: Constructor, request: pytest.FixtureRequest ) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() if "dask" in str(constructor) and DASK_VERSION < (2024, 10): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6]})) result = df.select(nw.col("a") + nw.col("b").sum()) expected = {"a": [16, 16, 17]} diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index 24f8006c49..4d8440a260 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -8,6 +8,7 @@ import narwhals as nw import narwhals.stable.v1 as nw_v1 +from tests.utils import DUCKDB_VERSION from tests.utils import Constructor from tests.utils import assert_equal_data @@ -18,11 +19,9 @@ def remove_docstring_examples(doc: str) -> str: return doc.rstrip() -def test_renamed_taxicab_norm( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_renamed_taxicab_norm(constructor: Constructor) -> None: + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3): + pytest.skip() # Suppose we need to rename `_l1_norm` to `_taxicab_norm`. # We need `narwhals.stable.v1` to stay stable. So, we # make the change in `narwhals`, and then add the new method