Skip to content

Commit da256f3

Browse files
committed
add a test
1 parent fa2f43c commit da256f3

File tree

3 files changed

+35
-19
lines changed

3 files changed

+35
-19
lines changed

firedrake/ufl_expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def action(form, coefficient, derivatives_expanded=None):
301301
if isinstance(form, firedrake.slate.TensorBase):
302302
if form.rank == 0:
303303
raise ValueError("Can't take action of rank-0 tensor")
304-
return form * firedrake.AssembledVector(coefficient)
304+
return form * coefficient
305305
else:
306306
return ufl.action(form, coefficient, derivatives_expanded=derivatives_expanded)
307307

tests/firedrake/slate/test_assemble_tensors.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -130,24 +130,6 @@ def test_assemble_matrix(rank_two_tensor):
130130
assert np.allclose(M.M.values, assemble(rank_two_tensor.form).M.values, rtol=1e-14)
131131

132132

133-
def test_assemble_solve(mesh):
134-
V = FunctionSpace(mesh, "DG", 0)
135-
u = TrialFunction(V)
136-
v = TestFunction(V)
137-
138-
M = inner(u, v)*dx
139-
f = Cofunction(V.dual())
140-
f.dat.data[...] = 1
141-
142-
u1 = Function(V)
143-
u2 = Function(V)
144-
# Assemble a SLATE tensor into u1
145-
assemble(Inverse(Tensor(M)) * AssembledVector(f), tensor=u1)
146-
# Solve in the usual way
147-
solve(M == f, u2)
148-
assert np.allclose(u1.dat.data, u2.dat.data, rtol=1e-14)
149-
150-
151133
def test_assemble_vector_into_tensor(mesh):
152134
V = FunctionSpace(mesh, "DG", 1)
153135
v = TestFunction(V)

tests/firedrake/slate/test_linear_algebra.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,37 @@ def test_local_solve(decomp):
116116
x = assemble(A.solve(b, decomposition=decomp))
117117

118118
assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)
119+
120+
121+
@pytest.mark.parametrize("mat_type, rhs_type", [
122+
("slate", "slate"), ("slate", "form"), ("slate", "cofunction"),
123+
("aij", "cofunction"), ("aij", "form"),
124+
("matfree", "cofunction"), ("matfree", "form")])
125+
def test_inverse_action(mat_type, rhs_type):
126+
"""Test combined UFL/SLATE expressions
127+
"""
128+
mesh = UnitSquareMesh(3, 3)
129+
V = FunctionSpace(mesh, "DG", 1)
130+
u = TrialFunction(V)
131+
v = TestFunction(V)
132+
133+
A = Tensor(inner(u, v)*dx)
134+
if mat_type == "slate":
135+
Ainv = A.inv
136+
else:
137+
Ainv = assemble(A.inv, mat_type=mat_type)
138+
139+
f = Function(V).assign(1.0)
140+
L = inner(f, v)*dx
141+
if rhs_type == "form":
142+
b = L
143+
elif rhs_type == "cofunction":
144+
b = assemble(L)
145+
elif rhs_type == "slate":
146+
b = Tensor(L)
147+
else:
148+
raise ValueError("Invalid rhs type")
149+
150+
x = Function(V)
151+
assemble(action(Ainv, b), tensor=x)
152+
assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)

0 commit comments

Comments
 (0)