Skip to content

Jaxfunction mapped domain#27

Merged
augustfe merged 5 commits intomainfrom
jaxfunction-mapped-domain
Mar 3, 2026
Merged

Jaxfunction mapped domain#27
augustfe merged 5 commits intomainfrom
jaxfunction-mapped-domain

Conversation

@augustfe
Copy link
Collaborator

@augustfe augustfe commented Feb 27, 2026

fixes #26
There was also an issue when the domain_factor was a sympy expression, as is typically the case when mapping Fourier domains.

Copy link
Member

@mikaem mikaem left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes issue #26 by making JAXFunction-based coefficients evaluate correctly when assembling forms on mapped (non-reference) domains, and extends the test suite to cover both reference and mapped domains.

Changes:

  • Parameterize tests/galerkin/test_jaxfunction.py over domain=None and domain=Domain(-2, 2) to reproduce/guard mapped-domain behavior.
  • Ensure JAXFunction expressions in inner() assembly are evaluated using physical-domain coordinates (mapping quadrature points to the true domain before evaluation).
  • Normalize derivative scaling factors (domain_factor**k) to plain floats for consistent numeric behavior in derivative evaluation.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/galerkin/test_jaxfunction.py Adds mapped-domain coverage via fixtures and updates space construction to pass domain=.
src/jaxfun/galerkin/tensorproductspace.py Casts derivative scaling factors to float during tensor-product derivative evaluation.
src/jaxfun/galerkin/orthogonal.py Casts domain_factor**k to float in OrthogonalSpace.evaluate_derivative.
src/jaxfun/galerkin/inner.py Evaluates JAXFunction coefficients at physical coordinates by mapping quadrature points to the true domain.
src/jaxfun/galerkin/arguments.py Clarifies/standardizes evaluate_jaxfunction_expr to take physical-domain coordinates and maps internally as needed for orthogonal/direct-sum spaces.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@mikaem
Copy link
Member

mikaem commented Feb 27, 2026

Just stupid floating point accuracy that is sometimes and randomly failing!

@augustfe
Copy link
Collaborator Author

Just stupid floating point accuracy that is sometimes and randomly failing!

Might be related this time to inner mapping to true domain, and then passing to evaluate_jaxfunction_expr which maps directly back to reference domain before actually doing anything

@mikaem
Copy link
Member

mikaem commented Feb 27, 2026

Just stupid floating point accuracy that is sometimes and randomly failing!

Might be related this time to inner mapping to true domain, and then passing to evaluate_jaxfunction_expr which maps directly back to reference domain before actually doing anything

I was thinking it is JAX that is choosing how to compute stuff somewhat randomly. I guess we have not been great at requesting max precision where appropriate?

@augustfe
Copy link
Collaborator Author

Just stupid floating point accuracy that is sometimes and randomly failing!

Might be related this time to inner mapping to true domain, and then passing to evaluate_jaxfunction_expr which maps directly back to reference domain before actually doing anything

I was thinking it is JAX that is choosing how to compute stuff somewhat randomly. I guess we have not been great at requesting max precision where appropriate?

Running with JAX_DEFAULT_MATMUL_PRECISION='highest' JAX_ENABLE_X64=True uv run pytest --float64, the test still fails...

@augustfe
Copy link
Collaborator Author

augustfe commented Feb 27, 2026

@mikaem does it make sense to just do all of the computation in reference space, and then just map the result back to the true domain in the end?

@mikaem
Copy link
Member

mikaem commented Feb 27, 2026

@mikaem does it make sense to just do all of the computation in reference space, and then just map the result back to the true domain in the end?

As much as possible of 1D stuff should be in reference space. I believe it's only evaluate_derivative that is not, but I don't remember exactly why. Probably because of the domain_factor that needs to be accounted for.

@augustfe
Copy link
Collaborator Author

@mikaem does it make sense to just do all of the computation in reference space, and then just map the result back to the true domain in the end?

As much as possible of 1D stuff should be in reference space. I believe it's only evaluate_derivative that is not, but I don't remember exactly why. Probably because of the domain_factor that needs to be accounted for.

Made variants of the functions in #28 which directly exposes the reference domain functions, so that the existing functions become wrappers which simply map the input from true to reference, then call the reference functions. In this way, we can to a greater degree avoid mapping back and forth between the domains. This eliminated the floating point issue which popped up here.

@augustfe
Copy link
Collaborator Author

augustfe commented Mar 3, 2026

Bumped the ty version, which introduced complaints of implicit shadowing in the get_BasisFunction and get_JAXFunction when setting

b.__class__.__str__ = __str__

By the way, why are we setting b.__class__.__str__ instead of b.__str__?

Full diagnostics:

error[invalid-assignment]: Object of type `def __str__(self) -> str` is not assignable to attribute `__str__` of type `def __str__(self) -> str`
   --> src/jaxfun/galerkin/arguments.py:133:5
    |
131 |     # b.__class__.__str__
132 |
133 |     b.__class__.__str__ = __str__
    |     ^^^^^^^^^^^^^^^^^^^
134 |     b.__class__._pretty = _pretty
135 |     b.__class__._sympystr = _sympystr
    |
info: Implicit shadowing of function `__str__`, add an annotation to make it explicit if this is intentional
info: rule `invalid-assignment` is enabled by default

@mikaem
Copy link
Member

mikaem commented Mar 3, 2026

Bumped the ty version, which introduced complaints of implicit shadowing in the get_BasisFunction and get_JAXFunction when setting

b.__class__.__str__ = __str__

By the way, why are we setting b.__class__.__str__ instead of b.__str__?

Because it works? Does setting b.__str__ work?

@augustfe
Copy link
Collaborator Author

augustfe commented Mar 3, 2026

Bumped the ty version, which introduced complaints of implicit shadowing in the get_BasisFunction and get_JAXFunction when setting

b.__class__.__str__ = __str__

By the way, why are we setting b.__class__.__str__ instead of b.__str__?

Because it works? Does setting b.__str__ work?

Ah I see, tried this, didn't work.

class C:
    def __str__(self) -> str: return "class"

def c__str__(self) -> str: return "c instance"

def d__str__(self) -> str: return "d instance"

c = C()
c.__str__ = c__str__
print(f"{str(c)=}")

d = C()
d.__class__.__str__ = d__str__
print(f"{str(d)=}")
print(f"{str(c)=}")
str(c)='class'
str(d)='d instance'
str(c)='d instance'

@mikaem
Copy link
Member

mikaem commented Mar 3, 2026

Squash and merge?

@augustfe augustfe merged commit 2805444 into main Mar 3, 2026
2 of 3 checks passed
@augustfe augustfe deleted the jaxfunction-mapped-domain branch March 3, 2026 12:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mapped domains not handled correctly with jaxfunction?

3 participants