Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
16647ef
feat: added is_choice_only property to Config
digicosmos86 Jan 29, 2026
763e870
feat: added a is_choice_only to the attributes of Model
digicosmos86 Jan 29, 2026
4ae091d
feat: update mechanism through which choice-only models are identified
digicosmos86 Jan 29, 2026
85993c9
tests: update tests for config
digicosmos86 Jan 29, 2026
8f4b681
feat: use a dummy simulator function to get around the lack of simula…
digicosmos86 Jan 29, 2026
c6f6d50
fix: added a general analytical likelihood for softmax_inv_temperatur…
digicosmos86 Feb 13, 2026
5dc7844
feat: added inv_softmax_temperature default config
digicosmos86 Feb 13, 2026
66ea7b7
feat: added specific defaults for softmax_inv_temperature models with…
digicosmos86 Feb 13, 2026
55d1b38
chore: prevent uploading AI slops
digicosmos86 Feb 13, 2026
fb67a71
feat: added two softmax_inv_temperature models to SupportedModels
digicosmos86 Feb 13, 2026
11aaeef
fix: slight update to inv_softmax_temperature function
digicosmos86 Feb 13, 2026
3fe87a4
tests: added unit tests for `softmax_inv_temperature`
digicosmos86 Feb 13, 2026
4155210
tests: added unit tests for `softmax_inv_temperature`
digicosmos86 Feb 13, 2026
2f12aeb
Merge branch '886-implement-the-first-choice-only-model' of https://g…
digicosmos86 Feb 13, 2026
22e1fdf
feat: implement softmax_inv_temperature_config for model configurations
digicosmos86 Feb 13, 2026
a4f2181
refactor: rename inv_softmax_temperature tests to softmax_inv_tempera…
digicosmos86 Feb 13, 2026
5e89809
feat: update response handling to support choice-only model configura…
digicosmos86 Feb 13, 2026
8a9d3db
Merge branch 'main' into 886-implement-the-first-choice-only-model
digicosmos86 Feb 13, 2026
fb6f4e1
fix: remove internal states from DataValidatorMixin
digicosmos86 Feb 24, 2026
2a0cf63
fix: remove "mu" parameter from softmax_inv_temperature_config
digicosmos86 Feb 24, 2026
41738c0
fix: chose what to pass to lapse_func based on the dimensions of data
digicosmos86 Feb 24, 2026
a402be0
tests: basic tests for choice-only models
digicosmos86 Feb 24, 2026
131ae47
fix: get around some data validation logic for choice only models
digicosmos86 Feb 24, 2026
2b85c83
fix: finalize softmax_inv_temperature function
digicosmos86 Feb 24, 2026
f02fe1a
feat: added a is_choice_only parameter to make_distribution to get ar…
digicosmos86 Feb 24, 2026
d1dbbea
feat: added tests to choice only models
digicosmos86 Feb 24, 2026
b1e60a6
fix: separated tests for choice_only likelihoods and actual tests for…
digicosmos86 Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,8 @@ explorations/
*.nbconvert.ipynb
pypi-token
pre-commit

# AI slops
.github/copilot-instructions.md
CLAUDE.md
.claude/
2 changes: 2 additions & 0 deletions src/hssm/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
"ddm_seq2_no_bias",
"lba3",
"lba2",
"softmax_inv_temperature_2",
"softmax_inv_temperature_3",
]


