Skip to content

Commit f77938a

Browse files
abogaziahLaRiffle
andauthored
Send & Get Datasets (#2960)
* dataset inherits abstractobject * added abstract dataset test * formatting * fix * removed owner annotation * formatting * Added simplify& detail * Edit simplify& detail * Added PointerDataset type * Added send() * Added send(), create_pointer() * Added get(), pointer.data(), pointer.target() * Added wrap(), edited base.search(), edit dataset.__init__() * Formatting * Added test file, __repr__(), removed get() override * Added tags, description to create_pointer() * formatting * changed repr of dataset ptr * replaced == with torch.equal() * adapt dataloader * formatting * added serede test! * added serede test! * test fixes * test fixes * serde test fix * formatting * hook_args fix * Update test_dataset_pointer.py * pointer location fix * formatting * formatting * formatting * formatting * fedrated check pass * formatting Co-authored-by: Muhammed Abogazia <abogaziah@users.noreply.github.com> Co-authored-by: Théo Ryffel <theo.leffyr@gmail.com>
1 parent 7e91927 commit f77938a

8 files changed

Lines changed: 285 additions & 35 deletions

File tree

syft/frameworks/torch/fl/dataset.py

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import math
22
import logging
3-
3+
from syft.generic.object import AbstractObject
4+
from syft.workers.base import BaseWorker
5+
from syft.generic.pointers.pointer_dataset import PointerDataset
46
import torch
57
from torch.utils.data import Dataset
8+
import syft
69

710
logger = logging.getLogger(__name__)
811

912

10-
class BaseDataset:
13+
class BaseDataset(AbstractObject):
1114
"""
1215
This is a base class to be used for manipulating a dataset. This is composed
1316
of a .data attribute for inputs and a .targets one for labels. It is to
@@ -22,8 +25,10 @@ class BaseDataset:
2225
2326
"""
2427

25-
def __init__(self, data, targets, transform=None):
26-
28+
def __init__(self, data, targets, transform=None, owner=None, **kwargs):
29+
if owner is None:
30+
owner = syft.framework.hook.local_worker
31+
super().__init__(owner=owner, **kwargs)
2732
self.data = data
2833
self.targets = targets
2934
self.transform_ = transform
@@ -68,21 +73,9 @@ def transform(self, transform):
6873

6974
raise TypeError("Transforms can be applied only on torch tensors")
7075

71-
def send(self, worker):
72-
"""
73-
Args:
74-
75-
worker[worker class]: worker to which the data must be sent
76-
77-
Returns:
78-
79-
self: Return the object instance with data sent to corresponding worker
80-
81-
"""
82-
83-
self.data.send_(worker)
84-
self.targets.send_(worker)
85-
return self
76+
def send(self, location: BaseWorker):
77+
ptr = self.owner.send(self, workers=location)
78+
return ptr
8679

8780
def get(self):
8881
"""
@@ -93,6 +86,12 @@ def get(self):
9386
self.targets.get_()
9487
return self
9588

89+
def get_data(self):
90+
return self.data
91+
92+
def get_targets(self):
93+
return self.targets
94+
9695
def fix_prec(self, *args, **kwargs):
9796
"""
9897
Converts data of BaseDataset into fixed precision
@@ -121,13 +120,81 @@ def share(self, *args, **kwargs):
121120
self.targets.share_(*args, **kwargs)
122121
return self
123122

123+
def create_pointer(
124+
self, owner, garbage_collect_data, location=None, id_at_location=None, **kwargs
125+
):
126+
"""creats a pointer to the self dataset"""
127+
if owner is None:
128+
owner = self.owner
129+
130+
if location is None:
131+
location = self.owner
132+
133+
owner = self.owner.get_worker(owner)
134+
location = self.owner.get_worker(location)
135+
136+
return PointerDataset(
137+
owner=owner,
138+
location=location,
139+
id_at_location=id_at_location or self.id,
140+
garbage_collect_data=garbage_collect_data,
141+
tags=self.tags,
142+
description=self.description,
143+
)
144+
145+
def __repr__(self):
146+
147+
fmt_str = "BaseDataset\n"
148+
fmt_str += f"\tData: {self.data}\n"
149+
fmt_str += f"\ttargets: {self.targets}"
150+
151+
if self.tags is not None and len(self.tags):
152+
fmt_str += "\n\tTags: "
153+
for tag in self.tags:
154+
fmt_str += str(tag) + " "
155+
156+
if self.description is not None:
157+
fmt_str += "\n\tDescription: " + str(self.description).split("\n")[0] + "..."
158+
159+
return fmt_str
160+
124161
@property
125162
def location(self):
126163
"""
127164
Get location of the data
128165
"""
129166
return self.data.location
130167

168+
@staticmethod
169+
def simplify(worker, dataset: "BaseDataset") -> tuple:
170+
chain = None
171+
if hasattr(dataset, "child"):
172+
chain = syft.serde.msgpack.serde._simplify(worker, dataset.child)
173+
return (
174+
syft.serde.msgpack.serde._simplify(worker, dataset.data),
175+
syft.serde.msgpack.serde._simplify(worker, dataset.targets),
176+
dataset.id,
177+
syft.serde.msgpack.serde._simplify(worker, dataset.tags),
178+
syft.serde.msgpack.serde._simplify(worker, dataset.description),
179+
chain,
180+
)
181+
182+
@staticmethod
183+
def detail(worker, dataset_tuple: tuple) -> "BaseDataset":
184+
data, targets, id, tags, description, chain = dataset_tuple
185+
dataset = BaseDataset(
186+
syft.serde.msgpack.serde._detail(worker, data),
187+
syft.serde.msgpack.serde._detail(worker, targets),
188+
owner=worker,
189+
id=id,
190+
tags=syft.serde.msgpack.serde._detail(worker, tags),
191+
description=syft.serde.msgpack.serde._detail(worker, description),
192+
)
193+
if chain is not None:
194+
chain = syft.serde.msgpack.serde._detail(worker, chain)
195+
dataset.child = chain
196+
return dataset
197+
131198

132199
def dataset_federate(dataset, workers):
133200
"""
@@ -172,11 +239,11 @@ def __init__(self, datasets):
172239
self.datasets[worker_id] = dataset
173240

174241
# Check that data and targets for a worker are consistent
175-
for worker_id in self.workers:
242+
"""for worker_id in self.workers:
176243
dataset = self.datasets[worker_id]
177-
assert len(dataset.data) == len(
178-
dataset.targets
179-
), "On each worker, the input and target must have the same number of rows."
244+
assert (
245+
dataset.data.shape == dataset.targets.shape
246+
), "On each worker, the input and target must have the same number of rows.""" ""
180247

181248
@property
182249
def workers(self):

syft/generic/frameworks/hook/hook_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ def register_response(
658658

659659
try:
660660
assert attr not in ambiguous_functions
661+
assert attr not in ambiguous_methods
661662

662663
# Load the utility function to register the response and transform tensors with pointers
663664
register_response_function = register_response_functions[attr_id]
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import List
2+
from typing import Union
3+
4+
import syft as sy
5+
from syft.generic.pointers.object_pointer import ObjectPointer
6+
from syft.workers.abstract import AbstractWorker
7+
8+
9+
class PointerDataset(ObjectPointer):
10+
def __init__(
11+
self,
12+
location: "AbstractWorker" = None,
13+
id_at_location: Union[str, int] = None,
14+
owner: "AbstractWorker" = None,
15+
garbage_collect_data: bool = True,
16+
id: Union[str, int] = None,
17+
tags: List[str] = None,
18+
description: str = None,
19+
):
20+
if owner is None:
21+
owner = sy.framework.hook.local_worker
22+
super().__init__(
23+
location=location,
24+
id_at_location=id_at_location,
25+
owner=owner,
26+
garbage_collect_data=garbage_collect_data,
27+
id=id,
28+
tags=tags,
29+
description=description,
30+
)
31+
32+
@property
33+
def data(self):
34+
command = ("get_data", self.id_at_location, [], {})
35+
ptr = self.owner.send_command(message=command, recipient=self.location).wrap()
36+
return ptr
37+
38+
@property
39+
def targets(self):
40+
command = ("get_targets", self.id_at_location, [], {})
41+
ptr = self.owner.send_command(message=command, recipient=self.location).wrap()
42+
return ptr
43+
44+
def wrap(self):
45+
return self
46+
47+
def __repr__(self):
48+
type_name = type(self).__name__
49+
out = f"[" f"{type_name} | " f"owner: {str(self.owner.id)}, id:{self.id}"
50+
51+
if self.point_to_attr is not None:
52+
out += "::" + str(self.point_to_attr).replace(".", "::")
53+
54+
big_str = False
55+
56+
if self.tags is not None and len(self.tags):
57+
big_str = True
58+
out += "\n\tTags: "
59+
for tag in self.tags:
60+
out += str(tag) + " "
61+
62+
if big_str and hasattr(self, "shape"):
63+
out += "\n\tShape: " + str(self.shape)
64+
65+
if self.description is not None:
66+
big_str = True
67+
out += "\n\tDescription: " + str(self.description).split("\n")[0] + "..."
68+
69+
return out
70+
71+
def __len__(self):
72+
command = ("__len__", self.id_at_location, [], {})
73+
len = self.owner.send_command(message=command, recipient=self.location)
74+
return len
75+
76+
def __getitem__(self, index):
77+
command = ("__getitem__", self.id_at_location, [index], {})
78+
data_elem, target_elem = self.owner.send_command(message=command, recipient=self.location)
79+
return data_elem.wrap(), target_elem.wrap()

syft/serde/msgpack/serde.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from syft.serde.msgpack.native_serde import MAP_NATIVE_SIMPLIFIERS_AND_DETAILERS
7373
from syft.workers.abstract import AbstractWorker
7474
from syft.workers.base import BaseWorker
75+
from syft.frameworks.torch.fl import BaseDataset
7576

7677
from syft.exceptions import GetNotPermittedError
7778
from syft.exceptions import ResponseSignatureError
@@ -133,6 +134,7 @@
133134
PlanCommandMessage,
134135
GradFunc,
135136
String,
137+
BaseDataset,
136138
ExecuteWorkerFunctionMessage,
137139
]
138140

test/serde/msgpack/test_msgpack_serde_full.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
samples[syft.frameworks.torch.tensors.interpreters.autograd.AutogradTensor] = make_autogradtensor
7272
samples[syft.frameworks.torch.tensors.interpreters.private.PrivateTensor] = make_privatetensor
7373
samples[syft.frameworks.torch.tensors.interpreters.placeholder.PlaceHolder] = make_placeholder
74+
samples[syft.frameworks.torch.fl.dataset.BaseDataset] = make_basedataset
7475

7576
samples[syft.messaging.message.CommandMessage] = make_command_message
7677
samples[syft.messaging.message.ObjectMessage] = make_objectmessage

test/serde/serde_helpers.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,45 @@ def compare(detailed, original):
668668
]
669669

670670

671+
# syft.frameworks.torch.fl.dataset
672+
def make_basedataset(**kwargs):
673+
workers = kwargs["workers"]
674+
alice, bob, james = workers["alice"], workers["bob"], workers["james"]
675+
dataset = syft.BaseDataset(torch.tensor([1, 2, 3, 4]), torch.tensor([5, 6, 7, 8]))
676+
dataset.tag("#tag1").describe("desc")
677+
678+
def compare(detailed, original):
679+
assert type(detailed) == syft.BaseDataset
680+
assert (detailed.data == original.data).all()
681+
assert (detailed.targets == original.targets).all()
682+
assert detailed.id == original.id
683+
assert detailed.tags == original.tags
684+
assert detailed.description == original.description
685+
return True
686+
687+
return [
688+
{
689+
"value": dataset,
690+
"simplified": (
691+
CODE[syft.frameworks.torch.fl.dataset.BaseDataset],
692+
(
693+
msgpack.serde._simplify(syft.hook.local_worker, dataset.data),
694+
msgpack.serde._simplify(syft.hook.local_worker, dataset.targets),
695+
dataset.id,
696+
msgpack.serde._simplify(
697+
syft.hook.local_worker, dataset.tags
698+
), # (set of str) tags
699+
msgpack.serde._simplify(
700+
syft.hook.local_worker, dataset.description
701+
), # (str) description
702+
msgpack.serde._simplify(syft.hook.local_worker, dataset.child),
703+
),
704+
),
705+
"cmp_detailed": compare,
706+
}
707+
]
708+
709+
671710
# syft.execution.plan.Plan
672711
def make_plan(**kwargs):
673712
# Function to plan

test/torch/federated/test_dataset.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,11 @@ def test_base_dataset(workers):
1515
assert len(dataset) == 4
1616
assert dataset[2] == (3, 3)
1717

18-
dataset.send(bob)
18+
dataset = dataset.send(bob)
1919
assert dataset.data.location.id == "bob"
2020
assert dataset.targets.location.id == "bob"
2121
assert dataset.location.id == "bob"
2222

23-
dataset.get()
24-
with pytest.raises(AttributeError):
25-
assert dataset.data.location.id == 0
26-
with pytest.raises(AttributeError):
27-
assert dataset.targets.location.id == 0
28-
2923

3024
def test_base_dataset_transform():
3125

@@ -61,11 +55,12 @@ def test_federated_dataset(workers):
6155
assert fed_dataset.workers == ["bob", "alice"]
6256
assert len(fed_dataset) == 6
6357

64-
fed_dataset["alice"].get()
65-
assert (fed_dataset["alice"].data == alice_base_dataset.data).all()
66-
assert fed_dataset["alice"][2] == (5, 5)
67-
assert len(fed_dataset["alice"]) == 4
68-
assert len(fed_dataset) == 6
58+
alice_remote_data = fed_dataset["alice"].get()
59+
del fed_dataset.datasets["alice"]
60+
assert (alice_remote_data.data == alice_base_dataset.data).all()
61+
assert alice_remote_data[2] == (5, 5)
62+
assert len(alice_remote_data) == 4
63+
assert len(fed_dataset) == 2
6964

7065
assert isinstance(fed_dataset.__str__(), str)
7166

@@ -114,3 +109,12 @@ def test_federated_dataset_search(workers):
114109
counter += 1
115110

116111
assert counter == len(train_loader), f"{counter} == {len(fed_dataset)}"
112+
113+
114+
def test_abstract_dataset():
115+
inputs = th.tensor([1, 2, 3, 4.0])
116+
targets = th.tensor([1, 2, 3, 4.0])
117+
dataset = BaseDataset(inputs, targets, id=1)
118+
119+
assert dataset.id == 1
120+
assert dataset.description == None

0 commit comments

Comments
 (0)