diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index f4010303..49e510af 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -464,11 +464,11 @@ def _jax_rdm_tcdf(t, b, v, A, eps_t: float = 1e-8, eps_v: float = 1e-6): return F -def _jax_rdm3_ll(t, ch, A, b, v0, v1, v2): +def _jax_rdm3_ll(t, ch, A, b, v0, v1, v2, eps_t: float = 1e-8): """Log-likelihood for 3-choice RDM at JAX level.""" v_vector = jnp.stack(jnp.broadcast_arrays(v0, v1, v2)) - all_pdfs = vmap(lambda v: _jax_rdm_tpdf(t, b, v, A))(v_vector) - all_cdfs = vmap(lambda v: _jax_rdm_tcdf(t, b, v, A))(v_vector) + all_pdfs = vmap(lambda v: _jax_rdm_tpdf(t, b, v, A, eps_t=eps_t))(v_vector) + all_cdfs = vmap(lambda v: _jax_rdm_tcdf(t, b, v, A, eps_t=eps_t))(v_vector) idx = jnp.arange(ch.shape[0]) pdf_winner = all_pdfs[ch, idx] @@ -489,6 +489,7 @@ def logp_rdm3( v1: float, v2: float, t: float, + epsilon: float = 1e-10, ) -> np.ndarray: """Compute the log-likelihood of the RDM model with 3 drift rates.""" data_reshaped = jnp.reshape(data, (-1, 2)).astype(pytensor.config.floatX) @@ -497,9 +498,13 @@ def logp_rdm3( response = data_reshaped[:, 1] response_int = response.astype(jnp.int_) - logp = _jax_rdm3_ll(rt, response_int, A, b, v0, v1, v2).squeeze() + is_negative_rt = rt <= 0.0 + rt_safe = jnp.where(is_negative_rt, jnp.asarray(epsilon, dtype=rt.dtype), rt) + + logp = _jax_rdm3_ll(rt_safe, response_int, A, b, v0, v1, v2, eps_t=epsilon) + logp = jnp.where(is_negative_rt, jnp.asarray(LOGP_LB, dtype=logp.dtype), logp) + logp = logp.squeeze() logp = jax_check_parameters(logp, b > A, msg="b > A") - logp = jax_check_parameters(logp, rt > 0.0, msg="rt > 0 after non-decision time") return logp diff --git a/src/hssm/modelconfig/racing_diffusion_3_config.py b/src/hssm/modelconfig/racing_diffusion_3_config.py index 3d6aac5e..0167fd11 100644 --- a/src/hssm/modelconfig/racing_diffusion_3_config.py +++ b/src/hssm/modelconfig/racing_diffusion_3_config.py @@ -25,8 +25,17 @@ def get_racing_diffusion_3_config() -> DefaultConfig: "likelihoods": { "analytical": { "loglik": logp_rdm3, - "backend": None, - "default_priors": {}, + "backend": "jax", + "default_priors": { + "A": { + "name": "HalfNormal", + "sigma": 0.5, + }, + "t": { + "name": "HalfNormal", + "sigma": 0.3, + }, + }, "bounds": rdm3_bounds, "extra_fields": None, } diff --git a/src/hssm/modelconfig/rdm3_config.py b/src/hssm/modelconfig/rdm3_config.py deleted file mode 100644 index 46aafb5e..00000000 --- a/src/hssm/modelconfig/rdm3_config.py +++ /dev/null @@ -1,34 +0,0 @@ -from .._types import DefaultConfig # noqa: D100 -from ..likelihoods.analytical import ( - logp_rdm3, - rdm3_bounds, - rdm3_params, -) - - -def get_rdm3_config() -> DefaultConfig: - """ - Get the default configuration for the Racing Diffusion Model 3 (RDM3). - - Returns - ------- - DefaultConfig - A dictionary containing the default configuration settings for the RDM3, - including response variables, model parameters, choices, description, - and likelihood specifications. - """ - return { - "response": ["rt", "response"], - "list_params": rdm3_params, - "choices": [0, 1, 2], - "description": "Racing Diffusion Model 3 Choices (RDM3)", - "likelihoods": { - "analytical": { - "loglik": logp_rdm3, - "backend": None, - "default_priors": {}, - "bounds": rdm3_bounds, - "extra_fields": None, - } - }, - } diff --git a/tests/test_likelihoods_rdm.py b/tests/test_likelihoods_rdm.py index ea0f638f..72990f5b 100644 --- a/tests/test_likelihoods_rdm.py +++ b/tests/test_likelihoods_rdm.py @@ -9,7 +9,7 @@ import hssm # pylint: disable=C0413 -from hssm.likelihoods.analytical import logp_rdm3 +from hssm.likelihoods.analytical import LOGP_LB, logp_rdm3 hssm.set_floatX("float32") @@ -82,3 +82,21 @@ def test_racing_diffusion(logp_func, model, theta): assert np.isneginf(res).all() except pm.logprob.utils.ParameterValueError: pass + + +def test_rdm3_negative_rt_returns_logp_lb_and_is_finite(): + """Trials with rt <= t should return LOGP_LB and remain finite.""" + data = np.array( + [ + (0.02, 0.0), + (0.60, 1.0), + ], + dtype="float32", + ) + theta = dict(v0=1.0, v1=1.2, v2=1.4, b=2.0, A=1.0, t=0.05) + + logp = np.asarray(logp_rdm3(data, **theta)) + + assert np.isclose(logp[0], LOGP_LB) + assert np.all(np.isfinite(logp)) + assert not np.any(np.isnan(logp))