Skip to content

Commit 45d4487

Browse files
authored
Fix Minimax M2 loading issue (sgl-project#13956)
1 parent 02cea58 commit 45d4487

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

python/sglang/srt/layers/quantization/fp8_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,13 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
244244

245245
if not (shape_supported and dtype_supported):
246246
# fall back to triton
247+
# If weight_scale is in UE8M0 packed format (int32), convert back to float32
248+
# UE8M0 format has shape (N, K//block_k//4) with dtype int32
249+
# Triton expects shape (N//block_n, K//block_k) with dtype float32
250+
if weight_scale.dtype == torch.int32:
251+
weight_scale = _unpack_ue8m0_scale_for_triton(
252+
weight_scale, weight.shape, block_size
253+
)
247254
return triton_w8a8_block_fp8_linear(
248255
input, weight, block_size, weight_scale, input_scale, bias
249256
)
@@ -267,6 +274,67 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
267274
return output.to(dtype=output_dtype).view(*output_shape)
268275

269276

277+
def _unpack_ue8m0_scale_for_triton(
278+
sf_packed: torch.Tensor,
279+
weight_shape: Tuple[int, int],
280+
block_size: List[int],
281+
) -> torch.Tensor:
282+
"""
283+
Unpack UE8M0 packed scale tensor back to float32 format for triton kernel.
284+
285+
The UE8M0 format packs scales as:
286+
- Shape: (N, K//block_k//4) with dtype int32
287+
- Each int32 contains 4 uint8 scale values
288+
289+
Triton expects:
290+
- Shape: (N//block_n, K//block_k) with dtype float32
291+
292+
Args:
293+
sf_packed: Packed scale tensor with shape (N, packed_k_groups) and dtype int32
294+
weight_shape: (N, K) shape of the weight tensor
295+
block_size: [block_n, block_k] quantization block size
296+
297+
Returns:
298+
Unpacked scale tensor with shape (n_groups, k_groups) and dtype float32
299+
"""
300+
assert sf_packed.dtype == torch.int32
301+
assert len(sf_packed.shape) == 2
302+
303+
N, K = weight_shape
304+
block_n, block_k = block_size
305+
n_groups = ceil_div(N, block_n)
306+
k_groups = ceil_div(K, block_k)
307+
308+
mn_repeat, k_div_4 = sf_packed.shape
309+
k_packed = k_div_4 * 4
310+
311+
# Unpack int32 -> 4x uint8 -> float32
312+
# Each uint8 represents an exponent in UE8M0 format
313+
sf_u8 = sf_packed.contiguous().view(torch.uint8).view(mn_repeat, k_packed)
314+
sf_fp32 = (sf_u8.to(torch.int32) << 23).view(torch.float32)
315+
316+
# Handle row dimension - may have 128x replication or direct mapping
317+
if mn_repeat == N:
318+
# Rows are replicated 128 times, take every 128th row
319+
# sf_fp32 shape: (N, k_packed) -> (n_groups, k_packed)
320+
# Select representative rows at indices 0, 128, 256, ...
321+
indices = torch.arange(0, N, block_n, device=sf_packed.device)
322+
sf_fp32 = sf_fp32.index_select(0, indices)
323+
elif mn_repeat == n_groups:
324+
# Already in the correct n_groups format
325+
pass
326+
else:
327+
raise ValueError(
328+
f"Unexpected scale shape: sf_packed.shape={sf_packed.shape}, "
329+
f"weight_shape={weight_shape}, block_size={block_size}"
330+
)
331+
332+
# Crop k dimension to expected size (remove padding if any)
333+
sf_fp32 = sf_fp32[:, :k_groups].contiguous()
334+
335+
return sf_fp32
336+
337+
270338
def aiter_w8a8_block_fp8_linear(
271339
input: torch.Tensor,
272340
weight: torch.Tensor,

test/nightly/test_minimax_m2_perf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def setUpClass(cls):
2121
cls.other_args = [
2222
"--tp",
2323
"8",
24+
"--ep",
25+
"8",
26+
"--model-loader-extra-config",
27+
'{"enable_multithread_load": true}',
2428
"--trust-remote-code",
2529
]
2630

@@ -34,6 +38,7 @@ def test_bench_one_batch(self):
3438
input_lens=self.input_lens,
3539
output_lens=self.output_lens,
3640
other_args=self.other_args,
41+
extra_bench_args=["--trust-remote-code"],
3742
)
3843

3944
self.runner.add_report(results)

0 commit comments

Comments
 (0)