Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 5 additions & 31 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,37 +451,11 @@ def _gather_slice(self, rows: _SliceIndex | range) -> Self:
raise NotImplementedError(msg)
return self._with_native(self.native.slice(start, stop - start))

def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
import numpy as np # ignore-banned-import

values_native: ArrayAny
if isinstance(indices, int):
indices_native = pa.array([indices])
values_native = pa.array([values])
else:
# TODO(unassigned): we may also want to let `indices` be a Series.
# https://github.com/narwhals-dev/narwhals/issues/2155
indices_native = pa.array(indices)
if isinstance(values, self.__class__):
values_native = values.native.combine_chunks()
else:
# NOTE: Requires fixes in https://github.com/zen-xu/pyarrow-stubs/pull/209
pa_array: Incomplete = pa.array
values_native = pa_array(values)

sorting_indices = pc.sort_indices(indices_native)
indices_native = indices_native.take(sorting_indices)
values_native = values_native.take(sorting_indices)

mask: _1DArray = np.zeros(self.len(), dtype=bool)
mask[indices_native] = True
# NOTE: Multiple issues
# - Missing `values` type
# - `mask` accepts a `np.ndarray`, but not mentioned in stubs
# - Missing `replacements` type
# - Missing return type
pc_replace_with_mask: Incomplete = pc.replace_with_mask
return self._with_native(pc_replace_with_mask(self.native, mask, values_native))
def scatter(self, indices: Self, values: Self) -> Self:
Copy link
Copy Markdown
Member

