Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,15 @@ def make_likelihood_callable(
A list of boolean values indicating whether the parameters are regression
parameters. Defaults to None.
params_only : Optional
Whether the missing data likelihood is takes its first argument as the data.
Defaults to None.
Controls the expected signature of the ``loglik`` callable.
If False (the default when None), the callable signature is
``f(data, *params)``, where ``data`` is a 2-column array of
[rt, choice]. This is the standard case for LANs and other
likelihoods that condition on observed data.
If True, the callable signature is ``f(*params)`` with no data
argument. This is used for Choice Probability Networks (CPNs)
and Outcome Probability Networks (OPNs).
Defaults to None (treated as False).
"""
if isinstance(loglik, pytensor.graph.Op):
return loglik
Expand Down
46 changes: 25 additions & 21 deletions src/hssm/distribution_utils/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,10 @@ def make_node(self, data, *dist_params):
"""
inputs = [pt.as_tensor_variable(dist_param) for dist_param in dist_params]
self.is_scalars_only = all(inp.ndim == 0 for inp in inputs)
# params_only means calculate gradients only with respect to the
# parameters, not the data.
self.is_params_only = data is not None
self.has_data = data is not None
self.n_params = n_params

if self.is_params_only:
if self.has_data:
inputs = [pt.as_tensor_variable(data)] + inputs

outputs = [pt.vector()]
Expand Down Expand Up @@ -109,7 +107,7 @@ def grad(self, inputs, output_gradients):
outputs `y`, and the gradient at `y` is grad(x), the required output
is y*grad(x).
"""
if self.is_params_only:
if self.has_data:
results = lan_logp_vjp_op(
inputs[0], *inputs[1:], gz=output_gradients[0]
)
Expand All @@ -121,13 +119,13 @@ def grad(self, inputs, output_gradients):

output = results

if self.is_params_only:
if self.has_data:
output = [
pytensor.gradient.grad_not_implemented(self, 0, inputs[0]),
] + output

if self.n_params is not None:
start_idx = self.n_params + 1 if self.is_params_only else 0
start_idx = self.n_params + 1 if self.has_data else 0
for i in range(start_idx, len(output)):
output[i] = pytensor.gradient.grad_undefined(self, i, inputs[i])

Expand All @@ -148,15 +146,15 @@ def make_node(self, data, *dist_params, gz):
dist_params:
A list of parameters used in the likelihood computation.
"""
self.is_params_only = data is not None
self.has_data = data is not None
self.is_scalars_only = gz is None
inputs = [pt.as_tensor_variable(dist_param) for dist_param in dist_params]
if self.is_params_only:
if self.has_data:
inputs = [pt.as_tensor_variable(data)] + inputs
if not self.is_scalars_only:
inputs += [pt.as_tensor_variable(gz)]

if self.is_params_only:
if self.has_data:
outputs = [inp.type() for inp in inputs[1:-1]]
else:
if self.is_scalars_only:
Expand All @@ -181,7 +179,7 @@ def perform(self, node, inputs, outputs):
output_storage. There is one storage cell for each output of
the Op.
"""
if self.is_params_only:
if self.has_data:
results = logp_vjp(*inputs[:-1], gz=inputs[-1])
else:
if self.is_scalars_only:
Expand Down Expand Up @@ -248,7 +246,14 @@ def make_jax_logp_funcs_from_callable(
Parameters that are regressions will not be vectorized in likelihood
calculations.
params_only:
If True, the log-likelihood function will only take parameters as input.
Controls the expected signature of the ``logp`` callable.
If False (default), the callable signature is ``f(data, *params)``,
where ``data`` is a 2-column array of [rt, choice]. This is the
standard case for LANs and other likelihoods that condition on
observed data.
If True, the callable signature is ``f(*params)`` with no data
argument. This is used for Choice Probability Networks (CPNs)
and Outcome Probability Networks (OPNs).
return_jit
If `True`, the function will return a JIT-compiled version of the vectorized
logp function, its VJP, and the non-jitted version of the logp function.
Expand All @@ -273,9 +278,6 @@ def make_jax_logp_funcs_from_callable(
"parameters are regressions."
)

print("params_only: ", params_only)
print("params_is_reg: ", params_is_reg)

# Looks silly but is required to please mypy.
if vmap and params_is_reg is not None:
in_axes: list[int | None] = [
Expand Down Expand Up @@ -328,12 +330,14 @@ def make_jax_single_trial_logp_from_network_forward(
jax_forward_fn : Callable
The JAX forward function to use for the log-likelihood computation.
params_only : bool, optional
Whether to compute the log-likelihood for only the parameters.
This will not assume a data part in the input.
`params_only = True` is appropriate for CPNs and OPNs,
where the data is not used in the log-likelihood computation.
`params_only = False` is appropriate for LANs,
where the data is used in the log-likelihood computation.
Controls the expected signature of the returned callable.
If False (default), the returned function expects
``(data, *params)`` where ``data`` is a 2-column array of
[rt, choice]. This is the standard case for LANs and other
likelihoods that condition on observed data.
If True, the returned function expects ``(*params)`` with no
data argument. This is used for Choice Probability Networks
(CPNs) and Outcome Probability Networks (OPNs).

Returns
-------
Expand Down
9 changes: 8 additions & 1 deletion src/hssm/distribution_utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ def make_jax_logp_funcs_from_onnx(
Parameters that are regressions will not be vectorized in likelihood
calculations.
params_only:
If True, the log-likelihood function will only take parameters as input.
Controls the expected signature of the ``logp`` callable.
If False (default), the callable signature is ``f(data, *params)``,
where ``data`` is a 2-column array of [rt, choice]. This is the
standard case for LANs and other likelihoods that condition on
observed data.
If True, the callable signature is ``f(*params)`` with no data
argument. This is used for Choice Probability Networks (CPNs)
and Outcome Probability Networks (OPNs).
return_jit
If `True`, the function will return a JIT-compiled version of the vectorized
logp function, its VJP, and the non-jitted version of the logp function.
Expand Down