Skip to content

Commit c12bc20

Browse files
improve naming to make logic clearer (#902)
1 parent 8cc9478 commit c12bc20

3 files changed

Lines changed: 42 additions & 24 deletions

File tree

src/hssm/distribution_utils/dist.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -719,8 +719,15 @@ def make_likelihood_callable(
719719
A list of boolean values indicating whether the parameters are regression
720720
parameters. Defaults to None.
721721
params_only : Optional
722-
Whether the missing data likelihood is takes its first argument as the data.
723-
Defaults to None.
722+
Controls the expected signature of the ``loglik`` callable.
723+
If False (the default when None), the callable signature is
724+
``f(data, *params)``, where ``data`` is a 2-column array of
725+
[rt, choice]. This is the standard case for LANs and other
726+
likelihoods that condition on observed data.
727+
If True, the callable signature is ``f(*params)`` with no data
728+
argument. This is used for Choice Probability Networks (CPNs)
729+
and Outcome Probability Networks (OPNs).
730+
Defaults to None (treated as False).
724731
"""
725732
if isinstance(loglik, pytensor.graph.Op):
726733
return loglik

src/hssm/distribution_utils/jax.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,10 @@ def make_node(self, data, *dist_params):
6363
"""
6464
inputs = [pt.as_tensor_variable(dist_param) for dist_param in dist_params]
6565
self.is_scalars_only = all(inp.ndim == 0 for inp in inputs)
66-
# params_only means calculate gradients only with respect to the
67-
# parameters, not the data.
68-
self.is_params_only = data is not None
66+
self.has_data = data is not None
6967
self.n_params = n_params
7068

71-
if self.is_params_only:
69+
if self.has_data:
7270
inputs = [pt.as_tensor_variable(data)] + inputs
7371

7472
outputs = [pt.vector()]
@@ -109,7 +107,7 @@ def grad(self, inputs, output_gradients):
109107
outputs `y`, and the gradient at `y` is grad(x), the required output
110108
is y*grad(x).
111109
"""
112-
if self.is_params_only:
110+
if self.has_data:
113111
results = lan_logp_vjp_op(
114112
inputs[0], *inputs[1:], gz=output_gradients[0]
115113
)
@@ -121,13 +119,13 @@ def grad(self, inputs, output_gradients):
121119

122120
output = results
123121

124-
if self.is_params_only:
122+
if self.has_data:
125123
output = [
126124
pytensor.gradient.grad_not_implemented(self, 0, inputs[0]),
127125
] + output
128126

129127
if self.n_params is not None:
130-
start_idx = self.n_params + 1 if self.is_params_only else 0
128+
start_idx = self.n_params + 1 if self.has_data else 0
131129
for i in range(start_idx, len(output)):
132130
output[i] = pytensor.gradient.grad_undefined(self, i, inputs[i])
133131

@@ -148,15 +146,15 @@ def make_node(self, data, *dist_params, gz):
148146
dist_params:
149147
A list of parameters used in the likelihood computation.
150148
"""
151-
self.is_params_only = data is not None
149+
self.has_data = data is not None
152150
self.is_scalars_only = gz is None
153151
inputs = [pt.as_tensor_variable(dist_param) for dist_param in dist_params]
154-
if self.is_params_only:
152+
if self.has_data:
155153
inputs = [pt.as_tensor_variable(data)] + inputs
156154
if not self.is_scalars_only:
157155
inputs += [pt.as_tensor_variable(gz)]
158156

159-
if self.is_params_only:
157+
if self.has_data:
160158
outputs = [inp.type() for inp in inputs[1:-1]]
161159
else:
162160
if self.is_scalars_only:
@@ -181,7 +179,7 @@ def perform(self, node, inputs, outputs):
181179
output_storage. There is one storage cell for each output of
182180
the Op.
183181
"""
184-
if self.is_params_only:
182+
if self.has_data:
185183
results = logp_vjp(*inputs[:-1], gz=inputs[-1])
186184
else:
187185
if self.is_scalars_only:
@@ -248,7 +246,14 @@ def make_jax_logp_funcs_from_callable(
248246
Parameters that are regressions will not be vectorized in likelihood
249247
calculations.
250248
params_only:
251-
If True, the log-likelihood function will only take parameters as input.
249+
Controls the expected signature of the ``logp`` callable.
250+
If False (default), the callable signature is ``f(data, *params)``,
251+
where ``data`` is a 2-column array of [rt, choice]. This is the
252+
standard case for LANs and other likelihoods that condition on
253+
observed data.
254+
If True, the callable signature is ``f(*params)`` with no data
255+
argument. This is used for Choice Probability Networks (CPNs)
256+
and Outcome Probability Networks (OPNs).
252257
return_jit
253258
If `True`, the function will return a JIT-compiled version of the vectorized
254259
logp function, its VJP, and the non-jitted version of the logp function.
@@ -273,9 +278,6 @@ def make_jax_logp_funcs_from_callable(
273278
"parameters are regressions."
274279
)
275280

276-
print("params_only: ", params_only)
277-
print("params_is_reg: ", params_is_reg)
278-
279281
# Looks silly but is required to please mypy.
280282
if vmap and params_is_reg is not None:
281283
in_axes: list[int | None] = [
@@ -328,12 +330,14 @@ def make_jax_single_trial_logp_from_network_forward(
328330
jax_forward_fn : Callable
329331
The JAX forward function to use for the log-likelihood computation.
330332
params_only : bool, optional
331-
Whether to compute the log-likelihood for only the parameters.
332-
This will not assume a data part in the input.
333-
`params_only = True` is appropriate for CPNs and OPNs,
334-
where the data is not used in the log-likelihood computation.
335-
`params_only = False` is appropriate for LANs,
336-
where the data is used in the log-likelihood computation.
333+
Controls the expected signature of the returned callable.
334+
If False (default), the returned function expects
335+
``(data, *params)`` where ``data`` is a 2-column array of
336+
[rt, choice]. This is the standard case for LANs and other
337+
likelihoods that condition on observed data.
338+
If True, the returned function expects ``(*params)`` with no
339+
data argument. This is used for Choice Probability Networks
340+
(CPNs) and Outcome Probability Networks (OPNs).
337341
338342
Returns
339343
-------

src/hssm/distribution_utils/onnx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,14 @@ def make_jax_logp_funcs_from_onnx(
5353
Parameters that are regressions will not be vectorized in likelihood
5454
calculations.
5555
params_only:
56-
If True, the log-likelihood function will only take parameters as input.
56+
Controls the expected signature of the ``logp`` callable.
57+
If False (default), the callable signature is ``f(data, *params)``,
58+
where ``data`` is a 2-column array of [rt, choice]. This is the
59+
standard case for LANs and other likelihoods that condition on
60+
observed data.
61+
If True, the callable signature is ``f(*params)`` with no data
62+
argument. This is used for Choice Probability Networks (CPNs)
63+
and Outcome Probability Networks (OPNs).
5764
return_jit
5865
If `True`, the function will return a JIT-compiled version of the vectorized
5966
logp function, its VJP, and the non-jitted version of the logp function.

0 commit comments

Comments
 (0)