From 7d1e688ec5d93a0e433e5831ed2617b3b0a23d45 Mon Sep 17 00:00:00 2001 From: George Muraru Date: Mon, 3 Feb 2020 22:58:45 +0200 Subject: [PATCH 1/2] Add tanh Chebyshev approx --- syft/__init__.py | 3 ++ syft/common/__init__.py | 0 syft/common/util.py | 52 +++++++++++++++++++ .../tensors/interpreters/additive_shared.py | 4 +- .../torch/tensors/interpreters/precision.py | 29 +++++++++-- syft/generic/frameworks/overload.py | 2 +- 6 files changed, 83 insertions(+), 7 deletions(-) create mode 100644 syft/common/__init__.py create mode 100644 syft/common/util.py diff --git a/syft/__init__.py b/syft/__init__.py index e56cce5bef5..079070309b4 100644 --- a/syft/__init__.py +++ b/syft/__init__.py @@ -83,6 +83,9 @@ from syft.frameworks.torch.functions import combine_pointers from syft.frameworks.torch.he.paillier import keygen +# import common +import syft.common.util + def pool(): if not hasattr(syft, "_pool"): diff --git a/syft/common/__init__.py b/syft/common/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/syft/common/util.py b/syft/common/util.py new file mode 100644 index 00000000000..3bbb89b03f0 --- /dev/null +++ b/syft/common/util.py @@ -0,0 +1,52 @@ +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np + + +def chebyshev_series(func, width, terms): + """ + Computes Chebyshev coefficients + For n = terms, the ith Chebyshev series coefficient is + .. math:: + c_i = 2/n \sum_{k=1}^n \cos(j(2k-1)\pi / 4n) f(w\cos((2k-1)\pi / 4n)) + Args: + func (function): function to be approximated + width (int): approximation will support inputs in range [-width, width] + terms (int): number of Chebyshev terms used in approximation + Returns: + Chebyshev coefficients with shape equal to num of terms. + """ + n_range = torch.arange(start=0, end=terms).float() + x = width * torch.cos((n_range + 0.5) * np.pi / terms) + y = func(x) + cos_term = torch.cos(torch.ger(n_range, n_range + 0.5) * np.pi / terms) + coeffs = (2 / terms) * torch.sum(y * cos_term, axis=1) + return coeffs + + +def chebyshev_polynomials(tensor, terms): + """ + Evaluates odd degree Chebyshev polynomials at x + Chebyshev Polynomials of the first kind are defined as + .. math:: + P_0(x) = 1, P_1(x) = x, P_n(x) = 2 P_{n - 1}(x) - P_{n-2}(x) + Args: + self (MPCTensor): input at which polynomials are evaluated + terms (int): highest degree of Chebyshev polynomials. + Must be even and at least 6. + """ + if terms % 2 != 0 or terms < 6: + raise ValueError("Chebyshev terms must be even and >= 6") + + polynomials = [tensor.clone()] + y = 4 * tensor.square() - 2 + z = y - 1 + polynomials.append(z.mul(tensor)) + + for k in range(2, terms // 2): + next_polynomial = y * polynomials[k - 1] - polynomials[k - 2] + polynomials.append(next_polynomial) + + return torch.stack(polynomials) diff --git a/syft/frameworks/torch/tensors/interpreters/additive_shared.py b/syft/frameworks/torch/tensors/interpreters/additive_shared.py index f882125226f..37ccb9a10e9 100644 --- a/syft/frameworks/torch/tensors/interpreters/additive_shared.py +++ b/syft/frameworks/torch/tensors/interpreters/additive_shared.py @@ -511,6 +511,9 @@ def __imul__(self, other): self = self.mul(other) return self + def square(self): + return self.mul(self) + def pow(self, power): """ Compute integer power of a number by recursion using mul @@ -699,7 +702,6 @@ def unbind(tensor_shares, **kwargs): @overloaded.function def stack(tensors_shares, **kwargs): - results = {} workers = tensors_shares[0].keys() diff --git a/syft/frameworks/torch/tensors/interpreters/precision.py b/syft/frameworks/torch/tensors/interpreters/precision.py index 7111d70a08d..bb6be5804eb 100644 --- a/syft/frameworks/torch/tensors/interpreters/precision.py +++ b/syft/frameworks/torch/tensors/interpreters/precision.py @@ -362,6 +362,9 @@ def __imul__(self, other): mul_ = __imul__ + def square(self): + return self.mul_and_div(self, "mul") + def div(self, other): return self.mul_and_div(other, "div") @@ -636,14 +639,30 @@ def log(tensor): module.log = log - def tanh(tensor): + def tanh(tensor, maxval=6, terms=32): """ Overloads torch.tanh to be able to use MPC - """ - - result = 2 * sigmoid(2 * tensor) - 1 - return result + Implementation taken from FacebookResearch - CrypTen project + Computes tanh via Chebyshev approximation with truncation. + .. math:: + tanh(x) = \sum_{j=1}^terms c_{2j - 1} P_{2j - 1} (x / maxval) + where c_i is the ith Chebyshev series coefficient and P_i is ith polynomial. + The approximation is truncated to +/-1 outside [-maxval, maxval]. + Args: + maxval (int): interval width used for computing chebyshev polynomials + terms (int): highest degree of Chebyshev polynomials. + Must be even and at least 6. + """ + coeffs = syft.common.util.chebyshev_series(torch.tanh, maxval, terms)[1::2] + tanh_polys = syft.common.util.chebyshev_polynomials(tensor.div(maxval), terms) + tanh_polys_flipped = tanh_polys.unsqueeze(dim=-1).transpose(0, -1).squeeze(dim=0) + + out = tanh_polys_flipped.matmul(coeffs) + # truncate outside [-maxval, maxval] + out = torch.where(tensor > maxval, 1.0, out) + out = torch.where(tensor < -maxval, -1.0, out) + return out module.tanh = tanh diff --git a/syft/generic/frameworks/overload.py b/syft/generic/frameworks/overload.py index 1ac679bfa99..95a81a2ca8d 100644 --- a/syft/generic/frameworks/overload.py +++ b/syft/generic/frameworks/overload.py @@ -46,7 +46,7 @@ def _hook_function_args(*args, **kwargs): # TODO have a better way to infer the type of tensor -> this is implies # that the first argument is a tensor (even if this is the case > 99%) - tensor = args[0] if not isinstance(args[0], tuple) else args[0][0] + tensor = args[0] if not isinstance(args[0], (tuple, list)) else args[0][0] cls = type(tensor) # Replace all syft tensor with their child attribute From b5d90c270db76dd8bff0a2ebf2ef2524860aa440 Mon Sep 17 00:00:00 2001 From: George Muraru Date: Thu, 6 Feb 2020 02:19:59 +0200 Subject: [PATCH 2/2] Add tests --- syft/common/util.py | 6 +-- .../torch/tensors/interpreters/precision.py | 3 -- test/common/test_util.py | 38 +++++++++++++++++++ 3 files changed, 41 insertions(+), 6 deletions(-) create mode 100644 test/common/test_util.py diff --git a/syft/common/util.py b/syft/common/util.py index 3bbb89b03f0..365718244a3 100644 --- a/syft/common/util.py +++ b/syft/common/util.py @@ -26,14 +26,14 @@ def chebyshev_series(func, width, terms): return coeffs -def chebyshev_polynomials(tensor, terms): +def chebyshev_polynomials(tensor, terms=32): """ Evaluates odd degree Chebyshev polynomials at x Chebyshev Polynomials of the first kind are defined as .. math:: P_0(x) = 1, P_1(x) = x, P_n(x) = 2 P_{n - 1}(x) - P_{n-2}(x) Args: - self (MPCTensor): input at which polynomials are evaluated + tensor (torch.tensor): input at which polynomials are evaluated terms (int): highest degree of Chebyshev polynomials. Must be even and at least 6. """ @@ -41,7 +41,7 @@ def chebyshev_polynomials(tensor, terms): raise ValueError("Chebyshev terms must be even and >= 6") polynomials = [tensor.clone()] - y = 4 * tensor.square() - 2 + y = 4 * tensor ** 2 - 2 z = y - 1 polynomials.append(z.mul(tensor)) diff --git a/syft/frameworks/torch/tensors/interpreters/precision.py b/syft/frameworks/torch/tensors/interpreters/precision.py index bb6be5804eb..c5f76aa7dfd 100644 --- a/syft/frameworks/torch/tensors/interpreters/precision.py +++ b/syft/frameworks/torch/tensors/interpreters/precision.py @@ -362,9 +362,6 @@ def __imul__(self, other): mul_ = __imul__ - def square(self): - return self.mul_and_div(self, "mul") - def div(self, other): return self.mul_and_div(other, "div") diff --git a/test/common/test_util.py b/test/common/test_util.py new file mode 100644 index 00000000000..e825389c55b --- /dev/null +++ b/test/common/test_util.py @@ -0,0 +1,38 @@ +import torch +import itertools + +from syft.common.util import chebyshev_series, chebyshev_polynomials + + +def test_chebyshev_polynomials(): + """Tests evaluation of chebyshev polynomials""" + sizes = [(1, 10), (3, 5), (3, 5, 10)] + possible_terms = [6, 40] + tolerance = 0.05 + + for size, terms in itertools.product(sizes, possible_terms): + tensor = torch.rand(torch.Size(size)) * 42 - 42 + result = chebyshev_polynomials(tensor, terms) + + # check number of polynomials + assert result.shape[0] == terms // 2 + + assert torch.all(result[0] == tensor), "first term is incorrect" + + second_term = 4 * tensor ** 3 - 3 * tensor + diff = (result[1] - second_term).abs() + norm_diff = diff.div(result[1].abs() + second_term.abs()) + assert torch.all(norm_diff <= tolerance), "second term is incorrect" + + +def test_chebyshev_series(): + """Checks coefficients returned by chebyshev_series are correct""" + for width, terms in [(6, 10), (6, 20)]: + result = chebyshev_series(torch.tanh, width, terms) + + # check shape + assert result.shape == torch.Size([terms]) + + # check terms + assert result[0] < 1e-4 + assert torch.isclose(result[-1], torch.tensor(3.5e-2), atol=1e-1)