[fix] Allow None values in _pack_field_values and fallback to NonTensorStack#75
[fix] Allow None values in _pack_field_values and fallback to NonTensorStack#750oshowero0 merged 2 commits intoAscend:mainfrom
_pack_field_values and fallback to NonTensorStack#75Conversation
NINGBENZHE
commented
Apr 7, 2026
- Modify _pack_field_values to tolerate None placeholders in the values list, falling back to NonTensorStack instead of raising ValueError.
- Pure tensor lists (no None) still use torch.stack or nested tensor.
- Update docstring to reflect the new None-tolerant behavior.
CLA Signature PassNINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
…orStack
- Modify _pack_field_values to tolerate None placeholders in the values
list, falling back to NonTensorStack instead of raising ValueError.
- Pure tensor lists (no None) still use torch.stack or nested tensor.
- Update docstring to reflect the new None-tolerant behavior.
Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
CLA Signature PassNINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
| return torch.nested.as_nested_tensor(values, layout=torch.strided) | ||
| non_none = [v for v in values if v is not None] | ||
| if non_none and all(isinstance(v, torch.Tensor) for v in non_none): | ||
| if not any(v is None for v in values): |
There was a problem hiding this comment.
Maybe consider
| if not any(v is None for v in values): | |
| if len(non_none) == len(values): |
to avoid redundant O(n) scanning of the values list.
There was a problem hiding this comment.
Pull request overview
This PR updates AsyncSimpleStorageManager._pack_field_values to tolerate None placeholders in per-sample value lists (common when scattering results back into original batch order), falling back to NonTensorStack instead of raising.
Changes:
- Allow
Nonevalues in_pack_field_valuesand fall back toNonTensorStackfor mixed/partial lists. - Preserve the existing fast-path for pure tensor lists (stack or nested tensor).
- Expand the docstring to describe the new behavior and contract.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| non_none = [v for v in values if v is not None] | ||
| if non_none and all(isinstance(v, torch.Tensor) for v in non_none): | ||
| if not any(v is None for v in values): | ||
| # Pure tensor list — try stacking / nested tensor | ||
| if all(v.shape == values[0].shape for v in values): | ||
| return torch.stack(values) | ||
| try: | ||
| return torch.nested.as_nested_tensor(values, layout=torch.jagged) | ||
| except (RuntimeError, TypeError) as e: | ||
| logger.warning( | ||
| f"Failed to pack nested tensor with jagged layout. " | ||
| f"Falling back to strided layout. Detailed error: {e}" | ||
| ) | ||
| return torch.nested.as_nested_tensor(values, layout=torch.strided) | ||
| # Mixed tensor + None — cannot stack, fall through to NonTensorStack | ||
| return NonTensorStack(*values) |
There was a problem hiding this comment.
New None-tolerant behavior isn’t covered by existing tests: there are tests for uniform tensors / nested tensors / non-tensors, but no assertion that a list containing None returns a NonTensorStack (and preserves None in the correct positions). Please add a unit test for values=[tensor, None, tensor] (and optionally [None, None]).
| if non_none and all(isinstance(v, torch.Tensor) for v in non_none): | ||
| if not any(v is None for v in values): |
There was a problem hiding this comment.
non_none is built and then any(v is None for v in values) scans the list again. You can avoid the extra pass by deriving has_none = len(non_none) != len(values) (or similar) and reusing it for the pure-tensor branch check.
| if non_none and all(isinstance(v, torch.Tensor) for v in non_none): | |
| if not any(v is None for v in values): | |
| has_none = len(non_none) != len(values) | |
| if non_none and all(isinstance(v, torch.Tensor) for v in non_none): | |
| if not has_none: |
| non_none = [v for v in values if v is not None] | ||
| if non_none and all(isinstance(v, torch.Tensor) for v in non_none): | ||
| if not any(v is None for v in values): |
There was a problem hiding this comment.
Can we merge some of these check logics? non_none filters None elements, then all(isinstance(v, torch.Tensor) for v in non_none) checks all the elements in non_none is Tensor, then if not any(v is None for v in values): checks whether any elements in original input is None again.
CLA Signature PassNINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
_pack_field_values and fallback to NonTensorStack