Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.remote_instance_weight_loader_utils import register_memory_region_v2
from sglang.srt.server_args import ServerArgs
from torch_memory_saver import torch_memory_saver
from tqdm import tqdm

from slime.utils.memory_utils import print_memory
Expand Down Expand Up @@ -326,28 +325,27 @@ def connect_rollout_engines(
# Create local model replicas and transfer engines for each target rollout shard
self.engines = {}
# Associate transfer tasks based on obtained session and weight info
with torch_memory_saver.region(tag=self.tag):
for target in targets:
session_id = targets_to_session_id[(target.engine_ind, target.engine_rank)]
remote_info = RemoteWeightInfo(session_id, self.remote_weight_infos_by_session_id[session_id][0])
parallelism_config = RankParallelismConfig.from_dict(
self.remote_weight_infos_by_session_id[session_id][1]
for target in targets:
session_id = targets_to_session_id[(target.engine_ind, target.engine_rank)]
remote_info = RemoteWeightInfo(session_id, self.remote_weight_infos_by_session_id[session_id][0])
parallelism_config = RankParallelismConfig.from_dict(
self.remote_weight_infos_by_session_id[session_id][1]
)
if target.engine_rank not in self.engines:
transfer_engine = self._create_transfer_engine()
logger.info(f"[RDMA] Creating model replica for engine rank {target.engine_rank}")
model_replica = self._create_inference_replica(
parallelism_config, self.args.hf_checkpoint, self.session_id_to_server_args[session_id]
)
print_memory(f"[RDMA] After model replica at {target.engine_rank}")
weight_memory_registry = self._register_replica_memory(
model_replica, self.remote_weight_infos_by_session_id[session_id][0], transfer_engine
)
self.engines[target.engine_rank] = TransferBundle(
model_replica, transfer_engine, weight_memory_registry, [remote_info]
)
if target.engine_rank not in self.engines:
transfer_engine = self._create_transfer_engine()
logger.info(f"[RDMA] Creating model replica for engine rank {target.engine_rank}")
model_replica = self._create_inference_replica(
parallelism_config, self.args.hf_checkpoint, self.session_id_to_server_args[session_id]
)
print_memory(f"[RDMA] After model replica at {target.engine_rank}")
weight_memory_registry = self._register_replica_memory(
model_replica, self.remote_weight_infos_by_session_id[session_id][0], transfer_engine
)
self.engines[target.engine_rank] = TransferBundle(
model_replica, transfer_engine, weight_memory_registry, [remote_info]
)
else:
self.engines[target.engine_rank].add_remote_session(remote_info)
else:
self.engines[target.engine_rank].add_remote_session(remote_info)

print_memory("[RDMA] After Local Engine Replicas and engine Creation")

Expand Down Expand Up @@ -422,10 +420,10 @@ def _update_bucket_weights_from_remote(

if not self._is_source or not converted_named_tensors:
return

if self._model_on_cpu:
torch_memory_saver.resume(self.tag)
self._model_on_cpu = False
# TODO(letian): update cpu -> gpu logic here ?
# if self._model_on_cpu:
# torch_memory_saver.resume(self.tag)
# self._model_on_cpu = False

for transfer_bundle in self.engines.values():
transfer_ready_params = transfer_bundle.get_transfer_ready_params(converted_named_tensors)
Expand Down Expand Up @@ -459,10 +457,11 @@ def finish_transfer_task(self) -> None:
transfer_bundle.reset()

# Offload model replicas from memory after transfer.
if not self._model_on_cpu:
print_memory("[RDMA] Before offloading model replica")
torch_memory_saver.pause(self.tag)
self._model_on_cpu = True
print_memory("[RDMA] After offloading model replica")
# TODO(letian): update gpu -> cpu logic here ?
# if not self._model_on_cpu:
# print_memory("[RDMA] Before offloading model replica")
# torch_memory_saver.pause(self.tag)
# self._model_on_cpu = True
# print_memory("[RDMA] After offloading model replica")

return