diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 907214058a74..9f8035772bad 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -39,7 +39,6 @@ import dpctl.tensor as dpt -import numpy import dpnp @@ -47,7 +46,6 @@ # pylint: disable=no-name-in-module from .dpnp_utils import ( - call_origin, get_usm_allocations, ) @@ -298,35 +296,59 @@ def where(condition, x=None, y=None, /): For full documentation refer to :obj:`numpy.where`. + Parameters + ---------- + condition : {dpnp.ndarray, usm_ndarray} + Where True, yield `x`, otherwise yield `y`. + x, y : {dpnp.ndarray, usm_ndarray, scalar}, optional + Values from which to choose. `x`, `y` and `condition` need to be + broadcastable to some shape. + Returns ------- y : dpnp.ndarray An array with elements from `x` where `condition` is True, and elements from `y` elsewhere. - Limitations - ----------- - Parameter `condition` is supported as either :class:`dpnp.ndarray` - or :class:`dpctl.tensor.usm_ndarray`. - Parameters `x` and `y` are supported as either scalar, :class:`dpnp.ndarray` - or :class:`dpctl.tensor.usm_ndarray` - Otherwise the function will be executed sequentially on CPU. - Input array data types of `x` and `y` are limited by supported DPNP - :ref:`Data types`. - See Also -------- - :obj:`nonzero` : The function that is called when `x` and `y`are omitted. + :obj:`dpnp.choose` : Construct an array from an index array and a list of + arrays to choose from. + :obj:`dpnp.nonzero` : Return the indices of the elements that are non-zero. Examples -------- - >>> import dpnp as dp - >>> a = dp.arange(10) - >>> d + >>> import dpnp as np + >>> a = np.arange(10) + >>> a array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - >>> dp.where(a < 5, a, 10*a) + >>> np.where(a < 5, a, 10*a) array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90]) + This can be used on multidimensional arrays too: + + >>> np.where(np.array([[True, False], [True, True]]), + ... np.array([[1, 2], [3, 4]]), + ... np.array([[9, 8], [7, 6]])) + array([[1, 8], + [3, 4]]) + + The shapes of x, y, and the condition are broadcast together: + + >>> x, y = np.ogrid[:3, :4] + >>> np.where(x < y, x, 10 + y) # both x and 10+y are broadcast + array([[10, 0, 0, 0], + [10, 11, 1, 1], + [10, 11, 12, 2]]) + + >>> a = np.array([[0, 1, 2], + ... [0, 2, 4], + ... [0, 3, 6]]) + >>> np.where(a < 4, a, -1) # -1 is broadcast + array([[ 0, 1, 2], + [ 0, 2, -1], + [ 0, 3, -1]]) + """ missing = (x is None, y is None).count(True) @@ -336,34 +358,17 @@ def where(condition, x=None, y=None, /): if missing == 2: return dpnp.nonzero(condition) - if missing == 0: - if dpnp.is_supported_array_type(condition): - if numpy.isscalar(x) or numpy.isscalar(y): - # get USM type and queue to copy scalar from the host memory - # into a USM allocation - usm_type, queue = get_usm_allocations([condition, x, y]) - x = ( - dpt.asarray(x, usm_type=usm_type, sycl_queue=queue) - if numpy.isscalar(x) - else x - ) - y = ( - dpt.asarray(y, usm_type=usm_type, sycl_queue=queue) - if numpy.isscalar(y) - else y - ) - if dpnp.is_supported_array_type(x) and dpnp.is_supported_array_type( - y - ): - dpt_condition = ( - condition.get_array() - if isinstance(condition, dpnp_array) - else condition - ) - dpt_x = x.get_array() if isinstance(x, dpnp_array) else x - dpt_y = y.get_array() if isinstance(y, dpnp_array) else y - return dpnp_array._create_from_usm_ndarray( - dpt.where(dpt_condition, dpt_x, dpt_y) - ) - - return call_origin(numpy.where, condition, x, y) + usm_x = dpnp.get_usm_ndarray_or_scalar(x) + usm_y = dpnp.get_usm_ndarray_or_scalar(y) + usm_condition = dpnp.get_usm_ndarray(condition) + + usm_type, queue = get_usm_allocations([condition, x, y]) + if dpnp.isscalar(usm_x): + usm_x = dpt.asarray(usm_x, usm_type=usm_type, sycl_queue=queue) + + if dpnp.isscalar(usm_y): + usm_y = dpt.asarray(usm_y, usm_type=usm_type, sycl_queue=queue) + + return dpnp_array._create_from_usm_ndarray( + dpt.where(usm_condition, usm_x, usm_y) + ) diff --git a/tests/test_indexing.py b/tests/test_indexing.py index 2640eb64c49a..95d5dfbe59af 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -906,22 +906,3 @@ def test_triu_indices_from(array, k): result = dpnp.triu_indices_from(ia, k) expected = numpy.triu_indices_from(a, k) assert_array_equal(expected, result) - - -@pytest.mark.parametrize("cond_dtype", get_all_dtypes()) -@pytest.mark.parametrize("scalar_dtype", get_all_dtypes(no_none=True)) -def test_where_with_scalars(cond_dtype, scalar_dtype): - a = numpy.array([-1, 0, 1, 0], dtype=cond_dtype) - ia = dpnp.array(a) - - result = dpnp.where(ia, scalar_dtype(1), scalar_dtype(0)) - expected = numpy.where(a, scalar_dtype(1), scalar_dtype(0)) - assert_array_equal(expected, result) - - result = dpnp.where(ia, ia * 2, scalar_dtype(0)) - expected = numpy.where(a, a * 2, scalar_dtype(0)) - assert_array_equal(expected, result) - - result = dpnp.where(ia, scalar_dtype(1), dpnp.array(0)) - expected = numpy.where(a, scalar_dtype(1), numpy.array(0)) - assert_array_equal(expected, result) diff --git a/tests/test_search.py b/tests/test_search.py index 56f4f23739cb..1a3313345f2c 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,7 +1,7 @@ import dpctl.tensor as dpt import numpy import pytest -from numpy.testing import assert_allclose +from numpy.testing import assert_allclose, assert_array_equal, assert_raises import dpnp @@ -92,3 +92,189 @@ def test_nanargmax_nanargmin_error(func): # All-NaN slice encountered -> ValueError with pytest.raises(ValueError): getattr(dpnp, func)(ia, axis=0) + + +class TestWhere: + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + def test_basic(self, dtype): + a = numpy.ones(53, dtype=bool) + ia = dpnp.array(a) + + np_res = numpy.where(a, dtype(0), dtype(1)) + dpnp_res = dpnp.where(ia, dtype(0), dtype(1)) + assert_array_equal(np_res, dpnp_res) + + np_res = numpy.where(~a, dtype(0), dtype(1)) + dpnp_res = dpnp.where(~ia, dtype(0), dtype(1)) + assert_array_equal(np_res, dpnp_res) + + d = numpy.ones_like(a).astype(dtype) + e = numpy.zeros_like(d) + a[7] = False + + ia[7] = False + id = dpnp.array(d) + ie = dpnp.array(e) + + np_res = numpy.where(a, e, e) + dpnp_res = dpnp.where(ia, ie, ie) + assert_array_equal(np_res, dpnp_res) + + np_res = numpy.where(a, d, e) + dpnp_res = dpnp.where(ia, id, ie) + assert_array_equal(np_res, dpnp_res) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize( + "slice_a, slice_d, slice_e", + [ + pytest.param( + slice(None, None, None), + slice(None, None, None), + slice(0, 1, None), + ), + pytest.param( + slice(None, None, None), + slice(0, 1, None), + slice(None, None, None), + ), + pytest.param( + slice(None, None, 2), slice(None, None, 2), slice(None, None, 2) + ), + pytest.param( + slice(1, None, 2), slice(1, None, 2), slice(1, None, 2) + ), + pytest.param( + slice(None, None, 3), slice(None, None, 3), slice(None, None, 3) + ), + pytest.param( + slice(1, None, 3), slice(1, None, 3), slice(1, None, 3) + ), + pytest.param( + slice(None, None, -2), + slice(None, None, -2), + slice(None, None, -2), + ), + pytest.param( + slice(None, None, -3), + slice(None, None, -3), + slice(None, None, -3), + ), + pytest.param( + slice(1, None, -3), slice(1, None, -3), slice(1, None, -3) + ), + ], + ) + def test_strided(self, dtype, slice_a, slice_d, slice_e): + a = numpy.ones(53, dtype=bool) + a[7] = False + d = numpy.ones_like(a).astype(dtype) + e = numpy.zeros_like(d) + + ia = dpnp.array(a) + id = dpnp.array(d) + ie = dpnp.array(e) + + np_res = numpy.where(a[slice_a], d[slice_d], e[slice_e]) + dpnp_res = dpnp.where(ia[slice_a], id[slice_d], ie[slice_e]) + assert_array_equal(np_res, dpnp_res) + + def test_zero_sized(self): + a = numpy.array([], dtype=bool).reshape(0, 3) + b = numpy.array([], dtype=numpy.float32).reshape(0, 3) + + ia = dpnp.array(a) + ib = dpnp.array(b) + + np_res = numpy.where(a, 0, b) + dpnp_res = dpnp.where(ia, 0, ib) + assert_array_equal(np_res, dpnp_res) + + def test_ndim(self): + a = numpy.zeros((2, 25)) + b = numpy.ones((2, 25)) + c = numpy.array([True, False]) + + ia = dpnp.array(a) + ib = dpnp.array(b) + ic = dpnp.array(c) + + np_res = numpy.where(c[:, numpy.newaxis], a, b) + dpnp_res = dpnp.where(ic[:, dpnp.newaxis], ia, ib) + assert_array_equal(np_res, dpnp_res) + + np_res = numpy.where(c, a.T, b.T) + dpnp_res = numpy.where(ic, ia.T, ib.T) + assert_array_equal(np_res, dpnp_res) + + def test_dtype_mix(self): + a = numpy.uint32(1) + b = numpy.array( + [5.0, 0.0, 3.0, 2.0, -1.0, -4.0, 0.0, -10.0, 10.0, 1.0, 0.0, 3.0], + dtype=numpy.float32, + ) + c = numpy.array( + [ + False, + True, + False, + False, + False, + False, + True, + False, + False, + False, + True, + False, + ] + ) + + ia = dpnp.array(a) + ib = dpnp.array(b) + ic = dpnp.array(c) + + np_res = numpy.where(c, a, b) + dpnp_res = dpnp.where(ic, ia, ib) + assert_array_equal(np_res, dpnp_res) + + b = b.astype(numpy.int64) + ib = dpnp.array(b) + + np_res = numpy.where(c, a, b) + dpnp_res = dpnp.where(ic, ia, ib) + assert_array_equal(np_res, dpnp_res) + + # non bool mask + c = c.astype(int) + c[c != 0] = 34242324 + ic = dpnp.array(c) + + np_res = numpy.where(c, a, b) + dpnp_res = dpnp.where(ic, ia, ib) + assert_array_equal(np_res, dpnp_res) + + # invert + tmpmask = c != 0 + c[c == 0] = 41247212 + c[tmpmask] = 0 + ic = dpnp.array(c) + + np_res = numpy.where(c, a, b) + dpnp_res = dpnp.where(ic, ia, ib) + assert_array_equal(np_res, dpnp_res) + + def test_error(self): + c = dpnp.array([True, True]) + a = dpnp.ones((4, 5)) + b = dpnp.ones((5, 5)) + assert_raises(ValueError, dpnp.where, c, a, a) + assert_raises(ValueError, dpnp.where, c[0], a, b) + + def test_empty_result(self): + a = numpy.zeros((1, 1)) + ia = dpnp.array(a) + + np_res = numpy.vstack(numpy.where(a == 99.0)) + dpnp_res = dpnp.vstack(dpnp.where(ia == 99.0)) + assert_array_equal(np_res, dpnp_res) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index acf2801faf9c..778547f35e97 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1770,6 +1770,24 @@ def test_grid(device, func): assert_sycl_queue_equal(x.sycl_queue, sycl_queue) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_where(device): + a = numpy.array([[0, 1, 2], [0, 2, 4], [0, 3, 6]]) + ia = dpnp.array(a, device=device) + + result = dpnp.where(ia < 4, ia, -1) + expected = numpy.where(a < 4, a, -1) + assert_allclose(expected, result) + + expected_queue = ia.get_array().sycl_queue + result_queue = result.get_array().sycl_queue + assert_sycl_queue_equal(result_queue, expected_queue) + + @pytest.mark.parametrize( "device", valid_devices, diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index eaaf734eadc5..e528185566ea 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -759,6 +759,13 @@ def test_clip(usm_type): assert x.usm_type == y.usm_type +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +def test_where(usm_type): + a = dp.array([[0, 1, 2], [0, 2, 4], [0, 3, 6]], usm_type=usm_type) + result = dp.where(a < 4, a, -1) + assert result.usm_type == usm_type + + @pytest.mark.parametrize( "usm_type_matrix", list_of_usm_types, ids=list_of_usm_types )