-
Notifications
You must be signed in to change notification settings - Fork 4.5k
ministral3 #14251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Fridge003
merged 13 commits into
sgl-project:main
from
JustinTong0323:add-new-model-ministral-3
Dec 4, 2025
Merged
ministral3 #14251
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
e6c9514
perpare for ministral 3
JustinTong0323 1a14eb0
fix rope config
JustinTong0323 793ccb4
remap fp8 weights
JustinTong0323 ad2c34e
lint
JustinTong0323 d4e5f29
fix ministral fp8 vision model
yueming-yuan 14799d1
lint
yueming-yuan 40e3baf
add unit test
yueming-yuan 7f884ec
Merge branch 'main' into add-new-model-ministral-3
JustinTong0323 93a13ed
remove submodule FlashMLA
JustinTong0323 fc69f4a
lint
JustinTong0323 1e8fd64
update
JustinTong0323 bc5ab15
Merge branch 'main' into add-new-model-ministral-3
JustinTong0323 44beef7
fix processor conflict
JustinTong0323 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| import torch | ||
| from transformers import PretrainedConfig | ||
|
|
||
| from sglang.srt.layers.quantization.base_config import QuantizationConfig | ||
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | ||
| from sglang.srt.models.llama import ( | ||
| LlamaAttention, | ||
| LlamaDecoderLayer, | ||
| LlamaForCausalLM, | ||
| LlamaModel, | ||
| ) | ||
| from sglang.srt.utils import add_prefix, make_layers | ||
|
|
||
|
|
||
| def _get_llama_4_attn_scale( | ||
| positions_ids: torch.Tensor, beta: float, max_position_embeddings: int | ||
| ) -> torch.Tensor: | ||
| scaling = 1 + beta * torch.log( | ||
| 1 + torch.floor(positions_ids / max_position_embeddings) | ||
| ) | ||
| return scaling.unsqueeze(-1) | ||
|
|
||
|
|
||
| class Ministral3Attention(LlamaAttention): | ||
| def __init__( | ||
| self, | ||
| config: PretrainedConfig, | ||
| hidden_size: int, | ||
| num_heads: int, | ||
| num_kv_heads: int, | ||
| layer_id: int = 0, | ||
| rope_theta: float = 1000000.0, | ||
| rope_scaling: Optional[Dict[str, Any]] = {}, | ||
| rope_is_neox_style: bool = True, | ||
| max_position_embeddings: int = 8192, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| bias: bool = False, | ||
| ) -> None: | ||
| super().__init__( | ||
| config, | ||
| hidden_size, | ||
| num_heads, | ||
| num_kv_heads, | ||
| layer_id, | ||
| rope_theta, | ||
| rope_scaling, | ||
| rope_is_neox_style, | ||
| max_position_embeddings, | ||
| quant_config, | ||
| prefix, | ||
| bias, | ||
| ) | ||
| # Ministral3 specific: llama 4 style scaling beta | ||
| self.llama_4_scaling_beta = None | ||
| if hasattr(config, "rope_parameters") and config.rope_parameters: | ||
| self.llama_4_scaling_beta = config.rope_parameters.get( | ||
| "llama_4_scaling_beta" | ||
| ) | ||
|
|
||
| # sliding window | ||
| self.sliding_window = getattr(config, "sliding_window", None) | ||
| if self.sliding_window is not None: | ||
| # Update RadixAttention with sliding window if needed | ||
| # currently RadixAttention in sglang handles this mostly via logic in forward/flashinfer | ||
| pass | ||
|
|
||
| def forward( | ||
| self, | ||
| positions: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| ) -> torch.Tensor: | ||
| qkv, _ = self.qkv_proj(hidden_states) | ||
| q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | ||
|
|
||
| # Apply RoPE | ||
| q, k = self.rotary_emb(positions, q, k) | ||
|
|
||
| # Ministral3 / Llama 4 scaling | ||
| if self.llama_4_scaling_beta is not None: | ||
| scale = _get_llama_4_attn_scale( | ||
| positions, self.llama_4_scaling_beta, self.max_position_embeddings | ||
| ).to(q.dtype) | ||
| # q shape is [batch_size * seq_len, num_heads * head_dim] or [batch_size * seq_len, num_heads, head_dim] | ||
| # positions is [batch_size * seq_len] | ||
| # scale is [batch_size * seq_len, 1] | ||
| # We need to reshape q to apply scale correctly if it's flattened | ||
| # Assuming q is (total_tokens, num_heads * head_dim) | ||
| q = q.view(-1, self.num_heads, self.head_dim) | ||
| q = q * scale.unsqueeze(1) # Broadcast over heads | ||
| q = q.view(-1, self.num_heads * self.head_dim) | ||
|
|
||
| attn_output = self.attn(q, k, v, forward_batch) | ||
| output, _ = self.o_proj(attn_output) | ||
| return output | ||
|
|
||
|
|
||
| class Ministral3DecoderLayer(LlamaDecoderLayer): | ||
| def __init__(self, config, layer_id=0, quant_config=None, prefix=""): | ||
| super().__init__(config, layer_id, quant_config, prefix) | ||
| self.self_attn = Ministral3Attention( | ||
| config=config, | ||
| hidden_size=self.hidden_size, | ||
| num_heads=config.num_attention_heads, | ||
| num_kv_heads=config.num_key_value_heads, | ||
| layer_id=layer_id, | ||
| rope_theta=getattr(config, "rope_parameters", {}).get( | ||
| "rope_theta", 1000000.0 | ||
| ), | ||
| rope_scaling=getattr( | ||
| config, "rope_parameters", {} | ||
| ), # rope_scaling is rope_parameters in Ministral3Config | ||
| max_position_embeddings=getattr( | ||
| config, "original_max_position_embeddings", 16384 | ||
| ), | ||
| quant_config=quant_config, | ||
| prefix=add_prefix("self_attn", prefix), | ||
| bias=getattr(config, "attention_bias", False) | ||
| or getattr(config, "bias", False), | ||
| ) | ||
|
|
||
|
|
||
| class Ministral3Model(LlamaModel): | ||
| def __init__( | ||
| self, | ||
| config: PretrainedConfig, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| ) -> None: | ||
| # Override layer creation to use Ministral3Attention | ||
| super().__init__(config, quant_config, prefix) | ||
|
|
||
| self.layers, self.start_layer, self.end_layer = make_layers( | ||
| config.num_hidden_layers, | ||
| lambda idx, prefix: Ministral3DecoderLayer( | ||
| config=config, quant_config=quant_config, layer_id=idx, prefix=prefix | ||
| ), | ||
| pp_rank=self.pp_group.rank_in_group, | ||
| pp_size=self.pp_group.world_size, | ||
| prefix="model.layers", | ||
| ) | ||
|
|
||
|
|
||
| class Ministral3ForCausalLM(LlamaForCausalLM): | ||
| def _init_model( | ||
| self, | ||
| config: PretrainedConfig, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| ): | ||
| return Ministral3Model(config, quant_config, prefix=prefix) | ||
|
|
||
|
|
||
| EntryClass = [Ministral3ForCausalLM] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import unittest | ||
| from types import SimpleNamespace | ||
|
|
||
| from sglang.test.gsm8k_mixin import GSM8KMixin | ||
| from sglang.test.mmmu_vlm_mixin import MMMUVLMMixin | ||
| from sglang.test.test_utils import CustomTestCase | ||
|
|
||
| MODEL = "mistralai/Ministral-3-3B-Instruct-2512" | ||
|
|
||
|
|
||
| class TestMinistral3TextOnly(GSM8KMixin, CustomTestCase): | ||
| accuracy = 0.6 | ||
| model = MODEL | ||
| other_args = ["--trust-remote-code"] | ||
|
|
||
|
|
||
| class TestMinistral3MMMU(MMMUVLMMixin, CustomTestCase): | ||
| accuracy = 0.3 | ||
| model = MODEL | ||
| other_args = ["--trust-remote-code"] | ||
| mmmu_args = ["--limit=0.1"] | ||
| """`--limit=0.1`: 10 percent of each task - this is fine for testing since the nominal result isn't interesting - this run is just to prevent relative regressions.""" | ||
|
|
||
| def test_vlm_mmmu_benchmark(self): | ||
| self._run_vlm_mmmu_test( | ||
| SimpleNamespace(model=self.model, mmmu_accuracy=self.accuracy), "./logs" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace(".weight_scale_inv", ".weight_scale")Incompatible with llama-fp8(Fine-grained FP8):
Parameter model.layers.0.mlp.down_proj.weight_scale not found in params_dict