Skip to content

Commit 19f786d

Browse files
authored
Merge pull request #430 from lnccbrown/427-add-choices-argument-to-hssm
Allow users to specify the number of choices
2 parents 107e071 + 7b934c2 commit 19f786d

6 files changed

Lines changed: 113 additions & 27 deletions

File tree

.github/workflows/build_and_publish.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,11 @@ jobs:
4949
- name: Linting
5050
run: ruff check src/hssm
5151

52-
- name: Run tests
53-
run: pytest -n auto -s
52+
- name: Run fast tests
53+
run: pytest -n auto -s --ignore=tests/slow
54+
55+
- name: Run slow tests
56+
run: pytest -n auto -s tests/slow
5457

5558
publish:
5659
name: Build wheel and publish to test-PyPI, and then PyPI, and publish docs

.github/workflows/run_tests.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@ jobs:
4747
- name: Linting
4848
run: ruff check src/hssm
4949

50-
- name: Run tests
51-
run: pytest -n auto -s
50+
- name: Run fast tests
51+
run: pytest -n auto -s --ignore=tests/slow
52+
53+
- name: Run slow tests
54+
run: pytest -n auto -s tests/slow
5255

5356
- name: build docs
5457
run: mkdocs build

src/hssm/hssm.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ class HSSM:
7373
"ddm_seq2_no_bias". If any other string is passed, the model will be considered
7474
custom, in which case all `model_config`, `loglik`, and `loglik_kind` have to be
7575
provided by the user.
76+
choices : optional
77+
When an `int`, the number of choices that the participants can make. If `2`, the
78+
choices are [-1, 1] by default. If anything greater than `2`, the choices are
79+
[0, 1, ..., n_choices - 1] by default. If a `list` is provided, it should be the
80+
list of choices that the participants can make. Defaults to `2`. If any value
81+
other than the choices provided is found in the "response" column of the data,
82+
an error will be raised.
7683
include : optional
7784
A list of dictionaries specifying parameter specifications to include in the
7885
model. If left unspecified, defaults will be used for all parameter
@@ -225,6 +232,7 @@ def __init__(
225232
self,
226233
data: pd.DataFrame,
227234
model: SupportedModels | str = "ddm",
235+
choices: int | list[int] = 2,
228236
include: list[dict | Param] | None = None,
229237
model_config: ModelConfig | dict | None = None,
230238
loglik: (
@@ -282,8 +290,20 @@ def __init__(
282290
self.loglik_kind = self.model_config.loglik_kind
283291
self.extra_fields = self.model_config.extra_fields
284292

285-
self.choices = self.data["response"].unique().astype(int)
286-
self.n_choices = len(self.choices)
293+
if isinstance(choices, int):
294+
if choices == 2:
295+
self.n_choices = 2
296+
self.choices = [-1, 1]
297+
elif choices > 2:
298+
self.n_choices = choices
299+
self.choices = list(range(choices))
300+
else:
301+
raise ValueError("choices must be greater than 1.")
302+
elif isinstance(choices, list):
303+
self.n_choices = len(choices)
304+
self.choices = choices
305+
else:
306+
raise ValueError("choices must be an integer or a list of integers.")
287307

288308
self._pre_check_data_sanity()
289309

@@ -1393,13 +1413,6 @@ def _pre_check_data_sanity(self):
13931413
+ "`participant_id` is not found in your dataset."
13941414
)
13951415

1396-
if self.n_choices == 2:
1397-
if -1 not in self.choices or 1 not in self.choices:
1398-
raise ValueError(
1399-
"The response column must contain only -1 and 1 when there are "
1400-
+ "two responses."
1401-
)
1402-
14031416
def _post_check_data_sanity(self):
14041417
"""Check if the data is clean enough for the model."""
14051418
if self.deadline or self.missing_data:
@@ -1425,6 +1438,24 @@ def _post_check_data_sanity(self):
14251438
+ "which is not allowed."
14261439
)
14271440

1441+
valid_responses = self.data.loc[self.data["rt"] != -999.0, "response"]
1442+
unique_responses = valid_responses.unique().astype(int)
1443+
1444+
if np.any(~np.isin(unique_responses, self.choices)):
1445+
invalid_responses = sorted(
1446+
unique_responses[~np.isin(unique_responses, self.choices)]
1447+
)
1448+
raise ValueError(
1449+
f"Invalid responses found in your dataset: {invalid_responses}"
1450+
)
1451+
1452+
if len(unique_responses) != self.n_choices:
1453+
missing_responses = sorted(np.setdiff1d(self.choices, unique_responses))
1454+
_logger.warning(
1455+
f"You set choices to be {self.choices}, but {missing_responses} are "
1456+
+ "missing from your dataset."
1457+
)
1458+
14281459
def _postprocess_initvals_deterministic(
14291460
self, initval_settings: dict = INITVAL_SETTINGS
14301461
) -> None:

tests/conftest.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def data_angle():
3737
def data_ddm_reg():
3838
# Generate some fake simulation data
3939
intercept = 1.5
40-
x = np.random.uniform(-5.0, 5.0, size=1000)
41-
y = np.random.uniform(-5.0, 5.0, size=1000)
40+
x = np.random.uniform(-0.5, 0.5, size=1000)
41+
y = np.random.uniform(-0.5, 0.5, size=1000)
4242

4343
v = intercept + 0.8 * x + 0.3 * y
4444
true_values = np.column_stack(
@@ -57,6 +57,35 @@ def data_ddm_reg():
5757
return dataset_reg_v
5858

5959

60+
@pytest.fixture(scope="module")
61+
def data_ddm_reg_va():
62+
# Generate some fake simulation data
63+
intercept = 1.5
64+
intercept_a = 1.0
65+
x = np.random.uniform(-0.5, 0.5, size=100)
66+
y = np.random.uniform(-0.5, 0.5, size=100)
67+
68+
m = np.random.uniform(-0.5, 0.5, size=100)
69+
n = np.random.uniform(-0.5, 0.5, size=100)
70+
71+
v = intercept + 0.8 * x + 0.3 * y
72+
a = intercept_a + 0.1 * m + 0.1 * n
73+
true_values = np.column_stack([v, a, np.repeat([[0.5, 0.5]], axis=0, repeats=100)])
74+
75+
dataset_reg_va = hssm.simulate_data(
76+
model="ddm",
77+
theta=true_values,
78+
size=1, # Generate one data point for each of the 1000 set of true values
79+
)
80+
81+
dataset_reg_va["x"] = x
82+
dataset_reg_va["y"] = y
83+
dataset_reg_va["m"] = m
84+
dataset_reg_va["n"] = n
85+
86+
return dataset_reg_va
87+
88+
6089
@pytest.fixture
6190
def cav_idata():
6291
return az.from_netcdf("tests/fixtures/cavanagh_idata.nc")

tests/slow/test_mcmc.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected)
184184

185185

186186
@pytest.mark.parametrize(parameter_names, parameter_grid)
187-
def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expected):
187+
def test_reg_models_v_a(data_ddm_reg_va, loglik_kind, backend, sampler, step, expected):
188188
print("PYMC VERSION: ")
189189
print(pm.__version__)
190190
print("TEST INPUTS WERE: ")
@@ -199,27 +199,27 @@ def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expec
199199
},
200200
)
201201
param_reg_a = dict(
202-
formula="a ~ 1 + x + y",
202+
formula="a ~ 1 + m + n",
203203
prior={
204204
"Intercept": {
205-
"name": "Uniform",
206-
"lower": 0.5,
207-
"upper": 3.0,
208-
"initval": 1.0,
205+
"name": "Normal",
206+
"mu": 1.0,
207+
"sigma": 0.5,
209208
},
210-
"x": {"name": "Uniform", "lower": -0.50, "upper": 0.50, "initval": 0.0},
211-
"y": {"name": "Uniform", "lower": -0.50, "upper": 0.50, "initval": 0.0},
209+
"m": {"name": "Uniform", "lower": 0.0, "upper": 0.2},
210+
"n": {"name": "Uniform", "lower": 0.0, "upper": 0.2},
212211
},
213212
link="identity",
214213
)
215214

