-
Notifications
You must be signed in to change notification settings - Fork 19
[WIP] Initial implementation of choice-only models #903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
16647ef
763e870
4ae091d
85993c9
8f4b681
c6f6d50
5dc7844
66ea7b7
55d1b38
fb67a71
11aaeef
3fe87a4
4155210
2f12aeb
22e1fdf
a4f2181
5e89809
8a9d3db
fb6f4e1
2a0cf63
41738c0
a402be0
131ae47
2b85c83
f02fe1a
d1dbbea
b1e60a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,8 @@ | |
| "ddm_seq2_no_bias", | ||
| "lba3", | ||
| "lba2", | ||
| "softmax_inv_temperature_2", | ||
| "softmax_inv_temperature_3", | ||
| ] | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -191,6 +191,7 @@ def make_hssm_rv( | |||||
| simulator_fun: Callable | str, | ||||||
| list_params: list[str], | ||||||
| lapse: bmb.Prior | None = None, | ||||||
| is_choice_only: bool = False, | ||||||
| ) -> type[RandomVariable]: | ||||||
| """Build a RandomVariable Op according to the list of parameters. | ||||||
|
|
||||||
|
|
@@ -202,6 +203,8 @@ def make_hssm_rv( | |||||
| A list of str of all parameters for this `RandomVariable`. | ||||||
| lapse : optional | ||||||
| A bmb.Prior object representing the lapse distribution. | ||||||
| is_choice_only : bool | ||||||
| Whether the model is a choice-only model. This parameter overrides | ||||||
|
||||||
| Whether the model is a choice-only model. This parameter overrides | |
| Whether the model is a choice-only model. |
Copilot
AI
Feb 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make_distribution now accepts is_choice_only, but the flag isn’t propagated when rv is a callable or a string (the internal make_hssm_rv(...) calls still use the default is_choice_only=False). This will build an RV with the wrong output signature for choice-only models unless the RV is provided as a class. Pass is_choice_only=is_choice_only into those make_hssm_rv calls.
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -43,6 +43,7 @@ | |||||||||||||
| assemble_callables, | ||||||||||||||
| make_distribution, | ||||||||||||||
| make_family, | ||||||||||||||
| make_hssm_rv, | ||||||||||||||
| make_likelihood_callable, | ||||||||||||||
| make_missing_data_callable, | ||||||||||||||
| ) | ||||||||||||||
|
|
@@ -396,50 +397,64 @@ def __init__( | |||||||||||||
| self.model_config.validate() | ||||||||||||||
|
|
||||||||||||||
| # Set up shortcuts so old code will work | ||||||||||||||
| self.response = self.model_config.response | ||||||||||||||
| self.response = self.model_config.response[:] | ||||||||||||||
| self.list_params = self.model_config.list_params | ||||||||||||||
| self.choices = self.model_config.choices | ||||||||||||||
| self.model_name = self.model_config.model_name | ||||||||||||||
| self.loglik = self.model_config.loglik | ||||||||||||||
| self.loglik_kind = self.model_config.loglik_kind | ||||||||||||||
| self.extra_fields = self.model_config.extra_fields | ||||||||||||||
|
|
||||||||||||||
| self.response = cast("list[str]", self.response) | ||||||||||||||
| self.is_choice_only: bool = self.model_config.is_choice_only | ||||||||||||||
|
|
||||||||||||||
| if self.choices is None: | ||||||||||||||
| raise ValueError( | ||||||||||||||
| "`choices` must be provided either in `model_config` or as an argument." | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| self.n_choices = len(self.choices) | ||||||||||||||
|
|
||||||||||||||
| self._validate_choices() | ||||||||||||||
| self._pre_check_data_sanity() | ||||||||||||||
|
|
||||||||||||||
| # Process missing data setting | ||||||||||||||
| # AF-TODO: Could be a function in data validator? | ||||||||||||||
| if isinstance(missing_data, float): | ||||||||||||||
| if not ((self.data.rt == missing_data).any()): | ||||||||||||||
| raise ValueError( | ||||||||||||||
| f"missing_data argument is provided as a float {missing_data}, " | ||||||||||||||
| f"However, you have no RTs of {missing_data} in your dataset!" | ||||||||||||||
| ) | ||||||||||||||
| if self.is_choice_only and missing_data: | ||||||||||||||
| raise ValueError("Choice-only models cannot have missing data.") | ||||||||||||||
|
|
||||||||||||||
| if not self.is_choice_only: | ||||||||||||||
| if isinstance(missing_data, float): | ||||||||||||||
| if not ((self.data.rt == missing_data).any()): | ||||||||||||||
| raise ValueError( | ||||||||||||||
| f"missing_data argument is provided as a float {missing_data}, " | ||||||||||||||
| f"However, you have no RTs of {missing_data} in your dataset!" | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| self.missing_data = True | ||||||||||||||
| self.missing_data_value = missing_data | ||||||||||||||
| elif isinstance(missing_data, bool): | ||||||||||||||
| if missing_data and (not (self.data.rt == -999.0).any()): | ||||||||||||||
| raise ValueError( | ||||||||||||||
| "missing_data argument is provided as True, " | ||||||||||||||
| " so RTs of -999.0 are treated as missing. \n" | ||||||||||||||
| "However, you have no RTs of -999.0 in your dataset!" | ||||||||||||||
| ) | ||||||||||||||
| elif (not missing_data) and (self.data.rt == -999.0).any(): | ||||||||||||||
| # self.missing_data = True | ||||||||||||||
| raise ValueError( | ||||||||||||||
| "Missing data provided as False. \n" | ||||||||||||||
| "However, you have RTs of -999.0 in your dataset!" | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| self.missing_data = missing_data | ||||||||||||||
| else: | ||||||||||||||
| self.missing_data = True | ||||||||||||||
| self.missing_data_value = missing_data | ||||||||||||||
| elif isinstance(missing_data, bool): | ||||||||||||||
| if missing_data and (not (self.data.rt == -999.0).any()): | ||||||||||||||
| raise ValueError( | ||||||||||||||
| "missing_data argument is provided as True, " | ||||||||||||||
| " so RTs of -999.0 are treated as missing. \n" | ||||||||||||||
| "However, you have no RTs of -999.0 in your dataset!" | ||||||||||||||
| "missing_data argument must be a bool or a float! \n" | ||||||||||||||
| f"You provided: {type(missing_data)}" | ||||||||||||||
| ) | ||||||||||||||
| elif (not missing_data) and (self.data.rt == -999.0).any(): | ||||||||||||||
| # self.missing_data = True | ||||||||||||||
| raise ValueError( | ||||||||||||||
| "Missing data provided as False. \n" | ||||||||||||||
| "However, you have RTs of -999.0 in your dataset!" | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| self.missing_data = missing_data | ||||||||||||||
| else: | ||||||||||||||
| raise ValueError( | ||||||||||||||
| "missing_data argument must be a bool or a float! \n" | ||||||||||||||
| f"You provided: {type(missing_data)}" | ||||||||||||||
| ) | ||||||||||||||
| self.missing_data = False | ||||||||||||||
|
|
||||||||||||||
| if isinstance(deadline, str): | ||||||||||||||
| self.deadline = True | ||||||||||||||
|
|
@@ -501,7 +516,8 @@ def __init__( | |||||||||||||
| self.p_outlier = self.params.get("p_outlier") | ||||||||||||||
| self.lapse = lapse if self.has_lapse else None | ||||||||||||||
|
|
||||||||||||||
| self._post_check_data_sanity() | ||||||||||||||
| if not self.is_choice_only: | ||||||||||||||
| self._post_check_data_sanity() | ||||||||||||||
|
|
||||||||||||||
| self.model_distribution = self._make_model_distribution() | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -1358,9 +1374,15 @@ def set_alias(self, aliases: dict[str, str | dict]): | |||||||||||||
|
|
||||||||||||||
| @property | ||||||||||||||
| def response_c(self) -> str: | ||||||||||||||
| """Return the response variable names in c() format.""" | ||||||||||||||
| """Return the response variable names in c() format. | ||||||||||||||
|
|
||||||||||||||
| New in 0.2.12: when model is choice-only and has deadline, the response | ||||||||||||||
| is not in the form of c(...). | ||||||||||||||
|
Comment on lines
+1379
to
+1380
|
||||||||||||||
| New in 0.2.12: when model is choice-only and has deadline, the response | |
| is not in the form of c(...). | |
| New in 0.2.12: when the model is choice-only and has no deadline, the | |
| response is returned as a single variable name (e.g., ``"choice"``) | |
| instead of in the form ``c(...)``. In all other cases, the response is | |
| returned in ``c(...)`` format. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing
DataValidatorMixin.__init__is a breaking change for existing direct instantiations (the repository’stests/test_data_validator.pyconstructsDataValidatorMixin(...)). Either restore a minimal/backward-compatible__init__(even if only for tests) or update the tests and any downstream usage to instantiate a concrete class that sets the required attributes.