Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion syft/frameworks/torch/hook/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from syft.messaging.promise import Promise

from syft.exceptions import route_method_exception
from syft.exceptions import TensorsNotCollocatedException
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was not needed



class TorchHook(FrameworkHook):
Expand Down
32 changes: 30 additions & 2 deletions syft/messaging/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __call__(self, plan_blueprint):
state_tensors=self.state_tensors,
include_state=self.include_state,
)

# Build the plan automatically
if self.args_shape:
args = Plan._create_placeholders(self.args_shape)
Expand Down Expand Up @@ -114,13 +115,18 @@ def __init__(
self.name = name or self.__class__.__name__
self.owner = owner

# If we have plans in plans we need to keep track of the states for each plan
# because we will need to serialize and send them to the remote workers
self.nested_states = []

# Info about the plan stored via the state and the procedure
self.procedure = procedure or Procedure(readable_plan, arg_ids, result_ids)
self.state = state or State(owner=owner, plan=self, state_ids=state_ids)
if state_tensors is not None:
for tensor in state_tensors:
self.state.state_ids.append(tensor.id)
self.owner.register_obj(tensor)

self.include_state = include_state
self.is_built = is_built
self.input_shapes = None
Expand Down Expand Up @@ -240,6 +246,8 @@ def build(self, *args):
cloned_state = self.state.clone_state_dict()
self.state.send_for_build(location=self)

self.owner.init_plan = self

# We usually have include_state==True for functions converted to plan
# using @func2plan and we need therefore to add the state manually
if self.include_state:
Expand All @@ -262,6 +270,7 @@ def build(self, *args):
self.procedure.result_ids = (res_ptr.id_at_location,)

self.is_built = True
self.owner.init_plan = None

def copy(self):
"""Creates a copy of a plan."""
Expand Down Expand Up @@ -310,17 +319,30 @@ def __call__(self, *args, **kwargs):
The pointer to the result of the execution if the plan was already sent,
else the None message serialized.
"""

cloned_state = None
if self.owner.init_plan:
cloned_state = self.state.clone_state_dict()
self.owner.init_plan.nested_states.append(self.state)
self.state.send_for_build(location=self.owner.init_plan)

if len(kwargs):
raise ValueError("Kwargs are not supported for plan.")

result_ids = [sy.ID_PROVIDER.pop()]

res = None
if self.forward is not None:
if self.include_state:
args = (*args, self.state)
return self.forward(*args)
res = self.forward(*args)
else:
return self.run(args, result_ids=result_ids)
res = self.run(args, result_ids=result_ids)

if self.owner.init_plan:
self.state.set_(cloned_state)

return res

def execute_commands(self):
for message in self.procedure.operations:
Expand Down Expand Up @@ -413,10 +435,12 @@ def send(self, *locations, force=False) -> PointerPlan:
return self.pointers[location]

self.procedure.update_worker_ids(self.owner.id, location.id)

# Send the Plan
pointer = self.owner.send(self, workers=location)
# Revert ids
self.procedure.update_worker_ids(location.id, self.owner.id)

self.pointers[location] = pointer
else:
ids_at_location = []
Expand Down Expand Up @@ -520,6 +544,7 @@ def simplify(worker: AbstractWorker, plan: "Plan") -> tuple:
sy.serde.msgpack.serde._simplify(worker, plan.name),
sy.serde.msgpack.serde._simplify(worker, plan.tags),
sy.serde.msgpack.serde._simplify(worker, plan.description),
sy.serde.msgpack.serde._simplify(worker, plan.nested_states),
)

@staticmethod
Expand All @@ -543,15 +568,18 @@ def detail(worker: AbstractWorker, plan_tuple: tuple) -> "Plan":
name,
tags,
description,
nested_states,
) = plan_tuple
id = sy.serde.msgpack.serde._detail(worker, id)
procedure = sy.serde.msgpack.serde._detail(worker, procedure)
state = sy.serde.msgpack.serde._detail(worker, state)
input_shapes = sy.serde.msgpack.serde._detail(worker, input_shapes)
output_shape = sy.serde.msgpack.serde._detail(worker, output_shape)
nested_states = sy.serde.msgpack.serde._detail(worker, nested_states)

plan = sy.Plan(owner=worker, id=id, include_state=include_state, is_built=is_built)

plan.nested_states = nested_states
plan.procedure = procedure
plan.state = state
state.plan = plan
Expand Down
15 changes: 11 additions & 4 deletions syft/messaging/plan/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@ def __init__(self, owner, plan=None, state_ids=None):
self.plan = plan
self.state_ids = state_ids or []

def __str__(self):
"""Returns the string representation of the State."""
out = "<"
out += "State:"
for state_id in self.state_ids:
out += " {}".format(state_id)
out += ">"
return out

def __repr__(self):
return "State: " + ", ".join(self.state_ids)
return self.__str__()

def tensors(self) -> List:
"""
Expand Down Expand Up @@ -84,10 +93,8 @@ def create_grad_if_missing(tensor):
def send_for_build(self, location, **kwargs):
"""
Send functionality that can only be used when sending the state for
building the plan. Other than this, you shouldn't need to send the
state separately.
building the plan.
"""
assert location.id == self.plan.id # ensure this is a send for the build

for tensor in self.tensors():
self.create_grad_if_missing(tensor)
Expand Down
5 changes: 4 additions & 1 deletion syft/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class BaseWorker(AbstractWorker, ObjectStorage):
variables are instantiated or deleted as opposed to handling
tensor/variable/model lifecycle internally. Set to True if this
object is not where the objects will be stored, but is instead
a pointer to a worker that eists elsewhere.
a pointer to a worker that exists elsewhere.
log_msgs: An optional boolean parameter to indicate whether all
messages should be saved into a log for later review. This is
primarily a development/testing feature.
Expand Down Expand Up @@ -154,6 +154,9 @@ def __init__(
# self is the to-be-created local worker
self.add_worker(self)

# Used to keep track of a building plan
self.init_plan = None

if hook is None:
self.framework = None
else:
Expand Down
155 changes: 155 additions & 0 deletions test/message/test_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,3 +1046,158 @@ def plan_abs(data):
pointers = plan_abs.get_pointers()

assert len(pointers) == 2


def test_plan_nested_no_build_inner(workers):
alice = workers["alice"]
expected_res = th.tensor(200)

@sy.func2plan()
def plan_double(data):
return 2 * data

@sy.func2plan()
def plan_abs(data):
return plan_double(data).abs()

x = th.tensor(100)
plan_abs.build(x)

# Run plan locally
assert plan_abs(x) == expected_res

# Run plan remote
x_ptr = x.send(alice)
plan_abs_ptr = plan_abs.send(alice)
res = plan_abs_ptr(x_ptr)

assert res.get() == expected_res


def test_plan_nested_build_inner_plan_before(workers):
alice = workers["alice"]
expected_res = th.tensor(200)

@sy.func2plan(args_shape=[(1,)])
def plan_double(data):
return -2 * data

@sy.func2plan()
def plan_abs(data):
return plan_double(data).abs()

x = th.tensor(100)
plan_abs.build(x)

# Run plan locally
assert plan_abs(x) == expected_res

x_ptr = x.send(alice)
plan_abs_ptr = plan_abs.send(alice)
res = plan_abs_ptr(x_ptr)

assert res.get() == expected_res


def test_plan_nested_build_inner_plan_after(workers):
alice = workers["alice"]
expected_res = th.tensor(200)

@sy.func2plan()
def plan_double(data):
return -2 * data

@sy.func2plan()
def plan_abs(data):
return plan_double(data).abs()

x = th.tensor(100)
plan_abs.build(x)
plan_double.build(x)

# Test locally
assert plan_abs(x) == expected_res

# Test remote
x_ptr = x.send(alice)
plan_double_ptr = plan_abs.send(alice)
res = plan_double_ptr(x_ptr)

assert res.get() == expected_res


def test_plan_nested_build_inner_plan_state(hook, workers):
alice = workers["alice"]
expected_res = th.tensor(199)

with hook.local_worker.registration_enabled():

@sy.func2plan(args_shape=[(1,)], state=(th.tensor([1]),))
def plan_double(data, state):
(bias,) = state.read()
return -2 * data + bias

@sy.func2plan()
def plan_abs(data):
return plan_double(data).abs()

x = th.tensor(100)
plan_abs.build(x)

# Test locally
assert plan_abs(x) == expected_res

# Test remote
x_ptr = x.send(alice)
plan_abs_ptr = plan_abs.send(alice)
plan_abs_ptr(x_ptr)

res = plan_abs_ptr(x_ptr)
assert res.get() == expected_res


def test_plan_nested_build_multiple_plans_state(hook, workers):
alice = workers["alice"]
expected_res = th.tensor(1043)

with hook.local_worker.registration_enabled():

@sy.func2plan(args_shape=[(1,)], state=(th.tensor([3]),))
def plan_3(data, state):
(bias,) = state.read()
return data + bias + 42

@sy.func2plan(args_shape=[(1,)])
def plan_2_2(data):
return data + 1331

@sy.func2plan(args_shape=[(1,)], state=(th.tensor([2]),))
def plan_2_1(data, state):
(bias,) = state.read()
return -2 * plan_3(data) + bias

@sy.func2plan()
def plan_1(data):
res = plan_2_1(data)
return plan_2_2(res)

# (-2 * (x + tensor(3) + 42) + tensor(2) + 1331)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove this - It was only a simple way that I thought of to show what each plan/function is doing

# -------------------
# plan_3
# --------------------------------------
# plan_2_1
# -----------------------------------------------
# plan_2_2

x = th.tensor(100)
plan_1.build(x)

# Test locally
assert plan_1(x) == expected_res

# Test remote
x_ptr = x.send(alice)
plan_1_ptr = plan_1.send(alice)

res = plan_1_ptr(x_ptr)
assert res.get() == expected_res
4 changes: 4 additions & 0 deletions test/test_serde_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,8 @@ def compare(detailed, original):
assert detailed.procedure.operations == original.procedure.operations
assert detailed.procedure.arg_ids == original.procedure.arg_ids
assert detailed.procedure.result_ids == original.procedure.result_ids
# States for the nested plans
assert detailed.nested_states == original.nested_states
# State
assert detailed.state.state_ids == original.state.state_ids
assert detailed.include_state == original.include_state
Expand Down Expand Up @@ -678,6 +680,7 @@ def compare(detailed, original):
msgpack.serde._simplify(
syft.hook.local_worker, plan.description
), # (str) description
msgpack.serde._simplify(syft.hook.local_worker, []), # (list of State)
),
),
"cmp_detailed": compare,
Expand All @@ -702,6 +705,7 @@ def compare(detailed, original):
msgpack.serde._simplify(
syft.hook.local_worker, model_plan.description
), # (str) description
msgpack.serde._simplify(syft.hook.local_worker, []), # (list of State)
),
),
"cmp_detailed": compare,
Expand Down