Skip to content

Commit 38775d0

Browse files
author
Wei Dai
authored
Merge pull request #148 from Alibaba-Gemini-Lab/dev/seal
Reordered invNTT powers of roots, and updated the method to divide by n in NTT.
2 parents da62e76 + 82f595a commit 38775d0

File tree

2 files changed

+108
-153
lines changed

2 files changed

+108
-153
lines changed

native/src/seal/util/smallntt.cpp

Lines changed: 108 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ namespace seal
4242
scaled_root_powers_.release();
4343
inv_root_powers_.release();
4444
scaled_inv_root_powers_.release();
45-
inv_root_powers_div_two_.release();
46-
scaled_inv_root_powers_div_two_.release();
4745
inv_degree_modulo_ = 0;
4846
coeff_count_power_ = 0;
4947
coeff_count_ = 0;
@@ -68,8 +66,6 @@ namespace seal
6866
inv_root_powers_ = allocate_uint(coeff_count_, pool_);
6967
scaled_root_powers_ = allocate_uint(coeff_count_, pool_);
7068
scaled_inv_root_powers_ = allocate_uint(coeff_count_, pool_);
71-
inv_root_powers_div_two_ = allocate_uint(coeff_count_, pool_);
72-
scaled_inv_root_powers_div_two_ = allocate_uint(coeff_count_, pool_);
7369
modulus_ = modulus;
7470

7571
// We defer parameter checking to try_minimal_primitive_root(...)
@@ -98,15 +94,21 @@ namespace seal
9894
ntt_scale_powers_of_primitive_root(inv_root_powers_.get(),
9995
scaled_inv_root_powers_.get());
10096

