Skip to content
Closed
15 changes: 6 additions & 9 deletions syft/messaging/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Plan(AbstractObject, ObjectStorage):
state_ids: ids of the state elements
arg_ids: ids of the last argument ids (present in the procedure commands)
result_ids: ids of the last result ids (present in the procedure commands)
readable_plan: list of commands
operations: list of commands
blueprint: the function to be transformed into a plan
state_tensors: a tuple of state elements. It can be used to populate a state
id: state id
Expand All @@ -98,7 +98,7 @@ def __init__(
state_ids: List[Union[str, int]] = None,
arg_ids: List[Union[str, int]] = None,
result_ids: List[Union[str, int]] = None,
readable_plan: List = None,
operations: List = None,
blueprint=None,
state_tensors=None,
# General kwargs
Expand All @@ -120,7 +120,7 @@ def __init__(
self.nested_states = []

# Info about the plan stored via the state and the procedure
self.procedure = procedure or Procedure(readable_plan, arg_ids, result_ids)
self.procedure = procedure or Procedure(operations, arg_ids, result_ids)
self.state = state or State(owner=owner, plan=self, state_ids=state_ids)
if state_tensors is not None:
for tensor in state_tensors:
Expand Down Expand Up @@ -163,7 +163,7 @@ def location(self):

# For backward compatibility
@property
def readable_plan(self):
def operations(self):
return self.procedure.operations

def parameters(self):
Expand Down Expand Up @@ -214,11 +214,8 @@ def _recv_msg(self, bin_message: bin):

if isinstance(msg, Operation):
# Re-deserialize without detailing
(msg_type, contents) = sy.serde.deserialize(
bin_message, strategy=sy.serde.msgpack.serde._deserialize_msgpack_binary
)

self.procedure.operations.append((msg_type, contents))
self.procedure.operations.append(msg)

return sy.serde.serialize(None)

Expand Down Expand Up @@ -346,7 +343,7 @@ def __call__(self, *args, **kwargs):

def execute_commands(self):
for message in self.procedure.operations:
bin_message = sy.serde.serialize(message, simplified=True)
bin_message = sy.serde.serialize(message)
_ = self.owner.recv_msg(bin_message)

def run(self, args: Tuple, result_ids: List[Union[str, int]]):
Expand Down
110 changes: 47 additions & 63 deletions syft/messaging/plan/procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Procedure(object):
on different workers.

Args:
operations: the list of (serialized) operations
operations: the list of operations
arg_ids: the argument ids present in the operations
result_ids: the result ids present in the operations
"""
Expand Down Expand Up @@ -63,8 +63,8 @@ def update_ids(
self,
from_ids: Tuple[Union[str, int]] = [],
to_ids: Tuple[Union[str, int]] = [],
from_worker: Union[str, int] = None,
to_worker: Union[str, int] = None,
from_worker_id: Union[str, int] = None,
to_worker_id: Union[str, int] = None,
):
"""Replace ids and worker ids in the list of operations stored

Expand All @@ -74,74 +74,58 @@ def update_ids(
from_worker: The previous worker that built the plan.
to_worker: The new worker that is running the plan.
"""

# We operate on simplified content of Operation, hence all values should be simplified
from_workers_simplified = None
to_workers_simplified = None
if from_worker and to_worker:
from_workers_simplified = [sy.serde.msgpack.serde._simplify(None, from_worker)]
to_workers_simplified = [sy.serde.msgpack.serde._simplify(None, to_worker)]

from_ids_simplified = None
to_ids_simplified = None
if len(from_ids) and len(to_ids):
from_ids_simplified = [sy.serde.msgpack.serde._simplify(None, id) for id in from_ids]
to_ids_simplified = [sy.serde.msgpack.serde._simplify(None, id) for id in to_ids]

for idx, operation in enumerate(self.operations):
if from_workers_simplified and to_workers_simplified:
operation = Procedure.replace_operation_ids(
operation, from_workers_simplified, to_workers_simplified
)

if from_ids_simplified and to_ids_simplified:
operation = Procedure.replace_operation_ids(
operation, from_ids_simplified, to_ids_simplified
)
# replace ids in the owner
owner = operation.cmd_owner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the type of operation now? I can't find the definition here it must be in the unchanged files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an Operation object.

if owner is not None:
if owner.id in from_ids:
owner.id = to_ids[from_ids.index(owner.id)]

if owner.id_at_location in from_ids:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think owner have no id_at_location

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The owner can be a PointerTensor, so it can have an id_at_location property, but may not always. As written, this code may fail with non-pointer tensors. 🤔

owner.id_at_location = to_ids[from_ids.index(owner.id_at_location)]

# replace ids in the args
for arg in operation.cmd_args:
try:
if arg.id in from_ids:
arg.id = to_ids[from_ids.index(arg.id)]

if arg.id_at_location in from_ids:
arg.id_at_location = to_ids[from_ids.index(arg.id_at_location)]
except:
pass

# replace ids in the returns
return_ids = list(operation.return_ids)
for idx, return_id in enumerate(return_ids):
if return_id in from_ids:
return_ids[idx] = to_ids[from_ids.index(return_id)]
operation.return_ids = return_ids

# create a dummy worker
to_worker = sy.workers.virtual.VirtualWorker(None, to_worker_id)

# replace worker in the owner
if owner is not None and owner.location is not None:
if owner.location.id == from_worker_id:
owner.location = to_worker

# replace worker in the args
for arg in operation.cmd_args:
try:
if arg.location.id == from_worker_id:
arg.location = to_worker
except:
pass

self.operations[idx] = operation

return self

@staticmethod
def replace_operation_ids(operation, from_ids, to_ids):
"""
Replace ids in a single operation

Args:
operation: the operation to update
from_ids: the ids to replace
to_ids: the new ids to put inplace
"""

assert isinstance(from_ids, (list, tuple))
assert isinstance(to_ids, (list, tuple))

type_obj = type(operation)
operation = list(operation)
for i, item in enumerate(operation):
# Since this is simplified content, id can be int or simplified str (tuple)
if isinstance(item, (int, tuple)) and item in from_ids:
operation[i] = to_ids[from_ids.index(item)]
elif isinstance(item, (list, tuple)):
operation[i] = Procedure.replace_operation_ids(
operation=item, from_ids=from_ids, to_ids=to_ids
)
return type_obj(operation)

def copy(self) -> "Procedure":
procedure = Procedure(
operations=copy.deepcopy(self.operations),
arg_ids=self.arg_ids,
result_ids=self.result_ids,
)
return procedure

@staticmethod
def simplify(worker: AbstractWorker, procedure: "Procedure") -> tuple:
return (
# We're not simplifying fully because operations are already simplified
sy.serde.msgpack.serde._simplify(worker, procedure.operations, shallow=True),
sy.serde.msgpack.serde._simplify(worker, procedure.operations),
sy.serde.msgpack.serde._simplify(worker, procedure.arg_ids),
sy.serde.msgpack.serde._simplify(worker, procedure.result_ids),
sy.serde.msgpack.serde._simplify(worker, procedure.promise_out_id),
Expand All @@ -151,7 +135,7 @@ def simplify(worker: AbstractWorker, procedure: "Procedure") -> tuple:
def detail(worker: AbstractWorker, procedure_tuple: tuple) -> "State":
operations, arg_ids, result_ids, promise_out_id = procedure_tuple

operations = sy.serde.msgpack.serde._detail(worker, operations, shallow=True)
operations = sy.serde.msgpack.serde._detail(worker, operations)
arg_ids = sy.serde.msgpack.serde._detail(worker, arg_ids)
result_ids = sy.serde.msgpack.serde._detail(worker, result_ids)

Expand Down
Loading