@@ -263,7 +263,18 @@ class Interpolator(abc.ABC):
263263
264264 def __new__ (cls , expr , V , ** kwargs ):
265265 if isinstance (expr , ufl .Interpolate ):
266+ # Mixed spaces are handled well only by the primal 1-form.
267+ # Are we a 2-form or a dual 1-form?
268+ arguments = expr .arguments ()
269+ if any (not isinstance (a , Coargument ) for a in arguments ):
270+ # Do we have mixed source or target spaces?
271+ spaces = [a .function_space () for a in arguments ]
272+ if len (spaces ) < 2 :
273+ spaces .append (expr .function_space ())
274+ if any (len (space ) > 1 for space in spaces ):
275+ return object .__new__ (MixedInterpolator )
266276 expr , = expr .ufl_operands
277+
267278 target_mesh = as_domain (V )
268279 source_mesh = extract_unique_domain (expr ) or target_mesh
269280 submesh_interp_implemented = \
@@ -369,7 +380,7 @@ def _interpolate(self, *args, **kwargs):
369380 """
370381 pass
371382
372- def assemble (self , tensor = None , default_missing_val = None ):
383+ def assemble (self , tensor = None , ** kwargs ):
373384 """Assemble the operator (or its action)."""
374385 from firedrake .assemble import assemble
375386 needs_adjoint = self .ufl_interpolate_renumbered != self .ufl_interpolate
@@ -383,13 +394,11 @@ def assemble(self, tensor=None, default_missing_val=None):
383394 if needs_adjoint :
384395 # Out-of-place Hermitian transpose
385396 petsc_mat .hermitianTranspose (out = res )
386- elif res :
387- petsc_mat .copy (res )
397+ elif tensor :
398+ petsc_mat .copy (tensor . petscmat )
388399 else :
389400 res = petsc_mat
390- if tensor is None :
391- tensor = firedrake .AssembledMatrix (arguments , self .bcs , res )
392- return tensor
401+ return tensor or firedrake .AssembledMatrix (arguments , self .bcs , res )
393402 else :
394403 # Assembling the action
395404 cofunctions = ()
@@ -401,11 +410,11 @@ def assemble(self, tensor=None, default_missing_val=None):
401410 cofunctions = (dual_arg ,)
402411
403412 if needs_adjoint and len (arguments ) == 0 :
404- Iu = self ._interpolate (default_missing_val = default_missing_val )
413+ Iu = self ._interpolate (** kwargs )
405414 return assemble (ufl .Action (* cofunctions , Iu ), tensor = tensor )
406415 else :
407416 return self ._interpolate (* cofunctions , output = tensor , adjoint = needs_adjoint ,
408- default_missing_val = default_missing_val )
417+ ** kwargs )
409418
410419
411420class DofNotDefinedError (Exception ):
@@ -975,33 +984,10 @@ def callable():
975984 return callable
976985 else :
977986 loops = []
978- if len (V ) == 1 :
979- expressions = (expr ,)
980- else :
981- if (hasattr (operand , "subfunctions" ) and len (operand .subfunctions ) == len (V )
982- and all (sub_op .ufl_shape == Vsub .value_shape for Vsub , sub_op in zip (V , operand .subfunctions ))):
983- # Use subfunctions if they match the target shapes
984- operands = operand .subfunctions
985- else :
986- # Unflatten the expression into the shapes of the mixed components
987- offset = 0
988- operands = []
989- for Vsub in V :
990- if len (Vsub .value_shape ) == 0 :
991- operands .append (operand [offset ])
992- else :
993- components = [operand [offset + j ] for j in range (Vsub .value_size )]
994- operands .append (ufl .as_tensor (numpy .reshape (components , Vsub .value_shape )))
995- offset += Vsub .value_size
996-
997- # Split the dual argument
998- if isinstance (dual_arg , Cofunction ):
999- duals = dual_arg .subfunctions
1000- elif isinstance (dual_arg , Coargument ):
1001- duals = [Coargument (Vsub , number = dual_arg .number ()) for Vsub in dual_arg .function_space ()]
1002- else :
1003- duals = [v for _ , v in sorted (firedrake .formmanipulation .split_form (dual_arg ))]
1004- expressions = map (expr ._ufl_expr_reconstruct_ , operands , duals )
987+ expressions = split_interpolate_target (expr )
988+
989+ if access == op2 .INC :
990+ loops .append (tensor .zero )
1005991
1006992 # Interpolate each sub expression into each function space
1007993 for Vsub , sub_tensor , sub_expr in zip (V , tensor , expressions ):
@@ -1074,8 +1060,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10741060 parameters ['scalar_type' ] = utils .ScalarType
10751061
10761062 callables = ()
1077- if access == op2 .INC :
1078- callables += (tensor .zero ,)
10791063
10801064 # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple
10811065 # contributions from the facet DOFs of the dual argument.
@@ -1720,3 +1704,90 @@ def _wrap_dummy_mat(self):
17201704
17211705 def duplicate (self , mat = None , op = None ):
17221706 return self ._wrap_dummy_mat ()
1707+
1708+
1709+ def split_interpolate_target (expr : ufl .Interpolate ):
1710+ """Split an Interpolate into the components (subfunctions) of the target space."""
1711+ dual_arg , operand = expr .argument_slots ()
1712+ V = dual_arg .function_space ().dual ()
1713+ if len (V ) == 1 :
1714+ return (expr ,)
1715+ # Split the target (dual) argument
1716+ if isinstance (dual_arg , Cofunction ):
1717+ duals = dual_arg .subfunctions
1718+ elif isinstance (dual_arg , ufl .Coargument ):
1719+ duals = [Coargument (Vsub , dual_arg .number ()) for Vsub in dual_arg .function_space ()]
1720+ else :
1721+ duals = [vi for _ , vi in sorted (firedrake .formmanipulation .split_form (dual_arg ))]
1722+ # Split the operand into the target shapes
1723+ if (isinstance (operand , firedrake .Function ) and len (operand .subfunctions ) == len (V )
1724+ and all (fsub .ufl_shape == Vsub .value_shape for Vsub , fsub in zip (V , operand .subfunctions ))):
1725+ # Use subfunctions if they match the target shapes
1726+ operands = operand .subfunctions
1727+ else :
1728+ # Unflatten the expression into the target shapes
1729+ cur = 0
1730+ operands = []
1731+ components = numpy .reshape (operand , (- 1 ,))
1732+ for Vi in V :
1733+ operands .append (ufl .as_tensor (components [cur :cur + Vi .value_size ].reshape (Vi .value_shape )))
1734+ cur += Vi .value_size
1735+ expressions = tuple (map (expr ._ufl_expr_reconstruct_ , operands , duals ))
1736+ return expressions
1737+
1738+
1739+ class MixedInterpolator (Interpolator ):
1740+
1741+ def __init__ (self , expr , V , bcs = None , ** kwargs ):
1742+ if bcs is None :
1743+ bcs = ()
1744+ self .expr = expr
1745+ self .V = V
1746+ self .bcs = bcs
1747+ self .arguments = expr .arguments ()
1748+ # Split the target (dual) argument
1749+ dual_split = split_interpolate_target (expr )
1750+ self .sub_interpolators = {}
1751+ for i , form in enumerate (dual_split ):
1752+ Vtarget = V .sub (i )
1753+ target_bcs = [bc for bc in bcs if bc .function_space () == Vtarget ]
1754+ # Split the source (primal) argument
1755+ for j , sub_interp in firedrake .formmanipulation .split_form (form ):
1756+ j = max (j ) if j else 0
1757+ dual_arg , operand = sub_interp .argument_slots ()
1758+ adjoint = dual_arg .number () == 1 if isinstance (dual_arg , Coargument ) else True
1759+ # Ensure block sparsity
1760+ if not isinstance (operand , ufl .classes .Zero ):
1761+ args = sub_interp .arguments ()
1762+ Vsource = args [0 if adjoint else 1 ].function_space ()
1763+ source_bcs = [bc for bc in bcs if bc .function_space () == Vsource ]
1764+ sub_bcs = source_bcs + target_bcs
1765+ indices = (j , i ) if adjoint else (i , j )
1766+ Isub = Interpolator (sub_interp , Vtarget , bcs = sub_bcs , ** kwargs )
1767+ self .sub_interpolators [indices ] = Isub
1768+
1769+ def assemble (self , tensor = None , ** kwargs ):
1770+ """Assemble the operator (or its action)."""
1771+ rank = len (self .arguments )
1772+ if rank == 2 :
1773+ sub_tensors = {}
1774+ for ij , Isub in self .sub_interpolators .items ():
1775+ block = tensor .petscmat .getNestSubMatrix (* ij ) if tensor else PETSc .Mat ()
1776+ sub_tensors [ij ] = firedrake .AssembledMatrix (Isub .arguments , Isub .bcs , block )
1777+ Isub .assemble (tensor = sub_tensors [ij ], ** kwargs )
1778+ if tensor is None :
1779+ shape = tuple (len (a .function_space ()) for a in self .arguments )
1780+ blocks = numpy .reshape ([sub_tensors [ij ].petscmat if ij in sub_tensors else PETSc .Mat ()
1781+ for ij in numpy .ndindex (shape )], shape )
1782+ petscmat = PETSc .Mat ().createNest (blocks )
1783+ tensor = firedrake .AssembledMatrix (self .arguments , self .bcs , petscmat )
1784+ else :
1785+ tensor = self ._interpolate (output = tensor , ** kwargs )
1786+ return tensor
1787+
1788+ def _interpolate (self , output = None , ** kwargs ):
1789+ """Assemble the action."""
1790+ tensor = output or firedrake .Function (self .expr .function_space ())
1791+ for k , fsub in enumerate (tensor .subfunctions ):
1792+ fsub .assign (sum (Isub ._interpolate (** kwargs ) for (i , j ), Isub in self .sub_interpolators .items () if i == k ))
1793+ return tensor
0 commit comments