Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 4 additions & 62 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
result = expr.assemble(assembly_opts=opts)
return tensor.assign(result) if tensor else result
elif isinstance(expr, ufl.Interpolate):
orig_expr = expr
# Replace assembled children
_, operand = expr.argument_slots()
v, *assembled_operand = args
Expand All @@ -568,13 +567,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
if (v, operand) != expr.argument_slots():
expr = reconstruct_interp(operand, v=v)

# Different assembly procedures:
# 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Jacobian (Interpolate matrix)
# 2) Interpolate(Coefficient(...), Argument(V2.dual(), 0)) -> Operator (or Jacobian action)
# 3) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Jacobian adjoint
# 4) Interpolate(Argument(V1, 0), Cofunction(...)) -> Action of the Jacobian adjoint
# This can be generalized to the case where the first slot is an arbitray expression.
rank = len(expr.arguments())
if rank > 2:
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
# If argument numbers have been swapped => Adjoint.
arg_operand = ufl.algorithms.extract_arguments(operand)
is_adjoint = (arg_operand and arg_operand[0].number() == 0)
Expand Down Expand Up @@ -605,67 +600,14 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
assemble(sub_interp, tensor=tensor.subfunctions[i])
return tensor

# Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument.
if not is_adjoint and rank == 2:
v0, v1 = expr.arguments()
expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()),
v1: v1.reconstruct(number=v0.number())})
v, operand = expr.argument_slots()

# Matrix-free adjoint interpolation is only implemented by SameMeshInterpolator
# so we need assemble the interpolator matrix if the meshes are different
target_mesh = V.mesh()
source_mesh = extract_unique_domain(operand) or target_mesh
if is_adjoint and rank < 2 and source_mesh is not target_mesh:
expr = reconstruct_interp(operand, v=V)
matfree = (rank == len(expr.arguments())) and (rank < 2)

# Get the interpolator
interp_data = expr.interp_data.copy()
default_missing_val = interp_data.pop('default_missing_val', None)
if matfree and ((is_adjoint and rank == 1) or rank == 0):
# Adjoint interpolation of a Cofunction or the action of a
# Cofunction on an interpolated Function require INC access
# on the output tensor
interp_data["access"] = op2.INC

if rank == 1 and matfree and isinstance(tensor, firedrake.Function):
if rank == 1 and isinstance(tensor, firedrake.Function):
V = tensor
interpolator = firedrake.Interpolator(expr, V, **interp_data)

# Assembly
if matfree:
# Assembling the operator
return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val)
elif rank == 0:
# Assembling the double action.
Iu = interpolator._interpolate(default_missing_val=default_missing_val)
return assemble(ufl.Action(v, Iu), tensor=tensor)
elif rank == 1:
# Assembling the action of the Jacobian adjoint.
if is_adjoint:
return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val)
# Assembling the Jacobian action.
else:
return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val)
elif rank == 2:
res = tensor.petscmat if tensor else PETSc.Mat()
# Get the interpolation matrix
op2_mat = interpolator.callable()
petsc_mat = op2_mat.handle
if is_adjoint:
# Out-of-place Hermitian transpose
petsc_mat.hermitianTranspose(out=res)
elif res:
# Copy the interpolation matrix into the output tensor
petsc_mat.copy(result=res)
else:
res = petsc_mat
if tensor is None:
tensor = self.assembled_matrix(orig_expr, res)
return tensor
else:
raise ValueError("Incompatible number of arguments.")
return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val)
elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)):
return tensor.assign(expr)
elif tensor and isinstance(expr, ufl.ZeroBaseForm):
Expand Down
Loading
Loading