diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index 5fa857894f..5b1419a5af 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -337,12 +337,12 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> if isinstance(descending, bool): descending = [descending for _ in range(len(by))] - sort_cols = [] + sort_cols: list[Any] = [] for i in range(len(by)): direction_fn = ibis.desc if descending[i] else ibis.asc col = direction_fn(by[i], nulls_first=not nulls_last) - sort_cols.append(cast("ir.Column", col)) + sort_cols.append(col) return self._with_native(self.native.order_by(*sort_cols)) diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index 13c3989d8c..df98a3c8c3 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -227,7 +227,7 @@ def _std(expr: ir.NumericColumn, ddof: int) -> ir.Value: return expr.std(how="sample") n_samples = expr.count() std_pop = expr.std(how="pop") - ddof_lit = cast("ir.IntegerScalar", ibis.literal(ddof)) + ddof_lit = lit(ddof) return std_pop * n_samples.sqrt() / (n_samples - ddof_lit).sqrt() return self._with_callable(lambda expr: _std(expr, ddof)) @@ -240,7 +240,7 @@ def _var(expr: ir.NumericColumn, ddof: int) -> ir.Value: return expr.var(how="sample") n_samples = expr.count() var_pop = expr.var(how="pop") - ddof_lit = cast("ir.IntegerScalar", ibis.literal(ddof)) + ddof_lit = lit(ddof) return var_pop * n_samples / (n_samples - ddof_lit) return self._with_callable(lambda expr: _var(expr, ddof)) @@ -290,35 +290,33 @@ def is_unique(self) -> Self: ) def rank(self, method: RankMethod, *, descending: bool) -> Self: - def _rank(expr: ir.Column) -> ir.Column: + def _rank(expr: ir.Column) -> ir.Value: order_by = next(self._sort(expr, descending=[descending], nulls_last=[True])) window = ibis.window(order_by=order_by) if method == "dense": rank_ = order_by.dense_rank() elif method == "ordinal": - rank_ = cast("ir.IntegerColumn", ibis.row_number().over(window)) + rank_ = ibis.row_number().over(window) else: rank_ = order_by.rank() # Ibis uses 0-based ranking. Add 1 to match polars 1-based rank. - rank_ = rank_ + cast("ir.IntegerValue", lit(1)) + rank_ = rank_ + lit(1) # For "max" and "average", adjust using the count of rows in the partition. if method == "max": # Define a window partitioned by expr (i.e. each distinct value) partition = ibis.window(group_by=[expr]) - cnt = cast("ir.IntegerValue", expr.count().over(partition)) - rank_ = rank_ + cnt - cast("ir.IntegerValue", lit(1)) + cnt = expr.count().over(partition) + rank_ = rank_ + cnt - lit(1) elif method == "average": partition = ibis.window(group_by=[expr]) - cnt = cast("ir.IntegerValue", expr.count().over(partition)) - avg = cast( - "ir.NumericValue", (cnt - cast("ir.IntegerScalar", lit(1))) / lit(2.0) - ) + cnt = expr.count().over(partition) + avg = cast("ir.NumericValue", (cnt - lit(1)) / lit(2.0)) rank_ = rank_ + avg - return cast("ir.Column", ibis.cases((expr.notnull(), rank_))) + return ibis.cases((expr.notnull(), rank_)) return self._with_callable(_rank) diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index 9123be46e8..acf05aa6d0 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -3,7 +3,7 @@ import operator from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import ibis import ibis.expr.types as ir @@ -88,8 +88,7 @@ def func(df: IbisLazyFrame) -> list[ir.Value]: for col in cols_casted[1:]: result = result + separator + col else: - sep = cast("ir.StringValue", lit(separator)) - result = sep.join(cols_casted) + result = lit(separator).join(cols_casted) return [result] diff --git a/narwhals/_ibis/utils.py b/narwhals/_ibis/utils.py index a76d668998..a8b5bfa833 100644 --- a/narwhals/_ibis/utils.py +++ b/narwhals/_ibis/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import lru_cache -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast, overload import ibis import ibis.expr.datatypes as ibis_dtypes @@ -23,8 +23,27 @@ from narwhals.dtypes import DType from narwhals.typing import IntoDType, PythonLiteral -lit = ibis.literal -"""Alias for `ibis.literal`.""" +Incomplete: TypeAlias = Any +"""Marker for upstream issues.""" + + +@overload +def lit(value: bool, dtype: None = ...) -> ir.BooleanScalar: ... # noqa: FBT001 +@overload +def lit(value: int, dtype: None = ...) -> ir.IntegerScalar: ... +@overload +def lit(value: float, dtype: None = ...) -> ir.FloatingScalar: ... +@overload +def lit(value: str, dtype: None = ...) -> ir.StringScalar: ... +@overload +def lit(value: PythonLiteral | ir.Value, dtype: None = ...) -> ir.Scalar: ... +@overload +def lit(value: Any, dtype: Any) -> Incomplete: ... +def lit(value: Any, dtype: Any | None = None) -> Incomplete: + """Alias for `ibis.literal`.""" + literal: Incomplete = ibis.literal + return literal(value, dtype) + BucketUnit: TypeAlias = Literal[ "years", @@ -231,11 +250,11 @@ def timedelta_to_ibis_interval(td: timedelta) -> ibis.expr.types.temporal.Interv def function(name: str, *args: ir.Value | PythonLiteral) -> ir.Value: # Workaround SQL vs Ibis differences. if name == "row_number": - return ibis.row_number() + 1 # pyright: ignore[reportOperatorIssue] + return ibis.row_number() + lit(1) if name == "least": - return ibis.least(*args) # pyright: ignore[reportOperatorIssue] + return ibis.least(*args) if name == "greatest": - return ibis.greatest(*args) # pyright: ignore[reportOperatorIssue] + return ibis.greatest(*args) expr = args[0] if name == "var_pop": return cast("ir.NumericColumn", expr).var(how="pop")