1+ import json
12import logging
23from typing import List , Optional
34
67logger = logging .getLogger (__name__ )
78
89
10+ def get_ib_devices_for_gpu (ib_device_str : Optional [str ], gpu_id : int ) -> Optional [str ]:
11+ """
12+ Parse IB device string and get IB devices for a specific GPU ID.
13+
14+ Supports both formats:
15+ 1. Old format: "ib0, ib1, ib2"
16+ 2. New format: {0: "ib0, ib1", 1: "ib2, ib3", 2: "ib4"}
17+
18+ Args:
19+ ib_device_str: The original IB device string
20+ gpu_id: The GPU ID to get devices for
21+
22+ Returns:
23+ IB devices string for the GPU, or None if not available
24+ """
25+ if ib_device_str is None or not ib_device_str .strip ():
26+ return None
27+
28+ ib_device_str = ib_device_str .strip ()
29+
30+ # Check if it's JSON format (new format)
31+ try :
32+ parsed_json = json .loads (ib_device_str )
33+ if isinstance (parsed_json , dict ):
34+ # Validate format - keys should be integers (or string rep), values should be strings
35+ gpu_mapping = {}
36+ for gpu_key , ib_devices in parsed_json .items ():
37+ if (
38+ isinstance (gpu_key , str )
39+ and gpu_key .isdigit ()
40+ and isinstance (ib_devices , str )
41+ ):
42+ gpu_mapping [int (gpu_key )] = ib_devices .strip ()
43+ elif isinstance (gpu_key , int ) and isinstance (ib_devices , str ):
44+ gpu_mapping [gpu_key ] = ib_devices .strip ()
45+ else :
46+ raise ValueError (
47+ f"Invalid format: keys must be integers (or string representations of integers) and values must be strings"
48+ )
49+
50+ if not gpu_mapping :
51+ raise ValueError ("No valid GPU mappings found in JSON" )
52+
53+ # Return devices for specific GPU
54+ if gpu_id in gpu_mapping :
55+ return gpu_mapping [gpu_id ]
56+ else :
57+ raise ValueError (
58+ f"No IB devices configured for GPU { gpu_id } . Available GPUs: { list (gpu_mapping .keys ())} "
59+ )
60+
61+ except json .JSONDecodeError :
62+ # Not JSON format, treat as old format - return same devices for all GPUs
63+ return ib_device_str if ib_device_str else None
64+
65+
966class MooncakeTransferEngine :
1067
1168 def __init__ (self , hostname : str , gpu_id : int , ib_device : Optional [str ] = None ):
@@ -21,7 +78,7 @@ def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
2178 self .engine = TransferEngine ()
2279 self .hostname = hostname
2380 self .gpu_id = gpu_id
24- self .ib_device = ib_device
81+ self .ib_device = get_ib_devices_for_gpu ( ib_device , gpu_id )
2582
2683 self .initialize (
2784 hostname = self .hostname ,
0 commit comments