Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pip-dep/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ phe>=1.4.0
Pillow<7
requests==2.22.0
scipy>=1.4.1
syft-proto>=0.1.1.a1.post20
syft-proto>=0.2.1.a1.post2
tblib>=1.4.0
torch==1.4
torchvision==0.5.0
Expand Down
6 changes: 3 additions & 3 deletions syft/execution/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from syft.generic.object import AbstractObject
from syft.generic.object_storage import ObjectStorage
from syft.generic.pointers.pointer_plan import PointerPlan
from syft.messaging.message import Operation
from syft.messaging.message import OperationMessage
from syft.execution.state import State
from syft.workers.abstract import AbstractWorker
from syft.frameworks.torch.tensors.interpreters.placeholder import PlaceHolder
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
state: State = None,
include_state: bool = False,
is_built: bool = False,
operations: List[Operation] = None,
operations: List[OperationMessage] = None,
placeholders: Dict[Union[str, int], PlaceHolder] = None,
forward_func=None,
state_tensors=None,
Expand Down Expand Up @@ -302,7 +302,7 @@ def build(self, *args):
self.replace_with_placeholders(response, node_type="output"),
)
# We're cheating a bit here because we put placeholders instead of return_ids
operation = Operation(*command_placeholders, return_ids=return_placeholders)
operation = OperationMessage(*command_placeholders, return_ids=return_placeholders)
self.operations.append(operation)

sy.hook.trace.clear()
Expand Down
52 changes: 29 additions & 23 deletions syft/messaging/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __repr__(self):
return self.__str__()


class Operation(Message):
class OperationMessage(Message):
"""All syft operations use this message type

In Syft, an operation is when one worker wishes to tell another worker to do something with
Expand Down Expand Up @@ -154,12 +154,12 @@ def contents(self):
return (message, self.return_ids)

@staticmethod
def simplify(worker: AbstractWorker, ptr: "Operation") -> tuple:
def simplify(worker: AbstractWorker, ptr: "OperationMessage") -> tuple:
"""
This function takes the attributes of a Operation and saves them in a tuple
This function takes the attributes of a OperationMessage and saves them in a tuple
Args:
worker (AbstractWorker): a reference to the worker doing the serialization
ptr (Operation): a Message
ptr (OperationMessage): a Message
Returns:
tuple: a tuple holding the unique attributes of the message
Examples:
Expand All @@ -175,7 +175,7 @@ def simplify(worker: AbstractWorker, ptr: "Operation") -> tuple:
)

@staticmethod
def detail(worker: AbstractWorker, msg_tuple: tuple) -> "Operation":
def detail(worker: AbstractWorker, msg_tuple: tuple) -> "OperationMessage":
"""
This function takes the simplified tuple version of this message and converts
it into a Operation. The simplify() method runs the inverse of this method.
Expand All @@ -185,7 +185,7 @@ def detail(worker: AbstractWorker, msg_tuple: tuple) -> "Operation":
syft/serde/serde.py for more information on why this is necessary.
msg_tuple (Tuple): the raw information being detailed.
Returns:
ptr (Operation): a Operation.
ptr (OperationMessage): an OperationMessage.
Examples:
message = detail(sy.local_worker, msg_tuple)
"""
Expand All @@ -200,15 +200,15 @@ def detail(worker: AbstractWorker, msg_tuple: tuple) -> "Operation":
cmd_args = detailed_msg[2]
cmd_kwargs = detailed_msg[3]

return Operation(cmd_name, cmd_owner, cmd_args, cmd_kwargs, detailed_ids)
return OperationMessage(cmd_name, cmd_owner, cmd_args, cmd_kwargs, detailed_ids)

@staticmethod
def bufferize(worker: AbstractWorker, operation: "Operation") -> "OperationMessagePB":
def bufferize(worker: AbstractWorker, operation: "OperationMessage") -> "OperationMessagePB":
"""
This function takes the attributes of a Operation and saves them in Protobuf
This function takes the attributes of a OperationMessage and saves them in Protobuf
Args:
worker (AbstractWorker): a reference to the worker doing the serialization
ptr (Operation): a Message
ptr (OperationMessage): an OperationMessage
Returns:
protobuf_obj: a Protobuf message holding the unique attributes of the message
Examples:
Expand All @@ -232,12 +232,12 @@ def bufferize(worker: AbstractWorker, operation: "Operation") -> "OperationMessa
protobuf_owner.CopyFrom(sy.serde.protobuf.serde._bufferize(worker, operation.cmd_owner))

if operation.cmd_args:
protobuf_op.args.extend(Operation._bufferize_args(worker, operation.cmd_args))
protobuf_op.args.extend(OperationMessage._bufferize_args(worker, operation.cmd_args))

if operation.cmd_kwargs:
for key, value in operation.cmd_kwargs.items():
protobuf_op.kwargs.get_or_create(key).CopyFrom(
Operation._bufferize_arg(worker, value)
OperationMessage._bufferize_arg(worker, value)
)

if operation.return_ids is not None:
Expand All @@ -258,18 +258,20 @@ def bufferize(worker: AbstractWorker, operation: "Operation") -> "OperationMessa
return protobuf_op_msg

@staticmethod
def unbufferize(worker: AbstractWorker, protobuf_obj: "OperationMessagePB") -> "Operation":
def unbufferize(
worker: AbstractWorker, protobuf_obj: "OperationMessagePB"
) -> "OperationMessage":
"""
This function takes the Protobuf version of this message and converts
it into an Operation. The bufferize() method runs the inverse of this method.
it into an OperationMessage. The bufferize() method runs the inverse of this method.

