Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@
from syft.frameworks.torch.functions import combine_pointers
from syft.frameworks.torch.he.paillier import keygen

# import common
import syft.common.util


def pool():
if not hasattr(syft, "_pool"):
Expand Down
Empty file added syft/common/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions syft/common/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import numpy as np


def chebyshev_series(func, width, terms):
"""
Computes Chebyshev coefficients
For n = terms, the ith Chebyshev series coefficient is
.. math::
c_i = 2/n \sum_{k=1}^n \cos(j(2k-1)\pi / 4n) f(w\cos((2k-1)\pi / 4n))
Args:
func (function): function to be approximated
width (int): approximation will support inputs in range [-width, width]
terms (int): number of Chebyshev terms used in approximation
Returns:
Chebyshev coefficients with shape equal to num of terms.
"""
n_range = torch.arange(start=0, end=terms).float()
x = width * torch.cos((n_range + 0.5) * np.pi / terms)
y = func(x)
cos_term = torch.cos(torch.ger(n_range, n_range + 0.5) * np.pi / terms)
coeffs = (2 / terms) * torch.sum(y * cos_term, axis=1)
return coeffs


def chebyshev_polynomials(tensor, terms=32):
"""
Evaluates odd degree Chebyshev polynomials at x
Chebyshev Polynomials of the first kind are defined as
.. math::
P_0(x) = 1, P_1(x) = x, P_n(x) = 2 P_{n - 1}(x) - P_{n-2}(x)
Args:
tensor (torch.tensor): input at which polynomials are evaluated
terms (int): highest degree of Chebyshev polynomials.
Must be even and at least 6.
"""
if terms % 2 != 0 or terms < 6:
raise ValueError("Chebyshev terms must be even and >= 6")

polynomials = [tensor.clone()]
y = 4 * tensor ** 2 - 2
z = y - 1
polynomials.append(z.mul(tensor))

for k in range(2, terms // 2):
next_polynomial = y * polynomials[k - 1] - polynomials[k - 2]
polynomials.append(next_polynomial)

return torch.stack(polynomials)
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,9 @@ def __imul__(self, other):
self = self.mul(other)
return self

def square(self):
return self.mul(self)

def pow(self, power):
"""
Compute integer power of a number by recursion using mul
Expand Down Expand Up @@ -699,7 +702,6 @@ def unbind(tensor_shares, **kwargs):

@overloaded.function
def stack(tensors_shares, **kwargs):

results = {}

workers = tensors_shares[0].keys()
Expand Down
26 changes: 21 additions & 5 deletions syft/frameworks/torch/tensors/interpreters/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,14 +635,30 @@ def log(tensor):

module.log = log

def tanh(tensor):
def tanh(tensor, maxval=6, terms=32):
"""
Overloads torch.tanh to be able to use MPC
"""

result = 2 * sigmoid(2 * tensor) - 1

return result
Implementation taken from FacebookResearch - CrypTen project
Computes tanh via Chebyshev approximation with truncation.
.. math::
tanh(x) = \sum_{j=1}^terms c_{2j - 1} P_{2j - 1} (x / maxval)
where c_i is the ith Chebyshev series coefficient and P_i is ith polynomial.
The approximation is truncated to +/-1 outside [-maxval, maxval].
Args:
maxval (int): interval width used for computing chebyshev polynomials
terms (int): highest degree of Chebyshev polynomials.
Must be even and at least 6.
"""
coeffs = syft.common.util.chebyshev_series(torch.tanh, maxval, terms)[1::2]
tanh_polys = syft.common.util.chebyshev_polynomials(tensor.div(maxval), terms)
tanh_polys_flipped = tanh_polys.unsqueeze(dim=-1).transpose(0, -1).squeeze(dim=0)

out = tanh_polys_flipped.matmul(coeffs)
# truncate outside [-maxval, maxval]
out = torch.where(tensor > maxval, 1.0, out)
out = torch.where(tensor < -maxval, -1.0, out)
return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, 2 * sigmoid(2 * tensor) - 1 wasn't precise enough or too slow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PR, they say this:

Implemented improved tanh approximation that's ~43.5% more accurate (measured by total relative error) and ~33% faster (see n196073).

Before this, they also had the implementation using the sigmoid

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok cool!


module.tanh = tanh

Expand Down
2 changes: 1 addition & 1 deletion syft/generic/frameworks/overload.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _hook_function_args(*args, **kwargs):

# TODO have a better way to infer the type of tensor -> this is implies
# that the first argument is a tensor (even if this is the case > 99%)
tensor = args[0] if not isinstance(args[0], tuple) else args[0][0]
tensor = args[0] if not isinstance(args[0], (tuple, list)) else args[0][0]
cls = type(tensor)

# Replace all syft tensor with their child attribute
Expand Down
38 changes: 38 additions & 0 deletions test/common/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import itertools

from syft.common.util import chebyshev_series, chebyshev_polynomials


def test_chebyshev_polynomials():
"""Tests evaluation of chebyshev polynomials"""
sizes = [(1, 10), (3, 5), (3, 5, 10)]
possible_terms = [6, 40]
tolerance = 0.05

for size, terms in itertools.product(sizes, possible_terms):
tensor = torch.rand(torch.Size(size)) * 42 - 42
result = chebyshev_polynomials(tensor, terms)

# check number of polynomials
assert result.shape[0] == terms // 2

assert torch.all(result[0] == tensor), "first term is incorrect"

second_term = 4 * tensor ** 3 - 3 * tensor
diff = (result[1] - second_term).abs()
norm_diff = diff.div(result[1].abs() + second_term.abs())
assert torch.all(norm_diff <= tolerance), "second term is incorrect"


def test_chebyshev_series():
"""Checks coefficients returned by chebyshev_series are correct"""
for width, terms in [(6, 10), (6, 20)]:
result = chebyshev_series(torch.tanh, width, terms)

# check shape
assert result.shape == torch.Size([terms])

# check terms
assert result[0] < 1e-4
assert torch.isclose(result[-1], torch.tensor(3.5e-2), atol=1e-1)