Skip to content

Commit f95660d

Browse files
author
Surfacebuaa
committed
feat: add OU homomorphic encryption for FATE 1.x
1 parent 15a2f12 commit f95660d

File tree

1 file changed

+374
-0
lines changed

1 file changed

+374
-0
lines changed
Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
"""OU encryption library for partially homomorphic encryption."""
2+
3+
import numpy as np
4+
import random
5+
6+
from federatedml.secureprotol import gmpy_math
7+
from federatedml.secureprotol.fixedpoint import FixedPointNumber
8+
9+
10+
# according to this paper
11+
# << Accelerating Okamoto-Uchiyama’s Public-Key Cryptosystem >>
12+
# and NIST's recommendation:
13+
# https://www.keylength.com/en/4/
14+
# 160 bits for key size 1024
15+
# 224 bits for key size 2048
16+
# 256 bits for key size 3072
17+
kPrimeFactorSize1024 = 160
18+
kPrimeFactorSize2048 = 224
19+
kPrimeFactorSize3072 = 256
20+
21+
class OUKeypair(object):
22+
def __init__(self):
23+
pass
24+
25+
@staticmethod
26+
def random_monic_exact_bits(bits):
27+
global last_generated
28+
new_value = random.getrandbits(bits)
29+
30+
if 'last_generated' not in globals():
31+
last_generated = new_value
32+
else:
33+
if new_value <= last_generated:
34+
new_value = last_generated + 1
35+
36+
last_generated = new_value
37+
return new_value
38+
39+
def generate_keypair(self, n_length=1024):
40+
"""return a new :class:`OUPublicKey` and :class:`OUPrivateKey`.
41+
"""
42+
secret_size = (n_length + 2) // 3
43+
44+
prime_factor_size = kPrimeFactorSize1024
45+
if n_length >= 3072:
46+
prime_factor_size = kPrimeFactorSize3072
47+
elif n_length >= 2048:
48+
prime_factor_size = kPrimeFactorSize2048
49+
50+
assert prime_factor_size * 2 <= secret_size, \
51+
"Key size must be larger than {} bits".format(prime_factor_size * 2 * 3 - 2)
52+
53+
# generate p
54+
while True:
55+
prime_factor = gmpy_math.getprimeover(prime_factor_size)
56+
# bits_of(a * b) <= bits_of(a) + bits_of(b),
57+
# So we add extra two bits to u:
58+
# one bit for prime_factor * u; another one bit for p^2;
59+
# Also, make sure that u > prime_factor
60+
u = self.random_monic_exact_bits(secret_size - prime_factor_size + 2) # p - 1 has a large prime factor
61+
p = prime_factor * u + 1
62+
63+
if gmpy_math.is_prime(p):
64+
break
65+
66+
# since bits_of(a * b) <= bits_of(a) + bits_of(b)
67+
# add another 1 bit for q
68+
q = gmpy_math.getprimeover(secret_size + 1)
69+
p_square = p ** 2
70+
t = prime_factor
71+
n = p_square * q
72+
73+
# calculate g_p
74+
while True:
75+
while True:
76+
g = random.randint(1, n-1)
77+
gcd = np.gcd(g, p)
78+
if gcd == 1:
79+
break
80+
81+
gp = gmpy_math.powmod(g % p_square, p - 1, p_square)
82+
check = gmpy_math.powmod(gp, p, p_square)
83+
84+
if check == 1:
85+
break
86+
87+
# calculate G
88+
capital_g = gmpy_math.powmod(g, u, n)
89+
90+
while True:
91+
g = random.randint(1, n-1)
92+
if g % p != 0:
93+
break
94+
95+
# calculate H
96+
capital_h = gmpy_math.powmod(g, n * u, n)
97+
98+
# max_plaintext_ must be a power of 2, for ease of use
99+
max_plaintext = pow(10, prime_factor_size // 2) // 2
100+
101+
public_key = OUPublicKey(n, capital_g, capital_h, max_plaintext)
102+
private_key = OUPrivateKey(public_key, p, q, t, gp, max_plaintext)
103+
104+
return public_key, private_key
105+
106+
107+
class OUPublicKey(object):
108+
"""Contains a public key and associated encryption methods.
109+
"""
110+
111+
def __init__(self, n, capital_g, capital_h, max_plaintext):
112+
self.n = n # n = p^2 * q
113+
self.capital_g = capital_g # G = g^u mod n for some random g \in [0, n)
114+
self.capital_h = capital_h # H = g'^{n*u} mod n for some random g' \in [0, n)
115+
self.max_plaintext = max_plaintext # always power of 2, e.g. max_plaintext_ == 2^681
116+
117+
def __repr__(self):
118+
hashcode = hex(hash(self))[2:]
119+
120+
return "<OUPublicKey {}>".format(hashcode[:10])
121+
122+
def __eq__(self, other):
123+
return self.n == other.n and self.capital_g == other.capital_g and self.capital_h == other.capital_h
124+
125+
def __hash__(self):
126+
return hash(self.n)
127+
128+
# multi H^r
129+
# r is a random number < n
130+
# H and n is public key
131+
def apply_obfuscator(self, ciphertext, random_value=None):
132+
"""
133+
"""
134+
r = random_value or random.SystemRandom().randrange(1, self.n)
135+
obfuscator = gmpy_math.powmod(self.capital_h, r, self.n)
136+
137+
return (ciphertext * obfuscator) % self.n
138+
139+
def raw_encrypt(self, plaintext, random_value=None):
140+
"""
141+
"""
142+
if not isinstance(plaintext, int):
143+
raise TypeError("plaintext should be int, but got: %s" %
144+
type(plaintext))
145+
146+
if plaintext >= self.max_plaintext:
147+
plaintext -= self.max_plaintext * 2
148+
149+
gm = gmpy_math.powmod(self.capital_g, plaintext, self.n)
150+
151+
ciphertext = self.apply_obfuscator(gm, random_value)
152+
153+
return ciphertext
154+
155+
def encrypt(self, value, precision=None, random_value=None):
156+
"""Encode and OU encrypt a real number value.
157+
"""
158+
if isinstance(value, FixedPointNumber):
159+
value = value.decode()
160+
encoding = FixedPointNumber.encode(value, self.max_plaintext * 2, self.max_plaintext, precision)
161+
obfuscator = random_value or 1
162+
ciphertext = self.raw_encrypt(encoding.encoding, random_value=obfuscator)
163+
encryptednumber = OUEncryptedNumber(self, ciphertext, encoding.exponent)
164+
165+
return encryptednumber
166+
167+
168+
class OUPrivateKey(object):
169+
"""Contains a private key and associated decryption method.
170+
"""
171+
172+
def __init__(self, public_key, p, q, t, gp, max_plaintext):
173+
self.public_key = public_key
174+
self.p = p
175+
self.q = q # primes such that log2(p), log2(q) ~ n_bits / 3
176+
self.t = t # a big prime factor of p - 1, i.e., p = t * u + 1
177+
self.gp = gp
178+
self.gp_inv = gmpy_math.invert((self.gp - 1) // p, p) # L(g^{p-1} mod p^2))^{-1} mod p
179+
self.p_square = p ** 2
180+
self.max_plaintext = max_plaintext
181+
182+
def __repr__(self):
183+
hashcode = hex(hash(self))[2:]
184+
185+
return "<OUPrivateKey {}>".format(hashcode[:10])
186+
187+
def __eq__(self, other):
188+
return self.p == other.p and self.q == other.q and self.t == other.t and self.gp_inv == other.gp_inv
189+
190+
def __hash__(self):
191+
return hash((self.p, self.q))
192+
193+
def raw_decrypt(self, ciphertext):
194+
"""return raw plaintext.
195+
"""
196+
if not isinstance(ciphertext, int):
197+
raise TypeError("ciphertext should be an int, not: %s" %
198+
type(ciphertext))
199+
200+
plaintext = 0
201+
202+
ct = gmpy_math.powmod(ciphertext % self.p_square, self.t, self.p_square)
203+
204+
plaintext = ((ct // self.p) * self.gp_inv) % self.p
205+
206+
if plaintext >= self.p / 2:
207+
plaintext -= self.p
208+
if plaintext >= self.max_plaintext:
209+
plaintext = plaintext % (self.max_plaintext * 2)
210+
211+
return plaintext
212+
213+
def decrypt(self, encrypted_number):
214+
"""return the decrypted & decoded plaintext of encrypted_number.
215+
"""
216+
if not isinstance(encrypted_number, OUEncryptedNumber):
217+
raise TypeError("encrypted_number should be an OUEncryptedNumber, \
218+
not: %s" % type(encrypted_number))
219+
220+
if self.public_key != encrypted_number.public_key:
221+
raise ValueError("encrypted_number was encrypted against a different key!")
222+
223+
encoded = self.raw_decrypt(encrypted_number.ciphertext(be_secure=False))
224+
encoded = FixedPointNumber(encoded,
225+
encrypted_number.exponent,
226+
self.public_key.max_plaintext * 2,
227+
self.public_key.max_plaintext)
228+
decrypt_value = encoded.decode()
229+
230+
return decrypt_value
231+
232+
233+
class OUEncryptedNumber(object):
234+
"""Represents the OU encryption of a float or int.
235+
"""
236+
237+
def __init__(self, public_key, ciphertext, exponent=0):
238+
self.public_key = public_key
239+
self.__ciphertext = ciphertext
240+
self.exponent = exponent
241+
self.__is_obfuscator = False
242+
243+
if not isinstance(self.__ciphertext, int):
244+
raise TypeError("ciphertext should be an int, not: %s" % type(self.__ciphertext))
245+
246+
if not isinstance(self.public_key, OUPublicKey):
247+
raise TypeError("public_key should be a OUPublicKey, not: %s" % type(self.public_key))
248+
249+
def ciphertext(self, be_secure=True):
250+
"""return the ciphertext of the OUEncryptedNumber.
251+
"""
252+
if be_secure and not self.__is_obfuscator:
253+
self.apply_obfuscator()
254+
255+
return self.__ciphertext
256+
257+
def apply_obfuscator(self):
258+
"""ciphertext by multiplying by H ** r with random r
259+
"""
260+
self.__ciphertext = self.public_key.apply_obfuscator(self.__ciphertext)
261+
self.__is_obfuscator = True
262+
263+
def __add__(self, other):
264+
if isinstance(other, OUEncryptedNumber):
265+
return self.__add_encryptednumber(other)
266+
else:
267+
return self.__add_scalar(other)
268+
269+
def __radd__(self, other):
270+
return self.__add__(other)
271+
272+
def __sub__(self, other):
273+
274+
return self + (other * -1)
275+
276+
def __rsub__(self, other):
277+
return other + (self * -1)
278+
279+
def __rmul__(self, scalar):
280+
return self.__mul__(scalar)
281+
282+
def __truediv__(self, scalar):
283+
return self.__mul__(1 / scalar)
284+
285+
def __mul__(self, scalar):
286+
"""return Multiply by an scalar(such as int, float)
287+
"""
288+
if isinstance(scalar, FixedPointNumber):
289+
scalar = scalar.decode()
290+
encode = FixedPointNumber.encode(scalar, self.public_key.max_plaintext * 2, self.public_key.max_plaintext)
291+
plaintext = encode.encoding
292+
293+
if plaintext < 0 or plaintext >= (self.public_key.max_plaintext * 2):
294+
raise ValueError("Scalar out of bounds: %i" % plaintext)
295+
296+
if plaintext > self.public_key.max_plaintext:
297+
# Very large plaintext, play a sneaky trick using inverses
298+
plaintext -= self.public_key.max_plaintext * 2
299+
300+
ciphertext = gmpy_math.powmod(self.ciphertext(False), plaintext, self.public_key.n)
301+
302+
exponent = self.exponent + encode.exponent
303+
304+
return OUEncryptedNumber(self.public_key, ciphertext, exponent)
305+
306+
def increase_exponent_to(self, new_exponent):
307+
"""return OUEncryptedNumber:
308+
new OUEncryptedNumber with same value but having great exponent.
309+
"""
310+
if new_exponent < self.exponent:
311+
raise ValueError("New exponent %i should be great than old exponent %i" % (new_exponent, self.exponent))
312+
313+
factor = pow(FixedPointNumber.BASE, new_exponent - self.exponent)
314+
new_encryptednumber = self.__mul__(factor)
315+
new_encryptednumber.exponent = new_exponent
316+
317+
return new_encryptednumber
318+
319+
def __align_exponent(self, x, y):
320+
"""return x,y with same exponet
321+
"""
322+
if x.exponent < y.exponent:
323+
x = x.increase_exponent_to(y.exponent)
324+
elif x.exponent > y.exponent:
325+
y = y.increase_exponent_to(x.exponent)
326+
327+
return x, y
328+
329+
def __add_scalar(self, scalar):
330+
"""return OUEncryptedNumber: z = E(x) + y
331+
"""
332+
if isinstance(scalar, FixedPointNumber):
333+
scalar = scalar.decode()
334+
335+
encoded = FixedPointNumber.encode(scalar,
336+
self.public_key.max_plaintext * 2,
337+
self.public_key.max_plaintext,
338+
max_exponent=self.exponent)
339+
340+
return self.__add_fixpointnumber(encoded)
341+
342+
def __add_fixpointnumber(self, encoded):
343+
"""return OUEncryptedNumber: z = E(x) + FixedPointNumber(y)
344+
# """
345+
if self.public_key.max_plaintext != encoded.max_int:
346+
raise ValueError("Attempted to add numbers encoded against different public keys!")
347+
348+
# their exponents must match, and align.
349+
x, y = self.__align_exponent(self, encoded)
350+
351+
encrypted_scalar = x.public_key.raw_encrypt(y.encoding, 1)
352+
encryptednumber = self.__raw_add(x.ciphertext(False), encrypted_scalar, x.exponent)
353+
354+
return encryptednumber
355+
356+
def __add_encryptednumber(self, other):
357+
"""return OUEncryptedNumber: z = E(x) + E(y)
358+
"""
359+
if self.public_key != other.public_key:
360+
raise ValueError("add two numbers have different public key!")
361+
362+
# their exponents must match, and align.
363+
x, y = self.__align_exponent(self, other)
364+
365+
encryptednumber = self.__raw_add(x.ciphertext(False), y.ciphertext(False), x.exponent)
366+
367+
return encryptednumber
368+
369+
def __raw_add(self, e_x, e_y, exponent):
370+
"""return the integer E(x + y) given ints E(x) and E(y).
371+
"""
372+
ciphertext = gmpy_math.mpz(e_x) * gmpy_math.mpz(e_y) % self.public_key.n
373+
374+
return OUEncryptedNumber(self.public_key, int(ciphertext), exponent)

0 commit comments

Comments
 (0)