Skip to content

Commit d1d1f7c

Browse files
maanug-nvananthsub
andauthored
LLM Forward Step (#12673)
* pretrain loss func Signed-off-by: Maanu Grover <[email protected]> * get batch and forward Signed-off-by: Maanu Grover <[email protected]> * add rerun functionality to loss Signed-off-by: Maanu Grover <[email protected]> * formatting Signed-off-by: Maanu Grover <[email protected]> * injection of state Signed-off-by: Maanu Grover <[email protected]> * remove globalstate singleton functionality Signed-off-by: Maanu Grover <[email protected]> * update example Signed-off-by: Maanu Grover <[email protected]> * missing copyright Signed-off-by: Maanu Grover <[email protected]> * fix for latest mcore Signed-off-by: Maanu Grover <[email protected]> * syntax Co-authored-by: Ananth Subramaniam <[email protected]> Signed-off-by: Maanu Grover <[email protected]> * move assertion Signed-off-by: Maanu Grover <[email protected]> * refactor for eval Signed-off-by: Maanu Grover <[email protected]> * move to avoid circular import Signed-off-by: Maanu Grover <[email protected]> * fix Signed-off-by: Maanu Grover <[email protected]> * unused Signed-off-by: Maanu Grover <[email protected]> * cache num fw args in train and eval Signed-off-by: Maanu Grover <[email protected]> * docstring fix Signed-off-by: Maanu Grover <[email protected]> * remove duplicate Signed-off-by: Maanu Grover <[email protected]> --------- Signed-off-by: Maanu Grover <[email protected]> Signed-off-by: Maanu Grover <[email protected]> Co-authored-by: Ananth Subramaniam <[email protected]>
1 parent 752edac commit d1d1f7c

File tree

11 files changed

+345
-88
lines changed

11 files changed

+345
-88
lines changed

nemo/tron/checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@
4040
FullyParallelLoadStrategyWrapper,
4141
FullyParallelSaveStrategyWrapper,
4242
)
43+
from megatron.core.fp8_utils import is_float8tensor
4344
from megatron.core.num_microbatches_calculator import update_num_microbatches
4445
from megatron.core.rerun_state_machine import get_rerun_state_machine
45-
from megatron.core.utils import is_float8tensor
4646

4747
from nemo.tron import fault_tolerance
4848
from nemo.tron.config import ConfigContainer

nemo/tron/eval.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from nemo.tron import fault_tolerance
2626
from nemo.tron.state import GlobalState
2727
from nemo.tron.utils.common_utils import is_last_rank, print_rank_0, print_rank_last
28+
from nemo.tron.utils.train_utils import check_forward_step_func_num_args, maybe_inject_state
2829

2930

