Skip to content

Commit 71f6165

Browse files
committed
MixedInterpolator
1 parent 8d631f8 commit 71f6165

File tree

3 files changed

+155
-65
lines changed

3 files changed

+155
-65
lines changed

firedrake/assemble.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import finat.ufl
1919
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
2020
tsfc_interface, utils)
21-
from firedrake.formmanipulation import split_form
2221
from firedrake.adjoint_utils import annotate_assemble
2322
from firedrake.ufl_expr import extract_unique_domain
2423
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
@@ -570,36 +569,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
570569
rank = len(expr.arguments())
571570
if rank > 2:
572571
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
573-
# If argument numbers have been swapped => Adjoint.
574-
arg_operand = ufl.algorithms.extract_arguments(operand)
575-
is_adjoint = (arg_operand and arg_operand[0].number() == 0)
576-
577572
# Get the target space
578573
V = v.function_space().dual()
579574

580-
# Dual interpolation from mixed source
581-
if is_adjoint and len(V) > 1:
582-
cur = 0
583-
sub_operands = []
584-
components = numpy.reshape(operand, (-1,))
585-
for Vi in V:
586-
sub_operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape)))
587-
cur += Vi.value_size
588-
589-
# Component-split of the primal operands interpolated into the dual argument-split
590-
split_interp = sum(reconstruct_interp(sub_operands[i], v=vi) for (i,), vi in split_form(v))
591-
return assemble(split_interp, tensor=tensor)
592-
593-
# Dual interpolation into mixed target
594-
if is_adjoint and len(arg_operand[0].function_space()) > 1 and rank == 1:
595-
V = arg_operand[0].function_space()
596-
tensor = tensor or firedrake.Cofunction(V.dual())
597-
598-
# Argument-split of the Interpolate gets assembled into the corresponding sub-tensor
599-
for (i,), sub_interp in split_form(expr):
600-
assemble(sub_interp, tensor=tensor.subfunctions[i])
601-
return tensor
602-
603575
# Get the interpolator
604576
interp_data = expr.interp_data.copy()
605577
default_missing_val = interp_data.pop('default_missing_val', None)

firedrake/interpolation.py

Lines changed: 108 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,18 @@ class Interpolator(abc.ABC):
263263

