diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 81695ae3114f..48426a9da42f 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -179,6 +179,7 @@ from keras.src.ops.numpy import divide as divide from keras.src.ops.numpy import divide_no_nan as divide_no_nan from keras.src.ops.numpy import dot as dot +from keras.src.ops.numpy import dstack as dstack from keras.src.ops.numpy import einsum as einsum from keras.src.ops.numpy import empty as empty from keras.src.ops.numpy import empty_like as empty_like diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index ddf9aa409abc..11114359d9cb 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -63,6 +63,7 @@ from keras.src.ops.numpy import divide as divide from keras.src.ops.numpy import divide_no_nan as divide_no_nan from keras.src.ops.numpy import dot as dot +from keras.src.ops.numpy import dstack as dstack from keras.src.ops.numpy import einsum as einsum from keras.src.ops.numpy import empty as empty from keras.src.ops.numpy import empty_like as empty_like diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 81695ae3114f..48426a9da42f 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -179,6 +179,7 @@ from keras.src.ops.numpy import divide as divide from keras.src.ops.numpy import divide_no_nan as divide_no_nan from keras.src.ops.numpy import dot as dot +from keras.src.ops.numpy import dstack as dstack from keras.src.ops.numpy import einsum as einsum from keras.src.ops.numpy import empty as empty from keras.src.ops.numpy import empty_like as empty_like diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index ddf9aa409abc..11114359d9cb 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -63,6 +63,7 @@ from keras.src.ops.numpy import divide as divide from keras.src.ops.numpy import divide_no_nan as divide_no_nan from keras.src.ops.numpy import dot as dot +from keras.src.ops.numpy import dstack as dstack from keras.src.ops.numpy import einsum as einsum from keras.src.ops.numpy import empty as empty from keras.src.ops.numpy import empty_like as empty_like diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index c2d748c7fdbe..c61350a26651 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -682,6 +682,10 @@ def dot(x1, x2): return jnp.dot(x1, x2) +def dstack(xs): + return jnp.dstack(xs) + + def empty(shape, dtype=None): dtype = dtype or config.floatx() return jnp.empty(shape, dtype=dtype) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index 61e85a8ea31a..e82988346bd5 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -607,6 +607,16 @@ def dot(x1, x2): return np.dot(x1, x2) +def dstack(xs): + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure( + lambda x: convert_to_tensor(x).astype(dtype), xs + ) + return np.dstack(xs) + + def empty(shape, dtype=None): dtype = dtype or config.floatx() return np.empty(shape, dtype=dtype) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 5c5181d62dfa..53a70462d89a 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1145,6 +1145,48 @@ def dot(x1, x2): return OpenVINOKerasTensor(ov_opset.matmul(x1, x2, False, False).output(0)) +def dstack(xs): + if not isinstance(xs, (list, tuple)): + xs = (xs,) + elems = [convert_to_tensor(elem) for elem in xs] + element_type = elems[0].output.get_element_type() + elems = [get_ov_output(elem, element_type) for elem in elems] + + processed_elems = [] + for elem in elems: + shape = elem.get_partial_shape() + rank = shape.rank + shape_len = rank.get_length() + if shape_len == 0: + elem = ov_opset.unsqueeze( + elem, ov_opset.constant(0, Type.i32) + ).output(0) + elem = ov_opset.unsqueeze( + elem, ov_opset.constant(1, Type.i32) + ).output(0) + elem = ov_opset.unsqueeze( + elem, ov_opset.constant(2, Type.i32) + ).output(0) + elif shape_len == 1: + elem = ov_opset.unsqueeze( + elem, ov_opset.constant(0, Type.i32) + ).output(0) + elem = ov_opset.unsqueeze( + elem, ov_opset.constant(2, Type.i32) + ).output(0) + elif shape_len == 2: + elem = ov_opset.unsqueeze( + elem, ov_opset.constant(2, Type.i32) + ).output(0) + processed_elems.append(elem) + + for i in range(1, len(processed_elems)): + processed_elems[0], processed_elems[i] = _align_operand_types( + processed_elems[0], processed_elems[i], "dstack()" + ) + return OpenVINOKerasTensor(ov_opset.concat(processed_elems, 2).output(0)) + + def empty(shape, dtype=None): dtype = standardize_dtype(dtype) or config.floatx() ov_type = OPENVINO_DTYPES[dtype] diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 76f8ae3c3a79..6290cb115206 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -1478,6 +1478,27 @@ def dot(x1, x2): return tf.cast(output, result_dtype) +def dstack(xs): + xs = [convert_to_tensor(x) for x in xs] + if len(xs) > 1: + unique_dtypes = {x.dtype for x in xs} + if len(unique_dtypes) > 1: + dtype = dtypes.result_type(*[x.dtype for x in xs]) + xs = [cast(x, dtype) for x in xs] + xs_reshaped = [] + for x in xs: + shape = x.shape + if len(shape) == 0: + x = tf.reshape(x, (1, 1, 1)) + elif len(shape) == 1: + x = tf.expand_dims(x, axis=0) + x = tf.expand_dims(x, axis=2) + elif len(shape) == 2: + x = tf.expand_dims(x, axis=2) + xs_reshaped.append(x) + return tf.concat(xs_reshaped, axis=2) + + def empty(shape, dtype=None): dtype = dtype or config.floatx() return tf.zeros(shape, dtype=dtype) diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index 265601fe0673..1c6e517a7ac9 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -765,6 +765,11 @@ def dot(x1, x2): return cast(torch.matmul(x1, x2), result_dtype) +def dstack(xs): + xs = [convert_to_tensor(x) for x in xs] + return torch.dstack(xs) + + def empty(shape, dtype=None): dtype = to_torch_dtype(dtype or config.floatx()) return torch.empty(size=shape, dtype=dtype, device=get_device()) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index b5ff4376f4a2..b27fda08baae 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -2837,6 +2837,80 @@ def dot(x1, x2): return backend.numpy.dot(x1, x2) +class Dstack(Operation): + def call(self, xs): + return backend.numpy.dstack(xs) + + def compute_output_spec(self, xs): + dtypes_to_resolve = [] + out_shapes = [] + for x in xs: + shape = list(x.shape) + if len(shape) == 0: + shape = [1, 1, 1] + elif len(shape) == 1: + shape = [1, shape[0], 1] + elif len(shape) == 2: + shape = shape + [1] + out_shapes.append(shape) + dtypes_to_resolve.append(getattr(x, "dtype", type(x))) + + first_shape = out_shapes[0] + total_depth = 0 + for shape in out_shapes: + if not shape_equal(shape, first_shape, axis=[2], allow_none=True): + raise ValueError( + "Every value in `xs` must have the same shape except on " + f"the `axis` dim. But found element of shape {shape}, " + f"which is different from the first element's " + f"shape {first_shape}." + ) + if total_depth is None or shape[2] is None: + total_depth = None + else: + total_depth += shape[2] + + output_shape = list(first_shape) + output_shape[2] = total_depth + dtype = dtypes.result_type(*dtypes_to_resolve) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.dstack", "keras.ops.numpy.dstack"]) +def dstack(xs): + """Stack tensors in sequence depth wise (along third axis). + + This is equivalent to concatenation along the third axis after 2-D tensors + of shape `(M, N)` have been reshaped to `(M, N, 1)` and 1-D tensors of shape + `(N,)` have been reshaped to `(1, N, 1)`. + + Args: + xs: Sequence of tensors. + + Returns: + The tensor formed by stacking the given tensors. + + Examples: + >>> import keras + >>> x = keras.ops.array([1, 2, 3]) + >>> y = keras.ops.array([4, 5, 6]) + >>> keras.ops.dstack([x, y]) + array([[[1, 4], + [2, 5], + [3, 6]]]) + + >>> x = keras.ops.array([[1], [2], [3]]) + >>> y = keras.ops.array([[4], [5], [6]]) + >>> keras.ops.dstack([x, y]) + array([[[1, 4]], + [[2, 5]], + [[3, 6]]]) + """ + if any_symbolic_tensors((xs,)): + return Dstack().symbolic_call(xs) + return backend.numpy.dstack(xs) + + class Einsum(Operation): def __init__(self, subscripts, *, name=None): super().__init__(name=name) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 2617835906c5..11a574b71af3 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -2002,6 +2002,19 @@ def test_vstack(self): y = KerasTensor((None, None)) self.assertEqual(knp.vstack([x, y]).shape, (None, 3)) + def test_dstack(self): + x = KerasTensor((None,)) + y = KerasTensor((None,)) + self.assertEqual(knp.dstack([x, y]).shape, (1, None, 2)) + + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.dstack([x, y]).shape, (None, 3, 2)) + + x = KerasTensor((None, 3)) + y = KerasTensor((None, None)) + self.assertEqual(knp.dstack([x, y]).shape, (None, 3, 2)) + def test_argpartition(self): x = KerasTensor((None, 3)) self.assertEqual(knp.argpartition(x, 3).shape, (None, 3)) @@ -2672,6 +2685,19 @@ def test_vstack(self): y = KerasTensor((2, 3)) self.assertEqual(knp.vstack([x, y]).shape, (4, 3)) + def test_dstack(self): + x = KerasTensor((3,)) + y = KerasTensor((3,)) + self.assertEqual(knp.dstack([x, y]).shape, (1, 3, 2)) + + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.dstack([x, y]).shape, (2, 3, 2)) + + x = KerasTensor((2, 3, 4)) + y = KerasTensor((2, 3, 5)) + self.assertEqual(knp.dstack([x, y]).shape, (2, 3, 9)) + def test_argpartition(self): x = KerasTensor((2, 3)) self.assertEqual(knp.argpartition(x, 3).shape, (2, 3)) @@ -5501,6 +5527,20 @@ def test_vstack(self): self.assertAllClose(knp.vstack([x, y]), np.vstack([x, y])) self.assertAllClose(knp.Vstack()([x, y]), np.vstack([x, y])) + def test_dstack(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [6, 5, 4]]) + self.assertAllClose(knp.dstack([x, y]), np.dstack([x, y])) + self.assertAllClose(knp.Dstack()([x, y]), np.dstack([x, y])) + + x = np.array([1, 2, 3]) + y = np.array([[4, 5, 6]]) + self.assertAllClose(knp.dstack([x, y]), np.dstack([x, y])) + + x = np.ones([2, 3, 4]) + y = np.ones([2, 3, 5]) + self.assertAllClose(knp.dstack([x, y]), np.dstack([x, y])) + def test_floor_divide(self): x = np.array([[1, 2, 3], [3, 2, 1]]) y = np.array([[4, 5, 6], [3, 2, 1]]) @@ -7700,6 +7740,29 @@ def test_dot(self, dtypes): ) self.assertEqual(knp.Dot().symbolic_call(x1, x2).dtype, expected_dtype) + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_dstack(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1, 1), dtype=dtype1) + x2 = knp.ones((1, 1), dtype=dtype2) + x1_jax = jnp.ones((1, 1), dtype=dtype1) + x2_jax = jnp.ones((1, 1), dtype=dtype2) + + expected_dtype = standardize_dtype(jnp.dstack([x1_jax, x2_jax]).dtype) + + self.assertEqual( + standardize_dtype(knp.dstack([x1, x2]).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Dstack().symbolic_call([x1, x2]).dtype), + expected_dtype, + ) + @parameterized.named_parameters( named_product( dtypes=list(itertools.combinations(ALL_DTYPES, 2))