|
| 1 | +from dataclasses import dataclass |
| 2 | +from pathlib import Path |
| 3 | +from typing import TYPE_CHECKING, Annotated, Callable, Optional |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel |
| 8 | +from nemo.collections.llm.utils import Config |
| 9 | +from nemo.collections.nlp.modules.common.megatron.utils import openai_gelu |
| 10 | +from nemo.lightning import OptimizerModule, io, teardown |
| 11 | + |
| 12 | +if TYPE_CHECKING: |
| 13 | + from transformers import GemmaForCausalLM |
| 14 | + |
| 15 | + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer |
| 16 | + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
| 17 | + |
| 18 | + |
| 19 | +# Note: Gemma requires huggingface transformers >= 4.38 |
| 20 | +# Note: these Gemma configs are copied from the corresponding HF model. You may need to modify the parameter for |
| 21 | +# your own needs, in particular: seq_length and rotary_base. |
| 22 | +@dataclass |
| 23 | +class GemmaConfig(GPTConfig): |
| 24 | + # configs that are common across model sizes |
| 25 | + normalization: str = "RMSNorm" |
| 26 | + activation_func: Callable = openai_gelu |
| 27 | + gated_linear_unit: bool = True |
| 28 | + position_embedding_type: str = "rope" |
| 29 | + add_bias_linear: bool = False |
| 30 | + seq_length: int = 8192 |
| 31 | + kv_channels: int = 256 |
| 32 | + share_embeddings_and_output_weights: bool = True |
| 33 | + # Note: different behavior compared to Legacy NeMo |
| 34 | + # Legacy NeMo does not set layernorm_zero_centered_gamma and instead adds 1 in the HF -> NeMo conversion script |
| 35 | + # The present implementation is more in line with the official implementation |
| 36 | + layernorm_zero_centered_gamma: bool = True |
| 37 | + |
| 38 | + |
| 39 | +@dataclass |
| 40 | +class GemmaConfig2B(GemmaConfig): |
| 41 | + num_layers: int = 18 |
| 42 | + hidden_size: int = 2048 |
| 43 | + num_attention_heads: int = 8 |
| 44 | + num_query_groups: int = 1 |
| 45 | + ffn_hidden_size: int = 16384 |
| 46 | + |
| 47 | + |
| 48 | +@dataclass |
| 49 | +class GemmaConfig7B(GemmaConfig): |
| 50 | + num_layers: int = 28 |
| 51 | + hidden_size: int = 3072 |
| 52 | + num_attention_heads: int = 16 |
| 53 | + num_query_groups: int = 16 |
| 54 | + ffn_hidden_size: int = 24576 |
| 55 | + |
| 56 | + |
| 57 | +class CodeGemmaConfig2B(GemmaConfig2B): |
| 58 | + pass |
| 59 | + |
| 60 | + |
| 61 | +class CodeGemmaConfig7B(GemmaConfig7B): |
| 62 | + pass |
| 63 | + |
| 64 | + |
| 65 | +class GemmaModel(GPTModel): |
| 66 | + def __init__( |
| 67 | + self, |
| 68 | + config: Annotated[Optional[GemmaConfig], Config[GemmaConfig]] = None, |
| 69 | + optim: Optional[OptimizerModule] = None, |
| 70 | + tokenizer: Optional["TokenizerSpec"] = None, |
| 71 | + ): |
| 72 | + super().__init__(config or GemmaConfig(), optim=optim, tokenizer=tokenizer) |
| 73 | + |
| 74 | + |
| 75 | +@io.model_importer(GemmaModel, "hf") |
| 76 | +class HFGemmaImporter(io.ModelConnector["GemmaForCausalLM", GemmaModel]): |
| 77 | + def init(self) -> GemmaModel: |
| 78 | + return GemmaModel(self.config, tokenizer=self.tokenizer) |
| 79 | + |
| 80 | + def apply(self, output_path: Path) -> Path: |
| 81 | + from transformers import GemmaForCausalLM |
| 82 | + |
| 83 | + source = GemmaForCausalLM.from_pretrained(str(self)) |
| 84 | + target = self.init() |
| 85 | + trainer = self.nemo_setup(target) |
| 86 | + self.convert_state(source, target) |
| 87 | + self.nemo_save(output_path, trainer) |
| 88 | + |
| 89 | + print(f"Converted Gemma model to Nemo, model saved to {output_path}") |
| 90 | + |
| 91 | + teardown(trainer, target) |
| 92 | + del trainer, target |
| 93 | + |
| 94 | + return output_path |
| 95 | + |
| 96 | + def convert_state(self, source, target): |
| 97 | + mapping = { |
| 98 | + "model.embed_tokens.weight": "embedding.word_embeddings.weight", |
| 99 | + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", |
| 100 | + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", |
| 101 | + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", |
| 102 | + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", |
| 103 | + "model.norm.weight": "decoder.final_layernorm.weight", |
| 104 | + } |
| 105 | + |
| 106 | + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) |
| 107 | + |
| 108 | + @property |
| 109 | + def tokenizer(self) -> "AutoTokenizer": |
| 110 | + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer |
| 111 | + |
| 112 | + return AutoTokenizer(str(self)) |
| 113 | + |
| 114 | + @property |
| 115 | + def config(self) -> GemmaConfig: |
| 116 | + from transformers import GemmaConfig as HFGemmaConfig |
| 117 | + |
| 118 | + source = HFGemmaConfig.from_pretrained(str(self)) |
| 119 | + |
| 120 | + def make_vocab_size_divisible_by(vocab_size): |
| 121 | + base = 128 |
| 122 | + while vocab_size % base != 0: |
| 123 | + base //= 2 |
| 124 | + return base |
| 125 | + |
| 126 | + output = GemmaConfig( |
| 127 | + num_layers=source.num_hidden_layers, |
| 128 | + hidden_size=source.hidden_size, |
| 129 | + ffn_hidden_size=source.intermediate_size, |
| 130 | + num_attention_heads=source.num_attention_heads, |
| 131 | + init_method_std=source.initializer_range, |
| 132 | + layernorm_epsilon=source.rms_norm_eps, |
| 133 | + num_query_groups=source.num_key_value_heads, |
| 134 | + rotary_base=source.rope_theta, |
| 135 | + gated_linear_unit=True, |
| 136 | + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), |
| 137 | + share_embeddings_and_output_weights=False, |
| 138 | + ) |
| 139 | + |
| 140 | + return output |
| 141 | + |
| 142 | + |
| 143 | +@io.model_exporter(GemmaModel, "hf") |
| 144 | +class HFGemmaExporter(io.ModelConnector[GemmaModel, "GemmaForCausalLM"]): |
| 145 | + def init(self) -> "GemmaForCausalLM": |
| 146 | + from transformers import AutoModelForCausalLM |
| 147 | + |
| 148 | + return AutoModelForCausalLM.from_config(self.config) |
| 149 | + |
| 150 | + def apply(self, output_path: Path) -> Path: |
| 151 | + target = self.init() |
| 152 | + source, _ = self.nemo_load(str(self)) |
| 153 | + target = self.convert_state(source, target) |
| 154 | + |
| 155 | + target = target.cpu() |
| 156 | + target.save_pretrained(output_path) |
| 157 | + self.tokenizer.save_pretrained(output_path) |
| 158 | + |
| 159 | + return output_path |
| 160 | + |
| 161 | + def convert_state(self, source, target): |
| 162 | + mapping = { |
| 163 | + "embedding.word_embeddings.weight": "model.embed_tokens.weight", |
| 164 | + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", |
| 165 | + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", |
| 166 | + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", |
| 167 | + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", |
| 168 | + "decoder.final_layernorm.weight": "model.norm.weight", |
| 169 | + } |
| 170 | + |
| 171 | + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1]) |
| 172 | + |
| 173 | + @property |
| 174 | + def tokenizer(self): |
| 175 | + return io.load_ckpt(str(self)).model.tokenizer.tokenizer |
| 176 | + |
| 177 | + @property |
| 178 | + def config(self) -> "GemmaConfig": |
| 179 | + source: GemmaConfig = io.load_ckpt(str(self)).model.config |
| 180 | + |
| 181 | + from transformers import GemmaConfig as HFGemmaConfig |
| 182 | + |
| 183 | + return HFGemmaConfig( |
| 184 | + num_hidden_layers=source.num_layers, |
| 185 | + hidden_size=source.hidden_size, |
| 186 | + intermediate_size=source.ffn_hidden_size, |
| 187 | + num_attention_heads=source.num_attention_heads, |
| 188 | + max_position_embeddings=source.seq_length, |
| 189 | + initializer_range=source.init_method_std, |
| 190 | + rms_norm_eps=source.layernorm_epsilon, |
| 191 | + num_key_value_heads=source.num_query_groups, |
| 192 | + vocab_size=self.tokenizer.vocab_size, |
| 193 | + ) |
| 194 | + |
| 195 | + |
| 196 | +@io.state_transform( |
| 197 | + source_key=( |
| 198 | + "model.layers.*.self_attn.q_proj.weight", |
| 199 | + "model.layers.*.self_attn.k_proj.weight", |
| 200 | + "model.layers.*.self_attn.v_proj.weight", |
| 201 | + ), |
| 202 | + target_key="decoder.layers.*.self_attention.linear_qkv.weight", |
| 203 | +) |
| 204 | +def _import_qkv(ctx: io.TransformCTX, q, k, v): |
| 205 | + megatron_config = ctx.target.config |
| 206 | + |
| 207 | + head_num = megatron_config.num_attention_heads |
| 208 | + num_query_groups = megatron_config.num_query_groups |
| 209 | + heads_per_group = head_num // num_query_groups |
| 210 | + hidden_size = megatron_config.hidden_size |
| 211 | + head_num = megatron_config.num_attention_heads |
| 212 | + head_size = hidden_size // head_num |
| 213 | + |
| 214 | + old_tensor_shape = q.size() |
| 215 | + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] |
| 216 | + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] |
| 217 | + |
| 218 | + q = q.view(*new_q_tensor_shape) |
| 219 | + k = k.view(*new_kv_tensor_shape) |
| 220 | + v = v.view(*new_kv_tensor_shape) |
| 221 | + |
| 222 | + qkv_weights_l = [] |
| 223 | + for i in range(num_query_groups): |
| 224 | + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) |
| 225 | + qkv_weights_l.append(k[i : i + 1, :, :]) |
| 226 | + qkv_weights_l.append(v[i : i + 1, :, :]) |
| 227 | + qkv_weights = torch.cat(qkv_weights_l) |
| 228 | + assert qkv_weights.ndim == 3, qkv_weights.shape |
| 229 | + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape |
| 230 | + assert qkv_weights.shape[1] == head_size, qkv_weights.shape |
| 231 | + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape |
| 232 | + |
| 233 | + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) |
| 234 | + |
| 235 | + return qkv_weights |
| 236 | + |
| 237 | + |
| 238 | +@io.state_transform( |
| 239 | + source_key="decoder.layers.*.self_attention.linear_qkv.weight", |
| 240 | + target_key=( |
| 241 | + "model.layers.*.self_attn.q_proj.weight", |
| 242 | + "model.layers.*.self_attn.k_proj.weight", |
| 243 | + "model.layers.*.self_attn.v_proj.weight", |
| 244 | + ), |
| 245 | +) |
| 246 | +def _export_qkv(ctx: io.TransformCTX, linear_qkv): |
| 247 | + megatron_config = ctx.source.config |
| 248 | + |
| 249 | + head_num = megatron_config.num_attention_heads |
| 250 | + num_query_groups = megatron_config.num_query_groups |
| 251 | + heads_per_group = head_num // num_query_groups |
| 252 | + hidden_size = megatron_config.hidden_size |
| 253 | + head_num = megatron_config.num_attention_heads |
| 254 | + head_size = hidden_size // head_num |
| 255 | + qkv_total_dim = head_num + 2 * num_query_groups |
| 256 | + |
| 257 | + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) |
| 258 | + q_slice = torch.cat( |
| 259 | + [ |
| 260 | + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) |
| 261 | + for i in range(num_query_groups) |
| 262 | + ] |
| 263 | + ) |
| 264 | + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) |
| 265 | + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) |
| 266 | + |
| 267 | + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() |
| 268 | + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() |
| 269 | + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() |
| 270 | + |
| 271 | + return q_proj, k_proj, v_proj |
| 272 | + |
| 273 | + |
| 274 | +@io.state_transform( |
| 275 | + source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), |
| 276 | + target_key="decoder.layers.*.mlp.linear_fc1.weight", |
| 277 | +) |
| 278 | +def _import_linear_fc1(down, gate): |
| 279 | + return torch.cat((down, gate), axis=0).float() |
| 280 | + |
| 281 | + |
| 282 | +@io.state_transform( |
| 283 | + source_key="decoder.layers.*.mlp.linear_fc1.weight", |
| 284 | + target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), |
| 285 | +) |
| 286 | +def _export_linear_fc1(linear_fc1): |
| 287 | + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) |
| 288 | + |
| 289 | + return gate_proj, up_proj |
| 290 | + |
| 291 | + |
| 292 | +__all__ = [ |
| 293 | + "GemmaConfig", |
| 294 | + "GemmaConfig2B", |
| 295 | + "GemmaConfig7B", |
| 296 | + "CodeGemmaConfig2B", |
| 297 | + "CodeGemmaConfig7B", |
| 298 | + "GemmaModel", |
| 299 | +] |
0 commit comments