Skip to content

Commit 22cf880

Browse files
akoumpaashors1
authored andcommitted
Akoumparouli/nemo ux mixtral (#9446)
* use default collate if dataset does not have one Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * mixtral config Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add convert_state Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix StateDictTransform for 2D layers, e.g. MoE Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * pass num_moe_experts to specs Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * udpate MixtralModel Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * mini docstring Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> Co-authored-by: akoumpa <akoumpa@users.noreply.github.com> Signed-off-by: ashors1 <ashors@nvidia.com>
1 parent 00791a1 commit 22cf880

File tree

6 files changed

+202
-11
lines changed

6 files changed

+202
-11
lines changed

nemo/collections/llm/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
MaskedTokenLossReduction,
1919
Mistral7BConfig,
2020
Mistral7BModel,
21+
MixtralConfig,
22+
MixtralModel,
2123
gpt_data_step,
2224
gpt_forward_step,
2325
)
@@ -31,6 +33,8 @@
3133
"MaskedTokenLossReduction",
3234
"Mistral7BConfig",
3335
"Mistral7BModel",
36+
"MixtralConfig",
37+
"MixtralModel",
3438
"PreTrainingDataModule",
3539
"FineTuningDataModule",
3640
"SquadDataModule",

nemo/collections/llm/gpt/data/pre_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytorch_lightning as pl
55
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
6+
from torch.utils import data
67
from torch.utils.data import DataLoader
78

89
from nemo.lightning.pytorch.plugins import MegatronDataSampler
@@ -121,7 +122,7 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
121122
num_workers=self.num_workers,
122123
pin_memory=self.pin_memory,
123124
persistent_workers=self.persistent_workers,
124-
collate_fn=dataset.collate_fn,
125+
collate_fn=getattr(dataset, 'collate_fn', data.dataloader.default_collate),
125126
**kwargs,
126127
)
127128

nemo/collections/llm/gpt/model/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66
gpt_forward_step,
77
)
88
from nemo.collections.llm.gpt.model.mistral_7b import Mistral7BConfig, Mistral7BModel
9+
from nemo.collections.llm.gpt.model.mixtral import MixtralConfig, MixtralModel
910

