Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
"\n",
"### What is a SplitNN?\n",
"\n",
"The training of a neural network (NN) is 'split' accross one or more hosts. Each model segment is a self contained NN that feeds into the segment in front. In this example Alice has unlabeled training data and the bottom of the network whereas Bob has the corresponding labels and the top of the network. The image below shows this training process where Bob has all the labels and there are multiple alices with <i>X</i> data [[1](https://arxiv.org/abs/1810.06060)]. Once Alice$_1$ has trained she sends a copy of her trained bottom model to the next Alice. This continues until Alice$_n$ has trained.\n",
"The training of a neural network (NN) is 'split' accross one or more hosts. Each model segment is a self contained NN that feeds into the segment in front. In this example Alice has unlabeled training data and the bottom of the network whereas Bob has the corresponding labels and the top of the network. The image below shows this training process where Bob has all the labels and there are multiple alices with <i>X</i> data [[1](https://arxiv.org/abs/1810.06060)]. Once $Alice_1$ has trained she sends a copy of her trained bottom model to the next Alice. This continues until $Alice_n$ has trained.\n",
"\n",
"<img src=\"images/P2P-DL.png\" width=\"50%\" alt=\"dominating_sets_example2\">\n",
"\n",
"In this case, both parties can train the model without knowing each others data or full details of the model. When Alice is finished training, she passes it to the next person with data.\n",
"\n",
"### Why use a SplitNN?\n",
"\n",
"The SplitNN has been shown to provide a dramatic reduction to the computational burden of training while maintaining higher accuracies when training over large number of clients [[2](https://arxiv.org/abs/1812.00564)]. In the figure below, the Blue line denotes distributed deep learning using splitNN, red line indicate federated learning (FL) and green line indicates Large Batch Stochastic Gradient Descent (LBSGD).\n",
"The SplitNN has been shown to provide a dramatic reduction to the computational burden of training while maintaining higher accuracies when training over large number of clients [[2](https://arxiv.org/abs/1812.00564)]. In the figure below, the Blue line denotes distributed deep learning using SplitNN, red line indicate federated learning (FL) and green line indicates Large Batch Stochastic Gradient Descent (LBSGD).\n",
"\n",
"<img src=\"images/AccuracyvsFlops.png\" width=\"100%\">\n",
"\n",
Expand All @@ -48,20 +48,20 @@
"- The scalability of this approach, in terms of both network and computational resources, could make this an a valid alternative to FL and LBSGD, particularly on low power devices.\n",
"- This could be an effective mechanism for both horizontal and vertical data distributions.\n",
"- As computational cost is already quite low, the cost of applying homomorphic encryption is also minimised.\n",
"- Only activation signal gradients are sent/ recieved, meaning that malicious actors cannot use gradients of model parameters to reverse engineer the original values\n",
"- Only activation signal gradients are sent/ recieved, meaning that malicious actors cannot use gradients of model parameters to reverse engineer the original values.\n",
"\n",
"### Constraints\n",
"\n",
"- A new technique with little surroundung literature, a large amount of comparison and evaluation is still to be performed\n",
"- This approach requires all hosts to remain online during the entire learning process (less fesible for hand-held devices)\n",
"- Not as established in privacy-preserving toolkits as FL and LBSGD\n",
"- Activation signals and their corresponding gradients still have the capacity to leak information, however this is yet to be fully addressed in the literature\n",
"- A new technique with little surroundung literature, a large amount of comparison and evaluation is still to be performed.\n",
"- This approach requires all hosts to remain online during the entire learning process (less fesible for hand-held devices).\n",
"- Not as established in privacy-preserving toolkits as FL and LBSGD.\n",
"- Activation signals and their corresponding gradients still have the capacity to leak information, however this is yet to be fully addressed in the literature.\n",
"\n",
"### Tutorial \n",
"\n",
"This tutorial demonstrates a basic example of SplitNN which;\n",
"This tutorial demonstrates a basic example of SplitNN which\n",
"\n",
"- Has two paticipants; Alice and Bob.\n",
"- Has two paticipants: Alice and Bob.\n",
" - Bob has <i>labels</i>\n",
" - Alice has <i>X</i> values\n",
"- Has two model segments.\n",
Expand All @@ -72,7 +72,8 @@
"\n",
"Authors:\n",
"- Adam J Hall - Twitter: [@AJH4LL](https://twitter.com/AJH4LL) · GitHub: [@H4LL](https://github.com/H4LL)\n",
"- Théo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel) · GitHub: [@LaRiffle](https://github.com/LaRiffle)"
"- Théo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel) · GitHub: [@LaRiffle](https://github.com/LaRiffle)\n",
"- Haofan Wang - github:[@haofanwang](https://github.com/haofanwang)"
]
},
{
Expand All @@ -81,15 +82,10 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torchvision\n",
"import matplotlib.pyplot as plt\n",
"from time import time\n",
"from torchvision import datasets, transforms\n",
"from torch import nn, optim\n",
"import syft as sy\n",
"import time\n",
"hook = sy.TorchHook(torch)"
]
},
Expand All @@ -99,6 +95,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Data preprocessing\n",
"transform = transforms.Compose([transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,)),\n",
" ])\n",
Expand All @@ -114,6 +111,7 @@
"source": [
"torch.manual_seed(0)\n",
"\n",
"# Define our model segments\n",
"\n",
"input_size = 784\n",
"hidden_sizes = [128, 640]\n",
Expand Down Expand Up @@ -163,37 +161,37 @@
"def train(x, target, models, optimizers):\n",
" # Training Logic\n",
"\n",
" # 1) erase previous gradients (if they exist)\n",
" #1) erase previous gradients (if they exist)\n",
" for opt in optimizers:\n",
" opt.zero_grad()\n",
"\n",
" # 2) make a prediction\n",
" a = models[0](x)\n",
" #2) make a prediction\n",
" a = models[0](x)\n",
"\n",
" # 3) break the computation graph link, and send the activation signal to the next model\n",
" #3) break the computation graph link, and send the activation signal to the next model\n",
" remote_a = a.detach().move(models[1].location).requires_grad_()\n",
"\n",
" # 4) make prediction on next model using recieved signal\n",
" #4) make prediction on next model using recieved signal\n",
" pred = models[1](remote_a)\n",
"\n",
" # 5) calculate how much we missed\n",
" #5) calculate how much we missed\n",
" criterion = nn.NLLLoss()\n",
" loss = criterion(pred, target)\n",
"\n",
" # 6) figure out which weights caused us to miss\n",
" #6) figure out which weights caused us to miss\n",
" loss.backward()\n",
"\n",
" # 7) send gradient of the recieved activation signal to the model behind\n",
" #7) send gradient of the recieved activation signal to the model behind\n",
" grad_a = remote_a.grad.copy().move(models[0].location)\n",
"\n",
" # 8) backpropagate on bottom model given this gradient\n",
" #8) backpropagate on bottom model given this gradient\n",
" a.backward(grad_a)\n",
"\n",
" # 9) change the weights\n",
" #9) change the weights\n",
" for opt in optimizers:\n",
" opt.step()\n",
"\n",
" # 10) print our progress\n",
" #10) print our progress\n",
" return loss.detach().get()"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
"We use the exact same model as we used in the previous tutorial, only this time we are splitting over 3 hosts, not two. However, we see the same loss being reported as there is <b>no reduction in accuracy</b> when training in this way. While we only use 3 models this can be done for any arbitrary number of models.\n",
"\n",
"Author:\n",
"- Adam J Hall - Twitter: [@AJH4LL](https://twitter.com/AJH4LL) · GitHub: [@H4LL](https://github.com/H4LL)\n"
"- Adam J Hall - Twitter: [@AJH4LL](https://twitter.com/AJH4LL) · GitHub: [@H4LL](https://github.com/H4LL)\n",
"- Haofan Wang - github:[@haofanwang](https://github.com/haofanwang)"
]
},
{
Expand All @@ -51,15 +52,10 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torchvision\n",
"import matplotlib.pyplot as plt\n",
"from time import time\n",
"from torchvision import datasets, transforms\n",
"from torch import nn, optim\n",
"import syft as sy\n",
"import time\n",
"hook = sy.TorchHook(torch)"
]
},
Expand Down Expand Up @@ -123,6 +119,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Data preprocessing\n",
"transform = transforms.Compose([transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,)),\n",
" ])\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"\n",
"<b>Recap:</b> The previous tutorial looked at building a SplitNN. The NN was split into three segments on three seperate hosts, where one host had data and another had labels. However, what if someone has data and labels in the same place? \n",
"\n",
"<b>Description: </b> Here we fold a multilayer SplitNN in on itself in order to accomodate the data nd labels being in the same place. We demonstrate the SplitNN class with a 3 segment distribution. This time;\n",
"<b>Description: </b> Here we fold a multilayer SplitNN in on itself in order to accomodate the data nd labels being in the same place. We demonstrate the SplitNN class with a 3 segment distribution. This time,\n",
"\n",
"<img src=\"images/FoldedNN.png\" width=\"20%\">\n",
"\n",
Expand All @@ -37,7 +37,8 @@
"Again, we use the exact same model as we used in the previous tutorial and see the same accuracy. Neither Alice nor Bob have the full model and Bob can't see Alice's data. \n",
"\n",
"Author:\n",
"- Adam J Hall - Twitter: [@AJH4LL](https://twitter.com/AJH4LL) · GitHub: [@H4LL](https://github.com/H4LL)\n"
"- Adam J Hall - Twitter: [@AJH4LL](https://twitter.com/AJH4LL) · GitHub: [@H4LL](https://github.com/H4LL)\n",
"- Haofan Wang - github:[@haofanwang](https://github.com/haofanwang)"
]
},
{
Expand Down Expand Up @@ -108,15 +109,10 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torchvision\n",
"import matplotlib.pyplot as plt\n",
"from time import time\n",
"from torchvision import datasets, transforms\n",
"from torch import nn, optim\n",
"import syft as sy\n",
"import time\n",
"hook = sy.TorchHook(torch)"
]
},
Expand All @@ -126,6 +122,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Data preprocessing\n",
"transform = transforms.Compose([transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,)),\n",
" ])\n",
Expand Down Expand Up @@ -178,7 +175,7 @@
" model.send(location)\n",
"\n",
"#Instantiate a SpliNN class with our distributed segments and their respective optimizers\n",
"splitNN = SplitNN(models, optimizers)"
"splitNN = SplitNN(models, optimizers)"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions syft/workers/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

