Skip to content

Commit 490ade4

Browse files
cuichenxmarcromeyn
andauthored
[NeMo-UX] Llama and Gemma (#9528)
* add llama Signed-off-by: Chen Cui <chcui@nvidia.com> * Apply isort and black reformatting Signed-off-by: cuichenx <cuichenx@users.noreply.github.com> * add llama Signed-off-by: Chen Cui <chcui@nvidia.com> * Apply isort and black reformatting Signed-off-by: cuichenx <cuichenx@users.noreply.github.com> * add llama3 Signed-off-by: Chen Cui <chcui@nvidia.com> * Apply isort and black reformatting Signed-off-by: cuichenx <cuichenx@users.noreply.github.com> * fix typo Signed-off-by: Chen Cui <chcui@nvidia.com> * enable importers with multiple models Signed-off-by: Chen Cui <chcui@nvidia.com> * Apply isort and black reformatting Signed-off-by: cuichenx <cuichenx@users.noreply.github.com> * add gemma Signed-off-by: Chen Cui <chcui@nvidia.com> * Apply isort and black reformatting Signed-off-by: cuichenx <cuichenx@users.noreply.github.com> * checks Signed-off-by: Chen Cui <chcui@nvidia.com> * Apply isort and black reformatting Signed-off-by: cuichenx <cuichenx@users.noreply.github.com> --------- Signed-off-by: Chen Cui <chcui@nvidia.com> Signed-off-by: cuichenx <cuichenx@users.noreply.github.com> Co-authored-by: cuichenx <cuichenx@users.noreply.github.com> Co-authored-by: Marc Romeyn <mromeijn@nvidia.com>
1 parent 6ad3615 commit 490ade4

File tree

7 files changed

+699
-7
lines changed

7 files changed

+699
-7
lines changed

nemo/collections/llm/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,25 @@
1313
SquadDataModule,
1414
)
1515
from nemo.collections.llm.gpt.model import (
16+
CodeGemmaConfig2B,
17+
CodeGemmaConfig7B,
18+
CodeLlamaConfig7B,
19+
CodeLlamaConfig13B,
20+
CodeLlamaConfig34B,
21+
CodeLlamaConfig70B,
22+
GemmaConfig,
23+
GemmaConfig2B,
24+
GemmaConfig7B,
25+
GemmaModel,
1626
GPTConfig,
1727
GPTModel,
28+
Llama2Config7B,
29+
Llama2Config13B,
30+
Llama2Config70B,
31+
Llama3Config8B,
32+
Llama3Config70B,
33+
LlamaConfig,
34+
LlamaModel,
1835
MaskedTokenLossReduction,
1936
Mistral7BConfig,
2037
Mistral7BModel,
@@ -35,6 +52,23 @@
3552
"Mistral7BModel",
3653
"MixtralConfig",
3754
"MixtralModel",
55+
"LlamaConfig",
56+
"Llama2Config7B",
57+
"Llama2Config13B",
58+
"Llama2Config70B",
59+
"Llama3Config8B",
60+
"Llama3Config70B",
61+
"CodeLlamaConfig7B",
62+
"CodeLlamaConfig13B",
63+
"CodeLlamaConfig34B",
64+
"CodeLlamaConfig70B",
65+
"LlamaModel",
66+
"GemmaConfig",
67+
"GemmaConfig2B",
68+
"GemmaConfig7B",
69+
"CodeGemmaConfig2B",
70+
"CodeGemmaConfig7B",
71+
"GemmaModel",
3872
"PreTrainingDataModule",
3973
"FineTuningDataModule",
4074
"SquadDataModule",

nemo/collections/llm/gpt/model/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
gpt_data_step,
66
gpt_forward_step,
77
)
8+
from nemo.collections.llm.gpt.model.gemma import *
9+
from nemo.collections.llm.gpt.model.llama import *
810
from nemo.collections.llm.gpt.model.mistral_7b import Mistral7BConfig, Mistral7BModel
911
from nemo.collections.llm.gpt.model.mixtral import MixtralConfig, MixtralModel
1012

