11import math
22import logging
3-
3+ from syft .generic .object import AbstractObject
4+ from syft .workers .base import BaseWorker
5+ from syft .generic .pointers .pointer_dataset import PointerDataset
46import torch
57from torch .utils .data import Dataset
8+ import syft
69
710logger = logging .getLogger (__name__ )
811
912
10- class BaseDataset :
13+ class BaseDataset ( AbstractObject ) :
1114 """
1215 This is a base class to be used for manipulating a dataset. This is composed
1316 of a .data attribute for inputs and a .targets one for labels. It is to
@@ -22,8 +25,10 @@ class BaseDataset:
2225
2326 """
2427
25- def __init__ (self , data , targets , transform = None ):
26-
28+ def __init__ (self , data , targets , transform = None , owner = None , ** kwargs ):
29+ if owner is None :
30+ owner = syft .framework .hook .local_worker
31+ super ().__init__ (owner = owner , ** kwargs )
2732 self .data = data
2833 self .targets = targets
2934 self .transform_ = transform
@@ -68,21 +73,9 @@ def transform(self, transform):
6873
6974 raise TypeError ("Transforms can be applied only on torch tensors" )
7075
71- def send (self , worker ):
72- """
73- Args:
74-
75- worker[worker class]: worker to which the data must be sent
76-
77- Returns:
78-
79- self: Return the object instance with data sent to corresponding worker
80-
81- """
82-
83- self .data .send_ (worker )
84- self .targets .send_ (worker )
85- return self
76+ def send (self , location : BaseWorker ):
77+ ptr = self .owner .send (self , workers = location )
78+ return ptr
8679
8780 def get (self ):
8881 """
@@ -93,6 +86,12 @@ def get(self):
9386 self .targets .get_ ()
9487 return self
9588
89+ def get_data (self ):
90+ return self .data
91+
92+ def get_targets (self ):
93+ return self .targets
94+
9695 def fix_prec (self , * args , ** kwargs ):
9796 """
9897 Converts data of BaseDataset into fixed precision
@@ -121,13 +120,81 @@ def share(self, *args, **kwargs):
121120 self .targets .share_ (* args , ** kwargs )
122121 return self
123122
123+ def create_pointer (
124+ self , owner , garbage_collect_data , location = None , id_at_location = None , ** kwargs
125+ ):
126+ """creats a pointer to the self dataset"""
127+ if owner is None :
128+ owner = self .owner
129+
130+ if location is None :
131+ location = self .owner
132+
133+ owner = self .owner .get_worker (owner )
134+ location = self .owner .get_worker (location )
135+
136+ return PointerDataset (
137+ owner = owner ,
138+ location = location ,
139+ id_at_location = id_at_location or self .id ,
140+ garbage_collect_data = garbage_collect_data ,
141+ tags = self .tags ,
142+ description = self .description ,
143+ )
144+
145+ def __repr__ (self ):
146+
147+ fmt_str = "BaseDataset\n "
148+ fmt_str += f"\t Data: { self .data } \n "
149+ fmt_str += f"\t targets: { self .targets } "
150+
151+ if self .tags is not None and len (self .tags ):
152+ fmt_str += "\n \t Tags: "
153+ for tag in self .tags :
154+ fmt_str += str (tag ) + " "
155+
156+ if self .description is not None :
157+ fmt_str += "\n \t Description: " + str (self .description ).split ("\n " )[0 ] + "..."
158+
159+ return fmt_str
160+
124161 @property
125162 def location (self ):
126163 """
127164 Get location of the data
128165 """
129166 return self .data .location
130167
168+ @staticmethod
169+ def simplify (worker , dataset : "BaseDataset" ) -> tuple :
170+ chain = None
171+ if hasattr (dataset , "child" ):
172+ chain = syft .serde .msgpack .serde ._simplify (worker , dataset .child )
173+ return (
174+ syft .serde .msgpack .serde ._simplify (worker , dataset .data ),
175+ syft .serde .msgpack .serde ._simplify (worker , dataset .targets ),
176+ dataset .id ,
177+ syft .serde .msgpack .serde ._simplify (worker , dataset .tags ),
178+ syft .serde .msgpack .serde ._simplify (worker , dataset .description ),
179+ chain ,
180+ )
181+
182+ @staticmethod
183+ def detail (worker , dataset_tuple : tuple ) -> "BaseDataset" :
184+ data , targets , id , tags , description , chain = dataset_tuple
185+ dataset = BaseDataset (
186+ syft .serde .msgpack .serde ._detail (worker , data ),
187+ syft .serde .msgpack .serde ._detail (worker , targets ),
188+ owner = worker ,
189+ id = id ,
190+ tags = syft .serde .msgpack .serde ._detail (worker , tags ),
191+ description = syft .serde .msgpack .serde ._detail (worker , description ),
192+ )
193+ if chain is not None :
194+ chain = syft .serde .msgpack .serde ._detail (worker , chain )
195+ dataset .child = chain
196+ return dataset
197+
131198
132199def dataset_federate (dataset , workers ):
133200 """
@@ -172,11 +239,11 @@ def __init__(self, datasets):
172239 self .datasets [worker_id ] = dataset
173240
174241 # Check that data and targets for a worker are consistent
175- for worker_id in self .workers :
242+ """ for worker_id in self.workers:
176243 dataset = self.datasets[worker_id]
177- assert len ( dataset . data ) == len (
178- dataset .targets
179- ), "On each worker, the input and target must have the same number of rows."
244+ assert (
245+ dataset.data.shape == dataset. targets.shape
246+ ), "On each worker, the input and target must have the same number of rows.""" ""
180247
181248 @property
182249 def workers (self ):
0 commit comments