fix: safe negative-RT handling and JAX backend for RDM3#926
Merged
AlexanderFengler merged 3 commits intomainfrom Mar 12, 2026
Merged
fix: safe negative-RT handling and JAX backend for RDM3#926AlexanderFengler merged 3 commits intomainfrom
AlexanderFengler merged 3 commits intomainfrom
Conversation
- Replace element-wise jax_check_parameters for rt > 0 with jnp.where + LOGP_LB floor in logp_rdm3. This clamps negative decision times to a small epsilon before passing to the PDF/CDF, preventing NaN in the likelihood and providing a finite gradient floor per trial. - Set backend to "jax" and add default HalfNormal priors for A and t in both racing_diffusion_3_config and rdm3_config. Made-with: Cursor
Contributor
There was a problem hiding this comment.
Pull request overview
This PR improves numerical stability and JAX integration for the 3-choice Racing Diffusion Model (RDM3) analytical likelihood, aiming to prevent NaNs when non-decision time exceeds observed RT and to provide sensible defaults for inference.
Changes:
- Clamp negative decision times inside
logp_rdm3and floor affected trials toLOGP_LBto avoid NaNs in the likelihood. - Set the RDM3 analytical likelihood backend to
"jax"in the default model configs. - Add default HalfNormal priors for
Aandtin both RDM3-related configs.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/hssm/modelconfig/rdm3_config.py |
Switch analytical backend to JAX and add default priors for A and t. |
src/hssm/modelconfig/racing_diffusion_3_config.py |
Same as above for the racing_diffusion_3 model config. |
src/hssm/likelihoods/analytical.py |
Make RDM3 logp robust to rt - t <= 0 via safe clamping and LOGP_LB flooring. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file is dead code — the dynamic config loader in modelconfig/__init__.py derives module names from the model name "racing_diffusion_3", so only racing_diffusion_3_config.py is ever loaded. rdm3_config.py was never imported or called. Made-with: Cursor
cpaniaguam
reviewed
Mar 10, 2026
Collaborator
cpaniaguam
left a comment
There was a problem hiding this comment.
lgtm. Just one little suggestion.
Pass logp_rdm3 epsilon through JAX RDM internals to keep clamping behavior consistent and preserve LOGP_LB dtype under JAX. Add a regression test ensuring rt<=t returns LOGP_LB and stays finite for RDM3. Made-with: Cursor
AlexanderFengler
approved these changes
Mar 12, 2026
Member
AlexanderFengler
left a comment
There was a problem hiding this comment.
lgtm,thanks @krishnbera
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Made-with: Cursor