Skip to content

Implement complete Keras-Orbax checkpoint integration#22002

Merged
hertschuh merged 19 commits intokeras-team:masterfrom
amitsrivastava78:modelload
Jan 28, 2026
Merged

Implement complete Keras-Orbax checkpoint integration#22002
hertschuh merged 19 commits intokeras-team:masterfrom
amitsrivastava78:modelload

Conversation

@amitsrivastava78
Copy link
Copy Markdown
Collaborator

  • Add compile config saving to OrbaxCheckpoint for full model restoration
  • Update saving_api.py to use compile_from_config for proper model loading
  • Replace direct Orbax API tests with Keras saving API tests
  • Remove unused Orbax imports from test file
  • Add comprehensive test coverage for all model state components
  • Ensure cross-backend compatibility (TensorFlow, JAX, PyTorch)

This enables users to save and load complete model state (weights, optimizer, metrics) using keras.saving.load_model() with Orbax checkpoints.

- Add compile config saving to OrbaxCheckpoint for full model restoration
- Update saving_api.py to use compile_from_config for proper model loading
- Replace direct Orbax API tests with Keras saving API tests
- Remove unused Orbax imports from test file
- Add comprehensive test coverage for all model state components
- Ensure cross-backend compatibility (TensorFlow, JAX, PyTorch)

This enables users to save and load complete model state (weights, optimizer,
metrics) using keras.saving.load_model() with Orbax checkpoints.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the integration of Orbax checkpointing within Keras by enabling the saving and loading of a complete model state, including its architecture, compilation settings, optimizer, and metrics. This allows users to fully restore a trained model from an Orbax checkpoint using the standard keras.saving.load_model() function, eliminating the need for manual reconstruction or separate loading steps for different model components. The changes ensure a more robust and user-friendly experience for checkpoint management.

Highlights

  • Complete Model State Saving: The OrbaxCheckpoint callback now saves the full model configuration, compile configuration, and optimizer configuration, enabling comprehensive model restoration.
  • Unified Loading API: The keras.saving.load_model() function has been extended to directly support loading models from Orbax checkpoint directories, streamlining the workflow for users.
  • Enhanced Test Coverage: Existing tests for Orbax checkpoint loading have been refactored to use the Keras saving API, and new tests have been added to ensure complete model state (weights, optimizer, metrics, compile config) is correctly saved and loaded across different backends.
  • Cross-Backend Compatibility: The loading mechanism includes logic to convert JAX arrays to NumPy arrays for non-JAX backends, ensuring seamless operation regardless of the underlying Keras backend.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request does a great job of integrating Keras with Orbax for checkpointing, enabling full model restoration via keras.saving.load_model. The changes to OrbaxCheckpoint to save model and compile configurations are correct. The new loading logic in saving_api.py is robust, handling detection of Orbax checkpoints and reconstruction of the model state. The test suite has been significantly improved by replacing direct Orbax API calls with tests against the public Keras saving API, and by adding comprehensive tests for all components of the model state.

I have a couple of suggestions to improve the code further. One is a high-severity suggestion to ensure remote path support in the new loading logic, and the other is a medium-severity suggestion to refactor duplicated code in the tests for better maintainability. Overall, this is a solid contribution.

Comment on lines +360 to +367
if os.path.exists(filepath):
subdirs = os.listdir(filepath)
for d in subdirs:
if os.path.isdir(os.path.join(filepath, d)):
try:
available_steps.append(int(d))
except ValueError:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For consistency with the rest of Keras and to support remote file systems (like GCS), you should use the wrappers from keras.src.utils.file_utils (e.g., file_utils.exists, file_utils.listdir, file_utils.isdir) instead of the os module. The detection logic in load_model already uses file_utils, and this function should too to ensure it works correctly with remote paths.

Suggested change
if os.path.exists(filepath):
subdirs = os.listdir(filepath)
for d in subdirs:
if os.path.isdir(os.path.join(filepath, d)):
try:
available_steps.append(int(d))
except ValueError:
pass
if file_utils.exists(filepath):
subdirs = file_utils.listdir(filepath)
for d in subdirs:
if file_utils.isdir(file_utils.join(filepath, d)):
try:
available_steps.append(int(d))
except ValueError:
pass

