Skip to content

Commit 6915e85

Browse files
committed
mixtral config
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 49fdbe8 commit 6915e85

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

nemo/collections/llm/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
MaskedTokenLossReduction,
1919
Mistral7BConfig,
2020
Mistral7BModel,
21+
MixtralConfig,
22+
MixtralModel,
2123
gpt_data_step,
2224
gpt_forward_step,
2325
)
@@ -31,6 +33,8 @@
3133
"MaskedTokenLossReduction",
3234
"Mistral7BConfig",
3335
"Mistral7BModel",
36+
"MixtralConfig",
37+
"MixtralModel",
3438
"PreTrainingDataModule",
3539
"FineTuningDataModule",
3640
"SquadDataModule",

nemo/collections/llm/gpt/model/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66
gpt_forward_step,
77
)
88
from nemo.collections.llm.gpt.model.mistral_7b import Mistral7BConfig, Mistral7BModel
9+
from nemo.collections.llm.gpt.model.mixtral import MixtralConfig, MixtralModel
910

1011
__all__ = [
1112
"GPTConfig",
1213
"GPTModel",
1314
"Mistral7BConfig",
1415
"Mistral7BModel",
16+
"MixtralConfig",
17+
"MixtralModel",
1518
"MaskedTokenLossReduction",
1619
"gpt_data_step",
1720
"gpt_forward_step",
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from dataclasses import dataclass, field
2+
from pathlib import Path
3+
from typing import TYPE_CHECKING, Callable, List, Optional
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
9+
from nemo.lightning import io, teardown
10+
11+
if TYPE_CHECKING:
12+
from transformers import MistralConfig, MistralForCausalLM
13+
14+
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
15+
16+
17+
@dataclass
18+
class MixtralConfig(GPTConfig):
19+
normalization: str = "RMSNorm"
20+
activation_func: Callable = F.silu
21+
position_embedding_type: str = "rope"
22+
add_bias_linear: bool = False
23+
gated_linear_unit: bool = True
24+
apply_query_key_layer_scaling: bool = False # TODO: Should this be True?
25+
26+
num_layers: int = 32
27+
hidden_size: int = 4096
28+
num_attention_heads: int = 32
29+
num_query_groups: int = 8
30+
ffn_hidden_size: int = 14336
31+
max_position_embeddings: int = 4096 # 32768
32+
seq_length: int = 4096 # 32768
33+
# MoE
34+
num_moe_experts: int = 8
35+
moe_router_topk: int = 1
36+
37+
init_method_std: float = 0.02
38+
layernorm_epsilon: float = 1e-5
39+
# rotary
40+
rotary_percent: float = 0.5
41+
rotary_base: float = 10000
42+
43+
44+
class MixtralModel(GPTModel):
45+
def __init__(self, config: Optional[MixtralConfig] = None, optim_config=None, tokenizer=None):
46+
_tokenizer = tokenizer or HFMixtralImporter().tokenizer
47+
48+
super().__init__(config or MixtralConfig(), optim_config, _tokenizer)
49+
50+
51+
@io.model_importer(MixtralModel, ext="hf")
52+
class HFMixtralImporter(io.ModelConnector["MixtralForCausalLM", MixtralModel]):
53+
def init(self) -> MixtralModel:
54+
return MixtralModel(self.config, tokenizer=self.tokenizer)
55+
56+
def apply(self, output_path: Path) -> Path:
57+
from transformers import MixtralForCausalLM
58+
59+
source = MixtralForCausalLM.from_pretrained(str(self))
60+
target = self.init()
61+
trainer = self.nemo_setup(target)
62+
self.convert_state(source, target)
63+
self.nemo_save(output_path, trainer)
64+
65+
teardown(trainer, target)
66+
del trainer, target
67+
68+
return output_path
69+
70+
@property
71+
def tokenizer(self) -> "AutoTokenizer":
72+
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
73+
74+
return AutoTokenizer(str(self))
75+
76+
@property
77+
def config(self) -> MixtralConfig:
78+
from transformers import MixtralConfig as HfMixtralConfig
79+
80+
config = HfMixtralConfig.from_pretrained(str(self))
81+
return MixtralConfig(
82+
activation_func=F.silu,
83+
# network
84+
num_layers=config.num_hidden_layers,
85+
hidden_size=config.hidden_size,
86+
ffn_hidden_size=config.intermediate_size,
87+
max_position_embeddings=config.max_position_embeddings, # TODO
88+
seq_length=config.max_position_embeddings,
89+
# RoPE
90+
position_embedding_type='rope',
91+
rotary_base=source.rope_theta,
92+
# Transformer config
93+
num_attention_heads=config.num_attention_heads,
94+
num_query_groups=config.num_key_value_heads,
95+
num_moe_experts=config.num_local_experts,
96+
moe_router_topk=config.num_experts_per_tok,
97+
# norm
98+
normalization='rmsnorm',
99+
layernorm_epsilon=source.rms_norm_eps,
100+
# Init
101+
init_method_std=source.initializer_range,
102+
gated_linear_unit=True,
103+
# Vocab
104+
make_vocab_size_divisible_by=128,
105+
)

0 commit comments

Comments
 (0)