Skip to content

Error When Using Training Example on MPS Device #82

@sam-hey

Description

@sam-hey

When attempting to train on an MPS device, the following error occurs due to an issue with torch.compile():
pytorch/pytorch#96976

To resolve this issue, you can either remove torch.compile() or follow the provided error message by suppressing errors using the following code snippet:

import torch._dynamo
torch._dynamo.config.suppress_errors = True

By doing this, you can bypass the error and proceed with training on the MPS device.

model = torch.compile(model)

No sentence-transformers model found with name bert-base-uncased.
The checkpoint does not contain a linear projection layer. Adding one with output dimensions (768, 128).
Created a PyLate model from base encoder.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The tokenizer does not support resizing the token embeddings, the prefixes token have not been added to vocabulary.
  0%|                                                                                                                                                  | 0/100 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2234, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 588, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1334, in load
    compiled_graph = compile_fx_fn(
                     ^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 780, in run
    return super().run(*args)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1319, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1024, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1021, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 361, in wrapped
    out = decomp_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5264, in var_mean
    return var_mean_helper_(
           ^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5249, in var_mean_helper_
    else var_mean_welford_(**kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5203, in var_mean_welford_
    mean, m2, _ = ir.WelfordReduction.create(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py", line 1608, in create
    hint, split = Reduction.num_splits(
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py", line 851, in num_splits
    not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 465, in has_feature
    return feature in self.get_backend_features(get_device_type(device))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py", line 170, in get_backend_features
    return scheduling(None).get_backend_features(device)
           ^^^^^^^^^^^^^^^^
torch._inductor.exc.LoweringException: TypeError: 'NoneType' object is not callable
  target: aten.var_mean.correction
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FlexibleLayout('mps', torch.float32, size=[2, 32, 768], stride=[24576, 768, 1]), data=Pointwise(
      'mps',
      torch.float32,
      def inner_fn(index):
          i0, i1, i2 = index
          tmp0 = ops.load(primals_2, i1 + 32 * i0)
          tmp1 = ops.load(primals_4, i2 + 768 * tmp0)
          tmp2 = ops.load(primals_1, i1 + 32 * i0)
          tmp3 = ops.load(primals_5, i2 + 768 * tmp2)
          tmp4 = tmp1 + tmp3
          tmp5 = ops.load(primals_3, i1)
          tmp6 = ops.load(primals_6, i2 + 768 * tmp5)
          tmp7 = tmp4 + tmp6
          return tmp7
      ,
      ranges=[2, 32, 768],
      origin_node=add_1,
      origins=OrderedSet([embedding_2, embedding, add, add_1, embed...
    ))
  ))
  args[1]: [2]
  kwargs: {'correction': 0, 'keepdim': True}

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/A200009373/Documents/Coding/pylate/test2-add-embeddings.py", line 79, in <module>
    trainer.train()
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3579, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/sentence_transformers/trainer.py", line 393, in compute_loss
    loss = loss_fn(features, labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/pylate/losses/distillation.py", line 83, in forward
    self.model(sentence_features[0])["token_embeddings"], p=2, dim=-1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
    self._return(inst)
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
    self.output.compile_subgraph(
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1142, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: TypeError: 'NoneType' object is not callable
  target: aten.var_mean.correction
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FlexibleLayout('mps', torch.float32, size=[2, 32, 768], stride=[24576, 768, 1]), data=Pointwise(
      'mps',
      torch.float32,
      def inner_fn(index):
          i0, i1, i2 = index
          tmp0 = ops.load(primals_2, i1 + 32 * i0)
          tmp1 = ops.load(primals_4, i2 + 768 * tmp0)
          tmp2 = ops.load(primals_1, i1 + 32 * i0)
          tmp3 = ops.load(primals_5, i2 + 768 * tmp2)
          tmp4 = tmp1 + tmp3
          tmp5 = ops.load(primals_3, i1)
          tmp6 = ops.load(primals_6, i2 + 768 * tmp5)
          tmp7 = tmp4 + tmp6
          return tmp7
      ,
      ranges=[2, 32, 768],
      origin_node=add_1,
      origins=OrderedSet([embedding_2, embedding, add, add_1, embed...
    ))
  ))
  args[1]: [2]
  kwargs: {'correction': 0, 'keepdim': True}

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions