Skip to content

[BUG] torch.compile - __setitem__ recompiles when internal dict length changes #1606

@peterdsharpe

Description

@peterdsharpe

When a compiled function is called on TensorDicts with different numbers of pre-existing keys, Dynamo guards on len(self._tensordict). Each distinct length triggers a recompile.

This also affects in-place mutation: if compiled code adds a key to a TensorDict, the next call with the original (pre-mutation) length fails the guard.

Concrete scenarios that could hit this:

  • A model that receives data with an optional field (present in some samples, absent in others)
  • Any lazy caching pattern where a TensorDict accumulates computed properties over time (this is my use case)
  • A training loop that enriches a TensorDict by adding keys

Reproducer

import torch
from tensordict import TensorDict

torch._dynamo.reset()

def fn(td):
    td["new"] = torch.ones(4)
    return td["new"]

compiled = torch.compile(fn)

for n_keys in range(10):
    td = TensorDict(
        {f"k{i}": torch.randn(4) for i in range(n_keys)},
        batch_size=[4],
    )
    compiled(td)

print(torch._dynamo.utils.counters["frames"]["total"])
# Prints 9 (many recompiles), ideal is 1

Printout:

W0224 02:00:46.201000 874296 torch/_dynamo/convert_frame.py:1676] [0/8] torch._dynamo hit config.recompile_limit (8)
W0224 02:00:46.201000 874296 torch/_dynamo/convert_frame.py:1676] [0/8]    function: 'fn' (<ipython-input-1-1cf5d9213a05>:6)
W0224 02:00:46.201000 874296 torch/_dynamo/convert_frame.py:1676] [0/8]    last reason: 0/7: len(td._tensordict) == 7                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
W0224 02:00:46.201000 874296 torch/_dynamo/convert_frame.py:1676] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0224 02:00:46.201000 874296 torch/_dynamo/convert_frame.py:1676] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html
9

Recompile log

Same as above, but set env var TORCH_LOGS="recompiles", and run the above code:

V0224 02:04:01.601000 875651 torch/_dynamo/guards.py:4514] [0/1] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:01.601000 875651 torch/_dynamo/guards.py:4514] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0224 02:04:01.601000 875651 torch/_dynamo/guards.py:4514] [0/1] [__recompiles]     - 0/0: not td._tensordict                                       # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.695000 875651 torch/_dynamo/guards.py:4514] [0/2] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:01.695000 875651 torch/_dynamo/guards.py:4514] [0/2] [__recompiles]     triggered by the following guard failure(s):
V0224 02:04:01.695000 875651 torch/_dynamo/guards.py:4514] [0/2] [__recompiles]     - 0/1: len(td._tensordict) == 1                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.695000 875651 torch/_dynamo/guards.py:4514] [0/2] [__recompiles]     - 0/0: not td._tensordict                                       # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.786000 875651 torch/_dynamo/guards.py:4514] [0/3] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:01.786000 875651 torch/_dynamo/guards.py:4514] [0/3] [__recompiles]     triggered by the following guard failure(s):
V0224 02:04:01.786000 875651 torch/_dynamo/guards.py:4514] [0/3] [__recompiles]     - 0/2: len(td._tensordict) == 2                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.786000 875651 torch/_dynamo/guards.py:4514] [0/3] [__recompiles]     - 0/1: len(td._tensordict) == 1                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.786000 875651 torch/_dynamo/guards.py:4514] [0/3] [__recompiles]     - 0/0: not td._tensordict                                       # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles]     triggered by the following guard failure(s):
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles]     - 0/3: len(td._tensordict) == 3                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles]     - 0/2: len(td._tensordict) == 2                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles]     - 0/1: len(td._tensordict) == 1                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.881000 875651 torch/_dynamo/guards.py:4514] [0/4] [__recompiles]     - 0/0: not td._tensordict                                       # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles]     triggered by the following guard failure(s):
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles]     - 0/4: len(td._tensordict) == 4                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles]     - 0/3: len(td._tensordict) == 3                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles]     - 0/2: len(td._tensordict) == 2                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles]     - 0/1: len(td._tensordict) == 1                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:01.973000 875651 torch/_dynamo/guards.py:4514] [0/5] [__recompiles]     - 0/0: not td._tensordict                                       # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles]     triggered by the following guard failure(s):
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles]     - 0/5: len(td._tensordict) == 5                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles]     - 0/4: len(td._tensordict) == 4                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles]     - 0/3: len(td._tensordict) == 3                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles]     - 0/2: len(td._tensordict) == 2                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles]     - 0/1: len(td._tensordict) == 1                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.356000 875651 torch/_dynamo/guards.py:4514] [0/6] [__recompiles]     - 0/0: not td._tensordict                                       # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles]     triggered by the following guard failure(s):
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles]     - 0/6: len(td._tensordict) == 6                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles]     - 0/5: len(td._tensordict) == 5                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles]     - 0/4: len(td._tensordict) == 4                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles]     - 0/3: len(td._tensordict) == 3                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles]     - 0/2: len(td._tensordict) == 2                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles]     - 0/1: len(td._tensordict) == 1                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.448000 875651 torch/_dynamo/guards.py:4514] [0/7] [__recompiles]     - 0/0: not td._tensordict                                       # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles] Recompiling function fn in <ipython-input-1-e15185e7bc2b>:6
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles]     triggered by the following guard failure(s):
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles]     - 0/7: len(td._tensordict) == 7                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles]     - 0/6: len(td._tensordict) == 6                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles]     - 0/5: len(td._tensordict) == 5                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles]     - 0/4: len(td._tensordict) == 4                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles]     - 0/3: len(td._tensordict) == 3                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles]     - 0/2: len(td._tensordict) == 2                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles]     - 0/1: len(td._tensordict) == 1                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
V0224 02:04:02.542000 875651 torch/_dynamo/guards.py:4514] [0/8] [__recompiles]     - 0/0: not td._tensordict                                       # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
W0224 02:04:02.543000 875651 torch/_dynamo/convert_frame.py:1676] [0/8] torch._dynamo hit config.recompile_limit (8)
W0224 02:04:02.543000 875651 torch/_dynamo/convert_frame.py:1676] [0/8]    function: 'fn' (<ipython-input-1-e15185e7bc2b>:6)
W0224 02:04:02.543000 875651 torch/_dynamo/convert_frame.py:1676] [0/8]    last reason: 0/7: len(td._tensordict) == 7                                 # self._tensordict[key] = value  # tensordict/_td.py:2486 in _set_str
W0224 02:04:02.543000 875651 torch/_dynamo/convert_frame.py:1676] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0224 02:04:02.543000 875651 torch/_dynamo/convert_frame.py:1676] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html
9

Root cause

In _set_str (_td.py:2486), the line self._tensordict[key] = value is a Python dict assignment. Dynamo installs a DICT_LENGTH guard on self._tensordict. When the TensorDict has a different number of keys, the guard fails and triggers recompilation.

Suggested fix

Under torch.compiler.is_compiling(), avoid code paths that cause Dynamo to install length-dependent guards on self._tensordict. One approach is to provide a compile-friendly _set_str that treats the internal dict as a dynamic container.

System info

Describe the characteristic of your environment:

  • uv installed
  • Python 3.13.9
  • tensordict 0.11.0
  • torch 2.10
  • Linux aarch64
  • NVIDIA GPU
In [2]: import tensordict, numpy, sys, torch
   ...: print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
0.11.0 2.2.6 3.13.9 (main, Oct 14 2025, 21:26:54) [Clang 20.1.4 ] linux 2.10.0+cpu

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions