Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/hssm/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
"lba2",
"racing_diffusion_3",
"poisson_race",
"softmax_inv_temperature_2",
"softmax_inv_temperature_3",
]


Expand Down
48 changes: 41 additions & 7 deletions src/hssm/likelihoods/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
https://gist.github.com/sammosummo/c1be633a74937efaca5215da776f194b.
"""

from typing import Type

import jax.numpy as jnp
import numpy as np
import pymc as pm
Expand Down Expand Up @@ -399,14 +397,14 @@ def logp_ddm_sdv(
ddm_params = ["v", "a", "z", "t"]
ddm_sdv_params = ddm_params + ["sv"]

DDM: Type[pm.Distribution] = make_distribution(
DDM: type[pm.Distribution] = make_distribution(
rv="ddm",
loglik=logp_ddm,
list_params=ddm_params,
bounds=ddm_bounds,
)

DDM_SDV: Type[pm.Distribution] = make_distribution(
DDM_SDV: type[pm.Distribution] = make_distribution(
rv="ddm_sdv",
loglik=logp_ddm_sdv,
list_params=ddm_sdv_params,
Expand Down Expand Up @@ -541,7 +539,7 @@ def _print_message(_):
"t": (0.0, inf),
}

RDM3: Type[pm.Distribution] = make_distribution(
RDM3: type[pm.Distribution] = make_distribution(
rv="racing_diffusion_3",
loglik=logp_rdm3,
list_params=rdm3_params,
Expand Down Expand Up @@ -685,14 +683,14 @@ def logp_lba3(
"v2": (0.0, inf),
}

LBA2: Type[pm.Distribution] = make_distribution(
LBA2: type[pm.Distribution] = make_distribution(
rv="lba2",
loglik=logp_lba2,
list_params=lba2_params,
bounds=lba2_bounds,
)

LBA3: Type[pm.Distribution] = make_distribution(
LBA3: type[pm.Distribution] = make_distribution(
rv="lba3",
loglik=logp_lba3,
list_params=lba3_params,
Expand Down Expand Up @@ -806,3 +804,39 @@ def logp_poisson_race(
list_params=poisson_race_params,
bounds=poisson_race_bounds,
)


def softmax_inv_temperature(data: np.ndarray, beta: np.ndarray, *logits):
"""Compute the log-likelihood of the Inverse Softmax Temperature Model.

Parameters
----------
data
1D array of responses (choices).
beta
A scalar for the softmax temperature (0, inf).
*logits
Logits for each choice excluding logit0.

Returns
-------
pt.TensorVariable
The log likelihood of the Inverse Softmax Temperature Model.
"""
choices = pt.where(data < 1, 0.0, data).astype("int32")
zeros = pt.zeros_like(data, dtype=pytensor.config.floatX)

logits_stacked = pt.stack(
[
zeros, # logit0 is always 0
*(logit + zeros for logit in logits),
],
)

logits_scaled = logits_stacked * beta
choice_logits = logits_scaled[choices, pt.arange(data.shape[0])]
log_prob_choices = choice_logits - pt.logsumexp(
logits_scaled, axis=0, keepdims=False
)

return log_prob_choices
59 changes: 59 additions & 0 deletions src/hssm/modelconfig/_softmax_inv_temperature_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""The default configuration for the Inverse Softmax Temperature Model with 2 logits."""

import numpy as np

from .._types import DefaultConfig, ParamSpec
from ..likelihoods.analytical import softmax_inv_temperature


def softmax_inv_temperature_config(n_choices: int = 2) -> DefaultConfig:
"""
Get the default config for the Softmax Inv. Temperature Model.

Parameters
----------
n_choices : optional
The number of choices in the model. Must be at least 2. The number of logits
will be n_choices - 1. Default is 2 (i.e., 1 logit).

Returns
-------
DefaultConfig
A dictionary containing the default configuration settings for the
Inverse Softmax Temperature Model, including response variables, model
parameters, choices, description, and likelihood specifications.
"""
if n_choices < 2:
raise ValueError("n_choices must be at least 2.")

bounds = {"beta": (0.0, np.inf)}
bounds.update({f"logit{i}": (-np.inf, np.inf) for i in range(1, n_choices)})
default_priors: dict[str, ParamSpec] = {
"beta": {
"name": "Gamma",
"alpha": 2.0,
"beta": 0.5,
},
}
default_priors.update(
{
f"logit{i}": {"name": "Normal", "mu": 0.0, "sigma": 1.0}
for i in range(1, n_choices)
}
)

return {
"response": ["response"],
"list_params": ["beta"] + [f"logit{i}" for i in range(1, n_choices)],
"choices": [-1, 1] if n_choices == 2 else list(range(n_choices)),
"description": f"The Softmax Inv. Temperature Model with {n_choices} choices",
"likelihoods": {
"analytical": {
"loglik": softmax_inv_temperature,
"backend": None,
"bounds": bounds,
"default_priors": default_priors,
"extra_fields": None,
},
},
}
18 changes: 18 additions & 0 deletions src/hssm/modelconfig/softmax_inv_temperature_2_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Softmax Inverse Temperature Model with 2 logits configuration."""

from .._types import DefaultConfig
from ._softmax_inv_temperature_config import softmax_inv_temperature_config


def get_softmax_inv_temperature_2_config() -> DefaultConfig:
"""
Get the default config for the Inverse Softmax Temperature Model with 2 logits.

