Skip to content

Fix float32 vs float64 issue#568

Merged
digicosmos86 merged 3 commits intomainfrom
156-fix-float32-vs-float64-issue
Aug 22, 2024
Merged

Fix float32 vs float64 issue#568
digicosmos86 merged 3 commits intomainfrom
156-fix-float32-vs-float64-issue

Conversation

@digicosmos86
Copy link
Copy Markdown
Collaborator

closes #468
closes #156

@AlexanderFengler @cpaniaguam After working on this PR, I realized that this is not very different from #468, but I have only incorporated changes that we agreed on. @cpaniaguam it seems that the code did not quite work after merging your suggestions in #468, so I didn't incorporate all of them. Can you take a look again, and if there's anything missing, let me know?

@digicosmos86 digicosmos86 linked an issue Aug 21, 2024 that may be closed by this pull request
grad_func = pytensor.function(
[v, a, z, t],
grad,
mode=NanGuardMode(nan_is_error=True, inf_is_error=True, big_is_error=False),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define mode once and reuse it twice?

    mode = NanGuardMode(nan_is_error=True, inf_is_error=True, big_is_error=False)

Copy link
Copy Markdown
Member

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the behavior actually a bit more robust now?
Otherwise looks good, left only a minor comments.

ks = 2 + pt.sqrt(-2 * rt * pt.log(2 * np.sqrt(2 * np.pi * rt) * err))
ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0)
ks = pt.switch(2 * pt.sqrt(2 * np.pi * rt) * err < 1, ks, 2)
_a = 2 * pt.sqrt(2 * np.pi * rt) * err < 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if you want to go really crazy, you can define __a=2 * np.pi * rt only once :)

kl = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt))
kl = pt.max(pt.stack([kl, 1.0 / (np.pi * pt.sqrt(rt))]), axis=0)
kl = pt.switch(np.pi * rt * err < 1, kl, 1.0 / (np.pi * pt.sqrt(rt)))
_a = np.pi * rt * err < 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here with np.pi * rt * err

@digicosmos86 digicosmos86 merged commit 83bf5a5 into main Aug 22, 2024
@digicosmos86 digicosmos86 deleted the 156-fix-float32-vs-float64-issue branch August 22, 2024 19:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix float32 vs float64 issue

3 participants