-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Closed
Labels
Type: Bug 🐛Some functionality not working in the codebase as intendedSome functionality not working in the codebase as intended
Description
Describe the bug
I'm trying to finetune a alexnet model and i've set the parameters except for the final layer of the model to requires_grad=False and have created a new classification layer with the desired outputs i want. However the .send() function keeps throwing a runtime error `RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
import syft
import torch
from torchvision import models
import torch.nn as nn
hook = syft.TorchHook(torch)
worker = syft.VirtualWorker(hook, id="worker")
model = models.alexnet(pretrained=True)
for param in model.parameters():
param.requires_grad=False
model.classifier[6] = nn.Linear(model.classifier[6].in_features, 3)
model.send(worker)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-15-a250859d9a13> in <module>
----> 1 model.send(worker)
~/implementation/PyGrid/gateway/src/syft/syft/frameworks/torch/hook/hook.py in module_send_(nn_self, force_send, *dest, **kwargs)
608
609 if module_is_missing_grad(nn_self):
--> 610 create_grad_objects(nn_self)
611
612 for p in nn_self.parameters():
~/implementation/PyGrid/gateway/src/syft/syft/frameworks/torch/hook/hook.py in create_grad_objects(model)
600 for p in model.parameters():
601 o = p.sum()
--> 602 o.backward()
603 if p.grad is not None:
604 p.grad -= p.grad
~/implementation/PyGrid/gateway/src/syft/syft/generic/frameworks/hook/trace.py in trace_wrapper(*args, **kwargs)
81 syft.hook.trace.logs.append((command, response))
82 else:
---> 83 response = func(*args, **kwargs)
84
85 return response
~/implementation/PyGrid/gateway/src/syft/syft/generic/frameworks/hook/hook.py in overloaded_native_method(self, *args, **kwargs)
436 except BaseException as e:
437 # we can make some errors more descriptive with this method
--> 438 raise route_method_exception(e, self, args, kwargs)
439
440 else: # means that there is a wrapper to remove
~/implementation/PyGrid/gateway/src/syft/syft/generic/frameworks/hook/hook.py in overloaded_native_method(self, *args, **kwargs)
432
433 try:
--> 434 response = method(*args, **kwargs)
435
436 except BaseException as e:
~/anaconda3/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
193 products. Defaults to ``False``.
194 """
--> 195 torch.autograd.backward(self, gradient, retain_graph, create_graph)
196
197 def register_hook(self, hook):
~/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
97 Variable._execution_engine.run_backward(
98 tensors, grad_tensors, retain_graph, create_graph,
---> 99 allow_unreachable=True) # allow_unreachable flag
100
101
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
Type: Bug 🐛Some functionality not working in the codebase as intendedSome functionality not working in the codebase as intended