216215
model = hssm.HSSM(
217-
data_ddm_reg,
216+
data_ddm_reg_va,
218217
loglik_kind=loglik_kind,
219218
model_config={"backend": backend},
220219
v=param_reg_v,
221220
a=param_reg_a,
222221
)
222+
print(model.params["a"])
223223
run_sample(model, sampler, step, expected)
224224

225225
# Only runs once

tests/test_data_sanity.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def test_data_sanity_check(data_ddm, cpn, caplog):
4141
):
4242
hssm.HSSM(data=data_ddm, model="ddm", hierarchical=True)
4343

44-
# Case 5: raise error if there are missing fields in data
44+
# Case 5: raise error if there are invalid responses in data
4545
with pytest.raises(
4646
ValueError,
47-
match="The response column must contain only -1 and 1 when there are two responses.",
47+
match=r"Invalid responses found in your dataset: \[0\]",
4848
):
4949
data_ddm_miscoded = data_ddm.copy()
5050
data_ddm_miscoded["response"] = data_ddm_miscoded["response"].replace(
@@ -53,7 +53,27 @@ def test_data_sanity_check(data_ddm, cpn, caplog):
5353

5454
hssm.HSSM(data=data_ddm_miscoded, model="ddm")
5555

56-
# Case 6: if deadline or missing_data is True, data should contain missing values
56+
with pytest.raises(
57+
ValueError,
58+
match=r"Invalid responses found in your dataset: \[0\]",
59+
):
60+
data_ddm_miscoded = data_ddm.copy()
61+
data_ddm_miscoded["response"] = np.random.choice([0, 1, 2], data_ddm.shape[0])
62+
63+
hssm.HSSM(data=data_ddm_miscoded, model="ddm", choices=[1, 2, 3])
64+
65+
# Case 6: raise warning if there are missing responses in data
66+
data_ddm_miscoded = data_ddm.copy()
67+
data_ddm_miscoded["response"] = np.random.choice([1, 2], data_ddm.shape[0])
68+
69+
hssm.HSSM(data=data_ddm_miscoded, model="ddm", choices=[1, 2, 3])
70+
71+
assert (
72+
caplog.records[-1].msg
73+
== "You set choices to be [1, 2, 3], but [3] are missing from your dataset."
74+
)
75+
76+
# Case 7: if deadline or missing_data is True, data should contain missing values
5777
with pytest.raises(
5878
ValueError,
5979
match="You have no missing data in your dataset, "

0 commit comments

Comments
 (0)