3031
def evaluate(
@@ -38,8 +39,10 @@ def evaluate(
3839
non_loss_data_func=None,
3940
):
4041
"""Evaluation."""
41-
timers = state.timers
42+
# Check num args to forward_step_func
43+
num_fw_args = check_forward_step_func_num_args(forward_step_func)
4244

45+
timers = state.timers
4346
timers("evaluate", log_level=0).start(barrier=True)
4447

4548
# Turn on evaluation mode which disables dropout.
@@ -66,12 +69,13 @@ def evaluate(
6669
if verbose:
6770
print_rank_0(f"Evaluating iter {iteration}/{state.cfg.train_config.eval_iters}")
6871

72+
wrapped_forward_step = maybe_inject_state(forward_step_func, state, num_fw_args=num_fw_args)
6973
forward_backward_func = get_forward_backward_func()
7074
# Don't care about timing during evaluation
7175
config.timers = None
7276
fault_tolerance.on_eval_step_start(state)
7377
loss_dicts = forward_backward_func(
74-
forward_step_func=forward_step_func,
78+
forward_step_func=wrapped_forward_step,
7579
data_iterator=data_iterator,
7680
model=model,
7781
num_microbatches=eval_num_microbatches,
@@ -119,7 +123,7 @@ def evaluate(
119123
collected_non_loss_data = non_loss_data_func(model)
120124
elif process_non_loss_data_func is not None and is_last_rank():
121125
collected_non_loss_data = forward_backward_func(
122-
forward_step_func=forward_step_func,
126+
forward_step_func=wrapped_forward_step,
123127
data_iterator=data_iterator,
124128
model=model,
125129
num_microbatches=get_num_microbatches(),

nemo/tron/examples/lingua-1b_dclm.py

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,13 @@
1313
# limitations under the License.
1414

1515
import math
16-
from functools import partial
1716

1817
import torch
1918
import torch.distributed
20-
from megatron.core import mpu
2119
from megatron.core.distributed import DistributedDataParallelConfig
2220
from megatron.core.optimizer import OptimizerConfig
2321

2422
from nemo.collections import llm
25-
from nemo.collections.llm.gpt.model.base import gpt_data_step
2623
from nemo.tron.api import megatron_pretrain
2724
from nemo.tron.config import (
2825
CheckpointConfig,
@@ -35,74 +32,7 @@
3532
TrainingConfig,
3633
)
3734
from nemo.tron.data.dataset import get_blend_and_blend_per_split
38-
from nemo.tron.state import GlobalState
39-
40-
# define spiky loss as a variation of 20% or more
41-
SPIKY_LOSS_PERC = 0.2
42-
43-
44-
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
45-
"""Loss function.
46-
47-
Args:
48-
loss_mask (torch.Tensor): Used to mask out some portions of the loss
49-
output_tensor (torch.Tensor): The tensor with the losses
50-
51-
Returns:
52-
the loss scalar for this micro-batch
53-
the number of non-padded tokens in this microbatch
54-
a dict containing reporting metrics on the loss and number of tokens across
55-
the data parallel ranks
56-
"""
57-
state = GlobalState()
58-
losses = output_tensor.float()
59-
loss_mask = loss_mask.view(-1).float()
60-
total_tokens = loss_mask.sum()
61-
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
62-
63-
if state.cfg.model_config.context_parallel_size > 1:
64-
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
65-
66-
# Reduce loss for logging.
67-
reporting_loss = loss.clone().detach()
68-
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
69-
70-
local_num_tokens = loss[1].clone().detach().to(torch.int)
71-
return (
72-
loss[0] * state.cfg.model_config.context_parallel_size,
73-
local_num_tokens,
74-
{"lm loss": (reporting_loss[0], reporting_loss[1])},
75-
)
76-
77-
78-
def forward_step(data_iterator, model):
79-
"""Forward training step.
80-
81-
Args:
82-
data_iterator : Input data iterator
83-
model (GPTModel): The GPT Model
84-
"""
85-
timers = GlobalState().timers
86-
87-
# Get the batch.
88-
timers("batch-generator", log_level=2).start()
89-
batch = gpt_data_step(data_iterator)
90-
if "attention_mask" not in batch:
91-
batch["attention_mask"] = None
92-
93-
tokens, labels, loss_mask, attention_mask, position_ids = (
94-
batch["tokens"],
95-
batch["labels"],
96-
batch["loss_mask"],
97-
batch["attention_mask"],
98-
batch["position_ids"],
99-
)
100-
timers("batch-generator").stop()
101-
102-
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
103-
104-
return output_tensor, partial(loss_func, loss_mask)
105-
35+
from nemo.tron.llm.gpt import forward_step
10636

10737
if __name__ == "__main__":
10838
global_batch_size = 256

nemo/tron/llm/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

nemo/tron/llm/gpt.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import partial
16+
from typing import Iterable
17+
18+
from megatron.core import parallel_state
19+
from megatron.core.models.gpt import GPTModel
20+
from megatron.core.utils import get_batch_on_this_cp_rank
21+
22+
from nemo.tron.config import ConfigContainer
23+
from nemo.tron.llm.utils import get_batch_on_this_tp_rank
24+
from nemo.tron.losses import masked_next_token_loss
25+
from nemo.tron.state import GlobalState
26+
27+
28+
def get_batch(data_iterator, cfg: ConfigContainer):
29+
"""Generate a batch."""
30+
31+
if (not parallel_state.is_pipeline_first_stage()) and (not parallel_state.is_pipeline_last_stage()):
32+
return None, None, None, None, None
33+
34+
# get batches based on the TP rank you are on
35+
batch = get_batch_on_this_tp_rank(data_iterator, cfg)
36+
37+
# slice batch along sequence dimension for context parallelism
38+
batch = get_batch_on_this_cp_rank(batch)
39+
40+
return batch.values()
41+
42+
43+
def forward_step(state: GlobalState, data_iterator: Iterable, model: GPTModel):
44+
"""Forward training step.
45+
46+
Args:
47+
state (GlobalState): Global state for the run
48+
data_iterator : Input data iterator
49+
model (GPTModel): The GPT Model
50+
"""
51+
52+
timers = state.timers
53+
straggler_timer = state.straggler_timer
54+
55+
timers("batch-generator", log_level=2).start()
56+
with straggler_timer(bdata=True):
57+
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, state.cfg)
58+
timers("batch-generator").stop()
59+
60+
with straggler_timer:
61+
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
62+
63+
return output_tensor, partial(masked_next_token_loss, loss_mask)

nemo/tron/llm/utils.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, Iterable
16+
import torch
17+
from megatron.core import parallel_state
18+
from nemo.tron.config import ConfigContainer
19+
20+
21+
def get_batch_on_this_tp_rank(data_iterator: Iterable, cfg: ConfigContainer) -> Dict[str, torch.Tensor]:
22+
def _broadcast(item):
23+
if item is not None:
24+
torch.distributed.broadcast(
25+
item,
26+
parallel_state.get_tensor_model_parallel_src_rank(),
27+
group=parallel_state.get_tensor_model_parallel_group(),
28+
)
29+
30+
if parallel_state.get_tensor_model_parallel_rank() == 0:
31+
if data_iterator is not None:
32+
data = next(data_iterator)
33+
else:
34+
data = None
35+
36+
batch = {
37+
"tokens": data["tokens"].cuda(non_blocking=True),
38+
"labels": data["labels"].cuda(non_blocking=True),
39+
"loss_mask": data["loss_mask"].cuda(non_blocking=True),
40+
"attention_mask": None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking=True),
41+
"position_ids": data["position_ids"].cuda(non_blocking=True),
42+
}
43+
44+
if cfg.model_config.pipeline_model_parallel_size == 1:
45+
_broadcast(batch["tokens"])
46+
_broadcast(batch["labels"])
47+
_broadcast(batch["loss_mask"])
48+
_broadcast(batch["attention_mask"])
49+
_broadcast(batch["position_ids"])
50+
51+
elif parallel_state.is_pipeline_first_stage():
52+
_broadcast(batch["tokens"])
53+
_broadcast(batch["attention_mask"])
54+
_broadcast(batch["position_ids"])
55+
56+
elif parallel_state.is_pipeline_last_stage():
57+
_broadcast(batch["labels"])
58+
_broadcast(batch["loss_mask"])
59+
_broadcast(batch["attention_mask"])
60+
61+
else:
62+
mbs = cfg.train_config.micro_batch_size
63+
seq_length = cfg.model_config.seq_length
64+
tokens = torch.empty(
65+
(mbs, seq_length),
66+
dtype=torch.int64,
67+
device=torch.cuda.current_device(),
68+
)
69+
labels = torch.empty(
70+
(mbs, seq_length),
71+
dtype=torch.int64,
72+
device=torch.cuda.current_device(),
73+
)
74+
loss_mask = torch.empty(
75+
(mbs, seq_length),
76+
dtype=torch.float32,
77+
device=torch.cuda.current_device(),
78+
)
79+
if cfg.dataset_config.create_attention_mask:
80+
attention_mask = torch.empty(
81+
(
82+
mbs,
83+
1,
84+
seq_length,
85+
seq_length,
86+
),
87+
dtype=torch.bool,
88+
device=torch.cuda.current_device(),
89+
)
90+
else:
91+
attention_mask = None
92+
position_ids = torch.empty(
93+
(mbs, seq_length),
94+
dtype=torch.int64,
95+
device=torch.cuda.current_device(),
96+
)
97+
98+
if cfg.model_config.pipeline_model_parallel_size == 1:
99+
_broadcast(tokens)
100+
_broadcast(labels)
101+
_broadcast(loss_mask)
102+
_broadcast(attention_mask)
103+
_broadcast(position_ids)
104+
105+
elif parallel_state.is_pipeline_first_stage():
106+
labels = None
107+
loss_mask = None
108+
109+
_broadcast(tokens)
110+
_broadcast(attention_mask)
111+
_broadcast(position_ids)
112+
113+
elif parallel_state.is_pipeline_last_stage():
114+
tokens = None
115+
position_ids = None
116+
117+
_broadcast(labels)
118+
_broadcast(loss_mask)
119+
_broadcast(attention_mask)
120+
121+
batch = {
122+
"tokens": tokens,
123+
"labels": labels,
124+
"loss_mask": loss_mask,
125+
"attention_mask": attention_mask,
126+
"position_ids": position_ids,
127+
}
128+
129+
return batch

0 commit comments

Comments
 (0)