diff --git a/hexl/eltwise/eltwise-reduce-mod.cpp b/hexl/eltwise/eltwise-reduce-mod.cpp index 48164cd9..b0a209dd 100644 --- a/hexl/eltwise/eltwise-reduce-mod.cpp +++ b/hexl/eltwise/eltwise-reduce-mod.cpp @@ -99,7 +99,10 @@ void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, } #ifdef HEXL_HAS_AVX512IFMA - if (has_avx512ifma && modulus < (1ULL << 52)) { + // Modulus can be 52 bits only if input mod factors <= 4 + // otherwise modulus should be 51 bits max to give correct results + if ((has_avx512ifma && modulus < (1ULL << 51)) || + (modulus < (1ULL << 52) && input_mod_factor <= 4)) { EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor, output_mod_factor); return; diff --git a/hexl/util/avx512-util.hpp b/hexl/util/avx512-util.hpp index 29956425..df62443c 100644 --- a/hexl/util/avx512-util.hpp +++ b/hexl/util/avx512-util.hpp @@ -472,7 +472,6 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, // alpha - beta == 52, so we only need high 52 bits __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr_64); - // Z = prod_lo - (p * q_hat)_lo x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod); } else { diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index 2b4db3d4..6fc502cf 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -65,8 +65,9 @@ TEST(EltwiseReduceModMontInOut, avx512_64_mod_1) { } #ifdef HEXL_HAS_AVX512IFMA + TEST(EltwiseReduceMod, avx512_52_mod_1) { - if (!has_avx512dq) { + if (!has_avx512ifma) { GTEST_SKIP(); } @@ -82,8 +83,8 @@ TEST(EltwiseReduceMod, avx512_52_mod_1) { CheckEqual(result, exp_out); } -TEST(EltwiseReduceMod, avx512Big_mod_1) { - if (!has_avx512dq) { +TEST(EltwiseReduceMod, avx512_52_Big_mod_1) { + if (!has_avx512ifma) { GTEST_SKIP(); } @@ -101,6 +102,7 @@ TEST(EltwiseReduceMod, avx512Big_mod_1) { EltwiseReduceModAVX512<52>(result.data(), op.data(), op.size(), modulus, input_mod_factor, output_mod_factor); + CheckEqual(result, exp_out); } @@ -204,7 +206,7 @@ TEST(EltwiseReduceMod, AVX512Big_0_1) { size_t num_trials = 100; #endif for (size_t trial = 0; trial < num_trials; ++trial) { - auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, modulus); + auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 1ULL << 63); auto op2 = op1; std::vector result1(length, 0); @@ -306,10 +308,138 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) { std::vector result1(length, 0); std::vector result2(length, 0); + EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 2, + 1); + EltwiseReduceModAVX512(result2.data(), op2.data(), op1.size(), modulus, 2, + 1); + + ASSERT_EQ(result1, result2); + ASSERT_EQ(result1, result2); + } + } +} + +#ifdef HEXL_HAS_AVX512IFMA +// Checks AVX512 and native EltwiseReduceMod implementations match with randomly +// generated inputs +TEST(EltwiseReduceMod, AVX512_52_Big_0_1) { + if (!has_avx512ifma) { + GTEST_SKIP(); + } + + size_t length = 8; + + for (size_t bits = 45; bits <= 51; ++bits) { + uint64_t modulus = GeneratePrimes(1, bits, true, length)[0]; +#ifdef HEXL_DEBUG + size_t num_trials = 10; +#else + size_t num_trials = 1; +#endif + for (size_t trial = 0; trial < num_trials; ++trial) { + auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 1ULL << 63); + auto op2 = op1; + + std::vector result1(length, 0); + std::vector result2(length, 0); + + EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, + modulus, 1); + EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(), + modulus, modulus, 1); + + ASSERT_EQ(result1, result2); + ASSERT_EQ(result1, result2); + } + } +} + +TEST(EltwiseReduceMod, AVX512_52_Big_4_1) { + if (!has_avx512ifma) { + GTEST_SKIP(); + } + + size_t length = 8; + + for (size_t bits = 45; bits <= 52; ++bits) { + uint64_t modulus = GeneratePrimes(1, bits, true, length)[0]; +#ifdef HEXL_DEBUG + size_t num_trials = 10; +#else + size_t num_trials = 1; +#endif + for (size_t trial = 0; trial < num_trials; ++trial) { + auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 4 * modulus); + auto op2 = op1; + std::vector result1(length, 0); + std::vector result2(length, 0); + EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 4, 1); - EltwiseReduceModAVX512(result2.data(), op2.data(), op1.size(), modulus, 4, + EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(), + modulus, 4, 1); + + ASSERT_EQ(result1, result2); + ASSERT_EQ(result1, result2); + } + } +} + +TEST(EltwiseReduceMod, AVX512_52_Big_4_2) { + if (!has_avx512ifma) { + GTEST_SKIP(); + } + + size_t length = 8; + + for (size_t bits = 45; bits <= 52; ++bits) { + uint64_t modulus = GeneratePrimes(1, bits, true, length)[0]; +#ifdef HEXL_DEBUG + size_t num_trials = 10; +#else + size_t num_trials = 1; +#endif + for (size_t trial = 0; trial < num_trials; ++trial) { + auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 4 * modulus); + auto op2 = op1; + std::vector result1(length, 0); + std::vector result2(length, 0); + + EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 4, + 2); + EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(), + modulus, 4, 2); + + ASSERT_EQ(result1, result2); + ASSERT_EQ(result1, result2); + } + } +} + +TEST(EltwiseReduceMod, AVX512_52_Big_2_1) { + if (!has_avx512ifma) { + GTEST_SKIP(); + } + + size_t length = 8; + + for (size_t bits = 45; bits <= 52; ++bits) { + uint64_t modulus = GeneratePrimes(1, bits, true, length)[0]; +#ifdef HEXL_DEBUG + size_t num_trials = 10; +#else + size_t num_trials = 1; +#endif + for (size_t trial = 0; trial < num_trials; ++trial) { + auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 2 * modulus); + auto op2 = op1; + std::vector result1(length, 0); + std::vector result2(length, 0); + + EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 2, 1); + EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(), + modulus, 2, 1); ASSERT_EQ(result1, result2); ASSERT_EQ(result1, result2); @@ -319,5 +449,7 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) { #endif +#endif + } // namespace hexl } // namespace intel