Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
from keras.src.ops.numpy import divide as divide
from keras.src.ops.numpy import divide_no_nan as divide_no_nan
from keras.src.ops.numpy import dot as dot
from keras.src.ops.numpy import dstack as dstack
from keras.src.ops.numpy import einsum as einsum
from keras.src.ops.numpy import empty as empty
from keras.src.ops.numpy import empty_like as empty_like
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from keras.src.ops.numpy import divide as divide
from keras.src.ops.numpy import divide_no_nan as divide_no_nan
from keras.src.ops.numpy import dot as dot
from keras.src.ops.numpy import dstack as dstack
from keras.src.ops.numpy import einsum as einsum
from keras.src.ops.numpy import empty as empty
from keras.src.ops.numpy import empty_like as empty_like
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
from keras.src.ops.numpy import divide as divide
from keras.src.ops.numpy import divide_no_nan as divide_no_nan
from keras.src.ops.numpy import dot as dot
from keras.src.ops.numpy import dstack as dstack
from keras.src.ops.numpy import einsum as einsum
from keras.src.ops.numpy import empty as empty
from keras.src.ops.numpy import empty_like as empty_like
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from keras.src.ops.numpy import divide as divide
from keras.src.ops.numpy import divide_no_nan as divide_no_nan
from keras.src.ops.numpy import dot as dot
from keras.src.ops.numpy import dstack as dstack
from keras.src.ops.numpy import einsum as einsum
from keras.src.ops.numpy import empty as empty
from keras.src.ops.numpy import empty_like as empty_like
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,10 @@ def dot(x1, x2):
return jnp.dot(x1, x2)


def dstack(xs):
return jnp.dstack(xs)


def empty(shape, dtype=None):
dtype = dtype or config.floatx()
return jnp.empty(shape, dtype=dtype)
Expand Down
10 changes: 10 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,16 @@ def dot(x1, x2):
return np.dot(x1, x2)


def dstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tree.map_structure(
lambda x: convert_to_tensor(x).astype(dtype), xs
)
return np.dstack(xs)


def empty(shape, dtype=None):
dtype = dtype or config.floatx()
return np.empty(shape, dtype=dtype)
Expand Down
42 changes: 42 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,48 @@ def dot(x1, x2):
return OpenVINOKerasTensor(ov_opset.matmul(x1, x2, False, False).output(0))


def dstack(xs):
if not isinstance(xs, (list, tuple)):
xs = (xs,)
elems = [convert_to_tensor(elem) for elem in xs]
element_type = elems[0].output.get_element_type()
elems = [get_ov_output(elem, element_type) for elem in elems]

processed_elems = []
for elem in elems:
shape = elem.get_partial_shape()
rank = shape.rank
shape_len = rank.get_length()
if shape_len == 0:
elem = ov_opset.unsqueeze(
elem, ov_opset.constant(0, Type.i32)
).output(0)
elem = ov_opset.unsqueeze(
elem, ov_opset.constant(1, Type.i32)
).output(0)
elem = ov_opset.unsqueeze(
elem, ov_opset.constant(2, Type.i32)
).output(0)
elif shape_len == 1:
elem = ov_opset.unsqueeze(
elem, ov_opset.constant(0, Type.i32)
).output(0)
elem = ov_opset.unsqueeze(
elem, ov_opset.constant(2, Type.i32)
).output(0)
elif shape_len == 2:
elem = ov_opset.unsqueeze(
elem, ov_opset.constant(2, Type.i32)
).output(0)
processed_elems.append(elem)

for i in range(1, len(processed_elems)):
processed_elems[0], processed_elems[i] = _align_operand_types(
processed_elems[0], processed_elems[i], "dstack()"
)
return OpenVINOKerasTensor(ov_opset.concat(processed_elems, 2).output(0))


def empty(shape, dtype=None):
dtype = standardize_dtype(dtype) or config.floatx()
ov_type = OPENVINO_DTYPES[dtype]
Expand Down
21 changes: 21 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,27 @@ def dot(x1, x2):
return tf.cast(output, result_dtype)


def dstack(xs):
xs = [convert_to_tensor(x) for x in xs]
if len(xs) > 1:
unique_dtypes = {x.dtype for x in xs}
if len(unique_dtypes) > 1:
dtype = dtypes.result_type(*[x.dtype for x in xs])
xs = [cast(x, dtype) for x in xs]
xs_reshaped = []
for x in xs:
shape = x.shape
if len(shape) == 0:
x = tf.reshape(x, (1, 1, 1))
elif len(shape) == 1:
x = tf.expand_dims(x, axis=0)
x = tf.expand_dims(x, axis=2)
elif len(shape) == 2:
x = tf.expand_dims(x, axis=2)
xs_reshaped.append(x)
return tf.concat(xs_reshaped, axis=2)
Comment thread
Shi-pra-19 marked this conversation as resolved.


def empty(shape, dtype=None):
dtype = dtype or config.floatx()
return tf.zeros(shape, dtype=dtype)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,11 @@ def dot(x1, x2):
return cast(torch.matmul(x1, x2), result_dtype)


def dstack(xs):
xs = [convert_to_tensor(x) for x in xs]
return torch.dstack(xs)


def empty(shape, dtype=None):
dtype = to_torch_dtype(dtype or config.floatx())
return torch.empty(size=shape, dtype=dtype, device=get_device())
Expand Down
74 changes: 74 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2837,6 +2837,80 @@ def dot(x1, x2):
return backend.numpy.dot(x1, x2)


class Dstack(Operation):
def call(self, xs):
return backend.numpy.dstack(xs)

