Skip to content

Reorganized the distribution_utils submodule#741

Merged
digicosmos86 merged 13 commits into
mainfrom
738-refactor-separate-onnx-jax-conversion-from-making-its-gradient-function
Jul 2, 2025
Merged

Reorganized the distribution_utils submodule#741
digicosmos86 merged 13 commits into
mainfrom
738-refactor-separate-onnx-jax-conversion-from-making-its-gradient-function

Conversation

@digicosmos86
Copy link
Copy Markdown
Collaborator

As #737 indicates, we need a more general workflow for wrapping JAX likelihoods in pytensor ops. Currently this workflow is entangled with the conversion of ONNX likelihoods to pytensor ops, which should be treated as a special case. This PR addresses this entanglement a little bit. There is still more to be done.

The following was done:

  1. Separated make_vmap_func from make_jax_logp_funcs_from_onnx, so it can be used separately to create vectorized functions, even though its utility outside of the onnx case still remains to be seen.
  2. Elevated make_jax_logp and make_pytensor_logp functions from distribution_utils.onnx submodule, and renamed the submodule to distribution_utils.onnx_utils, so that it only contains utility functions for onnx-related stuff.
  3. Separated make_jax_logp into a separate jax.py file, making it more general for all jax use cases.
  4. Added a return_jit toggle for a few functions, so that the non-jitted onnx wrapper functions can be used in other JAX functions in the RL case.

@AlexanderFengler we might also need to move make_jax_logp_funcs_from_jax_callable from onnx.py to jax.py as well, but since it depends on another function in onnx.py, I didn't do anything to avoid circular imports. I think this function should stand alone for more general JAX use cases.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR reorganizes the distribution utilities submodule to decouple general JAX likelihood workflows from ONNX-specific functionality. The key changes include separating the JAX vectorization functions from the ONNX logic, moving JAX functions to a dedicated file (jax.py), and refining module boundaries and imports.

Reviewed Changes

Copilot reviewed 7 out of 9 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/test_onnx.py Updated imports to reflect the new module boundaries and use of onnx_utils
src/hssm/distribution_utils/onnx_utils/init.py Introduced the onnx_utils submodule to contain onnx-specific utility functions
src/hssm/distribution_utils/onnx/init.py Removed the original onnx init.py to streamline the submodule split
src/hssm/distribution_utils/onnx.py Added general-purpose functions for JAX likelihood functions and decoupled from onnx_utils
src/hssm/distribution_utils/jax.py Reorganized and simplified the JAX vectorization functions with updated docstrings and API
src/hssm/distribution_utils/dist.py Updated to use the new function signatures and added type casts where needed
src/hssm/_types.py Extended type definitions for clarity in function signatures

Comment thread src/hssm/distribution_utils/onnx.py Outdated
Comment thread src/hssm/distribution_utils/jax.py Outdated
Comment thread src/hssm/distribution_utils/jax.py
Copy link
Copy Markdown
Collaborator

@cpaniaguam cpaniaguam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good overall. I left some comments regarding making functions type-stable. Though Python is a dynamically typed language, I think observing the principle is quite advantageous.

Comment thread src/hssm/distribution_utils/dist.py Outdated
Comment thread src/hssm/distribution_utils/onnx.py Outdated
Comment thread src/hssm/distribution_utils/onnx.py
Co-authored-by: Carlos Paniagua <cpaniaguam@gmail.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 25, 2025

Codecov Report

Attention: Patch coverage is 97.36842% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/hssm/distribution_utils/onnx.py 96.96% 2 Missing ⚠️
tests/distribution_utils/test_jax.py 97.91% 1 Missing ⚠️
tests/distribution_utils/test_onnx.py 94.73% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/hssm/_types.py 100.00% <100.00%> (ø)
src/hssm/distribution_utils/dist.py 85.88% <100.00%> (-0.06%) ⬇️
src/hssm/distribution_utils/jax.py 100.00% <100.00%> (ø)
src/hssm/distribution_utils/onnx_utils/__init__.py 100.00% <100.00%> (ø)
src/hssm/distribution_utils/onnx_utils/onnx2pt.py 100.00% <ø> (ø)
src/hssm/distribution_utils/onnx_utils/onnx2xla.py 63.41% <ø> (ø)
...ests/distribution_utils/test_distribution_utils.py 100.00% <ø> (ø)
tests/distribution_utils/test_jax.py 97.91% <97.91%> (ø)
tests/distribution_utils/test_onnx.py 98.24% <94.73%> (ø)
src/hssm/distribution_utils/onnx.py 96.96% <96.96%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Member

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small comments.

Comment thread pyproject.toml
Comment thread src/hssm/distribution_utils/onnx.py
Comment thread src/hssm/distribution_utils/onnx.py Outdated
Comment thread src/hssm/distribution_utils/onnx.py
Comment thread src/hssm/distribution_utils/onnx.py Outdated
@digicosmos86
Copy link
Copy Markdown
Collaborator Author

@AlexanderFengler can you approve this PR so it can be merged? Feel free to merge it yourself

Copy link
Copy Markdown
Collaborator

@cpaniaguam cpaniaguam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@digicosmos86 digicosmos86 merged commit 40b9561 into main Jul 2, 2025
4 of 5 checks passed
@digicosmos86 digicosmos86 deleted the 738-refactor-separate-onnx-jax-conversion-from-making-its-gradient-function branch July 2, 2025 15:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Refactor: Separate ONNX-JAX conversion from making its gradient function

4 participants