@dangotbanned dangotbanned Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very satisfying for (#3305) to finally pay off πŸ™

Thought it might be worth splitting into a separate PR, since it would be easy to upstream to main and (possibly) avoid a numpy import

mask = pc.is_in(arange(start=0, end=len(self), step=1), indices.native)
sorted_indices = pc.sort_indices(indices.native)
replacements = values.native.take(sorted_indices).combine_chunks()
return self._with_native(pc.replace_with_mask(self.native, mask, replacements))

def to_list(self) -> list[Any]:
return self.native.to_pylist()
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_compliant/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def sample(
with_replacement: bool,
seed: int | None,
) -> Self: ...
def scatter(self, indices: int | Sequence[int], values: Any) -> Self: ...
def scatter(self, indices: Self, values: Self) -> Self: ...
def shift(self, n: int) -> Self: ...
def skew(self) -> float | None: ...
def sort(self, *, descending: bool, nulls_last: bool) -> Self: ...
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:

sorting_indices = df.get_column(token)
for s in results:
s._scatter_in_place(sorting_indices, s)
s.scatter(sorting_indices, s, in_place=True)
return results

return self.__class__(
Expand Down Expand Up @@ -384,7 +384,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901,
# Ignore settingwithcopy warnings/errors, they're false-positives here.
warnings.filterwarnings("ignore", message="\n.*copy of a slice")
for s in results:
s._scatter_in_place(sorting_indices, s)
s.scatter(sorting_indices, s, in_place=True)
return results
if reverse:
return [s._gather_slice(slice(None, None, -1)) for s in results]
Expand Down
61 changes: 31 additions & 30 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import operator
import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal

import numpy as np
from typing import TYPE_CHECKING, Any, Callable, Literal, overload

from narwhals._compliant import EagerSeries, EagerSeriesHist
from narwhals._pandas_like.series_cat import PandasLikeSeriesCatNamespace
Expand All @@ -13,6 +11,7 @@
from narwhals._pandas_like.series_str import PandasLikeSeriesStringNamespace
from narwhals._pandas_like.series_struct import PandasLikeSeriesStructNamespace
from narwhals._pandas_like.utils import (
NUMPY_VERSION,
align_and_extract_native,
broadcast_series_to_index,
get_dtype_backend,
Expand All @@ -25,7 +24,7 @@
set_index,
)
from narwhals._typing_compat import assert_never
from narwhals._utils import Implementation, is_list_of, no_default, parse_version
from narwhals._utils import Implementation, is_list_of, no_default
from narwhals.dependencies import is_numpy_array_1d, is_pandas_like_series
from narwhals.exceptions import InvalidOperationError

Expand Down Expand Up @@ -280,34 +279,36 @@ def ewm_mean(
result[mask_na] = None
return self._with_native(result)

def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
if isinstance(values, self.__class__):
values = set_index(
values.native,
self.native.index[indices],
implementation=self._implementation,
)
s = self.native.copy(deep=True)
s.iloc[indices] = values
s.name = self.name
return self._with_native(s)

def _scatter_in_place(self, indices: Self, values: Self) -> None:
# Scatter, modifying original Series. Use with care!
implementation = self._implementation
backend_version = self._backend_version
@overload
def scatter(
self, indices: Self, values: Self, *, in_place: Literal[True]
) -> None: ...
@overload
def scatter(
self, indices: Self, values: Self, *, in_place: Literal[False] = False
) -> Self: ...

def scatter(
self, indices: Self, values: Self, *, in_place: bool = False
) -> Self | None:
# !NOTE: See conversation at https://github.com/narwhals-dev/narwhals/pull/3444#discussion_r2787546529
# to understand why signature differs from `CompliantSeries`
impl = self._implementation
native_series, indices_native = self.native, indices.native
values_native = set_index(
values.native,
self.native.index[indices.native],
implementation=implementation,
values.native, native_series.index[indices_native], implementation=impl
)
if implementation is Implementation.PANDAS and parse_version(np) < (2,):
values_native = values_native.copy() # pragma: no cover
min_pd_version = (1, 2)
if implementation is Implementation.PANDAS and backend_version < min_pd_version:
self.native.iloc[indices.native.values] = values_native # noqa: PD011
else:
self.native.iloc[indices.native] = values_native
series = native_series if in_place else native_series.copy(deep=True)

if impl.is_pandas():
if in_place and NUMPY_VERSION < (2,): # pragma: no cover
values_native = values_native.copy()
if self._backend_version < (1, 2):
indices_native = indices_native.to_numpy()

series.iloc[indices_native] = values_native

return None if in_place else self._with_native(series)

def cast(self, dtype: IntoDType) -> Self:
if self.dtype == dtype and self.native.dtype != "object":
Expand Down
9 changes: 9 additions & 0 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast

import numpy as np
import pandas as pd

from narwhals._compliant import EagerSeriesNamespace
Expand All @@ -23,6 +24,7 @@
_DeferredIterable,
check_columns_exist,
isinstance_or_issubclass,
parse_version,
requires,
)
from narwhals.exceptions import ShapeError
Expand Down Expand Up @@ -114,6 +116,13 @@
Always available if we reached here, due to a module-level import.
"""

NUMPY_VERSION = parse_version(np)
"""Static version for `numpy`.

Always available if we reached here, as imported in both _pandas_like/dataframe.py and
_pandas_like/series.py.
"""


def is_pandas_or_modin(implementation: Implementation) -> bool:
return implementation in {Implementation.PANDAS, Implementation.MODIN}
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,8 @@ def sort(self, *, descending: bool, nulls_last: bool) -> Self:

return self._with_native(result)

def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
s = self.native.clone().scatter(indices, extract_native(values))
def scatter(self, indices: Self, values: Self) -> Self:
s = self.native.clone().scatter(indices.native, values.native)
return self._with_native(s)

def value_counts(
Expand Down
34 changes: 28 additions & 6 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
from collections.abc import Iterable, Iterator, Mapping, Sequence
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -394,12 +395,16 @@ def to_native(self) -> IntoSeriesT:
"""
return self._compliant_series.native

def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
"""Set value(s) at given position(s).
def scatter(
self,
indices: Self | Iterable[int] | int,
values: Self | Iterable[PythonLiteral] | PythonLiteral,
) -> Self:
"""Set value(s) at the given index location(s).
Arguments:
indices: Position(s) to set items at.
values: Values to set.
indices: Integer(s) representing the index location(s).
values: Replacement values.
Note:
This method always returns a new Series, without modifying the original one.
Expand Down Expand Up @@ -436,10 +441,27 @@ def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
a: [[999,888,3]]
b: [[4,5,6]]
"""
return self._with_compliant(
self._compliant_series.scatter(indices, self._extract_native(values))
into_series = partial(
type(self).from_iterable, name="", backend=self.implementation
)

if not isinstance(indices, Series):
if not isinstance(indices, Iterable):
indices = [indices]
dtypes = self._version.dtypes
indices = into_series(values=indices, dtype=dtypes.Int64)

if indices.is_empty():
return self

if not isinstance(values, Series):
if not isinstance(values, Iterable) or isinstance(values, str):
values = [values]
values = into_series(values=values)

result = self._compliant.scatter(indices._compliant, values._compliant)
return self._with_compliant(result)

@property
def shape(self) -> tuple[int]:
"""Get the shape of the Series.
Expand Down
122 changes: 55 additions & 67 deletions tests/series_only/scatter_test.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,59 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any

import pytest

import narwhals as nw
from tests.utils import ConstructorEager, assert_equal_data


def test_scatter(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(
constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True
)
result = df.with_columns(
df["a"].scatter([0, 1], [999, 888]), df["b"].scatter([0, 2, 1], df["b"])
)
expected = {"a": [999, 888, 3], "b": [142, 132, 124]}
assert_equal_data(result, expected)


def test_scatter_indices() -> None:
from tests.utils import ConstructorEager, assert_equal_data, assert_equal_series

if TYPE_CHECKING:
from collections.abc import Collection


def series(frame: ConstructorEager, name: str, values: Collection[Any]) -> nw.Series[Any]:
return nw.from_native(frame({name: values})).get_column(name)


@pytest.mark.filterwarnings(
"ignore:.*all arguments of to_dict except for the argument:FutureWarning"
)
Comment on lines +19 to +21
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modin triggers this warning because it passes orient=list as positional.

So I copied over the filter from test_to_dict:

@pytest.mark.filterwarnings(
"ignore:.*all arguments of to_dict except for the argument:FutureWarning"
)
def test_to_dict(constructor_eager: ConstructorEager) -> None:

But I think we should either:

  1. Handle this in narwhals
  2. Add it to tool.pytest.ini_options.filterwarnings

@pytest.mark.parametrize(
("data", "indices", "values", "expected"),
[
([142, 124, 13], [0, 2, 1], (142, 124, 13), [142, 13, 124]),
([1, 2, 3], 0, 999, [999, 2, 3]),
(
[16, 12, 10, 9, 6, 5, 2],
(6, 1, 0, 5, 3, 2, 4),
[16, 12, 10, 9, 6, 5, 2],
[10, 12, 5, 6, 2, 9, 16],
),
([5.5, 9.2, 1.0], (), (), [5.5, 9.2, 1.0]),
],
ids=["single-series", "integer", "unordered-indices", "empty-indices"],
)
def test_scatter(
data: list[Any],
indices: int | Collection[int],
values: int | Collection[int],
expected: list[Any],
constructor_eager: ConstructorEager,
) -> None:
constructor = partial(series, constructor_eager)
s = constructor("s", data)
df = s.to_frame().with_row_index("dont change me")
unchanged_indexed = df.to_dict(as_series=False)
assert_equal_series(s.scatter(indices, values), expected, "s")
if not isinstance(indices, int):
assert_equal_series(s.scatter(constructor("i", indices), values), expected, "s")
if not isinstance(values, int):
assert_equal_series(s.scatter(indices, constructor("v", values)), expected, "s")
assert_equal_data(df, unchanged_indexed)


def test_scatter_pandas_index() -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this to make it clearer, since indices is a parameter name for scatter.

Although, I'd prefer if we could use a less direct way of testing index preservation.
E.g. if this is important, why only pandas and not all the derivatives?

pytest.importorskip("pandas")
import pandas as pd

Expand All @@ -27,56 +63,8 @@ def test_scatter_indices() -> None:
pd.testing.assert_series_equal(result.to_native(), expected)


def test_scatter_unchanged(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(
constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True
)
df.with_columns(
df["a"].scatter([0, 1], [999, 888]), df["b"].scatter([0, 2, 1], [142, 124, 132])
)
expected = {"a": [1, 2, 3], "b": [142, 124, 132]}
assert_equal_data(df, expected)


def test_single_series(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(
constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True
)
s = df["a"]
s.scatter([0, 1], [999, 888])
expected = {"a": [1, 2, 3]}
assert_equal_data({"a": s}, expected)


def test_scatter_integer(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(
constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True
)
s = df["a"]
result = s.scatter(0, 999)
expected = {"a": [999, 2, 3]}
assert_equal_data({"a": result}, expected)


def test_scatter_unordered_indices(constructor_eager: ConstructorEager) -> None:
data = {"a": [16, 12, 10, 9, 6, 5, 2]}
indices = [6, 1, 0, 5, 3, 2, 4]
df = nw.from_native(constructor_eager(data))
result = df["a"].scatter(indices, df["a"])
assert_equal_data({"a": result}, {"a": [10, 12, 5, 6, 2, 9, 16]})


def test_scatter_2862(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(
constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True
)
ser = df["a"]
result = ser.scatter(1, 999)
expected = {"a": [1, 999, 3]}
assert_equal_data({"a": result}, expected)
result = ser.scatter([0, 2], [999, 888])
expected = {"a": [999, 2, 888]}
assert_equal_data({"a": result}, expected)
result = ser.scatter([2, 0], [999, 888])
expected = {"a": [888, 2, 999]}
assert_equal_data({"a": result}, expected)
s = series(constructor_eager, "a", [1, 2, 3])
assert_equal_series(s.scatter(1, 999), [1, 999, 3], "a")
assert_equal_series(s.scatter([0, 2], [999, 888]), [999, 2, 888], "a")
assert_equal_series(s.scatter([2, 0], [999, 888]), [888, 2, 999], "a")
Loading