414 long running time of sample posterior predictive and eventual death by oom#436
Conversation
…-hddm Merging main.
…d to _mean prediction consistently
|
I will try to add a few more tests to this before merging. |
digicosmos86
left a comment
There was a problem hiding this comment.
Looks good! Two higher-level comments:
-
Since the only call to
simulatoris done here:, maybe we can use a for-loop here over n_samples to make the sampling safe, instead of patching the higher-level functions themselves? This way we can avoid running many intermediate-level code multiple times.HSSM/src/hssm/distribution_utils/dist.py
Lines 309 to 315 in 19f786d
-
InferenceDataobject does not come with attributes likeposterior, orposterior_predictiveby default, so type checker complains. The use of the square bracket notation is preferred. Or if this is too annoying we can disable this check (attr-defined) globally inpyproject.tomlmypysection, but that can be a bit risky
|
|
||
| if "posterior_predictive" in idata.groups(): | ||
| del idata.posterior_predictive | ||
| print("pre-existing posterior_predictive group deleted from idata. \n") |
There was a problem hiding this comment.
This should be a warning
digicosmos86
left a comment
There was a problem hiding this comment.
Looks awesome! Just some style suggestions at this point. Feel free to merge after the fixes :)
| if "posterior_predictive" in idata.groups(): | ||
| if idata is not None: |
There was a problem hiding this comment.
Should the order be the other way around?
There was a problem hiding this comment.
@digicosmos86 changed. This was useless to begin with, just an artifact appeasing mypy...
| from inspect import isclass | ||
| from os import PathLike | ||
| from typing import Any, Callable, Literal | ||
| from typing import Any, Callable, Literal, Union |
There was a problem hiding this comment.
We don't use Union any more. Now that we have Python 3.10, we use the | operator instead
| self.model, self._parent_param, self.response_c, self.response_str | ||
| ) | ||
| self.set_alias(self._aliases) | ||
| # _logger.info(self.pymc_model.initial_point()) |
There was a problem hiding this comment.
Should we remove debug comments?
There was a problem hiding this comment.
Stylistically eventually yes, but rn, I think it can sometimes still help future PRs that interact with this code. Here I literally have the next PR that I need to work on in mind. So in general agree, but let's skip here :)
| if self._inference_obj is not None: | ||
| if self._parent not in self._inference_obj.posterior.data_vars.keys(): | ||
| self.model.predict(self._inference_obj, kind="mean", inplace=True) | ||
| # self.model.predict(self._inference_obj, kind="mean", inplace=True) |
There was a problem hiding this comment.
Should we remove debug comments?
| self._parent in self._inference_obj.posterior.data_vars.keys() | ||
| and "rt,response_mean" in self._inference_obj.posterior.data_vars.keys() |
There was a problem hiding this comment.
data_vars are dicts, so the Python 3 style is to not use keys()
| and not np.allclose(draws, idata["posterior"].draw.values) | ||
| ): | ||
| # Reassign posterior to sub-sampled version | ||
| setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws)) |
There was a problem hiding this comment.
Are there any differences between setattr() and idata.add_groups()?
There was a problem hiding this comment.
to be honest I don't know... let me look into that independently to understand it properly.
There was a problem hiding this comment.
Actually, at least used somewhat semantically here, add_groups is about new groups, setattr is about reassigning to pre-existing group.
| if safe_mode: | ||
| # safe mode splits the draws into chunks of 10 to avoid | ||
| # memory issues (TODO: Figure out the source of memory issues) | ||
| split_draws = _split_array( | ||
| idata_copy["posterior"].draw.values, divisor=10 | ||
| ) | ||
|
|
||
| posterior_predictive_list = [] | ||
| for samples_tmp in split_draws: | ||
| tmp_posterior = idata["posterior"].sel(draw=samples_tmp) | ||
| setattr(idata_copy, "posterior", tmp_posterior) | ||
| self.model.predict( | ||
| idata_copy, kind, data, True, include_group_specific | ||
| ) | ||
| posterior_predictive_list.append(idata_copy["posterior_predictive"]) | ||
|
|
||
| if inplace: | ||
| idata.add_groups( | ||
| posterior_predictive=xr.concat( | ||
| posterior_predictive_list, dim="draw" | ||
| ) | ||
| ) | ||
| # for inplace, we don't return anything | ||
| return None | ||
| else: | ||
| # Reassign original posterior to idata_copy | ||
| setattr(idata_copy, "posterior", idata["posterior"]) | ||
| # Add new posterior predictive group to idata_copy | ||
| del idata_copy["posterior_predictive"] | ||
| idata_copy.add_groups( | ||
| posterior_predictive=xr.concat( | ||
| posterior_predictive_list, dim="draw" | ||
| ) | ||
| ) | ||
| return idata_copy | ||
| elif inplace: | ||
| # If not safe-mode | ||
| # We call .predict() directly without any | ||
| # chunking of data. | ||
|
|
||
| # .predict() is called on the copy of idata | ||
| # since we still subsampled (or assigned) the draws | ||
| self.model.predict(idata_copy, kind, data, True, include_group_specific) | ||
|
|
||
| # posterior predictive group added to idata | ||
| idata.add_groups( | ||
| posterior_predictive=idata_copy["posterior_predictive"] | ||
| ) | ||
|
|
||
| # don't return anything if inplace | ||
| return None | ||
|
|
||
| else: | ||
| # Not safe mode and not inplace | ||
| # Function acts as very thin wrapper around | ||
| # .predict(). It just operates on the | ||
| # idata_copy object | ||
| return self.model.predict( | ||
| idata_copy, kind, data, inplace, include_group_specific | ||
| ) |
There was a problem hiding this comment.
This if block looks slightly confusing. I think I understand what you mean, but would
if safe_mode:
if inplace:
...
else:
...
else:
if inplace:
...
else:
...be more readable?
| idata_copy, kind, data, False, include_group_specific | ||
| idata, kind, data, inplace, include_group_specific | ||
| ) | ||
|
|
There was a problem hiding this comment.
Add an else clause here to throw an error whenever other values are specified?
| return var_names | ||
|
|
||
| def _drop_parent_str_from_idata( | ||
| self, idata: Union[az.InferenceData, None] |
There was a problem hiding this comment.
| self, idata: Union[az.InferenceData, None] | |
| self, idata: az.InferenceData | None |
safe_modethat chunks computationsn_samplesargument was renamed todraws, and one can pass None | int | list | np.ndarraykind='mean', the posterior naming cleans uprt,response_mean-->v.tracesnow, and naming is also cleaned upsample_prior_predictive()will include the parent parameter as well now via an internal call to.predict()