Fix rldm vjp function and update rlssm tutorial#755
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Pull Request Overview
This PR corrects the vector-Jacobian product function for the RLDM likelihood and updates its documentation and output slicing.
- Changed
vjp_logpsignature to variadic inputs and keyword-onlygz - Refined the returned gradient slice to exclude data and extra fields
- Updated docstring to describe new inputs format
Comments suppressed due to low confidence (3)
src/hssm/likelihoods/rldm.py:221
- The docstring still refers to
inputsas a single list, but the function now uses*inputsvarargs. Update the parameter description to match the variadic signature.
inputs : list
src/hssm/likelihoods/rldm.py:213
- [nitpick] It would be helpful to add unit tests for the updated
vjp_logpbehavior, especially to verify the new argument handling and the correct slicing of the returned gradients.
A function that computes the VJP of the log likelihood for the RLDM model.
src/hssm/likelihoods/rldm.py:216
- Changing to
*inputs, gzmakesgzkeyword-only and may break callers passinggzpositionally. Consider allowinggzas a positional argument or update all call sites and doc to reflect the new signature.
def vjp_logp(*inputs, gz):
| """ | ||
| _, vjp_logp = jax.vjp(logp, *inputs) | ||
| return vjp_logp(gz)[1:] | ||
| return vjp_logp(gz)[1:7] # Exclude the data and the extra fields |
There was a problem hiding this comment.
Using a hardcoded slice 1:7 can become brittle if the number of parameters or extra fields changes. Consider computing the slice dynamically or documenting why exactly six elements are selected.
| return vjp_logp(gz)[1:7] # Exclude the data and the extra fields | |
| n_params = 6 # Number of model parameters: rl_alpha, scaler, a, z, t, theta | |
| return vjp_logp(gz)[1 : 1 + n_params] # Exclude the first field and retain only the parameters |
digicosmos86
left a comment
There was a problem hiding this comment.
LGTM! I am happy to make n_params an optional parameter to make_vjp_logp so the function is more generalizable
Codecov ReportAttention: Patch coverage is
🚀 New features to boost your workflow:
|
This PR fixed the issue in rldm
vjp_logp()that caused VI calls (both via HSSM and PyMC) to break.The rlssm demo tutorial is also updates to showcase a working example with NUTS and VI.