@@ -140,6 +140,68 @@ def destroy(self, pc):
140140 self .asmpc .destroy ()
141141
142142
143+ def get_entity_dofs (V , V_local_ises_indices , points ):
144+ """Extract degrees of freedom associated to mesh entities (points of the DMPlex)."""
145+ indices = []
146+ for (i , W ) in enumerate (V ):
147+ section = W .dm .getDefaultSection ()
148+ for p in points :
149+ dof = section .getDof (p )
150+ if dof <= 0 :
151+ continue
152+ off = section .getOffset (p )
153+ # Local indices within W
154+ W_slice = slice (off * W .block_size , W .block_size * (off + dof ))
155+ indices .extend (V_local_ises_indices [i ][W_slice ])
156+ return indices
157+
158+
159+ def build_star_indices (V , V_local_ises_indices , mesh_dm , ordering , prefix , seed_points ):
160+ """Build index sets for star patches."""
161+ points = []
162+ for seed in seed_points :
163+ # Only build patches over owned DoFs
164+ if mesh_dm .getLabelValue ("pyop2_ghost" , seed ) != - 1 :
165+ continue
166+ # Create point list from mesh DM
167+ star , _ = mesh_dm .getTransitiveClosure (seed , useCone = False )
168+ star = order_points (mesh_dm , star , ordering , prefix )
169+ points .extend (star )
170+
171+ indices = get_entity_dofs (V , V_local_ises_indices , points )
172+ iset = PETSc .IS ().createGeneral (indices , comm = PETSc .COMM_SELF )
173+ return iset
174+
175+
176+ def build_vanka_indices (Z , Z_local_ises_indices , mesh_dm , ordering , prefix , include_star , seed_points ):
177+ """Build index sets for Vanka patches."""
178+ V_points = []
179+ Q_points = []
180+ for seed in seed_points :
181+ # Only build patches over owned DoFs
182+ if mesh_dm .getLabelValue ("pyop2_ghost" , seed ) != - 1 :
183+ continue
184+ # Create point list from mesh DM
185+ star , _ = mesh_dm .getTransitiveClosure (seed , useCone = False )
186+ star = order_points (mesh_dm , star , ordering , prefix )
187+ if include_star :
188+ Q_points .extend (star )
189+ else :
190+ Q_points .append (seed )
191+
192+ closure = []
193+ for s in reversed (star ):
194+ cs , _ = mesh_dm .getTransitiveClosure (s , useCone = True )
195+ closure .extend (cs )
196+ # Grab unique points with stable ordering
197+ V_points .extend (reversed (dict .fromkeys (closure )))
198+
199+ indices = get_entity_dofs (Z [0 ], Z_local_ises_indices [0 ], V_points )
200+ indices .extend (get_entity_dofs (Z [1 ], Z_local_ises_indices [1 ], Q_points ))
201+ iset = PETSc .IS ().createGeneral (indices , comm = PETSc .COMM_SELF )
202+ return iset
203+
204+
143205class ASMStarPC (ASMPatchPC ):
144206 '''Patch-based PC using Star of mesh entities implmented as an
145207 :class:`ASMPatchPC`.
@@ -162,42 +224,25 @@ def get_patches(self, V):
162224 warning ("applying ASMStarPC on an extruded mesh" )
163225
164226 # Obtain the topological entities to use to construct the stars
165- opts = PETSc .Options (self .prefix )
227+ prefix = self .prefix
228+ opts = PETSc .Options (prefix )
166229 depth = opts .getInt ("construct_dim" , default = 0 )
230+ coloring = opts .getBool ("coloring" , default = False )
167231 ordering = opts .getString ("mat_ordering_type" , default = "natural" )
168232 validate_overlap (mesh_unique , depth , "star" )
169233
170234 # Accessing .indices causes the allocation of a global array,
171235 # so we need to cache these for efficiency
172236 V_local_ises_indices = tuple (iset .indices for iset in V .dof_dset .local_ises )
173237
174- # Build index sets for the patches
175- ises = []
176238 (start , end ) = mesh_dm .getDepthStratum (depth )
177- for seed in range (start , end ):
178- # Only build patches over owned DoFs
179- if mesh_dm .getLabelValue ("pyop2_ghost" , seed ) != - 1 :
180- continue
181-
182- # Create point list from mesh DM
183- pt_array , _ = mesh_dm .getTransitiveClosure (seed , useCone = False )
184- pt_array = order_points (mesh_dm , pt_array , ordering , self .prefix )
185-
186- # Get DoF indices for patch
187- indices = []
188- for (i , W ) in enumerate (V ):
189- section = W .dm .getDefaultSection ()
190- for p in pt_array .tolist ():
191- dof = section .getDof (p )
192- if dof <= 0 :
193- continue
194- off = section .getOffset (p )
195- # Local indices within W
196- W_indices = slice (off * W .block_size , W .block_size * (off + dof ))
197- indices .extend (V_local_ises_indices [i ][W_indices ])
198- iset = PETSc .IS ().createGeneral (indices , comm = PETSc .COMM_SELF )
199- ises .append (iset )
200-
239+ if coloring :
240+ colors = mesh_dm .createColoring (depth = depth , distance = 1 )
241+ ises = [build_star_indices (V , V_local_ises_indices , mesh_dm , ordering , prefix , color .indices + start )
242+ for color in colors ]
243+ else :
244+ ises = [build_star_indices (V , V_local_ises_indices , mesh_dm , ordering , prefix , (seed ,))
245+ for seed in range (start , end )]
201246 return ises
202247
203248
@@ -223,25 +268,30 @@ def get_patches(self, V):
223268 warning ("applying ASMVankaPC on an extruded mesh" )
224269
225270 # Obtain the topological entities to use to construct the stars
271+ prefix = self .prefix
226272 opts = PETSc .Options (self .prefix )
227273 depth = opts .getInt ("construct_dim" , default = - 1 )
228274 height = opts .getInt ("construct_codim" , default = - 1 )
229275 if (depth == - 1 and height == - 1 ) or (depth != - 1 and height != - 1 ):
230276 raise ValueError (f"Must set exactly one of { self .prefix } construct_dim or { self .prefix } construct_codim" )
231277
232278 exclude_subspaces = list (map (int , opts .getString ("exclude_subspaces" , default = "-1" ).split ("," )))
279+ include_subspaces = [i for i in range (len (V )) if i not in exclude_subspaces ]
233280 include_type = opts .getString ("include_type" , default = "star" ).lower ()
234281 if include_type not in ["star" , "entity" ]:
235282 raise ValueError (f"{ self .prefix } include_type must be either 'star' or 'entity', not { include_type } " )
236283 include_star = include_type == "star"
237284
238- ordering = opts .getString ("mat_ordering_type" , default = "natural" )
285+ def splitting (V ):
286+ return (tuple (V [i ] for i in include_subspaces ), tuple (V [i ] for i in exclude_subspaces ))
287+
288+ Z = splitting (V )
239289 # Accessing .indices causes the allocation of a global array,
240290 # so we need to cache these for efficiency
241291 V_local_ises_indices = tuple (iset .indices for iset in V .dof_dset .local_ises )
292+ Z_local_ises_indices = splitting (V_local_ises_indices )
242293
243294 # Build index sets for the patches
244- ises = []
245295 if depth != - 1 :
246296 (start , end ) = mesh_dm .getDepthStratum (depth )
247297 patch_dim = depth
@@ -250,40 +300,16 @@ def get_patches(self, V):
250300 patch_dim = mesh_dm .getDimension () - height
251301 validate_overlap (mesh_unique , patch_dim , "vanka" )
252302
253- for seed in range (start , end ):
254- # Only build patches over owned DoFs
255- if mesh_dm .getLabelValue ("pyop2_ghost" , seed ) != - 1 :
256- continue
257-
258- # Create point list from mesh DM
259- star , _ = mesh_dm .getTransitiveClosure (seed , useCone = False )
260- star = order_points (mesh_dm , star , ordering , self .prefix )
261- pt_array = []
262- for pt in reversed (star ):
263- closure , _ = mesh_dm .getTransitiveClosure (pt , useCone = True )
264- pt_array .extend (closure )
265- # Grab unique points with stable ordering
266- pt_array = list (reversed (dict .fromkeys (pt_array )))
267-
268- # Get DoF indices for patch
269- indices = []
270- for (i , W ) in enumerate (V ):
271- section = W .dm .getDefaultSection ()
272- if i in exclude_subspaces :
273- loop_list = star if include_star else [seed ]
274- else :
275- loop_list = pt_array
276- for p in loop_list :
277- dof = section .getDof (p )
278- if dof <= 0 :
279- continue
280- off = section .getOffset (p )
281- # Local indices within W
282- W_indices = slice (off * W .block_size , W .block_size * (off + dof ))
283- indices .extend (V_local_ises_indices [i ][W_indices ])
284- iset = PETSc .IS ().createGeneral (indices , comm = PETSc .COMM_SELF )
285- ises .append (iset )
286-
303+ coloring = opts .getBool ("coloring" , default = False )
304+ ordering = opts .getString ("mat_ordering_type" , default = "natural" )
305+ if coloring :
306+ colors = mesh_dm .createColoring (depth = patch_dim , distance = 2 )
307+ ises = [build_vanka_indices (Z , Z_local_ises_indices , mesh_dm , ordering , prefix ,
308+ include_star , color .indices + start )
309+ for color in colors ]
310+ else :
311+ ises = [build_vanka_indices (Z , Z_local_ises_indices , mesh_dm , ordering , prefix , include_star , (seed ,))
312+ for seed in range (start , end )]
287313 return ises
288314
289315
0 commit comments