@@ -188,17 +188,17 @@ def release_indexes(self, partition_id: str, indexes_to_release: list[int]):
188188 if not partition_indexes :
189189 self .partition_to_indexes .pop (partition_id , None )
190190
191- def get_indexes_for_partition (self , partition_id ) -> set [int ]:
191+ def get_indexes_for_partition (self , partition_id ) -> list [int ]:
192192 """
193193 Get all global_indexes for the specified partition.
194194
195195 Args:
196196 partition_id: Partition ID
197197
198198 Returns:
199- set: Set of global_indexes for this partition
199+ list: List of global_indexes for this partition
200200 """
201- return self .partition_to_indexes .get (partition_id , set ()).copy ()
201+ return list ( self .partition_to_indexes .get (partition_id , set ()).copy () )
202202
203203
204204@dataclass
@@ -216,7 +216,7 @@ class DataPartitionStatus:
216216
217217 # Production status tensor - dynamically expandable
218218 # Values: 0 = not produced, 1 = ready for consumption
219- production_status : Optional [ Tensor ] = torch .zeros (TQ_INIT_SAMPLE_NUM , TQ_INIT_FIELD_NUM , dtype = torch .int8 )
219+ production_status : Tensor = torch .zeros (TQ_INIT_SAMPLE_NUM , TQ_INIT_FIELD_NUM , dtype = torch .int8 )
220220
221221 # Consumption status per task - task_name -> consumption_tensor
222222 # Each tensor tracks which samples have been consumed by that task
@@ -251,16 +251,16 @@ def total_fields_num(self) -> int:
251251 @property
252252 def allocated_fields_num (self ) -> int :
253253 """Current number of allocated columns in the tensor."""
254- return self .production_status .shape [1 ] if self . production_status is not None else 0
254+ return self .production_status .shape [1 ]
255255
256256 @property
257257 def allocated_samples_num (self ) -> int :
258258 """Current number of allocated rows in the tensor."""
259- return self .production_status .shape [0 ] if self . production_status is not None else 0
259+ return self .production_status .shape [0 ]
260260
261261 # ==================== Dynamic Expansion Methods ====================
262262
263- def ensure_samples_capacity (self , required_samples : int ) -> bool :
263+ def ensure_samples_capacity (self , required_samples : int ) -> None :
264264 """
265265 Ensure the production status tensor has enough rows for the required samples.
266266 Dynamically expands if needed using unified minimum expansion size.
@@ -291,17 +291,14 @@ def ensure_samples_capacity(self, required_samples: int) -> bool:
291291 f"to { new_samples } samples (added { min_expansion } samples)"
292292 )
293293
294- def ensure_fields_capacity (self , required_fields : int ):
294+ def ensure_fields_capacity (self , required_fields : int ) -> None :
295295 """
296296 Ensure the production status tensor has enough columns for the required fields.
297297 Dynamically expands if needed using unified minimum expansion size.
298298
299299 Args:
300300 required_fields: Minimum number of fields needed
301301 """
302- if self .production_status is None :
303- # Will be initialized when samples are added
304- return
305302
306303 current_fields = self .production_status .shape [1 ]
307304 if required_fields > current_fields :
@@ -498,7 +495,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
498495 return partition_global_index , consumption_status
499496
500497 # ==================== Production Status Interface ====================
501- def get_production_status_for_fields (self , field_names : list [str ], mask : bool = False ) -> tuple [Tensor , Tensor ]:
498+ def get_production_status_for_fields (
499+ self , field_names : list [str ], mask : bool = False
500+ ) -> tuple [Optional [Tensor ], Optional [Tensor ]]:
502501 """
503502 Check if all samples for specified fields are fully produced and ready.
504503
@@ -511,13 +510,13 @@ def get_production_status_for_fields(self, field_names: list[str], mask: bool =
511510 - Partition global index tensor
512511 - Production status tensor for the specified task. 1 for ready, 0 for not ready.
513512 """
514- if self . production_status is None or field_names is None or len (field_names ) == 0 :
515- return False
513+ if field_names is None or len (field_names ) == 0 :
514+ return None , None
516515
517516 # Check if all requested fields are registered
518517 for field_name in field_names :
519518 if field_name not in self .field_name_mapping :
520- return False
519+ return None , None
521520
522521 # Create column mask for requested fields
523522 col_mask = torch .zeros (self .allocated_fields_num , dtype = torch .bool )
@@ -548,8 +547,6 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]:
548547 Returns:
549548 List of sample indices that are ready for consumption
550549 """
551- if self .production_status is None :
552- return []
553550
554551 # Check if all requested fields are registered
555552 for field_name in field_names :
@@ -837,15 +834,15 @@ def list_partitions(self) -> list[str]:
837834
838835 # ==================== Partition Index Management API ====================
839836
840- def get_partition_index_range (self , partition : DataPartitionStatus ) -> set :
837+ def get_partition_index_range (self , partition : DataPartitionStatus ) -> list [ int ] :
841838 """
842839 Get all indexes for a specific partition.
843840
844841 Args:
845842 partition: Partition identifier
846843
847844 Returns:
848- Set of indexes allocated to the partition
845+ List of indexes allocated to the partition
849846 """
850847 return self .index_manager .get_indexes_for_partition (partition )
851848
@@ -980,6 +977,9 @@ def get_metadata(
980977 if mode == "fetch" :
981978 # Find ready samples within current data partition and package into BatchMeta when reading
982979
980+ if batch_size is None :
981+ raise ValueError ("must provide batch_size in fetch mode" )
982+
983983 start_time = time .time ()
984984 while True :
985985 # ready_for_consume_indexes: samples where all required fields are produced
0 commit comments