Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions nemo/collections/llm/gpt/model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def make_vocab_size_divisible_by(mistral_vocab_size):
num_layers=source.num_hidden_layers,
hidden_size=source.hidden_size,
ffn_hidden_size=source.intermediate_size,
kv_channels=source.get('head_dim', source.hidden_size // source.num_attention_heads),
num_attention_heads=source.num_attention_heads,
# max_position_embeddings=source.max_position_embeddings,
init_method_std=source.initializer_range,
Expand Down Expand Up @@ -183,6 +184,7 @@ def config(self) -> "MistralConfig":
num_key_value_heads=source.num_query_groups,
rope_theta=source.rotary_base,
vocab_size=self.tokenizer.vocab_size,
head_dim=source.kv_channels,
)


Expand All @@ -202,7 +204,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v):
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_num = megatron_config.num_attention_heads
head_size = hidden_size // head_num
head_size = megatron_config.kv_channels

old_tensor_shape = q.size()
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
Expand Down Expand Up @@ -244,7 +246,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv):
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_num = megatron_config.num_attention_heads
head_size = hidden_size // head_num
head_size = megatron_config.kv_channels
qkv_total_dim = head_num + 2 * num_query_groups

linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size])
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def config(self) -> MixtralConfig8x7B:
num_layers=config.num_hidden_layers,
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
kv_channels=config.get('head_dim', config.hidden_size // config.num_attention_heads),
max_position_embeddings=config.max_position_embeddings, # TODO
seq_length=config.max_position_embeddings,
# RoPE
Expand Down Expand Up @@ -158,7 +159,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v):
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_num = megatron_config.num_attention_heads
head_size = hidden_size // head_num
head_size = megatron_config.kv_channels

old_tensor_shape = q.size()
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
Expand Down Expand Up @@ -262,6 +263,7 @@ def config(self) -> "MixtralConfig":
initializer_range=source.init_method_std,
# vocab
vocab_size=self.tokenizer.vocab_size,
head_dim=source.kv_channels,
)


Expand All @@ -281,7 +283,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv):
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_num = megatron_config.num_attention_heads
head_size = hidden_size // head_num
head_size = megatron_config.kv_channels
qkv_total_dim = head_num + 2 * num_query_groups

linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size])
Expand Down