Returns
-------
DefaultConfig
A dictionary containing the default configuration settings for the
Inverse Softmax Temperature Model with 2 logits, including response variables,
model parameters, choices, description, and likelihood specifications.
"""
return softmax_inv_temperature_config(n_choices=2)
18 changes: 18 additions & 0 deletions src/hssm/modelconfig/softmax_inv_temperature_3_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Softmax Inverse Temperature Model with 3 logits configuration."""

from .._types import DefaultConfig
from ._softmax_inv_temperature_config import softmax_inv_temperature_config


def get_softmax_inv_temperature_3_config() -> DefaultConfig:
"""
Get the default config for the Inverse Softmax Temperature Model with 3 logits.

Returns
-------
DefaultConfig
A dictionary containing the default configuration settings for the
Inverse Softmax Temperature Model with 3 logits, including response variables,
model parameters, choices, description, and likelihood specifications.
"""
return softmax_inv_temperature_config(n_choices=3)
45 changes: 45 additions & 0 deletions tests/test_likelihoods_choice_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

import numpy as np

from hssm.likelihoods.analytical import softmax_inv_temperature


_N = 10
_rng = np.random.default_rng(42)

_DATA_BINARY = _rng.choice([-1, 1], size=_N).astype(np.float32)
_DATA_TERNARY = _rng.choice([0, 1, 2], size=_N).astype(np.float32)

_SCALAR_BETA = np.float32(1.5)
_VECTOR_BETA = np.full(_N, 1.5, dtype=np.float32)

_SCALAR_LOGIT = np.float32(0.5)
_VECTOR_LOGIT = np.full(_N, 0.5, dtype=np.float32)


@pytest.mark.parametrize(
"beta", [_SCALAR_BETA, _VECTOR_BETA], ids=["scalar_beta", "vector_beta"]
)
@pytest.mark.parametrize(
"logit", [_SCALAR_LOGIT, _VECTOR_LOGIT], ids=["scalar_logit", "vector_logit"]
)
def test_softmax_inv_temperature_shape_2choice(beta, logit):
result = softmax_inv_temperature(_DATA_BINARY, beta, logit)
evaluated = result.eval()
assert evaluated.shape == (_N,)


@pytest.mark.parametrize(
"beta", [_SCALAR_BETA, _VECTOR_BETA], ids=["scalar_beta", "vector_beta"]
)
@pytest.mark.parametrize(
"logit1", [_SCALAR_LOGIT, _VECTOR_LOGIT], ids=["scalar_logit1", "vector_logit1"]
)
@pytest.mark.parametrize(
"logit2", [_SCALAR_LOGIT, _VECTOR_LOGIT], ids=["scalar_logit2", "vector_logit2"]
)
def test_softmax_inv_temperature_shape_3choice(beta, logit1, logit2):
result = softmax_inv_temperature(_DATA_TERNARY, beta, logit1, logit2)
evaluated = result.eval()
assert evaluated.shape == (_N,)
71 changes: 71 additions & 0 deletions tests/test_modelconfig.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
import pytest

from hssm.modelconfig import get_default_model_config
from hssm.modelconfig._softmax_inv_temperature_config import (
softmax_inv_temperature_config,
)
import hssm


Expand Down Expand Up @@ -108,3 +112,70 @@ def test_load_all_supported_model_configs(model):
def test_get_default_model_config_invalid():
with pytest.raises(ValueError):
get_default_model_config("invalid_model")


def test_softmax_inv_temperature_default():
"""Test softmax_inv_temperature with default n_choices=2."""
config = softmax_inv_temperature_config()

assert config["response"] == ["response"]
assert config["choices"] == [-1, 1]
assert config["list_params"] == ["beta", "logit1"]
assert config["description"] == "The Softmax Inv. Temperature Model with 2 choices"

likelihoods = config["likelihoods"]
lk_analytical = likelihoods["analytical"]

assert lk_analytical["backend"] is None
assert lk_analytical["extra_fields"] is None

# Test bounds
assert lk_analytical["bounds"]["beta"] == (0.0, np.inf)
assert lk_analytical["bounds"]["logit1"] == (-np.inf, np.inf)

# Test default priors
assert lk_analytical["default_priors"]["beta"] == {
"name": "Gamma",
"alpha": 2.0,
"beta": 0.5,
}
assert lk_analytical["default_priors"]["logit1"] == {
"name": "Normal",
"mu": 0.0,
"sigma": 1.0,
}


def test_softmax_inv_temperature_3_choices():
"""Test softmax_inv_temperature with n_choices=3."""
config = softmax_inv_temperature_config(n_choices=3)

assert config["response"] == ["response"]
assert config["choices"] == [0, 1, 2]
assert config["list_params"] == ["beta", "logit1", "logit2"]
assert config["description"] == "The Softmax Inv. Temperature Model with 3 choices"

likelihoods = config["likelihoods"]
lk_analytical = likelihoods["analytical"]

# Test bounds
assert lk_analytical["bounds"]["beta"] == (0.0, np.inf)
assert lk_analytical["bounds"]["logit1"] == (-np.inf, np.inf)
assert lk_analytical["bounds"]["logit2"] == (-np.inf, np.inf)

# Test default priors
assert lk_analytical["default_priors"]["beta"] == {
"name": "Gamma",
"alpha": 2.0,
"beta": 0.5,
}
assert lk_analytical["default_priors"]["logit1"] == {
"name": "Normal",
"mu": 0.0,
"sigma": 1.0,
}
assert lk_analytical["default_priors"]["logit2"] == {
"name": "Normal",
"mu": 0.0,
"sigma": 1.0,
}