Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,16 @@ def allocation_integral_types(self):
else:
return self._allocation_integral_types

@staticmethod
def _as_pyop2_type(tensor, indices=None):
if isinstance(tensor, (firedrake.Cofunction, firedrake.Function)):
return OneFormAssembler._as_pyop2_type(tensor, indices=indices)
elif isinstance(tensor, ufl.Matrix):
return ExplicitMatrixAssembler._as_pyop2_type(tensor, indices=indices)
else:
assert indices is None
return tensor

def assemble(self, tensor=None, current_state=None):
"""Assemble the form.

Expand Down Expand Up @@ -499,9 +509,6 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
raise TypeError("Mismatching weights and operands in FormSum")
if len(args) == 0:
raise TypeError("Empty FormSum")
if tensor:
tensor.zero()

# Assemble weights
weights = []
for w in expr.weights():
Expand All @@ -519,24 +526,51 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
raise ValueError("Expecting a scalar weight expression")
weights.append(w)

# Scalar FormSum
if all(isinstance(op, numbers.Complex) for op in args):
result = sum(weight * arg for weight, arg in zip(weights, args))
return tensor.assign(result) if tensor else result
elif (all(isinstance(op, firedrake.Cofunction) for op in args)
result = numpy.dot(weights, args)
return tensor.assign(result) if tensor else result.item()

# Accumulate coefficients in a dictionary for each unique Dat/Mat
terms = defaultdict(PETSc.ScalarType)
for arg, weight in zip(args, weights):
t = self._as_pyop2_type(arg)
terms[t] += weight

# Zero the output tensor, or rescale it if it appears in the sum
tensor_scale = terms.pop(self._as_pyop2_type(tensor), 0)
if tensor is None or tensor_scale == 1:
pass
elif tensor_scale == 0:
tensor.zero()
elif isinstance(tensor, (firedrake.Cofunction, firedrake.Function)):
with tensor.dat.vec as v:
v.scale(tensor_scale)
elif isinstance(tensor, ufl.Matrix):
tensor.petscmat.scale(tensor_scale)
else:
raise ValueError("Expecting tensor to be None, Function, Cofunction, or Matrix")

# Compute the linear combination
if (all(isinstance(op, firedrake.Cofunction) for op in args)
or all(isinstance(op, firedrake.Function) for op in args)):
# Vector FormSum
V, = set(a.function_space() for a in args)
result = tensor if tensor else firedrake.Function(V)
result.dat.maxpy(weights, [a.dat for a in args])
weights = terms.values()
dats = terms.keys()
result.dat.maxpy(weights, dats)
return result
elif all(isinstance(op, ufl.Matrix) for op in args):
# Matrix FormSum
result = tensor.petscmat if tensor else PETSc.Mat()
for (op, w) in zip(args, weights):
for (op, w) in terms.items():
if result:
# If result is not void, then accumulate on it
result.axpy(w, op.petscmat)
result.axpy(w, op.handle)
else:
# If result is void, then allocate it with first term
op.petscmat.copy(result=result)
op.handle.copy(result=result)
result.scale(w)
if tensor is None:
tensor = self.assembled_matrix(expr, bcs, result)
Expand Down
11 changes: 8 additions & 3 deletions firedrake/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from pyop2.mpi import internal_comm
from pyop2.utils import as_tuple
from firedrake.petsc import PETSc
from types import SimpleNamespace


class DummyOP2Mat:
"""A hashable implementation of M.handle"""
def __init__(self, handle):
self.handle = handle


class MatrixBase(ufl.Matrix):
Expand Down Expand Up @@ -240,8 +245,8 @@ def __init__(self, a, bcs, petscmat, *args, **kwargs):
if options_prefix is not None:
self.petscmat.setOptionsPrefix(options_prefix)

# this allows call to self.M.handle without a new class
self.M = SimpleNamespace(handle=self.mat())
# this mimics op2.Mat.handle
self.M = DummyOP2Mat(self.mat())

def mat(self):
return self.petscmat
33 changes: 32 additions & 1 deletion tests/firedrake/regression/test_assemble_baseform.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_vector_formsum(a):
formsum = res + a
res2 = assemble(formsum)

assert isinstance(formsum, ufl.form.FormSum)
assert isinstance(formsum, ufl.FormSum)
assert isinstance(res2, Cofunction)
assert isinstance(preassemble, Cofunction)
for f, f2 in zip(preassemble.subfunctions, res2.subfunctions):
Expand All @@ -183,6 +183,37 @@ def test_matrix_formsum(M):
assert np.allclose(sumfirst.petscmat[:, :], res2.petscmat[:, :], rtol=1E-14)


def test_formsum_vector_self(a):
operand = assemble(a)
tensor = assemble(a)

w = (42, 3.1416, 666)
formsum = w[0] * tensor + w[1] * operand + w[2] * tensor
assert isinstance(formsum, ufl.FormSum)

result = assemble(formsum, tensor=tensor)
assert result is tensor

expected = assemble(Constant(sum(w)) * a)
for f, f2 in zip(expected.subfunctions, result.subfunctions):
assert np.allclose(f.dat.data, f2.dat.data, atol=1e-12)


def test_formsum_matrix_self(M):
operand = assemble(M)
tensor = assemble(M)

w = (42, 3.1416, 666)
formsum = w[0] * tensor + w[1] * operand + w[2] * tensor
assert isinstance(formsum, ufl.FormSum)

result = assemble(formsum, tensor=tensor)
assert result is tensor

expected = assemble(Constant(sum(w)) * M)
assert np.allclose(expected.petscmat[:, :], result.petscmat[:, :], rtol=1E-14)


def test_zero_form(M, f, one):
zero_form = assemble(action(action(M, f), one))
assert isinstance(zero_form, ScalarType.type)
Expand Down
Loading