Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions syft/frameworks/torch/tensors/interpreters/additive_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,14 @@ def reconstruct(self):

return sy.MultiPointerTensor(children=pointers)

def zero(self):
def zero(self, shape=None):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed this in case I have a matrix multiplication - in that case the dimension for the zero shares might be different.

"""
Build an additive shared tensor of value zero with the same
properties as self
"""
shape = self.shape if self.shape else [1]

if shape == None or len(shape) == 0:
shape = self.shape if self.shape else [1]
zero = (
torch.zeros(*shape)
.long()
Expand Down Expand Up @@ -482,11 +484,17 @@ def _public_mul(self, shares, other, equation):
other_is_zero = True

if other_is_zero:
zero_shares = self.zero().child
return {
worker: ((cmd(share, other) + zero_shares[worker]) % self.field)
for worker, share in shares.items()
}
res = {}
first_it = True

for worker, share in shares.items():
cmd_res = cmd(share, other)
if first_it:
first_it = False
zero_shares = self.zero(cmd_res.shape).child

res[worker] = (cmd(share, other) + zero_shares[worker]) % self.field
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put all in the same loop (and add a flag to notice when it's the first iteration) ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

return res
else:
return {
worker: (cmd(share, other) % self.field) for worker, share in shares.items()
Expand Down
82 changes: 56 additions & 26 deletions syft/frameworks/torch/tensors/interpreters/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def matmul(self, *args, **kwargs):
), "In matmul, all args should have the same precision_fractional"

if isinstance(self.child, AdditiveSharingTensor) and isinstance(other.child, torch.Tensor):
# If we try to matmul a FPT>(wrap)>AST with a FPT>torch.tensor,
# If we try to matmul a FPT>AST with a FPT>torch.tensor,
# we want to perform AST @ torch.tensor
new_self = self.child
new_args = (other,)
Expand All @@ -422,7 +422,7 @@ def matmul(self, *args, **kwargs):
elif isinstance(other.child, AdditiveSharingTensor) and isinstance(
self.child, torch.Tensor
):
# If we try to matmul a FPT>torch.tensor with a FPT>(wrap)>AST,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LaRiffle removed the wrap

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

# If we try to matmul a FPT>torch.tensor with a FPT>AST,
# we swap operators so that we do the same operation as above
new_self = other.child
new_args = (self,)
Expand Down Expand Up @@ -551,6 +551,58 @@ def log(self, iterations=2, exp_iterations=8):

return y

@staticmethod
def _tanh_chebyshev(tensor, maxval: int = 6, terms: int = 32):
"""
Implementation taken from FacebookResearch - CrypTen project
Computes tanh via Chebyshev approximation with truncation.
tanh(x) = \sum_{j=1}^terms c_{2j - 1} P_{2j - 1} (x / maxval)
where c_i is the ith Chebyshev series coefficient and P_i is ith polynomial.
The approximation is truncated to +/-1 outside [-maxval, maxval].
Args:
maxval (int): interval width used for computing chebyshev polynomials
terms (int): highest degree of Chebyshev polynomials.
Must be even and at least 6.
More details can be found in the paper:
Guo, Chuan and Hannun, Awni and Knott, Brian and van der Maaten,
Laurens and Tygert, Mark and Zhu, Ruiyu,
"Secure multiparty computations in floating-point arithmetic", Jan-2020
Link: http://tygert.com/realcrypt.pdf

"""

coeffs = syft.common.util.chebyshev_series(torch.tanh, maxval, terms)[1::2]
coeffs = coeffs.fix_precision(**tensor.get_class_attributes()) % tensor.field
coeffs = coeffs.unsqueeze(1)

value = torch.tensor(maxval).fix_precision(**tensor.get_class_attributes()) % tensor.field
tanh_polys = syft.common.util.chebyshev_polynomials(tensor.div(value.child), terms)
tanh_polys_flipped = tanh_polys.unsqueeze(dim=-1).transpose(0, -1).squeeze(dim=0)

out = tanh_polys_flipped.matmul(coeffs.child)

# truncate outside [-maxval, maxval]
gate_up = tensor > value
gate_down = -tensor > value
res = gate_up - gate_down
out = out.squeeze(1) * (1 - gate_up - gate_down)
out = res + out

return out

@staticmethod
def _tanh_sigmoid(tensor):
"""
Compute the tanh using the sigmoid

"""
return 2 * torch.sigmoid(2 * tensor) - 1

def tanh(tensor, method="chebyshev"):
tanh_f = getattr(tensor, f"_tanh_{method}")

return tanh_f(tensor)

# Binary ops
@overloaded.method
def __gt__(self, _self, other):
Expand Down Expand Up @@ -635,30 +687,8 @@ def log(tensor):

module.log = log

def tanh(tensor, maxval=6, terms=32):
"""
Overloads torch.tanh to be able to use MPC

Implementation taken from FacebookResearch - CrypTen project
Computes tanh via Chebyshev approximation with truncation.
.. math::
tanh(x) = \sum_{j=1}^terms c_{2j - 1} P_{2j - 1} (x / maxval)
where c_i is the ith Chebyshev series coefficient and P_i is ith polynomial.
The approximation is truncated to +/-1 outside [-maxval, maxval].
Args:
maxval (int): interval width used for computing chebyshev polynomials
terms (int): highest degree of Chebyshev polynomials.
Must be even and at least 6.
"""
coeffs = syft.common.util.chebyshev_series(torch.tanh, maxval, terms)[1::2]
tanh_polys = syft.common.util.chebyshev_polynomials(tensor.div(maxval), terms)
tanh_polys_flipped = tanh_polys.unsqueeze(dim=-1).transpose(0, -1).squeeze(dim=0)

out = tanh_polys_flipped.matmul(coeffs)
# truncate outside [-maxval, maxval]
out = torch.where(tensor > maxval, 1.0, out)
out = torch.where(tensor < -maxval, -1.0, out)
return out
def tanh(tensor):
return tensor.tanh()

module.tanh = tanh

Expand Down
12 changes: 6 additions & 6 deletions syft/generic/frameworks/hook/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,25 +689,25 @@ def overloaded_attr(self, *args, **kwargs):
def _string_input_args_adaptor(cls, args: Tuple[object]):
"""
This method is used when hooking String methods.

Some 'String' methods which are overriden from 'str'
such as the magic '__add__' method
expects an object of type 'str' as its first
argument. However, since the '__add__' method
here is hooked to a String type, it will receive
arguments of type 'String' not 'str' in some cases.
This won't worker for the underlying hooked method
'__add__' of the 'str' type.
'__add__' of the 'str' type.
That is why the 'String' argument to '__add__' should
be peeled down to 'str'

Args:
args: A tuple or positional arguments of the method
being hooked to the String class.

Return:
A list of adapted positional arguments.

"""

new_args = []
Expand Down Expand Up @@ -739,7 +739,7 @@ def _wrap_str_return_value(cls, _self, attr: str, value: object):
@classmethod
def _get_hooked_string_method(cls, attr):
"""
Hook a `str` method to a corresponding method of
Hook a `str` method to a corresponding method of
`String` with the same name.

Args:
Expand Down Expand Up @@ -772,7 +772,7 @@ def overloaded_attr(_self, *args, **kwargs):
@classmethod
def _get_hooked_string_pointer_method(cls, attr):
"""
Hook a `String` method to a corresponding method of
Hook a `String` method to a corresponding method of
`StringPointer` with the same name.

Args:
Expand Down
2 changes: 1 addition & 1 deletion syft/generic/frameworks/overload.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _hook_function_args(*args, **kwargs):

# TODO have a better way to infer the type of tensor -> this is implies
# that the first argument is a tensor (even if this is the case > 99%)
tensor = args[0] if not isinstance(args[0], (tuple, list)) else args[0][0]
tensor = args[0] if not isinstance(args[0], (tuple)) else args[0][0]
cls = type(tensor)

# Replace all syft tensor with their child attribute
Expand Down
4 changes: 2 additions & 2 deletions syft/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class BaseWorker(AbstractWorker, ObjectStorage):
auto_add: Determines whether to automatically add this worker to the
list of known workers.
message_pending_time (optional): A number of seconds to delay the messages to be sent.
The argument may be a floating point number for subsecond
The argument may be a floating point number for subsecond
precision.
"""

Expand Down Expand Up @@ -995,7 +995,7 @@ def message_pending_time(self, seconds: Union[int, float]) -> None:

Args:
seconds: A number of seconds to delay the messages to be sent.
The argument may be a floating point number for subsecond
The argument may be a floating point number for subsecond
precision.

"""
Expand Down
30 changes: 30 additions & 0 deletions test/torch/tensors/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,36 @@ def test_torch_sigmoid_approx(workers):
assert (diff / (tolerance * norm)) < 1


def test_torch_tanh_approx(workers):
"""
Test the approximate tanh with different tolerance depending on
the precision_fractional considered
"""
alice, bob, james = workers["alice"], workers["bob"], workers["james"]

fix_prec_tolerance_by_method = {
"chebyshev": {3: 3 / 100, 4: 3 / 100, 5: 3 / 100},
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those might need more tweaking - currently, it works, but we might want a lower bound

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's ok I guess!

"sigmoid": {3: 10 / 100, 4: 15 / 100, 5: 15 / 100},
}

for method, fix_prec_tolerance in fix_prec_tolerance_by_method.items():
for prec_frac, tolerance in fix_prec_tolerance.items():
t = torch.tensor(range(-6, 6)) * 0.5
t_sh = t.fix_precision(precision_fractional=prec_frac).share(
alice, bob, crypto_provider=james
)
r_sh = t_sh.tanh(method)
r = r_sh.get().float_prec()
t = t.tanh()
print(method, prec_frac, tolerance)
print(r)
print(t)
diff = (r - t).abs().max()
norm = (r + t).abs().max() / 2

assert (diff / (tolerance * norm)) < 1


def test_torch_log_approx(workers):
"""
Test the approximate logarithm with different tolerance depending on
Expand Down