Args:
worker (AbstractWorker): a reference to the worker necessary for detailing. Read
syft/serde/serde.py for more information on why this is necessary.
protobuf_obj (OperationPB): the Protobuf message
protobuf_obj (OperationMessagePB): the Protobuf message

Returns:
obj (Operation): an Operation
obj (OperationMessage): an OperationMessage

Examples:
message = unbufferize(sy.local_worker, protobuf_msg)
Expand All @@ -283,11 +285,13 @@ def unbufferize(worker: AbstractWorker, protobuf_obj: "OperationMessagePB") -> "
)
else:
owner = None
args = Operation._unbufferize_args(worker, protobuf_obj.operation.args)
args = OperationMessage._unbufferize_args(worker, protobuf_obj.operation.args)

kwargs = {}
for key in protobuf_obj.operation.kwargs:
kwargs[key] = Operation._unbufferize_arg(worker, protobuf_obj.operation.kwargs[key])
kwargs[key] = OperationMessage._unbufferize_arg(
worker, protobuf_obj.operation.kwargs[key]
)

return_ids = [
sy.serde.protobuf.proto.get_protobuf_id(pb_id)
Expand All @@ -301,21 +305,23 @@ def unbufferize(worker: AbstractWorker, protobuf_obj: "OperationMessagePB") -> "

if return_placeholders:
if len(return_placeholders) == 1:
operation_msg = Operation(
operation_msg = OperationMessage(
command, owner, tuple(args), kwargs, return_placeholders[0]
)
else:
operation_msg = Operation(command, owner, tuple(args), kwargs, return_placeholders)
operation_msg = OperationMessage(
command, owner, tuple(args), kwargs, return_placeholders
)
else:
operation_msg = Operation(command, owner, tuple(args), kwargs, tuple(return_ids))
operation_msg = OperationMessage(command, owner, tuple(args), kwargs, tuple(return_ids))

return operation_msg

@staticmethod
def _bufferize_args(worker: AbstractWorker, args: list) -> list:
protobuf_args = []
for arg in args:
protobuf_args.append(Operation._bufferize_arg(worker, arg))
protobuf_args.append(OperationMessage._bufferize_arg(worker, arg))
return protobuf_args

@staticmethod
Expand All @@ -333,7 +339,7 @@ def _bufferize_arg(worker: AbstractWorker, arg: object) -> ArgPB:
def _unbufferize_args(worker: AbstractWorker, protobuf_args: list) -> list:
args = []
for protobuf_arg in protobuf_args:
args.append(Operation._unbufferize_arg(worker, protobuf_arg))
args.append(OperationMessage._unbufferize_arg(worker, protobuf_arg))
return args

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions syft/serde/msgpack/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from syft.execution.state import State
from syft.execution.protocol import Protocol
from syft.messaging.message import Message
from syft.messaging.message import Operation
from syft.messaging.message import OperationMessage
from syft.messaging.message import ObjectMessage
from syft.messaging.message import ObjectRequestMessage
from syft.messaging.message import IsNoneMessage
Expand Down Expand Up @@ -120,7 +120,7 @@
BaseWorker,
AutogradTensor,
Message,
Operation,
OperationMessage,
ObjectMessage,
ObjectRequestMessage,
IsNoneMessage,
Expand Down
4 changes: 2 additions & 2 deletions syft/serde/protobuf/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from syft.frameworks.torch.tensors.interpreters.placeholder import PlaceHolder
from syft.generic.pointers.pointer_tensor import PointerTensor
from syft.messaging.message import ObjectMessage
from syft.messaging.message import Operation
from syft.messaging.message import OperationMessage
from syft.execution.plan import Plan
from syft.execution.protocol import Protocol
from syft.execution.state import State
Expand Down Expand Up @@ -54,7 +54,7 @@
# Syft types
AdditiveSharingTensor: AdditiveSharingTensorPB,
ObjectMessage: ObjectMessagePB,
Operation: OperationMessagePB,
OperationMessage: OperationMessagePB,
PlaceHolder: PlaceholderPB,
Plan: PlanPB,
PointerTensor: PointerTensorPB,
Expand Down
6 changes: 3 additions & 3 deletions syft/serde/protobuf/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from syft.frameworks.torch.tensors.interpreters.placeholder import PlaceHolder
from syft.generic.pointers.pointer_tensor import PointerTensor
from syft.messaging.message import ObjectMessage
from syft.messaging.message import Operation
from syft.messaging.message import OperationMessage
from syft.execution.plan import Plan
from syft.execution.protocol import Protocol
from syft.execution.state import State
Expand Down Expand Up @@ -41,7 +41,7 @@
OBJ_PROTOBUF_TRANSLATORS = [
AdditiveSharingTensor,
ObjectMessage,
Operation,
OperationMessage,
PlaceHolder,
Plan,
PointerTensor,
Expand Down Expand Up @@ -237,7 +237,7 @@ def serialize(
msg_wrapper.contents_empty_msg.CopyFrom(protobuf_obj)
elif obj_type == ObjectMessage:
msg_wrapper.contents_object_msg.CopyFrom(protobuf_obj)
elif obj_type == Operation:
elif obj_type == OperationMessage:
msg_wrapper.contents_operation_msg.CopyFrom(protobuf_obj)

# 2) Serialize
Expand Down
9 changes: 5 additions & 4 deletions syft/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from syft.generic.pointers.pointer_tensor import PointerTensor
from syft.messaging.message import Message
from syft.messaging.message import ForceObjectDeleteMessage
from syft.messaging.message import Operation
from syft.messaging.message import OperationMessage
from syft.messaging.message import ObjectMessage
from syft.messaging.message import ObjectRequestMessage
from syft.messaging.message import IsNoneMessage
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(

# For performance, we cache all possible message types
self._message_router = {
Operation: self.execute_command,
OperationMessage: self.execute_command,
PlanCommandMessage: self.execute_plan_command,
ObjectMessage: self.set_obj,
ObjectRequestMessage: self.respond_to_obj_req,
Expand Down Expand Up @@ -514,7 +514,8 @@ def send_command(

try:
ret_val = self.send_msg(
Operation(cmd_name, cmd_owner, cmd_args, cmd_kwargs, return_ids), location=recipient
OperationMessage(cmd_name, cmd_owner, cmd_args, cmd_kwargs, return_ids),
location=recipient,
)
except ResponseSignatureError as e:
ret_val = None
Expand Down Expand Up @@ -1023,7 +1024,7 @@ def create_message_execute_command(
"""
if return_ids is None:
return_ids = []
return Operation(command_name, command_owner, args, kwargs, return_ids)
return OperationMessage(command_name, command_owner, args, kwargs, return_ids)

@property
def serializer(self, workers=None) -> codes.TENSOR_SERIALIZATION:
Expand Down
2 changes: 1 addition & 1 deletion test/message/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_cmd_message(workers):
x = th.tensor([1, 2, 3, 4]).send(bob)

y = x + x # this is the test
assert isinstance(bob._get_msg(-1), message.Operation)
assert isinstance(bob._get_msg(-1), message.OperationMessage)

y = y.get()

Expand Down
2 changes: 1 addition & 1 deletion test/serde/msgpack/test_msgpack_serde_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
samples[syft.frameworks.torch.tensors.interpreters.placeholder.PlaceHolder] = make_placeholder

samples[syft.messaging.message.Message] = make_message
samples[syft.messaging.message.Operation] = make_operation
samples[syft.messaging.message.OperationMessage] = make_operation
samples[syft.messaging.message.ObjectMessage] = make_objectmessage
samples[syft.messaging.message.ObjectRequestMessage] = make_objectrequestmessage
samples[syft.messaging.message.IsNoneMessage] = make_isnonemessage
Expand Down
2 changes: 1 addition & 1 deletion test/serde/protobuf/test_protobuf_serde_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

# Syft Messages
samples[syft.messaging.message.ObjectMessage] = make_objectmessage
samples[syft.messaging.message.Operation] = make_operation
samples[syft.messaging.message.OperationMessage] = make_operation


def test_serde_coverage():
Expand Down
8 changes: 4 additions & 4 deletions test/serde/serde_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,7 @@ def compare(detailed, original):
]


# syft.messaging.message.Operation
# syft.messaging.message.OperationMessage
def make_operation(**kwargs):
bob = kwargs["workers"]["bob"]
bob.log_msgs = True
Expand Down Expand Up @@ -1334,7 +1334,7 @@ def compare(detailed, original):
original.cmd_args,
original.cmd_kwargs,
)
assert type(detailed) == syft.messaging.message.Operation
assert type(detailed) == syft.messaging.message.OperationMessage
for i in range(len(original_msg)):
if type(original_msg[i]) != torch.Tensor:
assert detailed_msg[i] == original_msg[i]
Expand All @@ -1350,7 +1350,7 @@ def compare(detailed, original):
{
"value": op1,
"simplified": (
CODE[syft.messaging.message.Operation],
CODE[syft.messaging.message.OperationMessage],
(
msgpack.serde._simplify(syft.hook.local_worker, message1), # (Any) message
(CODE[tuple], (op1.return_ids[0],)), # (tuple) return_ids
Expand All @@ -1361,7 +1361,7 @@ def compare(detailed, original):
{
"value": op2,
"simplified": (
CODE[syft.messaging.message.Operation],
CODE[syft.messaging.message.OperationMessage],
(
msgpack.serde._simplify(syft.hook.local_worker, message2), # (Any) message
(CODE[tuple], (op2.return_ids[0],)), # (tuple) return_ids
Expand Down