Skip to content

Commit 936cd37

Browse files
github-actions[bot]marcromeyncuichenxashors1
authored andcommitted
[NeMo-UX] Fix when optimizers are setup for PEFT (#9619) (#9647)
* Fix when optimizers are setup for PEFT * Apply isort and black reformatting * Init DDP inside PEFT * Apply isort and black reformatting * Some fixes, loss seems to become nan with peft for some reason * Apply isort and black reformatting * Loss goes down on fp32 * Apply isort and black reformatting * Simplifying FNMixin * Apply isort and black reformatting * Fix bug with new checkpoint-io * Apply isort and black reformatting * Fix failing test: test_peft_on_train_epoch_start_with_adapter * Apply isort and black reformatting --------- Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com> Signed-off-by: ashors1 <ashors@nvidia.com> Co-authored-by: Marc Romeyn <mromeijn@nvidia.com> Co-authored-by: marcromeyn <marcromeyn@users.noreply.github.com> Co-authored-by: Chen Cui <chcui@nvidia.com> Co-authored-by: ashors1 <ashors@nvidia.com>
1 parent 8ef088a commit 936cd37

File tree

11 files changed

+177
-100
lines changed

11 files changed

+177
-100
lines changed

nemo/collections/llm/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _setup(
279279
model_transform: Optional[Union[PEFT, ModelTransform, Callable]],
280280
) -> Any: # Return type is Any because app_state's type is not specified
281281
_log = log or NeMoLogger()
282-
if resume and resume.adapter_path and _log.ckpt:
282+
if resume and isinstance(model_transform, PEFT) and _log.ckpt:
283283
logging.info("Disabling try_restore_best_ckpt restoration for adapters")
284284
_log.ckpt.try_restore_best_ckpt = False
285285

nemo/collections/llm/fn/mixin.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing_extensions import Self
33

44
from nemo.collections.llm.fn import base as fn
5+
from nemo.utils import logging
56

67

78
class FNMixin:
@@ -114,8 +115,12 @@ def freeze(self) -> None:
114115
"""
115116
assert isinstance(self, nn.Module), "self is not a nn.Module"
116117

117-
for param in self.parameters():
118-
param.requires_grad = False
118+
params = list(self.parameters())
119+
if not params:
120+
logging.info(f"No parameters found in module {self.__class__.__name__}")
121+
else:
122+
for param in params:
123+
param.requires_grad = False
119124

120125
def unfreeze(self) -> None:
121126
"""
@@ -124,5 +129,9 @@ def unfreeze(self) -> None:
124129
"""
125130
assert isinstance(self, nn.Module), "self is not a nn.Module"
126131

127-
for param in self.parameters():
128-
param.requires_grad = True
132+
params = list(self.parameters())
133+
if not params:
134+
logging.info(f"No parameters found in module {self.__class__.__name__}")
135+
else:
136+
for param in params:
137+
param.requires_grad = True

nemo/lightning/_strategy_lib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,4 +516,7 @@ def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], stri
516516
elif count > n_nesting:
517517
to_remove = "module." * (count - n_nesting)
518518
_state_dict[key[len(to_remove) :]] = value
519+
else:
520+
_state_dict[key] = value
521+
519522
module.load_state_dict(_state_dict, strict=strict)

nemo/lightning/io/connector.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,8 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer) -> None:
160160
output_path (Path): The path where the model checkpoint will be saved.
161161
trainer (pl.Trainer): The trainer with the strategy to save the model.
162162
"""
163-
_setup_kwargs = {}
164-
setup_signature = inspect.signature(trainer.strategy.setup)
165-
if 'setup_optimizers' in setup_signature.parameters:
166-
_setup_kwargs["setup_optimizers"] = False
167-
168-
trainer.strategy.setup(trainer, **_setup_kwargs)
163+
trainer.strategy._setup_optimizers = False
164+
trainer.strategy.setup(trainer)
169165
trainer.save_checkpoint(output_path)
170166

171167
def nemo_load(

nemo/lightning/megatron_parallel.py

Lines changed: 95 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
Iterable,
1313
Iterator,
1414
List,
15-
Mapping,
1615
Optional,
1716
Protocol,
1817
Sequence,
@@ -151,7 +150,6 @@ def __init__(
151150
cpu: bool = False,
152151
convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None,
153152
) -> None:
154-
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
155153
from megatron.core import parallel_state
156154

157155
_pipeline: List[nn.Module]
@@ -174,67 +172,15 @@ def __init__(
174172
_model.configure_model()
175173
_pipeline.append(_model)
176174

177-
if convert_module_fn:
178-
for i in range(len(_pipeline)):
179-
_pipeline[i] = convert_module_fn(_pipeline[i])
180-
181-
if isinstance(ddp_config, DistributedDataParallelConfig):
182-
for model_chunk_idx, model_chunk in enumerate(_pipeline):
183-
module = model_chunk.module
184-
185-
ddp = DDP(
186-
module.config,
187-
ddp_config,
188-
module,
189-
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
190-
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
191-
# Turn off bucketing for model_chunk 2 onwards, since communication for these
192-
# model chunks is overlapped with compute anyway.
193-
disable_bucketing=(model_chunk_idx > 0),
194-
)
195-
model_chunk.module = ddp
196-
model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses
197-
model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore
198-
199-
# param_sync_func is set in nemo.lightning.pytorch.optim.megatron
200-
no_sync_func, grad_sync_func = extract_ddp_funcs(ddp_config, _pipeline)
201-
for module in _pipeline:
202-
module.config.no_sync_func = no_sync_func
203-
module.config.grad_sync_func = grad_sync_func
204-
205-
for i, model_module in enumerate(_pipeline):
206-
if not cpu:
207-
model_module.cuda(torch.cuda.current_device())
208-
209-
for param in model_module.parameters():
210-
set_defaults_if_not_set_tensor_model_parallel_attributes(param)
211-
212-
if hasattr(model_module, "configure_model"):
213-
if not hasattr(model_module, "set_input_tensor"):
214-
if hasattr(model_module.module, "set_input_tensor"):
215-
model_module.set_input_tensor = model_module.module.set_input_tensor
216-
else:
217-
# TODO: What to do here?
218-
pass
219-
220-
# Print number of parameters.
221-
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
222-
from nemo.utils import logging
223-
224-
msg = (
225-
f" > number of parameters on (tensor, pipeline) model parallel rank "
226-
f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): "
227-
f"{_calc_number_of_params(_pipeline)}"
228-
)
229-
logging.info(msg)
230-
231175
super().__init__(_pipeline)
232176
self.precision_plugin = precision_plugin
177+
self._cpu = cpu
233178
self.callbacks = callbacks or CallbackConnector()
234179
self.data_step = data_step or default_data_step
235180
self.forward_step = forward_step or default_forward_step
236181
self.loss_reduction: MegatronLossReduction = loss_reduction
237182
self.ddp_config = ddp_config
183+
self.convert_module_fn = convert_module_fn
238184

239185
def forward(
240186
self,
@@ -497,6 +443,82 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat
497443

498444
raise ValueError("Cannot infer `num_microbatches` from data, please specify it manually")
499445

446+
def init_model_parallel(self):
447+
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
448+
from megatron.core import parallel_state
449+
450+
for model_module in self:
451+
if not self._cpu:
452+
model_module.cuda(torch.cuda.current_device())
453+
454+
for param in model_module.parameters():
455+
set_defaults_if_not_set_tensor_model_parallel_attributes(param)
456+
457+
if hasattr(model_module, "configure_model"):
458+
if not hasattr(model_module, "set_input_tensor"):
459+
if hasattr(model_module.module, "set_input_tensor"):
460+
model_module.set_input_tensor = model_module.module.set_input_tensor
461+
else:
462+
# TODO: What to do here?
463+
pass
464+
465+
# Print number of parameters.
466+
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
467+
from nemo.utils import logging
468+
469+
num_params = _calc_number_of_params(list(self))
470+
num_trainable_params = _calc_number_of_trainable_params(list(self))
471+
472+
msg = (
473+
f" > number of parameters on (tensor, pipeline) model parallel rank "
474+
f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): "
475+
f"{num_params}"
476+
)
477+
logging.info(msg)
478+
479+
if num_params != num_trainable_params:
480+
logging.info(
481+
f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)"
482+
)
483+
484+
if self.convert_module_fn:
485+
self.apply_convert_module_fn()
486+
487+
self.init_ddp()
488+
489+
def apply_convert_module_fn(self):
490+
for i in range(len(self)):
491+
self[i] = self.convert_module_fn(self[i])
492+
493+
def init_ddp(self):
494+
if not isinstance(self.ddp_config, DistributedDataParallelConfig):
495+
return
496+
497+
from megatron.core import parallel_state
498+
499+
for model_chunk_idx, model_chunk in enumerate(self):
500+
module = model_chunk.module
501+
502+
ddp = DDP(
503+
module.config,
504+
self.ddp_config,
505+
module,
506+
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
507+
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
508+
# Turn off bucketing for model_chunk 2 onwards, since communication for these
509+
# model chunks is overlapped with compute anyway.
510+
disable_bucketing=(model_chunk_idx > 0),
511+
)
512+
model_chunk.module = ddp
513+
model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses
514+
model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore
515+
516+
# param_sync_func is set in nemo.lightning.pytorch.optim.megatron
517+
no_sync_func, grad_sync_func = extract_ddp_funcs(self.ddp_config, self)
518+
for module in self:
519+
module.config.no_sync_func = no_sync_func
520+
module.config.grad_sync_func = grad_sync_func
521+
500522
def _build_context(self, context: Dict[str, Any]) -> Dict[str, Any]:
501523
if "self" in context:
502524
del context["self"]
@@ -587,18 +609,21 @@ def forward_backward_func(self) -> "MegatronStepProtocol":
587609

588610
@override
589611
def __getattr__(self, item: Any) -> Any:
590-
if len(self) == 0:
591-
return super().__getattr__(item)
592-
593612
try:
594-
# __getattr__ gets called as a last resort if the attribute does not exist
595-
# call nn.Module's implementation first
613+
# First, try to get the attribute from the superclass (nn.ModuleList)
596614
return super().__getattr__(item)
597615
except AttributeError:
598-
# If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
599-
attr = getattr(self._modules[self._get_abs_string_index(0)], item)
616+
# If not found in superclass, check if we have any modules
617+
if len(self) == 0:
618+
raise AttributeError(
619+
f"'{self.__class__.__name__}' object has no attribute '{item}' and contains no modules"
620+
)
600621

601-
return attr
622+
# Try to get it from the first module
623+
try:
624+
return getattr(self._modules[self._get_abs_string_index(0)], item)
625+
except AttributeError:
626+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
602627

603628

604629
class _ModuleStepFunction:
@@ -937,6 +962,12 @@ def _calc_number_of_params(model: List[nn.Module]) -> int:
937962
return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
938963

939964

965+
def _calc_number_of_trainable_params(model: List[nn.Module]) -> int:
966+
assert isinstance(model, list)
967+
968+
return sum([sum([p.numel() for p in model_module.parameters() if p.requires_grad]) for model_module in model])
969+
970+
940971
def is_list_of_iterators(var) -> bool:
941972
if not isinstance(var, list):
942973
return False

nemo/lightning/pytorch/callbacks/model_transform.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
6565

6666
def _maybe_apply_transform(self, trainer):
6767
if self._needs_to_call:
68-
self.model_transform(trainer.model)
68+
self.apply_transform(trainer)
69+
70+
def apply_transform(self, trainer):
71+
self.model_transform(trainer.model)
6972

7073
@property
7174
def _needs_to_call(self) -> bool:

nemo/lightning/pytorch/callbacks/peft.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,27 @@ def __call__(self, model: nn.Module) -> nn.Module:
8484
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
8585
super().setup(trainer, pl_module, stage=stage)
8686

87+
trainer.strategy.trainer = trainer
8788
self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io)
8889
trainer.strategy._checkpoint_io = self.wrapped_io
90+
trainer.strategy._init_model_parallel = False
91+
trainer.strategy._setup_optimizers = False
8992

