Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 20 additions & 32 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,10 @@ def sample(
Returns
-------
az.InferenceData | pm.Approximation
An ArviZ `InferenceData` instance if inference_method is `"mcmc"`
(default), "nuts_numpyro", "nuts_blackjax" or "laplace". An `Approximation`
object if `"vi"`.
A reference to the `model.traces` object, which stores the traces of the
last call to `model.sample()`. `model.traces` is an ArviZ `InferenceData`
instance if `sampler` is `"mcmc"` (default), `"nuts_numpyro"`,
`"nuts_blackjax"` or "`laplace"`, or an `Approximation` object if `"vi"`.
"""
# If initvals are None (default)
# we skip processing initvals here.
Expand Down Expand Up @@ -537,18 +538,16 @@ def sample(
# If sampler is finally `numpyro` make sure
# the jitter argument is set to False
if sampler == "nuts_numpyro":
if "jitter" not in kwargs.keys():
kwargs["jitter"] = False
elif kwargs["jitter"]:
if kwargs.get("jitter", None):
_logger.warning(
"The jitter argument is set to True. "
+ "This argument is not supported "
+ "by the numpyro backend. "
+ "The jitter argument will be set to False."
)
kwargs["jitter"] = False
elif sampler != "nuts_numpyro":
if "jitter" in kwargs.keys():
kwargs["jitter"] = False
else:
if "jitter" in kwargs:
_logger.warning(
"The jitter keyword argument is "
+ "supported only by the nuts_numpyro sampler. \n"
Expand All @@ -560,27 +559,21 @@ def sample(
# If not specified, include the mean prediction in
# kwargs to be passed to the model.fit() method
kwargs["include_mean"] = True
idata = self.model.fit(inference_method=sampler, init=init, **kwargs)

if self._inference_obj is None:
self._inference_obj = idata
elif isinstance(self._inference_obj, az.InferenceData):
_logger.info(
"Inference data already exsits. \n"
"Data from this run will overwrite the idata file..."
)
self._inference_obj.extend(idata, join="right")
else:
raise ValueError(
"The model has an attached inference object under"
+ " self._inference_obj, but it is not an InferenceData object."
if self._inference_obj is not None:
_logger.warning(
"The model has already been sampled. Overwriting the previous "
+ "inference object. Any previous reference to the inference object "
+ "will still point to the old object."
)
self._inference_obj = self.model.fit(
inference_method=sampler, init=init, **kwargs
)

# The parent was previously not part of deterministics --> compute it via
# posterior_predictive (works because it acts as the 'mu' parameter
# in the GLM as far as bambi is concerned)
if self._inference_obj is not None:
if self._parent not in self._inference_obj.posterior.data_vars.keys():
if self._parent not in self._inference_obj.posterior.data_vars:
# self.model.predict(self._inference_obj, kind="mean", inplace=True)
# rename 'rt,response_mean' to 'v' so in the traces everything
# looks the way it should
Expand All @@ -595,10 +588,7 @@ def sample(
# if parent already in posterior
del self._inference_obj.posterior["rt,response_mean"]

# returning copy of traces here to detach from the actual _inference_obj
# attached to the class. Otherise resampling will
# overwrite the 'returned' object leading to unexpected consequences
return deepcopy(self.traces)
return self.traces

def sample_posterior_predictive(
self,
Expand Down Expand Up @@ -1150,7 +1140,7 @@ def traces(self) -> az.InferenceData | pm.Approximation:
Returns
-------
az.InferenceData | pm.Approximation
The trace of the model after sampling.
The trace of the model after the last call to `sample()`.
"""
if not self._inference_obj:
raise ValueError("Please sample the model first.")
Expand Down Expand Up @@ -1515,9 +1505,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
bounds=self.bounds,
lapse=self.lapse,
extra_fields=(
None
if not self.extra_fields
else [deepcopy(self.data[field].values) for field in self.extra_fields]
None if not self.extra_fields else deepcopy(self.extra_fields)
),
)

Expand Down
11 changes: 11 additions & 0 deletions tests/test_hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,14 @@ def test_override_default_link(caplog, data_ddm_reg):

assert "t" in caplog.records[0].message
assert "strange" in caplog.records[0].message


def test_resampling(data_ddm):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should put a not somewhere concerning the resulting behavior.
If I understand correctly, if you change sample_2 you automatically change model.traces?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the return section of the sample() method and traces object to indicate that .sample() returns a reference to model.traces object, which stores the traces of the last call to .sample().

model = HSSM(data=data_ddm)
sample_1 = model.sample(draws=10, chains=1, tune=0)
assert sample_1 is model.traces

sample_2 = model.sample(draws=10, chains=1, tune=0)
assert sample_2 is model.traces

assert sample_1 is not sample_2