Expand Down
7 changes: 7 additions & 0 deletions src/hssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,13 @@ def get_defaults(
"""
return self.default_priors.get(param), self.bounds.get(param)

@property
def is_choice_only(self) -> bool:
"""Check if the model is a choice-only model."""
if self.response is None:
raise ValueError("Response is not defined in the configuration.")
return len(self.response) == 1


@dataclass
class RLSSMConfig(BaseModelConfig):
Expand Down
45 changes: 17 additions & 28 deletions src/hssm/data_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import logging
import warnings

import numpy as np # noqa: F401
import pandas as pd # noqa: F401
import numpy as np
import pandas as pd

from hssm.defaults import MissingDataNetwork # noqa: F401
from hssm.defaults import MissingDataNetwork

_logger = logging.getLogger("hssm")

Expand All @@ -26,31 +26,16 @@ class DataValidatorMixin:
- missing_data_value: float
"""

def __init__(
self,
data: pd.DataFrame,
response: list[str] | None = ["rt", "response"],
choices: list[int] | None = [0, 1],
n_choices: int = 2,
extra_fields: list[str] | None = None,
deadline: bool = False,
deadline_name: str = "deadline",
missing_data: bool = False,
missing_data_value: float = -999.0,
):
"""Initialize the DataValidatorMixin.

Init method kept for testing purposes.
"""
self.data = data
self.response = response
self.choices = choices
self.n_choices = n_choices
self.extra_fields = extra_fields
self.deadline = deadline
self.deadline_name = deadline_name
self.missing_data = missing_data
self.missing_data_value = missing_data_value
data: pd.DataFrame
response: list[str]
choices: list[int]
n_choices: int
extra_fields: list[str] | None
deadline: bool
deadline_name: str
missing_data: bool
missing_data_value: float
is_choice_only: bool

Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing DataValidatorMixin.__init__ is a breaking change for existing direct instantiations (the repository’s tests/test_data_validator.py constructs DataValidatorMixin(...)). Either restore a minimal/backward-compatible __init__ (even if only for tests) or update the tests and any downstream usage to instantiate a concrete class that sets the required attributes.

Suggested change
def __init__(
self,
data: pd.DataFrame | None = None,
response: list[str] | None = None,
choices: list[int] | None = None,
n_choices: int | None = None,
extra_fields: list[str] | None = None,
deadline: bool = False,
deadline_name: str | None = None,
missing_data: bool = False,
missing_data_value: float = -999.0,
is_choice_only: bool = False,
) -> None:
"""Initialize DataValidatorMixin with optional arguments.
This minimal initializer is provided for backward compatibility with
direct instantiation (e.g., in tests). Subclasses are free to override
this method and set these attributes themselves.
"""
self.data = data if data is not None else pd.DataFrame()
self.response = response if response is not None else []
self.choices = choices if choices is not None else []
# If n_choices is not provided, infer from choices if available.
if n_choices is not None:
self.n_choices = n_choices
else:
self.n_choices = len(self.choices) if self.choices is not None else 0
self.extra_fields = extra_fields
self.deadline = deadline
self.deadline_name = deadline_name if deadline_name is not None else "deadline"
self.missing_data = missing_data
self.missing_data_value = missing_data_value
self.is_choice_only = is_choice_only

Copilot uses AI. Check for mistakes.
@staticmethod
def check_fields(a, b):
Expand Down Expand Up @@ -131,6 +116,10 @@ def _post_check_data_sanity(self):
def _handle_missing_data_and_deadline(self):
"""Handle missing data and deadline."""
if not self.missing_data and not self.deadline:
# In the case of choice only model, we don't need to do anything with the
# data.
if self.is_choice_only:
return
# In the case where missing_data is set to False, we need to drop the
# cases where rt = na_value
if pd.isna(self.missing_data_value):
Expand Down
27 changes: 23 additions & 4 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def make_hssm_rv(
simulator_fun: Callable | str,
list_params: list[str],
lapse: bmb.Prior | None = None,
is_choice_only: bool = False,
) -> type[RandomVariable]:
"""Build a RandomVariable Op according to the list of parameters.

Expand All @@ -202,6 +203,8 @@ def make_hssm_rv(
A list of str of all parameters for this `RandomVariable`.
lapse : optional
A bmb.Prior object representing the lapse distribution.
is_choice_only : bool
Whether the model is a choice-only model. This parameter overrides
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring for is_choice_only is incomplete (ends with “This parameter overrides” without saying what it overrides). Please complete or remove the fragment so the generated docs/readers aren’t left with a partial sentence.

Suggested change
Whether the model is a choice-only model. This parameter overrides
Whether the model is a choice-only model.

Copilot uses AI. Check for mistakes.

Returns
-------
Expand All @@ -225,7 +228,12 @@ class HSSMRV(RandomVariable):
# parameter is a scalar. The string to the right of the
# `->` sign describes the output signature, which is `(2)`, which means the
# random variable is a length-2 array.
signature: str = f"{','.join(['()'] * len(list_params))}->({obs_dim_int})"

# Override the output from ssm_simulator based on whether the model is
# choice-only.
output = "()" if is_choice_only else f"({obs_dim_int})"
signature: str = f"{','.join(['()'] * len(list_params))}->{output}"

dtype: str = "floatX"
_print_name: tuple[str, str] = ("SSM", "\\operatorname{SSM}")
_list_params = list_params
Expand Down Expand Up @@ -389,6 +397,7 @@ def make_distribution(
extra_fields: list[np.ndarray] | None = None,
fixed_vector_params: dict[str, np.ndarray] | None = None,
params_is_trialwise: list[bool] | None = None,
is_choice_only: bool = False,
) -> type[pm.Distribution]:
"""Make a `pymc.Distribution`.

Comment on lines 397 to 403
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_distribution now accepts is_choice_only, but the flag isn’t propagated when rv is a callable or a string (the internal make_hssm_rv(...) calls still use the default is_choice_only=False). This will build an RV with the wrong output signature for choice-only models unless the RV is provided as a class. Pass is_choice_only=is_choice_only into those make_hssm_rv calls.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -437,6 +446,8 @@ def make_distribution(
that vmapped JAX log-likelihoods receive consistently shaped inputs,
regardless of whether Bambi produces ``(1,)`` or ``(n_obs,)`` tensors.
When ``None``, no graph-level broadcasting is applied.
is_choice_only : optional
Whether the model is a choice-only model.

Returns
-------
Expand Down Expand Up @@ -561,12 +572,15 @@ def logp(data, *dist_params): # pylint: disable=E0213
"lapse_func is not defined. "
"Make sure lapse is properly initialized."
)
lapse_logp = lapse_func(data[:, 0].eval())
data_for_lapse = data if is_choice_only else data[:, 0]
lapse_logp = lapse_func(data_for_lapse.eval())

# AF-TODO potentially apply clipping here
logp = loglik(data, *dist_params, *extra_fields)
# Ensure that non-decision time is always smaller than rt.
# Assuming that the non-decision time parameter is always named "t".
logp = ensure_positive_ndt(data, logp, list_params, dist_params)
if not is_choice_only:
logp = ensure_positive_ndt(data, logp, list_params, dist_params)
logp = pt.log(
(1.0 - p_outlier) * pt.exp(logp)
+ p_outlier * pt.exp(lapse_logp)
Expand All @@ -575,7 +589,8 @@ def logp(data, *dist_params): # pylint: disable=E0213
else:
logp = loglik(data, *dist_params, *extra_fields)
# Ensure that non-decision time is always smaller than rt.
logp = ensure_positive_ndt(data, logp, list_params, dist_params)
if not is_choice_only:
logp = ensure_positive_ndt(data, logp, list_params, dist_params)

if bounds is not None:
logp = apply_param_bounds_to_loglik(
Expand All @@ -593,6 +608,7 @@ def make_distribution_for_supported_model(
backend: Literal["pytensor", "jax", "other"] = "pytensor",
reg_params: list[str] | None = None,
lapse: bmb.Prior | None = None,
is_choice_only: bool = False,
) -> type[pm.Distribution]:
"""Make a pm.Distribution class for a supported model.

Expand All @@ -614,6 +630,8 @@ class that can be used for PyMC modeling.
parameters are assumed.
lapse : optional
A bmb.Prior object representing the lapse distribution.
is_choice_only : optional
Whether the model is a choice-only model.
"""
supported_models = get_args(SupportedModels)
if model not in supported_models:
Expand Down Expand Up @@ -643,6 +661,7 @@ class that can be used for PyMC modeling.
list_params=config.list_params,
bounds=config.bounds,
lapse=lapse,
is_choice_only=is_choice_only,
)


Expand Down
106 changes: 77 additions & 29 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
assemble_callables,
make_distribution,
make_family,
make_hssm_rv,
make_likelihood_callable,
make_missing_data_callable,
)
Expand Down Expand Up @@ -396,50 +397,64 @@ def __init__(
self.model_config.validate()

# Set up shortcuts so old code will work
self.response = self.model_config.response
self.response = self.model_config.response[:]
self.list_params = self.model_config.list_params
self.choices = self.model_config.choices
self.model_name = self.model_config.model_name
self.loglik = self.model_config.loglik
self.loglik_kind = self.model_config.loglik_kind
self.extra_fields = self.model_config.extra_fields

self.response = cast("list[str]", self.response)
self.is_choice_only: bool = self.model_config.is_choice_only

if self.choices is None:
raise ValueError(
"`choices` must be provided either in `model_config` or as an argument."
)

self.n_choices = len(self.choices)

self._validate_choices()
self._pre_check_data_sanity()

# Process missing data setting
# AF-TODO: Could be a function in data validator?
if isinstance(missing_data, float):
if not ((self.data.rt == missing_data).any()):
raise ValueError(
f"missing_data argument is provided as a float {missing_data}, "
f"However, you have no RTs of {missing_data} in your dataset!"
)
if self.is_choice_only and missing_data:
raise ValueError("Choice-only models cannot have missing data.")

if not self.is_choice_only:
if isinstance(missing_data, float):
if not ((self.data.rt == missing_data).any()):
raise ValueError(
f"missing_data argument is provided as a float {missing_data}, "
f"However, you have no RTs of {missing_data} in your dataset!"
)
else:
self.missing_data = True
self.missing_data_value = missing_data
elif isinstance(missing_data, bool):
if missing_data and (not (self.data.rt == -999.0).any()):
raise ValueError(
"missing_data argument is provided as True, "
" so RTs of -999.0 are treated as missing. \n"
"However, you have no RTs of -999.0 in your dataset!"
)
elif (not missing_data) and (self.data.rt == -999.0).any():
# self.missing_data = True
raise ValueError(
"Missing data provided as False. \n"
"However, you have RTs of -999.0 in your dataset!"
)
else:
self.missing_data = missing_data
else:
self.missing_data = True
self.missing_data_value = missing_data
elif isinstance(missing_data, bool):
if missing_data and (not (self.data.rt == -999.0).any()):
raise ValueError(
"missing_data argument is provided as True, "
" so RTs of -999.0 are treated as missing. \n"
"However, you have no RTs of -999.0 in your dataset!"
"missing_data argument must be a bool or a float! \n"
f"You provided: {type(missing_data)}"
)
elif (not missing_data) and (self.data.rt == -999.0).any():
# self.missing_data = True
raise ValueError(
"Missing data provided as False. \n"
"However, you have RTs of -999.0 in your dataset!"
)
else:
self.missing_data = missing_data
else:
raise ValueError(
"missing_data argument must be a bool or a float! \n"
f"You provided: {type(missing_data)}"
)
self.missing_data = False

if isinstance(deadline, str):
self.deadline = True
Expand Down Expand Up @@ -501,7 +516,8 @@ def __init__(
self.p_outlier = self.params.get("p_outlier")
self.lapse = lapse if self.has_lapse else None

self._post_check_data_sanity()
if not self.is_choice_only:
self._post_check_data_sanity()

self.model_distribution = self._make_model_distribution()

Expand Down Expand Up @@ -1358,9 +1374,15 @@ def set_alias(self, aliases: dict[str, str | dict]):

@property
def response_c(self) -> str:
"""Return the response variable names in c() format."""
"""Return the response variable names in c() format.

New in 0.2.12: when model is choice-only and has deadline, the response
is not in the form of c(...).
Comment on lines +1379 to +1380
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says that for choice-only models with a deadline the response is “not in the form of c(...)”, but the implementation returns c(response, deadline) in that case. Please update the docstring to match the actual behavior (or adjust the behavior if the doc is correct).

Suggested change
New in 0.2.12: when model is choice-only and has deadline, the response
is not in the form of c(...).
New in 0.2.12: when the model is choice-only and has no deadline, the
response is returned as a single variable name (e.g., ``"choice"``)
instead of in the form ``c(...)``. In all other cases, the response is
returned in ``c(...)`` format.

Copilot uses AI. Check for mistakes.
"""
if self.response is None:
return "c()"
raise ValueError("Response is not defined.")
if self.is_choice_only and not self.deadline:
return self.response[0]
return f"c({', '.join(self.response)})"

@property
Expand Down Expand Up @@ -2096,6 +2118,31 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
self.missing_data_value,
)

if self.is_choice_only and self.model_config.rv is None:
_logger.warning(
"You are building a choice-only model without specifying "
"a RandomVariable class. Using a dummy simulator function. "
"Simulating data from this model will result in an error."
)

def dummy_simulator_func(*args, **kwargs):
raise NotImplementedError(
"You are trying to simulate data from a choice-only model "
"without specifying a RandomVariable class. Please specify "
"a RandomVariable class via the `model_config.rv` argument."
)

setattr(dummy_simulator_func, "model_name", self.model_name)
setattr(dummy_simulator_func, "choices", self.choices)
setattr(dummy_simulator_func, "obs_dim", 1)

self.model_config.rv = make_hssm_rv(
dummy_simulator_func,
list_params=self.list_params,
lapse=self.lapse,
is_choice_only=True,
)

self.data = _rearrange_data(self.data)

# Collect fixed-vector params to substitute in the distribution logp
Expand All @@ -2118,6 +2165,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
),
fixed_vector_params=fixed_vector_params if fixed_vector_params else None,
params_is_trialwise=params_is_trialwise_base,
is_choice_only=self.is_choice_only,
)

def _get_deterministic_var_names(self, idata) -> list[str]:
Expand Down
Loading
Loading