Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 89 additions & 5 deletions python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import compute_yarn_parameters
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
add_prefix,
Expand All @@ -97,6 +102,9 @@
_is_cpu = is_cpu()
_device_sm = get_device_sm()

if _is_cuda:
from sgl_kernel import fused_qk_norm_rope

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -176,6 +184,7 @@ def __init__(
use_qk_norm: bool = False,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
config: Optional[PretrainedConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -204,6 +213,7 @@ def __init__(
self.use_qk_norm = use_qk_norm
self.max_position_embeddings = max_position_embeddings
self.tp_rank = get_tensor_model_parallel_rank()
self.config = config

self.qkv_proj = QKVParallelLinear(
hidden_size,
Expand Down Expand Up @@ -249,6 +259,65 @@ def __init__(
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.alt_stream = alt_stream
self.use_fused_qk_norm_rope = (
get_global_server_args().enable_fused_qk_norm_rope
and self.use_qk_norm
and self.head_dim in (64, 128, 256)
)
self._used_fused_qk_norm_rope_last_call = False
self.compatible_with_fused_kv_buffer = True

if self.use_fused_qk_norm_rope:
self.factor, self.low, self.high, self.attention_factor = (
compute_yarn_parameters(self.config)
)

def apply_qk_norm_rope(self, qkv, positions, forward_batch):
use_fused = self.use_fused_qk_norm_rope and qkv.dtype == torch.bfloat16
if use_fused:
positions = (
positions.view(-1).to(dtype=torch.int32, device=qkv.device).contiguous()
)
fused_qk_norm_rope(
qkv,
self.num_heads,
self.num_kv_heads,
self.num_kv_heads,
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rope_theta,
self.rotary_emb.is_neox_style,
positions,
self.factor,
self.low,
self.high,
self.attention_factor,
self.rotary_emb.rotary_dim,
)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
self._used_fused_qk_norm_rope_last_call = True
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
else None
),
)
self._used_fused_qk_norm_rope_last_call = False
return q, k, v

def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
Expand Down Expand Up @@ -293,18 +362,32 @@ def forward_prepare(
if hidden_states.shape[0] == 0:
return hidden_states, forward_batch, None
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)

q, k, v = self.apply_qk_norm_rope(qkv, positions, forward_batch)

inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state

def forward_core(self, intermediate_state):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state)

q, k, v, fb = inner_state

must_save_kv = self._used_fused_qk_norm_rope_last_call
save_kv_cache = must_save_kv or not (
enable_fused_set_kv_buffer(forward_batch)
and self.compatible_with_fused_kv_buffer
)

attn_output = self.attn(
q,
k,
v,
fb,
save_kv_cache=save_kv_cache,
)
output, _ = self.o_proj(attn_output)
return output

Expand Down Expand Up @@ -710,6 +793,7 @@ def __init__(
prefix=add_prefix("self_attn", prefix),
use_qk_norm=config.use_qk_norm,
alt_stream=alt_stream,
config=config,
)

self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
Expand Down
3 changes: 2 additions & 1 deletion sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
"int num_heads_k, int num_heads_v, int head_dim, float eps, "
"Tensor q_weight, Tensor k_weight, float base, "
"bool is_neox, Tensor position_ids, float factor, float low, float high, float attention_factor) -> ()");
"bool is_neox, Tensor position_ids, float factor, float low, float high, float attention_factor, int rotary_dim) "
"-> ()");
m.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);

