Skip to content

fix: safe negative-RT handling and JAX backend for RDM3#926

Merged
AlexanderFengler merged 3 commits intomainfrom
RDM_improvement
Mar 12, 2026
Merged

fix: safe negative-RT handling and JAX backend for RDM3#926
AlexanderFengler merged 3 commits intomainfrom
RDM_improvement

Conversation

@krishnbera
Copy link
Copy Markdown
Collaborator

  • Replace element-wise jax_check_parameters for rt > 0 with jnp.where + LOGP_LB floor in logp_rdm3. This clamps negative decision times to a small epsilon before passing to the PDF/CDF, preventing NaN in the likelihood and providing a finite gradient floor per trial.
  • Set backend to "jax" and add default HalfNormal priors for A and t in both racing_diffusion_3_config and rdm3_config.

Made-with: Cursor

- Replace element-wise jax_check_parameters for rt > 0 with jnp.where +
  LOGP_LB floor in logp_rdm3. This clamps negative decision times to a
  small epsilon before passing to the PDF/CDF, preventing NaN in the
  likelihood and providing a finite gradient floor per trial.
- Set backend to "jax" and add default HalfNormal priors for A and t in
  both racing_diffusion_3_config and rdm3_config.

Made-with: Cursor
Copilot AI review requested due to automatic review settings March 10, 2026 19:59
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves numerical stability and JAX integration for the 3-choice Racing Diffusion Model (RDM3) analytical likelihood, aiming to prevent NaNs when non-decision time exceeds observed RT and to provide sensible defaults for inference.

Changes:

  • Clamp negative decision times inside logp_rdm3 and floor affected trials to LOGP_LB to avoid NaNs in the likelihood.
  • Set the RDM3 analytical likelihood backend to "jax" in the default model configs.
  • Add default HalfNormal priors for A and t in both RDM3-related configs.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
src/hssm/modelconfig/rdm3_config.py Switch analytical backend to JAX and add default priors for A and t.
src/hssm/modelconfig/racing_diffusion_3_config.py Same as above for the racing_diffusion_3 model config.
src/hssm/likelihoods/analytical.py Make RDM3 logp robust to rt - t <= 0 via safe clamping and LOGP_LB flooring.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

This file is dead code — the dynamic config loader in
modelconfig/__init__.py derives module names from the model name
"racing_diffusion_3", so only racing_diffusion_3_config.py is ever
loaded. rdm3_config.py was never imported or called.

Made-with: Cursor
Copy link
Copy Markdown
Collaborator

@cpaniaguam cpaniaguam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. Just one little suggestion.

Pass logp_rdm3 epsilon through JAX RDM internals to keep clamping behavior consistent and preserve LOGP_LB dtype under JAX. Add a regression test ensuring rt<=t returns LOGP_LB and stays finite for RDM3.

Made-with: Cursor
Copy link
Copy Markdown
Member

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm,thanks @krishnbera

@AlexanderFengler AlexanderFengler merged commit b9650f5 into main Mar 12, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants