Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pip-dep/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ torch~=1.4.0
websocket_client~=0.57.0
websockets~=8.1.0
zstd~=1.4.4.0
git+https://github.com/facebookresearch/CrypTen.git@68e0364c66df95ddbb98422fb641382c3f58734c#egg=crypten
git+https://github.com/facebookresearch/CrypTen.git@e39a7aaf65436706321fe4e3fc055308c78b6b92#egg=crypten
14 changes: 12 additions & 2 deletions syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,24 @@
# Tensorflow / Keras dependencies
# Import Hooks

__all__ = []

if dependency_check.tfe_available:
from syft.frameworks.keras import KerasHook
from syft.workers.tfe import TFECluster
from syft.workers.tfe import TFEWorker

__all__ = ["KerasHook", "TFECluster", "TFEWorker"]
__all__.extend(["KerasHook", "TFECluster", "TFEWorker"])
else:
logger.info("TF Encrypted Keras not available.")
__all__ = []

if dependency_check.crypten_available:
from syft.frameworks.torch.tensors.crypten.syft_crypten import SyftCrypTensor

__all__.extend(["SyftCrypTensor"])
else:
logger.info("CrypTen not available.")


# Pytorch dependencies
# Import Hook
Expand Down Expand Up @@ -113,6 +122,7 @@ def pool():
"AutogradTensor",
"FixedPrecisionTensor",
"LargePrecisionTensor",
"SyftCrypTensor",
"PointerTensor",
"MultiPointerTensor",
"PrivateGridNetwork",
Expand Down
5 changes: 3 additions & 2 deletions syft/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
except ImportError:
tensorflow_available = False


tfe_spec = util.find_spec("tf_encrypted")
tfe_available = tfe_spec is not None


torch_spec = util.find_spec("torch")
torch_available = torch_spec is not None

crypten_spec = util.find_spec("crypten")
crypten_available = crypten_spec is not None
6 changes: 6 additions & 0 deletions syft/execution/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import torch

import syft as sy
from syft import dependency_check
from syft.generic.frameworks.types import FrameworkTensor
from syft.generic.frameworks.types import FrameworkLayerModule

from syft.generic.object import AbstractObject
from syft.generic.object_storage import ObjectStorage
from syft.generic.pointers.pointer_plan import PointerPlan
Expand All @@ -20,6 +22,9 @@
from syft_proto.execution.v1.plan_pb2 import Plan as PlanPB
from syft_proto.messaging.v1.message_pb2 import OperationMessage as OperationMessagePB

if dependency_check.crypten_available:
import crypten


class func2plan(object):
"""Decorator which converts a function to a plan.
Expand Down Expand Up @@ -383,6 +388,7 @@ def __call__(self, *args, **kwargs):
response = eval(cmd)(*args, **kwargs) # nosec
else:
response = getattr(_self, cmd)(*args, **kwargs)

return_placeholder.instantiate(response.child)

# This ensures that we return the output placeholder in the correct order
Expand Down
18 changes: 12 additions & 6 deletions syft/frameworks/crypten/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import torch
import syft

from syft.frameworks.crypten.context import toy_func, run_party
from syft.frameworks.crypten.context import run_party

import crypten.communicator as comm
import crypten


def load(tag: str, src: int):
if src == comm.get().get_rank():
results = syft.local_worker.search(tag)
# Means the data is on one of our local workers

worker = syft.local_worker.get_worker_from_rank(src)
results = worker.search(tag)

# Make sure there is only one result
assert len(results) == 1
Expand All @@ -22,12 +25,13 @@ def load(tag: str, src: int):
load_type = torch.tensor(0, dtype=torch.long)
comm.get().broadcast(load_type, src=src)

# Broadcast size to other parties.
# Broadcast size to other parties if it was not provided
dim = torch.tensor(result.dim(), dtype=torch.long)
size = torch.tensor(result.size(), dtype=torch.long)

comm.get().broadcast(dim, src=src)
comm.get().broadcast(size, src=src)

result = crypten.mpc.MPCTensor(result, src=src)
else:
raise TypeError("Unrecognized load type on src")
Expand All @@ -39,16 +43,18 @@ def load(tag: str, src: int):

# Load in tensor
if load_type.item() == 0:
# Receive size from source party
# Receive size from source party if it was not provided
dim = torch.empty(size=(), dtype=torch.long)
comm.get().broadcast(dim, src=src)
size = torch.empty(size=(dim.item(),), dtype=torch.long)
comm.get().broadcast(size, src=src)
result = crypten.mpc.MPCTensor(torch.empty(size=tuple(size.tolist())), src=src)
size = tuple(size.tolist())

result = crypten.mpc.MPCTensor(torch.empty(size=size), src=src)
else:
raise TypeError("Unrecognized load type on src")

return result


__all__ = ["toy_func", "run_party", "load"]
__all__ = ["run_party", "load", "get_plain_text"]
40 changes: 24 additions & 16 deletions syft/frameworks/crypten/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import syft as sy
from syft.messaging.message import CryptenInit
from syft.frameworks import crypten as syft_crypt

import crypten
from crypten.communicator import DistributedCommunicator
Expand All @@ -24,7 +23,7 @@ def _launch(func, rank, world_size, master_addr, master_port, queue, func_args,
os.environ[key] = str(val)

crypten.init()
return_value = func(*func_args, **func_kwargs)
return_value = func(*func_args, **func_kwargs).tolist()
crypten.uninit()

queue.put(return_value)
Expand Down Expand Up @@ -65,7 +64,8 @@ def run_party(func, rank, world_size, master_addr, master_port, func_args, func_
process.join()
if was_initialized:
crypten.init()
return queue.get()
res = queue.get()
return res


def _send_party_info(worker, rank, msg, return_values):
Expand All @@ -83,15 +83,7 @@ def _send_party_info(worker, rank, msg, return_values):
return_values[rank] = response.contents


def toy_func():
alice_tensor = syft_crypt.load("crypten_data", 1)
bob_tensor = syft_crypt.load("crypten_data", 2)

crypt = crypten.cat([alice_tensor, bob_tensor], dim=0)
return crypt.get_plain_text().tolist()


def run_multiworkers(workers: list, master_addr: str, master_port: int = 15987):
def run_multiworkers(workers: list, master_addr: str, master_port: int = 15448):
"""Defines decorator to run function across multiple workers.

