Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion hexl/eltwise/eltwise-reduce-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n,
}

#ifdef HEXL_HAS_AVX512IFMA
if (has_avx512ifma && modulus < (1ULL << 52)) {
// Modulus can be 52 bits only if input mod factors <= 4
// otherwise modulus should be 51 bits max to give correct results
if ((has_avx512ifma && modulus < (1ULL << 51)) ||
(modulus < (1ULL << 52) && input_mod_factor <= 4)) {
EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor,
output_mod_factor);
return;
Expand Down
1 change: 0 additions & 1 deletion hexl/util/avx512-util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,6 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q,

// alpha - beta == 52, so we only need high 52 bits
__m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr_64);

// Z = prod_lo - (p * q_hat)_lo
x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod);
} else {
Expand Down
142 changes: 137 additions & 5 deletions test/test-eltwise-reduce-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ TEST(EltwiseReduceModMontInOut, avx512_64_mod_1) {
}

#ifdef HEXL_HAS_AVX512IFMA

TEST(EltwiseReduceMod, avx512_52_mod_1) {
if (!has_avx512dq) {
if (!has_avx512ifma) {
GTEST_SKIP();
}

Expand All @@ -82,8 +83,8 @@ TEST(EltwiseReduceMod, avx512_52_mod_1) {
CheckEqual(result, exp_out);
}

TEST(EltwiseReduceMod, avx512Big_mod_1) {
if (!has_avx512dq) {
TEST(EltwiseReduceMod, avx512_52_Big_mod_1) {
if (!has_avx512ifma) {
GTEST_SKIP();
}

Expand All @@ -101,6 +102,7 @@ TEST(EltwiseReduceMod, avx512Big_mod_1) {

EltwiseReduceModAVX512<52>(result.data(), op.data(), op.size(), modulus,
input_mod_factor, output_mod_factor);

CheckEqual(result, exp_out);
}

Expand Down Expand Up @@ -204,7 +206,7 @@ TEST(EltwiseReduceMod, AVX512Big_0_1) {
size_t num_trials = 100;
#endif
for (size_t trial = 0; trial < num_trials; ++trial) {
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, modulus);
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 1ULL << 63);
auto op2 = op1;

std::vector<uint64_t> result1(length, 0);
Expand Down Expand Up @@ -306,10 +308,138 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) {
std::vector<uint64_t> result1(length, 0);
std::vector<uint64_t> result2(length, 0);

EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 2,
1);
EltwiseReduceModAVX512(result2.data(), op2.data(), op1.size(), modulus, 2,
1);

ASSERT_EQ(result1, result2);
ASSERT_EQ(result1, result2);
}
}
}

#ifdef HEXL_HAS_AVX512IFMA
// Checks AVX512 and native EltwiseReduceMod implementations match with randomly
// generated inputs
TEST(EltwiseReduceMod, AVX512_52_Big_0_1) {
if (!has_avx512ifma) {
GTEST_SKIP();
}

size_t length = 8;

for (size_t bits = 45; bits <= 51; ++bits) {
uint64_t modulus = GeneratePrimes(1, bits, true, length)[0];
#ifdef HEXL_DEBUG
size_t num_trials = 10;
#else
size_t num_trials = 1;
#endif
for (size_t trial = 0; trial < num_trials; ++trial) {
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 1ULL << 63);
auto op2 = op1;

std::vector<uint64_t> result1(length, 0);
std::vector<uint64_t> result2(length, 0);

EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus,
modulus, 1);
EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(),
modulus, modulus, 1);

ASSERT_EQ(result1, result2);
ASSERT_EQ(result1, result2);
}
}
}

TEST(EltwiseReduceMod, AVX512_52_Big_4_1) {
if (!has_avx512ifma) {
GTEST_SKIP();
}

size_t length = 8;

for (size_t bits = 45; bits <= 52; ++bits) {
uint64_t modulus = GeneratePrimes(1, bits, true, length)[0];
#ifdef HEXL_DEBUG
size_t num_trials = 10;
#else
size_t num_trials = 1;
#endif
for (size_t trial = 0; trial < num_trials; ++trial) {
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 4 * modulus);
auto op2 = op1;
std::vector<uint64_t> result1(length, 0);
std::vector<uint64_t> result2(length, 0);

EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 4,
1);
EltwiseReduceModAVX512(result2.data(), op2.data(), op1.size(), modulus, 4,
EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(),
modulus, 4, 1);

ASSERT_EQ(result1, result2);
ASSERT_EQ(result1, result2);
}
}
}

TEST(EltwiseReduceMod, AVX512_52_Big_4_2) {
if (!has_avx512ifma) {
GTEST_SKIP();
}

size_t length = 8;

for (size_t bits = 45; bits <= 52; ++bits) {
uint64_t modulus = GeneratePrimes(1, bits, true, length)[0];
#ifdef HEXL_DEBUG
size_t num_trials = 10;
#else
size_t num_trials = 1;
#endif
for (size_t trial = 0; trial < num_trials; ++trial) {
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 4 * modulus);
auto op2 = op1;
std::vector<uint64_t> result1(length, 0);
std::vector<uint64_t> result2(length, 0);

EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 4,
2);
EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(),
modulus, 4, 2);

ASSERT_EQ(result1, result2);
ASSERT_EQ(result1, result2);
}
}
}

TEST(EltwiseReduceMod, AVX512_52_Big_2_1) {
if (!has_avx512ifma) {
GTEST_SKIP();
}

size_t length = 8;

for (size_t bits = 45; bits <= 52; ++bits) {
uint64_t modulus = GeneratePrimes(1, bits, true, length)[0];
#ifdef HEXL_DEBUG
size_t num_trials = 10;
#else
size_t num_trials = 1;
#endif
for (size_t trial = 0; trial < num_trials; ++trial) {
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 2 * modulus);
auto op2 = op1;
std::vector<uint64_t> result1(length, 0);
std::vector<uint64_t> result2(length, 0);

EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 2,
1);
EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(),
modulus, 2, 1);

ASSERT_EQ(result1, result2);
ASSERT_EQ(result1, result2);
Expand All @@ -319,5 +449,7 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) {

#endif

#endif

} // namespace hexl
} // namespace intel