Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions tests/test_async_simple_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,21 +542,45 @@ class TestPackFieldValues:
def test_uniform_tensors_to_stack(self):
"""Same-shape tensors → torch.stack."""
values = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]
result = AsyncSimpleStorageManager._pack_field_values(values)
result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]
assert isinstance(result, torch.Tensor)
assert not result.is_nested
assert result.shape == (2, 2)

def test_variable_length_tensors_to_nested(self):
"""Different-shape tensors → nested tensor."""
values = [torch.tensor([1.0]), torch.tensor([2.0, 3.0])]
result = AsyncSimpleStorageManager._pack_field_values(values)
result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]
assert isinstance(result, torch.Tensor)
assert result.is_nested

def test_non_tensors_to_nontensorstack(self):
"""Non-tensor values → NonTensorStack."""
values = ["hello", "world"]
result = AsyncSimpleStorageManager._pack_field_values(values)
result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]
assert isinstance(result, NonTensorStack)
assert result.tolist() == ["hello", "world"]

def test_mixed_tensors_and_none_to_nontensorstack(self):
"""Mixed tensor + None values should stay as NonTensorStack (no stacking)."""
t0 = torch.tensor([1.0, 2.0])
t2 = torch.tensor([3.0, 4.0])
values = [t0, None, t2]

result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]

assert isinstance(result, NonTensorStack)
unpacked = result.tolist()
assert len(unpacked) == 3
assert torch.equal(unpacked[0], t0)
assert unpacked[1] is None
assert torch.equal(unpacked[2], t2)

def test_all_none_to_nontensorstack(self):
"""All-None values should be preserved in NonTensorStack."""
values = [None, None]

result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]

assert isinstance(result, NonTensorStack)
assert result.tolist() == [None, None]
44 changes: 29 additions & 15 deletions transfer_queue/storage/managers/simple_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,24 +387,38 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack:
"""
Pack a list of per-sample values into a batched container.

For tensor values, this performs a memory copy via stacking or nested tensor creation.
Non-tensor values are grouped into a ``NonTensorStack`` without copying.
For pure tensor lists (no None), this performs a memory copy via stacking
or nested tensor creation. Mixed types, non-tensor values, or lists
containing None placeholders are grouped into a ``NonTensorStack``.

Args:
values: List of per-sample values to pack. May contain None for
unfilled batch positions.

Returns:
A stacked ``torch.Tensor`` (or nested tensor) when all values are
tensors, otherwise a ``NonTensorStack``.

Raises:
ValueError: If *values* is empty.
"""
if not values:
raise ValueError("_pack_field_values received empty values list; caller should filter empty batches")
if any(v is None for v in values):
raise ValueError("_pack_field_values received None in values list; some batch positions were not filled")
if all(isinstance(v, torch.Tensor) for v in values):
if all(v.shape == values[0].shape for v in values):
return torch.stack(values)
try:
return torch.nested.as_nested_tensor(values, layout=torch.jagged)
except (RuntimeError, TypeError) as e:
logger.warning(
f"Failed to pack nested tensor with jagged layout. "
f"Falling back to strided layout. Detailed error: {e}"
)
return torch.nested.as_nested_tensor(values, layout=torch.strided)
non_none = [v for v in values if v is not None]
if non_none and all(isinstance(v, torch.Tensor) for v in non_none):
if len(non_none) == len(values):
# Pure tensor list — try stacking / nested tensor
if all(v.shape == values[0].shape for v in values):
return torch.stack(values)
try:
return torch.nested.as_nested_tensor(values, layout=torch.jagged)
except (RuntimeError, TypeError) as e:
logger.warning(
f"Failed to pack nested tensor with jagged layout. "
f"Falling back to strided layout. Detailed error: {e}"
)
return torch.nested.as_nested_tensor(values, layout=torch.strided)
# Mixed tensor + None — cannot stack, fall through to NonTensorStack
return NonTensorStack(*values)
Comment on lines +407 to 422
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New None-tolerant behavior isn’t covered by existing tests: there are tests for uniform tensors / nested tensors / non-tensors, but no assertion that a list containing None returns a NonTensorStack (and preserves None in the correct positions). Please add a unit test for values=[tensor, None, tensor] (and optionally [None, None]).

Copilot uses AI. Check for mistakes.

async def get_data(self, metadata: BatchMeta) -> TensorDict:
Expand Down
Loading