-
Notifications
You must be signed in to change notification settings - Fork 112
Description
When a compiled function is called on TensorDicts with different numbers of pre-existing keys, Dynamo guards on len(self._tensordict). Each distinct length triggers a recompile.
This also affects in-place mutation: if compiled code adds a key to a TensorDict, the next call with the original (pre-mutation) length fails the guard.
Concrete scenarios that could hit this:
- A model that receives data with an optional field (present in some samples, absent in others)
- Any lazy caching pattern where a TensorDict accumulates computed properties over time (this is my use case)
- A training loop that enriches a TensorDict by adding keys
Reproducer
import torch
from tensordict import TensorDict
torch._dynamo.reset()
def fn(td):
td["new"] = torch.ones(4)
return td["new"]
compiled = torch.compile(fn)
for n_keys in range(10):
td = TensorDict(
{f"k{i}": torch.randn(4) for i in range(n_keys)},
batch_size=[4],
)
compiled(td)
print(torch._dynamo.utils.counters["frames"]["total"])
# Prints 9 (many recompiles), ideal is 1Printout:
W0224 02:00:46.201000 874296 torch/_dynamo/convert_frame.py:1676] [0/8] torch._dynamo hit config.recompile_limit (8)
W0224 02:00:46.201000 874296 torch/_dynamo/convert_frame.py:1676] [0/8] function: 'fn' (<ipython-input-1-1cf5d9213a05>:6)
W0224 02:00:46.201000 874296 torch/_dynamo/convert_frame.py:1676] [0/8] last reason: 0/7: len(td._tensordict) == 7 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
W0224 02:00:46.201000 874296 torch/_dynamo/convert_frame.py:1676] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0224 02:00:46.201000 874296 torch/_dynamo/convert_frame.py:1676] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html
9
Recompile log
Same as above, but set env var TORCH_LOGS="recompiles", and run the above code:
V0224 02:04:01.601000 875651 torch/_dynamo/guards.py:4514] [0/1] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:01.601000 875651 torch/_dynamo/guards.py:4514] [0/1] [__recompiles] triggered by the following guard failure(s):
V0224 02:04:01.601000 875651 torch/_dynamo/guards.py:4514] [0/1] [__recompiles] - 0/0: not td._tensordict # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.695000 875651 torch/_dynamo/guards.py:4514] [0/2] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:01.695000 875651 torch/_dynamo/guards.py:4514] [0/2] [__recompiles] triggered by the following guard failure(s):
V0224 02:04:01.695000 875651 torch/_dynamo/guards.py:4514] [0/2] [__recompiles] - 0/1: len(td._tensordict) == 1 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.695000 875651 torch/_dynamo/guards.py:4514] [0/2] [__recompiles] - 0/0: not td._tensordict # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.786000 875651 torch/_dynamo/guards.py:4514] [0/3] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:01.786000 875651 torch/_dynamo/guards.py:4514] [0/3] [__recompiles] triggered by the following guard failure(s):
V0224 02:04:01.786000 875651 torch/_dynamo/guards.py:4514] [0/3] [__recompiles] - 0/2: len(td._tensordict) == 2 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.786000 875651 torch/_dynamo/guards.py:4514] [0/3] [__recompiles] - 0/1: len(td._tensordict) == 1 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.786000 875651 torch/_dynamo/guards.py:4514] [0/3] [__recompiles] - 0/0: not td._tensordict # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles] triggered by the following guard failure(s):
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles] - 0/3: len(td._tensordict) == 3 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles] - 0/2: len(td._tensordict) == 2 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles] - 0/1: len(td._tensordict) == 1 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles] - 0/0: not td._tensordict # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles] triggered by the following guard failure(s):
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles] - 0/4: len(td._tensordict) == 4 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles] - 0/3: len(td._tensordict) == 3 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles] - 0/2: len(td._tensordict) == 2 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles] - 0/1: len(td._tensordict) == 1 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles] - 0/0: not td._tensordict # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles] triggered by the following guard failure(s):
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles] - 0/5: len(td._tensordict) == 5 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles] - 0/4: len(td._tensordict) == 4 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles] - 0/3: len(td._tensordict) == 3 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles] - 0/2: len(td._tensordict) == 2 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles] - 0/1: len(td._tensordict) == 1 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles] - 0/0: not td._tensordict # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles] triggered by the following guard failure(s):
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles] - 0/6: len(td._tensordict) == 6 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles] - 0/5: len(td._tensordict) == 5 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles] - 0/4: len(td._tensordict) == 4 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles] - 0/3: len(td._tensordict) == 3 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles] - 0/2: len(td._tensordict) == 2 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles] - 0/1: len(td._tensordict) == 1 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles] - 0/0: not td._tensordict # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] triggered by the following guard failure(s):
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] - 0/7: len(td._tensordict) == 7 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] - 0/6: len(td._tensordict) == 6 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] - 0/5: len(td._tensordict) == 5 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] - 0/4: len(td._tensordict) == 4 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] - 0/3: len(td._tensordict) == 3 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] - 0/2: len(td._tensordict) == 2 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] - 0/1: len(td._tensordict) == 1 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] - 0/0: not td._tensordict # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
W0224 02:04:02.543000 875651 torch/_dynamo/convert_frame.py:1676] [0/8] torch._dynamo hit config.recompile_limit (8)
W0224 02:04:02.543000 875651 torch/_dynamo/convert_frame.py:1676] [0/8] function: 'fn' (<ipython-input-1-e15185e7bc2b>:6)
W0224 02:04:02.543000 875651 torch/_dynamo/convert_frame.py:1676] [0/8] last reason: 0/7: len(td._tensordict) == 7 # self._tensordict[key] = value # tensordict/_td.py:2486 in _set_str
W0224 02:04:02.543000 875651 torch/_dynamo/convert_frame.py:1676] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0224 02:04:02.543000 875651 torch/_dynamo/convert_frame.py:1676] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html
9
Root cause
In _set_str (_td.py:2486), the line self._tensordict[key] = value is a Python dict assignment. Dynamo installs a DICT_LENGTH guard on self._tensordict. When the TensorDict has a different number of keys, the guard fails and triggers recompilation.
Suggested fix
Under torch.compiler.is_compiling(), avoid code paths that cause Dynamo to install length-dependent guards on self._tensordict. One approach is to provide a compile-friendly _set_str that treats the internal dict as a dynamic container.
System info
Describe the characteristic of your environment:
- uv installed
- Python 3.13.9
- tensordict 0.11.0
- torch 2.10
- Linux aarch64
- NVIDIA GPU
In [2]: import tensordict, numpy, sys, torch
...: print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
0.11.0 2.2.6 3.13.9 (main, Oct 14 2025, 21:26:54) [Clang 20.1.4 ] linux 2.10.0+cpuChecklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)