class VirtualWorker(BaseWorker, FederatedClient):
def _send_msg(self, message: bin, location: BaseWorker) -> bin:
"""send message to worker location"""
if self.message_pending_time > 0:
if self.verbose:
print(f"pending time of {self.message_pending_time} seconds to send message...")
Expand All @@ -14,4 +15,5 @@ def _send_msg(self, message: bin, location: BaseWorker) -> bin:
return location._recv_msg(message)

def _recv_msg(self, message: bin) -> bin:
"""receieve message"""
return self.recv_msg(message)
24 changes: 16 additions & 8 deletions test/federated/test_federated_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,43 +11,51 @@
PRINT_IN_UNITTESTS = False

# To make execution deterministic to some extent
# for more information - refer https://pytorch.org/docs/stable/notes/randomness.html
# For more information - refer https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


def test_add_dataset():
# Create a client to execute federated learning
fed_client = FederatedClient()

# Create a dataset
dataset = "my_dataset"
fed_client.add_dataset(dataset, "string_dataset")
key = "string_dataset"
# Add new dataset
fed_client.add_dataset(dataset, key)

assert "string_dataset" in fed_client.datasets


def test_add_dataset_with_duplicate_key():
# Create a client to execute federated learning
fed_client = FederatedClient()

# Create a dataset
dataset = "my_dataset"
fed_client.add_dataset(dataset, "string_dataset")
key = "string_dataset"
# Add new dataset
fed_client.add_dataset(dataset, key)

assert "string_dataset" in fed_client.datasets

# Raise an error if the key is already exists
with pytest.raises(ValueError):
fed_client.add_dataset(dataset, "string_dataset")


def test_remove_dataset():
# Create a client to execute federated learning
fed_client = FederatedClient()

# Create a dataset
dataset = "my_dataset"
key = "string_dataset"
# Add new dataset
fed_client.add_dataset(dataset, key)

assert key in fed_client.datasets

# Remove new dataset
fed_client.remove_dataset(key)

assert key not in fed_client.datasets
Expand Down