Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions syft/frameworks/torch/tensors/decorators/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def simplify(worker: AbstractWorker, tensor: "LoggingTensor") -> tuple:
chain = None
if hasattr(tensor, "child"):
chain = sy.serde._simplify(worker, tensor.child)
return tensor.id, chain
return (sy.serde._simplify(worker, tensor.id), chain)

@staticmethod
def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "LoggingTensor":
Expand All @@ -161,7 +161,7 @@ def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "LoggingTensor":
"""
obj_id, chain = tensor_tuple

tensor = LoggingTensor(owner=worker, id=obj_id)
tensor = LoggingTensor(owner=worker, id=sy.serde._detail(worker, obj_id))

if chain is not None:
chain = sy.serde._detail(worker, chain)
Expand Down
10 changes: 8 additions & 2 deletions syft/frameworks/torch/tensors/interpreters/additive_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,12 @@ def simplify(worker: AbstractWorker, tensor: "AdditiveSharingTensor") -> tuple:
# Don't delete the remote values of the shares at simplification
tensor.set_garbage_collect_data(False)

return (tensor.id, tensor.field, tensor.crypto_provider.id, chain)
return (
sy.serde._simplify(worker, tensor.id),
tensor.field,
sy.serde._simplify(worker, tensor.crypto_provider.id),
chain,
)

@staticmethod
def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "AdditiveSharingTensor":
Expand All @@ -1003,10 +1008,11 @@ def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "AdditiveSharingTenso
"""

tensor_id, field, crypto_provider, chain = tensor_tuple
crypto_provider = sy.serde._detail(worker, crypto_provider)

tensor = AdditiveSharingTensor(
owner=worker,
id=tensor_id,
id=sy.serde._detail(worker, tensor_id),
field=field,
crypto_provider=worker.get_worker(crypto_provider),
)
Expand Down
2 changes: 1 addition & 1 deletion syft/frameworks/torch/tensors/interpreters/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
the tensor is located.
id: An optional string or integer id of the FixedPrecisionTensor.
"""
super().__init__(tags, description)
super().__init__(tags=tags, description=description)

self.owner = owner
self.id = id if id else syft.ID_PROVIDER.pop()
Expand Down
5 changes: 3 additions & 2 deletions syft/generic/pointers/multi_pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def simplify(worker: AbstractWorker, tensor: "MultiPointerTensor") -> tuple:
chain = None
if hasattr(tensor, "child"):
chain = sy.serde._simplify(worker, tensor.child)
return tensor.id, chain

return (sy.serde._simplify(worker, tensor.id), chain)

@staticmethod
def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "MultiPointerTensor":
Expand All @@ -250,7 +251,7 @@ def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "MultiPointerTensor":

tensor_id, chain = tensor_tuple

tensor = sy.MultiPointerTensor(owner=worker, id=tensor_id)
tensor = sy.MultiPointerTensor(owner=worker, id=sy.serde._detail(worker, tensor_id))

if chain is not None:
chain = sy.serde._detail(worker, chain)
Expand Down
18 changes: 10 additions & 8 deletions syft/generic/pointers/object_pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def get(self, user=None, reason: str = "", deregister_ptr: bool = True):

Args:
user (obj, optional) : authenticate/allow user to perform get on remote private objects.
reason (str, optional) : a description of why the data scientist wants to see it.
reason (str, optional) : a description of why the data scientist wants to see it.
deregister_ptr (bool, optional): this determines whether to
deregister this pointer from the pointer's owner during this
method. This defaults to True because the main reason people use
Expand Down Expand Up @@ -379,10 +379,10 @@ def simplify(worker: AbstractWorker, ptr: "ObjectPointer") -> tuple:
"""

return (
ptr.id,
ptr.id_at_location,
ptr.location.id,
ptr.point_to_attr,
syft.serde._simplify(worker, ptr.id),
syft.serde._simplify(worker, ptr.id_at_location),
syft.serde._simplify(worker, ptr.location.id),
syft.serde._simplify(worker, ptr.point_to_attr),
ptr.garbage_collect_data,
)

Expand All @@ -403,16 +403,18 @@ def detail(worker: "AbstractWorker", object_tuple: tuple) -> "ObjectPointer":
# TODO: fix comment for this and simplifier
obj_id, id_at_location, worker_id, point_to_attr, garbage_collect_data = object_tuple

if isinstance(worker_id, bytes):
worker_id = worker_id.decode()
obj_id = syft.serde._detail(worker, obj_id)
id_at_location = syft.serde._detail(worker, id_at_location)
worker_id = syft.serde._detail(worker, worker_id)
point_to_attr = syft.serde._detail(worker, point_to_attr)

