@@ -73,6 +73,13 @@ class HSSM:
7373 "ddm_seq2_no_bias". If any other string is passed, the model will be considered
7474 custom, in which case all `model_config`, `loglik`, and `loglik_kind` have to be
7575 provided by the user.
76+ choices : optional
77+ When an `int`, the number of choices that the participants can make. If `2`, the
78+ choices are [-1, 1] by default. If anything greater than `2`, the choices are
79+ [0, 1, ..., n_choices - 1] by default. If a `list` is provided, it should be the
80+ list of choices that the participants can make. Defaults to `2`. If any value
81+ other than the choices provided is found in the "response" column of the data,
82+ an error will be raised.
7683 include : optional
7784 A list of dictionaries specifying parameter specifications to include in the
7885 model. If left unspecified, defaults will be used for all parameter
@@ -225,6 +232,7 @@ def __init__(
225232 self ,
226233 data : pd .DataFrame ,
227234 model : SupportedModels | str = "ddm" ,
235+ choices : int | list [int ] = 2 ,
228236 include : list [dict | Param ] | None = None ,
229237 model_config : ModelConfig | dict | None = None ,
230238 loglik : (
@@ -282,8 +290,20 @@ def __init__(
282290 self .loglik_kind = self .model_config .loglik_kind
283291 self .extra_fields = self .model_config .extra_fields
284292
285- self .choices = self .data ["response" ].unique ().astype (int )
286- self .n_choices = len (self .choices )
293+ if isinstance (choices , int ):
294+ if choices == 2 :
295+ self .n_choices = 2
296+ self .choices = [- 1 , 1 ]
297+ elif choices > 2 :
298+ self .n_choices = choices
299+ self .choices = list (range (choices ))
300+ else :
301+ raise ValueError ("choices must be greater than 1." )
302+ elif isinstance (choices , list ):
303+ self .n_choices = len (choices )
304+ self .choices = choices
305+ else :
306+ raise ValueError ("choices must be an integer or a list of integers." )
287307
288308 self ._pre_check_data_sanity ()
289309
@@ -1393,13 +1413,6 @@ def _pre_check_data_sanity(self):
13931413 + "`participant_id` is not found in your dataset."
13941414 )
13951415
1396- if self .n_choices == 2 :
1397- if - 1 not in self .choices or 1 not in self .choices :
1398- raise ValueError (
1399- "The response column must contain only -1 and 1 when there are "
1400- + "two responses."
1401- )
1402-
14031416 def _post_check_data_sanity (self ):
14041417 """Check if the data is clean enough for the model."""
14051418 if self .deadline or self .missing_data :
@@ -1425,6 +1438,24 @@ def _post_check_data_sanity(self):
14251438 + "which is not allowed."
14261439 )
14271440
1441+ valid_responses = self .data .loc [self .data ["rt" ] != - 999.0 , "response" ]
1442+ unique_responses = valid_responses .unique ().astype (int )
1443+
1444+ if np .any (~ np .isin (unique_responses , self .choices )):
1445+ invalid_responses = sorted (
1446+ unique_responses [~ np .isin (unique_responses , self .choices )]
1447+ )
1448+ raise ValueError (
1449+ f"Invalid responses found in your dataset: { invalid_responses } "
1450+ )
1451+
1452+ if len (unique_responses ) != self .n_choices :
1453+ missing_responses = sorted (np .setdiff1d (self .choices , unique_responses ))
1454+ _logger .warning (
1455+ f"You set choices to be { self .choices } , but { missing_responses } are "
1456+ + "missing from your dataset."
1457+ )
1458+
14281459 def _postprocess_initvals_deterministic (
14291460 self , initval_settings : dict = INITVAL_SETTINGS
14301461 ) -> None :
0 commit comments