1011
__all__ = [
1112
"GPTConfig",
1213
"GPTModel",
1314
"Mistral7BConfig",
1415
"Mistral7BModel",
16+
"MixtralConfig",
17+
"MixtralModel",
1518
"MaskedTokenLossReduction",
1619
"gpt_data_step",
1720
"gpt_forward_step",

nemo/collections/llm/gpt/model/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def configure_model(self, tokenizer) -> "MCoreGPTModel":
5454

5555
return MCoreGPTModel(
5656
self,
57-
transformer_layer_spec=get_gpt_layer_spec(),
57+
transformer_layer_spec=get_gpt_layer_spec(self.num_moe_experts),
5858
vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by),
5959
max_sequence_length=self.seq_length,
6060
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from dataclasses import dataclass
2+
from pathlib import Path
3+
from typing import TYPE_CHECKING, Callable, Optional
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
9+
from nemo.lightning import io, teardown
10+
from nemo.lightning.pytorch.opt import OptimizerModule
11+
12+
if TYPE_CHECKING:
13+
from transformers import MistralConfig, MistralForCausalLM
14+
15+
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
16+
17+
18+
@dataclass
19+
class MixtralConfig(GPTConfig):
20+
"""
21+
Config for Mixtral-8x7B model
22+
Official announcement: https://mistral.ai/news/mixtral-of-experts/
23+
"""
24+
25+
normalization: str = "RMSNorm"
26+
activation_func: Callable = F.silu
27+
position_embedding_type: str = "rope"
28+
add_bias_linear: bool = False
29+
gated_linear_unit: bool = True
30+
apply_query_key_layer_scaling: bool = False # TODO: Should this be True?
31+
32+
num_layers: int = 32
33+
hidden_size: int = 4096
34+
num_attention_heads: int = 32
35+
num_query_groups: int = 8
36+
ffn_hidden_size: int = 14336
37+
max_position_embeddings: int = 4096 # 32768
38+
seq_length: int = 4096 # 32768
39+
# MoE
40+
num_moe_experts: int = 8
41+
moe_router_topk: int = 1
42+
43+
init_method_std: float = 0.02
44+
layernorm_epsilon: float = 1e-5
45+
# rotary
46+
rotary_percent: float = 0.5
47+
rotary_base: float = 10000
48+
49+
50+
class MixtralModel(GPTModel):
51+
def __init__(
52+
self,
53+
config: Optional[MixtralConfig] = None,
54+
optim: Optional[OptimizerModule] = None,
55+
tokenizer: Optional["TokenizerSpec"] = None,
56+
):
57+
super().__init__(config or MixtralConfig(), optim=optim, tokenizer=tokenizer)
58+
59+
60+
@io.model_importer(MixtralModel, ext="hf")
61+
class HFMixtralImporter(io.ModelConnector["MixtralForCausalLM", MixtralModel]):
62+
def init(self) -> MixtralModel:
63+
return MixtralModel(self.config, tokenizer=self.tokenizer)
64+
65+
def apply(self, output_path: Path) -> Path:
66+
from transformers import MixtralForCausalLM
67+
68+
source = MixtralForCausalLM.from_pretrained(str(self))
69+
target = self.init()
70+
trainer = self.nemo_setup(target)
71+
self.convert_state(source, target)
72+
self.nemo_save(output_path, trainer)
73+
74+
teardown(trainer, target)
75+
del trainer, target
76+
77+
return output_path
78+
79+
def convert_state(self, source, target):
80+
mapping = {
81+
"model.embed_tokens.weight": "embedding.word_embeddings.weight",
82+
"model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight",
83+
"model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
84+
"model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.pre_mlp_layernorm.weight",
85+
# MoE
86+
"model.layers.*.block_sparse_moe.experts.*.w2.weight": "decoder.layers.*.mlp.experts.local_experts.*.linear_fc2.weight",
87+
"model.layers.*.block_sparse_moe.gate.weight": "decoder.layers.*.mlp.router.weight",
88+
# lm-head
89+
"model.norm.weight": "decoder.final_layernorm.weight",
90+
"lm_head.weight": "output_layer.weight",
91+
}
92+
93+
return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_moe_w1_w3])
94+
95+
@property
96+
def tokenizer(self) -> "AutoTokenizer":
97+
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
98+
99+
return AutoTokenizer(str(self))
100+
101+
@property
102+
def config(self) -> MixtralConfig:
103+
from transformers import MixtralConfig as HfMixtralConfig
104+
105+
config = HfMixtralConfig.from_pretrained(str(self))
106+
return MixtralConfig(
107+
activation_func=F.silu,
108+
# network
109+
num_layers=config.num_hidden_layers,
110+
hidden_size=config.hidden_size,
111+
ffn_hidden_size=config.intermediate_size,
112+
max_position_embeddings=config.max_position_embeddings, # TODO
113+
seq_length=config.max_position_embeddings,
114+
# RoPE
115+
position_embedding_type='rope',
116+
rotary_base=config.rope_theta,
117+
# Transformer config
118+
num_attention_heads=config.num_attention_heads,
119+
num_query_groups=config.num_key_value_heads,
120+
num_moe_experts=config.num_local_experts,
121+
moe_router_topk=config.num_experts_per_tok,
122+
# norm
123+
normalization='RMSNorm',
124+
layernorm_epsilon=config.rms_norm_eps,
125+
# Init
126+
init_method_std=config.initializer_range,
127+
gated_linear_unit=True,
128+
# Vocab
129+
make_vocab_size_divisible_by=128,
130+
)
131+
132+
133+
@io.state_transform(
134+
source_key=(
135+
"model.layers.*.self_attn.q_proj.weight",
136+
"model.layers.*.self_attn.k_proj.weight",
137+
"model.layers.*.self_attn.v_proj.weight",
138+
),
139+
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
140+
)
141+
def _import_qkv(ctx: io.TransformCTX, q, k, v):
142+
megatron_config = ctx.target.config
143+
144+
head_num = megatron_config.num_attention_heads
145+
num_query_groups = megatron_config.num_query_groups
146+
heads_per_group = head_num // num_query_groups
147+
hidden_size = megatron_config.hidden_size
148+
head_num = megatron_config.num_attention_heads
149+
head_size = hidden_size // head_num
150+
151+
old_tensor_shape = q.size()
152+
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
153+
new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:]
154+
155+
q = q.view(*new_q_tensor_shape)
156+
k = k.view(*new_kv_tensor_shape)
157+
v = v.view(*new_kv_tensor_shape)
158+
159+
qkv_weights_l = []
160+
for i in range(num_query_groups):
161+
qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :])
162+
qkv_weights_l.append(k[i : i + 1, :, :])
163+
qkv_weights_l.append(v[i : i + 1, :, :])
164+
qkv_weights = torch.cat(qkv_weights_l)
165+
assert qkv_weights.ndim == 3, qkv_weights.shape
166+
assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape
167+
assert qkv_weights.shape[1] == head_size, qkv_weights.shape
168+
assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape
169+
170+
qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size])
171+
172+
return qkv_weights
173+
174+
175+
@io.state_transform(
176+
source_key=(
177+
"model.layers.*.block_sparse_moe.experts.*.w1.weight",
178+
"model.layers.*.block_sparse_moe.experts.*.w3.weight",
179+
),
180+
target_key="decoder.layers.*.mlp.experts.local_experts.*.linear_fc1.weight",
181+
)
182+
def _import_moe_w1_w3(gate_proj, up_proj):
183+
return torch.cat((gate_proj, up_proj), axis=0)

nemo/lightning/io/state.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,15 @@ def __call__(self, ctx: TransformCTX) -> TransformCTX:
217217
source_key_dict = source_key
218218
source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()}
219219
target_matches = _match_keys(list(target_dict.keys()), target_key)
220-
221-
for target_index, target_match in np.ndenumerate(target_matches):
222-
kwargs = {}
223-
for param in fn_params:
224-
if param in source_matches_dict:
225-
source_match = source_matches_dict[param][target_index[:-1]]
226-
kwargs[param] = source_dict[source_match[target_index]]
227-
228-
target_dict[target_match] = self.call_transform(ctx, **kwargs)
220+
param_names = list(filter(lambda x: x in source_matches_dict, fn_params))
221+
for layer_names_group in zip(*([source_matches_dict[v] for v in param_names] + [target_matches])):
222+
# Wrap in a list if it's a single layer (ie non-expert)
223+
if isinstance(layer_names_group[0], str):
224+
layer_names_group = [[x] for x in layer_names_group]
225+
for layer_names in zip(*layer_names_group):
226+
target_dict[layer_names[-1]] = self.call_transform(
227+
ctx, **dict(zip(param_names, [source_dict[x] for x in layer_names[:-1]]))
228+
)
229229
else:
230230
source_keys = list(source_dict.keys())
231231
target_keys = list(target_dict.keys())

0 commit comments

Comments
 (0)