diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 45a6d6b3d2..650217bd3d 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -451,37 +451,11 @@ def _gather_slice(self, rows: _SliceIndex | range) -> Self: raise NotImplementedError(msg) return self._with_native(self.native.slice(start, stop - start)) - def scatter(self, indices: int | Sequence[int], values: Any) -> Self: - import numpy as np # ignore-banned-import - - values_native: ArrayAny - if isinstance(indices, int): - indices_native = pa.array([indices]) - values_native = pa.array([values]) - else: - # TODO(unassigned): we may also want to let `indices` be a Series. - # https://github.com/narwhals-dev/narwhals/issues/2155 - indices_native = pa.array(indices) - if isinstance(values, self.__class__): - values_native = values.native.combine_chunks() - else: - # NOTE: Requires fixes in https://github.com/zen-xu/pyarrow-stubs/pull/209 - pa_array: Incomplete = pa.array - values_native = pa_array(values) - - sorting_indices = pc.sort_indices(indices_native) - indices_native = indices_native.take(sorting_indices) - values_native = values_native.take(sorting_indices) - - mask: _1DArray = np.zeros(self.len(), dtype=bool) - mask[indices_native] = True - # NOTE: Multiple issues - # - Missing `values` type - # - `mask` accepts a `np.ndarray`, but not mentioned in stubs - # - Missing `replacements` type - # - Missing return type - pc_replace_with_mask: Incomplete = pc.replace_with_mask - return self._with_native(pc_replace_with_mask(self.native, mask, values_native)) + def scatter(self, indices: Self, values: Self) -> Self: + mask = pc.is_in(arange(start=0, end=len(self), step=1), indices.native) + sorted_indices = pc.sort_indices(indices.native) + replacements = values.native.take(sorted_indices).combine_chunks() + return self._with_native(pc.replace_with_mask(self.native, mask, replacements)) def to_list(self) -> list[Any]: return self.native.to_pylist() diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index eb7f3ae6ef..077e6a86ad 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -162,7 +162,7 @@ def sample( with_replacement: bool, seed: int | None, ) -> Self: ... - def scatter(self, indices: int | Sequence[int], values: Any) -> Self: ... + def scatter(self, indices: Self, values: Self) -> Self: ... def shift(self, n: int) -> Self: ... def skew(self) -> float | None: ... def sort(self, *, descending: bool, nulls_last: bool) -> Self: ... diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 40f7560552..944a136b11 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -238,7 +238,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: sorting_indices = df.get_column(token) for s in results: - s._scatter_in_place(sorting_indices, s) + s.scatter(sorting_indices, s, in_place=True) return results return self.__class__( @@ -384,7 +384,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, # Ignore settingwithcopy warnings/errors, they're false-positives here. warnings.filterwarnings("ignore", message="\n.*copy of a slice") for s in results: - s._scatter_in_place(sorting_indices, s) + s.scatter(sorting_indices, s, in_place=True) return results if reverse: return [s._gather_slice(slice(None, None, -1)) for s in results] diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 206f581c73..f0dc6c4e48 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -2,9 +2,7 @@ import operator import warnings -from typing import TYPE_CHECKING, Any, Callable, Literal - -import numpy as np +from typing import TYPE_CHECKING, Any, Callable, Literal, overload from narwhals._compliant import EagerSeries, EagerSeriesHist from narwhals._pandas_like.series_cat import PandasLikeSeriesCatNamespace @@ -13,6 +11,7 @@ from narwhals._pandas_like.series_str import PandasLikeSeriesStringNamespace from narwhals._pandas_like.series_struct import PandasLikeSeriesStructNamespace from narwhals._pandas_like.utils import ( + NUMPY_VERSION, align_and_extract_native, broadcast_series_to_index, get_dtype_backend, @@ -25,7 +24,7 @@ set_index, ) from narwhals._typing_compat import assert_never -from narwhals._utils import Implementation, is_list_of, no_default, parse_version +from narwhals._utils import Implementation, is_list_of, no_default from narwhals.dependencies import is_numpy_array_1d, is_pandas_like_series from narwhals.exceptions import InvalidOperationError @@ -280,34 +279,36 @@ def ewm_mean( result[mask_na] = None return self._with_native(result) - def scatter(self, indices: int | Sequence[int], values: Any) -> Self: - if isinstance(values, self.__class__): - values = set_index( - values.native, - self.native.index[indices], - implementation=self._implementation, - ) - s = self.native.copy(deep=True) - s.iloc[indices] = values - s.name = self.name - return self._with_native(s) - - def _scatter_in_place(self, indices: Self, values: Self) -> None: - # Scatter, modifying original Series. Use with care! - implementation = self._implementation - backend_version = self._backend_version + @overload + def scatter( + self, indices: Self, values: Self, *, in_place: Literal[True] + ) -> None: ... + @overload + def scatter( + self, indices: Self, values: Self, *, in_place: Literal[False] = False + ) -> Self: ... + + def scatter( + self, indices: Self, values: Self, *, in_place: bool = False + ) -> Self | None: + # !NOTE: See conversation at https://github.com/narwhals-dev/narwhals/pull/3444#discussion_r2787546529 + # to understand why signature differs from `CompliantSeries` + impl = self._implementation + native_series, indices_native = self.native, indices.native values_native = set_index( - values.native, - self.native.index[indices.native], - implementation=implementation, + values.native, native_series.index[indices_native], implementation=impl ) - if implementation is Implementation.PANDAS and parse_version(np) < (2,): - values_native = values_native.copy() # pragma: no cover - min_pd_version = (1, 2) - if implementation is Implementation.PANDAS and backend_version < min_pd_version: - self.native.iloc[indices.native.values] = values_native # noqa: PD011 - else: - self.native.iloc[indices.native] = values_native + series = native_series if in_place else native_series.copy(deep=True) + + if impl.is_pandas(): + if in_place and NUMPY_VERSION < (2,): # pragma: no cover + values_native = values_native.copy() + if self._backend_version < (1, 2): + indices_native = indices_native.to_numpy() + + series.iloc[indices_native] = values_native + + return None if in_place else self._with_native(series) def cast(self, dtype: IntoDType) -> Self: if self.dtype == dtype and self.native.dtype != "object": diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index eee0833763..8d8476704a 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -5,6 +5,7 @@ import re from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast +import numpy as np import pandas as pd from narwhals._compliant import EagerSeriesNamespace @@ -23,6 +24,7 @@ _DeferredIterable, check_columns_exist, isinstance_or_issubclass, + parse_version, requires, ) from narwhals.exceptions import ShapeError @@ -114,6 +116,13 @@ Always available if we reached here, due to a module-level import. """ +NUMPY_VERSION = parse_version(np) +"""Static version for `numpy`. + +Always available if we reached here, as imported in both _pandas_like/dataframe.py and +_pandas_like/series.py. +""" + def is_pandas_or_modin(implementation: Implementation) -> bool: return implementation in {Implementation.PANDAS, Implementation.MODIN} diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 2ff271fc66..b77bba4eb9 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -504,8 +504,8 @@ def sort(self, *, descending: bool, nulls_last: bool) -> Self: return self._with_native(result) - def scatter(self, indices: int | Sequence[int], values: Any) -> Self: - s = self.native.clone().scatter(indices, extract_native(values)) + def scatter(self, indices: Self, values: Self) -> Self: + s = self.native.clone().scatter(indices.native, values.native) return self._with_native(s) def value_counts( diff --git a/narwhals/series.py b/narwhals/series.py index 0ce582e496..6acc476245 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -2,6 +2,7 @@ import math from collections.abc import Iterable, Iterator, Mapping, Sequence +from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -394,12 +395,16 @@ def to_native(self) -> IntoSeriesT: """ return self._compliant_series.native - def scatter(self, indices: int | Sequence[int], values: Any) -> Self: - """Set value(s) at given position(s). + def scatter( + self, + indices: Self | Iterable[int] | int, + values: Self | Iterable[PythonLiteral] | PythonLiteral, + ) -> Self: + """Set value(s) at the given index location(s). Arguments: - indices: Position(s) to set items at. - values: Values to set. + indices: Integer(s) representing the index location(s). + values: Replacement values. Note: This method always returns a new Series, without modifying the original one. @@ -436,10 +441,27 @@ def scatter(self, indices: int | Sequence[int], values: Any) -> Self: a: [[999,888,3]] b: [[4,5,6]] """ - return self._with_compliant( - self._compliant_series.scatter(indices, self._extract_native(values)) + into_series = partial( + type(self).from_iterable, name="", backend=self.implementation ) + if not isinstance(indices, Series): + if not isinstance(indices, Iterable): + indices = [indices] + dtypes = self._version.dtypes + indices = into_series(values=indices, dtype=dtypes.Int64) + + if indices.is_empty(): + return self + + if not isinstance(values, Series): + if not isinstance(values, Iterable) or isinstance(values, str): + values = [values] + values = into_series(values=values) + + result = self._compliant.scatter(indices._compliant, values._compliant) + return self._with_compliant(result) + @property def shape(self) -> tuple[int]: """Get the shape of the Series. diff --git a/tests/series_only/scatter_test.py b/tests/series_only/scatter_test.py index 4b827276b9..d608f78f1a 100644 --- a/tests/series_only/scatter_test.py +++ b/tests/series_only/scatter_test.py @@ -1,23 +1,59 @@ from __future__ import annotations +from functools import partial +from typing import TYPE_CHECKING, Any + import pytest import narwhals as nw -from tests.utils import ConstructorEager, assert_equal_data - - -def test_scatter(constructor_eager: ConstructorEager) -> None: - df = nw.from_native( - constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True - ) - result = df.with_columns( - df["a"].scatter([0, 1], [999, 888]), df["b"].scatter([0, 2, 1], df["b"]) - ) - expected = {"a": [999, 888, 3], "b": [142, 132, 124]} - assert_equal_data(result, expected) - - -def test_scatter_indices() -> None: +from tests.utils import ConstructorEager, assert_equal_data, assert_equal_series + +if TYPE_CHECKING: + from collections.abc import Collection + + +def series(frame: ConstructorEager, name: str, values: Collection[Any]) -> nw.Series[Any]: + return nw.from_native(frame({name: values})).get_column(name) + + +@pytest.mark.filterwarnings( + "ignore:.*all arguments of to_dict except for the argument:FutureWarning" +) +@pytest.mark.parametrize( + ("data", "indices", "values", "expected"), + [ + ([142, 124, 13], [0, 2, 1], (142, 124, 13), [142, 13, 124]), + ([1, 2, 3], 0, 999, [999, 2, 3]), + ( + [16, 12, 10, 9, 6, 5, 2], + (6, 1, 0, 5, 3, 2, 4), + [16, 12, 10, 9, 6, 5, 2], + [10, 12, 5, 6, 2, 9, 16], + ), + ([5.5, 9.2, 1.0], (), (), [5.5, 9.2, 1.0]), + ], + ids=["single-series", "integer", "unordered-indices", "empty-indices"], +) +def test_scatter( + data: list[Any], + indices: int | Collection[int], + values: int | Collection[int], + expected: list[Any], + constructor_eager: ConstructorEager, +) -> None: + constructor = partial(series, constructor_eager) + s = constructor("s", data) + df = s.to_frame().with_row_index("dont change me") + unchanged_indexed = df.to_dict(as_series=False) + assert_equal_series(s.scatter(indices, values), expected, "s") + if not isinstance(indices, int): + assert_equal_series(s.scatter(constructor("i", indices), values), expected, "s") + if not isinstance(values, int): + assert_equal_series(s.scatter(indices, constructor("v", values)), expected, "s") + assert_equal_data(df, unchanged_indexed) + + +def test_scatter_pandas_index() -> None: pytest.importorskip("pandas") import pandas as pd @@ -27,56 +63,8 @@ def test_scatter_indices() -> None: pd.testing.assert_series_equal(result.to_native(), expected) -def test_scatter_unchanged(constructor_eager: ConstructorEager) -> None: - df = nw.from_native( - constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True - ) - df.with_columns( - df["a"].scatter([0, 1], [999, 888]), df["b"].scatter([0, 2, 1], [142, 124, 132]) - ) - expected = {"a": [1, 2, 3], "b": [142, 124, 132]} - assert_equal_data(df, expected) - - -def test_single_series(constructor_eager: ConstructorEager) -> None: - df = nw.from_native( - constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True - ) - s = df["a"] - s.scatter([0, 1], [999, 888]) - expected = {"a": [1, 2, 3]} - assert_equal_data({"a": s}, expected) - - -def test_scatter_integer(constructor_eager: ConstructorEager) -> None: - df = nw.from_native( - constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True - ) - s = df["a"] - result = s.scatter(0, 999) - expected = {"a": [999, 2, 3]} - assert_equal_data({"a": result}, expected) - - -def test_scatter_unordered_indices(constructor_eager: ConstructorEager) -> None: - data = {"a": [16, 12, 10, 9, 6, 5, 2]} - indices = [6, 1, 0, 5, 3, 2, 4] - df = nw.from_native(constructor_eager(data)) - result = df["a"].scatter(indices, df["a"]) - assert_equal_data({"a": result}, {"a": [10, 12, 5, 6, 2, 9, 16]}) - - def test_scatter_2862(constructor_eager: ConstructorEager) -> None: - df = nw.from_native( - constructor_eager({"a": [1, 2, 3], "b": [142, 124, 132]}), eager_only=True - ) - ser = df["a"] - result = ser.scatter(1, 999) - expected = {"a": [1, 999, 3]} - assert_equal_data({"a": result}, expected) - result = ser.scatter([0, 2], [999, 888]) - expected = {"a": [999, 2, 888]} - assert_equal_data({"a": result}, expected) - result = ser.scatter([2, 0], [999, 888]) - expected = {"a": [888, 2, 999]} - assert_equal_data({"a": result}, expected) + s = series(constructor_eager, "a", [1, 2, 3]) + assert_equal_series(s.scatter(1, 999), [1, 999, 3], "a") + assert_equal_series(s.scatter([0, 2], [999, 888]), [999, 2, 888], "a") + assert_equal_series(s.scatter([2, 0], [999, 888]), [888, 2, 999], "a")