Skip to content

Commit 8dbe1da

Browse files
authored
nemo-ux: Use kv_channels to enable cases where head_dim != hidden_size // head… (#9994)
* Use kv_channels to enable cases where head_dim != hidden_size // head_num Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Add head_dim to exporter Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Drop default values for kv_channels Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent d0efff0 commit 8dbe1da

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

nemo/collections/llm/gpt/model/mistral.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def make_vocab_size_divisible_by(mistral_vocab_size):
111111
num_layers=source.num_hidden_layers,
112112
hidden_size=source.hidden_size,
113113
ffn_hidden_size=source.intermediate_size,
114+
kv_channels=source.get('head_dim', source.hidden_size // source.num_attention_heads),
114115
num_attention_heads=source.num_attention_heads,
115116
# max_position_embeddings=source.max_position_embeddings,
116117
init_method_std=source.initializer_range,
@@ -183,6 +184,7 @@ def config(self) -> "MistralConfig":
183184
num_key_value_heads=source.num_query_groups,
184185
rope_theta=source.rotary_base,
185186
vocab_size=self.tokenizer.vocab_size,
187+
head_dim=source.kv_channels,
186188
)
187189

188190

@@ -202,7 +204,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v):
202204
heads_per_group = head_num // num_query_groups
203205
hidden_size = megatron_config.hidden_size
204206
head_num = megatron_config.num_attention_heads
205-
head_size = hidden_size // head_num
207+
head_size = megatron_config.kv_channels
206208

207209
old_tensor_shape = q.size()
208210
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
@@ -244,7 +246,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv):
244246
heads_per_group = head_num // num_query_groups
245247
hidden_size = megatron_config.hidden_size
246248
head_num = megatron_config.num_attention_heads
247-
head_size = hidden_size // head_num
249+
head_size = megatron_config.kv_channels
248250
qkv_total_dim = head_num + 2 * num_query_groups
249251

250252
linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size])

nemo/collections/llm/gpt/model/mixtral.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def config(self) -> MixtralConfig8x7B | MixtralConfig8x22B:
155155
num_layers=config.num_hidden_layers,
156156
hidden_size=config.hidden_size,
157157
ffn_hidden_size=config.intermediate_size,
158+
kv_channels=config.get('head_dim', config.hidden_size // config.num_attention_heads),
158159
max_position_embeddings=config.max_position_embeddings, # TODO
159160
seq_length=config.max_position_embeddings,
160161
# RoPE
@@ -197,7 +198,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v):
197198
heads_per_group = head_num // num_query_groups
198199
hidden_size = megatron_config.hidden_size
199200
head_num = megatron_config.num_attention_heads
200-
head_size = hidden_size // head_num
201+
head_size = megatron_config.kv_channels
201202

202203
old_tensor_shape = q.size()
203204
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
@@ -302,6 +303,7 @@ def config(self) -> "MixtralConfig":
302303
initializer_range=source.init_method_std,
303304
# vocab
304305
vocab_size=self.tokenizer.vocab_size,
306+
head_dim=source.kv_channels,
305307
)
306308

307309

@@ -321,7 +323,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv):
321323
heads_per_group = head_num // num_query_groups
322324
hidden_size = megatron_config.hidden_size
323325
head_num = megatron_config.num_attention_heads
324-
head_size = hidden_size // head_num
326+
head_size = megatron_config.kv_channels
325327
qkv_total_dim = head_num + 2 * num_query_groups
326328

327329
linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size])

0 commit comments

Comments
 (0)