Fix nan gradients in analytical likelihood#468
Conversation
AlexanderFengler
left a comment
There was a problem hiding this comment.
Looks good, this is mostly about iterating conceptually, not code quality.
| LOGP_LB, | ||
| tt = negative_rt * epsilon + (1 - negative_rt) * rt | ||
|
|
||
| p = pt.maximum(ftt01w(tt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB)) |
There was a problem hiding this comment.
quick note,
it seems like we are only passing k_terms here, not actually computing k_terms.
I think we had agreed to do that way back on another iteration of trying to fix issues with this likelihood, and I think it's fine, but in this case we should make the default a bit higher than 7.
There was a problem hiding this comment.
Just playing around here. Not actually changing
| - (v_flipped**2 * rt / 2.0) | ||
| - 2.0 * pt.log(a), | ||
| - (v_flipped**2 * tt / 2.0) | ||
| - 2.0 * pt.log(pt.maximum(epsilon, a)) |
There was a problem hiding this comment.
reflecting on this a bit,
I think this maximum business is actually corrupting the gradients, so we should just a priori restrict a > epsilon (via prior essentially?).
There was a problem hiding this comment.
on the other hand, apart from initialization (which 1. our strategies should already avoid, 2. we generally can impact) a should basically never come close to 0, so this should basically never be the culprit...
There was a problem hiding this comment.
But this did help a bit, for some reason...
| - 2.0 * pt.log(pt.maximum(epsilon, a)) | ||
| ) | ||
|
|
||
| checked_logp = check_parameters(logp, a >= 0, msg="a >= 0") |
There was a problem hiding this comment.
in the spirit of above, this check could be a>0 but honestly we shouldn't really ever get there.
There was a problem hiding this comment.
Same as above
| err: float = 1e-15, | ||
| k_terms: int = 20, | ||
| k_terms: int = 7, | ||
| epsilon: float = 1e-15, |
There was a problem hiding this comment.
I don't know what was used for testing / is used as actual value for inference, but I guess it is this default?
The epsilon for the rt part should rather be on the order of 1e-3, or even 1e-2.
If we are reusing the same epsilon in multiple places, we should probably separate it out.
There was a problem hiding this comment.
Was playing around. It seems that changing k_terms to 7 did not improve speed or computation
src/hssm/likelihoods/analytical.py
Outdated
| rt = rt - t | ||
|
|
||
| p = pt.maximum(ftt01w(rt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB)) | ||
| negative_rt = rt <= epsilon |
There was a problem hiding this comment.
Ok reflecting on this a bit, the logic that we want should probably look something like:
- flag all rts lower than epsilon
- go through with ftt01w
- then set all flagged rts to
LOGB_LB
This should actually cut the gradient for problematic rts.
Potentially we put this as a logp_ddm_2 and compare results / gradients.
Alternatively, if any rt breaches epsilon, directly send logp to -infty (this is probably not preferable).
There was a problem hiding this comment.
We were doing this. I think the problem is that the gradient is computed anyway and the over/underflow was still happening
|
@digicosmos86 is this stale for now? |
There doesn't seem to be a solution for really small RTs in the denominator, which can blow up |
|
ah well maybe this is related to my "RT hack" I proposed as an interim
solution for cases where *t* is low. (Which I attributed to the possibility
that the sampler would end up proposing values of t that hit the lower
bound and lead to unstable gradients, but I know that pymc is supposed to
deal with such boundaries smoothly under the hood - so maybe the issue is
just the RTs in the denominator being small). In that case the RT hack
would still work (ie under the hood befor fitting just add a constant value
to all RTs (say 0.5), which should only shift the t parameter, and then
report t_new = t - 0.5).
…On Tue, Jul 9, 2024 at 8:34 AM Paul Xu ***@***.***> wrote:
@digicosmos86 <https://github.com/digicosmos86> is this stale for now?
There doesn't seem to be a solution for really small RTs in the
denominator, which can blow up
—
Reply to this email directly, view it on GitHub
<#468 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAG7TFGTTCFH2PBQAF525KLZLPKEDAVCNFSM6AAAAABJUTJUAKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMJXGU3DGMRWHE>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
|
@frankmj I ran a few more tests and the RT-hack did do the trick. It might be hard for us to implement this trick in our code though, mostly because people use |
|
Great. maybe one simple solution would be to simply add a link function
with t= t'-const ?
…On Tue, Jul 9, 2024 at 9:47 AM Paul Xu ***@***.***> wrote:
@frankmj <https://github.com/frankmj> I ran a few more tests and the
RT-hack did do the trick. It might be hard for us to implement this trick
in our code though, mostly because people use arviz functions instead of
the convenience functions that we provide, which could give us some control
over the output. We could note this in our documentation somewhere about
this trick so that the users can implement this themselves so that they
have full control
—
Reply to this email directly, view it on GitHub
<#468 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAG7TFDEGDMAENAMCYVOMCDZLPSY5AVCNFSM6AAAAABJUTJUAKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMJXG44TKNRVGA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
|
@frankmj That's a great idea! I also noticed that the RT-hack only worked when |
cpaniaguam
left a comment
There was a problem hiding this comment.
A few things here I might end up picking up myself.
| _a = 2 * pt.sqrt(2 * np.pi * rt) * err < 1 | ||
| _b = 2 + pt.sqrt(-2 * rt * pt.log(2 * pt.sqrt(2 * np.pi * rt) * err)) | ||
| _c = pt.sqrt(rt) + 1 |
There was a problem hiding this comment.
The fundamental operation is pt.sqrt(rt). It's better to do this first and reuse the result to avoid computing it again.
There was a problem hiding this comment.
For numerical stability, it's better to group the constant factor C = 2 * pt.sqrt(2 * np.pi) * err and compare each member of sqrt_rt = pt.sqrt(rt) against 1/C.
There was a problem hiding this comment.
Sure! Feel free to change this
| ks = 2 + pt.sqrt(-2 * rt * pt.log(2 * np.sqrt(2 * np.pi * rt) * err)) | ||
| ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0) | ||
| ks = pt.switch(2 * pt.sqrt(2 * np.pi * rt) * err < 1, ks, 2) | ||
| _a = 2 * pt.sqrt(2 * np.pi * rt) * err < 1 |
There was a problem hiding this comment.
What would a better name for this boolean array be, maybe mask or sieve?
There was a problem hiding this comment.
Should pt.lt be used here as done elsewhere in this PR?
There was a problem hiding this comment.
It's actually equivalent but I was just playing around
| _b = 2 + pt.sqrt(-2 * rt * pt.log(2 * pt.sqrt(2 * np.pi * rt) * err)) | ||
| _c = pt.sqrt(rt) + 1 | ||
| _d = pt.max(pt.stack([_b, _c]), axis=0) | ||
| ks = _a * _d + (1 - _a) * 2 |
There was a problem hiding this comment.
Because _a is boolean, I think it's better to treat it as such and use pt.switch.
| ks = _a * _d + (1 - _a) * 2 | |
| ks = pt.switch(mask, _d, 2) # having renamed `_a` to `mask`, for example |
There was a problem hiding this comment.
Please see comment below
| _b = 1.0 / (np.pi * pt.sqrt(rt)) | ||
| _c = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt)) | ||
| _d = pt.max(pt.stack([_b, _c]), axis=0) | ||
| kl = _a * _b + (1 - _a) * _b |
There was a problem hiding this comment.
_c and _d are not used. Should _d be used in the second term instead of _b? Otherwise kl will be _b.
| kl = _a * _b + (1 - _a) * _b | |
| kl = pt.switch(mask, _b, _d) |
There was a problem hiding this comment.
Please see comment below
| logp = pt.where( | ||
| rt <= epsilon, | ||
| LOGP_LB, | ||
| tt = negative_rt * epsilon + (1 - negative_rt) * rt |
There was a problem hiding this comment.
| tt = negative_rt * epsilon + (1 - negative_rt) * rt | |
| tt = pt.switch(negative_rt, epsilon, rt) |
There was a problem hiding this comment.
This actually is done on purpose. pt.switch can cause some weird errors
| + ( | ||
| (a * z_flipped * sv) ** 2 | ||
| - 2 * a * v_flipped * z_flipped | ||
| - (v_flipped**2) * rt | ||
| - (v_flipped**2) * tt | ||
| ) | ||
| / (2 * (sv**2) * rt + 2) | ||
| - 0.5 * pt.log(sv**2 * rt + 1) | ||
| - 2 * pt.log(a), | ||
| / (2 * (sv**2) * tt + 2) | ||
| - 0.5 * pt.log(sv**2 * tt + 1) | ||
| - 2 * pt.log(pt.maximum(epsilon, a)), |
There was a problem hiding this comment.
Evaluate separately providing a meaningful name.
There was a problem hiding this comment.
We are probably not going to keep this one. I just tried this to see if we keep the log positive we can get somewhere. It helps a bit it seems, but the culprit is not this one
Co-authored-by: Carlos Paniagua <cpaniaguam@gmail.com>
Co-authored-by: Carlos Paniagua <cpaniaguam@gmail.com>
Co-authored-by: Carlos Paniagua <cpaniaguam@gmail.com>
Co-authored-by: Carlos Paniagua <cpaniaguam@gmail.com>
|
@cpaniaguam Thanks for the suggestions! I committed all excluding those involving Please feel free to take this further. This PR wasn't final - was just a placeholder for some of my experiments |
|
@digicosmos86 let's use this PR to switch to float64 overall? Also, the latest state of affairs with all changes in this PR is that it's still breaking right? |
|
You are correct. It is still broken. This PR is kind of my mess though. I'd rather start a new one and just switch out all the |
|
@digicosmos86 I am good with that approach. |
|
Since this is still in the works, I am going to convert it to a draft PR |
|
@digicosmos86 to be closed now that the other PR is up? |
No description provided.