diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 4a3ae137e8..34cd4dac7f 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -386,6 +386,16 @@ def allocation_integral_types(self): else: return self._allocation_integral_types + @staticmethod + def _as_pyop2_type(tensor, indices=None): + if isinstance(tensor, (firedrake.Cofunction, firedrake.Function)): + return OneFormAssembler._as_pyop2_type(tensor, indices=indices) + elif isinstance(tensor, ufl.Matrix): + return ExplicitMatrixAssembler._as_pyop2_type(tensor, indices=indices) + else: + assert indices is None + return tensor + def assemble(self, tensor=None, current_state=None): """Assemble the form. @@ -499,9 +509,6 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args): raise TypeError("Mismatching weights and operands in FormSum") if len(args) == 0: raise TypeError("Empty FormSum") - if tensor: - tensor.zero() - # Assemble weights weights = [] for w in expr.weights(): @@ -519,24 +526,51 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args): raise ValueError("Expecting a scalar weight expression") weights.append(w) + # Scalar FormSum if all(isinstance(op, numbers.Complex) for op in args): - result = sum(weight * arg for weight, arg in zip(weights, args)) - return tensor.assign(result) if tensor else result - elif (all(isinstance(op, firedrake.Cofunction) for op in args) + result = numpy.dot(weights, args) + return tensor.assign(result) if tensor else result.item() + + # Accumulate coefficients in a dictionary for each unique Dat/Mat + terms = defaultdict(PETSc.ScalarType) + for arg, weight in zip(args, weights): + t = self._as_pyop2_type(arg) + terms[t] += weight + + # Zero the output tensor, or rescale it if it appears in the sum + tensor_scale = terms.pop(self._as_pyop2_type(tensor), 0) + if tensor is None or tensor_scale == 1: + pass + elif tensor_scale == 0: + tensor.zero() + elif isinstance(tensor, (firedrake.Cofunction, firedrake.Function)): + with tensor.dat.vec as v: + v.scale(tensor_scale) + elif isinstance(tensor, ufl.Matrix): + tensor.petscmat.scale(tensor_scale) + else: + raise ValueError("Expecting tensor to be None, Function, Cofunction, or Matrix") + + # Compute the linear combination + if (all(isinstance(op, firedrake.Cofunction) for op in args) or all(isinstance(op, firedrake.Function) for op in args)): + # Vector FormSum V, = set(a.function_space() for a in args) result = tensor if tensor else firedrake.Function(V) - result.dat.maxpy(weights, [a.dat for a in args]) + weights = terms.values() + dats = terms.keys() + result.dat.maxpy(weights, dats) return result elif all(isinstance(op, ufl.Matrix) for op in args): + # Matrix FormSum result = tensor.petscmat if tensor else PETSc.Mat() - for (op, w) in zip(args, weights): + for (op, w) in terms.items(): if result: # If result is not void, then accumulate on it - result.axpy(w, op.petscmat) + result.axpy(w, op.handle) else: # If result is void, then allocate it with first term - op.petscmat.copy(result=result) + op.handle.copy(result=result) result.scale(w) if tensor is None: tensor = self.assembled_matrix(expr, bcs, result) diff --git a/firedrake/matrix.py b/firedrake/matrix.py index 670b38c202..2f33841289 100644 --- a/firedrake/matrix.py +++ b/firedrake/matrix.py @@ -5,7 +5,12 @@ from pyop2.mpi import internal_comm from pyop2.utils import as_tuple from firedrake.petsc import PETSc -from types import SimpleNamespace + + +class DummyOP2Mat: + """A hashable implementation of M.handle""" + def __init__(self, handle): + self.handle = handle class MatrixBase(ufl.Matrix): @@ -240,8 +245,8 @@ def __init__(self, a, bcs, petscmat, *args, **kwargs): if options_prefix is not None: self.petscmat.setOptionsPrefix(options_prefix) - # this allows call to self.M.handle without a new class - self.M = SimpleNamespace(handle=self.mat()) + # this mimics op2.Mat.handle + self.M = DummyOP2Mat(self.mat()) def mat(self): return self.petscmat diff --git a/tests/firedrake/regression/test_assemble_baseform.py b/tests/firedrake/regression/test_assemble_baseform.py index 64093ce525..9644cbe9d7 100644 --- a/tests/firedrake/regression/test_assemble_baseform.py +++ b/tests/firedrake/regression/test_assemble_baseform.py @@ -157,7 +157,7 @@ def test_vector_formsum(a): formsum = res + a res2 = assemble(formsum) - assert isinstance(formsum, ufl.form.FormSum) + assert isinstance(formsum, ufl.FormSum) assert isinstance(res2, Cofunction) assert isinstance(preassemble, Cofunction) for f, f2 in zip(preassemble.subfunctions, res2.subfunctions): @@ -183,6 +183,37 @@ def test_matrix_formsum(M): assert np.allclose(sumfirst.petscmat[:, :], res2.petscmat[:, :], rtol=1E-14) +def test_formsum_vector_self(a): + operand = assemble(a) + tensor = assemble(a) + + w = (42, 3.1416, 666) + formsum = w[0] * tensor + w[1] * operand + w[2] * tensor + assert isinstance(formsum, ufl.FormSum) + + result = assemble(formsum, tensor=tensor) + assert result is tensor + + expected = assemble(Constant(sum(w)) * a) + for f, f2 in zip(expected.subfunctions, result.subfunctions): + assert np.allclose(f.dat.data, f2.dat.data, atol=1e-12) + + +def test_formsum_matrix_self(M): + operand = assemble(M) + tensor = assemble(M) + + w = (42, 3.1416, 666) + formsum = w[0] * tensor + w[1] * operand + w[2] * tensor + assert isinstance(formsum, ufl.FormSum) + + result = assemble(formsum, tensor=tensor) + assert result is tensor + + expected = assemble(Constant(sum(w)) * M) + assert np.allclose(expected.petscmat[:, :], result.petscmat[:, :], rtol=1E-14) + + def test_zero_form(M, f, one): zero_form = assemble(action(action(M, f), one)) assert isinstance(zero_form, ScalarType.type)