diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..052c1de65 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,92 @@ +{ + "files.associations": { + "stdexcept": "cpp", + "__bit_reference": "cpp", + "__config": "cpp", + "__debug": "cpp", + "__errc": "cpp", + "__functional_base": "cpp", + "__hash_table": "cpp", + "__locale": "cpp", + "__mutex_base": "cpp", + "__node_handle": "cpp", + "__nullptr": "cpp", + "__split_buffer": "cpp", + "__string": "cpp", + "__threading_support": "cpp", + "__tree": "cpp", + "__tuple": "cpp", + "algorithm": "cpp", + "array": "cpp", + "atomic": "cpp", + "bit": "cpp", + "bitset": "cpp", + "cctype": "cpp", + "chrono": "cpp", + "cinttypes": "cpp", + "cmath": "cpp", + "complex": "cpp", + "condition_variable": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "cstring": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "deque": "cpp", + "exception": "cpp", + "forward_list": "cpp", + "fstream": "cpp", + "functional": "cpp", + "initializer_list": "cpp", + "iomanip": "cpp", + "ios": "cpp", + "iosfwd": "cpp", + "iostream": "cpp", + "istream": "cpp", + "iterator": "cpp", + "limits": "cpp", + "list": "cpp", + "locale": "cpp", + "map": "cpp", + "memory": "cpp", + "mutex": "cpp", + "new": "cpp", + "numeric": "cpp", + "optional": "cpp", + "ostream": "cpp", + "queue": "cpp", + "random": "cpp", + "ratio": "cpp", + "regex": "cpp", + "set": "cpp", + "shared_mutex": "cpp", + "sstream": "cpp", + "stack": "cpp", + "streambuf": "cpp", + "string": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "thread": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "typeinfo": "cpp", + "unordered_map": "cpp", + "unordered_set": "cpp", + "utility": "cpp", + "vector": "cpp", + "pointers": "cpp", + "__functional_03": "cpp", + "multi_span": "cpp", + "*.tcc": "cpp", + "clocale": "cpp", + "compare": "cpp", + "concepts": "cpp", + "memory_resource": "cpp", + "ranges": "cpp", + "stop_token": "cpp" + } +} \ No newline at end of file diff --git a/dotnet/src/BatchEncoder.cs b/dotnet/src/BatchEncoder.cs index 34927e738..5fbbdfdbf 100644 --- a/dotnet/src/BatchEncoder.cs +++ b/dotnet/src/BatchEncoder.cs @@ -59,7 +59,7 @@ public class BatchEncoder : NativeObject /// @param[in] context /// if context is null. /// if the encryption parameters are not valid for batching - /// if scheme is not SchemeType.BFV + /// if scheme is not SchemeType.BFV or SchemeType.BGV public BatchEncoder(SEALContext context) { if (null == context) @@ -68,7 +68,7 @@ public BatchEncoder(SEALContext context) throw new ArgumentException("Encryption parameters are not set correctly"); SEALContext.ContextData contextData = context.FirstContextData; - if (contextData.Parms.Scheme != SchemeType.BFV) + if (contextData.Parms.Scheme != SchemeType.BFV && contextData.Parms.Scheme != SchemeType.BGV) throw new ArgumentException("Unsupported scheme"); if (!contextData.Qualifiers.UsingBatching) throw new ArgumentException("Encryption parameters are not valid for batching"); diff --git a/dotnet/src/Decryptor.cs b/dotnet/src/Decryptor.cs index 1e28f1b32..7ab3fa5f5 100644 --- a/dotnet/src/Decryptor.cs +++ b/dotnet/src/Decryptor.cs @@ -103,7 +103,7 @@ public void Decrypt(Ciphertext encrypted, Plaintext destination) /// /// The ciphertext /// if encrypted is null - /// if the scheme is not BFV + /// if the scheme is not BFV or BGV /// if encrypted is not valid for the encryption parameters /// if encrypted is in NTT form /// if pool is uninitialized diff --git a/dotnet/src/EncryptionParameters.cs b/dotnet/src/EncryptionParameters.cs index 1e55bd940..8f299dd53 100644 --- a/dotnet/src/EncryptionParameters.cs +++ b/dotnet/src/EncryptionParameters.cs @@ -28,7 +28,14 @@ public enum SchemeType : byte /// /// Cheon-Kim-Kim-Song scheme /// - CKKS = 0x2 + CKKS = 0x2, + + /// + /// Brakerski-Gentry-Vaikuntanathan + /// + BGV = 0x3 + + } /// @@ -223,7 +230,7 @@ public IEnumerable CoeffModulus /// of a particular form. /// /// if the value being set is null - /// if scheme is not SchemeType.BFV + /// if scheme is not SchemeType.BFV or SchemeType.BGV public Modulus PlainModulus { get diff --git a/dotnet/src/Encryptor.cs b/dotnet/src/Encryptor.cs index 12138926e..7038940c1 100644 --- a/dotnet/src/Encryptor.cs +++ b/dotnet/src/Encryptor.cs @@ -129,7 +129,7 @@ public void SetSecretKey(SecretKey secretKey) /// /// /// The encryption parameters for the resulting ciphertext correspond to: - /// 1) in BFV, the highest (data) level in the modulus switching chain, + /// 1) in BFV or BGV, the highest (data) level in the modulus switching chain, /// 2) in CKKS, the encryption parameters of the plaintext. /// Dynamic memory allocations in the process are allocated from the memory /// pool pointed to by the given MemoryPoolHandle. @@ -173,7 +173,7 @@ public void Encrypt( /// /// /// The encryption parameters for the resulting ciphertext correspond to: - /// 1) in BFV, the highest (data) level in the modulus switching chain, + /// 1) in BFV or BGV, the highest (data) level in the modulus switching chain, /// 2) in CKKS, the encryption parameters of the plaintext. /// Dynamic memory allocations in the process are allocated from the memory /// pool pointed to by the given MemoryPoolHandle. @@ -337,7 +337,7 @@ public Serializable EncryptZero(MemoryPoolHandle pool = null) /// /// /// The encryption parameters for the resulting ciphertext correspond to: - /// 1) in BFV, the highest (data) level in the modulus switching chain, + /// 1) in BFV or BGV, the highest (data) level in the modulus switching chain, /// 2) in CKKS, the encryption parameters of the plaintext. /// Dynamic memory allocations in the process are allocated from the memory /// pool pointed to by the given MemoryPoolHandle. @@ -385,7 +385,7 @@ public void EncryptSymmetric( /// /// /// The encryption parameters for the resulting ciphertext correspond to: - /// 1) in BFV, the highest (data) level in the modulus switching chain, + /// 1) in BFV or BGV, the highest (data) level in the modulus switching chain, /// 2) in CKKS, the encryption parameters of the plaintext. /// Dynamic memory allocations in the process are allocated from the memory /// pool pointed to by the given MemoryPoolHandle. diff --git a/dotnet/src/Evaluator.cs b/dotnet/src/Evaluator.cs index 67d76a302..6efd561b9 100644 --- a/dotnet/src/Evaluator.cs +++ b/dotnet/src/Evaluator.cs @@ -58,7 +58,7 @@ namespace Microsoft.Research.SEAL /// /// /// NTT form - /// When using the BFV scheme (SchemeType.BFV), all plaintexts and ciphertexts should remain by default in the usual + /// When using the BFV or BGV scheme (SchemeType.BFV or SchemeType.BGV), all plaintexts and ciphertexts should remain by default in the usual /// coefficient representation, i.e., not in NTT form. When using the CKKS scheme (SchemeType.CKKS), all plaintexts /// and ciphertexts should remain by default in NTT form. We call these scheme-specific NTT states the "default NTT /// form". Some functions, such as add, work even if the inputs are not in the default state, but others, such as @@ -740,7 +740,7 @@ public void RescaleTo(Ciphertext encrypted, ParmsId parmsId, Ciphertext destinat /// The ciphertext to overwrite with the multiplication result /// The MemoryPoolHandle pointing to a valid memory pool /// if encrypteds, relinKeys, or destination is null - /// if scheme is not SchemeType.BFV + /// if scheme is not SchemeType.BFV or SchemeType.BGV /// if encrypteds is empty /// if encrypteds or relinKeys are not valid for the encryption /// parameters @@ -784,7 +784,7 @@ public void MultiplyMany(IEnumerable encrypteds, RelinKeys relinKeys /// The relinearization keys /// The MemoryPoolHandle pointing to a valid memory pool /// if encrypted or relinKeys is null - /// if scheme is not SchemeType.BFV + /// if scheme is not SchemeType.BFV or SchemeType.BGV /// if encrypted or relinKeys is not valid for the encryption /// parameters /// if encrypted is not in the default NTT form @@ -816,7 +816,7 @@ public void ExponentiateInplace(Ciphertext encrypted, ulong exponent, /// The relinearization keys /// The ciphertext to overwrite with the power /// The MemoryPoolHandle pointing to a valid memory pool - /// if scheme is not SchemeType.BFV + /// if scheme is not SchemeType.BFV or SchemeType.BGV /// if encrypted or relinKeys is not valid for the encryption /// parameters /// if encrypted is not in the default NTT form @@ -1148,7 +1148,7 @@ public void TransformFromNTT(Ciphertext encryptedNTT, Ciphertext destination) /// The desired Galois automorphism is given as a Galois element, and must be an odd integer in the interval /// [1, M-1], where M = 2*N, and N = PolyModulusDegree. Used with batching, a Galois element 3^i % M corresponds /// to a cyclic row rotation i steps to the left, and a Galois element 3^(N/2-i) % M corresponds to a cyclic row - /// rotation i steps to the right. The Galois element M-1 corresponds to a column rotation (row swap) in BFV, + /// rotation i steps to the right. The Galois element M-1 corresponds to a column rotation (row swap) in BFV or BGV, /// and complex conjugation in CKKS. In the polynomial view (not batching), a Galois automorphism by a Galois /// element p changes Enc(plain(x)) to Enc(plain(x^p)). /// @@ -1188,7 +1188,7 @@ public void ApplyGaloisInplace(Ciphertext encrypted, uint galoisElt, /// The desired Galois automorphism is given as a Galois element, and must be an odd integer in the interval /// [1, M-1], where M = 2*N, and N = PolyModulusDegree. Used with batching, a Galois element 3^i % M corresponds /// to a cyclic row rotation i steps to the left, and a Galois element 3^(N/2-i) % M corresponds to a cyclic row - /// rotation i steps to the right. The Galois element M-1 corresponds to a column rotation (row swap) in BFV, + /// rotation i steps to the right. The Galois element M-1 corresponds to a column rotation (row swap) in BFV or BGV, /// and complex conjugation in CKKS. In the polynomial view (not batching), a Galois automorphism by a Galois /// element p changes Enc(plain(x)) to Enc(plain(x^p)). /// @@ -1232,7 +1232,7 @@ public void ApplyGalois(Ciphertext encrypted, uint galoisElt, GaloisKeys galoisK /// Rotates plaintext matrix rows cyclically. /// /// - /// When batching is used with the BFV scheme, this function rotates the encrypted plaintext matrix rows + /// When batching is used with the BFV or BGV scheme, this function rotates the encrypted plaintext matrix rows /// cyclically to the left (steps > 0) or to the right (steps < 0). Since the size of the batched matrix /// is 2-by-(N/2), where N is the degree of the polynomial modulus, the number of steps to rotate must have /// absolute value at most N/2-1. Dynamic memory allocations in the process are allocated from the memory pool @@ -1244,7 +1244,7 @@ public void ApplyGalois(Ciphertext encrypted, uint galoisElt, GaloisKeys galoisK /// The MemoryPoolHandle pointing to a valid memory pool /// if encrypted or galoisKeys is null /// if the encryption parameters do not support batching - /// if scheme is not SchemeType.BFV + /// if scheme is not SchemeType.BFV or SchemeType.BGV /// if encrypted or galoisKeys is not valid for the encryption /// parameters /// if galoisKeys do not correspond to the top level parameters in the @@ -1266,7 +1266,7 @@ public void RotateRowsInplace(Ciphertext encrypted, /// Rotates plaintext matrix rows cyclically. /// /// - /// When batching is used with the BFV scheme, this function rotates the encrypted plaintext matrix rows + /// When batching is used with the BFV or BGV scheme, this function rotates the encrypted plaintext matrix rows /// cyclically to the left (steps > 0) or to the right (steps < 0) and writes the result to the /// destination parameter. Since the size of the batched matrix is 2-by-(N/2), where N is the degree of the /// polynomial modulus, the number of steps to rotate must have absolute value at most N/2-1. Dynamic memory @@ -1279,7 +1279,7 @@ public void RotateRowsInplace(Ciphertext encrypted, /// The MemoryPoolHandle pointing to a valid memory pool /// if encrypted, galoisKeys, or destination is null /// if the encryption parameters do not support batching - /// if scheme is not SchemeType.BFV + /// if scheme is not SchemeType.BFV or SchemeType.BGV /// if encrypted or galoisKeys is not valid for the encryption /// parameters /// if galoisKeys do not correspond to the top level parameters in the @@ -1313,7 +1313,7 @@ public void RotateRows(Ciphertext encrypted, int steps, GaloisKeys galoisKeys, /// Rotates plaintext matrix columns cyclically. /// /// - /// When batching is used with the BFV scheme, this function rotates the encrypted plaintext matrix columns + /// When batching is used with the BFV or BGV scheme, this function rotates the encrypted plaintext matrix columns /// cyclically. Since the size of the batched matrix is 2-by-(N/2), where N is the degree of the polynomial /// modulus, this means simply swapping the two rows. Dynamic memory allocations in the process are allocated /// from the memory pool pointed to by the given MemoryPoolHandle. @@ -1323,7 +1323,7 @@ public void RotateRows(Ciphertext encrypted, int steps, GaloisKeys galoisKeys, /// The MemoryPoolHandle pointing to a valid memory pool /// if encrypted or galoisKeys is null /// if the encryption parameters do not support batching - /// if scheme is not SchemeType.BFV + /// if scheme is not SchemeType.BFV or SchemeType.BGV /// if encrypted or galoisKeys is not valid for the encryption /// parameters /// if galoisKeys do not correspond to the top level parameters in the @@ -1343,7 +1343,7 @@ public void RotateColumnsInplace(Ciphertext encrypted, GaloisKeys galoisKeys, Me /// Rotates plaintext matrix columns cyclically. /// /// - /// When batching is used with the BFV scheme, this function rotates the encrypted plaintext matrix columns + /// When batching is used with the BFV or BGV scheme, this function rotates the encrypted plaintext matrix columns /// cyclically, and writes the result to the destination parameter. Since the size of the batched matrix is /// 2-by-(N/2), where N is the degree of the polynomial modulus, this means simply swapping the two rows. /// Dynamic memory allocations in the process are allocated from the memory pool pointed to by the given @@ -1355,7 +1355,7 @@ public void RotateColumnsInplace(Ciphertext encrypted, GaloisKeys galoisKeys, Me /// The MemoryPoolHandle pointing to a valid memory pool /// if encrypted, galoisKeys, or destination is null /// if the encryption parameters do not support batching - /// if scheme is not SchemeType.BFV + /// if scheme is not SchemeType.BFV or SchemeType.BGV /// if encrypted or galoisKeys is not valid for the encryption /// parameters /// if galoisKeys do not correspond to the top level parameters in the diff --git a/dotnet/src/KeyGenerator.cs b/dotnet/src/KeyGenerator.cs index d56a7a95b..d9770a1ef 100644 --- a/dotnet/src/KeyGenerator.cs +++ b/dotnet/src/KeyGenerator.cs @@ -200,7 +200,7 @@ public Serializable CreateRelinKeys() /// The Galois keys to overwrite with the generated /// Galois keys /// if the encryption parameters - /// do not support batching and scheme is SchemeType.BFV + /// do not support batching and scheme is SchemeType.BFV or SchemeType.BGV /// if the encryption /// parameters do not support keyswitching /// if the Galois elements are not valid @@ -250,7 +250,7 @@ public void CreateGaloisKeys(IEnumerable galoisElts, out GaloisKeys destin /// The Galois elements for which to generate keys /// if galoisElts is null /// if the encryption parameters - /// do not support batching and scheme is SchemeType.BFV + /// do not support batching and scheme is SchemeType.BFV or SchemeType.BGV /// if the encryption /// parameters do not support keyswitching /// if the Galois elements are not valid @@ -279,7 +279,7 @@ public Serializable CreateGaloisKeys(IEnumerable galoisElts) /// The user needs to give as input a vector of desired Galois rotation step /// counts, where negative step counts correspond to rotations to the right /// and positive step counts correspond to rotations to the left. A step - /// count of zero can be used to indicate a column rotation in the BFV scheme + /// count of zero can be used to indicate a column rotation in the BFV/BGV scheme /// and complex conjugation in the CKKS scheme. /// /// @@ -288,7 +288,7 @@ public Serializable CreateGaloisKeys(IEnumerable galoisElts) /// Galois keys /// if steps is null /// if the encryption parameters - /// do not support batching and scheme is SchemeType.BFV + /// do not support batching and scheme is SchemeType.BFV or SchemeType.BGV /// if the encryption /// parameters do not support keyswitching /// if the step counts are not valid @@ -323,14 +323,14 @@ public void CreateGaloisKeys(IEnumerable steps, out GaloisKeys destination) /// The user needs to give as input a vector of desired Galois rotation step /// counts, where negative step counts correspond to rotations to the right /// and positive step counts correspond to rotations to the left. A step - /// count of zero can be used to indicate a column rotation in the BFV scheme + /// count of zero can be used to indicate a column rotation in the BFV or BGV scheme /// and complex conjugation in the CKKS scheme. /// /// /// The rotation step counts for which to generate keys /// if steps is null /// if the encryption parameters - /// do not support batching and scheme is SchemeType.BFV + /// do not support batching and scheme is SchemeType.BFV or SchemeType.BGV /// if the encryption /// parameters do not support keyswitching /// if the step counts are not valid @@ -365,7 +365,7 @@ public Serializable CreateGaloisKeys(IEnumerable steps) /// The Galois keys to overwrite with the generated /// Galois keys /// if the encryption parameters - /// do not support batching and scheme is SchemeType.BFV + /// do not support batching and scheme is SchemeType.BFV or SchemeType.BGV /// if the encryption /// parameters do not support keyswitching public void CreateGaloisKeys(out GaloisKeys destination) @@ -399,7 +399,7 @@ public void CreateGaloisKeys(out GaloisKeys destination) /// /// /// if the encryption parameters - /// do not support batching and scheme is SchemeType.BFV + /// do not support batching and scheme is SchemeType.BFV or SchemeType.BGV /// if the encryption /// parameters do not support keyswitching public Serializable CreateGaloisKeys() diff --git a/dotnet/src/Modulus.cs b/dotnet/src/Modulus.cs index b8310a289..b78eca324 100644 --- a/dotnet/src/Modulus.cs +++ b/dotnet/src/Modulus.cs @@ -469,6 +469,52 @@ static public IEnumerable BFVDefault( return result; } + /// + /// Returns a default coefficient modulus for the BGV scheme that guarantees + /// a given security level when using a given PolyModulusDegree, according + /// to the HomomorphicEncryption.org security standard. + /// + /// + /// + /// Returns a default coefficient modulus for the BGV scheme that guarantees + /// a given security level when using a given PolyModulusDegree, according + /// to the HomomorphicEncryption.org security standard. Note that all security + /// guarantees are lost if the output is used with encryption parameters with + /// a mismatching value for the PolyModulusDegree. Currently, we just use the + /// parameters of BFV. + /// + /// + /// The coefficient modulus returned by this function will not perform well + /// if used with the CKKS scheme. + /// + /// + /// The value of the PolyModulusDegree + /// encryption parameter + /// The desired standard security level + /// if polyModulusDegree is not + /// a power-of-two or is too large + /// if secLevel is SecLevelType.None + static public IEnumerable BGVDefault( + ulong polyModulusDegree, SecLevelType secLevel = SecLevelType.TC128) + { + List result = null; + + ulong length = 0; + NativeMethods.CoeffModulus_BFVDefault(polyModulusDegree, (int)secLevel, ref length, null); + + IntPtr[] coeffArray = new IntPtr[length]; + NativeMethods.CoeffModulus_BFVDefault(polyModulusDegree, (int)secLevel, ref length, coeffArray); + + result = new List(checked((int)length)); + foreach (IntPtr sm in coeffArray) + { + result.Add(new Modulus(sm)); + } + + return result; + } + + /// /// Returns a custom coefficient modulus suitable for use with the specified /// PolyModulusDegree. diff --git a/dotnet/src/NativeMethods.cs b/dotnet/src/NativeMethods.cs index 3502705a1..a5466781a 100644 --- a/dotnet/src/NativeMethods.cs +++ b/dotnet/src/NativeMethods.cs @@ -196,7 +196,7 @@ internal static extern void CoeffModulus_BFVDefault( int secLevel, ref ulong length, [MarshalAs(UnmanagedType.LPArray)] IntPtr[] coeffArray); - + [DllImport(sealc, PreserveSig = false)] internal static extern void CoeffModulus_Create( ulong polyModulusDegree, diff --git a/dotnet/src/Plaintext.cs b/dotnet/src/Plaintext.cs index c89ac7921..7c5a94309 100644 --- a/dotnet/src/Plaintext.cs +++ b/dotnet/src/Plaintext.cs @@ -32,7 +32,7 @@ namespace Microsoft.Research.SEAL /// providing the desired capacity to the constructor as an extra argument, or /// by calling the reserve function at any time. /// - /// When the scheme is SchemeType.BFV each coefficient of a plaintext is + /// When the scheme is SchemeType.BFV or SchemeType.BGV, each coefficient of a plaintext is /// a 64-bit word, but when the scheme is SchemeType.CKKS the plaintext is /// by default stored in an NTT transformed form with respect to each of the /// primes in the coefficient modulus. Thus, the size of the allocation that diff --git a/dotnet/tests/CiphertextTests.cs b/dotnet/tests/CiphertextTests.cs index 7d7d07ce9..de143d0ea 100644 --- a/dotnet/tests/CiphertextTests.cs +++ b/dotnet/tests/CiphertextTests.cs @@ -62,7 +62,39 @@ public void Create3Test() } [TestMethod] - public void ResizeTest() + public void Create4Test() + { + SEALContext context = GlobalContext.BGVContext; + ParmsId parms = context.FirstParmsId; + + Assert.AreNotEqual(0ul, parms.Block[0]); + Assert.AreNotEqual(0ul, parms.Block[1]); + Assert.AreNotEqual(0ul, parms.Block[2]); + Assert.AreNotEqual(0ul, parms.Block[3]); + + Ciphertext cipher = new Ciphertext(context, parms); + + Assert.AreEqual(parms, cipher.ParmsId); + } + + [TestMethod] + public void Create5Test() + { + SEALContext context = GlobalContext.BGVContext; + ParmsId parms = context.FirstParmsId; + + Assert.AreNotEqual(0ul, parms.Block[0]); + Assert.AreNotEqual(0ul, parms.Block[1]); + Assert.AreNotEqual(0ul, parms.Block[2]); + Assert.AreNotEqual(0ul, parms.Block[3]); + + Ciphertext cipher = new Ciphertext(context, parms, sizeCapacity: 5); + + Assert.AreEqual(5ul, cipher.SizeCapacity); + } + + [TestMethod] + public void BFVResizeTest() { SEALContext context = GlobalContext.BFVContext; ParmsId parms = context.FirstParmsId; @@ -97,6 +129,42 @@ public void ResizeTest() Assert.AreEqual(6ul, cipher5.SizeCapacity); } + [TestMethod] + public void BGVResizeTest() + { + SEALContext context = GlobalContext.BGVContext; + ParmsId parms = context.FirstParmsId; + + Ciphertext cipher = new Ciphertext(context, parms); + + Assert.AreEqual(2ul, cipher.SizeCapacity); + + cipher.Reserve(context, parms, sizeCapacity: 10); + Assert.AreEqual(10ul, cipher.SizeCapacity); + + Ciphertext cipher2 = new Ciphertext(); + + Assert.AreEqual(0ul, cipher2.SizeCapacity); + + cipher2.Reserve(context, 5); + Assert.AreEqual(5ul, cipher2.SizeCapacity); + + Ciphertext cipher3 = new Ciphertext(); + + Assert.AreEqual(0ul, cipher3.SizeCapacity); + + cipher3.Reserve(4); + Assert.AreEqual(0ul, cipher3.SizeCapacity); + + Ciphertext cipher4 = new Ciphertext(context); + cipher4.Resize(context, context.GetContextData(context.FirstParmsId).NextContextData.ParmsId, 4); + Assert.AreEqual(10ul, cipher.SizeCapacity); + + Ciphertext cipher5 = new Ciphertext(context); + cipher5.Resize(context, 6ul); + Assert.AreEqual(6ul, cipher5.SizeCapacity); + } + [TestMethod] public void ReleaseTest() { @@ -110,7 +178,7 @@ public void ReleaseTest() } [TestMethod] - public void SaveLoadTest() + public void BFVSaveLoadTest() { SEALContext context = GlobalContext.BFVContext; KeyGenerator keygen = new KeyGenerator(context); @@ -154,7 +222,51 @@ public void SaveLoadTest() } [TestMethod] - public void IndexTest() + public void BGVSaveLoadTest() + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Plaintext plain = new Plaintext("2x^3 + 4x^2 + 5x^1 + 6"); + Ciphertext cipher = new Ciphertext(); + + encryptor.Encrypt(plain, cipher); + + Assert.AreEqual(2ul, cipher.Size); + Assert.AreEqual(8192ul, cipher.PolyModulusDegree); + Assert.AreEqual(4ul, cipher.CoeffModulusSize); + + Ciphertext loaded = new Ciphertext(); + + Assert.AreEqual(0ul, loaded.Size); + Assert.AreEqual(0ul, loaded.PolyModulusDegree); + Assert.AreEqual(0ul, loaded.CoeffModulusSize); + + using (MemoryStream mem = new MemoryStream()) + { + cipher.Save(mem); + + mem.Seek(offset: 0, loc: SeekOrigin.Begin); + + loaded.Load(context, mem); + } + + Assert.AreEqual(2ul, loaded.Size); + Assert.AreEqual(8192ul, loaded.PolyModulusDegree); + Assert.AreEqual(4ul, loaded.CoeffModulusSize); + Assert.IsTrue(ValCheck.IsValidFor(loaded, context)); + + ulong ulongCount = cipher.Size * cipher.PolyModulusDegree * cipher.CoeffModulusSize; + for (ulong i = 0; i < ulongCount; i++) + { + Assert.AreEqual(cipher[i], loaded[i]); + } + } + + [TestMethod] + public void BFVIndexTest() { SEALContext context = GlobalContext.BFVContext; KeyGenerator keygen = new KeyGenerator(context); @@ -176,7 +288,29 @@ public void IndexTest() } [TestMethod] - public void IndexRangeFail1Test() + public void BGVIndexTest() + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Plaintext plain = new Plaintext("1"); + Ciphertext cipher = new Ciphertext(); + + encryptor.Encrypt(plain, cipher); + + Assert.AreEqual(2ul, cipher.Size); + Assert.AreNotEqual(0ul, cipher[0, 0]); + Assert.AreNotEqual(0ul, cipher[0, 1]); + Assert.AreNotEqual(0ul, cipher[0, 2]); + Assert.AreNotEqual(0ul, cipher[1, 0]); + Assert.AreNotEqual(0ul, cipher[1, 1]); + Assert.AreNotEqual(0ul, cipher[1, 2]); + } + + [TestMethod] + public void BFVIndexRangeFail1Test() { SEALContext context = GlobalContext.BFVContext; KeyGenerator keygen = new KeyGenerator(context); @@ -196,7 +330,27 @@ public void IndexRangeFail1Test() } [TestMethod] - public void IndexRangeFail2Test() + public void BGVIndexRangeFail1Test() + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Plaintext plain = new Plaintext("1"); + Ciphertext cipher = new Ciphertext(context); + + encryptor.Encrypt(plain, cipher); + + Utilities.AssertThrows(() => + { + // We only have 2 polynomials + ulong data = cipher[2, 0]; + }); + } + + [TestMethod] + public void BFVIndexRangeFail2Test() { SEALContext context = GlobalContext.BFVContext; KeyGenerator keygen = new KeyGenerator(context); @@ -221,7 +375,32 @@ public void IndexRangeFail2Test() } [TestMethod] - public void IndexRangeFail3Test() + public void BGVIndexRangeFail2Test() + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Plaintext plain = new Plaintext("1"); + Ciphertext cipher = new Ciphertext(); + + encryptor.Encrypt(plain, cipher); + + // We only have 2 polynomials + ulong data = cipher[1, 0]; + + // We should have 8192 coefficients + data = cipher[0, 32767]; // This will succeed + + Utilities.AssertThrows(() => + { + data = cipher[0, 32768]; // This will fail + }); + } + + [TestMethod] + public void BFVIndexRangeFail3Test() { SEALContext context = GlobalContext.BFVContext; KeyGenerator keygen = new KeyGenerator(context); @@ -238,6 +417,24 @@ public void IndexRangeFail3Test() Utilities.AssertThrows(() => cipher[65536] = 10ul); } + [TestMethod] + public void BGVIndexRangeFail3Test() + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Plaintext plain = new Plaintext("1"); + Ciphertext cipher = new Ciphertext(); + + encryptor.Encrypt(plain, cipher); + ulong data = 0; + + Utilities.AssertThrows(() => data = cipher[65536]); + Utilities.AssertThrows(() => cipher[65536] = 10ul); + } + [TestMethod] public void ScaleTest() { @@ -295,7 +492,7 @@ public void ScaleTest() } [TestMethod] - public void ExceptionsTest() + public void BFVExceptionsTest() { SEALContext context = GlobalContext.BFVContext; MemoryPoolHandle pool = MemoryManager.GetPool(MMProfOpt.ForceGlobal); @@ -341,5 +538,53 @@ public void ExceptionsTest() Utilities.AssertThrows(() => cipher.Load(null, new MemoryStream())); Utilities.AssertThrows(() => cipher.Load(context, null)); } + + [TestMethod] + public void BGVExceptionsTest() + { + SEALContext context = GlobalContext.BGVContext; + MemoryPoolHandle pool = MemoryManager.GetPool(MMProfOpt.ForceGlobal); + MemoryPoolHandle poolu = new MemoryPoolHandle(); + Ciphertext cipher = new Ciphertext(); + Ciphertext copy = null; + + Utilities.AssertThrows(() => copy = new Ciphertext((Ciphertext)null)); + + Utilities.AssertThrows(() => cipher = new Ciphertext(context, null, pool)); + Utilities.AssertThrows(() => cipher = new Ciphertext(null, context.FirstParmsId, pool)); + Utilities.AssertThrows(() => cipher = new Ciphertext(context, ParmsId.Zero, pool)); + + Utilities.AssertThrows(() => cipher = new Ciphertext((SEALContext)null, poolu)); + Utilities.AssertThrows(() => cipher = new Ciphertext(context, poolu)); + + Utilities.AssertThrows(() => cipher = new Ciphertext(context, null, 6ul)); + Utilities.AssertThrows(() => cipher = new Ciphertext(null, context.FirstParmsId, 6ul, poolu)); + Utilities.AssertThrows(() => cipher = new Ciphertext(context, ParmsId.Zero, 6ul, poolu)); + + Utilities.AssertThrows(() => cipher.Reserve(context, null, 10ul)); + Utilities.AssertThrows(() => cipher.Reserve(null, ParmsId.Zero, 10ul)); + Utilities.AssertThrows(() => cipher.Reserve(context, ParmsId.Zero, 10ul)); + + Utilities.AssertThrows(() => cipher.Reserve(null, 10ul)); + + Utilities.AssertThrows(() => cipher.Resize(context, null, 10ul)); + Utilities.AssertThrows(() => cipher.Resize(null, ParmsId.Zero, 10ul)); + Utilities.AssertThrows(() => cipher.Resize(context, ParmsId.Zero, 10ul)); + + Utilities.AssertThrows(() => cipher.Resize(null, 10ul)); + + Utilities.AssertThrows(() => cipher.Set(null)); + + Utilities.AssertThrows(() => ValCheck.IsValidFor(cipher, null)); + + Utilities.AssertThrows(() => cipher.Save(null)); + + Utilities.AssertThrows(() => cipher.UnsafeLoad(context, null)); + Utilities.AssertThrows(() => cipher.UnsafeLoad(null, new MemoryStream())); + Utilities.AssertThrows(() => cipher.UnsafeLoad(context, new MemoryStream())); + + Utilities.AssertThrows(() => cipher.Load(null, new MemoryStream())); + Utilities.AssertThrows(() => cipher.Load(context, null)); + } } } diff --git a/dotnet/tests/DecryptorTests.cs b/dotnet/tests/DecryptorTests.cs index dcf3e1302..41b72db71 100644 --- a/dotnet/tests/DecryptorTests.cs +++ b/dotnet/tests/DecryptorTests.cs @@ -8,7 +8,7 @@ namespace SEALNetTest { [TestClass] - public class DecryptorTests + public class BFVDecryptorTests { SEALContext context_; KeyGenerator keyGen_; @@ -91,4 +91,89 @@ public void ExceptionsTest() Utilities.AssertThrows(() => decryptor.InvariantNoiseBudget(null)); } } + + [TestClass] + public class BGVDecryptorTests + { + SEALContext context_; + KeyGenerator keyGen_; + SecretKey secretKey_; + PublicKey publicKey_; + + [TestInitialize] + public void TestInit() + { + context_ = GlobalContext.BGVContext; + keyGen_ = new KeyGenerator(context_); + secretKey_ = keyGen_.SecretKey; + keyGen_.CreatePublicKey(out publicKey_); + } + + [TestMethod] + public void CreateTest() + { + Decryptor decryptor = new Decryptor(context_, secretKey_); + + Assert.IsNotNull(decryptor); + } + + [TestMethod] + public void DecryptTest() + { + Encryptor encryptor = new Encryptor(context_, publicKey_); + Decryptor decryptor = new Decryptor(context_, secretKey_); + + Plaintext plain = new Plaintext("1x^1 + 2"); + Ciphertext cipher = new Ciphertext(); + + Assert.AreEqual(0ul, cipher.Size); + + encryptor.Encrypt(plain, cipher); + + Assert.AreEqual(2ul, cipher.Size); + + Plaintext decrypted = new Plaintext(); + Assert.AreEqual(0ul, decrypted.CoeffCount); + + decryptor.Decrypt(cipher, decrypted); + + Assert.AreEqual(2ul, decrypted.CoeffCount); + Assert.AreEqual(2ul, decrypted[0]); + Assert.AreEqual(1ul, decrypted[1]); + } + + [TestMethod] + public void InvariantNoiseBudgetTest() + { + Encryptor encryptor = new Encryptor(context_, publicKey_); + Decryptor decryptor = new Decryptor(context_, secretKey_); + + Plaintext plain = new Plaintext("1"); + Ciphertext cipher = new Ciphertext(); + + encryptor.Encrypt(plain, cipher); + + int budget = decryptor.InvariantNoiseBudget(cipher); + Assert.IsTrue(budget > 80); + } + + [TestMethod] + public void ExceptionsTest() + { + Decryptor decryptor = new Decryptor(context_, secretKey_); + SecretKey secret = new SecretKey(); + Ciphertext cipher = new Ciphertext(); + Plaintext plain = new Plaintext(); + + Utilities.AssertThrows(() => decryptor = new Decryptor(context_, null)); + Utilities.AssertThrows(() => decryptor = new Decryptor(null, secretKey_)); + Utilities.AssertThrows(() => decryptor = new Decryptor(context_, secret)); + + Utilities.AssertThrows(() => decryptor.Decrypt(cipher, null)); + Utilities.AssertThrows(() => decryptor.Decrypt(null, plain)); + Utilities.AssertThrows(() => decryptor.Decrypt(cipher, plain)); + + Utilities.AssertThrows(() => decryptor.InvariantNoiseBudget(null)); + } + } } diff --git a/dotnet/tests/EncryptionParameterQualifiersTests.cs b/dotnet/tests/EncryptionParameterQualifiersTests.cs index a026da1be..8c3ade75f 100644 --- a/dotnet/tests/EncryptionParameterQualifiersTests.cs +++ b/dotnet/tests/EncryptionParameterQualifiersTests.cs @@ -50,10 +50,21 @@ public void PropertiesTest() Assert.AreEqual(SecLevelType.TC128, qualifiers.SecLevel); Assert.IsTrue(qualifiers.UsingDescendingModulusChain); Assert.IsTrue(qualifiers.UsingNTT); + + SEALContext context3 = GlobalContext.BGVContext; + + Assert.IsTrue(context.FirstContextData.Qualifiers.ParametersSet); + Assert.IsTrue(context.FirstContextData.Qualifiers.UsingBatching); + Assert.IsTrue(context.FirstContextData.Qualifiers.UsingFastPlainLift); + Assert.IsTrue(context.FirstContextData.Qualifiers.UsingFFT); + Assert.AreEqual(SecLevelType.TC128, context.FirstContextData.Qualifiers.SecLevel); + Assert.IsFalse(context.FirstContextData.Qualifiers.UsingDescendingModulusChain); + Assert.IsTrue(context.FirstContextData.Qualifiers.UsingNTT); + Assert.IsTrue(context.UsingKeyswitching); } [TestMethod] - public void ParameterErrorTest() + public void BFVParameterErrorTest() { SEALContext context = GlobalContext.BFVContext; EncryptionParameterQualifiers qualifiers = context.FirstContextData.Qualifiers; @@ -74,12 +85,42 @@ public void ParameterErrorTest() } [TestMethod] - public void ExceptionsTest() + public void BGVParameterErrorTest() + { + SEALContext context = GlobalContext.BGVContext; + EncryptionParameterQualifiers qualifiers = context.FirstContextData.Qualifiers; + + Assert.AreEqual(qualifiers.ParametersErrorName(), "success"); + Assert.AreEqual(qualifiers.ParametersErrorMessage(), "valid"); + + EncryptionParameters encParam = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 127, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30 }) + }; + context = new SEALContext(encParam, expandModChain: true, secLevel: SecLevelType.None); + qualifiers = context.FirstContextData.Qualifiers; + Assert.AreEqual(qualifiers.ParametersErrorName(), "invalid_poly_modulus_degree_non_power_of_two"); + Assert.AreEqual(qualifiers.ParametersErrorMessage(), "poly_modulus_degree is not a power of two"); + } + + [TestMethod] + public void BFVExceptionsTest() { EncryptionParameterQualifiers epq1 = GlobalContext.BFVContext.FirstContextData.Qualifiers; EncryptionParameterQualifiers epq2 = null; Utilities.AssertThrows(() => epq2 = new EncryptionParameterQualifiers(null)); } + + [TestMethod] + public void BGVExceptionsTest() + { + EncryptionParameterQualifiers epq1 = GlobalContext.BGVContext.FirstContextData.Qualifiers; + EncryptionParameterQualifiers epq2 = null; + + Utilities.AssertThrows(() => epq2 = new EncryptionParameterQualifiers(null)); + } } } diff --git a/dotnet/tests/EncryptionParametersTests.cs b/dotnet/tests/EncryptionParametersTests.cs index ee4778a22..5d37dcd9d 100644 --- a/dotnet/tests/EncryptionParametersTests.cs +++ b/dotnet/tests/EncryptionParametersTests.cs @@ -32,18 +32,36 @@ public void CreateTest() Assert.IsNotNull(encParams3); Assert.AreEqual(SchemeType.CKKS, encParams3.Scheme); + EncryptionParameters encParams4 = new EncryptionParameters(SchemeType.BGV); + + Assert.IsNotNull(encParams4); + Assert.AreEqual(SchemeType.BGV, encParams4.Scheme); + EncryptionParameters copy = new EncryptionParameters(encParams); Assert.AreEqual(SchemeType.BFV, copy.Scheme); Assert.AreEqual(encParams, copy); Assert.AreEqual(encParams.GetHashCode(), copy.GetHashCode()); + EncryptionParameters copy4 = new EncryptionParameters(encParams4); + + Assert.AreEqual(SchemeType.BGV, copy4.Scheme); + Assert.AreEqual(encParams4, copy4); + Assert.AreEqual(encParams4.GetHashCode(), copy4.GetHashCode()); + EncryptionParameters third = new EncryptionParameters(SchemeType.CKKS); third.Set(copy); Assert.AreEqual(SchemeType.BFV, third.Scheme); Assert.AreEqual(encParams, third); Assert.AreEqual(encParams.GetHashCode(), third.GetHashCode()); + + EncryptionParameters forth = new EncryptionParameters(SchemeType.CKKS); + forth.Set(copy4); + + Assert.AreEqual(SchemeType.BGV, forth.Scheme); + Assert.AreEqual(encParams4, forth); + Assert.AreEqual(encParams4.GetHashCode(), forth.GetHashCode()); } [TestMethod] @@ -63,7 +81,7 @@ public void SetPlainModulusCKKSTest() } [TestMethod] - public void CoeffModulusTest() + public void BFVCoeffModulusTest() { EncryptionParameters encParams = new EncryptionParameters(SchemeType.BFV); @@ -83,6 +101,27 @@ public void CoeffModulusTest() Assert.AreEqual(0x1ffffe0001ul, newCoeffs[2].Value); } + [TestMethod] + public void BGVCoeffModulusTest() + { + EncryptionParameters encParams = new EncryptionParameters(SchemeType.BGV); + + Assert.IsNotNull(encParams); + + List coeffs = new List(encParams.CoeffModulus); + Assert.IsNotNull(coeffs); + Assert.AreEqual(0, coeffs.Count); + + encParams.CoeffModulus = CoeffModulus.BGVDefault(4096); + + List newCoeffs = new List(encParams.CoeffModulus); + Assert.IsNotNull(newCoeffs); + Assert.AreEqual(3, newCoeffs.Count); + Assert.AreEqual(0xffffee001ul, newCoeffs[0].Value); + Assert.AreEqual(0xffffc4001ul, newCoeffs[1].Value); + Assert.AreEqual(0x1ffffe0001ul, newCoeffs[2].Value); + } + [TestMethod] public void SaveLoadTest() { @@ -121,6 +160,7 @@ public void SaveLoadTest() Assert.AreEqual(coeffModulus[1], loadedCoeffModulus[1]); }; save_load_test(SchemeType.BFV); + save_load_test(SchemeType.BGV); save_load_test(SchemeType.CKKS); } @@ -136,7 +176,16 @@ public void EqualsTest() EncryptionParameters parms2 = new EncryptionParameters(SchemeType.CKKS); + EncryptionParameters parms3 = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 8, + PlainModulus = new Modulus(257), + CoeffModulus = CoeffModulus.Create(8, new int[] { 40, 40 }) + }; + Assert.AreNotEqual(parms, parms2); + Assert.AreNotEqual(parms, parms3); + Assert.AreNotEqual(parms2, parms3); Assert.IsFalse(parms.Equals(null)); } @@ -150,6 +199,14 @@ public void ExceptionsTest() Utilities.AssertThrows(() => parms.Save(null)); Utilities.AssertThrows(() => parms.Load(null)); Utilities.AssertThrows(() => parms.Load(new MemoryStream())); + + EncryptionParameters parms2 = new EncryptionParameters(SchemeType.BGV); + Utilities.AssertThrows(() => parms2 = new EncryptionParameters(null)); + Utilities.AssertThrows(() => parms2.Set(null)); + Utilities.AssertThrows(() => parms2.CoeffModulus = null); + Utilities.AssertThrows(() => parms2.Save(null)); + Utilities.AssertThrows(() => parms2.Load(null)); + Utilities.AssertThrows(() => parms2.Load(new MemoryStream())); } } } diff --git a/dotnet/tests/EncryptorTests.cs b/dotnet/tests/EncryptorTests.cs index e1bf01090..d6d5690fb 100644 --- a/dotnet/tests/EncryptorTests.cs +++ b/dotnet/tests/EncryptorTests.cs @@ -73,6 +73,71 @@ public void EncryptTest() Assert.IsNotNull(encryptor); + Plaintext plain = new Plaintext("1x^1 + 1"); + encryptor.EncryptSymmetric(plain).Save(stream); + stream.Seek(0, SeekOrigin.Begin); + Ciphertext cipher = new Ciphertext(); + cipher.Load(context, stream); + Assert.IsNotNull(cipher); + Assert.AreEqual(2ul, cipher.Size); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + SecretKey secretKey = keygen.SecretKey; + Encryptor encryptor = new Encryptor(context, publicKey, secretKey); + + Assert.IsNotNull(encryptor); + + Plaintext plain = new Plaintext("1x^1 + 1"); + Ciphertext cipher = new Ciphertext(); + Assert.AreEqual(0ul, cipher.Size); + encryptor.Encrypt(plain, cipher); + Assert.IsNotNull(cipher); + Assert.AreEqual(2ul, cipher.Size); + } + using (MemoryStream stream = new MemoryStream()) + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + Encryptor encryptor = new Encryptor(context, publicKey); + + Assert.IsNotNull(encryptor); + + Plaintext plain = new Plaintext("1x^1 + 1"); + encryptor.Encrypt(plain).Save(stream); + stream.Seek(0, SeekOrigin.Begin); + Ciphertext cipher = new Ciphertext(); + cipher.Load(context, stream); + Assert.IsNotNull(cipher); + Assert.AreEqual(2ul, cipher.Size); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + SecretKey secretKey = keygen.SecretKey; + Encryptor encryptor = new Encryptor(context, secretKey); + + Assert.IsNotNull(encryptor); + + Plaintext plain = new Plaintext("1x^1 + 1"); + Ciphertext cipher = new Ciphertext(); + Assert.AreEqual(0ul, cipher.Size); + encryptor.EncryptSymmetric(plain, cipher); + Assert.IsNotNull(cipher); + Assert.AreEqual(2ul, cipher.Size); + } + using (MemoryStream stream = new MemoryStream()) + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + SecretKey secretKey = keygen.SecretKey; + Encryptor encryptor = new Encryptor(context, secretKey); + + Assert.IsNotNull(encryptor); + Plaintext plain = new Plaintext("1x^1 + 1"); encryptor.EncryptSymmetric(plain).Save(stream); stream.Seek(0, SeekOrigin.Begin); @@ -191,6 +256,111 @@ public void EncryptZeroTest() Assert.IsTrue(plain.IsZero); } } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + SecretKey secretKey = keygen.SecretKey; + + Decryptor decryptor = new Decryptor(context, secretKey); + Assert.IsNotNull(decryptor); + + Ciphertext cipher = new Ciphertext(); + Plaintext plain = new Plaintext(); + ParmsId nextParms = context.FirstContextData.NextContextData.ParmsId; + + { + Encryptor encryptor = new Encryptor(context, publicKey); + Assert.IsNotNull(encryptor); + + encryptor.EncryptZero(cipher); + Assert.IsFalse(cipher.IsNTTForm); + Assert.IsFalse(cipher.IsTransparent); + Assert.AreEqual(cipher.Scale, 1.0, double.Epsilon); + decryptor.Decrypt(cipher, plain); + Assert.IsTrue(plain.IsZero); + + encryptor.EncryptZero(nextParms, cipher); + Assert.IsFalse(cipher.IsNTTForm); + Assert.IsFalse(cipher.IsTransparent); + Assert.AreEqual(cipher.Scale, 1.0, double.Epsilon); + Assert.AreEqual(cipher.ParmsId, nextParms); + decryptor.Decrypt(cipher, plain); + Assert.IsTrue(plain.IsZero); + } + { + Encryptor encryptor = new Encryptor(context, secretKey); + + encryptor.EncryptZeroSymmetric(cipher); + Assert.IsFalse(cipher.IsNTTForm); + Assert.IsFalse(cipher.IsTransparent); + Assert.AreEqual(cipher.Scale, 1.0, double.Epsilon); + decryptor.Decrypt(cipher, plain); + Assert.IsTrue(plain.IsZero); + + encryptor.EncryptZeroSymmetric(nextParms, cipher); + Assert.IsFalse(cipher.IsNTTForm); + Assert.IsFalse(cipher.IsTransparent); + Assert.AreEqual(cipher.Scale, 1.0, double.Epsilon); + Assert.AreEqual(cipher.ParmsId, nextParms); + decryptor.Decrypt(cipher, plain); + Assert.IsTrue(plain.IsZero); + } + using (MemoryStream stream = new MemoryStream()) + { + Encryptor encryptor = new Encryptor(context, publicKey); + + encryptor.EncryptZero().Save(stream); + stream.Seek(0, SeekOrigin.Begin); + cipher.Load(context, stream); + Assert.IsFalse(cipher.IsNTTForm); + Assert.IsFalse(cipher.IsTransparent); + Assert.AreEqual(cipher.Scale, 1.0, double.Epsilon); + decryptor.Decrypt(cipher, plain); + Assert.IsTrue(plain.IsZero); + } + using (MemoryStream stream = new MemoryStream()) + { + Encryptor encryptor = new Encryptor(context, publicKey); + + encryptor.EncryptZero(nextParms).Save(stream); + stream.Seek(0, SeekOrigin.Begin); + cipher.Load(context, stream); + Assert.IsFalse(cipher.IsNTTForm); + Assert.IsFalse(cipher.IsTransparent); + Assert.AreEqual(cipher.Scale, 1.0, double.Epsilon); + Assert.AreEqual(cipher.ParmsId, nextParms); + decryptor.Decrypt(cipher, plain); + Assert.IsTrue(plain.IsZero); + } + using (MemoryStream stream = new MemoryStream()) + { + Encryptor encryptor = new Encryptor(context, secretKey); + + encryptor.EncryptZeroSymmetric().Save(stream); + stream.Seek(0, SeekOrigin.Begin); + cipher.Load(context, stream); + Assert.IsFalse(cipher.IsNTTForm); + Assert.IsFalse(cipher.IsTransparent); + Assert.AreEqual(cipher.Scale, 1.0, double.Epsilon); + decryptor.Decrypt(cipher, plain); + Assert.IsTrue(plain.IsZero); + } + using (MemoryStream stream = new MemoryStream()) + { + Encryptor encryptor = new Encryptor(context, secretKey); + + encryptor.EncryptZeroSymmetric(nextParms).Save(stream); + stream.Seek(0, SeekOrigin.Begin); + cipher.Load(context, stream); + Assert.IsFalse(cipher.IsNTTForm); + Assert.IsFalse(cipher.IsTransparent); + Assert.AreEqual(cipher.Scale, 1.0, double.Epsilon); + Assert.AreEqual(cipher.ParmsId, nextParms); + decryptor.Decrypt(cipher, plain); + Assert.IsTrue(plain.IsZero); + } + } { SEALContext context = GlobalContext.CKKSContext; KeyGenerator keygen = new KeyGenerator(context); @@ -363,39 +533,76 @@ public void EncryptZeroTest() [TestMethod] public void ExceptionsTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey pubKey); - PublicKey pubKey_invalid = new PublicKey(); - SecretKey secKey = keygen.SecretKey; - SecretKey secKey_invalid = new SecretKey(); - Encryptor encryptor = new Encryptor(context, pubKey); - Plaintext plain = new Plaintext(); - Ciphertext cipher = new Ciphertext(); - MemoryPoolHandle pool_invalid = new MemoryPoolHandle(); - ParmsId parmsId_invalid = new ParmsId(); - - Utilities.AssertThrows(() => encryptor = new Encryptor(context, null)); - Utilities.AssertThrows(() => encryptor = new Encryptor(null, pubKey)); - Utilities.AssertThrows(() => encryptor = new Encryptor(context, pubKey_invalid)); - Utilities.AssertThrows(() => encryptor = new Encryptor(context, pubKey_invalid, secKey)); - encryptor = new Encryptor(context, pubKey, secKey); - Utilities.AssertThrows(() => encryptor.SetPublicKey(pubKey_invalid)); - Utilities.AssertThrows(() => encryptor.SetSecretKey(secKey_invalid)); - - Utilities.AssertThrows(() => encryptor.Encrypt(null, cipher)); - Utilities.AssertThrows(() => encryptor.Encrypt(plain, cipher, pool_invalid)); - Utilities.AssertThrows(() => encryptor.EncryptZero(cipher, pool_invalid)); - Utilities.AssertThrows(() => encryptor.EncryptZero(parmsId_invalid, cipher)); - - Utilities.AssertThrows(() => encryptor.EncryptSymmetric(plain, destination: null)); - Utilities.AssertThrows(() => encryptor.EncryptSymmetric(null, cipher)); - Utilities.AssertThrows(() => encryptor.EncryptSymmetric(plain, cipher, pool_invalid)); - Utilities.AssertThrows(() => encryptor.EncryptZeroSymmetric(cipher, pool_invalid)); - Utilities.AssertThrows(() => encryptor.EncryptZeroSymmetric(parmsId_invalid, cipher)); - - Utilities.AssertThrows(() => encryptor.EncryptSymmetric(plain).Save(null)); - Utilities.AssertThrows(() => encryptor.EncryptZeroSymmetric().Save(null)); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey pubKey); + PublicKey pubKey_invalid = new PublicKey(); + SecretKey secKey = keygen.SecretKey; + SecretKey secKey_invalid = new SecretKey(); + Encryptor encryptor = new Encryptor(context, pubKey); + Plaintext plain = new Plaintext(); + Ciphertext cipher = new Ciphertext(); + MemoryPoolHandle pool_invalid = new MemoryPoolHandle(); + ParmsId parmsId_invalid = new ParmsId(); + + Utilities.AssertThrows(() => encryptor = new Encryptor(context, null)); + Utilities.AssertThrows(() => encryptor = new Encryptor(null, pubKey)); + Utilities.AssertThrows(() => encryptor = new Encryptor(context, pubKey_invalid)); + Utilities.AssertThrows(() => encryptor = new Encryptor(context, pubKey_invalid, secKey)); + encryptor = new Encryptor(context, pubKey, secKey); + Utilities.AssertThrows(() => encryptor.SetPublicKey(pubKey_invalid)); + Utilities.AssertThrows(() => encryptor.SetSecretKey(secKey_invalid)); + + Utilities.AssertThrows(() => encryptor.Encrypt(null, cipher)); + Utilities.AssertThrows(() => encryptor.Encrypt(plain, cipher, pool_invalid)); + Utilities.AssertThrows(() => encryptor.EncryptZero(cipher, pool_invalid)); + Utilities.AssertThrows(() => encryptor.EncryptZero(parmsId_invalid, cipher)); + + Utilities.AssertThrows(() => encryptor.EncryptSymmetric(plain, destination: null)); + Utilities.AssertThrows(() => encryptor.EncryptSymmetric(null, cipher)); + Utilities.AssertThrows(() => encryptor.EncryptSymmetric(plain, cipher, pool_invalid)); + Utilities.AssertThrows(() => encryptor.EncryptZeroSymmetric(cipher, pool_invalid)); + Utilities.AssertThrows(() => encryptor.EncryptZeroSymmetric(parmsId_invalid, cipher)); + + Utilities.AssertThrows(() => encryptor.EncryptSymmetric(plain).Save(null)); + Utilities.AssertThrows(() => encryptor.EncryptZeroSymmetric().Save(null)); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey pubKey); + PublicKey pubKey_invalid = new PublicKey(); + SecretKey secKey = keygen.SecretKey; + SecretKey secKey_invalid = new SecretKey(); + Encryptor encryptor = new Encryptor(context, pubKey); + Plaintext plain = new Plaintext(); + Ciphertext cipher = new Ciphertext(); + MemoryPoolHandle pool_invalid = new MemoryPoolHandle(); + ParmsId parmsId_invalid = new ParmsId(); + + Utilities.AssertThrows(() => encryptor = new Encryptor(context, null)); + Utilities.AssertThrows(() => encryptor = new Encryptor(null, pubKey)); + Utilities.AssertThrows(() => encryptor = new Encryptor(context, pubKey_invalid)); + Utilities.AssertThrows(() => encryptor = new Encryptor(context, pubKey_invalid, secKey)); + encryptor = new Encryptor(context, pubKey, secKey); + Utilities.AssertThrows(() => encryptor.SetPublicKey(pubKey_invalid)); + Utilities.AssertThrows(() => encryptor.SetSecretKey(secKey_invalid)); + + Utilities.AssertThrows(() => encryptor.Encrypt(null, cipher)); + Utilities.AssertThrows(() => encryptor.Encrypt(plain, cipher, pool_invalid)); + Utilities.AssertThrows(() => encryptor.EncryptZero(cipher, pool_invalid)); + Utilities.AssertThrows(() => encryptor.EncryptZero(parmsId_invalid, cipher)); + + Utilities.AssertThrows(() => encryptor.EncryptSymmetric(plain, destination: null)); + Utilities.AssertThrows(() => encryptor.EncryptSymmetric(null, cipher)); + Utilities.AssertThrows(() => encryptor.EncryptSymmetric(plain, cipher, pool_invalid)); + Utilities.AssertThrows(() => encryptor.EncryptZeroSymmetric(cipher, pool_invalid)); + Utilities.AssertThrows(() => encryptor.EncryptZeroSymmetric(parmsId_invalid, cipher)); + + Utilities.AssertThrows(() => encryptor.EncryptSymmetric(plain).Save(null)); + Utilities.AssertThrows(() => encryptor.EncryptZeroSymmetric().Save(null)); + } } } } diff --git a/dotnet/tests/EvaluatorTests.cs b/dotnet/tests/EvaluatorTests.cs index 347833d5c..c6cbef757 100644 --- a/dotnet/tests/EvaluatorTests.cs +++ b/dotnet/tests/EvaluatorTests.cs @@ -16,854 +16,1720 @@ public class EvaluatorTests [TestMethod] public void CreateTest() { - Evaluator evaluator = new Evaluator(GlobalContext.BFVContext); - Assert.IsNotNull(evaluator); + { + Evaluator evaluator = new Evaluator(GlobalContext.BFVContext); + Assert.IsNotNull(evaluator); + } + { + Evaluator evaluator = new Evaluator(GlobalContext.BGVContext); + Assert.IsNotNull(evaluator); + } } [TestMethod] public void NegateTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - Assert.IsTrue(context.ParametersSet); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted = new Ciphertext(); - Ciphertext encdestination = new Ciphertext(); - Plaintext plain = new Plaintext("3x^2 + 2x^1 + 1"); - Plaintext plaindest = new Plaintext(); - encryptor.Encrypt(plain, encrypted); - evaluator.Negate(encrypted, encdestination); - decryptor.Decrypt(encdestination, plaindest); - - // coefficients are negated (modulo 64) - Assert.AreEqual(0x3Ful, plaindest[0]); - Assert.AreEqual(0x3Eul, plaindest[1]); - Assert.AreEqual(0x3Dul, plaindest[2]); - - plain = new Plaintext("6x^3 + 7x^2 + 8x^1 + 9"); - encryptor.Encrypt(plain, encrypted); - evaluator.NegateInplace(encrypted); - decryptor.Decrypt(encrypted, plain); - - // coefficients are negated (modulo 64) - Assert.AreEqual(0x37ul, plain[0]); - Assert.AreEqual(0x38ul, plain[1]); - Assert.AreEqual(0x39ul, plain[2]); - Assert.AreEqual(0x3Aul, plain[3]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + Assert.IsTrue(context.ParametersSet); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdestination = new Ciphertext(); + Plaintext plain = new Plaintext("3x^2 + 2x^1 + 1"); + Plaintext plaindest = new Plaintext(); + encryptor.Encrypt(plain, encrypted); + evaluator.Negate(encrypted, encdestination); + decryptor.Decrypt(encdestination, plaindest); + + // coefficients are negated (modulo 64) + Assert.AreEqual(0x3Ful, plaindest[0]); + Assert.AreEqual(0x3Eul, plaindest[1]); + Assert.AreEqual(0x3Dul, plaindest[2]); + + plain = new Plaintext("6x^3 + 7x^2 + 8x^1 + 9"); + encryptor.Encrypt(plain, encrypted); + evaluator.NegateInplace(encrypted); + decryptor.Decrypt(encrypted, plain); + + // coefficients are negated (modulo 64) + Assert.AreEqual(0x37ul, plain[0]); + Assert.AreEqual(0x38ul, plain[1]); + Assert.AreEqual(0x39ul, plain[2]); + Assert.AreEqual(0x3Aul, plain[3]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + Assert.IsTrue(context.ParametersSet); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdestination = new Ciphertext(); + Plaintext plain = new Plaintext("3x^2 + 2x^1 + 1"); + Plaintext plaindest = new Plaintext(); + encryptor.Encrypt(plain, encrypted); + evaluator.Negate(encrypted, encdestination); + decryptor.Decrypt(encdestination, plaindest); + + // coefficients are negated (modulo 64) + Assert.AreEqual(0x3Ful, plaindest[0]); + Assert.AreEqual(0x3Eul, plaindest[1]); + Assert.AreEqual(0x3Dul, plaindest[2]); + + plain = new Plaintext("6x^3 + 7x^2 + 8x^1 + 9"); + encryptor.Encrypt(plain, encrypted); + evaluator.NegateInplace(encrypted); + decryptor.Decrypt(encrypted, plain); + + // coefficients are negated (modulo 64) + Assert.AreEqual(0x37ul, plain[0]); + Assert.AreEqual(0x38ul, plain[1]); + Assert.AreEqual(0x39ul, plain[2]); + Assert.AreEqual(0x3Aul, plain[3]); + } } [TestMethod] public void AddTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted1 = new Ciphertext(); - Ciphertext encrypted2 = new Ciphertext(); - Ciphertext encdestination = new Ciphertext(); - - Plaintext plain1 = new Plaintext("5x^4 + 4x^3 + 3x^2 + 2x^1 + 1"); - Plaintext plain2 = new Plaintext("4x^7 + 5x^6 + 6x^5 + 7x^4 + 8x^3 + 9x^2 + Ax^1 + B"); - Plaintext plaindest = new Plaintext(); - - encryptor.Encrypt(plain1, encrypted1); - encryptor.Encrypt(plain2, encrypted2); - evaluator.Add(encrypted1, encrypted2, encdestination); - decryptor.Decrypt(encdestination, plaindest); - - Assert.AreEqual(12ul, plaindest[0]); - Assert.AreEqual(12ul, plaindest[1]); - Assert.AreEqual(12ul, plaindest[2]); - Assert.AreEqual(12ul, plaindest[3]); - Assert.AreEqual(12ul, plaindest[4]); - Assert.AreEqual(6ul, plaindest[5]); - Assert.AreEqual(5ul, plaindest[6]); - Assert.AreEqual(4ul, plaindest[7]); - - plain1 = new Plaintext("1x^2 + 2x^1 + 3"); - plain2 = new Plaintext("2x^3 + 2x^2 + 2x^1 + 2"); - - encryptor.Encrypt(plain1, encrypted1); - encryptor.Encrypt(plain2, encrypted2); - evaluator.AddInplace(encrypted1, encrypted2); - decryptor.Decrypt(encrypted1, plaindest); - - Assert.AreEqual(5ul, plaindest[0]); - Assert.AreEqual(4ul, plaindest[1]); - Assert.AreEqual(3ul, plaindest[2]); - Assert.AreEqual(2ul, plaindest[3]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted1 = new Ciphertext(); + Ciphertext encrypted2 = new Ciphertext(); + Ciphertext encdestination = new Ciphertext(); + + Plaintext plain1 = new Plaintext("5x^4 + 4x^3 + 3x^2 + 2x^1 + 1"); + Plaintext plain2 = new Plaintext("4x^7 + 5x^6 + 6x^5 + 7x^4 + 8x^3 + 9x^2 + Ax^1 + B"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(plain1, encrypted1); + encryptor.Encrypt(plain2, encrypted2); + evaluator.Add(encrypted1, encrypted2, encdestination); + decryptor.Decrypt(encdestination, plaindest); + + Assert.AreEqual(12ul, plaindest[0]); + Assert.AreEqual(12ul, plaindest[1]); + Assert.AreEqual(12ul, plaindest[2]); + Assert.AreEqual(12ul, plaindest[3]); + Assert.AreEqual(12ul, plaindest[4]); + Assert.AreEqual(6ul, plaindest[5]); + Assert.AreEqual(5ul, plaindest[6]); + Assert.AreEqual(4ul, plaindest[7]); + + plain1 = new Plaintext("1x^2 + 2x^1 + 3"); + plain2 = new Plaintext("2x^3 + 2x^2 + 2x^1 + 2"); + + encryptor.Encrypt(plain1, encrypted1); + encryptor.Encrypt(plain2, encrypted2); + evaluator.AddInplace(encrypted1, encrypted2); + decryptor.Decrypt(encrypted1, plaindest); + + Assert.AreEqual(5ul, plaindest[0]); + Assert.AreEqual(4ul, plaindest[1]); + Assert.AreEqual(3ul, plaindest[2]); + Assert.AreEqual(2ul, plaindest[3]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted1 = new Ciphertext(); + Ciphertext encrypted2 = new Ciphertext(); + Ciphertext encdestination = new Ciphertext(); + + Plaintext plain1 = new Plaintext("5x^4 + 4x^3 + 3x^2 + 2x^1 + 1"); + Plaintext plain2 = new Plaintext("4x^7 + 5x^6 + 6x^5 + 7x^4 + 8x^3 + 9x^2 + Ax^1 + B"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(plain1, encrypted1); + encryptor.Encrypt(plain2, encrypted2); + evaluator.Add(encrypted1, encrypted2, encdestination); + decryptor.Decrypt(encdestination, plaindest); + + Assert.AreEqual(12ul, plaindest[0]); + Assert.AreEqual(12ul, plaindest[1]); + Assert.AreEqual(12ul, plaindest[2]); + Assert.AreEqual(12ul, plaindest[3]); + Assert.AreEqual(12ul, plaindest[4]); + Assert.AreEqual(6ul, plaindest[5]); + Assert.AreEqual(5ul, plaindest[6]); + Assert.AreEqual(4ul, plaindest[7]); + + plain1 = new Plaintext("1x^2 + 2x^1 + 3"); + plain2 = new Plaintext("2x^3 + 2x^2 + 2x^1 + 2"); + + encryptor.Encrypt(plain1, encrypted1); + encryptor.Encrypt(plain2, encrypted2); + evaluator.AddInplace(encrypted1, encrypted2); + decryptor.Decrypt(encrypted1, plaindest); + + Assert.AreEqual(5ul, plaindest[0]); + Assert.AreEqual(4ul, plaindest[1]); + Assert.AreEqual(3ul, plaindest[2]); + Assert.AreEqual(2ul, plaindest[3]); + } } [TestMethod] public void AddPlainTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - Plaintext plain = new Plaintext("3x^2 + 2x^1 + 1"); - Plaintext plaindest = new Plaintext(); - - encryptor.Encrypt(new Plaintext("2x^2 + 2x^1 + 2"), encrypted); - evaluator.AddPlain(encrypted, plain, encdest); - decryptor.Decrypt(encdest, plaindest); - - Assert.AreEqual(3ul, plaindest[0]); - Assert.AreEqual(4ul, plaindest[1]); - Assert.AreEqual(5ul, plaindest[2]); - - plain.Set("1x^2 + 1x^1 + 1"); - encryptor.Encrypt(new Plaintext("2x^3 + 2x^2 + 2x^1 + 2"), encrypted); - evaluator.AddPlainInplace(encrypted, plain); - decryptor.Decrypt(encrypted, plaindest); - - Assert.AreEqual(4ul, plaindest.CoeffCount); - Assert.AreEqual(3ul, plaindest[0]); - Assert.AreEqual(3ul, plaindest[1]); - Assert.AreEqual(3ul, plaindest[2]); - Assert.AreEqual(2ul, plaindest[3]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext("3x^2 + 2x^1 + 1"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("2x^2 + 2x^1 + 2"), encrypted); + evaluator.AddPlain(encrypted, plain, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.AreEqual(3ul, plaindest[0]); + Assert.AreEqual(4ul, plaindest[1]); + Assert.AreEqual(5ul, plaindest[2]); + + plain.Set("1x^2 + 1x^1 + 1"); + encryptor.Encrypt(new Plaintext("2x^3 + 2x^2 + 2x^1 + 2"), encrypted); + evaluator.AddPlainInplace(encrypted, plain); + decryptor.Decrypt(encrypted, plaindest); + + Assert.AreEqual(4ul, plaindest.CoeffCount); + Assert.AreEqual(3ul, plaindest[0]); + Assert.AreEqual(3ul, plaindest[1]); + Assert.AreEqual(3ul, plaindest[2]); + Assert.AreEqual(2ul, plaindest[3]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext("3x^2 + 2x^1 + 1"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("2x^2 + 2x^1 + 2"), encrypted); + evaluator.AddPlain(encrypted, plain, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.AreEqual(3ul, plaindest[0]); + Assert.AreEqual(4ul, plaindest[1]); + Assert.AreEqual(5ul, plaindest[2]); + + plain.Set("1x^2 + 1x^1 + 1"); + encryptor.Encrypt(new Plaintext("2x^3 + 2x^2 + 2x^1 + 2"), encrypted); + evaluator.AddPlainInplace(encrypted, plain); + decryptor.Decrypt(encrypted, plaindest); + + Assert.AreEqual(4ul, plaindest.CoeffCount); + Assert.AreEqual(3ul, plaindest[0]); + Assert.AreEqual(3ul, plaindest[1]); + Assert.AreEqual(3ul, plaindest[2]); + Assert.AreEqual(2ul, plaindest[3]); + } } [TestMethod] public void AddManyTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext[] encrypteds = new Ciphertext[6]; - - for(int i = 0; i < encrypteds.Length; i++) + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext[] encrypteds = new Ciphertext[6]; + + for(int i = 0; i < encrypteds.Length; i++) + { + encrypteds[i] = new Ciphertext(); + encryptor.Encrypt(new Plaintext((i + 1).ToString()), encrypteds[i]); + } + + Ciphertext encdest = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + evaluator.AddMany(encrypteds, encdest); + decryptor.Decrypt(encdest, plaindest); + + // 1+2+3+4+5+6 + Assert.AreEqual(21ul, plaindest[0]); + } { - encrypteds[i] = new Ciphertext(); - encryptor.Encrypt(new Plaintext((i + 1).ToString()), encrypteds[i]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext[] encrypteds = new Ciphertext[6]; + + for(int i = 0; i < encrypteds.Length; i++) + { + encrypteds[i] = new Ciphertext(); + encryptor.Encrypt(new Plaintext((i + 1).ToString()), encrypteds[i]); + } + + Ciphertext encdest = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + evaluator.AddMany(encrypteds, encdest); + decryptor.Decrypt(encdest, plaindest); + + // 1+2+3+4+5+6 + Assert.AreEqual(21ul, plaindest[0]); } - - Ciphertext encdest = new Ciphertext(); - Plaintext plaindest = new Plaintext(); - evaluator.AddMany(encrypteds, encdest); - decryptor.Decrypt(encdest, plaindest); - - // 1+2+3+4+5+6 - Assert.AreEqual(21ul, plaindest[0]); } [TestMethod] public void SubTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted1 = new Ciphertext(); - Ciphertext encrypted2 = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - Plaintext plain1 = new Plaintext("Ax^2 + Bx^1 + C"); - Plaintext plain2 = new Plaintext("5x^3 + 5x^2 + 5x^1 + 5"); - Plaintext plaindest = new Plaintext(); - - encryptor.Encrypt(plain1, encrypted1); - encryptor.Encrypt(plain2, encrypted2); - evaluator.Sub(encrypted1, encrypted2, encdest); - decryptor.Decrypt(encdest, plaindest); - - Assert.AreEqual(7ul, plaindest[0]); - Assert.AreEqual(6ul, plaindest[1]); - Assert.AreEqual(5ul, plaindest[2]); - Assert.AreEqual(0x3Bul, plaindest[3]); - - plain1.Set("Ax^3 + Bx^2 + Cx^1 + D"); - plain2.Set("5x^2 + 5x^1 + 5"); - - encryptor.Encrypt(plain1, encrypted1); - encryptor.Encrypt(plain2, encrypted2); - evaluator.SubInplace(encrypted1, encrypted2); - decryptor.Decrypt(encrypted1, plaindest); - - Assert.AreEqual(8ul, plaindest[0]); - Assert.AreEqual(7ul, plaindest[1]); - Assert.AreEqual(6ul, plaindest[2]); - Assert.AreEqual(10ul, plaindest[3]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted1 = new Ciphertext(); + Ciphertext encrypted2 = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain1 = new Plaintext("Ax^2 + Bx^1 + C"); + Plaintext plain2 = new Plaintext("5x^3 + 5x^2 + 5x^1 + 5"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(plain1, encrypted1); + encryptor.Encrypt(plain2, encrypted2); + evaluator.Sub(encrypted1, encrypted2, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.AreEqual(7ul, plaindest[0]); + Assert.AreEqual(6ul, plaindest[1]); + Assert.AreEqual(5ul, plaindest[2]); + Assert.AreEqual(0x3Bul, plaindest[3]); + + plain1.Set("Ax^3 + Bx^2 + Cx^1 + D"); + plain2.Set("5x^2 + 5x^1 + 5"); + + encryptor.Encrypt(plain1, encrypted1); + encryptor.Encrypt(plain2, encrypted2); + evaluator.SubInplace(encrypted1, encrypted2); + decryptor.Decrypt(encrypted1, plaindest); + + Assert.AreEqual(8ul, plaindest[0]); + Assert.AreEqual(7ul, plaindest[1]); + Assert.AreEqual(6ul, plaindest[2]); + Assert.AreEqual(10ul, plaindest[3]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted1 = new Ciphertext(); + Ciphertext encrypted2 = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain1 = new Plaintext("Ax^2 + Bx^1 + C"); + Plaintext plain2 = new Plaintext("5x^3 + 5x^2 + 5x^1 + 5"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(plain1, encrypted1); + encryptor.Encrypt(plain2, encrypted2); + evaluator.Sub(encrypted1, encrypted2, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.AreEqual(7ul, plaindest[0]); + Assert.AreEqual(6ul, plaindest[1]); + Assert.AreEqual(5ul, plaindest[2]); + Assert.AreEqual(0x3Bul, plaindest[3]); + + plain1.Set("Ax^3 + Bx^2 + Cx^1 + D"); + plain2.Set("5x^2 + 5x^1 + 5"); + + encryptor.Encrypt(plain1, encrypted1); + encryptor.Encrypt(plain2, encrypted2); + evaluator.SubInplace(encrypted1, encrypted2); + decryptor.Decrypt(encrypted1, plaindest); + + Assert.AreEqual(8ul, plaindest[0]); + Assert.AreEqual(7ul, plaindest[1]); + Assert.AreEqual(6ul, plaindest[2]); + Assert.AreEqual(10ul, plaindest[3]); + } } [TestMethod] public void SubPlainTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - Plaintext plain = new Plaintext("5x^2 + 4x^1 + 3"); - Plaintext plaindest = new Plaintext(); - - encryptor.Encrypt(new Plaintext("3x^1 + 4"), encrypted); - evaluator.SubPlain(encrypted, plain, encdest); - decryptor.Decrypt(encdest, plaindest); - - Assert.AreEqual(3ul, plaindest.CoeffCount); - Assert.AreEqual(1ul, plaindest[0]); - Assert.AreEqual(0x3Ful, plaindest[1]); // -1 - Assert.AreEqual(0x3Bul, plaindest[2]); // -5 - - plain.Set("6x^3 + 1x^2 + 7x^1 + 2"); - encryptor.Encrypt(new Plaintext("Ax^2 + Bx^1 + C"), encrypted); - evaluator.SubPlainInplace(encrypted, plain); - decryptor.Decrypt(encrypted, plaindest); - - Assert.AreEqual(4ul, plaindest.CoeffCount); - Assert.AreEqual(10ul, plaindest[0]); - Assert.AreEqual(4ul, plaindest[1]); - Assert.AreEqual(9ul, plaindest[2]); - Assert.AreEqual(0x3Aul, plaindest[3]); // -6 + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext("5x^2 + 4x^1 + 3"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("3x^1 + 4"), encrypted); + evaluator.SubPlain(encrypted, plain, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.AreEqual(3ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + Assert.AreEqual(0x3Ful, plaindest[1]); // -1 + Assert.AreEqual(0x3Bul, plaindest[2]); // -5 + + plain.Set("6x^3 + 1x^2 + 7x^1 + 2"); + encryptor.Encrypt(new Plaintext("Ax^2 + Bx^1 + C"), encrypted); + evaluator.SubPlainInplace(encrypted, plain); + decryptor.Decrypt(encrypted, plaindest); + + Assert.AreEqual(4ul, plaindest.CoeffCount); + Assert.AreEqual(10ul, plaindest[0]); + Assert.AreEqual(4ul, plaindest[1]); + Assert.AreEqual(9ul, plaindest[2]); + Assert.AreEqual(0x3Aul, plaindest[3]); // -6 + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext("5x^2 + 4x^1 + 3"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("3x^1 + 4"), encrypted); + evaluator.SubPlain(encrypted, plain, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.AreEqual(3ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + Assert.AreEqual(0x3Ful, plaindest[1]); // -1 + Assert.AreEqual(0x3Bul, plaindest[2]); // -5 + + plain.Set("6x^3 + 1x^2 + 7x^1 + 2"); + encryptor.Encrypt(new Plaintext("Ax^2 + Bx^1 + C"), encrypted); + evaluator.SubPlainInplace(encrypted, plain); + decryptor.Decrypt(encrypted, plaindest); + + Assert.AreEqual(4ul, plaindest.CoeffCount); + Assert.AreEqual(10ul, plaindest[0]); + Assert.AreEqual(4ul, plaindest[1]); + Assert.AreEqual(9ul, plaindest[2]); + Assert.AreEqual(0x3Aul, plaindest[3]); // -6 + } } [TestMethod] public void MultiplyTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted1 = new Ciphertext(); - Ciphertext encrypted2 = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - Plaintext plaindest = new Plaintext(); - - encryptor.Encrypt(new Plaintext("1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"), encrypted1); - encryptor.Encrypt(new Plaintext("3x^2 + 2x^1 + 1"), encrypted2); - evaluator.Multiply(encrypted1, encrypted2, encdest); - decryptor.Decrypt(encdest, plaindest); - - // {3x^6 + 8x^5 + Ex^4 + 14x^3 + 1Ax^2 + Ex^1 + 5} - Assert.AreEqual(7ul, plaindest.CoeffCount); - Assert.AreEqual(5ul, plaindest[0]); - Assert.AreEqual(14ul, plaindest[1]); - Assert.AreEqual(26ul, plaindest[2]); - Assert.AreEqual(20ul, plaindest[3]); - Assert.AreEqual(14ul, plaindest[4]); - Assert.AreEqual(8ul, plaindest[5]); - Assert.AreEqual(3ul, plaindest[6]); - - encryptor.Encrypt(new Plaintext("2x^2 + 3x^1 + 4"), encrypted1); - encryptor.Encrypt(new Plaintext("4x^1 + 5"), encrypted2); - evaluator.MultiplyInplace(encrypted1, encrypted2); - decryptor.Decrypt(encrypted1, plaindest); - - // {8x^3 + 16x^2 + 1Fx^1 + 14} - Assert.AreEqual(4ul, plaindest.CoeffCount); - Assert.AreEqual(20ul, plaindest[0]); - Assert.AreEqual(31ul, plaindest[1]); - Assert.AreEqual(22ul, plaindest[2]); - Assert.AreEqual(8ul, plaindest[3]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted1 = new Ciphertext(); + Ciphertext encrypted2 = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"), encrypted1); + encryptor.Encrypt(new Plaintext("3x^2 + 2x^1 + 1"), encrypted2); + evaluator.Multiply(encrypted1, encrypted2, encdest); + decryptor.Decrypt(encdest, plaindest); + + // {3x^6 + 8x^5 + Ex^4 + 14x^3 + 1Ax^2 + Ex^1 + 5} + Assert.AreEqual(7ul, plaindest.CoeffCount); + Assert.AreEqual(5ul, plaindest[0]); + Assert.AreEqual(14ul, plaindest[1]); + Assert.AreEqual(26ul, plaindest[2]); + Assert.AreEqual(20ul, plaindest[3]); + Assert.AreEqual(14ul, plaindest[4]); + Assert.AreEqual(8ul, plaindest[5]); + Assert.AreEqual(3ul, plaindest[6]); + + encryptor.Encrypt(new Plaintext("2x^2 + 3x^1 + 4"), encrypted1); + encryptor.Encrypt(new Plaintext("4x^1 + 5"), encrypted2); + evaluator.MultiplyInplace(encrypted1, encrypted2); + decryptor.Decrypt(encrypted1, plaindest); + + // {8x^3 + 16x^2 + 1Fx^1 + 14} + Assert.AreEqual(4ul, plaindest.CoeffCount); + Assert.AreEqual(20ul, plaindest[0]); + Assert.AreEqual(31ul, plaindest[1]); + Assert.AreEqual(22ul, plaindest[2]); + Assert.AreEqual(8ul, plaindest[3]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted1 = new Ciphertext(); + Ciphertext encrypted2 = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"), encrypted1); + encryptor.Encrypt(new Plaintext("3x^2 + 2x^1 + 1"), encrypted2); + evaluator.Multiply(encrypted1, encrypted2, encdest); + decryptor.Decrypt(encdest, plaindest); + + // {3x^6 + 8x^5 + Ex^4 + 14x^3 + 1Ax^2 + Ex^1 + 5} + Assert.AreEqual(7ul, plaindest.CoeffCount); + Assert.AreEqual(5ul, plaindest[0]); + Assert.AreEqual(14ul, plaindest[1]); + Assert.AreEqual(26ul, plaindest[2]); + Assert.AreEqual(20ul, plaindest[3]); + Assert.AreEqual(14ul, plaindest[4]); + Assert.AreEqual(8ul, plaindest[5]); + Assert.AreEqual(3ul, plaindest[6]); + + encryptor.Encrypt(new Plaintext("2x^2 + 3x^1 + 4"), encrypted1); + encryptor.Encrypt(new Plaintext("4x^1 + 5"), encrypted2); + evaluator.MultiplyInplace(encrypted1, encrypted2); + decryptor.Decrypt(encrypted1, plaindest); + + // {8x^3 + 16x^2 + 1Fx^1 + 14} + Assert.AreEqual(4ul, plaindest.CoeffCount); + Assert.AreEqual(20ul, plaindest[0]); + Assert.AreEqual(31ul, plaindest[1]); + Assert.AreEqual(22ul, plaindest[2]); + Assert.AreEqual(8ul, plaindest[3]); + } } [TestMethod] public void MultiplyManyTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) - { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - keygen.CreateRelinKeys(out RelinKeys relinKeys); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext[] encrypteds = new Ciphertext[4]; - Ciphertext encdest = new Ciphertext(); - Plaintext plaindest = new Plaintext(); - - for (int i = 0; i < encrypteds.Length; i++) { - encrypteds[i] = new Ciphertext(); - encryptor.Encrypt(new Plaintext((i + 1).ToString()), encrypteds[i]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext[] encrypteds = new Ciphertext[4]; + Ciphertext encdest = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + + for (int i = 0; i < encrypteds.Length; i++) + { + encrypteds[i] = new Ciphertext(); + encryptor.Encrypt(new Plaintext((i + 1).ToString()), encrypteds[i]); + } + + evaluator.MultiplyMany(encrypteds, relinKeys, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(24ul, plaindest[0]); + + Utilities.AssertThrows(() => + { + // Uninitialized memory pool handle + MemoryPoolHandle pool = new MemoryPoolHandle(); + evaluator.MultiplyMany(encrypteds, relinKeys, encdest, pool); + }); } - - evaluator.MultiplyMany(encrypteds, relinKeys, encdest); - decryptor.Decrypt(encdest, plaindest); - - Assert.AreEqual(1ul, plaindest.CoeffCount); - Assert.AreEqual(24ul, plaindest[0]); - - Utilities.AssertThrows(() => { - // Uninitialized memory pool handle - MemoryPoolHandle pool = new MemoryPoolHandle(); - evaluator.MultiplyMany(encrypteds, relinKeys, encdest, pool); - }); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext[] encrypteds = new Ciphertext[4]; + Ciphertext encdest = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + + for (int i = 0; i < encrypteds.Length; i++) + { + encrypteds[i] = new Ciphertext(); + encryptor.Encrypt(new Plaintext((i + 1).ToString()), encrypteds[i]); + } + + evaluator.MultiplyMany(encrypteds, relinKeys, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(24ul, plaindest[0]); + + Utilities.AssertThrows(() => + { + // Uninitialized memory pool handle + MemoryPoolHandle pool = new MemoryPoolHandle(); + evaluator.MultiplyMany(encrypteds, relinKeys, encdest, pool); + }); + } } [TestMethod] public void MultiplyPlainTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - keygen.CreateRelinKeys(out RelinKeys relinKeys); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - Plaintext plain = new Plaintext("2x^2 + 1"); - Plaintext plaindest = new Plaintext(); - - encryptor.Encrypt(new Plaintext("3x^2 + 2"), encrypted); - evaluator.MultiplyPlain(encrypted, plain, encdest); - decryptor.Decrypt(encdest, plaindest); - - // {6x^4 + 7x^2 + 2} - Assert.AreEqual(5ul, plaindest.CoeffCount); - Assert.AreEqual(2ul, plaindest[0]); - Assert.AreEqual(0ul, plaindest[1]); - Assert.AreEqual(7ul, plaindest[2]); - Assert.AreEqual(0ul, plaindest[3]); - Assert.AreEqual(6ul, plaindest[4]); - - encryptor.Encrypt(new Plaintext("4x^1 + 3"), encrypted); - plain.Set("2x^2 + 1"); - evaluator.MultiplyPlainInplace(encrypted, plain); - decryptor.Decrypt(encrypted, plaindest); - - // {8x^3 + 6x^2 + 4x^1 + 3} - Assert.AreEqual(4ul, plaindest.CoeffCount); - Assert.AreEqual(3ul, plaindest[0]); - Assert.AreEqual(4ul, plaindest[1]); - Assert.AreEqual(6ul, plaindest[2]); - Assert.AreEqual(8ul, plaindest[3]); - - encryptor.Encrypt(new Plaintext("4x^1 + 3"), encrypted); - plain.Set("3x^5"); - evaluator.MultiplyPlainInplace(encrypted, plain); - decryptor.Decrypt(encrypted, plaindest); - - // {Cx^6 + 9x^5} - Assert.AreEqual(7ul, plaindest.CoeffCount); - Assert.AreEqual(2ul, plaindest.NonZeroCoeffCount); - Assert.AreEqual(0ul, plaindest[0]); - Assert.AreEqual(0ul, plaindest[1]); - Assert.AreEqual(0ul, plaindest[2]); - Assert.AreEqual(0ul, plaindest[3]); - Assert.AreEqual(0ul, plaindest[4]); - Assert.AreEqual(9ul, plaindest[5]); - Assert.AreEqual(12ul, plaindest[6]); - - Utilities.AssertThrows(() => + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext("2x^2 + 1"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("3x^2 + 2"), encrypted); + evaluator.MultiplyPlain(encrypted, plain, encdest); + decryptor.Decrypt(encdest, plaindest); + + // {6x^4 + 7x^2 + 2} + Assert.AreEqual(5ul, plaindest.CoeffCount); + Assert.AreEqual(2ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(7ul, plaindest[2]); + Assert.AreEqual(0ul, plaindest[3]); + Assert.AreEqual(6ul, plaindest[4]); + + encryptor.Encrypt(new Plaintext("4x^1 + 3"), encrypted); + plain.Set("2x^2 + 1"); + evaluator.MultiplyPlainInplace(encrypted, plain); + decryptor.Decrypt(encrypted, plaindest); + + // {8x^3 + 6x^2 + 4x^1 + 3} + Assert.AreEqual(4ul, plaindest.CoeffCount); + Assert.AreEqual(3ul, plaindest[0]); + Assert.AreEqual(4ul, plaindest[1]); + Assert.AreEqual(6ul, plaindest[2]); + Assert.AreEqual(8ul, plaindest[3]); + + encryptor.Encrypt(new Plaintext("4x^1 + 3"), encrypted); + plain.Set("3x^5"); + evaluator.MultiplyPlainInplace(encrypted, plain); + decryptor.Decrypt(encrypted, plaindest); + + // {Cx^6 + 9x^5} + Assert.AreEqual(7ul, plaindest.CoeffCount); + Assert.AreEqual(2ul, plaindest.NonZeroCoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(0ul, plaindest[2]); + Assert.AreEqual(0ul, plaindest[3]); + Assert.AreEqual(0ul, plaindest[4]); + Assert.AreEqual(9ul, plaindest[5]); + Assert.AreEqual(12ul, plaindest[6]); + + Utilities.AssertThrows(() => + { + // Uninitialized pool + MemoryPoolHandle pool = new MemoryPoolHandle(); + evaluator.MultiplyPlain(encrypted, plain, encdest, pool); + }); + } { - // Uninitialized pool - MemoryPoolHandle pool = new MemoryPoolHandle(); - evaluator.MultiplyPlain(encrypted, plain, encdest, pool); - }); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext("2x^2 + 1"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("3x^2 + 2"), encrypted); + evaluator.MultiplyPlain(encrypted, plain, encdest); + decryptor.Decrypt(encdest, plaindest); + + // {6x^4 + 7x^2 + 2} + Assert.AreEqual(5ul, plaindest.CoeffCount); + Assert.AreEqual(2ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(7ul, plaindest[2]); + Assert.AreEqual(0ul, plaindest[3]); + Assert.AreEqual(6ul, plaindest[4]); + + encryptor.Encrypt(new Plaintext("4x^1 + 3"), encrypted); + plain.Set("2x^2 + 1"); + evaluator.MultiplyPlainInplace(encrypted, plain); + decryptor.Decrypt(encrypted, plaindest); + + // {8x^3 + 6x^2 + 4x^1 + 3} + Assert.AreEqual(4ul, plaindest.CoeffCount); + Assert.AreEqual(3ul, plaindest[0]); + Assert.AreEqual(4ul, plaindest[1]); + Assert.AreEqual(6ul, plaindest[2]); + Assert.AreEqual(8ul, plaindest[3]); + + encryptor.Encrypt(new Plaintext("4x^1 + 3"), encrypted); + plain.Set("3x^5"); + evaluator.MultiplyPlainInplace(encrypted, plain); + decryptor.Decrypt(encrypted, plaindest); + + // {Cx^6 + 9x^5} + Assert.AreEqual(7ul, plaindest.CoeffCount); + Assert.AreEqual(2ul, plaindest.NonZeroCoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(0ul, plaindest[2]); + Assert.AreEqual(0ul, plaindest[3]); + Assert.AreEqual(0ul, plaindest[4]); + Assert.AreEqual(9ul, plaindest[5]); + Assert.AreEqual(12ul, plaindest[6]); + + Utilities.AssertThrows(() => + { + // Uninitialized pool + MemoryPoolHandle pool = new MemoryPoolHandle(); + evaluator.MultiplyPlain(encrypted, plain, encdest, pool); + }); + } } [TestMethod] public void SquareTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - Plaintext plain = new Plaintext("2x^2 + 3x^1 + 4"); - Plaintext plaindest = new Plaintext(); - - encryptor.Encrypt(plain, encrypted); - evaluator.Square(encrypted, encdest); - decryptor.Decrypt(encdest, plaindest); - - // {4x^4 + Cx^3 + 19x^2 + 18x^1 + 10} - Assert.AreEqual(5ul, plaindest.CoeffCount); - Assert.AreEqual(16ul, plaindest[0]); - Assert.AreEqual(24ul, plaindest[1]); - Assert.AreEqual(25ul, plaindest[2]); - Assert.AreEqual(12ul, plaindest[3]); - Assert.AreEqual(4ul, plaindest[4]); - - encryptor.Encrypt(new Plaintext("3x^1 + 2"), encrypted); - evaluator.SquareInplace(encrypted); - decryptor.Decrypt(encrypted, plaindest); - - // {9x^2 + Cx^1 + 4} - Assert.AreEqual(3ul, plaindest.CoeffCount); - Assert.AreEqual(4ul, plaindest[0]); - Assert.AreEqual(12ul, plaindest[1]); - Assert.AreEqual(9ul, plaindest[2]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext("2x^2 + 3x^1 + 4"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(plain, encrypted); + evaluator.Square(encrypted, encdest); + decryptor.Decrypt(encdest, plaindest); + + // {4x^4 + Cx^3 + 19x^2 + 18x^1 + 10} + Assert.AreEqual(5ul, plaindest.CoeffCount); + Assert.AreEqual(16ul, plaindest[0]); + Assert.AreEqual(24ul, plaindest[1]); + Assert.AreEqual(25ul, plaindest[2]); + Assert.AreEqual(12ul, plaindest[3]); + Assert.AreEqual(4ul, plaindest[4]); + + encryptor.Encrypt(new Plaintext("3x^1 + 2"), encrypted); + evaluator.SquareInplace(encrypted); + decryptor.Decrypt(encrypted, plaindest); + + // {9x^2 + Cx^1 + 4} + Assert.AreEqual(3ul, plaindest.CoeffCount); + Assert.AreEqual(4ul, plaindest[0]); + Assert.AreEqual(12ul, plaindest[1]); + Assert.AreEqual(9ul, plaindest[2]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext("2x^2 + 3x^1 + 4"); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(plain, encrypted); + evaluator.Square(encrypted, encdest); + decryptor.Decrypt(encdest, plaindest); + + // {4x^4 + Cx^3 + 19x^2 + 18x^1 + 10} + Assert.AreEqual(5ul, plaindest.CoeffCount); + Assert.AreEqual(16ul, plaindest[0]); + Assert.AreEqual(24ul, plaindest[1]); + Assert.AreEqual(25ul, plaindest[2]); + Assert.AreEqual(12ul, plaindest[3]); + Assert.AreEqual(4ul, plaindest[4]); + + encryptor.Encrypt(new Plaintext("3x^1 + 2"), encrypted); + evaluator.SquareInplace(encrypted); + decryptor.Decrypt(encrypted, plaindest); + + // {9x^2 + Cx^1 + 4} + Assert.AreEqual(3ul, plaindest.CoeffCount); + Assert.AreEqual(4ul, plaindest[0]); + Assert.AreEqual(12ul, plaindest[1]); + Assert.AreEqual(9ul, plaindest[2]); + } } [TestMethod] public void ExponentiateTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - keygen.CreateRelinKeys(out RelinKeys relinKeys); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - Plaintext plain = new Plaintext(); - - encryptor.Encrypt(new Plaintext("2x^2 + 1"), encrypted); - evaluator.Exponentiate(encrypted, 3, relinKeys, encdest); - decryptor.Decrypt(encdest, plain); - - // {8x^6 + Cx^4 + 6x^2 + 1} - Assert.AreEqual(7ul, plain.CoeffCount); - Assert.AreEqual(1ul, plain[0]); - Assert.AreEqual(0ul, plain[1]); - Assert.AreEqual(6ul, plain[2]); - Assert.AreEqual(0ul, plain[3]); - Assert.AreEqual(12ul, plain[4]); - Assert.AreEqual(0ul, plain[5]); - Assert.AreEqual(8ul, plain[6]); - - encryptor.Encrypt(new Plaintext("3x^3 + 2"), encrypted); - evaluator.ExponentiateInplace(encrypted, 4, relinKeys); - decryptor.Decrypt(encrypted, plain); - - // {11x^12 + 18x^9 + 18x^6 + 20x^3 + 10} - Assert.AreEqual(13ul, plain.CoeffCount); - Assert.AreEqual(16ul, plain[0]); - Assert.AreEqual(0ul, plain[1]); - Assert.AreEqual(0ul, plain[2]); - Assert.AreEqual(32ul, plain[3]); - Assert.AreEqual(0ul, plain[4]); - Assert.AreEqual(0ul, plain[5]); - Assert.AreEqual(24ul, plain[6]); - Assert.AreEqual(0ul, plain[7]); - Assert.AreEqual(0ul, plain[8]); - Assert.AreEqual(24ul, plain[9]); - Assert.AreEqual(0ul, plain[10]); - Assert.AreEqual(0ul, plain[11]); - Assert.AreEqual(17ul, plain[12]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext(); + + encryptor.Encrypt(new Plaintext("2x^2 + 1"), encrypted); + evaluator.Exponentiate(encrypted, 3, relinKeys, encdest); + decryptor.Decrypt(encdest, plain); + + // {8x^6 + Cx^4 + 6x^2 + 1} + Assert.AreEqual(7ul, plain.CoeffCount); + Assert.AreEqual(1ul, plain[0]); + Assert.AreEqual(0ul, plain[1]); + Assert.AreEqual(6ul, plain[2]); + Assert.AreEqual(0ul, plain[3]); + Assert.AreEqual(12ul, plain[4]); + Assert.AreEqual(0ul, plain[5]); + Assert.AreEqual(8ul, plain[6]); + + encryptor.Encrypt(new Plaintext("3x^3 + 2"), encrypted); + evaluator.ExponentiateInplace(encrypted, 4, relinKeys); + decryptor.Decrypt(encrypted, plain); + + // {11x^12 + 18x^9 + 18x^6 + 20x^3 + 10} + Assert.AreEqual(13ul, plain.CoeffCount); + Assert.AreEqual(16ul, plain[0]); + Assert.AreEqual(0ul, plain[1]); + Assert.AreEqual(0ul, plain[2]); + Assert.AreEqual(32ul, plain[3]); + Assert.AreEqual(0ul, plain[4]); + Assert.AreEqual(0ul, plain[5]); + Assert.AreEqual(24ul, plain[6]); + Assert.AreEqual(0ul, plain[7]); + Assert.AreEqual(0ul, plain[8]); + Assert.AreEqual(24ul, plain[9]); + Assert.AreEqual(0ul, plain[10]); + Assert.AreEqual(0ul, plain[11]); + Assert.AreEqual(17ul, plain[12]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext(); + + encryptor.Encrypt(new Plaintext("2x^2 + 1"), encrypted); + evaluator.Exponentiate(encrypted, 3, relinKeys, encdest); + decryptor.Decrypt(encdest, plain); + + // {8x^6 + Cx^4 + 6x^2 + 1} + Assert.AreEqual(7ul, plain.CoeffCount); + Assert.AreEqual(1ul, plain[0]); + Assert.AreEqual(0ul, plain[1]); + Assert.AreEqual(6ul, plain[2]); + Assert.AreEqual(0ul, plain[3]); + Assert.AreEqual(12ul, plain[4]); + Assert.AreEqual(0ul, plain[5]); + Assert.AreEqual(8ul, plain[6]); + + encryptor.Encrypt(new Plaintext("3x^3 + 2"), encrypted); + evaluator.ExponentiateInplace(encrypted, 4, relinKeys); + decryptor.Decrypt(encrypted, plain); + + // {11x^12 + 18x^9 + 18x^6 + 20x^3 + 10} + Assert.AreEqual(13ul, plain.CoeffCount); + Assert.AreEqual(16ul, plain[0]); + Assert.AreEqual(0ul, plain[1]); + Assert.AreEqual(0ul, plain[2]); + Assert.AreEqual(32ul, plain[3]); + Assert.AreEqual(0ul, plain[4]); + Assert.AreEqual(0ul, plain[5]); + Assert.AreEqual(24ul, plain[6]); + Assert.AreEqual(0ul, plain[7]); + Assert.AreEqual(0ul, plain[8]); + Assert.AreEqual(24ul, plain[9]); + Assert.AreEqual(0ul, plain[10]); + Assert.AreEqual(0ul, plain[11]); + Assert.AreEqual(17ul, plain[12]); + } } [TestMethod] public void ApplyGaloisTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 8, - PlainModulus = new Modulus(257), - CoeffModulus = CoeffModulus.Create(8, new int[] { 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - keygen.CreateGaloisKeys(galoisElts: new uint[] { 1u, 3u, 5u, 15u }, out GaloisKeys galoisKeys); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Plaintext plain = new Plaintext("1"); - Plaintext plaindest = new Plaintext(); - Ciphertext encrypted = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - - encryptor.Encrypt(plain, encrypted); - evaluator.ApplyGalois(encrypted, galoisElt: 1, galoisKeys: galoisKeys, destination: encdest); - decryptor.Decrypt(encdest, plaindest); - - Assert.AreEqual(1ul, plaindest.CoeffCount); - Assert.AreEqual(1ul, plaindest[0]); - - plain.Set("1x^1"); - encryptor.Encrypt(plain, encrypted); - evaluator.ApplyGalois(encrypted, galoisElt: 1, galoisKeys: galoisKeys, destination: encdest); - decryptor.Decrypt(encdest, plaindest); - - // {1x^1} - Assert.AreEqual(2ul, plaindest.CoeffCount); - Assert.AreEqual(0ul, plaindest[0]); - Assert.AreEqual(1ul, plaindest[1]); - - evaluator.ApplyGalois(encdest, galoisElt: 3, galoisKeys: galoisKeys, destination: encrypted); - decryptor.Decrypt(encrypted, plaindest); - - // {1x^3} - Assert.AreEqual(4ul, plaindest.CoeffCount); - Assert.AreEqual(0ul, plaindest[0]); - Assert.AreEqual(0ul, plaindest[1]); - Assert.AreEqual(0ul, plaindest[2]); - Assert.AreEqual(1ul, plaindest[3]); - - evaluator.ApplyGalois(encrypted, galoisElt: 5, galoisKeys: galoisKeys, destination: encdest); - decryptor.Decrypt(encdest, plaindest); - - // {100x^7} - Assert.AreEqual(8ul, plaindest.CoeffCount); - Assert.AreEqual(0ul, plaindest[0]); - Assert.AreEqual(0ul, plaindest[1]); - Assert.AreEqual(0ul, plaindest[2]); - Assert.AreEqual(0ul, plaindest[3]); - Assert.AreEqual(0ul, plaindest[4]); - Assert.AreEqual(0ul, plaindest[5]); - Assert.AreEqual(0ul, plaindest[6]); - Assert.AreEqual(256ul, plaindest[7]); - - plain.Set("1x^2"); - encryptor.Encrypt(plain, encrypted); - evaluator.ApplyGaloisInplace(encrypted, 1, galoisKeys); - decryptor.Decrypt(encrypted, plaindest); - - // {1x^2} - Assert.AreEqual(3ul, plaindest.CoeffCount); - Assert.AreEqual(0ul, plaindest[0]); - Assert.AreEqual(0ul, plaindest[1]); - Assert.AreEqual(1ul, plaindest[2]); - - evaluator.ApplyGaloisInplace(encrypted, 3, galoisKeys); - decryptor.Decrypt(encrypted, plaindest); - - // {1x^6} - Assert.AreEqual(7ul, plaindest.CoeffCount); - Assert.AreEqual(0ul, plaindest[0]); - Assert.AreEqual(0ul, plaindest[1]); - Assert.AreEqual(0ul, plaindest[2]); - Assert.AreEqual(0ul, plaindest[3]); - Assert.AreEqual(0ul, plaindest[4]); - Assert.AreEqual(0ul, plaindest[5]); - Assert.AreEqual(1ul, plaindest[6]); - - evaluator.ApplyGaloisInplace(encrypted, 5, galoisKeys); - decryptor.Decrypt(encrypted, plaindest); - - // {100x^6} - Assert.AreEqual(7ul, plaindest.CoeffCount); - Assert.AreEqual(0ul, plaindest[0]); - Assert.AreEqual(0ul, plaindest[1]); - Assert.AreEqual(0ul, plaindest[2]); - Assert.AreEqual(0ul, plaindest[3]); - Assert.AreEqual(0ul, plaindest[4]); - Assert.AreEqual(0ul, plaindest[5]); - Assert.AreEqual(256ul, plaindest[6]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 8, + PlainModulus = new Modulus(257), + CoeffModulus = CoeffModulus.Create(8, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + keygen.CreateGaloisKeys(galoisElts: new uint[] { 1u, 3u, 5u, 15u }, out GaloisKeys galoisKeys); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Plaintext plain = new Plaintext("1"); + Plaintext plaindest = new Plaintext(); + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + + encryptor.Encrypt(plain, encrypted); + evaluator.ApplyGalois(encrypted, galoisElt: 1, galoisKeys: galoisKeys, destination: encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + + plain.Set("1x^1"); + encryptor.Encrypt(plain, encrypted); + evaluator.ApplyGalois(encrypted, galoisElt: 1, galoisKeys: galoisKeys, destination: encdest); + decryptor.Decrypt(encdest, plaindest); + + // {1x^1} + Assert.AreEqual(2ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(1ul, plaindest[1]); + + evaluator.ApplyGalois(encdest, galoisElt: 3, galoisKeys: galoisKeys, destination: encrypted); + decryptor.Decrypt(encrypted, plaindest); + + // {1x^3} + Assert.AreEqual(4ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(0ul, plaindest[2]); + Assert.AreEqual(1ul, plaindest[3]); + + evaluator.ApplyGalois(encrypted, galoisElt: 5, galoisKeys: galoisKeys, destination: encdest); + decryptor.Decrypt(encdest, plaindest); + + // {100x^7} + Assert.AreEqual(8ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(0ul, plaindest[2]); + Assert.AreEqual(0ul, plaindest[3]); + Assert.AreEqual(0ul, plaindest[4]); + Assert.AreEqual(0ul, plaindest[5]); + Assert.AreEqual(0ul, plaindest[6]); + Assert.AreEqual(256ul, plaindest[7]); + + plain.Set("1x^2"); + encryptor.Encrypt(plain, encrypted); + evaluator.ApplyGaloisInplace(encrypted, 1, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + + // {1x^2} + Assert.AreEqual(3ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(1ul, plaindest[2]); + + evaluator.ApplyGaloisInplace(encrypted, 3, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + + // {1x^6} + Assert.AreEqual(7ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(0ul, plaindest[2]); + Assert.AreEqual(0ul, plaindest[3]); + Assert.AreEqual(0ul, plaindest[4]); + Assert.AreEqual(0ul, plaindest[5]); + Assert.AreEqual(1ul, plaindest[6]); + + evaluator.ApplyGaloisInplace(encrypted, 5, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + + // {100x^6} + Assert.AreEqual(7ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(0ul, plaindest[2]); + Assert.AreEqual(0ul, plaindest[3]); + Assert.AreEqual(0ul, plaindest[4]); + Assert.AreEqual(0ul, plaindest[5]); + Assert.AreEqual(256ul, plaindest[6]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 8, + PlainModulus = new Modulus(257), + CoeffModulus = CoeffModulus.Create(8, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + keygen.CreateGaloisKeys(galoisElts: new uint[] { 1u, 3u, 5u, 15u }, out GaloisKeys galoisKeys); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Plaintext plain = new Plaintext("1"); + Plaintext plaindest = new Plaintext(); + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + + encryptor.Encrypt(plain, encrypted); + evaluator.ApplyGalois(encrypted, galoisElt: 1, galoisKeys: galoisKeys, destination: encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + + plain.Set("1x^1"); + encryptor.Encrypt(plain, encrypted); + evaluator.ApplyGalois(encrypted, galoisElt: 1, galoisKeys: galoisKeys, destination: encdest); + decryptor.Decrypt(encdest, plaindest); + + // {1x^1} + Assert.AreEqual(2ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(1ul, plaindest[1]); + + evaluator.ApplyGalois(encdest, galoisElt: 3, galoisKeys: galoisKeys, destination: encrypted); + decryptor.Decrypt(encrypted, plaindest); + + // {1x^3} + Assert.AreEqual(4ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(0ul, plaindest[2]); + Assert.AreEqual(1ul, plaindest[3]); + + evaluator.ApplyGalois(encrypted, galoisElt: 5, galoisKeys: galoisKeys, destination: encdest); + decryptor.Decrypt(encdest, plaindest); + + // {100x^7} + Assert.AreEqual(8ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(0ul, plaindest[2]); + Assert.AreEqual(0ul, plaindest[3]); + Assert.AreEqual(0ul, plaindest[4]); + Assert.AreEqual(0ul, plaindest[5]); + Assert.AreEqual(0ul, plaindest[6]); + Assert.AreEqual(256ul, plaindest[7]); + + plain.Set("1x^2"); + encryptor.Encrypt(plain, encrypted); + evaluator.ApplyGaloisInplace(encrypted, 1, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + + // {1x^2} + Assert.AreEqual(3ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(1ul, plaindest[2]); + + evaluator.ApplyGaloisInplace(encrypted, 3, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + + // {1x^6} + Assert.AreEqual(7ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(0ul, plaindest[2]); + Assert.AreEqual(0ul, plaindest[3]); + Assert.AreEqual(0ul, plaindest[4]); + Assert.AreEqual(0ul, plaindest[5]); + Assert.AreEqual(1ul, plaindest[6]); + + evaluator.ApplyGaloisInplace(encrypted, 5, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + + // {100x^6} + Assert.AreEqual(7ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(0ul, plaindest[1]); + Assert.AreEqual(0ul, plaindest[2]); + Assert.AreEqual(0ul, plaindest[3]); + Assert.AreEqual(0ul, plaindest[4]); + Assert.AreEqual(0ul, plaindest[5]); + Assert.AreEqual(256ul, plaindest[6]); + } } [TestMethod] public void TransformPlainToNTTTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - Evaluator evaluator = new Evaluator(context); - - Plaintext plain = new Plaintext("0"); - Plaintext plaindest = new Plaintext(); - Assert.IsFalse(plain.IsNTTForm); - - evaluator.TransformToNTT(plain, context.FirstParmsId, plaindest); - Assert.IsTrue(plaindest.IsZero); - Assert.IsTrue(plaindest.IsNTTForm); - Assert.IsTrue(plaindest.ParmsId == context.FirstParmsId); - - plain = new Plaintext("1"); - Assert.IsFalse(plain.IsNTTForm); - - evaluator.TransformToNTTInplace(plain, context.FirstParmsId); - Assert.IsTrue(plain.IsNTTForm); - - for (ulong i = 0; i < 256; i++) + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + Evaluator evaluator = new Evaluator(context); + + Plaintext plain = new Plaintext("0"); + Plaintext plaindest = new Plaintext(); + Assert.IsFalse(plain.IsNTTForm); + + evaluator.TransformToNTT(plain, context.FirstParmsId, plaindest); + Assert.IsTrue(plaindest.IsZero); + Assert.IsTrue(plaindest.IsNTTForm); + Assert.IsTrue(plaindest.ParmsId == context.FirstParmsId); + + plain = new Plaintext("1"); + Assert.IsFalse(plain.IsNTTForm); + + evaluator.TransformToNTTInplace(plain, context.FirstParmsId); + Assert.IsTrue(plain.IsNTTForm); + + for (ulong i = 0; i < 256; i++) + { + Assert.AreEqual(1ul, plain[i]); + } + } { - Assert.AreEqual(1ul, plain[i]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + Evaluator evaluator = new Evaluator(context); + + Plaintext plain = new Plaintext("0"); + Plaintext plaindest = new Plaintext(); + Assert.IsFalse(plain.IsNTTForm); + + evaluator.TransformToNTT(plain, context.FirstParmsId, plaindest); + Assert.IsTrue(plaindest.IsZero); + Assert.IsTrue(plaindest.IsNTTForm); + Assert.IsTrue(plaindest.ParmsId == context.FirstParmsId); + + plain = new Plaintext("1"); + Assert.IsFalse(plain.IsNTTForm); + + evaluator.TransformToNTTInplace(plain, context.FirstParmsId); + Assert.IsTrue(plain.IsNTTForm); + + for (ulong i = 0; i < 256; i++) + { + Assert.AreEqual(1ul, plain[i]); + } } } [TestMethod] public void TransformEncryptedToNTTTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - Ciphertext encdest2 = new Ciphertext(); - Plaintext plaindest = new Plaintext(); - - encryptor.Encrypt(new Plaintext("0"), encrypted); - Assert.IsFalse(encrypted.IsNTTForm); - - evaluator.TransformToNTT(encrypted, encdest); - Assert.IsTrue(encdest.IsNTTForm); - - evaluator.TransformFromNTT(encdest, encdest2); - Assert.IsFalse(encdest2.IsNTTForm); - - decryptor.Decrypt(encdest2, plaindest); - Assert.AreEqual(1ul, plaindest.CoeffCount); - Assert.AreEqual(0ul, plaindest[0]); - Assert.AreEqual(context.FirstParmsId, encdest2.ParmsId); - - encryptor.Encrypt(new Plaintext("1"), encrypted); - Assert.IsFalse(encrypted.IsNTTForm); - - evaluator.TransformToNTTInplace(encrypted); - Assert.IsTrue(encrypted.IsNTTForm); - - evaluator.TransformFromNTTInplace(encrypted); - Assert.IsFalse(encrypted.IsNTTForm); - - decryptor.Decrypt(encrypted, plaindest); - - Assert.AreEqual(1ul, plaindest.CoeffCount); - Assert.AreEqual(1ul, plaindest[0]); - Assert.AreEqual(context.FirstParmsId, encrypted.ParmsId); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Ciphertext encdest2 = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("0"), encrypted); + Assert.IsFalse(encrypted.IsNTTForm); + + evaluator.TransformToNTT(encrypted, encdest); + Assert.IsTrue(encdest.IsNTTForm); + + evaluator.TransformFromNTT(encdest, encdest2); + Assert.IsFalse(encdest2.IsNTTForm); + + decryptor.Decrypt(encdest2, plaindest); + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(context.FirstParmsId, encdest2.ParmsId); + + encryptor.Encrypt(new Plaintext("1"), encrypted); + Assert.IsFalse(encrypted.IsNTTForm); + + evaluator.TransformToNTTInplace(encrypted); + Assert.IsTrue(encrypted.IsNTTForm); + + evaluator.TransformFromNTTInplace(encrypted); + Assert.IsFalse(encrypted.IsNTTForm); + + decryptor.Decrypt(encrypted, plaindest); + + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + Assert.AreEqual(context.FirstParmsId, encrypted.ParmsId); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Ciphertext encdest2 = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("0"), encrypted); + Assert.IsFalse(encrypted.IsNTTForm); + + evaluator.TransformToNTT(encrypted, encdest); + Assert.IsTrue(encdest.IsNTTForm); + + evaluator.TransformFromNTT(encdest, encdest2); + Assert.IsFalse(encdest2.IsNTTForm); + + decryptor.Decrypt(encdest2, plaindest); + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(0ul, plaindest[0]); + Assert.AreEqual(context.FirstParmsId, encdest2.ParmsId); + + encryptor.Encrypt(new Plaintext("1"), encrypted); + Assert.IsFalse(encrypted.IsNTTForm); + + evaluator.TransformToNTTInplace(encrypted); + Assert.IsTrue(encrypted.IsNTTForm); + + evaluator.TransformFromNTTInplace(encrypted); + Assert.IsFalse(encrypted.IsNTTForm); + + decryptor.Decrypt(encrypted, plaindest); + + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + Assert.AreEqual(context.FirstParmsId, encrypted.ParmsId); + } } [TestMethod] public void ModSwitchToNextTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: true, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted = new Ciphertext(context); - Ciphertext encdest = new Ciphertext(); - Plaintext plain = new Plaintext(); - - plain.Set("0"); - encryptor.Encrypt(plain, encrypted); - evaluator.ModSwitchToNext(encrypted, encdest); - decryptor.Decrypt(encdest, plain); - - Assert.AreEqual(1ul, plain.CoeffCount); - Assert.AreEqual(0ul, plain[0]); - - plain.Set("1"); - encryptor.Encrypt(plain, encrypted); - evaluator.ModSwitchToNextInplace(encrypted); - decryptor.Decrypt(encrypted, plain); - - Assert.AreEqual(1ul, plain.CoeffCount); - Assert.AreEqual(1ul, plain[0]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: true, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(context); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext(); + + plain.Set("0"); + encryptor.Encrypt(plain, encrypted); + evaluator.ModSwitchToNext(encrypted, encdest); + decryptor.Decrypt(encdest, plain); + + Assert.AreEqual(1ul, plain.CoeffCount); + Assert.AreEqual(0ul, plain[0]); + + plain.Set("1"); + encryptor.Encrypt(plain, encrypted); + evaluator.ModSwitchToNextInplace(encrypted); + decryptor.Decrypt(encrypted, plain); + + Assert.AreEqual(1ul, plain.CoeffCount); + Assert.AreEqual(1ul, plain[0]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: true, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(context); + Ciphertext encdest = new Ciphertext(); + Plaintext plain = new Plaintext(); + + plain.Set("0"); + encryptor.Encrypt(plain, encrypted); + evaluator.ModSwitchToNext(encrypted, encdest); + decryptor.Decrypt(encdest, plain); + + Assert.AreEqual(1ul, plain.CoeffCount); + Assert.AreEqual(0ul, plain[0]); + + plain.Set("1"); + encryptor.Encrypt(plain, encrypted); + evaluator.ModSwitchToNextInplace(encrypted); + decryptor.Decrypt(encrypted, plain); + + Assert.AreEqual(1ul, plain.CoeffCount); + Assert.AreEqual(1ul, plain[0]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 8192, + PlainModulus = new Modulus(786433), + CoeffModulus = CoeffModulus.BGVDefault(8192) + }; + SEALContext context = new SEALContext(parms, + expandModChain: true, + secLevel: SecLevelType.TC128); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(context); + Ciphertext encdest_1 = new Ciphertext(); + Ciphertext encdest_2 = new Ciphertext(); + Plaintext plain = new Plaintext(); + + plain.Set("1"); + encryptor.Encrypt(plain, encrypted); + evaluator.ModSwitchToNext(encrypted, encdest_1); + evaluator.ModSwitchToNext(encdest_1, encdest_2); + decryptor.Decrypt(encdest_2, plain); + Assert.AreEqual(1ul, plain.CoeffCount); + Assert.AreEqual(1ul, plain[0]); + } } [TestMethod] public void ModSwitchToTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30, 30 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: true, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted = new Ciphertext(context); - Ciphertext encdest = new Ciphertext(context); - Plaintext plaindest = new Plaintext(); - - encryptor.Encrypt(new Plaintext("1"), encrypted); - ParmsId destParmsId = context.FirstContextData.NextContextData - .NextContextData.ParmsId; - - evaluator.ModSwitchTo(encrypted, context.FirstParmsId, encdest); - decryptor.Decrypt(encdest, plaindest); - - Assert.IsTrue(encrypted.ParmsId == context.FirstParmsId); - Assert.IsTrue(encdest.ParmsId == context.FirstParmsId); - Assert.AreEqual(1ul, plaindest.CoeffCount); - Assert.AreEqual(1ul, plaindest[0]); - - evaluator.ModSwitchTo(encrypted, destParmsId, encdest); - decryptor.Decrypt(encdest, plaindest); - - Assert.IsTrue(encrypted.ParmsId == context.FirstParmsId); - Assert.IsTrue(encdest.ParmsId == destParmsId); - Assert.AreEqual(1ul, plaindest.CoeffCount); - Assert.AreEqual(1ul, plaindest[0]); - - encryptor.Encrypt(new Plaintext("3x^2 + 2x^1 + 1"), encrypted); - evaluator.ModSwitchToInplace(encrypted, context.FirstParmsId); - decryptor.Decrypt(encrypted, plaindest); - - Assert.IsTrue(encrypted.ParmsId == context.FirstParmsId); - Assert.AreEqual(3ul, plaindest.CoeffCount); - Assert.AreEqual(1ul, plaindest[0]); - Assert.AreEqual(2ul, plaindest[1]); - Assert.AreEqual(3ul, plaindest[2]); - - evaluator.ModSwitchToInplace(encrypted, destParmsId); - decryptor.Decrypt(encrypted, plaindest); - - Assert.IsTrue(encrypted.ParmsId == destParmsId); - Assert.AreEqual(3ul, plaindest.CoeffCount); - Assert.AreEqual(1ul, plaindest[0]); - Assert.AreEqual(2ul, plaindest[1]); - Assert.AreEqual(3ul, plaindest[2]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30, 30 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: true, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(context); + Ciphertext encdest = new Ciphertext(context); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("1"), encrypted); + ParmsId destParmsId = context.FirstContextData.NextContextData + .NextContextData.ParmsId; + + evaluator.ModSwitchTo(encrypted, context.FirstParmsId, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.IsTrue(encrypted.ParmsId == context.FirstParmsId); + Assert.IsTrue(encdest.ParmsId == context.FirstParmsId); + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + + evaluator.ModSwitchTo(encrypted, destParmsId, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.IsTrue(encrypted.ParmsId == context.FirstParmsId); + Assert.IsTrue(encdest.ParmsId == destParmsId); + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + + encryptor.Encrypt(new Plaintext("3x^2 + 2x^1 + 1"), encrypted); + evaluator.ModSwitchToInplace(encrypted, context.FirstParmsId); + decryptor.Decrypt(encrypted, plaindest); + + Assert.IsTrue(encrypted.ParmsId == context.FirstParmsId); + Assert.AreEqual(3ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + Assert.AreEqual(2ul, plaindest[1]); + Assert.AreEqual(3ul, plaindest[2]); + + evaluator.ModSwitchToInplace(encrypted, destParmsId); + decryptor.Decrypt(encrypted, plaindest); + + Assert.IsTrue(encrypted.ParmsId == destParmsId); + Assert.AreEqual(3ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + Assert.AreEqual(2ul, plaindest[1]); + Assert.AreEqual(3ul, plaindest[2]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30, 30 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: true, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted = new Ciphertext(context); + Ciphertext encdest = new Ciphertext(context); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(new Plaintext("1"), encrypted); + ParmsId destParmsId = context.FirstContextData.NextContextData + .NextContextData.ParmsId; + + evaluator.ModSwitchTo(encrypted, context.FirstParmsId, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.IsTrue(encrypted.ParmsId == context.FirstParmsId); + Assert.IsTrue(encdest.ParmsId == context.FirstParmsId); + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + + evaluator.ModSwitchTo(encrypted, destParmsId, encdest); + decryptor.Decrypt(encdest, plaindest); + + Assert.IsTrue(encrypted.ParmsId == context.FirstParmsId); + Assert.IsTrue(encdest.ParmsId == destParmsId); + Assert.AreEqual(1ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + + encryptor.Encrypt(new Plaintext("3x^2 + 2x^1 + 1"), encrypted); + evaluator.ModSwitchToInplace(encrypted, context.FirstParmsId); + decryptor.Decrypt(encrypted, plaindest); + + Assert.IsTrue(encrypted.ParmsId == context.FirstParmsId); + Assert.AreEqual(3ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + Assert.AreEqual(2ul, plaindest[1]); + Assert.AreEqual(3ul, plaindest[2]); + + evaluator.ModSwitchToInplace(encrypted, destParmsId); + decryptor.Decrypt(encrypted, plaindest); + + Assert.IsTrue(encrypted.ParmsId == destParmsId); + Assert.AreEqual(3ul, plaindest.CoeffCount); + Assert.AreEqual(1ul, plaindest[0]); + Assert.AreEqual(2ul, plaindest[1]); + Assert.AreEqual(3ul, plaindest[2]); + } } [TestMethod] @@ -944,128 +1810,253 @@ public void ModSwitchToPlainTest() [TestMethod] public void RotateMatrixTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 8, - PlainModulus = new Modulus(257), - CoeffModulus = CoeffModulus.Create(8, new int[] { 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateGaloisKeys(out GaloisKeys galoisKeys); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - BatchEncoder encoder = new BatchEncoder(context); - - Plaintext plain = new Plaintext(); - List vec = new List - { - 1, 2, 3, 4, - 5, 6, 7, 8 - }; - - encoder.Encode(vec, plain); - - Ciphertext encrypted = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - Plaintext plaindest = new Plaintext(); - - encryptor.Encrypt(plain, encrypted); - evaluator.RotateColumns(encrypted, galoisKeys, encdest); - decryptor.Decrypt(encdest, plaindest); - encoder.Decode(plaindest, vec); - - Assert.IsTrue(AreCollectionsEqual(vec, new List - { - 5, 6, 7, 8, - 1, 2, 3, 4 - })); - - evaluator.RotateRows(encdest, -1, galoisKeys, encrypted); - decryptor.Decrypt(encrypted, plaindest); - encoder.Decode(plaindest, vec); - - Assert.IsTrue(AreCollectionsEqual(vec, new List - { - 8, 5, 6, 7, - 4, 1, 2, 3 - })); - - evaluator.RotateRowsInplace(encrypted, 2, galoisKeys); - decryptor.Decrypt(encrypted, plaindest); - encoder.Decode(plaindest, vec); - - Assert.IsTrue(AreCollectionsEqual(vec, new List - { - 6, 7, 8, 5, - 2, 3, 4, 1 - })); - - evaluator.RotateColumnsInplace(encrypted, galoisKeys); - decryptor.Decrypt(encrypted, plaindest); - encoder.Decode(plaindest, vec); - - Assert.IsTrue(AreCollectionsEqual(vec, new List + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 8, + PlainModulus = new Modulus(257), + CoeffModulus = CoeffModulus.Create(8, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys galoisKeys); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + BatchEncoder encoder = new BatchEncoder(context); + + Plaintext plain = new Plaintext(); + List vec = new List + { + 1, 2, 3, 4, + 5, 6, 7, 8 + }; + + encoder.Encode(vec, plain); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(plain, encrypted); + evaluator.RotateColumns(encrypted, galoisKeys, encdest); + decryptor.Decrypt(encdest, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 5, 6, 7, 8, + 1, 2, 3, 4 + })); + + evaluator.RotateRows(encdest, -1, galoisKeys, encrypted); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 8, 5, 6, 7, + 4, 1, 2, 3 + })); + + evaluator.RotateRowsInplace(encrypted, 2, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 6, 7, 8, 5, + 2, 3, 4, 1 + })); + + evaluator.RotateColumnsInplace(encrypted, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 2, 3, 4, 1, + 6, 7, 8, 5 + })); + } { - 2, 3, 4, 1, - 6, 7, 8, 5 - })); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 8, + PlainModulus = new Modulus(257), + CoeffModulus = CoeffModulus.Create(8, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys galoisKeys); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + BatchEncoder encoder = new BatchEncoder(context); + + Plaintext plain = new Plaintext(); + List vec = new List + { + 1, 2, 3, 4, + 5, 6, 7, 8 + }; + + encoder.Encode(vec, plain); + + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(plain, encrypted); + evaluator.RotateColumns(encrypted, galoisKeys, encdest); + decryptor.Decrypt(encdest, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 5, 6, 7, 8, + 1, 2, 3, 4 + })); + + evaluator.RotateRows(encdest, -1, galoisKeys, encrypted); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 8, 5, 6, 7, + 4, 1, 2, 3 + })); + + evaluator.RotateRowsInplace(encrypted, 2, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 6, 7, 8, 5, + 2, 3, 4, 1 + })); + + evaluator.RotateColumnsInplace(encrypted, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 2, 3, 4, 1, + 6, 7, 8, 5 + })); + } } [TestMethod] public void RelinearizeTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateRelinKeys(out RelinKeys relinKeys); - keygen.CreatePublicKey(out PublicKey publicKey); - - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted1 = new Ciphertext(context); - Ciphertext encrypted2 = new Ciphertext(context); - Plaintext plain1 = new Plaintext(); - Plaintext plain2 = new Plaintext(); - - plain1.Set(0); - encryptor.Encrypt(plain1, encrypted1); - evaluator.SquareInplace(encrypted1); - evaluator.RelinearizeInplace(encrypted1, relinKeys); - decryptor.Decrypt(encrypted1, plain2); - - Assert.AreEqual(1ul, plain2.CoeffCount); - Assert.AreEqual(0ul, plain2[0]); - - plain1.Set("1x^10 + 2"); - encryptor.Encrypt(plain1, encrypted1); - evaluator.SquareInplace(encrypted1); - evaluator.RelinearizeInplace(encrypted1, relinKeys); - evaluator.SquareInplace(encrypted1); - evaluator.Relinearize(encrypted1, relinKeys, encrypted2); - decryptor.Decrypt(encrypted2, plain2); - - // {1x^40 + 8x^30 + 18x^20 + 20x^10 + 10} - Assert.AreEqual(41ul, plain2.CoeffCount); - Assert.AreEqual(16ul, plain2[0]); - Assert.AreEqual(32ul, plain2[10]); - Assert.AreEqual(24ul, plain2[20]); - Assert.AreEqual(8ul, plain2[30]); - Assert.AreEqual(1ul, plain2[40]); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted1 = new Ciphertext(context); + Ciphertext encrypted2 = new Ciphertext(context); + Plaintext plain1 = new Plaintext(); + Plaintext plain2 = new Plaintext(); + + plain1.Set(0); + encryptor.Encrypt(plain1, encrypted1); + evaluator.SquareInplace(encrypted1); + evaluator.RelinearizeInplace(encrypted1, relinKeys); + decryptor.Decrypt(encrypted1, plain2); + + Assert.AreEqual(1ul, plain2.CoeffCount); + Assert.AreEqual(0ul, plain2[0]); + + plain1.Set("1x^10 + 2"); + encryptor.Encrypt(plain1, encrypted1); + evaluator.SquareInplace(encrypted1); + evaluator.RelinearizeInplace(encrypted1, relinKeys); + evaluator.SquareInplace(encrypted1); + evaluator.Relinearize(encrypted1, relinKeys, encrypted2); + decryptor.Decrypt(encrypted2, plain2); + + // {1x^40 + 8x^30 + 18x^20 + 20x^10 + 10} + Assert.AreEqual(41ul, plain2.CoeffCount); + Assert.AreEqual(16ul, plain2[0]); + Assert.AreEqual(32ul, plain2[10]); + Assert.AreEqual(24ul, plain2[20]); + Assert.AreEqual(8ul, plain2[30]); + Assert.AreEqual(1ul, plain2[40]); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted1 = new Ciphertext(context); + Ciphertext encrypted2 = new Ciphertext(context); + Plaintext plain1 = new Plaintext(); + Plaintext plain2 = new Plaintext(); + + plain1.Set(0); + encryptor.Encrypt(plain1, encrypted1); + evaluator.SquareInplace(encrypted1); + evaluator.RelinearizeInplace(encrypted1, relinKeys); + decryptor.Decrypt(encrypted1, plain2); + + Assert.AreEqual(1ul, plain2.CoeffCount); + Assert.AreEqual(0ul, plain2[0]); + + plain1.Set("1x^10 + 2"); + encryptor.Encrypt(plain1, encrypted1); + evaluator.SquareInplace(encrypted1); + evaluator.RelinearizeInplace(encrypted1, relinKeys); + evaluator.SquareInplace(encrypted1); + evaluator.Relinearize(encrypted1, relinKeys, encrypted2); + decryptor.Decrypt(encrypted2, plain2); + + // {1x^40 + 8x^30 + 18x^20 + 20x^10 + 10} + Assert.AreEqual(41ul, plain2.CoeffCount); + Assert.AreEqual(16ul, plain2[0]); + Assert.AreEqual(32ul, plain2[10]); + Assert.AreEqual(24ul, plain2[20]); + Assert.AreEqual(8ul, plain2[30]); + Assert.AreEqual(1ul, plain2[40]); + } } [TestMethod] @@ -1184,220 +2175,438 @@ public void ComplexConjugateTest() [TestMethod] public void ExceptionsTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(65537ul), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(65537ul), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); - Evaluator evaluator = null; - Utilities.AssertThrows(() => evaluator = new Evaluator(null)); - evaluator = new Evaluator(context); + Evaluator evaluator = null; + Utilities.AssertThrows(() => evaluator = new Evaluator(null)); + evaluator = new Evaluator(context); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateGaloisKeys(out GaloisKeys galoisKeys); - keygen.CreateRelinKeys(out RelinKeys relinKeys); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys galoisKeys); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + + Ciphertext encrypted1 = new Ciphertext(); + Ciphertext encrypted2 = new Ciphertext(); + Ciphertext encrypted3 = new Ciphertext(); + Plaintext plain1 = new Plaintext(); + Plaintext plain2 = new Plaintext(); + List encrypteds = new List(); - Ciphertext encrypted1 = new Ciphertext(); - Ciphertext encrypted2 = new Ciphertext(); - Ciphertext encrypted3 = new Ciphertext(); - Plaintext plain1 = new Plaintext(); - Plaintext plain2 = new Plaintext(); - List encrypteds = new List(); + MemoryPoolHandle pool = MemoryManager.GetPool(MMProfOpt.ForceGlobal); - MemoryPoolHandle pool = MemoryManager.GetPool(MMProfOpt.ForceGlobal); + Utilities.AssertThrows(() => evaluator.Add(null, encrypted2, encrypted3)); + Utilities.AssertThrows(() => evaluator.Add(encrypted1, null, encrypted3)); + Utilities.AssertThrows(() => evaluator.Add(encrypted1, encrypted2, null)); + Utilities.AssertThrows(() => evaluator.Add(encrypted1, encrypted2, encrypted3)); - Utilities.AssertThrows(() => evaluator.Add(null, encrypted2, encrypted3)); - Utilities.AssertThrows(() => evaluator.Add(encrypted1, null, encrypted3)); - Utilities.AssertThrows(() => evaluator.Add(encrypted1, encrypted2, null)); - Utilities.AssertThrows(() => evaluator.Add(encrypted1, encrypted2, encrypted3)); + Utilities.AssertThrows(() => evaluator.AddInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.AddInplace(null, encrypted2)); - Utilities.AssertThrows(() => evaluator.AddInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.AddInplace(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.AddMany(encrypteds, null)); + Utilities.AssertThrows(() => evaluator.AddMany(null, encrypted2)); - Utilities.AssertThrows(() => evaluator.AddMany(encrypteds, null)); - Utilities.AssertThrows(() => evaluator.AddMany(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.AddPlain(encrypted1, plain1, null)); + Utilities.AssertThrows(() => evaluator.AddPlain(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.AddPlain(null, plain1, encrypted2)); + Utilities.AssertThrows(() => evaluator.AddPlain(encrypted1, plain1, encrypted2)); - Utilities.AssertThrows(() => evaluator.AddPlain(encrypted1, plain1, null)); - Utilities.AssertThrows(() => evaluator.AddPlain(encrypted1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.AddPlain(null, plain1, encrypted2)); - Utilities.AssertThrows(() => evaluator.AddPlain(encrypted1, plain1, encrypted2)); + Utilities.AssertThrows(() => evaluator.AddPlainInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.AddPlainInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.ApplyGalois(encrypted1, 1, galoisKeys, null)); + Utilities.AssertThrows(() => evaluator.ApplyGalois(encrypted1, 1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.ApplyGalois(null, 1, galoisKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.ApplyGalois(encrypted1, 1, galoisKeys, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.ApplyGalois(encrypted1, 1, galoisKeys, null)); - Utilities.AssertThrows(() => evaluator.ApplyGalois(encrypted1, 1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.ApplyGalois(null, 1, galoisKeys, encrypted2)); - Utilities.AssertThrows(() => evaluator.ApplyGalois(encrypted1, 1, galoisKeys, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.ApplyGaloisInplace(encrypted1, 1, null)); + Utilities.AssertThrows(() => evaluator.ApplyGaloisInplace(null, 1, galoisKeys)); - Utilities.AssertThrows(() => evaluator.ApplyGaloisInplace(encrypted1, 1, null)); - Utilities.AssertThrows(() => evaluator.ApplyGaloisInplace(null, 1, galoisKeys)); + Utilities.AssertThrows(() => evaluator.ComplexConjugate(encrypted1, galoisKeys, null)); + Utilities.AssertThrows(() => evaluator.ComplexConjugate(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.ComplexConjugate(null, galoisKeys, encrypted2)); - Utilities.AssertThrows(() => evaluator.ComplexConjugate(encrypted1, galoisKeys, null)); - Utilities.AssertThrows(() => evaluator.ComplexConjugate(encrypted1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.ComplexConjugate(null, galoisKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.ComplexConjugateInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.ComplexConjugateInplace(null, galoisKeys)); + + Utilities.AssertThrows(() => evaluator.Exponentiate(encrypted1, 2, relinKeys, null)); + Utilities.AssertThrows(() => evaluator.Exponentiate(encrypted1, 2, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.Exponentiate(null, 2, relinKeys, encrypted2)); + + Utilities.AssertThrows(() => evaluator.ExponentiateInplace(encrypted1, 2, null)); + Utilities.AssertThrows(() => evaluator.ExponentiateInplace(null, 2, relinKeys)); - Utilities.AssertThrows(() => evaluator.ComplexConjugateInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.ComplexConjugateInplace(null, galoisKeys)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(plain1, ParmsId.Zero, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(plain1, null, plain2)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(null, ParmsId.Zero, plain2)); - Utilities.AssertThrows(() => evaluator.Exponentiate(encrypted1, 2, relinKeys, null)); - Utilities.AssertThrows(() => evaluator.Exponentiate(encrypted1, 2, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.Exponentiate(null, 2, relinKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(encrypted1, ParmsId.Zero, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(null, ParmsId.Zero, encrypted2)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(encrypted1, ParmsId.Zero, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(encrypted: null, parmsId: ParmsId.Zero)); + Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(encrypted1, ParmsId.Zero, pool)); - Utilities.AssertThrows(() => evaluator.ExponentiateInplace(encrypted1, 2, null)); - Utilities.AssertThrows(() => evaluator.ExponentiateInplace(null, 2, relinKeys)); + Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(plain1, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(plain: null, parmsId: ParmsId.Zero)); - Utilities.AssertThrows(() => evaluator.ModSwitchTo(plain1, ParmsId.Zero, null)); - Utilities.AssertThrows(() => evaluator.ModSwitchTo(plain1, null, plain2)); - Utilities.AssertThrows(() => evaluator.ModSwitchTo(null, ParmsId.Zero, plain2)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNext(plain1, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNext(null, plain2)); - Utilities.AssertThrows(() => evaluator.ModSwitchTo(encrypted1, ParmsId.Zero, null)); - Utilities.AssertThrows(() => evaluator.ModSwitchTo(encrypted1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.ModSwitchTo(null, ParmsId.Zero, encrypted2)); - Utilities.AssertThrows(() => evaluator.ModSwitchTo(encrypted1, ParmsId.Zero, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNextInplace(null)); - Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(encrypted: null, parmsId: ParmsId.Zero)); - Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(encrypted1, ParmsId.Zero, pool)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNext(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNext(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNext(encrypted1, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.ModSwitchToNextInplace(encrypted: null)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNextInplace(encrypted1, pool)); - Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(plain1, null)); - Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(plain: null, parmsId: ParmsId.Zero)); + Utilities.AssertThrows(() => evaluator.Multiply(encrypted1, encrypted2, null)); + Utilities.AssertThrows(() => evaluator.Multiply(encrypted1, null, encrypted3)); + Utilities.AssertThrows(() => evaluator.Multiply(null, encrypted2, encrypted3)); + Utilities.AssertThrows(() => evaluator.Multiply(encrypted1, encrypted2, encrypted3, pool)); - Utilities.AssertThrows(() => evaluator.ModSwitchToNext(plain1, null)); - Utilities.AssertThrows(() => evaluator.ModSwitchToNext(null, plain2)); + Utilities.AssertThrows(() => evaluator.MultiplyInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.MultiplyInplace(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyInplace(encrypted1, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.ModSwitchToNextInplace(null)); + Utilities.AssertThrows(() => evaluator.MultiplyMany(encrypteds, relinKeys, null)); + Utilities.AssertThrows(() => evaluator.MultiplyMany(encrypteds, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyMany(null, relinKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyMany(encrypteds, relinKeys, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.ModSwitchToNext(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.ModSwitchToNext(null, encrypted2)); - Utilities.AssertThrows(() => evaluator.ModSwitchToNext(encrypted1, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.MultiplyPlain(encrypted1, plain1, null)); + Utilities.AssertThrows(() => evaluator.MultiplyPlain(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyPlain(null, plain1, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyPlain(encrypted1, plain1, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.ModSwitchToNextInplace(encrypted: null)); - Utilities.AssertThrows(() => evaluator.ModSwitchToNextInplace(encrypted1, pool)); + Utilities.AssertThrows(() => evaluator.MultiplyPlainInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.MultiplyPlainInplace(null, plain1)); - Utilities.AssertThrows(() => evaluator.Multiply(encrypted1, encrypted2, null)); - Utilities.AssertThrows(() => evaluator.Multiply(encrypted1, null, encrypted3)); - Utilities.AssertThrows(() => evaluator.Multiply(null, encrypted2, encrypted3)); - Utilities.AssertThrows(() => evaluator.Multiply(encrypted1, encrypted2, encrypted3, pool)); + Utilities.AssertThrows(() => evaluator.Negate(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.Negate(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.Negate(encrypted1, encrypted2)); - Utilities.AssertThrows(() => evaluator.MultiplyInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.MultiplyInplace(null, encrypted2)); - Utilities.AssertThrows(() => evaluator.MultiplyInplace(encrypted1, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.NegateInplace(null)); - Utilities.AssertThrows(() => evaluator.MultiplyMany(encrypteds, relinKeys, null)); - Utilities.AssertThrows(() => evaluator.MultiplyMany(encrypteds, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.MultiplyMany(null, relinKeys, encrypted2)); - Utilities.AssertThrows(() => evaluator.MultiplyMany(encrypteds, relinKeys, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.Relinearize(encrypted1, relinKeys, null)); + Utilities.AssertThrows(() => evaluator.Relinearize(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.Relinearize(null, relinKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.Relinearize(encrypted1, relinKeys, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.MultiplyPlain(encrypted1, plain1, null)); - Utilities.AssertThrows(() => evaluator.MultiplyPlain(encrypted1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.MultiplyPlain(null, plain1, encrypted2)); - Utilities.AssertThrows(() => evaluator.MultiplyPlain(encrypted1, plain1, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.RelinearizeInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.RelinearizeInplace(null, relinKeys)); + Utilities.AssertThrows(() => evaluator.RelinearizeInplace(encrypted1, relinKeys, pool)); - Utilities.AssertThrows(() => evaluator.MultiplyPlainInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.MultiplyPlainInplace(null, plain1)); + Utilities.AssertThrows(() => evaluator.RescaleTo(encrypted1, ParmsId.Zero, null)); + Utilities.AssertThrows(() => evaluator.RescaleTo(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.RescaleTo(null, ParmsId.Zero, encrypted2)); + Utilities.AssertThrows(() => evaluator.RescaleTo(encrypted1, ParmsId.Zero, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.Negate(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.Negate(null, encrypted2)); - Utilities.AssertThrows(() => evaluator.Negate(encrypted1, encrypted2)); + Utilities.AssertThrows(() => evaluator.RescaleToInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.RescaleToInplace(null, ParmsId.Zero)); + Utilities.AssertThrows(() => evaluator.RescaleToInplace(encrypted1, ParmsId.Zero, pool)); - Utilities.AssertThrows(() => evaluator.NegateInplace(null)); + Utilities.AssertThrows(() => evaluator.RescaleToNext(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.RescaleToNext(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.RescaleToNext(encrypted1, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.Relinearize(encrypted1, relinKeys, null)); - Utilities.AssertThrows(() => evaluator.Relinearize(encrypted1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.Relinearize(null, relinKeys, encrypted2)); - Utilities.AssertThrows(() => evaluator.Relinearize(encrypted1, relinKeys, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.RescaleToNextInplace(null)); + Utilities.AssertThrows(() => evaluator.RescaleToNextInplace(encrypted1, pool)); - Utilities.AssertThrows(() => evaluator.RelinearizeInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.RelinearizeInplace(null, relinKeys)); - Utilities.AssertThrows(() => evaluator.RelinearizeInplace(encrypted1, relinKeys, pool)); + Utilities.AssertThrows(() => evaluator.RotateColumns(encrypted1, galoisKeys, null)); + Utilities.AssertThrows(() => evaluator.RotateColumns(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.RotateColumns(null, galoisKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.RotateColumns(encrypted1, galoisKeys, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.RescaleTo(encrypted1, ParmsId.Zero, null)); - Utilities.AssertThrows(() => evaluator.RescaleTo(encrypted1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.RescaleTo(null, ParmsId.Zero, encrypted2)); - Utilities.AssertThrows(() => evaluator.RescaleTo(encrypted1, ParmsId.Zero, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.RotateColumnsInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.RotateColumnsInplace(null, galoisKeys)); + Utilities.AssertThrows(() => evaluator.RotateColumnsInplace(encrypted1, galoisKeys, pool)); - Utilities.AssertThrows(() => evaluator.RescaleToInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.RescaleToInplace(null, ParmsId.Zero)); - Utilities.AssertThrows(() => evaluator.RescaleToInplace(encrypted1, ParmsId.Zero, pool)); + Utilities.AssertThrows(() => evaluator.RotateRows(encrypted1, 1, galoisKeys, null)); + Utilities.AssertThrows(() => evaluator.RotateRows(encrypted1, 1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.RotateRows(null, 1, galoisKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.RotateRows(encrypted1, 1, galoisKeys, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.RescaleToNext(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.RescaleToNext(null, encrypted2)); - Utilities.AssertThrows(() => evaluator.RescaleToNext(encrypted1, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.RotateRowsInplace(encrypted1, 1, null)); + Utilities.AssertThrows(() => evaluator.RotateRowsInplace(null, 1, galoisKeys)); + Utilities.AssertThrows(() => evaluator.RotateRowsInplace(encrypted1, 1, galoisKeys, pool)); - Utilities.AssertThrows(() => evaluator.RescaleToNextInplace(null)); - Utilities.AssertThrows(() => evaluator.RescaleToNextInplace(encrypted1, pool)); + Utilities.AssertThrows(() => evaluator.RotateVector(encrypted1, 1, galoisKeys, null)); + Utilities.AssertThrows(() => evaluator.RotateVector(encrypted1, 1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.RotateVector(null, 1, galoisKeys, encrypted2)); - Utilities.AssertThrows(() => evaluator.RotateColumns(encrypted1, galoisKeys, null)); - Utilities.AssertThrows(() => evaluator.RotateColumns(encrypted1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.RotateColumns(null, galoisKeys, encrypted2)); - Utilities.AssertThrows(() => evaluator.RotateColumns(encrypted1, galoisKeys, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.RotateVectorInplace(encrypted1, 1, null)); + Utilities.AssertThrows(() => evaluator.RotateVectorInplace(null, 1, galoisKeys)); - Utilities.AssertThrows(() => evaluator.RotateColumnsInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.RotateColumnsInplace(null, galoisKeys)); - Utilities.AssertThrows(() => evaluator.RotateColumnsInplace(encrypted1, galoisKeys, pool)); + Utilities.AssertThrows(() => evaluator.Square(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.Square(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.Square(encrypted1, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.RotateRows(encrypted1, 1, galoisKeys, null)); - Utilities.AssertThrows(() => evaluator.RotateRows(encrypted1, 1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.RotateRows(null, 1, galoisKeys, encrypted2)); - Utilities.AssertThrows(() => evaluator.RotateRows(encrypted1, 1, galoisKeys, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.SquareInplace(null)); + Utilities.AssertThrows(() => evaluator.SquareInplace(encrypted1, pool)); - Utilities.AssertThrows(() => evaluator.RotateRowsInplace(encrypted1, 1, null)); - Utilities.AssertThrows(() => evaluator.RotateRowsInplace(null, 1, galoisKeys)); - Utilities.AssertThrows(() => evaluator.RotateRowsInplace(encrypted1, 1, galoisKeys, pool)); + Utilities.AssertThrows(() => evaluator.Sub(encrypted1, encrypted2, null)); + Utilities.AssertThrows(() => evaluator.Sub(encrypted1, null, encrypted3)); + Utilities.AssertThrows(() => evaluator.Sub(null, encrypted2, encrypted3)); + Utilities.AssertThrows(() => evaluator.Sub(encrypted1, encrypted2, encrypted3)); - Utilities.AssertThrows(() => evaluator.RotateVector(encrypted1, 1, galoisKeys, null)); - Utilities.AssertThrows(() => evaluator.RotateVector(encrypted1, 1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.RotateVector(null, 1, galoisKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.SubInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.SubInplace(null, encrypted2)); - Utilities.AssertThrows(() => evaluator.RotateVectorInplace(encrypted1, 1, null)); - Utilities.AssertThrows(() => evaluator.RotateVectorInplace(null, 1, galoisKeys)); + Utilities.AssertThrows(() => evaluator.SubPlain(encrypted1, plain1, null)); + Utilities.AssertThrows(() => evaluator.SubPlain(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.SubPlain(null, plain1, encrypted2)); + Utilities.AssertThrows(() => evaluator.SubPlain(encrypted1, plain1, encrypted2)); - Utilities.AssertThrows(() => evaluator.Square(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.Square(null, encrypted2)); - Utilities.AssertThrows(() => evaluator.Square(encrypted1, encrypted2, pool)); + Utilities.AssertThrows(() => evaluator.SubPlainInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.SquareInplace(null)); - Utilities.AssertThrows(() => evaluator.SquareInplace(encrypted1, pool)); + Utilities.AssertThrows(() => evaluator.TransformFromNTT(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.TransformFromNTT(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.TransformFromNTT(encrypted1, encrypted2)); - Utilities.AssertThrows(() => evaluator.Sub(encrypted1, encrypted2, null)); - Utilities.AssertThrows(() => evaluator.Sub(encrypted1, null, encrypted3)); - Utilities.AssertThrows(() => evaluator.Sub(null, encrypted2, encrypted3)); - Utilities.AssertThrows(() => evaluator.Sub(encrypted1, encrypted2, encrypted3)); + Utilities.AssertThrows(() => evaluator.TransformFromNTTInplace(null)); - Utilities.AssertThrows(() => evaluator.SubInplace(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.SubInplace(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.TransformToNTT(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.TransformToNTT(null, encrypted2)); - Utilities.AssertThrows(() => evaluator.SubPlain(encrypted1, plain1, null)); - Utilities.AssertThrows(() => evaluator.SubPlain(encrypted1, null, encrypted2)); - Utilities.AssertThrows(() => evaluator.SubPlain(null, plain1, encrypted2)); - Utilities.AssertThrows(() => evaluator.SubPlain(encrypted1, plain1, encrypted2)); + Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(null)); - Utilities.AssertThrows(() => evaluator.SubPlainInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.TransformToNTT(plain1, ParmsId.Zero, null)); + Utilities.AssertThrows(() => evaluator.TransformToNTT(plain1, null, plain2)); + Utilities.AssertThrows(() => evaluator.TransformToNTT(null, ParmsId.Zero, plain2)); + Utilities.AssertThrows(() => evaluator.TransformToNTT(plain1, ParmsId.Zero, plain2, pool)); - Utilities.AssertThrows(() => evaluator.TransformFromNTT(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.TransformFromNTT(null, encrypted2)); - Utilities.AssertThrows(() => evaluator.TransformFromNTT(encrypted1, encrypted2)); + Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(plain1, null)); + Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(null, ParmsId.Zero)); + Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(plain1, ParmsId.Zero, pool)); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(65537ul), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + + Evaluator evaluator = null; + Utilities.AssertThrows(() => evaluator = new Evaluator(null)); + evaluator = new Evaluator(context); + + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys galoisKeys); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + + Ciphertext encrypted1 = new Ciphertext(); + Ciphertext encrypted2 = new Ciphertext(); + Ciphertext encrypted3 = new Ciphertext(); + Plaintext plain1 = new Plaintext(); + Plaintext plain2 = new Plaintext(); + List encrypteds = new List(); + + MemoryPoolHandle pool = MemoryManager.GetPool(MMProfOpt.ForceGlobal); + + Utilities.AssertThrows(() => evaluator.Add(null, encrypted2, encrypted3)); + Utilities.AssertThrows(() => evaluator.Add(encrypted1, null, encrypted3)); + Utilities.AssertThrows(() => evaluator.Add(encrypted1, encrypted2, null)); + Utilities.AssertThrows(() => evaluator.Add(encrypted1, encrypted2, encrypted3)); + + Utilities.AssertThrows(() => evaluator.AddInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.AddInplace(null, encrypted2)); + + Utilities.AssertThrows(() => evaluator.AddMany(encrypteds, null)); + Utilities.AssertThrows(() => evaluator.AddMany(null, encrypted2)); + + Utilities.AssertThrows(() => evaluator.AddPlain(encrypted1, plain1, null)); + Utilities.AssertThrows(() => evaluator.AddPlain(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.AddPlain(null, plain1, encrypted2)); + Utilities.AssertThrows(() => evaluator.AddPlain(encrypted1, plain1, encrypted2)); + + Utilities.AssertThrows(() => evaluator.AddPlainInplace(encrypted1, null)); + + Utilities.AssertThrows(() => evaluator.ApplyGalois(encrypted1, 1, galoisKeys, null)); + Utilities.AssertThrows(() => evaluator.ApplyGalois(encrypted1, 1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.ApplyGalois(null, 1, galoisKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.ApplyGalois(encrypted1, 1, galoisKeys, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.ApplyGaloisInplace(encrypted1, 1, null)); + Utilities.AssertThrows(() => evaluator.ApplyGaloisInplace(null, 1, galoisKeys)); + + Utilities.AssertThrows(() => evaluator.ComplexConjugate(encrypted1, galoisKeys, null)); + Utilities.AssertThrows(() => evaluator.ComplexConjugate(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.ComplexConjugate(null, galoisKeys, encrypted2)); + + Utilities.AssertThrows(() => evaluator.ComplexConjugateInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.ComplexConjugateInplace(null, galoisKeys)); + + Utilities.AssertThrows(() => evaluator.Exponentiate(encrypted1, 2, relinKeys, null)); + Utilities.AssertThrows(() => evaluator.Exponentiate(encrypted1, 2, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.Exponentiate(null, 2, relinKeys, encrypted2)); + + Utilities.AssertThrows(() => evaluator.ExponentiateInplace(encrypted1, 2, null)); + Utilities.AssertThrows(() => evaluator.ExponentiateInplace(null, 2, relinKeys)); + + Utilities.AssertThrows(() => evaluator.ModSwitchTo(plain1, ParmsId.Zero, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(plain1, null, plain2)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(null, ParmsId.Zero, plain2)); + + Utilities.AssertThrows(() => evaluator.ModSwitchTo(encrypted1, ParmsId.Zero, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(null, ParmsId.Zero, encrypted2)); + Utilities.AssertThrows(() => evaluator.ModSwitchTo(encrypted1, ParmsId.Zero, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(encrypted: null, parmsId: ParmsId.Zero)); + Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(encrypted1, ParmsId.Zero, pool)); + + Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(plain1, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchToInplace(plain: null, parmsId: ParmsId.Zero)); + + Utilities.AssertThrows(() => evaluator.ModSwitchToNext(plain1, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNext(null, plain2)); + + Utilities.AssertThrows(() => evaluator.ModSwitchToNextInplace(null)); + + Utilities.AssertThrows(() => evaluator.ModSwitchToNext(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNext(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNext(encrypted1, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.ModSwitchToNextInplace(encrypted: null)); + Utilities.AssertThrows(() => evaluator.ModSwitchToNextInplace(encrypted1, pool)); + + Utilities.AssertThrows(() => evaluator.Multiply(encrypted1, encrypted2, null)); + Utilities.AssertThrows(() => evaluator.Multiply(encrypted1, null, encrypted3)); + Utilities.AssertThrows(() => evaluator.Multiply(null, encrypted2, encrypted3)); + Utilities.AssertThrows(() => evaluator.Multiply(encrypted1, encrypted2, encrypted3, pool)); + + Utilities.AssertThrows(() => evaluator.MultiplyInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.MultiplyInplace(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyInplace(encrypted1, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.MultiplyMany(encrypteds, relinKeys, null)); + Utilities.AssertThrows(() => evaluator.MultiplyMany(encrypteds, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyMany(null, relinKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyMany(encrypteds, relinKeys, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.TransformFromNTTInplace(null)); + Utilities.AssertThrows(() => evaluator.MultiplyPlain(encrypted1, plain1, null)); + Utilities.AssertThrows(() => evaluator.MultiplyPlain(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyPlain(null, plain1, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyPlain(encrypted1, plain1, encrypted2, pool)); - Utilities.AssertThrows(() => evaluator.TransformToNTT(encrypted1, null)); - Utilities.AssertThrows(() => evaluator.TransformToNTT(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.MultiplyPlainInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.MultiplyPlainInplace(null, plain1)); - Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(null)); + Utilities.AssertThrows(() => evaluator.Negate(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.Negate(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.Negate(encrypted1, encrypted2)); - Utilities.AssertThrows(() => evaluator.TransformToNTT(plain1, ParmsId.Zero, null)); - Utilities.AssertThrows(() => evaluator.TransformToNTT(plain1, null, plain2)); - Utilities.AssertThrows(() => evaluator.TransformToNTT(null, ParmsId.Zero, plain2)); - Utilities.AssertThrows(() => evaluator.TransformToNTT(plain1, ParmsId.Zero, plain2, pool)); + Utilities.AssertThrows(() => evaluator.NegateInplace(null)); - Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(plain1, null)); - Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(null, ParmsId.Zero)); - Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(plain1, ParmsId.Zero, pool)); + Utilities.AssertThrows(() => evaluator.Relinearize(encrypted1, relinKeys, null)); + Utilities.AssertThrows(() => evaluator.Relinearize(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.Relinearize(null, relinKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.Relinearize(encrypted1, relinKeys, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.RelinearizeInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.RelinearizeInplace(null, relinKeys)); + Utilities.AssertThrows(() => evaluator.RelinearizeInplace(encrypted1, relinKeys, pool)); + + Utilities.AssertThrows(() => evaluator.RescaleTo(encrypted1, ParmsId.Zero, null)); + Utilities.AssertThrows(() => evaluator.RescaleTo(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.RescaleTo(null, ParmsId.Zero, encrypted2)); + Utilities.AssertThrows(() => evaluator.RescaleTo(encrypted1, ParmsId.Zero, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.RescaleToInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.RescaleToInplace(null, ParmsId.Zero)); + Utilities.AssertThrows(() => evaluator.RescaleToInplace(encrypted1, ParmsId.Zero, pool)); + + Utilities.AssertThrows(() => evaluator.RescaleToNext(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.RescaleToNext(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.RescaleToNext(encrypted1, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.RescaleToNextInplace(null)); + Utilities.AssertThrows(() => evaluator.RescaleToNextInplace(encrypted1, pool)); + + Utilities.AssertThrows(() => evaluator.RotateColumns(encrypted1, galoisKeys, null)); + Utilities.AssertThrows(() => evaluator.RotateColumns(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.RotateColumns(null, galoisKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.RotateColumns(encrypted1, galoisKeys, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.RotateColumnsInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.RotateColumnsInplace(null, galoisKeys)); + Utilities.AssertThrows(() => evaluator.RotateColumnsInplace(encrypted1, galoisKeys, pool)); + + Utilities.AssertThrows(() => evaluator.RotateRows(encrypted1, 1, galoisKeys, null)); + Utilities.AssertThrows(() => evaluator.RotateRows(encrypted1, 1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.RotateRows(null, 1, galoisKeys, encrypted2)); + Utilities.AssertThrows(() => evaluator.RotateRows(encrypted1, 1, galoisKeys, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.RotateRowsInplace(encrypted1, 1, null)); + Utilities.AssertThrows(() => evaluator.RotateRowsInplace(null, 1, galoisKeys)); + Utilities.AssertThrows(() => evaluator.RotateRowsInplace(encrypted1, 1, galoisKeys, pool)); + + Utilities.AssertThrows(() => evaluator.RotateVector(encrypted1, 1, galoisKeys, null)); + Utilities.AssertThrows(() => evaluator.RotateVector(encrypted1, 1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.RotateVector(null, 1, galoisKeys, encrypted2)); + + Utilities.AssertThrows(() => evaluator.RotateVectorInplace(encrypted1, 1, null)); + Utilities.AssertThrows(() => evaluator.RotateVectorInplace(null, 1, galoisKeys)); + + Utilities.AssertThrows(() => evaluator.Square(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.Square(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.Square(encrypted1, encrypted2, pool)); + + Utilities.AssertThrows(() => evaluator.SquareInplace(null)); + Utilities.AssertThrows(() => evaluator.SquareInplace(encrypted1, pool)); + + Utilities.AssertThrows(() => evaluator.Sub(encrypted1, encrypted2, null)); + Utilities.AssertThrows(() => evaluator.Sub(encrypted1, null, encrypted3)); + Utilities.AssertThrows(() => evaluator.Sub(null, encrypted2, encrypted3)); + Utilities.AssertThrows(() => evaluator.Sub(encrypted1, encrypted2, encrypted3)); + + Utilities.AssertThrows(() => evaluator.SubInplace(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.SubInplace(null, encrypted2)); + + Utilities.AssertThrows(() => evaluator.SubPlain(encrypted1, plain1, null)); + Utilities.AssertThrows(() => evaluator.SubPlain(encrypted1, null, encrypted2)); + Utilities.AssertThrows(() => evaluator.SubPlain(null, plain1, encrypted2)); + Utilities.AssertThrows(() => evaluator.SubPlain(encrypted1, plain1, encrypted2)); + + Utilities.AssertThrows(() => evaluator.SubPlainInplace(encrypted1, null)); + + Utilities.AssertThrows(() => evaluator.TransformFromNTT(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.TransformFromNTT(null, encrypted2)); + Utilities.AssertThrows(() => evaluator.TransformFromNTT(encrypted1, encrypted2)); + + Utilities.AssertThrows(() => evaluator.TransformFromNTTInplace(null)); + + Utilities.AssertThrows(() => evaluator.TransformToNTT(encrypted1, null)); + Utilities.AssertThrows(() => evaluator.TransformToNTT(null, encrypted2)); + + Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(null)); + + Utilities.AssertThrows(() => evaluator.TransformToNTT(plain1, ParmsId.Zero, null)); + Utilities.AssertThrows(() => evaluator.TransformToNTT(plain1, null, plain2)); + Utilities.AssertThrows(() => evaluator.TransformToNTT(null, ParmsId.Zero, plain2)); + Utilities.AssertThrows(() => evaluator.TransformToNTT(plain1, ParmsId.Zero, plain2, pool)); + + Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(plain1, null)); + Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(null, ParmsId.Zero)); + Utilities.AssertThrows(() => evaluator.TransformToNTTInplace(plain1, ParmsId.Zero, pool)); + } } /// diff --git a/dotnet/tests/GaloisKeysTests.cs b/dotnet/tests/GaloisKeysTests.cs index 7f250f48d..368248d4e 100644 --- a/dotnet/tests/GaloisKeysTests.cs +++ b/dotnet/tests/GaloisKeysTests.cs @@ -25,64 +25,128 @@ public void CreateTest() [TestMethod] public void CreateNonEmptyTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateGaloisKeys(out GaloisKeys keys); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys keys); - Assert.IsNotNull(keys); - Assert.AreEqual(24ul, keys.Size); + Assert.IsNotNull(keys); + Assert.AreEqual(24ul, keys.Size); + + GaloisKeys copy = new GaloisKeys(keys); + + Assert.IsNotNull(copy); + Assert.AreEqual(24ul, copy.Size); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys keys); - GaloisKeys copy = new GaloisKeys(keys); + Assert.IsNotNull(keys); + Assert.AreEqual(24ul, keys.Size); - Assert.IsNotNull(copy); - Assert.AreEqual(24ul, copy.Size); + GaloisKeys copy = new GaloisKeys(keys); + + Assert.IsNotNull(copy); + Assert.AreEqual(24ul, copy.Size); + } } [TestMethod] public void SaveLoadTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateGaloisKeys(out GaloisKeys keys); - GaloisKeys other = new GaloisKeys(); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys keys); + GaloisKeys other = new GaloisKeys(); - Assert.IsNotNull(keys); - Assert.AreEqual(24ul, keys.Size); + Assert.IsNotNull(keys); + Assert.AreEqual(24ul, keys.Size); - using (MemoryStream ms = new MemoryStream()) - { - keys.Save(ms); - ms.Seek(offset: 0, loc: SeekOrigin.Begin); - other.Load(context, ms); - } + using (MemoryStream ms = new MemoryStream()) + { + keys.Save(ms); + ms.Seek(offset: 0, loc: SeekOrigin.Begin); + other.Load(context, ms); + } - Assert.AreEqual(24ul, other.Size); - Assert.IsTrue(ValCheck.IsValidFor(other, context)); + Assert.AreEqual(24ul, other.Size); + Assert.IsTrue(ValCheck.IsValidFor(other, context)); - List> keysData = new List>(keys.Data); - List> otherData = new List>(other.Data); + List> keysData = new List>(keys.Data); + List> otherData = new List>(other.Data); + + Assert.AreEqual(keysData.Count, otherData.Count); + for (int i = 0; i < keysData.Count; i++) + { + List keysCiphers = new List(keysData[i]); + List otherCiphers = new List(otherData[i]); - Assert.AreEqual(keysData.Count, otherData.Count); - for (int i = 0; i < keysData.Count; i++) + Assert.AreEqual(keysCiphers.Count, otherCiphers.Count); + + for (int j = 0; j < keysCiphers.Count; j++) + { + PublicKey keysCipher = keysCiphers[j]; + PublicKey otherCipher = otherCiphers[j]; + + Assert.AreEqual(keysCipher.Data.Size, otherCipher.Data.Size); + Assert.AreEqual(keysCipher.Data.PolyModulusDegree, otherCipher.Data.PolyModulusDegree); + Assert.AreEqual(keysCipher.Data.CoeffModulusSize, otherCipher.Data.CoeffModulusSize); + + ulong coeffCount = keysCipher.Data.Size * keysCipher.Data.PolyModulusDegree * keysCipher.Data.CoeffModulusSize; + for (ulong k = 0; k < coeffCount; k++) + { + Assert.AreEqual(keysCipher.Data[k], otherCipher.Data[k]); + } + } + } + } { - List keysCiphers = new List(keysData[i]); - List otherCiphers = new List(otherData[i]); + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys keys); + GaloisKeys other = new GaloisKeys(); + + Assert.IsNotNull(keys); + Assert.AreEqual(24ul, keys.Size); + + using (MemoryStream ms = new MemoryStream()) + { + keys.Save(ms); + ms.Seek(offset: 0, loc: SeekOrigin.Begin); + other.Load(context, ms); + } + + Assert.AreEqual(24ul, other.Size); + Assert.IsTrue(ValCheck.IsValidFor(other, context)); - Assert.AreEqual(keysCiphers.Count, otherCiphers.Count); + List> keysData = new List>(keys.Data); + List> otherData = new List>(other.Data); - for (int j = 0; j < keysCiphers.Count; j++) + Assert.AreEqual(keysData.Count, otherData.Count); + for (int i = 0; i < keysData.Count; i++) { - PublicKey keysCipher = keysCiphers[j]; - PublicKey otherCipher = otherCiphers[j]; + List keysCiphers = new List(keysData[i]); + List otherCiphers = new List(otherData[i]); - Assert.AreEqual(keysCipher.Data.Size, otherCipher.Data.Size); - Assert.AreEqual(keysCipher.Data.PolyModulusDegree, otherCipher.Data.PolyModulusDegree); - Assert.AreEqual(keysCipher.Data.CoeffModulusSize, otherCipher.Data.CoeffModulusSize); + Assert.AreEqual(keysCiphers.Count, otherCiphers.Count); - ulong coeffCount = keysCipher.Data.Size * keysCipher.Data.PolyModulusDegree * keysCipher.Data.CoeffModulusSize; - for (ulong k = 0; k < coeffCount; k++) + for (int j = 0; j < keysCiphers.Count; j++) { - Assert.AreEqual(keysCipher.Data[k], otherCipher.Data[k]); + PublicKey keysCipher = keysCiphers[j]; + PublicKey otherCipher = otherCiphers[j]; + + Assert.AreEqual(keysCipher.Data.Size, otherCipher.Data.Size); + Assert.AreEqual(keysCipher.Data.PolyModulusDegree, otherCipher.Data.PolyModulusDegree); + Assert.AreEqual(keysCipher.Data.CoeffModulusSize, otherCipher.Data.CoeffModulusSize); + + ulong coeffCount = keysCipher.Data.Size * keysCipher.Data.PolyModulusDegree * keysCipher.Data.CoeffModulusSize; + for (ulong k = 0; k < coeffCount; k++) + { + Assert.AreEqual(keysCipher.Data[k], otherCipher.Data[k]); + } } } } @@ -91,147 +155,289 @@ public void SaveLoadTest() [TestMethod] public void SeededKeyTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 8, - PlainModulus = new Modulus(257), - CoeffModulus = CoeffModulus.Create(8, new int[] { 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 8, + PlainModulus = new Modulus(257), + CoeffModulus = CoeffModulus.Create(8, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + BatchEncoder encoder = new BatchEncoder(context); + + GaloisKeys galoisKeys = new GaloisKeys(); + using (MemoryStream stream = new MemoryStream()) + { + keygen.CreateGaloisKeys().Save(stream); + stream.Seek(0, SeekOrigin.Begin); + galoisKeys.Load(context, stream); + } - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - BatchEncoder encoder = new BatchEncoder(context); + Plaintext plain = new Plaintext(); + List vec = new List + { + 1, 2, 3, 4, + 5, 6, 7, 8 + }; - GaloisKeys galoisKeys = new GaloisKeys(); - using (MemoryStream stream = new MemoryStream()) - { - keygen.CreateGaloisKeys().Save(stream); - stream.Seek(0, SeekOrigin.Begin); - galoisKeys.Load(context, stream); - } + encoder.Encode(vec, plain); - Plaintext plain = new Plaintext(); - List vec = new List - { - 1, 2, 3, 4, - 5, 6, 7, 8 - }; + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plaindest = new Plaintext(); - encoder.Encode(vec, plain); + encryptor.Encrypt(plain, encrypted); + evaluator.RotateColumns(encrypted, galoisKeys, encdest); + decryptor.Decrypt(encdest, plaindest); + encoder.Decode(plaindest, vec); - Ciphertext encrypted = new Ciphertext(); - Ciphertext encdest = new Ciphertext(); - Plaintext plaindest = new Plaintext(); + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 5, 6, 7, 8, + 1, 2, 3, 4 + })); - encryptor.Encrypt(plain, encrypted); - evaluator.RotateColumns(encrypted, galoisKeys, encdest); - decryptor.Decrypt(encdest, plaindest); - encoder.Decode(plaindest, vec); + evaluator.RotateRows(encdest, -1, galoisKeys, encrypted); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); - Assert.IsTrue(AreCollectionsEqual(vec, new List - { - 5, 6, 7, 8, - 1, 2, 3, 4 - })); + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 8, 5, 6, 7, + 4, 1, 2, 3 + })); - evaluator.RotateRows(encdest, -1, galoisKeys, encrypted); - decryptor.Decrypt(encrypted, plaindest); - encoder.Decode(plaindest, vec); + evaluator.RotateRowsInplace(encrypted, 2, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); - Assert.IsTrue(AreCollectionsEqual(vec, new List - { - 8, 5, 6, 7, - 4, 1, 2, 3 - })); + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 6, 7, 8, 5, + 2, 3, 4, 1 + })); - evaluator.RotateRowsInplace(encrypted, 2, galoisKeys); - decryptor.Decrypt(encrypted, plaindest); - encoder.Decode(plaindest, vec); + evaluator.RotateColumnsInplace(encrypted, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); - Assert.IsTrue(AreCollectionsEqual(vec, new List + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 2, 3, 4, 1, + 6, 7, 8, 5 + })); + } { - 6, 7, 8, 5, - 2, 3, 4, 1 - })); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 8, + PlainModulus = new Modulus(257), + CoeffModulus = CoeffModulus.Create(8, new int[] { 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + BatchEncoder encoder = new BatchEncoder(context); + + GaloisKeys galoisKeys = new GaloisKeys(); + using (MemoryStream stream = new MemoryStream()) + { + keygen.CreateGaloisKeys().Save(stream); + stream.Seek(0, SeekOrigin.Begin); + galoisKeys.Load(context, stream); + } + + Plaintext plain = new Plaintext(); + List vec = new List + { + 1, 2, 3, 4, + 5, 6, 7, 8 + }; - evaluator.RotateColumnsInplace(encrypted, galoisKeys); - decryptor.Decrypt(encrypted, plaindest); - encoder.Decode(plaindest, vec); + encoder.Encode(vec, plain); - Assert.IsTrue(AreCollectionsEqual(vec, new List - { - 2, 3, 4, 1, - 6, 7, 8, 5 - })); + Ciphertext encrypted = new Ciphertext(); + Ciphertext encdest = new Ciphertext(); + Plaintext plaindest = new Plaintext(); + + encryptor.Encrypt(plain, encrypted); + evaluator.RotateColumns(encrypted, galoisKeys, encdest); + decryptor.Decrypt(encdest, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 5, 6, 7, 8, + 1, 2, 3, 4 + })); + + evaluator.RotateRows(encdest, -1, galoisKeys, encrypted); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 8, 5, 6, 7, + 4, 1, 2, 3 + })); + + evaluator.RotateRowsInplace(encrypted, 2, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 6, 7, 8, 5, + 2, 3, 4, 1 + })); + + evaluator.RotateColumnsInplace(encrypted, galoisKeys); + decryptor.Decrypt(encrypted, plaindest); + encoder.Decode(plaindest, vec); + + Assert.IsTrue(AreCollectionsEqual(vec, new List + { + 2, 3, 4, 1, + 6, 7, 8, 5 + })); + } } [TestMethod] public void SetTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateGaloisKeys(out GaloisKeys keys); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys keys); - Assert.IsNotNull(keys); - Assert.AreEqual(24ul, keys.Size); + Assert.IsNotNull(keys); + Assert.AreEqual(24ul, keys.Size); + + GaloisKeys keys2 = new GaloisKeys(); + + Assert.IsNotNull(keys2); + Assert.AreEqual(0ul, keys2.Size); + + keys2.Set(keys); + + Assert.AreNotSame(keys, keys2); + Assert.AreEqual(24ul, keys2.Size); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys keys); - GaloisKeys keys2 = new GaloisKeys(); + Assert.IsNotNull(keys); + Assert.AreEqual(24ul, keys.Size); - Assert.IsNotNull(keys2); - Assert.AreEqual(0ul, keys2.Size); + GaloisKeys keys2 = new GaloisKeys(); - keys2.Set(keys); + Assert.IsNotNull(keys2); + Assert.AreEqual(0ul, keys2.Size); - Assert.AreNotSame(keys, keys2); - Assert.AreEqual(24ul, keys2.Size); + keys2.Set(keys); + + Assert.AreNotSame(keys, keys2); + Assert.AreEqual(24ul, keys2.Size); + } } [TestMethod] public void KeyTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateGaloisKeys(out GaloisKeys keys); - MemoryPoolHandle handle = keys.Pool; + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys keys); + MemoryPoolHandle handle = keys.Pool; - Assert.IsNotNull(keys); - Assert.AreEqual(24ul, keys.Size); + Assert.IsNotNull(keys); + Assert.AreEqual(24ul, keys.Size); - Assert.IsFalse(keys.HasKey(galoisElt: 1)); - Assert.IsTrue(keys.HasKey(galoisElt: 3)); - Assert.IsFalse(keys.HasKey(galoisElt: 5)); - Assert.IsFalse(keys.HasKey(galoisElt: 7)); - Assert.IsTrue(keys.HasKey(galoisElt: 9)); - Assert.IsFalse(keys.HasKey(galoisElt: 11)); + Assert.IsFalse(keys.HasKey(galoisElt: 1)); + Assert.IsTrue(keys.HasKey(galoisElt: 3)); + Assert.IsFalse(keys.HasKey(galoisElt: 5)); + Assert.IsFalse(keys.HasKey(galoisElt: 7)); + Assert.IsTrue(keys.HasKey(galoisElt: 9)); + Assert.IsFalse(keys.HasKey(galoisElt: 11)); - IEnumerable key = keys.Key(3); - Assert.AreEqual(4, key.Count()); + IEnumerable key = keys.Key(3); + Assert.AreEqual(4, key.Count()); - IEnumerable key2 = keys.Key(9); - Assert.AreEqual(4, key2.Count()); + IEnumerable key2 = keys.Key(9); + Assert.AreEqual(4, key2.Count()); - Assert.IsTrue(handle.AllocByteCount > 0ul); + Assert.IsTrue(handle.AllocByteCount > 0ul); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(out GaloisKeys keys); + MemoryPoolHandle handle = keys.Pool; + + Assert.IsNotNull(keys); + Assert.AreEqual(24ul, keys.Size); + + Assert.IsFalse(keys.HasKey(galoisElt: 1)); + Assert.IsTrue(keys.HasKey(galoisElt: 3)); + Assert.IsFalse(keys.HasKey(galoisElt: 5)); + Assert.IsFalse(keys.HasKey(galoisElt: 7)); + Assert.IsTrue(keys.HasKey(galoisElt: 9)); + Assert.IsFalse(keys.HasKey(galoisElt: 11)); + + IEnumerable key = keys.Key(3); + Assert.AreEqual(4, key.Count()); + + IEnumerable key2 = keys.Key(9); + Assert.AreEqual(4, key2.Count()); + + Assert.IsTrue(handle.AllocByteCount > 0ul); + } } [TestMethod] public void KeyEltTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateGaloisKeys(galoisElts: new uint[] { 1, 3 }, out GaloisKeys keys); - Assert.IsNotNull(keys); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(galoisElts: new uint[] { 1, 3 }, out GaloisKeys keys); + Assert.IsNotNull(keys); - Assert.AreEqual(2ul, keys.Size); + Assert.AreEqual(2ul, keys.Size); - Assert.IsTrue(keys.HasKey(1)); - Assert.IsTrue(keys.HasKey(3)); - Assert.IsFalse(keys.HasKey(5)); + Assert.IsTrue(keys.HasKey(1)); + Assert.IsTrue(keys.HasKey(3)); + Assert.IsFalse(keys.HasKey(5)); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateGaloisKeys(galoisElts: new uint[] { 1, 3 }, out GaloisKeys keys); + Assert.IsNotNull(keys); + + Assert.AreEqual(2ul, keys.Size); + + Assert.IsTrue(keys.HasKey(1)); + Assert.IsTrue(keys.HasKey(3)); + Assert.IsFalse(keys.HasKey(5)); + } } [TestMethod] @@ -270,23 +476,44 @@ public void KeyStepTest() [TestMethod] public void ExceptionsTest() { - SEALContext context = GlobalContext.BFVContext; - GaloisKeys keys = new GaloisKeys(); + { + SEALContext context = GlobalContext.BFVContext; + GaloisKeys keys = new GaloisKeys(); - Utilities.AssertThrows(() => keys = new GaloisKeys(null)); + Utilities.AssertThrows(() => keys = new GaloisKeys(null)); - Utilities.AssertThrows(() => keys.Set(null)); + Utilities.AssertThrows(() => keys.Set(null)); - Utilities.AssertThrows(() => ValCheck.IsValidFor(keys, null)); + Utilities.AssertThrows(() => ValCheck.IsValidFor(keys, null)); - Utilities.AssertThrows(() => keys.Save(null)); + Utilities.AssertThrows(() => keys.Save(null)); - Utilities.AssertThrows(() => keys.UnsafeLoad(context, null)); - Utilities.AssertThrows(() => keys.UnsafeLoad(context, new MemoryStream())); - Utilities.AssertThrows(() => keys.UnsafeLoad(null, new MemoryStream())); + Utilities.AssertThrows(() => keys.UnsafeLoad(context, null)); + Utilities.AssertThrows(() => keys.UnsafeLoad(context, new MemoryStream())); + Utilities.AssertThrows(() => keys.UnsafeLoad(null, new MemoryStream())); + + Utilities.AssertThrows(() => keys.Load(context, null)); + Utilities.AssertThrows(() => keys.Load(null, new MemoryStream())); + } + { + SEALContext context = GlobalContext.BGVContext; + GaloisKeys keys = new GaloisKeys(); - Utilities.AssertThrows(() => keys.Load(context, null)); - Utilities.AssertThrows(() => keys.Load(null, new MemoryStream())); + Utilities.AssertThrows(() => keys = new GaloisKeys(null)); + + Utilities.AssertThrows(() => keys.Set(null)); + + Utilities.AssertThrows(() => ValCheck.IsValidFor(keys, null)); + + Utilities.AssertThrows(() => keys.Save(null)); + + Utilities.AssertThrows(() => keys.UnsafeLoad(context, null)); + Utilities.AssertThrows(() => keys.UnsafeLoad(context, new MemoryStream())); + Utilities.AssertThrows(() => keys.UnsafeLoad(null, new MemoryStream())); + + Utilities.AssertThrows(() => keys.Load(context, null)); + Utilities.AssertThrows(() => keys.Load(null, new MemoryStream())); + } } /// diff --git a/dotnet/tests/GlobalContext.cs b/dotnet/tests/GlobalContext.cs index 81294d0e5..b8aea08e2 100644 --- a/dotnet/tests/GlobalContext.cs +++ b/dotnet/tests/GlobalContext.cs @@ -28,9 +28,18 @@ static GlobalContext() CoeffModulus = CoeffModulus.BFVDefault(polyModulusDegree: 8192) }; CKKSContext = new SEALContext(encParams); + + encParams = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 8192, + CoeffModulus = CoeffModulus.BGVDefault(polyModulusDegree: 8192) + }; + encParams.SetPlainModulus(65537ul); + BGVContext = new SEALContext(encParams); } public static SEALContext BFVContext { get; private set; } = null; public static SEALContext CKKSContext { get; private set; } = null; + public static SEALContext BGVContext { get; private set; } = null; } } diff --git a/dotnet/tests/KeyGeneratorTests.cs b/dotnet/tests/KeyGeneratorTests.cs index 5d2c38035..727f6884a 100644 --- a/dotnet/tests/KeyGeneratorTests.cs +++ b/dotnet/tests/KeyGeneratorTests.cs @@ -15,106 +15,207 @@ public class KeyGeneratorTests [TestMethod] public void CreateTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); + + Assert.IsNotNull(keygen); + + keygen.CreatePublicKey(out PublicKey pubKey); + SecretKey secKey = keygen.SecretKey; + + Assert.IsNotNull(pubKey); + Assert.IsNotNull(secKey); + + Ciphertext cipher = pubKey.Data; + Assert.IsNotNull(cipher); + + Plaintext plain = secKey.Data; + Assert.IsNotNull(plain); + Assert.AreEqual(40960ul, plain.CoeffCount); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); - Assert.IsNotNull(keygen); + Assert.IsNotNull(keygen); - keygen.CreatePublicKey(out PublicKey pubKey); - SecretKey secKey = keygen.SecretKey; + keygen.CreatePublicKey(out PublicKey pubKey); + SecretKey secKey = keygen.SecretKey; - Assert.IsNotNull(pubKey); - Assert.IsNotNull(secKey); + Assert.IsNotNull(pubKey); + Assert.IsNotNull(secKey); - Ciphertext cipher = pubKey.Data; - Assert.IsNotNull(cipher); + Ciphertext cipher = pubKey.Data; + Assert.IsNotNull(cipher); - Plaintext plain = secKey.Data; - Assert.IsNotNull(plain); - Assert.AreEqual(40960ul, plain.CoeffCount); + Plaintext plain = secKey.Data; + Assert.IsNotNull(plain); + Assert.AreEqual(40960ul, plain.CoeffCount); + } } [TestMethod] public void Create2Test() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen1 = new KeyGenerator(context); - keygen1.CreatePublicKey(out PublicKey publicKey); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen1 = new KeyGenerator(context); + keygen1.CreatePublicKey(out PublicKey publicKey); + + Encryptor encryptor1 = new Encryptor(context, publicKey); + Decryptor decryptor1 = new Decryptor(context, keygen1.SecretKey); + + Ciphertext cipher = new Ciphertext(); + Plaintext plain = new Plaintext("2x^1 + 5"); + Plaintext plain2 = new Plaintext(); - Encryptor encryptor1 = new Encryptor(context, publicKey); - Decryptor decryptor1 = new Decryptor(context, keygen1.SecretKey); + encryptor1.Encrypt(plain, cipher); + decryptor1.Decrypt(cipher, plain2); - Ciphertext cipher = new Ciphertext(); - Plaintext plain = new Plaintext("2x^1 + 5"); - Plaintext plain2 = new Plaintext(); + Assert.AreNotSame(plain, plain2); + Assert.AreEqual(plain, plain2); - encryptor1.Encrypt(plain, cipher); - decryptor1.Decrypt(cipher, plain2); + KeyGenerator keygen2 = new KeyGenerator(context, keygen1.SecretKey); - Assert.AreNotSame(plain, plain2); - Assert.AreEqual(plain, plain2); + keygen2.CreatePublicKey(out publicKey); + Encryptor encryptor2 = new Encryptor(context, publicKey); + Decryptor decryptor2 = new Decryptor(context, keygen2.SecretKey); - KeyGenerator keygen2 = new KeyGenerator(context, keygen1.SecretKey); + Plaintext plain3 = new Plaintext(); + decryptor2.Decrypt(cipher, plain3); + + Assert.AreNotSame(plain, plain3); + Assert.AreEqual(plain, plain3); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen1 = new KeyGenerator(context); + keygen1.CreatePublicKey(out PublicKey publicKey); - keygen2.CreatePublicKey(out publicKey); - Encryptor encryptor2 = new Encryptor(context, publicKey); - Decryptor decryptor2 = new Decryptor(context, keygen2.SecretKey); + Encryptor encryptor1 = new Encryptor(context, publicKey); + Decryptor decryptor1 = new Decryptor(context, keygen1.SecretKey); - Plaintext plain3 = new Plaintext(); - decryptor2.Decrypt(cipher, plain3); + Ciphertext cipher = new Ciphertext(); + Plaintext plain = new Plaintext("2x^1 + 5"); + Plaintext plain2 = new Plaintext(); - Assert.AreNotSame(plain, plain3); - Assert.AreEqual(plain, plain3); + encryptor1.Encrypt(plain, cipher); + decryptor1.Decrypt(cipher, plain2); + + Assert.AreNotSame(plain, plain2); + Assert.AreEqual(plain, plain2); + + KeyGenerator keygen2 = new KeyGenerator(context, keygen1.SecretKey); + + keygen2.CreatePublicKey(out publicKey); + Encryptor encryptor2 = new Encryptor(context, publicKey); + Decryptor decryptor2 = new Decryptor(context, keygen2.SecretKey); + + Plaintext plain3 = new Plaintext(); + decryptor2.Decrypt(cipher, plain3); + + Assert.AreNotSame(plain, plain3); + Assert.AreEqual(plain, plain3); + } } [TestMethod] public void KeyCopyTest() { - SEALContext context = GlobalContext.BFVContext; - PublicKey pk; - SecretKey sk = null; - - using (KeyGenerator keygen = new KeyGenerator(context)) { - keygen.CreatePublicKey(out pk); - sk = keygen.SecretKey; + SEALContext context = GlobalContext.BFVContext; + PublicKey pk; + SecretKey sk = null; + + using (KeyGenerator keygen = new KeyGenerator(context)) + { + keygen.CreatePublicKey(out pk); + sk = keygen.SecretKey; + } + + ParmsId parmsIdPK = pk.ParmsId; + ParmsId parmsIdSK = sk.ParmsId; + Assert.AreEqual(parmsIdPK, parmsIdSK); + Assert.AreEqual(parmsIdPK, context.KeyParmsId); + } + { + SEALContext context = GlobalContext.BGVContext; + PublicKey pk; + SecretKey sk = null; + + using (KeyGenerator keygen = new KeyGenerator(context)) + { + keygen.CreatePublicKey(out pk); + sk = keygen.SecretKey; + } + + ParmsId parmsIdPK = pk.ParmsId; + ParmsId parmsIdSK = sk.ParmsId; + Assert.AreEqual(parmsIdPK, parmsIdSK); + Assert.AreEqual(parmsIdPK, context.KeyParmsId); } - - ParmsId parmsIdPK = pk.ParmsId; - ParmsId parmsIdSK = sk.ParmsId; - Assert.AreEqual(parmsIdPK, parmsIdSK); - Assert.AreEqual(parmsIdPK, context.KeyParmsId); } [TestMethod] public void ExceptionsTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); - SecretKey secret = new SecretKey(); - List elts = new List { 16385 }; - List elts_null = null; - List steps = new List { 4096 }; - List steps_null = null; - - Utilities.AssertThrows(() => keygen = new KeyGenerator(null)); - - Utilities.AssertThrows(() => keygen = new KeyGenerator(context, null)); - Utilities.AssertThrows(() => keygen = new KeyGenerator(null, keygen.SecretKey)); - Utilities.AssertThrows(() => keygen = new KeyGenerator(context, secret)); - - Utilities.AssertThrows(() => keygen.CreateGaloisKeys(elts_null)); - Utilities.AssertThrows(() => keygen.CreateGaloisKeys(elts)); - Utilities.AssertThrows(() => keygen.CreateGaloisKeys(steps_null)); - Utilities.AssertThrows(() => keygen.CreateGaloisKeys(steps)); - - EncryptionParameters smallParms = new EncryptionParameters(SchemeType.CKKS); - smallParms.PolyModulusDegree = 128; - smallParms.CoeffModulus = CoeffModulus.Create(smallParms.PolyModulusDegree, new int[] { 60 }); - context = new SEALContext(smallParms, true, SecLevelType.None); - keygen = new KeyGenerator(context); - Utilities.AssertThrows(() => keygen.CreateRelinKeys()); - Utilities.AssertThrows(() => keygen.CreateGaloisKeys()); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); + SecretKey secret = new SecretKey(); + List elts = new List { 16385 }; + List elts_null = null; + List steps = new List { 4096 }; + List steps_null = null; + + Utilities.AssertThrows(() => keygen = new KeyGenerator(null)); + + Utilities.AssertThrows(() => keygen = new KeyGenerator(context, null)); + Utilities.AssertThrows(() => keygen = new KeyGenerator(null, keygen.SecretKey)); + Utilities.AssertThrows(() => keygen = new KeyGenerator(context, secret)); + + Utilities.AssertThrows(() => keygen.CreateGaloisKeys(elts_null)); + Utilities.AssertThrows(() => keygen.CreateGaloisKeys(elts)); + Utilities.AssertThrows(() => keygen.CreateGaloisKeys(steps_null)); + Utilities.AssertThrows(() => keygen.CreateGaloisKeys(steps)); + + EncryptionParameters smallParms = new EncryptionParameters(SchemeType.CKKS); + smallParms.PolyModulusDegree = 128; + smallParms.CoeffModulus = CoeffModulus.Create(smallParms.PolyModulusDegree, new int[] { 60 }); + context = new SEALContext(smallParms, true, SecLevelType.None); + keygen = new KeyGenerator(context); + Utilities.AssertThrows(() => keygen.CreateRelinKeys()); + Utilities.AssertThrows(() => keygen.CreateGaloisKeys()); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + SecretKey secret = new SecretKey(); + List elts = new List { 16385 }; + List elts_null = null; + List steps = new List { 4096 }; + List steps_null = null; + + Utilities.AssertThrows(() => keygen = new KeyGenerator(null)); + + Utilities.AssertThrows(() => keygen = new KeyGenerator(context, null)); + Utilities.AssertThrows(() => keygen = new KeyGenerator(null, keygen.SecretKey)); + Utilities.AssertThrows(() => keygen = new KeyGenerator(context, secret)); + + Utilities.AssertThrows(() => keygen.CreateGaloisKeys(elts_null)); + Utilities.AssertThrows(() => keygen.CreateGaloisKeys(elts)); + Utilities.AssertThrows(() => keygen.CreateGaloisKeys(steps_null)); + Utilities.AssertThrows(() => keygen.CreateGaloisKeys(steps)); + + EncryptionParameters smallParms = new EncryptionParameters(SchemeType.CKKS); + smallParms.PolyModulusDegree = 128; + smallParms.CoeffModulus = CoeffModulus.Create(smallParms.PolyModulusDegree, new int[] { 60 }); + context = new SEALContext(smallParms, true, SecLevelType.None); + keygen = new KeyGenerator(context); + Utilities.AssertThrows(() => keygen.CreateRelinKeys()); + Utilities.AssertThrows(() => keygen.CreateGaloisKeys()); + } } } } diff --git a/dotnet/tests/PlaintextTests.cs b/dotnet/tests/PlaintextTests.cs index bf7dd8f43..9a4f1d65e 100644 --- a/dotnet/tests/PlaintextTests.cs +++ b/dotnet/tests/PlaintextTests.cs @@ -61,25 +61,48 @@ public void CreateWithHexTest() [TestMethod] public void CopyTest() { - Plaintext plain = new Plaintext("6x^5 + 5x^4 + 3x^2 + 2x^1 + 1"); - Assert.IsFalse(plain.IsNTTForm); - Plaintext plain2 = new Plaintext(plain); - Assert.AreEqual(plain, plain2); - Assert.IsFalse(plain2.IsNTTForm); - Assert.AreEqual(plain.ParmsId, plain2.ParmsId); - - SEALContext context = GlobalContext.BFVContext; - Evaluator evaluator = new Evaluator(context); - evaluator.TransformToNTTInplace(plain, context.FirstParmsId); - Assert.IsTrue(plain.IsNTTForm); - Assert.IsFalse(plain2.IsNTTForm); - Assert.AreNotEqual(plain.ParmsId, plain2.ParmsId); - Assert.AreEqual(plain.ParmsId, context.FirstParmsId); - - Plaintext plain3 = new Plaintext(plain); - Assert.AreEqual(plain3, plain); - Assert.IsTrue(plain3.IsNTTForm); - Assert.AreEqual(plain3.ParmsId, context.FirstParmsId); + { + Plaintext plain = new Plaintext("6x^5 + 5x^4 + 3x^2 + 2x^1 + 1"); + Assert.IsFalse(plain.IsNTTForm); + Plaintext plain2 = new Plaintext(plain); + Assert.AreEqual(plain, plain2); + Assert.IsFalse(plain2.IsNTTForm); + Assert.AreEqual(plain.ParmsId, plain2.ParmsId); + + SEALContext context = GlobalContext.BFVContext; + Evaluator evaluator = new Evaluator(context); + evaluator.TransformToNTTInplace(plain, context.FirstParmsId); + Assert.IsTrue(plain.IsNTTForm); + Assert.IsFalse(plain2.IsNTTForm); + Assert.AreNotEqual(plain.ParmsId, plain2.ParmsId); + Assert.AreEqual(plain.ParmsId, context.FirstParmsId); + + Plaintext plain3 = new Plaintext(plain); + Assert.AreEqual(plain3, plain); + Assert.IsTrue(plain3.IsNTTForm); + Assert.AreEqual(plain3.ParmsId, context.FirstParmsId); + } + { + Plaintext plain = new Plaintext("6x^5 + 5x^4 + 3x^2 + 2x^1 + 1"); + Assert.IsFalse(plain.IsNTTForm); + Plaintext plain2 = new Plaintext(plain); + Assert.AreEqual(plain, plain2); + Assert.IsFalse(plain2.IsNTTForm); + Assert.AreEqual(plain.ParmsId, plain2.ParmsId); + + SEALContext context = GlobalContext.BGVContext; + Evaluator evaluator = new Evaluator(context); + evaluator.TransformToNTTInplace(plain, context.FirstParmsId); + Assert.IsTrue(plain.IsNTTForm); + Assert.IsFalse(plain2.IsNTTForm); + Assert.AreNotEqual(plain.ParmsId, plain2.ParmsId); + Assert.AreEqual(plain.ParmsId, context.FirstParmsId); + + Plaintext plain3 = new Plaintext(plain); + Assert.AreEqual(plain3, plain); + Assert.IsTrue(plain3.IsNTTForm); + Assert.AreEqual(plain3.ParmsId, context.FirstParmsId); + } } [TestMethod] @@ -327,25 +350,48 @@ public void EqualsTest() [TestMethod] public void SaveLoadTest() { - SEALContext context = GlobalContext.BFVContext; - Plaintext plain = new Plaintext("6x^5 + 5x^4 + 4x^3 + 3x^2 + 2x^1 + 5"); - Plaintext other = new Plaintext(); + { + SEALContext context = GlobalContext.BFVContext; + Plaintext plain = new Plaintext("6x^5 + 5x^4 + 4x^3 + 3x^2 + 2x^1 + 5"); + Plaintext other = new Plaintext(); - Assert.AreNotSame(plain, other); - Assert.AreNotEqual(plain, other); + Assert.AreNotSame(plain, other); + Assert.AreNotEqual(plain, other); - using (MemoryStream stream = new MemoryStream()) - { - plain.Save(stream); + using (MemoryStream stream = new MemoryStream()) + { + plain.Save(stream); - stream.Seek(offset: 0, loc: SeekOrigin.Begin); + stream.Seek(offset: 0, loc: SeekOrigin.Begin); - other.Load(context, stream); + other.Load(context, stream); + } + + Assert.AreNotSame(plain, other); + Assert.AreEqual(plain, other); + Assert.IsTrue(ValCheck.IsValidFor(other, context)); } + { + SEALContext context = GlobalContext.BGVContext; + Plaintext plain = new Plaintext("6x^5 + 5x^4 + 4x^3 + 3x^2 + 2x^1 + 5"); + Plaintext other = new Plaintext(); + + Assert.AreNotSame(plain, other); + Assert.AreNotEqual(plain, other); + + using (MemoryStream stream = new MemoryStream()) + { + plain.Save(stream); + + stream.Seek(offset: 0, loc: SeekOrigin.Begin); - Assert.AreNotSame(plain, other); - Assert.AreEqual(plain, other); - Assert.IsTrue(ValCheck.IsValidFor(other, context)); + other.Load(context, stream); + } + + Assert.AreNotSame(plain, other); + Assert.AreEqual(plain, other); + Assert.IsTrue(ValCheck.IsValidFor(other, context)); + } } [TestMethod] @@ -368,31 +414,60 @@ public void HashCodeTest() [TestMethod] public void ExceptionsTest() { - SEALContext context = GlobalContext.BFVContext; - Plaintext plain = new Plaintext(); - MemoryPoolHandle pool = MemoryManager.GetPool(MMProfOpt.ForceGlobal); - MemoryPoolHandle pool_uninit = new MemoryPoolHandle(); + { + SEALContext context = GlobalContext.BFVContext; + Plaintext plain = new Plaintext(); + MemoryPoolHandle pool = MemoryManager.GetPool(MMProfOpt.ForceGlobal); + MemoryPoolHandle pool_uninit = new MemoryPoolHandle(); + + Utilities.AssertThrows(() => plain = new Plaintext(pool_uninit)); + Utilities.AssertThrows(() => plain = new Plaintext((string)null, pool)); - Utilities.AssertThrows(() => plain = new Plaintext(pool_uninit)); - Utilities.AssertThrows(() => plain = new Plaintext((string)null, pool)); + Utilities.AssertThrows(() => plain.Set((Plaintext)null)); + Utilities.AssertThrows(() => plain.Set((string)null)); - Utilities.AssertThrows(() => plain.Set((Plaintext)null)); - Utilities.AssertThrows(() => plain.Set((string)null)); + Utilities.AssertThrows(() => plain.SetZero(100000)); + Utilities.AssertThrows(() => plain.SetZero(1, 100000)); + Utilities.AssertThrows(() => plain.SetZero(100000, 1)); - Utilities.AssertThrows(() => plain.SetZero(100000)); - Utilities.AssertThrows(() => plain.SetZero(1, 100000)); - Utilities.AssertThrows(() => plain.SetZero(100000, 1)); + Utilities.AssertThrows(() => ValCheck.IsValidFor(plain, null)); - Utilities.AssertThrows(() => ValCheck.IsValidFor(plain, null)); + Utilities.AssertThrows(() => plain.Save(null)); - Utilities.AssertThrows(() => plain.Save(null)); + Utilities.AssertThrows(() => plain.UnsafeLoad(null, new MemoryStream())); + Utilities.AssertThrows(() => plain.UnsafeLoad(context, null)); + Utilities.AssertThrows(() => plain.UnsafeLoad(context, new MemoryStream())); - Utilities.AssertThrows(() => plain.UnsafeLoad(null, new MemoryStream())); - Utilities.AssertThrows(() => plain.UnsafeLoad(context, null)); - Utilities.AssertThrows(() => plain.UnsafeLoad(context, new MemoryStream())); + Utilities.AssertThrows(() => plain.Load(context, null)); + Utilities.AssertThrows(() => plain.Load(null, new MemoryStream())); + } + { + SEALContext context = GlobalContext.BGVContext; + Plaintext plain = new Plaintext(); + MemoryPoolHandle pool = MemoryManager.GetPool(MMProfOpt.ForceGlobal); + MemoryPoolHandle pool_uninit = new MemoryPoolHandle(); - Utilities.AssertThrows(() => plain.Load(context, null)); - Utilities.AssertThrows(() => plain.Load(null, new MemoryStream())); + Utilities.AssertThrows(() => plain = new Plaintext(pool_uninit)); + Utilities.AssertThrows(() => plain = new Plaintext((string)null, pool)); + + Utilities.AssertThrows(() => plain.Set((Plaintext)null)); + Utilities.AssertThrows(() => plain.Set((string)null)); + + Utilities.AssertThrows(() => plain.SetZero(100000)); + Utilities.AssertThrows(() => plain.SetZero(1, 100000)); + Utilities.AssertThrows(() => plain.SetZero(100000, 1)); + + Utilities.AssertThrows(() => ValCheck.IsValidFor(plain, null)); + + Utilities.AssertThrows(() => plain.Save(null)); + + Utilities.AssertThrows(() => plain.UnsafeLoad(null, new MemoryStream())); + Utilities.AssertThrows(() => plain.UnsafeLoad(context, null)); + Utilities.AssertThrows(() => plain.UnsafeLoad(context, new MemoryStream())); + + Utilities.AssertThrows(() => plain.Load(context, null)); + Utilities.AssertThrows(() => plain.Load(null, new MemoryStream())); + } } } } diff --git a/dotnet/tests/PublicKeyTests.cs b/dotnet/tests/PublicKeyTests.cs index 98a156183..418667d04 100644 --- a/dotnet/tests/PublicKeyTests.cs +++ b/dotnet/tests/PublicKeyTests.cs @@ -15,92 +15,180 @@ public class PublicKeyTests [TestMethod] public void CreateTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey pub); - PublicKey copy = new PublicKey(pub); - - Assert.IsNotNull(copy); - Assert.AreEqual(2ul, copy.Data.Size); - Assert.IsTrue(copy.Data.IsNTTForm); - - PublicKey copy2 = new PublicKey(); - copy2.Set(copy); - - Assert.AreEqual(2ul, copy2.Data.Size); - Assert.IsTrue(copy2.Data.IsNTTForm); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey pub); + PublicKey copy = new PublicKey(pub); + + Assert.IsNotNull(copy); + Assert.AreEqual(2ul, copy.Data.Size); + Assert.IsTrue(copy.Data.IsNTTForm); + + PublicKey copy2 = new PublicKey(); + copy2.Set(copy); + + Assert.AreEqual(2ul, copy2.Data.Size); + Assert.IsTrue(copy2.Data.IsNTTForm); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey pub); + PublicKey copy = new PublicKey(pub); + + Assert.IsNotNull(copy); + Assert.AreEqual(2ul, copy.Data.Size); + Assert.IsTrue(copy.Data.IsNTTForm); + + PublicKey copy2 = new PublicKey(); + copy2.Set(copy); + + Assert.AreEqual(2ul, copy2.Data.Size); + Assert.IsTrue(copy2.Data.IsNTTForm); + } } [TestMethod] public void SaveLoadTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey pub); - - Assert.IsNotNull(pub); - Assert.AreEqual(2ul, pub.Data.Size); - Assert.IsTrue(pub.Data.IsNTTForm); - - PublicKey pub2 = new PublicKey(); - MemoryPoolHandle handle = pub2.Pool; - - Assert.AreEqual(0ul, pub2.Data.Size); - Assert.IsFalse(pub2.Data.IsNTTForm); - Assert.AreEqual(ParmsId.Zero, pub2.ParmsId); - - using (MemoryStream stream = new MemoryStream()) + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey pub); + + Assert.IsNotNull(pub); + Assert.AreEqual(2ul, pub.Data.Size); + Assert.IsTrue(pub.Data.IsNTTForm); + + PublicKey pub2 = new PublicKey(); + MemoryPoolHandle handle = pub2.Pool; + + Assert.AreEqual(0ul, pub2.Data.Size); + Assert.IsFalse(pub2.Data.IsNTTForm); + Assert.AreEqual(ParmsId.Zero, pub2.ParmsId); + + using (MemoryStream stream = new MemoryStream()) + { + pub.Save(stream); + + stream.Seek(offset: 0, loc: SeekOrigin.Begin); + + pub2.Load(context, stream); + } + + Assert.AreNotSame(pub, pub2); + Assert.AreEqual(2ul, pub2.Data.Size); + Assert.IsTrue(pub2.Data.IsNTTForm); + Assert.AreEqual(pub.ParmsId, pub2.ParmsId); + Assert.AreNotEqual(ParmsId.Zero, pub2.ParmsId); + Assert.IsTrue(handle.AllocByteCount != 0ul); + } { - pub.Save(stream); - - stream.Seek(offset: 0, loc: SeekOrigin.Begin); - - pub2.Load(context, stream); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey pub); + + Assert.IsNotNull(pub); + Assert.AreEqual(2ul, pub.Data.Size); + Assert.IsTrue(pub.Data.IsNTTForm); + + PublicKey pub2 = new PublicKey(); + MemoryPoolHandle handle = pub2.Pool; + + Assert.AreEqual(0ul, pub2.Data.Size); + Assert.IsFalse(pub2.Data.IsNTTForm); + Assert.AreEqual(ParmsId.Zero, pub2.ParmsId); + + using (MemoryStream stream = new MemoryStream()) + { + pub.Save(stream); + + stream.Seek(offset: 0, loc: SeekOrigin.Begin); + + pub2.Load(context, stream); + } + + Assert.AreNotSame(pub, pub2); + Assert.AreEqual(2ul, pub2.Data.Size); + Assert.IsTrue(pub2.Data.IsNTTForm); + Assert.AreEqual(pub.ParmsId, pub2.ParmsId); + Assert.AreNotEqual(ParmsId.Zero, pub2.ParmsId); + Assert.IsTrue(handle.AllocByteCount != 0ul); } - - Assert.AreNotSame(pub, pub2); - Assert.AreEqual(2ul, pub2.Data.Size); - Assert.IsTrue(pub2.Data.IsNTTForm); - Assert.AreEqual(pub.ParmsId, pub2.ParmsId); - Assert.AreNotEqual(ParmsId.Zero, pub2.ParmsId); - Assert.IsTrue(handle.AllocByteCount != 0ul); } [TestMethod] public void ExceptionsTest() { - SEALContext context = GlobalContext.BFVContext; - PublicKey key = new PublicKey(); + { + SEALContext context = GlobalContext.BFVContext; + PublicKey key = new PublicKey(); - Utilities.AssertThrows(() => key = new PublicKey(null)); + Utilities.AssertThrows(() => key = new PublicKey(null)); - Utilities.AssertThrows(() => key.Set(null)); + Utilities.AssertThrows(() => key.Set(null)); - Utilities.AssertThrows(() => key.Save(null)); - Utilities.AssertThrows(() => key.UnsafeLoad(context, null)); - Utilities.AssertThrows(() => key.UnsafeLoad(null, new MemoryStream())); + Utilities.AssertThrows(() => key.Save(null)); + Utilities.AssertThrows(() => key.UnsafeLoad(context, null)); + Utilities.AssertThrows(() => key.UnsafeLoad(null, new MemoryStream())); - Utilities.AssertThrows(() => key.Load(context, null)); - Utilities.AssertThrows(() => key.Load(null, new MemoryStream())); - Utilities.AssertThrows(() => key.Load(context, new MemoryStream())); + Utilities.AssertThrows(() => key.Load(context, null)); + Utilities.AssertThrows(() => key.Load(null, new MemoryStream())); + Utilities.AssertThrows(() => key.Load(context, new MemoryStream())); - Utilities.AssertThrows(() => ValCheck.IsValidFor(key, null)); + Utilities.AssertThrows(() => ValCheck.IsValidFor(key, null)); + } + { + SEALContext context = GlobalContext.BGVContext; + PublicKey key = new PublicKey(); + + Utilities.AssertThrows(() => key = new PublicKey(null)); + + Utilities.AssertThrows(() => key.Set(null)); + + Utilities.AssertThrows(() => key.Save(null)); + Utilities.AssertThrows(() => key.UnsafeLoad(context, null)); + Utilities.AssertThrows(() => key.UnsafeLoad(null, new MemoryStream())); + + Utilities.AssertThrows(() => key.Load(context, null)); + Utilities.AssertThrows(() => key.Load(null, new MemoryStream())); + Utilities.AssertThrows(() => key.Load(context, new MemoryStream())); + + Utilities.AssertThrows(() => ValCheck.IsValidFor(key, null)); + } } } } diff --git a/dotnet/tests/RelinKeysTests.cs b/dotnet/tests/RelinKeysTests.cs index c38cd72ce..c1d588d9c 100644 --- a/dotnet/tests/RelinKeysTests.cs +++ b/dotnet/tests/RelinKeysTests.cs @@ -24,78 +24,156 @@ public void CreateRelinKeysTest() [TestMethod] public void CreateNonEmptyRelinKeysTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateRelinKeys(out RelinKeys keys); + keygen.CreateRelinKeys(out RelinKeys keys); - Assert.IsNotNull(keys); - Assert.AreEqual(1ul, keys.Size); + Assert.IsNotNull(keys); + Assert.AreEqual(1ul, keys.Size); + + RelinKeys copy = new RelinKeys(keys); + + Assert.IsNotNull(copy); + Assert.AreEqual(1ul, copy.Size); + + RelinKeys copy2 = new RelinKeys(); + + copy2.Set(keys); + Assert.IsNotNull(copy2); + Assert.AreEqual(1ul, copy2.Size); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); - RelinKeys copy = new RelinKeys(keys); + keygen.CreateRelinKeys(out RelinKeys keys); - Assert.IsNotNull(copy); - Assert.AreEqual(1ul, copy.Size); + Assert.IsNotNull(keys); + Assert.AreEqual(1ul, keys.Size); - RelinKeys copy2 = new RelinKeys(); + RelinKeys copy = new RelinKeys(keys); - copy2.Set(keys); - Assert.IsNotNull(copy2); - Assert.AreEqual(1ul, copy2.Size); + Assert.IsNotNull(copy); + Assert.AreEqual(1ul, copy.Size); + + RelinKeys copy2 = new RelinKeys(); + + copy2.Set(keys); + Assert.IsNotNull(copy2); + Assert.AreEqual(1ul, copy2.Size); + } } [TestMethod] public void SaveLoadTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateRelinKeys(out RelinKeys keys); + keygen.CreateRelinKeys(out RelinKeys keys); - Assert.IsNotNull(keys); - Assert.AreEqual(1ul, keys.Size); + Assert.IsNotNull(keys); + Assert.AreEqual(1ul, keys.Size); - RelinKeys other = new RelinKeys(); - MemoryPoolHandle handle = other.Pool; + RelinKeys other = new RelinKeys(); + MemoryPoolHandle handle = other.Pool; - Assert.AreEqual(0ul, other.Size); - ulong alloced = handle.AllocByteCount; + Assert.AreEqual(0ul, other.Size); + ulong alloced = handle.AllocByteCount; - using (MemoryStream ms = new MemoryStream()) - { - keys.Save(ms); - ms.Seek(offset: 0, loc: SeekOrigin.Begin); - other.Load(context, ms); - } + using (MemoryStream ms = new MemoryStream()) + { + keys.Save(ms); + ms.Seek(offset: 0, loc: SeekOrigin.Begin); + other.Load(context, ms); + } - Assert.AreEqual(1ul, other.Size); - Assert.IsTrue(ValCheck.IsValidFor(other, context)); - Assert.IsTrue(handle.AllocByteCount > 0ul); + Assert.AreEqual(1ul, other.Size); + Assert.IsTrue(ValCheck.IsValidFor(other, context)); + Assert.IsTrue(handle.AllocByteCount > 0ul); - List> keysData = new List>(keys.Data); - List> otherData = new List>(other.Data); + List> keysData = new List>(keys.Data); + List> otherData = new List>(other.Data); + + Assert.AreEqual(keysData.Count, otherData.Count); + for (int i = 0; i < keysData.Count; i++) + { + List keysCiphers = new List(keysData[i]); + List otherCiphers = new List(otherData[i]); - Assert.AreEqual(keysData.Count, otherData.Count); - for (int i = 0; i < keysData.Count; i++) + Assert.AreEqual(keysCiphers.Count, otherCiphers.Count); + + for (int j = 0; j < keysCiphers.Count; j++) + { + PublicKey keysCipher = keysCiphers[j]; + PublicKey otherCipher = otherCiphers[j]; + + Assert.AreEqual(keysCipher.Data.Size, otherCipher.Data.Size); + Assert.AreEqual(keysCipher.Data.PolyModulusDegree, otherCipher.Data.PolyModulusDegree); + Assert.AreEqual(keysCipher.Data.CoeffModulusSize, otherCipher.Data.CoeffModulusSize); + + ulong coeffCount = keysCipher.Data.Size * keysCipher.Data.PolyModulusDegree * keysCipher.Data.CoeffModulusSize; + for (ulong k = 0; k < coeffCount; k++) + { + Assert.AreEqual(keysCipher.Data[k], otherCipher.Data[k]); + } + } + } + } { - List keysCiphers = new List(keysData[i]); - List otherCiphers = new List(otherData[i]); + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + + keygen.CreateRelinKeys(out RelinKeys keys); + + Assert.IsNotNull(keys); + Assert.AreEqual(1ul, keys.Size); + + RelinKeys other = new RelinKeys(); + MemoryPoolHandle handle = other.Pool; + + Assert.AreEqual(0ul, other.Size); + ulong alloced = handle.AllocByteCount; + + using (MemoryStream ms = new MemoryStream()) + { + keys.Save(ms); + ms.Seek(offset: 0, loc: SeekOrigin.Begin); + other.Load(context, ms); + } - Assert.AreEqual(keysCiphers.Count, otherCiphers.Count); + Assert.AreEqual(1ul, other.Size); + Assert.IsTrue(ValCheck.IsValidFor(other, context)); + Assert.IsTrue(handle.AllocByteCount > 0ul); - for (int j = 0; j < keysCiphers.Count; j++) + List> keysData = new List>(keys.Data); + List> otherData = new List>(other.Data); + + Assert.AreEqual(keysData.Count, otherData.Count); + for (int i = 0; i < keysData.Count; i++) { - PublicKey keysCipher = keysCiphers[j]; - PublicKey otherCipher = otherCiphers[j]; + List keysCiphers = new List(keysData[i]); + List otherCiphers = new List(otherData[i]); - Assert.AreEqual(keysCipher.Data.Size, otherCipher.Data.Size); - Assert.AreEqual(keysCipher.Data.PolyModulusDegree, otherCipher.Data.PolyModulusDegree); - Assert.AreEqual(keysCipher.Data.CoeffModulusSize, otherCipher.Data.CoeffModulusSize); + Assert.AreEqual(keysCiphers.Count, otherCiphers.Count); - ulong coeffCount = keysCipher.Data.Size * keysCipher.Data.PolyModulusDegree * keysCipher.Data.CoeffModulusSize; - for (ulong k = 0; k < coeffCount; k++) + for (int j = 0; j < keysCiphers.Count; j++) { - Assert.AreEqual(keysCipher.Data[k], otherCipher.Data[k]); + PublicKey keysCipher = keysCiphers[j]; + PublicKey otherCipher = otherCiphers[j]; + + Assert.AreEqual(keysCipher.Data.Size, otherCipher.Data.Size); + Assert.AreEqual(keysCipher.Data.PolyModulusDegree, otherCipher.Data.PolyModulusDegree); + Assert.AreEqual(keysCipher.Data.CoeffModulusSize, otherCipher.Data.CoeffModulusSize); + + ulong coeffCount = keysCipher.Data.Size * keysCipher.Data.PolyModulusDegree * keysCipher.Data.CoeffModulusSize; + for (ulong k = 0; k < coeffCount; k++) + { + Assert.AreEqual(keysCipher.Data[k], otherCipher.Data[k]); + } } } } @@ -104,97 +182,190 @@ public void SaveLoadTest() [TestMethod] public void SeededKeyTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) - { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - - RelinKeys relinKeys = new RelinKeys(); - using (MemoryStream stream = new MemoryStream()) { - keygen.CreateRelinKeys().Save(stream); - stream.Seek(0, SeekOrigin.Begin); - relinKeys.Load(context, stream); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + + RelinKeys relinKeys = new RelinKeys(); + using (MemoryStream stream = new MemoryStream()) + { + keygen.CreateRelinKeys().Save(stream); + stream.Seek(0, SeekOrigin.Begin); + relinKeys.Load(context, stream); + } + + keygen.CreatePublicKey(out PublicKey publicKey); + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted1 = new Ciphertext(context); + Ciphertext encrypted2 = new Ciphertext(context); + Plaintext plain1 = new Plaintext(); + Plaintext plain2 = new Plaintext(); + + plain1.Set(0); + encryptor.Encrypt(plain1, encrypted1); + evaluator.SquareInplace(encrypted1); + evaluator.RelinearizeInplace(encrypted1, relinKeys); + decryptor.Decrypt(encrypted1, plain2); + + Assert.AreEqual(1ul, plain2.CoeffCount); + Assert.AreEqual(0ul, plain2[0]); + + plain1.Set("1x^10 + 2"); + encryptor.Encrypt(plain1, encrypted1); + evaluator.SquareInplace(encrypted1); + evaluator.RelinearizeInplace(encrypted1, relinKeys); + evaluator.SquareInplace(encrypted1); + evaluator.Relinearize(encrypted1, relinKeys, encrypted2); + decryptor.Decrypt(encrypted2, plain2); + + // {1x^40 + 8x^30 + 18x^20 + 20x^10 + 10} + Assert.AreEqual(41ul, plain2.CoeffCount); + Assert.AreEqual(16ul, plain2[0]); + Assert.AreEqual(32ul, plain2[10]); + Assert.AreEqual(24ul, plain2[20]); + Assert.AreEqual(8ul, plain2[30]); + Assert.AreEqual(1ul, plain2[40]); } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 40, 40, 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + + RelinKeys relinKeys = new RelinKeys(); + using (MemoryStream stream = new MemoryStream()) + { + keygen.CreateRelinKeys().Save(stream); + stream.Seek(0, SeekOrigin.Begin); + relinKeys.Load(context, stream); + } - keygen.CreatePublicKey(out PublicKey publicKey); - Encryptor encryptor = new Encryptor(context, publicKey); - Decryptor decryptor = new Decryptor(context, keygen.SecretKey); - Evaluator evaluator = new Evaluator(context); - - Ciphertext encrypted1 = new Ciphertext(context); - Ciphertext encrypted2 = new Ciphertext(context); - Plaintext plain1 = new Plaintext(); - Plaintext plain2 = new Plaintext(); - - plain1.Set(0); - encryptor.Encrypt(plain1, encrypted1); - evaluator.SquareInplace(encrypted1); - evaluator.RelinearizeInplace(encrypted1, relinKeys); - decryptor.Decrypt(encrypted1, plain2); - - Assert.AreEqual(1ul, plain2.CoeffCount); - Assert.AreEqual(0ul, plain2[0]); - - plain1.Set("1x^10 + 2"); - encryptor.Encrypt(plain1, encrypted1); - evaluator.SquareInplace(encrypted1); - evaluator.RelinearizeInplace(encrypted1, relinKeys); - evaluator.SquareInplace(encrypted1); - evaluator.Relinearize(encrypted1, relinKeys, encrypted2); - decryptor.Decrypt(encrypted2, plain2); - - // {1x^40 + 8x^30 + 18x^20 + 20x^10 + 10} - Assert.AreEqual(41ul, plain2.CoeffCount); - Assert.AreEqual(16ul, plain2[0]); - Assert.AreEqual(32ul, plain2[10]); - Assert.AreEqual(24ul, plain2[20]); - Assert.AreEqual(8ul, plain2[30]); - Assert.AreEqual(1ul, plain2[40]); + keygen.CreatePublicKey(out PublicKey publicKey); + Encryptor encryptor = new Encryptor(context, publicKey); + Decryptor decryptor = new Decryptor(context, keygen.SecretKey); + Evaluator evaluator = new Evaluator(context); + + Ciphertext encrypted1 = new Ciphertext(context); + Ciphertext encrypted2 = new Ciphertext(context); + Plaintext plain1 = new Plaintext(); + Plaintext plain2 = new Plaintext(); + + plain1.Set(0); + encryptor.Encrypt(plain1, encrypted1); + evaluator.SquareInplace(encrypted1); + evaluator.RelinearizeInplace(encrypted1, relinKeys); + decryptor.Decrypt(encrypted1, plain2); + + Assert.AreEqual(1ul, plain2.CoeffCount); + Assert.AreEqual(0ul, plain2[0]); + + plain1.Set("1x^10 + 2"); + encryptor.Encrypt(plain1, encrypted1); + evaluator.SquareInplace(encrypted1); + evaluator.RelinearizeInplace(encrypted1, relinKeys); + evaluator.SquareInplace(encrypted1); + evaluator.Relinearize(encrypted1, relinKeys, encrypted2); + decryptor.Decrypt(encrypted2, plain2); + + // {1x^40 + 8x^30 + 18x^20 + 20x^10 + 10} + Assert.AreEqual(41ul, plain2.CoeffCount); + Assert.AreEqual(16ul, plain2[0]); + Assert.AreEqual(32ul, plain2[10]); + Assert.AreEqual(24ul, plain2[20]); + Assert.AreEqual(8ul, plain2[30]); + Assert.AreEqual(1ul, plain2[40]); + } } [TestMethod] public void GetKeyTest() { - SEALContext context = GlobalContext.BFVContext; - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreateRelinKeys(out RelinKeys relinKeys); + { + SEALContext context = GlobalContext.BFVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateRelinKeys(out RelinKeys relinKeys); - Assert.IsTrue(relinKeys.HasKey(2)); - Assert.IsFalse(relinKeys.HasKey(3)); + Assert.IsTrue(relinKeys.HasKey(2)); + Assert.IsFalse(relinKeys.HasKey(3)); - Utilities.AssertThrows(() => relinKeys.Key(0)); - Utilities.AssertThrows(() => relinKeys.Key(1)); + Utilities.AssertThrows(() => relinKeys.Key(0)); + Utilities.AssertThrows(() => relinKeys.Key(1)); - List key1 = new List(relinKeys.Key(2)); - Assert.AreEqual(4, key1.Count); - Assert.AreEqual(5ul, key1[0].Data.CoeffModulusSize); + List key1 = new List(relinKeys.Key(2)); + Assert.AreEqual(4, key1.Count); + Assert.AreEqual(5ul, key1[0].Data.CoeffModulusSize); + } + { + SEALContext context = GlobalContext.BGVContext; + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreateRelinKeys(out RelinKeys relinKeys); + + Assert.IsTrue(relinKeys.HasKey(2)); + Assert.IsFalse(relinKeys.HasKey(3)); + + Utilities.AssertThrows(() => relinKeys.Key(0)); + Utilities.AssertThrows(() => relinKeys.Key(1)); + + List key1 = new List(relinKeys.Key(2)); + Assert.AreEqual(4, key1.Count); + Assert.AreEqual(5ul, key1[0].Data.CoeffModulusSize); + } } [TestMethod] public void ExceptionsTest() { - RelinKeys keys = new RelinKeys(); - SEALContext context = GlobalContext.BFVContext; + { + RelinKeys keys = new RelinKeys(); + SEALContext context = GlobalContext.BFVContext; - Utilities.AssertThrows(() => keys = new RelinKeys(null)); - Utilities.AssertThrows(() => keys.Set(null)); + Utilities.AssertThrows(() => keys = new RelinKeys(null)); + Utilities.AssertThrows(() => keys.Set(null)); - Utilities.AssertThrows(() => ValCheck.IsValidFor(keys, null)); + Utilities.AssertThrows(() => ValCheck.IsValidFor(keys, null)); - Utilities.AssertThrows(() => keys.Save(null)); + Utilities.AssertThrows(() => keys.Save(null)); - Utilities.AssertThrows(() => keys.Load(context, null)); - Utilities.AssertThrows(() => keys.Load(null, new MemoryStream())); - Utilities.AssertThrows(() => keys.Load(context, new MemoryStream())); - Utilities.AssertThrows(() => keys.UnsafeLoad(null, new MemoryStream())); - Utilities.AssertThrows(() => keys.UnsafeLoad(context, null)); + Utilities.AssertThrows(() => keys.Load(context, null)); + Utilities.AssertThrows(() => keys.Load(null, new MemoryStream())); + Utilities.AssertThrows(() => keys.Load(context, new MemoryStream())); + Utilities.AssertThrows(() => keys.UnsafeLoad(null, new MemoryStream())); + Utilities.AssertThrows(() => keys.UnsafeLoad(context, null)); + } + { + RelinKeys keys = new RelinKeys(); + SEALContext context = GlobalContext.BGVContext; + + Utilities.AssertThrows(() => keys = new RelinKeys(null)); + Utilities.AssertThrows(() => keys.Set(null)); + + Utilities.AssertThrows(() => ValCheck.IsValidFor(keys, null)); + + Utilities.AssertThrows(() => keys.Save(null)); + + Utilities.AssertThrows(() => keys.Load(context, null)); + Utilities.AssertThrows(() => keys.Load(null, new MemoryStream())); + Utilities.AssertThrows(() => keys.Load(context, new MemoryStream())); + Utilities.AssertThrows(() => keys.UnsafeLoad(null, new MemoryStream())); + Utilities.AssertThrows(() => keys.UnsafeLoad(context, null)); + } } } } diff --git a/dotnet/tests/SEALContextTests.cs b/dotnet/tests/SEALContextTests.cs index 2438ace76..d218503c8 100644 --- a/dotnet/tests/SEALContextTests.cs +++ b/dotnet/tests/SEALContextTests.cs @@ -15,19 +15,26 @@ public void SEALContextCreateTest() { EncryptionParameters encParams1 = new EncryptionParameters(SchemeType.BFV); EncryptionParameters encParams2 = new EncryptionParameters(SchemeType.CKKS); + EncryptionParameters encParams3 = new EncryptionParameters(SchemeType.BGV); SEALContext context1 = new SEALContext(encParams1); SEALContext context2 = new SEALContext(encParams2); + SEALContext context3 = new SEALContext(encParams3); Assert.IsNotNull(context1); Assert.IsNotNull(context2); + Assert.IsNotNull(context3); Assert.IsFalse(context1.ParametersSet); Assert.IsFalse(context2.ParametersSet); + Assert.IsFalse(context3.ParametersSet); Assert.AreNotSame(context1.FirstParmsId, context1.LastParmsId); Assert.AreEqual(context1.FirstParmsId, context1.LastParmsId); + Assert.AreNotSame(context3.FirstParmsId, context3.LastParmsId); + Assert.AreEqual(context3.FirstParmsId, context3.LastParmsId); + SEALContext.ContextData data1 = context2.FirstContextData; SEALContext.ContextData data2 = context2.GetContextData(context2.FirstParmsId); @@ -50,64 +57,126 @@ public void SEALContextCreateTest() [TestMethod] public void SEALContextParamsTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 128, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30 }) - }; - SEALContext context = new SEALContext(parms, expandModChain: true, secLevel: SecLevelType.None); - - SEALContext.ContextData data = context.KeyContextData; - Assert.IsNotNull(data); - - EncryptionParameters parms2 = data.Parms; - Assert.AreEqual(parms.PolyModulusDegree, parms2.PolyModulusDegree); - - EncryptionParameterQualifiers qualifiers = data.Qualifiers; - Assert.IsNotNull(qualifiers); - - Assert.IsTrue(qualifiers.ParametersSet); - Assert.IsFalse(qualifiers.UsingBatching); - Assert.IsTrue(qualifiers.UsingFastPlainLift); - Assert.IsTrue(qualifiers.UsingFFT); - Assert.IsTrue(qualifiers.UsingNTT); - Assert.AreEqual(SecLevelType.None, qualifiers.SecLevel); - Assert.IsFalse(qualifiers.UsingDescendingModulusChain); - Assert.IsTrue(context.UsingKeyswitching); - - ulong[] cdpm = data.CoeffDivPlainModulus; - Assert.AreEqual(3, cdpm.Length); - - Assert.AreEqual(32ul, data.PlainUpperHalfThreshold); - - Assert.AreEqual(3, data.PlainUpperHalfIncrement.Length); - Assert.IsNull(data.UpperHalfThreshold); - Assert.IsNotNull(data.UpperHalfIncrement); - Assert.AreEqual(3, data.UpperHalfIncrement.Length); - Assert.AreEqual(2ul, data.ChainIndex); - - Assert.IsNull(data.PrevContextData); - SEALContext.ContextData data2 = data.NextContextData; - Assert.IsNotNull(data2); - Assert.AreEqual(1ul, data2.ChainIndex); - Assert.AreEqual(2ul, data2.PrevContextData.ChainIndex); - - SEALContext.ContextData data3 = data2.NextContextData; - Assert.IsNotNull(data3); - Assert.AreEqual(0ul, data3.ChainIndex); - Assert.AreEqual(1ul, data3.PrevContextData.ChainIndex); - Assert.IsNull(data3.NextContextData); - - parms = new EncryptionParameters(SchemeType.BFV) + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30 }) + }; + SEALContext context = new SEALContext(parms, expandModChain: true, secLevel: SecLevelType.None); + + SEALContext.ContextData data = context.KeyContextData; + Assert.IsNotNull(data); + + EncryptionParameters parms2 = data.Parms; + Assert.AreEqual(parms.PolyModulusDegree, parms2.PolyModulusDegree); + + EncryptionParameterQualifiers qualifiers = data.Qualifiers; + Assert.IsNotNull(qualifiers); + + Assert.IsTrue(qualifiers.ParametersSet); + Assert.IsFalse(qualifiers.UsingBatching); + Assert.IsTrue(qualifiers.UsingFastPlainLift); + Assert.IsTrue(qualifiers.UsingFFT); + Assert.IsTrue(qualifiers.UsingNTT); + Assert.AreEqual(SecLevelType.None, qualifiers.SecLevel); + Assert.IsFalse(qualifiers.UsingDescendingModulusChain); + Assert.IsTrue(context.UsingKeyswitching); + + ulong[] cdpm = data.CoeffDivPlainModulus; + Assert.AreEqual(3, cdpm.Length); + + Assert.AreEqual(32ul, data.PlainUpperHalfThreshold); + + Assert.AreEqual(3, data.PlainUpperHalfIncrement.Length); + Assert.IsNull(data.UpperHalfThreshold); + Assert.IsNotNull(data.UpperHalfIncrement); + Assert.AreEqual(3, data.UpperHalfIncrement.Length); + Assert.AreEqual(2ul, data.ChainIndex); + + Assert.IsNull(data.PrevContextData); + SEALContext.ContextData data2 = data.NextContextData; + Assert.IsNotNull(data2); + Assert.AreEqual(1ul, data2.ChainIndex); + Assert.AreEqual(2ul, data2.PrevContextData.ChainIndex); + + SEALContext.ContextData data3 = data2.NextContextData; + Assert.IsNotNull(data3); + Assert.AreEqual(0ul, data3.ChainIndex); + Assert.AreEqual(1ul, data3.PrevContextData.ChainIndex); + Assert.IsNull(data3.NextContextData); + + parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 127, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30 }) + }; + context = new SEALContext(parms, expandModChain: true, secLevel: SecLevelType.None); + Assert.AreEqual(context.ParameterErrorName(), "invalid_poly_modulus_degree_non_power_of_two"); + Assert.AreEqual(context.ParameterErrorMessage(), "poly_modulus_degree is not a power of two"); + } { - PolyModulusDegree = 127, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30 }) - }; - context = new SEALContext(parms, expandModChain: true, secLevel: SecLevelType.None); - Assert.AreEqual(context.ParameterErrorName(), "invalid_poly_modulus_degree_non_power_of_two"); - Assert.AreEqual(context.ParameterErrorMessage(), "poly_modulus_degree is not a power of two"); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 128, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30 }) + }; + SEALContext context = new SEALContext(parms, expandModChain: true, secLevel: SecLevelType.None); + + SEALContext.ContextData data = context.KeyContextData; + Assert.IsNotNull(data); + + EncryptionParameters parms2 = data.Parms; + Assert.AreEqual(parms.PolyModulusDegree, parms2.PolyModulusDegree); + + EncryptionParameterQualifiers qualifiers = data.Qualifiers; + Assert.IsNotNull(qualifiers); + + Assert.IsTrue(qualifiers.ParametersSet); + Assert.IsFalse(qualifiers.UsingBatching); + Assert.IsTrue(qualifiers.UsingFastPlainLift); + Assert.IsTrue(qualifiers.UsingFFT); + Assert.IsTrue(qualifiers.UsingNTT); + Assert.AreEqual(SecLevelType.None, qualifiers.SecLevel); + Assert.IsFalse(qualifiers.UsingDescendingModulusChain); + Assert.IsTrue(context.UsingKeyswitching); + + ulong[] cdpm = data.CoeffDivPlainModulus; + Assert.AreEqual(3, cdpm.Length); + + Assert.AreEqual(32ul, data.PlainUpperHalfThreshold); + + Assert.AreEqual(3, data.PlainUpperHalfIncrement.Length); + Assert.IsNull(data.UpperHalfThreshold); + Assert.IsNotNull(data.UpperHalfIncrement); + Assert.AreEqual(3, data.UpperHalfIncrement.Length); + Assert.AreEqual(2ul, data.ChainIndex); + + Assert.IsNull(data.PrevContextData); + SEALContext.ContextData data2 = data.NextContextData; + Assert.IsNotNull(data2); + Assert.AreEqual(1ul, data2.ChainIndex); + Assert.AreEqual(2ul, data2.PrevContextData.ChainIndex); + + SEALContext.ContextData data3 = data2.NextContextData; + Assert.IsNotNull(data3); + Assert.AreEqual(0ul, data3.ChainIndex); + Assert.AreEqual(1ul, data3.PrevContextData.ChainIndex); + Assert.IsNull(data3.NextContextData); + + parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 127, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(128, new int[] { 30, 30, 30 }) + }; + context = new SEALContext(parms, expandModChain: true, secLevel: SecLevelType.None); + Assert.AreEqual(context.ParameterErrorName(), "invalid_poly_modulus_degree_non_power_of_two"); + Assert.AreEqual(context.ParameterErrorMessage(), "poly_modulus_degree is not a power of two"); + } } [TestMethod] @@ -154,37 +223,72 @@ public void SEALContextCKKSParamsTest() [TestMethod] public void ExpandModChainTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 4096, - CoeffModulus = CoeffModulus.BFVDefault(polyModulusDegree: 4096), - PlainModulus = new Modulus(1 << 20) - }; - - SEALContext context1 = new SEALContext(parms, - expandModChain: true, - secLevel: SecLevelType.None); - - // By default there is a chain - SEALContext.ContextData contextData = context1.KeyContextData; - Assert.IsNotNull(contextData); - Assert.IsNull(contextData.PrevContextData); - Assert.IsNotNull(contextData.NextContextData); - contextData = context1.FirstContextData; - Assert.IsNotNull(contextData); - Assert.IsNotNull(contextData.PrevContextData); - Assert.IsNotNull(contextData.NextContextData); - - // This should not create a chain - SEALContext context2 = new SEALContext(parms, expandModChain: false); - contextData = context2.KeyContextData; - Assert.IsNotNull(contextData); - Assert.IsNull(contextData.PrevContextData); - Assert.IsNotNull(contextData.NextContextData); - contextData = context2.FirstContextData; - Assert.IsNotNull(contextData); - Assert.IsNotNull(contextData.PrevContextData); - Assert.IsNull(contextData.NextContextData); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 4096, + CoeffModulus = CoeffModulus.BFVDefault(polyModulusDegree: 4096), + PlainModulus = new Modulus(1 << 20) + }; + + SEALContext context1 = new SEALContext(parms, + expandModChain: true, + secLevel: SecLevelType.None); + + // By default there is a chain + SEALContext.ContextData contextData = context1.KeyContextData; + Assert.IsNotNull(contextData); + Assert.IsNull(contextData.PrevContextData); + Assert.IsNotNull(contextData.NextContextData); + contextData = context1.FirstContextData; + Assert.IsNotNull(contextData); + Assert.IsNotNull(contextData.PrevContextData); + Assert.IsNotNull(contextData.NextContextData); + + // This should not create a chain + SEALContext context2 = new SEALContext(parms, expandModChain: false); + contextData = context2.KeyContextData; + Assert.IsNotNull(contextData); + Assert.IsNull(contextData.PrevContextData); + Assert.IsNotNull(contextData.NextContextData); + contextData = context2.FirstContextData; + Assert.IsNotNull(contextData); + Assert.IsNotNull(contextData.PrevContextData); + Assert.IsNull(contextData.NextContextData); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 4096, + CoeffModulus = CoeffModulus.BGVDefault(polyModulusDegree: 4096), + PlainModulus = new Modulus(1 << 20) + }; + + SEALContext context1 = new SEALContext(parms, + expandModChain: true, + secLevel: SecLevelType.None); + + // By default there is a chain + SEALContext.ContextData contextData = context1.KeyContextData; + Assert.IsNotNull(contextData); + Assert.IsNull(contextData.PrevContextData); + Assert.IsNotNull(contextData.NextContextData); + contextData = context1.FirstContextData; + Assert.IsNotNull(contextData); + Assert.IsNotNull(contextData.PrevContextData); + Assert.IsNotNull(contextData.NextContextData); + + // This should not create a chain + SEALContext context2 = new SEALContext(parms, expandModChain: false); + contextData = context2.KeyContextData; + Assert.IsNotNull(contextData); + Assert.IsNull(contextData.PrevContextData); + Assert.IsNotNull(contextData.NextContextData); + contextData = context2.FirstContextData; + Assert.IsNotNull(contextData); + Assert.IsNotNull(contextData.PrevContextData); + Assert.IsNull(contextData.NextContextData); + } } } } diff --git a/dotnet/tests/SecretKeyTests.cs b/dotnet/tests/SecretKeyTests.cs index 968b47f2a..fb2e66e4b 100644 --- a/dotnet/tests/SecretKeyTests.cs +++ b/dotnet/tests/SecretKeyTests.cs @@ -15,88 +15,172 @@ public class SecretKeyTests [TestMethod] public void CreateTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - - SecretKey secret = keygen.SecretKey; - SecretKey copy = new SecretKey(secret); - - Assert.AreEqual(64ul, copy.Data.CoeffCount); - Assert.IsTrue(copy.Data.IsNTTForm); - - SecretKey copy2 = new SecretKey(); - copy2.Set(copy); - - Assert.AreEqual(64ul, copy2.Data.CoeffCount); - Assert.IsTrue(copy2.Data.IsNTTForm); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + + SecretKey secret = keygen.SecretKey; + SecretKey copy = new SecretKey(secret); + + Assert.AreEqual(64ul, copy.Data.CoeffCount); + Assert.IsTrue(copy.Data.IsNTTForm); + + SecretKey copy2 = new SecretKey(); + copy2.Set(copy); + + Assert.AreEqual(64ul, copy2.Data.CoeffCount); + Assert.IsTrue(copy2.Data.IsNTTForm); + } + { + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + + SecretKey secret = keygen.SecretKey; + SecretKey copy = new SecretKey(secret); + + Assert.AreEqual(64ul, copy.Data.CoeffCount); + Assert.IsTrue(copy.Data.IsNTTForm); + + SecretKey copy2 = new SecretKey(); + copy2.Set(copy); + + Assert.AreEqual(64ul, copy2.Data.CoeffCount); + Assert.IsTrue(copy2.Data.IsNTTForm); + } } [TestMethod] public void SaveLoadTest() { - EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) { - PolyModulusDegree = 64, - PlainModulus = new Modulus(1 << 6), - CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) - }; - SEALContext context = new SEALContext(parms, - expandModChain: false, - secLevel: SecLevelType.None); - KeyGenerator keygen = new KeyGenerator(context); - - SecretKey secret = keygen.SecretKey; - - Assert.AreEqual(64ul, secret.Data.CoeffCount); - Assert.IsTrue(secret.Data.IsNTTForm); - Assert.AreNotEqual(ParmsId.Zero, secret.ParmsId); - - SecretKey secret2 = new SecretKey(); - Assert.IsNotNull(secret2); - Assert.AreEqual(0ul, secret2.Data.CoeffCount); - Assert.IsFalse(secret2.Data.IsNTTForm); - - using (MemoryStream stream = new MemoryStream()) + EncryptionParameters parms = new EncryptionParameters(SchemeType.BFV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + + SecretKey secret = keygen.SecretKey; + + Assert.AreEqual(64ul, secret.Data.CoeffCount); + Assert.IsTrue(secret.Data.IsNTTForm); + Assert.AreNotEqual(ParmsId.Zero, secret.ParmsId); + + SecretKey secret2 = new SecretKey(); + Assert.IsNotNull(secret2); + Assert.AreEqual(0ul, secret2.Data.CoeffCount); + Assert.IsFalse(secret2.Data.IsNTTForm); + + using (MemoryStream stream = new MemoryStream()) + { + secret.Save(stream); + stream.Seek(offset: 0, loc: SeekOrigin.Begin); + secret2.Load(context, stream); + } + + Assert.AreNotSame(secret, secret2); + Assert.AreEqual(64ul, secret2.Data.CoeffCount); + Assert.IsTrue(secret2.Data.IsNTTForm); + Assert.AreNotEqual(ParmsId.Zero, secret2.ParmsId); + Assert.AreEqual(secret.ParmsId, secret2.ParmsId); + } { - secret.Save(stream); - stream.Seek(offset: 0, loc: SeekOrigin.Begin); - secret2.Load(context, stream); + EncryptionParameters parms = new EncryptionParameters(SchemeType.BGV) + { + PolyModulusDegree = 64, + PlainModulus = new Modulus(1 << 6), + CoeffModulus = CoeffModulus.Create(64, new int[] { 40 }) + }; + SEALContext context = new SEALContext(parms, + expandModChain: false, + secLevel: SecLevelType.None); + KeyGenerator keygen = new KeyGenerator(context); + + SecretKey secret = keygen.SecretKey; + + Assert.AreEqual(64ul, secret.Data.CoeffCount); + Assert.IsTrue(secret.Data.IsNTTForm); + Assert.AreNotEqual(ParmsId.Zero, secret.ParmsId); + + SecretKey secret2 = new SecretKey(); + Assert.IsNotNull(secret2); + Assert.AreEqual(0ul, secret2.Data.CoeffCount); + Assert.IsFalse(secret2.Data.IsNTTForm); + + using (MemoryStream stream = new MemoryStream()) + { + secret.Save(stream); + stream.Seek(offset: 0, loc: SeekOrigin.Begin); + secret2.Load(context, stream); + } + + Assert.AreNotSame(secret, secret2); + Assert.AreEqual(64ul, secret2.Data.CoeffCount); + Assert.IsTrue(secret2.Data.IsNTTForm); + Assert.AreNotEqual(ParmsId.Zero, secret2.ParmsId); + Assert.AreEqual(secret.ParmsId, secret2.ParmsId); } - - Assert.AreNotSame(secret, secret2); - Assert.AreEqual(64ul, secret2.Data.CoeffCount); - Assert.IsTrue(secret2.Data.IsNTTForm); - Assert.AreNotEqual(ParmsId.Zero, secret2.ParmsId); - Assert.AreEqual(secret.ParmsId, secret2.ParmsId); } [TestMethod] public void ExceptionsTest() { - SEALContext context = GlobalContext.BFVContext; - SecretKey key = new SecretKey(); + { + SEALContext context = GlobalContext.BFVContext; + SecretKey key = new SecretKey(); + + Utilities.AssertThrows(() => key = new SecretKey(null)); - Utilities.AssertThrows(() => key = new SecretKey(null)); + Utilities.AssertThrows(() => key.Set(null)); - Utilities.AssertThrows(() => key.Set(null)); + Utilities.AssertThrows(() => ValCheck.IsValidFor(key, null)); - Utilities.AssertThrows(() => ValCheck.IsValidFor(key, null)); + Utilities.AssertThrows(() => key.Save(null)); - Utilities.AssertThrows(() => key.Save(null)); + Utilities.AssertThrows(() => key.UnsafeLoad(null, new MemoryStream())); + Utilities.AssertThrows(() => key.UnsafeLoad(context, null)); + Utilities.AssertThrows(() => key.Load(context, null)); + Utilities.AssertThrows(() => key.Load(null, new MemoryStream())); + Utilities.AssertThrows(() => key.Load(context, new MemoryStream())); + } + { + SEALContext context = GlobalContext.BGVContext; + SecretKey key = new SecretKey(); + + Utilities.AssertThrows(() => key = new SecretKey(null)); + + Utilities.AssertThrows(() => key.Set(null)); - Utilities.AssertThrows(() => key.UnsafeLoad(null, new MemoryStream())); - Utilities.AssertThrows(() => key.UnsafeLoad(context, null)); - Utilities.AssertThrows(() => key.Load(context, null)); - Utilities.AssertThrows(() => key.Load(null, new MemoryStream())); - Utilities.AssertThrows(() => key.Load(context, new MemoryStream())); + Utilities.AssertThrows(() => ValCheck.IsValidFor(key, null)); + + Utilities.AssertThrows(() => key.Save(null)); + + Utilities.AssertThrows(() => key.UnsafeLoad(null, new MemoryStream())); + Utilities.AssertThrows(() => key.UnsafeLoad(context, null)); + Utilities.AssertThrows(() => key.Load(context, null)); + Utilities.AssertThrows(() => key.Load(null, new MemoryStream())); + Utilities.AssertThrows(() => key.Load(context, new MemoryStream())); + } } } } diff --git a/dotnet/tests/SerializationTests.cs b/dotnet/tests/SerializationTests.cs index f9d621a80..c59a7a6de 100644 --- a/dotnet/tests/SerializationTests.cs +++ b/dotnet/tests/SerializationTests.cs @@ -90,24 +90,47 @@ public void SEALHeaderUpgrade() [TestMethod] public void ExceptionsTest() { - SEALContext context = GlobalContext.BFVContext; - Ciphertext cipher = new Ciphertext(); + { + SEALContext context = GlobalContext.BFVContext; + Ciphertext cipher = new Ciphertext(); - using (MemoryStream mem = new MemoryStream()) + using (MemoryStream mem = new MemoryStream()) + { + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + Encryptor encryptor = new Encryptor(context, publicKey); + + Plaintext plain = new Plaintext("2x^3 + 4x^2 + 5x^1 + 6"); + encryptor.Encrypt(plain, cipher); + cipher.Save(mem); + mem.Seek(offset: 8, loc: SeekOrigin.Begin); + BinaryWriter writer = new BinaryWriter(mem, Encoding.UTF8, true); + writer.Write((ulong)0x80000000); + + mem.Seek(offset: 0, loc: SeekOrigin.Begin); + Utilities.AssertThrows(() => cipher.Load(context, mem)); + } + } { - KeyGenerator keygen = new KeyGenerator(context); - keygen.CreatePublicKey(out PublicKey publicKey); - Encryptor encryptor = new Encryptor(context, publicKey); + SEALContext context = GlobalContext.BGVContext; + Ciphertext cipher = new Ciphertext(); - Plaintext plain = new Plaintext("2x^3 + 4x^2 + 5x^1 + 6"); - encryptor.Encrypt(plain, cipher); - cipher.Save(mem); - mem.Seek(offset: 8, loc: SeekOrigin.Begin); - BinaryWriter writer = new BinaryWriter(mem, Encoding.UTF8, true); - writer.Write((ulong)0x80000000); + using (MemoryStream mem = new MemoryStream()) + { + KeyGenerator keygen = new KeyGenerator(context); + keygen.CreatePublicKey(out PublicKey publicKey); + Encryptor encryptor = new Encryptor(context, publicKey); - mem.Seek(offset: 0, loc: SeekOrigin.Begin); - Utilities.AssertThrows(() => cipher.Load(context, mem)); + Plaintext plain = new Plaintext("2x^3 + 4x^2 + 5x^1 + 6"); + encryptor.Encrypt(plain, cipher); + cipher.Save(mem); + mem.Seek(offset: 8, loc: SeekOrigin.Begin); + BinaryWriter writer = new BinaryWriter(mem, Encoding.UTF8, true); + writer.Write((ulong)0x80000000); + + mem.Seek(offset: 0, loc: SeekOrigin.Begin); + Utilities.AssertThrows(() => cipher.Load(context, mem)); + } } } } diff --git a/native/src/seal/batchencoder.cpp b/native/src/seal/batchencoder.cpp index bd553bf3e..1c8779d9b 100644 --- a/native/src/seal/batchencoder.cpp +++ b/native/src/seal/batchencoder.cpp @@ -23,7 +23,7 @@ namespace seal } auto &context_data = *context_.first_context_data(); - if (context_data.parms().scheme() != scheme_type::bfv) + if (context_data.parms().scheme() != scheme_type::bfv && context_data.parms().scheme() != scheme_type::bgv) { throw invalid_argument("unsupported scheme"); } diff --git a/native/src/seal/ciphertext.cpp b/native/src/seal/ciphertext.cpp index 7541f0d78..393dac56d 100644 --- a/native/src/seal/ciphertext.cpp +++ b/native/src/seal/ciphertext.cpp @@ -322,6 +322,18 @@ namespace seal // Set up a UniformRandomGenerator and expand new_data.data_.resize(total_uint64_count); new_data.expand_seed(context, prng_info, version); + + // In BGV, c1 = -A + auto parms = context.get_context_data(parms_id)->parms(); + if(parms.scheme() == scheme_type::bgv){ + uint64_t *c1 = new_data.data(1); + auto coeff_count = parms.poly_modulus_degree(); + auto coeff_modulus = parms.coeff_modulus(); + size_t coeff_modulus_size = parms.coeff_modulus().size(); + for (size_t i = 0; i < coeff_modulus_size; i++){ + negate_poly_coeffmod(c1 + i * coeff_count, coeff_count, coeff_modulus[i], c1 + i * coeff_count); + } + } } // Verify that the buffer is correct diff --git a/native/src/seal/context.cpp b/native/src/seal/context.cpp index c14b2b9c7..887a1312c 100644 --- a/native/src/seal/context.cpp +++ b/native/src/seal/context.cpp @@ -86,7 +86,7 @@ namespace seal return "valid"; case error_type::invalid_scheme: - return "scheme must be BFV or CKKS"; + return "scheme must be BFV or CKKS or BGV"; case error_type::invalid_coeff_modulus_size: return "coeff_modulus's primes' count is not bounded by SEAL_COEFF_MOD_COUNT_MIN(MAX)"; @@ -250,7 +250,7 @@ namespace seal return context_data; } - if (parms.scheme() == scheme_type::bfv) + if (parms.scheme() == scheme_type::bfv || parms.scheme() == scheme_type::bgv) { // Plain modulus must be at least 2 and at most 60 bits if (plain_modulus.value() >> SEAL_PLAIN_MOD_BIT_COUNT_MAX || diff --git a/native/src/seal/context.h b/native/src/seal/context.h index df0056467..4f971f76d 100644 --- a/native/src/seal/context.h +++ b/native/src/seal/context.h @@ -44,7 +44,7 @@ namespace seal success = 0, /** - scheme must be BFV or CKKS + scheme must be BFV or CKKS or BGV */ invalid_scheme = 1, diff --git a/native/src/seal/decryptor.cpp b/native/src/seal/decryptor.cpp index 585b721d1..ad7776b86 100644 --- a/native/src/seal/decryptor.cpp +++ b/native/src/seal/decryptor.cpp @@ -101,10 +101,15 @@ namespace seal case scheme_type::ckks: ckks_decrypt(encrypted, destination, pool_); return; + + case scheme_type::bgv: + bgv_decrypt(encrypted, destination, pool_); + return; default: throw invalid_argument("unsupported scheme"); } + } void Decryptor::bfv_decrypt(const Ciphertext &encrypted, Plaintext &destination, MemoryPoolHandle pool) @@ -182,6 +187,51 @@ namespace seal destination.scale() = encrypted.scale(); } + void Decryptor::bgv_decrypt(const Ciphertext &encrypted, Plaintext &destination, MemoryPoolHandle pool) + { + if (encrypted.is_ntt_form()) + { + throw invalid_argument("encrypted cannot be in NTT form"); + } + + auto &context_data = *context_.get_context_data(encrypted.parms_id()); + auto &first_context_data = *context_.first_context_data(); + auto &parms = context_data.parms(); + auto &first_parms = first_context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + auto &first_coeff_modulus = first_parms.coeff_modulus(); + auto &plain_modulus = parms.plain_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_modulus_size = coeff_modulus.size(); + + SEAL_ALLOCATE_ZERO_GET_RNS_ITER(tmp_dest_modq, coeff_count, coeff_modulus_size, pool); + + dot_product_ct_sk_array(encrypted, tmp_dest_modq, pool_); + + destination.parms_id() = parms_id_zero; + destination.resize(coeff_count); + + context_data.rns_tool()->decrypt_modt(tmp_dest_modq, destination.data(), pool); + + //Fix the plaintext after mod-switch operations. + uint64_t fix = 1; + for(size_t i = context_data.chain_index(); i < first_context_data.chain_index(); i++) + { + auto scalar = barrett_reduce_64(first_coeff_modulus[i+1].value(), plain_modulus); + fix = multiply_uint_mod(fix, scalar, plain_modulus); + } + if(fix != 1) + { + multiply_poly_scalar_coeffmod(CoeffIter(destination.data()), coeff_count, fix, plain_modulus, CoeffIter(destination.data())); + } + + // How many non-zero coefficients do we really have in the result? + size_t plain_coeff_count = get_significant_uint64_count_uint(destination.data(), coeff_count); + + // Resize destination to appropriate size + destination.resize(max(plain_coeff_count, size_t(1))); + } + void Decryptor::compute_secret_key_array(size_t max_power) { #ifdef SEAL_DEBUG @@ -313,7 +363,8 @@ namespace seal throw invalid_argument("encrypted is empty"); } - if (context_.key_context_data()->parms().scheme() != scheme_type::bfv) + auto scheme = context_.key_context_data()->parms().scheme(); + if (scheme!= scheme_type::bfv && scheme != scheme_type::bgv) { throw logic_error("unsupported scheme"); } @@ -346,7 +397,10 @@ namespace seal // Multiply by plain_modulus and reduce mod coeff_modulus to get // coeff_modulus()*noise. - multiply_poly_scalar_coeffmod(noise_poly, coeff_modulus_size, plain_modulus.value(), coeff_modulus, noise_poly); + if(scheme == scheme_type::bfv) + { + multiply_poly_scalar_coeffmod(noise_poly, coeff_modulus_size, plain_modulus.value(), coeff_modulus, noise_poly); + } // CRT-compose the noise context_data.rns_tool()->base_q()->compose_array(noise_poly, coeff_count, pool_); diff --git a/native/src/seal/decryptor.h b/native/src/seal/decryptor.h index 0384a10c8..302a812f6 100644 --- a/native/src/seal/decryptor.h +++ b/native/src/seal/decryptor.h @@ -104,6 +104,8 @@ namespace seal void ckks_decrypt(const Ciphertext &encrypted, Plaintext &destination, MemoryPoolHandle pool); + void bgv_decrypt(const Ciphertext &encrypted, Plaintext &destination, MemoryPoolHandle pool); + Decryptor(const Decryptor ©) = delete; Decryptor(Decryptor &&source) = delete; diff --git a/native/src/seal/encryptionparams.cpp b/native/src/seal/encryptionparams.cpp index 923765b00..31e074411 100644 --- a/native/src/seal/encryptionparams.cpp +++ b/native/src/seal/encryptionparams.cpp @@ -32,7 +32,7 @@ namespace seal mod.save(stream, compr_mode_type::none); } - // Only BFV uses plain_modulus but save it in any case for simplicity + // Only BFV and BGV uses plain_modulus but save it in any case for simplicity plain_modulus_.save(stream, compr_mode_type::none); } catch (const ios_base::failure &) @@ -99,7 +99,7 @@ namespace seal parms.set_poly_modulus_degree(safe_cast(poly_modulus_degree64)); parms.set_coeff_modulus(coeff_modulus); - // Only BFV uses plain_modulus; set_plain_modulus checks that for + // Only BFV and BGV uses plain_modulus; set_plain_modulus checks that for // other schemes it is zero parms.set_plain_modulus(plain_modulus); diff --git a/native/src/seal/encryptionparams.h b/native/src/seal/encryptionparams.h index 09bbdbf78..23dd2d5f8 100644 --- a/native/src/seal/encryptionparams.h +++ b/native/src/seal/encryptionparams.h @@ -31,7 +31,10 @@ namespace seal bfv = 0x1, // Cheon-Kim-Kim-Song scheme - ckks = 0x2 + ckks = 0x2, + + //Brakerski-Gentry-Vaikuntanathan scheme + bgv = 0x3 }; /** @@ -222,7 +225,7 @@ namespace seal inline void set_plain_modulus(const Modulus &plain_modulus) { // Check that scheme is BFV - if (scheme_ != scheme_type::bfv && !plain_modulus.is_zero()) + if (scheme_ != scheme_type::bfv && scheme_ != scheme_type::bgv && !plain_modulus.is_zero()) { throw std::logic_error("plain_modulus is not supported for this scheme"); } @@ -466,7 +469,11 @@ namespace seal /* fall through */ case static_cast(scheme_type::ckks): + /* fall through */ + + case static_cast(scheme_type::bgv): return true; + } return false; } diff --git a/native/src/seal/encryptor.cpp b/native/src/seal/encryptor.cpp index e826c5c06..d69d5af7f 100644 --- a/native/src/seal/encryptor.cpp +++ b/native/src/seal/encryptor.cpp @@ -111,7 +111,7 @@ namespace seal { is_ntt_form = true; } - else if (parms.scheme() != scheme_type::bfv) + else if (parms.scheme() != scheme_type::bfv && parms.scheme() != scheme_type::bgv) { throw invalid_argument("unsupported scheme"); } @@ -141,10 +141,16 @@ namespace seal rns_tool->divide_and_round_q_last_ntt_inplace( get<0>(I), prev_context_data.small_ntt_tables(), pool); } - else + // bfv switch-to-next + else if (parms.scheme() != scheme_type::bgv) { rns_tool->divide_and_round_q_last_inplace(get<0>(I), pool); } + // bgv switch-to-next + else + { + rns_tool->mod_t_and_divide_q_last_inplace(get<0>(I), pool); + } set_poly(get<0>(I), coeff_count, coeff_modulus_size, get<1>(I)); }); @@ -231,6 +237,20 @@ namespace seal destination.scale() = plain.scale(); } + else if (scheme == scheme_type::bgv) + { + if (plain.is_ntt_form()) + { + throw invalid_argument("plain cannot be in NTT form"); + } + encrypt_zero_internal(context_.first_parms_id(), is_asymmetric, save_seed, destination, pool); + auto context_data_ptr = context_.first_context_data(); + auto &parms = context_data_ptr->parms(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_modulus_size = parms.coeff_modulus().size(); + // c_{0} = pk_{0}*u + p*e_{0} + M + add_plain_without_scaling_variant(plain, *context_data_ptr, RNSIter(destination.data(0), coeff_count)); + } else { throw invalid_argument("unsupported scheme"); diff --git a/native/src/seal/encryptor.h b/native/src/seal/encryptor.h index 589ccedd3..f911db5b5 100644 --- a/native/src/seal/encryptor.h +++ b/native/src/seal/encryptor.h @@ -37,7 +37,7 @@ namespace seal performance bottlenecks. @par NTT form - When using the BFV scheme (scheme_type::bfv), all plaintext and ciphertexts should + When using the BFV/BGV scheme (scheme_type::bfv/bgv), all plaintext and ciphertexts should remain by default in the usual coefficient representation, i.e. not in NTT form. When using the CKKS scheme (scheme_type::ckks), all plaintexts and ciphertexts should remain by default in NTT form. We call these scheme-specific NTT states @@ -116,7 +116,7 @@ namespace seal destination. The encryption parameters for the resulting ciphertext correspond to: - 1) in BFV, the highest (data) level in the modulus switching chain, + 1) in BFV/BGV, the highest (data) level in the modulus switching chain, 2) in CKKS, the encryption parameters of the plaintext. Dynamic memory allocations in the process are allocated from the memory pool pointed to by the given MemoryPoolHandle. @@ -142,7 +142,7 @@ namespace seal a serializable object. The encryption parameters for the resulting ciphertext correspond to: - 1) in BFV, the highest (data) level in the modulus switching chain, + 1) in BFV/BGV, the highest (data) level in the modulus switching chain, 2) in CKKS, the encryption parameters of the plaintext. Dynamic memory allocations in the process are allocated from the memory pool pointed to by the given MemoryPoolHandle. @@ -253,7 +253,7 @@ namespace seal destination. The encryption parameters for the resulting ciphertext correspond to: - 1) in BFV, the highest (data) level in the modulus switching chain, + 1) in BFV/BGV, the highest (data) level in the modulus switching chain, 2) in CKKS, the encryption parameters of the plaintext. Dynamic memory allocations in the process are allocated from the memory pool pointed to by the given MemoryPoolHandle. @@ -284,7 +284,7 @@ namespace seal impact. The encryption parameters for the resulting ciphertext correspond to: - 1) in BFV, the highest (data) level in the modulus switching chain, + 1) in BFV/BGV, the highest (data) level in the modulus switching chain, 2) in CKKS, the encryption parameters of the plaintext. Dynamic memory allocations in the process are allocated from the memory pool pointed to by the given MemoryPoolHandle. diff --git a/native/src/seal/evaluator.cpp b/native/src/seal/evaluator.cpp index 8f4a4dfdb..2808a672f 100644 --- a/native/src/seal/evaluator.cpp +++ b/native/src/seal/evaluator.cpp @@ -33,6 +33,7 @@ namespace seal switch (context_data.parms().scheme()) { case scheme_type::bfv: + case scheme_type::bgv: scale_bit_count_bound = context_data.parms().plain_modulus().bit_count(); break; case scheme_type::ckks: @@ -272,6 +273,10 @@ namespace seal ckks_multiply(encrypted1, encrypted2, pool); break; + case scheme_type::bgv: + bgv_multiply(encrypted1, const_cast(encrypted2), pool); + break; + default: throw invalid_argument("unsupported scheme"); } @@ -549,6 +554,86 @@ namespace seal encrypted1.scale() = new_scale; } + void Evaluator::bgv_multiply(Ciphertext &encrypted1, Ciphertext &encrypted2, MemoryPoolHandle pool) + { + if (encrypted1.is_ntt_form() || encrypted2.is_ntt_form()) + { + throw invalid_argument("encryped1 or encrypted2 must be not in NTT form"); + } + + // Extract encryption parameters. + auto &context_data = *context_.get_context_data(encrypted1.parms_id()); + auto &parms = context_data.parms(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_modulus_size = parms.coeff_modulus().size(); + size_t encrypted1_size = encrypted1.size(); + size_t encrypted2_size = encrypted2.size(); + auto ntt_table = context_data.small_ntt_tables(); + + // Determine destination.size() + // Default is 3 (c_0, c_1, c_2) + size_t dest_size = sub_safe(add_safe(encrypted1_size, encrypted2_size), size_t(1)); + + // Set up iterator for the base + auto coeff_modulus = iter(parms.coeff_modulus()); + + // Prepare destination + encrypted1.resize(context_, context_data.parms_id(), dest_size); + + // convert c0 and c1 to ntt + ntt_negacyclic_harvey(encrypted1, encrypted1_size, ntt_table); + if (&encrypted1 != &encrypted2) + { + ntt_negacyclic_harvey(encrypted2, encrypted2_size, ntt_table); + } + + // Set up iterators for input ciphertexts + auto encrypted1_iter = iter(encrypted1); + auto encrypted2_iter = iter(encrypted2); + + // Allocate temporary space for the result + SEAL_ALLOCATE_ZERO_GET_POLY_ITER(temp, dest_size, coeff_count, coeff_modulus_size, pool); + + SEAL_ITERATE(iter(size_t(0)), dest_size, [&](auto I) { + // We iterate over relevant components of encrypted1 and encrypted2 in increasing order for + // encrypted1 and reversed (decreasing) order for encrypted2. The bounds for the indices of + // the relevant terms are obtained as follows. + size_t curr_encrypted1_last = min(I, encrypted1_size - 1); + size_t curr_encrypted2_first = min(I, encrypted2_size - 1); + size_t curr_encrypted1_first = I - curr_encrypted2_first; + // size_t curr_encrypted2_last = secret_power_index - curr_encrypted1_last; + + // The total number of dyadic products is now easy to compute + size_t steps = curr_encrypted1_last - curr_encrypted1_first + 1; + + // Create a shifted iterator for the first input + auto shifted_encrypted1_iter = encrypted1_iter + curr_encrypted1_first; + + // Create a shifted reverse iterator for the second input + auto shifted_reversed_encrypted2_iter = reverse_iter(encrypted2_iter + curr_encrypted2_first); + + SEAL_ITERATE(iter(shifted_encrypted1_iter, shifted_reversed_encrypted2_iter), steps, [&](auto J) { + // Extra care needed here: + // temp_iter must be dereferenced once to produce an appropriate RNSIter + SEAL_ITERATE(iter(J, coeff_modulus, temp[I]), coeff_modulus_size, [&](auto K) { + SEAL_ALLOCATE_GET_COEFF_ITER(prod, coeff_count, pool); + dyadic_product_coeffmod(get<0, 0>(K), get<0, 1>(K), coeff_count, get<1>(K), prod); + add_poly_coeffmod(prod, get<2>(K), coeff_count, get<1>(K), get<2>(K)); + }); + }); + }); + + // Set the final result + set_poly_array(temp, dest_size, coeff_count, coeff_modulus_size, encrypted1.data()); + + // Convert the result (and the original ciphertext) back to non-NTT + inverse_ntt_negacyclic_harvey(encrypted1, encrypted1.size(), ntt_table); + if (&encrypted1 != &encrypted2) + { + inverse_ntt_negacyclic_harvey(encrypted2, encrypted2.size(), ntt_table); + } + } + void Evaluator::square_inplace(Ciphertext &encrypted, MemoryPoolHandle pool) { // Verify parameters. @@ -568,6 +653,10 @@ namespace seal ckks_square(encrypted, move(pool)); break; + case scheme_type::bgv: + bgv_square(encrypted, move(pool)); + break; + default: throw invalid_argument("unsupported scheme"); } @@ -799,6 +888,70 @@ namespace seal encrypted.scale() = new_scale; } + void Evaluator::bgv_square(Ciphertext &encrypted, MemoryPoolHandle pool) + { + if (encrypted.is_ntt_form()) + { + throw invalid_argument("encrypted cannot be in NTT form"); + } + + // Extract encryption parameters. + auto &context_data = *context_.get_context_data(encrypted.parms_id()); + auto &parms = context_data.parms(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_modulus_size = parms.coeff_modulus().size(); + size_t encrypted_size = encrypted.size(); + auto ntt_table = context_data.small_ntt_tables(); + + // Optimization implemented currently only for size 2 ciphertexts + if (encrypted_size != 2) + { + bgv_multiply(encrypted, encrypted, move(pool)); + return; + } + + // Determine destination.size() + // Default is 3 (c_0, c_1, c_2) + size_t dest_size = sub_safe(add_safe(encrypted_size, encrypted_size), size_t(1)); + + // Size check + if (!product_fits_in(dest_size, coeff_count, coeff_modulus_size)) + { + throw logic_error("invalid parameters"); + } + + // Set up iterator for the base + auto coeff_modulus = iter(parms.coeff_modulus()); + + // Prepare destination + encrypted.resize(context_, context_data.parms_id(), dest_size); + + // Convert c0 and c1 to ntt + ntt_negacyclic_harvey(encrypted, encrypted_size, ntt_table); + + // Set up iterators for input ciphertext + auto encrypted_iter = iter(encrypted); + + // Allocate temporary space for the result + SEAL_ALLOCATE_ZERO_GET_POLY_ITER(temp, dest_size, coeff_count, coeff_modulus_size, pool); + + // Compute c0^2 + dyadic_product_coeffmod(encrypted_iter[0], encrypted_iter[0], coeff_modulus_size, coeff_modulus, temp[0]); + + // Compute 2*c0*c1 + dyadic_product_coeffmod(encrypted_iter[0], encrypted_iter[1], coeff_modulus_size, coeff_modulus, temp[1]); + add_poly_coeffmod(temp[1], temp[1], coeff_modulus_size, coeff_modulus, temp[1]); + + // Compute c1^2 + dyadic_product_coeffmod(encrypted_iter[1], encrypted_iter[1], coeff_modulus_size, coeff_modulus, temp[2]); + + // Set the final result + set_poly_array(temp, dest_size, coeff_count, coeff_modulus_size, encrypted.data()); + + // Convert the final output to Non-NTT form + inverse_ntt_negacyclic_harvey(encrypted, dest_size, ntt_table); + } + void Evaluator::relinearize_internal( Ciphertext &encrypted, const RelinKeys &relin_keys, size_t destination_size, MemoryPoolHandle pool) { @@ -869,6 +1022,10 @@ namespace seal { throw invalid_argument("CKKS encrypted must be in NTT form"); } + if (context_data_ptr->parms().scheme() == scheme_type::bgv && encrypted.is_ntt_form()) + { + throw invalid_argument("BGV encrypted cannot be in NTT form"); + } if (!pool) { throw invalid_argument("pool is uninitialized"); @@ -901,6 +1058,12 @@ namespace seal }); break; + case scheme_type::bgv: + SEAL_ITERATE(iter(encrypted_copy), encrypted_size, [&](auto I) { + rns_tool->mod_t_and_divide_q_last_inplace(I, pool); + }); + break; + default: throw invalid_argument("unsupported scheme"); } @@ -1050,6 +1213,10 @@ namespace seal mod_switch_drop_to_next(encrypted, destination, move(pool)); break; + case scheme_type::bgv: + mod_switch_scale_to_next(encrypted, destination, move(pool)); + break; + default: throw invalid_argument("unsupported scheme"); } @@ -1133,6 +1300,8 @@ namespace seal switch (context_.first_context_data()->parms().scheme()) { case scheme_type::bfv: + /* Fall through */ + case scheme_type::bgv: throw invalid_argument("unsupported operation for scheme type"); case scheme_type::ckks: @@ -1182,6 +1351,8 @@ namespace seal switch (context_data_ptr->parms().scheme()) { case scheme_type::bfv: + /* Fall through */ + case scheme_type::bgv: throw invalid_argument("unsupported operation for scheme type"); case scheme_type::ckks: @@ -1236,7 +1407,7 @@ namespace seal auto &context_data = *context_data_ptr; auto &parms = context_data.parms(); - if (parms.scheme() != scheme_type::bfv) + if (parms.scheme() != scheme_type::bfv && parms.scheme() != scheme_type::bgv) { throw logic_error("unsupported scheme"); } @@ -1336,6 +1507,10 @@ namespace seal { throw invalid_argument("CKKS encrypted must be in NTT form"); } + if (parms.scheme() == scheme_type::bgv && encrypted.is_ntt_form()) + { + throw invalid_argument("BGV encrypted cannot be in NTT form"); + } if (plain.is_ntt_form() != encrypted.is_ntt_form()) { throw invalid_argument("NTT form mismatch"); @@ -1376,6 +1551,12 @@ namespace seal break; } + case scheme_type::bgv: + { + add_plain_without_scaling_variant(plain, context_data, *iter(encrypted)); + break; + } + default: throw invalid_argument("unsupported scheme"); } @@ -1406,6 +1587,10 @@ namespace seal { throw invalid_argument("BFV encrypted cannot be in NTT form"); } + if (parms.scheme() == scheme_type::bgv && encrypted.is_ntt_form()) + { + throw invalid_argument("BGV encrypted cannot be in NTT form"); + } if (parms.scheme() == scheme_type::ckks && !encrypted.is_ntt_form()) { throw invalid_argument("CKKS encrypted must be in NTT form"); @@ -1450,6 +1635,12 @@ namespace seal break; } + case scheme_type::bgv: + { + sub_plain_without_scaling_variant(plain, context_data, *iter(encrypted)); + break; + } + default: throw invalid_argument("unsupported scheme"); } @@ -1907,7 +2098,7 @@ namespace seal // DO NOT CHANGE EXECUTION ORDER OF FOLLOWING SECTION // BEGIN: Apply Galois for each ciphertext // Execution order is sensitive, since apply_galois is not inplace! - if (parms.scheme() == scheme_type::bfv) + if (parms.scheme() == scheme_type::bfv || parms.scheme() == scheme_type::bgv) { // !!! DO NOT CHANGE EXECUTION ORDER!!! @@ -2062,6 +2253,10 @@ namespace seal { throw invalid_argument("CKKS encrypted must be in NTT form"); } + if (scheme == scheme_type::bgv && encrypted.is_ntt_form()) + { + throw invalid_argument("BGV encrypted cannot be in NTT form"); + } // Extract encryption parameters. size_t coeff_count = parms.poly_modulus_degree(); @@ -2203,63 +2398,153 @@ namespace seal // Perform modulus switching with scaling PolyIter t_poly_prod_iter(t_poly_prod.get(), coeff_count, rns_modulus_size); SEAL_ITERATE(iter(encrypted, t_poly_prod_iter), key_component_count, [&](auto I) { - // Lazy reduction; this needs to be then reduced mod qi - CoeffIter t_last(get<1>(I)[decomp_modulus_size]); - inverse_ntt_negacyclic_harvey_lazy(t_last, key_ntt_tables[key_modulus_size - 1]); - - // Add (p-1)/2 to change from flooring to rounding. - uint64_t qk = key_modulus[key_modulus_size - 1].value(); - uint64_t qk_half = qk >> 1; - SEAL_ITERATE(t_last, coeff_count, [&](auto &J) { - J = barrett_reduce_64(J + qk_half, key_modulus[key_modulus_size - 1]); - }); - - SEAL_ITERATE(iter(I, key_modulus, key_ntt_tables, modswitch_factors), decomp_modulus_size, [&](auto J) { - SEAL_ALLOCATE_GET_COEFF_ITER(t_ntt, coeff_count, pool); - - // (ct mod 4qk) mod qi - uint64_t qi = get<1>(J).value(); - if (qk > qi) + if (scheme == scheme_type::bgv) + { + const Modulus &plain_modulus = parms.plain_modulus(); + // qk is the special prime + uint64_t qk = key_modulus[key_modulus_size - 1].value(); + uint64_t qp = plain_modulus.value(); + uint64_t qk_inv_qp = context_.key_context_data()->rns_tool()->inv_q_last_mod_p(); + + // Lazy reduction; this needs to be then reduced mod qi + CoeffIter t_last(get<1>(I)[decomp_modulus_size]); + inverse_ntt_negacyclic_harvey(t_last, key_ntt_tables[key_modulus_size - 1]); + + // if q_k mod q_p == 1, we can use k = -c mod p + if (qk_inv_qp == 1) { - // This cannot be spared. NTT only tolerates input that is less than 4*modulus (i.e. qk <=4*qi). - modulo_poly_coeffs(t_last, coeff_count, get<1>(J), t_ntt); + // ct mod qi + qk * (t - ct mod qi mod t) + SEAL_ITERATE(t_last, coeff_count, [&](auto &J) { + uint64_t c = barrett_reduce_64(J, plain_modulus); + J += qk * (qp - c); + }); + + SEAL_ITERATE( + iter(I, key_modulus, modswitch_factors, key_ntt_tables), decomp_modulus_size, [&](auto J) { + auto barrett_reduce_64_lazy = [](uint64_t input, const Modulus &modulus) { + // Reduces input using base 2^64 Barrett reduction + // floor(2^64 / mod) == floor( floor(2^128 / mod) ) + unsigned long long tmp[2]; + const uint64_t *const_ratio = modulus.const_ratio().data(); + multiply_uint64_hw64(input, const_ratio[1], tmp + 1); + + // Barrett subtraction + return static_cast(input - tmp[1] * modulus.value()); + }; + + const Modulus &modulus = get<1>(J); + /// result at [0, 2qi) + inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<3>(J)); + + const uint64_t Lqi = modulus.value() << 2; + std::transform( + get<0, 1>(J), get<0, 1>(J) + coeff_count, t_last, get<0, 1>(J), + [&](uint64_t u, uint64_t v) { + /// lazy substraction, barrett_reduce_64_lazy result at [0, 2qi) + /// so that we use Lqi = 4*qi here + return u + Lqi - barrett_reduce_64_lazy(v, modulus); + }); + + multiply_poly_scalar_coeffmod( + get<0, 1>(J), coeff_count, get<2>(J), get<1>(J), get<0, 1>(J)); + add_poly_coeffmod(get<0, 1>(J), get<0, 0>(J), coeff_count, get<1>(J), get<0, 0>(J)); + }); + // when q_k mod q_p != 1, k = -c * q_k mod p } else { - set_uint(t_last, coeff_count, t_ntt); + SEAL_ALLOCATE_ZERO_GET_COEFF_ITER(k, coeff_count, pool); + modulo_poly_coeffs(t_last, coeff_count, plain_modulus, k); + negate_poly_coeffmod(k, coeff_count, plain_modulus, k); + multiply_poly_scalar_coeffmod(k, coeff_count, qk_inv_qp, plain_modulus, k); + + SEAL_ALLOCATE_ZERO_GET_COEFF_ITER(delta, coeff_count, pool); + SEAL_ALLOCATE_ZERO_GET_COEFF_ITER(c_mod_qi, coeff_count, pool); + SEAL_ITERATE( + iter(I, key_modulus, modswitch_factors, key_ntt_tables), decomp_modulus_size, [&](auto J) { + inverse_ntt_negacyclic_harvey(get<0, 1>(J), get<3>(J)); + // delta = k mod q_i + modulo_poly_coeffs(k, coeff_count, get<1>(J), delta); + // delta = k * q_k mod q_i + multiply_poly_scalar_coeffmod(delta, coeff_count, qk, get<1>(J), delta); + + // c mod q_i + modulo_poly_coeffs(t_last, coeff_count, get<1>(J), c_mod_qi); + // delta = c + k * q_k mod q_i + // c_{i} = c_{i} - delta mod q_i + const uint64_t Lqi = get<1>(J).value() * 2; + SEAL_ITERATE(iter(delta, c_mod_qi, get<0, 1>(J)), coeff_count, [Lqi](auto K) { + get<2>(K) = get<2>(K) + Lqi - (get<0>(K) + get<1>(K)); + }); + + multiply_poly_scalar_coeffmod( + get<0, 1>(J), coeff_count, get<2>(J), get<1>(J), get<0, 1>(J)); + + add_poly_coeffmod(get<0, 1>(J), get<0, 0>(J), coeff_count, get<1>(J), get<0, 0>(J)); + }); } + } + else + { + // Lazy reduction; this needs to be then reduced mod qi + CoeffIter t_last(get<1>(I)[decomp_modulus_size]); + inverse_ntt_negacyclic_harvey_lazy(t_last, key_ntt_tables[key_modulus_size - 1]); + + // Add (p-1)/2 to change from flooring to rounding. + uint64_t qk = key_modulus[key_modulus_size - 1].value(); + uint64_t qk_half = qk >> 1; + SEAL_ITERATE(t_last, coeff_count, [&](auto &J) { + J = barrett_reduce_64(J + qk_half, key_modulus[key_modulus_size - 1]); + }); - // Lazy substraction, results in [0, 2*qi), since fix is in [0, qi]. - uint64_t fix = qi - barrett_reduce_64(qk_half, get<1>(J)); - SEAL_ITERATE(t_ntt, coeff_count, [fix](auto &K) { K += fix; }); + SEAL_ITERATE(iter(I, key_modulus, key_ntt_tables, modswitch_factors), decomp_modulus_size, [&](auto J) { + SEAL_ALLOCATE_GET_COEFF_ITER(t_ntt, coeff_count, pool); - uint64_t qi_lazy = qi << 1; // some multiples of qi - if (scheme == scheme_type::ckks) - { - // This ntt_negacyclic_harvey_lazy results in [0, 4*qi). - ntt_negacyclic_harvey_lazy(t_ntt, get<2>(J)); + // (ct mod 4qk) mod qi + uint64_t qi = get<1>(J).value(); + if (qk > qi) + { + // This cannot be spared. NTT only tolerates input that is less than 4*modulus (i.e. qk <=4*qi). + modulo_poly_coeffs(t_last, coeff_count, get<1>(J), t_ntt); + } + else + { + set_uint(t_last, coeff_count, t_ntt); + } + + // Lazy substraction, results in [0, 2*qi), since fix is in [0, qi]. + uint64_t fix = qi - barrett_reduce_64(qk_half, get<1>(J)); + SEAL_ITERATE(t_ntt, coeff_count, [fix](auto &K) { K += fix; }); + + uint64_t qi_lazy = qi << 1; // some multiples of qi + if (scheme == scheme_type::ckks) + { + // This ntt_negacyclic_harvey_lazy results in [0, 4*qi). + ntt_negacyclic_harvey_lazy(t_ntt, get<2>(J)); #if SEAL_USER_MOD_BIT_COUNT_MAX > 60 // Reduce from [0, 4qi) to [0, 2qi) SEAL_ITERATE(t_ntt, coeff_count, [&](auto &K) { K -= SEAL_COND_SELECT(K >= qi_lazy, qi_lazy, 0); }); #else - // Since SEAL uses at most 60bit moduli, 8*qi < 2^63. - qi_lazy = qi << 2; + // Since SEAL uses at most 60bit moduli, 8*qi < 2^63. + qi_lazy = qi << 2; #endif - } - else if (scheme == scheme_type::bfv) - { - inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J)); - } + } + else if (scheme == scheme_type::bfv) + { + inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J)); + } - // ((ct mod qi) - (ct mod qk)) mod qi - SEAL_ITERATE(iter(get<0, 1>(J), t_ntt), coeff_count, [&](auto K) { get<0>(K) += qi_lazy - get<1>(K); }); + // ((ct mod qi) - (ct mod qk)) mod qi + SEAL_ITERATE( + iter(get<0, 1>(J), t_ntt), coeff_count, [&](auto K) { get<0>(K) += qi_lazy - get<1>(K); }); - // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi - multiply_poly_scalar_coeffmod(get<0, 1>(J), coeff_count, get<3>(J), get<1>(J), get<0, 1>(J)); - add_poly_coeffmod(get<0, 1>(J), get<0, 0>(J), coeff_count, get<1>(J), get<0, 0>(J)); - }); + // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi + multiply_poly_scalar_coeffmod(get<0, 1>(J), coeff_count, get<3>(J), get<1>(J), get<0, 1>(J)); + add_poly_coeffmod(get<0, 1>(J), get<0, 0>(J), coeff_count, get<1>(J), get<0, 0>(J)); + }); + } }); } } // namespace seal diff --git a/native/src/seal/evaluator.h b/native/src/seal/evaluator.h index c9380eda0..4fcf5042f 100644 --- a/native/src/seal/evaluator.h +++ b/native/src/seal/evaluator.h @@ -58,7 +58,7 @@ namespace seal sense. @par NTT form - When using the BFV scheme (scheme_type::bfv), all plaintexts and ciphertexts should remain by default in the usual + When using the BFV/BGV scheme (scheme_type::bfv/bgv), all plaintexts and ciphertexts should remain by default in the usual coefficient representation, i.e., not in NTT form. When using the CKKS scheme (scheme_type::ckks), all plaintexts and ciphertexts should remain by default in NTT form. We call these scheme-specific NTT states the "default NTT form". Some functions, such as add, work even if the inputs are not in the default state, but others, such as @@ -578,7 +578,7 @@ namespace seal @param[in] relin_keys The relinearization keys @param[out] destination The ciphertext to overwrite with the multiplication result @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::bfv + @throws std::logic_error if scheme is not scheme_type::bfv or scheme_type::bgv @throws std::invalid_argument if encrypteds is empty @throws std::invalid_argument if ciphertexts or relin_keys are not valid for the encryption parameters @throws std::invalid_argument if encrypteds are not in the default NTT form @@ -602,7 +602,7 @@ namespace seal @param[in] exponent The power to raise the ciphertext to @param[in] relin_keys The relinearization keys @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::bfv + @throws std::logic_error if scheme is not scheme_type::bfv or scheme_type::bgv @throws std::invalid_argument if encrypted or relin_keys is not valid for the encryption parameters @throws std::invalid_argument if encrypted is not in the default NTT form @throws std::invalid_argument if the output scale is too large for the encryption parameters @@ -628,7 +628,7 @@ namespace seal @param[in] relin_keys The relinearization keys @param[out] destination The ciphertext to overwrite with the power @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::bfv + @throws std::logic_error if scheme is not scheme_type::bfv or scheme_type::bgv @throws std::invalid_argument if encrypted or relin_keys is not valid for the encryption parameters @throws std::invalid_argument if encrypted is not in the default NTT form @throws std::invalid_argument if the output scale is too large for the encryption parameters @@ -860,7 +860,7 @@ namespace seal The desired Galois automorphism is given as a Galois element, and must be an odd integer in the interval [1, M-1], where M = 2*N, and N = poly_modulus_degree. Used with batching, a Galois element 3^i % M corresponds to a cyclic row rotation i steps to the left, and a Galois element 3^(N/2-i) % M corresponds to a cyclic row - rotation i steps to the right. The Galois element M-1 corresponds to a column rotation (row swap) in BFV, and + rotation i steps to the right. The Galois element M-1 corresponds to a column rotation (row swap) in BFV/BGV, and complex conjugation in CKKS. In the polynomial view (not batching), a Galois automorphism by a Galois element p changes Enc(plain(x)) to Enc(plain(x^p)). @@ -892,7 +892,7 @@ namespace seal The desired Galois automorphism is given as a Galois element, and must be an odd integer in the interval [1, M-1], where M = 2*N, and N = poly_modulus_degree. Used with batching, a Galois element 3^i % M corresponds to a cyclic row rotation i steps to the left, and a Galois element 3^(N/2-i) % M corresponds to a cyclic row - rotation i steps to the right. The Galois element M-1 corresponds to a column rotation (row swap) in BFV, and + rotation i steps to the right. The Galois element M-1 corresponds to a column rotation (row swap) in BFV/BGV, and complex conjugation in CKKS. In the polynomial view (not batching), a Galois automorphism by a Galois element p changes Enc(plain(x)) to Enc(plain(x^p)). @@ -922,7 +922,7 @@ namespace seal } /** - Rotates plaintext matrix rows cyclically. When batching is used with the BFV scheme, this function rotates the + Rotates plaintext matrix rows cyclically. When batching is used with the BFV/BGV scheme, this function rotates the encrypted plaintext matrix rows cyclically to the left (steps > 0) or to the right (steps < 0). Since the size of the batched matrix is 2-by-(N/2), where N is the degree of the polynomial modulus, the number of steps to rotate must have absolute value at most N/2-1. Dynamic memory allocations in the process are allocated from the @@ -932,7 +932,7 @@ namespace seal @param[in] steps The number of steps to rotate (negative left, positive right) @param[in] galois_keys The Galois keys @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::bfv + @throws std::logic_error if scheme is not scheme_type::bfv or scheme_type::bgv @throws std::logic_error if the encryption parameters do not support batching @throws std::invalid_argument if encrypted or galois_keys is not valid for the encryption parameters @@ -950,7 +950,8 @@ namespace seal Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys, MemoryPoolHandle pool = MemoryManager::GetPool()) { - if (context_.key_context_data()->parms().scheme() != scheme_type::bfv) + auto scheme = context_.key_context_data()->parms().scheme(); + if (scheme != scheme_type::bfv && scheme != scheme_type::bgv) { throw std::logic_error("unsupported scheme"); } @@ -958,7 +959,7 @@ namespace seal } /** - Rotates plaintext matrix rows cyclically. When batching is used with the BFV scheme, this function rotates the + Rotates plaintext matrix rows cyclically. When batching is used with the BFV/BGV scheme, this function rotates the encrypted plaintext matrix rows cyclically to the left (steps > 0) or to the right (steps < 0) and writes the result to the destination parameter. Since the size of the batched matrix is 2-by-(N/2), where N is the degree of the polynomial modulus, the number of steps to rotate must have absolute value at most N/2-1. Dynamic memory @@ -969,7 +970,7 @@ namespace seal @param[in] galois_keys The Galois keys @param[out] destination The ciphertext to overwrite with the rotated result @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::bfv + @throws std::logic_error if scheme is not scheme_type::bfv or scheme_type::bgv @throws std::logic_error if the encryption parameters do not support batching @throws std::invalid_argument if encrypted or galois_keys is not valid for the encryption parameters @@ -1001,7 +1002,7 @@ namespace seal @param[in] galois_keys The Galois keys @param[out] destination The ciphertext to overwrite with the rotated result @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::bfv + @throws std::logic_error if scheme is not scheme_type::bfv or scheme_type::bgv @throws std::logic_error if the encryption parameters do not support batching @throws std::invalid_argument if encrypted or galois_keys is not valid for the encryption parameters @@ -1016,8 +1017,9 @@ namespace seal */ inline void rotate_columns_inplace( Ciphertext &encrypted, const GaloisKeys &galois_keys, MemoryPoolHandle pool = MemoryManager::GetPool()) - { - if (context_.key_context_data()->parms().scheme() != scheme_type::bfv) + { + auto scheme = context_.key_context_data()->parms().scheme(); + if (scheme != scheme_type::bfv && scheme != scheme_type::bgv) { throw std::logic_error("unsupported scheme"); } @@ -1025,7 +1027,7 @@ namespace seal } /** - Rotates plaintext matrix columns cyclically. When batching is used with the BFV scheme, this function rotates + Rotates plaintext matrix columns cyclically. When batching is used with the BFV/BGV scheme, this function rotates the encrypted plaintext matrix columns cyclically, and writes the result to the destination parameter. Since the size of the batched matrix is 2-by-(N/2), where N is the degree of the polynomial modulus, this means simply swapping the two rows. Dynamic memory allocations in the process are allocated from the memory pool pointed to @@ -1035,7 +1037,7 @@ namespace seal @param[in] galois_keys The Galois keys @param[out] destination The ciphertext to overwrite with the rotated result @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::bfv + @throws std::logic_error if scheme is not scheme_type::bfv or scheme_type::bgv @throws std::logic_error if the encryption parameters do not support batching @throws std::invalid_argument if encrypted or galois_keys is not valid for the encryption parameters @@ -1202,10 +1204,14 @@ namespace seal void ckks_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, MemoryPoolHandle pool); + void bgv_multiply(Ciphertext &encrypted1, Ciphertext &encrypted2, MemoryPoolHandle pool); + void bfv_square(Ciphertext &encrypted, MemoryPoolHandle pool); void ckks_square(Ciphertext &encrypted, MemoryPoolHandle pool); + void bgv_square(Ciphertext &encrypted, MemoryPoolHandle pool); + void relinearize_internal( Ciphertext &encrypted, const RelinKeys &relin_keys, std::size_t destination_size, MemoryPoolHandle pool); diff --git a/native/src/seal/modulus.cpp b/native/src/seal/modulus.cpp index 3161b8531..9a1c9d73d 100644 --- a/native/src/seal/modulus.cpp +++ b/native/src/seal/modulus.cpp @@ -140,6 +140,11 @@ namespace seal } } + vector CoeffModulus::BGVDefault(size_t poly_modulus_degree, sec_level_type sec_level) + { + return BFVDefault(poly_modulus_degree, sec_level); + } + vector CoeffModulus::Create(size_t poly_modulus_degree, vector bit_sizes) { if (poly_modulus_degree > SEAL_POLY_MOD_DEGREE_MAX || poly_modulus_degree < SEAL_POLY_MOD_DEGREE_MIN || diff --git a/native/src/seal/modulus.h b/native/src/seal/modulus.h index 8c3e905bf..c2f5c6656 100644 --- a/native/src/seal/modulus.h +++ b/native/src/seal/modulus.h @@ -478,6 +478,20 @@ namespace seal SEAL_NODISCARD static std::vector BFVDefault( std::size_t poly_modulus_degree, sec_level_type sec_level = sec_level_type::tc128); + /** + Returns a default coefficient modulus for the BFV scheme that guarantees + a given security level when using a given poly_modulus_degree, which currently + uses the same modulus as BFV. + @param[in] poly_modulus_degree The value of the poly_modulus_degree + encryption parameter + @param[in] sec_level The desired standard security level + @throws std::invalid_argument if poly_modulus_degree is not a power-of-two + or is too large + @throws std::invalid_argument if sec_level is sec_level_type::none + */ + SEAL_NODISCARD static std::vector BGVDefault( + std::size_t poly_modulus_degree, sec_level_type sec_level = sec_level_type::tc128); + /** Returns a custom coefficient modulus suitable for use with the specified poly_modulus_degree. The return value will be a vector consisting of diff --git a/native/src/seal/util/defines.h b/native/src/seal/util/defines.h index b14eb410f..307ea63e1 100644 --- a/native/src/seal/util/defines.h +++ b/native/src/seal/util/defines.h @@ -171,6 +171,13 @@ namespace seal #define SEAL_NOISE_SAMPLER sample_poly_cbd #endif +#ifdef SEAL_USE_GAUSSIAN_NOISE +#define SEAL_NOISE_SAMPLER_T sample_poly_normal_t +#else +#define SEAL_NOISE_SAMPLER_T sample_poly_cbd_t +#endif + + // Use generic functions as (slower) fallback #ifndef SEAL_ADD_CARRY_UINT64 #define SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result) add_uint64_generic(operand1, operand2, carry, result) diff --git a/native/src/seal/util/rlwe.cpp b/native/src/seal/util/rlwe.cpp index 4f64daa0f..c64399ae8 100644 --- a/native/src/seal/util/rlwe.cpp +++ b/native/src/seal/util/rlwe.cpp @@ -63,6 +63,34 @@ namespace seal }); } + void sample_poly_normal_t( + shared_ptr prng, const EncryptionParameters &parms, uint64_t *destination) + { + auto coeff_modulus = parms.coeff_modulus(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t coeff_count = parms.poly_modulus_degree(); + uint64_t plain_modulus = parms.plain_modulus().value(); + + if (are_close(global_variables::noise_max_deviation, 0.0)) + { + set_zero_poly(coeff_count, coeff_modulus_size, destination); + return; + } + + RandomToStandardAdapter engine(prng); + ClippedNormalDistribution dist( + 0, global_variables::noise_standard_deviation, global_variables::noise_max_deviation); + + SEAL_ITERATE(iter(destination), coeff_count, [&](auto &I) { + int64_t noise = static_cast(dist(engine)); + uint64_t flag = static_cast(-static_cast(noise < 0)); + SEAL_ITERATE( + iter(StrideIter(&I, coeff_count), coeff_modulus), coeff_modulus_size, [&](auto J) { + *get<0>(J) = static_cast(noise) * plain_modulus + (flag & get<1>(J).value()); + }); + }); + } + void sample_poly_cbd( shared_ptr prng, const EncryptionParameters &parms, uint64_t *destination) { @@ -100,6 +128,45 @@ namespace seal }); } + void sample_poly_cbd_t( + shared_ptr prng, const EncryptionParameters &parms, uint64_t *destination) + { + auto coeff_modulus = parms.coeff_modulus(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t coeff_count = parms.poly_modulus_degree(); + uint64_t plain_modulus = parms.plain_modulus().value(); + + if (are_close(global_variables::noise_max_deviation, 0.0)) + { + set_zero_poly(coeff_count, coeff_modulus_size, destination); + return; + } + + if (!are_close(global_variables::noise_standard_deviation, 3.2)) + { + throw logic_error("centered binomial distribution only supports standard deviation 3.2; use rounded " + "Gaussian instead"); + } + + auto cbd = [&]() { + unsigned char x[6]; + prng->generate(6, reinterpret_cast(x)); + x[2] &= 0x1F; + x[5] &= 0x1F; + return hamming_weight(x[0]) + hamming_weight(x[1]) + hamming_weight(x[2]) - hamming_weight(x[3]) - + hamming_weight(x[4]) - hamming_weight(x[5]); + }; + + SEAL_ITERATE(iter(destination), coeff_count, [&](auto &I) { + int32_t noise = cbd(); + uint64_t flag = static_cast(-static_cast(noise < 0)); + SEAL_ITERATE( + iter(StrideIter(&I, coeff_count), coeff_modulus), coeff_modulus_size, [&](auto J) { + *get<0>(J) = static_cast(noise) * plain_modulus + (flag & get<1>(J).value()); + }); + }); + } + void sample_poly_uniform( shared_ptr prng, const EncryptionParameters &parms, uint64_t *destination) { @@ -174,10 +241,12 @@ namespace seal auto &context_data = *context.get_context_data(parms_id); auto &parms = context_data.parms(); auto &coeff_modulus = parms.coeff_modulus(); + auto &plain_modulus = parms.plain_modulus(); size_t coeff_modulus_size = coeff_modulus.size(); size_t coeff_count = parms.poly_modulus_degree(); auto ntt_tables = context_data.small_ntt_tables(); size_t encrypted_size = public_key.data().size(); + scheme_type type = parms.scheme(); // Make destination have right size and parms_id // Ciphertext (c_0,c_1, ...) @@ -216,17 +285,55 @@ namespace seal // c[j] = public_key[j] * u + e[j] for (size_t j = 0; j < encrypted_size; j++) { - SEAL_NOISE_SAMPLER(prng, parms, u.get()); - for (size_t i = 0; i < coeff_modulus_size; i++) + // In BGV, p * e is used + if (type == scheme_type::bgv) { - // Addition with e_0, e_1 is in NTT form - if (is_ntt_form) + const int min_bit_size = + std::min_element(coeff_modulus.cbegin(), coeff_modulus.cend(), [](auto a, auto b) { + return a.bit_count() < b.bit_count(); + })->bit_count(); + + // e * p may exceed the limit of modulus, +5 as noise_max_deviation ~ 19.2 + if (plain_modulus.bit_count() + 5 < min_bit_size) { - ntt_negacyclic_harvey(u.get() + i * coeff_count, ntt_tables[i]); + // Optimzation: we sample scaled Gaussian noise with SEAL_NOISE_SAMPLER_T + SEAL_NOISE_SAMPLER_T(prng, parms, u.get()); + RNSIter gaussian_iter(u.get(), coeff_count); + if (is_ntt_form) + { + ntt_negacyclic_harvey(gaussian_iter, coeff_modulus_size, ntt_tables); + } + RNSIter dst_iter(destination.data(j), coeff_count); + add_poly_coeffmod(gaussian_iter, dst_iter, coeff_modulus_size, coeff_modulus, dst_iter); + } + else + { + SEAL_NOISE_SAMPLER(prng, parms, u.get()); + RNSIter gaussian_iter(u.get(), coeff_count); + if (is_ntt_form) + { + ntt_negacyclic_harvey_lazy(gaussian_iter, coeff_modulus_size, ntt_tables); + } + RNSIter dst_iter(destination.data(j), coeff_count); + multiply_poly_scalar_coeffmod( + gaussian_iter, coeff_modulus_size, plain_modulus.value(), coeff_modulus, gaussian_iter); + add_poly_coeffmod(gaussian_iter, dst_iter, coeff_modulus_size, coeff_modulus, dst_iter); + } + } + else + { + SEAL_NOISE_SAMPLER(prng, parms, u.get()); + for (size_t i = 0; i < coeff_modulus_size; i++) + { + // Addition with e_0, e_1 is in NTT form + if (is_ntt_form) + { + ntt_negacyclic_harvey(u.get() + i * coeff_count, ntt_tables[i]); + } + add_poly_coeffmod( + u.get() + i * coeff_count, destination.data(j) + i * coeff_count, coeff_count, + coeff_modulus[i], destination.data(j) + i * coeff_count); } - add_poly_coeffmod( - u.get() + i * coeff_count, destination.data(j) + i * coeff_count, coeff_count, coeff_modulus[i], - destination.data(j) + i * coeff_count); } } } @@ -247,10 +354,12 @@ namespace seal auto &context_data = *context.get_context_data(parms_id); auto &parms = context_data.parms(); auto &coeff_modulus = parms.coeff_modulus(); + auto &plain_modulus = parms.plain_modulus(); size_t coeff_modulus_size = coeff_modulus.size(); size_t coeff_count = parms.poly_modulus_degree(); auto ntt_tables = context_data.small_ntt_tables(); size_t encrypted_size = 2; + scheme_type type = parms.scheme(); // If a polynomial is too small to store UniformRandomGeneratorInfo, // it is best to just disable save_seed. Note that the size needed is @@ -322,10 +431,27 @@ namespace seal { inverse_ntt_negacyclic_harvey(c0 + i * coeff_count, ntt_tables[i]); } - add_poly_coeffmod( - noise.get() + i * coeff_count, c0 + i * coeff_count, coeff_count, coeff_modulus[i], - c0 + i * coeff_count); - negate_poly_coeffmod(c0 + i * coeff_count, coeff_count, coeff_modulus[i], c0 + i * coeff_count); + // bgv: (c0,c1) = ((as+pe), -a) + if (type == scheme_type::bgv) + { + // noise = pe + multiply_poly_scalar_coeffmod( + noise.get() + i * coeff_count, coeff_count, plain_modulus.value(), coeff_modulus[i], + noise.get() + i * coeff_count); + // c0 = as + pe + add_poly_coeffmod( + noise.get() + i * coeff_count, c0 + i * coeff_count, coeff_count, coeff_modulus[i], + c0 + i * coeff_count); + // (as + pe, a) -> (as + pe, -a), + negate_poly_coeffmod(c1 + i * coeff_count, coeff_count, coeff_modulus[i], c1 + i * coeff_count); + } + else + { + add_poly_coeffmod( + noise.get() + i * coeff_count, c0 + i * coeff_count, coeff_count, coeff_modulus[i], + c0 + i * coeff_count); + negate_poly_coeffmod(c0 + i * coeff_count, coeff_count, coeff_modulus[i], c0 + i * coeff_count); + } } if (!is_ntt_form && !save_seed) diff --git a/native/src/seal/util/rlwe.h b/native/src/seal/util/rlwe.h index af12b5b66..31cebd46c 100644 --- a/native/src/seal/util/rlwe.h +++ b/native/src/seal/util/rlwe.h @@ -48,6 +48,28 @@ namespace seal std::shared_ptr prng, const EncryptionParameters &parms, std::uint64_t *destination); + /** + Generate a polynomial from a normal distribution, mulitply it with the plaintext moudlus (denoted by t), and store in RNS representation. + + @param[in] prng A uniform random generator + @param[in] parms EncryptionParameters used to parameterize an RNS polynomial + @param[out] destination Allocated space to store a random polynomial + */ + void sample_poly_normal_t( + std::shared_ptr prng, const EncryptionParameters &parms, + std::uint64_t *destination); + + /** + Generate a polynomial from a centered binomial distribution, mulitply it with the plaintext moudlus, and store in RNS representation. + + @param[in] prng A uniform random generator. + @param[in] parms EncryptionParameters used to parameterize an RNS polynomial + @param[out] destination Allocated space to store a random polynomial + */ + void sample_poly_cbd_t( + std::shared_ptr prng, const EncryptionParameters &parms, + std::uint64_t *destination); + /** Generate a uniformly random polynomial and store in RNS representation. diff --git a/native/src/seal/util/rns.cpp b/native/src/seal/util/rns.cpp index 1d6d1ef85..a1ad46d41 100644 --- a/native/src/seal/util/rns.cpp +++ b/native/src/seal/util/rns.cpp @@ -462,6 +462,79 @@ namespace seal }); } + // See "An Improved RNS Variant of the BFV Homomorphic Encryption Scheme" (CT-RSA 2019) for details + void BaseConverter::exact_convert_array(ConstRNSIter in, CoeffIter out, MemoryPoolHandle pool) const { + size_t ibase_size = ibase_.size(); + size_t obase_size = obase_.size(); + size_t count = in.poly_modulus_degree(); + + if(obase_size != 1){ + throw invalid_argument("out base in exact_convert_array must be one."); + } + + // Note that the stride size is ibase_size + SEAL_ALLOCATE_GET_STRIDE_ITER(temp, uint64_t, count, ibase_size, pool); + + // The iterator storing v + SEAL_ALLOCATE_GET_STRIDE_ITER(v, double_t, count, ibase_size, pool); + + // Aggregated rounded v + SEAL_ALLOCATE_GET_PTR_ITER(aggregated_rounded_v, uint64_t, count, pool); + + // Calculate [x_{i} * \hat{q_{i}}]_{q_{i}} + SEAL_ITERATE( + iter(in, ibase_.inv_punctured_prod_mod_base_array(), ibase_.base(), size_t(0)), ibase_size, + [&](auto I) { + // The current ibase index + size_t ibase_index = get<3>(I); + double_t divisor = static_cast(get<2>(I).value()); + + if (get<1>(I).operand == 1) + { + // No multiplication needed + SEAL_ITERATE(iter(get<0>(I), temp, v), count, [&](auto J) { + // Reduce modulo ibase element + get<1>(J)[ibase_index] = barrett_reduce_64(get<0>(J), get<2>(I)); + double_t dividend = static_cast(get<1>(J)[ibase_index]); + get<2>(J)[ibase_index] = dividend/divisor; + }); + } + else + { + // Multiplication needed + SEAL_ITERATE(iter(get<0>(I), temp, v), count, [&](auto J) { + // Multiply coefficient of in with ibase_.inv_punctured_prod_mod_base_array_ element + get<1>(J)[ibase_index] = multiply_uint_mod(get<0>(J), get<1>(I), get<2>(I)); + double_t dividend = static_cast(get<1>(J)[ibase_index]); + get<2>(J)[ibase_index] = dividend/divisor; + }); + } + }); + + //Aggrate v and rounding + SEAL_ITERATE(iter(v, aggregated_rounded_v), count, [&](auto I){ + //Otherwise a memory space of the last execution will be used. + double_t aggregated_v = 0.0; + for(size_t i = 0; i < ibase_size; ++i){ + aggregated_v += get<0>(I)[i]; + } + aggregated_v += 0.5; + get<1>(I) = static_cast(aggregated_v); + }); + + auto p = obase_.base()[0]; + auto q_mod_p = modulo_uint(ibase_.base_prod(), ibase_size, p); + auto base_change_matrix_first = base_change_matrix_[0].get(); + //Final multiplication + SEAL_ITERATE(iter(out, temp, aggregated_rounded_v), count, [&](auto J) { + // Compute the base conversion sum modulo obase element + auto sum_mod_obase = dot_product_mod(get<1>(J), base_change_matrix_first, ibase_size, p); + // Minus v*[q]_{p} mod p + auto v_q_mod_p = multiply_uint_mod(get<2>(J), q_mod_p, p); + get<0>(J) = sub_uint_mod(sum_mod_obase, v_q_mod_p, p); + }); + } + void BaseConverter::initialize() { // Verify that the size is not too large @@ -582,6 +655,11 @@ namespace seal throw logic_error("invalid rns bases"); } + if (!t_.is_zero()){ + // Set up BaseConvTool for q --> {t} + base_q_to_t_conv_ = allocate(pool_, *base_q_, RNSBase({ t_ },pool_) ,pool_); + } + // Set up BaseConverter for q --> Bsk base_q_to_Bsk_conv_ = allocate(pool_, *base_q_, *base_Bsk_, pool_); @@ -691,6 +769,11 @@ namespace seal } get<0>(I).set(temp, get<1>(I)); }); + + inv_q_last_mod_p_ = 1; + if(t_.value() != 0){ + try_invert_uint_mod(base_q_->base()[base_q_size - 1].value(), t_, inv_q_last_mod_p_); + } } void RNSTool::divide_and_round_q_last_inplace(RNSIter input, MemoryPoolHandle pool) const @@ -1106,5 +1189,78 @@ namespace seal } }); } + + void RNSTool::mod_t_and_divide_q_last_inplace(RNSIter input, MemoryPoolHandle pool) const + { + size_t modulus_size = base_q_->size(); + const Modulus* curr_modulus = base_q_->base(); + const Modulus plain_modulus = t_; + const Modulus last_modulus = curr_modulus[modulus_size - 1]; + uint64_t last_modulus_value = curr_modulus[modulus_size - 1].value(); + uint64_t plain_modulus_value = plain_modulus.value(); + CoeffIter last = input[modulus_size - 1]; + + SEAL_ALLOCATE_ZERO_GET_COEFF_ITER(delta, coeff_count_, pool); + + // last_q^(-1) mod t = 1, k mod t = -c mod t. + if(inv_q_last_mod_p_ == 1){ + SEAL_ITERATE(iter(last, delta), coeff_count_, [&](auto I){ + uint64_t coeff = barrett_reduce_64(get<0>(I), plain_modulus); + int64_t non_zero = (coeff != 0); + coeff = (plain_modulus_value - coeff) & static_cast(-non_zero); + get<1>(I) = get<0>(I) + last_modulus_value * coeff; + }); + + SEAL_ITERATE(iter(input, curr_modulus, inv_q_last_mod_q_), modulus_size - 1, [&](auto I){ + SEAL_ITERATE(iter(get<0>(I), delta), coeff_count_, [&](auto J){ + // \delta mod the other modulus + uint64_t delta_mod = barrett_reduce_64(get<1>(J), get<1>(I)); + // c = c - \delta + get<0>(J) = sub_uint_mod(get<0>(J), delta_mod, get<1>(I)); + }); + + // c = c/q_t + multiply_poly_scalar_coeffmod(get<0>(I), coeff_count_, get<2>(I), get<1>(I), get<0>(I)); + }); + } + else // last_q^(-1) mod t != 1, k mod t = -c * last_q^(-1) mod t. + { + //delta = -c mod t + modulo_poly_coeffs(CoeffIter(input[modulus_size - 1]), coeff_count_, plain_modulus, delta); + negate_poly_coeffmod(delta, coeff_count_, plain_modulus, delta); + //delta = (-c mod t) * last_q^(-1) mod t, which is k mod t + multiply_poly_scalar_coeffmod(delta, coeff_count_, inv_q_last_mod_p_, plain_modulus, delta); + + //RNS format as delta may be larger than 64 bytes. + SEAL_ALLOCATE_ZERO_GET_RNS_ITER(delta_rns, coeff_count_, modulus_size - 1, pool); + + SEAL_ITERATE(iter(input, delta_rns, curr_modulus, inv_q_last_mod_q_), modulus_size - 1, [&](auto I){ + // delta_rns is (k mod t) mod q_{i} + modulo_poly_coeffs(delta, coeff_count_, get<2>(I), get<1>(I)); + // delta_rns = (k mod t) * last_q mod q_{i} + multiply_poly_scalar_coeffmod(get<1>(I), coeff_count_, last_modulus_value, get<2>(I), get<1>(I)); + // 2 * qi + const uint64_t Lqi = get<2>(I).value() << 1; + // delta_rns = (c mod q_i + k * last_q mod q_i) mod q_i + SEAL_ITERATE(iter(get<0>(I), get<1>(I), input[modulus_size - 1]), coeff_count_, [&](auto J){ + // c mod q_i + auto last_mod_q_i = barrett_reduce_64(get<2>(J), get<2>(I)); + // k * last_q mod q_i + auto k_last_q = barrett_reduce_64(get<1>(J), get<2>(I)); + // c = c - delta_rns + get<0>(J) = get<0>(J) + Lqi - (last_mod_q_i + k_last_q); + }); + + // c = c/q_t + multiply_poly_scalar_coeffmod(get<0>(I), coeff_count_, get<3>(I), get<2>(I), get<0>(I)); + }); + } + } + + void RNSTool::decrypt_modt(RNSIter phase, CoeffIter destination, MemoryPoolHandle pool) const + { + //Use exact base convension rather than convert the base through the compose API + base_q_to_t_conv_->exact_convert_array(phase, destination, pool); + } } // namespace util } // namespace seal diff --git a/native/src/seal/util/rns.h b/native/src/seal/util/rns.h index 87b7caeec..f434564d2 100644 --- a/native/src/seal/util/rns.h +++ b/native/src/seal/util/rns.h @@ -164,6 +164,9 @@ namespace seal void fast_convert_array(ConstRNSIter in, RNSIter out, MemoryPoolHandle pool) const; + //The exact base convertion function, only supports obase size of 1. + void exact_convert_array(ConstRNSIter in, CoeffIter out, MemoryPoolHandle) const; + private: BaseConverter(const BaseConverter ©) = delete; @@ -229,6 +232,16 @@ namespace seal */ void decrypt_scale_and_round(ConstRNSIter phase, CoeffIter destination, MemoryPoolHandle pool) const; + /** + Remove the last q for bgv ciphertext + */ + void mod_t_and_divide_q_last_inplace(RNSIter input, MemoryPoolHandle pool) const; + + /** + Compute mod t + */ + void decrypt_modt(RNSIter phase, CoeffIter destination, MemoryPoolHandle pool) const; + SEAL_NODISCARD inline auto inv_q_last_mod_q() const noexcept { return inv_q_last_mod_q_.get(); @@ -284,6 +297,11 @@ namespace seal return gamma_; } + SEAL_NODISCARD inline auto &inv_q_last_mod_p() const noexcept + { + return inv_q_last_mod_p_; + } + private: RNSTool(const RNSTool ©) = delete; @@ -327,6 +345,9 @@ namespace seal // Base converter: q --> {t, gamma} Pointer base_q_to_t_gamma_conv_; + // Base converter: q --> t + Pointer base_q_to_t_conv_; + // prod(q)^(-1) mod Bsk Pointer inv_prod_q_mod_Bsk_; @@ -367,6 +388,8 @@ namespace seal Modulus t_; Modulus gamma_; + + uint64_t inv_q_last_mod_p_; }; } // namespace util } // namespace seal diff --git a/native/src/seal/util/scalingvariant.cpp b/native/src/seal/util/scalingvariant.cpp index 1b0d71d0b..9c05a3aa4 100644 --- a/native/src/seal/util/scalingvariant.cpp +++ b/native/src/seal/util/scalingvariant.cpp @@ -12,6 +12,65 @@ namespace seal { namespace util { + void add_plain_without_scaling_variant( + const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination) + { + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + auto &plain_modulus = parms.plain_modulus(); + const size_t coeff_count = plain.coeff_count(); + const size_t coeff_modulus_size = coeff_modulus.size(); + if (coeff_count > parms.poly_modulus_degree()) + { + throw std::invalid_argument("add_plain_without_scaling_variant: invalid plaintext"); + } + + if (destination.poly_modulus_degree() != parms.poly_modulus_degree()) + { + throw std::invalid_argument("add_plain_without_scaling_variant: invalid destination iter"); + } + + SEAL_ITERATE(iter(destination, coeff_modulus), coeff_modulus_size, [&](auto I) { + const Modulus &cipher_modulus = get<1>(I); + std::transform( + plain.data(), plain.data() + coeff_count, get<0>(I), get<0>(I), + [&](uint64_t m, uint64_t c) -> uint64_t { + m = barrett_reduce_64(m, plain_modulus); + return add_uint_mod(c, m, cipher_modulus); + }); + }); + } + + void sub_plain_without_scaling_variant( + const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination) + { + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + auto &plain_modulus = parms.plain_modulus(); + const size_t coeff_count = plain.coeff_count(); + const size_t coeff_modulus_size = coeff_modulus.size(); + + if (coeff_count > parms.poly_modulus_degree()) + { + throw std::invalid_argument("sub_plain_without_scaling_variant: invalid plaintext"); + } + + if (destination.poly_modulus_degree() != parms.poly_modulus_degree()) + { + throw std::invalid_argument("sub_plain_without_scaling_variant: invalid destination iter"); + } + + SEAL_ITERATE(iter(destination, coeff_modulus), coeff_modulus_size, [&](auto I) { + const Modulus &cipher_modulus = get<1>(I); + std::transform( + plain.data(), plain.data() + coeff_count, get<0>(I), get<0>(I), + [&](uint64_t m, uint64_t c) -> uint64_t { + m = barrett_reduce_64(m, plain_modulus); + return sub_uint_mod(c, m, cipher_modulus); + }); + }); + } + void multiply_add_plain_with_scaling_variant( const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination) { diff --git a/native/src/seal/util/scalingvariant.h b/native/src/seal/util/scalingvariant.h index 29970efae..8a7add2e7 100644 --- a/native/src/seal/util/scalingvariant.h +++ b/native/src/seal/util/scalingvariant.h @@ -13,6 +13,12 @@ namespace seal { namespace util { + void add_plain_without_scaling_variant( + const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination); + + void sub_plain_without_scaling_variant( + const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination); + void multiply_add_plain_with_scaling_variant( const Plaintext &plain, const SEALContext::ContextData &context_data, RNSIter destination); diff --git a/native/tests/seal/ciphertext.cpp b/native/tests/seal/ciphertext.cpp index 4b2ebbbfd..7666195dc 100644 --- a/native/tests/seal/ciphertext.cpp +++ b/native/tests/seal/ciphertext.cpp @@ -15,7 +15,7 @@ using namespace std; namespace sealtest { - TEST(CiphertextTest, CiphertextBasics) + TEST(CiphertextTest, BFVCiphertextBasics) { EncryptionParameters parms(scheme_type::bfv); @@ -90,7 +90,7 @@ namespace sealtest ASSERT_EQ(ctxt.size(), ctxt3.size()); } - TEST(CiphertextTest, SaveLoadCiphertext) + TEST(CiphertextTest, BFVSaveLoadCiphertext) { stringstream stream; EncryptionParameters parms(scheme_type::bfv); @@ -126,4 +126,118 @@ namespace sealtest is_equal_uint(ctxt.data(), ctxt2.data(), parms.poly_modulus_degree() * parms.coeff_modulus().size() * 2)); ASSERT_TRUE(ctxt.data() != ctxt2.data()); } + + TEST(CiphertextTest, BGVCiphertextBasics) + { + EncryptionParameters parms(scheme_type::bgv); + + parms.set_poly_modulus_degree(2); + parms.set_coeff_modulus(CoeffModulus::Create(2, { 30 })); + parms.set_plain_modulus(2); + //auto context = SEALContext::Create(parms, false, sec_level_type::none); + SEALContext context(parms, false, sec_level_type::none); + + Ciphertext ctxt(context); + ctxt.reserve(10); + ASSERT_EQ(0ULL, ctxt.size()); + ASSERT_EQ(0ULL, ctxt.dyn_array().size()); + ASSERT_EQ(10ULL * 2 * 1, ctxt.dyn_array().capacity()); + ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); + ASSERT_TRUE(ctxt.parms_id() == context.first_parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + const uint64_t *ptr = ctxt.data(); + + ctxt.reserve(5); + ASSERT_EQ(0ULL, ctxt.size()); + ASSERT_EQ(0ULL, ctxt.dyn_array().size()); + ASSERT_EQ(5ULL * 2 * 1, ctxt.dyn_array().capacity()); + ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); + ASSERT_TRUE(ptr != ctxt.data()); + ASSERT_TRUE(ctxt.parms_id() == context.first_parms_id()); + ptr = ctxt.data(); + + ctxt.reserve(10); + ASSERT_EQ(0ULL, ctxt.size()); + ASSERT_EQ(0ULL, ctxt.dyn_array().size()); + ASSERT_EQ(10ULL * 2 * 1, ctxt.dyn_array().capacity()); + ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); + ASSERT_TRUE(ptr != ctxt.data()); + ASSERT_TRUE(ctxt.parms_id() == context.first_parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + ptr = ctxt.data(); + + ctxt.reserve(2); + ASSERT_EQ(0ULL, ctxt.size()); + ASSERT_EQ(2ULL * 2 * 1, ctxt.dyn_array().capacity()); + ASSERT_EQ(0ULL, ctxt.dyn_array().size()); + ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); + ASSERT_TRUE(ptr != ctxt.data()); + ASSERT_TRUE(ctxt.parms_id() == context.first_parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + ptr = ctxt.data(); + + ctxt.reserve(5); + ASSERT_EQ(0ULL, ctxt.size()); + ASSERT_EQ(5ULL * 2 * 1, ctxt.dyn_array().capacity()); + ASSERT_EQ(0ULL, ctxt.dyn_array().size()); + ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); + ASSERT_TRUE(ptr != ctxt.data()); + ASSERT_TRUE(ctxt.parms_id() == context.first_parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + + Ciphertext ctxt2{ ctxt }; + ASSERT_EQ(ctxt.coeff_modulus_size(), ctxt2.coeff_modulus_size()); + ASSERT_EQ(ctxt.is_ntt_form(), ctxt2.is_ntt_form()); + ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt2.poly_modulus_degree()); + ASSERT_TRUE(ctxt.parms_id() == ctxt2.parms_id()); + ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt2.poly_modulus_degree()); + ASSERT_EQ(ctxt.size(), ctxt2.size()); + + Ciphertext ctxt3; + ctxt3 = ctxt; + ASSERT_EQ(ctxt.coeff_modulus_size(), ctxt3.coeff_modulus_size()); + ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt3.poly_modulus_degree()); + ASSERT_EQ(ctxt.is_ntt_form(), ctxt3.is_ntt_form()); + ASSERT_TRUE(ctxt.parms_id() == ctxt3.parms_id()); + ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt3.poly_modulus_degree()); + ASSERT_EQ(ctxt.size(), ctxt3.size()); + } + + TEST(CiphertextTest, BGVSaveLoadCiphertext) + { + stringstream stream; + EncryptionParameters parms(scheme_type::bgv); + parms.set_poly_modulus_degree(2); + parms.set_coeff_modulus(CoeffModulus::Create(2, { 30 })); + parms.set_plain_modulus(2); + + SEALContext context(parms, false, sec_level_type::none); + + Ciphertext ctxt(context); + Ciphertext ctxt2; + ctxt.save(stream); + ctxt2.load(context, stream); + ASSERT_TRUE(ctxt.parms_id() == ctxt2.parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + ASSERT_FALSE(ctxt2.is_ntt_form()); + + parms.set_poly_modulus_degree(1024); + parms.set_coeff_modulus(CoeffModulus::BGVDefault(1024)); + parms.set_plain_modulus(0xF0F0); + context = SEALContext(parms, false); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + Encryptor encryptor(context, pk); + encryptor.encrypt(Plaintext("Ax^10 + 9x^9 + 8x^8 + 7x^7 + 6x^6 + 5x^5 + 4x^4 + 3x^3 + 2x^2 + 1"), ctxt); + ctxt.save(stream); + ctxt2.load(context, stream); + ASSERT_TRUE(ctxt.parms_id() == ctxt2.parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + ASSERT_FALSE(ctxt2.is_ntt_form()); + ASSERT_TRUE( + is_equal_uint(ctxt.data(), ctxt2.data(), parms.poly_modulus_degree() * parms.coeff_modulus().size() * 2)); + ASSERT_TRUE(ctxt.data() != ctxt2.data()); + } + } // namespace sealtest diff --git a/native/tests/seal/context.cpp b/native/tests/seal/context.cpp index 392277098..acd143528 100644 --- a/native/tests/seal/context.cpp +++ b/native/tests/seal/context.cpp @@ -12,7 +12,7 @@ using error_type = EncryptionParameterQualifiers::error_type; namespace sealtest { - TEST(ContextTest, ContextConstructor) + TEST(ContextTest, BFVContextConstructor) { // Nothing set auto scheme = scheme_type::bfv; @@ -344,6 +344,38 @@ namespace sealtest ASSERT_FALSE(!!context.first_context_data()->next_context_data()); ASSERT_TRUE(!!context.first_context_data()->prev_context_data()); } + { + EncryptionParameters parms(scheme_type::bgv); + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 41, 137, 193, 65537 }); + parms.set_plain_modulus(73); + SEALContext context(parms, true, sec_level_type::none); + auto context_data = context.key_context_data(); + ASSERT_EQ(size_t(2), context_data->chain_index()); + ASSERT_EQ(71047416497ULL, *context_data->total_coeff_modulus()); + ASSERT_FALSE(!!context_data->prev_context_data()); + ASSERT_EQ(context_data->parms_id(), context.key_parms_id()); + auto prev_context_data = context_data; + context_data = context_data->next_context_data(); + ASSERT_EQ(size_t(1), context_data->chain_index()); + ASSERT_EQ(1084081ULL, *context_data->total_coeff_modulus()); + ASSERT_EQ(context_data->prev_context_data()->parms_id(), prev_context_data->parms_id()); + prev_context_data = context_data; + context_data = context_data->next_context_data(); + ASSERT_EQ(size_t(0), context_data->chain_index()); + ASSERT_EQ(5617ULL, *context_data->total_coeff_modulus()); + ASSERT_EQ(context_data->prev_context_data()->parms_id(), prev_context_data->parms_id()); + ASSERT_FALSE(!!context_data->next_context_data()); + ASSERT_EQ(context_data->parms_id(), context.last_parms_id()); + + context = SEALContext(parms, false, sec_level_type::none); + ASSERT_EQ(size_t(1), context.key_context_data()->chain_index()); + ASSERT_EQ(size_t(0), context.first_context_data()->chain_index()); + ASSERT_EQ(71047416497ULL, *context.key_context_data()->total_coeff_modulus()); + ASSERT_EQ(1084081ULL, *context.first_context_data()->total_coeff_modulus()); + ASSERT_FALSE(!!context.first_context_data()->next_context_data()); + ASSERT_TRUE(!!context.first_context_data()->prev_context_data()); + } { EncryptionParameters parms(scheme_type::ckks); parms.set_poly_modulus_degree(4); @@ -382,7 +414,7 @@ namespace sealtest } } - TEST(EncryptionParameterQualifiersTest, ParameterError) + TEST(EncryptionParameterQualifiersTest, BFVParameterError) { auto scheme = scheme_type::bfv; EncryptionParameters parms(scheme); @@ -412,4 +444,333 @@ namespace sealtest ASSERT_STREQ(context.parameter_error_name(), "invalid_poly_modulus_degree_non_power_of_two"); ASSERT_STREQ(context.parameter_error_message(), "poly_modulus_degree is not a power of two"); } + + TEST(ContextTest, BGVContextConstructor) + { + // Nothing set + auto scheme = scheme_type::bgv; + EncryptionParameters parms(scheme); + { + SEALContext context(parms, false, sec_level_type::none); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set()); + ASSERT_EQ(qualifiers.parameter_error, error_type::invalid_coeff_modulus_size); + ASSERT_FALSE(qualifiers.using_fft); + ASSERT_FALSE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + + // Not relatively prime coeff moduli + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 2, 30 }); + parms.set_plain_modulus(2); + parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + { + SEALContext context(parms, false, sec_level_type::none); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set()); + ASSERT_EQ(qualifiers.parameter_error, error_type::failed_creating_rns_base); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_FALSE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + + // Plain modulus not relatively prime to coeff moduli + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17, 41 }); + parms.set_plain_modulus(34); + parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + { + SEALContext context(parms, false, sec_level_type::none); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set()); + ASSERT_EQ(qualifiers.parameter_error, error_type::invalid_plain_modulus_coprimality); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + + // Plain modulus not smaller than product of coeff moduli + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17 }); + parms.set_plain_modulus(41); + parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + { + SEALContext context(parms, false, sec_level_type::none); + ASSERT_EQ(17ULL, *context.first_context_data()->total_coeff_modulus()); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set()); + ASSERT_EQ(qualifiers.parameter_error, error_type::invalid_plain_modulus_too_large); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + + // FFT poly but not NTT modulus + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 3 }); + parms.set_plain_modulus(2); + parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + { + SEALContext context(parms, false, sec_level_type::none); + ASSERT_EQ(3ULL, *context.first_context_data()->total_coeff_modulus()); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set()); + ASSERT_EQ(qualifiers.parameter_error, error_type::invalid_coeff_modulus_no_ntt); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_FALSE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + + // Parameters OK; no fast plain lift + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17, 41 }); + parms.set_plain_modulus(18); + parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + { + SEALContext context(parms, false, sec_level_type::none); + ASSERT_EQ(697ULL, *context.first_context_data()->total_coeff_modulus()); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set()); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + + // Parameters OK; fast plain lift + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17, 41 }); + parms.set_plain_modulus(16); + parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + { + SEALContext context(parms, false, sec_level_type::none); + ASSERT_EQ(17ULL, *context.first_context_data()->total_coeff_modulus()); + ASSERT_EQ(697ULL, *context.key_context_data()->total_coeff_modulus()); + auto qualifiers = context.first_context_data()->qualifiers(); + auto key_qualifiers = context.key_context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set()); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_TRUE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(key_qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_TRUE(context.using_keyswitching()); + } + + // Parameters OK; no batching due to non-prime plain modulus + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17, 41 }); + parms.set_plain_modulus(49); + parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + { + SEALContext context(parms, false, sec_level_type::none); + ASSERT_EQ(697ULL, *context.first_context_data()->total_coeff_modulus()); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set()); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + + // Parameters OK; batching enabled + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17, 41 }); + parms.set_plain_modulus(73); + parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + { + SEALContext context(parms, false, sec_level_type::none); + ASSERT_EQ(697ULL, *context.first_context_data()->total_coeff_modulus()); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set()); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_TRUE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + + // Parameters OK; batching and fast plain lift enabled + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 137, 193 }); + parms.set_plain_modulus(73); + parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + { + SEALContext context(parms, false, sec_level_type::none); + ASSERT_EQ(137ULL, *context.first_context_data()->total_coeff_modulus()); + ASSERT_EQ(26441ULL, *context.key_context_data()->total_coeff_modulus()); + auto qualifiers = context.first_context_data()->qualifiers(); + auto key_qualifiers = context.key_context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set()); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_TRUE(qualifiers.using_batching); + ASSERT_TRUE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(key_qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_TRUE(context.using_keyswitching()); + } + + // Parameters OK; batching and fast plain lift enabled; nullptr RNG + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 137, 193 }); + parms.set_plain_modulus(73); + parms.set_random_generator(nullptr); + { + SEALContext context(parms, false, sec_level_type::none); + ASSERT_EQ(137ULL, *context.first_context_data()->total_coeff_modulus()); + ASSERT_EQ(26441ULL, *context.key_context_data()->total_coeff_modulus()); + auto qualifiers = context.first_context_data()->qualifiers(); + auto key_qualifiers = context.key_context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set()); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_TRUE(qualifiers.using_batching); + ASSERT_TRUE(qualifiers.using_fast_plain_lift); + ASSERT_FALSE(key_qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_TRUE(context.using_keyswitching()); + } + + // Parameters not OK due to too small poly_modulus_degree and enforce_hes + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 137, 193 }); + parms.set_plain_modulus(73); + parms.set_random_generator(nullptr); + { + SEALContext context(parms, false, sec_level_type::tc128); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set()); + ASSERT_EQ(qualifiers.parameter_error, error_type::invalid_parameters_insecure); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + + // Parameters not OK due to too large coeff_modulus and enforce_hes + parms.set_poly_modulus_degree(2048); + parms.set_coeff_modulus(CoeffModulus::BGVDefault(4096, sec_level_type::tc128)); + parms.set_plain_modulus(73); + parms.set_random_generator(nullptr); + { + SEALContext context(parms, false, sec_level_type::tc128); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set()); + ASSERT_EQ(qualifiers.parameter_error, error_type::invalid_parameters_insecure); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + + // Parameters OK; descending modulus chain + parms.set_poly_modulus_degree(4096); + parms.set_coeff_modulus({ 0xffffee001, 0xffffc4001 }); + parms.set_plain_modulus(73); + { + SEALContext context(parms, false, sec_level_type::tc128); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set()); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_TRUE(qualifiers.using_fast_plain_lift); + ASSERT_TRUE(qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::tc128, qualifiers.sec_level); + ASSERT_TRUE(context.using_keyswitching()); + } + + // Parameters OK; no standard security + parms.set_poly_modulus_degree(2048); + parms.set_coeff_modulus({ 0x1ffffe0001, 0xffffee001, 0xffffc4001 }); + parms.set_plain_modulus(73); + { + SEALContext context(parms, false, sec_level_type::none); + auto qualifiers = context.first_context_data()->qualifiers(); + auto key_qualifiers = context.key_context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set()); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_TRUE(qualifiers.using_fast_plain_lift); + ASSERT_TRUE(key_qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_TRUE(context.using_keyswitching()); + } + + // Parameters OK; using batching; no keyswitching + parms.set_poly_modulus_degree(2048); + parms.set_coeff_modulus(CoeffModulus::Create(2048, { 40 })); + parms.set_plain_modulus(65537); + { + SEALContext context(parms, false, sec_level_type::none); + auto qualifiers = context.first_context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set()); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_TRUE(qualifiers.using_batching); + ASSERT_TRUE(qualifiers.using_fast_plain_lift); + ASSERT_TRUE(qualifiers.using_descending_modulus_chain); + ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); + ASSERT_FALSE(context.using_keyswitching()); + } + } + + TEST(EncryptionParameterQualifiersTest, BGVParameterError) + { + auto scheme = scheme_type::bgv; + EncryptionParameters parms(scheme); + SEALContext context(parms, false, sec_level_type::none); + auto qualifiers = context.first_context_data()->qualifiers(); + + qualifiers.parameter_error = error_type::none; + ASSERT_STREQ(qualifiers.parameter_error_name(), "none"); + ASSERT_STREQ(qualifiers.parameter_error_message(), "constructed but not yet validated"); + + qualifiers.parameter_error = error_type::success; + ASSERT_STREQ(qualifiers.parameter_error_name(), "success"); + ASSERT_STREQ(qualifiers.parameter_error_message(), "valid"); + + qualifiers.parameter_error = error_type::invalid_coeff_modulus_bit_count; + ASSERT_STREQ(qualifiers.parameter_error_name(), "invalid_coeff_modulus_bit_count"); + ASSERT_STREQ( + qualifiers.parameter_error_message(), + "coeff_modulus's primes' bit counts are not bounded by SEAL_USER_MOD_BIT_COUNT_MIN(MAX)"); + + parms.set_poly_modulus_degree(127); + parms.set_coeff_modulus({ 17, 73 }); + parms.set_plain_modulus(41); + parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + context = SEALContext(parms, false, sec_level_type::none); + ASSERT_FALSE(context.parameters_set()); + ASSERT_STREQ(context.parameter_error_name(), "invalid_poly_modulus_degree_non_power_of_two"); + ASSERT_STREQ(context.parameter_error_message(), "poly_modulus_degree is not a power of two"); + } } // namespace sealtest diff --git a/native/tests/seal/encryptionparams.cpp b/native/tests/seal/encryptionparams.cpp index 8b32a041d..eb3089675 100644 --- a/native/tests/seal/encryptionparams.cpp +++ b/native/tests/seal/encryptionparams.cpp @@ -16,7 +16,7 @@ namespace sealtest auto encryption_parameters_test = [](scheme_type scheme) { EncryptionParameters parms(scheme); parms.set_coeff_modulus({ 2, 3 }); - if (scheme == scheme_type::bfv) + if (scheme == scheme_type::bfv || scheme == scheme_type::bgv) parms.set_plain_modulus(2); parms.set_poly_modulus_degree(2); parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); @@ -24,7 +24,7 @@ namespace sealtest ASSERT_TRUE(scheme == parms.scheme()); ASSERT_TRUE(parms.coeff_modulus()[0] == 2); ASSERT_TRUE(parms.coeff_modulus()[1] == 3); - if (scheme == scheme_type::bfv) + if (scheme == scheme_type::bfv || scheme == scheme_type::bgv) { ASSERT_TRUE(parms.plain_modulus().value() == 2); } @@ -36,7 +36,7 @@ namespace sealtest ASSERT_TRUE(parms.random_generator() == UniformRandomGeneratorFactory::DefaultFactory()); parms.set_coeff_modulus(CoeffModulus::Create(2, { 30, 40, 50 })); - if (scheme == scheme_type::bfv) + if (scheme == scheme_type::bfv || scheme == scheme_type::bgv) parms.set_plain_modulus(2); parms.set_poly_modulus_degree(128); parms.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); @@ -45,7 +45,7 @@ namespace sealtest ASSERT_TRUE(util::is_prime(parms.coeff_modulus()[1])); ASSERT_TRUE(util::is_prime(parms.coeff_modulus()[2])); - if (scheme == scheme_type::bfv) + if (scheme == scheme_type::bfv || scheme == scheme_type::bgv) { ASSERT_TRUE(parms.plain_modulus().value() == 2); } @@ -58,91 +58,97 @@ namespace sealtest }; encryption_parameters_test(scheme_type::bfv); encryption_parameters_test(scheme_type::ckks); + encryption_parameters_test(scheme_type::bgv); } TEST(EncryptionParametersTest, EncryptionParametersCompare) { - auto scheme = scheme_type::bfv; - EncryptionParameters parms1(scheme); - parms1.set_coeff_modulus(CoeffModulus::Create(64, { 30 })); - if (scheme == scheme_type::bfv) - parms1.set_plain_modulus(1 << 6); - parms1.set_poly_modulus_degree(64); - parms1.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); - - EncryptionParameters parms2(parms1); - ASSERT_TRUE(parms1 == parms2); - - EncryptionParameters parms3(scheme); - parms3 = parms2; - ASSERT_TRUE(parms3 == parms2); - parms3.set_coeff_modulus(CoeffModulus::Create(64, { 32 })); - ASSERT_FALSE(parms3 == parms2); - - parms3 = parms2; - ASSERT_TRUE(parms3 == parms2); - parms3.set_coeff_modulus(CoeffModulus::Create(64, { 30, 30 })); - ASSERT_FALSE(parms3 == parms2); - - parms3 = parms2; - parms3.set_poly_modulus_degree(128); - ASSERT_FALSE(parms3 == parms1); - - parms3 = parms2; - if (scheme == scheme_type::bfv) - parms3.set_plain_modulus((1 << 6) + 1); - ASSERT_FALSE(parms3 == parms2); - - parms3 = parms2; - ASSERT_TRUE(parms3 == parms2); - - parms3 = parms2; - parms3.set_random_generator(nullptr); - ASSERT_TRUE(parms3 == parms2); - - parms3 = parms2; - parms3.set_poly_modulus_degree(128); - parms3.set_poly_modulus_degree(64); - ASSERT_TRUE(parms3 == parms1); - - parms3 = parms2; - parms3.set_coeff_modulus({ 2 }); - parms3.set_coeff_modulus(CoeffModulus::Create(64, { 50 })); - parms3.set_coeff_modulus(parms2.coeff_modulus()); - ASSERT_TRUE(parms3 == parms2); + auto encryption_parameters_compare = [](scheme_type scheme){ + EncryptionParameters parms1(scheme); + parms1.set_coeff_modulus(CoeffModulus::Create(64, { 30 })); + if (scheme == scheme_type::bfv || scheme == scheme_type::bgv) + parms1.set_plain_modulus(1 << 6); + parms1.set_poly_modulus_degree(64); + parms1.set_random_generator(UniformRandomGeneratorFactory::DefaultFactory()); + + EncryptionParameters parms2(parms1); + ASSERT_TRUE(parms1 == parms2); + + EncryptionParameters parms3(scheme); + parms3 = parms2; + ASSERT_TRUE(parms3 == parms2); + parms3.set_coeff_modulus(CoeffModulus::Create(64, { 32 })); + ASSERT_FALSE(parms3 == parms2); + + parms3 = parms2; + ASSERT_TRUE(parms3 == parms2); + parms3.set_coeff_modulus(CoeffModulus::Create(64, { 30, 30 })); + ASSERT_FALSE(parms3 == parms2); + + parms3 = parms2; + parms3.set_poly_modulus_degree(128); + ASSERT_FALSE(parms3 == parms1); + + parms3 = parms2; + if (scheme == scheme_type::bfv || scheme == scheme_type::bgv) + parms3.set_plain_modulus((1 << 6) + 1); + ASSERT_FALSE(parms3 == parms2); + + parms3 = parms2; + ASSERT_TRUE(parms3 == parms2); + + parms3 = parms2; + parms3.set_random_generator(nullptr); + ASSERT_TRUE(parms3 == parms2); + + parms3 = parms2; + parms3.set_poly_modulus_degree(128); + parms3.set_poly_modulus_degree(64); + ASSERT_TRUE(parms3 == parms1); + + parms3 = parms2; + parms3.set_coeff_modulus({ 2 }); + parms3.set_coeff_modulus(CoeffModulus::Create(64, { 50 })); + parms3.set_coeff_modulus(parms2.coeff_modulus()); + ASSERT_TRUE(parms3 == parms2); + }; + encryption_parameters_compare(scheme_type::bfv); + encryption_parameters_compare(scheme_type::bgv); } TEST(EncryptionParametersTest, EncryptionParametersSaveLoad) { - stringstream stream; - - auto scheme = scheme_type::bfv; - EncryptionParameters parms(scheme); - EncryptionParameters parms2(scheme); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 30 })); - if (scheme == scheme_type::bfv) - parms.set_plain_modulus(1 << 6); - parms.set_poly_modulus_degree(64); - parms.save(stream); - parms2.load(stream); - ASSERT_TRUE(parms.scheme() == parms2.scheme()); - ASSERT_TRUE(parms.coeff_modulus() == parms2.coeff_modulus()); - ASSERT_TRUE(parms.plain_modulus() == parms2.plain_modulus()); - ASSERT_TRUE(parms.poly_modulus_degree() == parms2.poly_modulus_degree()); - ASSERT_TRUE(parms == parms2); - - parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 60, 60 })); - - if (scheme == scheme_type::bfv) - parms.set_plain_modulus(1 << 30); - parms.set_poly_modulus_degree(256); - - parms.save(stream); - parms2.load(stream); - ASSERT_TRUE(parms.scheme() == parms2.scheme()); - ASSERT_TRUE(parms.coeff_modulus() == parms2.coeff_modulus()); - ASSERT_TRUE(parms.plain_modulus() == parms2.plain_modulus()); - ASSERT_TRUE(parms.poly_modulus_degree() == parms2.poly_modulus_degree()); - ASSERT_TRUE(parms == parms2); + auto encryption_parameters_save_load = [](scheme_type scheme){ + stringstream stream; + EncryptionParameters parms(scheme); + EncryptionParameters parms2(scheme); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 30 })); + if (scheme == scheme_type::bfv || scheme == scheme_type::bgv) + parms.set_plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.save(stream); + parms2.load(stream); + ASSERT_TRUE(parms.scheme() == parms2.scheme()); + ASSERT_TRUE(parms.coeff_modulus() == parms2.coeff_modulus()); + ASSERT_TRUE(parms.plain_modulus() == parms2.plain_modulus()); + ASSERT_TRUE(parms.poly_modulus_degree() == parms2.poly_modulus_degree()); + ASSERT_TRUE(parms == parms2); + + parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 60, 60 })); + + if (scheme == scheme_type::bfv || scheme == scheme_type::bgv) + parms.set_plain_modulus(1 << 30); + parms.set_poly_modulus_degree(256); + + parms.save(stream); + parms2.load(stream); + ASSERT_TRUE(parms.scheme() == parms2.scheme()); + ASSERT_TRUE(parms.coeff_modulus() == parms2.coeff_modulus()); + ASSERT_TRUE(parms.plain_modulus() == parms2.plain_modulus()); + ASSERT_TRUE(parms.poly_modulus_degree() == parms2.poly_modulus_degree()); + ASSERT_TRUE(parms == parms2); + }; + encryption_parameters_save_load(scheme_type::bfv); + encryption_parameters_save_load(scheme_type::bgv); } } // namespace sealtest diff --git a/native/tests/seal/encryptor.cpp b/native/tests/seal/encryptor.cpp index f0df72a73..0f92b4325 100644 --- a/native/tests/seal/encryptor.cpp +++ b/native/tests/seal/encryptor.cpp @@ -850,4 +850,192 @@ namespace sealtest } } } + + TEST(EncryptorTest, BGVEncryptDecrypt) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_plain_modulus(plain_modulus); + std::string max_possible_plain; + { + std::stringstream ss; + ss << std::hex << (plain_modulus.value() - 1); + max_possible_plain = ss.str(); + } + for (char &c : max_possible_plain) + { + c = std::toupper(c); + } + { + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + string hex_poly; + + hex_poly = + "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + encryptor.encrypt(Plaintext(hex_poly), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(hex_poly, plain.to_string()); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + hex_poly = "0"; + encryptor.encrypt(Plaintext(hex_poly), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(hex_poly, plain.to_string()); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + hex_poly = "1"; + encryptor.encrypt(Plaintext(hex_poly), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(hex_poly, plain.to_string()); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + hex_poly = "1x^1"; + encryptor.encrypt(Plaintext(hex_poly), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(hex_poly, plain.to_string()); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + hex_poly = + "1x^62 + 1x^61 + 1x^60 + 1x^59 + 1x^58 + 1x^57 + 1x^56 + 1x^55 + 1x^54 + 1x^53 + 1x^52 + 1x^51 + 1x^50 " + "+ 1x^49 + 1x^48 + 1x^47 + 1x^46 + 1x^45 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 1x^40 + 1x^39 + 1x^38 + " + "1x^37 + 1x^36 + 1x^35 + 1x^34 + 1x^33 + 1x^32 + 1x^31 + 1x^30 + 1x^29 + 1x^28 + 1x^27 + 1x^26 + 1x^25 " + "+ 1x^24 + 1x^23 + 1x^22 + 1x^21 + 1x^20 + 1x^19 + 1x^18 + 1x^17 + 1x^16 + 1x^15 + 1x^14 + 1x^13 + " + "1x^12 + 1x^11 + 1x^10 + 1x^9 + 1x^8 + 1x^7 + 1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^2 + " + max_possible_plain; + encryptor.encrypt(Plaintext(hex_poly), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(hex_poly, plain.to_string()); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + hex_poly = + "1x^62 + 1x^61 + 1x^60 + 1x^59 + 1x^58 + 1x^57 + 1x^56 + 1x^55 + 1x^54 + 1x^53 + 1x^52 + 1x^51 + 1x^50 " + "+ 1x^49 + 1x^48 + 1x^47 + 1x^46 + 1x^45 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 1x^40 + 1x^39 + 1x^38 + " + "1x^37 + 1x^36 + 1x^35 + 1x^34 + 1x^33 + 1x^32 + 1x^31 + 1x^30 + 1x^29 + 1x^28 + 1x^27 + 1x^26 + 1x^25 " + "+ 1x^24 + 1x^23 + 1x^22 + 1x^21 + 1x^20 + 1x^19 + 1x^18 + 1x^17 + 1x^16 + 1x^15 + 1x^14 + 1x^13 + " + "1x^12 + 1x^11 + 1x^10 + 1x^9 + 1x^8 + 1x^7 + 1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^2 + 1x^1"; + encryptor.encrypt(Plaintext(hex_poly), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(hex_poly, plain.to_string()); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + hex_poly = + "3Fx^62 + 1x^61 + 1x^60 + 1x^59 + 1x^58 + 1x^57 + 1x^56 + 1x^55 + 1x^54 + 1x^53 + 1x^52 + 1x^51 + " + "1x^50 " + "+ 1x^49 + 1x^48 + 1x^47 + 1x^46 + 1x^45 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 1x^40 + 1x^39 + 1x^38 + " + "1x^37 + 1x^36 + 1x^35 + 1x^34 + 1x^33 + 1x^32 + 1x^31 + 1x^30 + 1x^29 + 1x^28 + 1x^27 + 1x^26 + 1x^25 " + "+ 1x^24 + 1x^23 + 1x^22 + 1x^21 + 1x^20 + 1x^19 + 1x^18 + 1x^17 + 1x^16 + 1x^15 + 1x^14 + 1x^13 + " + "1x^12 + 1x^11 + 1x^10 + 1x^9 + 1x^8 + 1x^7 + 1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^2 + 1x^1 + " + max_possible_plain; + encryptor.encrypt(Plaintext(hex_poly), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(hex_poly, plain.to_string()); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + hex_poly = + "1x^28 + 1x^25 + 1x^23 + 1x^21 + 1x^20 + 1x^19 + 1x^16 + 1x^15 + 1x^13 + 1x^12 + 1x^7 + 1x^5 + " + max_possible_plain; + encryptor.encrypt(Plaintext(hex_poly), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(hex_poly, plain.to_string()); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + } + } + + TEST(EncryptorTest, BGVEncryptZeroDecrypt) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_plain_modulus(plain_modulus); + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40 })); + SEALContext context(parms, true, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk, keygen.secret_key()); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext ct; + Plaintext pt; + parms_id_type next_parms = context.first_context_data()->next_context_data()->parms_id(); + { + encryptor.encrypt_zero(ct); + ASSERT_FALSE(ct.is_ntt_form()); + ASSERT_FALSE(ct.is_transparent()); + ASSERT_DOUBLE_EQ(ct.scale(), 1.0); + decryptor.decrypt(ct, pt); + ASSERT_TRUE(pt.is_zero()); + + encryptor.encrypt_zero(next_parms, ct); + ASSERT_FALSE(ct.is_ntt_form()); + ASSERT_FALSE(ct.is_transparent()); + ASSERT_DOUBLE_EQ(ct.scale(), 1.0); + ASSERT_EQ(ct.parms_id(), next_parms); + decryptor.decrypt(ct, pt); + ASSERT_TRUE(pt.is_zero()); + } + { + stringstream stream; + encryptor.encrypt_zero().save(stream); + ct.load(context, stream); + ASSERT_FALSE(ct.is_ntt_form()); + ASSERT_FALSE(ct.is_transparent()); + ASSERT_DOUBLE_EQ(ct.scale(), 1.0); + decryptor.decrypt(ct, pt); + ASSERT_TRUE(pt.is_zero()); + + encryptor.encrypt_zero(next_parms).save(stream); + ct.load(context, stream); + ASSERT_FALSE(ct.is_ntt_form()); + ASSERT_FALSE(ct.is_transparent()); + ASSERT_DOUBLE_EQ(ct.scale(), 1.0); + ASSERT_EQ(ct.parms_id(), next_parms); + decryptor.decrypt(ct, pt); + ASSERT_TRUE(pt.is_zero()); + } + { + encryptor.encrypt_zero_symmetric(ct); + ASSERT_FALSE(ct.is_ntt_form()); + ASSERT_FALSE(ct.is_transparent()); + ASSERT_DOUBLE_EQ(ct.scale(), 1.0); + decryptor.decrypt(ct, pt); + ASSERT_TRUE(pt.is_zero()); + + encryptor.encrypt_zero_symmetric(next_parms, ct); + ASSERT_FALSE(ct.is_ntt_form()); + ASSERT_FALSE(ct.is_transparent()); + ASSERT_DOUBLE_EQ(ct.scale(), 1.0); + ASSERT_EQ(ct.parms_id(), next_parms); + decryptor.decrypt(ct, pt); + ASSERT_TRUE(pt.is_zero()); + } + { + stringstream stream; + encryptor.encrypt_zero_symmetric().save(stream); + ct.load(context, stream); + ASSERT_FALSE(ct.is_ntt_form()); + ASSERT_FALSE(ct.is_transparent()); + ASSERT_DOUBLE_EQ(ct.scale(), 1.0); + decryptor.decrypt(ct, pt); + ASSERT_TRUE(pt.is_zero()); + + encryptor.encrypt_zero_symmetric(next_parms).save(stream); + ct.load(context, stream); + ASSERT_FALSE(ct.is_ntt_form()); + ASSERT_FALSE(ct.is_transparent()); + ASSERT_DOUBLE_EQ(ct.scale(), 1.0); + ASSERT_EQ(ct.parms_id(), next_parms); + decryptor.decrypt(ct, pt); + ASSERT_TRUE(pt.is_zero()); + } + } } // namespace sealtest diff --git a/native/tests/seal/evaluator.cpp b/native/tests/seal/evaluator.cpp index e48d42797..e1a6f8c61 100644 --- a/native/tests/seal/evaluator.cpp +++ b/native/tests/seal/evaluator.cpp @@ -179,6 +179,165 @@ namespace sealtest ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); } + TEST(EvaluatorTest, BGVEncryptNegateDecrypt) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + + plain = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + encryptor.encrypt(plain, encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ( + plain.to_string(), "3Fx^28 + 3Fx^25 + 3Fx^21 + 3Fx^20 + 3Fx^18 + 3Fx^14 + 3Fx^12 + 3Fx^10 + 3Fx^9 + 3Fx^6 " + "+ 3Fx^5 + 3Fx^4 + 3Fx^3"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "0"; + encryptor.encrypt(plain, encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1"; + encryptor.encrypt(plain, encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "3F"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "3F"; + encryptor.encrypt(plain, encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^1"; + encryptor.encrypt(plain, encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "3Fx^1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "3Fx^2 + 3F"; + encryptor.encrypt(plain, encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + } + + TEST(EvaluatorTest, BGVEncryptAddDecrypt) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain, plain1, plain2; + + plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ( + plain.to_string(), "1x^28 + 1x^25 + 1x^21 + 1x^20 + 2x^18 + 1x^16 + 2x^14 + 1x^12 + 1x^10 + 2x^9 + 1x^8 + " + "1x^6 + 2x^5 + 1x^4 + 1x^3 + 1"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "0"; + plain2 = "0"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ("0", plain.to_string()); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "0"; + plain2 = "1x^2 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "1x^2 + 1"; + plain2 = "3Fx^1 + 3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 3Fx^1"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "3Fx^2 + 3Fx^1 + 3F"; + plain2 = "1x^1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "3Fx^2 + 3F"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "2x^2 + 1x^1 + 3"; + plain2 = "3x^3 + 4x^2 + 5x^1 + 6"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_TRUE(plain.to_string() == "3x^3 + 6x^2 + 6x^1 + 9"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "3x^5 + 1x^4 + 4x^3 + 1"; + plain2 = "5x^2 + 9x^1 + 2"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_TRUE(plain.to_string() == "3x^5 + 1x^4 + 4x^3 + 5x^2 + 9x^1 + 3"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + } + TEST(EvaluatorTest, CKKSEncryptAddDecrypt) { EncryptionParameters parms(scheme_type::ckks); @@ -342,6 +501,7 @@ namespace sealtest } } } + TEST(EvaluatorTest, CKKSEncryptAddPlainDecrypt) { EncryptionParameters parms(scheme_type::ckks); @@ -1463,480 +1623,944 @@ namespace sealtest } } -#include "seal/randomgen.h" - TEST(EvaluatorTest, BFVRelinearize) + TEST(EvaluatorTest, BGVEncryptSubDecrypt) { - EncryptionParameters parms(scheme_type::bfv); + EncryptionParameters parms(scheme_type::bgv); Modulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); + parms.set_poly_modulus_degree(64); parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40, 40 })); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); - SEALContext context(parms, true, sec_level_type::none); + SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - RelinKeys rlk; - keygen.create_relin_keys(rlk); Encryptor encryptor(context, pk); Evaluator evaluator(context); Decryptor decryptor(context, keygen.secret_key()); - Ciphertext encrypted(context); - Ciphertext encrypted2(context); - - Plaintext plain; - Plaintext plain2; + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain, plain1, plain2; - plain = 0; - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain == plain2); + plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.sub_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ( + plain.to_string(), + "1x^28 + 1x^25 + 1x^21 + 1x^20 + 3Fx^16 + 1x^12 + 1x^10 + 3Fx^8 + 1x^6 + 1x^4 + 1x^3 + 3F"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain == plain2); + plain1 = "0"; + plain2 = "0"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.sub_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - plain = "1x^10 + 2"; - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4"); + plain1 = "0"; + plain2 = "1x^2 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.sub_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "3Fx^2 + 3F"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10"); + plain1 = "1x^2 + 1"; + plain2 = "3Fx^1 + 3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.sub_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 2"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - // Relinearization with modulus switching - plain = "1x^10 + 2"; - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.mod_switch_to_next_inplace(encrypted); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4"); + plain1 = "3Fx^2 + 3Fx^1 + 3F"; + plain2 = "1x^1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.sub_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "3Fx^2 + 3Ex^1 + 3F"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + } - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.mod_switch_to_next_inplace(encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.mod_switch_to_next_inplace(encrypted); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10"); + TEST(EvaluatorTest, BGVEncryptAddPlainDecrypt) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain, plain1, plain2; + + plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1"; + encryptor.encrypt(plain1, encrypted1); + evaluator.add_plain_inplace(encrypted1, plain2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ( + plain.to_string(), "1x^28 + 1x^25 + 1x^21 + 1x^20 + 2x^18 + 1x^16 + 2x^14 + 1x^12 + 1x^10 + 2x^9 + 1x^8 + " + "1x^6 + 2x^5 + 1x^4 + 1x^3 + 1"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "0"; + plain2 = "0"; + encryptor.encrypt(plain1, encrypted1); + evaluator.add_plain_inplace(encrypted1, plain2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "0"; + plain2 = "1x^2 + 1"; + encryptor.encrypt(plain1, encrypted1); + evaluator.add_plain_inplace(encrypted1, plain2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "1x^2 + 1"; + plain2 = "3Fx^1 + 3F"; + encryptor.encrypt(plain1, encrypted1); + evaluator.add_plain_inplace(encrypted1, plain2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 3Fx^1"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "3Fx^2 + 3Fx^1 + 3F"; + plain2 = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + evaluator.add_plain_inplace(encrypted1, plain2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); } - TEST(EvaluatorTest, CKKSEncryptNaiveMultiplyDecrypt) + TEST(EvaluatorTest, BGVEncryptSubPlainDecrypt) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + Plaintext plain, plain1, plain2; + + plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1"; + encryptor.encrypt(plain1, encrypted1); + evaluator.sub_plain_inplace(encrypted1, plain2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ( + plain.to_string(), + "1x^28 + 1x^25 + 1x^21 + 1x^20 + 3Fx^16 + 1x^12 + 1x^10 + 3Fx^8 + 1x^6 + 1x^4 + 1x^3 + 3F"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "0"; + plain2 = "0"; + encryptor.encrypt(plain1, encrypted1); + evaluator.sub_plain_inplace(encrypted1, plain2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "0"; + plain2 = "1x^2 + 1"; + encryptor.encrypt(plain1, encrypted1); + evaluator.sub_plain_inplace(encrypted1, plain2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "3Fx^2 + 3F"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "1x^2 + 1"; + plain2 = "3Fx^1 + 3F"; + encryptor.encrypt(plain1, encrypted1); + evaluator.sub_plain_inplace(encrypted1, plain2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 2"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + + plain1 = "3Fx^2 + 3Fx^1 + 3F"; + plain2 = "1x^1"; + encryptor.encrypt(plain1, encrypted1); + evaluator.sub_plain_inplace(encrypted1, plain2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "3Fx^2 + 3Ex^1 + 3F"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + } + + TEST(EvaluatorTest, BGVEncryptMultiplyPlainDecrypt) { - EncryptionParameters parms(scheme_type::ckks); { - // Multiplying two zero vectors - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30 })); + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - CKKSEncoder encoder(context); Encryptor encryptor(context, pk); - Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; + Plaintext plain, plain1, plain2; - vector> input(slot_size, 0.0); - vector> output(slot_size); - const double delta = static_cast(1 << 30); - encoder.encode(input, context.first_parms_id(), delta, plain); + plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1"; + encryptor.encrypt(plain1, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ( + plain.to_string(), "1x^46 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 2x^39 + 1x^38 + 2x^37 + 3x^36 + 1x^35 + " + "3x^34 + 2x^33 + 2x^32 + 4x^30 + 2x^29 + 5x^28 + 2x^27 + 4x^26 + 3x^25 + 2x^24 + " + "4x^23 + 3x^22 + 4x^21 + 4x^20 + 4x^19 + 4x^18 + 3x^17 + 2x^15 + 4x^14 + 2x^13 + " + "3x^12 + 2x^11 + 2x^10 + 2x^9 + 1x^8 + 1x^6 + 1x^5 + 1x^4 + 1x^3"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - encryptor.encrypt(plain, encrypted); - evaluator.multiply_inplace(encrypted, encrypted); + plain1 = "0"; + plain2 = "1x^2 + 1"; + encryptor.encrypt(plain1, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - // Check correctness of encryption + plain1 = "1x^2 + 1x^1 + 1"; + plain2 = "1x^2"; + encryptor.encrypt(plain1, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^4 + 1x^3 + 1x^2"); ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } + plain1 = "1x^2 + 1x^1 + 1"; + plain2 = "1x^1"; + encryptor.encrypt(plain1, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^3 + 1x^2 + 1x^1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain1 = "1x^2 + 1x^1 + 1"; + plain2 = "1"; + encryptor.encrypt(plain1, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain1 = "1x^2 + 1"; + plain2 = "3Fx^1 + 3F"; + encryptor.encrypt(plain1, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "3Fx^3 + 3Fx^2 + 3Fx^1 + 3F"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain1 = "3Fx^2 + 3Fx^1 + 3F"; + plain2 = "1x^1"; + encryptor.encrypt(plain1, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "3Fx^3 + 3Fx^2 + 3Fx^1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); } { - // Multiplying two random vectors - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus((1ULL << 20) - 1); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 60, 60 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - CKKSEncoder encoder(context); Encryptor encryptor(context, pk); - Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); - Ciphertext encrypted1; - Ciphertext encrypted2; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - vector> input1(slot_size, 0.0); - vector> input2(slot_size, 0.0); - vector> expected(slot_size, 0.0); - vector> output(slot_size); - const double delta = static_cast(1ULL << 40); + Ciphertext encrypted; + Plaintext plain, plain1, plain2; - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); + plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + plain2 = "1"; + encryptor.encrypt(plain1, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ( + plain.to_string(), + "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i]; - } - encoder.encode(input1, context.first_parms_id(), delta, plain1); - encoder.encode(input2, context.first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - - // Check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } + plain2 = "5"; + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ( + plain.to_string(), + "5x^28 + 5x^25 + 5x^21 + 5x^20 + 5x^18 + 5x^14 + 5x^12 + 5x^10 + 5x^9 + 5x^6 + 5x^5 + 5x^4 + 5x^3"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); } { - // Multiplying two random vectors - size_t slot_size = 16; + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus((1ULL << 40) - 1); parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 })); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 60, 60 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - CKKSEncoder encoder(context); Encryptor encryptor(context, pk); - Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); - Ciphertext encrypted1; - Ciphertext encrypted2; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; + Ciphertext encrypted; + Plaintext plain, plain1, plain2; - vector> input1(slot_size, 0.0); - vector> input2(slot_size, 0.0); - vector> expected(slot_size, 0.0); - vector> output(slot_size); - const double delta = static_cast(1ULL << 40); + plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + plain2 = "1"; + encryptor.encrypt(plain1, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ( + plain.to_string(), + "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); + plain2 = "5"; + evaluator.multiply_plain_inplace(encrypted, plain2); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ( + plain.to_string(), + "5x^28 + 5x^25 + 5x^21 + 5x^20 + 5x^18 + 5x^14 + 5x^12 + 5x^10 + 5x^9 + 5x^6 + 5x^5 + 5x^4 + 5x^3"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + } + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(PlainModulus::Batching(64, 20)); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 30, 30 })); - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i]; - } - encoder.encode(input1, context.first_parms_id(), delta, plain1); - encoder.encode(input2, context.first_parms_id(), delta, plain2); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); + BatchEncoder batch_encoder(context); + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); - // Check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + Ciphertext encrypted; + Plaintext plain; + vector result; - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } + batch_encoder.encode(vector(batch_encoder.slot_count(), 7), plain); + encryptor.encrypt(plain, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, result); + ASSERT_TRUE(vector(batch_encoder.slot_count(), 49) == result); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - TEST(EvaluatorTest, CKKSEncryptMultiplyByNumberDecrypt) - { - EncryptionParameters parms(scheme_type::ckks); + batch_encoder.encode(vector(batch_encoder.slot_count(), -7), plain); + encryptor.encrypt(plain, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, result); + ASSERT_TRUE(vector(batch_encoder.slot_count(), 49) == result); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + } { - // Multiplying two random vectors by an integer - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 40 })); + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(PlainModulus::Batching(64, 40)); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 30, 30, 30, 30 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - CKKSEncoder encoder(context); + BatchEncoder batch_encoder(context); Encryptor encryptor(context, pk); - Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); - Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - vector> input1(slot_size, 0.0); - int64_t input2; - vector> expected(slot_size, 0.0); - - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); - - for (int iExp = 0; iExp < 50; iExp++) - { - input2 = max(rand() % data_bound, 1); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * static_cast(input2); - } + Ciphertext encrypted; + Plaintext plain; + vector result; - vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - encoder.encode(input1, context.first_parms_id(), delta, plain1); - encoder.encode(input2, context.first_parms_id(), plain2); + // First test with constant plaintext + batch_encoder.encode(vector(batch_encoder.slot_count(), 7), plain); + encryptor.encrypt(plain, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, result); + ASSERT_TRUE(vector(batch_encoder.slot_count(), 49) == result); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - encryptor.encrypt(plain1, encrypted1); - evaluator.multiply_plain_inplace(encrypted1, plain2); + batch_encoder.encode(vector(batch_encoder.slot_count(), -7), plain); + encryptor.encrypt(plain, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, result); + ASSERT_TRUE(vector(batch_encoder.slot_count(), 49) == result); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - // Check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + // Now test a non-constant plaintext + vector input(batch_encoder.slot_count() - 1, 7); + input.push_back(1); + vector true_result(batch_encoder.slot_count() - 1, 49); + true_result.push_back(1); + batch_encoder.encode(input, plain); + encryptor.encrypt(plain, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, result); + ASSERT_TRUE(true_result == result); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } + input = vector(batch_encoder.slot_count() - 1, -7); + input.push_back(1); + batch_encoder.encode(input, plain); + encryptor.encrypt(plain, encrypted); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, result); + ASSERT_TRUE(true_result == result); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); } + } + + TEST(EvaluatorTest, BGVEncryptMultiplyDecrypt) + { { - // Multiplying two random vectors by an integer - size_t slot_size = 8; + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60 })); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - CKKSEncoder encoder(context); Encryptor encryptor(context, pk); - Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - vector> input1(slot_size, 0.0); - int64_t input2; - vector> expected(slot_size, 0.0); + Ciphertext encrypted2; + Plaintext plain, plain1, plain2; - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); + plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ( + plain.to_string(), "1x^46 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 2x^39 + 1x^38 + 2x^37 + 3x^36 + 1x^35 + " + "3x^34 + 2x^33 + 2x^32 + 4x^30 + 2x^29 + 5x^28 + 2x^27 + 4x^26 + 3x^25 + 2x^24 + " + "4x^23 + 3x^22 + 4x^21 + 4x^20 + 4x^19 + 4x^18 + 3x^17 + 2x^15 + 4x^14 + 2x^13 + " + "3x^12 + 2x^11 + 2x^10 + 2x^9 + 1x^8 + 1x^6 + 1x^5 + 1x^4 + 1x^3"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - for (int iExp = 0; iExp < 50; iExp++) - { - input2 = max(rand() % data_bound, 1); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * static_cast(input2); - } + plain1 = "0"; + plain2 = "0"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - encoder.encode(input1, context.first_parms_id(), delta, plain1); - encoder.encode(input2, context.first_parms_id(), plain2); + plain1 = "0"; + plain2 = "1x^2 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - encryptor.encrypt(plain1, encrypted1); - evaluator.multiply_plain_inplace(encrypted1, plain2); + plain1 = "1x^2 + 1x^1 + 1"; + plain2 = "1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 1"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - // Check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + plain1 = "1x^2 + 1"; + plain2 = "3Fx^1 + 3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "3Fx^3 + 3Fx^2 + 3Fx^1 + 3F"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } + plain1 = "1x^16"; + plain2 = "1x^8"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^24"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); } { - // Multiplying two random vectors by a double - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus((1ULL << 60) - 1); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60, 60 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - CKKSEncoder encoder(context); Encryptor encryptor(context, pk); - Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - vector> input1(slot_size, 0.0); - double input2; - vector> expected(slot_size, 0.0); - vector> output(slot_size); + Ciphertext encrypted2; + Plaintext plain, plain1, plain2; - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); + plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ( + plain.to_string(), "1x^46 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 2x^39 + 1x^38 + 2x^37 + 3x^36 + 1x^35 + " + "3x^34 + 2x^33 + 2x^32 + 4x^30 + 2x^29 + 5x^28 + 2x^27 + 4x^26 + 3x^25 + 2x^24 + " + "4x^23 + 3x^22 + 4x^21 + 4x^20 + 4x^19 + 4x^18 + 3x^17 + 2x^15 + 4x^14 + 2x^13 + " + "3x^12 + 2x^11 + 2x^10 + 2x^9 + 1x^8 + 1x^6 + 1x^5 + 1x^4 + 1x^3"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - for (int iExp = 0; iExp < 50; iExp++) - { - input2 = static_cast(rand() % (data_bound * data_bound)) / static_cast(data_bound); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2; - } + plain1 = "0"; + plain2 = "0"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - const double delta = static_cast(1ULL << 40); - encoder.encode(input1, context.first_parms_id(), delta, plain1); - encoder.encode(input2, context.first_parms_id(), delta, plain2); + plain1 = "0"; + plain2 = "1x^2 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - encryptor.encrypt(plain1, encrypted1); - evaluator.multiply_plain_inplace(encrypted1, plain2); + plain1 = "1x^2 + 1x^1 + 1"; + plain2 = "1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 1"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - // Check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + plain1 = "1x^2 + 1"; + plain2 = "FFFFFFFFFFFFFFEx^1 + FFFFFFFFFFFFFFE"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ( + plain.to_string(), "FFFFFFFFFFFFFFEx^3 + FFFFFFFFFFFFFFEx^2 + FFFFFFFFFFFFFFEx^1 + FFFFFFFFFFFFFFE"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } + plain1 = "1x^16"; + plain2 = "1x^8"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^24"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); } { - // Multiplying two random vectors by a double - size_t slot_size = 16; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 })); + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - CKKSEncoder encoder(context); Encryptor encryptor(context, pk); - Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - vector> input1(slot_size, 2.1); - double input2; - vector> expected(slot_size, 2.1); - vector> output(slot_size); + Ciphertext encrypted2; + Plaintext plain, plain1, plain2; - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); + plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3"; + plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ( + plain.to_string(), "1x^46 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 2x^39 + 1x^38 + 2x^37 + 3x^36 + 1x^35 + " + "3x^34 + 2x^33 + 2x^32 + 4x^30 + 2x^29 + 5x^28 + 2x^27 + 4x^26 + 3x^25 + 2x^24 + " + "4x^23 + 3x^22 + 4x^21 + 4x^20 + 4x^19 + 4x^18 + 3x^17 + 2x^15 + 4x^14 + 2x^13 + " + "3x^12 + 2x^11 + 2x^10 + 2x^9 + 1x^8 + 1x^6 + 1x^5 + 1x^4 + 1x^3"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - for (int iExp = 0; iExp < 50; iExp++) - { - input2 = static_cast(rand() % (data_bound * data_bound)) / static_cast(data_bound); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2; - } + plain1 = "0"; + plain2 = "0"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - const double delta = static_cast(1ULL << 40); - encoder.encode(input1, context.first_parms_id(), delta, plain1); - encoder.encode(input2, context.first_parms_id(), delta, plain2); + plain1 = "0"; + plain2 = "1x^2 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - encryptor.encrypt(plain1, encrypted1); - evaluator.multiply_plain_inplace(encrypted1, plain2); + plain1 = "1x^2 + 1x^1 + 1"; + plain2 = "1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 1"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - // Check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + plain1 = "1x^2 + 1"; + plain2 = "3Fx^1 + 3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "3Fx^3 + 3Fx^2 + 3Fx^1 + 3F"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } + plain1 = "1x^16"; + plain2 = "1x^8"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(plain.to_string(), "1x^24"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + } + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 8); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + Plaintext plain, plain1; + + plain1 = "1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + evaluator.multiply(encrypted1, encrypted1, encrypted1); + evaluator.multiply(encrypted1, encrypted1, encrypted1); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ( + plain.to_string(), "1x^24 + 4x^23 + Ax^22 + 14x^21 + 1Fx^20 + 2Cx^19 + 3Cx^18 + 4Cx^17 + 5Fx^16 + " + "6Cx^15 + 70x^14 + 74x^13 + 71x^12 + 6Cx^11 + 64x^10 + 50x^9 + 40x^8 + 34x^7 + " + "26x^6 + 1Cx^5 + 11x^4 + 8x^3 + 6x^2 + 4x^1 + 1"); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); } } - TEST(EvaluatorTest, CKKSEncryptMultiplyRelinDecrypt) +#include "seal/randomgen.h" + TEST(EvaluatorTest, BFVRelinearize) + { + EncryptionParameters parms(scheme_type::bfv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40, 40 })); + + SEALContext context(parms, true, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted(context); + Ciphertext encrypted2(context); + + Plaintext plain; + Plaintext plain2; + + plain = 0; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain == plain2); + + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain == plain2); + + plain = "1x^10 + 2"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4"); + + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10"); + + // Relinearization with modulus switching + plain = "1x^10 + 2"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.mod_switch_to_next_inplace(encrypted); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4"); + + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.mod_switch_to_next_inplace(encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.mod_switch_to_next_inplace(encrypted); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10"); + } + + TEST(EvaluatorTest, BGVRelinearize) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40, 40 })); + + SEALContext context(parms, true, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted(context); + Ciphertext encrypted2(context); + + Plaintext plain; + Plaintext plain2; + + plain = 0; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain == plain2); + + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain == plain2); + + plain = "1x^10 + 2"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4"); + + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10"); + + // Relinearization with modulus switching + plain = "1x^10 + 2"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.mod_switch_to_next_inplace(encrypted); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4"); + + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.mod_switch_to_next_inplace(encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.mod_switch_to_next_inplace(encrypted); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10"); + } + + TEST(EvaluatorTest, CKKSEncryptNaiveMultiplyDecrypt) { EncryptionParameters parms(scheme_type::ckks); { - // Multiplying two random vectors 50 times + // Multiplying two zero vectors + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, pk); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + vector> input(slot_size, 0.0); + vector> output(slot_size); + const double delta = static_cast(1 << 30); + encoder.encode(input, context.first_parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + evaluator.multiply_inplace(encrypted, encrypted); + + // Check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + // Multiplying two random vectors size_t slot_size = 32; parms.set_poly_modulus_degree(slot_size * 2); parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); @@ -1945,8 +2569,6 @@ namespace sealtest KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - RelinKeys rlk; - keygen.create_relin_keys(rlk); CKKSEncoder encoder(context); Encryptor encryptor(context, pk); @@ -1955,7 +2577,6 @@ namespace sealtest Ciphertext encrypted1; Ciphertext encrypted2; - Ciphertext encryptedRes; Plaintext plain1; Plaintext plain2; Plaintext plainRes; @@ -1963,33 +2584,29 @@ namespace sealtest vector> input1(slot_size, 0.0); vector> input2(slot_size, 0.0); vector> expected(slot_size, 0.0); - int data_bound = 1 << 10; + vector> output(slot_size); + const double delta = static_cast(1ULL << 40); - for (int round = 0; round < 50; round++) + int data_bound = (1 << 10); + srand(static_cast(time(NULL))); + + for (int round = 0; round < 100; round++) { - srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { input1[i] = static_cast(rand() % data_bound); input2[i] = static_cast(rand() % data_bound); expected[i] = input1[i] * input2[i]; } - - vector> output(slot_size); - const double delta = static_cast(1ULL << 40); encoder.encode(input1, context.first_parms_id(), delta, plain1); encoder.encode(input2, context.first_parms_id(), delta, plain2); encryptor.encrypt(plain1, encrypted1); encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); // Check correctness of encryption ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - // Check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); @@ -2001,17 +2618,15 @@ namespace sealtest } } { - // Multiplying two random vectors 50 times - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 30, 30, 30 })); + // Multiplying two random vectors + size_t slot_size = 16; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - RelinKeys rlk; - keygen.create_relin_keys(rlk); CKKSEncoder encoder(context); Encryptor encryptor(context, pk); @@ -2020,7 +2635,6 @@ namespace sealtest Ciphertext encrypted1; Ciphertext encrypted2; - Ciphertext encryptedRes; Plaintext plain1; Plaintext plain2; Plaintext plainRes; @@ -2028,33 +2642,29 @@ namespace sealtest vector> input1(slot_size, 0.0); vector> input2(slot_size, 0.0); vector> expected(slot_size, 0.0); - int data_bound = 1 << 10; + vector> output(slot_size); + const double delta = static_cast(1ULL << 40); - for (int round = 0; round < 50; round++) + int data_bound = (1 << 10); + srand(static_cast(time(NULL))); + + for (int round = 0; round < 100; round++) { - srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { input1[i] = static_cast(rand() % data_bound); input2[i] = static_cast(rand() % data_bound); expected[i] = input1[i] * input2[i]; } - - vector> output(slot_size); - const double delta = static_cast(1ULL << 40); encoder.encode(input1, context.first_parms_id(), delta, plain1); encoder.encode(input2, context.first_parms_id(), delta, plain2); encryptor.encrypt(plain1, encrypted1); encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); // Check correctness of encryption ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - // Check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); @@ -2065,18 +2675,21 @@ namespace sealtest } } } + } + + TEST(EvaluatorTest, CKKSEncryptMultiplyByNumberDecrypt) + { + EncryptionParameters parms(scheme_type::ckks); { - // Multiplying two random vectors 50 times - size_t slot_size = 2; - parms.set_poly_modulus_degree(8); - parms.set_coeff_modulus(CoeffModulus::Create(8, { 60, 30, 30, 30 })); + // Multiplying two random vectors by an integer + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 40 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - RelinKeys rlk; - keygen.create_relin_keys(rlk); CKKSEncoder encoder(context); Encryptor encryptor(context, pk); @@ -2084,42 +2697,36 @@ namespace sealtest Evaluator evaluator(context); Ciphertext encrypted1; - Ciphertext encrypted2; - Ciphertext encryptedRes; Plaintext plain1; Plaintext plain2; Plaintext plainRes; vector> input1(slot_size, 0.0); - vector> input2(slot_size, 0.0); + int64_t input2; vector> expected(slot_size, 0.0); - vector> output(slot_size); - int data_bound = 1 << 10; - const double delta = static_cast(1ULL << 40); - for (int round = 0; round < 50; round++) + int data_bound = (1 << 10); + srand(static_cast(time(NULL))); + + for (int iExp = 0; iExp < 50; iExp++) { - srand(static_cast(time(NULL))); + input2 = max(rand() % data_bound, 1); for (size_t i = 0; i < slot_size; i++) { input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i]; + expected[i] = input1[i] * static_cast(input2); } + vector> output(slot_size); + const double delta = static_cast(1ULL << 40); encoder.encode(input1, context.first_parms_id(), delta, plain1); - encoder.encode(input2, context.first_parms_id(), delta, plain2); + encoder.encode(input2, context.first_parms_id(), plain2); encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_plain_inplace(encrypted1, plain2); // Check correctness of encryption ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - // Check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - // Evaluator.relinearize_inplace(encrypted1, rlk); decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); @@ -2130,61 +2737,55 @@ namespace sealtest } } } - } - - TEST(EvaluatorTest, CKKSEncryptSquareRelinDecrypt) - { - EncryptionParameters parms(scheme_type::ckks); { - // Squaring two random vectors 100 times - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); + // Multiplying two random vectors by an integer + size_t slot_size = 8; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - RelinKeys rlk; - keygen.create_relin_keys(rlk); CKKSEncoder encoder(context); Encryptor encryptor(context, pk); Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); - Ciphertext encrypted; - Plaintext plain; + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; Plaintext plainRes; - vector> input(slot_size, 0.0); + vector> input1(slot_size, 0.0); + int64_t input2; vector> expected(slot_size, 0.0); - int data_bound = 1 << 7; + int data_bound = (1 << 10); srand(static_cast(time(NULL))); - for (int round = 0; round < 100; round++) + for (int iExp = 0; iExp < 50; iExp++) { + input2 = max(rand() % data_bound, 1); for (size_t i = 0; i < slot_size; i++) { - input[i] = static_cast(rand() % data_bound); - expected[i] = input[i] * input[i]; + input1[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * static_cast(input2); } vector> output(slot_size); const double delta = static_cast(1ULL << 40); - encoder.encode(input, context.first_parms_id(), delta, plain); + encoder.encode(input1, context.first_parms_id(), delta, plain1); + encoder.encode(input2, context.first_parms_id(), plain2); - encryptor.encrypt(plain, encrypted); + encryptor.encrypt(plain1, encrypted1); + evaluator.multiply_plain_inplace(encrypted1, plain2); // Check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - - // Evaluator.square_inplace(encrypted); - evaluator.multiply_inplace(encrypted, encrypted); - evaluator.relinearize_inplace(encrypted, rlk); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - decryptor.decrypt(encrypted, plainRes); + decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); for (size_t i = 0; i < slot_size; i++) { @@ -2194,55 +2795,54 @@ namespace sealtest } } { - // Squaring two random vectors 100 times + // Multiplying two random vectors by a double size_t slot_size = 32; parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 30, 30, 30 })); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - RelinKeys rlk; - keygen.create_relin_keys(rlk); CKKSEncoder encoder(context); Encryptor encryptor(context, pk); Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); - Ciphertext encrypted; - Plaintext plain; + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; Plaintext plainRes; - vector> input(slot_size, 0.0); + vector> input1(slot_size, 0.0); + double input2; vector> expected(slot_size, 0.0); + vector> output(slot_size); - int data_bound = 1 << 7; + int data_bound = (1 << 10); srand(static_cast(time(NULL))); - for (int round = 0; round < 100; round++) + for (int iExp = 0; iExp < 50; iExp++) { + input2 = static_cast(rand() % (data_bound * data_bound)) / static_cast(data_bound); for (size_t i = 0; i < slot_size; i++) { - input[i] = static_cast(rand() % data_bound); - expected[i] = input[i] * input[i]; + input1[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2; } - vector> output(slot_size); const double delta = static_cast(1ULL << 40); - encoder.encode(input, context.first_parms_id(), delta, plain); + encoder.encode(input1, context.first_parms_id(), delta, plain1); + encoder.encode(input2, context.first_parms_id(), delta, plain2); - encryptor.encrypt(plain, encrypted); + encryptor.encrypt(plain1, encrypted1); + evaluator.multiply_plain_inplace(encrypted1, plain2); // Check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - - // Evaluator.square_inplace(encrypted); - evaluator.multiply_inplace(encrypted, encrypted); - evaluator.relinearize_inplace(encrypted, rlk); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - decryptor.decrypt(encrypted, plainRes); + decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); for (size_t i = 0; i < slot_size; i++) { @@ -2252,55 +2852,54 @@ namespace sealtest } } { - // Squaring two random vectors 100 times + // Multiplying two random vectors by a double size_t slot_size = 16; parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 30, 30, 30 })); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - RelinKeys rlk; - keygen.create_relin_keys(rlk); CKKSEncoder encoder(context); Encryptor encryptor(context, pk); Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); - Ciphertext encrypted; - Plaintext plain; + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; Plaintext plainRes; - vector> input(slot_size, 0.0); - vector> expected(slot_size, 0.0); + vector> input1(slot_size, 2.1); + double input2; + vector> expected(slot_size, 2.1); + vector> output(slot_size); - int data_bound = 1 << 7; + int data_bound = (1 << 10); srand(static_cast(time(NULL))); - for (int round = 0; round < 100; round++) + for (int iExp = 0; iExp < 50; iExp++) { + input2 = static_cast(rand() % (data_bound * data_bound)) / static_cast(data_bound); for (size_t i = 0; i < slot_size; i++) { - input[i] = static_cast(rand() % data_bound); - expected[i] = input[i] * input[i]; + input1[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2; } - vector> output(slot_size); const double delta = static_cast(1ULL << 40); - encoder.encode(input, context.first_parms_id(), delta, plain); + encoder.encode(input1, context.first_parms_id(), delta, plain1); + encoder.encode(input2, context.first_parms_id(), delta, plain2); - encryptor.encrypt(plain, encrypted); + encryptor.encrypt(plain1, encrypted1); + evaluator.multiply_plain_inplace(encrypted1, plain2); // Check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - - // Evaluator.square_inplace(encrypted); - evaluator.multiply_inplace(encrypted, encrypted); - evaluator.relinearize_inplace(encrypted, rlk); + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - decryptor.decrypt(encrypted, plainRes); + decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); for (size_t i = 0; i < slot_size; i++) { @@ -2311,17 +2910,16 @@ namespace sealtest } } - TEST(EvaluatorTest, CKKSEncryptMultiplyRelinRescaleDecrypt) + TEST(EvaluatorTest, CKKSEncryptMultiplyRelinDecrypt) { EncryptionParameters parms(scheme_type::ckks); { - // Multiplying two random vectors 100 times - size_t slot_size = 64; + // Multiplying two random vectors 50 times + size_t slot_size = 32; parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30, 30, 30 })); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); - SEALContext context(parms, true, sec_level_type::none); - auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); @@ -2343,10 +2941,10 @@ namespace sealtest vector> input1(slot_size, 0.0); vector> input2(slot_size, 0.0); vector> expected(slot_size, 0.0); + int data_bound = 1 << 10; - for (int round = 0; round < 100; round++) + for (int round = 0; round < 50; round++) { - int data_bound = 1 << 7; srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { @@ -2356,7 +2954,7 @@ namespace sealtest } vector> output(slot_size); - double delta = static_cast(1ULL << 40); + const double delta = static_cast(1ULL << 40); encoder.encode(input1, context.first_parms_id(), delta, plain1); encoder.encode(input2, context.first_parms_id(), delta, plain2); @@ -2370,10 +2968,6 @@ namespace sealtest evaluator.multiply_inplace(encrypted1, encrypted2); evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.rescale_to_next_inplace(encrypted1); - - // Check correctness of modulus switching - ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); @@ -2385,13 +2979,12 @@ namespace sealtest } } { - // Multiplying two random vectors 100 times - size_t slot_size = 16; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30, 30 })); + // Multiplying two random vectors 50 times + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 30, 30, 30 })); - SEALContext context(parms, true, sec_level_type::none); - auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); @@ -2413,10 +3006,10 @@ namespace sealtest vector> input1(slot_size, 0.0); vector> input2(slot_size, 0.0); vector> expected(slot_size, 0.0); + int data_bound = 1 << 10; - for (int round = 0; round < 100; round++) + for (int round = 0; round < 50; round++) { - int data_bound = 1 << 7; srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { @@ -2426,7 +3019,7 @@ namespace sealtest } vector> output(slot_size); - double delta = static_cast(1ULL << 40); + const double delta = static_cast(1ULL << 40); encoder.encode(input1, context.first_parms_id(), delta, plain1); encoder.encode(input2, context.first_parms_id(), delta, plain2); @@ -2440,10 +3033,6 @@ namespace sealtest evaluator.multiply_inplace(encrypted1, encrypted2); evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.rescale_to_next_inplace(encrypted1); - - // Check correctness of modulus switching - ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); @@ -2455,12 +3044,12 @@ namespace sealtest } } { - // Multiplying two random vectors 100 times - size_t slot_size = 16; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 60, 60, 60, 60, 60 })); + // Multiplying two random vectors 50 times + size_t slot_size = 2; + parms.set_poly_modulus_degree(8); + parms.set_coeff_modulus(CoeffModulus::Create(8, { 60, 30, 30, 30 })); - SEALContext context(parms, true, sec_level_type::none); + SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); @@ -2482,66 +3071,20 @@ namespace sealtest vector> input1(slot_size, 0.0); vector> input2(slot_size, 0.0); vector> expected(slot_size, 0.0); + vector> output(slot_size); + int data_bound = 1 << 10; + const double delta = static_cast(1ULL << 40); - for (int round = 0; round < 100; round++) + for (int round = 0; round < 50; round++) { - int data_bound = 1 << 7; - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i] * input2[i]; - } - - vector> output(slot_size); - double delta = static_cast(1ULL << 60); - encoder.encode(input1, context.first_parms_id(), delta, plain1); - encoder.encode(input2, context.first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - - // Check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); - // Check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - - // Scale down by two levels - auto target_parms = context.first_context_data()->next_context_data()->next_context_data()->parms_id(); - evaluator.rescale_to_inplace(encrypted1, target_parms); - - // Check correctness of modulus switching - ASSERT_TRUE(encrypted1.parms_id() == target_parms); - - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - - // Test with inverted order: rescale then relin - for (int round = 0; round < 100; round++) - { - int data_bound = 1 << 7; srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { input1[i] = static_cast(rand() % data_bound); input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i] * input2[i]; + expected[i] = input1[i] * input2[i]; } - vector> output(slot_size); - double delta = static_cast(1ULL << 50); encoder.encode(input1, context.first_parms_id(), delta, plain1); encoder.encode(input2, context.first_parms_id(), delta, plain2); @@ -2554,18 +3097,7 @@ namespace sealtest ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.multiply_inplace(encrypted1, encrypted2); - - // Scale down by two levels - auto target_parms = context.first_context_data()->next_context_data()->next_context_data()->parms_id(); - evaluator.rescale_to_inplace(encrypted1, target_parms); - - // Relinearize now - evaluator.relinearize_inplace(encrypted1, rlk); - - // Check correctness of modulus switching - ASSERT_TRUE(encrypted1.parms_id() == target_parms); + // Evaluator.relinearize_inplace(encrypted1, rlk); decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); @@ -2578,17 +3110,16 @@ namespace sealtest } } - TEST(EvaluatorTest, CKKSEncryptSquareRelinRescaleDecrypt) + TEST(EvaluatorTest, CKKSEncryptSquareRelinDecrypt) { EncryptionParameters parms(scheme_type::ckks); { // Squaring two random vectors 100 times - size_t slot_size = 64; + size_t slot_size = 32; parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 50, 50, 50 })); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); - SEALContext context(parms, true, sec_level_type::none); - auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); @@ -2605,20 +3136,21 @@ namespace sealtest Plaintext plainRes; vector> input(slot_size, 0.0); - vector> output(slot_size); vector> expected(slot_size, 0.0); - int data_bound = 1 << 8; + + int data_bound = 1 << 7; + srand(static_cast(time(NULL))); for (int round = 0; round < 100; round++) { - srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { input[i] = static_cast(rand() % data_bound); expected[i] = input[i] * input[i]; } - double delta = static_cast(1ULL << 40); + vector> output(slot_size); + const double delta = static_cast(1ULL << 40); encoder.encode(input, context.first_parms_id(), delta, plain); encryptor.encrypt(plain, encrypted); @@ -2626,12 +3158,9 @@ namespace sealtest // Check correctness of encryption ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - evaluator.square_inplace(encrypted); + // Evaluator.square_inplace(encrypted); + evaluator.multiply_inplace(encrypted, encrypted); evaluator.relinearize_inplace(encrypted, rlk); - evaluator.rescale_to_next_inplace(encrypted); - - // Check correctness of modulus switching - ASSERT_TRUE(encrypted.parms_id() == next_parms_id); decryptor.decrypt(encrypted, plainRes); encoder.decode(plainRes, output); @@ -2644,12 +3173,11 @@ namespace sealtest } { // Squaring two random vectors 100 times - size_t slot_size = 16; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 50, 50, 50 })); + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 30, 30, 30 })); - SEALContext context(parms, true, sec_level_type::none); - auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); @@ -2666,20 +3194,21 @@ namespace sealtest Plaintext plainRes; vector> input(slot_size, 0.0); - vector> output(slot_size); vector> expected(slot_size, 0.0); - int data_bound = 1 << 8; + + int data_bound = 1 << 7; + srand(static_cast(time(NULL))); for (int round = 0; round < 100; round++) { - srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { input[i] = static_cast(rand() % data_bound); expected[i] = input[i] * input[i]; } - double delta = static_cast(1ULL << 40); + vector> output(slot_size); + const double delta = static_cast(1ULL << 40); encoder.encode(input, context.first_parms_id(), delta, plain); encryptor.encrypt(plain, encrypted); @@ -2687,12 +3216,9 @@ namespace sealtest // Check correctness of encryption ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - evaluator.square_inplace(encrypted); + // Evaluator.square_inplace(encrypted); + evaluator.multiply_inplace(encrypted, encrypted); evaluator.relinearize_inplace(encrypted, rlk); - evaluator.rescale_to_next_inplace(encrypted); - - // Check correctness of modulus switching - ASSERT_TRUE(encrypted.parms_id() == next_parms_id); decryptor.decrypt(encrypted, plainRes); encoder.decode(plainRes, output); @@ -2703,45 +3229,44 @@ namespace sealtest } } } - } - TEST(EvaluatorTest, CKKSEncryptModSwitchDecrypt) - { - EncryptionParameters parms(scheme_type::ckks); { - // Modulus switching without rescaling for random vectors - size_t slot_size = 64; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60, 60, 60 })); + // Squaring two random vectors 100 times + size_t slot_size = 16; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 30, 30, 30 })); - SEALContext context(parms, true, sec_level_type::none); - auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); CKKSEncoder encoder(context); Encryptor encryptor(context, pk); Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); - int data_bound = 1 << 30; - srand(static_cast(time(NULL))); - - vector> input(slot_size, 0.0); - vector> output(slot_size); - Ciphertext encrypted; Plaintext plain; Plaintext plainRes; + vector> input(slot_size, 0.0); + vector> expected(slot_size, 0.0); + + int data_bound = 1 << 7; + srand(static_cast(time(NULL))); + for (int round = 0; round < 100; round++) { for (size_t i = 0; i < slot_size; i++) { input[i] = static_cast(rand() % data_bound); + expected[i] = input[i] * input[i]; } - double delta = static_cast(1ULL << 40); + vector> output(slot_size); + const double delta = static_cast(1ULL << 40); encoder.encode(input, context.first_parms_id(), delta, plain); encryptor.encrypt(plain, encrypted); @@ -2749,195 +3274,171 @@ namespace sealtest // Check correctness of encryption ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - // Not inplace - Ciphertext destination; - evaluator.mod_switch_to_next(encrypted, destination); - - // Check correctness of modulus switching - ASSERT_TRUE(destination.parms_id() == next_parms_id); - - decryptor.decrypt(destination, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - - // Inplace - evaluator.mod_switch_to_next_inplace(encrypted); - - // Check correctness of modulus switching - ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + // Evaluator.square_inplace(encrypted); + evaluator.multiply_inplace(encrypted, encrypted); + evaluator.relinearize_inplace(encrypted, rlk); decryptor.decrypt(encrypted, plainRes); encoder.decode(plainRes, output); for (size_t i = 0; i < slot_size; i++) { - auto tmp = abs(input[i].real() - output[i].real()); + auto tmp = abs(expected[i].real() - output[i].real()); ASSERT_TRUE(tmp < 0.5); } } } + } + + TEST(EvaluatorTest, CKKSEncryptMultiplyRelinRescaleDecrypt) + { + EncryptionParameters parms(scheme_type::ckks); { - // Modulus switching without rescaling for random vectors - size_t slot_size = 32; + // Multiplying two random vectors 100 times + size_t slot_size = 64; parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40, 40 })); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30, 30, 30 })); SEALContext context(parms, true, sec_level_type::none); auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); CKKSEncoder encoder(context); Encryptor encryptor(context, pk); Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); - int data_bound = 1 << 30; - srand(static_cast(time(NULL))); - - vector> input(slot_size, 0.0); - vector> output(slot_size); - - Ciphertext encrypted; - Plaintext plain; + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encryptedRes; + Plaintext plain1; + Plaintext plain2; Plaintext plainRes; + vector> input1(slot_size, 0.0); + vector> input2(slot_size, 0.0); + vector> expected(slot_size, 0.0); + for (int round = 0; round < 100; round++) { + int data_bound = 1 << 7; + srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { - input[i] = static_cast(rand() % data_bound); + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i]; } + vector> output(slot_size); double delta = static_cast(1ULL << 40); - encoder.encode(input, context.first_parms_id(), delta, plain); + encoder.encode(input1, context.first_parms_id(), delta, plain1); + encoder.encode(input2, context.first_parms_id(), delta, plain2); - encryptor.encrypt(plain, encrypted); + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); // Check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - - // Not inplace - Ciphertext destination; - evaluator.mod_switch_to_next(encrypted, destination); - - // Check correctness of modulus switching - ASSERT_TRUE(destination.parms_id() == next_parms_id); - - decryptor.decrypt(destination, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + // Check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); - // Inplace - evaluator.mod_switch_to_next_inplace(encrypted); + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); + evaluator.rescale_to_next_inplace(encrypted1); // Check correctness of modulus switching - ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); - decryptor.decrypt(encrypted, plainRes); + decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); for (size_t i = 0; i < slot_size; i++) { - auto tmp = abs(input[i].real() - output[i].real()); + auto tmp = abs(expected[i].real() - output[i].real()); ASSERT_TRUE(tmp < 0.5); } } } { - // Modulus switching without rescaling for random vectors - size_t slot_size = 32; + // Multiplying two random vectors 100 times + size_t slot_size = 16; parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40, 40, 40 })); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30, 30 })); SEALContext context(parms, true, sec_level_type::none); auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); CKKSEncoder encoder(context); Encryptor encryptor(context, pk); Decryptor decryptor(context, keygen.secret_key()); Evaluator evaluator(context); - int data_bound = 1 << 30; - srand(static_cast(time(NULL))); - - vector> input(slot_size, 0.0); - vector> output(slot_size); - - Ciphertext encrypted; - Plaintext plain; + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encryptedRes; + Plaintext plain1; + Plaintext plain2; Plaintext plainRes; + vector> input1(slot_size, 0.0); + vector> input2(slot_size, 0.0); + vector> expected(slot_size, 0.0); + for (int round = 0; round < 100; round++) { + int data_bound = 1 << 7; + srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { - input[i] = static_cast(rand() % data_bound); + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i]; } + vector> output(slot_size); double delta = static_cast(1ULL << 40); - encoder.encode(input, context.first_parms_id(), delta, plain); + encoder.encode(input1, context.first_parms_id(), delta, plain1); + encoder.encode(input2, context.first_parms_id(), delta, plain2); - encryptor.encrypt(plain, encrypted); + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); // Check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - - // Not inplace - Ciphertext destination; - evaluator.mod_switch_to_next(encrypted, destination); - - // Check correctness of modulus switching - ASSERT_TRUE(destination.parms_id() == next_parms_id); - - decryptor.decrypt(destination, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + // Check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); - // Inplace - evaluator.mod_switch_to_next_inplace(encrypted); + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); + evaluator.rescale_to_next_inplace(encrypted1); // Check correctness of modulus switching - ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); - decryptor.decrypt(encrypted, plainRes); + decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); for (size_t i = 0; i < slot_size; i++) { - auto tmp = abs(input[i].real() - output[i].real()); + auto tmp = abs(expected[i].real() - output[i].real()); ASSERT_TRUE(tmp < 0.5); } } } - } - TEST(EvaluatorTest, CKKSEncryptMultiplyRelinRescaleModSwitchAddDecrypt) - { - EncryptionParameters parms(scheme_type::ckks); { - // Multiplication and addition without rescaling for random vectors - size_t slot_size = 64; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 50, 50, 50 })); + // Multiplying two random vectors 100 times + size_t slot_size = 16; + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 60, 60, 60, 60, 60 })); SEALContext context(parms, true, sec_level_type::none); - auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); @@ -2951,58 +3452,50 @@ namespace sealtest Ciphertext encrypted1; Ciphertext encrypted2; - Ciphertext encrypted3; + Ciphertext encryptedRes; Plaintext plain1; Plaintext plain2; - Plaintext plain3; Plaintext plainRes; vector> input1(slot_size, 0.0); vector> input2(slot_size, 0.0); - vector> input3(slot_size, 0.0); vector> expected(slot_size, 0.0); for (int round = 0; round < 100; round++) { - int data_bound = 1 << 8; + int data_bound = 1 << 7; srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { input1[i] = static_cast(rand() % data_bound); input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i] + input3[i]; + expected[i] = input1[i] * input2[i] * input2[i]; } vector> output(slot_size); - double delta = static_cast(1ULL << 40); + double delta = static_cast(1ULL << 60); encoder.encode(input1, context.first_parms_id(), delta, plain1); encoder.encode(input2, context.first_parms_id(), delta, plain2); - encoder.encode(input3, context.first_parms_id(), delta * delta, plain3); encryptor.encrypt(plain1, encrypted1); encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); // Check correctness of encryption ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); // Check correctness of encryption ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); - // Check correctness of encryption - ASSERT_TRUE(encrypted3.parms_id() == context.first_parms_id()); - // Enc1*enc2 evaluator.multiply_inplace(encrypted1, encrypted2); evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.rescale_to_next_inplace(encrypted1); - - // Check correctness of modulus switching with rescaling - ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); - // Move enc3 to the level of enc1 * enc2 - evaluator.rescale_to_inplace(encrypted3, next_parms_id); + // Scale down by two levels + auto target_parms = context.first_context_data()->next_context_data()->next_context_data()->parms_id(); + evaluator.rescale_to_inplace(encrypted1, target_parms); - // Enc1*enc2 + enc3 - evaluator.add_inplace(encrypted1, encrypted3); + // Check correctness of modulus switching + ASSERT_TRUE(encrypted1.parms_id() == target_parms); decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); @@ -3012,80 +3505,45 @@ namespace sealtest ASSERT_TRUE(tmp < 0.5); } } - } - { - // Multiplication and addition without rescaling for random vectors - size_t slot_size = 16; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 50, 50, 50 })); - - SEALContext context(parms, true, sec_level_type::none); - auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); - KeyGenerator keygen(context); - PublicKey pk; - keygen.create_public_key(pk); - RelinKeys rlk; - keygen.create_relin_keys(rlk); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, pk); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Ciphertext encrypted3; - Plaintext plain1; - Plaintext plain2; - Plaintext plain3; - Plaintext plainRes; - - vector> input1(slot_size, 0.0); - vector> input2(slot_size, 0.0); - vector> input3(slot_size, 0.0); - vector> expected(slot_size, 0.0); - vector> output(slot_size); + // Test with inverted order: rescale then relin for (int round = 0; round < 100; round++) { - int data_bound = 1 << 8; + int data_bound = 1 << 7; srand(static_cast(time(NULL))); for (size_t i = 0; i < slot_size; i++) { input1[i] = static_cast(rand() % data_bound); input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i] + input3[i]; + expected[i] = input1[i] * input2[i] * input2[i]; } - double delta = static_cast(1ULL << 40); + vector> output(slot_size); + double delta = static_cast(1ULL << 50); encoder.encode(input1, context.first_parms_id(), delta, plain1); encoder.encode(input2, context.first_parms_id(), delta, plain2); - encoder.encode(input3, context.first_parms_id(), delta * delta, plain3); encryptor.encrypt(plain1, encrypted1); encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); // Check correctness of encryption ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); // Check correctness of encryption ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); - // Check correctness of encryption - ASSERT_TRUE(encrypted3.parms_id() == context.first_parms_id()); - // Enc1*enc2 evaluator.multiply_inplace(encrypted1, encrypted2); evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.rescale_to_next_inplace(encrypted1); + evaluator.multiply_inplace(encrypted1, encrypted2); - // Check correctness of modulus switching with rescaling - ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); + // Scale down by two levels + auto target_parms = context.first_context_data()->next_context_data()->next_context_data()->parms_id(); + evaluator.rescale_to_inplace(encrypted1, target_parms); - // Move enc3 to the level of enc1 * enc2 - evaluator.rescale_to_inplace(encrypted3, next_parms_id); + // Relinearize now + evaluator.relinearize_inplace(encrypted1, rlk); - // Enc1*enc2 + enc3 - evaluator.add_inplace(encrypted1, encrypted3); + // Check correctness of modulus switching + ASSERT_TRUE(encrypted1.parms_id() == target_parms); decryptor.decrypt(encrypted1, plainRes); encoder.decode(plainRes, output); @@ -3097,329 +3555,1648 @@ namespace sealtest } } } - TEST(EvaluatorTest, CKKSEncryptRotateDecrypt) + + TEST(EvaluatorTest, CKKSEncryptSquareRelinRescaleDecrypt) { EncryptionParameters parms(scheme_type::ckks); { - // Maximal number of slots - size_t slot_size = 4; + // Squaring two random vectors 100 times + size_t slot_size = 64; parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40 })); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 50, 50, 50 })); - SEALContext context(parms, false, sec_level_type::none); + SEALContext context(parms, true, sec_level_type::none); + auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - GaloisKeys glk; - keygen.create_galois_keys(glk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + CKKSEncoder encoder(context); Encryptor encryptor(context, pk); - Evaluator evaluator(context); Decryptor decryptor(context, keygen.secret_key()); - CKKSEncoder encoder(context); - const double delta = static_cast(1ULL << 30); + Evaluator evaluator(context); Ciphertext encrypted; Plaintext plain; + Plaintext plainRes; - vector> input{ complex(1, 1), complex(2, 2), complex(3, 3), - complex(4, 4) }; - input.resize(slot_size); - - vector> output(slot_size, 0); + vector> input(slot_size, 0.0); + vector> output(slot_size); + vector> expected(slot_size, 0.0); + int data_bound = 1 << 8; - encoder.encode(input, context.first_parms_id(), delta, plain); - int shift = 1; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) + for (int round = 0; round < 100; round++) { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + expected[i] = input[i] * input[i]; + } - encoder.encode(input, context.first_parms_id(), delta, plain); - shift = 2; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } + double delta = static_cast(1ULL << 40); + encoder.encode(input, context.first_parms_id(), delta, plain); - encoder.encode(input, context.first_parms_id(), delta, plain); - shift = 3; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } + encryptor.encrypt(plain, encrypted); - encoder.encode(input, context.first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - evaluator.complex_conjugate_inplace(encrypted, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[i].real(), round(output[i].real())); - ASSERT_EQ(-input[i].imag(), round(output[i].imag())); - } - } - { - size_t slot_size = 32; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40, 40 })); + // Check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.rescale_to_next_inplace(encrypted); + + // Check correctness of modulus switching + ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + // Squaring two random vectors 100 times + size_t slot_size = 16; + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 50, 50, 50 })); + + SEALContext context(parms, true, sec_level_type::none); + auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - GaloisKeys glk; - keygen.create_galois_keys(glk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + CKKSEncoder encoder(context); Encryptor encryptor(context, pk); - Evaluator evaluator(context); Decryptor decryptor(context, keygen.secret_key()); - CKKSEncoder encoder(context); - const double delta = static_cast(1ULL << 30); + Evaluator evaluator(context); Ciphertext encrypted; Plaintext plain; + Plaintext plainRes; - vector> input{ complex(1, 1), complex(2, 2), complex(3, 3), - complex(4, 4) }; - input.resize(slot_size); - - vector> output(slot_size, 0); + vector> input(slot_size, 0.0); + vector> output(slot_size); + vector> expected(slot_size, 0.0); + int data_bound = 1 << 8; - encoder.encode(input, context.first_parms_id(), delta, plain); - int shift = 1; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < input.size(); i++) + for (int round = 0; round < 100; round++) { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + expected[i] = input[i] * input[i]; + } - encoder.encode(input, context.first_parms_id(), delta, plain); - shift = 2; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } + double delta = static_cast(1ULL << 40); + encoder.encode(input, context.first_parms_id(), delta, plain); - encoder.encode(input, context.first_parms_id(), delta, plain); - shift = 3; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } + encryptor.encrypt(plain, encrypted); - encoder.encode(input, context.first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - evaluator.complex_conjugate_inplace(encrypted, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[i].real()), round(output[i].real())); - ASSERT_EQ(round(-input[i].imag()), round(output[i].imag())); + // Check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.rescale_to_next_inplace(encrypted); + + // Check correctness of modulus switching + ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } } } } - - TEST(EvaluatorTest, CKKSEncryptRescaleRotateDecrypt) + + TEST(EvaluatorTest, CKKSEncryptModSwitchDecrypt) { EncryptionParameters parms(scheme_type::ckks); { - // Maximal number of slots - size_t slot_size = 4; + // Modulus switching without rescaling for random vectors + size_t slot_size = 64; parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40 })); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60, 60, 60 })); SEALContext context(parms, true, sec_level_type::none); + auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - GaloisKeys glk; - keygen.create_galois_keys(glk); + CKKSEncoder encoder(context); Encryptor encryptor(context, pk); - Evaluator evaluator(context); Decryptor decryptor(context, keygen.secret_key()); - CKKSEncoder encoder(context); - const double delta = pow(2.0, 70); + Evaluator evaluator(context); + + int data_bound = 1 << 30; + srand(static_cast(time(NULL))); + + vector> input(slot_size, 0.0); + vector> output(slot_size); Ciphertext encrypted; Plaintext plain; + Plaintext plainRes; - vector> input{ complex(1, 1), complex(2, 2), complex(3, 3), - complex(4, 4) }; - input.resize(slot_size); + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + } - vector> output(slot_size, 0); + double delta = static_cast(1ULL << 40); + encoder.encode(input, context.first_parms_id(), delta, plain); - encoder.encode(input, context.first_parms_id(), delta, plain); - int shift = 1; - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } + encryptor.encrypt(plain, encrypted); - encoder.encode(input, context.first_parms_id(), delta, plain); - shift = 2; - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } + // Check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - encoder.encode(input, context.first_parms_id(), delta, plain); - shift = 3; - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } + // Not inplace + Ciphertext destination; + evaluator.mod_switch_to_next(encrypted, destination); - encoder.encode(input, context.first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.complex_conjugate_inplace(encrypted, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[i].real(), round(output[i].real())); - ASSERT_EQ(-input[i].imag(), round(output[i].imag())); + // Check correctness of modulus switching + ASSERT_TRUE(destination.parms_id() == next_parms_id); + + decryptor.decrypt(destination, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + + // Inplace + evaluator.mod_switch_to_next_inplace(encrypted); + + // Check correctness of modulus switching + ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } } } { + // Modulus switching without rescaling for random vectors size_t slot_size = 32; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40, 40 })); + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40, 40 })); + + SEALContext context(parms, true, sec_level_type::none); + auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, pk); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + int data_bound = 1 << 30; + srand(static_cast(time(NULL))); + + vector> input(slot_size, 0.0); + vector> output(slot_size); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + } + + double delta = static_cast(1ULL << 40); + encoder.encode(input, context.first_parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + + // Check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + // Not inplace + Ciphertext destination; + evaluator.mod_switch_to_next(encrypted, destination); + + // Check correctness of modulus switching + ASSERT_TRUE(destination.parms_id() == next_parms_id); + + decryptor.decrypt(destination, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + + // Inplace + evaluator.mod_switch_to_next_inplace(encrypted); + + // Check correctness of modulus switching + ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + // Modulus switching without rescaling for random vectors + size_t slot_size = 32; + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40, 40, 40 })); + + SEALContext context(parms, true, sec_level_type::none); + auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, pk); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + int data_bound = 1 << 30; + srand(static_cast(time(NULL))); + + vector> input(slot_size, 0.0); + vector> output(slot_size); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + } + + double delta = static_cast(1ULL << 40); + encoder.encode(input, context.first_parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + + // Check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + // Not inplace + Ciphertext destination; + evaluator.mod_switch_to_next(encrypted, destination); + + // Check correctness of modulus switching + ASSERT_TRUE(destination.parms_id() == next_parms_id); + + decryptor.decrypt(destination, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + + // Inplace + evaluator.mod_switch_to_next_inplace(encrypted); + + // Check correctness of modulus switching + ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + + TEST(EvaluatorTest, CKKSEncryptMultiplyRelinRescaleModSwitchAddDecrypt) + { + EncryptionParameters parms(scheme_type::ckks); + { + // Multiplication and addition without rescaling for random vectors + size_t slot_size = 64; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 50, 50, 50 })); + + SEALContext context(parms, true, sec_level_type::none); + auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, pk); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encrypted3; + Plaintext plain1; + Plaintext plain2; + Plaintext plain3; + Plaintext plainRes; + + vector> input1(slot_size, 0.0); + vector> input2(slot_size, 0.0); + vector> input3(slot_size, 0.0); + vector> expected(slot_size, 0.0); + + for (int round = 0; round < 100; round++) + { + int data_bound = 1 << 8; + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i] + input3[i]; + } + + vector> output(slot_size); + double delta = static_cast(1ULL << 40); + encoder.encode(input1, context.first_parms_id(), delta, plain1); + encoder.encode(input2, context.first_parms_id(), delta, plain2); + encoder.encode(input3, context.first_parms_id(), delta * delta, plain3); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + + // Check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + // Check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); + // Check correctness of encryption + ASSERT_TRUE(encrypted3.parms_id() == context.first_parms_id()); + + // Enc1*enc2 + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); + evaluator.rescale_to_next_inplace(encrypted1); + + // Check correctness of modulus switching with rescaling + ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); + + // Move enc3 to the level of enc1 * enc2 + evaluator.rescale_to_inplace(encrypted3, next_parms_id); + + // Enc1*enc2 + enc3 + evaluator.add_inplace(encrypted1, encrypted3); + + decryptor.decrypt(encrypted1, plainRes); + encoder.decode(plainRes, output); + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + // Multiplication and addition without rescaling for random vectors + size_t slot_size = 16; + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 50, 50, 50 })); + + SEALContext context(parms, true, sec_level_type::none); + auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, pk); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encrypted3; + Plaintext plain1; + Plaintext plain2; + Plaintext plain3; + Plaintext plainRes; + + vector> input1(slot_size, 0.0); + vector> input2(slot_size, 0.0); + vector> input3(slot_size, 0.0); + vector> expected(slot_size, 0.0); + vector> output(slot_size); + + for (int round = 0; round < 100; round++) + { + int data_bound = 1 << 8; + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i] + input3[i]; + } + + double delta = static_cast(1ULL << 40); + encoder.encode(input1, context.first_parms_id(), delta, plain1); + encoder.encode(input2, context.first_parms_id(), delta, plain2); + encoder.encode(input3, context.first_parms_id(), delta * delta, plain3); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + + // Check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id()); + // Check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id()); + // Check correctness of encryption + ASSERT_TRUE(encrypted3.parms_id() == context.first_parms_id()); + + // Enc1*enc2 + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); + evaluator.rescale_to_next_inplace(encrypted1); + + // Check correctness of modulus switching with rescaling + ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); + + // Move enc3 to the level of enc1 * enc2 + evaluator.rescale_to_inplace(encrypted3, next_parms_id); + + // Enc1*enc2 + enc3 + evaluator.add_inplace(encrypted1, encrypted3); + + decryptor.decrypt(encrypted1, plainRes); + encoder.decode(plainRes, output); + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + + TEST(EvaluatorTest, CKKSEncryptRotateDecrypt) + { + EncryptionParameters parms(scheme_type::ckks); + { + // Maximal number of slots + size_t slot_size = 4; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + GaloisKeys glk; + keygen.create_galois_keys(glk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + CKKSEncoder encoder(context); + const double delta = static_cast(1ULL << 30); + + Ciphertext encrypted; + Plaintext plain; + + vector> input{ complex(1, 1), complex(2, 2), complex(3, 3), + complex(4, 4) }; + input.resize(slot_size); + + vector> output(slot_size, 0); + + encoder.encode(input, context.first_parms_id(), delta, plain); + int shift = 1; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + shift = 2; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + shift = 3; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + evaluator.complex_conjugate_inplace(encrypted, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[i].real(), round(output[i].real())); + ASSERT_EQ(-input[i].imag(), round(output[i].imag())); + } + } + { + size_t slot_size = 32; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + GaloisKeys glk; + keygen.create_galois_keys(glk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + CKKSEncoder encoder(context); + const double delta = static_cast(1ULL << 30); + + Ciphertext encrypted; + Plaintext plain; + + vector> input{ complex(1, 1), complex(2, 2), complex(3, 3), + complex(4, 4) }; + input.resize(slot_size); + + vector> output(slot_size, 0); + + encoder.encode(input, context.first_parms_id(), delta, plain); + int shift = 1; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < input.size(); i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + shift = 2; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + shift = 3; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + evaluator.complex_conjugate_inplace(encrypted, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[i].real()), round(output[i].real())); + ASSERT_EQ(round(-input[i].imag()), round(output[i].imag())); + } + } + } + + TEST(EvaluatorTest, CKKSEncryptRescaleRotateDecrypt) + { + EncryptionParameters parms(scheme_type::ckks); + { + // Maximal number of slots + size_t slot_size = 4; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40 })); + + SEALContext context(parms, true, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + GaloisKeys glk; + keygen.create_galois_keys(glk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + CKKSEncoder encoder(context); + const double delta = pow(2.0, 70); + + Ciphertext encrypted; + Plaintext plain; + + vector> input{ complex(1, 1), complex(2, 2), complex(3, 3), + complex(4, 4) }; + input.resize(slot_size); + + vector> output(slot_size, 0); + + encoder.encode(input, context.first_parms_id(), delta, plain); + int shift = 1; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + shift = 2; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + shift = 3; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.complex_conjugate_inplace(encrypted, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[i].real(), round(output[i].real())); + ASSERT_EQ(-input[i].imag(), round(output[i].imag())); + } + } + { + size_t slot_size = 32; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40, 40 })); + + SEALContext context(parms, true, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + GaloisKeys glk; + keygen.create_galois_keys(glk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + CKKSEncoder encoder(context); + const double delta = pow(2, 70); + + Ciphertext encrypted; + Plaintext plain; + + vector> input{ complex(1, 1), complex(2, 2), complex(3, 3), + complex(4, 4) }; + input.resize(slot_size); + + vector> output(slot_size, 0); + + encoder.encode(input, context.first_parms_id(), delta, plain); + int shift = 1; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + shift = 2; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + shift = 3; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, context.first_parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.complex_conjugate_inplace(encrypted, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[i].real()), round(output[i].real())); + ASSERT_EQ(round(-input[i].imag()), round(output[i].imag())); + } + } + } + + TEST(EvaluatorTest, BFVEncryptSquareDecrypt) + { + EncryptionParameters parms(scheme_type::bfv); + Modulus plain_modulus(1 << 8); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + + plain = "1"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "0"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "FFx^2 + FF"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^4 + 2x^2 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "FF"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^1 + 1"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ( + plain.to_string(), + "1x^12 + 2x^11 + 3x^10 + 4x^9 + 3x^8 + 4x^7 + 5x^6 + 4x^5 + 4x^4 + 2x^3 + 1x^2 + 2x^1 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^16"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^32"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^1 + 1"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ( + plain.to_string(), + "1x^24 + 4x^23 + Ax^22 + 14x^21 + 1Fx^20 + 2Cx^19 + 3Cx^18 + 4Cx^17 + 5Fx^16 + 6Cx^15 + 70x^14 + 74x^13 + " + "71x^12 + 6Cx^11 + 64x^10 + 50x^9 + 40x^8 + 34x^7 + 26x^6 + 1Cx^5 + 11x^4 + 8x^3 + 6x^2 + 4x^1 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + } + + TEST(EvaluatorTest, BFVEncryptMultiplyManyDecrypt) + { + EncryptionParameters parms(scheme_type::bfv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, product; + Plaintext plain, plain1, plain2, plain3, plain4; + + plain1 = "1x^2 + 1"; + plain2 = "1x^2 + 1x^1"; + plain3 = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + vector encrypteds{ encrypted1, encrypted2, encrypted3 }; + evaluator.multiply_many(encrypteds, rlk, product); + ASSERT_EQ(3, encrypteds.size()); + decryptor.decrypt(product, plain); + ASSERT_EQ(plain.to_string(), "1x^6 + 2x^5 + 3x^4 + 3x^3 + 2x^2 + 1x^1"); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + + plain1 = "3Fx^3 + 3F"; + plain2 = "3Fx^4 + 3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encrypteds = { encrypted1, encrypted2 }; + evaluator.multiply_many(encrypteds, rlk, product); + ASSERT_EQ(2, encrypteds.size()); + decryptor.decrypt(product, plain); + ASSERT_EQ(plain.to_string(), "1x^7 + 1x^4 + 1x^3 + 1"); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + + plain1 = "1x^1"; + plain2 = "3Fx^4 + 3Fx^3 + 3Fx^2 + 3Fx^1 + 3F"; + plain3 = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encrypteds = { encrypted1, encrypted2, encrypted3 }; + evaluator.multiply_many(encrypteds, rlk, product); + ASSERT_EQ(3, encrypteds.size()); + decryptor.decrypt(product, plain); + ASSERT_EQ(plain.to_string(), "3Fx^7 + 3Ex^6 + 3Dx^5 + 3Dx^4 + 3Dx^3 + 3Ex^2 + 3Fx^1"); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + + plain1 = "1"; + plain2 = "3F"; + plain3 = "1"; + plain4 = "3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encryptor.encrypt(plain4, encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.multiply_many(encrypteds, rlk, product); + ASSERT_EQ(4, encrypteds.size()); + decryptor.decrypt(product, plain); + ASSERT_EQ(plain.to_string(), "1"); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + + plain1 = "1x^16 + 1x^15 + 1x^8 + 1x^7 + 1x^6 + 1x^3 + 1x^2 + 1"; + plain2 = "0"; + plain3 = "1x^13 + 1x^12 + 1x^5 + 1x^4 + 1x^3 + 1"; + plain4 = "1x^15 + 1x^10 + 1x^9 + 1x^8 + 1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encryptor.encrypt(plain4, encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.multiply_many(encrypteds, rlk, product); + ASSERT_EQ(4, encrypteds.size()); + decryptor.decrypt(product, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + } + + TEST(EvaluatorTest, BFVEncryptExponentiateDecrypt) + { + EncryptionParameters parms(scheme_type::bfv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + + plain = "1x^2 + 1"; + encryptor.encrypt(plain, encrypted); + evaluator.exponentiate_inplace(encrypted, 1, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain, encrypted); + evaluator.exponentiate_inplace(encrypted, 2, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^4 + 2x^3 + 3x^2 + 2x^1 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "3Fx^2 + 3Fx^1 + 3F"; + encryptor.encrypt(plain, encrypted); + evaluator.exponentiate_inplace(encrypted, 3, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "3Fx^6 + 3Dx^5 + 3Ax^4 + 39x^3 + 3Ax^2 + 3Dx^1 + 3F"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^8"; + encryptor.encrypt(plain, encrypted); + evaluator.exponentiate_inplace(encrypted, 4, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^32"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + } + + TEST(EvaluatorTest, BFVEncryptAddManyDecrypt) + { + EncryptionParameters parms(scheme_type::bfv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, sum; + Plaintext plain, plain1, plain2, plain3, plain4; + + plain1 = "1x^2 + 1"; + plain2 = "1x^2 + 1x^1"; + plain3 = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + vector encrypteds = { encrypted1, encrypted2, encrypted3 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(plain.to_string(), "3x^2 + 2x^1 + 2"); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + + plain1 = "3Fx^3 + 3F"; + plain2 = "3Fx^4 + 3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encrypteds = { + encrypted1, + encrypted2, + }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(plain.to_string(), "3Fx^4 + 3Fx^3 + 3E"); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + + plain1 = "1x^1"; + plain2 = "3Fx^4 + 3Fx^3 + 3Fx^2 + 3Fx^1 + 3F"; + plain3 = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encrypteds = { encrypted1, encrypted2, encrypted3 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(plain.to_string(), "3Fx^4 + 3Fx^3 + 1x^1"); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + + plain1 = "1"; + plain2 = "3F"; + plain3 = "1"; + plain4 = "3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encryptor.encrypt(plain4, encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + + plain1 = "1x^16 + 1x^15 + 1x^8 + 1x^7 + 1x^6 + 1x^3 + 1x^2 + 1"; + plain2 = "0"; + plain3 = "1x^13 + 1x^12 + 1x^5 + 1x^4 + 1x^3 + 1"; + plain4 = "1x^15 + 1x^10 + 1x^9 + 1x^8 + 1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encryptor.encrypt(plain4, encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ( + plain.to_string(), + "1x^16 + 2x^15 + 1x^13 + 1x^12 + 1x^10 + 1x^9 + 2x^8 + 1x^7 + 1x^6 + 1x^5 + 1x^4 + 2x^3 + 2x^2 + 1x^1 + 3"); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + } + + TEST(EvaluatorTest, BGVEncryptSquareDecrypt) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 8); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + + plain = "1"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "0"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "FFx^2 + FF"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^4 + 2x^2 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "FF"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^1 + 1"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ( + plain.to_string(), + "1x^12 + 2x^11 + 3x^10 + 4x^9 + 3x^8 + 4x^7 + 5x^6 + 4x^5 + 4x^4 + 2x^3 + 1x^2 + 2x^1 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^16"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^32"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^1 + 1"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ( + plain.to_string(), + "1x^24 + 4x^23 + Ax^22 + 14x^21 + 1Fx^20 + 2Cx^19 + 3Cx^18 + 4Cx^17 + 5Fx^16 + 6Cx^15 + 70x^14 + 74x^13 + " + "71x^12 + 6Cx^11 + 64x^10 + 50x^9 + 40x^8 + 34x^7 + 26x^6 + 1Cx^5 + 11x^4 + 8x^3 + 6x^2 + 4x^1 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + } + + TEST(EvaluatorTest, BGVEncryptMultiplyManyDecrypt) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, product; + Plaintext plain, plain1, plain2, plain3, plain4; + + plain1 = "1x^2 + 1"; + plain2 = "1x^2 + 1x^1"; + plain3 = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + vector encrypteds{ encrypted1, encrypted2, encrypted3 }; + evaluator.multiply_many(encrypteds, rlk, product); + ASSERT_EQ(3, encrypteds.size()); + decryptor.decrypt(product, plain); + ASSERT_EQ(plain.to_string(), "1x^6 + 2x^5 + 3x^4 + 3x^3 + 2x^2 + 1x^1"); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + + plain1 = "3Fx^3 + 3F"; + plain2 = "3Fx^4 + 3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encrypteds = { encrypted1, encrypted2 }; + evaluator.multiply_many(encrypteds, rlk, product); + ASSERT_EQ(2, encrypteds.size()); + decryptor.decrypt(product, plain); + ASSERT_EQ(plain.to_string(), "1x^7 + 1x^4 + 1x^3 + 1"); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + + plain1 = "1x^1"; + plain2 = "3Fx^4 + 3Fx^3 + 3Fx^2 + 3Fx^1 + 3F"; + plain3 = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encrypteds = { encrypted1, encrypted2, encrypted3 }; + evaluator.multiply_many(encrypteds, rlk, product); + ASSERT_EQ(3, encrypteds.size()); + decryptor.decrypt(product, plain); + ASSERT_EQ(plain.to_string(), "3Fx^7 + 3Ex^6 + 3Dx^5 + 3Dx^4 + 3Dx^3 + 3Ex^2 + 3Fx^1"); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + + plain1 = "1"; + plain2 = "3F"; + plain3 = "1"; + plain4 = "3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encryptor.encrypt(plain4, encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.multiply_many(encrypteds, rlk, product); + ASSERT_EQ(4, encrypteds.size()); + decryptor.decrypt(product, plain); + ASSERT_EQ(plain.to_string(), "1"); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + + plain1 = "1x^16 + 1x^15 + 1x^8 + 1x^7 + 1x^6 + 1x^3 + 1x^2 + 1"; + plain2 = "0"; + plain3 = "1x^13 + 1x^12 + 1x^5 + 1x^4 + 1x^3 + 1"; + plain4 = "1x^15 + 1x^10 + 1x^9 + 1x^8 + 1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encryptor.encrypt(plain4, encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.multiply_many(encrypteds, rlk, product); + ASSERT_EQ(4, encrypteds.size()); + decryptor.decrypt(product, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + } + + TEST(EvaluatorTest, BGVEncryptExponentiateDecrypt) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + + plain = "1x^2 + 1"; + encryptor.encrypt(plain, encrypted); + evaluator.exponentiate_inplace(encrypted, 1, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^2 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain, encrypted); + evaluator.exponentiate_inplace(encrypted, 2, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^4 + 2x^3 + 3x^2 + 2x^1 + 1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "3Fx^2 + 3Fx^1 + 3F"; + encryptor.encrypt(plain, encrypted); + evaluator.exponentiate_inplace(encrypted, 3, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "3Fx^6 + 3Dx^5 + 3Ax^4 + 39x^3 + 3Ax^2 + 3Dx^1 + 3F"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + + plain = "1x^8"; + encryptor.encrypt(plain, encrypted); + evaluator.exponentiate_inplace(encrypted, 4, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(plain.to_string(), "1x^32"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + } + + TEST(EvaluatorTest, BGVEncryptAddManyDecrypt) + { + EncryptionParameters parms(scheme_type::bgv); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, sum; + Plaintext plain, plain1, plain2, plain3, plain4; + + plain1 = "1x^2 + 1"; + plain2 = "1x^2 + 1x^1"; + plain3 = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + vector encrypteds = { encrypted1, encrypted2, encrypted3 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(plain.to_string(), "3x^2 + 2x^1 + 2"); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + + plain1 = "3Fx^3 + 3F"; + plain2 = "3Fx^4 + 3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encrypteds = { + encrypted1, + encrypted2, + }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(plain.to_string(), "3Fx^4 + 3Fx^3 + 3E"); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + + plain1 = "1x^1"; + plain2 = "3Fx^4 + 3Fx^3 + 3Fx^2 + 3Fx^1 + 3F"; + plain3 = "1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encrypteds = { encrypted1, encrypted2, encrypted3 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(plain.to_string(), "3Fx^4 + 3Fx^3 + 1x^1"); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + + plain1 = "1"; + plain2 = "3F"; + plain3 = "1"; + plain4 = "3F"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encryptor.encrypt(plain4, encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(plain.to_string(), "0"); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + + plain1 = "1x^16 + 1x^15 + 1x^8 + 1x^7 + 1x^6 + 1x^3 + 1x^2 + 1"; + plain2 = "0"; + plain3 = "1x^13 + 1x^12 + 1x^5 + 1x^4 + 1x^3 + 1"; + plain4 = "1x^15 + 1x^10 + 1x^9 + 1x^8 + 1x^2 + 1x^1 + 1"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + encryptor.encrypt(plain4, encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ( + plain.to_string(), + "1x^16 + 2x^15 + 1x^13 + 1x^12 + 1x^10 + 1x^9 + 2x^8 + 1x^7 + 1x^6 + 1x^5 + 1x^4 + 2x^3 + 2x^2 + 1x^1 + 3"); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + } + + TEST(EvaluatorTest, TransformPlainToNTT) + { + auto evaluator_transform_plain_to_ntt = [](scheme_type scheme){ + EncryptionParameters parms(scheme); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + SEALContext context(parms, true, sec_level_type::none); + + Evaluator evaluator(context); + Plaintext plain("0"); + ASSERT_FALSE(plain.is_ntt_form()); + evaluator.transform_to_ntt_inplace(plain, context.first_parms_id()); + ASSERT_TRUE(plain.is_zero()); + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == context.first_parms_id()); + + plain.release(); + plain = "0"; + ASSERT_FALSE(plain.is_ntt_form()); + auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); + evaluator.transform_to_ntt_inplace(plain, next_parms_id); + ASSERT_TRUE(plain.is_zero()); + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == next_parms_id); + + plain.release(); + plain = "1"; + ASSERT_FALSE(plain.is_ntt_form()); + evaluator.transform_to_ntt_inplace(plain, context.first_parms_id()); + for (size_t i = 0; i < 256; i++) + { + ASSERT_TRUE(plain[i] == uint64_t(1)); + } + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == context.first_parms_id()); + + plain.release(); + plain = "1"; + ASSERT_FALSE(plain.is_ntt_form()); + evaluator.transform_to_ntt_inplace(plain, next_parms_id); + for (size_t i = 0; i < 128; i++) + { + ASSERT_TRUE(plain[i] == uint64_t(1)); + } + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == next_parms_id); + + plain.release(); + plain = "2"; + ASSERT_FALSE(plain.is_ntt_form()); + evaluator.transform_to_ntt_inplace(plain, context.first_parms_id()); + for (size_t i = 0; i < 256; i++) + { + ASSERT_TRUE(plain[i] == uint64_t(2)); + } + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == context.first_parms_id()); + + plain.release(); + plain = "2"; + evaluator.transform_to_ntt_inplace(plain, next_parms_id); + for (size_t i = 0; i < 128; i++) + { + ASSERT_TRUE(plain[i] == uint64_t(2)); + } + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == next_parms_id); + }; + evaluator_transform_plain_to_ntt(scheme_type::bfv); + evaluator_transform_plain_to_ntt(scheme_type::bgv); + } + + TEST(EvaluatorTest, TransformEncryptedToFromNTT) + { + auto evaluator_transform_encrypted_to_from_ntt = [](scheme_type scheme){ + EncryptionParameters parms(scheme); + Modulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); - SEALContext context(parms, true, sec_level_type::none); + SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - GaloisKeys glk; - keygen.create_galois_keys(glk); Encryptor encryptor(context, pk); Evaluator evaluator(context); Decryptor decryptor(context, keygen.secret_key()); - CKKSEncoder encoder(context); - const double delta = pow(2, 70); - Ciphertext encrypted; Plaintext plain; - - vector> input{ complex(1, 1), complex(2, 2), complex(3, 3), - complex(4, 4) }; - input.resize(slot_size); - - vector> output(slot_size, 0); - - encoder.encode(input, context.first_parms_id(), delta, plain); - int shift = 1; - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } - - encoder.encode(input, context.first_parms_id(), delta, plain); - shift = 2; + Ciphertext encrypted; + plain = "0"; encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); + evaluator.transform_to_ntt_inplace(encrypted); + evaluator.transform_from_ntt_inplace(encrypted); decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } + ASSERT_TRUE(plain.to_string() == "0"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - encoder.encode(input, context.first_parms_id(), delta, plain); - shift = 3; + plain = "1"; encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); + evaluator.transform_to_ntt_inplace(encrypted); + evaluator.transform_from_ntt_inplace(encrypted); decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } + ASSERT_TRUE(plain.to_string() == "1"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - encoder.encode(input, context.first_parms_id(), delta, plain); + plain = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"; encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.complex_conjugate_inplace(encrypted, glk); + evaluator.transform_to_ntt_inplace(encrypted); + evaluator.transform_from_ntt_inplace(encrypted); decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[i].real()), round(output[i].real())); - ASSERT_EQ(round(-input[i].imag()), round(output[i].imag())); - } - } + ASSERT_TRUE(plain.to_string() == "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"); + ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + }; + evaluator_transform_encrypted_to_from_ntt(scheme_type::bfv); + evaluator_transform_encrypted_to_from_ntt(scheme_type::bgv); } - TEST(EvaluatorTest, BFVEncryptSquareDecrypt) + TEST(EvaluatorTest, BFVEncryptMultiplyPlainNTTDecrypt) { EncryptionParameters parms(scheme_type::bfv); - Modulus plain_modulus(1 << 8); + Modulus plain_modulus(1 << 6); parms.set_poly_modulus_degree(128); parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); @@ -3430,443 +5207,427 @@ namespace sealtest Evaluator evaluator(context); Decryptor decryptor(context, keygen.secret_key()); - Ciphertext encrypted; Plaintext plain; + Plaintext plain_multiplier; + Ciphertext encrypted; - plain = "1"; - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(plain.to_string(), "1"); - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - - plain = "0"; - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(plain.to_string(), "0"); - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - - plain = "FFx^2 + FF"; - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(plain.to_string(), "1x^4 + 2x^2 + 1"); - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - - plain = "FF"; + plain = 0; encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + plain_multiplier = 1; + evaluator.transform_to_ntt_inplace(plain_multiplier, context.first_parms_id()); + evaluator.multiply_plain_inplace(encrypted, plain_multiplier); + evaluator.transform_from_ntt_inplace(encrypted); decryptor.decrypt(encrypted, plain); - ASSERT_EQ(plain.to_string(), "1"); + ASSERT_TRUE(plain.to_string() == "0"); ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - plain = "1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^1 + 1"; + plain = 2; encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + plain_multiplier.release(); + plain_multiplier = 3; + evaluator.transform_to_ntt_inplace(plain_multiplier, context.first_parms_id()); + evaluator.multiply_plain_inplace(encrypted, plain_multiplier); + evaluator.transform_from_ntt_inplace(encrypted); decryptor.decrypt(encrypted, plain); - ASSERT_EQ( - plain.to_string(), - "1x^12 + 2x^11 + 3x^10 + 4x^9 + 3x^8 + 4x^7 + 5x^6 + 4x^5 + 4x^4 + 2x^3 + 1x^2 + 2x^1 + 1"); + ASSERT_TRUE(plain.to_string() == "6"); ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - plain = "1x^16"; + plain = 1; encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + plain_multiplier.release(); + plain_multiplier = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"; + evaluator.transform_to_ntt_inplace(plain_multiplier, context.first_parms_id()); + evaluator.multiply_plain_inplace(encrypted, plain_multiplier); + evaluator.transform_from_ntt_inplace(encrypted); decryptor.decrypt(encrypted, plain); - ASSERT_EQ(plain.to_string(), "1x^32"); + ASSERT_TRUE(plain.to_string() == "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"); ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - plain = "1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^1 + 1"; + plain = "1x^20"; encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.square_inplace(encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + plain_multiplier.release(); + plain_multiplier = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"; + evaluator.transform_to_ntt_inplace(plain_multiplier, context.first_parms_id()); + evaluator.multiply_plain_inplace(encrypted, plain_multiplier); + evaluator.transform_from_ntt_inplace(encrypted); decryptor.decrypt(encrypted, plain); - ASSERT_EQ( - plain.to_string(), - "1x^24 + 4x^23 + Ax^22 + 14x^21 + 1Fx^20 + 2Cx^19 + 3Cx^18 + 4Cx^17 + 5Fx^16 + 6Cx^15 + 70x^14 + 74x^13 + " - "71x^12 + 6Cx^11 + 64x^10 + 50x^9 + 40x^8 + 34x^7 + 26x^6 + 1Cx^5 + 11x^4 + 8x^3 + 6x^2 + 4x^1 + 1"); + ASSERT_TRUE( + plain.to_string() == + "Fx^30 + Ex^29 + Dx^28 + Cx^27 + Bx^26 + Ax^25 + 1x^24 + 2x^23 + 3x^22 + 4x^21 + 5x^20"); ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); } - TEST(EvaluatorTest, BFVEncryptMultiplyManyDecrypt) + TEST(EvaluatorTest, BFVEncryptApplyGaloisDecrypt) { EncryptionParameters parms(scheme_type::bfv); - Modulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); + Modulus plain_modulus(257); + parms.set_poly_modulus_degree(8); parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + parms.set_coeff_modulus(CoeffModulus::Create(8, { 40, 40 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - RelinKeys rlk; - keygen.create_relin_keys(rlk); + GaloisKeys glk; + keygen.create_galois_keys(vector{ 1, 3, 5, 15 }, glk); Encryptor encryptor(context, pk); Evaluator evaluator(context); Decryptor decryptor(context, keygen.secret_key()); - Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, product; - Plaintext plain, plain1, plain2, plain3, plain4; - - plain1 = "1x^2 + 1"; - plain2 = "1x^2 + 1x^1"; - plain3 = "1x^2 + 1x^1 + 1"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); - vector encrypteds{ encrypted1, encrypted2, encrypted3 }; - evaluator.multiply_many(encrypteds, rlk, product); - ASSERT_EQ(3, encrypteds.size()); - decryptor.decrypt(product, plain); - ASSERT_EQ(plain.to_string(), "1x^6 + 2x^5 + 3x^4 + 3x^3 + 2x^2 + 1x^1"); - ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); - ASSERT_TRUE(product.parms_id() == context.first_parms_id()); - - plain1 = "3Fx^3 + 3F"; - plain2 = "3Fx^4 + 3F"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encrypteds = { encrypted1, encrypted2 }; - evaluator.multiply_many(encrypteds, rlk, product); - ASSERT_EQ(2, encrypteds.size()); - decryptor.decrypt(product, plain); - ASSERT_EQ(plain.to_string(), "1x^7 + 1x^4 + 1x^3 + 1"); - ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); - ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + Plaintext plain("1"); + Ciphertext encrypted; + encryptor.encrypt(plain, encrypted); + evaluator.apply_galois_inplace(encrypted, 1, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 3, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 5, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 15, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1" == plain.to_string()); - plain1 = "1x^1"; - plain2 = "3Fx^4 + 3Fx^3 + 3Fx^2 + 3Fx^1 + 3F"; - plain3 = "1x^2 + 1x^1 + 1"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); - encrypteds = { encrypted1, encrypted2, encrypted3 }; - evaluator.multiply_many(encrypteds, rlk, product); - ASSERT_EQ(3, encrypteds.size()); - decryptor.decrypt(product, plain); - ASSERT_EQ(plain.to_string(), "3Fx^7 + 3Ex^6 + 3Dx^5 + 3Dx^4 + 3Dx^3 + 3Ex^2 + 3Fx^1"); - ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); - ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + plain = "1x^1"; + encryptor.encrypt(plain, encrypted); + evaluator.apply_galois_inplace(encrypted, 1, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 3, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^3" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 5, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("100x^7" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 15, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^1" == plain.to_string()); - plain1 = "1"; - plain2 = "3F"; - plain3 = "1"; - plain4 = "3F"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); - encryptor.encrypt(plain4, encrypted4); - encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; - evaluator.multiply_many(encrypteds, rlk, product); - ASSERT_EQ(4, encrypteds.size()); - decryptor.decrypt(product, plain); - ASSERT_EQ(plain.to_string(), "1"); - ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted4.parms_id() == product.parms_id()); - ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + plain = "1x^2"; + encryptor.encrypt(plain, encrypted); + evaluator.apply_galois_inplace(encrypted, 1, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^2" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 3, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^6" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 5, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("100x^6" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 15, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^2" == plain.to_string()); - plain1 = "1x^16 + 1x^15 + 1x^8 + 1x^7 + 1x^6 + 1x^3 + 1x^2 + 1"; - plain2 = "0"; - plain3 = "1x^13 + 1x^12 + 1x^5 + 1x^4 + 1x^3 + 1"; - plain4 = "1x^15 + 1x^10 + 1x^9 + 1x^8 + 1x^2 + 1x^1 + 1"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); - encryptor.encrypt(plain4, encrypted4); - encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; - evaluator.multiply_many(encrypteds, rlk, product); - ASSERT_EQ(4, encrypteds.size()); - decryptor.decrypt(product, plain); - ASSERT_EQ(plain.to_string(), "0"); - ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted4.parms_id() == product.parms_id()); - ASSERT_TRUE(product.parms_id() == context.first_parms_id()); + plain = "1x^3 + 2x^2 + 1x^1 + 1"; + encryptor.encrypt(plain, encrypted); + evaluator.apply_galois_inplace(encrypted, 1, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 3, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("2x^6 + 1x^3 + 100x^1 + 1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 5, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("100x^7 + FFx^6 + 100x^5 + 1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 15, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string()); } - TEST(EvaluatorTest, BFVEncryptExponentiateDecrypt) + TEST(EvaluatorTest, BFVEncryptRotateMatrixDecrypt) { EncryptionParameters parms(scheme_type::bfv); - Modulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); + Modulus plain_modulus(257); + parms.set_poly_modulus_degree(8); parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); + parms.set_coeff_modulus(CoeffModulus::Create(8, { 40, 40 })); SEALContext context(parms, false, sec_level_type::none); KeyGenerator keygen(context); PublicKey pk; keygen.create_public_key(pk); - RelinKeys rlk; - keygen.create_relin_keys(rlk); + GaloisKeys glk; + keygen.create_galois_keys(glk); Encryptor encryptor(context, pk); Evaluator evaluator(context); Decryptor decryptor(context, keygen.secret_key()); + BatchEncoder batch_encoder(context); - Ciphertext encrypted; Plaintext plain; - - plain = "1x^2 + 1"; + vector plain_vec{ 1, 2, 3, 4, 5, 6, 7, 8 }; + batch_encoder.encode(plain_vec, plain); + Ciphertext encrypted; encryptor.encrypt(plain, encrypted); - evaluator.exponentiate_inplace(encrypted, 1, rlk); + + evaluator.rotate_columns_inplace(encrypted, glk); decryptor.decrypt(encrypted, plain); - ASSERT_EQ(plain.to_string(), "1x^2 + 1"); - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + batch_encoder.decode(plain, plain_vec); + ASSERT_TRUE((plain_vec == vector{ 5, 6, 7, 8, 1, 2, 3, 4 })); - plain = "1x^2 + 1x^1 + 1"; - encryptor.encrypt(plain, encrypted); - evaluator.exponentiate_inplace(encrypted, 2, rlk); + evaluator.rotate_rows_inplace(encrypted, -1, glk); decryptor.decrypt(encrypted, plain); - ASSERT_EQ(plain.to_string(), "1x^4 + 2x^3 + 3x^2 + 2x^1 + 1"); - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + batch_encoder.decode(plain, plain_vec); + ASSERT_TRUE((plain_vec == vector{ 8, 5, 6, 7, 4, 1, 2, 3 })); - plain = "3Fx^2 + 3Fx^1 + 3F"; - encryptor.encrypt(plain, encrypted); - evaluator.exponentiate_inplace(encrypted, 3, rlk); + evaluator.rotate_rows_inplace(encrypted, 2, glk); decryptor.decrypt(encrypted, plain); - ASSERT_EQ(plain.to_string(), "3Fx^6 + 3Dx^5 + 3Ax^4 + 39x^3 + 3Ax^2 + 3Dx^1 + 3F"); - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + batch_encoder.decode(plain, plain_vec); + ASSERT_TRUE((plain_vec == vector{ 6, 7, 8, 5, 2, 3, 4, 1 })); - plain = "1x^8"; - encryptor.encrypt(plain, encrypted); - evaluator.exponentiate_inplace(encrypted, 4, rlk); + evaluator.rotate_columns_inplace(encrypted, glk); decryptor.decrypt(encrypted, plain); - ASSERT_EQ(plain.to_string(), "1x^32"); - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - } + batch_encoder.decode(plain, plain_vec); + ASSERT_TRUE((plain_vec == vector{ 2, 3, 4, 1, 6, 7, 8, 5 })); - TEST(EvaluatorTest, BFVEncryptAddManyDecrypt) + evaluator.rotate_rows_inplace(encrypted, 0, glk); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, plain_vec); + ASSERT_TRUE((plain_vec == vector{ 2, 3, 4, 1, 6, 7, 8, 5 })); + } + + TEST(EvaluatorTest, BFVEncryptModSwitchToNextDecrypt) { - EncryptionParameters parms(scheme_type::bfv); + // The common parameters: the plaintext and the polynomial moduli Modulus plain_modulus(1 << 6); + + // The parameters and the context of the higher level + EncryptionParameters parms(scheme_type::bfv); parms.set_poly_modulus_degree(128); parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30 })); - SEALContext context(parms, false, sec_level_type::none); + SEALContext context(parms, true, sec_level_type::none); KeyGenerator keygen(context); + SecretKey secret_key = keygen.secret_key(); PublicKey pk; keygen.create_public_key(pk); Encryptor encryptor(context, pk); Evaluator evaluator(context); Decryptor decryptor(context, keygen.secret_key()); + auto parms_id = context.first_parms_id(); - Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, sum; - Plaintext plain, plain1, plain2, plain3, plain4; - - plain1 = "1x^2 + 1"; - plain2 = "1x^2 + 1x^1"; - plain3 = "1x^2 + 1x^1 + 1"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); - vector encrypteds = { encrypted1, encrypted2, encrypted3 }; - evaluator.add_many(encrypteds, sum); - decryptor.decrypt(sum, plain); - ASSERT_EQ(plain.to_string(), "3x^2 + 2x^1 + 2"); - ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); - ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + Ciphertext encrypted(context); + Ciphertext encryptedRes; + Plaintext plain; - plain1 = "3Fx^3 + 3F"; - plain2 = "3Fx^4 + 3F"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encrypteds = { - encrypted1, - encrypted2, - }; - evaluator.add_many(encrypteds, sum); - decryptor.decrypt(sum, plain); - ASSERT_EQ(plain.to_string(), "3Fx^4 + 3Fx^3 + 3E"); - ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); - ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + plain = 0; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); - plain1 = "1x^1"; - plain2 = "3Fx^4 + 3Fx^3 + 3Fx^2 + 3Fx^1 + 3F"; - plain3 = "1x^2 + 1x^1 + 1"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); - encrypteds = { encrypted1, encrypted2, encrypted3 }; - evaluator.add_many(encrypteds, sum); - decryptor.decrypt(sum, plain); - ASSERT_EQ(plain.to_string(), "3Fx^4 + 3Fx^3 + 1x^1"); - ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); - ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); - plain1 = "1"; - plain2 = "3F"; - plain3 = "1"; - plain4 = "3F"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); - encryptor.encrypt(plain4, encrypted4); - encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; - evaluator.add_many(encrypteds, sum); - decryptor.decrypt(sum, plain); - ASSERT_EQ(plain.to_string(), "0"); - ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); - ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); + parms_id = context.first_parms_id(); + plain = 1; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); - plain1 = "1x^16 + 1x^15 + 1x^8 + 1x^7 + 1x^6 + 1x^3 + 1x^2 + 1"; - plain2 = "0"; - plain3 = "1x^13 + 1x^12 + 1x^5 + 1x^4 + 1x^3 + 1"; - plain4 = "1x^15 + 1x^10 + 1x^9 + 1x^8 + 1x^2 + 1x^1 + 1"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); - encryptor.encrypt(plain4, encrypted4); - encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; - evaluator.add_many(encrypteds, sum); - decryptor.decrypt(sum, plain); - ASSERT_EQ( - plain.to_string(), - "1x^16 + 2x^15 + 1x^13 + 1x^12 + 1x^10 + 1x^9 + 2x^8 + 1x^7 + 1x^6 + 1x^5 + 1x^4 + 2x^3 + 2x^2 + 1x^1 + 3"); - ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); - ASSERT_TRUE(sum.parms_id() == context.first_parms_id()); - } + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); - TEST(EvaluatorTest, TransformPlainToNTT) - { - EncryptionParameters parms(scheme_type::bfv); - Modulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); - SEALContext context(parms, true, sec_level_type::none); + parms_id = context.first_parms_id(); + plain = "1x^127"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); - Evaluator evaluator(context); - Plaintext plain("0"); - ASSERT_FALSE(plain.is_ntt_form()); - evaluator.transform_to_ntt_inplace(plain, context.first_parms_id()); - ASSERT_TRUE(plain.is_zero()); - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == context.first_parms_id()); - - plain.release(); - plain = "0"; - ASSERT_FALSE(plain.is_ntt_form()); - auto next_parms_id = context.first_context_data()->next_context_data()->parms_id(); - evaluator.transform_to_ntt_inplace(plain, next_parms_id); - ASSERT_TRUE(plain.is_zero()); - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == next_parms_id); - - plain.release(); - plain = "1"; - ASSERT_FALSE(plain.is_ntt_form()); - evaluator.transform_to_ntt_inplace(plain, context.first_parms_id()); - for (size_t i = 0; i < 256; i++) - { - ASSERT_TRUE(plain[i] == uint64_t(1)); - } - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == context.first_parms_id()); + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); - plain.release(); - plain = "1"; - ASSERT_FALSE(plain.is_ntt_form()); - evaluator.transform_to_ntt_inplace(plain, next_parms_id); - for (size_t i = 0; i < 128; i++) - { - ASSERT_TRUE(plain[i] == uint64_t(1)); - } - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == next_parms_id); - - plain.release(); - plain = "2"; - ASSERT_FALSE(plain.is_ntt_form()); - evaluator.transform_to_ntt_inplace(plain, context.first_parms_id()); - for (size_t i = 0; i < 256; i++) - { - ASSERT_TRUE(plain[i] == uint64_t(2)); - } - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == context.first_parms_id()); + parms_id = context.first_parms_id(); + plain = "5x^64 + Ax^5"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); - plain.release(); - plain = "2"; - evaluator.transform_to_ntt_inplace(plain, next_parms_id); - for (size_t i = 0; i < 128; i++) - { - ASSERT_TRUE(plain[i] == uint64_t(2)); - } - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == next_parms_id); + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); } - TEST(EvaluatorTest, TransformEncryptedToFromNTT) + TEST(EvaluatorTest, BFVEncryptModSwitchToDecrypt) { - EncryptionParameters parms(scheme_type::bfv); + // The common parameters: the plaintext and the polynomial moduli Modulus plain_modulus(1 << 6); + + // The parameters and the context of the higher level + EncryptionParameters parms(scheme_type::bfv); parms.set_poly_modulus_degree(128); parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30 })); - SEALContext context(parms, false, sec_level_type::none); + SEALContext context(parms, true, sec_level_type::none); KeyGenerator keygen(context); + SecretKey secret_key = keygen.secret_key(); PublicKey pk; keygen.create_public_key(pk); Encryptor encryptor(context, pk); Evaluator evaluator(context); Decryptor decryptor(context, keygen.secret_key()); + auto parms_id = context.first_parms_id(); + Ciphertext encrypted(context); Plaintext plain; - Ciphertext encrypted; - plain = "0"; + + plain = 0; encryptor.encrypt(plain, encrypted); - evaluator.transform_to_ntt_inplace(encrypted); - evaluator.transform_from_ntt_inplace(encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); ASSERT_TRUE(plain.to_string() == "0"); - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - plain = "1"; + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); encryptor.encrypt(plain, encrypted); - evaluator.transform_to_ntt_inplace(encrypted); - evaluator.transform_from_ntt_inplace(encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + parms_id = context.first_parms_id(); + encryptor.encrypt(plain, encrypted); + parms_id = context.get_context_data(parms_id)->next_context_data()->next_context_data()->parms_id(); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + parms_id = context.first_parms_id(); + plain = 1; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); ASSERT_TRUE(plain.to_string() == "1"); - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); - plain = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"; + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); encryptor.encrypt(plain, encrypted); - evaluator.transform_to_ntt_inplace(encrypted); - evaluator.transform_from_ntt_inplace(encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(plain.to_string() == "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"); - ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + parms_id = context.first_parms_id(); + encryptor.encrypt(plain, encrypted); + parms_id = context.get_context_data(parms_id)->next_context_data()->next_context_data()->parms_id(); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + parms_id = context.first_parms_id(); + plain = "1x^127"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + parms_id = context.first_parms_id(); + encryptor.encrypt(plain, encrypted); + parms_id = context.get_context_data(parms_id)->next_context_data()->next_context_data()->parms_id(); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + parms_id = context.first_parms_id(); + plain = "5x^64 + Ax^5"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + + parms_id = context.first_parms_id(); + encryptor.encrypt(plain, encrypted); + parms_id = context.get_context_data(parms_id)->next_context_data()->next_context_data()->parms_id(); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); } - TEST(EvaluatorTest, BFVEncryptMultiplyPlainNTTDecrypt) + TEST(EvaluatorTest, BGVEncryptMultiplyPlainNTTDecrypt) { - EncryptionParameters parms(scheme_type::bfv); + EncryptionParameters parms(scheme_type::bgv); Modulus plain_modulus(1 << 6); parms.set_poly_modulus_degree(128); parms.set_plain_modulus(plain_modulus); @@ -3935,9 +5696,9 @@ namespace sealtest ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id()); } - TEST(EvaluatorTest, BFVEncryptApplyGaloisDecrypt) + TEST(EvaluatorTest, BGVEncryptApplyGaloisDecrypt) { - EncryptionParameters parms(scheme_type::bfv); + EncryptionParameters parms(scheme_type::bgv); Modulus plain_modulus(257); parms.set_poly_modulus_degree(8); parms.set_plain_modulus(plain_modulus); @@ -4016,9 +5777,9 @@ namespace sealtest ASSERT_TRUE("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string()); } - TEST(EvaluatorTest, BFVEncryptRotateMatrixDecrypt) + TEST(EvaluatorTest, BGVEncryptRotateMatrixDecrypt) { - EncryptionParameters parms(scheme_type::bfv); + EncryptionParameters parms(scheme_type::bgv); Modulus plain_modulus(257); parms.set_poly_modulus_degree(8); parms.set_plain_modulus(plain_modulus); @@ -4067,99 +5828,131 @@ namespace sealtest batch_encoder.decode(plain, plain_vec); ASSERT_TRUE((plain_vec == vector{ 2, 3, 4, 1, 6, 7, 8, 5 })); } - TEST(EvaluatorTest, BFVEncryptModSwitchToNextDecrypt) - { - // The common parameters: the plaintext and the polynomial moduli - Modulus plain_modulus(1 << 6); - - // The parameters and the context of the higher level - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30 })); - SEALContext context(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - SecretKey secret_key = keygen.secret_key(); - PublicKey pk; - keygen.create_public_key(pk); + TEST(EvaluatorTest, BGVEncryptModSwitchToNextDecrypt) + { + { + // The common parameters: the plaintext and the polynomial moduli + Modulus plain_modulus(1 << 6); - Encryptor encryptor(context, pk); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - auto parms_id = context.first_parms_id(); + // The parameters and the context of the higher level + EncryptionParameters parms(scheme_type::bgv); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30 })); - Ciphertext encrypted(context); - Ciphertext encryptedRes; - Plaintext plain; + SEALContext context(parms, true, sec_level_type::none); + KeyGenerator keygen(context); + SecretKey secret_key = keygen.secret_key(); + PublicKey pk; + keygen.create_public_key(pk); - plain = 0; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_next(encrypted, encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "0"); + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + auto parms_id = context.first_parms_id(); - evaluator.mod_switch_to_next_inplace(encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "0"); + Ciphertext encrypted(context); + Ciphertext encryptedRes; + Plaintext plain; - parms_id = context.first_parms_id(); - plain = 1; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_next(encrypted, encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1"); + plain = 0; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + parms_id = context.first_parms_id(); + plain = 1; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + parms_id = context.first_parms_id(); + plain = "1x^127"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + parms_id = context.first_parms_id(); + plain = "5x^64 + Ax^5"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + } + { + //Consider the case of qi mod p != 1 + Modulus plain_modulus(786433); - evaluator.mod_switch_to_next_inplace(encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1"); + EncryptionParameters parms(scheme_type::bgv); + parms.set_poly_modulus_degree(8192); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus(CoeffModulus::BGVDefault(8192)); + SEALContext context(parms, true, sec_level_type::tc128); - parms_id = context.first_parms_id(); - plain = "1x^127"; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_next(encrypted, encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1x^127"); + KeyGenerator keygen(context); + SecretKey secret_key = keygen.secret_key(); + PublicKey pk; + keygen.create_public_key(pk); - evaluator.mod_switch_to_next_inplace(encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1x^127"); + Encryptor encryptor(context, pk); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); - parms_id = context.first_parms_id(); - plain = "5x^64 + Ax^5"; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_next(encrypted, encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + Ciphertext encrypted(context); + Plaintext plain; - evaluator.mod_switch_to_next_inplace(encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + plain = "1"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next_inplace(encrypted); + evaluator.mod_switch_to_next_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(plain.to_string() == "1"); + } } - TEST(EvaluatorTest, BFVEncryptModSwitchToDecrypt) + TEST(EvaluatorTest, BGVEncryptModSwitchToDecrypt) { // The common parameters: the plaintext and the polynomial moduli Modulus plain_modulus(1 << 6); // The parameters and the context of the higher level - EncryptionParameters parms(scheme_type::bfv); + EncryptionParameters parms(scheme_type::bgv); parms.set_poly_modulus_degree(128); parms.set_plain_modulus(plain_modulus); parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30 })); diff --git a/native/tests/seal/galoiskeys.cpp b/native/tests/seal/galoiskeys.cpp index 6fa25d724..c0dfc019d 100644 --- a/native/tests/seal/galoiskeys.cpp +++ b/native/tests/seal/galoiskeys.cpp @@ -18,189 +18,200 @@ namespace sealtest { TEST(GaloisKeysTest, GaloisKeysSaveLoad) { - stringstream stream; - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(65537); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60 })); - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - GaloisKeys keys; - GaloisKeys test_keys; - keys.save(stream); - test_keys.unsafe_load(context, stream); - ASSERT_EQ(keys.data().size(), test_keys.data().size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - ASSERT_EQ(0ULL, keys.data().size()); - - keygen.create_galois_keys(keys); - keys.save(stream); - test_keys.load(context, stream); - ASSERT_EQ(keys.data().size(), test_keys.data().size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - for (size_t j = 0; j < test_keys.data().size(); j++) + auto galoiskey_save_load = [](scheme_type scheme){ + stringstream stream; { - for (size_t i = 0; i < test_keys.data()[j].size(); i++) + + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60 })); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + + GaloisKeys keys; + GaloisKeys test_keys; + keys.save(stream); + test_keys.unsafe_load(context, stream); + ASSERT_EQ(keys.data().size(), test_keys.data().size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(0ULL, keys.data().size()); + + keygen.create_galois_keys(keys); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.data().size(), test_keys.data().size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + for (size_t j = 0; j < test_keys.data().size(); j++) { - ASSERT_EQ(keys.data()[j][i].data().size(), test_keys.data()[j][i].data().size()); - ASSERT_EQ( - keys.data()[j][i].data().dyn_array().size(), test_keys.data()[j][i].data().dyn_array().size()); - ASSERT_TRUE(is_equal_uint( - keys.data()[j][i].data().data(), test_keys.data()[j][i].data().data(), - keys.data()[j][i].data().dyn_array().size())); + for (size_t i = 0; i < test_keys.data()[j].size(); i++) + { + ASSERT_EQ(keys.data()[j][i].data().size(), test_keys.data()[j][i].data().size()); + ASSERT_EQ( + keys.data()[j][i].data().dyn_array().size(), test_keys.data()[j][i].data().dyn_array().size()); + ASSERT_TRUE(is_equal_uint( + keys.data()[j][i].data().data(), test_keys.data()[j][i].data().data(), + keys.data()[j][i].data().dyn_array().size())); + } } + ASSERT_EQ(64ULL, keys.data().size()); } - ASSERT_EQ(64ULL, keys.data().size()); - } - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(65537); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); - - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - GaloisKeys keys; - GaloisKeys test_keys; - keys.save(stream); - test_keys.unsafe_load(context, stream); - ASSERT_EQ(keys.data().size(), test_keys.data().size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - ASSERT_EQ(0ULL, keys.data().size()); - - keygen.create_galois_keys(keys); - keys.save(stream); - test_keys.load(context, stream); - ASSERT_EQ(keys.data().size(), test_keys.data().size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - for (size_t j = 0; j < test_keys.data().size(); j++) { - for (size_t i = 0; i < test_keys.data()[j].size(); i++) + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + + GaloisKeys keys; + GaloisKeys test_keys; + keys.save(stream); + test_keys.unsafe_load(context, stream); + ASSERT_EQ(keys.data().size(), test_keys.data().size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(0ULL, keys.data().size()); + + keygen.create_galois_keys(keys); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.data().size(), test_keys.data().size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + for (size_t j = 0; j < test_keys.data().size(); j++) { - ASSERT_EQ(keys.data()[j][i].data().size(), test_keys.data()[j][i].data().size()); - ASSERT_EQ( - keys.data()[j][i].data().dyn_array().size(), test_keys.data()[j][i].data().dyn_array().size()); - ASSERT_TRUE(is_equal_uint( - keys.data()[j][i].data().data(), test_keys.data()[j][i].data().data(), - keys.data()[j][i].data().dyn_array().size())); + for (size_t i = 0; i < test_keys.data()[j].size(); i++) + { + ASSERT_EQ(keys.data()[j][i].data().size(), test_keys.data()[j][i].data().size()); + ASSERT_EQ( + keys.data()[j][i].data().dyn_array().size(), test_keys.data()[j][i].data().dyn_array().size()); + ASSERT_TRUE(is_equal_uint( + keys.data()[j][i].data().data(), test_keys.data()[j][i].data().data(), + keys.data()[j][i].data().dyn_array().size())); + } } + ASSERT_EQ(256ULL, keys.data().size()); } - ASSERT_EQ(256ULL, keys.data().size()); - } + }; + galoiskey_save_load(scheme_type::bfv); + galoiskey_save_load(scheme_type::bgv); } + TEST(GaloisKeysTest, GaloisKeysSeededSaveLoad) { - // Returns true if a, b contains the same error. - auto compare_kswitchkeys = [](const KSwitchKeys &a, const KSwitchKeys &b, const SecretKey &sk, - const SEALContext &context) { - auto compare_error = [](const Ciphertext &a_ct, const Ciphertext &b_ct, const SecretKey &sk1, - const SEALContext &ctx) { - auto get_error = [](const Ciphertext &encrypted, const SecretKey &sk2, const SEALContext &ctx2) { - auto pool = MemoryManager::GetPool(); - auto &ctx2_data = *ctx2.get_context_data(encrypted.parms_id()); - auto &parms = ctx2_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_modulus_size = coeff_modulus.size(); - size_t rns_poly_uint64_count = util::mul_safe(coeff_count, coeff_modulus_size); - - DynArray error; - error.resize(rns_poly_uint64_count); - auto destination = error.begin(); - - auto copy_operand1(util::allocate_uint(coeff_count, pool)); - for (size_t i = 0; i < coeff_modulus_size; i++) - { - // Initialize pointers for multiplication - const uint64_t *encrypted_ptr = encrypted.data(1) + (i * coeff_count); - const uint64_t *secret_key_ptr = sk2.data().data() + (i * coeff_count); - uint64_t *destination_ptr = destination + (i * coeff_count); - util::set_zero_uint(coeff_count, destination_ptr); - util::set_uint(encrypted_ptr, coeff_count, copy_operand1.get()); - // compute c_{j+1} * s^{j+1} - util::dyadic_product_coeffmod( - copy_operand1.get(), secret_key_ptr, coeff_count, coeff_modulus[i], copy_operand1.get()); - // add c_{j+1} * s^{j+1} to destination - util::add_poly_coeffmod( - destination_ptr, copy_operand1.get(), coeff_count, coeff_modulus[i], destination_ptr); - // add c_0 into destination - util::add_poly_coeffmod( - destination_ptr, encrypted.data() + (i * coeff_count), coeff_count, coeff_modulus[i], - destination_ptr); - } - return error; + auto galoiskey_seeded_save_load = [](scheme_type scheme){ + // Returns true if a, b contains the same error. + auto compare_kswitchkeys = [](const KSwitchKeys &a, const KSwitchKeys &b, const SecretKey &sk, + const SEALContext &context) { + auto compare_error = [](const Ciphertext &a_ct, const Ciphertext &b_ct, const SecretKey &sk1, + const SEALContext &ctx) { + auto get_error = [](const Ciphertext &encrypted, const SecretKey &sk2, const SEALContext &ctx2) { + auto pool = MemoryManager::GetPool(); + auto &ctx2_data = *ctx2.get_context_data(encrypted.parms_id()); + auto &parms = ctx2_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t rns_poly_uint64_count = util::mul_safe(coeff_count, coeff_modulus_size); + + DynArray error; + error.resize(rns_poly_uint64_count); + auto destination = error.begin(); + + auto copy_operand1(util::allocate_uint(coeff_count, pool)); + for (size_t i = 0; i < coeff_modulus_size; i++) + { + // Initialize pointers for multiplication + const uint64_t *encrypted_ptr = encrypted.data(1) + (i * coeff_count); + const uint64_t *secret_key_ptr = sk2.data().data() + (i * coeff_count); + uint64_t *destination_ptr = destination + (i * coeff_count); + util::set_zero_uint(coeff_count, destination_ptr); + util::set_uint(encrypted_ptr, coeff_count, copy_operand1.get()); + // compute c_{j+1} * s^{j+1} + util::dyadic_product_coeffmod( + copy_operand1.get(), secret_key_ptr, coeff_count, coeff_modulus[i], copy_operand1.get()); + // add c_{j+1} * s^{j+1} to destination + util::add_poly_coeffmod( + destination_ptr, copy_operand1.get(), coeff_count, coeff_modulus[i], destination_ptr); + // add c_0 into destination + util::add_poly_coeffmod( + destination_ptr, encrypted.data() + (i * coeff_count), coeff_count, coeff_modulus[i], + destination_ptr); + } + return error; + }; + + auto error_a = get_error(a_ct, sk1, ctx); + auto error_b = get_error(b_ct, sk1, ctx); + ASSERT_EQ(error_a.size(), error_b.size()); + ASSERT_TRUE(is_equal_uint(error_a.cbegin(), error_b.cbegin(), error_a.size())); }; - auto error_a = get_error(a_ct, sk1, ctx); - auto error_b = get_error(b_ct, sk1, ctx); - ASSERT_EQ(error_a.size(), error_b.size()); - ASSERT_TRUE(is_equal_uint(error_a.cbegin(), error_b.cbegin(), error_a.size())); + ASSERT_EQ(a.size(), b.size()); + auto iter_a = a.data().begin(); + auto iter_b = b.data().begin(); + for (; iter_a != a.data().end(); iter_a++, iter_b++) + { + ASSERT_EQ(iter_a->size(), iter_b->size()); + auto pk_a = iter_a->begin(); + auto pk_b = iter_b->begin(); + for (; pk_a != iter_a->end(); pk_a++, pk_b++) + { + compare_error(pk_a->data(), pk_b->data(), sk, context); + } + } }; - ASSERT_EQ(a.size(), b.size()); - auto iter_a = a.data().begin(); - auto iter_b = b.data().begin(); - for (; iter_a != a.data().end(); iter_a++, iter_b++) + stringstream stream; { - ASSERT_EQ(iter_a->size(), iter_b->size()); - auto pk_a = iter_a->begin(); - auto pk_b = iter_b->begin(); - for (; pk_a != iter_a->end(); pk_a++, pk_b++) + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(8); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus(CoeffModulus::Create(8, { 60, 60 })); + prng_seed_type seed; + for (auto &i : seed) { - compare_error(pk_a->data(), pk_b->data(), sk, context); + i = random_uint64(); } - } - }; + auto rng = make_shared(Blake2xbPRNGFactory(seed)); + parms.set_random_generator(rng); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + SecretKey secret_key = keygen.secret_key(); - stringstream stream; - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(8); - parms.set_plain_modulus(65537); - parms.set_coeff_modulus(CoeffModulus::Create(8, { 60, 60 })); - prng_seed_type seed; - for (auto &i : seed) - { - i = random_uint64(); + keygen.create_galois_keys().save(stream); + GaloisKeys test_keys; + test_keys.load(context, stream); + GaloisKeys keys; + keygen.create_galois_keys(keys); + compare_kswitchkeys(keys, test_keys, secret_key, context); } - auto rng = make_shared(Blake2xbPRNGFactory(seed)); - parms.set_random_generator(rng); - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - SecretKey secret_key = keygen.secret_key(); - - keygen.create_galois_keys().save(stream); - GaloisKeys test_keys; - test_keys.load(context, stream); - GaloisKeys keys; - keygen.create_galois_keys(keys); - compare_kswitchkeys(keys, test_keys, secret_key, context); - } - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(65537); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); - prng_seed_type seed; - for (auto &i : seed) { - i = random_uint64(); + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); + prng_seed_type seed; + for (auto &i : seed) + { + i = random_uint64(); + } + auto rng = make_shared(Blake2xbPRNGFactory(seed)); + parms.set_random_generator(rng); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + SecretKey secret_key = keygen.secret_key(); + + keygen.create_galois_keys().save(stream); + GaloisKeys test_keys; + test_keys.load(context, stream); + GaloisKeys keys; + keygen.create_galois_keys(keys); + compare_kswitchkeys(keys, test_keys, secret_key, context); } - auto rng = make_shared(Blake2xbPRNGFactory(seed)); - parms.set_random_generator(rng); - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - SecretKey secret_key = keygen.secret_key(); - - keygen.create_galois_keys().save(stream); - GaloisKeys test_keys; - test_keys.load(context, stream); - GaloisKeys keys; - keygen.create_galois_keys(keys); - compare_kswitchkeys(keys, test_keys, secret_key, context); - } + }; + + galoiskey_seeded_save_load(scheme_type::bfv); + galoiskey_seeded_save_load(scheme_type::bgv); } } // namespace sealtest diff --git a/native/tests/seal/keygenerator.cpp b/native/tests/seal/keygenerator.cpp index 9695f04a0..df023be3c 100644 --- a/native/tests/seal/keygenerator.cpp +++ b/native/tests/seal/keygenerator.cpp @@ -160,6 +160,151 @@ namespace sealtest } } + TEST(KeyGeneratorTest, BGVKeyGeneration) + { + EncryptionParameters parms(scheme_type::bgv); + { + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + + ASSERT_THROW(auto evk = keygen.create_relin_keys(), logic_error); + ASSERT_THROW(auto galk = keygen.create_galois_keys(), logic_error); + } + { + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60 })); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + + RelinKeys evk; + keygen.create_relin_keys(evk); + ASSERT_TRUE(evk.parms_id() == context.key_parms_id()); + ASSERT_EQ(1ULL, evk.key(2).size()); + for (auto &a : evk.data()) + { + for (auto &b : a) + { + ASSERT_FALSE(b.data().is_transparent()); + } + } + ASSERT_TRUE(is_valid_for(evk, context)); + + GaloisKeys galks; + keygen.create_galois_keys(galks); + for (auto &a : galks.data()) + { + for (auto &b : a) + { + ASSERT_FALSE(b.data().is_transparent()); + } + } + ASSERT_TRUE(is_valid_for(galks, context)); + + ASSERT_TRUE(galks.parms_id() == context.key_parms_id()); + ASSERT_EQ(1ULL, galks.key(3).size()); + ASSERT_EQ(10ULL, galks.size()); + + keygen.create_galois_keys(vector{ 1, 3, 5, 7 }, galks); + ASSERT_TRUE(galks.parms_id() == context.key_parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(127)); + ASSERT_EQ(1ULL, galks.key(1).size()); + ASSERT_EQ(1ULL, galks.key(3).size()); + ASSERT_EQ(1ULL, galks.key(5).size()); + ASSERT_EQ(1ULL, galks.key(7).size()); + ASSERT_EQ(4ULL, galks.size()); + + keygen.create_galois_keys(vector{ 1 }, galks); + ASSERT_TRUE(galks.parms_id() == context.key_parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_FALSE(galks.has_key(3)); + ASSERT_FALSE(galks.has_key(127)); + ASSERT_EQ(1ULL, galks.key(1).size()); + ASSERT_EQ(1ULL, galks.size()); + + keygen.create_galois_keys(vector{ 127 }, galks); + ASSERT_TRUE(galks.parms_id() == context.key_parms_id()); + ASSERT_FALSE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(127)); + ASSERT_EQ(1ULL, galks.key(127).size()); + ASSERT_EQ(1ULL, galks.size()); + } + { + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 30, 30 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + + RelinKeys evk; + keygen.create_relin_keys(evk); + ASSERT_TRUE(evk.parms_id() == context.key_parms_id()); + ASSERT_EQ(2ULL, evk.key(2).size()); + + for (auto &a : evk.data()) + { + for (auto &b : a) + { + ASSERT_FALSE(b.data().is_transparent()); + } + } + ASSERT_TRUE(is_valid_for(evk, context)); + + GaloisKeys galks; + keygen.create_galois_keys(galks); + for (auto &a : galks.data()) + { + for (auto &b : a) + { + ASSERT_FALSE(b.data().is_transparent()); + } + } + ASSERT_TRUE(is_valid_for(galks, context)); + + ASSERT_TRUE(galks.parms_id() == context.key_parms_id()); + ASSERT_EQ(2ULL, galks.key(3).size()); + ASSERT_EQ(14ULL, galks.size()); + + keygen.create_galois_keys(vector{ 1, 3, 5, 7 }, galks); + ASSERT_TRUE(galks.parms_id() == context.key_parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(511)); + ASSERT_EQ(2ULL, galks.key(1).size()); + ASSERT_EQ(2ULL, galks.key(3).size()); + ASSERT_EQ(2ULL, galks.key(5).size()); + ASSERT_EQ(2ULL, galks.key(7).size()); + ASSERT_EQ(4ULL, galks.size()); + + keygen.create_galois_keys(vector{ 1 }, galks); + ASSERT_TRUE(galks.parms_id() == context.key_parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_FALSE(galks.has_key(3)); + ASSERT_FALSE(galks.has_key(511)); + ASSERT_EQ(2ULL, galks.key(1).size()); + ASSERT_EQ(1ULL, galks.size()); + + keygen.create_galois_keys(vector{ 511 }, galks); + ASSERT_TRUE(galks.parms_id() == context.key_parms_id()); + ASSERT_FALSE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(511)); + ASSERT_EQ(2ULL, galks.key(511).size()); + ASSERT_EQ(1ULL, galks.size()); + } + } + TEST(KeyGeneratorTest, CKKSKeyGeneration) { EncryptionParameters parms(scheme_type::ckks); @@ -304,66 +449,71 @@ namespace sealtest TEST(KeyGeneratorTest, Constructors) { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(65537); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 60, 50, 40 })); - SEALContext context(parms, false, sec_level_type::none); - Evaluator evaluator(context); - - KeyGenerator keygen(context); - PublicKey pk; - keygen.create_public_key(pk); - auto sk = keygen.secret_key(); - RelinKeys rlk; - keygen.create_relin_keys(rlk); - GaloisKeys galk; - keygen.create_galois_keys(galk); - - ASSERT_TRUE(is_valid_for(rlk, context)); - ASSERT_TRUE(is_valid_for(galk, context)); - - Encryptor encryptor(context, pk); - Decryptor decryptor(context, sk); - Plaintext pt("1x^2 + 2"), ptres; - Ciphertext ct; - encryptor.encrypt(pt, ct); - evaluator.square_inplace(ct); - evaluator.relinearize_inplace(ct, rlk); - decryptor.decrypt(ct, ptres); - ASSERT_EQ("1x^4 + 4x^2 + 4", ptres.to_string()); - - KeyGenerator keygen2(context, sk); - auto sk2 = keygen.secret_key(); - PublicKey pk2; - keygen2.create_public_key(pk2); - ASSERT_EQ(sk2.data(), sk.data()); - - RelinKeys rlk2; - keygen2.create_relin_keys(rlk2); - GaloisKeys galk2; - keygen2.create_galois_keys(galk2); - - ASSERT_TRUE(is_valid_for(rlk2, context)); - ASSERT_TRUE(is_valid_for(galk2, context)); - - Encryptor encryptor2(context, pk2); - Decryptor decryptor2(context, sk2); - pt = "1x^2 + 2"; - ptres.set_zero(); - encryptor.encrypt(pt, ct); - evaluator.square_inplace(ct); - evaluator.relinearize_inplace(ct, rlk2); - decryptor.decrypt(ct, ptres); - ASSERT_EQ("1x^4 + 4x^2 + 4", ptres.to_string()); - - PublicKey pk3; - keygen2.create_public_key(pk3); - - // There is a small random chance for this to fail - for (size_t i = 0; i < pk3.data().dyn_array().size(); i++) - { - ASSERT_NE(pk3.data().data()[i], pk2.data().data()[i]); - } + auto constructors = [](scheme_type scheme){ + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus(CoeffModulus::Create(128, { 60, 50, 40 })); + SEALContext context(parms, false, sec_level_type::none); + Evaluator evaluator(context); + + KeyGenerator keygen(context); + PublicKey pk; + keygen.create_public_key(pk); + auto sk = keygen.secret_key(); + RelinKeys rlk; + keygen.create_relin_keys(rlk); + GaloisKeys galk; + keygen.create_galois_keys(galk); + + ASSERT_TRUE(is_valid_for(rlk, context)); + ASSERT_TRUE(is_valid_for(galk, context)); + + Encryptor encryptor(context, pk); + Decryptor decryptor(context, sk); + Plaintext pt("1x^2 + 2"), ptres; + Ciphertext ct; + encryptor.encrypt(pt, ct); + evaluator.square_inplace(ct); + evaluator.relinearize_inplace(ct, rlk); + decryptor.decrypt(ct, ptres); + ASSERT_EQ("1x^4 + 4x^2 + 4", ptres.to_string()); + + KeyGenerator keygen2(context, sk); + auto sk2 = keygen.secret_key(); + PublicKey pk2; + keygen2.create_public_key(pk2); + ASSERT_EQ(sk2.data(), sk.data()); + + RelinKeys rlk2; + keygen2.create_relin_keys(rlk2); + GaloisKeys galk2; + keygen2.create_galois_keys(galk2); + + ASSERT_TRUE(is_valid_for(rlk2, context)); + ASSERT_TRUE(is_valid_for(galk2, context)); + + Encryptor encryptor2(context, pk2); + Decryptor decryptor2(context, sk2); + pt = "1x^2 + 2"; + ptres.set_zero(); + encryptor.encrypt(pt, ct); + evaluator.square_inplace(ct); + evaluator.relinearize_inplace(ct, rlk2); + decryptor.decrypt(ct, ptres); + ASSERT_EQ("1x^4 + 4x^2 + 4", ptres.to_string()); + + PublicKey pk3; + keygen2.create_public_key(pk3); + + // There is a small random chance for this to fail + for (size_t i = 0; i < pk3.data().dyn_array().size(); i++) + { + ASSERT_NE(pk3.data().data()[i], pk2.data().data()[i]); + } + }; + + constructors(scheme_type::bfv); + constructors(scheme_type::bgv); } } // namespace sealtest diff --git a/native/tests/seal/plaintext.cpp b/native/tests/seal/plaintext.cpp index a1a5d7ec6..7d08d7ddc 100644 --- a/native/tests/seal/plaintext.cpp +++ b/native/tests/seal/plaintext.cpp @@ -190,6 +190,28 @@ namespace sealtest ASSERT_TRUE(plain.data() != plain2.data()); ASSERT_TRUE(plain2.is_ntt_form()); } + { + EncryptionParameters parms(scheme_type::bgv); + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 30 })); + parms.set_plain_modulus(65537); + + SEALContext context(parms, false, sec_level_type::none); + + plain.parms_id() = parms_id_zero; + plain = "1x^63 + 2x^62 + Fx^32 + Ax^9 + 1x^1 + 1"; + plain.save(stream); + plain2.load(context, stream); + ASSERT_TRUE(plain.data() != plain2.data()); + ASSERT_FALSE(plain2.is_ntt_form()); + + Evaluator evaluator(context); + evaluator.transform_to_ntt_inplace(plain, context.first_parms_id()); + plain.save(stream); + plain2.load(context, stream); + ASSERT_TRUE(plain.data() != plain2.data()); + ASSERT_TRUE(plain2.is_ntt_form()); + } { EncryptionParameters parms(scheme_type::ckks); parms.set_poly_modulus_degree(64); diff --git a/native/tests/seal/publickey.cpp b/native/tests/seal/publickey.cpp index 859bb2215..ac07b8629 100644 --- a/native/tests/seal/publickey.cpp +++ b/native/tests/seal/publickey.cpp @@ -14,54 +14,59 @@ namespace sealtest { TEST(PublicKeyTest, SaveLoadPublicKey) { - stringstream stream; - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); + auto save_load_public_key = [](scheme_type scheme){ + stringstream stream; + { + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); - PublicKey pk; - keygen.create_public_key(pk); - ASSERT_TRUE(pk.parms_id() == context.key_parms_id()); - pk.save(stream); + PublicKey pk; + keygen.create_public_key(pk); + ASSERT_TRUE(pk.parms_id() == context.key_parms_id()); + pk.save(stream); - PublicKey pk2; - pk2.load(context, stream); + PublicKey pk2; + pk2.load(context, stream); - ASSERT_EQ(pk.data().dyn_array().size(), pk2.data().dyn_array().size()); - for (size_t i = 0; i < pk.data().dyn_array().size(); i++) - { - ASSERT_EQ(pk.data().data()[i], pk2.data().data()[i]); + ASSERT_EQ(pk.data().dyn_array().size(), pk2.data().dyn_array().size()); + for (size_t i = 0; i < pk.data().dyn_array().size(); i++) + { + ASSERT_EQ(pk.data().data()[i], pk2.data().data()[i]); + } + ASSERT_TRUE(pk.parms_id() == pk2.parms_id()); } - ASSERT_TRUE(pk.parms_id() == pk2.parms_id()); - } - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(1 << 20); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 30, 40 })); + { + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(1 << 20); + parms.set_coeff_modulus(CoeffModulus::Create(256, { 30, 40 })); - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); - PublicKey pk; - keygen.create_public_key(pk); - ASSERT_TRUE(pk.parms_id() == context.key_parms_id()); - pk.save(stream); + PublicKey pk; + keygen.create_public_key(pk); + ASSERT_TRUE(pk.parms_id() == context.key_parms_id()); + pk.save(stream); - PublicKey pk2; - pk2.load(context, stream); + PublicKey pk2; + pk2.load(context, stream); - ASSERT_EQ(pk.data().dyn_array().size(), pk2.data().dyn_array().size()); - for (size_t i = 0; i < pk.data().dyn_array().size(); i++) - { - ASSERT_EQ(pk.data().data()[i], pk2.data().data()[i]); + ASSERT_EQ(pk.data().dyn_array().size(), pk2.data().dyn_array().size()); + for (size_t i = 0; i < pk.data().dyn_array().size(); i++) + { + ASSERT_EQ(pk.data().data()[i], pk2.data().data()[i]); + } + ASSERT_TRUE(pk.parms_id() == pk2.parms_id()); } - ASSERT_TRUE(pk.parms_id() == pk2.parms_id()); - } + }; + + save_load_public_key(scheme_type::bfv); + save_load_public_key(scheme_type::bgv); } } // namespace sealtest diff --git a/native/tests/seal/relinkeys.cpp b/native/tests/seal/relinkeys.cpp index c939eefa4..42105f95b 100644 --- a/native/tests/seal/relinkeys.cpp +++ b/native/tests/seal/relinkeys.cpp @@ -17,177 +17,186 @@ namespace sealtest { TEST(RelinKeysTest, RelinKeysSaveLoad) { - stringstream stream; - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60 })); - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - RelinKeys keys; - RelinKeys test_keys; - keygen.create_relin_keys(keys); - keys.save(stream); - test_keys.load(context, stream); - ASSERT_EQ(keys.size(), test_keys.size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - for (size_t j = 0; j < test_keys.size(); j++) + auto relin_keys_save_load = [](scheme_type scheme){ + stringstream stream; { - for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60 })); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + + RelinKeys keys; + RelinKeys test_keys; + keygen.create_relin_keys(keys); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + for (size_t j = 0; j < test_keys.size(); j++) { - ASSERT_EQ(keys.key(j + 2)[i].data().size(), test_keys.key(j + 2)[i].data().size()); - ASSERT_EQ( - keys.key(j + 2)[i].data().dyn_array().size(), - test_keys.key(j + 2)[i].data().dyn_array().size()); - ASSERT_TRUE(is_equal_uint( - keys.key(j + 2)[i].data().data(), test_keys.key(j + 2)[i].data().data(), - keys.key(j + 2)[i].data().dyn_array().size())); + for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + { + ASSERT_EQ(keys.key(j + 2)[i].data().size(), test_keys.key(j + 2)[i].data().size()); + ASSERT_EQ( + keys.key(j + 2)[i].data().dyn_array().size(), + test_keys.key(j + 2)[i].data().dyn_array().size()); + ASSERT_TRUE(is_equal_uint( + keys.key(j + 2)[i].data().data(), test_keys.key(j + 2)[i].data().data(), + keys.key(j + 2)[i].data().dyn_array().size())); + } } } - } - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); + { + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); - RelinKeys keys; - RelinKeys test_keys; - keygen.create_relin_keys(keys); - keys.save(stream); - test_keys.load(context, stream); - ASSERT_EQ(keys.size(), test_keys.size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - for (size_t j = 0; j < test_keys.size(); j++) - { - for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + RelinKeys keys; + RelinKeys test_keys; + keygen.create_relin_keys(keys); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + for (size_t j = 0; j < test_keys.size(); j++) { - ASSERT_EQ(keys.key(j + 2)[i].data().size(), test_keys.key(j + 2)[i].data().size()); - ASSERT_EQ( - keys.key(j + 2)[i].data().dyn_array().size(), - test_keys.key(j + 2)[i].data().dyn_array().size()); - ASSERT_TRUE(is_equal_uint( - keys.key(j + 2)[i].data().data(), test_keys.key(j + 2)[i].data().data(), - keys.key(j + 2)[i].data().dyn_array().size())); + for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + { + ASSERT_EQ(keys.key(j + 2)[i].data().size(), test_keys.key(j + 2)[i].data().size()); + ASSERT_EQ( + keys.key(j + 2)[i].data().dyn_array().size(), + test_keys.key(j + 2)[i].data().dyn_array().size()); + ASSERT_TRUE(is_equal_uint( + keys.key(j + 2)[i].data().data(), test_keys.key(j + 2)[i].data().data(), + keys.key(j + 2)[i].data().dyn_array().size())); + } } } - } + }; + + relin_keys_save_load(scheme_type::bfv); + relin_keys_save_load(scheme_type::bgv); } TEST(RelinKeysTest, RelinKeysSeededSaveLoad) { - // Returns true if a, b contains the same error. - auto compare_kswitchkeys = [](const KSwitchKeys &a, const KSwitchKeys &b, const SecretKey &sk, - const SEALContext &context) { - auto compare_error = [](const Ciphertext &a_ct, const Ciphertext &b_ct, const SecretKey &sk1, - const SEALContext &context1) { - auto get_error = [](const Ciphertext &encrypted, const SecretKey &sk2, const SEALContext &context2) { - auto pool = MemoryManager::GetPool(); - auto &context_data = *context2.get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_modulus_size = coeff_modulus.size(); - size_t rns_poly_uint64_count = util::mul_safe(coeff_count, coeff_modulus_size); + auto relin_keys_seeded_save_load = [](scheme_type scheme){ + // Returns true if a, b contains the same error. + auto compare_kswitchkeys = [](const KSwitchKeys &a, const KSwitchKeys &b, const SecretKey &sk, + const SEALContext &context) { + auto compare_error = [](const Ciphertext &a_ct, const Ciphertext &b_ct, const SecretKey &sk1, + const SEALContext &context1) { + auto get_error = [](const Ciphertext &encrypted, const SecretKey &sk2, const SEALContext &context2) { + auto pool = MemoryManager::GetPool(); + auto &context_data = *context2.get_context_data(encrypted.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t rns_poly_uint64_count = util::mul_safe(coeff_count, coeff_modulus_size); - DynArray error; - error.resize(rns_poly_uint64_count); - auto destination = error.begin(); + DynArray error; + error.resize(rns_poly_uint64_count); + auto destination = error.begin(); - auto copy_operand1(util::allocate_uint(coeff_count, pool)); - for (size_t i = 0; i < coeff_modulus_size; i++) - { - // Initialize pointers for multiplication - const uint64_t *encrypted_ptr = encrypted.data(1) + (i * coeff_count); - const uint64_t *secret_key_ptr = sk2.data().data() + (i * coeff_count); - uint64_t *destination_ptr = destination + (i * coeff_count); - util::set_zero_uint(coeff_count, destination_ptr); - util::set_uint(encrypted_ptr, coeff_count, copy_operand1.get()); - // compute c_{j+1} * s^{j+1} - util::dyadic_product_coeffmod( - copy_operand1.get(), secret_key_ptr, coeff_count, coeff_modulus[i], copy_operand1.get()); - // add c_{j+1} * s^{j+1} to destination - util::add_poly_coeffmod( - destination_ptr, copy_operand1.get(), coeff_count, coeff_modulus[i], destination_ptr); - // add c_0 into destination - util::add_poly_coeffmod( - destination_ptr, encrypted.data() + (i * coeff_count), coeff_count, coeff_modulus[i], - destination_ptr); - } - return error; + auto copy_operand1(util::allocate_uint(coeff_count, pool)); + for (size_t i = 0; i < coeff_modulus_size; i++) + { + // Initialize pointers for multiplication + const uint64_t *encrypted_ptr = encrypted.data(1) + (i * coeff_count); + const uint64_t *secret_key_ptr = sk2.data().data() + (i * coeff_count); + uint64_t *destination_ptr = destination + (i * coeff_count); + util::set_zero_uint(coeff_count, destination_ptr); + util::set_uint(encrypted_ptr, coeff_count, copy_operand1.get()); + // compute c_{j+1} * s^{j+1} + util::dyadic_product_coeffmod( + copy_operand1.get(), secret_key_ptr, coeff_count, coeff_modulus[i], copy_operand1.get()); + // add c_{j+1} * s^{j+1} to destination + util::add_poly_coeffmod( + destination_ptr, copy_operand1.get(), coeff_count, coeff_modulus[i], destination_ptr); + // add c_0 into destination + util::add_poly_coeffmod( + destination_ptr, encrypted.data() + (i * coeff_count), coeff_count, coeff_modulus[i], + destination_ptr); + } + return error; + }; + + auto error_a = get_error(a_ct, sk1, context1); + auto error_b = get_error(b_ct, sk1, context1); + ASSERT_EQ(error_a.size(), error_b.size()); + ASSERT_TRUE(is_equal_uint(error_a.cbegin(), error_b.cbegin(), error_a.size())); }; - auto error_a = get_error(a_ct, sk1, context1); - auto error_b = get_error(b_ct, sk1, context1); - ASSERT_EQ(error_a.size(), error_b.size()); - ASSERT_TRUE(is_equal_uint(error_a.cbegin(), error_b.cbegin(), error_a.size())); + ASSERT_EQ(a.size(), b.size()); + auto iter_a = a.data().begin(); + auto iter_b = b.data().begin(); + for (; iter_a != a.data().end(); iter_a++, iter_b++) + { + ASSERT_EQ(iter_a->size(), iter_b->size()); + auto pk_a = iter_a->begin(); + auto pk_b = iter_b->begin(); + for (; pk_a != iter_a->end(); pk_a++, pk_b++) + { + compare_error(pk_a->data(), pk_b->data(), sk, context); + } + } }; - ASSERT_EQ(a.size(), b.size()); - auto iter_a = a.data().begin(); - auto iter_b = b.data().begin(); - for (; iter_a != a.data().end(); iter_a++, iter_b++) + stringstream stream; { - ASSERT_EQ(iter_a->size(), iter_b->size()); - auto pk_a = iter_a->begin(); - auto pk_b = iter_b->begin(); - for (; pk_a != iter_a->end(); pk_a++, pk_b++) + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(8); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus(CoeffModulus::Create(8, { 60, 60 })); + prng_seed_type seed; + for (auto &i : seed) { - compare_error(pk_a->data(), pk_b->data(), sk, context); + i = random_uint64(); } - } - }; + auto rng = make_shared(Blake2xbPRNGFactory(seed)); + parms.set_random_generator(rng); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + SecretKey secret_key = keygen.secret_key(); - stringstream stream; - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(8); - parms.set_plain_modulus(65537); - parms.set_coeff_modulus(CoeffModulus::Create(8, { 60, 60 })); - prng_seed_type seed; - for (auto &i : seed) - { - i = random_uint64(); + keygen.create_relin_keys().save(stream); + RelinKeys test_keys; + test_keys.load(context, stream); + RelinKeys keys; + keygen.create_relin_keys(keys); + compare_kswitchkeys(keys, test_keys, secret_key, context); } - auto rng = make_shared(Blake2xbPRNGFactory(seed)); - parms.set_random_generator(rng); - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - SecretKey secret_key = keygen.secret_key(); - - keygen.create_relin_keys().save(stream); - RelinKeys test_keys; - test_keys.load(context, stream); - RelinKeys keys; - keygen.create_relin_keys(keys); - compare_kswitchkeys(keys, test_keys, secret_key, context); - } - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(65537); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); - prng_seed_type seed; - for (auto &i : seed) { - i = random_uint64(); - } - auto rng = make_shared(Blake2xbPRNGFactory(seed)); - parms.set_random_generator(rng); - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - SecretKey secret_key = keygen.secret_key(); + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); + prng_seed_type seed; + for (auto &i : seed) + { + i = random_uint64(); + } + auto rng = make_shared(Blake2xbPRNGFactory(seed)); + parms.set_random_generator(rng); + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + SecretKey secret_key = keygen.secret_key(); - keygen.create_relin_keys().save(stream); - RelinKeys test_keys; - test_keys.load(context, stream); - RelinKeys keys; - keygen.create_relin_keys(keys); - compare_kswitchkeys(keys, test_keys, secret_key, context); - } + keygen.create_relin_keys().save(stream); + RelinKeys test_keys; + test_keys.load(context, stream); + RelinKeys keys; + keygen.create_relin_keys(keys); + compare_kswitchkeys(keys, test_keys, secret_key, context); + } + }; + relin_keys_seeded_save_load(scheme_type::bfv); + relin_keys_seeded_save_load(scheme_type::bgv); } } // namespace sealtest diff --git a/native/tests/seal/secretkey.cpp b/native/tests/seal/secretkey.cpp index 8b8b653d1..c58f90ac6 100644 --- a/native/tests/seal/secretkey.cpp +++ b/native/tests/seal/secretkey.cpp @@ -14,44 +14,49 @@ namespace sealtest { TEST(SecretKeyTest, SaveLoadSecretKey) { - stringstream stream; - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - SecretKey sk = keygen.secret_key(); - ASSERT_TRUE(sk.parms_id() == context.key_parms_id()); - sk.save(stream); - - SecretKey sk2; - sk2.load(context, stream); - - ASSERT_TRUE(sk.data() == sk2.data()); - ASSERT_TRUE(sk.parms_id() == sk2.parms_id()); - } - { - EncryptionParameters parms(scheme_type::bfv); - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(1 << 20); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 30, 40 })); - - SEALContext context(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - SecretKey sk = keygen.secret_key(); - ASSERT_TRUE(sk.parms_id() == context.key_parms_id()); - sk.save(stream); - - SecretKey sk2; - sk2.load(context, stream); - - ASSERT_TRUE(sk.data() == sk2.data()); - ASSERT_TRUE(sk.parms_id() == sk2.parms_id()); - } + auto save_load_secret_key = [](scheme_type scheme){ + stringstream stream; + { + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + + SecretKey sk = keygen.secret_key(); + ASSERT_TRUE(sk.parms_id() == context.key_parms_id()); + sk.save(stream); + + SecretKey sk2; + sk2.load(context, stream); + + ASSERT_TRUE(sk.data() == sk2.data()); + ASSERT_TRUE(sk.parms_id() == sk2.parms_id()); + } + { + EncryptionParameters parms(scheme); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(1 << 20); + parms.set_coeff_modulus(CoeffModulus::Create(256, { 30, 40 })); + + SEALContext context(parms, false, sec_level_type::none); + KeyGenerator keygen(context); + + SecretKey sk = keygen.secret_key(); + ASSERT_TRUE(sk.parms_id() == context.key_parms_id()); + sk.save(stream); + + SecretKey sk2; + sk2.load(context, stream); + + ASSERT_TRUE(sk.data() == sk2.data()); + ASSERT_TRUE(sk.parms_id() == sk2.parms_id()); + } + }; + + save_load_secret_key(scheme_type::bfv); + save_load_secret_key(scheme_type::bgv); } } // namespace sealtest