Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions nemo/collections/llm/gpt/model/mistral_7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, List, Optional

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from typing_extensions import Annotated
Expand Down Expand Up @@ -46,9 +47,7 @@ def __init__(
optim: Optional[OptimizerModule] = None,
tokenizer: Optional["TokenizerSpec"] = None,
):
_tokenizer = tokenizer or HFMistral7BImporter("mistralai/Mistral-7B-v0.1").tokenizer

super().__init__(config or Mistral7BConfig(), optim=optim, tokenizer=_tokenizer)
super().__init__(config or Mistral7BConfig(), optim=optim, tokenizer=tokenizer)


@io.model_importer(Mistral7BModel, "hf")
Expand All @@ -72,6 +71,9 @@ def apply(self, output_path: Path) -> Path:

return output_path

def on_import_ckpt(self, model: pl.LightningModule):
model.tokenizer = self.tokenizer

def convert_state(self, source, target):
mapping = {
"model.embed_tokens.weight": "embedding.word_embeddings.weight",
Expand Down
122 changes: 0 additions & 122 deletions nemo/lightning/experiment.py

This file was deleted.

16 changes: 13 additions & 3 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
import os
import shutil
Expand Down Expand Up @@ -138,7 +139,7 @@ def nemo_setup(self, model: pl.LightningModule, trainer: Optional[pl.Trainer] =
from nemo.lightning import MegatronStrategy, Trainer

_trainer = trainer or Trainer(
devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False)
devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False, ddp="pytorch")
)

_trainer.strategy.connect(model)
Expand All @@ -159,7 +160,12 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer) -> None:
output_path (Path): The path where the model checkpoint will be saved.
trainer (pl.Trainer): The trainer with the strategy to save the model.
"""
trainer.strategy.setup(trainer)
_setup_kwargs = {}
setup_signature = inspect.signature(trainer.strategy.setup)
if 'setup_optimizers' in setup_signature.parameters:
_setup_kwargs["setup_optimizers"] = False

trainer.strategy.setup(trainer, **_setup_kwargs)
trainer.save_checkpoint(output_path)

def nemo_load(
Expand All @@ -181,7 +187,9 @@ def nemo_load(
from nemo.lightning.io.api import load_ckpt

model = load_ckpt(path).model
_trainer = trainer or Trainer(devices=1, accelerator="cpu" if cpu else "gpu", strategy=MegatronStrategy())
_trainer = trainer or Trainer(
devices=1, accelerator="cpu" if cpu else "gpu", strategy=MegatronStrategy(ddp="pytorch")
)

_trainer.strategy.connect(model)
_trainer.strategy.setup_environment()
Expand All @@ -208,3 +216,5 @@ def local_path(self, base_path: Optional[Path] = None) -> Path:
_base = Path(NEMO_MODELS_CACHE)

return _base / str(self).replace("://", "/")

def on_import_ckpt(self, model: pl.LightningModule): ...

Check notice

Code scanning / CodeQL

Statement has no effect

This statement has no effect.
2 changes: 2 additions & 0 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def import_ckpt(self, path: str, overwrite: bool = False, base_path: Optional[Pa
ckpt_path: Path = connector.local_path(base_path=base_path)
ckpt_path = connector(ckpt_path, overwrite=overwrite)

connector.on_import_ckpt(self)

return ckpt_path

@classmethod
Expand Down
22 changes: 11 additions & 11 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def connect(self, model: pl.LightningModule) -> None:
self._mcore_config = config

@override
def setup(self, trainer: pl.Trainer) -> None:
def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None:
assert self.accelerator is not None
self.accelerator.setup(trainer)
self.trainer = trainer
Expand All @@ -150,7 +150,7 @@ def setup(self, trainer: pl.Trainer) -> None:
self.data_sampler.connect(trainer)

self._fix_progress_bar(trainer)
self.setup_megatron_parallel(trainer)
self.setup_megatron_parallel(trainer, setup_optimizers=setup_optimizers)
self.setup_precision_plugin()

if trainer.num_sanity_val_steps > 1 and self.pipeline_model_parallel_size > 1:
Expand Down Expand Up @@ -205,7 +205,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:

return dataloader

def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None:
assert self.model is not None, "Model is not set"

self.megatron_parallel = MegatronParallel(
Expand All @@ -224,16 +224,16 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
self.model.configure_optimizers, megatron_parallel=self.megatron_parallel
)

self.setup_optimizers(trainer)
if setup_optimizers:
self.setup_optimizers(trainer)

# TODO: Throw an execption if we have a mcore optimizer and no ddp_config
# TODO: Throw an execption if we have a mcore optimizer and no ddp_config
if hasattr(self.precision_plugin, "convert_optimizer"):
_optimizers = [*self.optimizers]
_optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0])
self.optimizers = _optimizers

if hasattr(self.precision_plugin, "convert_optimizer"):
_optimizers = [*self.optimizers]
_optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0])
self.optimizers = _optimizers

_optimizers_to_device(self.optimizers, self.root_device)
_optimizers_to_device(self.optimizers, self.root_device)

self.model = self.megatron_parallel

Expand Down