Skip to content
Closed
36 changes: 21 additions & 15 deletions syft/serde/msgpack/torch_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
for all tensors (Torch and Numpy).
"""
from collections import OrderedDict
import torch
import io
from tempfile import TemporaryFile
from typing import Tuple, List
Expand Down Expand Up @@ -93,21 +94,36 @@ def _deserialize_tensor(worker: AbstractWorker, serializer: str, tensor_bin) ->
def numpy_tensor_serializer(worker: AbstractWorker, tensor: torch.Tensor) -> bin:
"""Strategy to serialize a tensor using numpy npy format.
If tensor requires to calculate gradients, it will be detached.

Args
(torch.Tensor): an input tensor to be serialized

Returns
A serialized version of the input tensor
"""
if tensor.requires_grad:
warnings.warn(
"Torch to Numpy serializer can only be used with tensors that do not require grad. "
"Detaching tensor to continue"
)
tensor = tensor.detach()
tensor = torch.detach()

np_tensor = tensor.numpy()
outfile = TemporaryFile()
outfile = io.BytesIO()
numpy.save(outfile, np_tensor)
# Simulate close and open by calling seek
outfile.seek(0)
return outfile.read()
return outfile.getvalue()

def numpy_tensor_deserializer(tensor_bin) -> torch.Tensor:
"""Strategy to deserialize a binary input in npy format into Torch tensor

Args
tensor_bin: A binary representation of a tensor

Returns
a Torch tensor
"""
bin_tensor_stream = io.BytesIO(tensor_bin)
return torch.from_numpy(numpy.load(bin_tensor_stream))

def generic_tensor_serializer(worker: AbstractWorker, tensor: torch.Tensor) -> tuple:
"""Strategy to serialize a tensor to native python types.
Expand All @@ -131,16 +147,6 @@ def generic_tensor_deserializer(worker: AbstractWorker, tensor_tuple: tuple) ->
tensor = torch.tensor(data_arr, dtype=TORCH_STR_DTYPE[dtype]).reshape(size)
return tensor


def numpy_tensor_deserializer(worker: AbstractWorker, tensor_bin) -> torch.Tensor:
""""Strategy to deserialize a binary input in npy format into a Torch tensor"""
input_file = TemporaryFile()
input_file.write(tensor_bin)
# read data from file
input_file.seek(0)
return torch.from_numpy(numpy.load(input_file))


def torch_tensor_serializer(worker: AbstractWorker, tensor) -> bin:
"""Strategy to serialize a tensor using Torch saver"""
binary_stream = io.BytesIO()
Expand Down
21 changes: 13 additions & 8 deletions test/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,17 +676,22 @@ def test_pointer_tensor_detail(id):
assert (x_back == 2 * x).all()


def test_numpy_tensor_serde():
compression._apply_compress_scheme = compression.apply_lz4_compression

@pytest.mark.parametrize(
"tensor",
[
(torch.tensor(numpy.ones((10, 10)), requires_grad=False)),
(torch.tensor([[0.25, 1.5], [0.15, 0.25], [1.25, 0.5]], requires_grad=True)),
(torch.randint(low=0, high=10, size=[3, 7], requires_grad=False)),
],
)
def test_numpy_tensor_serde(tensor):
serde._apply_compress_scheme = serde.apply_lz4_compression
serde._serialize_tensor = syft.serde.msgpack.torch_serde.numpy_tensor_serializer
serde._deserialize_tensor = syft.serde.msgpack.torch_serde.numpy_tensor_deserializer

tensor = torch.tensor(numpy.ones((10, 10)), requires_grad=False)

tensor_serialized = syft.serde.serialize(tensor)
assert tensor_serialized[0] != compression.NO_COMPRESSION
tensor_deserialized = syft.serde.deserialize(tensor_serialized)
tensor_serialized = serde.serialize(tensor)
assert tensor_serialized[0] != serde.NO_COMPRESSION
tensor_deserialized = serde.deserialize(tensor_serialized)

# Back to Pytorch serializer
serde._serialize_tensor = syft.serde.msgpack.torch_serde.torch_tensor_serializer
Expand Down