forked from OpenMined/PySyft
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_virtual.py
More file actions
279 lines (196 loc) · 7.58 KB
/
test_virtual.py
File metadata and controls
279 lines (196 loc) · 7.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
from time import time
from unittest.mock import patch
import pytest
import torch
import syft as sy
from syft import serde
from syft.generic.pointers.object_wrapper import ObjectWrapper
from syft.messaging.message import ObjectMessage
from syft.messaging.message import ObjectRequestMessage
from syft.workers.virtual import VirtualWorker
from syft.exceptions import GetNotPermittedError
from syft.exceptions import ObjectNotFoundError
def test_send_msg():
"""Tests sending a message with a specific ID
This is a simple test to ensure that the BaseWorker interface
can properly send/receive a message containing a tensor.
"""
# get pointer to local worker
me = sy.torch.hook.local_worker
# pending time to simulate lantency (optional)
me.message_pending_time = 0.1
# create a new worker (to send the object to)
worker_id = sy.ID_PROVIDER.pop()
bob = VirtualWorker(sy.torch.hook, id=f"bob{worker_id}")
# initialize the object and save it's id
obj = torch.Tensor([100, 100])
obj_id = obj.id
# Send data to bob
start_time = time()
me.send_msg(ObjectMessage(obj), bob)
elapsed_time = time() - start_time
# ensure that object is now on bob's machine
assert obj_id in bob._objects
# ensure that object was sent 0.1 secs later
assert abs(elapsed_time - me.message_pending_time) < 0.1
def test_send_msg_using_tensor_api():
"""Tests sending a message with a specific ID
This is a simple test to ensure that the high level tensor .send()
method correctly sends a message to another worker.
"""
# create worker to send object to
worker_id = sy.ID_PROVIDER.pop()
bob = VirtualWorker(sy.torch.hook, id=f"bob{worker_id}")
# create a tensor to send (default on local_worker)
obj = torch.Tensor([100, 100])
# save the object's id
obj_id = obj.id
# send the object to Bob (from local_worker)
_ = obj.send(bob)
# ensure tensor made it to Bob
assert obj_id in bob._objects
def test_recv_msg():
"""Tests the recv_msg command with 2 tests
The first test uses recv_msg to send an object to alice.
The second test uses recv_msg to request the object
previously sent to alice."""
# TEST 1: send tensor to alice
# create a worker to send data to
worker_id = sy.ID_PROVIDER.pop()
alice = VirtualWorker(sy.torch.hook, id=f"alice{worker_id}")
# create object to send
obj = torch.Tensor([100, 100])
# create/serialize message
message = ObjectMessage(obj)
bin_msg = serde.serialize(message)
# have alice receive message
alice.recv_msg(bin_msg)
# ensure that object is now in alice's registry
assert obj.id in alice._objects
# Test 2: get tensor back from alice
# Create message: Get tensor from alice
message = ObjectRequestMessage((obj.id, None, ""))
# serialize message
bin_msg = serde.serialize(message)
# call receive message on alice
resp = alice.recv_msg(bin_msg)
obj_2 = sy.serde.deserialize(resp)
# assert that response is correct type
assert type(resp) == bytes
# ensure that the object we receive is correct
assert obj_2.id == obj.id
def tests_worker_convenience_methods():
"""Tests send and get object methods on BaseWorker
This test comes in two parts. The first uses the simple
BaseWorker.send_obj and BaseWorker.request_obj to send a
tensor to Alice and to get the worker back from Alice.
The second part shows that the same methods work between
bob and alice directly.
"""
me = sy.torch.hook.local_worker
worker_id = sy.ID_PROVIDER.pop()
bob = VirtualWorker(sy.torch.hook, id=f"bob{worker_id}")
worker_id = sy.ID_PROVIDER.pop()
alice = VirtualWorker(sy.torch.hook, id=f"alice{worker_id}")
obj = torch.Tensor([100, 100])
# Send data to alice
me.send_obj(obj, alice)
# Get data from alice
resp_alice = me.request_obj(obj.id, alice)
assert (resp_alice == obj).all()
obj2 = torch.Tensor([200, 200])
# Set data on self
bob.set_obj(obj2)
# Get data from self
resp_bob_self = bob.get_obj(obj2.id)
assert (resp_bob_self == obj2).all()
# Get data from bob as alice
resp_bob_alice = alice.request_obj(obj2.id, bob)
assert (resp_bob_alice == obj2).all()
def test_search():
worker_id = sy.ID_PROVIDER.pop()
bob = VirtualWorker(sy.torch.hook, id=f"bob{worker_id}")
x = (
torch.tensor([1, 2, 3, 4, 5])
.tag("#fun", "#mnist")
.describe("The images in the MNIST training dataset.")
.send(bob)
)
y = (
torch.tensor([1, 2, 3, 4, 5])
.tag("#not_fun", "#cifar")
.describe("The images in the MNIST training dataset.")
.send(bob)
)
z = (
torch.tensor([1, 2, 3, 4, 5])
.tag("#fun", "#boston_housing")
.describe("The images in the MNIST training dataset.")
.send(bob)
)
a = (
torch.tensor([1, 2, 3, 4, 5])
.tag("#not_fun", "#boston_housing")
.describe("The images in the MNIST training dataset.")
.send(bob)
)
assert len(bob.search("#fun")) == 2
assert len(bob.search("#mnist")) == 1
assert len(bob.search("#cifar")) == 1
assert len(bob.search("#not_fun")) == 2
assert len(bob.search(["#not_fun", "#boston_housing"])) == 1
def test_obj_not_found(workers):
"""Test for useful error message when trying to call a method on
a tensor which does not exist on a worker anymore."""
bob = workers["bob"]
x = torch.tensor([1, 2, 3, 4, 5]).send(bob)
bob._objects = {}
with pytest.raises(ObjectNotFoundError):
y = x + x
def test_get_not_permitted(workers):
bob = workers["bob"]
x = torch.tensor([1, 2, 3, 4, 5]).send(bob)
with patch.object(torch.Tensor, "allow") as mock_allowed_to_get:
mock_allowed_to_get.return_value = False
with pytest.raises(GetNotPermittedError):
x.get()
mock_allowed_to_get.assert_called_once()
def test_spinup_time(hook):
"""Tests to ensure that virtual workers intialized with 10000 data points
load in under 0.05 seconds. This is needed to ensure that virtual workers
spun up inside web frameworks are created quickly enough to not cause timeout errors"""
data = []
for i in range(10000):
data.append(torch.Tensor(5, 5).random_(100))
start_time = time()
dummy = sy.VirtualWorker(hook, id="dummy", data=data)
end_time = time()
assert (end_time - start_time) < 0.1
def test_send_jit_scriptmodule(hook, workers): # pragma: no cover
bob = workers["bob"]
@torch.jit.script
def foo(x):
return x + 2
foo_wrapper = ObjectWrapper(obj=foo, id=99)
foo_ptr = hook.local_worker.send(foo_wrapper, bob)
res = foo_ptr(torch.tensor(4))
assert res == torch.tensor(6)
def test_send_command_whitelist(hook, workers):
bob = workers["bob"]
whitelisted_methods = {
"torch": {"tensor": [1, 2, 3], "rand": (2, 3), "randn": (2, 3), "zeros": (2, 3)}
}
for framework, methods in whitelisted_methods.items():
attr = getattr(bob.remote, framework)
for method, inp in methods.items():
x = getattr(attr, method)(inp)
if "rand" not in method:
assert (x.get() == getattr(torch, method)(inp)).all()
def test_send_command_not_whitelisted(hook, workers):
bob = workers["bob"]
method_not_exist = "openmind"
for framework in bob.remote.frameworks:
if framework in dir(bob.remote):
attr = getattr(bob.remote, framework)
with pytest.raises(AttributeError):
getattr(attr, method_not_exist)