-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Continue the work on Promises and PromiseTensors #2610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8240cf1
b9fdfb7
044038c
4840f5b
34cef68
db87a0d
2a236b9
061acd4
81ee55b
087c499
9e8eded
02a4841
0a54bec
0c603d5
ee05fcd
8f6d3e2
4dd8b9b
e6c0ff4
6f37b9b
92bb928
282504b
8002f09
2e325fe
8501ec0
f49ae95
5674104
78555aa
d1fe2fa
d675d57
3d9df12
91b3a9a
4eb1767
24ea741
199b816
c788667
e29fb75
2ac6950
c2130d3
3fb1dcc
4fad0a5
33f03cd
15bcdd4
e2d47a9
d3b33bf
85fc9ba
e43c55a
fe6888c
0e660ed
78481cf
d19f04b
1b9e699
ac3c2ff
d666cc6
3232401
b69c4a1
431b0fe
0646ca4
7561791
1173ae8
c89a4f6
9409912
29a2d0d
55aaefc
0485438
3b27821
b38fe6e
a489e01
56a5a6f
0a62f71
fa7f1e6
4f4487f
90ac4c5
4ae26d8
ddc1599
c010016
4cd5eb1
0398fba
f32a179
ec430eb
a520d80
f72bd81
956178f
8057846
01508f3
1e0c52b
9cfffad
e61c5bc
f39c8a5
3724538
8ba5904
a3f58a2
fb971e9
de1b8e3
766160e
e0eb207
a442a2b
f7fd828
0461087
4f434b5
dbbf200
e5f73d0
d85ad4c
23e0341
eba0e77
efaa4b9
51fcf04
0dbabad
19d988b
9ebc6c0
3fc2296
c7fab93
da4cfc8
c84f6fe
19209b6
115d92d
bd3bf86
c84b533
a89b847
0fac765
2234c03
a3355a6
b8007c8
66b2c50
29e49d6
98d06c5
fe2da29
6b42dd4
80f4c36
7a8228c
8df13f9
8763509
0a3ddf7
4493ffd
dde0e95
99a3b1d
6fc9cef
e948471
ee481c5
73ad4f3
141a174
e889625
f58cff4
5aec32e
fabf47e
9ebb9c3
dd5a357
67e3af3
164f083
b6e742a
daf7ef2
7ad284d
0767de6
614ae26
e5490dd
f636361
16b8ee1
a9386a7
984055c
0bbb38f
c6d3d75
67cb85a
f601062
30f28b4
69bd722
50b83ec
a09c78a
d5f5a2c
659eccc
1ea2b47
cf5ade2
8aaea1c
8d4257e
4143951
e720f07
4322eae
ff3903c
dbaf27c
7398dab
2543f0a
377542d
d5ffa19
6201466
b8b4d78
9971c4f
8eb10e7
158f4f4
439f70e
8914fb7
2dc80e0
43448d5
8c2c082
b0b58fd
bb9f7cc
77051ed
b1b89c9
e8f31f5
c4b875f
073b6ab
bc165a7
b4ea663
1bf721e
ad0fe55
20e53ad
a8308c9
706dd56
f22ab47
a295168
f0293e7
07e2ede
f818a0c
d09c053
90d9656
a329b5d
4d57256
9663af2
3ee7ed2
543f0d0
d990da2
1aeaf0b
47033ff
8907cf3
e46baff
45d6ec0
6f48943
d59bc10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,9 +10,11 @@ | |
| import syft | ||
| from syft.generic.frameworks.hook import hook_args | ||
| from syft.generic.frameworks.hook.hook import FrameworkHook | ||
| from syft.generic.tensor import AbstractTensor | ||
| from syft.generic.frameworks.remote import Remote | ||
| from syft.frameworks.torch.tensors.interpreters.autograd import AutogradTensor | ||
| from syft.frameworks.torch.tensors.interpreters.native import TorchTensor | ||
| from syft.frameworks.torch.tensors.interpreters.promise import PromiseTensor | ||
| from syft.frameworks.torch.tensors.interpreters.paillier import PaillierTensor | ||
| from syft.frameworks.torch.tensors.decorators.logging import LoggingTensor | ||
| from syft.frameworks.torch.tensors.interpreters.precision import FixedPrecisionTensor | ||
|
|
@@ -26,6 +28,7 @@ | |
| from syft.workers.base import BaseWorker | ||
| from syft.workers.virtual import VirtualWorker | ||
| from syft.messaging.plan import Plan | ||
| from syft.messaging.promise import Promise | ||
|
|
||
| from syft.exceptions import route_method_exception | ||
| from syft.exceptions import TensorsNotCollocatedException | ||
|
|
@@ -168,6 +171,9 @@ def __init__( | |
| # Add all hooked tensor methods to LargePrecisionTensor tensor | ||
| self._hook_syft_tensor_methods(LargePrecisionTensor) | ||
|
|
||
| # Add all hooked tensor methods to PromiseTensor | ||
| self._hook_promise_tensor() | ||
|
|
||
| # Hook the tensor constructor function | ||
| self._hook_tensor() | ||
|
|
||
|
|
@@ -504,6 +510,103 @@ def overloaded_attr(self, *args, **kwargs): | |
|
|
||
| return overloaded_attr | ||
|
|
||
| def _hook_promise_tensor(hook_self): | ||
|
|
||
| methods_to_hook = hook_self.to_auto_overload[torch.Tensor] | ||
|
|
||
| def generate_method(method_name): | ||
| def method(self, *args, **kwargs): | ||
|
|
||
| arg_shapes = list([self.shape]) | ||
| arg_ids = list([self.id]) | ||
|
|
||
| # Convert scalar arguments to tensors to be able to use them with plans | ||
| args = list(args) | ||
| for ia in range(len(args)): | ||
| if not isinstance(args[ia], (torch.Tensor, AbstractTensor)): | ||
| args[ia] = torch.tensor(args[ia]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm surprised that this is needed: you can usually use scalars with torch operation:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand it will make your next call to .shape fail, but you can test there if it has a shape
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was to be able to use plans with scalar arguments.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it does now :)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How often will we use plans with scalar arguments? If this this an edge case we can remove this an leave it for another PR with a TODO
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it depends on the use case: for NN stuff maybe it's not needed but for crypto protocols, it might be more often (I'm not a crypto expert though).
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can "just store the scalar value" in the plan, it looks like it works from what I've seen
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still think it's easier to change the scalars to tensors because this way, I can store and retrieve them from the worker when needed (when the promises are fulfilled). I can't see how to do that with scalars (but maybe it's just me :)) |
||
|
|
||
| for arg in args: | ||
| arg_shapes.append(arg.shape) | ||
|
|
||
| @syft.func2plan(arg_shapes) | ||
| def operation_method(self, *args, **kwargs): | ||
| return getattr(self, method_name)(*args, **kwargs) | ||
|
|
||
| self.plans.add(operation_method.id) | ||
| for arg in args: | ||
| if isinstance(arg, PromiseTensor): | ||
| arg.plans.add(operation_method.id) | ||
|
|
||
| operation_method.procedure.update_args( | ||
| [self, *args], operation_method.procedure.result_ids | ||
| ) | ||
|
|
||
| promise_out = PromiseTensor( | ||
| owner=self.owner, | ||
| shape=operation_method.output_shape, | ||
| tensor_type=self.obj_type, | ||
| plans=set(), | ||
| ) | ||
| operation_method.procedure.promise_out_id = promise_out.id | ||
|
|
||
| if operation_method.owner != self.owner: | ||
| operation_method.send(self.owner) | ||
| else: # otherwise object not registered on local worker | ||
| operation_method.owner.register_obj(operation_method) | ||
|
|
||
| return promise_out | ||
|
|
||
| return method | ||
|
|
||
| for method_name in methods_to_hook: | ||
| setattr(PromiseTensor, method_name, generate_method(method_name)) | ||
|
|
||
| def FloatTensor(shape, *args, **kwargs): | ||
| return PromiseTensor(shape, tensor_type="torch.FloatTensor", *args, **kwargs).wrap() | ||
|
|
||
| setattr(Promise, "FloatTensor", FloatTensor) | ||
|
|
||
| def DoubleTensor(shape, *args, **kwargs): | ||
| return PromiseTensor(shape, tensor_type="torch.DoubleTensor", *args, **kwargs).wrap() | ||
|
|
||
| setattr(Promise, "DoubleTensor", DoubleTensor) | ||
|
|
||
| def HalfTensor(shape, *args, **kwargs): | ||
| return PromiseTensor(shape, tensor_type="torch.HalfTensor", *args, **kwargs).wrap() | ||
|
|
||
| setattr(Promise, "HalfTensor", HalfTensor) | ||
|
|
||
| def ByteTensor(shape, *args, **kwargs): | ||
| return PromiseTensor(shape, tensor_type="torch.ByteTensor", *args, **kwargs).wrap() | ||
|
|
||
| setattr(Promise, "ByteTensor", ByteTensor) | ||
|
|
||
| def CharTensor(shape, *args, **kwargs): | ||
| return PromiseTensor(shape, tensor_type="torch.CharTensor", *args, **kwargs).wrap() | ||
|
|
||
| setattr(Promise, "CharTensor", CharTensor) | ||
|
|
||
| def ShortTensor(shape, *args, **kwargs): | ||
| return PromiseTensor(shape, tensor_type="torch.ShortTensor", *args, **kwargs).wrap() | ||
|
|
||
| setattr(Promise, "ShortTensor", ShortTensor) | ||
|
|
||
| def IntTensor(shape, *args, **kwargs): | ||
| return PromiseTensor(shape, tensor_type="torch.IntTensor", *args, **kwargs).wrap() | ||
|
|
||
| setattr(Promise, "IntTensor", IntTensor) | ||
|
|
||
| def LongTensor(shape, *args, **kwargs): | ||
| return PromiseTensor(shape, tensor_type="torch.LongTensor", *args, **kwargs).wrap() | ||
|
|
||
| setattr(Promise, "LongTensor", LongTensor) | ||
|
|
||
| def BoolTensor(shape, args, **kwargs): | ||
| return PromiseTensor(shape, tensor_type="torch.BoolTensor", *args, **kwargs).wrap() | ||
|
|
||
| setattr(Promise, "BoolTensor", BoolTensor) | ||
|
|
||
| def _hook_tensor(hook_self): | ||
| """Hooks the function torch.tensor() | ||
| We need to do this seperately from hooking the class because internally | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| import syft as sy | ||
| from syft.workers.abstract import AbstractWorker | ||
| import weakref | ||
|
|
||
| from syft.generic.tensor import AbstractTensor | ||
| from syft.generic.tensor import initialize_tensor | ||
| from syft.messaging.promise import Promise | ||
| from syft.generic.frameworks.hook import hook_args | ||
|
|
||
|
|
||
| class PromiseTensor(AbstractTensor, Promise): | ||
| def __init__( | ||
| self, shape, owner=None, id=None, tensor_type=None, plans=None, tags=None, description=None, | ||
| ): | ||
| """Initializes a PromiseTensor | ||
|
|
||
| Args: | ||
| shape: the shape that should have the tensors keeping the promise. | ||
| owner: an optional BaseWorker object to specify the worker on which | ||
| the tensor is located. | ||
| id: an optional string or integer id of the PromiseTensor. | ||
| tensor_type: the type that should have the tensors keeping the promise. | ||
| plans: the ids of the plans waiting for the promise to be kept. When the promise is | ||
| kept, all the plans corresponding to these ids will be executed if the other | ||
| promises they were waiting for are also kept. | ||
| tags: an optional set of hashtags corresponding to this tensor | ||
| which this tensor should be searchable for. | ||
| description: an optional string describing the purpose of the | ||
| tensor. | ||
| """ | ||
|
|
||
| if owner is None: | ||
| owner = sy.local_worker | ||
|
|
||
| # constructors for AbstractTensor and Promise | ||
| AbstractTensor.__init__(self, id=id, owner=owner, tags=tags, description=description) | ||
| Promise.__init__(self, owner=owner, obj_type=tensor_type, plans=plans) | ||
|
|
||
| self._shape = shape | ||
|
|
||
| del self.child | ||
|
|
||
| def torch_type(self): | ||
| return self.obj_type | ||
|
|
||
| @property | ||
| def shape(self): | ||
Jasopaum marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return self._shape | ||
|
|
||
| @property | ||
| def grad(self): | ||
| return None | ||
| # if not hasattr(self, "_grad"): | ||
| # self._grad = PromiseTensor(shape=self._shape, tensor_type=self.torch_type()).wrap() | ||
| # | ||
| # return self._grad | ||
Jasopaum marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __str__(self): | ||
| return f"[PromiseTensor({self.owner.id}:{self.id}) -future-> {self.obj_type.split('.')[-1]} -blocking-> {len(self.plans)} plans]" | ||
|
|
||
| def __repr__(self): | ||
| return self.__str__() | ||
|
|
||
| @staticmethod | ||
| def simplify(worker: AbstractWorker, tensor: "PromiseTensor") -> tuple: | ||
| """Takes the attributes of a FixedPrecisionTensor and saves them in a tuple. | ||
|
|
||
| Args: | ||
| tensor: a FixedPrecisionTensor. | ||
|
|
||
| Returns: | ||
| tuple: a tuple holding the unique attributes of the fixed precision tensor. | ||
| """ | ||
|
|
||
| return ( | ||
| sy.serde._simplify(worker, tensor.id), | ||
| sy.serde._simplify(worker, tensor.shape), | ||
| sy.serde._simplify(worker, tensor.obj_type), | ||
| sy.serde._simplify(worker, tensor.plans), | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "PromiseTensor": | ||
| """ | ||
| This function reconstructs a FixedPrecisionTensor given it's attributes in form of a tuple. | ||
| Args: | ||
| worker: the worker doing the deserialization | ||
| tensor_tuple: a tuple holding the attributes of the FixedPrecisionTensor | ||
| Returns: | ||
| FixedPrecisionTensor: a FixedPrecisionTensor | ||
| Examples: | ||
| shared_tensor = detail(data) | ||
| """ | ||
|
|
||
| id, shape, tensor_type, plans = tensor_tuple | ||
|
|
||
| id = sy.serde._detail(worker, id) | ||
| shape = sy.serde._detail(worker, shape) | ||
| tensor_type = sy.serde._detail(worker, tensor_type) | ||
| plans = sy.serde._detail(worker, plans) | ||
|
|
||
| tensor = PromiseTensor( | ||
| owner=worker, id=id, shape=shape, tensor_type=tensor_type, plans=plans | ||
| ) | ||
|
|
||
| return tensor | ||
|
|
||
|
|
||
| ### Register the tensor with hook_args.py ### | ||
| hook_args.default_register_tensor(PromiseTensor) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to hook the promise tensor? Why can't we just define all the methods directly in the PromiseTensor class? The class resides within PySyft, so why don't we define the methods as for example
DoubleTensorthere? It doesn't seem to rely on the DoubleTensor being hooked before defining the method.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Methods like
DoubleTensorwere in the file where the class is defined before but I was asked in some comments to move them here ^^For the other methods, I think this file was supposed to be where this kind of method generation happen but maybe not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that sounds like a contradiction. I'd like to hear the opinion of @LaRiffle and @robert-wagner 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Historically PromiseTensor has always been a little bit of an exception because of the way it works. I'm ok with this for now.