-
Notifications
You must be signed in to change notification settings - Fork 73
Open
Labels
bugSomething isn't workingSomething isn't working
Description
When attempting to train on an MPS device, the following error occurs due to an issue with torch.compile():
pytorch/pytorch#96976
To resolve this issue, you can either remove torch.compile() or follow the provided error message by suppressing errors using the following code snippet:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
By doing this, you can bypass the error and proceed with training on the MPS device.
Line 165 in e349918
| model = torch.compile(model) |
No sentence-transformers model found with name bert-base-uncased.
The checkpoint does not contain a linear projection layer. Adding one with output dimensions (768, 128).
Created a PyLate model from base encoder.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The tokenizer does not support resizing the token embeddings, the prefixes token have not been added to vocabulary.
0%| | 0/100 [00:00<?, ?it/s]Traceback (most recent call last):
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2234, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 588, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base
return _fw_compiler_base(model, example_inputs, is_inference)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
compiled_graph = FxGraphCache.load(
^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1334, in load
compiled_graph = compile_fx_fn(
^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile
graph.run(*example_inputs)
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 780, in run
return super().run(*args)
^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1319, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1024, in call_function
raise LoweringException(e, target, args, kwargs).with_traceback(
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1021, in call_function
out = lowerings[target](*args, **kwargs) # type: ignore[index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 361, in wrapped
out = decomp_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5264, in var_mean
return var_mean_helper_(
^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5249, in var_mean_helper_
else var_mean_welford_(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 5203, in var_mean_welford_
mean, m2, _ = ir.WelfordReduction.create(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py", line 1608, in create
hint, split = Reduction.num_splits(
^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py", line 851, in num_splits
not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 465, in has_feature
return feature in self.get_backend_features(get_device_type(device))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py", line 170, in get_backend_features
return scheduling(None).get_backend_features(device)
^^^^^^^^^^^^^^^^
torch._inductor.exc.LoweringException: TypeError: 'NoneType' object is not callable
target: aten.var_mean.correction
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf0', layout=FlexibleLayout('mps', torch.float32, size=[2, 32, 768], stride=[24576, 768, 1]), data=Pointwise(
'mps',
torch.float32,
def inner_fn(index):
i0, i1, i2 = index
tmp0 = ops.load(primals_2, i1 + 32 * i0)
tmp1 = ops.load(primals_4, i2 + 768 * tmp0)
tmp2 = ops.load(primals_1, i1 + 32 * i0)
tmp3 = ops.load(primals_5, i2 + 768 * tmp2)
tmp4 = tmp1 + tmp3
tmp5 = ops.load(primals_3, i1)
tmp6 = ops.load(primals_6, i2 + 768 * tmp5)
tmp7 = tmp4 + tmp6
return tmp7
,
ranges=[2, 32, 768],
origin_node=add_1,
origins=OrderedSet([embedding_2, embedding, add, add_1, embed...
))
))
args[1]: [2]
kwargs: {'correction': 0, 'keepdim': True}
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/A200009373/Documents/Coding/pylate/test2-add-embeddings.py", line 79, in <module>
trainer.train()
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2123, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3579, in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/sentence_transformers/trainer.py", line 393, in compute_loss
loss = loss_fn(features, labels)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/pylate/losses/distillation.py", line 83, in forward
self.model(sentence_features[0])["token_embeddings"], p=2, dim=-1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
return _compile(
^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
tracer.run()
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
super().run()
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
self._return(inst)
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
self.output.compile_subgraph(
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1142, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/A200009373/Documents/Coding/pylate/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: TypeError: 'NoneType' object is not callable
target: aten.var_mean.correction
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf0', layout=FlexibleLayout('mps', torch.float32, size=[2, 32, 768], stride=[24576, 768, 1]), data=Pointwise(
'mps',
torch.float32,
def inner_fn(index):
i0, i1, i2 = index
tmp0 = ops.load(primals_2, i1 + 32 * i0)
tmp1 = ops.load(primals_4, i2 + 768 * tmp0)
tmp2 = ops.load(primals_1, i1 + 32 * i0)
tmp3 = ops.load(primals_5, i2 + 768 * tmp2)
tmp4 = tmp1 + tmp3
tmp5 = ops.load(primals_3, i1)
tmp6 = ops.load(primals_6, i2 + 768 * tmp5)
tmp7 = tmp4 + tmp6
return tmp7
,
ranges=[2, 32, 768],
origin_node=add_1,
origins=OrderedSet([embedding_2, embedding, add, add_1, embed...
))
))
args[1]: [2]
kwargs: {'correction': 0, 'keepdim': True}
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working