Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
class SiluAndMul(CustomOp):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if get_global_server_args().rl_on_policy_target == "fsdp":
if get_global_server_args().rl_on_policy_target is not None:
self._forward_method = self.forward_native

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def _get_logits(
None, # bias
True, # is_vnni
)
elif get_global_server_args().rl_on_policy_target == "fsdp":
elif get_global_server_args().rl_on_policy_target is not None:
# Due to tie-weight, we may not be able to change lm_head's weight dtype
logits = torch.matmul(
hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(

self._apply_rotary_emb_wrapped = _apply_rotary_emb

if get_global_server_args().rl_on_policy_target == "fsdp":
if get_global_server_args().rl_on_policy_target is not None:
self._forward_method = self.forward_native
self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)(
self._apply_rotary_emb_wrapped
Expand All @@ -140,7 +140,7 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
init_device = (
"cpu" if get_global_server_args().rl_on_policy_target == "fsdp" else None
"cpu" if get_global_server_args().rl_on_policy_target is not None else None
)
inv_freq = 1.0 / (
base
Expand All @@ -151,7 +151,7 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
/ self.rotary_dim
)
)
if get_global_server_args().rl_on_policy_target == "fsdp":
if get_global_server_args().rl_on_policy_target is not None:
inv_freq = inv_freq.cuda()
return inv_freq

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def forward(
if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
probs_without_temp_scaling = torch.softmax(logits, dim=-1)

if get_global_server_args().rl_on_policy_target == "fsdp":
if get_global_server_args().rl_on_policy_target is not None:
logits_div_temperature = (
logits.bfloat16().div(sampling_info.temperatures).bfloat16()
)
Expand Down Expand Up @@ -156,7 +156,7 @@ def forward(
)

if return_logprob:
if get_global_server_args().rl_on_policy_target == "fsdp":
if get_global_server_args().rl_on_policy_target is not None:
logprobs = logprobs_via_logsoftmax_kernel
del logprobs_via_logsoftmax_kernel
# clamp to avoid -inf
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
self.act_fn = SiluAndMul()

def forward(self, x):
if get_global_server_args().rl_on_policy_target == "fsdp":
if get_global_server_args().rl_on_policy_target is not None:
x = x.bfloat16()

gate_up, _ = self.gate_up_proj(x)
Expand Down Expand Up @@ -281,7 +281,7 @@ def __init__(
prefix=add_prefix("embed_tokens", prefix),
params_dtype=(
torch.float32
if get_global_server_args().rl_on_policy_target == "fsdp"
if get_global_server_args().rl_on_policy_target is not None
else None
),
)
Expand Down Expand Up @@ -311,7 +311,7 @@ def __init__(
override_orig_dtype=torch.float32,
fp32_residual=True,
)
if get_global_server_args().rl_on_policy_target == "fsdp"
if get_global_server_args().rl_on_policy_target is not None
else {}
)
self.norm = RMSNorm(
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
weight_dtype=torch.float32,
cast_x_before_out_mul=True,
)
if get_global_server_args().rl_on_policy_target == "fsdp"
if get_global_server_args().rl_on_policy_target is not None
else {}
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
Expand Down Expand Up @@ -167,15 +167,15 @@ def forward(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if get_global_server_args().rl_on_policy_target == "fsdp":
if get_global_server_args().rl_on_policy_target is not None:
hidden_states = hidden_states.bfloat16()

qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)

if get_global_server_args().rl_on_policy_target == "fsdp":
if get_global_server_args().rl_on_policy_target is not None:
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)

Expand Down Expand Up @@ -229,7 +229,7 @@ def __init__(
override_orig_dtype=torch.float32,
fp32_residual=True,
)
if get_global_server_args().rl_on_policy_target == "fsdp"
if get_global_server_args().rl_on_policy_target is not None
else {}
)
self.input_layernorm = RMSNorm(
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@

RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]

RL_ON_POLICY_TARGET_CHOICES = ["fsdp"]

MOE_RUNNER_BACKEND_CHOICES = [
"auto",
"deep_gemm",
Expand Down Expand Up @@ -204,6 +206,10 @@ def add_radix_eviction_policy_choices(choices):
RADIX_EVICTION_POLICY_CHOICES.extend(choices)


def add_rl_on_policy_target_choices(choices):
RL_ON_POLICY_TARGET_CHOICES.extend(choices)


def add_mamba_ssm_dtype_choices(choices):
MAMBA_SSM_DTYPE_CHOICES.extend(choices)

Expand Down Expand Up @@ -3429,7 +3435,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"--rl-on-policy-target",
type=str,
default=ServerArgs.rl_on_policy_target,
choices=["fsdp"],
choices=RL_ON_POLICY_TARGET_CHOICES,
help="The training system that SGLang needs to match for true on-policy.",
)

Expand Down
Loading