Skip to content

Allow passing variables to a function with @custom_gradient.#22069

Merged
hertschuh merged 1 commit intokeras-team:masterfrom
hertschuh:custom_gradient
Jan 29, 2026
Merged

Allow passing variables to a function with @custom_gradient.#22069
hertschuh merged 1 commit intokeras-team:masterfrom
hertschuh:custom_gradient

Conversation

@hertschuh
Copy link
Copy Markdown
Collaborator

With the JAX backend, passing a variable to a function with a custom gradient would fail.

This adds a wrapper to unwrap variables.

Fixes #21105

With the JAX backend, passing a variable to a function with a custom gradient would fail.

This adds a wrapper to unwrap variables.

Fixes keras-team#21105
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a compatibility issue within the JAX backend where KerasVariable objects were not correctly handled by functions utilizing @custom_gradient. By introducing a wrapper that unwraps these variables to their raw values, the system now ensures seamless integration and proper gradient computation in such scenarios. The accompanying test updates validate this fix, enhancing the robustness of custom gradient definitions in Keras with JAX.

Highlights

  • JAX Custom Gradient Variable Handling: Modified the custom_gradient function in the JAX backend to automatically unwrap KerasVariable instances to their underlying .value before passing them to JAX's custom_gradient mechanism.
  • Bug Fix: This change resolves an issue where passing a KerasVariable directly to a function decorated with @custom_gradient would cause a failure in the JAX backend.
  • Test Coverage: Expanded the test_custom_gradient in keras/src/ops/core_test.py to include a parameterized test case that explicitly verifies the correct behavior when a KerasVariable is passed to a custom-gradient-decorated function.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an issue with @custom_gradient on the JAX backend when keras.Variable instances are passed as arguments. The approach of wrapping the function to unwrap variables is sound, and the added tests effectively validate the fix for both tensor and variable inputs. However, there's a small but critical issue in the implementation of the wrapper function that needs to be addressed.

