diff --git a/syft/__init__.py b/syft/__init__.py index 30cfbb71f44..092e2bc96dc 100644 --- a/syft/__init__.py +++ b/syft/__init__.py @@ -54,6 +54,7 @@ from syft.messaging.plan import Plan from syft.messaging.plan import func2plan from syft.messaging.plan import method2plan +from syft.messaging.promise import Promise # Import Worker Types from syft.workers.virtual import VirtualWorker @@ -67,6 +68,7 @@ from syft.frameworks.torch.tensors.interpreters.autograd import AutogradTensor from syft.frameworks.torch.tensors.interpreters.precision import FixedPrecisionTensor from syft.frameworks.torch.tensors.interpreters.large_precision import LargePrecisionTensor +from syft.frameworks.torch.tensors.interpreters.promise import PromiseTensor from syft.generic.pointers.pointer_plan import PointerPlan from syft.generic.pointers.pointer_protocol import PointerProtocol from syft.generic.pointers.pointer_tensor import PointerTensor diff --git a/syft/frameworks/torch/hook/hook.py b/syft/frameworks/torch/hook/hook.py index f8062e62761..2b90178ba8c 100644 --- a/syft/frameworks/torch/hook/hook.py +++ b/syft/frameworks/torch/hook/hook.py @@ -10,9 +10,11 @@ import syft from syft.generic.frameworks.hook import hook_args from syft.generic.frameworks.hook.hook import FrameworkHook +from syft.generic.tensor import AbstractTensor from syft.generic.frameworks.remote import Remote from syft.frameworks.torch.tensors.interpreters.autograd import AutogradTensor from syft.frameworks.torch.tensors.interpreters.native import TorchTensor +from syft.frameworks.torch.tensors.interpreters.promise import PromiseTensor from syft.frameworks.torch.tensors.interpreters.paillier import PaillierTensor from syft.frameworks.torch.tensors.decorators.logging import LoggingTensor from syft.frameworks.torch.tensors.interpreters.precision import FixedPrecisionTensor @@ -26,6 +28,7 @@ from syft.workers.base import BaseWorker from syft.workers.virtual import VirtualWorker from syft.messaging.plan import Plan +from syft.messaging.promise import Promise from syft.exceptions import route_method_exception from syft.exceptions import TensorsNotCollocatedException @@ -168,6 +171,9 @@ def __init__( # Add all hooked tensor methods to LargePrecisionTensor tensor self._hook_syft_tensor_methods(LargePrecisionTensor) + # Add all hooked tensor methods to PromiseTensor + self._hook_promise_tensor() + # Hook the tensor constructor function self._hook_tensor() @@ -504,6 +510,103 @@ def overloaded_attr(self, *args, **kwargs): return overloaded_attr + def _hook_promise_tensor(hook_self): + + methods_to_hook = hook_self.to_auto_overload[torch.Tensor] + + def generate_method(method_name): + def method(self, *args, **kwargs): + + arg_shapes = list([self.shape]) + arg_ids = list([self.id]) + + # Convert scalar arguments to tensors to be able to use them with plans + args = list(args) + for ia in range(len(args)): + if not isinstance(args[ia], (torch.Tensor, AbstractTensor)): + args[ia] = torch.tensor(args[ia]) + + for arg in args: + arg_shapes.append(arg.shape) + + @syft.func2plan(arg_shapes) + def operation_method(self, *args, **kwargs): + return getattr(self, method_name)(*args, **kwargs) + + self.plans.add(operation_method.id) + for arg in args: + if isinstance(arg, PromiseTensor): + arg.plans.add(operation_method.id) + + operation_method.procedure.update_args( + [self, *args], operation_method.procedure.result_ids + ) + + promise_out = PromiseTensor( + owner=self.owner, + shape=operation_method.output_shape, + tensor_type=self.obj_type, + plans=set(), + ) + operation_method.procedure.promise_out_id = promise_out.id + + if operation_method.owner != self.owner: + operation_method.send(self.owner) + else: # otherwise object not registered on local worker + operation_method.owner.register_obj(operation_method) + + return promise_out + + return method + + for method_name in methods_to_hook: + setattr(PromiseTensor, method_name, generate_method(method_name)) + + def FloatTensor(shape, *args, **kwargs): + return PromiseTensor(shape, tensor_type="torch.FloatTensor", *args, **kwargs).wrap() + + setattr(Promise, "FloatTensor", FloatTensor) + + def DoubleTensor(shape, *args, **kwargs): + return PromiseTensor(shape, tensor_type="torch.DoubleTensor", *args, **kwargs).wrap() + + setattr(Promise, "DoubleTensor", DoubleTensor) + + def HalfTensor(shape, *args, **kwargs): + return PromiseTensor(shape, tensor_type="torch.HalfTensor", *args, **kwargs).wrap() + + setattr(Promise, "HalfTensor", HalfTensor) + + def ByteTensor(shape, *args, **kwargs): + return PromiseTensor(shape, tensor_type="torch.ByteTensor", *args, **kwargs).wrap() + + setattr(Promise, "ByteTensor", ByteTensor) + + def CharTensor(shape, *args, **kwargs): + return PromiseTensor(shape, tensor_type="torch.CharTensor", *args, **kwargs).wrap() + + setattr(Promise, "CharTensor", CharTensor) + + def ShortTensor(shape, *args, **kwargs): + return PromiseTensor(shape, tensor_type="torch.ShortTensor", *args, **kwargs).wrap() + + setattr(Promise, "ShortTensor", ShortTensor) + + def IntTensor(shape, *args, **kwargs): + return PromiseTensor(shape, tensor_type="torch.IntTensor", *args, **kwargs).wrap() + + setattr(Promise, "IntTensor", IntTensor) + + def LongTensor(shape, *args, **kwargs): + return PromiseTensor(shape, tensor_type="torch.LongTensor", *args, **kwargs).wrap() + + setattr(Promise, "LongTensor", LongTensor) + + def BoolTensor(shape, args, **kwargs): + return PromiseTensor(shape, tensor_type="torch.BoolTensor", *args, **kwargs).wrap() + + setattr(Promise, "BoolTensor", BoolTensor) + def _hook_tensor(hook_self): """Hooks the function torch.tensor() We need to do this seperately from hooking the class because internally diff --git a/syft/frameworks/torch/tensors/interpreters/__init__.py b/syft/frameworks/torch/tensors/interpreters/__init__.py index e69de29bb2d..8b137891791 100644 --- a/syft/frameworks/torch/tensors/interpreters/__init__.py +++ b/syft/frameworks/torch/tensors/interpreters/__init__.py @@ -0,0 +1 @@ + diff --git a/syft/frameworks/torch/tensors/interpreters/native.py b/syft/frameworks/torch/tensors/interpreters/native.py index 08f2aa29323..57a8f9ce780 100644 --- a/syft/frameworks/torch/tensors/interpreters/native.py +++ b/syft/frameworks/torch/tensors/interpreters/native.py @@ -11,6 +11,7 @@ from syft.generic.frameworks.hook import hook_args from syft.generic.frameworks.overload import overloaded from syft.frameworks.torch.tensors.interpreters.crt_precision import _moduli_for_fields +from syft.frameworks.torch.tensors.interpreters.promise import PromiseTensor from syft.frameworks.torch.tensors.interpreters.paillier import PaillierTensor from syft.generic.frameworks.types import FrameworkTensor from syft.generic.tensor import AbstractTensor @@ -583,6 +584,10 @@ def move(self, location): self.child.owner.register_obj(self) return self + def remote_send(self, location, change_location=False): + self.child.remote_send(location, change_location) + return self + def attr(self, attr_name): """""" @@ -814,6 +819,25 @@ def combine(self, *pointers): return syft.combine_pointers(*ps) + def keep(self, obj): + """ Call .keep() on self's child if the child is a Promise (otherwise an error is raised). + .keep() is used to fulfill a promise with a value. + """ + return self.child.keep(obj) + + def value(self): + """ Call .value() on self's child if the child is a Promise (otherwise an error is raised). + .value() is used to retrieve the oldest unused value the promise was kept with. + """ + return self.child.value() + + def torch_type(self): + + if isinstance(self, torch.Tensor) and not self.is_wrapper: + return self.type() + else: + return self.child.torch_type() + def encrypt(self, public_key): """This method will encrypt each value in the tensor using Paillier homomorphic encryption. diff --git a/syft/frameworks/torch/tensors/interpreters/promise.py b/syft/frameworks/torch/tensors/interpreters/promise.py new file mode 100644 index 00000000000..d7f3983e975 --- /dev/null +++ b/syft/frameworks/torch/tensors/interpreters/promise.py @@ -0,0 +1,110 @@ +import syft as sy +from syft.workers.abstract import AbstractWorker +import weakref + +from syft.generic.tensor import AbstractTensor +from syft.generic.tensor import initialize_tensor +from syft.messaging.promise import Promise +from syft.generic.frameworks.hook import hook_args + + +class PromiseTensor(AbstractTensor, Promise): + def __init__( + self, shape, owner=None, id=None, tensor_type=None, plans=None, tags=None, description=None, + ): + """Initializes a PromiseTensor + + Args: + shape: the shape that should have the tensors keeping the promise. + owner: an optional BaseWorker object to specify the worker on which + the tensor is located. + id: an optional string or integer id of the PromiseTensor. + tensor_type: the type that should have the tensors keeping the promise. + plans: the ids of the plans waiting for the promise to be kept. When the promise is + kept, all the plans corresponding to these ids will be executed if the other + promises they were waiting for are also kept. + tags: an optional set of hashtags corresponding to this tensor + which this tensor should be searchable for. + description: an optional string describing the purpose of the + tensor. + """ + + if owner is None: + owner = sy.local_worker + + # constructors for AbstractTensor and Promise + AbstractTensor.__init__(self, id=id, owner=owner, tags=tags, description=description) + Promise.__init__(self, owner=owner, obj_type=tensor_type, plans=plans) + + self._shape = shape + + del self.child + + def torch_type(self): + return self.obj_type + + @property + def shape(self): + return self._shape + + @property + def grad(self): + return None + # if not hasattr(self, "_grad"): + # self._grad = PromiseTensor(shape=self._shape, tensor_type=self.torch_type()).wrap() + # + # return self._grad + + def __str__(self): + return f"[PromiseTensor({self.owner.id}:{self.id}) -future-> {self.obj_type.split('.')[-1]} -blocking-> {len(self.plans)} plans]" + + def __repr__(self): + return self.__str__() + + @staticmethod + def simplify(worker: AbstractWorker, tensor: "PromiseTensor") -> tuple: + """Takes the attributes of a FixedPrecisionTensor and saves them in a tuple. + + Args: + tensor: a FixedPrecisionTensor. + + Returns: + tuple: a tuple holding the unique attributes of the fixed precision tensor. + """ + + return ( + sy.serde._simplify(worker, tensor.id), + sy.serde._simplify(worker, tensor.shape), + sy.serde._simplify(worker, tensor.obj_type), + sy.serde._simplify(worker, tensor.plans), + ) + + @staticmethod + def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "PromiseTensor": + """ + This function reconstructs a FixedPrecisionTensor given it's attributes in form of a tuple. + Args: + worker: the worker doing the deserialization + tensor_tuple: a tuple holding the attributes of the FixedPrecisionTensor + Returns: + FixedPrecisionTensor: a FixedPrecisionTensor + Examples: + shared_tensor = detail(data) + """ + + id, shape, tensor_type, plans = tensor_tuple + + id = sy.serde._detail(worker, id) + shape = sy.serde._detail(worker, shape) + tensor_type = sy.serde._detail(worker, tensor_type) + plans = sy.serde._detail(worker, plans) + + tensor = PromiseTensor( + owner=worker, id=id, shape=shape, tensor_type=tensor_type, plans=plans + ) + + return tensor + + +### Register the tensor with hook_args.py ### +hook_args.default_register_tensor(PromiseTensor) diff --git a/syft/generic/object_storage.py b/syft/generic/object_storage.py index 620cb25f06a..45f73644c69 100644 --- a/syft/generic/object_storage.py +++ b/syft/generic/object_storage.py @@ -15,6 +15,7 @@ class ObjectStorage: """ def __init__(self): + # This is the collection of objects being stored. self._objects = {} def register_obj(self, obj: object, obj_id: Union[str, int] = None): @@ -103,7 +104,7 @@ def force_rm_obj(self, remote_key: Union[str, int]): """ if remote_key in self._objects: obj = self._objects[remote_key] - if hasattr(obj, "child") and obj.child is not None: + if hasattr(obj, "child") and hasattr(obj.child, "garbage_collect_data"): obj.child.garbage_collect_data = True del self._objects[remote_key] diff --git a/syft/generic/pointers/pointer_tensor.py b/syft/generic/pointers/pointer_tensor.py index cdbaa61bdaa..1ab8bfd9efd 100644 --- a/syft/generic/pointers/pointer_tensor.py +++ b/syft/generic/pointers/pointer_tensor.py @@ -265,6 +265,22 @@ def move(self, location): ptr.garbage_collect_data = False return ptr + def remote_send(self, destination, change_location=False): + """ Request the worker where the tensor being pointed to belongs to send it to destination. + For instance, if C holds a pointer, ptr, to a tensor on A and calls ptr.remote_send(B), + C will hold a pointer to a pointer on A which points to the tensor on B. + If change_location is set to True, the original pointer will point to the moved object. + Considering the same example as before with ptr.remote_send(B, change_location=True): + C will hold a pointer to the tensor on B. We may need to be careful here because this pointer + will have 2 references pointing to it. + """ + args = (destination,) + kwargs = {"inplace": True} + self.owner.send_command(message=("send", self, args, kwargs), recipient=self.location) + if change_location: + self.location = destination + return self + def remote_get(self): self.owner.send_command(message=("mid_get", self, (), {}), recipient=self.location) return self @@ -369,6 +385,34 @@ def share(self, *args, **kwargs): return response + def keep(self, *args, **kwargs): + """ + Send a command to remote worker to keep a promise + + Returns: + A pointer to a Tensor + """ + + # Send the command + command = ("keep", self, args, kwargs) + + response = self.owner.send_command(self.location, command) + + return response + + def value(self, *args, **kwargs): + """ + Send a command to remote worker to get the result generated by a promise. + + Returns: + A pointer to a Tensor + """ + command = ("value", self, args, kwargs) + + response = self.owner.send_command(self.location, command) + + return response + def share_(self, *args, **kwargs): """ Send a command to remote worker to additively share inplace a tensor diff --git a/syft/messaging/__init__.py b/syft/messaging/__init__.py index e69de29bb2d..8b137891791 100644 --- a/syft/messaging/__init__.py +++ b/syft/messaging/__init__.py @@ -0,0 +1 @@ + diff --git a/syft/messaging/plan/plan.py b/syft/messaging/plan/plan.py index 8be9f757f73..23bc6788fff 100644 --- a/syft/messaging/plan/plan.py +++ b/syft/messaging/plan/plan.py @@ -12,6 +12,7 @@ from syft.messaging.plan.procedure import Procedure from syft.messaging.plan.state import State from syft.workers.abstract import AbstractWorker +from syft.frameworks.torch.tensors.interpreters.promise import PromiseTensor class func2plan(object): @@ -122,6 +123,8 @@ def __init__( self.owner.register_obj(tensor) self.include_state = include_state self.is_built = is_built + self.input_shapes = None + self._output_shape = None # The plan has not been sent self.pointers = dict() @@ -163,6 +166,20 @@ def parameters(self): """ return self.state.tensors() + @property + def output_shape(self): + if self._output_shape is None: + args = self._create_placeholders(self.input_shapes) + # NOTE I currently need to regiser and then remove objects to use the method + # but a better syntax is being worked on + for arg in args: + self.owner.register_obj(arg) + output = self(*args) + for arg in args: + self.owner.rm_obj(arg) + self._output_shape = output.shape + return self._output_shape + def send_msg(self, *args, **kwargs): return self.owner.send_msg(*args, **kwargs) @@ -211,6 +228,9 @@ def build(self, *args): args: Input data. """ + self.input_shapes = [x.shape for x in args] + self._output_shape = None + # Move the arguments of the first call to the plan build_args = [arg.send(self) for arg in args] @@ -316,6 +336,10 @@ def run(self, args: Tuple, result_ids: List[Union[str, int]]): args: Arguments used to run plan. result_ids: List of ids where the results will be stored. """ + # If promises are given to the plan, prepare it to receive values from these promises + if self.has_promises_args(args): + return self.setup_plan_with_promises(*args) + # We build the plan only if needed if not self.is_built: self.build(args) @@ -329,6 +353,43 @@ def run(self, args: Tuple, result_ids: List[Union[str, int]]): return responses[0] return responses + def has_args_fulfilled(self): + """ Check if all the arguments of the plan are ready or not. + It might be the case that we still need to wait for some arguments in + case some of them are Promises. + """ + for arg_id in self.procedure.arg_ids: + arg = self.owner.get_obj(arg_id) + if hasattr(arg, "child") and isinstance(arg.child, PromiseTensor): + if not arg.child.is_kept(): + return False + return True + + def has_promises_args(self, args): + return any([hasattr(arg, "child") and isinstance(arg.child, PromiseTensor) for arg in args]) + + def setup_plan_with_promises(self, *args): + """ Slightly modifies a plan so that it can work with promises. + The plan will also be sent to location with this method. + """ + for arg in args: + if hasattr(arg, "child") and isinstance(arg.child, PromiseTensor): + arg.child.plans.add(self.id) + prom_owner = arg.owner + + # As we cannot perform operation between different type of tensors with torch, all the + # input tensors should have the same type and the result should also have this same type. + result_type = args[0].torch_type() + + res = PromiseTensor( + owner=prom_owner, shape=self.output_shape, tensor_type=result_type, plans=set(), + ) + + self.procedure.update_args(args, self.procedure.result_ids) + self.procedure.promise_out_id = res.id + + return res.wrap() + def send(self, *locations, force=False) -> PointerPlan: """Send plan to locations. @@ -452,6 +513,8 @@ def simplify(worker: AbstractWorker, plan: "Plan") -> tuple: sy.serde._simplify(worker, plan.state), sy.serde._simplify(worker, plan.include_state), sy.serde._simplify(worker, plan.is_built), + sy.serde._simplify(worker, plan.input_shapes), + sy.serde._simplify(worker, plan._output_shape), sy.serde._simplify(worker, plan.name), sy.serde._simplify(worker, plan.tags), sy.serde._simplify(worker, plan.description), @@ -467,16 +530,31 @@ def detail(worker: AbstractWorker, plan_tuple: tuple) -> "Plan": plan: a Plan object """ - id, procedure, state, include_state, is_built, name, tags, description = plan_tuple + ( + id, + procedure, + state, + include_state, + is_built, + input_shapes, + output_shape, + name, + tags, + description, + ) = plan_tuple id = sy.serde._detail(worker, id) procedure = sy.serde._detail(worker, procedure) state = sy.serde._detail(worker, state) + input_shapes = sy.serde._detail(worker, input_shapes) + output_shape = sy.serde._detail(worker, output_shape) plan = sy.Plan(owner=worker, id=id, include_state=include_state, is_built=is_built) plan.procedure = procedure plan.state = state state.plan = plan + plan.input_shapes = input_shapes + plan._output_shape = output_shape plan.name = sy.serde._detail(worker, name) plan.tags = sy.serde._detail(worker, tags) diff --git a/syft/messaging/plan/procedure.py b/syft/messaging/plan/procedure.py index e9a4585d3a8..800d0c8d00b 100644 --- a/syft/messaging/plan/procedure.py +++ b/syft/messaging/plan/procedure.py @@ -27,6 +27,8 @@ def __init__(self, operations=None, arg_ids=None, result_ids=None): self.operations = operations or [] self.arg_ids = arg_ids or [] self.result_ids = result_ids or [] + # promise_out_id id used for plan augmented to be used with promises + self.promise_out_id = None def __str__(self): return f"" @@ -128,14 +130,16 @@ def simplify(worker: AbstractWorker, procedure: "Procedure") -> tuple: ), # We're not simplifying because operations are already simplified sy.serde._simplify(worker, procedure.arg_ids), sy.serde._simplify(worker, procedure.result_ids), + sy.serde._simplify(worker, procedure.promise_out_id), ) @staticmethod def detail(worker: AbstractWorker, procedure_tuple: tuple) -> "State": - operations, arg_ids, result_ids = procedure_tuple + operations, arg_ids, result_ids, promise_out_id = procedure_tuple operations = list(operations) arg_ids = sy.serde._detail(worker, arg_ids) result_ids = sy.serde._detail(worker, result_ids) procedure = Procedure(operations, arg_ids, result_ids) + procedure.promise_out_id = promise_out_id return procedure diff --git a/syft/messaging/promise.py b/syft/messaging/promise.py new file mode 100644 index 00000000000..6163eddead6 --- /dev/null +++ b/syft/messaging/promise.py @@ -0,0 +1,127 @@ +"""This file contains the object we use to tell another worker that they will +receive an object with a certain ID. This object will also contain a set of +Plan objects which will be executed when the promise is kept (assuming +the plan has all of the inputs it requires). """ + +import syft as sy + +from abc import ABC +from abc import abstractmethod +from syft.workers.abstract import AbstractWorker + + +class Promise(ABC): + def __init__(self, owner=None, obj_id=None, obj_type=None, plans=None): + """Initialize a Promise with a unique ID and a set of (possibly empty) plans + + A Promise is a data-structure which indicates that "there will be an object + with this ID at some point, and when you get it, use it to execute these plans". + + As such, the promise has an ID itself and it also has an queue of object ids + with which the promise has been kept. + + However, it's important to know that some Plans are actually waiting on multiple + objects before they can be executed. Thus, it's possible that you might call + .keep() on a promise and nothing will happen because all of the plans are also + waiting on other promises to be kept before they execute. + + Args: + id (int): the id of the promise + plans (set): a set of the plans waiting on the promise to be kept + Example: + future_x = Promise() + """ + + self.owner = owner + + self.obj_type = obj_type + self.queue_obj_ids = [] + + if plans is None: + plans = set() + self.plans = plans + + def keep(self, obj): + """ This method is used to keep a promise. + This will register the object on the worker, add its id to the queue of the promise, + and every plan waiting for this promise will try to execute if it can. + """ + if obj.type() != self.obj_type: + raise TypeError( + "keep() was called with an object of incorrect type (not the type that was promised)" + ) + + if self.id in self.owner._objects: + self.owner.register_obj(obj) + + self.queue_obj_ids.append(obj.id) + + # If some plans were waiting for this promise... + for plan_id in self.plans: + plan = self.owner.get_obj(plan_id) + + # ... execute them if it was the last argument they were waiting for. + if plan.has_args_fulfilled(): + # Collect args + orig_ids = plan.procedure.arg_ids + args = [] + ids_to_rm = [] + for i, arg_id in enumerate(plan.procedure.arg_ids): + arg = self.owner.get_obj(arg_id) + if hasattr(arg, "child") and isinstance(arg.child, Promise): + id_to_add = arg.child.queue_obj_ids.pop(0) + ids_to_rm.append(id_to_add) + else: + id_to_add = arg_id + args.append(self.owner.get_obj(id_to_add)) + # FIXME Ugly fix because I had id_to_add != self.owner.get_obj(id_to_add).id... + args[-1].id = id_to_add + result = plan(*args) + + # ids of promises are changed automatically otherwise + plan.procedure.update_ids(plan.procedure.arg_ids, orig_ids) + plan.procedure.arg_ids = orig_ids + + # Remove objects from queues: + for to_rm in ids_to_rm: + self.owner.rm_obj(to_rm) + + self.owner.get_obj(plan.procedure.promise_out_id).keep(result) + + return obj + + def is_kept(self): + """ Check if promise has objects waiting to be used. + This returns False if the queue of objects is empty. + """ + return self.queue_obj_ids != [] + + def value(self): + """ Returns the next object in the queue of results. + """ + if not self.queue_obj_ids: + # If the promise has still not been kept + # or if the queue of results has been emptied + # the user should check if queue_obj_ids is empty on his own when this + # method is called on a PointerTensor because None cannot be "pointed to" + return None + ret_id = self.queue_obj_ids.pop(0) + ret = self.owner.get_obj(ret_id) + self.owner.rm_obj(ret_id) + return ret + + def __repr__(self): + return self.__str__() + + def __str__(self): + return ( + f"" + ) + + @abstractmethod + def simplify(self): + pass + + @abstractmethod + def detail(self): + pass diff --git a/syft/messaging/protocol.py b/syft/messaging/protocol.py index 02cbe82e1f3..a432c9eba09 100644 --- a/syft/messaging/protocol.py +++ b/syft/messaging/protocol.py @@ -8,8 +8,10 @@ from syft.generic.pointers.pointer_protocol import PointerProtocol from syft.workers.abstract import AbstractWorker from syft.workers.base import BaseWorker +from syft.messaging.promise import Promise from typing import List, Union +import warnings class Protocol(AbstractObject): @@ -84,6 +86,55 @@ def deploy(self, *workers): return self + def __call__(self, *args, **kwargs): + has_promised_inputs = any( + [hasattr(arg, "child") and isinstance(arg.child, Promise) for arg in args] + ) + if has_promised_inputs: + return self.build_with_promises(*args, **kwargs) + else: + return self.run(*args, **kwargs) + + def build_with_promises(self, *args, **kwargs): + """ This method is used to build the graph of computation distributed across the different workers, + meaning that output promises are built for plans and these output promises are used as inputs + for the next worker. + + The input args (with at least one promise) provided are sent to the first plan location. + The output promise(s) is created and linked to this plan and send to the second plan location, and so on. + Pointer(s) to the final result(s) on the last worker as well as to the input promise(s) + that have to be kept are returned. + """ + self._assert_is_resolved() + + # TODO if self.location is not None: + + # Local and sequential coordination of the plan execution + previous_worker_id = None + response = None + for worker, plan in self.plans: + # Transmit the args to the next worker if it's a different one % the previous + if previous_worker_id is None: + args = [arg.send(worker).child for arg in args] + elif previous_worker_id != worker.id: + args = [arg.remote_send(worker).child for arg in args] + for arg in args: + # Not clean but need to point to promise on next worker from protocol owner + # TODO see if a better solution exists + arg.location = worker + else: + args = [arg.child for arg in args] + + if previous_worker_id is None: + in_promise_ptrs = args[0] if len(args) == 1 else args + previous_worker_id = worker.id + + response = plan(*args) + + args = response if isinstance(response, tuple) else (response,) + + return in_promise_ptrs, response + def run(self, *args, **kwargs): """ Run the protocol by executing the plans sequentially @@ -281,9 +332,26 @@ def _assert_is_resolved(self): def _resolve_workers(self, workers): """Map the abstract workers (named by strings) to the provided workers and update the plans accordingly""" + dict_workers = {w.id: w for w in workers} + set_fake_ids = set(worker for worker, _ in self.plans) + set_real_ids = set(dict_workers.keys()) + + if 0 < len(set_fake_ids.intersection(set_real_ids)) < len(set_real_ids): + # The user chose fake ids that correspond to real ids but not all of them match. + # Maybe it's a mistake so we warn the user. + warnings.warn( + "You are deploying a protocol with workers for which only a subpart" + "have ids that match an id chosen for the protocol." + ) + + # If the "fake" ids manually set by the user when writing the protocol exactly match the ids + # of the workers, these fake ids in self.plans are replaced with the real workers. + if set_fake_ids == set_real_ids: + self.plans = [(dict_workers[w], p) for w, p in self.plans] + # If there is an exact one-to-one mapping, just iterate and keep the order # provided when assigning the workers - if len(workers) == len(self.plans): + elif len(workers) == len(self.plans): self.plans = [(worker, plan) for (_, plan), worker in zip(self.plans, workers)] # Else, there are duplicates in the self.plans keys and we need to build diff --git a/syft/serde/serde.py b/syft/serde/serde.py index 2ac49f9754e..e0924500296 100644 --- a/syft/serde/serde.py +++ b/syft/serde/serde.py @@ -49,6 +49,7 @@ from syft.frameworks.torch.tensors.interpreters.additive_shared import AdditiveSharingTensor from syft.frameworks.torch.tensors.interpreters.crt_precision import CRTPrecisionTensor from syft.frameworks.torch.tensors.interpreters.autograd import AutogradTensor +from syft.frameworks.torch.tensors.interpreters.promise import PromiseTensor from syft.generic.pointers.multi_pointer import MultiPointerTensor from syft.generic.pointers.object_pointer import ObjectPointer from syft.generic.pointers.pointer_tensor import PointerTensor @@ -106,6 +107,7 @@ CRTPrecisionTensor, LoggingTensor, MultiPointerTensor, + PromiseTensor, ObjectPointer, Plan, State, diff --git a/test/run_websocket_server.py b/test/run_websocket_server.py new file mode 100644 index 00000000000..8b323fa3e04 --- /dev/null +++ b/test/run_websocket_server.py @@ -0,0 +1,134 @@ +import logging +import syft as sy +from syft.workers import WebsocketServerWorker +import torch +import argparse +from torchvision import datasets +from torchvision import transforms +import numpy as np +from syft.frameworks.torch.federated import utils + +KEEP_LABELS_DICT = { + "alice": [0, 1, 2, 3], + "bob": [4, 5, 6], + "charlie": [7, 8, 9], + "testing": list(range(10)), +} # pragma: no cover + + +def start_websocket_server_worker( + id, host, port, hook, verbose, keep_labels=None, training=True +): # pragma: no cover + """Helper function for spinning up a websocket server and setting up the local datasets.""" + + server = WebsocketServerWorker(id=id, host=host, port=port, hook=hook, verbose=verbose) + + # Setup toy data (mnist example) + mnist_dataset = datasets.MNIST( + root="./data", + train=training, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + + if training: + indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8") + logger.info("number of true indices: %s", indices.sum()) + selected_data = ( + torch.native_masked_select(mnist_dataset.data.transpose(0, 2), torch.tensor(indices)) + .view(28, 28, -1) + .transpose(2, 0) + ) + logger.info("after selection: %s", selected_data.shape) + selected_targets = torch.native_masked_select(mnist_dataset.targets, torch.tensor(indices)) + + dataset = sy.BaseDataset( + data=selected_data, targets=selected_targets, transform=mnist_dataset.transform + ) + key = "mnist" + else: + dataset = sy.BaseDataset( + data=mnist_dataset.data, + targets=mnist_dataset.targets, + transform=mnist_dataset.transform, + ) + key = "mnist_testing" + + server.add_dataset(dataset, key=key) + + # Setup toy data (vectors example) + data_vectors = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True) + target_vectors = torch.tensor([[1], [0], [1], [0]]) + + server.add_dataset(sy.BaseDataset(data_vectors, target_vectors), key="vectors") + + # Setup toy data (xor example) + data_xor = torch.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True) + target_xor = torch.tensor([1.0, 1.0, 0.0, 0.0], requires_grad=False) + + server.add_dataset(sy.BaseDataset(data_xor, target_xor), key="xor") + + # Setup gaussian mixture dataset + data, target = utils.create_gaussian_mixture_toy_data(nr_samples=100) + server.add_dataset(sy.BaseDataset(data, target), key="gaussian_mixture") + + # Setup partial iris dataset + data, target = utils.iris_data_partial() + dataset = sy.BaseDataset(data, target) + dataset_key = "iris" + server.add_dataset(dataset, key=dataset_key) + + logger.info("datasets: %s", server.datasets) + if training: + logger.info("len(datasets[mnist]): %s", len(server.datasets["mnist"])) + + server.start() + return server + + +if __name__ == "__main__": # pragma: no cover + # Logging setup + logger = logging.getLogger("run_websocket_server") + FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d, p:%(process)d) - %(message)s" + logging.basicConfig(format=FORMAT) + logger.setLevel(level=logging.DEBUG) + + # Parse args + parser = argparse.ArgumentParser(description="Run websocket server worker.") + parser.add_argument( + "--port", + "-p", + type=int, + help="port number of the websocket server worker, e.g. --port 8777", + ) + parser.add_argument("--host", type=str, default="localhost", help="host for the connection") + parser.add_argument( + "--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice" + ) + parser.add_argument( + "--testing", + action="store_true", + help="if set, websocket server worker will load the test dataset instead of the training dataset", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="if set, websocket server worker will be started in verbose mode", + ) + + args = parser.parse_args() + + # Hook and start server + hook = sy.TorchHook(torch) + server = start_websocket_server_worker( + id=args.id, + host=args.host, + port=args.port, + hook=hook, + verbose=args.verbose, + keep_labels=KEEP_LABELS_DICT[args.id] if args.id in KEEP_LABELS_DICT else list(range(10)), + training=not args.testing, + ) diff --git a/test/torch/tensors/test_native.py b/test/torch/tensors/test_native.py index 4fd35d7d91d..c8bc2b0010e 100644 --- a/test/torch/tensors/test_native.py +++ b/test/torch/tensors/test_native.py @@ -110,6 +110,26 @@ def test_remote_get(hook, workers): assert len(alice._objects) == 1 +def test_remote_send(hook, workers): + me = workers["me"] + bob = workers["bob"] + alice = workers["alice"] + + x = torch.tensor([1, 2, 3, 4, 5]) + ptr_ptr_x = x.send(bob).remote_send(alice) + + assert ptr_ptr_x.owner == me + assert ptr_ptr_x.location == bob + assert x.id in alice._objects + + y = torch.tensor([1, 2, 3, 4, 5]) + ptr_y = y.send(bob).remote_send(alice, change_location=True) + + assert ptr_y.owner == me + assert ptr_y.location == alice + assert y.id in alice._objects + + def test_copy(): tensor = torch.rand(5, 3) coppied_tensor = tensor.copy() diff --git a/test/torch/tensors/test_promise.py b/test/torch/tensors/test_promise.py new file mode 100644 index 00000000000..0a1767b5fa4 --- /dev/null +++ b/test/torch/tensors/test_promise.py @@ -0,0 +1,198 @@ +import pytest +import torch +import syft + + +def test__str__(): + a = syft.Promise.FloatTensor(shape=torch.Size((3, 3))) + assert isinstance(a.__str__(), str) + + +@pytest.mark.parametrize("cmd", ["__add__", "__sub__", "__mul__"]) +def test_operations_between_promises(hook, cmd): + hook.local_worker.is_client_worker = False + + a = syft.Promise.FloatTensor(shape=torch.Size((2, 2))) + b = syft.Promise.FloatTensor(shape=torch.Size((2, 2))) + + actual = getattr(a, cmd)(b) + + ta = torch.tensor([[1.0, 2], [3, 4]]) + tb = torch.tensor([[-8.0, -7], [6, 5]]) + a.keep(ta) + b.keep(tb) + + expected = getattr(ta, cmd)(tb) + + assert (actual.value() == expected).all() + + hook.local_worker.is_client_worker = True + + +@pytest.mark.parametrize("cmd", ["__add__", "__sub__", "__mul__"]) +def test_operations_with_concrete(hook, cmd): + hook.local_worker.is_client_worker = False + + a = syft.Promise.FloatTensor(shape=torch.Size((2, 2))) + b = torch.tensor([[-8.0, -7], [6, 5]]).wrap() # TODO fix need to wrap + + actual = getattr(a, cmd)(b) + + ta = torch.tensor([[1.0, 2], [3, 4]]).wrap() + a.keep(ta) + + expected = getattr(ta, cmd)(b) + + assert (actual.value() == expected).all() + + hook.local_worker.is_client_worker = True + + +def test_send(workers): + bob = workers["bob"] + + a = syft.Promise.FloatTensor(shape=torch.Size((2, 2))) + + x = a.send(bob) + + x.keep(torch.ones((2, 2))) + + assert (x.value().get() == torch.ones((2, 2))).all() + + +@pytest.mark.parametrize("cmd", ["__add__", "__sub__", "__mul__"]) +def test_remote_operations(workers, cmd): + bob = workers["bob"] + + a = syft.Promise.FloatTensor(shape=torch.Size((3, 3))) + b = syft.Promise.FloatTensor(shape=torch.Size((3, 3))) + + x = a.send(bob) + y = b.send(bob) + + actual = getattr(x, cmd)(y) + + tx = torch.tensor([[1.0, 2], [3, 4]]) + ty = torch.tensor([[-8.0, -7], [6, 5]]) + x.keep(tx) + y.keep(ty) + + expected = getattr(tx, cmd)(ty) + + assert (actual.value().get() == expected).all() + + +def test_bufferized_results(hook): + hook.local_worker.is_client_worker = False + + a = syft.Promise.FloatTensor(shape=torch.Size((3, 3))) + + a.keep(torch.ones(3, 3)) + a.keep(2 * torch.ones(3, 3)) + a.keep(3 * torch.ones(3, 3)) + + assert (a.value() == torch.ones(3, 3)).all() + assert (a.value() == 2 * torch.ones(3, 3)).all() + assert (a.value() == 3 * torch.ones(3, 3)).all() + + hook.local_worker.is_client_worker = True + + +def test_plan_waiting_promise(hook, workers): + hook.local_worker.is_client_worker = False + + @syft.func2plan(args_shape=[(3, 3)]) + def plan_test(data): + return 2 * data + 1 + + # Hack otherwise plan not found on local worker... + hook.local_worker.register_obj(plan_test) + + a = syft.Promise.FloatTensor(shape=torch.Size((3, 3))) + + res = plan_test(a) + + a.keep(torch.ones(3, 3)) + + assert (res.value() == 3 * torch.ones(3, 3)).all() + + # With non promises + @syft.func2plan(args_shape=[(3, 3), (3, 3)]) + def plan_test(prom, tens): + return prom + tens + + # Hack otherwise plan not found on local worker... + hook.local_worker.register_obj(plan_test) + + a = syft.Promise.FloatTensor(shape=torch.Size((3, 3))) + b = 2 * torch.ones(3, 3).wrap() + + res = plan_test(a, b) + + a.keep(torch.ones(3, 3).wrap()) + + assert (res.value().child == 3 * torch.ones(3, 3)).all() + + # With several arguments and remote + bob = workers["bob"] + + @syft.func2plan(args_shape=[(3, 3), (3, 3)]) + def plan_test_remote(in_a, in_b): + return in_a + in_b + + a = syft.Promise.FloatTensor(shape=torch.Size((3, 3))) + b = syft.Promise.FloatTensor(shape=torch.Size((3, 3))) + + x = b.send(bob) + y = a.send(bob) + ptr_plan = plan_test_remote.send(bob) + + res_ptr = ptr_plan(x, y) + + x.keep(torch.ones(3, 3)) + x.keep(3 * torch.ones(3, 3)) + + y.keep(2 * torch.ones(3, 3)) + y.keep(4 * torch.ones(3, 3)) + + assert (res_ptr.value().get() == 3 * torch.ones(3, 3)).all() + assert (res_ptr.value().get() == 7 * torch.ones(3, 3)).all() + + hook.local_worker.is_client_worker = True + + +def test_protocol_waiting_promise(hook, workers): + hook.local_worker.is_client_worker = False + + alice = workers["alice"] + bob = workers["bob"] + + @syft.func2plan(args_shape=[(1,)]) + def plan_alice1(in_a): + return in_a + 1 + + @syft.func2plan(args_shape=[(1,)]) + def plan_bob1(in_b): + return in_b + 2 + + @syft.func2plan(args_shape=[(1,)]) + def plan_bob2(in_b): + return in_b + 3 + + @syft.func2plan(args_shape=[(1,)]) + def plan_alice2(in_a): + return in_a + 4 + + protocol = syft.Protocol( + [("alice", plan_alice1), ("bob", plan_bob1), ("bob", plan_bob2), ("alice", plan_alice2)] + ) + protocol.deploy(alice, bob) + + x = syft.Promise.FloatTensor(shape=torch.Size((1,))) + in_ptr, res_ptr = protocol(x) + + in_ptr.keep(torch.tensor([1.0])) + + assert res_ptr.value().get() == 11 + + hook.local_worker.is_client_worker = True