Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import builtins
import math

import jax
import jax.experimental.sparse as jax_sparse
import jax.numpy as jnp
from jax import export as jax_export
Expand All @@ -16,6 +17,18 @@
from keras.src.backend.jax.core import convert_to_tensor


def _uses_cpu(x):
if hasattr(x, "device"):
device = x.device
if not isinstance(device, jax.Device):
# Array is sharded.
return False
return device.platform == "cpu"
else:
# This is a Tracer, not a concrete Array.
return jax.default_backend() == "cpu"


def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the specified plane."""
if array.ndim < 2:
Expand Down Expand Up @@ -402,11 +415,9 @@ def arctanh(x):


def argmax(x, axis=None, keepdims=False):
from keras.src.testing.test_case import uses_cpu

x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or not uses_cpu() or x.ndim == 0:
if "float" not in dtype or x.ndim == 0 or not _uses_cpu(x):
return jnp.argmax(x, axis=axis, keepdims=keepdims)

# Fix the flush-to-zero (FTZ) issue based on this issue:
Expand All @@ -419,11 +430,9 @@ def argmax(x, axis=None, keepdims=False):


def argmin(x, axis=None, keepdims=False):
from keras.src.testing.test_case import uses_cpu

x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or not uses_cpu() or x.ndim == 0:
if "float" not in dtype or x.ndim == 0 or not _uses_cpu(x):
return jnp.argmin(x, axis=axis, keepdims=keepdims)

# Fix the flush-to-zero (FTZ) issue based on this issue:
Expand Down
4 changes: 1 addition & 3 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,11 +949,9 @@ def argmax(x, axis=None, keepdims=False):


def argmin(x, axis=None, keepdims=False):
from keras.src.testing.test_case import uses_cpu

x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or not uses_cpu() or x.ndim == 0:
if "float" not in dtype or x.ndim == 0:
_x = x
if axis is None:
x = tf.reshape(x, [-1])
Expand Down
14 changes: 4 additions & 10 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,8 +1277,8 @@ def test_argmax(self):
self.assertEqual(knp.argmax(x, keepdims=True).shape, (None, 3, 3))

@pytest.mark.skipif(
keras.config.backend() == "openvino" or testing.uses_tpu(),
reason="OpenVINO doesn't support this change",
keras.config.backend() == "openvino" or testing.jax_uses_tpu(),
reason="OpenVINO and JAX TPU don't support this",
)
def test_argmax_negative_zero(self):
input_data = np.array(
Expand All @@ -1287,14 +1287,8 @@ def test_argmax_negative_zero(self):
self.assertEqual(knp.argmax(input_data), 2)

@pytest.mark.skipif(
keras.config.backend() == "openvino"
or keras.config.backend() == "tensorflow"
or testing.uses_tpu(),
reason="""
OpenVINO and TensorFlow don't support this
change, TensorFlow behavior for this case is under
evaluation and may change within this PR
""",
keras.config.backend() == "openvino" or testing.jax_uses_tpu(),
reason="OpenVINO and JAX TPU don't support this",
)
def test_argmin_negative_zero(self):
input_data = np.array(
Expand Down
1 change: 1 addition & 0 deletions keras/src/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keras.src.testing.test_case import TestCase
from keras.src.testing.test_case import jax_uses_gpu
from keras.src.testing.test_case import jax_uses_tpu
from keras.src.testing.test_case import tensorflow_uses_gpu
from keras.src.testing.test_case import torch_uses_gpu
from keras.src.testing.test_case import uses_gpu
Expand Down
11 changes: 4 additions & 7 deletions keras/src/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,10 @@ def uses_gpu():
return False


def jax_uses_tpu():
return backend.backend() == "jax" and uses_tpu()


def uses_tpu():
# Condition used to skip tests when using the TPU
try:
Expand All @@ -661,13 +665,6 @@ def uses_tpu():
return False


def uses_cpu():
devices = distribution.list_devices()
if any(d.startswith("cpu") for d in devices):
return True
return False


def create_keras_tensors(input_shape, dtype, sparse, ragged):
if isinstance(input_shape, dict):
return {
Expand Down