Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions src/jaxfun/galerkin/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,32 +775,41 @@ def evaluate_mesh(self, x: Array | list[Array]) -> Array:


def evaluate_jaxfunction_expr(
a: Basic, xj: Array | tuple[Array, ...], jaxf: AppliedUndef | None = None
a: Basic, Xj: Array | tuple[Array, ...], jaxf: AppliedUndef | None = None
) -> Array:
"""Evaluate a symbolic JAXFunction expression on a mesh.

Input coordinates ``Xj`` are always given in the reference domain.
"""
if jaxf is None:
for p in sp.core.traversal.preorder_traversal(a):
if get_arg(p) is ArgumentTag.JAXFUNC: # JAXFunction->AppliedUndef
jaxf = cast(AppliedUndef, p)
break
assert hasattr(jaxf, "functionspace") and hasattr(jaxf, "array")
V = cast(FunctionSpaceType, jaxf.functionspace)

def evaluate_value() -> Array:
if isinstance(V, OrthogonalSpace | DirectSum):
assert isinstance(Xj, Array)
return V.evaluate(Xj, jaxf.array)
return V.evaluate_mesh_reference(Xj, jaxf.array, True)

if isinstance(a, sp.Pow):
wa = a.args[0]
variables = getattr(wa, "variables", ())
var = tuple(variables.count(s) for s in V.system.base_scalars())
var = var[0] if V.dims == 1 else var
h = V.evaluate_derivative(xj, jaxf.array, k=var)
h = h ** int(a.exp)
if isinstance(wa, sp.Derivative):
variables = getattr(wa, "variables", ())
var = tuple(variables.count(s) for s in V.system.base_scalars())
k = var[0] if V.dims == 1 else var
h = V.evaluate_derivative_reference(Xj, jaxf.array, k=k)
else:
h = evaluate_value()
return h ** int(a.exp)

elif isinstance(a, sp.Derivative):
if isinstance(a, sp.Derivative):
variables = getattr(a, "variables", ())
var = tuple(variables.count(s) for s in V.system.base_scalars())
var = var[0] if V.dims == 1 else var
h = V.evaluate_derivative(xj, jaxf.array, k=var)
k = var[0] if V.dims == 1 else var
return V.evaluate_derivative_reference(Xj, jaxf.array, k=k)

else:
if not isinstance(V, OrthogonalSpace | DirectSum):
h = V.evaluate_mesh(xj, jaxf.array, True)
else:
h = V.evaluate(xj, jaxf.array)
return h
return evaluate_value()
18 changes: 15 additions & 3 deletions src/jaxfun/galerkin/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,10 @@ def mesh(self, kind: str = "quadrature", N: int = 0) -> Array:
"""Return mesh from homogeneous Composite summand."""
return self[0].mesh(kind=kind, N=N)

def mesh_reference(self, kind: str = "quadrature", N: int = 0) -> Array:
"""Return reference-domain mesh from homogeneous Composite summand."""
return self[0].mesh_reference(kind=kind, N=N)

def bnd_vals(self) -> Array:
"""Return boundary lifting values (from BCGeneric)."""
return self[1].bnd_vals()
Expand Down Expand Up @@ -502,11 +506,19 @@ def forward(self, uj: Array) -> Array:
return jnp.linalg.solve(M, b)

@jax.jit(static_argnums=(0, 3))
def evaluate_derivative(self, X: Array, c: Array, k: int = 0) -> float:
"""Evaluate k-th derivative at X (composite + boundary)."""
def evaluate_derivative(self, x: Array, c: Array, k: int = 0) -> float:
"""Evaluate k-th derivative at true-domain points (composite + boundary)."""
X = self.map_reference_domain(x)
return self.evaluate_derivative_reference(X, c, k)

@jax.jit(static_argnums=(0, 3))
def evaluate_derivative_reference(self, X: Array, c: Array, k: int = 0) -> float:
"""Evaluate k-th derivative at reference-domain points."""
a, b = self.basespaces
bv = self.bnd_vals()
return a.evaluate_derivative(X, c, k) + b.evaluate_derivative(X, bv, k)
return a.evaluate_derivative_reference(
X, c, k
) + b.evaluate_derivative_reference(X, bv, k)


