Skip to content
16 changes: 8 additions & 8 deletions benchmark/mmmu/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ async def process_sample(
assert image is not None
image_path = sample["image_path"]
extra_body = None if lora_path is None else {"lora_path": lora_path}
response = await client.chat.completions.create(
model="default",
messages=[
payload = {
"model": "default",
"messages": [
{
"role": "user",
"content": [
Expand All @@ -95,11 +95,11 @@ async def process_sample(
],
}
],
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
extra_body=extra_body,
)
"extra_body": extra_body,
}
if sampling_params:
payload.update(sampling_params)
response = await client.chat.completions.create(**payload)
return sample, response.choices[0].message.content


Expand Down
25 changes: 17 additions & 8 deletions benchmark/mmmu/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class EvalArgs:
profile: bool = False
profile_number: int = 5
concurrency: int = 1
max_new_tokens: int = 30
max_new_tokens: Optional[int] = None
temperature: Optional[float] = None
response_answer_regex: str = "(.*)"
lora_path: Optional[str] = None

Expand Down Expand Up @@ -101,6 +102,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=EvalArgs.max_new_tokens,
help="Maximum number of new tokens to generate per sample.",
)
parser.add_argument(
"--temperature",
type=float,
default=EvalArgs.temperature,
help="Sampling temperature for generation.",
)
parser.add_argument(
"--response-answer-regex",
type=str,
Expand Down Expand Up @@ -241,19 +248,21 @@ def process_sample(i, sample):


def get_sampling_params(eval_args):
max_new_tokens = eval_args.max_new_tokens
temperature = 0.001

extra_request_body = {}
if eval_args.extra_request_body:
extra_request_body = json.loads(eval_args.extra_request_body)

return {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
sampling_params = {
**extra_request_body,
}

if eval_args.max_new_tokens is not None and eval_args.max_new_tokens > 0:
sampling_params.update({"max_completion_tokens": eval_args.max_new_tokens})

if eval_args.temperature is not None:
sampling_params.update({"temperature": eval_args.temperature})

return sampling_params


# ----------- Process Multi-choice -------------
def parse_multi_choice_response(response, all_choices, index2ans):
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/configs/olmo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import enum

from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -90,7 +89,6 @@ def __init__(
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
rope_config_validation(self)
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout

Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/configs/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import enum

from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging

from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
Expand Down Expand Up @@ -226,7 +225,6 @@ def __init__(
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.head_dim = head_dim
rope_config_validation(self)

# linear attention (gdn now part)
self.linear_conv_kernel_dim = linear_conv_kernel_dim
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config:
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
ignored_layers = cls.get_from_keys_or(
config, ["ignored_layers", "modules_to_not_convert"], None
)
if ignored_layers:
# hacking ministral
ignored_layers = [layer.replace("model.", "") for layer in ignored_layers]
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
return cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/layers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def is_layer_skipped(

is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in ignored_layers
is_shard_skipped = any(
ignored in shard_prefix for ignored in ignored_layers
)

if is_skipped is None:
is_skipped = is_shard_skipped
Expand All @@ -75,7 +77,7 @@ def is_layer_skipped(
"to have the same precision."
)
else:
is_skipped = prefix in ignored_layers
is_skipped = any(ignored in prefix for ignored in ignored_layers)
if "gate_up_proj" in prefix:
prefix_gate = prefix.replace("gate_up_proj", "gate_proj")
prefix_up = prefix.replace("gate_up_proj", "up_proj")
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,10 @@ def check_quantized_moe_compatibility(self):
quantization_config := getattr(
self.model_config.hf_config, "quantization_config", None
)
) is not None and "weight_block_size" in quantization_config:
weight_block_size_n = quantization_config["weight_block_size"][0]
) is not None and (
weight_block_size := quantization_config.get("weight_block_size", None)
) is not None:
weight_block_size_n = weight_block_size[0]

if self.tp_size % self.moe_ep_size != 0:
raise ValueError(
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,14 @@ def filter_duplicate_safetensors_files(
# torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, index_file)
if not os.path.isfile(index_file_name):
# NOTE: this is a trick of handling mistral model
# skip the unsupported consolidated.safetensors file
if len(hf_weights_files) == 2:
hf_weights_files.sort()
if hf_weights_files[0].endswith(
"consolidated.safetensors"
) and hf_weights_files[1].endswith("model.safetensors"):
return [hf_weights_files[1]]
return hf_weights_files

# Iterate through the weight_map (weight_name: safetensors files)
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/models/llama.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace(".weight_scale_inv", ".weight_scale")
Incompatible with llama-fp8(Fine-grained FP8):
Parameter model.layers.0.mlp.down_proj.weight_scale not found in params_dict

Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())

for name, loaded_weight in weights:
if name.endswith(".activation_scale"):
name = name.replace(".activation_scale", ".input_scale")
if name.endswith(".weight_scale_inv"):
name = name.replace(".weight_scale_inv", ".weight_scale")

layer_id = get_layer_id(name)
if (
layer_id is not None
Expand Down
157 changes: 157 additions & 0 deletions python/sglang/srt/models/ministral3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import Any, Dict, Optional

import torch
from transformers import PretrainedConfig

from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
)
from sglang.srt.utils import add_prefix, make_layers


def _get_llama_4_attn_scale(
positions_ids: torch.Tensor, beta: float, max_position_embeddings: int
) -> torch.Tensor:
scaling = 1 + beta * torch.log(
1 + torch.floor(positions_ids / max_position_embeddings)
)
return scaling.unsqueeze(-1)


class Ministral3Attention(LlamaAttention):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
rope_theta: float = 1000000.0,
rope_scaling: Optional[Dict[str, Any]] = {},
rope_is_neox_style: bool = True,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
bias: bool = False,
) -> None:
super().__init__(
config,
hidden_size,
num_heads,
num_kv_heads,
layer_id,
rope_theta,
rope_scaling,
rope_is_neox_style,
max_position_embeddings,
quant_config,
prefix,
bias,
)
# Ministral3 specific: llama 4 style scaling beta
self.llama_4_scaling_beta = None
if hasattr(config, "rope_parameters") and config.rope_parameters:
self.llama_4_scaling_beta = config.rope_parameters.get(
"llama_4_scaling_beta"
)

# sliding window
self.sliding_window = getattr(config, "sliding_window", None)
if self.sliding_window is not None:
# Update RadixAttention with sliding window if needed
# currently RadixAttention in sglang handles this mostly via logic in forward/flashinfer
pass

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

# Apply RoPE
q, k = self.rotary_emb(positions, q, k)

# Ministral3 / Llama 4 scaling
if self.llama_4_scaling_beta is not None:
scale = _get_llama_4_attn_scale(
positions, self.llama_4_scaling_beta, self.max_position_embeddings
).to(q.dtype)
# q shape is [batch_size * seq_len, num_heads * head_dim] or [batch_size * seq_len, num_heads, head_dim]
# positions is [batch_size * seq_len]
# scale is [batch_size * seq_len, 1]
# We need to reshape q to apply scale correctly if it's flattened
# Assuming q is (total_tokens, num_heads * head_dim)
q = q.view(-1, self.num_heads, self.head_dim)
q = q * scale.unsqueeze(1) # Broadcast over heads
q = q.view(-1, self.num_heads * self.head_dim)

attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output


class Ministral3DecoderLayer(LlamaDecoderLayer):
def __init__(self, config, layer_id=0, quant_config=None, prefix=""):
super().__init__(config, layer_id, quant_config, prefix)
self.self_attn = Ministral3Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=getattr(config, "rope_parameters", {}).get(
"rope_theta", 1000000.0
),
rope_scaling=getattr(
config, "rope_parameters", {}
), # rope_scaling is rope_parameters in Ministral3Config
max_position_embeddings=getattr(
config, "original_max_position_embeddings", 16384
),
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
bias=getattr(config, "attention_bias", False)
or getattr(config, "bias", False),
)


class Ministral3Model(LlamaModel):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
# Override layer creation to use Ministral3Attention
super().__init__(config, quant_config, prefix)

self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Ministral3DecoderLayer(
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix="model.layers",
)


class Ministral3ForCausalLM(LlamaForCausalLM):
def _init_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return Ministral3Model(config, quant_config, prefix=prefix)


EntryClass = [Ministral3ForCausalLM]
2 changes: 2 additions & 0 deletions python/sglang/srt/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,8 @@ def __init__(
dropout=0.0,
use_context_forward=False,
flatten_batch=False,
qkv_bias=False,
proj_bias=False,
prefix=f"{prefix}.attention",
)

Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/multimodal/processors/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
self.patch_size = self.vision_config.patch_size

self._processor.patch_size = self.patch_size
self._processor.spatial_merge_size = self.vision_config.spatial_merge_size
if hasattr(self.vision_config, "spatial_merge_size"):
self._processor.spatial_merge_size = self.vision_config.spatial_merge_size

self.mm_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token,
Expand Down
31 changes: 31 additions & 0 deletions test/srt/models/test_ministral3_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
from types import SimpleNamespace

from sglang.test.gsm8k_mixin import GSM8KMixin
from sglang.test.mmmu_vlm_mixin import MMMUVLMMixin
from sglang.test.test_utils import CustomTestCase

MODEL = "mistralai/Ministral-3-3B-Instruct-2512"


class TestMinistral3TextOnly(GSM8KMixin, CustomTestCase):
accuracy = 0.6
model = MODEL
other_args = ["--trust-remote-code"]


class TestMinistral3MMMU(MMMUVLMMixin, CustomTestCase):
accuracy = 0.3
model = MODEL
other_args = ["--trust-remote-code"]
mmmu_args = ["--limit=0.1"]
"""`--limit=0.1`: 10 percent of each task - this is fine for testing since the nominal result isn't interesting - this run is just to prevent relative regressions."""

def test_vlm_mmmu_benchmark(self):
self._run_vlm_mmmu_test(
SimpleNamespace(model=self.model, mmmu_accuracy=self.accuracy), "./logs"
)


if __name__ == "__main__":
unittest.main()
Loading
Loading