/*
Expand Down
79 changes: 48 additions & 31 deletions sgl-kernel/csrc/moe/fused_qknorm_rope_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ __global__ void fusedQKNormRopeKernel(
float factor, // factor in rope_scaling in config.json. When it is not 1.0, it means the model is using yarn.
float low, // threshold for high frequency
float high, // threshold for low frequency
float attention_factor // attention_factor applied on cos and sin
) {
float attention_factor, // attention_factor applied on cos and sin
int const rotary_dim) {
int const warpsPerBlock = blockDim.x / 32;
int const warpId = threadIdx.x / 32;
int const laneId = threadIdx.x % 32;
Expand Down Expand Up @@ -195,43 +195,54 @@ __global__ void fusedQKNormRopeKernel(
float cos_vals[numElemsPerThread];
float sin_vals[numElemsPerThread];
float pos_id = static_cast<float>(position_ids[tokenIdx]);
int const rotary_lanes = rotary_dim / numElemsPerThread; // rotary range
bool const in_rotary = (laneId < rotary_lanes);

if constexpr (interleave) {
// Perform interleaving. Fill cos_vals and sin_vals.
for (int i = 0; i < numElemsPerThread; i++) {
if (i % 2 == 0) {
elements2[i] = -elements[i + 1];
} else {
elements2[i] = elements[i - 1];
if (in_rotary) {
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = (i % 2 == 0) ? -elements[i + 1] : elements[i - 1];

int dim_idx = laneId * numElemsPerThread + i;
int half_dim = dim_idx / 2;
float freq = compute_freq_yarn(base, rotary_dim, half_dim, factor, low, high);
float theta = pos_id * freq;
__sincosf(theta, &sin_vals[i], &cos_vals[i]);
}
int dim_idx = laneId * numElemsPerThread + i;
int half_dim = dim_idx / 2;
float freq = compute_freq_yarn(base, head_dim, half_dim, factor, low, high);
float theta = pos_id * freq;
__sincosf(theta, &sin_vals[i], &cos_vals[i]);
}

} else {
// Neox style
// Before data exchange with in warp, we need to sync.
__syncwarp();
// Get the data from the other half of the warp. Fill cos_vals and sin_vals.
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
if (laneId < 16) {
elements2[i] = -elements2[i];
if (in_rotary) {
int const half_rotary_lanes = rotary_lanes / 2;
unsigned int active_mask = (1u << rotary_lanes) - 1; // 低16位为1: 0x0000FFFF
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(active_mask, elements[i], half_rotary_lanes);
if (laneId < half_rotary_lanes) {
elements2[i] = -elements2[i];
}

int dim_idx = laneId * numElemsPerThread + i;
dim_idx = (dim_idx * 2) % rotary_dim;
int half_dim = dim_idx / 2;
float freq = compute_freq_yarn(base, rotary_dim, half_dim, factor, low, high);
float theta = pos_id * freq;
__sincosf(theta, &sin_vals[i], &cos_vals[i]);
}
int dim_idx = laneId * numElemsPerThread + i;
dim_idx = (dim_idx * 2) % head_dim;
int half_dim = dim_idx / 2;
float freq = compute_freq_yarn(base, head_dim, half_dim, factor, low, high);
float theta = pos_id * freq;
__sincosf(theta, &sin_vals[i], &cos_vals[i]);
}
// __shfl_xor_sync does not provide memfence. Need to sync again.
__syncwarp();
}
for (int i = 0; i < numElemsPerThread; i++) {
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;

if (in_rotary) {
for (int i = 0; i < numElemsPerThread; i++) {
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
}
}

// Store.
{
vec_T vec;
Expand Down Expand Up @@ -270,6 +281,7 @@ void launchFusedQKNormRope(
float low,
float high,
float attention_factor,
int const rotary_dim,
cudaStream_t stream) {
constexpr int blockSize = 256;
int const warpsPerBlock = blockSize / 32;
Expand Down Expand Up @@ -297,7 +309,8 @@ void launchFusedQKNormRope(
factor,
low,
high,
attention_factor);
attention_factor,
rotary_dim);
});
break;
case 128:
Expand All @@ -316,7 +329,8 @@ void launchFusedQKNormRope(
factor,
low,
high,
attention_factor);
attention_factor,
rotary_dim);
});
break;
case 256:
Expand All @@ -335,7 +349,8 @@ void launchFusedQKNormRope(
factor,
low,
high,
attention_factor);
attention_factor,
rotary_dim);
});
break;
default:
Expand Down Expand Up @@ -363,16 +378,17 @@ void fused_qk_norm_rope(
double factor, // factor in rope_scaling in config.json. When it is not 1.0, it means the model is using yarn.
double low, // threshold for high frequency
double high, // threshold for low frequency
double attention_factor // attention_factor applied on cos and sin
) {
double attention_factor, // attention_factor applied on cos and sin
int64_t rotary_dim) {
// Input validation
TORCH_CHECK(qkv.dim() == 2, "QKV tensor must be 2D: [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim]");
TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]");
TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]");
TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]");
TORCH_CHECK(q_weight.size(0) == head_dim, "Query weights size must match head dimension");
TORCH_CHECK(k_weight.size(0) == head_dim, "Key weights size must match head dimension");

TORCH_CHECK(rotary_dim % (head_dim / 32) == 0, "rotary_dim must be divisible by numElemsPerThread");
TORCH_CHECK((rotary_dim / (head_dim / 32)) % 2 == 0, "rotary_lanes must be even for neox style");
CHECK_INPUT(qkv, torch::kBFloat16);
CHECK_INPUT(position_ids, torch::kInt32);
CHECK_INPUT(q_weight, torch::kBFloat16);
Expand Down Expand Up @@ -404,5 +420,6 @@ void fused_qk_norm_rope(
static_cast<float>(low),
static_cast<float>(high),
static_cast<float>(attention_factor),
static_cast<int>(rotary_dim),
stream);
}
3 changes: 2 additions & 1 deletion sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ void fused_qk_norm_rope(
double factor,
double low,
double high,
double attention_factor);
double attention_factor,
int64_t rotary_dim);

void cutlass_fp4_group_mm(
torch::Tensor& output,
Expand Down
2 changes: 2 additions & 0 deletions sgl-kernel/python/sgl_kernel/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def fused_qk_norm_rope(
low: float,
high: float,
attention_factor: float,
rotary_dim: int,
) -> None:
torch.ops.sgl_kernel.fused_qk_norm_rope(
qkv,
Expand All @@ -284,6 +285,7 @@ def fused_qk_norm_rope(
low,
high,
attention_factor,
rotary_dim,
)


Expand Down
Loading
Loading