90-
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
91-
needs_to_call = self._needs_to_call
92-
self._maybe_apply_transform(trainer)
93+
def apply_transform(self, trainer):
94+
super().apply_transform(trainer)
9395

94-
# Check if we need to load the adapters
95-
if needs_to_call and self.wrapped_io.adapter_ckpt_path is not None:
96+
if self.wrapped_io.adapter_ckpt_path is not None:
9697
logging.info(f"Loading adapters from {self.wrapped_io.adapter_ckpt_path}")
9798
adapter_state = self.wrapped_io.load_checkpoint(self.wrapped_io.adapter_ckpt_path)
9899
trainer.strategy.load_model_state_dict(adapter_state, strict=False)
99100

101+
if hasattr(trainer.strategy, "init_model_parallel"):
102+
logging.info("Initializing model parallel")
103+
trainer.strategy.init_model_parallel()
104+
105+
logging.info("Setting up optimizers")
106+
trainer.strategy.setup_optimizers(trainer)
107+
100108
def on_load_checkpoint(
101109
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
102110
) -> None:

nemo/lightning/pytorch/optim/lr_scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,6 @@ def scheduler(self, model, optimizer):
445445

446446
return {
447447
"optimizer": optimizer,
448-
"scheduler": lr_scheduler,
449448
"lr_scheduler": {
450449
# REQUIRED: The scheduler instance
451450
"scheduler": lr_scheduler,

nemo/lightning/pytorch/plugins/mixed_precision.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,17 @@ def convert_module(self, module: Module) -> Module:
6161
This is optional and depends on the precision limitations during optimization.
6262
6363
"""
64-
from megatron.core.distributed import DistributedDataParallel
6564
from megatron.core.transformer.module import Float16Module
6665
from megatron.core.utils import get_model_config
6766

6867
if self.precision in ["16-mixed", "bf16-mixed"]:
6968
config = get_model_config(module.module)
7069
config.fp16 = self.precision == "16-mixed"
7170
config.bf16 = self.precision == "bf16-mixed"
72-
if not isinstance(module.module, Float16Module):
71+
if isinstance(module.module, Float16Module):
72+
new_float16_module = Float16Module(config, module.module.module)
73+
module.module = new_float16_module
74+
else:
7375
module.module = Float16Module(config, module.module)
7476

7577
return module

0 commit comments

Comments
 (0)