diff --git a/pip-dep/requirements.txt b/pip-dep/requirements.txt index 67b22a5f7e5..88e354ab154 100644 --- a/pip-dep/requirements.txt +++ b/pip-dep/requirements.txt @@ -14,4 +14,4 @@ torch~=1.4.0 websocket_client~=0.57.0 websockets~=8.1.0 zstd~=1.4.4.0 -git+https://github.com/facebookresearch/CrypTen.git@68e0364c66df95ddbb98422fb641382c3f58734c#egg=crypten +git+https://github.com/facebookresearch/CrypTen.git@e39a7aaf65436706321fe4e3fc055308c78b6b92#egg=crypten diff --git a/syft/__init__.py b/syft/__init__.py index 1e680101b21..bc7fb758611 100644 --- a/syft/__init__.py +++ b/syft/__init__.py @@ -25,15 +25,24 @@ # Tensorflow / Keras dependencies # Import Hooks +__all__ = [] + if dependency_check.tfe_available: from syft.frameworks.keras import KerasHook from syft.workers.tfe import TFECluster from syft.workers.tfe import TFEWorker - __all__ = ["KerasHook", "TFECluster", "TFEWorker"] + __all__.extend(["KerasHook", "TFECluster", "TFEWorker"]) else: logger.info("TF Encrypted Keras not available.") - __all__ = [] + +if dependency_check.crypten_available: + from syft.frameworks.torch.tensors.crypten.syft_crypten import SyftCrypTensor + + __all__.extend(["SyftCrypTensor"]) +else: + logger.info("CrypTen not available.") + # Pytorch dependencies # Import Hook @@ -113,6 +122,7 @@ def pool(): "AutogradTensor", "FixedPrecisionTensor", "LargePrecisionTensor", + "SyftCrypTensor", "PointerTensor", "MultiPointerTensor", "PrivateGridNetwork", diff --git a/syft/dependency_check.py b/syft/dependency_check.py index cf250ee514b..5c89aecd567 100644 --- a/syft/dependency_check.py +++ b/syft/dependency_check.py @@ -14,10 +14,11 @@ except ImportError: tensorflow_available = False - tfe_spec = util.find_spec("tf_encrypted") tfe_available = tfe_spec is not None - torch_spec = util.find_spec("torch") torch_available = torch_spec is not None + +crypten_spec = util.find_spec("crypten") +crypten_available = crypten_spec is not None diff --git a/syft/execution/plan.py b/syft/execution/plan.py index 5c803f130bc..e990c8646f6 100644 --- a/syft/execution/plan.py +++ b/syft/execution/plan.py @@ -7,8 +7,10 @@ import torch import syft as sy +from syft import dependency_check from syft.generic.frameworks.types import FrameworkTensor from syft.generic.frameworks.types import FrameworkLayerModule + from syft.generic.object import AbstractObject from syft.generic.object_storage import ObjectStorage from syft.generic.pointers.pointer_plan import PointerPlan @@ -20,6 +22,9 @@ from syft_proto.execution.v1.plan_pb2 import Plan as PlanPB from syft_proto.messaging.v1.message_pb2 import OperationMessage as OperationMessagePB +if dependency_check.crypten_available: + import crypten + class func2plan(object): """Decorator which converts a function to a plan. @@ -383,6 +388,7 @@ def __call__(self, *args, **kwargs): response = eval(cmd)(*args, **kwargs) # nosec else: response = getattr(_self, cmd)(*args, **kwargs) + return_placeholder.instantiate(response.child) # This ensures that we return the output placeholder in the correct order diff --git a/syft/frameworks/crypten/__init__.py b/syft/frameworks/crypten/__init__.py index 0ec2d7b7b08..43e7cca8a71 100644 --- a/syft/frameworks/crypten/__init__.py +++ b/syft/frameworks/crypten/__init__.py @@ -1,7 +1,7 @@ import torch import syft -from syft.frameworks.crypten.context import toy_func, run_party +from syft.frameworks.crypten.context import run_party import crypten.communicator as comm import crypten @@ -9,7 +9,10 @@ def load(tag: str, src: int): if src == comm.get().get_rank(): - results = syft.local_worker.search(tag) + # Means the data is on one of our local workers + + worker = syft.local_worker.get_worker_from_rank(src) + results = worker.search(tag) # Make sure there is only one result assert len(results) == 1 @@ -22,12 +25,13 @@ def load(tag: str, src: int): load_type = torch.tensor(0, dtype=torch.long) comm.get().broadcast(load_type, src=src) - # Broadcast size to other parties. + # Broadcast size to other parties if it was not provided dim = torch.tensor(result.dim(), dtype=torch.long) size = torch.tensor(result.size(), dtype=torch.long) comm.get().broadcast(dim, src=src) comm.get().broadcast(size, src=src) + result = crypten.mpc.MPCTensor(result, src=src) else: raise TypeError("Unrecognized load type on src") @@ -39,16 +43,18 @@ def load(tag: str, src: int): # Load in tensor if load_type.item() == 0: - # Receive size from source party + # Receive size from source party if it was not provided dim = torch.empty(size=(), dtype=torch.long) comm.get().broadcast(dim, src=src) size = torch.empty(size=(dim.item(),), dtype=torch.long) comm.get().broadcast(size, src=src) - result = crypten.mpc.MPCTensor(torch.empty(size=tuple(size.tolist())), src=src) + size = tuple(size.tolist()) + + result = crypten.mpc.MPCTensor(torch.empty(size=size), src=src) else: raise TypeError("Unrecognized load type on src") return result -__all__ = ["toy_func", "run_party", "load"] +__all__ = ["run_party", "load", "get_plain_text"] diff --git a/syft/frameworks/crypten/context.py b/syft/frameworks/crypten/context.py index f295779b630..197f042cc3f 100644 --- a/syft/frameworks/crypten/context.py +++ b/syft/frameworks/crypten/context.py @@ -5,7 +5,6 @@ import syft as sy from syft.messaging.message import CryptenInit -from syft.frameworks import crypten as syft_crypt import crypten from crypten.communicator import DistributedCommunicator @@ -24,7 +23,7 @@ def _launch(func, rank, world_size, master_addr, master_port, queue, func_args, os.environ[key] = str(val) crypten.init() - return_value = func(*func_args, **func_kwargs) + return_value = func(*func_args, **func_kwargs).tolist() crypten.uninit() queue.put(return_value) @@ -65,7 +64,8 @@ def run_party(func, rank, world_size, master_addr, master_port, func_args, func_ process.join() if was_initialized: crypten.init() - return queue.get() + res = queue.get() + return res def _send_party_info(worker, rank, msg, return_values): @@ -83,15 +83,7 @@ def _send_party_info(worker, rank, msg, return_values): return_values[rank] = response.contents -def toy_func(): - alice_tensor = syft_crypt.load("crypten_data", 1) - bob_tensor = syft_crypt.load("crypten_data", 2) - - crypt = crypten.cat([alice_tensor, bob_tensor], dim=0) - return crypt.get_plain_text().tolist() - - -def run_multiworkers(workers: list, master_addr: str, master_port: int = 15987): +def run_multiworkers(workers: list, master_addr: str, master_port: int = 15448): """Defines decorator to run function across multiple workers. Args: @@ -100,8 +92,8 @@ def run_multiworkers(workers: list, master_addr: str, master_port: int = 15987): master_port (int, str): port of the master party (party with rank 0), default is 15987. """ - def decorator(func): - @functools.wraps(func) + def decorator(plan): + @functools.wraps(plan) def wrapper(*args, **kwargs): # TODO: # - check if workers are reachable / they can handle the computation @@ -110,12 +102,28 @@ def wrapper(*args, **kwargs): world_size = len(workers) + 1 return_values = {rank: None for rank in range(world_size)} + plan.build() + + # Mark the plan so the other workers will use that tag to retrieve the plan + plan.tags = ["crypten_plan"] + + rank_to_worker_id = dict( + zip(range(1, len(workers) + 1), [worker.id for worker in workers]) + ) + + sy.local_worker._set_rank_to_worker_id(rank_to_worker_id) + + for worker in workers: + plan.send(worker) + # Start local party - process, queue = _new_party(toy_func, 0, world_size, master_addr, master_port, (), {}) + process, queue = _new_party(plan, 0, world_size, master_addr, master_port, (), {}) + was_initialized = DistributedCommunicator.is_initialized() if was_initialized: crypten.uninit() process.start() + # Run TTP if required # TODO: run ttp in a specified worker if crypten.mpc.ttp_required(): @@ -134,7 +142,7 @@ def wrapper(*args, **kwargs): threads = [] for i in range(len(workers)): rank = i + 1 - msg = CryptenInit((rank, world_size, master_addr, master_port)) + msg = CryptenInit((rank_to_worker_id, world_size, master_addr, master_port)) thread = threading.Thread( target=_send_party_info, args=(workers[i], rank, msg, return_values) ) diff --git a/syft/frameworks/crypten/hook/hook.py b/syft/frameworks/crypten/hook/hook.py new file mode 100644 index 00000000000..413c54e5748 --- /dev/null +++ b/syft/frameworks/crypten/hook/hook.py @@ -0,0 +1,22 @@ +from functools import wraps + +import crypten +from syft.generic.frameworks.hook.trace import tracer +from syft.frameworks.torch.tensors.crypten.syft_crypten import SyftCrypTensor +import torch as th + + +def get_hooked_crypten_func(func_api_name, func): + cmd_name = f"crypten.{func_api_name}" + + @tracer(func_name=cmd_name) + @wraps(func) + def overloaded_func(*args, **kwargs): + try: + response = SyftCrypTensor(tensor=func(*args, **kwargs)).wrap() + except RuntimeError: + response = SyftCrypTensor(tensor=th.zeros([])).wrap() + + return response + + return overloaded_func diff --git a/syft/frameworks/torch/hook/hook.py b/syft/frameworks/torch/hook/hook.py index 9ac865f15dd..45645c8ad02 100644 --- a/syft/frameworks/torch/hook/hook.py +++ b/syft/frameworks/torch/hook/hook.py @@ -1,4 +1,5 @@ import copy +from collections import defaultdict from functools import wraps import logging from math import inf @@ -8,6 +9,7 @@ import weakref import syft +from syft import dependency_check from syft.generic.frameworks.hook import hook_args from syft.generic.frameworks.hook.hook import FrameworkHook from syft.generic.frameworks.hook.trace import Trace @@ -34,6 +36,10 @@ from syft.exceptions import route_method_exception +if dependency_check.crypten_available: + import crypten + from syft.frameworks.torch.tensors.crypten.syft_crypten import SyftCrypTensor + class TorchHook(FrameworkHook): """A Hook which Overrides Methods on PyTorch Tensors. @@ -130,10 +136,15 @@ def __init__( else: self.local_worker.hook = self - self.to_auto_overload = {} + self.to_auto_overload = defaultdict(set) self.args_hook_for_overloaded_attr = {} + # Hook the Crypten module + # We do because SyftCrypTensor (wrapper in Syft) and MPCTensor (from Crypten) + if dependency_check.crypten_available: + self._hook_crypten() + self._hook_native_tensor(torch.Tensor, TorchTensor) # Add all hooked tensor methods to pointer but change behaviour to have the cmd sent @@ -161,6 +172,12 @@ def __init__( # SyftTensor class file) self._hook_syft_tensor_methods(FixedPrecisionTensor) + # Add all hooked tensor methods to SyftCrypTensor tensor but change behaviour + # to just forward the cmd to the next child (behaviour can be changed in the + # SyftTensor class file) + if dependency_check.crypten_available: + self._hook_syft_tensor_methods(SyftCrypTensor) + # Add all hooked tensor methods to AutogradTensor tensor but change behaviour # to just forward the cmd to the next child (behaviour can be changed in the # SyftTensor class file) @@ -254,9 +271,8 @@ def _hook_native_tensor(self, tensor_type: type, syft_type: type): # Returns a list of methods to be overloaded, stored in the dict to_auto_overload # with tensor_type as a key - self.to_auto_overload[tensor_type] = self._which_methods_should_we_auto_overload( - tensor_type - ) + to_overload = self._which_methods_should_we_auto_overload(tensor_type) + self.to_auto_overload[tensor_type].update(to_overload) # [We don't rename native methods as torch tensors are not hooked] Rename native functions # #self._rename_native_functions(tensor_type) @@ -472,6 +488,20 @@ def _hook_torch_module(self): self._perform_function_overloading(module_name, torch_module, func) + def _hook_crypten(self): + from syft.frameworks.crypten import load as crypten_load + from syft.frameworks.crypten.hook.hook import get_hooked_crypten_func + + native_func = getattr(crypten, "load") + setattr(crypten, "native_load", native_func) # Currenty we do nothing with the native load + + new_func = get_hooked_crypten_func("load", crypten_load) + setattr(crypten, "load", new_func) + + crypten_specific_methods = ["get_plain_text"] + for method in crypten_specific_methods: + self.to_auto_overload[torch.Tensor].add(method) + @classmethod def _get_hooked_func(cls, public_module_name, func_api_name, attr): """Torch-specific implementation. See the subclass for more.""" diff --git a/syft/frameworks/torch/tensors/crypten/__init__.py b/syft/frameworks/torch/tensors/crypten/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/syft/frameworks/torch/tensors/crypten/syft_crypten.py b/syft/frameworks/torch/tensors/crypten/syft_crypten.py new file mode 100644 index 00000000000..e893cdaca85 --- /dev/null +++ b/syft/frameworks/torch/tensors/crypten/syft_crypten.py @@ -0,0 +1,45 @@ +import torch + +from crypten.mpc import MPCTensor + +from syft.generic.frameworks.hook import hook_args +from syft.generic.tensor import AbstractTensor + +from functools import wraps +from syft.generic.frameworks.hook.trace import tracer +from syft.generic.frameworks.overload import overloaded + + +class SyftCrypTensor(AbstractTensor): + def __init__( + self, owner=None, id=None, tensor=None, tags: set = None, description: str = None, + ): + super().__init__(id=id, owner=owner, tags=tags, description=description) + self.child = tensor + + def get_class_attributes(self): + """ + Specify all the attributes need to build a wrapper correctly when returning a response, + """ + # TODO: what we should return specific for this one? + return {} + + @property + def data(self): + return self + + @data.setter + def data(self, new_data): + self.child = new_data.child + return self + + def get_plain_text(self, dst=None): + """Decrypts the tensor.""" + return self.child.get_plain_text(dst=dst) + + +### Register the tensor with hook_args.py ### +hook_args.default_register_tensor(SyftCrypTensor) + +### This is needed to build the wrap around MPCTensor +hook_args.default_register_tensor(MPCTensor) diff --git a/syft/frameworks/torch/tensors/interpreters/native.py b/syft/frameworks/torch/tensors/interpreters/native.py index 253a9024acb..d6c2ff2ba55 100644 --- a/syft/frameworks/torch/tensors/interpreters/native.py +++ b/syft/frameworks/torch/tensors/interpreters/native.py @@ -830,6 +830,11 @@ def fix_prec_(self, *args, **kwargs): fix_precision_ = fix_prec_ + def get_plain_text(self, *args, **kwargs): + """Required for CrypTen -- for SyftCrypTensor + In case we reach this the simply return a copy of the tensor""" + return self.copy() + def _requires_large_precision(self, max_precision, base, precision_fractional): """Check if any of the elements in the tensor would require large precision. """ diff --git a/syft/generic/frameworks/hook/hook.py b/syft/generic/frameworks/hook/hook.py index c03b34a0ac0..ac814098884 100644 --- a/syft/generic/frameworks/hook/hook.py +++ b/syft/generic/frameworks/hook/hook.py @@ -398,7 +398,7 @@ def _get_hooked_method(cls, tensor_type, method_name): their child attribute if they exist If so, forward this method with the new args and new self, get response and "rebuild" the torch tensor wrapper upon all tensors found - If not, just execute the native torch methodn + If not, just execute the native torch method Args: attr (str): the method to hook @@ -501,7 +501,7 @@ def _get_hooked_private_method(cls, method_name): their child attribute if they exist If so, forward this method with the new args and new self, get response and "rebuild" the torch tensor wrapper upon all tensors found - If not, just execute the native torch methodn + If not, just execute the native torch method Args: attr (str): the method to hook @@ -539,7 +539,6 @@ def overloaded_native_method(self, *args, **kwargs): response = method(*new_args, **new_kwargs) response.parents = (self.id, new_self.id) - # For inplace methods, just directly return self if syft.framework.is_inplace_method(method_name): return self @@ -689,7 +688,7 @@ def overloaded_attr(self, *args, **kwargs): def _string_input_args_adaptor(cls, args: Tuple[object]): """ This method is used when hooking String methods. - + Some 'String' methods which are overriden from 'str' such as the magic '__add__' method expects an object of type 'str' as its first @@ -697,17 +696,17 @@ def _string_input_args_adaptor(cls, args: Tuple[object]): here is hooked to a String type, it will receive arguments of type 'String' not 'str' in some cases. This won't worker for the underlying hooked method - '__add__' of the 'str' type. + '__add__' of the 'str' type. That is why the 'String' argument to '__add__' should be peeled down to 'str' - + Args: args: A tuple or positional arguments of the method being hooked to the String class. Return: A list of adapted positional arguments. - + """ new_args = [] @@ -739,7 +738,7 @@ def _wrap_str_return_value(cls, _self, attr: str, value: object): @classmethod def _get_hooked_string_method(cls, attr): """ - Hook a `str` method to a corresponding method of + Hook a `str` method to a corresponding method of `String` with the same name. Args: @@ -772,7 +771,7 @@ def overloaded_attr(_self, *args, **kwargs): @classmethod def _get_hooked_string_pointer_method(cls, attr): """ - Hook a `String` method to a corresponding method of + Hook a `String` method to a corresponding method of `StringPointer` with the same name. Args: diff --git a/syft/generic/frameworks/hook/hook_args.py b/syft/generic/frameworks/hook/hook_args.py index 0bee3187423..aaa75f681cb 100644 --- a/syft/generic/frameworks/hook/hook_args.py +++ b/syft/generic/frameworks/hook/hook_args.py @@ -228,6 +228,7 @@ def hook_response(attr, response, wrap_type, wrap_args={}, new_self=None): hash_wrap_args = hash(frozenset(wrap_args.items())) attr_id = f"{attr}@{wrap_type.__name__}.{response_is_tuple}.{hash_wrap_args}" + # import pdb; pdb.set_trace() try: assert attr not in ambiguous_functions @@ -263,6 +264,7 @@ def build_wrap_reponse_from_function(response, wrap_type, wrap_args): # Inspect the call to find tensor arguments and return a rule whose # structure is the same as the response object, with 1 where there was # (framework or syft) tensors and 0 when not (ex: number, str, ...) + rule = build_rule(response) # Build a function with this rule to efficiently replace syft tensors # (but not pointer) with their child in the args objects diff --git a/syft/generic/frameworks/hook/trace.py b/syft/generic/frameworks/hook/trace.py index 027f8105d54..244d97c8b68 100644 --- a/syft/generic/frameworks/hook/trace.py +++ b/syft/generic/frameworks/hook/trace.py @@ -75,6 +75,7 @@ def trace_wrapper(*args, **kwargs): syft.hook.trace.out_of_operation = False response = func(*args, **kwargs) + print(f"This is it {response}") syft.hook.trace.out_of_operation = True diff --git a/syft/generic/frameworks/types.py b/syft/generic/frameworks/types.py index d80cc68d3f6..9ce6a06115e 100644 --- a/syft/generic/frameworks/types.py +++ b/syft/generic/frameworks/types.py @@ -28,6 +28,11 @@ framework_layer_module.named_tensors = torch.nn.Module.named_parameters framework_layer_modules.append(framework_layer_module) +if dependency_check.crypten_available: + import crypten + + framework_tensors.append(crypten.mpc.MPCTensor) + framework_tensors = tuple(framework_tensors) FrameworkTensorType = Union[framework_tensors] FrameworkTensor = framework_tensors diff --git a/syft/generic/tensor.py b/syft/generic/tensor.py index 3bc97e005a9..b2b2ad8bc17 100644 --- a/syft/generic/tensor.py +++ b/syft/generic/tensor.py @@ -6,6 +6,11 @@ from syft.generic.object import _apply_args from syft.generic.object import AbstractObject from syft.generic.object import initialize_object +from syft.generic.frameworks.overload import overloaded + +from syft.generic.frameworks.hook.trace import tracer +import crypten +import torch as th class AbstractTensor(AbstractObject): diff --git a/syft/workers/base.py b/syft/workers/base.py index b73f0c4133f..95f855034ee 100644 --- a/syft/workers/base.py +++ b/syft/workers/base.py @@ -32,7 +32,8 @@ from syft.execution.plan import Plan from syft.messaging.message import CryptenInit from syft.workers.abstract import AbstractWorker -from syft.frameworks.crypten import toy_func, run_party +from syft.frameworks.crypten import run_party + from syft.exceptions import GetNotPermittedError from syft.exceptions import WorkerNotFoundException @@ -179,6 +180,8 @@ def __init__( self.tensorflow = self.framework self.remote = Remote(self, "tensorflow") + self.rank_to_worker_id = None + # SECTION: Methods which MUST be overridden by subclasses @abstractmethod def _send_msg(self, message: bin, location: "BaseWorker"): @@ -414,10 +417,22 @@ def run_crypten_party(self, message: tuple): An ObjectMessage containing the return value of the crypten function computed. """ - rank, world_size, master_addr, master_port = message + self.rank_to_worker_id, world_size, master_addr, master_port = message + + plans = self.search("crypten_plan") + assert len(plans) == 1 + + plan = plans[0].get() - return_value = run_party(toy_func, rank, world_size, master_addr, master_port, (), {}) + rank = None + for r, worker_id in self.rank_to_worker_id.items(): + if worker_id == self.id: + rank = r + break + assert rank != None + + return_value = run_party(plan, rank, world_size, master_addr, master_port, (), {}) return ObjectMessage(return_value) def execute_command(self, message: tuple) -> PointerTensor: @@ -701,6 +716,12 @@ def get_worker( else: return self._get_worker(id_or_worker) + def get_worker_from_rank(self, rank: int): + return self._get_worker_based_on_id(self.rank_to_worker_id[rank]) + + def _set_rank_to_worker_id(self, rank_to_worker_id): + self.rank_to_worker_id = rank_to_worker_id + def _get_worker(self, worker: AbstractWorker): if worker.id not in self._known_workers: self.add_worker(worker) diff --git a/test/crypten/test_context.py b/test/crypten/test_context.py index 389856d610c..b5523d2a181 100644 --- a/test/crypten/test_context.py +++ b/test/crypten/test_context.py @@ -1,6 +1,8 @@ import pytest +import syft as sy from syft.frameworks.crypten.context import run_multiworkers import torch as th +import crypten def test_context(workers): @@ -10,18 +12,24 @@ def test_context(workers): alice = workers["alice"] bob = workers["bob"] - alice_tensor_ptr = th.tensor([42, 53]).tag("crypten_data").send(alice) - bob_tensor_ptr = th.tensor([101, 32]).tag("crypten_data").send(bob) + alice_tensor_ptr = th.tensor([42, 53, 3, 2]).tag("crypten_data").send(alice) + bob_tensor_ptr = th.tensor([101, 32, 29, 2]).tag("crypten_data").send(bob) @run_multiworkers([alice, bob], master_addr="127.0.0.1") - def test_three_parties(): - pass # pragma: no cover + @sy.func2plan() + def plan_func(): + alice_tensor = crypten.load("crypten_data", 1) + bob_tensor = crypten.load("crypten_data", 2) - return_values = test_three_parties() + crypt = alice_tensor + bob_tensor + result = crypt.get_plain_text() + return result + + return_values = plan_func() # A toy function is ran at each party, and they should all decrypt - # a tensor with value [42, 53, 101, 32] - expected_value = [42, 53, 101, 32] + # a tensor with value [143, 85] + expected_value = [143, 85, 32, 4] for rank in range(n_workers): assert ( return_values[rank] == expected_value