@@ -42,8 +42,6 @@ namespace seal
4242 scaled_root_powers_.release ();
4343 inv_root_powers_.release ();
4444 scaled_inv_root_powers_.release ();
45- inv_root_powers_div_two_.release ();
46- scaled_inv_root_powers_div_two_.release ();
4745 inv_degree_modulo_ = 0 ;
4846 coeff_count_power_ = 0 ;
4947 coeff_count_ = 0 ;
@@ -68,8 +66,6 @@ namespace seal
6866 inv_root_powers_ = allocate_uint (coeff_count_, pool_);
6967 scaled_root_powers_ = allocate_uint (coeff_count_, pool_);
7068 scaled_inv_root_powers_ = allocate_uint (coeff_count_, pool_);
71- inv_root_powers_div_two_ = allocate_uint (coeff_count_, pool_);
72- scaled_inv_root_powers_div_two_ = allocate_uint (coeff_count_, pool_);
7369 modulus_ = modulus;
7470
7571 // We defer parameter checking to try_minimal_primitive_root(...)
@@ -98,15 +94,21 @@ namespace seal
9894 ntt_scale_powers_of_primitive_root (inv_root_powers_.get (),
9995 scaled_inv_root_powers_.get ());
10096
101- // Populate the tables storing (scaled version of ) 2 times
102- // powers of roots^-1 mod q in bit-scrambled order.
103- for (size_t i = 0 ; i < coeff_count_; i++)
104- {
105- inv_root_powers_div_two_[i] =
106- div2_uint_mod (inv_root_powers_[i], modulus_);
97+ // Reordering inv_root_powers_ so that the access pattern at inverse NTT is sequential.
98+ std::vector<uint64_t > tmp (coeff_count_);
99+ uint64_t *ptr = tmp.data () + 1 ;
100+ for (size_t i = coeff_count_ / 2 ; i > 0 ; i /= 2 ) {
101+ for (size_t j = i; j < i * 2 ; ++j)
102+ *ptr++ = inv_root_powers_[j];
103+ }
104+ std::copy (tmp.cbegin (), tmp.cend (), inv_root_powers_.get ());
105+
106+ ptr = tmp.data () + 1 ;
107+ for (size_t i = coeff_count_ / 2 ; i > 0 ; i /= 2 ) {
108+ for (size_t j = i; j < i * 2 ; ++j)
109+ *ptr++ = scaled_inv_root_powers_[j];
107110 }
108- ntt_scale_powers_of_primitive_root (inv_root_powers_div_two_.get (),
109- scaled_inv_root_powers_div_two_.get ());
111+ std::copy (tmp.cbegin (), tmp.cend (), scaled_inv_root_powers_.get ());
110112
111113 // Last compute n^(-1) modulo q.
112114 uint64_t degree_uint = static_cast <uint64_t >(coeff_count_);
@@ -137,18 +139,72 @@ namespace seal
137139
138140 // compute floor ( input * beta /q ), where beta is a 64k power of 2
139141 // and 0 < q < beta.
142+ static inline uint64_t precompute_mulmod (uint64_t y, uint64_t p) {
143+ uint64_t wide_quotient[2 ]{ 0 , 0 };
144+ uint64_t wide_coeff[2 ]{ 0 , y };
145+ divide_uint128_uint64_inplace (wide_coeff, p, wide_quotient);
146+ return wide_quotient[0 ];
147+ }
148+
140149 void SmallNTTTables::ntt_scale_powers_of_primitive_root (
141150 const uint64_t *input, uint64_t *destination) const
142151 {
143152 for (size_t i = 0 ; i < coeff_count_; i++, input++, destination++)
144153 {
145- uint64_t wide_quotient[2 ]{ 0 , 0 };
146- uint64_t wide_coeff[2 ]{ 0 , *input };
147- divide_uint128_uint64_inplace (wide_coeff, modulus_.value (), wide_quotient);
148- *destination = wide_quotient[0 ];
154+ *destination = precompute_mulmod (*input, modulus_.value ());
149155 }
150156 }
151157
158+ struct ntt_body {
159+ const uint64_t modulus, two_times_modulus;
160+ ntt_body (uint64_t modulus) : modulus(modulus), two_times_modulus(modulus << 1 ) {}
161+
162+ // x0' <- x0 + w * x1
163+ // x1' <- x0 - w * x1
164+ inline void forward (uint64_t *x0, uint64_t *x1, uint64_t W, uint64_t Wprime) const {
165+ uint64_t u = *x0;
166+ uint64_t v = mulmod_lazy (*x1, W, Wprime);
167+
168+ u -= select (two_times_modulus, u < two_times_modulus);
169+ *x0 = u + v;
170+ *x1 = u - v + two_times_modulus;
171+ }
172+
173+ // x0' <- x0 + x1
174+ // x1' <- x0 - w * x1
175+ inline void backward (uint64_t *x0, uint64_t *x1, uint64_t W, uint64_t Wprime) const {
176+ uint64_t u = *x0;
177+ uint64_t v = *x1;
178+ uint64_t t = u + v;
179+ t -= select (two_times_modulus, t < two_times_modulus);
180+
181+ *x0 = t;
182+ *x1 = mulmod_lazy (u - v + two_times_modulus, W, Wprime);
183+ }
184+
185+ inline void backward_last (uint64_t *x0, uint64_t *x1, uint64_t inv_N, uint64_t inv_Nprime, uint64_t inv_N_W, uint64_t inv_N_Wprime) const {
186+ uint64_t u = *x0;
187+ uint64_t v = *x1;
188+ uint64_t t = u + v;
189+ t -= select (two_times_modulus, t < two_times_modulus);
190+
191+ *x0 = mulmod_lazy (t, inv_N, inv_Nprime);
192+ *x1 = mulmod_lazy (u - v + two_times_modulus, inv_N_W, inv_N_Wprime);
193+ }
194+
195+ // x * y mod p using Shoup's trick, i.e., yprime = floor(2^64 * y / p)
196+ inline uint64_t mulmod_lazy (uint64_t x, uint64_t y, uint64_t yprime) const {
197+ unsigned long long q;
198+ multiply_uint64_hw64 (x, yprime, &q);
199+ return x * y - q * modulus;
200+ }
201+
202+ // return 0 if cond = true, else return b if cond = false
203+ inline uint64_t select (uint64_t b, bool cond) const {
204+ return (b & -(uint64_t ) cond) ^ b;
205+ }
206+ };
207+
152208 /* *
153209 This function computes in-place the negacyclic NTT. The input is
154210 a polynomial a of degree n in R_q, where n is assumed to be a power of
@@ -162,10 +218,8 @@ namespace seal
162218 void ntt_negacyclic_harvey_lazy (uint64_t *operand,
163219 const SmallNTTTables &tables)
164220 {
165- uint64_t modulus = tables.modulus ().value ();
166- uint64_t two_times_modulus = modulus * 2 ;
221+ ntt_body ntt (tables.modulus ().value ());
167222
168- // Return the NTT in scrambled order
169223 size_t n = size_t (1 ) << tables.coeff_count_power ();
170224 size_t t = n >> 1 ;
171225 for (size_t m = 1 ; m < n; m <<= 1 )
@@ -181,33 +235,12 @@ namespace seal
181235
182236 uint64_t *X = operand + j1;
183237 uint64_t *Y = X + t;
184- uint64_t currX;
185- unsigned long long Q;
186238 for (size_t j = j1; j < j2; j += 4 )
187239 {
188- currX = *X - (two_times_modulus & static_cast <uint64_t >(-static_cast <int64_t >(*X >= two_times_modulus)));
189- multiply_uint64_hw64 (Wprime, *Y, &Q);
190- Q = *Y * W - Q * modulus;
191- *X++ = currX + Q;
192- *Y++ = currX + (two_times_modulus - Q);
193-
194- currX = *X - (two_times_modulus & static_cast <uint64_t >(-static_cast <int64_t >(*X >= two_times_modulus)));
195- multiply_uint64_hw64 (Wprime, *Y, &Q);
196- Q = *Y * W - Q * modulus;
197- *X++ = currX + Q;
198- *Y++ = currX + (two_times_modulus - Q);
199-
200- currX = *X - (two_times_modulus & static_cast <uint64_t >(-static_cast <int64_t >(*X >= two_times_modulus)));
201- multiply_uint64_hw64 (Wprime, *Y, &Q);
202- Q = *Y * W - Q * modulus;
203- *X++ = currX + Q;
204- *Y++ = currX + (two_times_modulus - Q);
205-
206- currX = *X - (two_times_modulus & static_cast <uint64_t >(-static_cast <int64_t >(*X >= two_times_modulus)));
207- multiply_uint64_hw64 (Wprime, *Y, &Q);
208- Q = *Y * W - Q * modulus;
209- *X++ = currX + Q;
210- *Y++ = currX + (two_times_modulus - Q);
240+ ntt.forward (X++, Y++, W, Wprime);
241+ ntt.forward (X++, Y++, W, Wprime);
242+ ntt.forward (X++, Y++, W, Wprime);
243+ ntt.forward (X++, Y++, W, Wprime);
211244 }
212245 }
213246 }
@@ -222,17 +255,9 @@ namespace seal
222255
223256 uint64_t *X = operand + j1;
224257 uint64_t *Y = X + t;
225- uint64_t currX;
226- unsigned long long Q;
227258 for (size_t j = j1; j < j2; j++)
228259 {
229- // The Harvey butterfly: assume X, Y in [0, 2p), and return X', Y' in [0, 4p).
230- // X', Y' = X + WY, X - WY (mod p).
231- currX = *X - (two_times_modulus & static_cast <uint64_t >(-static_cast <int64_t >(*X >= two_times_modulus)));
232- multiply_uint64_hw64 (Wprime, *Y, &Q);
233- Q = W * *Y - Q * modulus;
234- *X++ = currX + Q;
235- *Y++ = currX + (two_times_modulus - Q);
260+ ntt.forward (X++, Y++, W, Wprime);
236261 }
237262 }
238263 }
@@ -243,103 +268,71 @@ namespace seal
243268 // Inverse negacyclic NTT using Harvey's butterfly. (See Patrick Longa and Michael Naehrig).
244269 void inverse_ntt_negacyclic_harvey_lazy (uint64_t *operand, const SmallNTTTables &tables)
245270 {
246- uint64_t modulus = tables.modulus ().value ();
247- uint64_t two_times_modulus = modulus * 2 ;
271+ ntt_body ntt (tables.modulus ().value ());
248272
249- // return the bit-reversed order of NTT.
250- size_t n = size_t (1 ) << tables.coeff_count_power ();
273+ const size_t n = size_t (1 ) << tables.coeff_count_power ();
251274 size_t t = 1 ;
252-
253- for (size_t m = n; m > 1 ; m >>= 1 )
275+ size_t inv_root_index = 1 ;
276+ // m > 2 to skip the last layer
277+ for (size_t m = n; m > 2 ; m >>= 1 )
254278 {
255279 size_t j1 = 0 ;
256280 size_t h = m >> 1 ;
257281 if (t >= 4 )
258282 {
259- for (size_t i = 0 ; i < h; i++)
283+ for (size_t i = 0 ; i < h; i++, ++inv_root_index )
260284 {
261285 size_t j2 = j1 + t;
262286 // Need the powers of phi^{-1} in bit-reversed order
263- const uint64_t W = tables.get_from_inv_root_powers_div_two (h + i );
264- const uint64_t Wprime = tables.get_from_scaled_inv_root_powers_div_two (h + i );
287+ const uint64_t W = tables.get_from_inv_root_powers (inv_root_index );
288+ const uint64_t Wprime = tables.get_from_scaled_inv_root_powers (inv_root_index );
265289
266290 uint64_t *U = operand + j1;
267291 uint64_t *V = U + t;
268- uint64_t currU;
269- uint64_t T;
270- unsigned long long H;
271292 for (size_t j = j1; j < j2; j += 4 )
272293 {
273- T = two_times_modulus - *V + *U;
274- currU = *U + *V - (two_times_modulus & static_cast <uint64_t >(-static_cast <int64_t >((*U << 1 ) >= T)));
275- *U++ = (currU + (modulus & static_cast <uint64_t >(-static_cast <int64_t >(T & 1 )))) >> 1 ;
276- multiply_uint64_hw64 (Wprime, T, &H);
277- *V++ = T * W - H * modulus;
278-
279- T = two_times_modulus - *V + *U;
280- currU = *U + *V - (two_times_modulus & static_cast <uint64_t >(-static_cast <int64_t >((*U << 1 ) >= T)));
281- *U++ = (currU + (modulus & static_cast <uint64_t >(-static_cast <int64_t >(T & 1 )))) >> 1 ;
282- multiply_uint64_hw64 (Wprime, T, &H);
283- *V++ = T * W - H * modulus;
284-
285- T = two_times_modulus - *V + *U;
286- currU = *U + *V - (two_times_modulus & static_cast <uint64_t >(-static_cast <int64_t >((*U << 1 ) >= T)));
287- *U++ = (currU + (modulus & static_cast <uint64_t >(-static_cast <int64_t >(T & 1 )))) >> 1 ;
288- multiply_uint64_hw64 (Wprime, T, &H);
289- *V++ = T * W - H * modulus;
290-
291- T = two_times_modulus - *V + *U;
292- currU = *U + *V - (two_times_modulus & static_cast <uint64_t >(-static_cast <int64_t >((*U << 1 ) >= T)));
293- *U++ = (currU + (modulus & static_cast <uint64_t >(-static_cast <int64_t >(T & 1 )))) >> 1 ;
294- multiply_uint64_hw64 (Wprime, T, &H);
295- *V++ = T * W - H * modulus;
294+ ntt.backward (U++, V++, W, Wprime);
295+ ntt.backward (U++, V++, W, Wprime);
296+ ntt.backward (U++, V++, W, Wprime);
297+ ntt.backward (U++, V++, W, Wprime);
296298 }
297299 j1 += (t << 1 );
298300 }
299301 }
300302 else
301303 {
302- for (size_t i = 0 ; i < h; i++)
304+ for (size_t i = 0 ; i < h; i++, ++inv_root_index )
303305 {
304306 size_t j2 = j1 + t;
305307 // Need the powers of phi^{-1} in bit-reversed order
306- const uint64_t W = tables.get_from_inv_root_powers_div_two (h + i );
307- const uint64_t Wprime = tables.get_from_scaled_inv_root_powers_div_two (h + i );
308+ const uint64_t W = tables.get_from_inv_root_powers (inv_root_index );
309+ const uint64_t Wprime = tables.get_from_scaled_inv_root_powers (inv_root_index );
308310
309311 uint64_t *U = operand + j1;
310312 uint64_t *V = U + t;
311- uint64_t currU;
312- uint64_t T;
313- unsigned long long H;
314313 for (size_t j = j1; j < j2; j++)
315314 {
316- // U = x[i], V = x[i+m]
317-
318- // Compute U - V + 2q
319- T = two_times_modulus - *V + *U;
320-
321- // Cleverly check whether currU + currV >= two_times_modulus
322- currU = *U + *V - (two_times_modulus & static_cast <uint64_t >(-static_cast <int64_t >((*U << 1 ) >= T)));
323-
324- // Need to make it so that div2_uint_mod takes values that are > q.
325- // div2_uint_mod(U, modulusptr, coeff_uint64_count, U);
326- // We use also the fact that parity of currU is same as parity of T.
327- // Since our modulus is always so small that currU + masked_modulus < 2^64,
328- // we never need to worry about wrapping around when adding masked_modulus.
329- // uint64_t masked_modulus = modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1));
330- // uint64_t carry = add_uint64(currU, masked_modulus, 0, &currU);
331- // currU += modulus & static_cast<uint64_t>(-static_cast<int64_t>(T & 1));
332- *U++ = (currU + (modulus & static_cast <uint64_t >(-static_cast <int64_t >(T & 1 )))) >> 1 ;
333-
334- multiply_uint64_hw64 (Wprime, T, &H);
335- // effectively, the next two multiply perform multiply modulo beta = 2**wordsize.
336- *V++ = W * T - H * modulus;
315+ ntt.backward (U++, V++, W, Wprime);
337316 }
338317 j1 += (t << 1 );
339318 }
340319 }
341320 t <<= 1 ;
342321 }
322+
323+ // merge n^{-1} with the last layer of invNTT
324+ const uint64_t W = tables.get_from_inv_root_powers (inv_root_index);
325+ const uint64_t inv_N = *(tables.get_inv_degree_modulo ());
326+ const uint64_t inv_N_W = multiply_uint_uint_mod (inv_N, W, tables.modulus ());
327+ const uint64_t inv_Nprime = precompute_mulmod (inv_N, tables.modulus ().value ());
328+ const uint64_t inv_N_Wprime = precompute_mulmod (inv_N_W, tables.modulus ().value ());
329+
330+ uint64_t *U = operand;
331+ uint64_t *V = U + (n / 2 );
332+ for (size_t j = n / 2 ; j < n; j++)
333+ {
334+ ntt.backward_last (U++, V++, inv_N, inv_Nprime, inv_N_W, inv_N_Wprime);
335+ }
343336 }
344337 }
345338}
0 commit comments