Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm

from slime.utils.distributed_utils import get_gloo_group
from slime.utils.profile_utils import FunctionStepProfiler
from slime.utils.timer import timer

from ..megatron_to_hf import convert_to_hf
Expand Down Expand Up @@ -45,6 +46,19 @@ def __init__(
self.transfer_plan = RemoteTransferPlan(args, model, weight_update_mode)
self._is_source = self.transfer_plan.is_source()
self.global_rank = dist.get_rank(group=get_gloo_group())
self.update_weight_profiler = None
self.update_weights_wrapped = None
if getattr(args, "use_pytorch_profiler_update_weight", False):
start_step = getattr(args, "profile_update_weight_start", 0)
end_step = getattr(args, "profile_update_weight_end", 1)
self.update_weight_profiler = FunctionStepProfiler(
self.args,
name="update_weights",
label="update_weights",
start=start_step,
end=end_step,
)
self.update_weights_wrapped = self.update_weight_profiler.wrap(self.update_weights_implementation)

@abstractmethod
def connect_rollout_engines(
Expand All @@ -63,7 +77,7 @@ def _update_bucket_weights_from_remote(
"""

@torch.no_grad()
def update_weights(self) -> None:
def update_weights_implementation(self) -> None:
"""
For each named parameter in the model, do bucketed weight update by all-gather EP/TP, convert and quantize,
and relies on underlying implementation to do the transfer.
Expand Down Expand Up @@ -91,6 +105,14 @@ def update_weights(self) -> None:
self.leader_post_update()
dist.barrier(group=get_gloo_group())

@torch.no_grad()
def update_weights(self) -> None:
if self.update_weights_wrapped is not None:
self.update_weights_wrapped()
# Don't call stop() here - let profiler accumulate steps across multiple calls
else:
self.update_weights_implementation()

def leader_post_update(self) -> None:
ray.get([engine.continue_generation.remote() for engine in self.rollout_engines])
return
Expand Down
18 changes: 18 additions & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,24 @@ def add_debug_arguments(parser):
default="torch",
)
parser.add_argument("--check-weight-update-equal", action="store_true")
parser.add_argument(
"--use-pytorch-profiler-update-weight",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this impact performance of the run? If not, let's just set it default on

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're not targeting at merging this feature into the official slime repo, i think it's ok. And i'll also tag this pr with [TEMP] and we'll revert this PR in the end.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aight

action="store_true",
default=False,
help="Enable PyTorch profiler for weight update operations. Requires --tensorboard-dir to be set.",
)
parser.add_argument(
"--profile-update-weight-start",
type=int,
default=0,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just have the start default to 1, and log everything afterwards? We only do 3 training steps anyways and this profiler will only be used for our profiling configs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean we could skip the first step?

help="After enabling PyTorch profiler for weight update operations, start profiling from this point. Requires --tensorboard-dir to be set.",
)
parser.add_argument(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and I'd prefer we remove this

Copy link
Collaborator Author

@JensenFire JensenFire Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or keep it, just like the profiler of slime itself?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh they have this also?

Copy link
Collaborator Author

@JensenFire JensenFire Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the interesting part. Slime and Megatron-LM share the same argument use-pytorch-profiler (/root/Megatron-LM/megatron/training/arguments.py). We could also find the usage in TrainProfiler of /root/slime/slime/utils/profile_utils.py:

class TrainProfiler:
    def __init__(self, args):
        self.args = args
        self._torch_profiler_overall = None
        self._memory_profiler_overall = None

        if args.use_pytorch_profiler and ("train_overall" in args.profile_target):
            self._torch_profiler_overall = _create_torch_profiler(args, name="train_overall")

        if args.record_memory_history and ("train_overall" in args.profile_target):
            self._memory_profiler_overall = _BaseMemoryProfiler.create(args)
            self._memory_profiler_overall.start()

Basically, when --use-pytorch-profiler enabled, it will record all python functions between steps, and it's quite large (>100MB) if we want to get the mapping between python functions and the cpu/gpu occupying. Besides that, there're too many redundant parts in this profiler, since all we care about is updating weights. That's why i create another function profiler

"--profile-update-weight-end",
type=int,
default=1,
help="After enabling PyTorch profiler for weight update operations, end profiling at this point. Requires --tensorboard-dir to be set.",
)
return parser

def add_network_arguments(parser):
Expand Down
76 changes: 76 additions & 0 deletions slime/utils/profile_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import gzip
import json
import logging
import tempfile
import time
import traceback
from pathlib import Path

import torch
from torch.profiler import record_function

from slime.utils.memory_utils import print_memory

Expand Down Expand Up @@ -78,6 +82,78 @@ def _create_torch_profiler(args, name):
)


class FunctionStepProfiler:
"""
Wraps a function to profile each invocation.

Uses torch.profiler.profile with CUDA activities to capture kernel-level
details and Python-to-CUDA correlation.
"""

def __init__(self, args, name: str, label: str = "target_fn", start: int = 0, end: int = 1):
self.args = args
self.name = name
self.label = label
self.call_count = 0
self.enabled = True
self.output_dir = Path(args.tensorboard_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.start = start
self.end = end

def wrap(self, fn):
def _wrapped(*args, **kwargs):
if not self.enabled:
return fn(*args, **kwargs)
self.call_count += 1
if not (self.start <= self.call_count < self.end):
return fn(*args, **kwargs)
logger.info(f"FunctionStepProfiler: Profiling call {self.call_count} for '{self.label}'")

try:
# Determine activities based on CUDA availability
assert torch.cuda.is_available(), "CUDA must be available for FunctionStepProfiler"
activities = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]

# Use torch.profiler.profile for proper CUDA kernel profiling
with torch.profiler.profile(
activities=activities,
record_shapes=True,
with_stack=True,
profile_memory=True,
with_flops=True,
) as prof:
with record_function(self.label):
result = fn(*args, **kwargs)
torch.cuda.synchronize()

# Export the trace to a gzipped file
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
trace_file = self.output_dir / f"{self.name}_call{self.call_count}_rank_{rank}.pt.trace.json.gz"
with tempfile.NamedTemporaryFile(suffix=".json", delete=True) as tmp:
prof.export_chrome_trace(tmp.name)
with open(tmp.name, "rb") as f_in, gzip.open(trace_file, "wb") as f_out:
f_out.write(f_in.read())
logger.info(f"FunctionStepProfiler: Call {self.call_count} profiled, trace saved to {trace_file}")
return result
except Exception as e:
raise ValueError(f"FunctionStepProfiler: Profiler error for '{self.label}', details: {e}") from e

return _wrapped


def merge_traces(name="update_weights", call_end=5, rank=0, output_dir="/root/profiler_logs/"):
merged = {"traceEvents": []}
output_file = Path(output_dir) / f"merged_{name}_rank_{rank}_merged.pt.trace.json.gz"
for call_iter in range(1, call_end):
f = Path(output_dir) / f"{name}_call{call_iter}_rank_{rank}.pt.trace.json.gz"
with gzip.open(f, "rt") as fp:
data = json.load(fp)
merged["traceEvents"].extend(data.get("traceEvents", []))
with gzip.open(output_file, "wt") as fp:
json.dump(merged, fp)


class _BaseMemoryProfiler:
@staticmethod
def create(args):
Expand Down
12 changes: 12 additions & 0 deletions tests/test_weight_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typer

import slime.utils.external_utils.command_utils as U
from slime.utils.profile_utils import merge_traces
from slime.utils.timer import log_experiment_start

MODEL_NAME = "Qwen3-4B"
Expand All @@ -26,6 +27,8 @@ class ScriptArgs(U.ExecuteTrainConfig):
num_rollout_gpus: int = 1
# Optimizations
pipelined_transfer: bool = False
# Profiling
use_pytorch_profiler_update_weight: bool = False

def validate(self):
assert self.sglang_pp == 1, "Not supported yet for sglang pp"
Expand Down Expand Up @@ -151,6 +154,13 @@ def execute(args: ScriptArgs):
if args.mode == "rdma":
misc_args += "--update-weight-transfer-mode rdma "

profile_args = (
"--use-pytorch-profiler-update-weight "
"--profile-update-weight-start 0 "
"--profile-update-weight-end 6 "
"--tensorboard-dir /root/profiler_logs/ "
)

train_args = (
f"{ckpt_args} "
f"{rollout_args} "
Expand All @@ -161,6 +171,7 @@ def execute(args: ScriptArgs):
f"{sglang_args} "
# f"{ci_args} "
f"{misc_args} "
f"{profile_args} "
)

U.execute_train(
Expand All @@ -170,6 +181,7 @@ def execute(args: ScriptArgs):
train_script="train_async.py",
# extra_env_vars={"RAY_DEBUG": "1"},
)
merge_traces(name="update_weights", call_end=5, rank=0, output_dir="/root/profiler_logs/")


@U.dataclass_cli
Expand Down