Skip to content

Commit e18467a

Browse files
committed
[RL] Allow passing tensors of different dtypes for FlattenedTensorBucket
1 parent 9509c4c commit e18467a

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

python/sglang/srt/weight_sync/tensor_bucket.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
4949

5050
for i, (name, tensor) in enumerate(named_tensors):
51-
flattened = tensor.flatten()
51+
flattened = tensor.flatten().view(torch.uint8)
5252
flattened_tensors[i] = flattened
5353

5454
# Store metadata
@@ -93,14 +93,12 @@ def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
9393
reconstructed = [None] * len(self.metadata)
9494

9595
for i, meta in enumerate(self.metadata):
96-
tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].reshape(
97-
meta.shape
96+
tensor = (
97+
self.flattened_tensor[meta.start_idx : meta.end_idx]
98+
.view(meta.dtype)
99+
.reshape(meta.shape)
98100
)
99101

100-
# batch dtype conversion (if needed)
101-
if tensor.dtype != meta.dtype:
102-
tensor = tensor.to(meta.dtype)
103-
104102
reconstructed[i] = (meta.name, tensor)
105103

106104
return reconstructed

0 commit comments

Comments
 (0)