Comment on lines +610 to +613
args, kwargs = tree.map_shape_structure(
lambda x: x.value if isinstance(x, KerasVariable) else x,
(args, kwargs),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using tree.map_shape_structure here is incorrect. This function is designed to treat shape-like tuples as leaves and will attempt to traverse into KerasVariable objects since they are not shape-like tuples and are registered as PyTrees. This will apply the lambda function to the children of the KerasVariable instead of the variable itself, defeating the purpose of this wrapper.

You should use tree.map_structure instead. It will correctly treat KerasVariable instances as leaves and apply the unwrapping function to them.

Suggested change
args, kwargs = tree.map_shape_structure(
lambda x: x.value if isinstance(x, KerasVariable) else x,
(args, kwargs),
)
args, kwargs = tree.map_structure(
lambda x: x.value if isinstance(x, KerasVariable) else x,
(args, kwargs),
)

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Jan 28, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.80%. Comparing base (8a37f9d) to head (8768da4).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #22069   +/-   ##
=======================================
  Coverage   82.80%   82.80%           
=======================================
  Files         592      592           
  Lines       62501    62505    +4     
  Branches     9789     9789           
=======================================
+ Hits        51751    51757    +6     
+ Misses       8215     8213    -2     
  Partials     2535     2535           
Flag Coverage Δ
keras 82.63% <100.00%> (+<0.01%) ⬆️
keras-jax 62.38% <100.00%> (+<0.01%) ⬆️
keras-numpy 56.47% <0.00%> (-0.01%) ⬇️
keras-openvino 37.64% <0.00%> (+<0.01%) ⬆️
keras-tensorflow 63.63% <0.00%> (-0.01%) ⬇️
keras-torch 62.41% <0.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Collaborator

@JyotinderSingh JyotinderSingh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jan 29, 2026
@hertschuh hertschuh merged commit 38687a6 into keras-team:master Jan 29, 2026
12 of 13 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Jan 29, 2026
@hertschuh hertschuh deleted the custom_gradient branch January 29, 2026 02:33
jerryxyj added a commit to jerryxyj/keras that referenced this pull request Feb 14, 2026
* Implement logaddexp2 function in keras.ops (keras-team#21691)

* [Keras 3 OpenVINO Backend]: Support numpy.sort (keras-team#21687)

* [Keras 3 OpenVINO Backend]: Support numpy.median operation (keras-team#21667)

* Fix deadlock in `CallbackList`. (keras-team#21701)

* [OpenVINO backend] solve randomuniform issue (keras-team#21670)

* Bug fixes with variable handling in `LossScaleOptimizer`. (keras-team#21706)

* Do not use backend ops in `ProgBar`. (keras-team#21709)

* Fix the Doc of the combination relation in func `keras.layers.Normali…

* Remove reliance on `__jax_array__` to unwrap variables. (keras-team#21719)

* Bump the github-actions group with 6 updates (keras-team#21705)

* Add linspace and logspace implementations in OpenVINO NumPy backend (…

* Add jvp op (keras-team#21720)

* Add unfold op (keras-team#21685)

* Add the description that `0` should not in the arg `axes` in `keras.l…

* Add daily Python 3.13 CPU-only tests to nightly workflow (keras-team#21566)

* Fix histogram op for symbolic inputs (keras-team#21729)

* Relax tolerance for svd test (keras-team#21731)

* Use jax.enable_x64 in place of jax.experimental.disable_x64 (keras-team#21734)

* Refactor variable serialization. (keras-team#21713)

* Ensure keras.ops.eye behavior is consistent across backends. (keras-team#21738)

* Add `eye` support for OpenVINO backend (keras-team#21739)

* Update Torch and Tensorflow versions in cuda requirements files. (keras-team#21…

* Implement isreal function in keras.ops (keras-team#21740)

* Remove the unused jax `enable_x64`. (keras-team#21737)

* Correct implementation for several OpenVINO operations (keras-team#21746)

* Sets `is_gptq_calibrated` flag when deserializing GPTQ models (keras-team#21748)

* Correct implementation for several OpenVINO operations (keras-team#21752)

* Fix the Bug in func `preprocess_input` when `x` in 3D and `data_forma…

* Update Torch to 2.9.0 on GPU. (keras-team#21756)

* `StringLookup` & `IntegerLookup` now save vocabulary loaded from file…

* Implement trapezoid function in keras.ops (keras-team#21757)

* Upstream `ReversibleEmbedding` from KerasHub. (keras-team#21753)

* Raise exception on batch_size mismatch for stateful RNNs (keras-team#21742)

* Propose a method for handling datasets which doesn't explicitly requi…

* Use `filter="data"` option of `TarFile.extractall`. (keras-team#21760)

* Add Distillation API to Keras (keras-team#21572)

* removes unnecessary try-catch blocks and guard conditions (keras-team#21767)

* cleanup distillation  loss names (keras-team#21766)

* Document that `set_backend` requires re-importing keras. (keras-team#21764)

* Fix discretization discrepancy (keras-team#21769)

* fix sas metrics in jax `fit` (keras-team#21765)

* Support for extracting volume patches (keras-team#21759)

* Fix negative index handling in MultiHeadAttention attention_axes (keras-team#21…

* Make confusion metrics compilable. (keras-team#21775)

* Suport keras.op.view() to view the same data bitwise at a new dtype  …

* Fix: `keras.ops.quantile` works with tf graph execution (keras-team#21782)

* Fix typo in Distiller docstring

* Add warning to `set_backend` and more detailed example. (keras-team#21787)

* Don't fail `Variable.__repr__` if the value cannot be retrieved. (keras-team#21…

* Update Keras backend installation instructions

* Fix: Support 'jpg' format in keras.utils.save_img() (keras-team#21683)

* Fix tf dataset detection logic. (keras-team#21794)

* update test after jax.config.jax_vjp3 is enabled (keras-team#21776)

* Add keras.ops.array_split for Tensor Parallelism Support (keras-team#21697)

* Adding get_device_count function to the distribution_lib (keras-team#21791)

* Fix: use raw string for CALIBRATION_TEXT (keras-team#21790)

* Add backend compatibility table to documentation (keras-team#21733)

* More OpenVINO Operations (keras-team#21774)

* Support scalar view for tf backend. (keras-team#21802)

* Address bug with convolution using Tensorflow, Numpy, Jax backends (#…

* Fix bug with correlate for tensorflow (keras-team#21778)

* Pass optional field in a few places to fix None input error. (keras-team#21818)

* Fix(backend/torch): Resolve MPS broadcast crash in binary_crossentrop…

* Fix broken example indentation in Keras io (keras-team#21807)

* Add missing `convert_to_tensor` to `take_along_axis` on JAX. (keras-team#21825)

* Added  numpy.digitize support for OPENVINO backend  (keras-team#21824)

* Bump the github-actions group with 4 updates (keras-team#21809)

* Fix typo in CONTRIBUTING.md (keras-team#21812)

* Fix `Progbar.update` when receiving list, np arrays, and tensors. (#2…

* Fix CosineDecay documentation to clarify alpha is a multiplier (keras-team#21827)

* Fix noise_shape validation in keras.layers.Dropout (keras-team#21819)

* Fix typos in some files (keras-team#21830)

* Fix failing sklearn tests following release of pytest 9.0. (keras-team#21843)

* Implement empty_like function in keras.ops (keras-team#21840)

* Run tests on TPU (keras-team#21425)

* Fix typo in variable name 'embeding' to 'embedding' (keras-team#21845)

* Fix name_scope_stack AttributeError and IndexError in __exit__ (keras-team#21834)

* Update keras3 Softmax mask handling to be more numerically robust. (#…

* Support jax2tf in JaxLayer for tf backend (keras-team#21842)

* Fix assigning a value to a variable within an autocast scope. (keras-team#21864)

* Add note about label noise in CIFAR-10 dataset documentation (keras-team#21855)

* Allow None inputs in `Layer.build`. (keras-team#21866)

* `standardize_shape` normalizes the dimensions and tuple. (keras-team#21867)

* Improve error message when layer/model input validation fails. (keras-team#21869)

* Add verbose logging when ModelCheckpoint callback is done saving ... …

* [OpenVINO backend] Remove deprecated openvino.runtime import (keras-team#21826)

* Fix Torch output_padding constraint for ConvTranspose layers (keras-team#21852)

* Support PyDataset in Normalization layer `adapt` methods (keras-team#21817)

* Fix test failures when nnx is enabled (keras-team#21875)

* Implement ldexp function in keras.ops (keras-team#21863)

* Added OrbaxCheckpoint for keras 3.0 for Data centric saving and resto…

* Add raise_error option to TerminateOnNaN for immediate termination on…

* Fix NNX tests (keras-team#21884)

* `keras.utils.set_random_seed` clear the global `SeedGenerator`. (keras-team#21874)

* fix tpu test (keras-team#21893)

* Model Export to liteRT (keras-team#21674)

* Fix: torch layer losses keyword arguments in rematscope (keras-team#21865)

* Add label to trigger TPU tests manually. (keras-team#21897)

* Support tpu tests allowing tpu precision for matmul (keras-team#21887)

* remove log (keras-team#21901)

* Introduces layer filtering for quantization and fixes GPTQ dependency…

* Replace `np.reshape(x, newshape=y)` with `np.reshape(x, y)`. (keras-team#21899)

* Modified Dense layer documentation for use_bias with batch normalizat…

* [OpenVINO Backend] Support np.diag (keras-team#20967)

* Modify Muon optimizer (keras-team#21885)

* Disables implicit GPTQ quantization using dtype_policy setter (keras-team#21895)

* Dense: validate units argument (keras-team#21902)

* Pin `ai-edge-litert` version to fix CI (keras-team#21912)

* Increase JAX GPU tests timeout to 2 hours (keras-team#21915)

* Fix TPU tests - for splash attention (keras-team#21891)

* Support various filtering functions in OpenVINO (keras-team#21836)

* OpenVINO NN Module Functions (keras-team#21803)

* fix XLA dynamic shape output of ops.diag (keras-team#21906)

* Fix: Remove redundant epsilon in loss mask weight calculation (keras-team#21908)

* Implement vander function in keras.ops (keras-team#21882)

* Fix Muon optimizer with TensorFlow backend. (keras-team#21924)

* OpenVino `device_scope` and data adapters tests (keras-team#21922)

* Fix fake quant gradient output shape and use `jax.grad` for tests. (#…

* Introduces QuantizationConfig for fine-grained quantization control (…

* Extended fix OOM Issue keras-team#21634 on Keras side (keras-team#21755)

* Fix ops.tile shape inference issue on TensorFlow backend (keras-team#21860)

* Add adaptive pooling (1D, 2D, 3D) support across JAX, NumPy, TensorFl…

* More OpenVINO Numpy Operations (keras-team#21925)

* Adds Serialization Support for QuantizationConfig based quantized mod…

* Refactors AbsMaxQuantizer to accept axis in __call__ (keras-team#21931)

* Speed up unit tests on JAX and TensorFlow. (keras-team#21933)

* update dev version number (keras-team#21921)

* Always use `run_tpu_tests` label to run the TPU tests. (keras-team#21900)

* Revert "Always use `run_tpu_tests` label to run the TPU tests. (keras-team#2190…

* Forward-fix for JAX API changes (keras-team#21938)

* Remove nightly tests with Python 3.13. (keras-team#21943)

* Do no always make batch size dynamic during export. (keras-team#21944)

* Fix `numpy.mean` with dynamic shape on OpenVino. (keras-team#21947)

* Remove NumPy warning with NumPy >= 2. (keras-team#21949)

* Always use `run_tpu_tests` label to run the TPU tests. (keras-team#21950)

* [OpenVINO backend] Support np.vander, np.trapezoid, np.corrcoef, np.c…

* Fixed a bug in _keras_mask (keras-team#21946)

* Fix handling of symbolic Tensor in RNN (keras-team#21945)

* Add example for arctanh (keras-team#21951)

* Fix DoS via malicious HDF5 dataset metadata in KerasFileEditor (keras-team#21880)

* Implement nextafter function in keras.ops (keras-team#21960)

* fix image.extract_patches strides handling (keras-team#21959)

* [OpenVINO backend] Support numpy.flip (keras-team#21963)

* Bump the github-actions group with 4 updates (keras-team#21968)

* Fix CUDNN flash attention for JAX > 0.6.2. (keras-team#21970)

* Skip `PyDataset` tests on TPU. (keras-team#21964)

* Add missing `name` to `SeedGenerator.get_config`. (keras-team#21975)

* Use `subprocess.run` in `pip_build.py` to escape wheel path. (keras-team#21976)

* Update dependencies and `dependabot.yml`. (keras-team#21974)

* Use `kokoro:force-run` label for TPU tests too. (keras-team#21956)

* Add simple example for keras.layers.Resizing (keras-team#21966)

* [OpenVINO backend] Support numpy.diagonal (keras-team#21965)

* Bump actions/checkout from 5.0.1 to 6.0.1 in the github-actions group…

* Fix ReversibleEmbedding mask error when using reverse=True (keras-team#21961)

* Update feature_space.py (keras-team#21935)

* Clarify Tracker docstring wording (keras-team#21985)

* Remove semi-colon after email in SECURITY.md (keras-team#21993)

* Implement cbrt function for OpenVINO backend (keras-team#21987)

* Fix config keys for chain depth and num chains (keras-team#21979)

* Implement hypot and trace function for OpenVINO backend (keras-team#21991)

* Implement ptp function in keras.ops (keras-team#21990)

* Orbax Loading and Sharding Support feature (keras-team#21903)

* Add usage examples to loss docstrings (keras-team#21989)

* Unify extract_patches to support both 2D and 3D patches (keras-team#21980)

* Fix ndim to support tf.RaggedTensor by using shape.rank (keras-team#21999)

* Implement size and swapaxes function for OpenVINO backend.  (keras-team#21995)

* Implement kron function for OpenVINO backend (keras-team#22000)

* Adds support for AWQ (keras-team#21992)

* Trigger TPU tests on kokoro label removal rather than addition. (keras-team#22001)

* Document complex dtype limitation in ops.correlate (keras-team#21984)

* [OpenVINO backend] Fix and enable numpy.rot90 (keras-team#21967)

* Only skip TPU excluded tests on TPU. (keras-team#22008)

* Improvements to `JaxLayer` and `FlaxLayer` related to RNG handling an…

* Fix typo in contrast adjustment method (keras-team#22012)

* Fix typo and improve docstring formatting (keras-team#22017)

* Implement nansum function in keras.ops (keras-team#21996)

* Fix unreliable Orbax checkpoint detection with custom implementation …

* Unpin as many Python packages versions as possible. (keras-team#22023)

* Allow `CenterCrop` layer to handle dynamic image sizes. (keras-team#22020)

* TPU tests now verify that we can detect TPUs and fails it not. (keras-team#22019)

* Refactor ExtractPatches to handle both 2D and 3D (keras-team#22013)

* Implement  argpartition function for OpenVINO backend (keras-team#22025)

* Implement logaddexp2 function for OpenVINO backend (keras-team#22026)

* Implement nanmin function in keras.ops (keras-team#22040)

* Increase test coverage for IntegerLookup layer (keras-team#22022)

* feat: Add documentation examples for image preprocessing augmentation…

* Fix: activity regularizer not normalized by batch size (keras-team#22021)

* Implement ldexp and select ops for OpenVINO backend (keras-team#22042)

* Fix: convert deque to list before tf.transpose in keras.ops.quantile …

* Fix timedistributed mask validation (keras-team#22039)

* Torch backend: allow explicit device selection and guard DirectML usa…

* Implement nanmax function in keras.ops (keras-team#22043)

* Add bias support for torch's `dot_product_attention`. (keras-team#22045)

* Fix incorrect example in `ops.associative_scan` docstring (keras-team#22051)

* Add Batch Renormalisation (keras-team#22047)

* Implement round and divide_no_nan ops for OpenVINO backend (keras-team#22052)

* Add dynamic shape support for torch backend export (keras-team#22041)

* Implement vstack func for OpenVINO backend (keras-team#22059)

* Implement ptp function for OpenVINO backend (keras-team#22060)

* Implement nanmean function in keras.ops (keras-team#22055)

* Do not allow external links in HDF5 files. (keras-team#22057)

* Fix discretization symbolic one hot (keras-team#22048)

* Implement complete Keras-Orbax checkpoint integration (keras-team#22002)

* Increase test coverage for StringLookup preprocessing layer (keras-team#22056)

* Set mutable to True by default in nnx_metadata (keras-team#22074)

* Adds Asymmetric INT4 Sub-Channel Quantization Support (keras-team#22007)

* Allow passing variables to a function with `@custom_gradient`. (keras-team#22069)

* Disallow TFSMLayer deserialization in safe_mode to prevent external S…

* Remove redundant global seed initialization code. (keras-team#22084)

* Add `Muon` to the list of all optimizer classes. (keras-team#22083)

* Implement tile function for openvino backend (keras-team#22071)

* implement nansum ops for openvino backend (keras-team#22078)

* Remove `testing.uses_cpu()` and re-implement for JAX. (keras-team#22087)

* benchmarks: add RandomRotation tf.data performance benchmark (keras-team#21986)

* Fix arctan2 NaN propagation in OpenVINO backend (keras-team#22064)

* Validate positive height and width in image resize (keras-team#22079)

* Don't skip some JAX linalg tests on JAX. (keras-team#22091)

* Implement nanprod function in keras.ops (keras-team#22089)

* Increase test coverage for TextVectorization layer (keras-team#22066)

* Bump the github-actions group with 2 updates (keras-team#22093)

* fix: pytorch onnx export symbolic test (keras-team#22086)

* Improvements to `*_uses_gpu` and `*_uses_tpu`. (keras-team#22088)

* Implement cross product operation for OpenVINO backend (keras-team#22096)

* Fail fast on invalid convolution output shapes during symbolic execut…

* Fix Normalization broadcasting for scalar and multidim mean and varia…

* Standardize the way tests are skipped based on backend and accelerato…

* Don't call `pythonify_logs` within `get_metrics_result`. (keras-team#22107)

* Fix gaussian_blur padding calculation for even kernel sizes (keras-team#22054)

* Adjust JAX variable initializer jitting criteria. (keras-team#22116)

* Exclude conv transpose tests on TPU. (keras-team#22117)

* Remove incorrect but dead code in `BaseOptimizer.stateless_apply`. (#…

* Implement tensordot operation for OpenVINO backend (keras-team#22098)

* Fix bounding box docstring references (keras-team#22110)

* feat: add depth_to_space and space_to_depth ops (keras-team#22112)

* Fix sparse reshape test with Numpy 2.4. (keras-team#22141)

* Fix vocabulary reload corruption caused by trailing newline handling …

* Add support for dynamic dimensions in `ops.slice.compute_output_spec`…

* Revamp graph validation in `Function.__init__`. (keras-team#22153)

* Fix: draw_bounding_boxes float32 to uint8 conversion (keras-team#22129)

* Implement dstack function across all backends (keras-team#22120)

* Add exp2 operation to OpenVINO backend (keras-team#22131)

* Add trunc operation to OpenVINO backend (keras-team#22134)

* Fix: add missing validation for output padding < strides (keras-team#22130)

* docs: Add guide on resuming training from weight-only checkpoints (#2…

* feat(openvino): upgrade opset to opset15 (keras-team#22159)

* Fix order-dependent float16/bfloat16 promotion in cast_to_common_dtyp…

* Fix TrackedDict constructor to support iterable (key, value) inputs (…

* Implement numpy.gcd using Euclidean algorithm for OpenVINO backend (#…

* [Keras 3] Refactor ExportArchive to be a dispatcher for different exp…

* [Keras 3] Refactor ExportArchive to be a dispatcher for different exp…
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

custom_gradient not working with JAX backend

5 participants