-
Notifications
You must be signed in to change notification settings - Fork 112
Description
Description
@tensorclass(tensor_only=True) rejects field annotations that use parameterized generics of tensor-compatible types (e.g., TensorDict[str, torch.Tensor]), even though the underlying origin type (TensorDict) is a valid tensor collection.
Bare TensorDict works. TensorDict[str, torch.Tensor] does not - despite being the same type with additional static type information.
Minimum reproducible example
from tensordict import TensorDict, tensorclass
import torch
# Works: bare TensorDict
@tensorclass(tensor_only=True)
class WorksOK:
data: TensorDict
# Fails: parameterized TensorDict
@tensorclass(tensor_only=True)
class Fails:
data: TensorDict[str, torch.Tensor]Output:
Traceback (most recent call last):
File "repro.py", line 12, in <module>
@tensorclass(tensor_only=True)
File ".../tensordict/tensorclass.py", line 914, in wrap
...
File ".../tensordict/tensorclass.py", line 1449, in _get_type_hints
raise _TENSOR_ONLY_TYPE_ERR
TypeError: tensor_only requires types to be Tensor, Tensor-subtrypes or None.
Root cause
is_tensor_or_optional_tensor in _get_type_hints (tensorclass.py, ~line 1420) handles Union types but not other parameterized generics. TensorDict[str, torch.Tensor] evaluates to a types.GenericAlias, which fails the isinstance(type_hint, type) check and falls through to an unconditional return False:
def is_tensor_or_optional_tensor(type_hint):
if isinstance(type_hint, type):
return issubclass(type_hint, _TensorTypes) or _is_tensor_collection(type_hint)
...
origin = get_origin(type_hint)
if origin is Union:
...
return False # <-- TensorDict[str, Tensor] lands hereThe function already computes origin = get_origin(type_hint) but only checks origin is Union. For TensorDict[str, torch.Tensor], get_origin() returns TensorDict itself - a valid tensor collection - but this is never tested:
>>> type(TensorDict[str, torch.Tensor])
<class 'types.GenericAlias'>
>>> isinstance(TensorDict[str, torch.Tensor], type)
False
>>> get_origin(TensorDict[str, torch.Tensor]) is TensorDict
TrueProposed fix
Add a check for parameterized generics after the Union handling, before the final return False:
origin = get_origin(type_hint)
if origin is Union:
args = get_args(type_hint)
return all(
t is None or t is NoneType or is_tensor_or_optional_tensor(t)
for t in args
)
# Handle parameterized generics (e.g., TensorDict[str, Tensor])
if origin is not None and isinstance(origin, type):
return issubclass(origin, _TensorTypes) or _is_tensor_collection(origin)
return FalseThis correctly accepts TensorDict[K, V] (origin is TensorDict, a tensor collection) while still rejecting non-tensor generics like list[int] or dict[str, float] (origin is list/dict, not a tensor type or collection).
I validated the fix against these cases:
| Annotation | Expected | Result |
|---|---|---|
torch.Tensor |
True |
True |
TensorDict |
True |
True |
TensorDict[str, TensorDict] |
True |
True |
torch.Tensor | None |
True |
True |
TensorDict | None |
True |
True |
@tensorclass subclass |
True |
True |
int |
False |
False |
str |
False |
False |
list[int] |
False |
False |
dict[str, int] |
False |
False |
Any |
False |
False |
Environment
- tensordict 0.11.0
- torch 2.10.0+cu128
- Python 3.13.8
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)