-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Chebyshev or Sigmoid for tanh #3113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9ade20c
ec282f1
43163c5
e9e0f6f
c8acd50
5ff1168
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -243,12 +243,14 @@ def reconstruct(self): | |
|
|
||
| return sy.MultiPointerTensor(children=pointers) | ||
|
|
||
| def zero(self): | ||
| def zero(self, shape=None): | ||
| """ | ||
| Build an additive shared tensor of value zero with the same | ||
| properties as self | ||
| """ | ||
| shape = self.shape if self.shape else [1] | ||
|
|
||
| if shape == None or len(shape) == 0: | ||
| shape = self.shape if self.shape else [1] | ||
| zero = ( | ||
| torch.zeros(*shape) | ||
| .long() | ||
|
|
@@ -482,11 +484,17 @@ def _public_mul(self, shares, other, equation): | |
| other_is_zero = True | ||
|
|
||
| if other_is_zero: | ||
| zero_shares = self.zero().child | ||
| return { | ||
| worker: ((cmd(share, other) + zero_shares[worker]) % self.field) | ||
| for worker, share in shares.items() | ||
| } | ||
| res = {} | ||
| first_it = True | ||
|
|
||
| for worker, share in shares.items(): | ||
| cmd_res = cmd(share, other) | ||
| if first_it: | ||
| first_it = False | ||
| zero_shares = self.zero(cmd_res.shape).child | ||
|
|
||
| res[worker] = (cmd(share, other) + zero_shares[worker]) % self.field | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you put all in the same loop (and add a flag to notice when it's the first iteration) ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure |
||
| return res | ||
| else: | ||
| return { | ||
| worker: (cmd(share, other) % self.field) for worker, share in shares.items() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -413,7 +413,7 @@ def matmul(self, *args, **kwargs): | |
| ), "In matmul, all args should have the same precision_fractional" | ||
|
|
||
| if isinstance(self.child, AdditiveSharingTensor) and isinstance(other.child, torch.Tensor): | ||
| # If we try to matmul a FPT>(wrap)>AST with a FPT>torch.tensor, | ||
| # If we try to matmul a FPT>AST with a FPT>torch.tensor, | ||
| # we want to perform AST @ torch.tensor | ||
| new_self = self.child | ||
| new_args = (other,) | ||
|
|
@@ -422,7 +422,7 @@ def matmul(self, *args, **kwargs): | |
| elif isinstance(other.child, AdditiveSharingTensor) and isinstance( | ||
| self.child, torch.Tensor | ||
| ): | ||
| # If we try to matmul a FPT>torch.tensor with a FPT>(wrap)>AST, | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LaRiffle removed the wrap
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great! |
||
| # If we try to matmul a FPT>torch.tensor with a FPT>AST, | ||
| # we swap operators so that we do the same operation as above | ||
| new_self = other.child | ||
| new_args = (self,) | ||
|
|
@@ -551,6 +551,58 @@ def log(self, iterations=2, exp_iterations=8): | |
|
|
||
| return y | ||
|
|
||
| @staticmethod | ||
| def _tanh_chebyshev(tensor, maxval: int = 6, terms: int = 32): | ||
| """ | ||
| Implementation taken from FacebookResearch - CrypTen project | ||
| Computes tanh via Chebyshev approximation with truncation. | ||
| tanh(x) = \sum_{j=1}^terms c_{2j - 1} P_{2j - 1} (x / maxval) | ||
| where c_i is the ith Chebyshev series coefficient and P_i is ith polynomial. | ||
| The approximation is truncated to +/-1 outside [-maxval, maxval]. | ||
| Args: | ||
| maxval (int): interval width used for computing chebyshev polynomials | ||
| terms (int): highest degree of Chebyshev polynomials. | ||
| Must be even and at least 6. | ||
| More details can be found in the paper: | ||
| Guo, Chuan and Hannun, Awni and Knott, Brian and van der Maaten, | ||
| Laurens and Tygert, Mark and Zhu, Ruiyu, | ||
| "Secure multiparty computations in floating-point arithmetic", Jan-2020 | ||
| Link: http://tygert.com/realcrypt.pdf | ||
|
|
||
| """ | ||
|
|
||
| coeffs = syft.common.util.chebyshev_series(torch.tanh, maxval, terms)[1::2] | ||
| coeffs = coeffs.fix_precision(**tensor.get_class_attributes()) % tensor.field | ||
| coeffs = coeffs.unsqueeze(1) | ||
|
|
||
| value = torch.tensor(maxval).fix_precision(**tensor.get_class_attributes()) % tensor.field | ||
| tanh_polys = syft.common.util.chebyshev_polynomials(tensor.div(value.child), terms) | ||
| tanh_polys_flipped = tanh_polys.unsqueeze(dim=-1).transpose(0, -1).squeeze(dim=0) | ||
|
|
||
| out = tanh_polys_flipped.matmul(coeffs.child) | ||
|
|
||
| # truncate outside [-maxval, maxval] | ||
| gate_up = tensor > value | ||
| gate_down = -tensor > value | ||
| res = gate_up - gate_down | ||
| out = out.squeeze(1) * (1 - gate_up - gate_down) | ||
| out = res + out | ||
|
|
||
| return out | ||
|
|
||
| @staticmethod | ||
| def _tanh_sigmoid(tensor): | ||
| """ | ||
| Compute the tanh using the sigmoid | ||
|
|
||
| """ | ||
| return 2 * torch.sigmoid(2 * tensor) - 1 | ||
|
|
||
| def tanh(tensor, method="chebyshev"): | ||
| tanh_f = getattr(tensor, f"_tanh_{method}") | ||
|
|
||
| return tanh_f(tensor) | ||
|
|
||
| # Binary ops | ||
| @overloaded.method | ||
| def __gt__(self, _self, other): | ||
|
|
@@ -635,30 +687,8 @@ def log(tensor): | |
|
|
||
| module.log = log | ||
|
|
||
| def tanh(tensor, maxval=6, terms=32): | ||
| """ | ||
| Overloads torch.tanh to be able to use MPC | ||
|
|
||
| Implementation taken from FacebookResearch - CrypTen project | ||
| Computes tanh via Chebyshev approximation with truncation. | ||
| .. math:: | ||
| tanh(x) = \sum_{j=1}^terms c_{2j - 1} P_{2j - 1} (x / maxval) | ||
| where c_i is the ith Chebyshev series coefficient and P_i is ith polynomial. | ||
| The approximation is truncated to +/-1 outside [-maxval, maxval]. | ||
| Args: | ||
| maxval (int): interval width used for computing chebyshev polynomials | ||
| terms (int): highest degree of Chebyshev polynomials. | ||
| Must be even and at least 6. | ||
| """ | ||
| coeffs = syft.common.util.chebyshev_series(torch.tanh, maxval, terms)[1::2] | ||
| tanh_polys = syft.common.util.chebyshev_polynomials(tensor.div(maxval), terms) | ||
| tanh_polys_flipped = tanh_polys.unsqueeze(dim=-1).transpose(0, -1).squeeze(dim=0) | ||
|
|
||
| out = tanh_polys_flipped.matmul(coeffs) | ||
| # truncate outside [-maxval, maxval] | ||
| out = torch.where(tensor > maxval, 1.0, out) | ||
| out = torch.where(tensor < -maxval, -1.0, out) | ||
| return out | ||
| def tanh(tensor): | ||
| return tensor.tanh() | ||
|
|
||
| module.tanh = tanh | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -468,6 +468,36 @@ def test_torch_sigmoid_approx(workers): | |
| assert (diff / (tolerance * norm)) < 1 | ||
|
|
||
|
|
||
| def test_torch_tanh_approx(workers): | ||
| """ | ||
| Test the approximate tanh with different tolerance depending on | ||
| the precision_fractional considered | ||
| """ | ||
| alice, bob, james = workers["alice"], workers["bob"], workers["james"] | ||
|
|
||
| fix_prec_tolerance_by_method = { | ||
| "chebyshev": {3: 3 / 100, 4: 3 / 100, 5: 3 / 100}, | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those might need more tweaking - currently, it works, but we might want a lower bound
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's ok I guess! |
||
| "sigmoid": {3: 10 / 100, 4: 15 / 100, 5: 15 / 100}, | ||
| } | ||
|
|
||
| for method, fix_prec_tolerance in fix_prec_tolerance_by_method.items(): | ||
| for prec_frac, tolerance in fix_prec_tolerance.items(): | ||
| t = torch.tensor(range(-6, 6)) * 0.5 | ||
| t_sh = t.fix_precision(precision_fractional=prec_frac).share( | ||
| alice, bob, crypto_provider=james | ||
| ) | ||
| r_sh = t_sh.tanh(method) | ||
| r = r_sh.get().float_prec() | ||
| t = t.tanh() | ||
| print(method, prec_frac, tolerance) | ||
| print(r) | ||
| print(t) | ||
| diff = (r - t).abs().max() | ||
| norm = (r + t).abs().max() / 2 | ||
|
|
||
| assert (diff / (tolerance * norm)) < 1 | ||
|
|
||
|
|
||
| def test_torch_log_approx(workers): | ||
| """ | ||
| Test the approximate logarithm with different tolerance depending on | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed this in case I have a matrix multiplication - in that case the dimension for the zero shares might be different.