- Accept upstream changes for various backend and layer updates
- Re-apply Orbax checkpoint modifications
- Add comprehensive model state restoration test with JAX compatibility
- Ensure cross-backend compatibility for checkpoint loading
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Jan 13, 2026

Codecov Report

❌ Patch coverage is 90.62500% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.80%. Comparing base (fdc5543) to head (62b8cb6).
⚠️ Report is 254 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/callbacks/orbax_checkpoint.py 80.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22002      +/-   ##
==========================================
+ Coverage   82.69%   82.80%   +0.10%     
==========================================
  Files         589      592       +3     
  Lines       61632    62543     +911     
  Branches     9650     9797     +147     
==========================================
+ Hits        50967    51787     +820     
- Misses       8165     8220      +55     
- Partials     2500     2536      +36     
Flag Coverage Δ
keras 82.62% <90.62%> (+0.10%) ⬆️
keras-jax 62.39% <90.62%> (+0.92%) ⬆️
keras-numpy 56.45% <12.50%> (-0.27%) ⬇️
keras-openvino 37.61% <12.50%> (+0.13%) ⬆️
keras-tensorflow 63.66% <90.62%> (+0.02%) ⬆️
keras-torch 62.42% <90.62%> (+0.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- Fix metrics initialization during model loading to ensure all metric variables are created before state restoration
- Add evaluation step during loading to initialize missing metrics like mean_absolute_error
- Update tests to verify exact metric value matching instead of just structure validation
- Fix async save exact weight matching using final checkpoint strategy
- Simplify JAX backend code by removing unnecessary numpy conversions
- Ensure compatibility across JAX, PyTorch, and TensorFlow backends
- All tests now pass with exact metrics comparison for proper checkpoint fidelity
- Fix line length compliance to stay within 80 columns
- Implement proper exact weight matching for async saves by forcing
  final sync checkpoint in on_train_end with max_to_keep=1
- Remove inappropriate epoch-based sync forcing that hurt performance
- Simplify weight getter API: remove redundant get_final_saved_weights
  and get_last_saved_weights_exact functions
- Fix line length compliance to stay under 80 columns
- All tests pass with exact weight matching for both sync and async saves
- Remove redundant numpy import (already imported at top)
- Use backend-aware state handling: JAX preserves native arrays,
  other backends convert to numpy for exact matching
- Simplify state copying logic using tree.map_structure
- Maintain 80-column line length compliance
- Preserve performance for JAX while ensuring exact matching works
- All tests pass across backends
- Consolidated 7 redundant tests into 3 optimized comprehensive tests
- Reduced test file size by 34% (1174 -> 817 lines) while maintaining coverage
- Fixed nested dictionary comparison for JAX optimizer state variables
- Enhanced cross-backend compatibility with graceful error handling
- Ensured all lines comply with 80-column limit for better readability
- All tests pass across JAX, TensorFlow, and PyTorch backends

Tests consolidated:
- test_checkpoint_loading_via_saving_api: Basic loading + weights-only error handling
- test_checkpoint_loading_full_state_via_saving_api: Optimizer/metrics state loading
- test_comprehensive_model_state_restoration: Advanced state restoration with custom layers
- test_exact_weight_matching_with_sync_save: Sync vs async weight matching verification
Performance improvements:
- Use os.scandir() for 2-3x faster step detection vs file_utils calls
- Consolidate imports to reduce repeated import overhead
- Streamline state tree preparation with dictionary comprehension

Code simplification:
- Simplified model building and compilation logic
- Reduced nested conditions for better readability
- Optimized metrics initialization with cleaner logic
- Enhanced error handling without losing functionality

Results:
- 30-line reduction (521 → 493 lines) - 6% file size reduction
- Improved performance with faster directory operations
- Maintained cross-backend compatibility (JAX, TensorFlow, PyTorch)
- All lines comply with 80-column limit
- All tests passing with optimized implementation
Problem:
- test_save_on_background_async failing with 'Too many open files' error
- Manual orbax checkpoint detection in load_model() caused fd leaks
- Redundant code duplicated checkpoint detection logic

Solution:
- Replace manual detection with imported is_orbax_checkpoint() utility
- Eliminate file_utils.listdir() calls that leaked file descriptors
- Use existing optimized checkpoint detection logic

Results:
- Fixed OSError: [Errno 24] Too many open files in async tests
- Removed code duplication and improved maintainability
- All orbax checkpoint tests now passing consistently
- Better performance with optimized checkpoint detection
Problem:
- 'Too many open files' errors in async/sync checkpoint tests
- Orbax checkpointer file descriptors not properly cleaned up
- Tests failing in CI environment due to accumulated open file handles

Solution:
- Added __del__ method to OrbaxCheckpoint for automatic cleanup
- Added try/finally blocks in tests for explicit cleanup
- Ensures checkpointer.close() is called in all scenarios

Root Cause Analysis:
- Orbax checkpointer maintains file descriptors for checkpoint operations
- Without proper cleanup via checkpointer.close(), these accumulate
- In test environments with multiple runs, this hits system limits

Fixes:
1. Automatic cleanup: __del__ ensures cleanup during garbage collection
2. Explicit cleanup: try/finally blocks in async/sync tests
3. Defense in depth: Both normal and abnormal termination scenarios covered

Results:
- Resolves OSError: [Errno 24] Too many open files
- All async/sync checkpoint tests now pass consistently
- Proper resource management prevents file descriptor accumulation
Test file descriptor leak fixes in OrbaxCheckpoint:
- Enhanced resource management with retry logic for RESOURCE_EXHAUSTED errors
- Added garbage collection and explicit cleanup in sync save operations
- Improved test cleanup patterns to prevent file descriptor accumulation
Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole asset saving / loading part of it is missing.

- Remove redundant numpy conversion in OrbaxCheckpoint._save_checkpoint since _get_state_tree already handles format conversion
- Add future tracking for async saves to avoid memory issues with buffer donation
- Remove unused _training_ending flag and defensive self.model check
- Simplify redundant conditional logic in on_train_end fallback
- Remove manual model building in saving_api.py since build_config handles it
- Remove backwards compatibility optimizer_config fallback
- Remove unnecessary cross-backend numpy conversion assuming same-backend save/load
- Remove hacky metrics initialization via dummy evaluation
- Clean up and optimize checkpoint loading flow
@amitsrivastava78
Copy link
Copy Markdown
Collaborator Author

The whole asset saving / loading part of it is missing.

yes will raise a separate PR for that

- Add save_decision_policy=FixedIntervalPolicy(1) to fix race condition with rapid async saves
- Remove unnecessary on_train_end workaround (no longer needed with save_decision_policy)
- Remove get_last_saved_weights() method (tests use model.get_weights() directly)
- All tests pass with the cleaner implementation
- Consolidate checkpoint step detection to use find_latest_orbax_checkpoint() utility
- Remove duplicate os.scandir logic in _load_model_from_orbax_checkpoint
- Simplify state_tree filtering to only include keys that exist in composite_state
- More maintainable and DRY code
The overwrite parameter was never used (always defaulted to False) and is
unnecessary with our preservation policy and save_decision_policy handling
checkpoint management automatically.
The force_sync parameter was never used (always defaulted to None) and added
unnecessary complexity. Sync vs async behavior is already controlled by the
save_on_background constructor parameter, making this override unnecessary.
Orbax's checkpointer.close() already waits for pending async operations to
complete before closing (per its API contract). The explicit wait_until_finished()
call was redundant and added in later commits unnecessarily.

Reverting to the simpler original pattern where close() handles the wait.
Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized that the config saving and loading should reuse some utilities that already exist and take care of some non-trivial stuff.

- Remove redundant wait_until_finished() calls from tests since on_train_end already waits
- Remove unnecessary try/finally cleanup blocks in tests
- Remove unused _last_checkpoint_path variable in callback
- Use saving_lib._serialize_model_as_json for model config serialization (consistency and proper object sharing)
- Use saving_lib._model_from_config for model loading (handles shared objects and compile_config)
- Replace np.testing.assert_array_almost_equal with self.assertAllClose for better cross-backend compatibility
- Consolidate tests using parameterized tests (batch/epoch freq, sync/async, save_best_only modes)
- Remove redundant test_directory_creation test
- All tests pass on JAX, PyTorch, and TensorFlow backends
@amitsrivastava78
Copy link
Copy Markdown
Collaborator Author

amitsrivastava78 commented Jan 27, 2026

I realized that the config saving and loading should reuse some utilities that already exist and take care of some non-trivial stuff.

yes, the implementation now is perfectly aligned with the recommendation to reuse existing utilities

…test structure

- Move imports to top of file (saving, utils, register_keras_serializable)
- Remove import keras statements, use specific imports instead
- Replace keras.utils and keras.backend with direct module usage
- Simplify optimizer variable comparison using direct iteration
- Remove complex nested state tree comparison, use simple weight/optimizer variable loops
- Parametrize test_comprehensive_model_state_restoration with sync/async modes
- Remove redundant test_exact_weight_matching_with_sync_save (covered by parametrized test)
- Remove redundant test_checkpoint_loading_full_state_via_saving_api (covered by comprehensive test)
- Simplify test logic leveraging assertAllClose's built-in numpy conversion
Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jan 28, 2026
@hertschuh hertschuh merged commit 0737462 into keras-team:master Jan 28, 2026
11 of 12 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Jan 28, 2026
jerryxyj added a commit to jerryxyj/keras that referenced this pull request Feb 14, 2026
* Implement logaddexp2 function in keras.ops (keras-team#21691)

* [Keras 3 OpenVINO Backend]: Support numpy.sort (keras-team#21687)

* [Keras 3 OpenVINO Backend]: Support numpy.median operation (keras-team#21667)

* Fix deadlock in `CallbackList`. (keras-team#21701)

* [OpenVINO backend] solve randomuniform issue (keras-team#21670)

* Bug fixes with variable handling in `LossScaleOptimizer`. (keras-team#21706)

* Do not use backend ops in `ProgBar`. (keras-team#21709)

* Fix the Doc of the combination relation in func `keras.layers.Normali…

* Remove reliance on `__jax_array__` to unwrap variables. (keras-team#21719)

* Bump the github-actions group with 6 updates (keras-team#21705)

* Add linspace and logspace implementations in OpenVINO NumPy backend (…

* Add jvp op (keras-team#21720)

* Add unfold op (keras-team#21685)

* Add the description that `0` should not in the arg `axes` in `keras.l…

* Add daily Python 3.13 CPU-only tests to nightly workflow (keras-team#21566)

* Fix histogram op for symbolic inputs (keras-team#21729)

* Relax tolerance for svd test (keras-team#21731)

* Use jax.enable_x64 in place of jax.experimental.disable_x64 (keras-team#21734)

* Refactor variable serialization. (keras-team#21713)

* Ensure keras.ops.eye behavior is consistent across backends. (keras-team#21738)

* Add `eye` support for OpenVINO backend (keras-team#21739)

* Update Torch and Tensorflow versions in cuda requirements files. (keras-team#21…

* Implement isreal function in keras.ops (keras-team#21740)

* Remove the unused jax `enable_x64`. (keras-team#21737)

* Correct implementation for several OpenVINO operations (keras-team#21746)

* Sets `is_gptq_calibrated` flag when deserializing GPTQ models (keras-team#21748)

* Correct implementation for several OpenVINO operations (keras-team#21752)

* Fix the Bug in func `preprocess_input` when `x` in 3D and `data_forma…

* Update Torch to 2.9.0 on GPU. (keras-team#21756)

* `StringLookup` & `IntegerLookup` now save vocabulary loaded from file…

* Implement trapezoid function in keras.ops (keras-team#21757)

* Upstream `ReversibleEmbedding` from KerasHub. (keras-team#21753)

* Raise exception on batch_size mismatch for stateful RNNs (keras-team#21742)

* Propose a method for handling datasets which doesn't explicitly requi…

* Use `filter="data"` option of `TarFile.extractall`. (keras-team#21760)

* Add Distillation API to Keras (keras-team#21572)

* removes unnecessary try-catch blocks and guard conditions (keras-team#21767)

* cleanup distillation  loss names (keras-team#21766)

* Document that `set_backend` requires re-importing keras. (keras-team#21764)

* Fix discretization discrepancy (keras-team#21769)

* fix sas metrics in jax `fit` (keras-team#21765)

* Support for extracting volume patches (keras-team#21759)

* Fix negative index handling in MultiHeadAttention attention_axes (keras-team#21…

* Make confusion metrics compilable. (keras-team#21775)

* Suport keras.op.view() to view the same data bitwise at a new dtype  …

* Fix: `keras.ops.quantile` works with tf graph execution (keras-team#21782)

* Fix typo in Distiller docstring

* Add warning to `set_backend` and more detailed example. (keras-team#21787)

* Don't fail `Variable.__repr__` if the value cannot be retrieved. (keras-team#21…

* Update Keras backend installation instructions

* Fix: Support 'jpg' format in keras.utils.save_img() (keras-team#21683)

* Fix tf dataset detection logic. (keras-team#21794)

* update test after jax.config.jax_vjp3 is enabled (keras-team#21776)

* Add keras.ops.array_split for Tensor Parallelism Support (keras-team#21697)

* Adding get_device_count function to the distribution_lib (keras-team#21791)

* Fix: use raw string for CALIBRATION_TEXT (keras-team#21790)

* Add backend compatibility table to documentation (keras-team#21733)

* More OpenVINO Operations (keras-team#21774)

* Support scalar view for tf backend. (keras-team#21802)

* Address bug with convolution using Tensorflow, Numpy, Jax backends (#…

* Fix bug with correlate for tensorflow (keras-team#21778)

* Pass optional field in a few places to fix None input error. (keras-team#21818)

* Fix(backend/torch): Resolve MPS broadcast crash in binary_crossentrop…

* Fix broken example indentation in Keras io (keras-team#21807)

* Add missing `convert_to_tensor` to `take_along_axis` on JAX. (keras-team#21825)

* Added  numpy.digitize support for OPENVINO backend  (keras-team#21824)

* Bump the github-actions group with 4 updates (keras-team#21809)

* Fix typo in CONTRIBUTING.md (keras-team#21812)

* Fix `Progbar.update` when receiving list, np arrays, and tensors. (#2…

* Fix CosineDecay documentation to clarify alpha is a multiplier (keras-team#21827)

* Fix noise_shape validation in keras.layers.Dropout (keras-team#21819)

* Fix typos in some files (keras-team#21830)

* Fix failing sklearn tests following release of pytest 9.0. (keras-team#21843)

* Implement empty_like function in keras.ops (keras-team#21840)

* Run tests on TPU (keras-team#21425)

* Fix typo in variable name 'embeding' to 'embedding' (keras-team#21845)

* Fix name_scope_stack AttributeError and IndexError in __exit__ (keras-team#21834)

* Update keras3 Softmax mask handling to be more numerically robust. (#…

* Support jax2tf in JaxLayer for tf backend (keras-team#21842)

* Fix assigning a value to a variable within an autocast scope. (keras-team#21864)

* Add note about label noise in CIFAR-10 dataset documentation (keras-team#21855)

* Allow None inputs in `Layer.build`. (keras-team#21866)

* `standardize_shape` normalizes the dimensions and tuple. (keras-team#21867)

* Improve error message when layer/model input validation fails. (keras-team#21869)

* Add verbose logging when ModelCheckpoint callback is done saving ... …

* [OpenVINO backend] Remove deprecated openvino.runtime import (keras-team#21826)

* Fix Torch output_padding constraint for ConvTranspose layers (keras-team#21852)

* Support PyDataset in Normalization layer `adapt` methods (keras-team#21817)

* Fix test failures when nnx is enabled (keras-team#21875)

* Implement ldexp function in keras.ops (keras-team#21863)

* Added OrbaxCheckpoint for keras 3.0 for Data centric saving and resto…

* Add raise_error option to TerminateOnNaN for immediate termination on…

* Fix NNX tests (keras-team#21884)

* `keras.utils.set_random_seed` clear the global `SeedGenerator`. (keras-team#21874)

* fix tpu test (keras-team#21893)

* Model Export to liteRT (keras-team#21674)

* Fix: torch layer losses keyword arguments in rematscope (keras-team#21865)

* Add label to trigger TPU tests manually. (keras-team#21897)

* Support tpu tests allowing tpu precision for matmul (keras-team#21887)

* remove log (keras-team#21901)

* Introduces layer filtering for quantization and fixes GPTQ dependency…

* Replace `np.reshape(x, newshape=y)` with `np.reshape(x, y)`. (keras-team#21899)

* Modified Dense layer documentation for use_bias with batch normalizat…

* [OpenVINO Backend] Support np.diag (keras-team#20967)

* Modify Muon optimizer (keras-team#21885)

* Disables implicit GPTQ quantization using dtype_policy setter (keras-team#21895)

* Dense: validate units argument (keras-team#21902)

* Pin `ai-edge-litert` version to fix CI (keras-team#21912)

* Increase JAX GPU tests timeout to 2 hours (keras-team#21915)

* Fix TPU tests - for splash attention (keras-team#21891)

* Support various filtering functions in OpenVINO (keras-team#21836)

* OpenVINO NN Module Functions (keras-team#21803)

* fix XLA dynamic shape output of ops.diag (keras-team#21906)

* Fix: Remove redundant epsilon in loss mask weight calculation (keras-team#21908)

* Implement vander function in keras.ops (keras-team#21882)

* Fix Muon optimizer with TensorFlow backend. (keras-team#21924)

* OpenVino `device_scope` and data adapters tests (keras-team#21922)

* Fix fake quant gradient output shape and use `jax.grad` for tests. (#…

* Introduces QuantizationConfig for fine-grained quantization control (…

* Extended fix OOM Issue keras-team#21634 on Keras side (keras-team#21755)

* Fix ops.tile shape inference issue on TensorFlow backend (keras-team#21860)

* Add adaptive pooling (1D, 2D, 3D) support across JAX, NumPy, TensorFl…

* More OpenVINO Numpy Operations (keras-team#21925)

* Adds Serialization Support for QuantizationConfig based quantized mod…

* Refactors AbsMaxQuantizer to accept axis in __call__ (keras-team#21931)

* Speed up unit tests on JAX and TensorFlow. (keras-team#21933)

* update dev version number (keras-team#21921)

* Always use `run_tpu_tests` label to run the TPU tests. (keras-team#21900)

* Revert "Always use `run_tpu_tests` label to run the TPU tests. (keras-team#2190…

* Forward-fix for JAX API changes (keras-team#21938)

* Remove nightly tests with Python 3.13. (keras-team#21943)

* Do no always make batch size dynamic during export. (keras-team#21944)

* Fix `numpy.mean` with dynamic shape on OpenVino. (keras-team#21947)

* Remove NumPy warning with NumPy >= 2. (keras-team#21949)

* Always use `run_tpu_tests` label to run the TPU tests. (keras-team#21950)

* [OpenVINO backend] Support np.vander, np.trapezoid, np.corrcoef, np.c…

* Fixed a bug in _keras_mask (keras-team#21946)

* Fix handling of symbolic Tensor in RNN (keras-team#21945)

* Add example for arctanh (keras-team#21951)

* Fix DoS via malicious HDF5 dataset metadata in KerasFileEditor (keras-team#21880)

* Implement nextafter function in keras.ops (keras-team#21960)

* fix image.extract_patches strides handling (keras-team#21959)

* [OpenVINO backend] Support numpy.flip (keras-team#21963)

* Bump the github-actions group with 4 updates (keras-team#21968)

* Fix CUDNN flash attention for JAX > 0.6.2. (keras-team#21970)

* Skip `PyDataset` tests on TPU. (keras-team#21964)

* Add missing `name` to `SeedGenerator.get_config`. (keras-team#21975)

* Use `subprocess.run` in `pip_build.py` to escape wheel path. (keras-team#21976)

* Update dependencies and `dependabot.yml`. (keras-team#21974)

* Use `kokoro:force-run` label for TPU tests too. (keras-team#21956)

* Add simple example for keras.layers.Resizing (keras-team#21966)

* [OpenVINO backend] Support numpy.diagonal (keras-team#21965)

* Bump actions/checkout from 5.0.1 to 6.0.1 in the github-actions group…

* Fix ReversibleEmbedding mask error when using reverse=True (keras-team#21961)

* Update feature_space.py (keras-team#21935)

* Clarify Tracker docstring wording (keras-team#21985)

* Remove semi-colon after email in SECURITY.md (keras-team#21993)

* Implement cbrt function for OpenVINO backend (keras-team#21987)

* Fix config keys for chain depth and num chains (keras-team#21979)

* Implement hypot and trace function for OpenVINO backend (keras-team#21991)

* Implement ptp function in keras.ops (keras-team#21990)

* Orbax Loading and Sharding Support feature (keras-team#21903)

* Add usage examples to loss docstrings (keras-team#21989)

* Unify extract_patches to support both 2D and 3D patches (keras-team#21980)

* Fix ndim to support tf.RaggedTensor by using shape.rank (keras-team#21999)

* Implement size and swapaxes function for OpenVINO backend.  (keras-team#21995)

* Implement kron function for OpenVINO backend (keras-team#22000)

* Adds support for AWQ (keras-team#21992)

* Trigger TPU tests on kokoro label removal rather than addition. (keras-team#22001)

* Document complex dtype limitation in ops.correlate (keras-team#21984)

* [OpenVINO backend] Fix and enable numpy.rot90 (keras-team#21967)

* Only skip TPU excluded tests on TPU. (keras-team#22008)

* Improvements to `JaxLayer` and `FlaxLayer` related to RNG handling an…

* Fix typo in contrast adjustment method (keras-team#22012)

* Fix typo and improve docstring formatting (keras-team#22017)

* Implement nansum function in keras.ops (keras-team#21996)

* Fix unreliable Orbax checkpoint detection with custom implementation …

* Unpin as many Python packages versions as possible. (keras-team#22023)

* Allow `CenterCrop` layer to handle dynamic image sizes. (keras-team#22020)

* TPU tests now verify that we can detect TPUs and fails it not. (keras-team#22019)

* Refactor ExtractPatches to handle both 2D and 3D (keras-team#22013)

* Implement  argpartition function for OpenVINO backend (keras-team#22025)

* Implement logaddexp2 function for OpenVINO backend (keras-team#22026)

* Implement nanmin function in keras.ops (keras-team#22040)

* Increase test coverage for IntegerLookup layer (keras-team#22022)

* feat: Add documentation examples for image preprocessing augmentation…

* Fix: activity regularizer not normalized by batch size (keras-team#22021)

* Implement ldexp and select ops for OpenVINO backend (keras-team#22042)

* Fix: convert deque to list before tf.transpose in keras.ops.quantile …

* Fix timedistributed mask validation (keras-team#22039)

* Torch backend: allow explicit device selection and guard DirectML usa…

* Implement nanmax function in keras.ops (keras-team#22043)

* Add bias support for torch's `dot_product_attention`. (keras-team#22045)

* Fix incorrect example in `ops.associative_scan` docstring (keras-team#22051)

* Add Batch Renormalisation (keras-team#22047)

* Implement round and divide_no_nan ops for OpenVINO backend (keras-team#22052)

* Add dynamic shape support for torch backend export (keras-team#22041)

* Implement vstack func for OpenVINO backend (keras-team#22059)

* Implement ptp function for OpenVINO backend (keras-team#22060)

* Implement nanmean function in keras.ops (keras-team#22055)

* Do not allow external links in HDF5 files. (keras-team#22057)

* Fix discretization symbolic one hot (keras-team#22048)

* Implement complete Keras-Orbax checkpoint integration (keras-team#22002)

* Increase test coverage for StringLookup preprocessing layer (keras-team#22056)

* Set mutable to True by default in nnx_metadata (keras-team#22074)

* Adds Asymmetric INT4 Sub-Channel Quantization Support (keras-team#22007)

* Allow passing variables to a function with `@custom_gradient`. (keras-team#22069)

* Disallow TFSMLayer deserialization in safe_mode to prevent external S…

* Remove redundant global seed initialization code. (keras-team#22084)

* Add `Muon` to the list of all optimizer classes. (keras-team#22083)

* Implement tile function for openvino backend (keras-team#22071)

* implement nansum ops for openvino backend (keras-team#22078)

* Remove `testing.uses_cpu()` and re-implement for JAX. (keras-team#22087)

* benchmarks: add RandomRotation tf.data performance benchmark (keras-team#21986)

* Fix arctan2 NaN propagation in OpenVINO backend (keras-team#22064)

* Validate positive height and width in image resize (keras-team#22079)

* Don't skip some JAX linalg tests on JAX. (keras-team#22091)

* Implement nanprod function in keras.ops (keras-team#22089)

* Increase test coverage for TextVectorization layer (keras-team#22066)

* Bump the github-actions group with 2 updates (keras-team#22093)

* fix: pytorch onnx export symbolic test (keras-team#22086)

* Improvements to `*_uses_gpu` and `*_uses_tpu`. (keras-team#22088)

* Implement cross product operation for OpenVINO backend (keras-team#22096)

* Fail fast on invalid convolution output shapes during symbolic execut…

* Fix Normalization broadcasting for scalar and multidim mean and varia…

* Standardize the way tests are skipped based on backend and accelerato…

* Don't call `pythonify_logs` within `get_metrics_result`. (keras-team#22107)

* Fix gaussian_blur padding calculation for even kernel sizes (keras-team#22054)

* Adjust JAX variable initializer jitting criteria. (keras-team#22116)

* Exclude conv transpose tests on TPU. (keras-team#22117)

* Remove incorrect but dead code in `BaseOptimizer.stateless_apply`. (#…

* Implement tensordot operation for OpenVINO backend (keras-team#22098)

* Fix bounding box docstring references (keras-team#22110)

* feat: add depth_to_space and space_to_depth ops (keras-team#22112)

* Fix sparse reshape test with Numpy 2.4. (keras-team#22141)

* Fix vocabulary reload corruption caused by trailing newline handling …

* Add support for dynamic dimensions in `ops.slice.compute_output_spec`…

* Revamp graph validation in `Function.__init__`. (keras-team#22153)

* Fix: draw_bounding_boxes float32 to uint8 conversion (keras-team#22129)

* Implement dstack function across all backends (keras-team#22120)

* Add exp2 operation to OpenVINO backend (keras-team#22131)

* Add trunc operation to OpenVINO backend (keras-team#22134)

* Fix: add missing validation for output padding < strides (keras-team#22130)

* docs: Add guide on resuming training from weight-only checkpoints (#2…

* feat(openvino): upgrade opset to opset15 (keras-team#22159)

* Fix order-dependent float16/bfloat16 promotion in cast_to_common_dtyp…

* Fix TrackedDict constructor to support iterable (key, value) inputs (…

* Implement numpy.gcd using Euclidean algorithm for OpenVINO backend (#…

* [Keras 3] Refactor ExportArchive to be a dispatcher for different exp…

* [Keras 3] Refactor ExportArchive to be a dispatcher for different exp…
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants