Skip to content
Closed
76 changes: 76 additions & 0 deletions syft/frameworks/torch/tensors/interpreters/additive_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from syft.generic.frameworks.overload import overloaded
from syft.workers.abstract import AbstractWorker

from syft_proto.frameworks.torch.tensors.interpreters.v1.additive_shared_pb2 import (
AdditiveSharingTensor as AdditiveSharingTensorPB,
)
from syft_proto.types.syft.v1.id_pb2 import Id as IdPB

no_wrap = {"no_wrap": True}


Expand Down Expand Up @@ -1023,6 +1028,77 @@ def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "AdditiveSharingTenso

return tensor

@staticmethod
def bufferize(
worker: AbstractWorker, tensor: "AdditiveSharingTensor"
) -> "AdditiveSharingTensorPB":
"""
This function takes the attributes of a AdditiveSharingTensor and saves them in a protobuf object
Args:
tensor (AdditiveSharingTensor): a AdditiveSharingTensor
Returns:
protobuf: a protobuf object holding the unique attributes of the additive shared tensor
Examples:
data = protobuf(tensor)
"""

location_ids = []
shares = []
if hasattr(tensor, "child"):
for key, value in tensor.child.items():
location_ids.append(sy.serde.protobuf.serde.create_protobuf_id(key))
shares.append(value)

# Don't delete the remote values of the shares at simplification
tensor.set_garbage_collect_data(False)

protobuf_tensor = AdditiveSharingTensorPB()
protobuf_tensor.id.CopyFrom(sy.serde.protobuf.serde.create_protobuf_id(tensor.id))
protobuf_tensor.field_size = tensor.field
protobuf_tensor.crypto_provider_id.CopyFrom(
sy.serde.protobuf.serde.create_protobuf_id(tensor.crypto_provider.id)
)
protobuf_tensor.location_ids.extend(location_ids)
protobuf_tensor.shares.extend(shares)

return protobuf_tensor

@staticmethod
def unbufferize(
worker: AbstractWorker, protobuf_tensor: "AdditiveSharingTensorPB"
) -> "AdditiveSharingTensor":
"""
This function reconstructs a AdditiveSharingTensor given its' attributes in form of a protobuf object.
Args:
worker: the worker doing the deserialization
protobuf_tensor: a protobuf object holding the attributes of the AdditiveSharingTensor
Returns:
AdditiveSharingTensor: a AdditiveSharingTensor
Examples:
shared_tensor = unprotobuf(data)
"""

tensor_id = getattr(protobuf_tensor, protobuf_tensor.WhichOneof("id"))
field = protobuf_tensor.field_size
crypto_provider_id = getattr(
protobuf_tensor, protobuf_tensor.WhichOneof("crypto_provider_id")
)

tensor = AdditiveSharingTensor(
owner=worker,
id=tensor_id,
field=field,
crypto_provider=worker.get_worker(crypto_provider_id),
)

if protobuf_tensor.location_ids is not None:
chain = {}
for location_id, share in zip(protobuf_tensor.location_ids, protobuf_tensor.shares):
chain[location_id] = share
tensor.child = chain

return tensor


### Register the tensor with hook_args.py ###
hook_args.default_register_tensor(AdditiveSharingTensor)
2 changes: 1 addition & 1 deletion syft/generic/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def serialize(self): # check serde.py to see how to provide compression schemes
x = torch.Tensor([1,2,3,4,5])
x.serialize() # returns a serialized object
"""
return sy.serde.serialize(self)
return sy.serde.protobuf.serde.serialize(self)

def ser(self, *args, **kwargs):
return self.serialize(*args, **kwargs)
Expand Down
89 changes: 89 additions & 0 deletions syft/messaging/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from syft.workers.abstract import AbstractWorker


from syft_proto.messaging.v1.message_pb2 import ObjectMessage as ObjectMessagePB
from syft_proto.messaging.v1.message_pb2 import OperationMessage as OperationMessagePB
from syft_proto.types.syft.v1.operation_pb2 import Operation as OperationPB


class Message:
"""All syft message types extend this class

Expand Down Expand Up @@ -195,6 +200,65 @@ def detail(worker: AbstractWorker, msg_tuple: tuple) -> "Operation":

return Operation(cmd_name, cmd_owner, cmd_args, cmd_kwargs, detailed_ids)

@staticmethod
def bufferize(worker: AbstractWorker, operation: "Operation") -> "OperationMessagePB":
"""
This function takes the attributes of a Operation and saves them in Protobuf
Args:
worker (AbstractWorker): a reference to the worker doing the serialization
ptr (Operation): a Message
Returns:
protobuf_obj: a Protobuf message holding the unique attributes of the message
Examples:
data = bufferize(message)
"""

protobuf_op_msg = OperationMessagePB()
protobuf_op = OperationPB()
protobuf_op.command = operation.cmd_name
# protobuf_op.owner = sy.serde.protobuf.serde._bufferize(worker, operation.cmd_owner)
if operation.cmd_args:
protobuf_op.args = sy.serde.protobuf.serde._bufferize(worker, operation.cmd_args)
if operation.cmd_kwargs:
protobuf_op.kwargs = sy.serde.protobuf.serde._bufferize(worker, operation.cmd_kwargs)

return_ids = []
for return_id in operation.return_ids:
return_ids.append(sy.serde.protobuf.serde.create_protobuf_id(return_id))

protobuf_op.return_ids.extend(return_ids)

protobuf_op_msg.operation.CopyFrom(protobuf_op)
return protobuf_op_msg

