diff --git a/firedrake/assemble.py b/firedrake/assemble.py index fd200dec2c..497e596edf 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -556,7 +556,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args): result = expr.assemble(assembly_opts=opts) return tensor.assign(result) if tensor else result elif isinstance(expr, ufl.Interpolate): - orig_expr = expr # Replace assembled children _, operand = expr.argument_slots() v, *assembled_operand = args @@ -568,13 +567,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args): if (v, operand) != expr.argument_slots(): expr = reconstruct_interp(operand, v=v) - # Different assembly procedures: - # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Jacobian (Interpolate matrix) - # 2) Interpolate(Coefficient(...), Argument(V2.dual(), 0)) -> Operator (or Jacobian action) - # 3) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Jacobian adjoint - # 4) Interpolate(Argument(V1, 0), Cofunction(...)) -> Action of the Jacobian adjoint - # This can be generalized to the case where the first slot is an arbitray expression. rank = len(expr.arguments()) + if rank > 2: + raise ValueError("Cannot assemble an Interpolate with more than two arguments") # If argument numbers have been swapped => Adjoint. arg_operand = ufl.algorithms.extract_arguments(operand) is_adjoint = (arg_operand and arg_operand[0].number() == 0) @@ -605,67 +600,14 @@ def base_form_assembly_visitor(self, expr, tensor, *args): assemble(sub_interp, tensor=tensor.subfunctions[i]) return tensor - # Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument. - if not is_adjoint and rank == 2: - v0, v1 = expr.arguments() - expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), - v1: v1.reconstruct(number=v0.number())}) - v, operand = expr.argument_slots() - - # Matrix-free adjoint interpolation is only implemented by SameMeshInterpolator - # so we need assemble the interpolator matrix if the meshes are different - target_mesh = V.mesh() - source_mesh = extract_unique_domain(operand) or target_mesh - if is_adjoint and rank < 2 and source_mesh is not target_mesh: - expr = reconstruct_interp(operand, v=V) - matfree = (rank == len(expr.arguments())) and (rank < 2) - # Get the interpolator interp_data = expr.interp_data.copy() default_missing_val = interp_data.pop('default_missing_val', None) - if matfree and ((is_adjoint and rank == 1) or rank == 0): - # Adjoint interpolation of a Cofunction or the action of a - # Cofunction on an interpolated Function require INC access - # on the output tensor - interp_data["access"] = op2.INC - - if rank == 1 and matfree and isinstance(tensor, firedrake.Function): + if rank == 1 and isinstance(tensor, firedrake.Function): V = tensor interpolator = firedrake.Interpolator(expr, V, **interp_data) - # Assembly - if matfree: - # Assembling the operator - return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val) - elif rank == 0: - # Assembling the double action. - Iu = interpolator._interpolate(default_missing_val=default_missing_val) - return assemble(ufl.Action(v, Iu), tensor=tensor) - elif rank == 1: - # Assembling the action of the Jacobian adjoint. - if is_adjoint: - return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val) - # Assembling the Jacobian action. - else: - return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val) - elif rank == 2: - res = tensor.petscmat if tensor else PETSc.Mat() - # Get the interpolation matrix - op2_mat = interpolator.callable() - petsc_mat = op2_mat.handle - if is_adjoint: - # Out-of-place Hermitian transpose - petsc_mat.hermitianTranspose(out=res) - elif res: - # Copy the interpolation matrix into the output tensor - petsc_mat.copy(result=res) - else: - res = petsc_mat - if tensor is None: - tensor = self.assembled_matrix(orig_expr, res) - return tensor - else: - raise ValueError("Incompatible number of arguments.") + return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val) elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): return tensor.assign(expr) elif tensor and isinstance(expr, ufl.ZeroBaseForm): diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index b990e68a8e..40f35e18a7 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -3,6 +3,8 @@ import tempfile import abc import warnings +from collections.abc import Iterable +from typing import Literal from functools import partial, singledispatch from typing import Hashable @@ -23,6 +25,7 @@ import finat import firedrake +import firedrake.bcs from firedrake import tsfc_interface, utils, functionspaceimpl from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology @@ -47,7 +50,7 @@ class Interpolate(ufl.Interpolate): def __init__(self, expr, v, subset=None, - access=op2.WRITE, + access=None, allow_missing_dofs=False, default_missing_val=None, matfree=True): @@ -122,7 +125,7 @@ def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): @PETSc.Log.EventDecorator() -def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False, default_missing_val=None, matfree=True): +def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, default_missing_val=None, matfree=True): """Returns a UFL expression for the interpolation operation of ``expr`` into ``V``. :arg expr: a UFL expression. @@ -202,25 +205,34 @@ def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False class Interpolator(abc.ABC): """A reusable interpolation object. - :arg expr: The expression to interpolate. - :arg V: The :class:`.FunctionSpace` or :class:`.Function` to + Parameters + ---------- + expr + The underlying ufl.Interpolate or the operand to the ufl.Interpolate. + V + The :class:`.FunctionSpace` or :class:`.Function` to interpolate into. - :kwarg subset: An optional :class:`pyop2.types.set.Subset` to apply the + subset + An optional :class:`pyop2.types.set.Subset` to apply the interpolation over. Cannot, at present, be used when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - :kwarg freeze_expr: Set to True to prevent the expression being + freeze_expr + Set to True to prevent the expression being re-evaluated on each call. Cannot, at present, be used when interpolating across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - :kwarg access: The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is - supported at present when interpolating across meshes. See note in - :func:`.interpolate` if changing this from default. - :kwarg bcs: An optional list of boundary conditions to zero-out in the + access + The pyop2 access descriptor for combining updates to shared DoFs. + Only ``op2.WRITE`` is supported at present when interpolating across meshes. + Only ``op2.INC`` is supported for the matrix-free adjoint interpolation. + See note in :func:`.interpolate` if changing this from default. + bcs + An optional list of boundary conditions to zero-out in the output function space. Interpolator rows or columns which are associated with boundary condition nodes are zeroed out when this is specified. - :kwarg allow_missing_dofs: For interpolation across meshes: allow + allow_missing_dofs + For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be defined on the source mesh. For example, where nodes are point evaluations, points in the target mesh that are not in the source mesh. @@ -232,14 +244,16 @@ class Interpolator(abc.ABC): Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). - :kwarg matfree: If ``False``, then construct the permutation matrix for interpolating + matfree + If ``False``, then construct the permutation matrix for interpolating between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast and reduce operations. This object can be used to carry out the same interpolation multiple times (for example in a timestepping loop). - .. note:: + Note + ---- The :class:`Interpolator` holds a reference to the provided arguments (such that they won't be collected until the @@ -266,34 +280,72 @@ def __new__(cls, expr, V, **kwargs): def __init__( self, - expr, - V, - subset=None, - freeze_expr=False, - access=op2.WRITE, - bcs=None, - allow_missing_dofs=False, - matfree=True + expr: ufl.Interpolate | ufl.classes.Expr, + V: ufl.FunctionSpace | firedrake.function.Function, + subset: op2.Subset | None = None, + freeze_expr: bool = False, + access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None, + bcs: Iterable[firedrake.bcs.BCBase] | None = None, + allow_missing_dofs: bool = False, + matfree: bool = True ): - if isinstance(expr, ufl.Interpolate): - expr, = expr.ufl_operands - self.expr = expr + if not isinstance(expr, ufl.Interpolate): + fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() + expr = interpolate(expr, fs) + dual_arg, operand = expr.argument_slots() + self.ufl_interpolate = expr + self.expr = operand self.V = V self.subset = subset self.freeze_expr = freeze_expr - self.access = access self.bcs = bcs self._allow_missing_dofs = allow_missing_dofs self.matfree = matfree self.callable = None - # Cope with the different convention of `Interpolate` and `Interpolator`: - # -> Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) - # -> Interpolator(Argument(V1, 0), V2) - expr_args = extract_arguments(expr) - if expr_args and expr_args[0].number() == 0: - v, = expr_args - expr = replace(expr, {v: v.reconstruct(number=1)}) - self.expr_renumbered = expr + + # TODO CrossMeshInterpolator and VomOntoVomXXX are not yet aware of + # self.ufl_interpolate (which carries the dual argument). + # See github issue https://github.com/firedrakeproject/firedrake/issues/4592 + target_mesh = as_domain(V) + source_mesh = extract_unique_domain(operand) or target_mesh + vom_onto_other_vom = ((source_mesh is not target_mesh) + and isinstance(source_mesh.topology, VertexOnlyMeshTopology) + and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) + if not isinstance(self, SameMeshInterpolator) or vom_onto_other_vom: + # For bespoke interpolation, we currently rely on different assembly procedures: + # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form) + # 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form) + # 3) Interpolate(Coefficient(V1), Argument(V2.dual(), 0)) -> Forward action (1-form) + # 4) Interpolate(Argument(V1, 0), Cofunction(V2.dual()) -> Adjoint action (1-form) + # 5) Interpolate(Coefficient(V1), Cofunction(V2.dual()) -> Double action (0-form) + + # CrossMeshInterpolator._interpolate only supports forward interpolation (cases 1 and 3). + # For case 2, we first redundantly assemble case 1 and then construct the transpose. + # For cases 4 and 5, we take the forward Interpolate that corresponds to dropping the Cofunction, + # and we separately compute the action against the dropped Cofunction within assemble(). + if not isinstance(dual_arg, ufl.Coargument): + # Drop the Cofunction + expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) + expr_args = extract_arguments(operand) + if expr_args and expr_args[0].number() == 0: + # Construct the symbolic forward Interpolate + v0, v1 = expr.arguments() + expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), + v1: v1.reconstruct(number=v0.number())}) + + dual_arg, operand = expr.argument_slots() + self.expr_renumbered = operand + self.ufl_interpolate_renumbered = expr + + if not isinstance(dual_arg, ufl.Coargument): + # Matrix-free assembly of 0-form or 1-form requires INC access + if access and access != op2.INC: + raise ValueError("Matfree adjoint interpolation requires INC access") + access = op2.INC + elif access is None: + # Default access for forward 1-form or 2-form (forward and adjoint) + access = op2.WRITE + self.access = access def interpolate(self, *function, transpose=None, adjoint=False, default_missing_val=None): """ @@ -317,6 +369,44 @@ def _interpolate(self, *args, **kwargs): """ pass + def assemble(self, tensor=None, default_missing_val=None): + """Assemble the operator (or its action).""" + from firedrake.assemble import assemble + needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate + arguments = self.ufl_interpolate.arguments() + if len(arguments) == 2: + # Assembling the operator + res = tensor.petscmat if tensor else PETSc.Mat() + # Get the interpolation matrix + op2mat = self.callable() + petsc_mat = op2mat.handle + if needs_adjoint: + # Out-of-place Hermitian transpose + petsc_mat.hermitianTranspose(out=res) + elif res: + petsc_mat.copy(res) + else: + res = petsc_mat + if tensor is None: + tensor = firedrake.AssembledMatrix(arguments, self.bcs, res) + return tensor + else: + # Assembling the action + cofunctions = () + if needs_adjoint: + # The renumbered Interpolate has dropped Cofunctions. + # We need to explicitly operate on them. + dual_arg, _ = self.ufl_interpolate.argument_slots() + if not isinstance(dual_arg, ufl.Coargument): + cofunctions = (dual_arg,) + + if needs_adjoint and len(arguments) == 0: + Iu = self._interpolate(default_missing_val=default_missing_val) + return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) + else: + return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, + default_missing_val=default_missing_val) + class DofNotDefinedError(Exception): r"""Raised when attempting to interpolate across function spaces where the @@ -361,7 +451,7 @@ def __init__( V, subset=None, freeze_expr=False, - access=op2.WRITE, + access=None, bcs=None, allow_missing_dofs=False, matfree=True @@ -372,8 +462,6 @@ def __init__( # Probably just need to pass freeze_expr to the various # interpolators for this to work. raise NotImplementedError("freeze_expr not implemented") - if access != op2.WRITE: - raise NotImplementedError("access other than op2.WRITE not implemented") if bcs: raise NotImplementedError("bcs not implemented") if V.ufl_element().mapping() != "identity": @@ -384,13 +472,12 @@ def __init__( raise NotImplementedError( "Can only interpolate into spaces with point evaluation nodes." ) - - if isinstance(expr, ufl.Interpolate): - dual_arg, expr = expr.argument_slots() - if not isinstance(dual_arg, Coargument): - raise NotImplementedError(f"{type(self).__name__} does not support matrix-free adjoint interpolation.") super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree) + if self.access != op2.WRITE: + raise NotImplementedError("access other than op2.WRITE not implemented") + + expr = self.expr_renumbered self.arguments = extract_arguments(expr) self.nargs = len(self.arguments) @@ -689,15 +776,13 @@ class SameMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, + def __init__(self, expr, V, subset=None, freeze_expr=False, access=None, bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): - if isinstance(expr, ufl.Interpolate): - operand, = expr.ufl_operands - else: - fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - operand = expr - expr = Interpolate(operand, fs) if subset is None: + if isinstance(expr, ufl.Interpolate): + operand, = expr.ufl_operands + else: + operand = expr target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh target = target_mesh.topology @@ -718,8 +803,9 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, pass super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs) + expr = self.ufl_interpolate_renumbered try: - self.callable = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree) + self.callable = make_interpolator(expr, V, subset, self.access, bcs=bcs, matfree=matfree) except FIAT.hdiv_trace.TraceError: raise NotImplementedError("Can't interpolate onto traces sorry") self.arguments = expr.arguments() @@ -758,8 +844,8 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, ** else: mul = assembled_interpolator.handle.mult row, col = self.arguments - V = col.function_space().dual() - assert function.function_space() == row.function_space() + V = row.function_space().dual() + assert function.function_space() == col.function_space() result = output or firedrake.Function(V) with function.dat.vec_ro as x, result.dat.vec_wo as out: @@ -793,14 +879,11 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): if not isinstance(expr, ufl.Interpolate): raise ValueError(f"Expecting to interpolate a ufl.Interpolate, got {type(expr).__name__}.") dual_arg, operand = expr.argument_slots() - target_mesh = as_domain(dual_arg) source_mesh = extract_unique_domain(operand) or target_mesh - vom_onto_other_vom = ( - isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) - and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) - and target_mesh is not source_mesh - ) + vom_onto_other_vom = ((source_mesh is not target_mesh) + and isinstance(source_mesh.topology, VertexOnlyMeshTopology) + and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) arguments = expr.arguments() rank = len(arguments) @@ -812,7 +895,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): f = V V = f.function_space() else: - V_dest = arguments[-1].function_space().dual() + V_dest = arguments[0].function_space().dual() f = firedrake.Function(V_dest) if access in {firedrake.MIN, firedrake.MAX}: finfo = numpy.finfo(f.dat.dtype) @@ -825,40 +908,27 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): elif rank == 2: if isinstance(V, firedrake.Function): raise ValueError("Cannot interpolate an expression with an argument into a Function") - if len(V) > 1: + Vrow = arguments[0].function_space() + Vcol = arguments[1].function_space() + if len(Vrow) > 1 or len(Vcol) > 1: raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported") - argfs = arguments[0].function_space() - argfs_map = argfs.cell_node_map() - if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: - if not isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): + if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: + if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - if argfs_map: - # Since the par_loop is over the target mesh cells we need to - # compose a map that takes us from target mesh cells to the - # function space nodes on the source mesh. NOTE: argfs_map is - # allowed to be None when interpolating from a Real space, even - # in the trans-mesh case. - if source_mesh.extruded: - # ExtrudedSet cannot be a map target so we need to build - # this ourselves - argfs_map = vom_cell_parent_node_map_extruded(target_mesh, argfs_map) - else: - argfs_map = compose_map_and_cache(target_mesh.cell_parent_cell_map, argfs_map) - elif vom_onto_other_vom: - argfs_map = argfs.cell_node_map() - else: - argfs_map = argfs.entity_node_map(target_mesh.topology, "cell", None, None) + if vom_onto_other_vom: # We make our own linear operator for this case using PETSc SFs tensor = None else: - sparsity = op2.Sparsity((V.dof_dset, argfs.dof_dset), - [(V.cell_node_map(), argfs_map, None)], # non-mixed - name="%s_%s_sparsity" % (V.name, argfs.name), + Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) + Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) + sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), + [(Vrow_map, Vcol_map, None)], # non-mixed + name="%s_%s_sparsity" % (Vrow.name, Vcol.name), nest=False, block_sparse=True) tensor = op2.Mat(sparsity) @@ -891,7 +961,7 @@ def callable(): # safely use the argument function space. NOTE: If this changes # after cofunctions are fully implemented, this will need to be # reconsidered. - temp_source_func = firedrake.Function(argfs) + temp_source_func = firedrake.Function(Vcol) wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) # Leave wrapper inside a callable so we can access the handle @@ -966,9 +1036,9 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): # NOTE: The par_loop is always over the target mesh cells. target_mesh = as_domain(V) source_mesh = extract_unique_domain(operand) or target_mesh - if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): + if isinstance(target_mesh.topology, VertexOnlyMeshTopology): if target_mesh is not source_mesh: - if not isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): + if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") @@ -1014,15 +1084,15 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): if needs_weight: # Compute the reciprocal of the DOF multiplicity W = dual_arg.function_space() - shapes = (W.finat_element.space_dimension(), W.block_size) - domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes - instructions = """ - for i, j - w[i,j] = w[i,j] + 1 - end - """ + wsize = W.finat_element.space_dimension() * W.block_size + kernel_code = f""" + void multiplicity(PetscScalar *restrict w) {{ + for (PetscInt i=0; i<{wsize}; i++) w[i] += 1; + }}""" + kernel = op2.Kernel(kernel_code, "multiplicity", requires_zeroed_output_arguments=False) weight = firedrake.Function(W) - firedrake.par_loop((domain, instructions), ufl.dx, {"w": (weight, op2.INC)}) + m_ = get_interp_node_map(source_mesh, target_mesh, W) + op2.par_loop(kernel, cell_set, weight.dat(op2.INC, m_)) with weight.dat.vec as w: w.reciprocal() @@ -1070,33 +1140,26 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): V_dest = arguments[-1].function_space() if isinstance(dual_arg, ufl.Cofunction) else V - parloop_args.append(tensor(access, V_dest.cell_node_map())) + m_ = get_interp_node_map(source_mesh, target_mesh, V_dest) + parloop_args.append(tensor(access, m_)) else: assert access == op2.WRITE # Other access descriptors not done for Matrices. - rows_map = V.cell_node_map() - Vcol = arguments[0].function_space() - assert tensor.handle.getSize() == (V.dim(), Vcol.dim()) - if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): - columns_map = Vcol.cell_node_map() - if target_mesh is not source_mesh: - # Since the par_loop is over the target mesh cells we need to - # compose a map that takes us from target mesh cells to the - # function space nodes on the source mesh. - if source_mesh.extruded: - # ExtrudedSet cannot be a map target so we need to build - # this ourselves - columns_map = vom_cell_parent_node_map_extruded(target_mesh, columns_map) - else: - columns_map = compose_map_and_cache(target_mesh.cell_parent_cell_map, - columns_map) - else: - columns_map = Vcol.entity_node_map(target_mesh.topology, "cell", None, None) + Vrow = arguments[0].function_space() + Vcol = arguments[1].function_space() + assert tensor.handle.getSize() == (Vrow.dim(), Vcol.dim()) + rows_map = get_interp_node_map(source_mesh, target_mesh, Vrow) + columns_map = get_interp_node_map(source_mesh, target_mesh, Vcol) + lgmaps = None if bcs: - bc_rows = [bc for bc in bcs if bc.function_space() == V] + if ufl.duals.is_dual(Vrow): + Vrow = Vrow.dual() + if ufl.duals.is_dual(Vcol): + Vcol = Vcol.dual() + bc_rows = [bc for bc in bcs if bc.function_space() == Vrow] bc_cols = [bc for bc in bcs if bc.function_space() == Vcol] - lgmaps = [(V.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))] - parloop_args.append(tensor(op2.WRITE, (rows_map, columns_map), lgmaps=lgmaps)) + lgmaps = [(Vrow.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))] + parloop_args.append(tensor(access, (rows_map, columns_map), lgmaps=lgmaps)) if oriented: co = target_mesh.cell_orientations() parloop_args.append(co.dat(op2.READ, co.cell_node_map())) @@ -1105,38 +1168,14 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): parloop_args.append(cs.dat(op2.READ, cs.cell_node_map())) for coefficient in coefficients: - if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): - coeff_mesh = extract_unique_domain(coefficient) - if coeff_mesh is target_mesh or not coeff_mesh: - # NOTE: coeff_mesh is None is allowed e.g. when interpolating from - # a Real space - m_ = coefficient.cell_node_map() - elif coeff_mesh is source_mesh: - if coefficient.cell_node_map(): - # Since the par_loop is over the target mesh cells we need to - # compose a map that takes us from target mesh cells to the - # function space nodes on the source mesh. - if source_mesh.extruded: - # ExtrudedSet cannot be a map target so we need to build - # this ourselves - m_ = vom_cell_parent_node_map_extruded(target_mesh, coefficient.cell_node_map()) - else: - m_ = compose_map_and_cache(target_mesh.cell_parent_cell_map, coefficient.cell_node_map()) - else: - # m_ is allowed to be None when interpolating from a Real space, - # even in the trans-mesh case. - m_ = coefficient.cell_node_map() - else: - raise ValueError("Have coefficient with unexpected mesh") - else: - m_ = coefficient.function_space().entity_node_map(target_mesh.topology, "cell", None, None) + m_ = get_interp_node_map(source_mesh, target_mesh, coefficient.function_space()) parloop_args.append(coefficient.dat(op2.READ, m_)) for const in extract_firedrake_constants(expr): parloop_args.append(const.dat(op2.READ)) # Finally, add the target mesh reference coordinates if they appear in the kernel - if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): + if isinstance(target_mesh.topology, VertexOnlyMeshTopology): if target_mesh is not source_mesh: # NOTE: TSFC will sometimes drop run-time arguments in generated # kernels if they are deemed not-necessary. @@ -1167,6 +1206,41 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): return copyin + callables + (parloop_compute_callable, ) + copyout +def get_interp_node_map(source_mesh, target_mesh, fs): + """Return the map between cells of the target mesh and nodes of the function space. + + If the function space is defined on the source mesh then the node map is composed + with a map between target and source cells. + """ + if isinstance(target_mesh.topology, VertexOnlyMeshTopology): + coeff_mesh = fs.mesh() + m_ = fs.cell_node_map() + if coeff_mesh is target_mesh or not coeff_mesh: + # NOTE: coeff_mesh is None is allowed e.g. when interpolating from + # a Real space + pass + elif coeff_mesh is source_mesh: + if m_: + # Since the par_loop is over the target mesh cells we need to + # compose a map that takes us from target mesh cells to the + # function space nodes on the source mesh. + if source_mesh.extruded: + # ExtrudedSet cannot be a map target so we need to build + # this ourselves + m_ = vom_cell_parent_node_map_extruded(target_mesh, m_) + else: + m_ = compose_map_and_cache(target_mesh.cell_parent_cell_map, m_) + else: + # m_ is allowed to be None when interpolating from a Real space, + # even in the trans-mesh case. + pass + else: + raise ValueError("Have coefficient with unexpected mesh") + else: + m_ = fs.entity_node_map(target_mesh.topology, "cell", None, None) + return m_ + + try: _expr_cachedir = os.environ["FIREDRAKE_TSFC_KERNEL_CACHE_DIR"] except KeyError: @@ -1357,7 +1431,7 @@ def vom_cell_parent_node_map_extruded(vertex_only_mesh, extruded_cell_node_map): the parent extruded mesh. """ - if not isinstance(vertex_only_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology): + if not isinstance(vertex_only_mesh.topology, VertexOnlyMeshTopology): raise TypeError("The input mesh must be a VertexOnlyMesh") cnm = extruded_cell_node_map vmx = vertex_only_mesh diff --git a/firedrake/preconditioners/hiptmair.py b/firedrake/preconditioners/hiptmair.py index 203f3baa9e..14ec77fe1a 100644 --- a/firedrake/preconditioners/hiptmair.py +++ b/firedrake/preconditioners/hiptmair.py @@ -202,7 +202,7 @@ def coarsen(self, pc): coarse_space_bcs = tuple(coarse_space_bcs) if G_callback is None: - interp_petscmat = chop(Interpolator(dminus(test), V, bcs=bcs + coarse_space_bcs).callable().handle) + interp_petscmat = chop(Interpolator(dminus(trial), V, bcs=bcs + coarse_space_bcs).callable().handle) else: interp_petscmat = G_callback(coarse_space, V, coarse_space_bcs, bcs) diff --git a/firedrake/preconditioners/pmg.py b/firedrake/preconditioners/pmg.py index 1cdea965a6..f4b45a67a5 100644 --- a/firedrake/preconditioners/pmg.py +++ b/firedrake/preconditioners/pmg.py @@ -1457,7 +1457,7 @@ def make_kernels(self, Vf, Vc): except KeyError: pass prolong_kernel, _ = prolongation_transfer_kernel_action(Vf, self.uc) - matrix_kernel, coefficients = prolongation_transfer_kernel_action(Vf, firedrake.TestFunction(Vc)) + matrix_kernel, coefficients = prolongation_transfer_kernel_action(Vf, firedrake.TrialFunction(Vc)) # The way we transpose the prolongation kernel is suboptimal. # A local matrix is generated each time the kernel is executed. @@ -1593,7 +1593,7 @@ def prolongation_matrix_aij(P1, Pk, P1_bcs=[], Pk_bcs=[]): for bc in chain(Pk_bcs_i, P1_bcs_i) if bc is not None) matarg = mat[i, i](op2.WRITE, (Pk.sub(i).cell_node_map(), P1.sub(i).cell_node_map()), lgmaps=((rlgmap, clgmap), ), unroll_map=unroll) - expr = firedrake.TestFunction(P1.sub(i)) + expr = firedrake.TrialFunction(P1.sub(i)) kernel, coefficients = prolongation_transfer_kernel_action(Pk.sub(i), expr) parloop_args = [kernel, mesh.cell_set, matarg] for coefficient in coefficients: @@ -1610,7 +1610,7 @@ def prolongation_matrix_aij(P1, Pk, P1_bcs=[], Pk_bcs=[]): for bc in chain(Pk_bcs, P1_bcs) if bc is not None) matarg = mat(op2.WRITE, (Pk.cell_node_map(), P1.cell_node_map()), lgmaps=((rlgmap, clgmap), ), unroll_map=unroll) - expr = firedrake.TestFunction(P1) + expr = firedrake.TrialFunction(P1) kernel, coefficients = prolongation_transfer_kernel_action(Pk, expr) parloop_args = [kernel, mesh.cell_set, matarg] for coefficient in coefficients: diff --git a/tests/firedrake/regression/test_interp_dual.py b/tests/firedrake/regression/test_interp_dual.py index 444f352453..50e29b05cb 100644 --- a/tests/firedrake/regression/test_interp_dual.py +++ b/tests/firedrake/regression/test_interp_dual.py @@ -106,7 +106,11 @@ def test_assemble_interp_adjoint_matrix(V1, V2): # Interpolation from V2* to V1* c1 = Cofunction(V1.dual()).interpolate(c2) # Interpolation matrix (V2* -> V1*) - a = assemble(adjoint(Iv1)) + adj_Iv1 = adjoint(Iv1) + a = assemble(adj_Iv1) + assert a.arguments() == adj_Iv1.arguments() + assert a.petscmat.getSize() == (V1.dim(), V2.dim()) + res = Cofunction(V1.dual()) with c2.dat.vec_ro as x, res.dat.vec_ro as y: a.petscmat.mult(x, y) diff --git a/tests/firedrake/submesh/test_submesh_interpolate.py b/tests/firedrake/submesh/test_submesh_interpolate.py index 0d3805974f..a26c1acb08 100644 --- a/tests/firedrake/submesh/test_submesh_interpolate.py +++ b/tests/firedrake/submesh/test_submesh_interpolate.py @@ -27,14 +27,18 @@ def _get_expr(V): return as_vector([cos(x), sin(y)]) -def _test_submesh_interpolate_cell_cell(mesh, subdomain_cond, fe_fesub): +def make_submesh(mesh, subdomain_cond, label_value): dim = mesh.topological_dimension() - (family, degree), (family_sub, degree_sub) = fe_fesub DG0 = FunctionSpace(mesh, "DG", 0) indicator_function = Function(DG0).interpolate(subdomain_cond) - label_value = 999 mesh.mark_entities(indicator_function, label_value) - subm = Submesh(mesh, dim, label_value) + return Submesh(mesh, dim, label_value) + + +def _test_submesh_interpolate_cell_cell(mesh, subdomain_cond, fe_fesub): + (family, degree), (family_sub, degree_sub) = fe_fesub + label_value = 999 + subm = make_submesh(mesh, subdomain_cond, label_value) V = FunctionSpace(mesh, family, degree) V_ = FunctionSpace(mesh, family_sub, degree_sub) Vsub = FunctionSpace(subm, family_sub, degree_sub) @@ -268,3 +272,77 @@ def expr(m): dg2d_ = Function(DG2d).interpolate(dg3d) error = assemble(inner(dg2d_ - expr(subm), dg2d_ - expr(subm)) * dx)**0.5 assert abs(error) < 1.e-14 + + +@pytest.mark.parallel(nprocs=[1, 3]) +@pytest.mark.parametrize('fe_fesub', [[("DG", 2), ("DG", 1)], + [("CG", 3), ("CG", 2)]]) +def test_submesh_interpolate_adjoint(fe_fesub): + (family, degree), (family_sub, degree_sub) = fe_fesub + + mesh = UnitSquareMesh(8, 8) + x, y = SpatialCoordinate(mesh) + subdomain_cond = conditional(And(LT(x, 0.5), LT(y, 0.5)), 1, 0) + label_value = 999 + subm = make_submesh(mesh, subdomain_cond, label_value) + + V1 = FunctionSpace(subm, family_sub, degree_sub) + V2 = FunctionSpace(mesh, family, degree) + + x, y = SpatialCoordinate(V1.mesh()) + expr = x * y + u1 = Function(V1).interpolate(expr) + ustar2 = assemble(inner(1, TestFunction(V2))*dx(label_value)) + + expected = assemble(inner(1, u1)*dx(label_value)) + + # Test forward 2-form + I = assemble(interpolate(TrialFunction(V1), TestFunction(V2.dual()), allow_missing_dofs=True)) + assert I.arguments()[0].function_space() == V2.dual() + assert I.arguments()[1].function_space() == V1 + + result_forward_2 = assemble(action(ustar2, action(I, u1))) + assert np.isclose(result_forward_2, expected) + + # Test adjoint 2-form + I_adj = assemble(interpolate(TestFunction(V1), TrialFunction(V2.dual()), allow_missing_dofs=True)) + assert I_adj.arguments()[0].function_space() == V1 + assert I_adj.arguments()[1].function_space() == V2.dual() + + result_adjoint_2 = assemble(action(action(I_adj, ustar2), u1)) + assert np.isclose(result_adjoint_2, expected) + + # Test forward 1-form (only works in serial for continuous elements) + # Matfree forward interpolation with Submesh currently fails in parallel. + # The ghost nodes of the parent mesh may be redistributed + # into different processes as non-ghost dofs of the submesh. + # The submesh kernel will write into ghost nodes of the parent mesh, + # but this will be ignored in the halo exchange if access=op2.WRITE. + + # See https://github.com/firedrakeproject/firedrake/issues/4483 + expected_to_pass = (V2.comm.size == 1 or V2.finat_element.is_dg()) + + Iu1 = assemble(interpolate(u1, TestFunction(V2.dual()), allow_missing_dofs=True)) + assert Iu1.function_space() == V2 + + expected_primal = assemble(action(I, u1)) + test1 = np.allclose(Iu1.dat.data, expected_primal.dat.data) + assert test1 or not expected_to_pass + + result_forward_1 = assemble(action(ustar2, Iu1)) + test0 = np.isclose(result_forward_1, expected) + assert test0 or not expected_to_pass + + # Test adjoint 1-form + ustar2I = assemble(interpolate(TestFunction(V1), ustar2, allow_missing_dofs=True)) + assert ustar2I.function_space() == V1.dual() + + expected_dual = assemble(action(I_adj, ustar2)) + assert np.allclose(ustar2I.dat.data, expected_dual.dat.data) + + result_adjoint_1 = assemble(action(ustar2I, u1)) + assert np.isclose(result_adjoint_1, expected) + + # Test 0-form + result_0 = assemble(interpolate(u1, ustar2, allow_missing_dofs=True)) + assert np.isclose(result_0, expected) diff --git a/tsfc/driver.py b/tsfc/driver.py index 87863dd1d6..89db890f24 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -6,7 +6,7 @@ from finat.physically_mapped import DirectlyDefinedElement, PhysicallyMappedElement import ufl -from ufl.algorithms import extract_arguments, extract_coefficients +from ufl.algorithms import extract_coefficients from ufl.algorithms.analysis import has_type from ufl.algorithms.apply_coefficient_split import CoefficientSplitter from ufl.classes import Form, GeometricQuantity @@ -211,11 +211,12 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, if isinstance(to_element, (PhysicallyMappedElement, DirectlyDefinedElement)): raise NotImplementedError("Don't know how to interpolate onto zany spaces, sorry") - orig_expression = expression + orig_coefficients = extract_coefficients(expression) if isinstance(expression, ufl.Interpolate): - operand, = expression.ufl_operands + v, operand = expression.argument_slots() else: operand = expression + v = ufl.FunctionSpace(extract_unique_domain(operand), ufl_element) # Map into reference space operand = apply_mapping(operand, ufl_element, domain) @@ -223,11 +224,8 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # Apply UFL preprocessing operand = ufl_utils.preprocess_expression(operand, complex_mode=complex_mode) - if isinstance(expression, ufl.Interpolate): - v, _ = expression.argument_slots() - expression = ufl.Interpolate(operand, v) - else: - expression = operand + # Reconstructed Interpolate with mapped operand + expression = ufl.Interpolate(operand, v) # Initialise kernel builder if interface is None: @@ -235,9 +233,10 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, from tsfc.kernel_interface.firedrake_loopy import ExpressionKernelBuilder as interface builder = interface(parameters["scalar_type"]) - arguments = extract_arguments(operand) - argument_multiindices = tuple(builder.create_element(arg.ufl_element()).get_indices() - for arg in arguments) + arguments = expression.arguments() + argument_multiindices = {arg.number(): builder.create_element(arg.ufl_element()).get_indices() + for arg in arguments} + assert len(argument_multiindices) == len(arguments) # Replace coordinates (if any) unless otherwise specified by kwarg if domain is None: @@ -246,7 +245,6 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # Collect required coefficients and determine numbering coefficients = extract_coefficients(expression) - orig_coefficients = extract_coefficients(orig_expression) coefficient_numbers = tuple(map(orig_coefficients.index, coefficients)) builder.set_coefficient_numbers(coefficient_numbers) @@ -284,11 +282,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, if isinstance(to_element, finat.QuadratureElement): kernel_cfg["quadrature_rule"] = to_element._rule - if isinstance(expression, ufl.Interpolate): - dual_arg, operand = expression.argument_slots() - else: - operand = expression - dual_arg = None + dual_arg, operand = expression.argument_slots() # Create callable for translation of UFL expression to gem fn = DualEvaluationCallable(operand, kernel_cfg) @@ -307,9 +301,13 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, evaluation = gem.MathFunction('conj', evaluation) evaluation = gem.IndexSum(evaluation * gem_dual[basis_indices], basis_indices) basis_indices = () + else: + argument_multiindices[dual_arg.number()] = basis_indices + + argument_multiindices = dict(sorted(argument_multiindices.items())) # Build kernel body - return_indices = tuple(chain(basis_indices, *argument_multiindices)) + return_indices = tuple(chain.from_iterable(argument_multiindices.values())) return_shape = tuple(i.extent for i in return_indices) return_var = gem.Variable('A', return_shape or (1,)) return_expr = gem.Indexed(return_var, return_indices or (0,)) @@ -318,7 +316,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *, # but we don't for now. evaluation, = impero_utils.preprocess_gem([evaluation]) impero_c = impero_utils.compile_gem([(return_expr, evaluation)], return_indices) - index_names = dict((idx, "p%d" % i) for (i, idx) in enumerate(basis_indices)) + index_names = {idx: f"p{i}" for (i, idx) in enumerate(basis_indices)} # Handle kernel interface requirements builder.register_requirements([evaluation]) builder.set_output(return_var) @@ -381,7 +379,7 @@ def __call__(self, ps): gem_expr, = fem.compile_ufl(self.expression, translation_context, point_sum=False) # In some cases ps.indices may be dropped from expr, but nothing # new should now appear - argument_multiindices = kernel_cfg["argument_multiindices"] + argument_multiindices = kernel_cfg["argument_multiindices"].values() assert set(gem_expr.free_indices) <= set(chain(ps.indices, *argument_multiindices)) return gem_expr