Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions docs/tutorials/main_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@
"source": [
"If we wish to simulate from another model, we can do so by changing the `model` string.\n",
"\n",
"The number of models we can simulate differs from the number of models for which we have likelihoods available (both will increase over time). To get the models for which likelihood functions are supplied out of the box, we should check the `SupportedModels` under `hssm.defaults`."
"The number of models we can simulate differs from the number of models for which we have likelihoods available (both will increase over time). To get the models for which likelihood functions are supplied out of the box, we should inspect `hssm.HSSM.supported_models`."
]
},
{
Expand All @@ -427,59 +427,82 @@
{
"data": {
"text/plain": [
"typing.Literal['ddm', 'ddm_sdv', 'full_ddm', 'angle', 'levy', 'ornstein', 'weibull', 'race_no_bias_angle_4', 'ddm_seq2_no_bias', 'lba3', 'lba2']"
"('ddm',\n",
" 'ddm_sdv',\n",
" 'full_ddm',\n",
" 'angle',\n",
" 'levy',\n",
" 'ornstein',\n",
" 'weibull',\n",
" 'race_no_bias_angle_4',\n",
" 'ddm_seq2_no_bias',\n",
" 'lba3',\n",
" 'lba2')"
]
},
"execution_count": 30,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hssm.defaults.SupportedModels"
"hssm.HSSM.supported_models"
]
},
{
"cell_type": "markdown",
"id": "43703bd7-f50c-40eb-a130-2070b160a8ec",
"metadata": {},
"source": [
"If we wish to check more detailed information about a given model, we can use the `default_model_config` under `hssm.default`.\n",
"\n",
"Let's look at the `ddm`:"
"If we wish to check more detailed information about a given supported model, we can use accessors `get_<model_name>_config` under `hssm.default`. For example, we inspect `ddm` model metada below."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "1a526455-37e4-428c-a478-43ef387e496e",
"execution_count": 6,
"id": "04fd49b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'response': ['rt', 'response'],\n",
" 'list_params': ['v', 'a', 'z', 't', 'theta'],\n",
" 'list_params': ['v', 'a', 'z', 't'],\n",
" 'choices': [-1, 1],\n",
" 'description': None,\n",
" 'likelihoods': {'approx_differentiable': {'loglik': 'angle.onnx',\n",
" 'description': 'The Drift Diffusion Model (DDM)',\n",
" 'likelihoods': {'analytical': {'loglik': <function hssm.likelihoods.analytical.logp_ddm(data: numpy.ndarray, v: float, a: float, z: float, t: float, err: float = 1e-15, k_terms: int = 20, epsilon: float = 1e-15) -> numpy.ndarray>,\n",
" 'backend': None,\n",
" 'bounds': {'v': (-inf, inf),\n",
" 'a': (0.0, inf),\n",
" 'z': (0.0, 1.0),\n",
" 't': (0.0, inf)},\n",
" 'default_priors': {'t': {'name': 'HalfNormal', 'sigma': 2.0}},\n",
" 'extra_fields': None},\n",
" 'approx_differentiable': {'loglik': 'ddm.onnx',\n",
" 'backend': 'jax',\n",
" 'default_priors': {},\n",
" 'default_priors': {'t': {'name': 'HalfNormal', 'sigma': 2.0}},\n",
" 'bounds': {'v': (-3.0, 3.0),\n",
" 'a': (0.3, 3.0),\n",
" 'z': (0.1, 0.9),\n",
" 't': (0.001, 2.0),\n",
" 'theta': (-0.1, 1.3)},\n",
" 'a': (0.3, 2.5),\n",
" 'z': (0.0, 1.0),\n",
" 't': (0.0, 2.0)},\n",
" 'extra_fields': None},\n",
" 'blackbox': {'loglik': <function hssm.likelihoods.blackbox.hddm_to_hssm.<locals>.outer(data: numpy.ndarray, *args, **kwargs)>,\n",
" 'backend': None,\n",
" 'bounds': {'v': (-inf, inf),\n",
" 'a': (0.0, inf),\n",
" 'z': (0.0, 1.0),\n",
" 't': (0.0, inf)},\n",
" 'default_priors': {'t': {'name': 'HalfNormal', 'sigma': 2.0}},\n",
" 'extra_fields': None}}}"
]
},
"execution_count": 31,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hssm.defaults.default_model_config[\"angle\"]"
"hssm.defaults.get_ddm_config()"
]
},
{
Expand Down Expand Up @@ -6671,7 +6694,7 @@
"array([<Axes: title={'center': '94.0% HDI'}>], dtype=object)"
]
},
"execution_count": 49,
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
},
Expand Down Expand Up @@ -6713,7 +6736,7 @@
"array([<Axes: title={'center': '94.0% HDI'}>], dtype=object)"
]
},
"execution_count": 50,
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
},
Expand Down Expand Up @@ -7145,7 +7168,7 @@
"Lapse distribution: Uniform(lower: 0.0, upper: 20.0)"
]
},
"execution_count": 58,
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
42 changes: 41 additions & 1 deletion src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from inspect import isclass, signature
from os import PathLike
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union, cast
from typing import Any, Callable, Literal, Optional, Union, cast, get_args

import arviz as az
import bambi as bmb
Expand Down Expand Up @@ -61,6 +61,34 @@
_logger = logging.getLogger("hssm")


class classproperty:
"""A decorator that combines the behavior of @property and @classmethod.

This decorator allows you to define a property that can be accessed on the class
itself, rather than on instances of the class. It is useful for defining class-level
properties that need to perform some computation or access class-level data.

This implementation is provided for compatibility with Python versions 3.10 through
3.12, as one cannot combine the @property and @classmethod decorators is across all
these versions.

Example
-------
class MyClass:
@classproperty
def my_class_property(cls):
return "This is a class property"

print(MyClass.my_class_property) # Output: This is a class property
"""

def __init__(self, fget):
self.fget = fget

def __get__(self, instance, owner): # noqa: D105
return self.fget(owner)


class HSSM:
"""The basic Hierarchical Sequential Sampling Model (HSSM) class.

Expand Down Expand Up @@ -327,6 +355,7 @@ def __init__(
# Model config is not provided, but at this point was constructed from
# defaults.
if model not in typing.get_args(SupportedModels):
# TODO: ideally use self.supported_models above but mypy doesn't like it
if choices is not None:
self.model_config.update_choices(choices)
elif model in ssms_model_config:
Expand Down Expand Up @@ -480,6 +509,17 @@ def __init__(
)
_logger.info("Model initialized successfully.")

@classproperty
def supported_models(cls) -> tuple[SupportedModels, ...]:
"""Get a tuple of all supported models.

Returns
-------
tuple[SupportedModels, ...]
A tuple containing all supported model names.
"""
return get_args(SupportedModels)

@classmethod
def _store_init_args(cls, *args, **kwargs):
"""Store initialization arguments using signature binding."""
Expand Down
Loading