Spectral / Galerkin / PINNs experimentation toolkit built on top of JAX for fast differentiable ODE / PDE prototyping, variational forms, and mixed spectral bases.
The coverage badge will update once Codecov is fully configured.
- Orthogonal polynomial and Fourier bases (Chebyshev, Legendre, Jacobi, etc.)
- Tensor product and direct sum spaces with boundary conditions
- Assembly of bilinear / linear forms with symbolic (SymPy) coefficients
- A SymPy-based form-language for describing PDEs
- Curvilinear coordinates
- JAX-backed forward/backward transforms and differentiation
- Utilities for sparse conversion, preconditioning, and projection
- A friendly interface for experimenting with PINNs
Using uv (recommended):
pip install uv # if not already installed
uv add jaxfun # when publishedFrom source:
git clone https://github.com/spectralDNS/jaxfun.git
cd jaxfun
uv syncfrom jaxfun.galerkin import Chebyshev, TensorProduct, TestFunction, TrialFunction
from jaxfun.galerkin.inner import inner
from jaxfun.operators import Div, Grad
C = Chebyshev.Chebyshev(16)
T = TensorProduct((C, C))
v = TestFunction(T)
u = TrialFunction(T)
A = inner(Div(Grad(u)) * v)Use a simple multilayer perceptron neural network and solve Poisson's equation on the unit square
from jaxfun.pinns import FlaxFunction, Loss, MLPSpace, Trainer, UnitSquare, adam, lbfgs
from jaxfun.operators import Div, Grad
# Create an MLP neural network space with two hidden layers
V = MLPSpace([12, 12], dims=2, rank=0, name="V")
u = FlaxFunction(V, name="u") # The trial function, which here is a neural network
# Get mesh points on and inside the unit square
N = 50
mesh = UnitSquare()
xyi = mesh.get_points_inside_domain(N, N, "uniform")
xyb = mesh.get_points_on_domain(N, N, "uniform")
# Define Poisson's equation: residual = △u - 2
residual = Div(Grad(u)) - 2
# Define loss function based on Poisson's equation, including
# homogeneous Dirichlet boundary conditions, and train model
loss_fn = Loss((residual, xyi), (u, xyb))
trainer = Trainer(loss_fn)
trainer.train(adam(u), 5000)
trainer.train(lbfgs(u), 5000)See the examples directory and preliminary notebooks for more patterns.
Run tests (excluding slow):
uv run pytestRun full (including slow demos):
uv run pytest -m "slow or not slow"Lint & format:
uv run pre-commit run --all-filesSee CONTRIBUTING and the Code of Conduct.
BSD 2-Clause – see LICENSE.
- Mikael Mortensen: mikaem@math.uio.no
- August Femtehjell: august.femtehjell@uio.no