# If the pointer received is pointing at the current worker, we load the tensor instead
if worker_id == worker.id:
obj = worker.get_obj(id_at_location)

if point_to_attr is not None and obj is not None:

point_to_attrs = point_to_attr.decode("utf-8").split(".")
point_to_attrs = point_to_attr.split(".")
for attr in point_to_attrs:
if len(attr) > 0:
obj = getattr(obj, attr)
Expand Down
12 changes: 9 additions & 3 deletions syft/generic/pointers/pointer_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,21 @@ def get(self, deregister_ptr: bool = True):
@staticmethod
def simplify(worker: AbstractWorker, ptr: "PointerPlan") -> tuple:

return (ptr.id, ptr.id_at_location, ptr.location.id, ptr.garbage_collect_data)
return (
sy.serde._simplify(worker, ptr.id),
sy.serde._simplify(worker, ptr.id_at_location),
sy.serde._simplify(worker, ptr.location.id),
ptr.garbage_collect_data,
)

@staticmethod
def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "PointerPlan":
# TODO: fix comment for this and simplifier
obj_id, id_at_location, worker_id, garbage_collect_data = tensor_tuple

if isinstance(worker_id, bytes):
worker_id = worker_id.decode()
obj_id = sy.serde._detail(worker, obj_id)
id_at_location = sy.serde._detail(worker, id_at_location)
worker_id = sy.serde._detail(worker, worker_id)

# If the pointer received is pointing at the current worker, we load the tensor instead
if worker_id == worker.id:
Expand Down
11 changes: 8 additions & 3 deletions syft/generic/pointers/pointer_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,20 @@ def get(self, deregister_ptr: bool = True):
@staticmethod
def simplify(worker: AbstractWorker, ptr: "PointerPlan") -> tuple:

return (ptr.id, ptr.id_at_location, ptr.location.id, ptr.garbage_collect_data)
return (
sy.serde._simplify(worker, ptr.id),
ptr.id_at_location,
sy.serde._simplify(worker, ptr.location.id),
ptr.garbage_collect_data,
)

@staticmethod
def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "PointerPlan":
# TODO: fix comment for this and simplifier
obj_id, id_at_location, worker_id, garbage_collect_data = tensor_tuple

if isinstance(worker_id, bytes):
worker_id = worker_id.decode()
obj_id = sy.serde._detail(worker, obj_id)
worker_id = sy.serde._detail(worker, worker_id)

# If the pointer received is pointing at the current worker, we load the tensor instead
if worker_id == worker.id:
Expand Down
19 changes: 11 additions & 8 deletions syft/generic/pointers/pointer_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def remote_send(self, destination, change_location=False):
C will hold a pointer to a pointer on A which points to the tensor on B.
If change_location is set to True, the original pointer will point to the moved object.
Considering the same example as before with ptr.remote_send(B, change_location=True):
C will hold a pointer to the tensor on B. We may need to be careful here because this pointer
C will hold a pointer to the tensor on B. We may need to be careful here because this pointer
will have 2 references pointing to it.
"""
args = (destination,)
Expand Down Expand Up @@ -458,10 +458,11 @@ def simplify(worker: AbstractWorker, ptr: "PointerTensor") -> tuple:
"""

return (
ptr.id,
ptr.id_at_location,
ptr.location.id,
ptr.point_to_attr,
# ptr.id,
syft.serde._simplify(worker, ptr.id),
syft.serde._simplify(worker, ptr.id_at_location),
syft.serde._simplify(worker, ptr.location.id),
syft.serde._simplify(worker, ptr.point_to_attr),
syft.serde._simplify(worker, ptr._shape),
ptr.garbage_collect_data,
)
Expand Down Expand Up @@ -491,8 +492,10 @@ def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "PointerTensor":
# TODO: fix comment for this and simplifier
obj_id, id_at_location, worker_id, point_to_attr, shape, garbage_collect_data = tensor_tuple

if isinstance(worker_id, bytes):
worker_id = worker_id.decode()
obj_id = syft.serde._detail(worker, obj_id)
id_at_location = syft.serde._detail(worker, id_at_location)
worker_id = syft.serde._detail(worker, worker_id)
point_to_attr = syft.serde._detail(worker, point_to_attr)

if shape is not None:
shape = syft.hook.create_shape(syft.serde._detail(worker, shape))
Expand All @@ -503,7 +506,7 @@ def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "PointerTensor":

if point_to_attr is not None and tensor is not None:

point_to_attrs = point_to_attr.decode("utf-8").split(".")
point_to_attrs = point_to_attr.split(".")
for attr in point_to_attrs:
if len(attr) > 0:
tensor = getattr(tensor, attr)
Expand Down
6 changes: 4 additions & 2 deletions syft/messaging/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def simplify(worker: AbstractWorker, ptr: "Operation") -> tuple:
"""
# NOTE: we can skip calling _simplify on return_ids because they should already be
# a list of simple types.
return (sy.serde._simplify(worker, ptr.message), ptr.return_ids)
return (sy.serde._simplify(worker, ptr.message), sy.serde._simplify(worker, ptr.return_ids))

@staticmethod
def detail(worker: AbstractWorker, msg_tuple: tuple) -> "Operation":
Expand All @@ -172,7 +172,9 @@ def detail(worker: AbstractWorker, msg_tuple: tuple) -> "Operation":
Examples:
message = detail(sy.local_worker, msg_tuple)
"""
return Operation(sy.serde._detail(worker, msg_tuple[0]), msg_tuple[1])
return Operation(
sy.serde._detail(worker, msg_tuple[0]), sy.serde._detail(worker, msg_tuple[1])
)


class ObjectMessage(Message):
Expand Down
41 changes: 28 additions & 13 deletions syft/messaging/plan/procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,30 @@ def update_ids(
from_worker: The previous worker that built the plan.
to_worker: The new worker that is running the plan.
"""

# We operate on simplified content of Operation, hence all values should be simplified
from_workers_simplified = None
to_workers_simplified = None
if from_worker and to_worker:
Copy link
Member

@gmuraru gmuraru Dec 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Will we always have from_worker and to_worker. Is there any case only one is != None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think replacement doesn't make sense when any of them are None.
This was existing code, I just moved it out of the loop to avoid making simplification each time.

from_workers_simplified = [sy.serde._simplify(None, from_worker)]
to_workers_simplified = [sy.serde._simplify(None, to_worker)]

from_ids_simplified = None
to_ids_simplified = None
if len(from_ids) and len(to_ids):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: The same case as above for here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above :)

from_ids_simplified = [sy.serde._simplify(None, id) for id in from_ids]
to_ids_simplified = [sy.serde._simplify(None, id) for id in to_ids]

for idx, operation in enumerate(self.operations):
if from_worker and to_worker:
from_workers, to_workers = [from_worker], [to_worker]
if isinstance(from_worker, str):
from_workers.append(from_worker.encode("utf-8"))
to_workers.append(to_worker)
operation = Procedure.replace_operation_ids(operation, from_workers, to_workers)
if from_workers_simplified and to_workers_simplified:
operation = Procedure.replace_operation_ids(
operation, from_workers_simplified, to_workers_simplified
)

if len(from_ids) and len(to_ids):
operation = Procedure.replace_operation_ids(operation, from_ids, to_ids)
if from_ids_simplified and to_ids_simplified:
operation = Procedure.replace_operation_ids(
operation, from_ids_simplified, to_ids_simplified
)

self.operations[idx] = operation

Expand All @@ -106,7 +120,8 @@ def replace_operation_ids(operation, from_ids, to_ids):
type_obj = type(operation)
operation = list(operation)
for i, item in enumerate(operation):
if isinstance(item, (int, str, bytes)) and item in from_ids:
# Since this is simplified content, id can be int or simplified str (tuple)
if isinstance(item, (int, tuple)) and item in from_ids:
operation[i] = to_ids[from_ids.index(item)]
elif isinstance(item, (list, tuple)):
operation[i] = Procedure.replace_operation_ids(
Expand All @@ -125,9 +140,8 @@ def copy(self) -> "Procedure":
@staticmethod
def simplify(worker: AbstractWorker, procedure: "Procedure") -> tuple:
return (
tuple(
procedure.operations
), # We're not simplifying because operations are already simplified
# We're not simplifying fully because operations are already simplified
sy.serde._simplify(worker, procedure.operations, shallow=True),
sy.serde._simplify(worker, procedure.arg_ids),
sy.serde._simplify(worker, procedure.result_ids),
sy.serde._simplify(worker, procedure.promise_out_id),
Expand All @@ -136,7 +150,8 @@ def simplify(worker: AbstractWorker, procedure: "Procedure") -> tuple:
@staticmethod
def detail(worker: AbstractWorker, procedure_tuple: tuple) -> "State":
operations, arg_ids, result_ids, promise_out_id = procedure_tuple
operations = list(operations)

operations = sy.serde._detail(worker, operations, shallow=True)
arg_ids = sy.serde._detail(worker, arg_ids)
result_ids = sy.serde._detail(worker, result_ids)

Expand Down
Loading