Skip to content

Fix rldm vjp function and update rlssm tutorial#755

Merged
krishnbera merged 2 commits intomainfrom
fix-rldm-vjp-function
Jul 9, 2025
Merged

Fix rldm vjp function and update rlssm tutorial#755
krishnbera merged 2 commits intomainfrom
fix-rldm-vjp-function

Conversation

@krishnbera
Copy link
Copy Markdown
Collaborator

This PR fixed the issue in rldm vjp_logp() that caused VI calls (both via HSSM and PyMC) to break.
The rlssm demo tutorial is also updates to showcase a working example with NUTS and VI.

@krishnbera krishnbera requested a review from Copilot July 8, 2025 22:31
@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR corrects the vector-Jacobian product function for the RLDM likelihood and updates its documentation and output slicing.

  • Changed vjp_logp signature to variadic inputs and keyword-only gz
  • Refined the returned gradient slice to exclude data and extra fields
  • Updated docstring to describe new inputs format
Comments suppressed due to low confidence (3)

src/hssm/likelihoods/rldm.py:221

  • The docstring still refers to inputs as a single list, but the function now uses *inputs varargs. Update the parameter description to match the variadic signature.
        inputs : list

src/hssm/likelihoods/rldm.py:213

  • [nitpick] It would be helpful to add unit tests for the updated vjp_logp behavior, especially to verify the new argument handling and the correct slicing of the returned gradients.
        A function that computes the VJP of the log likelihood for the RLDM model.

src/hssm/likelihoods/rldm.py:216

  • Changing to *inputs, gz makes gz keyword-only and may break callers passing gz positionally. Consider allowing gz as a positional argument or update all call sites and doc to reflect the new signature.
    def vjp_logp(*inputs, gz):

"""
_, vjp_logp = jax.vjp(logp, *inputs)
return vjp_logp(gz)[1:]
return vjp_logp(gz)[1:7] # Exclude the data and the extra fields
Copy link

Copilot AI Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a hardcoded slice 1:7 can become brittle if the number of parameters or extra fields changes. Consider computing the slice dynamically or documenting why exactly six elements are selected.

Suggested change
return vjp_logp(gz)[1:7] # Exclude the data and the extra fields
n_params = 6 # Number of model parameters: rl_alpha, scaler, a, z, t, theta
return vjp_logp(gz)[1 : 1 + n_params] # Exclude the first field and retain only the parameters

Copilot uses AI. Check for mistakes.
@krishnbera krishnbera added bug Something isn't working documentation Improvements or additions to documentation labels Jul 8, 2025
Copy link
Copy Markdown
Collaborator

@digicosmos86 digicosmos86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I am happy to make n_params an optional parameter to make_vjp_logp so the function is more generalizable

@codecov
Copy link
Copy Markdown

codecov bot commented Jul 8, 2025

Codecov Report

Attention: Patch coverage is 50.00000% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/hssm/likelihoods/rldm.py 50.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/hssm/likelihoods/rldm.py 96.96% <50.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@krishnbera krishnbera merged commit 0c3fcc2 into main Jul 9, 2025
4 of 5 checks passed
@krishnbera krishnbera deleted the fix-rldm-vjp-function branch July 9, 2025 17:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants