Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion examples/tutorials/Part 02 - Intro to Federated Learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,65 @@
"train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we will do the same training with the Federated optimizer which is capable of maintaining a set of optimizers each for worker. We can use an optimizer which has internal data cache of nature Adam in this way."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from syft.federated.federated_optimizer import FLOptimier\n",
"model = nn.Linear(2,1)\n",
"\n",
"def train_with_adam():\n",
" # Training Logic\n",
" fl_opt = FLOptimier(optim.Adam, lr=0.1)\n",
" for iter in range(10): \n",
"\n",
" # NEW) iterate through each worker's dataset\n",
" for data,target in datasets:\n",
" \n",
" # NEW) send model to correct worker\n",
" model.send(data.location)\n",
" # get the optimizer relavant for the current location of the model \n",
" opt = fl_opt.get_optimizer(model)\n",
" # 1) erase previous gradients (if they exist)\n",
" opt.zero_grad()\n",
"\n",
" # 2) make a prediction\n",
" pred = model(data)\n",
"\n",
" # 3) calculate how much we missed\n",
" loss = ((pred - target)**2).sum()\n",
"\n",
" # 4) figure out which weights caused us to miss\n",
" loss.backward()\n",
"\n",
" # 5) change those weights\n",
" opt.step()\n",
" \n",
" model.get()\n",
"\n",
" # 6) print our progress\n",
" print(loss.get()) # NEW) slight edit... need to call .get() on loss\\\n",
"# federated averaging"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_with_adam()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -298,4 +357,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
34 changes: 34 additions & 0 deletions syft/federated/federated_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""maintains an optimizer for each worker"""

from collections import defaultdict


class FLOptimier:
"""Creates a remote optimizer object
which manage an optimizer for each worker"""

def __init__(self, optimizer_class, **kwargs):
"""
Args:
optimizer_class: class of the pytorch optimizer
kwargs: arguments to be forwarded to the optimizer class
"""
self.optimizer_class = optimizer_class
self.opt_dict = defaultdict()
self.kwargs = kwargs

def get_optimizer(self, model):
""" adds an optimizer for the worker and returns the optimizer

Args:
model: model belonging to a worker
"""
if hasattr(model, "location"):
opt = self.opt_dict.setdefault(
model.location, self.optimizer_class(model.parameters(), **self.kwargs)
)
return opt
opt = self.opt_dict.setdefault(
"central", self.optimizer_class(model.parameters(), **self.kwargs)
)
return opt