def get_stencil_matrix(bcs: BoundaryConditions, orthogonal: Jacobi) -> dict:
Expand Down
30 changes: 18 additions & 12 deletions src/jaxfun/galerkin/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ def inner(
scales.append(a0["multivar"])
if "jaxfunction" in a0:
scales.append(
evaluate_jaxfunction_expr(a0["jaxfunction"], test_space.mesh())
evaluate_jaxfunction_expr(
a0["jaxfunction"], test_space.mesh_reference()
)
)

Am = assemble_multivar(mats_, scales, test_space)
Expand Down Expand Up @@ -323,7 +325,9 @@ def inner(
s = test_space.system.base_scalars()
uj = lambdify(s, b0["multivar"], modules="jax")(*test_space.mesh())
elif "jaxfunction" in b0:
uj = evaluate_jaxfunction_expr(b0["jaxfunction"], test_space.mesh())
uj = evaluate_jaxfunction_expr(
b0["jaxfunction"], test_space.mesh_reference()
)
else:
raise ValueError("Expected multivar or jaxfunction key in b0")
res = bs[0][0].T @ uj @ bs[1][0]
Expand All @@ -343,7 +347,9 @@ def inner(
s = test_space.system.base_scalars()
uj = lambdify(s, b0["multivar"], modules="jax")(*test_space.mesh())
elif "jaxfunction" in b0:
uj = evaluate_jaxfunction_expr(b0["jaxfunction"], test_space.mesh())
uj = evaluate_jaxfunction_expr(
b0["jaxfunction"], test_space.mesh_reference()
)
else:
raise ValueError("Expected multivar or jaxfunction key in b0")
res = jnp.einsum("il,jm,kn,ijk->lmn", bs[0][0], bs[1][0], bs[2][0], uj)
Expand Down Expand Up @@ -475,7 +481,7 @@ def inner_bilinear(
"""
vo = v.orthogonal
uo = u.orthogonal
xj, wj = vo.quad_points_and_weights()
Xj, wj = vo.quad_points_and_weights()
df = float(vo.domain_factor)
i, j = 0, 0
scale = jnp.array([sc])
Expand All @@ -501,11 +507,11 @@ def inner_bilinear(
jaxfunction = cast(AppliedUndef, p)
break
if jaxfunction:
scale *= evaluate_jaxfunction_expr(aii, xj, jaxfunction)
scale *= evaluate_jaxfunction_expr(aii, Xj, jaxfunction)
continue
if len(aii.free_symbols) > 0:
s = aii.free_symbols.pop()
scale *= lambdify(s, uo.map_expr_true_domain(aii), modules="jax")(xj)
scale *= lambdify(s, uo.map_expr_true_domain(aii), modules="jax")(Xj)
else:
scale *= float(aii) # ty:ignore[invalid-argument-type]

Expand All @@ -524,8 +530,8 @@ def inner_bilinear(

if z is None:
w = wj * df ** (i + j - 1) * scale
Pi = vo.evaluate_basis_derivative(xj, k=i)
Pj = uo.evaluate_basis_derivative(xj, k=j)
Pi = vo.evaluate_basis_derivative(Xj, k=i)
Pj = uo.evaluate_basis_derivative(Xj, k=j)

if multivar:
multi: Array = v.apply_stencil_right(w[:, None] * Pi)
Expand Down Expand Up @@ -565,7 +571,7 @@ def inner_linear(
Vector (1D), tuple (Pi,) for multivar, or projected result.
"""
vo = v.orthogonal
xj, wj = vo.quad_points_and_weights()
Xj, wj = vo.quad_points_and_weights()
df = float(vo.domain_factor)
i = 0
uj = jnp.array([sc]) # incorporate scalar coefficient into first vector
Expand All @@ -592,17 +598,17 @@ def inner_linear(
jaxfunction = cast(AppliedUndef, p)
break
if jaxfunction:
uj *= evaluate_jaxfunction_expr(bii, xj, jaxfunction)
uj *= evaluate_jaxfunction_expr(bii, Xj, jaxfunction)
continue
# bii contains coordinates in the domain of v, e.g., (r, theta) for polar
# Need to compute bii as bii(x(X)), since we use quadrature points
if len(bii.free_symbols) > 0:
s = bii.free_symbols.pop()
uj *= lambdify(s, vo.map_expr_true_domain(bii), modules="jax")(xj)
uj *= lambdify(s, vo.map_expr_true_domain(bii), modules="jax")(Xj)
else:
uj *= float(bii) # ty:ignore[invalid-argument-type]

Pi = vo.evaluate_basis_derivative(xj, k=i)
Pi = vo.evaluate_basis_derivative(Xj, k=i)
w = wj * df ** (i - 1) # Account for domain different from reference
if multivar:
return (v.apply_stencil_right((uj * w)[:, None] * jnp.conj(Pi)),)
Expand Down
29 changes: 22 additions & 7 deletions src/jaxfun/galerkin/orthogonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,22 @@ def evaluate(self, X: float | Array, c: Array) -> Array:

@jax.jit(static_argnums=(0, 3))
def evaluate_derivative(self, x: Array, c: Array, k: int = 0) -> Array:
"""Evaluate truncated series sum_k c_k psi_k(X).
"""Evaluate k-th physical derivative at true-domain points.

Args:
x: Evaluation point(s) in real coordinates.
x: Evaluation point(s) in true-domain coordinates.
c: Coefficient vector ( <= self.N).
k: Derivative order (default 0 -> function value).
Returns:
Array of shape like x containing series evaluation.
"""
X = self.map_reference_domain(x)
df = self.domain_factor**k
return self.evaluate_derivative_reference(X, c, k)

@jax.jit(static_argnums=(0, 3))
def evaluate_derivative_reference(self, X: Array, c: Array, k: int = 0) -> Array:
"""Evaluate k-th physical derivative at reference-domain points."""
df = float(self.domain_factor**k)
return df * self.evaluate_basis_derivative(X, k)[..., : len(c)] @ c

@jax.jit(static_argnums=0)
Expand Down Expand Up @@ -151,8 +156,8 @@ def evaluate_basis_derivative(self, X: Array, k: int = 0) -> Array:
@jax.jit(static_argnums=(0, 2, 3))
def backward(self, c: Array, kind: str = "quadrature", N: int = 0) -> Array:
"""Implementation of backward (allows subclass override)."""
xj = self.mesh(kind=kind, N=N)
return self.evaluate(self.map_reference_domain(xj), c)
Xj = self.mesh_reference(kind=kind, N=N)
return self.evaluate(Xj, c)

@jax.jit(static_argnums=0)
def mass_matrix(self) -> BCOO:
Expand Down Expand Up @@ -320,14 +325,24 @@ def map_true_domain(self, X: sp.Symbol | Array) -> sp.Expr | Array:
def mesh(self, kind: str = "quadrature", N: int = 0) -> Array:
"""Return sampling mesh in true domain.

Args:
kind: 'quadrature' (default) or 'uniform'.
N: Number of uniform points (0 -> num_quad_points).
"""
return self.map_true_domain(self.mesh_reference(kind, N))

@jax.jit(static_argnums=(0, 1, 2))
def mesh_reference(self, kind: str = "quadrature", N: int = 0) -> Array:
"""Return sampling mesh in reference domain.

Args:
kind: 'quadrature' (default) or 'uniform'.
N: Number of uniform points (0 -> num_quad_points).
"""
if kind == "quadrature":
return self.map_true_domain(self.quad_points_and_weights(N)[0])
return self.quad_points_and_weights(N)[0]
assert kind == "uniform"
a, b = self.domain
a, b = self.reference_domain
M = N if N != 0 else self.num_quad_points
return jnp.linspace(float(a), float(b), M)

Expand Down
Loading