Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes issue #26 by making JAXFunction-based coefficients evaluate correctly when assembling forms on mapped (non-reference) domains, and extends the test suite to cover both reference and mapped domains.
Changes:
- Parameterize
tests/galerkin/test_jaxfunction.pyoverdomain=Noneanddomain=Domain(-2, 2)to reproduce/guard mapped-domain behavior. - Ensure
JAXFunctionexpressions ininner()assembly are evaluated using physical-domain coordinates (mapping quadrature points to the true domain before evaluation). - Normalize derivative scaling factors (
domain_factor**k) to plain floats for consistent numeric behavior in derivative evaluation.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
tests/galerkin/test_jaxfunction.py |
Adds mapped-domain coverage via fixtures and updates space construction to pass domain=. |
src/jaxfun/galerkin/tensorproductspace.py |
Casts derivative scaling factors to float during tensor-product derivative evaluation. |
src/jaxfun/galerkin/orthogonal.py |
Casts domain_factor**k to float in OrthogonalSpace.evaluate_derivative. |
src/jaxfun/galerkin/inner.py |
Evaluates JAXFunction coefficients at physical coordinates by mapping quadrature points to the true domain. |
src/jaxfun/galerkin/arguments.py |
Clarifies/standardizes evaluate_jaxfunction_expr to take physical-domain coordinates and maps internally as needed for orthogonal/direct-sum spaces. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Just stupid floating point accuracy that is sometimes and randomly failing! |
Might be related this time to |
I was thinking it is JAX that is choosing how to compute stuff somewhat randomly. I guess we have not been great at requesting max precision where appropriate? |
Running with |
|
@mikaem does it make sense to just do all of the computation in reference space, and then just map the result back to the true domain in the end? |
As much as possible of 1D stuff should be in reference space. I believe it's only evaluate_derivative that is not, but I don't remember exactly why. Probably because of the domain_factor that needs to be accounted for. |
Made variants of the functions in #28 which directly exposes the reference domain functions, so that the existing functions become wrappers which simply map the input from true to reference, then call the reference functions. In this way, we can to a greater degree avoid mapping back and forth between the domains. This eliminated the floating point issue which popped up here. |
|
Bumped the b.__class__.__str__ = __str__By the way, why are we setting Full diagnostics: |
Because it works? Does setting |
Ah I see, tried this, didn't work. class C:
def __str__(self) -> str: return "class"
def c__str__(self) -> str: return "c instance"
def d__str__(self) -> str: return "d instance"
c = C()
c.__str__ = c__str__
print(f"{str(c)=}")
d = C()
d.__class__.__str__ = d__str__
print(f"{str(d)=}")
print(f"{str(c)=}") |
|
Squash and merge? |
fixes #26
There was also an issue when the
domain_factorwas a sympy expression, as is typically the case when mapping Fourier domains.