Reorganized the distribution_utils submodule#741
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR reorganizes the distribution utilities submodule to decouple general JAX likelihood workflows from ONNX-specific functionality. The key changes include separating the JAX vectorization functions from the ONNX logic, moving JAX functions to a dedicated file (jax.py), and refining module boundaries and imports.
Reviewed Changes
Copilot reviewed 7 out of 9 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_onnx.py | Updated imports to reflect the new module boundaries and use of onnx_utils |
| src/hssm/distribution_utils/onnx_utils/init.py | Introduced the onnx_utils submodule to contain onnx-specific utility functions |
| src/hssm/distribution_utils/onnx/init.py | Removed the original onnx init.py to streamline the submodule split |
| src/hssm/distribution_utils/onnx.py | Added general-purpose functions for JAX likelihood functions and decoupled from onnx_utils |
| src/hssm/distribution_utils/jax.py | Reorganized and simplified the JAX vectorization functions with updated docstrings and API |
| src/hssm/distribution_utils/dist.py | Updated to use the new function signatures and added type casts where needed |
| src/hssm/_types.py | Extended type definitions for clarity in function signatures |
cpaniaguam
left a comment
There was a problem hiding this comment.
This looks good overall. I left some comments regarding making functions type-stable. Though Python is a dynamically typed language, I think observing the principle is quite advantageous.
Co-authored-by: Carlos Paniagua <cpaniaguam@gmail.com>
AlexanderFengler
left a comment
There was a problem hiding this comment.
Some small comments.
|
@AlexanderFengler can you approve this PR so it can be merged? Feel free to merge it yourself |
As #737 indicates, we need a more general workflow for wrapping JAX likelihoods in pytensor ops. Currently this workflow is entangled with the conversion of ONNX likelihoods to pytensor ops, which should be treated as a special case. This PR addresses this entanglement a little bit. There is still more to be done.
The following was done:
make_vmap_funcfrommake_jax_logp_funcs_from_onnx, so it can be used separately to create vectorized functions, even though its utility outside of theonnxcase still remains to be seen.make_jax_logpandmake_pytensor_logpfunctions fromdistribution_utils.onnxsubmodule, and renamed the submodule todistribution_utils.onnx_utils, so that it only contains utility functions foronnx-related stuff.make_jax_logpinto a separatejax.pyfile, making it more general for all jax use cases.return_jittoggle for a few functions, so that the non-jitted onnx wrapper functions can be used in other JAX functions in the RL case.@AlexanderFengler we might also need to move
make_jax_logp_funcs_from_jax_callablefromonnx.pytojax.pyas well, but since it depends on another function inonnx.py, I didn't do anything to avoid circular imports. I think this function should stand alone for more general JAX use cases.