Skip to content
Merged
63 changes: 43 additions & 20 deletions syft/frameworks/crypten/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import os

import syft as sy
from syft.messaging.message import CryptenInit
from syft.messaging.message import CryptenInitPlan, CryptenInitJail
from syft.frameworks.crypten import jail, utils

import crypten
from syft.frameworks.crypten.hook.hook import hook_plan_building, unhook_plan_building
Expand All @@ -24,9 +25,10 @@ def _launch(func, rank, world_size, master_addr, master_port, queue, func_args,
os.environ[key] = str(val)

crypten.init()
return_value = func(*func_args, **func_kwargs).tolist()
return_value = func(*func_args, **func_kwargs)
crypten.uninit()

return_value = utils.pack_values(return_value)
queue.put(return_value)


Expand Down Expand Up @@ -69,7 +71,7 @@ def run_party(func, rank, world_size, master_addr, master_port, func_args, func_
return res


def _send_party_info(worker, rank, msg, return_values):
def _send_party_info(worker, rank, msg, return_values, model=None):
"""Send message to worker with necessary information to run a crypten party.
Add response to return_values dictionary.

Expand All @@ -78,10 +80,11 @@ def _send_party_info(worker, rank, msg, return_values):
rank (int): rank of the crypten party.
msg (CryptenInitMessage): message containing the rank, world_size, master_addr and master_port.
return_values (dict): dictionnary holding return values of workers.
model: crypten model to unpack parameters to (if received).
"""

response = worker.send_msg(msg, worker)
return_values[rank] = response.object
return_values[rank] = utils.unpack_values(response.object, model)


def run_multiworkers(workers: list, master_addr: str, master_port: int = 15463):
Expand All @@ -93,8 +96,8 @@ def run_multiworkers(workers: list, master_addr: str, master_port: int = 15463):
master_port (int, str): port of the master party (party with rank 0), default is 15987.
"""

def decorator(plan):
@functools.wraps(plan)
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# TODO:
# - check if workers are reachable / they can handle the computation
Expand All @@ -103,29 +106,42 @@ def wrapper(*args, **kwargs):
world_size = len(workers) + 1
return_values = {rank: None for rank in range(world_size)}

# This is needed because at building we use a set of methods defined in syft (ex: load)
hook_plan_building()
crypten.init()
if isinstance(func, sy.Plan):
using_plan = True
plan = func

plan.build()
# This is needed because at building we use a set of methods defined in syft (ex: load)
hook_plan_building()
crypten.init()
plan.build()
crypten.uninit()
unhook_plan_building()

# Mark the plan so the other workers will use that tag to retrieve the plan
plan.tags = ["crypten_plan"]

for worker in workers:
plan.send(worker)

crypten.uninit()
unhook_plan_building()
jail_or_plan = plan

# Mark the plan so the other workers will use that tag to retrieve the plan
plan.tags = ["crypten_plan"]
else: # func
using_plan = False
jail_runner = jail.JailRunner(func=func)
ser_jail_runner = jail.JailRunner.simplify(jail_runner)

jail_or_plan = jail_runner

rank_to_worker_id = dict(
zip(range(1, len(workers) + 1), [worker.id for worker in workers])
)

sy.local_worker._set_rank_to_worker_id(rank_to_worker_id)

for worker in workers:
plan.send(worker)

# Start local party
process, queue = _new_party(plan, 0, world_size, master_addr, master_port, (), {})
process, queue = _new_party(
jail_or_plan, 0, world_size, master_addr, master_port, (), {}
)

was_initialized = DistributedCommunicator.is_initialized()
if was_initialized:
Expand All @@ -150,7 +166,14 @@ def wrapper(*args, **kwargs):
threads = []
for i in range(len(workers)):
rank = i + 1
msg = CryptenInit((rank_to_worker_id, world_size, master_addr, master_port))
if using_plan:
msg = CryptenInitPlan((rank_to_worker_id, world_size, master_addr, master_port))
else: # jail
msg = CryptenInitJail(
(rank_to_worker_id, world_size, master_addr, master_port),
ser_jail_runner,
None,
)
thread = threading.Thread(
target=_send_party_info, args=(workers[i], rank, msg, return_values)
)
Expand All @@ -159,7 +182,7 @@ def wrapper(*args, **kwargs):

# Wait for local party and sender threads
process.join()
return_values[0] = queue.get()
return_values[0] = utils.unpack_values(queue.get(), None)
for thread in threads:
thread.join()
if was_initialized:
Expand Down
82 changes: 73 additions & 9 deletions syft/messaging/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def detail(worker: AbstractWorker, msg_tuple: tuple) -> "WorkerCommandMessage":
)


class CryptenInit(Message):
class CryptenInitPlan(Message):
"""Initialize a Crypten party using this message.

Crypten uses processes as parties, those processes need to be initialized with information
Expand All @@ -699,34 +699,98 @@ def contents(self):
return (self.crypten_context,)

@staticmethod
def simplify(worker: AbstractWorker, ptr: "CryptenInit") -> tuple:
def simplify(worker: AbstractWorker, message: "CryptenInitPlan") -> tuple:
"""
This function takes the attributes of a CryptenInit and saves them in a tuple
This function takes the attributes of a CryptenInitPlan and saves them in a tuple

Args:
worker (AbstractWorker): a reference to the worker doing the serialization
ptr (CryptenInit): a Message
ptr (CryptenInitPlan): a Message

Returns:
tuple: a tuple holding the unique attributes of the message
"""
return (sy.serde.msgpack.serde._simplify(worker, ptr.crypten_context),)
return (sy.serde.msgpack.serde._simplify(worker, message.crypten_context),)

@staticmethod
def detail(worker: AbstractWorker, msg_tuple: tuple) -> "CryptenInit":
def detail(worker: AbstractWorker, msg_tuple: tuple) -> "CryptenInitPlan":
"""
This function takes the simplified tuple version of this message and converts
it into an CryptenInit. The simplify() method runs the inverse of this method.
it into a CryptenInitPlan. The simplify() method runs the inverse of this method.

Args:
worker (AbstractWorker): a reference to the worker necessary for detailing. Read
syft/serde/serde.py for more information on why this is necessary.
msg_tuple (Tuple): the raw information being detailed.

Returns:
CryptenInit message.
CryptenInitPlan message.

Examples:
message = detail(sy.local_worker, msg_tuple)
"""
return CryptenInit(sy.serde.msgpack.serde._detail(worker, msg_tuple[0]))
return CryptenInitPlan(sy.serde.msgpack.serde._detail(worker, msg_tuple[0]))


class CryptenInitJail(Message):
"""Initialize a Crypten party using this message.

Crypten uses processes as parties, those processes need to be initialized with information
so they can communicate and exchange tensors and shares while doing computation. This message
allows the exchange of information such as the ip and port of the master party to connect to,
as well as the rank of the party to run and the number of parties involved. Compared to
CryptenInitPlan, this message also sends two extra fields, a JailRunner and a Crypten model."""

def __init__(self, crypten_context, jail_runner, model=None):
# crypten_context = (rank_to_worker_ids, world_size, master_addr, master_port)
self.crypten_context = crypten_context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one example of Crypten context would be useful

self.jail_runner = jail_runner
self.model = model

def __str__(self):
"""Return a human readable version of this message"""
return f"({type(self).__name__} {self.crypten_context}, {self.jail_runner})"

@property
def contents(self):
"""Returns a tuple with the contents of the operation (backwards compatability)."""
return (self.crypten_context, self.jail_runner, self.model)

@staticmethod
def simplify(worker: AbstractWorker, message: "CryptenInitJail") -> tuple:
"""
This function takes the attributes of a CryptenInitJail and saves them in a tuple

Args:
worker (AbstractWorker): a reference to the worker doing the serialization
ptr (CryptenInitJail): a Message

Returns:
tuple: a tuple holding the unique attributes of the message
"""
return (
sy.serde.msgpack.serde._simplify(
worker, (*message.crypten_context, message.jail_runner, message.model)
),
)

@staticmethod
def detail(worker: AbstractWorker, msg_tuple: tuple) -> "CryptenInitJail":
"""
This function takes the simplified tuple version of this message and converts
it into a CryptenInitJail. The simplify() method runs the inverse of this method.

Args:
worker (AbstractWorker): a reference to the worker necessary for detailing. Read
syft/serde/serde.py for more information on why this is necessary.
msg_tuple (Tuple): the raw information being detailed.

Returns:
CryptenInitJail message.

Examples:
message = detail(sy.local_worker, msg_tuple)
"""
msg_tuple = sy.serde.msgpack.serde._detail(worker, msg_tuple[0])
*context, jail_runner, model = msg_tuple
return CryptenInitJail(tuple(context), jail_runner, model)
5 changes: 3 additions & 2 deletions syft/serde/msgpack/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from syft.messaging.message import SearchMessage
from syft.messaging.message import PlanCommandMessage
from syft.messaging.message import WorkerCommandMessage
from syft.messaging.message import CryptenInit
from syft.messaging.message import CryptenInitPlan, CryptenInitJail
from syft.serde import compression
from syft.serde import msgpack
from syft.serde.msgpack.native_serde import MAP_NATIVE_SIMPLIFIERS_AND_DETAILERS
Expand Down Expand Up @@ -138,7 +138,8 @@
SearchMessage,
PlanCommandMessage,
WorkerCommandMessage,
CryptenInit,
CryptenInitPlan,
CryptenInitJail,
GradFunc,
String,
BaseDataset,
Expand Down
39 changes: 33 additions & 6 deletions syft/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
from syft.messaging.message import ForceObjectDeleteMessage
from syft.messaging.message import GetShapeMessage
from syft.messaging.message import IsNoneMessage
from syft.messaging.message import CryptenInit
from syft.messaging.message import CryptenInitPlan
from syft.messaging.message import CryptenInitJail
from syft.messaging.message import Message
from syft.messaging.message import ObjectMessage
from syft.messaging.message import ObjectRequestMessage
from syft.messaging.message import PlanCommandMessage
from syft.messaging.message import SearchMessage
from syft.workers.abstract import AbstractWorker
from syft.messaging.message import CryptenInit

from syft.frameworks.crypten import run_party

Expand Down Expand Up @@ -132,7 +132,8 @@ def __init__(
IsNoneMessage: self.is_object_none,
GetShapeMessage: self.handle_get_shape_message,
SearchMessage: self.respond_to_search,
CryptenInit: self.run_crypten_party,
CryptenInitPlan: self.run_crypten_party_plan,
CryptenInitJail: self.run_crypten_party_jail,
}

self._plan_command_router = {
Expand Down Expand Up @@ -430,11 +431,11 @@ def send(

return pointer

def run_crypten_party(self, message: tuple):
def run_crypten_party_plan(self, message: CryptenInitPlan):
"""Run crypten party according to the information received.

Args:
message (CryptenInit): should contain the rank, world_size, master_addr and master_port.
message (CryptenInitPlan): should contain the rank, world_size, master_addr and master_port.

Returns:
An ObjectMessage containing the return value of the crypten function computed.
Expand All @@ -453,11 +454,37 @@ def run_crypten_party(self, message: tuple):
rank = r
break

assert rank != None
assert rank is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing this :D


return_value = run_party(plan, rank, world_size, master_addr, master_port, (), {})
return ObjectMessage(return_value)

def run_crypten_party_jail(self, message: CryptenInitJail):
"""Run crypten party according to the information received.

Args:
message (CryptenInitJail): should contain the rank, world_size, master_addr and master_port.

Returns:
An ObjectMessage containing the return value of the crypten function computed.
"""
from syft.frameworks.crypten.jail import JailRunner

self.rank_to_worker_id, world_size, master_addr, master_port = message.crypten_context
ser_func = message.jail_runner
jail_runner = JailRunner.detail(ser_func)

rank = None
for r, worker_id in self.rank_to_worker_id.items():
if worker_id == self.id:
rank = r
break

assert rank is not None

return_value = run_party(jail_runner, rank, world_size, master_addr, master_port, (), {})
return ObjectMessage(return_value)

def handle_object_msg(self, obj_msg: ObjectMessage):
# This should be a good seam for separating Workers from ObjectStorage (someday),
# so that Workers have ObjectStores instead of being ObjectStores. That would open
Expand Down
6 changes: 2 additions & 4 deletions test/crypten/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ def plan_func(crypten=crypten):

return_values = plan_func()

# A toy function is ran at each party, and they should all decrypt
# a tensor with value [143, 85]
expected_value = [143, 85, 32, 4]
expected_value = th.tensor([143, 85, 32, 4])
for rank in range(n_workers):
assert (
assert th.all(
return_values[rank] == expected_value
), "Crypten party with rank {} don't match expected value {} != {}".format(
rank, return_values[rank], expected_value
Expand Down
3 changes: 2 additions & 1 deletion test/serde/msgpack/test_msgpack_serde_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@
samples[syft.messaging.message.SearchMessage] = make_searchmessage
samples[syft.messaging.message.PlanCommandMessage] = make_plancommandmessage
samples[syft.messaging.message.WorkerCommandMessage] = make_workercommandmessage
samples[syft.messaging.message.CryptenInit] = make_crypteninit
samples[syft.messaging.message.CryptenInitPlan] = make_crypteninitplan
samples[syft.messaging.message.CryptenInitJail] = make_crypteninitjail

samples[syft.frameworks.torch.tensors.interpreters.gradients_core.GradFunc] = make_gradfn

Expand Down
Loading