@@ -15,6 +17,23 @@
1517
"Mistral7BModel",
1618
"MixtralConfig",
1719
"MixtralModel",
20+
"LlamaConfig",
21+
"Llama2Config7B",
22+
"Llama2Config13B",
23+
"Llama2Config70B",
24+
"Llama3Config8B",
25+
"Llama3Config70B",
26+
"CodeLlamaConfig7B",
27+
"CodeLlamaConfig13B",
28+
"CodeLlamaConfig34B",
29+
"CodeLlamaConfig70B",
30+
"GemmaConfig",
31+
"GemmaConfig2B",
32+
"GemmaConfig7B",
33+
"CodeGemmaConfig2B",
34+
"CodeGemmaConfig7B",
35+
"GemmaModel",
36+
"LlamaModel",
1837
"MaskedTokenLossReduction",
1938
"gpt_data_step",
2039
"gpt_forward_step",
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
from dataclasses import dataclass
2+
from pathlib import Path
3+
from typing import TYPE_CHECKING, Annotated, Callable, Optional
4+
5+
import torch
6+
7+
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
8+
from nemo.collections.llm.utils import Config
9+
from nemo.collections.nlp.modules.common.megatron.utils import openai_gelu
10+
from nemo.lightning import OptimizerModule, io, teardown
11+
12+
if TYPE_CHECKING:
13+
from transformers import GemmaForCausalLM
14+
15+
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
16+
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
17+
18+
19+
# Note: Gemma requires huggingface transformers >= 4.38
20+
# Note: these Gemma configs are copied from the corresponding HF model. You may need to modify the parameter for
21+
# your own needs, in particular: seq_length and rotary_base.
22+
@dataclass
23+
class GemmaConfig(GPTConfig):
24+
# configs that are common across model sizes
25+
normalization: str = "RMSNorm"
26+
activation_func: Callable = openai_gelu
27+
gated_linear_unit: bool = True
28+
position_embedding_type: str = "rope"
29+
add_bias_linear: bool = False
30+
seq_length: int = 8192
31+
kv_channels: int = 256
32+
share_embeddings_and_output_weights: bool = True
33+
# Note: different behavior compared to Legacy NeMo
34+
# Legacy NeMo does not set layernorm_zero_centered_gamma and instead adds 1 in the HF -> NeMo conversion script
35+
# The present implementation is more in line with the official implementation
36+
layernorm_zero_centered_gamma: bool = True
37+
38+
39+
@dataclass
40+
class GemmaConfig2B(GemmaConfig):
41+
num_layers: int = 18
42+
hidden_size: int = 2048
43+
num_attention_heads: int = 8
44+
num_query_groups: int = 1
45+
ffn_hidden_size: int = 16384
46+
47+
48+
@dataclass
49+
class GemmaConfig7B(GemmaConfig):
50+
num_layers: int = 28
51+
hidden_size: int = 3072
52+
num_attention_heads: int = 16
53+
num_query_groups: int = 16
54+
ffn_hidden_size: int = 24576
55+
56+
57+
class CodeGemmaConfig2B(GemmaConfig2B):
58+
pass
59+
60+
61+
class CodeGemmaConfig7B(GemmaConfig7B):
62+
pass
63+
64+
65+
class GemmaModel(GPTModel):
66+
def __init__(
67+
self,
68+
config: Annotated[Optional[GemmaConfig], Config[GemmaConfig]] = None,
69+
optim: Optional[OptimizerModule] = None,
70+
tokenizer: Optional["TokenizerSpec"] = None,
71+
):
72+
super().__init__(config or GemmaConfig(), optim=optim, tokenizer=tokenizer)
73+
74+
75+
@io.model_importer(GemmaModel, "hf")
76+
class HFGemmaImporter(io.ModelConnector["GemmaForCausalLM", GemmaModel]):
77+
def init(self) -> GemmaModel:
78+
return GemmaModel(self.config, tokenizer=self.tokenizer)
79+
80+
def apply(self, output_path: Path) -> Path:
81+
from transformers import GemmaForCausalLM
82+
83+
source = GemmaForCausalLM.from_pretrained(str(self))
84+
target = self.init()
85+
trainer = self.nemo_setup(target)
86+
self.convert_state(source, target)
87+
self.nemo_save(output_path, trainer)
88+
89+
print(f"Converted Gemma model to Nemo, model saved to {output_path}")
90+
91+
teardown(trainer, target)
92+
del trainer, target
93+
94+
return output_path
95+
96+
def convert_state(self, source, target):
97+
mapping = {
98+
"model.embed_tokens.weight": "embedding.word_embeddings.weight",
99+
"model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight",
100+
"model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight",
101+
"model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
102+
"model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight",
103+
"model.norm.weight": "decoder.final_layernorm.weight",
104+
}
105+
106+
return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1])
107+
108+
@property
109+
def tokenizer(self) -> "AutoTokenizer":
110+
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
111+
112+
return AutoTokenizer(str(self))
113+
114+
@property
115+
def config(self) -> GemmaConfig:
116+
from transformers import GemmaConfig as HFGemmaConfig
117+
118+
source = HFGemmaConfig.from_pretrained(str(self))
119+
120+
def make_vocab_size_divisible_by(vocab_size):
121+
base = 128
122+
while vocab_size % base != 0:
123+
base //= 2
124+
return base
125+
126+
output = GemmaConfig(
127+
num_layers=source.num_hidden_layers,
128+
hidden_size=source.hidden_size,
129+
ffn_hidden_size=source.intermediate_size,
130+
num_attention_heads=source.num_attention_heads,
131+
init_method_std=source.initializer_range,
132+
layernorm_epsilon=source.rms_norm_eps,
133+
num_query_groups=source.num_key_value_heads,
134+
rotary_base=source.rope_theta,
135+
gated_linear_unit=True,
136+
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
137+
share_embeddings_and_output_weights=False,
138+
)
139+
140+
return output
141+
142+
143+
@io.model_exporter(GemmaModel, "hf")
144+
class HFGemmaExporter(io.ModelConnector[GemmaModel, "GemmaForCausalLM"]):
145+
def init(self) -> "GemmaForCausalLM":
146+
from transformers import AutoModelForCausalLM
147+
148+
return AutoModelForCausalLM.from_config(self.config)
149+
150+
def apply(self, output_path: Path) -> Path:
151+
target = self.init()
152+
source, _ = self.nemo_load(str(self))
153+
target = self.convert_state(source, target)
154+
155+
target = target.cpu()
156+
target.save_pretrained(output_path)
157+
self.tokenizer.save_pretrained(output_path)
158+
159+
return output_path
160+
161+
def convert_state(self, source, target):
162+
mapping = {
163+
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
164+
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
165+
"decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight",
166+
"decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight",
167+
"decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight",
168+
"decoder.final_layernorm.weight": "model.norm.weight",
169+
}
170+
171+
return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1])
172+
173+
@property
174+
def tokenizer(self):
175+
return io.load_ckpt(str(self)).model.tokenizer.tokenizer
176+
177+
@property
178+
def config(self) -> "GemmaConfig":
179+
source: GemmaConfig = io.load_ckpt(str(self)).model.config
180+
181+
from transformers import GemmaConfig as HFGemmaConfig
182+
183+
return HFGemmaConfig(
184+
num_hidden_layers=source.num_layers,
185+
hidden_size=source.hidden_size,
186+
intermediate_size=source.ffn_hidden_size,
187+
num_attention_heads=source.num_attention_heads,
188+
max_position_embeddings=source.seq_length,
189+
initializer_range=source.init_method_std,
190+
rms_norm_eps=source.layernorm_epsilon,
191+
num_key_value_heads=source.num_query_groups,
192+
vocab_size=self.tokenizer.vocab_size,
193+
)
194+
195+
196+
@io.state_transform(
197+
source_key=(
198+
"model.layers.*.self_attn.q_proj.weight",
199+
"model.layers.*.self_attn.k_proj.weight",
200+
"model.layers.*.self_attn.v_proj.weight",
201+
),
202+
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
203+
)
204+
def _import_qkv(ctx: io.TransformCTX, q, k, v):
205+
megatron_config = ctx.target.config
206+
207+
head_num = megatron_config.num_attention_heads
208+
num_query_groups = megatron_config.num_query_groups
209+
heads_per_group = head_num // num_query_groups
210+
hidden_size = megatron_config.hidden_size
211+
head_num = megatron_config.num_attention_heads
212+
head_size = hidden_size // head_num
213+
214+
old_tensor_shape = q.size()
215+
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
216+
new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:]
217+
218+
q = q.view(*new_q_tensor_shape)
219+
k = k.view(*new_kv_tensor_shape)
220+
v = v.view(*new_kv_tensor_shape)
221+
222+
qkv_weights_l = []
223+
for i in range(num_query_groups):
224+
qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :])
225+
qkv_weights_l.append(k[i : i + 1, :, :])
226+
qkv_weights_l.append(v[i : i + 1, :, :])
227+
qkv_weights = torch.cat(qkv_weights_l)
228+
assert qkv_weights.ndim == 3, qkv_weights.shape
229+
assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape
230+
assert qkv_weights.shape[1] == head_size, qkv_weights.shape
231+
assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape
232+
233+
qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size])
234+
235+
return qkv_weights
236+
237+
238+
@io.state_transform(
239+
source_key="decoder.layers.*.self_attention.linear_qkv.weight",
240+
target_key=(
241+
"model.layers.*.self_attn.q_proj.weight",
242+
"model.layers.*.self_attn.k_proj.weight",
243+
"model.layers.*.self_attn.v_proj.weight",
244+
),
245+
)
246+
def _export_qkv(ctx: io.TransformCTX, linear_qkv):
247+
megatron_config = ctx.source.config
248+
249+
head_num = megatron_config.num_attention_heads
250+
num_query_groups = megatron_config.num_query_groups
251+
heads_per_group = head_num // num_query_groups
252+
hidden_size = megatron_config.hidden_size
253+
head_num = megatron_config.num_attention_heads
254+
head_size = hidden_size // head_num
255+
qkv_total_dim = head_num + 2 * num_query_groups
256+
257+
linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size])
258+
q_slice = torch.cat(
259+
[
260+
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
261+
for i in range(num_query_groups)
262+
]
263+
)
264+
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
265+
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
266+
267+
q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu()
268+
k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu()
269+
v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu()
270+
271+
return q_proj, k_proj, v_proj
272+
273+
274+
@io.state_transform(
275+
source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"),
276+
target_key="decoder.layers.*.mlp.linear_fc1.weight",
277+
)
278+
def _import_linear_fc1(down, gate):
279+
return torch.cat((down, gate), axis=0).float()
280+
281+
282+
@io.state_transform(
283+
source_key="decoder.layers.*.mlp.linear_fc1.weight",
284+
target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"),
285+
)
286+
def _export_linear_fc1(linear_fc1):
287+
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)
288+
289+
return gate_proj, up_proj
290+
291+
292+
__all__ = [
293+
"GemmaConfig",
294+
"GemmaConfig2B",
295+
"GemmaConfig7B",
296+
"CodeGemmaConfig2B",
297+
"CodeGemmaConfig7B",
298+
"GemmaModel",
299+
]

0 commit comments

Comments
 (0)