101-
// Populate the tables storing (scaled version of ) 2 times
102-
// powers of roots^-1 mod q in bit-scrambled order.
103-
for (size_t i = 0; i < coeff_count_; i++)
104-
{
105-
inv_root_powers_div_two_[i] =
106-
div2_uint_mod(inv_root_powers_[i], modulus_);
97+
// Reordering inv_root_powers_ so that the access pattern at inverse NTT is sequential.
98+
std::vector<uint64_t> tmp(coeff_count_);
99+
uint64_t *ptr = tmp.data() + 1;
100+
for (size_t i = coeff_count_ / 2; i > 0; i /= 2) {
101+
for (size_t j = i; j < i * 2; ++j)
102+
*ptr++ = inv_root_powers_[j];
103+
}
104+
std::copy(tmp.cbegin(), tmp.cend(), inv_root_powers_.get());
105+
106+
ptr = tmp.data() + 1;
107+
for (size_t i = coeff_count_ / 2; i > 0; i /= 2) {
108+
for (size_t j = i; j < i * 2; ++j)
109+
*ptr++ = scaled_inv_root_powers_[j];
107110
}
108-
ntt_scale_powers_of_primitive_root(inv_root_powers_div_two_.get(),
109-
scaled_inv_root_powers_div_two_.get());
111+
std::copy(tmp.cbegin(), tmp.cend(), scaled_inv_root_powers_.get());
110112

111113
// Last compute n^(-1) modulo q.
112114
uint64_t degree_uint = static_cast<uint64_t>(coeff_count_);
@@ -137,18 +139,72 @@ namespace seal
137139

138140
// compute floor ( input * beta /q ), where beta is a 64k power of 2
139141
// and 0 < q < beta.
142+
static inline uint64_t precompute_mulmod(uint64_t y, uint64_t p) {
143+
uint64_t wide_quotient[2]{ 0, 0 };
144+
uint64_t wide_coeff[2]{ 0, y };
145+
divide_uint128_uint64_inplace(wide_coeff, p, wide_quotient);
146+
return wide_quotient[0];
147+
}
148+
140149
void SmallNTTTables::ntt_scale_powers_of_primitive_root(
141150
const uint64_t *input, uint64_t *destination) const
142151
{
143152
for (size_t i = 0; i < coeff_count_; i++, input++, destination++)
144153
{
145-
uint64_t wide_quotient[2]{ 0, 0 };
146-
uint64_t wide_coeff[2]{ 0, *input };
147-
divide_uint128_uint64_inplace(wide_coeff, modulus_.value(), wide_quotient);
148-
*destination = wide_quotient[0];
154+
*destination = precompute_mulmod(*input, modulus_.value());
149155
}
150156
}
151157

158+
struct ntt_body {
159+
const uint64_t modulus, two_times_modulus;
160+
ntt_body(uint64_t modulus) : modulus(modulus), two_times_modulus(modulus << 1) {}
161+
162+
// x0' <- x0 + w * x1
163+
// x1' <- x0 - w * x1
164+
inline void forward(uint64_t *x0, uint64_t *x1, uint64_t W, uint64_t Wprime) const {
165+
uint64_t u = *x0;
166+
uint64_t v = mulmod_lazy(*x1, W, Wprime);
167+
168+
u -= select(two_times_modulus, u < two_times_modulus);
169+
*x0 = u + v;
170+
*x1 = u - v + two_times_modulus;
171+
}
172+
173+
// x0' <- x0 + x1
174+
// x1' <- x0 - w * x1
175+
inline void backward(uint64_t *x0, uint64_t *x1, uint64_t W, uint64_t Wprime) const {
176+
uint64_t u = *x0;
177+
uint64_t v = *x1;
178+
uint64_t t = u + v;
179+
t -= select(two_times_modulus, t < two_times_modulus);
180+
181+
*x0 = t;
182+
*x1 = mulmod_lazy(u - v + two_times_modulus, W, Wprime);
183+
}
184+
185+
inline void backward_last(uint64_t *x0, uint64_t *x1, uint64_t inv_N, uint64_t inv_Nprime, uint64_t inv_N_W, uint64_t inv_N_Wprime) const {
186+
uint64_t u = *x0;
187+
uint64_t v = *x1;
188+
uint64_t t = u + v;
189+
t -= select(two_times_modulus, t < two_times_modulus);
190+
191+
*x0 = mulmod_lazy(t, inv_N, inv_Nprime);
192+
*x1 = mulmod_lazy(u - v + two_times_modulus, inv_N_W, inv_N_Wprime);
193+
}
194+
195+
// x * y mod p using Shoup's trick, i.e., yprime = floor(2^64 * y / p)
196+
inline uint64_t mulmod_lazy(uint64_t x, uint64_t y, uint64_t yprime) const {
197+
unsigned long long q;
198+
multiply_uint64_hw64(x, yprime, &q);
199+
return x * y - q * modulus;
200+
}
201+
202+
// return 0 if cond = true, else return b if cond = false
203+
inline uint64_t select(uint64_t b, bool cond) const {
204+
return (b & -(uint64_t) cond) ^ b;
205+
}
206+
};
207+
152208
/**
153209
This function computes in-place the negacyclic NTT. The input is
154210
a polynomial a of degree n in R_q, where n is assumed to be a power of
@@ -162,10 +218,8 @@ namespace seal
162218
void ntt_negacyclic_harvey_lazy(uint64_t *operand,
163219
const SmallNTTTables &tables)
164220
{
165-
uint64_t modulus = tables.modulus().value();
166-
uint64_t two_times_modulus = modulus * 2;
221+
ntt_body ntt(tables.modulus().value());
167222

168-
// Return the NTT in scrambled order
169223
size_t n = size_t(1) << tables.coeff_count_power();
170224
size_t t = n >> 1;
171225
for (size_t m = 1; m < n; m <<= 1)
@@ -181,33 +235,12 @@ namespace seal
181235

182236
uint64_t *X = operand + j1;
183237
uint64_t *Y = X + t;
184-
uint64_t currX;
185-
unsigned long long Q;
186238
for (size_t j = j1; j < j2; j += 4)
187239
{
188-
currX = *X - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>(*X >= two_times_modulus)));
189-
multiply_uint64_hw64(Wprime, *Y, &Q);
190-
Q = *Y * W - Q * modulus;
191-
*X++ = currX + Q;
192-
*Y++ = currX + (two_times_modulus - Q);
193-
194-
currX = *X - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>(*X >= two_times_modulus)));
195-
multiply_uint64_hw64(Wprime, *Y, &Q);
196-
Q = *Y * W - Q * modulus;
197-
*X++ = currX + Q;
198-
*Y++ = currX + (two_times_modulus - Q);
199-
200-
currX = *X - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>(*X >= two_times_modulus)));
201-
multiply_uint64_hw64(Wprime, *Y, &Q);
202-
Q = *Y * W - Q * modulus;
203-
*X++ = currX + Q;
204-
*Y++ = currX + (two_times_modulus - Q);
205-
206-
currX = *X - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>(*X >= two_times_modulus)));
207-
multiply_uint64_hw64(Wprime, *Y, &Q);
208-
Q = *Y * W - Q * modulus;
209-
*X++ = currX + Q;
210-
*Y++ = currX + (two_times_modulus - Q);
240+
ntt.forward(X++, Y++, W, Wprime);
241+
ntt.forward(X++, Y++, W, Wprime);
242+
ntt.forward(X++, Y++, W, Wprime);
243+
ntt.forward(X++, Y++, W, Wprime);
211244
}
212245
}
213246
}
@@ -222,17 +255,9 @@ namespace seal
222255

223256
uint64_t *X = operand + j1;
224257
uint64_t *Y = X + t;
225-
uint64_t currX;
226-
unsigned long long Q;
227258
for (size_t j = j1; j < j2; j++)
228259
{
229-
// The Harvey butterfly: assume X, Y in [0, 2p), and return X', Y' in [0, 4p).
230-
// X', Y' = X + WY, X - WY (mod p).
231-
currX = *X - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>(*X >= two_times_modulus)));
232-
multiply_uint64_hw64(Wprime, *Y, &Q);
233-
Q = W * *Y - Q * modulus;
234-
*X++ = currX + Q;
235-
*Y++ = currX + (two_times_modulus - Q);
260+
ntt.forward(X++, Y++, W, Wprime);
236261
}
237262
}
238263
}
@@ -243,103 +268,71 @@ namespace seal
243268
// Inverse negacyclic NTT using Harvey's butterfly. (See Patrick Longa and Michael Naehrig).
244269
void inverse_ntt_negacyclic_harvey_lazy(uint64_t *operand, const SmallNTTTables &tables)
245270
{
246-
uint64_t modulus = tables.modulus().value();
247-
uint64_t two_times_modulus = modulus * 2;
271+
ntt_body ntt(tables.modulus().value());
248272

249-
// return the bit-reversed order of NTT.
250-
size_t n = size_t(1) << tables.coeff_count_power();
273+
const size_t n = size_t(1) << tables.coeff_count_power();
251274
size_t t = 1;
252-
253-
for (size_t m = n; m > 1; m >>= 1)
275+
size_t inv_root_index = 1;
276+
// m > 2 to skip the last layer
277+
for (size_t m = n; m > 2; m >>= 1)
254278
{
255279
size_t j1 = 0;
256280
size_t h = m >> 1;
257281
if (t >= 4)
258282
{
259-
for (size_t i = 0; i < h; i++)
283+
for (size_t i = 0; i < h; i++, ++inv_root_index)
260284
{
261285
size_t j2 = j1 + t;
262286
// Need the powers of phi^{-1} in bit-reversed order
263-
const uint64_t W = tables.get_from_inv_root_powers_div_two(h + i);
264-
const uint64_t Wprime = tables.get_from_scaled_inv_root_powers_div_two(h + i);
287+
const uint64_t W = tables.get_from_inv_root_powers(inv_root_index);
288+
const uint64_t Wprime = tables.get_from_scaled_inv_root_powers(inv_root_index);
265289

266290
uint64_t *U = operand + j1;
267291
uint64_t *V = U + t;
268-
uint64_t currU;
269-
uint64_t T;
270-
unsigned long long H;
271292
for (size_t j = j1; j < j2; j += 4)
272293
{
273-
T = two_times_modulus - *V + *U;
274-
currU = *U + *V - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>((*U << 1) >= T)));
275-
*U++ = (currU + (modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1)))) >> 1;
276-
multiply_uint64_hw64(Wprime, T, &H);
277-
*V++ = T * W - H * modulus;
278-
279-
T = two_times_modulus - *V + *U;
280-
currU = *U + *V - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>((*U << 1) >= T)));
281-
*U++ = (currU + (modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1)))) >> 1;
282-
multiply_uint64_hw64(Wprime, T, &H);
283-
*V++ = T * W - H * modulus;
284-
285-
T = two_times_modulus - *V + *U;
286-
currU = *U + *V - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>((*U << 1) >= T)));
287-
*U++ = (currU + (modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1)))) >> 1;
288-
multiply_uint64_hw64(Wprime, T, &H);
289-
*V++ = T * W - H * modulus;
290-
291-
T = two_times_modulus - *V + *U;
292-
currU = *U + *V - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>((*U << 1) >= T)));
293-
*U++ = (currU + (modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1)))) >> 1;
294-
multiply_uint64_hw64(Wprime, T, &H);
295-
*V++ = T * W - H * modulus;
294+
ntt.backward(U++, V++, W, Wprime);
295+
ntt.backward(U++, V++, W, Wprime);
296+
ntt.backward(U++, V++, W, Wprime);
297+
ntt.backward(U++, V++, W, Wprime);
296298
}
297299
j1 += (t << 1);
298300
}
299301
}
300302
else
301303
{
302-
for (size_t i = 0; i < h; i++)
304+
for (size_t i = 0; i < h; i++, ++inv_root_index)
303305
{
304306
size_t j2 = j1 + t;
305307
// Need the powers of phi^{-1} in bit-reversed order
306-
const uint64_t W = tables.get_from_inv_root_powers_div_two(h + i);
307-
const uint64_t Wprime = tables.get_from_scaled_inv_root_powers_div_two(h + i);
308+
const uint64_t W = tables.get_from_inv_root_powers(inv_root_index);
309+
const uint64_t Wprime = tables.get_from_scaled_inv_root_powers(inv_root_index);
308310

309311
uint64_t *U = operand + j1;
310312
uint64_t *V = U + t;
311-
uint64_t currU;
312-
uint64_t T;
313-
unsigned long long H;
314313
for (size_t j = j1; j < j2; j++)
315314
{
316-
// U = x[i], V = x[i+m]
317-
318-
// Compute U - V + 2q
319-
T = two_times_modulus - *V + *U;
320-
321-
// Cleverly check whether currU + currV >= two_times_modulus
322-
currU = *U + *V - (two_times_modulus & static_cast<uint64_t>(-static_cast<int64_t>((*U << 1) >= T)));
323-
324-
// Need to make it so that div2_uint_mod takes values that are > q.
325-
//div2_uint_mod(U, modulusptr, coeff_uint64_count, U);
326-
// We use also the fact that parity of currU is same as parity of T.
327-
// Since our modulus is always so small that currU + masked_modulus < 2^64,
328-
// we never need to worry about wrapping around when adding masked_modulus.
329-
//uint64_t masked_modulus = modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1));
330-
//uint64_t carry = add_uint64(currU, masked_modulus, 0, &currU);
331-
//currU += modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1));
332-
*U++ = (currU + (modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1)))) >> 1;
333-
334-
multiply_uint64_hw64(Wprime, T, &H);
335-
// effectively, the next two multiply perform multiply modulo beta = 2**wordsize.
336-
*V++ = W * T - H * modulus;
315+
ntt.backward(U++, V++, W, Wprime);
337316
}
338317
j1 += (t << 1);
339318
}
340319
}
341320
t <<= 1;
342321
}
322+
323+
// merge n^{-1} with the last layer of invNTT
324+
const uint64_t W = tables.get_from_inv_root_powers(inv_root_index);
325+
const uint64_t inv_N = *(tables.get_inv_degree_modulo());
326+
const uint64_t inv_N_W = multiply_uint_uint_mod(inv_N, W, tables.modulus());
327+
const uint64_t inv_Nprime = precompute_mulmod(inv_N, tables.modulus().value());
328+
const uint64_t inv_N_Wprime = precompute_mulmod(inv_N_W, tables.modulus().value());
329+
330+
uint64_t *U = operand;
331+
uint64_t *V = U + (n / 2);
332+
for (size_t j = n / 2; j < n; j++)
333+
{
334+
ntt.backward_last(U++, V++, inv_N, inv_Nprime, inv_N_W, inv_N_Wprime);
335+
}
343336
}
344337
}
345338
}

native/src/seal/util/smallntt.h

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -113,38 +113,6 @@ namespace seal
113113
return scaled_inv_root_powers_[index];
114114
}
115115

116-
SEAL_NODISCARD inline auto get_from_inv_root_powers_div_two(
117-
std::size_t index) const -> std::uint64_t
118-
{
119-
#ifdef SEAL_DEBUG
120-
if (index >= coeff_count_)
121-
{
122-
throw std::out_of_range("index");
123-
}
124-
if (!generated_)
125-
{
126-
throw std::logic_error("tables are not generated");
127-
}
128-
#endif
129-
return inv_root_powers_div_two_[index];
130-
}
131-
132-
SEAL_NODISCARD inline auto get_from_scaled_inv_root_powers_div_two(
133-
std::size_t index) const -> std::uint64_t
134-
{
135-
#ifdef SEAL_DEBUG
136-
if (index >= coeff_count_)
137-
{
138-
throw std::out_of_range("index");
139-
}
140-
if (!generated_)
141-
{
142-
throw std::logic_error("tables are not generated");
143-
}
144-
#endif
145-
return scaled_inv_root_powers_div_two_[index];
146-
}
147-
148116
SEAL_NODISCARD inline auto get_inv_degree_modulo() const
149117
-> const std::uint64_t*
150118
{
@@ -203,12 +171,6 @@ namespace seal
203171
// Size coeff_count_
204172
Pointer<decltype(root_)> scaled_root_powers_;
205173

206-
// Size coeff_count_
207-
Pointer<decltype(root_)> inv_root_powers_div_two_;
208-
209-
// Size coeff_count_
210-
Pointer<decltype(root_)> scaled_inv_root_powers_div_two_;
211-
212174
int coeff_count_power_ = 0;
213175

214176
std::size_t coeff_count_ = 0;

0 commit comments

Comments
 (0)