Skip to content

Commit 00fc326

Browse files
authored
Implement mean compensation for key switch without using it (#214)
1 parent be14274 commit 00fc326

File tree

1 file changed

+72
-1
lines changed

1 file changed

+72
-1
lines changed

sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch.rs

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::{
2-
LweDef, OverlaySize, PolynomialDegree, RadixDecomposition, TorusOps,
2+
LweDef, OverlaySize, PolynomialDegree, RadixDecomposition, Torus, TorusOps,
33
dst::{FromMutSlice, FromSlice},
44
entities::{LweCiphertext, LweCiphertextRef, LweKeyswitchKeyRef, PolynomialRef},
55
ops::{
@@ -10,6 +10,59 @@ use crate::{
1010
scratch::allocate_scratch_ref,
1111
};
1212

13+
/// Run mean compensation before key switch, see <https://eprint.iacr.org/2025/809.pdf>
14+
///
15+
/// Basically, we use the radix definition to compute how many bits are dropped, then
16+
/// we effectively round all A's to that number of bits, for example, if we have 64
17+
/// bits total and radix definition says dropping 52 LSBs, then 0xabcd1234567890ef
18+
/// will be rounded to 0xabd0000000000000, while 0x4321098765abcdef will be rounded
19+
/// to 0x4320000000000000, this creates an error so we accumulate this error for all
20+
/// A's and then for B we subtract half of that accumulated error (half because the
21+
/// mean of secret key is 0.5). These new A's and B will be written to the output
22+
///
23+
/// Arguments:
24+
///
25+
/// * output: the output ciphertext
26+
/// * input: the input ciphertext
27+
/// * params: the LWE definition
28+
/// * radix: the key switch radix decomposition definition
29+
pub fn mean_compensate_pre_keyswitch_lwe_to_lwe<S: TorusOps>(
30+
output: &mut LweCiphertextRef<S>,
31+
input: &LweCiphertextRef<S>,
32+
params: &LweDef,
33+
radix: &RadixDecomposition,
34+
) {
35+
input.assert_is_valid(params.dim);
36+
output.assert_is_valid(params.dim);
37+
38+
let (input_a, input_b) = input.a_b(params);
39+
let (output_a, output_b) = output.a_b_mut(params);
40+
41+
let bits_to_drop = S::BITS as usize - radix.count.0 * radix.radix_log.0;
42+
let rounder = <S as sunscreen_math::One>::one() << (bits_to_drop - 1);
43+
44+
let mut cum_err = <S as sunscreen_math::Zero>::zero();
45+
46+
for (i, o) in input_a.iter().zip(output_a.iter_mut()) {
47+
// variable `rounder` defined above is a special number that has only one bit as `1` at the most
48+
// significant position in the dropped section, that position determines if rounding goes up
49+
// or down, by adding `1` at that position, an original value of `0` will not generate a carry
50+
// while an original value of `1` will generate a carry, this allows use to just "truncate" the
51+
// new value to achieve rounding
52+
*o = Torus::from(i.wrapping_add(&rounder) >> bits_to_drop << bits_to_drop);
53+
cum_err = cum_err.wrapping_add(&i.wrapping_sub(o));
54+
}
55+
56+
// multiply `cum_err` by mean of secret key which is 1/2, so we implement right shift by 1
57+
// for this purpose, and note here `cum_err` must be interpreted as signed value so we mask
58+
// its MSB to add back later
59+
let cum_err_msb = cum_err & (<S as sunscreen_math::One>::one() << (S::BITS as usize - 1));
60+
cum_err = cum_err >> 1;
61+
cum_err |= cum_err_msb;
62+
63+
*output_b = Torus::from(input_b.wrapping_sub(&cum_err));
64+
}
65+
1366
/// Switches a ciphertext under the original key to a ciphertext under the new
1467
/// key using a keyswitch key.
1568
///
@@ -37,6 +90,24 @@ pub fn keyswitch_lwe_to_lwe<S>(
3790
ciphertext_under_original_key.assert_is_valid(old_params.dim);
3891
keyswitch_key.assert_is_valid((old_params.dim, new_params.dim, radix.count));
3992

93+
// we decide not to use mean compensation, but here is the code to make use of it in case useful
94+
if false {
95+
allocate_scratch_ref!(
96+
fixed_ciphertext_under_original_key,
97+
LweCiphertextRef<S>,
98+
(old_params.dim)
99+
);
100+
101+
mean_compensate_pre_keyswitch_lwe_to_lwe(
102+
fixed_ciphertext_under_original_key,
103+
ciphertext_under_original_key,
104+
old_params,
105+
radix,
106+
);
107+
108+
let (_ciphertext_a, _ciphertext_b) = fixed_ciphertext_under_original_key.a_b(old_params);
109+
}
110+
40111
let (ciphertext_a, ciphertext_b) = ciphertext_under_original_key.a_b(old_params);
41112

42113
let keyswitch_levs = keyswitch_key.rows(new_params, radix);

0 commit comments

Comments
 (0)