Skip to content

[BUG] tensor_only=True rejects parameterized generic annotations of tensor-compatible types #1658

@peterdsharpe

Description

@peterdsharpe

Description

@tensorclass(tensor_only=True) rejects field annotations that use parameterized generics of tensor-compatible types (e.g., TensorDict[str, torch.Tensor]), even though the underlying origin type (TensorDict) is a valid tensor collection.

Bare TensorDict works. TensorDict[str, torch.Tensor] does not - despite being the same type with additional static type information.

Minimum reproducible example

from tensordict import TensorDict, tensorclass
import torch

# Works: bare TensorDict
@tensorclass(tensor_only=True)
class WorksOK:
    data: TensorDict

# Fails: parameterized TensorDict
@tensorclass(tensor_only=True)
class Fails:
    data: TensorDict[str, torch.Tensor]

Output:

Traceback (most recent call last):
  File "repro.py", line 12, in <module>
    @tensorclass(tensor_only=True)
  File ".../tensordict/tensorclass.py", line 914, in wrap
    ...
  File ".../tensordict/tensorclass.py", line 1449, in _get_type_hints
    raise _TENSOR_ONLY_TYPE_ERR
TypeError: tensor_only requires types to be Tensor, Tensor-subtrypes or None.

Root cause

is_tensor_or_optional_tensor in _get_type_hints (tensorclass.py, ~line 1420) handles Union types but not other parameterized generics. TensorDict[str, torch.Tensor] evaluates to a types.GenericAlias, which fails the isinstance(type_hint, type) check and falls through to an unconditional return False:

def is_tensor_or_optional_tensor(type_hint):
    if isinstance(type_hint, type):
        return issubclass(type_hint, _TensorTypes) or _is_tensor_collection(type_hint)
    ...
    origin = get_origin(type_hint)
    if origin is Union:
        ...
    return False  # <-- TensorDict[str, Tensor] lands here

The function already computes origin = get_origin(type_hint) but only checks origin is Union. For TensorDict[str, torch.Tensor], get_origin() returns TensorDict itself - a valid tensor collection - but this is never tested:

>>> type(TensorDict[str, torch.Tensor])
<class 'types.GenericAlias'>
>>> isinstance(TensorDict[str, torch.Tensor], type)
False
>>> get_origin(TensorDict[str, torch.Tensor]) is TensorDict
True

Proposed fix

Add a check for parameterized generics after the Union handling, before the final return False:

    origin = get_origin(type_hint)
    if origin is Union:
        args = get_args(type_hint)
        return all(
            t is None or t is NoneType or is_tensor_or_optional_tensor(t)
            for t in args
        )
    # Handle parameterized generics (e.g., TensorDict[str, Tensor])
    if origin is not None and isinstance(origin, type):
        return issubclass(origin, _TensorTypes) or _is_tensor_collection(origin)
    return False

This correctly accepts TensorDict[K, V] (origin is TensorDict, a tensor collection) while still rejecting non-tensor generics like list[int] or dict[str, float] (origin is list/dict, not a tensor type or collection).

I validated the fix against these cases:

Annotation Expected Result
torch.Tensor True True
TensorDict True True
TensorDict[str, TensorDict] True True
torch.Tensor | None True True
TensorDict | None True True
@tensorclass subclass True True
int False False
str False False
list[int] False False
dict[str, int] False False
Any False False

Environment

  • tensordict 0.11.0
  • torch 2.10.0+cu128
  • Python 3.13.8

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions