Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,78 +51,70 @@
"metadata": {},
"outputs": [],
"source": [
"class SplitNN:\n",
"import numpy as np\n",
"import torch\n",
"import torchvision\n",
"import matplotlib.pyplot as plt\n",
"from time import time\n",
"from torchvision import datasets, transforms\n",
"from torch import nn, optim\n",
"import syft as sy\n",
"import time\n",
"hook = sy.TorchHook(torch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SplitNN(torch.nn.Module):\n",
" def __init__(self, models, optimizers):\n",
" self.models = models\n",
" self.optimizers = optimizers\n",
" self.outputs = [None]*len(self.models)\n",
" self.inputs = [None]*len(self.models)\n",
" super().__init__()\n",
" \n",
" def forward(self, x):\n",
" a = []\n",
" remote_a = []\n",
" \n",
" a.append(models[0](x))\n",
" if a[-1].location == models[1].location:\n",
" remote_a.append(a[-1].detach().requires_grad_())\n",
" else:\n",
" remote_a.append(a[-1].detach().move(models[1].location).requires_grad_())\n",
"\n",
" i=1 \n",
" while i < (len(models)-1):\n",
" \n",
" a.append(models[i](remote_a[-1]))\n",
" if a[-1].location == models[i+1].location:\n",
" remote_a.append(a[-1].detach().requires_grad_())\n",
" else:\n",
" remote_a.append(a[-1].detach().move(models[i+1].location).requires_grad_())\n",
" \n",
" i+=1\n",
" self.inputs[0] = x\n",
" self.outputs[0] = self.models[0](self.inputs[0])\n",
" \n",
" a.append(models[i](remote_a[-1]))\n",
" self.a = a\n",
" self.remote_a = remote_a\n",
" for i in range(1, len(self.models)):\n",
" self.inputs[i] = self.outputs[i-1].detach().requires_grad_()\n",
" if self.outputs[i-1].location != self.models[i].location:\n",
" self.inputs[i] = self.inputs[i].move(self.models[i].location).requires_grad_() \n",
" self.outputs[i] = self.models[i](self.inputs[i])\n",
" \n",
" return a[-1]\n",
" return self.outputs[-1]\n",
" \n",
" def backward(self):\n",
" a=self.a\n",
" remote_a=self.remote_a\n",
" optimizers = self.optimizers\n",
" \n",
" i= len(models)-2 \n",
" while i > -1:\n",
" if remote_a[i].location == a[i].location:\n",
" grad_a = remote_a[i].grad.copy()\n",
" else:\n",
" grad_a = remote_a[i].grad.copy().move(a[i].location)\n",
" a[i].backward(grad_a)\n",
" i-=1\n",
"\n",
" for i in range(len(self.models)-2, -1, -1):\n",
" grad_in = self.inputs[i+1].grad.copy()\n",
" if self.outputs[i].location != self.inputs[i+1].location:\n",
" grad_in = grad_in.move(self.outputs[i].location)\n",
" self.outputs[i].backward(grad_in)\n",
" \n",
" def zero_grads(self):\n",
" for opt in optimizers:\n",
" for opt in self.optimizers:\n",
" opt.zero_grad()\n",
" \n",
" def step(self):\n",
" for opt in optimizers:\n",
" opt.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torchvision\n",
"import matplotlib.pyplot as plt\n",
"from time import time\n",
"from torchvision import datasets, transforms\n",
"from torch import nn, optim\n",
"import syft as sy\n",
"import time\n",
"hook = sy.TorchHook(torch)"
" for opt in self.optimizers:\n",
" opt.step()\n",
" \n",
" def train(self):\n",
" for model in self.models:\n",
" model.train()\n",
" \n",
" def eval(self):\n",
" for model in self.models:\n",
" model.eval()\n",
" \n",
" @property\n",
" def location(self):\n",
" return self.models[0].location if self.models and len(self.models) else None"
]
},
{
Expand Down Expand Up @@ -200,7 +192,7 @@
" \n",
" #2) Make a prediction\n",
" pred = splitNN.forward(x)\n",
" \n",
" \n",
" #3) Figure out how much we missed by\n",
" criterion = nn.NLLLoss()\n",
" loss = criterion(pred, target)\n",
Expand All @@ -225,6 +217,7 @@
"source": [
"for i in range(epochs):\n",
" running_loss = 0\n",
" splitNN.train()\n",
" for images, labels in trainloader:\n",
" images = images.send(models[0].location)\n",
" images = images.view(images.shape[0], -1)\n",
Expand All @@ -235,6 +228,40 @@
" else:\n",
" print(\"Epoch {} - Training loss: {}\".format(i, running_loss/len(trainloader)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def test(model, dataloader, dataset_name):\n",
" model.eval()\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for data, target in dataloader:\n",
" data = data.view(data.shape[0], -1).send(model.location)\n",
" output = model(data).get()\n",
" pred = output.data.max(1, keepdim=True)[1]\n",
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
"\n",
" print(\"{}: Accuracy {}/{} ({:.0f}%)\".format(dataset_name, \n",
" correct,\n",
" len(dataloader.dataset), \n",
" 100. * correct / len(dataloader.dataset)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"testset = datasets.MNIST('mnist', download=True, train=False, transform=transform)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)\n",
"test(splitNN, testloader, \"Test set\")\n",
"test(splitNN, trainloader, \"Train set\")"
]
}
],
"metadata": {
Expand Down