Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions narwhals/_ibis/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->
if isinstance(descending, bool):
descending = [descending for _ in range(len(by))]

sort_cols = []
sort_cols: list[Any] = []

for i in range(len(by)):
direction_fn = ibis.desc if descending[i] else ibis.asc
col = direction_fn(by[i], nulls_first=not nulls_last)
sort_cols.append(cast("ir.Column", col))
sort_cols.append(col)

return self._with_native(self.native.order_by(*sort_cols))

Expand Down
22 changes: 10 additions & 12 deletions narwhals/_ibis/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _std(expr: ir.NumericColumn, ddof: int) -> ir.Value:
return expr.std(how="sample")
n_samples = expr.count()
std_pop = expr.std(how="pop")
ddof_lit = cast("ir.IntegerScalar", ibis.literal(ddof))
ddof_lit = lit(ddof)
return std_pop * n_samples.sqrt() / (n_samples - ddof_lit).sqrt()

return self._with_callable(lambda expr: _std(expr, ddof))
Expand All @@ -240,7 +240,7 @@ def _var(expr: ir.NumericColumn, ddof: int) -> ir.Value:
return expr.var(how="sample")
n_samples = expr.count()
var_pop = expr.var(how="pop")
ddof_lit = cast("ir.IntegerScalar", ibis.literal(ddof))
ddof_lit = lit(ddof)
return var_pop * n_samples / (n_samples - ddof_lit)

return self._with_callable(lambda expr: _var(expr, ddof))
Expand Down Expand Up @@ -290,35 +290,33 @@ def is_unique(self) -> Self:
)

def rank(self, method: RankMethod, *, descending: bool) -> Self:
def _rank(expr: ir.Column) -> ir.Column:
def _rank(expr: ir.Column) -> ir.Value:
order_by = next(self._sort(expr, descending=[descending], nulls_last=[True]))
window = ibis.window(order_by=order_by)

if method == "dense":
rank_ = order_by.dense_rank()
elif method == "ordinal":
rank_ = cast("ir.IntegerColumn", ibis.row_number().over(window))
rank_ = ibis.row_number().over(window)
else:
rank_ = order_by.rank()

# Ibis uses 0-based ranking. Add 1 to match polars 1-based rank.
rank_ = rank_ + cast("ir.IntegerValue", lit(1))
rank_ = rank_ + lit(1)

# For "max" and "average", adjust using the count of rows in the partition.
if method == "max":
# Define a window partitioned by expr (i.e. each distinct value)
partition = ibis.window(group_by=[expr])
cnt = cast("ir.IntegerValue", expr.count().over(partition))
rank_ = rank_ + cnt - cast("ir.IntegerValue", lit(1))
cnt = expr.count().over(partition)
rank_ = rank_ + cnt - lit(1)
elif method == "average":
partition = ibis.window(group_by=[expr])
cnt = cast("ir.IntegerValue", expr.count().over(partition))
avg = cast(
"ir.NumericValue", (cnt - cast("ir.IntegerScalar", lit(1))) / lit(2.0)
)
cnt = expr.count().over(partition)
avg = cast("ir.NumericValue", (cnt - lit(1)) / lit(2.0))
rank_ = rank_ + avg

return cast("ir.Column", ibis.cases((expr.notnull(), rank_)))
return ibis.cases((expr.notnull(), rank_))

return self._with_callable(_rank)

Expand Down
5 changes: 2 additions & 3 deletions narwhals/_ibis/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

import ibis
import ibis.expr.types as ir
Expand Down Expand Up @@ -88,8 +88,7 @@ def func(df: IbisLazyFrame) -> list[ir.Value]:
for col in cols_casted[1:]:
result = result + separator + col
else:
sep = cast("ir.StringValue", lit(separator))
result = sep.join(cols_casted)
result = lit(separator).join(cols_casted)

return [result]

Expand Down
31 changes: 25 additions & 6 deletions narwhals/_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, cast, overload

import ibis
import ibis.expr.datatypes as ibis_dtypes
Expand All @@ -23,8 +23,27 @@
from narwhals.dtypes import DType
from narwhals.typing import IntoDType, PythonLiteral

lit = ibis.literal
"""Alias for `ibis.literal`."""
Incomplete: TypeAlias = Any
"""Marker for upstream issues."""


@overload
def lit(value: bool, dtype: None = ...) -> ir.BooleanScalar: ... # noqa: FBT001
@overload
def lit(value: int, dtype: None = ...) -> ir.IntegerScalar: ...
@overload
def lit(value: float, dtype: None = ...) -> ir.FloatingScalar: ...
@overload
def lit(value: str, dtype: None = ...) -> ir.StringScalar: ...
@overload
def lit(value: PythonLiteral | ir.Value, dtype: None = ...) -> ir.Scalar: ...
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've only included ir.Value here since our typing has a lot of PythonLiteral | ir.Value that doesn't get narrowed any further than that

@overload
def lit(value: Any, dtype: Any) -> Incomplete: ...
def lit(value: Any, dtype: Any | None = None) -> Incomplete:
"""Alias for `ibis.literal`."""
literal: Incomplete = ibis.literal
return literal(value, dtype)


BucketUnit: TypeAlias = Literal[
"years",
Expand Down Expand Up @@ -231,11 +250,11 @@ def timedelta_to_ibis_interval(td: timedelta) -> ibis.expr.types.temporal.Interv
def function(name: str, *args: ir.Value | PythonLiteral) -> ir.Value:
# Workaround SQL vs Ibis differences.
if name == "row_number":
return ibis.row_number() + 1 # pyright: ignore[reportOperatorIssue]
return ibis.row_number() + lit(1)
if name == "least":
return ibis.least(*args) # pyright: ignore[reportOperatorIssue]
return ibis.least(*args)
if name == "greatest":
return ibis.greatest(*args) # pyright: ignore[reportOperatorIssue]
return ibis.greatest(*args)
expr = args[0]
if name == "var_pop":
return cast("ir.NumericColumn", expr).var(how="pop")
Expand Down
Loading