Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
97a2cf2
dataset inherits abstractobject
abogaziah Jan 24, 2020
bb5054a
added abstract dataset test
abogaziah Jan 24, 2020
803b37c
formatting
abogaziah Jan 24, 2020
500cf4a
fix
abogaziah Jan 24, 2020
1fabb5a
removed owner annotation
abogaziah Jan 25, 2020
ffc0e08
formatting
abogaziah Jan 25, 2020
c3da99b
Added simplify& detail
abogaziah Jan 26, 2020
8a72d64
Merge branch 'master' into master
LaRiffle Jan 26, 2020
0937d00
Edit simplify& detail
abogaziah Jan 26, 2020
7acd96c
Merge branch 'master' of https://github.com/abogaziah/PySyft
abogaziah Jan 26, 2020
871adf8
Merge branch 'master' into master
abogaziah Feb 5, 2020
06f6298
Merge branch 'master' into master
LaRiffle Feb 5, 2020
46bb642
Added PointerDataset type
abogaziah Feb 11, 2020
849ff8a
Merge branch 'master' of https://github.com/abogaziah/PySyft
abogaziah Feb 11, 2020
6472aac
Added send()
abogaziah Feb 11, 2020
fc27545
Added send(), create_pointer()
abogaziah Feb 11, 2020
017d7f7
Merge remote-tracking branch 'upstream/master'
abogaziah Feb 12, 2020
5a4a8b3
Added get(), pointer.data(), pointer.target()
abogaziah Feb 13, 2020
2b4ee06
Added wrap(), edited base.search(), edit dataset.__init__()
abogaziah Feb 15, 2020
ae5101c
Formatting
abogaziah Feb 15, 2020
880864a
Added test file, __repr__(), removed get() override
abogaziah Feb 17, 2020
fe52b27
Added tags, description to create_pointer()
abogaziah Feb 20, 2020
5e37ab4
formatting
abogaziah Feb 22, 2020
e1aa881
changed repr of dataset ptr
abogaziah Feb 22, 2020
bd860c3
replaced == with torch.equal()
abogaziah Feb 22, 2020
847f51a
adapt dataloader
abogaziah Mar 1, 2020
7e239c9
formatting
abogaziah Mar 1, 2020
a1ddd72
added serede test!
abogaziah Mar 2, 2020
6ec07fc
added serede test!
abogaziah Mar 2, 2020
4620cd9
Merge branch 'master' into master
LaRiffle Mar 2, 2020
a6491ac
test fixes
abogaziah Mar 5, 2020
7245e60
[CI debug] Rm tests
LaRiffle Mar 6, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 78 additions & 21 deletions syft/frameworks/torch/fl/dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import math
import logging

from syft.generic.object import AbstractObject
from syft.workers.base import BaseWorker
from syft.generic.pointers.pointer_dataset import PointerDataset
import torch
from torch.utils.data import Dataset
import syft

logger = logging.getLogger(__name__)


class BaseDataset:
class BaseDataset(AbstractObject):
"""
This is a base class to be used for manipulating a dataset. This is composed
of a .data attribute for inputs and a .targets one for labels. It is to
Expand All @@ -22,8 +25,10 @@ class BaseDataset:

"""

def __init__(self, data, targets, transform=None):

def __init__(self, data, targets, transform=None, owner=None, **kwargs):
if owner is None:
owner = syft.framework.hook.local_worker
super().__init__(owner=owner, **kwargs)
self.data = data
self.targets = targets
self.transform_ = transform
Expand Down Expand Up @@ -68,21 +73,9 @@ def transform(self, transform):

raise TypeError("Transforms can be applied only on torch tensors")

def send(self, worker):
"""
Args:

worker[worker class]: worker to which the data must be sent

Returns:

self: Return the object instance with data sent to corresponding worker

"""

self.data.send_(worker)
self.targets.send_(worker)
return self
def send(self, location: BaseWorker):
ptr = self.owner.send(self, workers=location)
return ptr

def get(self):
"""
Expand All @@ -93,6 +86,12 @@ def get(self):
self.targets.get_()
return self

def get_data(self):
return self.data

def get_targets(self):
return self.targets

def fix_prec(self, *args, **kwargs):
"""
Converts data of BaseDataset into fixed precision
Expand Down Expand Up @@ -121,13 +120,71 @@ def share(self, *args, **kwargs):
self.targets.share_(*args, **kwargs)
return self

def create_pointer(
self, owner, garbage_collect_data, location=None, id_at_location=None, **kwargs
):
return PointerDataset(
owner=owner,
location=location,
id_at_location=id_at_location or self.id,
garbage_collect_data=garbage_collect_data,
tags=self.tags,
description=self.description,
)

def __repr__(self):

fmt_str = "BaseDataset\n"
fmt_str += f"\tData: {self.data}\n"
fmt_str += f"\ttargets: {self.targets}"

if self.tags is not None and len(self.tags):
fmt_str += "\n\tTags: "
for tag in self.tags:
fmt_str += str(tag) + " "

if self.description is not None:
fmt_str += "\n\tDescription: " + str(self.description).split("\n")[0] + "..."

return fmt_str

@property
def location(self):
"""
Get location of the data
"""
return self.data.location

@staticmethod
def simplify(worker, dataset: "BaseDataset") -> tuple:
chain = None
if hasattr(dataset, "child"):
chain = syft.serde.msgpack.serde._simplify(worker, dataset.child)
return (
syft.serde.msgpack.serde._simplify(worker, dataset.data),
syft.serde.msgpack.serde._simplify(worker, dataset.targets),
dataset.id,
syft.serde.msgpack.serde._simplify(worker, dataset.tags),
syft.serde.msgpack.serde._simplify(worker, dataset.description),
chain,
)

@staticmethod
def detail(worker, dataset_tuple: tuple) -> "BaseDataset":
data, targets, id, tags, description, chain = dataset_tuple
dataset = BaseDataset(
syft.serde.msgpack.serde._detail(worker, data),
syft.serde.msgpack.serde._detail(worker, targets),
id=id,
tags=syft.serde.msgpack.serde._detail(worker, tags),
description=syft.serde.msgpack.serde._detail(worker, description),
)
if chain is not None:
chain = syft.serde.msgpack.serde._detail(worker, chain)
dataset.child = chain

return dataset


def dataset_federate(dataset, workers):
"""
Expand Down Expand Up @@ -174,8 +231,8 @@ def __init__(self, datasets):
# Check that data and targets for a worker are consistent
for worker_id in self.workers:
dataset = self.datasets[worker_id]
assert len(dataset.data) == len(
dataset.targets
assert (
dataset.data.shape == dataset.targets.shape
), "On each worker, the input and target must have the same number of rows."

@property
Expand Down
79 changes: 79 additions & 0 deletions syft/generic/pointers/pointer_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import List
from typing import Union

import syft as sy
from syft.generic.pointers.object_pointer import ObjectPointer
from syft.workers.abstract import AbstractWorker


class PointerDataset(ObjectPointer):
def __init__(
self,
location: "AbstractWorker" = None,
id_at_location: Union[str, int] = None,
owner: "AbstractWorker" = None,
garbage_collect_data: bool = True,
id: Union[str, int] = None,
tags: List[str] = None,
description: str = None,
):
if owner is None:
owner = sy.framework.hook.local_worker
super().__init__(
location=location,
id_at_location=id_at_location,
owner=owner,
garbage_collect_data=garbage_collect_data,
id=id,
tags=tags,
description=description,
)

@property
def data(self):
command = ("get_data", self.id_at_location, [], {})
ptr = self.owner.send_command(message=command, recipient=self.location).wrap()
return ptr

@property
def targets(self):
command = ("get_targets", self.id_at_location, [], {})
ptr = self.owner.send_command(message=command, recipient=self.location).wrap()
return ptr

def wrap(self):
return self

def __repr__(self):
type_name = type(self).__name__
out = f"[" f"{type_name} | " f"owner: {str(self.owner.id)}, id:{self.id}"

if self.point_to_attr is not None:
out += "::" + str(self.point_to_attr).replace(".", "::")

big_str = False

if self.tags is not None and len(self.tags):
big_str = True
out += "\n\tTags: "
for tag in self.tags:
out += str(tag) + " "

if big_str and hasattr(self, "shape"):
out += "\n\tShape: " + str(self.shape)

if self.description is not None:
big_str = True
out += "\n\tDescription: " + str(self.description).split("\n")[0] + "..."

return out

def __len__(self):
command = ("__len__", self.id_at_location, [], {})
len = self.owner.send_command(message=command, recipient=self.location)
return len

def __getitem__(self, index):
command = ("__getitem__", self.id_at_location, [index], {})
data_elem, target_elem = self.owner.send_command(message=command, recipient=self.location)
return data_elem.wrap(), target_elem.wrap()
2 changes: 2 additions & 0 deletions syft/serde/msgpack/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from syft.serde.msgpack.native_serde import MAP_NATIVE_SIMPLIFIERS_AND_DETAILERS
from syft.workers.abstract import AbstractWorker
from syft.workers.base import BaseWorker
from syft.frameworks.torch.fl import BaseDataset

from syft.exceptions import GetNotPermittedError
from syft.exceptions import ResponseSignatureError
Expand Down Expand Up @@ -130,6 +131,7 @@
PlanCommandMessage,
GradFunc,
String,
BaseDataset,
]

# If an object implements its own force_simplify and force_detail functions it should be stored in this list
Expand Down
38 changes: 0 additions & 38 deletions test/common/test_util.py

This file was deleted.

Empty file removed test/efficiency/__init__.py
Empty file.
23 changes: 0 additions & 23 deletions test/efficiency/assertions.py

This file was deleted.

18 changes: 0 additions & 18 deletions test/efficiency/test_activations_time.py

This file was deleted.

20 changes: 0 additions & 20 deletions test/efficiency/test_linalg_time.py

This file was deleted.

Loading