Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
97a2cf2
dataset inherits abstractobject
abogaziah Jan 24, 2020
bb5054a
added abstract dataset test
abogaziah Jan 24, 2020
803b37c
formatting
abogaziah Jan 24, 2020
500cf4a
fix
abogaziah Jan 24, 2020
1fabb5a
removed owner annotation
abogaziah Jan 25, 2020
ffc0e08
formatting
abogaziah Jan 25, 2020
c3da99b
Added simplify& detail
abogaziah Jan 26, 2020
8a72d64
Merge branch 'master' into master
LaRiffle Jan 26, 2020
0937d00
Edit simplify& detail
abogaziah Jan 26, 2020
7acd96c
Merge branch 'master' of https://github.com/abogaziah/PySyft
abogaziah Jan 26, 2020
871adf8
Merge branch 'master' into master
abogaziah Feb 5, 2020
06f6298
Merge branch 'master' into master
LaRiffle Feb 5, 2020
46bb642
Added PointerDataset type
abogaziah Feb 11, 2020
849ff8a
Merge branch 'master' of https://github.com/abogaziah/PySyft
abogaziah Feb 11, 2020
6472aac
Added send()
abogaziah Feb 11, 2020
fc27545
Added send(), create_pointer()
abogaziah Feb 11, 2020
017d7f7
Merge remote-tracking branch 'upstream/master'
abogaziah Feb 12, 2020
5a4a8b3
Added get(), pointer.data(), pointer.target()
abogaziah Feb 13, 2020
2b4ee06
Added wrap(), edited base.search(), edit dataset.__init__()
abogaziah Feb 15, 2020
ae5101c
Formatting
abogaziah Feb 15, 2020
880864a
Added test file, __repr__(), removed get() override
abogaziah Feb 17, 2020
fe52b27
Added tags, description to create_pointer()
abogaziah Feb 20, 2020
5e37ab4
formatting
abogaziah Feb 22, 2020
e1aa881
changed repr of dataset ptr
abogaziah Feb 22, 2020
bd860c3
replaced == with torch.equal()
abogaziah Feb 22, 2020
847f51a
adapt dataloader
abogaziah Mar 1, 2020
7e239c9
formatting
abogaziah Mar 1, 2020
a1ddd72
added serede test!
abogaziah Mar 2, 2020
6ec07fc
added serede test!
abogaziah Mar 2, 2020
4620cd9
Merge branch 'master' into master
LaRiffle Mar 2, 2020
a6491ac
test fixes
abogaziah Mar 5, 2020
17a4251
test fixes
abogaziah Mar 6, 2020
f50cda5
serde test fix
abogaziah Mar 6, 2020
9169ad9
formatting
abogaziah Mar 6, 2020
022a0da
hook_args fix
abogaziah Mar 12, 2020
69a10fb
Merge branch 'master' into master
LaRiffle Mar 12, 2020
a3ddae8
Update test_dataset_pointer.py
LaRiffle Mar 12, 2020
c776e10
pointer location fix
abogaziah Mar 13, 2020
ebfc062
formatting
abogaziah Mar 13, 2020
1081a3c
formatting
abogaziah Mar 13, 2020
ef5204b
formatting
abogaziah Mar 13, 2020
642166c
formatting
abogaziah Mar 13, 2020
ae46ad1
fedrated check pass
abogaziah Mar 16, 2020
bfe8f2a
formatting
abogaziah Mar 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 90 additions & 23 deletions syft/frameworks/torch/fl/dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import math
import logging

from syft.generic.object import AbstractObject
from syft.workers.base import BaseWorker
from syft.generic.pointers.pointer_dataset import PointerDataset
import torch
from torch.utils.data import Dataset
import syft

logger = logging.getLogger(__name__)


class BaseDataset:
class BaseDataset(AbstractObject):
"""
This is a base class to be used for manipulating a dataset. This is composed
of a .data attribute for inputs and a .targets one for labels. It is to
Expand All @@ -22,8 +25,10 @@ class BaseDataset:

"""

def __init__(self, data, targets, transform=None):

def __init__(self, data, targets, transform=None, owner=None, **kwargs):
if owner is None:
owner = syft.framework.hook.local_worker
super().__init__(owner=owner, **kwargs)
self.data = data
self.targets = targets
self.transform_ = transform
Expand Down Expand Up @@ -68,21 +73,9 @@ def transform(self, transform):

raise TypeError("Transforms can be applied only on torch tensors")

def send(self, worker):
"""
Args:

worker[worker class]: worker to which the data must be sent

Returns:

self: Return the object instance with data sent to corresponding worker

"""

self.data.send_(worker)
self.targets.send_(worker)
return self
def send(self, location: BaseWorker):
ptr = self.owner.send(self, workers=location)
return ptr

def get(self):
"""
Expand All @@ -93,6 +86,12 @@ def get(self):
self.targets.get_()
return self

def get_data(self):
return self.data

def get_targets(self):
return self.targets

def fix_prec(self, *args, **kwargs):
"""
Converts data of BaseDataset into fixed precision
Expand Down Expand Up @@ -121,13 +120,81 @@ def share(self, *args, **kwargs):
self.targets.share_(*args, **kwargs)
return self

def create_pointer(
self, owner, garbage_collect_data, location=None, id_at_location=None, **kwargs
):
"""creats a pointer to the self dataset"""
if owner is None:
owner = self.owner

if location is None:
location = self.owner

owner = self.owner.get_worker(owner)
location = self.owner.get_worker(location)

return PointerDataset(
owner=owner,
location=location,
id_at_location=id_at_location or self.id,
garbage_collect_data=garbage_collect_data,
tags=self.tags,
description=self.description,
)

def __repr__(self):

fmt_str = "BaseDataset\n"
fmt_str += f"\tData: {self.data}\n"
fmt_str += f"\ttargets: {self.targets}"

if self.tags is not None and len(self.tags):
fmt_str += "\n\tTags: "
for tag in self.tags:
fmt_str += str(tag) + " "

if self.description is not None:
fmt_str += "\n\tDescription: " + str(self.description).split("\n")[0] + "..."

return fmt_str

@property
def location(self):
"""
Get location of the data
"""
return self.data.location

@staticmethod
def simplify(worker, dataset: "BaseDataset") -> tuple:
chain = None
if hasattr(dataset, "child"):
chain = syft.serde.msgpack.serde._simplify(worker, dataset.child)
return (
syft.serde.msgpack.serde._simplify(worker, dataset.data),
syft.serde.msgpack.serde._simplify(worker, dataset.targets),
dataset.id,
syft.serde.msgpack.serde._simplify(worker, dataset.tags),
syft.serde.msgpack.serde._simplify(worker, dataset.description),
chain,
)

@staticmethod
def detail(worker, dataset_tuple: tuple) -> "BaseDataset":
data, targets, id, tags, description, chain = dataset_tuple
dataset = BaseDataset(
syft.serde.msgpack.serde._detail(worker, data),
syft.serde.msgpack.serde._detail(worker, targets),
owner=worker,
id=id,
tags=syft.serde.msgpack.serde._detail(worker, tags),
description=syft.serde.msgpack.serde._detail(worker, description),
)
if chain is not None:
chain = syft.serde.msgpack.serde._detail(worker, chain)
dataset.child = chain
return dataset


def dataset_federate(dataset, workers):
"""
Expand Down Expand Up @@ -172,11 +239,11 @@ def __init__(self, datasets):
self.datasets[worker_id] = dataset

# Check that data and targets for a worker are consistent
for worker_id in self.workers:
"""for worker_id in self.workers:
dataset = self.datasets[worker_id]
assert len(dataset.data) == len(
dataset.targets
), "On each worker, the input and target must have the same number of rows."
assert (
dataset.data.shape == dataset.targets.shape
), "On each worker, the input and target must have the same number of rows.""" ""

@property
def workers(self):
Expand Down
1 change: 1 addition & 0 deletions syft/generic/frameworks/hook/hook_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ def register_response(

try:
assert attr not in ambiguous_functions
assert attr not in ambiguous_methods

# Load the utility function to register the response and transform tensors with pointers
register_response_function = register_response_functions[attr_id]
Expand Down
79 changes: 79 additions & 0 deletions syft/generic/pointers/pointer_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import List
from typing import Union

import syft as sy
from syft.generic.pointers.object_pointer import ObjectPointer
from syft.workers.abstract import AbstractWorker


class PointerDataset(ObjectPointer):
def __init__(
self,
location: "AbstractWorker" = None,
id_at_location: Union[str, int] = None,
owner: "AbstractWorker" = None,
garbage_collect_data: bool = True,
id: Union[str, int] = None,
tags: List[str] = None,
description: str = None,
):
if owner is None:
owner = sy.framework.hook.local_worker
super().__init__(
location=location,
id_at_location=id_at_location,
owner=owner,
garbage_collect_data=garbage_collect_data,
id=id,
tags=tags,
description=description,
)

@property
def data(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how this works exactly, I'll wait that you add tests on this functionality to comment :)

command = ("get_data", self.id_at_location, [], {})
ptr = self.owner.send_command(message=command, recipient=self.location).wrap()
return ptr

@property
def targets(self):
command = ("get_targets", self.id_at_location, [], {})
ptr = self.owner.send_command(message=command, recipient=self.location).wrap()
return ptr

def wrap(self):
return self

def __repr__(self):
type_name = type(self).__name__
out = f"[" f"{type_name} | " f"owner: {str(self.owner.id)}, id:{self.id}"

if self.point_to_attr is not None:
out += "::" + str(self.point_to_attr).replace(".", "::")

big_str = False

if self.tags is not None and len(self.tags):
big_str = True
out += "\n\tTags: "
for tag in self.tags:
out += str(tag) + " "

if big_str and hasattr(self, "shape"):
out += "\n\tShape: " + str(self.shape)

if self.description is not None:
big_str = True
out += "\n\tDescription: " + str(self.description).split("\n")[0] + "..."

return out

def __len__(self):
command = ("__len__", self.id_at_location, [], {})
len = self.owner.send_command(message=command, recipient=self.location)
return len

def __getitem__(self, index):
command = ("__getitem__", self.id_at_location, [index], {})
data_elem, target_elem = self.owner.send_command(message=command, recipient=self.location)
return data_elem.wrap(), target_elem.wrap()
2 changes: 2 additions & 0 deletions syft/serde/msgpack/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from syft.serde.msgpack.native_serde import MAP_NATIVE_SIMPLIFIERS_AND_DETAILERS
from syft.workers.abstract import AbstractWorker
from syft.workers.base import BaseWorker
from syft.frameworks.torch.fl import BaseDataset

from syft.exceptions import GetNotPermittedError
from syft.exceptions import ResponseSignatureError
Expand Down Expand Up @@ -131,6 +132,7 @@
PlanCommandMessage,
GradFunc,
String,
BaseDataset,
ExecuteWorkerFunctionMessage,
]

Expand Down
1 change: 1 addition & 0 deletions test/serde/msgpack/test_msgpack_serde_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
samples[syft.frameworks.torch.tensors.interpreters.autograd.AutogradTensor] = make_autogradtensor
samples[syft.frameworks.torch.tensors.interpreters.private.PrivateTensor] = make_privatetensor
samples[syft.frameworks.torch.tensors.interpreters.placeholder.PlaceHolder] = make_placeholder
samples[syft.frameworks.torch.fl.dataset.BaseDataset] = make_basedataset

samples[syft.messaging.message.CommandMessage] = make_command_message
samples[syft.messaging.message.ObjectMessage] = make_objectmessage
Expand Down
39 changes: 39 additions & 0 deletions test/serde/serde_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,45 @@ def compare(detailed, original):
]


# syft.frameworks.torch.fl.dataset
def make_basedataset(**kwargs):
workers = kwargs["workers"]
alice, bob, james = workers["alice"], workers["bob"], workers["james"]
dataset = syft.BaseDataset(torch.tensor([1, 2, 3, 4]), torch.tensor([5, 6, 7, 8]))
dataset.tag("#tag1").describe("desc")

def compare(detailed, original):
assert type(detailed) == syft.BaseDataset
assert (detailed.data == original.data).all()
assert (detailed.targets == original.targets).all()
assert detailed.id == original.id
assert detailed.tags == original.tags
assert detailed.description == original.description
return True

return [
{
"value": dataset,
"simplified": (
CODE[syft.frameworks.torch.fl.dataset.BaseDataset],
(
msgpack.serde._simplify(syft.hook.local_worker, dataset.data),
msgpack.serde._simplify(syft.hook.local_worker, dataset.targets),
dataset.id,
msgpack.serde._simplify(
syft.hook.local_worker, dataset.tags
), # (set of str) tags
msgpack.serde._simplify(
syft.hook.local_worker, dataset.description
), # (str) description
msgpack.serde._simplify(syft.hook.local_worker, dataset.child),
),
),
"cmp_detailed": compare,
}
]


# syft.execution.plan.Plan
def make_plan(**kwargs):
# Function to plan
Expand Down
28 changes: 16 additions & 12 deletions test/torch/federated/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,11 @@ def test_base_dataset(workers):
assert len(dataset) == 4
assert dataset[2] == (3, 3)

dataset.send(bob)
dataset = dataset.send(bob)
assert dataset.data.location.id == "bob"
assert dataset.targets.location.id == "bob"
assert dataset.location.id == "bob"

dataset.get()
with pytest.raises(AttributeError):
assert dataset.data.location.id == 0
with pytest.raises(AttributeError):
assert dataset.targets.location.id == 0


def test_base_dataset_transform():

Expand Down Expand Up @@ -61,11 +55,12 @@ def test_federated_dataset(workers):
assert fed_dataset.workers == ["bob", "alice"]
assert len(fed_dataset) == 6

fed_dataset["alice"].get()
assert (fed_dataset["alice"].data == alice_base_dataset.data).all()
assert fed_dataset["alice"][2] == (5, 5)
assert len(fed_dataset["alice"]) == 4
assert len(fed_dataset) == 6
alice_remote_data = fed_dataset["alice"].get()
del fed_dataset.datasets["alice"]
assert (alice_remote_data.data == alice_base_dataset.data).all()
assert alice_remote_data[2] == (5, 5)
assert len(alice_remote_data) == 4
assert len(fed_dataset) == 2

assert isinstance(fed_dataset.__str__(), str)

Expand Down Expand Up @@ -114,3 +109,12 @@ def test_federated_dataset_search(workers):
counter += 1

assert counter == len(train_loader), f"{counter} == {len(fed_dataset)}"


def test_abstract_dataset():
inputs = th.tensor([1, 2, 3, 4.0])
targets = th.tensor([1, 2, 3, 4.0])
dataset = BaseDataset(inputs, targets, id=1)

assert dataset.id == 1
assert dataset.description == None
Loading