@staticmethod
def unbufferize(worker: AbstractWorker, protobuf_obj: "OperationMessagePB") -> "Operation":
"""
This function takes the Protobuf version of this message and converts
it into an Operation. The bufferize() method runs the inverse of this method.

Args:
worker (AbstractWorker): a reference to the worker necessary for detailing. Read
syft/serde/serde.py for more information on why this is necessary.
protobuf_obj (OperationPB): the Protobuf message

Returns:
obj (Operation): an Operation

Examples:
message = unbufferize(sy.local_worker, protobuf_msg)
"""

command = protobuf_obj.operation.command
# owner
# Args
# kwargs
return_ids = protobuf_obj.operation.return_ids

operation_msg = Operation(command, [], [], [], return_ids)

return operation_msg


class ObjectMessage(Message):
"""Send an object to another worker using this message type.
Expand Down Expand Up @@ -227,6 +291,31 @@ def detail(worker: AbstractWorker, msg_tuple: tuple) -> "ObjectMessage":
"""
return ObjectMessage(sy.serde.msgpack.serde._detail(worker, msg_tuple[0]))

@staticmethod
def bufferize(worker: AbstractWorker, message: "ObjectMessage") -> "ObjectMessagePB":
"""
This function takes the attributes of an Object Message and saves them in a protobuf object
Args:
message (ObjectMessage): an ObjectMessage
Returns:
protobuf: a protobuf object holding the unique attributes of the object message
Examples:
data = bufferize(object_message)
"""

protobuf_obj_msg = ObjectMessagePB()
bufferized_contents = sy.serde.protobuf.serde._bufferize(worker, message.contents)
protobuf_obj_msg.tensor.CopyFrom(bufferized_contents)
return protobuf_obj_msg

@staticmethod
def unbufferize(worker: AbstractWorker, protobuf_obj: "ObjectMessagePB") -> "ObjectMessage":
protobuf_contents = protobuf_obj.tensor
contents = sy.serde.protobuf.serde._unbufferize(worker, protobuf_contents)
object_msg = ObjectMessage(contents)

return object_msg


class ObjectRequestMessage(Message):
"""Request another worker to send one of its objects
Expand Down
4 changes: 2 additions & 2 deletions syft/messaging/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _recv_msg(self, bin_message: bin):

self.procedure.operations.append((msg_type, contents))

return sy.serde.serialize(None)
return sy.serde.protobuf.serde.serialize(None)

def build(self, *args):
"""Builds the plan.
Expand Down Expand Up @@ -324,7 +324,7 @@ def __call__(self, *args, **kwargs):

def execute_commands(self):
for message in self.procedure.operations:
bin_message = sy.serde.serialize(message, simplified=True)
bin_message = sy.serde.protobuf.serde.serialize(message, simplified=True)
_ = self.owner.recv_msg(bin_message)

def run(self, args: Tuple, result_ids: List[Union[str, int]]):
Expand Down
40 changes: 40 additions & 0 deletions syft/serde/protobuf/native_serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
This file exists to provide a common place for all Protobuf
serialisation for native Python objects. If you're adding
something here that isn't for `None`, think twice and either
use an existing sub-class of Message or add a new one.
"""

from collections import OrderedDict
from google.protobuf.empty_pb2 import Empty
from syft.workers.abstract import AbstractWorker


def _bufferize_none(worker: AbstractWorker, obj: "type(None)") -> "Empty":
"""
This function converts None into an empty Protobuf message.

Args:
obj (None): makes signature match other bufferize methods

Returns:
protobuf_obj: Empty Protobuf message
"""
return Empty()


def _unbufferize_none(worker: AbstractWorker, obj: "Empty") -> "type(None)":
"""
This function converts an empty Protobuf message back into None.

Args:
obj (Empty): Empty Protobuf message

Returns:
obj: None
"""
return None


# Maps a type to its bufferizer and unbufferizer functions
MAP_NATIVE_PROTOBUF_TRANSLATORS = OrderedDict({type(None): (_bufferize_none, _unbufferize_none)})
36 changes: 36 additions & 0 deletions syft/serde/protobuf/proto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
This file exists to translate python classes to and from Protobuf messages.
The reason for this is to have stable serialization protocol that can be used
not only by PySyft but also in other languages.

https://github.com/OpenMined/syft-proto (`syft_proto` module) is included as
a dependency in setup.py.
"""
import torch

from google.protobuf.empty_pb2 import Empty

from syft.messaging.message import ObjectMessage
from syft.messaging.message import Operation
from syft.frameworks.torch.tensors.interpreters.additive_shared import AdditiveSharingTensor

from syft_proto.frameworks.torch.tensors.interpreters.v1.additive_shared_pb2 import (
AdditiveSharingTensor as AdditiveSharingTensorPB,
)
from syft_proto.messaging.v1.message_pb2 import ObjectMessage as ObjectMessagePB
from syft_proto.messaging.v1.message_pb2 import OperationMessage as OperationMessagePB
from syft_proto.generic.v1.tensor_pb2 import Tensor as TensorPB


MAP_PYTHON_TO_PROTOBUF_CLASSES = {
ObjectMessage: ObjectMessagePB,
Operation: OperationMessagePB,
torch.Tensor: TensorPB,
AdditiveSharingTensor: AdditiveSharingTensorPB,
type(None): Empty,
}

MAP_PROTOBUF_TO_PYTHON_CLASSES = {}

for key, value in MAP_PYTHON_TO_PROTOBUF_CLASSES.items():
MAP_PROTOBUF_TO_PYTHON_CLASSES[value] = key
Loading