Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions src/hssm/likelihoods/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
https://gist.github.com/sammosummo/c1be633a74937efaca5215da776f194b.
"""

from typing import Type

import numpy as np
import pymc as pm
import pytensor
Expand Down Expand Up @@ -395,14 +393,14 @@ def logp_ddm_sdv(
ddm_params = ["v", "a", "z", "t"]
ddm_sdv_params = ddm_params + ["sv"]

DDM: Type[pm.Distribution] = make_distribution(
DDM: type[pm.Distribution] = make_distribution(
rv="ddm",
loglik=logp_ddm,
list_params=ddm_params,
bounds=ddm_bounds,
)

DDM_SDV: Type[pm.Distribution] = make_distribution(
DDM_SDV: type[pm.Distribution] = make_distribution(
rv="ddm_sdv",
loglik=logp_ddm_sdv,
list_params=ddm_sdv_params,
Expand Down Expand Up @@ -546,16 +544,52 @@ def logp_lba3(
"v2": (0.0, inf),
}

LBA2: Type[pm.Distribution] = make_distribution(
LBA2: type[pm.Distribution] = make_distribution(
rv="lba2",
loglik=logp_lba2,
list_params=lba2_params,
bounds=lba2_bounds,
)

LBA3: Type[pm.Distribution] = make_distribution(
LBA3: type[pm.Distribution] = make_distribution(
rv="lba3",
loglik=logp_lba3,
list_params=lba3_params,
bounds=lba3_bounds,
)


def softmax_inv_temperature(data: np.ndarray, beta: np.ndarray, *logits):
"""Compute the log-likelihood of the Inverse Softmax Temperature Model.

Parameters
----------
data
1D array of responses (choices).
beta
A scalar for the softmax temperature (0, inf).
*logits
Logits for each choice excluding logit0.

Returns
-------
pt.TensorVariable
The log likelihood of the Inverse Softmax Temperature Model.
"""
choices = pt.where(data < 1, 0.0, data).astype("int32")
zeros = pt.zeros_like(data, dtype=pytensor.config.floatX)

logits_stacked = pt.stack(
[
zeros, # logit0 is always 0
*(logit + zeros for logit in logits),
],
)

logits_scaled = logits_stacked * beta
choice_logits = logits_scaled[choices, pt.arange(data.shape[0])]
log_prob_choices = choice_logits - pt.logsumexp(
logits_scaled, axis=0, keepdims=False
)

return log_prob_choices
45 changes: 45 additions & 0 deletions tests/test_likelihoods_choice_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

import numpy as np

from hssm.likelihoods.analytical import softmax_inv_temperature


_N = 10
_rng = np.random.default_rng(42)

_DATA_BINARY = _rng.choice([-1, 1], size=_N).astype(np.float32)
_DATA_TERNARY = _rng.choice([0, 1, 2], size=_N).astype(np.float32)

_SCALAR_BETA = np.float32(1.5)
_VECTOR_BETA = np.full(_N, 1.5, dtype=np.float32)

_SCALAR_LOGIT = np.float32(0.5)
_VECTOR_LOGIT = np.full(_N, 0.5, dtype=np.float32)


@pytest.mark.parametrize(
"beta", [_SCALAR_BETA, _VECTOR_BETA], ids=["scalar_beta", "vector_beta"]
)
@pytest.mark.parametrize(
"logit", [_SCALAR_LOGIT, _VECTOR_LOGIT], ids=["scalar_logit", "vector_logit"]
)
def test_softmax_inv_temperature_shape_2choice(beta, logit):
result = softmax_inv_temperature(_DATA_BINARY, beta, logit)
evaluated = result.eval()
assert evaluated.shape == (_N,)


@pytest.mark.parametrize(
"beta", [_SCALAR_BETA, _VECTOR_BETA], ids=["scalar_beta", "vector_beta"]
)
@pytest.mark.parametrize(
"logit1", [_SCALAR_LOGIT, _VECTOR_LOGIT], ids=["scalar_logit1", "vector_logit1"]
)
@pytest.mark.parametrize(
"logit2", [_SCALAR_LOGIT, _VECTOR_LOGIT], ids=["scalar_logit2", "vector_logit2"]
)
def test_softmax_inv_temperature_shape_3choice(beta, logit1, logit2):
result = softmax_inv_temperature(_DATA_TERNARY, beta, logit1, logit2)
evaluated = result.eval()
assert evaluated.shape == (_N,)