Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
direct_register_custom_op,
get_bool_env_var,
get_device_name,
is_cuda_available,
is_hip,
)

Expand Down
156 changes: 88 additions & 68 deletions python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
import builtins
import inspect
import re
from copy import deepcopy
from typing import Callable, Dict, Optional, Type, Union

import torch
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMoEMethod,
)
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
Expand Down Expand Up @@ -180,96 +179,117 @@ def gptq_get_quant_method(self, layer, prefix):
return None


def awq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinLinearMethod,
AWQMoEMethod,
)
original_isinstance = builtins.isinstance

from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead

if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
return None
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
"""
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize sglang layers
"""

if reverse:
builtins.isinstance = original_isinstance
return

original_awq_moe_method_apply = AWQMoEMethod.apply


def awq_moe_method_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
**kwargs,
):
return original_awq_moe_method_apply(
self,
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
e_score_correction_bias,
)


def patch_vllm_linear_base_isinstance():
import builtins

from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)

from sglang.srt.layers.linear import LinearBase as PatchedLinearBase

original_isinstance = builtins.isinstance
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
)

def patched_isinstance(obj, classinfo):
if classinfo is LinearBase:
return original_isinstance(obj, PatchedLinearBase)
if classinfo is FusedMoE:
return original_isinstance(obj, PatchedFusedMoE)
if classinfo is VocabParallelEmbedding:
return original_isinstance(obj, PatchedVocabParallelEmbedding)
return original_isinstance(obj, classinfo)

builtins.isinstance = patched_isinstance


def apply_monkey_patches():
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
"""
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert sglang arguments to vllm arguments.
"""
original_apply = class_obj.apply
sig = inspect.signature(original_apply)
param_names = list(sig.parameters.keys())
has_correction_bias = "e_score_correction_bias" in param_names

def new_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
):
assert activation == "silu"
assert inplace and not no_combine

kwargs = {
"self": self,
"layer": layer,
"x": x,
"router_logits": router_logits,
"top_k": top_k,
"renormalize": renormalize,
"use_grouped_topk": use_grouped_topk,
"topk_group": topk_group,
"num_expert_group": num_expert_group,
"custom_routing_function": custom_routing_function,
}
if correction_bias is not None:
if not has_correction_bias:
raise ValueError(
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
)
kwargs["e_score_correction_bias"] = correction_bias
return original_apply(**kwargs)

setattr(class_obj, "apply", new_apply)


def monkey_patch_quant_configs():
"""Apply all monkey patches in one place."""
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod

setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)

monkey_patch_moe_apply(AWQMoEMethod)
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)


patch_vllm_linear_base_isinstance()
# Apply patches when module is imported
apply_monkey_patches()
monkey_patch_quant_configs()


__all__ = [
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]
4 changes: 4 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager
Expand Down Expand Up @@ -341,13 +342,16 @@ def load_model(self):
# Load the model
# Remove monkey_patch when linear.py quant remove dependencies with vllm
monkey_patch_vllm_parallel_state()
monkey_patch_isinstance_for_vllm_base_layer()

with self.memory_saver_adapter.region():
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
monkey_patch_vllm_parallel_state(reverse=True)
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)

if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_REASONING_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST = (
"hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
)
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4,hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_SMALL_VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B"

Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class TestFile:
TestFile("models/test_reward_models.py", 83),
TestFile("models/test_gme_qwen_models.py", 45),
TestFile("test_abort.py", 51),
TestFile("test_awq.py"),
TestFile("test_block_int8.py", 22),
TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 447),
Expand Down
44 changes: 44 additions & 0 deletions test/srt/test_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


class TestAWQ(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--trust-remote-code"],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.65)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions test/srt/test_nightly_gsm8k_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.82,
"hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.84,
"hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.83,
"hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.60,
}


Expand Down
Loading