@@ -63,12 +63,10 @@ def make_node(self, data, *dist_params):
6363 """
6464 inputs = [pt .as_tensor_variable (dist_param ) for dist_param in dist_params ]
6565 self .is_scalars_only = all (inp .ndim == 0 for inp in inputs )
66- # params_only means calculate gradients only with respect to the
67- # parameters, not the data.
68- self .is_params_only = data is not None
66+ self .has_data = data is not None
6967 self .n_params = n_params
7068
71- if self .is_params_only :
69+ if self .has_data :
7270 inputs = [pt .as_tensor_variable (data )] + inputs
7371
7472 outputs = [pt .vector ()]
@@ -109,7 +107,7 @@ def grad(self, inputs, output_gradients):
109107 outputs `y`, and the gradient at `y` is grad(x), the required output
110108 is y*grad(x).
111109 """
112- if self .is_params_only :
110+ if self .has_data :
113111 results = lan_logp_vjp_op (
114112 inputs [0 ], * inputs [1 :], gz = output_gradients [0 ]
115113 )
@@ -121,13 +119,13 @@ def grad(self, inputs, output_gradients):
121119
122120 output = results
123121
124- if self .is_params_only :
122+ if self .has_data :
125123 output = [
126124 pytensor .gradient .grad_not_implemented (self , 0 , inputs [0 ]),
127125 ] + output
128126
129127 if self .n_params is not None :
130- start_idx = self .n_params + 1 if self .is_params_only else 0
128+ start_idx = self .n_params + 1 if self .has_data else 0
131129 for i in range (start_idx , len (output )):
132130 output [i ] = pytensor .gradient .grad_undefined (self , i , inputs [i ])
133131
@@ -148,15 +146,15 @@ def make_node(self, data, *dist_params, gz):
148146 dist_params:
149147 A list of parameters used in the likelihood computation.
150148 """
151- self .is_params_only = data is not None
149+ self .has_data = data is not None
152150 self .is_scalars_only = gz is None
153151 inputs = [pt .as_tensor_variable (dist_param ) for dist_param in dist_params ]
154- if self .is_params_only :
152+ if self .has_data :
155153 inputs = [pt .as_tensor_variable (data )] + inputs
156154 if not self .is_scalars_only :
157155 inputs += [pt .as_tensor_variable (gz )]
158156
159- if self .is_params_only :
157+ if self .has_data :
160158 outputs = [inp .type () for inp in inputs [1 :- 1 ]]
161159 else :
162160 if self .is_scalars_only :
@@ -181,7 +179,7 @@ def perform(self, node, inputs, outputs):
181179 output_storage. There is one storage cell for each output of
182180 the Op.
183181 """
184- if self .is_params_only :
182+ if self .has_data :
185183 results = logp_vjp (* inputs [:- 1 ], gz = inputs [- 1 ])
186184 else :
187185 if self .is_scalars_only :
@@ -248,7 +246,14 @@ def make_jax_logp_funcs_from_callable(
248246 Parameters that are regressions will not be vectorized in likelihood
249247 calculations.
250248 params_only:
251- If True, the log-likelihood function will only take parameters as input.
249+ Controls the expected signature of the ``logp`` callable.
250+ If False (default), the callable signature is ``f(data, *params)``,
251+ where ``data`` is a 2-column array of [rt, choice]. This is the
252+ standard case for LANs and other likelihoods that condition on
253+ observed data.
254+ If True, the callable signature is ``f(*params)`` with no data
255+ argument. This is used for Choice Probability Networks (CPNs)
256+ and Outcome Probability Networks (OPNs).
252257 return_jit
253258 If `True`, the function will return a JIT-compiled version of the vectorized
254259 logp function, its VJP, and the non-jitted version of the logp function.
@@ -273,9 +278,6 @@ def make_jax_logp_funcs_from_callable(
273278 "parameters are regressions."
274279 )
275280
276- print ("params_only: " , params_only )
277- print ("params_is_reg: " , params_is_reg )
278-
279281 # Looks silly but is required to please mypy.
280282 if vmap and params_is_reg is not None :
281283 in_axes : list [int | None ] = [
@@ -328,12 +330,14 @@ def make_jax_single_trial_logp_from_network_forward(
328330 jax_forward_fn : Callable
329331 The JAX forward function to use for the log-likelihood computation.
330332 params_only : bool, optional
331- Whether to compute the log-likelihood for only the parameters.
332- This will not assume a data part in the input.
333- `params_only = True` is appropriate for CPNs and OPNs,
334- where the data is not used in the log-likelihood computation.
335- `params_only = False` is appropriate for LANs,
336- where the data is used in the log-likelihood computation.
333+ Controls the expected signature of the returned callable.
334+ If False (default), the returned function expects
335+ ``(data, *params)`` where ``data`` is a 2-column array of
336+ [rt, choice]. This is the standard case for LANs and other
337+ likelihoods that condition on observed data.
338+ If True, the returned function expects ``(*params)`` with no
339+ data argument. This is used for Choice Probability Networks
340+ (CPNs) and Outcome Probability Networks (OPNs).
337341
338342 Returns
339343 -------
0 commit comments