def compute_output_spec(self, xs):
dtypes_to_resolve = []
out_shapes = []
for x in xs:
shape = list(x.shape)
if len(shape) == 0:
shape = [1, 1, 1]
elif len(shape) == 1:
shape = [1, shape[0], 1]
elif len(shape) == 2:
shape = shape + [1]
out_shapes.append(shape)
dtypes_to_resolve.append(getattr(x, "dtype", type(x)))

first_shape = out_shapes[0]
total_depth = 0
for shape in out_shapes:
if not shape_equal(shape, first_shape, axis=[2], allow_none=True):
raise ValueError(
"Every value in `xs` must have the same shape except on "
f"the `axis` dim. But found element of shape {shape}, "
f"which is different from the first element's "
f"shape {first_shape}."
)
if total_depth is None or shape[2] is None:
total_depth = None
else:
total_depth += shape[2]

output_shape = list(first_shape)
output_shape[2] = total_depth
dtype = dtypes.result_type(*dtypes_to_resolve)
return KerasTensor(output_shape, dtype=dtype)


@keras_export(["keras.ops.dstack", "keras.ops.numpy.dstack"])
def dstack(xs):
"""Stack tensors in sequence depth wise (along third axis).

This is equivalent to concatenation along the third axis after 2-D tensors
of shape `(M, N)` have been reshaped to `(M, N, 1)` and 1-D tensors of shape
`(N,)` have been reshaped to `(1, N, 1)`.

Args:
xs: Sequence of tensors.

Returns:
The tensor formed by stacking the given tensors.

Examples:
>>> import keras
>>> x = keras.ops.array([1, 2, 3])
>>> y = keras.ops.array([4, 5, 6])
>>> keras.ops.dstack([x, y])
array([[[1, 4],
[2, 5],
[3, 6]]])

>>> x = keras.ops.array([[1], [2], [3]])
>>> y = keras.ops.array([[4], [5], [6]])
>>> keras.ops.dstack([x, y])
array([[[1, 4]],
[[2, 5]],
[[3, 6]]])
"""
Comment thread
Shi-pra-19 marked this conversation as resolved.
if any_symbolic_tensors((xs,)):
return Dstack().symbolic_call(xs)
return backend.numpy.dstack(xs)


class Einsum(Operation):
def __init__(self, subscripts, *, name=None):
super().__init__(name=name)
Expand Down
63 changes: 63 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2002,6 +2002,19 @@ def test_vstack(self):
y = KerasTensor((None, None))
self.assertEqual(knp.vstack([x, y]).shape, (None, 3))

def test_dstack(self):
x = KerasTensor((None,))
y = KerasTensor((None,))
self.assertEqual(knp.dstack([x, y]).shape, (1, None, 2))

x = KerasTensor((None, 3))
y = KerasTensor((None, 3))
self.assertEqual(knp.dstack([x, y]).shape, (None, 3, 2))

x = KerasTensor((None, 3))
y = KerasTensor((None, None))
self.assertEqual(knp.dstack([x, y]).shape, (None, 3, 2))

def test_argpartition(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.argpartition(x, 3).shape, (None, 3))
Expand Down Expand Up @@ -2672,6 +2685,19 @@ def test_vstack(self):
y = KerasTensor((2, 3))
self.assertEqual(knp.vstack([x, y]).shape, (4, 3))

def test_dstack(self):
x = KerasTensor((3,))
y = KerasTensor((3,))
self.assertEqual(knp.dstack([x, y]).shape, (1, 3, 2))

x = KerasTensor((2, 3))
y = KerasTensor((2, 3))
self.assertEqual(knp.dstack([x, y]).shape, (2, 3, 2))

x = KerasTensor((2, 3, 4))
y = KerasTensor((2, 3, 5))
self.assertEqual(knp.dstack([x, y]).shape, (2, 3, 9))

def test_argpartition(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.argpartition(x, 3).shape, (2, 3))
Expand Down Expand Up @@ -5501,6 +5527,20 @@ def test_vstack(self):
self.assertAllClose(knp.vstack([x, y]), np.vstack([x, y]))
self.assertAllClose(knp.Vstack()([x, y]), np.vstack([x, y]))

def test_dstack(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array([[4, 5, 6], [6, 5, 4]])
self.assertAllClose(knp.dstack([x, y]), np.dstack([x, y]))
self.assertAllClose(knp.Dstack()([x, y]), np.dstack([x, y]))

x = np.array([1, 2, 3])
y = np.array([[4, 5, 6]])
self.assertAllClose(knp.dstack([x, y]), np.dstack([x, y]))

x = np.ones([2, 3, 4])
y = np.ones([2, 3, 5])
self.assertAllClose(knp.dstack([x, y]), np.dstack([x, y]))

def test_floor_divide(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array([[4, 5, 6], [3, 2, 1]])
Expand Down Expand Up @@ -7700,6 +7740,29 @@ def test_dot(self, dtypes):
)
self.assertEqual(knp.Dot().symbolic_call(x1, x2).dtype, expected_dtype)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
)
def test_dstack(self, dtypes):
import jax.numpy as jnp

dtype1, dtype2 = dtypes
x1 = knp.ones((1, 1), dtype=dtype1)
x2 = knp.ones((1, 1), dtype=dtype2)
x1_jax = jnp.ones((1, 1), dtype=dtype1)
x2_jax = jnp.ones((1, 1), dtype=dtype2)

expected_dtype = standardize_dtype(jnp.dstack([x1_jax, x2_jax]).dtype)

self.assertEqual(
standardize_dtype(knp.dstack([x1, x2]).dtype),
expected_dtype,
)
self.assertEqual(
standardize_dtype(knp.Dstack().symbolic_call([x1, x2]).dtype),
expected_dtype,
)

@parameterized.named_parameters(
named_product(
dtypes=list(itertools.combinations(ALL_DTYPES, 2))
Expand Down