diff --git a/syft/frameworks/torch/hook/hook.py b/syft/frameworks/torch/hook/hook.py index 24ae65eaa3e..11e76f94c2a 100644 --- a/syft/frameworks/torch/hook/hook.py +++ b/syft/frameworks/torch/hook/hook.py @@ -32,7 +32,6 @@ from syft.messaging.promise import Promise from syft.exceptions import route_method_exception -from syft.exceptions import TensorsNotCollocatedException class TorchHook(FrameworkHook): diff --git a/syft/messaging/plan/plan.py b/syft/messaging/plan/plan.py index 449ebf227b4..3e519f74842 100644 --- a/syft/messaging/plan/plan.py +++ b/syft/messaging/plan/plan.py @@ -42,6 +42,7 @@ def __call__(self, plan_blueprint): state_tensors=self.state_tensors, include_state=self.include_state, ) + # Build the plan automatically if self.args_shape: args = Plan._create_placeholders(self.args_shape) @@ -114,6 +115,10 @@ def __init__( self.name = name or self.__class__.__name__ self.owner = owner + # If we have plans in plans we need to keep track of the states for each plan + # because we will need to serialize and send them to the remote workers + self.nested_states = [] + # Info about the plan stored via the state and the procedure self.procedure = procedure or Procedure(readable_plan, arg_ids, result_ids) self.state = state or State(owner=owner, plan=self, state_ids=state_ids) @@ -121,6 +126,7 @@ def __init__( for tensor in state_tensors: self.state.state_ids.append(tensor.id) self.owner.register_obj(tensor) + self.include_state = include_state self.is_built = is_built self.input_shapes = None @@ -240,6 +246,8 @@ def build(self, *args): cloned_state = self.state.clone_state_dict() self.state.send_for_build(location=self) + self.owner.init_plan = self + # We usually have include_state==True for functions converted to plan # using @func2plan and we need therefore to add the state manually if self.include_state: @@ -262,6 +270,7 @@ def build(self, *args): self.procedure.result_ids = (res_ptr.id_at_location,) self.is_built = True + self.owner.init_plan = None def copy(self): """Creates a copy of a plan.""" @@ -310,17 +319,30 @@ def __call__(self, *args, **kwargs): The pointer to the result of the execution if the plan was already sent, else the None message serialized. """ + + cloned_state = None + if self.owner.init_plan: + cloned_state = self.state.clone_state_dict() + self.owner.init_plan.nested_states.append(self.state) + self.state.send_for_build(location=self.owner.init_plan) + if len(kwargs): raise ValueError("Kwargs are not supported for plan.") result_ids = [sy.ID_PROVIDER.pop()] + res = None if self.forward is not None: if self.include_state: args = (*args, self.state) - return self.forward(*args) + res = self.forward(*args) else: - return self.run(args, result_ids=result_ids) + res = self.run(args, result_ids=result_ids) + + if self.owner.init_plan: + self.state.set_(cloned_state) + + return res def execute_commands(self): for message in self.procedure.operations: @@ -413,10 +435,12 @@ def send(self, *locations, force=False) -> PointerPlan: return self.pointers[location] self.procedure.update_worker_ids(self.owner.id, location.id) + # Send the Plan pointer = self.owner.send(self, workers=location) # Revert ids self.procedure.update_worker_ids(location.id, self.owner.id) + self.pointers[location] = pointer else: ids_at_location = [] @@ -520,6 +544,7 @@ def simplify(worker: AbstractWorker, plan: "Plan") -> tuple: sy.serde.msgpack.serde._simplify(worker, plan.name), sy.serde.msgpack.serde._simplify(worker, plan.tags), sy.serde.msgpack.serde._simplify(worker, plan.description), + sy.serde.msgpack.serde._simplify(worker, plan.nested_states), ) @staticmethod @@ -543,15 +568,18 @@ def detail(worker: AbstractWorker, plan_tuple: tuple) -> "Plan": name, tags, description, + nested_states, ) = plan_tuple id = sy.serde.msgpack.serde._detail(worker, id) procedure = sy.serde.msgpack.serde._detail(worker, procedure) state = sy.serde.msgpack.serde._detail(worker, state) input_shapes = sy.serde.msgpack.serde._detail(worker, input_shapes) output_shape = sy.serde.msgpack.serde._detail(worker, output_shape) + nested_states = sy.serde.msgpack.serde._detail(worker, nested_states) plan = sy.Plan(owner=worker, id=id, include_state=include_state, is_built=is_built) + plan.nested_states = nested_states plan.procedure = procedure plan.state = state state.plan = plan diff --git a/syft/messaging/plan/state.py b/syft/messaging/plan/state.py index 3d46da021f1..ba42441e3b9 100644 --- a/syft/messaging/plan/state.py +++ b/syft/messaging/plan/state.py @@ -21,8 +21,17 @@ def __init__(self, owner, plan=None, state_ids=None): self.plan = plan self.state_ids = state_ids or [] + def __str__(self): + """Returns the string representation of the State.""" + out = "<" + out += "State:" + for state_id in self.state_ids: + out += " {}".format(state_id) + out += ">" + return out + def __repr__(self): - return "State: " + ", ".join(self.state_ids) + return self.__str__() def tensors(self) -> List: """ @@ -84,10 +93,8 @@ def create_grad_if_missing(tensor): def send_for_build(self, location, **kwargs): """ Send functionality that can only be used when sending the state for - building the plan. Other than this, you shouldn't need to send the - state separately. + building the plan. """ - assert location.id == self.plan.id # ensure this is a send for the build for tensor in self.tensors(): self.create_grad_if_missing(tensor) diff --git a/syft/workers/base.py b/syft/workers/base.py index 8e653f118f2..84a780bd489 100644 --- a/syft/workers/base.py +++ b/syft/workers/base.py @@ -78,7 +78,7 @@ class BaseWorker(AbstractWorker, ObjectStorage): variables are instantiated or deleted as opposed to handling tensor/variable/model lifecycle internally. Set to True if this object is not where the objects will be stored, but is instead - a pointer to a worker that eists elsewhere. + a pointer to a worker that exists elsewhere. log_msgs: An optional boolean parameter to indicate whether all messages should be saved into a log for later review. This is primarily a development/testing feature. @@ -154,6 +154,9 @@ def __init__( # self is the to-be-created local worker self.add_worker(self) + # Used to keep track of a building plan + self.init_plan = None + if hook is None: self.framework = None else: diff --git a/test/message/test_plan.py b/test/message/test_plan.py index 21cb8b68e14..ae18eae67b7 100644 --- a/test/message/test_plan.py +++ b/test/message/test_plan.py @@ -1046,3 +1046,158 @@ def plan_abs(data): pointers = plan_abs.get_pointers() assert len(pointers) == 2 + + +def test_plan_nested_no_build_inner(workers): + alice = workers["alice"] + expected_res = th.tensor(200) + + @sy.func2plan() + def plan_double(data): + return 2 * data + + @sy.func2plan() + def plan_abs(data): + return plan_double(data).abs() + + x = th.tensor(100) + plan_abs.build(x) + + # Run plan locally + assert plan_abs(x) == expected_res + + # Run plan remote + x_ptr = x.send(alice) + plan_abs_ptr = plan_abs.send(alice) + res = plan_abs_ptr(x_ptr) + + assert res.get() == expected_res + + +def test_plan_nested_build_inner_plan_before(workers): + alice = workers["alice"] + expected_res = th.tensor(200) + + @sy.func2plan(args_shape=[(1,)]) + def plan_double(data): + return -2 * data + + @sy.func2plan() + def plan_abs(data): + return plan_double(data).abs() + + x = th.tensor(100) + plan_abs.build(x) + + # Run plan locally + assert plan_abs(x) == expected_res + + x_ptr = x.send(alice) + plan_abs_ptr = plan_abs.send(alice) + res = plan_abs_ptr(x_ptr) + + assert res.get() == expected_res + + +def test_plan_nested_build_inner_plan_after(workers): + alice = workers["alice"] + expected_res = th.tensor(200) + + @sy.func2plan() + def plan_double(data): + return -2 * data + + @sy.func2plan() + def plan_abs(data): + return plan_double(data).abs() + + x = th.tensor(100) + plan_abs.build(x) + plan_double.build(x) + + # Test locally + assert plan_abs(x) == expected_res + + # Test remote + x_ptr = x.send(alice) + plan_double_ptr = plan_abs.send(alice) + res = plan_double_ptr(x_ptr) + + assert res.get() == expected_res + + +def test_plan_nested_build_inner_plan_state(hook, workers): + alice = workers["alice"] + expected_res = th.tensor(199) + + with hook.local_worker.registration_enabled(): + + @sy.func2plan(args_shape=[(1,)], state=(th.tensor([1]),)) + def plan_double(data, state): + (bias,) = state.read() + return -2 * data + bias + + @sy.func2plan() + def plan_abs(data): + return plan_double(data).abs() + + x = th.tensor(100) + plan_abs.build(x) + + # Test locally + assert plan_abs(x) == expected_res + + # Test remote + x_ptr = x.send(alice) + plan_abs_ptr = plan_abs.send(alice) + plan_abs_ptr(x_ptr) + + res = plan_abs_ptr(x_ptr) + assert res.get() == expected_res + + +def test_plan_nested_build_multiple_plans_state(hook, workers): + alice = workers["alice"] + expected_res = th.tensor(1043) + + with hook.local_worker.registration_enabled(): + + @sy.func2plan(args_shape=[(1,)], state=(th.tensor([3]),)) + def plan_3(data, state): + (bias,) = state.read() + return data + bias + 42 + + @sy.func2plan(args_shape=[(1,)]) + def plan_2_2(data): + return data + 1331 + + @sy.func2plan(args_shape=[(1,)], state=(th.tensor([2]),)) + def plan_2_1(data, state): + (bias,) = state.read() + return -2 * plan_3(data) + bias + + @sy.func2plan() + def plan_1(data): + res = plan_2_1(data) + return plan_2_2(res) + + # (-2 * (x + tensor(3) + 42) + tensor(2) + 1331) + # ------------------- + # plan_3 + # -------------------------------------- + # plan_2_1 + # ----------------------------------------------- + # plan_2_2 + + x = th.tensor(100) + plan_1.build(x) + + # Test locally + assert plan_1(x) == expected_res + + # Test remote + x_ptr = x.send(alice) + plan_1_ptr = plan_1.send(alice) + + res = plan_1_ptr(x_ptr) + assert res.get() == expected_res diff --git a/test/test_serde_full.py b/test/test_serde_full.py index c07c059140e..e513db7d4de 100644 --- a/test/test_serde_full.py +++ b/test/test_serde_full.py @@ -643,6 +643,8 @@ def compare(detailed, original): assert detailed.procedure.operations == original.procedure.operations assert detailed.procedure.arg_ids == original.procedure.arg_ids assert detailed.procedure.result_ids == original.procedure.result_ids + # States for the nested plans + assert detailed.nested_states == original.nested_states # State assert detailed.state.state_ids == original.state.state_ids assert detailed.include_state == original.include_state @@ -678,6 +680,7 @@ def compare(detailed, original): msgpack.serde._simplify( syft.hook.local_worker, plan.description ), # (str) description + msgpack.serde._simplify(syft.hook.local_worker, []), # (list of State) ), ), "cmp_detailed": compare, @@ -702,6 +705,7 @@ def compare(detailed, original): msgpack.serde._simplify( syft.hook.local_worker, model_plan.description ), # (str) description + msgpack.serde._simplify(syft.hook.local_worker, []), # (list of State) ), ), "cmp_detailed": compare,