Skip to content

Commit 70f5212

Browse files
committed
ASMStarPC: group patches by mesh coloring
1 parent 2b90992 commit 70f5212

File tree

1 file changed

+89
-63
lines changed
  • firedrake/preconditioners

1 file changed

+89
-63
lines changed

firedrake/preconditioners/asm.py

Lines changed: 89 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,68 @@ def destroy(self, pc):
140140
self.asmpc.destroy()
141141

142142

143+
def get_entity_dofs(V, V_local_ises_indices, points):
144+
"""Extract degrees of freedom associated to mesh entities (points of the DMPlex)."""
145+
indices = []
146+
for (i, W) in enumerate(V):
147+
section = W.dm.getDefaultSection()
148+
for p in points:
149+
dof = section.getDof(p)
150+
if dof <= 0:
151+
continue
152+
off = section.getOffset(p)
153+
# Local indices within W
154+
W_slice = slice(off*W.block_size, W.block_size * (off + dof))
155+
indices.extend(V_local_ises_indices[i][W_slice])
156+
return indices
157+
158+
159+
def build_star_indices(V, V_local_ises_indices, mesh_dm, ordering, prefix, seed_points):
160+
"""Build index sets for star patches."""
161+
points = []
162+
for seed in seed_points:
163+
# Only build patches over owned DoFs
164+
if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1:
165+
continue
166+
# Create point list from mesh DM
167+
star, _ = mesh_dm.getTransitiveClosure(seed, useCone=False)
168+
star = order_points(mesh_dm, star, ordering, prefix)
169+
points.extend(star)
170+
171+
indices = get_entity_dofs(V, V_local_ises_indices, points)
172+
iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF)
173+
return iset
174+
175+
176+
def build_vanka_indices(Z, Z_local_ises_indices, mesh_dm, ordering, prefix, include_star, seed_points):
177+
"""Build index sets for Vanka patches."""
178+
V_points = []
179+
Q_points = []
180+
for seed in seed_points:
181+
# Only build patches over owned DoFs
182+
if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1:
183+
continue
184+
# Create point list from mesh DM
185+
star, _ = mesh_dm.getTransitiveClosure(seed, useCone=False)
186+
star = order_points(mesh_dm, star, ordering, prefix)
187+
if include_star:
188+
Q_points.extend(star)
189+
else:
190+
Q_points.append(seed)
191+
192+
closure = []
193+
for s in reversed(star):
194+
cs, _ = mesh_dm.getTransitiveClosure(s, useCone=True)
195+
closure.extend(cs)
196+
# Grab unique points with stable ordering
197+
V_points.extend(reversed(dict.fromkeys(closure)))
198+
199+
indices = get_entity_dofs(Z[0], Z_local_ises_indices[0], V_points)
200+
indices.extend(get_entity_dofs(Z[1], Z_local_ises_indices[1], Q_points))
201+
iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF)
202+
return iset
203+
204+
143205
class ASMStarPC(ASMPatchPC):
144206
'''Patch-based PC using Star of mesh entities implmented as an
145207
:class:`ASMPatchPC`.
@@ -162,42 +224,25 @@ def get_patches(self, V):
162224
warning("applying ASMStarPC on an extruded mesh")
163225

164226
# Obtain the topological entities to use to construct the stars
165-
opts = PETSc.Options(self.prefix)
227+
prefix = self.prefix
228+
opts = PETSc.Options(prefix)
166229
depth = opts.getInt("construct_dim", default=0)
230+
coloring = opts.getBool("coloring", default=False)
167231
ordering = opts.getString("mat_ordering_type", default="natural")
168232
validate_overlap(mesh_unique, depth, "star")
169233

170234
# Accessing .indices causes the allocation of a global array,
171235
# so we need to cache these for efficiency
172236
V_local_ises_indices = tuple(iset.indices for iset in V.dof_dset.local_ises)
173237

174-
# Build index sets for the patches
175-
ises = []
176238
(start, end) = mesh_dm.getDepthStratum(depth)
177-
for seed in range(start, end):
178-
# Only build patches over owned DoFs
179-
if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1:
180-
continue
181-
182-
# Create point list from mesh DM
183-
pt_array, _ = mesh_dm.getTransitiveClosure(seed, useCone=False)
184-
pt_array = order_points(mesh_dm, pt_array, ordering, self.prefix)
185-
186-
# Get DoF indices for patch
187-
indices = []
188-
for (i, W) in enumerate(V):
189-
section = W.dm.getDefaultSection()
190-
for p in pt_array.tolist():
191-
dof = section.getDof(p)
192-
if dof <= 0:
193-
continue
194-
off = section.getOffset(p)
195-
# Local indices within W
196-
W_indices = slice(off*W.block_size, W.block_size * (off + dof))
197-
indices.extend(V_local_ises_indices[i][W_indices])
198-
iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF)
199-
ises.append(iset)
200-
239+
if coloring:
240+
colors = mesh_dm.createColoring(depth=depth, distance=1)
241+
ises = [build_star_indices(V, V_local_ises_indices, mesh_dm, ordering, prefix, color.indices+start)
242+
for color in colors]
243+
else:
244+
ises = [build_star_indices(V, V_local_ises_indices, mesh_dm, ordering, prefix, (seed,))
245+
for seed in range(start, end)]
201246
return ises
202247

203248

@@ -223,25 +268,30 @@ def get_patches(self, V):
223268
warning("applying ASMVankaPC on an extruded mesh")
224269

225270
# Obtain the topological entities to use to construct the stars
271+
prefix = self.prefix
226272
opts = PETSc.Options(self.prefix)
227273
depth = opts.getInt("construct_dim", default=-1)
228274
height = opts.getInt("construct_codim", default=-1)
229275
if (depth == -1 and height == -1) or (depth != -1 and height != -1):
230276
raise ValueError(f"Must set exactly one of {self.prefix}construct_dim or {self.prefix}construct_codim")
231277

232278
exclude_subspaces = list(map(int, opts.getString("exclude_subspaces", default="-1").split(",")))
279+
include_subspaces = [i for i in range(len(V)) if i not in exclude_subspaces]
233280
include_type = opts.getString("include_type", default="star").lower()
234281
if include_type not in ["star", "entity"]:
235282
raise ValueError(f"{self.prefix}include_type must be either 'star' or 'entity', not {include_type}")
236283
include_star = include_type == "star"
237284

238-
ordering = opts.getString("mat_ordering_type", default="natural")
285+
def splitting(V):
286+
return (tuple(V[i] for i in include_subspaces), tuple(V[i] for i in exclude_subspaces))
287+
288+
Z = splitting(V)
239289
# Accessing .indices causes the allocation of a global array,
240290
# so we need to cache these for efficiency
241291
V_local_ises_indices = tuple(iset.indices for iset in V.dof_dset.local_ises)
292+
Z_local_ises_indices = splitting(V_local_ises_indices)
242293

243294
# Build index sets for the patches
244-
ises = []
245295
if depth != -1:
246296
(start, end) = mesh_dm.getDepthStratum(depth)
247297
patch_dim = depth
@@ -250,40 +300,16 @@ def get_patches(self, V):
250300
patch_dim = mesh_dm.getDimension() - height
251301
validate_overlap(mesh_unique, patch_dim, "vanka")
252302

253-
for seed in range(start, end):
254-
# Only build patches over owned DoFs
255-
if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1:
256-
continue
257-
258-
# Create point list from mesh DM
259-
star, _ = mesh_dm.getTransitiveClosure(seed, useCone=False)
260-
star = order_points(mesh_dm, star, ordering, self.prefix)
261-
pt_array = []
262-
for pt in reversed(star):
263-
closure, _ = mesh_dm.getTransitiveClosure(pt, useCone=True)
264-
pt_array.extend(closure)
265-
# Grab unique points with stable ordering
266-
pt_array = list(reversed(dict.fromkeys(pt_array)))
267-
268-
# Get DoF indices for patch
269-
indices = []
270-
for (i, W) in enumerate(V):
271-
section = W.dm.getDefaultSection()
272-
if i in exclude_subspaces:
273-
loop_list = star if include_star else [seed]
274-
else:
275-
loop_list = pt_array
276-
for p in loop_list:
277-
dof = section.getDof(p)
278-
if dof <= 0:
279-
continue
280-
off = section.getOffset(p)
281-
# Local indices within W
282-
W_indices = slice(off*W.block_size, W.block_size * (off + dof))
283-
indices.extend(V_local_ises_indices[i][W_indices])
284-
iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF)
285-
ises.append(iset)
286-
303+
coloring = opts.getBool("coloring", default=False)
304+
ordering = opts.getString("mat_ordering_type", default="natural")
305+
if coloring:
306+
colors = mesh_dm.createColoring(depth=patch_dim, distance=2)
307+
ises = [build_vanka_indices(Z, Z_local_ises_indices, mesh_dm, ordering, prefix,
308+
include_star, color.indices+start)
309+
for color in colors]
310+
else:
311+
ises = [build_vanka_indices(Z, Z_local_ises_indices, mesh_dm, ordering, prefix, include_star, (seed,))
312+
for seed in range(start, end)]
287313
return ises
288314

289315

0 commit comments

Comments
 (0)