|
| 1 | +from dataclasses import dataclass, field |
| 2 | +from pathlib import Path |
| 3 | +from typing import TYPE_CHECKING, Callable, List, Optional |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn.functional as F |
| 7 | + |
| 8 | +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel |
| 9 | +from nemo.lightning import io, teardown |
| 10 | + |
| 11 | +if TYPE_CHECKING: |
| 12 | + from transformers import MistralConfig, MistralForCausalLM |
| 13 | + |
| 14 | + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer |
| 15 | + |
| 16 | + |
| 17 | +@dataclass |
| 18 | +class MixtralConfig(GPTConfig): |
| 19 | + normalization: str = "RMSNorm" |
| 20 | + activation_func: Callable = F.silu |
| 21 | + position_embedding_type: str = "rope" |
| 22 | + add_bias_linear: bool = False |
| 23 | + gated_linear_unit: bool = True |
| 24 | + apply_query_key_layer_scaling: bool = False # TODO: Should this be True? |
| 25 | + |
| 26 | + num_layers: int = 32 |
| 27 | + hidden_size: int = 4096 |
| 28 | + num_attention_heads: int = 32 |
| 29 | + num_query_groups: int = 8 |
| 30 | + ffn_hidden_size: int = 14336 |
| 31 | + max_position_embeddings: int = 4096 # 32768 |
| 32 | + seq_length: int = 4096 # 32768 |
| 33 | + # MoE |
| 34 | + num_moe_experts: int = 8 |
| 35 | + moe_router_topk: int = 1 |
| 36 | + |
| 37 | + init_method_std: float = 0.02 |
| 38 | + layernorm_epsilon: float = 1e-5 |
| 39 | + # rotary |
| 40 | + rotary_percent: float = 0.5 |
| 41 | + rotary_base: float = 10000 |
| 42 | + |
| 43 | + |
| 44 | +class MixtralModel(GPTModel): |
| 45 | + def __init__(self, config: Optional[MixtralConfig] = None, optim_config=None, tokenizer=None): |
| 46 | + _tokenizer = tokenizer or HFMixtralImporter().tokenizer |
| 47 | + |
| 48 | + super().__init__(config or MixtralConfig(), optim_config, _tokenizer) |
| 49 | + |
| 50 | + |
| 51 | +@io.model_importer(MixtralModel, ext="hf") |
| 52 | +class HFMixtralImporter(io.ModelConnector["MixtralForCausalLM", MixtralModel]): |
| 53 | + def init(self) -> MixtralModel: |
| 54 | + return MixtralModel(self.config, tokenizer=self.tokenizer) |
| 55 | + |
| 56 | + def apply(self, output_path: Path) -> Path: |
| 57 | + from transformers import MixtralForCausalLM |
| 58 | + |
| 59 | + source = MixtralForCausalLM.from_pretrained(str(self)) |
| 60 | + target = self.init() |
| 61 | + trainer = self.nemo_setup(target) |
| 62 | + self.convert_state(source, target) |
| 63 | + self.nemo_save(output_path, trainer) |
| 64 | + |
| 65 | + teardown(trainer, target) |
| 66 | + del trainer, target |
| 67 | + |
| 68 | + return output_path |
| 69 | + |
| 70 | + @property |
| 71 | + def tokenizer(self) -> "AutoTokenizer": |
| 72 | + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer |
| 73 | + |
| 74 | + return AutoTokenizer(str(self)) |
| 75 | + |
| 76 | + @property |
| 77 | + def config(self) -> MixtralConfig: |
| 78 | + from transformers import MixtralConfig as HfMixtralConfig |
| 79 | + |
| 80 | + config = HfMixtralConfig.from_pretrained(str(self)) |
| 81 | + return MixtralConfig( |
| 82 | + activation_func=F.silu, |
| 83 | + # network |
| 84 | + num_layers=config.num_hidden_layers, |
| 85 | + hidden_size=config.hidden_size, |
| 86 | + ffn_hidden_size=config.intermediate_size, |
| 87 | + max_position_embeddings=config.max_position_embeddings, # TODO |
| 88 | + seq_length=config.max_position_embeddings, |
| 89 | + # RoPE |
| 90 | + position_embedding_type='rope', |
| 91 | + rotary_base=source.rope_theta, |
| 92 | + # Transformer config |
| 93 | + num_attention_heads=config.num_attention_heads, |
| 94 | + num_query_groups=config.num_key_value_heads, |
| 95 | + num_moe_experts=config.num_local_experts, |
| 96 | + moe_router_topk=config.num_experts_per_tok, |
| 97 | + # norm |
| 98 | + normalization='rmsnorm', |
| 99 | + layernorm_epsilon=source.rms_norm_eps, |
| 100 | + # Init |
| 101 | + init_method_std=source.initializer_range, |
| 102 | + gated_linear_unit=True, |
| 103 | + # Vocab |
| 104 | + make_vocab_size_divisible_by=128, |
| 105 | + ) |
0 commit comments