diff --git a/tests/test_async_simple_storage_manager.py b/tests/test_async_simple_storage_manager.py index bbf6d4b..4d1419a 100644 --- a/tests/test_async_simple_storage_manager.py +++ b/tests/test_async_simple_storage_manager.py @@ -542,7 +542,7 @@ class TestPackFieldValues: def test_uniform_tensors_to_stack(self): """Same-shape tensors → torch.stack.""" values = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] - result = AsyncSimpleStorageManager._pack_field_values(values) + result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined] assert isinstance(result, torch.Tensor) assert not result.is_nested assert result.shape == (2, 2) @@ -550,13 +550,37 @@ def test_uniform_tensors_to_stack(self): def test_variable_length_tensors_to_nested(self): """Different-shape tensors → nested tensor.""" values = [torch.tensor([1.0]), torch.tensor([2.0, 3.0])] - result = AsyncSimpleStorageManager._pack_field_values(values) + result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined] assert isinstance(result, torch.Tensor) assert result.is_nested def test_non_tensors_to_nontensorstack(self): """Non-tensor values → NonTensorStack.""" values = ["hello", "world"] - result = AsyncSimpleStorageManager._pack_field_values(values) + result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined] assert isinstance(result, NonTensorStack) assert result.tolist() == ["hello", "world"] + + def test_mixed_tensors_and_none_to_nontensorstack(self): + """Mixed tensor + None values should stay as NonTensorStack (no stacking).""" + t0 = torch.tensor([1.0, 2.0]) + t2 = torch.tensor([3.0, 4.0]) + values = [t0, None, t2] + + result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined] + + assert isinstance(result, NonTensorStack) + unpacked = result.tolist() + assert len(unpacked) == 3 + assert torch.equal(unpacked[0], t0) + assert unpacked[1] is None + assert torch.equal(unpacked[2], t2) + + def test_all_none_to_nontensorstack(self): + """All-None values should be preserved in NonTensorStack.""" + values = [None, None] + + result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined] + + assert isinstance(result, NonTensorStack) + assert result.tolist() == [None, None] diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 27e173c..00d8782 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -387,24 +387,38 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack: """ Pack a list of per-sample values into a batched container. - For tensor values, this performs a memory copy via stacking or nested tensor creation. - Non-tensor values are grouped into a ``NonTensorStack`` without copying. + For pure tensor lists (no None), this performs a memory copy via stacking + or nested tensor creation. Mixed types, non-tensor values, or lists + containing None placeholders are grouped into a ``NonTensorStack``. + + Args: + values: List of per-sample values to pack. May contain None for + unfilled batch positions. + + Returns: + A stacked ``torch.Tensor`` (or nested tensor) when all values are + tensors, otherwise a ``NonTensorStack``. + + Raises: + ValueError: If *values* is empty. """ if not values: raise ValueError("_pack_field_values received empty values list; caller should filter empty batches") - if any(v is None for v in values): - raise ValueError("_pack_field_values received None in values list; some batch positions were not filled") - if all(isinstance(v, torch.Tensor) for v in values): - if all(v.shape == values[0].shape for v in values): - return torch.stack(values) - try: - return torch.nested.as_nested_tensor(values, layout=torch.jagged) - except (RuntimeError, TypeError) as e: - logger.warning( - f"Failed to pack nested tensor with jagged layout. " - f"Falling back to strided layout. Detailed error: {e}" - ) - return torch.nested.as_nested_tensor(values, layout=torch.strided) + non_none = [v for v in values if v is not None] + if non_none and all(isinstance(v, torch.Tensor) for v in non_none): + if len(non_none) == len(values): + # Pure tensor list — try stacking / nested tensor + if all(v.shape == values[0].shape for v in values): + return torch.stack(values) + try: + return torch.nested.as_nested_tensor(values, layout=torch.jagged) + except (RuntimeError, TypeError) as e: + logger.warning( + f"Failed to pack nested tensor with jagged layout. " + f"Falling back to strided layout. Detailed error: {e}" + ) + return torch.nested.as_nested_tensor(values, layout=torch.strided) + # Mixed tensor + None — cannot stack, fall through to NonTensorStack return NonTensorStack(*values) async def get_data(self, metadata: BatchMeta) -> TensorDict: