diff --git a/syft/messaging/plan/plan.py b/syft/messaging/plan/plan.py index 3e519f74842..f330da588c2 100644 --- a/syft/messaging/plan/plan.py +++ b/syft/messaging/plan/plan.py @@ -78,7 +78,7 @@ class Plan(AbstractObject, ObjectStorage): state_ids: ids of the state elements arg_ids: ids of the last argument ids (present in the procedure commands) result_ids: ids of the last result ids (present in the procedure commands) - readable_plan: list of commands + operations: list of commands blueprint: the function to be transformed into a plan state_tensors: a tuple of state elements. It can be used to populate a state id: state id @@ -98,7 +98,7 @@ def __init__( state_ids: List[Union[str, int]] = None, arg_ids: List[Union[str, int]] = None, result_ids: List[Union[str, int]] = None, - readable_plan: List = None, + operations: List = None, blueprint=None, state_tensors=None, # General kwargs @@ -120,7 +120,7 @@ def __init__( self.nested_states = [] # Info about the plan stored via the state and the procedure - self.procedure = procedure or Procedure(readable_plan, arg_ids, result_ids) + self.procedure = procedure or Procedure(operations, arg_ids, result_ids) self.state = state or State(owner=owner, plan=self, state_ids=state_ids) if state_tensors is not None: for tensor in state_tensors: @@ -163,7 +163,7 @@ def location(self): # For backward compatibility @property - def readable_plan(self): + def operations(self): return self.procedure.operations def parameters(self): @@ -214,11 +214,8 @@ def _recv_msg(self, bin_message: bin): if isinstance(msg, Operation): # Re-deserialize without detailing - (msg_type, contents) = sy.serde.deserialize( - bin_message, strategy=sy.serde.msgpack.serde._deserialize_msgpack_binary - ) - self.procedure.operations.append((msg_type, contents)) + self.procedure.operations.append(msg) return sy.serde.serialize(None) @@ -346,7 +343,7 @@ def __call__(self, *args, **kwargs): def execute_commands(self): for message in self.procedure.operations: - bin_message = sy.serde.serialize(message, simplified=True) + bin_message = sy.serde.serialize(message) _ = self.owner.recv_msg(bin_message) def run(self, args: Tuple, result_ids: List[Union[str, int]]): diff --git a/syft/messaging/plan/procedure.py b/syft/messaging/plan/procedure.py index 3e498043b20..b6d33315fe5 100644 --- a/syft/messaging/plan/procedure.py +++ b/syft/messaging/plan/procedure.py @@ -18,7 +18,7 @@ class Procedure(object): on different workers. Args: - operations: the list of (serialized) operations + operations: the list of operations arg_ids: the argument ids present in the operations result_ids: the result ids present in the operations """ @@ -63,8 +63,8 @@ def update_ids( self, from_ids: Tuple[Union[str, int]] = [], to_ids: Tuple[Union[str, int]] = [], - from_worker: Union[str, int] = None, - to_worker: Union[str, int] = None, + from_worker_id: Union[str, int] = None, + to_worker_id: Union[str, int] = None, ): """Replace ids and worker ids in the list of operations stored @@ -74,74 +74,58 @@ def update_ids( from_worker: The previous worker that built the plan. to_worker: The new worker that is running the plan. """ - - # We operate on simplified content of Operation, hence all values should be simplified - from_workers_simplified = None - to_workers_simplified = None - if from_worker and to_worker: - from_workers_simplified = [sy.serde.msgpack.serde._simplify(None, from_worker)] - to_workers_simplified = [sy.serde.msgpack.serde._simplify(None, to_worker)] - - from_ids_simplified = None - to_ids_simplified = None - if len(from_ids) and len(to_ids): - from_ids_simplified = [sy.serde.msgpack.serde._simplify(None, id) for id in from_ids] - to_ids_simplified = [sy.serde.msgpack.serde._simplify(None, id) for id in to_ids] - for idx, operation in enumerate(self.operations): - if from_workers_simplified and to_workers_simplified: - operation = Procedure.replace_operation_ids( - operation, from_workers_simplified, to_workers_simplified - ) - - if from_ids_simplified and to_ids_simplified: - operation = Procedure.replace_operation_ids( - operation, from_ids_simplified, to_ids_simplified - ) + # replace ids in the owner + owner = operation.cmd_owner + if owner is not None: + if owner.id in from_ids: + owner.id = to_ids[from_ids.index(owner.id)] + + if owner.id_at_location in from_ids: + owner.id_at_location = to_ids[from_ids.index(owner.id_at_location)] + + # replace ids in the args + for arg in operation.cmd_args: + try: + if arg.id in from_ids: + arg.id = to_ids[from_ids.index(arg.id)] + + if arg.id_at_location in from_ids: + arg.id_at_location = to_ids[from_ids.index(arg.id_at_location)] + except: + pass + + # replace ids in the returns + return_ids = list(operation.return_ids) + for idx, return_id in enumerate(return_ids): + if return_id in from_ids: + return_ids[idx] = to_ids[from_ids.index(return_id)] + operation.return_ids = return_ids + + # create a dummy worker + to_worker = sy.workers.virtual.VirtualWorker(None, to_worker_id) + + # replace worker in the owner + if owner is not None and owner.location is not None: + if owner.location.id == from_worker_id: + owner.location = to_worker + + # replace worker in the args + for arg in operation.cmd_args: + try: + if arg.location.id == from_worker_id: + arg.location = to_worker + except: + pass self.operations[idx] = operation return self - @staticmethod - def replace_operation_ids(operation, from_ids, to_ids): - """ - Replace ids in a single operation - - Args: - operation: the operation to update - from_ids: the ids to replace - to_ids: the new ids to put inplace - """ - - assert isinstance(from_ids, (list, tuple)) - assert isinstance(to_ids, (list, tuple)) - - type_obj = type(operation) - operation = list(operation) - for i, item in enumerate(operation): - # Since this is simplified content, id can be int or simplified str (tuple) - if isinstance(item, (int, tuple)) and item in from_ids: - operation[i] = to_ids[from_ids.index(item)] - elif isinstance(item, (list, tuple)): - operation[i] = Procedure.replace_operation_ids( - operation=item, from_ids=from_ids, to_ids=to_ids - ) - return type_obj(operation) - - def copy(self) -> "Procedure": - procedure = Procedure( - operations=copy.deepcopy(self.operations), - arg_ids=self.arg_ids, - result_ids=self.result_ids, - ) - return procedure - @staticmethod def simplify(worker: AbstractWorker, procedure: "Procedure") -> tuple: return ( - # We're not simplifying fully because operations are already simplified - sy.serde.msgpack.serde._simplify(worker, procedure.operations, shallow=True), + sy.serde.msgpack.serde._simplify(worker, procedure.operations), sy.serde.msgpack.serde._simplify(worker, procedure.arg_ids), sy.serde.msgpack.serde._simplify(worker, procedure.result_ids), sy.serde.msgpack.serde._simplify(worker, procedure.promise_out_id), @@ -151,7 +135,7 @@ def simplify(worker: AbstractWorker, procedure: "Procedure") -> tuple: def detail(worker: AbstractWorker, procedure_tuple: tuple) -> "State": operations, arg_ids, result_ids, promise_out_id = procedure_tuple - operations = sy.serde.msgpack.serde._detail(worker, operations, shallow=True) + operations = sy.serde.msgpack.serde._detail(worker, operations) arg_ids = sy.serde.msgpack.serde._detail(worker, arg_ids) result_ids = sy.serde.msgpack.serde._detail(worker, result_ids) diff --git a/test/message/test_plan.py b/test/message/test_plan.py index ae18eae67b7..337a5fb8cb0 100644 --- a/test/message/test_plan.py +++ b/test/message/test_plan.py @@ -21,7 +21,7 @@ def plan_abs(data): return data.abs() assert isinstance(plan_abs.__str__(), str) - assert len(plan_abs.readable_plan) > 0 + assert len(plan_abs.operations) > 0 assert plan_abs.is_built @@ -36,7 +36,7 @@ def foo(x, state): return x + bias assert isinstance(foo.__str__(), str) - assert len(foo.readable_plan) > 0 + assert len(foo.operations) > 0 assert foo.is_built t = th.tensor([1.0, 2]) @@ -51,11 +51,11 @@ def plan_abs(data): return data.abs() assert not plan_abs.is_built - assert not len(plan_abs.readable_plan) + assert not len(plan_abs.operations) plan_abs.build(th.tensor([-1])) - assert len(plan_abs.readable_plan) + assert len(plan_abs.operations) assert plan_abs.is_built @@ -80,7 +80,7 @@ def plan_abs(data): return data.abs() assert isinstance(plan_abs.__str__(), str) - assert len(plan_abs.readable_plan) > 0 + assert len(plan_abs.operations) > 0 def test_raise_exception_for_invalid_shape(): @@ -190,17 +190,17 @@ def test_plan_multiple_send(workers): def plan_abs(data): return data.abs() - plan_ptr = plan_abs.send(bob) x_ptr = th.tensor([-1, 7, 3]).send(bob) + plan_ptr = plan_abs.send(bob) p = plan_ptr(x_ptr) x_abs = p.get() assert (x_abs == th.tensor([1, 7, 3])).all() # Test get / send plan + x_ptr = th.tensor([-1, 2, 3]).send(alice) plan_ptr = plan_abs.send(alice) - x_ptr = th.tensor([-1, 2, 3]).send(alice) p = plan_ptr(x_ptr) x_abs = p.get() assert (x_abs == th.tensor([1, 2, 3])).all() @@ -381,7 +381,7 @@ def forward(self, x): assert th.all(th.eq(fetched_plan(x), plan(x))) # assert fetched_plan.state.state_ids != plan.state.state_ids #TODO - # Make sure fetched_plan is using the readable_plan + # Make sure fetched_plan is using the operations assert fetched_plan.forward is None assert fetched_plan.is_built @@ -428,7 +428,7 @@ def forward(self, x): assert th.all(th.eq(fetched_plan(x), expected)) # assert fetched_plan.state.state_ids != plan.state.state_ids #TODO - # Make sure fetched_plan is using the readable_plan + # Make sure fetched_plan is using the operations assert fetched_plan.forward is None assert fetched_plan.is_built @@ -545,7 +545,7 @@ def forward(self, x): assert th.all(decrypted - expected.detach() < 1e-2) # assert fetched_plan.state.state_ids != plan.state.state_ids #TODO - # Make sure fetched_plan is using the readable_plan + # Make sure fetched_plan is using the operations assert fetched_plan.forward is None assert fetched_plan.is_built @@ -840,167 +840,6 @@ def forward(self, x): # Plan._replace_message_ids = _replace_message_ids_orig -def test_procedure_update_ids(): - commands = [ - ( - 31, - ( - 1, - ( - ( - 6, - ( - (5, (b"__add__",)), - (23, (27674294093, 68519530406, (5, (b"me",)), None, (10, (1,)), True)), - ( - 6, - ( - ( - 23, - ( - 2843683950, - 91383408771, - (5, (b"me",)), - None, - (10, (1,)), - True, - ), - ), - ), - ), - (0, ()), - ), - ), - (75165665059,), - ), - ), - ) - ] - - procedure = Procedure(operations=commands, arg_ids=[68519530406], result_ids=(75165665059,)) - - procedure.update_ids( - from_ids=[27674294093], to_ids=[73570994542], from_worker="me", to_worker="alice" - ) - - assert procedure.operations == [ - ( - 31, - ( - 1, - ( - ( - 6, - ( - (5, (b"__add__",)), - ( - 23, - ( - 73570994542, - 68519530406, - (5, (b"alice",)), - None, - (10, (1,)), - True, - ), - ), - ( - 6, - ( - ( - 23, - ( - 2843683950, - 91383408771, - (5, (b"alice",)), - None, - (10, (1,)), - True, - ), - ), - ), - ), - (0, ()), - ), - ), - (75165665059,), - ), - ), - ) - ] - - tensor = th.tensor([1.0]) - tensor_id = tensor.id - procedure.update_args(args=(tensor,), result_ids=[8730174527]) - - assert procedure.operations == [ - ( - 31, - ( - 1, - ( - ( - 6, - ( - (5, (b"__add__",)), - ( - 23, - (73570994542, tensor_id, (5, (b"alice",)), None, (10, (1,)), True), - ), - ( - 6, - ( - ( - 23, - ( - 2843683950, - 91383408771, - (5, (b"alice",)), - None, - (10, (1,)), - True, - ), - ), - ), - ), - (0, ()), - ), - ), - (8730174527,), - ), - ), - ) - ] - - procedure.operations = [ - (73570994542, 8730174527, (5, (b"alice",)), None, (10, (1,)), True), - (2843683950, 91383408771, (5, (b"alice",)), None, (10, (1,)), True), - ] - - procedure.update_worker_ids(from_worker_id="alice", to_worker_id="me") - - assert procedure.operations == [ - (73570994542, 8730174527, (5, (b"me",)), None, (10, (1,)), True), - (2843683950, 91383408771, (5, (b"me",)), None, (10, (1,)), True), - ] - - # From int worker_id to str - procedure.operations = [(73570994542, 8730174527, 1234567890, None, (10, (1,)), True)] - - procedure.update_worker_ids(from_worker_id=1234567890, to_worker_id="me") - - assert procedure.operations == [ - (73570994542, 8730174527, (5, (b"me",)), None, (10, (1,)), True) - ] - - # From str worker_id to int - procedure.operations = [(73570994542, 8730174527, (5, (b"alice",)), None, (10, (1,)), True)] - - procedure.update_worker_ids(from_worker_id="alice", to_worker_id=1234567890) - - assert procedure.operations == [(73570994542, 8730174527, 1234567890, None, (10, (1,)), True)] - - def test_send_with_plan(workers): bob = workers["bob"] diff --git a/test/message/test_procedure.py b/test/message/test_procedure.py new file mode 100644 index 00000000000..c362e0a88e5 --- /dev/null +++ b/test/message/test_procedure.py @@ -0,0 +1,35 @@ +import syft +import torch + +from syft.generic.pointers.pointer_tensor import PointerTensor +from syft.messaging.message import Operation +from syft.messaging.plan.procedure import Procedure + + +def test_procedure_update_ids_and_args(workers): + pointer1 = PointerTensor( + workers["bob"], 68519530406, workers["me"], 27674294093, True, torch.Size([1]) + ) + + pointer2 = PointerTensor( + workers["bob"], 91383408771, workers["me"], 2843683950, False, torch.Size([1]) + ) + + operation = Operation("__add__", pointer1, [pointer2], {}, (75165665059,)) + + procedure = Procedure(operations=[operation], arg_ids=[68519530406], result_ids=(75165665059,)) + + procedure.update_ids( + from_ids=[27674294093], to_ids=[73570994542], from_worker_id="bob", to_worker_id="alice" + ) + + assert procedure.operations[0].cmd_owner.id == 73570994542 + assert procedure.operations[0].cmd_owner.location.id == "alice" + assert procedure.operations[0].cmd_args[0].location.id == "alice" + + tensor = torch.tensor([1.0]) + tensor_id = tensor.id + procedure.update_args(args=(tensor,), result_ids=[8730174527]) + + assert procedure.operations[0].cmd_owner.id_at_location == tensor_id + assert procedure.operations[0].return_ids[0] == 8730174527