Args:
Expand All @@ -100,8 +92,8 @@ def run_multiworkers(workers: list, master_addr: str, master_port: int = 15987):
master_port (int, str): port of the master party (party with rank 0), default is 15987.
"""

def decorator(func):
@functools.wraps(func)
def decorator(plan):
@functools.wraps(plan)
def wrapper(*args, **kwargs):
# TODO:
# - check if workers are reachable / they can handle the computation
Expand All @@ -110,12 +102,28 @@ def wrapper(*args, **kwargs):
world_size = len(workers) + 1
return_values = {rank: None for rank in range(world_size)}

plan.build()

# Mark the plan so the other workers will use that tag to retrieve the plan
plan.tags = ["crypten_plan"]

rank_to_worker_id = dict(
zip(range(1, len(workers) + 1), [worker.id for worker in workers])
)

sy.local_worker._set_rank_to_worker_id(rank_to_worker_id)

for worker in workers:
plan.send(worker)

# Start local party
process, queue = _new_party(toy_func, 0, world_size, master_addr, master_port, (), {})
process, queue = _new_party(plan, 0, world_size, master_addr, master_port, (), {})

was_initialized = DistributedCommunicator.is_initialized()
if was_initialized:
crypten.uninit()
process.start()

# Run TTP if required
# TODO: run ttp in a specified worker
if crypten.mpc.ttp_required():
Expand All @@ -134,7 +142,7 @@ def wrapper(*args, **kwargs):
threads = []
for i in range(len(workers)):
rank = i + 1
msg = CryptenInit((rank, world_size, master_addr, master_port))
msg = CryptenInit((rank_to_worker_id, world_size, master_addr, master_port))
thread = threading.Thread(
target=_send_party_info, args=(workers[i], rank, msg, return_values)
)
Expand Down
22 changes: 22 additions & 0 deletions syft/frameworks/crypten/hook/hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from functools import wraps

import crypten
from syft.generic.frameworks.hook.trace import tracer
from syft.frameworks.torch.tensors.crypten.syft_crypten import SyftCrypTensor
import torch as th


def get_hooked_crypten_func(func_api_name, func):
cmd_name = f"crypten.{func_api_name}"

@tracer(func_name=cmd_name)
@wraps(func)
def overloaded_func(*args, **kwargs):
try:
response = SyftCrypTensor(tensor=func(*args, **kwargs)).wrap()
except RuntimeError:
response = SyftCrypTensor(tensor=th.zeros([])).wrap()

return response

return overloaded_func
38 changes: 34 additions & 4 deletions syft/frameworks/torch/hook/hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from collections import defaultdict
from functools import wraps
import logging
from math import inf
Expand All @@ -8,6 +9,7 @@
import weakref

import syft
from syft import dependency_check
from syft.generic.frameworks.hook import hook_args
from syft.generic.frameworks.hook.hook import FrameworkHook
from syft.generic.frameworks.hook.trace import Trace
Expand All @@ -34,6 +36,10 @@

from syft.exceptions import route_method_exception

if dependency_check.crypten_available:
import crypten
from syft.frameworks.torch.tensors.crypten.syft_crypten import SyftCrypTensor


class TorchHook(FrameworkHook):
"""A Hook which Overrides Methods on PyTorch Tensors.
Expand Down Expand Up @@ -130,10 +136,15 @@ def __init__(
else:
self.local_worker.hook = self

self.to_auto_overload = {}
self.to_auto_overload = defaultdict(set)

self.args_hook_for_overloaded_attr = {}

# Hook the Crypten module
# We do because SyftCrypTensor (wrapper in Syft) and MPCTensor (from Crypten)
if dependency_check.crypten_available:
self._hook_crypten()

self._hook_native_tensor(torch.Tensor, TorchTensor)

# Add all hooked tensor methods to pointer but change behaviour to have the cmd sent
Expand Down Expand Up @@ -161,6 +172,12 @@ def __init__(
# SyftTensor class file)
self._hook_syft_tensor_methods(FixedPrecisionTensor)

# Add all hooked tensor methods to SyftCrypTensor tensor but change behaviour
# to just forward the cmd to the next child (behaviour can be changed in the
# SyftTensor class file)
if dependency_check.crypten_available:
self._hook_syft_tensor_methods(SyftCrypTensor)

# Add all hooked tensor methods to AutogradTensor tensor but change behaviour
# to just forward the cmd to the next child (behaviour can be changed in the
# SyftTensor class file)
Expand Down Expand Up @@ -254,9 +271,8 @@ def _hook_native_tensor(self, tensor_type: type, syft_type: type):

# Returns a list of methods to be overloaded, stored in the dict to_auto_overload
# with tensor_type as a key
self.to_auto_overload[tensor_type] = self._which_methods_should_we_auto_overload(
tensor_type
)
to_overload = self._which_methods_should_we_auto_overload(tensor_type)
self.to_auto_overload[tensor_type].update(to_overload)

# [We don't rename native methods as torch tensors are not hooked] Rename native functions
# #self._rename_native_functions(tensor_type)
Expand Down Expand Up @@ -472,6 +488,20 @@ def _hook_torch_module(self):

self._perform_function_overloading(module_name, torch_module, func)

def _hook_crypten(self):
from syft.frameworks.crypten import load as crypten_load
from syft.frameworks.crypten.hook.hook import get_hooked_crypten_func

native_func = getattr(crypten, "load")
setattr(crypten, "native_load", native_func) # Currenty we do nothing with the native load

new_func = get_hooked_crypten_func("load", crypten_load)
setattr(crypten, "load", new_func)

crypten_specific_methods = ["get_plain_text"]
for method in crypten_specific_methods:
self.to_auto_overload[torch.Tensor].add(method)

@classmethod
def _get_hooked_func(cls, public_module_name, func_api_name, attr):
"""Torch-specific implementation. See the subclass for more."""
Expand Down
Empty file.
45 changes: 45 additions & 0 deletions syft/frameworks/torch/tensors/crypten/syft_crypten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch

from crypten.mpc import MPCTensor

from syft.generic.frameworks.hook import hook_args
from syft.generic.tensor import AbstractTensor

from functools import wraps
from syft.generic.frameworks.hook.trace import tracer
from syft.generic.frameworks.overload import overloaded


class SyftCrypTensor(AbstractTensor):
def __init__(
self, owner=None, id=None, tensor=None, tags: set = None, description: str = None,
):
super().__init__(id=id, owner=owner, tags=tags, description=description)
self.child = tensor

def get_class_attributes(self):
"""
Specify all the attributes need to build a wrapper correctly when returning a response,
"""
# TODO: what we should return specific for this one?
return {}

@property
def data(self):
return self

@data.setter
def data(self, new_data):
self.child = new_data.child
return self

def get_plain_text(self, dst=None):
"""Decrypts the tensor."""
return self.child.get_plain_text(dst=dst)


### Register the tensor with hook_args.py ###
hook_args.default_register_tensor(SyftCrypTensor)

### This is needed to build the wrap around MPCTensor
hook_args.default_register_tensor(MPCTensor)
Loading