Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions syft/frameworks/torch/hook/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,10 +589,11 @@ def module_is_missing_grad(model):
def create_grad_objects(model):
"""Assigns gradient to model parameters if not assigned"""
for p in model.parameters():
o = p.sum()
o.backward()
if p.grad is not None:
p.grad -= p.grad
if p.requires_grad: # check if the object requires a grad object
o = p.sum()
o.backward()
if p.grad is not None:
p.grad -= p.grad

def module_send_(nn_self, *dest, force_send=False, **kwargs):
"""Overloads torch.nn instances so that they could be sent to other workers"""
Expand Down
42 changes: 39 additions & 3 deletions test/torch/hook/test_hook.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest
import torch


# import syft
import syft


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not available")
Expand Down Expand Up @@ -60,3 +58,41 @@ def test_param_data(): # pragma: no cover
param.data = data2
assert (param.data == data2).all()
assert param.is_cuda


def test_send_frozen():
hook = syft.TorchHook(torch)
worker = syft.VirtualWorker(hook, id="worker")

d_in, h, d_out = 1000, 100, 10

model = torch.nn.Sequential(
torch.nn.Linear(d_in, h), torch.nn.ReLU(), torch.nn.Linear(h, d_out)
)

for param in model.parameters():
param.requires_grad = False

model.send(worker)


def test_send_partially_frozen():
hook = syft.TorchHook(torch)
worker = syft.VirtualWorker(hook, id="worker")

d_in, h1, h2, d_out = 1000, 1000, 100, 10

model = torch.nn.Sequential(
torch.nn.Linear(d_in, h1),
torch.nn.ReLU(),
torch.nn.Linear(h1, h2),
torch.nn.ReLU(),
torch.nn.Linear(h2, d_out),
)

for layer_idx, param in enumerate(model.parameters()):
if layer_idx > 2: # freezing the first two layers
pass
param.requires_grad = False

model.send(worker)