Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/extremes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ jobs:
- name: install dask
run: |
uv pip uninstall dask dask-expr --system
python -m pip install git+https://github.com/dask/distributed git+https://github.com/dask/dask
uv pip install "git+https://github.com/dask/dask[dataframe]" --system
- name: install duckdb nightly
run: |
uv pip uninstall duckdb --system
Expand Down
21 changes: 11 additions & 10 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import (
arange,
concat_tables,
narwhals_to_native_dtype,
native_to_narwhals_dtype,
Expand Down Expand Up @@ -531,19 +532,19 @@ def to_dict(
return {ser.name: ser.to_list() for ser in it}

def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
import numpy as np # ignore-banned-import

plx = self.__narwhals_namespace__()
data = pa.array(np.arange(len(self)))
row_index_s = plx._series.from_iterable(data, context=self, name=name)
row_index = plx._expr._from_series(row_index_s)
if order_by:
data = arange(0, len(self), 1)
if order_by is None:
row_index = plx._expr._from_series(
self.select(row_index, *(plx.col(x) for x in order_by))
.sort(*order_by, descending=False, nulls_last=False)
.get_column(name)
plx._series.from_iterable(data, context=self, name=name)
)
return self.select(row_index, plx.all())
return self.select(row_index, plx.all())
indices = pc.sort_indices(self.native, [(by, "ascending") for by in order_by])
if self._backend_version < (20,):
new_col = data.take(pc.sort_indices(indices))
else:
new_col = pc.scatter(data, indices.cast(pa.int64())) # type: ignore[attr-defined]
return self._with_native(self.native.add_column(0, name, new_col))

def filter(self, predicate: ArrowExpr) -> Self:
mask_native = self._evaluate_single_output_expr(predicate).native
Expand Down
9 changes: 9 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,12 @@ def concat_tables(


class ArrowSeriesNamespace(EagerSeriesNamespace["ArrowSeries", "ChunkedArrayAny"]): ...


def arange(start: int, end: int, step: int) -> ArrayAny:
if BACKEND_VERSION < (21,):
import numpy as np # ignore-banned-import

return pa.array(np.arange(start, end, step))
# NOTE: Added in https://github.com/apache/arrow/pull/46778
return pa.arange(start, end, step) # type: ignore[attr-defined]
24 changes: 15 additions & 9 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from narwhals.dependencies import is_pandas_like_dataframe
from narwhals.exceptions import InvalidOperationError, ShapeError
from narwhals.functions import col as nw_col

if TYPE_CHECKING:
from io import BytesIO
Expand Down Expand Up @@ -465,17 +466,22 @@ def estimated_size(self, unit: SizeUnit) -> int | float:

def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
plx = self.__narwhals_namespace__()
data = self._array_funcs.arange(len(self))
row_index_s = plx._series.from_iterable(
data, context=self, index=self.native.index, name=name
)
row_index = plx._expr._from_series(row_index_s)
if order_by:
if order_by is None:
data = self._array_funcs.arange(len(self))
row_index = plx._expr._from_series(
self.select(row_index, *(plx.col(x) for x in order_by))
.sort(*order_by, descending=False, nulls_last=False)
.get_column(name)
plx._series.from_iterable(
data, context=self, index=self.native.index, name=name
)
)
else:
rank = cast(
"PandasLikeExpr",
nw_col(order_by[0]).rank(method="ordinal")._to_compliant_expr(plx),
)
row_index = (
rank.over(partition_by=[], order_by=order_by)
- plx.lit(1, None).broadcast()
).alias(name)
return self.select(row_index, plx.all())

def row(self, index: int) -> tuple[Any, ...]:
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,8 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
if order_by is None:
result = frame.with_row_index(name)
else:
end = pl.count() if self._backend_version < (0, 20, 5) else pl.len()
result = frame.select(
pl.int_range(start=0, end=end).sort_by(order_by).alias(name), pl.all()
pl.int_range(pl.len()).over(order_by=order_by).alias(name), pl.all()
)

return self._with_native(result)
Expand Down
30 changes: 30 additions & 0 deletions tests/frame/with_row_index_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tests.utils import (
DUCKDB_VERSION,
PANDAS_VERSION,
POLARS_VERSION,
Constructor,
ConstructorEager,
assert_equal_data,
Expand Down Expand Up @@ -44,6 +45,8 @@ def test_with_row_index_lazy(
pytest.skip(reason=reason)
if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3):
pytest.skip()
if "polars" in str(constructor) and POLARS_VERSION < (1, 10):
pytest.skip()

result = (
nw.from_native(constructor(data))
Expand All @@ -63,3 +66,30 @@ def test_with_row_index_lazy_exception(constructor: Constructor) -> None:
else:
result = frame.with_row_index()
assert_equal_data(result, {"index": [0, 1], **data})


@pytest.mark.parametrize(
("order_by", "expected_index"),
[
(["a"], [0, 2, 1]),
(["c"], [2, 0, 1]),
(["a", "c"], [1, 2, 0]),
(["c", "a"], [2, 0, 1]),
],
)
def test_with_row_index_lazy_meaner_examples(
constructor: Constructor, order_by: list[str], expected_index: list[int]
) -> None:
# https://github.com/narwhals-dev/narwhals/issues/3289
if "polars" in str(constructor) and POLARS_VERSION < (1, 10):
pytest.skip()
if "pandas" in str(constructor) and PANDAS_VERSION < (1, 3): # pragma: no cover
reason = "ValueError: first not supported for non-numeric data."
pytest.skip(reason=reason)
if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3):
pytest.skip()
data = {"a": ["A", "B", "A"], "b": [1, 2, 3], "c": [9, 2, 4]}
df = nw.from_native(constructor(data))
result = df.with_row_index(name="index", order_by=order_by).sort("b")
expected = {"index": expected_index, **data}
assert_equal_data(result, expected)
Loading