11use crate :: {
2- LweDef , OverlaySize , PolynomialDegree , RadixDecomposition , TorusOps ,
2+ LweDef , OverlaySize , PolynomialDegree , RadixDecomposition , Torus , TorusOps ,
33 dst:: { FromMutSlice , FromSlice } ,
44 entities:: { LweCiphertext , LweCiphertextRef , LweKeyswitchKeyRef , PolynomialRef } ,
55 ops:: {
@@ -10,6 +10,59 @@ use crate::{
1010 scratch:: allocate_scratch_ref,
1111} ;
1212
13+ /// Run mean compensation before key switch, see <https://eprint.iacr.org/2025/809.pdf>
14+ ///
15+ /// Basically, we use the radix definition to compute how many bits are dropped, then
16+ /// we effectively round all A's to that number of bits, for example, if we have 64
17+ /// bits total and radix definition says dropping 52 LSBs, then 0xabcd1234567890ef
18+ /// will be rounded to 0xabd0000000000000, while 0x4321098765abcdef will be rounded
19+ /// to 0x4320000000000000, this creates an error so we accumulate this error for all
20+ /// A's and then for B we subtract half of that accumulated error (half because the
21+ /// mean of secret key is 0.5). These new A's and B will be written to the output
22+ ///
23+ /// Arguments:
24+ ///
25+ /// * output: the output ciphertext
26+ /// * input: the input ciphertext
27+ /// * params: the LWE definition
28+ /// * radix: the key switch radix decomposition definition
29+ pub fn mean_compensate_pre_keyswitch_lwe_to_lwe < S : TorusOps > (
30+ output : & mut LweCiphertextRef < S > ,
31+ input : & LweCiphertextRef < S > ,
32+ params : & LweDef ,
33+ radix : & RadixDecomposition ,
34+ ) {
35+ input. assert_is_valid ( params. dim ) ;
36+ output. assert_is_valid ( params. dim ) ;
37+
38+ let ( input_a, input_b) = input. a_b ( params) ;
39+ let ( output_a, output_b) = output. a_b_mut ( params) ;
40+
41+ let bits_to_drop = S :: BITS as usize - radix. count . 0 * radix. radix_log . 0 ;
42+ let rounder = <S as sunscreen_math:: One >:: one ( ) << ( bits_to_drop - 1 ) ;
43+
44+ let mut cum_err = <S as sunscreen_math:: Zero >:: zero ( ) ;
45+
46+ for ( i, o) in input_a. iter ( ) . zip ( output_a. iter_mut ( ) ) {
47+ // variable `rounder` defined above is a special number that has only one bit as `1` at the most
48+ // significant position in the dropped section, that position determines if rounding goes up
49+ // or down, by adding `1` at that position, an original value of `0` will not generate a carry
50+ // while an original value of `1` will generate a carry, this allows use to just "truncate" the
51+ // new value to achieve rounding
52+ * o = Torus :: from ( i. wrapping_add ( & rounder) >> bits_to_drop << bits_to_drop) ;
53+ cum_err = cum_err. wrapping_add ( & i. wrapping_sub ( o) ) ;
54+ }
55+
56+ // multiply `cum_err` by mean of secret key which is 1/2, so we implement right shift by 1
57+ // for this purpose, and note here `cum_err` must be interpreted as signed value so we mask
58+ // its MSB to add back later
59+ let cum_err_msb = cum_err & ( <S as sunscreen_math:: One >:: one ( ) << ( S :: BITS as usize - 1 ) ) ;
60+ cum_err = cum_err >> 1 ;
61+ cum_err |= cum_err_msb;
62+
63+ * output_b = Torus :: from ( input_b. wrapping_sub ( & cum_err) ) ;
64+ }
65+
1366/// Switches a ciphertext under the original key to a ciphertext under the new
1467/// key using a keyswitch key.
1568///
@@ -37,6 +90,24 @@ pub fn keyswitch_lwe_to_lwe<S>(
3790 ciphertext_under_original_key. assert_is_valid ( old_params. dim ) ;
3891 keyswitch_key. assert_is_valid ( ( old_params. dim , new_params. dim , radix. count ) ) ;
3992
93+ // we decide not to use mean compensation, but here is the code to make use of it in case useful
94+ if false {
95+ allocate_scratch_ref ! (
96+ fixed_ciphertext_under_original_key,
97+ LweCiphertextRef <S >,
98+ ( old_params. dim)
99+ ) ;
100+
101+ mean_compensate_pre_keyswitch_lwe_to_lwe (
102+ fixed_ciphertext_under_original_key,
103+ ciphertext_under_original_key,
104+ old_params,
105+ radix,
106+ ) ;
107+
108+ let ( _ciphertext_a, _ciphertext_b) = fixed_ciphertext_under_original_key. a_b ( old_params) ;
109+ }
110+
40111 let ( ciphertext_a, ciphertext_b) = ciphertext_under_original_key. a_b ( old_params) ;
41112
42113 let keyswitch_levs = keyswitch_key. rows ( new_params, radix) ;
0 commit comments