264264
def __new__(cls, expr, V, **kwargs):
265265
if isinstance(expr, ufl.Interpolate):
266+
# Mixed spaces are handled well only by the primal 1-form.
267+
# Are we a 2-form or a dual 1-form?
268+
arguments = expr.arguments()
269+
if any(not isinstance(a, Coargument) for a in arguments):
270+
# Do we have mixed source or target spaces?
271+
spaces = [a.function_space() for a in arguments]
272+
if len(spaces) < 2:
273+
spaces.append(expr.function_space())
274+
if any(len(space) > 1 for space in spaces):
275+
return object.__new__(MixedInterpolator)
266276
expr, = expr.ufl_operands
277+
267278
target_mesh = as_domain(V)
268279
source_mesh = extract_unique_domain(expr) or target_mesh
269280
submesh_interp_implemented = \
@@ -369,7 +380,7 @@ def _interpolate(self, *args, **kwargs):
369380
"""
370381
pass
371382

372-
def assemble(self, tensor=None, default_missing_val=None):
383+
def assemble(self, tensor=None, **kwargs):
373384
"""Assemble the operator (or its action)."""
374385
from firedrake.assemble import assemble
375386
needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate
@@ -383,13 +394,11 @@ def assemble(self, tensor=None, default_missing_val=None):
383394
if needs_adjoint:
384395
# Out-of-place Hermitian transpose
385396
petsc_mat.hermitianTranspose(out=res)
386-
elif res:
387-
petsc_mat.copy(res)
397+
elif tensor:
398+
petsc_mat.copy(tensor.petscmat)
388399
else:
389400
res = petsc_mat
390-
if tensor is None:
391-
tensor = firedrake.AssembledMatrix(arguments, self.bcs, res)
392-
return tensor
401+
return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res)
393402
else:
394403
# Assembling the action
395404
cofunctions = ()
@@ -401,11 +410,11 @@ def assemble(self, tensor=None, default_missing_val=None):
401410
cofunctions = (dual_arg,)
402411

403412
if needs_adjoint and len(arguments) == 0:
404-
Iu = self._interpolate(default_missing_val=default_missing_val)
413+
Iu = self._interpolate(**kwargs)
405414
return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor)
406415
else:
407416
return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint,
408-
default_missing_val=default_missing_val)
417+
**kwargs)
409418

410419

411420
class DofNotDefinedError(Exception):
@@ -975,33 +984,10 @@ def callable():
975984
return callable
976985
else:
977986
loops = []
978-
if len(V) == 1:
979-
expressions = (expr,)
980-
else:
981-
if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(V)
982-
and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(V, operand.subfunctions))):
983-
# Use subfunctions if they match the target shapes
984-
operands = operand.subfunctions
985-
else:
986-
# Unflatten the expression into the shapes of the mixed components
987-
offset = 0
988-
operands = []
989-
for Vsub in V:
990-
if len(Vsub.value_shape) == 0:
991-
operands.append(operand[offset])
992-
else:
993-
components = [operand[offset + j] for j in range(Vsub.value_size)]
994-
operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape)))
995-
offset += Vsub.value_size
996-
997-
# Split the dual argument
998-
if isinstance(dual_arg, Cofunction):
999-
duals = dual_arg.subfunctions
1000-
elif isinstance(dual_arg, Coargument):
1001-
duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()]
1002-
else:
1003-
duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))]
1004-
expressions = map(expr._ufl_expr_reconstruct_, operands, duals)
987+
expressions = split_interpolate_target(expr)
988+
989+
if access == op2.INC:
990+
loops.append(tensor.zero)
1005991

1006992
# Interpolate each sub expression into each function space
1007993
for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions):
@@ -1074,8 +1060,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10741060
parameters['scalar_type'] = utils.ScalarType
10751061

10761062
callables = ()
1077-
if access == op2.INC:
1078-
callables += (tensor.zero,)
10791063

10801064
# For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple
10811065
# contributions from the facet DOFs of the dual argument.
@@ -1720,3 +1704,90 @@ def _wrap_dummy_mat(self):
17201704

17211705
def duplicate(self, mat=None, op=None):
17221706
return self._wrap_dummy_mat()
1707+
1708+
1709+
def split_interpolate_target(expr: ufl.Interpolate):
1710+
"""Split an Interpolate into the components (subfunctions) of the target space."""
1711+
dual_arg, operand = expr.argument_slots()
1712+
V = dual_arg.function_space().dual()
1713+
if len(V) == 1:
1714+
return (expr,)
1715+
# Split the target (dual) argument
1716+
if isinstance(dual_arg, Cofunction):
1717+
duals = dual_arg.subfunctions
1718+
elif isinstance(dual_arg, ufl.Coargument):
1719+
duals = [Coargument(Vsub, dual_arg.number()) for Vsub in dual_arg.function_space()]
1720+
else:
1721+
duals = [vi for _, vi in sorted(firedrake.formmanipulation.split_form(dual_arg))]
1722+
# Split the operand into the target shapes
1723+
if (isinstance(operand, firedrake.Function) and len(operand.subfunctions) == len(V)
1724+
and all(fsub.ufl_shape == Vsub.value_shape for Vsub, fsub in zip(V, operand.subfunctions))):
1725+
# Use subfunctions if they match the target shapes
1726+
operands = operand.subfunctions
1727+
else:
1728+
# Unflatten the expression into the target shapes
1729+
cur = 0
1730+
operands = []
1731+
components = numpy.reshape(operand, (-1,))
1732+
for Vi in V:
1733+
operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape)))
1734+
cur += Vi.value_size
1735+
expressions = tuple(map(expr._ufl_expr_reconstruct_, operands, duals))
1736+
return expressions
1737+
1738+
1739+
class MixedInterpolator(Interpolator):
1740+
1741+
def __init__(self, expr, V, bcs=None, **kwargs):
1742+
if bcs is None:
1743+
bcs = ()
1744+
self.expr = expr
1745+
self.V = V
1746+
self.bcs = bcs
1747+
self.arguments = expr.arguments()
1748+
# Split the target (dual) argument
1749+
dual_split = split_interpolate_target(expr)
1750+
self.sub_interpolators = {}
1751+
for i, form in enumerate(dual_split):
1752+
Vtarget = V.sub(i)
1753+
target_bcs = [bc for bc in bcs if bc.function_space() == Vtarget]
1754+
# Split the source (primal) argument
1755+
for j, sub_interp in firedrake.formmanipulation.split_form(form):
1756+
j = max(j) if j else 0
1757+
dual_arg, operand = sub_interp.argument_slots()
1758+
adjoint = dual_arg.number() == 1 if isinstance(dual_arg, Coargument) else True
1759+
# Ensure block sparsity
1760+
if not isinstance(operand, ufl.classes.Zero):
1761+
args = sub_interp.arguments()
1762+
Vsource = args[0 if adjoint else 1].function_space()
1763+
source_bcs = [bc for bc in bcs if bc.function_space() == Vsource]
1764+
sub_bcs = source_bcs + target_bcs
1765+
indices = (j, i) if adjoint else (i, j)
1766+
Isub = Interpolator(sub_interp, Vtarget, bcs=sub_bcs, **kwargs)
1767+
self.sub_interpolators[indices] = Isub
1768+
1769+
def assemble(self, tensor=None, **kwargs):
1770+
"""Assemble the operator (or its action)."""
1771+
rank = len(self.arguments)
1772+
if rank == 2:
1773+
sub_tensors = {}
1774+
for ij, Isub in self.sub_interpolators.items():
1775+
block = tensor.petscmat.getNestSubMatrix(*ij) if tensor else PETSc.Mat()
1776+
sub_tensors[ij] = firedrake.AssembledMatrix(Isub.arguments, Isub.bcs, block)
1777+
Isub.assemble(tensor=sub_tensors[ij], **kwargs)
1778+
if tensor is None:
1779+
shape = tuple(len(a.function_space()) for a in self.arguments)
1780+
blocks = numpy.reshape([sub_tensors[ij].petscmat if ij in sub_tensors else PETSc.Mat()
1781+
for ij in numpy.ndindex(shape)], shape)
1782+
petscmat = PETSc.Mat().createNest(blocks)
1783+
tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat)
1784+
else:
1785+
tensor = self._interpolate(output=tensor, **kwargs)
1786+
return tensor
1787+
1788+
def _interpolate(self, output=None, **kwargs):
1789+
"""Assemble the action."""
1790+
tensor = output or firedrake.Function(self.expr.function_space())
1791+
for k, fsub in enumerate(tensor.subfunctions):
1792+
fsub.assign(sum(Isub._interpolate(**kwargs) for (i, j), Isub in self.sub_interpolators.items() if i == k))
1793+
return tensor

tests/firedrake/regression/test_interpolate.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,3 +517,50 @@ def test_interpolate_logical_not():
517517
a = assemble(interpolate(conditional(Not(x < .2), 1, 0), V))
518518
b = assemble(interpolate(conditional(x >= .2, 1, 0), V))
519519
assert np.allclose(a.dat.data, b.dat.data)
520+
521+
522+
@pytest.mark.parametrize("mode", ("forward", "adjoint"))
523+
def test_mixed_matrix(mode):
524+
nx = 3
525+
mesh = UnitSquareMesh(nx, nx)
526+
527+
V1 = VectorFunctionSpace(mesh, "CG", 2)
528+
V2 = FunctionSpace(mesh, "CG", 1)
529+
V3 = VectorFunctionSpace(mesh, "CG", 1)
530+
V4 = FunctionSpace(mesh, "DG", 1)
531+
532+
Z = V1 * V2
533+
W = V3 * V4
534+
535+
if mode == "forward":
536+
I = Interpolate(TrialFunction(Z), TestFunction(W.dual()))
537+
a = assemble(I)
538+
assert a.arguments()[0].function_space() == W.dual()
539+
assert a.arguments()[1].function_space() == Z
540+
assert a.petscmat.getSize() == (W.dim(), Z.dim())
541+
assert a.petscmat.getType() == "nest"
542+
543+
u = Function(Z)
544+
u.subfunctions[0].sub(0).assign(1)
545+
u.subfunctions[0].sub(1).assign(2)
546+
u.subfunctions[1].assign(3)
547+
result_matfree = assemble(Interpolate(u, TestFunction(W.dual())))
548+
elif mode == "adjoint":
549+
I = Interpolate(TestFunction(Z), TrialFunction(W.dual()))
550+
a = assemble(I)
551+
assert a.arguments()[1].function_space() == W.dual()
552+
assert a.arguments()[0].function_space() == Z
553+
assert a.petscmat.getSize() == (Z.dim(), W.dim())
554+
assert a.petscmat.getType() == "nest"
555+
556+
u = Function(W.dual())
557+
u.subfunctions[0].sub(0).assign(1)
558+
u.subfunctions[0].sub(1).assign(2)
559+
u.subfunctions[1].assign(3)
560+
result_matfree = assemble(Interpolate(TestFunction(Z), u))
561+
else:
562+
raise ValueError(f"Unrecognized mode {mode}")
563+
564+
result_explicit = assemble(action(a, u))
565+
for x, y in zip(result_explicit.subfunctions, result_matfree.subfunctions):
566+
assert np.allclose(x.dat.data, y.dat.data)

0 commit comments

Comments
 (0)