1212 Iterable ,
1313 Iterator ,
1414 List ,
15- Mapping ,
1615 Optional ,
1716 Protocol ,
1817 Sequence ,
@@ -151,7 +150,6 @@ def __init__(
151150 cpu : bool = False ,
152151 convert_module_fn : Optional [Callable [[ModelT ], nn .Module ]] = None ,
153152 ) -> None :
154- from apex .transformer .tensor_parallel .layers import set_defaults_if_not_set_tensor_model_parallel_attributes
155153 from megatron .core import parallel_state
156154
157155 _pipeline : List [nn .Module ]
@@ -174,67 +172,15 @@ def __init__(
174172 _model .configure_model ()
175173 _pipeline .append (_model )
176174
177- if convert_module_fn :
178- for i in range (len (_pipeline )):
179- _pipeline [i ] = convert_module_fn (_pipeline [i ])
180-
181- if isinstance (ddp_config , DistributedDataParallelConfig ):
182- for model_chunk_idx , model_chunk in enumerate (_pipeline ):
183- module = model_chunk .module
184-
185- ddp = DDP (
186- module .config ,
187- ddp_config ,
188- module ,
189- data_parallel_group = parallel_state .get_data_parallel_group (with_context_parallel = True ),
190- expert_data_parallel_group = parallel_state .get_data_modulo_expert_parallel_group (),
191- # Turn off bucketing for model_chunk 2 onwards, since communication for these
192- # model chunks is overlapped with compute anyway.
193- disable_bucketing = (model_chunk_idx > 0 ),
194- )
195- model_chunk .module = ddp
196- model_chunk .buffers = ddp .buffers # We need to do this explicitly since this is a attr pytorch uses
197- model_chunk .__class__ .__getattr__ = getattr_proxy # type: ignore
198-
199- # param_sync_func is set in nemo.lightning.pytorch.optim.megatron
200- no_sync_func , grad_sync_func = extract_ddp_funcs (ddp_config , _pipeline )
201- for module in _pipeline :
202- module .config .no_sync_func = no_sync_func
203- module .config .grad_sync_func = grad_sync_func
204-
205- for i , model_module in enumerate (_pipeline ):
206- if not cpu :
207- model_module .cuda (torch .cuda .current_device ())
208-
209- for param in model_module .parameters ():
210- set_defaults_if_not_set_tensor_model_parallel_attributes (param )
211-
212- if hasattr (model_module , "configure_model" ):
213- if not hasattr (model_module , "set_input_tensor" ):
214- if hasattr (model_module .module , "set_input_tensor" ):
215- model_module .set_input_tensor = model_module .module .set_input_tensor
216- else :
217- # TODO: What to do here?
218- pass
219-
220- # Print number of parameters.
221- if parallel_state .model_parallel_is_initialized () and parallel_state .get_data_parallel_rank () == 0 :
222- from nemo .utils import logging
223-
224- msg = (
225- f" > number of parameters on (tensor, pipeline) model parallel rank "
226- f"({ parallel_state .get_tensor_model_parallel_rank ()} , { parallel_state .get_pipeline_model_parallel_rank ()} ): "
227- f"{ _calc_number_of_params (_pipeline )} "
228- )
229- logging .info (msg )
230-
231175 super ().__init__ (_pipeline )
232176 self .precision_plugin = precision_plugin
177+ self ._cpu = cpu
233178 self .callbacks = callbacks or CallbackConnector ()
234179 self .data_step = data_step or default_data_step
235180 self .forward_step = forward_step or default_forward_step
236181 self .loss_reduction : MegatronLossReduction = loss_reduction
237182 self .ddp_config = ddp_config
183+ self .convert_module_fn = convert_module_fn
238184
239185 def forward (
240186 self ,
@@ -497,6 +443,82 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat
497443
498444 raise ValueError ("Cannot infer `num_microbatches` from data, please specify it manually" )
499445
446+ def init_model_parallel (self ):
447+ from apex .transformer .tensor_parallel .layers import set_defaults_if_not_set_tensor_model_parallel_attributes
448+ from megatron .core import parallel_state
449+
450+ for model_module in self :
451+ if not self ._cpu :
452+ model_module .cuda (torch .cuda .current_device ())
453+
454+ for param in model_module .parameters ():
455+ set_defaults_if_not_set_tensor_model_parallel_attributes (param )
456+
457+ if hasattr (model_module , "configure_model" ):
458+ if not hasattr (model_module , "set_input_tensor" ):
459+ if hasattr (model_module .module , "set_input_tensor" ):
460+ model_module .set_input_tensor = model_module .module .set_input_tensor
461+ else :
462+ # TODO: What to do here?
463+ pass
464+
465+ # Print number of parameters.
466+ if parallel_state .model_parallel_is_initialized () and parallel_state .get_data_parallel_rank () == 0 :
467+ from nemo .utils import logging
468+
469+ num_params = _calc_number_of_params (list (self ))
470+ num_trainable_params = _calc_number_of_trainable_params (list (self ))
471+
472+ msg = (
473+ f" > number of parameters on (tensor, pipeline) model parallel rank "
474+ f"({ parallel_state .get_tensor_model_parallel_rank ()} , { parallel_state .get_pipeline_model_parallel_rank ()} ): "
475+ f"{ num_params } "
476+ )
477+ logging .info (msg )
478+
479+ if num_params != num_trainable_params :
480+ logging .info (
481+ f" > number of trainable parameters: { num_trainable_params } ({ num_trainable_params / num_params :.2%} of total)"
482+ )
483+
484+ if self .convert_module_fn :
485+ self .apply_convert_module_fn ()
486+
487+ self .init_ddp ()
488+
489+ def apply_convert_module_fn (self ):
490+ for i in range (len (self )):
491+ self [i ] = self .convert_module_fn (self [i ])
492+
493+ def init_ddp (self ):
494+ if not isinstance (self .ddp_config , DistributedDataParallelConfig ):
495+ return
496+
497+ from megatron .core import parallel_state
498+
499+ for model_chunk_idx , model_chunk in enumerate (self ):
500+ module = model_chunk .module
501+
502+ ddp = DDP (
503+ module .config ,
504+ self .ddp_config ,
505+ module ,
506+ data_parallel_group = parallel_state .get_data_parallel_group (with_context_parallel = True ),
507+ expert_data_parallel_group = parallel_state .get_data_modulo_expert_parallel_group (),
508+ # Turn off bucketing for model_chunk 2 onwards, since communication for these
509+ # model chunks is overlapped with compute anyway.
510+ disable_bucketing = (model_chunk_idx > 0 ),
511+ )
512+ model_chunk .module = ddp
513+ model_chunk .buffers = ddp .buffers # We need to do this explicitly since this is a attr pytorch uses
514+ model_chunk .__class__ .__getattr__ = getattr_proxy # type: ignore
515+
516+ # param_sync_func is set in nemo.lightning.pytorch.optim.megatron
517+ no_sync_func , grad_sync_func = extract_ddp_funcs (self .ddp_config , self )
518+ for module in self :
519+ module .config .no_sync_func = no_sync_func
520+ module .config .grad_sync_func = grad_sync_func
521+
500522 def _build_context (self , context : Dict [str , Any ]) -> Dict [str , Any ]:
501523 if "self" in context :
502524 del context ["self" ]
@@ -587,18 +609,21 @@ def forward_backward_func(self) -> "MegatronStepProtocol":
587609
588610 @override
589611 def __getattr__ (self , item : Any ) -> Any :
590- if len (self ) == 0 :
591- return super ().__getattr__ (item )
592-
593612 try :
594- # __getattr__ gets called as a last resort if the attribute does not exist
595- # call nn.Module's implementation first
613+ # First, try to get the attribute from the superclass (nn.ModuleList)
596614 return super ().__getattr__ (item )
597615 except AttributeError :
598- # If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
599- attr = getattr (self ._modules [self ._get_abs_string_index (0 )], item )
616+ # If not found in superclass, check if we have any modules
617+ if len (self ) == 0 :
618+ raise AttributeError (
619+ f"'{ self .__class__ .__name__ } ' object has no attribute '{ item } ' and contains no modules"
620+ )
600621
601- return attr
622+ # Try to get it from the first module
623+ try :
624+ return getattr (self ._modules [self ._get_abs_string_index (0 )], item )
625+ except AttributeError :
626+ raise AttributeError (f"'{ self .__class__ .__name__ } ' object has no attribute '{ item } '" )
602627
603628
604629class _ModuleStepFunction :
@@ -937,6 +962,12 @@ def _calc_number_of_params(model: List[nn.Module]) -> int:
937962 return sum ([sum ([p .nelement () for p in model_module .parameters ()]) for model_module in model ])
938963
939964
965+ def _calc_number_of_trainable_params (model : List [nn .Module ]) -> int :
966+ assert isinstance (model , list )
967+
968+ return sum ([sum ([p .numel () for p in model_module .parameters () if p .requires_grad ]) for model_module in model ])
969+
970+
940971def is_list_of_iterators (var ) -> bool :
941972 if not isinstance (var , list ):
942973 return False
0 commit comments