diff --git a/.github/workflows/mldsa-hax.yml b/.github/workflows/mldsa-hax.yml index eba2902ba..4cfe45079 100644 --- a/.github/workflows/mldsa-hax.yml +++ b/.github/workflows/mldsa-hax.yml @@ -34,11 +34,11 @@ jobs: - uses: hacspec/hax-actions@main with: hax_reference: ${{ github.event.inputs.hax_rev || 'main' }} - fstar: v2025.01.17 + fstar: v2025.03.25 - name: 🏃 Extract ML-DSA crate working-directory: libcrux-ml-dsa - run: ./hax.py extract + run: ./hax.sh extract - name: ↑ Upload F* extraction uses: actions/upload-artifact@v4 @@ -58,7 +58,7 @@ jobs: - uses: hacspec/hax-actions@main with: hax_reference: ${{ github.event.inputs.hax_rev || 'main' }} - fstar: v2025.01.17 + fstar: v2025.03.25 - uses: actions/download-artifact@v4 with: @@ -67,7 +67,7 @@ jobs: - name: 🏃 Lax ML-DSA crate working-directory: libcrux-ml-dsa - run: ./hax.py prove --admit + run: ./hax.sh prove --admit mldsa-extract-hax-status: if: ${{ always() }} diff --git a/fstar-helpers/minicore/Cargo.toml b/fstar-helpers/minicore/Cargo.toml index cd29bd6d5..881d6c859 100644 --- a/fstar-helpers/minicore/Cargo.toml +++ b/fstar-helpers/minicore/Cargo.toml @@ -1,7 +1,14 @@ [package] name = "minicore" -edition = "2021" +version.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +edition.workspace = true +repository.workspace = true +readme.workspace = true publish = false [dependencies] rand = "0.9" +hax-lib.workspace = true diff --git a/fstar-helpers/minicore/src/abstractions/bit.rs b/fstar-helpers/minicore/src/abstractions/bit.rs index 7287c3585..66aee1aed 100644 --- a/fstar-helpers/minicore/src/abstractions/bit.rs +++ b/fstar-helpers/minicore/src/abstractions/bit.rs @@ -68,9 +68,12 @@ impl From for Bit { } /// A trait for types that represent machine integers. +#[hax_lib::attributes] pub trait MachineInteger { /// The size of this integer type in bits. - const BITS: u32; + #[hax_lib::requires(true)] + #[hax_lib::ensures(|bits| bits >= 8)] + fn bits() -> u32; /// The signedness of this integer type. const SIGNED: bool; @@ -78,8 +81,8 @@ pub trait MachineInteger { macro_rules! generate_machine_integer_impls { ($($ty:ident),*) => { - $(impl MachineInteger for $ty { - const BITS: u32 = $ty::BITS; + $(#[hax_lib::exclude]impl MachineInteger for $ty { + fn bits() -> u32 { $ty::BITS } #[allow(unused_comparisons)] const SIGNED: bool = $ty::MIN < 0; })* @@ -87,6 +90,18 @@ macro_rules! generate_machine_integer_impls { } generate_machine_integer_impls!(u8, u16, u32, u64, u128, i8, i16, i32, i64, i128); +#[hax_lib::fstar::replace( + r" +instance impl_MachineInteger_poly (t: inttype): t_MachineInteger (int_t t) = + { f_bits = (fun () -> mk_u32 (bits t)); + f_bits_pre = (fun () -> True); + f_bits_post = (fun () r -> r == mk_u32 (bits t)); + f_SIGNED = signed t } +" +)] +const _: () = {}; + +#[hax_lib::exclude] impl Bit { fn of_raw_int(x: u128, nth: u32) -> Self { if x / 2u128.pow(nth) % 2 == 1 { @@ -101,7 +116,7 @@ impl Bit { if x >= 0 { Self::of_raw_int(x as u128, nth) } else { - Self::of_raw_int((2i128.pow(T::BITS) + x) as u128, nth) + Self::of_raw_int((2i128.pow(T::bits()) + x) as u128, nth) } } } diff --git a/fstar-helpers/minicore/src/abstractions/bitvec.rs b/fstar-helpers/minicore/src/abstractions/bitvec.rs index ad072741a..0bbc4ba79 100644 --- a/fstar-helpers/minicore/src/abstractions/bitvec.rs +++ b/fstar-helpers/minicore/src/abstractions/bitvec.rs @@ -1,10 +1,15 @@ //! This module provides a specification-friendly bit vector type. use super::bit::{Bit, MachineInteger}; - -// TODO: this module uses `u128/i128` as mathematic integers. We should use `hax_lib::int` or bigint. +use super::funarr::*; use std::fmt::Formatter; +// This is required due to some hax-lib inconsistencies with versus without `cfg(hax)`. +#[cfg(hax)] +use hax_lib::{int, ToInt}; + +// TODO: this module uses `u128/i128` as mathematic integers. We should use `hax_lib::int` or bigint. + /// A fixed-size bit vector type. /// /// `BitVec` is a specification-friendly, fixed-length bit vector that internally @@ -15,12 +20,14 @@ use std::fmt::Formatter; /// The [`Debug`] implementation for `BitVec` pretty-prints the bits in groups of eight, /// making the bit pattern more human-readable. The type also implements indexing, /// allowing for easy access to individual bits. +#[hax_lib::fstar::before("noeq")] #[derive(Copy, Clone, Eq, PartialEq)] -pub struct BitVec([Bit; N]); +pub struct BitVec(FunArray); /// Pretty prints a bit slice by group of 8 +#[hax_lib::exclude] fn bit_slice_to_string(bits: &[Bit]) -> String { - bits.into_iter() + bits.iter() .map(|bit| match bit { Bit::Zero => '0', Bit::One => '1', @@ -34,33 +41,38 @@ fn bit_slice_to_string(bits: &[Bit]) -> String { .into() } -impl core::fmt::Debug for BitVec { +#[hax_lib::exclude] +impl core::fmt::Debug for BitVec { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - write!(f, "{}", bit_slice_to_string(&self.0)) + write!(f, "{}", bit_slice_to_string(&self.0.as_vec())) } } -impl core::ops::Index for BitVec { +#[hax_lib::attributes] +impl core::ops::Index for BitVec { type Output = Bit; - fn index(&self, index: usize) -> &Self::Output { - &self.0[index] + #[requires(index < N)] + fn index(&self, index: u64) -> &Self::Output { + self.0.get(index) } } /// Convert a bit slice into an unsigned number. +#[hax_lib::exclude] fn u64_int_from_bit_slice(bits: &[Bit]) -> u64 { - bits.into_iter() + bits.iter() .enumerate() .map(|(i, bit)| u64::from(bit.clone()) << i) .sum::() } /// Convert a bit slice into a machine integer of type `T`. +#[hax_lib::exclude] fn int_from_bit_slice + MachineInteger + Copy>(bits: &[Bit]) -> T { - debug_assert!(bits.len() <= T::BITS as usize); + debug_assert!(bits.len() <= T::bits() as usize); let result = if T::SIGNED { - let is_negative = matches!(bits[T::BITS as usize - 1], Bit::One); - let s = u64_int_from_bit_slice(&bits[0..T::BITS as usize - 1]) as i128; + let is_negative = matches!(bits[T::bits() as usize - 1], Bit::One); + let s = u64_int_from_bit_slice(&bits[0..T::bits() as usize - 1]) as i128; if is_negative { -s } else { @@ -76,31 +88,126 @@ fn int_from_bit_slice + MachineInteger + Copy>(bits: &[Bit]) -> n } -impl BitVec { - /// Constructor for BitVec. `BitVec::::from_fn` constructs a bitvector out of a function that takes usizes smaller than `N` and produces bits. - pub fn from_fn Bit>(f: F) -> Self { - Self(core::array::from_fn(f)) +/// An F* attribute that indiquates a rewritting lemma should be applied +pub const REWRITE_RULE: () = {}; + +#[hax_lib::fstar::replace( + r#" +let ${BitVec::<0>::from_fn::Bit>} + (v_N: u64) + (f: (i: u64 {v i < v v_N}) -> $:{Bit}) + : t_BitVec v_N = + ${BitVec::<0>}(${FunArray::<0,()>::from_fn::()>} v_N f) + +open FStar.FunctionalExtensionality +let ${BitVec::<0>::pointwise} + (v_N: u64) (f: t_BitVec v_N) + (#[${_pointwise_apply_mk_term} (v v_N) (fun (i:nat{i < v v_N}) -> f._0 (mk_u64 i))] def: (n: nat {n < v v_N}) -> $:{Bit}) + : t_BitVec v_N + = ${BitVec::<0>::from_fn::Bit>} v_N (on (i: u64 {v i < v v_N}) (fun i -> def (v i))) + +let extensionality' (#a: Type) (#b: Type) (f g: FStar.FunctionalExtensionality.(a ^-> b)) + : Lemma (ensures (FStar.FunctionalExtensionality.feq f g <==> f == g)) + = () + +open FStar.Tactics.V2 +#push-options "--z3rlimit 80 --admit_smt_queries true" +let ${BitVec::<128>::rewrite_pointwise} (x: $:{BitVec<128>}) +: Lemma (x == ${BitVec::<128>::pointwise} (${128u64}) x) = + let a = x._0 in + let b = (${BitVec::<128>::pointwise} (${128u64}) x)._0 in + assert_norm (FStar.FunctionalExtensionality.feq a b); + extensionality' a b + +let ${BitVec::<256>::rewrite_pointwise} (x: $:{BitVec<256>}) +: Lemma (x == ${BitVec::<256>::pointwise} (${256u64}) x) = + let a = x._0 in + let b = (${BitVec::<256>::pointwise} (${256u64}) x)._0 in + assert_norm (FStar.FunctionalExtensionality.feq a b); + extensionality' a b +#pop-options + +let postprocess_rewrite_helper (rw_lemma: term) (): Tac unit = with_compat_pre_core 1 (fun () -> + let debug_mode = ext_enabled "debug_bv_postprocess_rewrite" in + let crate = match cur_module () with | crate::_ -> crate | _ -> fail "Empty module name" in + // Remove indirections + norm [primops; iota; delta_namespace [crate; "Libcrux_intrinsics"]; zeta_full]; + // Rewrite call chains + let lemmas = FStar.List.Tot.map (fun f -> pack_ln (FStar.Stubs.Reflection.V2.Data.Tv_FVar f)) (lookup_attr (`${REWRITE_RULE}) (top_env ())) in + l_to_r lemmas; + /// Get rid of casts + norm [primops; iota; delta_namespace ["Rust_primitives"; "Prims.pow2"]; zeta_full]; + if debug_mode then print ("[postprocess_rewrite_helper] lemmas = " ^ term_to_string (quote lemmas)); + if debug_mode then dump "[postprocess_rewrite_helper] After applying lemmas"; + // Apply pointwise rw + let done = alloc false in + ctrl_rewrite TopDown (fun _ -> if read done then (false, Skip) else (true, Continue)) + (fun _ -> (fun () -> apply_lemma_rw rw_lemma; write done true) + `or_else` trefl); + // Normalize as much as possible + norm [primops; iota; delta_namespace ["Core"; crate; "Minicore"; "Libcrux_intrinsics"; "FStar.FunctionalExtensionality"; "Rust_primitives"]; zeta_full]; + // Compute the last bits + compute (); + // Force full normalization + norm [primops; iota; delta; zeta_full]; + if debug_mode then dump "[postprocess_rewrite_helper] after full normalization"; + // Solves the goal ` == ?u` + trefl () +) + +let ${BitVec::<256>::postprocess_rewrite} = postprocess_rewrite_helper (`${BitVec::<256>::rewrite_pointwise}) +let ${BitVec::<128>::postprocess_rewrite} = postprocess_rewrite_helper (`${BitVec::<128>::rewrite_pointwise}) +"# +)] +const _: () = (); + +#[hax_lib::fstar::replace( + r#" +"# +)] +pub fn postprocess_normalize_128() {} + +#[hax_lib::exclude] +impl BitVec<128> { + pub fn rewrite_pointwise(self) {} + pub fn postprocess_rewrite() {} +} +#[hax_lib::exclude] +impl BitVec<256> { + pub fn rewrite_pointwise(self) {} + pub fn postprocess_rewrite() {} +} + +#[hax_lib::exclude] +impl BitVec { + pub fn pointwise(self) -> Self { + self } + /// Constructor for BitVec. `BitVec::::from_fn` constructs a bitvector out of a function that takes usizes smaller than `N` and produces bits. + pub fn from_fn Bit>(f: F) -> Self { + Self(FunArray::from_fn(f)) + } /// Convert a slice of machine integers where only the `d` least significant bits are relevant. - pub fn from_slice + MachineInteger + Copy>(x: &[T], d: usize) -> Self { - Self::from_fn(|i| Bit::of_int(x[i / d], (i % d) as u32)) + pub fn from_slice + MachineInteger + Copy>(x: &[T], d: u64) -> Self { + Self::from_fn(|i| Bit::of_int(x[(i / d) as usize], (i % d) as u32)) } /// Construct a BitVec out of a machine integer. pub fn from_int + MachineInteger + Copy>(n: T) -> Self { - Self::from_slice(&[n.into()], T::BITS as usize) + Self::from_slice(&[n.into()], T::bits() as u64) } /// Convert a BitVec into a machine integer of type `T`. pub fn to_int + MachineInteger + Copy>(self) -> T { - int_from_bit_slice(&self.0) + int_from_bit_slice(&self.0.as_vec()) } /// Convert a BitVec into a vector of machine integers of type `T`. pub fn to_vec + MachineInteger + Copy>(&self) -> Vec { self.0 - .chunks(T::BITS as usize) + .as_vec() + .chunks(T::bits() as usize) .map(int_from_bit_slice) .collect() } @@ -108,7 +215,54 @@ impl BitVec { /// Generate a random BitVec. pub fn rand() -> Self { use rand::prelude::*; - let mut rng = rand::rng(); - Self::from_fn(|_| rng.random::().into()) + let random_source: Vec<_> = { + let mut rng = rand::rng(); + (0..N).map(|_| rng.random::()).collect() + }; + Self::from_fn(|i| random_source[i as usize].into()) + } +} + +#[hax_lib::attributes] +impl BitVec { + #[hax_lib::requires(CHUNK > 0 && CHUNK.to_int() * SHIFTS.to_int() == N.to_int())] + pub fn chunked_shift( + self, + shl: FunArray, + ) -> BitVec { + // TODO: this inner method is because of https://github.com/cryspen/hax-evit/issues/29 + #[hax_lib::fstar::options("--z3rlimit 50 --split_queries always")] + #[hax_lib::requires(CHUNK > 0 && CHUNK.to_int() * SHIFTS.to_int() == N.to_int())] + fn chunked_shift( + bitvec: BitVec, + shl: FunArray, + ) -> BitVec { + BitVec::from_fn(|i| { + let nth_bit = i % CHUNK; + let nth_chunk = i / CHUNK; + hax_lib::assert_prop!(nth_chunk.to_int() <= SHIFTS.to_int() - int!(1)); + hax_lib::assert_prop!( + nth_chunk.to_int() * CHUNK.to_int() + <= (SHIFTS.to_int() - int!(1)) * CHUNK.to_int() + ); + let shift: i128 = if nth_chunk < SHIFTS { + shl[nth_chunk] + } else { + 0 + }; + let local_index = (nth_bit as i128).wrapping_sub(shift); + if local_index < CHUNK as i128 && local_index >= 0 { + let local_index = local_index as u64; + hax_lib::assert_prop!( + nth_chunk.to_int() * CHUNK.to_int() + local_index.to_int() + < SHIFTS.to_int() * CHUNK.to_int() + ); + bitvec[nth_chunk * CHUNK + local_index] + } else { + Bit::Zero + } + }) + } + chunked_shift::(self, shl) } } diff --git a/fstar-helpers/minicore/src/abstractions/funarr.rs b/fstar-helpers/minicore/src/abstractions/funarr.rs new file mode 100644 index 000000000..5dd13825e --- /dev/null +++ b/fstar-helpers/minicore/src/abstractions/funarr.rs @@ -0,0 +1,149 @@ +/// A fixed-size array wrapper with functional semantics and F* integration. +/// +/// `FunArray` represents an array of `T` values of length `N`, where `N` is a compile-time constant. +/// Internally, it uses a fixed-length array of `Option` with a maximum capacity of 512 elements. +/// Unused elements beyond `N` are filled with `None`. +/// +/// This type is integrated with F* through various `#[hax_lib::fstar::replace]` attributes to support +/// formal verification workflows. + +/// Internal helper for generating pointwise applications in F*. +/// This replaces a functional array `arr` by `fun i -> match arr with | 0 -> arr 0 | ... | (N-1) -> arr (N-1)`. +/// Replaced by F* tactic code to generate match-based function applications over bounded natural numbers. +#[hax_lib::fstar::replace( + r#" +open FStar.Tactics + +let ${_pointwise_apply_mk_term} #t + (max: nat) + (f: (n:nat {n < max}) -> t) + : Tac unit + = let rec brs (n:int): Tac _ = + if n < 0 then [] + else + let c = C_Int n in + let p = Pat_Constant c in + (p, mk_e_app (quote f) [pack (Tv_Const c)])::brs (n - 1) + in + let bd = fresh_binder_named "i" (quote (m: nat {m < max})) in + let t = mk_abs [bd] (Tv_Match bd None (brs (max - 1))) in + exact t"# +)] +pub fn _pointwise_apply_mk_term() {} + +#[hax_lib::fstar::replace( + r#" +open FStar.FunctionalExtensionality +type t_FunArray (n: u64) (t: Type0) = i:u64 {v i < v n} ^-> t + +let pointwise_apply + (v_N: u64) (#v_T: Type0) (f: t_FunArray v_N v_T) + (#[${_pointwise_apply_mk_term} (v v_N) (fun (i:nat{i < v v_N}) -> f (mk_u64 i))] def: (n: nat {n < v v_N}) -> v_T) + : t_FunArray v_N v_T + = on (i: u64 {v i < v v_N}) (fun i -> def (v i)) + +let ${FunArray::<0, ()>::get} (v_N: u64) (#v_T: Type0) (self: t_FunArray v_N v_T) (i: u64 {v i < v v_N}) : v_T = + self i + +let ${FunArray::<0, ()>::from_fn::()>} + (v_N: u64) + (#v_T: Type0) + (f: (i: u64 {v i < v v_N}) -> v_T) + : t_FunArray v_N v_T = on (i: u64 {v i < v v_N}) f + +let ${FunArray::<0, ()>::as_vec} n #t (self: t_FunArray n t) = FStar.Seq.init (v n) (fun i -> self (mk_u64 i)) + +let rec ${FunArray::<0, ()>::fold::<()>} n #t #a (arr: t_FunArray n t) (init: a) (f: a -> t -> a): Tot a (decreases (v n)) = + match n with + | MkInt 0 -> init + | MkInt n -> + let acc: a = f init (arr (mk_u64 0)) in + let n = MkInt (n - 1) in + ${FunArray::<0, ()>::fold::<()>} n #t #a + (${FunArray::<0, ()>::from_fn::()>} n (fun i -> arr (i +. mk_u64 1))) + acc f +"# +)] +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct FunArray([Option; 512]); + +#[hax_lib::exclude] +impl FunArray { + /// In F*, this returns the array, applied pointwise. + /// In Rust, this is identity. + pub fn pointwise_apply(self) -> FunArray { + self + } + + /// Gets a reference to the element at index `i`. + pub fn get(&self, i: u64) -> &T { + self.0[i as usize].as_ref().unwrap() + } + /// Constructor for BitVec. `BitVec::::from_fn` constructs a bitvector out of a function that takes usizes smaller than `N` and produces bits. + pub fn from_fn T>(f: F) -> Self { + // let vec = (0..N).map(f).collect(); + let arr = core::array::from_fn(|i| { + if (i as u64) < N { + Some(f(i as u64)) + } else { + None + } + }); + Self(arr) + } + + /// Converts the `FunArray` into a `Vec`. + pub fn as_vec(&self) -> Vec + where + T: Clone, + { + self.0[0..(N as usize)] + .iter() + .cloned() + .map(|x| x.unwrap()) + .collect() + } + + /// Folds over the array, accumulating a result. + /// + /// # Arguments + /// * `init` - The initial value of the accumulator. + /// * `f` - A function combining the accumulator and each element. + pub fn fold(&self, mut init: A, f: fn(A, T) -> A) -> A + where + T: Clone, + { + for i in 0..N { + init = f(init, self[i].clone()); + } + init + } +} + +#[hax_lib::exclude] +impl TryFrom> for FunArray { + type Error = (); + fn try_from(v: Vec) -> Result { + if (v.len() as u64) < N { + Err(()) + } else { + Ok(Self::from_fn(|i| v[i as usize].clone())) + } + } +} + +#[hax_lib::exclude] +impl core::fmt::Debug for FunArray { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{:?}", self.as_vec()) + } +} + +#[hax_lib::attributes] +impl core::ops::Index for FunArray { + type Output = T; + #[requires(index < N)] + fn index(&self, index: u64) -> &Self::Output { + self.get(index) + } +} diff --git a/fstar-helpers/minicore/src/abstractions/mod.rs b/fstar-helpers/minicore/src/abstractions/mod.rs index 481faf956..b9bd0b42b 100644 --- a/fstar-helpers/minicore/src/abstractions/mod.rs +++ b/fstar-helpers/minicore/src/abstractions/mod.rs @@ -22,3 +22,4 @@ pub mod bit; pub mod bitvec; +pub mod funarr; diff --git a/fstar-helpers/minicore/src/arch/x86.rs b/fstar-helpers/minicore/src/arch/x86.rs deleted file mode 100644 index be023f9f3..000000000 --- a/fstar-helpers/minicore/src/arch/x86.rs +++ /dev/null @@ -1,298 +0,0 @@ -//! A (partial) Rust-based model of [`core::arch::x86`] and [`core::arch::x86_64`]. -//! -//! This module provides a purely Rust implementation of selected operations from -//! `core::arch::x86` and `core::arch::x86_64`. - -use crate::abstractions::{bit::*, bitvec::*}; - -pub(crate) mod upstream { - #[cfg(target_arch = "x86")] - pub use core::arch::x86::*; - #[cfg(target_arch = "x86_64")] - pub use core::arch::x86_64::*; -} - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -/// Conversions impls between `BitVec` and `__mNi` types. -mod conversions { - use super::upstream::{ - __m128i, __m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm_loadu_si128, - _mm_storeu_si128, - }; - use super::BitVec; - - impl From> for __m256i { - fn from(bv: BitVec<256>) -> __m256i { - let bv: &[u8] = &bv.to_vec()[..]; - unsafe { _mm256_loadu_si256(bv.as_ptr() as *const _) } - } - } - - impl From> for __m128i { - fn from(bv: BitVec<128>) -> __m128i { - let slice: &[u8] = &bv.to_vec()[..]; - unsafe { _mm_loadu_si128(slice.as_ptr() as *const __m128i) } - } - } - - impl From<__m256i> for BitVec<256> { - fn from(vec: __m256i) -> BitVec<256> { - let mut v = [0u8; 32]; - unsafe { - _mm256_storeu_si256(v.as_mut_ptr() as *mut _, vec); - } - BitVec::from_slice(&v[..], 8) - } - } - - impl From<__m128i> for BitVec<128> { - fn from(vec: __m128i) -> BitVec<128> { - let mut v = [0u8; 16]; - unsafe { - _mm_storeu_si128(v.as_mut_ptr() as *mut _, vec); - } - BitVec::from_slice(&v[..], 8) - } - } -} - -/// 256-bit wide integer vector type. -/// Models `core::arch::x86::__m256i` or `core::arch::x86_64::__m256i` (the __m256i type defined by Intel, representing a 256-bit SIMD register). -#[allow(non_camel_case_types)] -type __m256i = BitVec<256>; - -/// 128-bit wide integer vector type. -/// Models `core::arch::x86::__m128i` or `core::arch::x86_64::__m128i` (the __m128i type defined by Intel, representing a 128-bit SIMD register). -#[allow(non_camel_case_types)] -type __m128i = BitVec<128>; - -pub fn _mm_storeu_si128(output: *mut __m128i, a: __m128i) { - // This is equivalent to `*output = a` - let mut out = [0u8; 128]; - extra::mm_storeu_bytes_si128(&mut out, a); - unsafe { - *(output.as_mut().unwrap()) = BitVec::from_slice(&mut out, 8); - } -} - -pub fn _mm256_slli_epi16(vector: __m256i) -> __m256i { - debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); - BitVec::from_fn(|i| { - let nth_bit = i % 16; - let shift = SHIFT_BY as usize; - if nth_bit >= shift { - vector[i - shift] - } else { - Bit::Zero - } - }) -} - -pub fn _mm256_srli_epi64(vector: __m256i) -> __m256i { - debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64); - BitVec::from_fn(|i| { - let nth_bit = i % 64; - let shift = SHIFT_BY as usize; - if nth_bit < 64 - shift { - vector[i + shift] - } else { - Bit::Zero - } - }) -} - -pub fn _mm256_sllv_epi32(vector: __m256i, counts: __m256i) -> __m256i { - extra::mm256_sllv_epi32_u32_array(vector, counts.to_vec().try_into().unwrap()) -} - -pub fn _mm256_permutevar8x32_epi32(a: __m256i, b: __m256i) -> __m256i { - extra::mm256_permutevar8x32_epi32_u32_array(a, b.to_vec().try_into().unwrap()) -} - -pub fn _mm256_castsi256_si128(vector: __m256i) -> __m128i { - BitVec::from_fn(|i| vector[i]) -} - -pub fn _mm_shuffle_epi8(vector: __m128i, indexes: __m128i) -> __m128i { - let indexes: [u8; 16] = indexes.to_vec().try_into().unwrap(); - extra::mm_shuffle_epi8_u8_array(vector, indexes) -} - -pub mod extra { - use super::*; - - pub fn mm256_sllv_epi32_u32_array(vector: BitVec<256>, counts: [u32; 8]) -> BitVec<256> { - BitVec::from_fn(|i| { - let nth_bit = i % 32; - let shift = counts[i / 32]; - if nth_bit as i128 >= shift as i128 { - vector[i - shift as usize] - } else { - Bit::Zero - } - }) - } - - pub fn mm256_sllv_epi32_u32( - vector: BitVec<256>, - b7: u32, - b6: u32, - b5: u32, - b4: u32, - b3: u32, - b2: u32, - b1: u32, - b0: u32, - ) -> BitVec<256> { - mm256_sllv_epi32_u32_array(vector, [b7, b6, b5, b4, b3, b2, b1, b0]) - } - - pub fn mm256_permutevar8x32_epi32_u32_array(a: BitVec<256>, b: [u32; 8]) -> BitVec<256> { - BitVec::from_fn(|i| { - let j = i / 32; - let index = ((b[j] & 0b111) as usize) * 32; - a[index + i % 32] - }) - } - - pub fn mm256_permutevar8x32_epi32_u32( - vector: BitVec<256>, - b7: u32, - b6: u32, - b5: u32, - b4: u32, - b3: u32, - b2: u32, - b1: u32, - b0: u32, - ) -> BitVec<256> { - mm256_permutevar8x32_epi32_u32_array(vector, [b7, b6, b5, b4, b3, b2, b1, b0]) - } - - pub fn mm_shuffle_epi8_u8_array(vector: BitVec<128>, indexes: [u8; 16]) -> BitVec<128> { - BitVec::from_fn(|i: usize| { - let nth = i / 8; - let index = indexes[nth]; - if index > 127 { - Bit::Zero - } else { - let index = (index & 0b1111) as usize; - vector[index * 8 + i % 8] - } - }) - } - - pub fn mm_shuffle_epi8_u8( - vector: BitVec<128>, - b15: u8, - b14: u8, - b13: u8, - b12: u8, - b11: u8, - b10: u8, - b9: u8, - b8: u8, - b7: u8, - b6: u8, - b5: u8, - b4: u8, - b3: u8, - b2: u8, - b1: u8, - b0: u8, - ) -> BitVec<128> { - let indexes = [ - b15, b14, b13, b12, b11, b10, b9, b8, b7, b6, b5, b4, b3, b2, b1, b0, - ]; - mm_shuffle_epi8_u8_array(vector, indexes) - } - - pub fn mm_storeu_bytes_si128(output: &mut [u8], vector: BitVec<128>) { - output.copy_from_slice(&vector.to_vec()[..]); - } -} - -/// Tests of equivalence between `safe::*` and `upstream::*`. -#[cfg(all(test, any(target_arch = "x86", target_arch = "x86_64")))] -mod tests { - use super::*; - - /// Number of tests to run for each function - const N: usize = 1000; - - #[test] - fn mm256_slli_epi16() { - macro_rules! mk { - ($($shift: literal)*) => { - $(for _ in 0..N { - let input = BitVec::<256>::rand(); - assert_eq!( - super::_mm256_slli_epi16::<$shift>(input), - unsafe {upstream::_mm256_slli_epi16::<$shift>(input.into()).into()} - ); - })* - }; - } - mk!(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15); - } - - #[test] - fn mm256_srli_epi64() { - macro_rules! mk { - ($($shift: literal)*) => { - $(for _ in 0..N { - let input = BitVec::<256>::rand(); - assert_eq!( - super::_mm256_srli_epi64::<$shift>(input), - unsafe{upstream::_mm256_srli_epi64::<$shift>(input.into()).into()} - ); - })* - }; - } - mk!(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63); - } - - #[test] - fn mm256_sllv_epi32() { - for _ in 0..100 { - let vector: BitVec<256> = BitVec::rand(); - let counts: BitVec<256> = BitVec::rand(); - assert_eq!(super::_mm256_sllv_epi32(vector, counts), unsafe { - upstream::_mm256_sllv_epi32(vector.into(), counts.into()).into() - }); - } - } - - #[test] - fn mm256_permutevar8x32_epi32() { - for _ in 0..N { - let vector: BitVec<256> = BitVec::rand(); - let counts: BitVec<256> = BitVec::rand(); - assert_eq!(super::_mm256_permutevar8x32_epi32(vector, counts), unsafe { - upstream::_mm256_permutevar8x32_epi32(vector.into(), counts.into()).into() - }); - } - } - - #[test] - fn mm256_castsi256_si128() { - for _ in 0..N { - let vector: BitVec<256> = BitVec::rand(); - assert_eq!(super::_mm256_castsi256_si128(vector), unsafe { - upstream::_mm256_castsi256_si128(vector.into()).into() - }); - } - } - - #[test] - fn mm_shuffle_epi8() { - for _ in 0..N { - let a: BitVec<128> = BitVec::rand(); - let _: upstream::__m128i = a.into(); - let b: BitVec<128> = BitVec::rand(); - assert_eq!(super::_mm_shuffle_epi8(a, b), unsafe { - upstream::_mm_shuffle_epi8(a.into(), b.into()).into() - }); - } - } -} diff --git a/fstar-helpers/minicore/src/arch.rs b/fstar-helpers/minicore/src/core_arch.rs similarity index 100% rename from fstar-helpers/minicore/src/arch.rs rename to fstar-helpers/minicore/src/core_arch.rs diff --git a/fstar-helpers/minicore/src/core_arch/x86.rs b/fstar-helpers/minicore/src/core_arch/x86.rs new file mode 100644 index 000000000..114c8f63e --- /dev/null +++ b/fstar-helpers/minicore/src/core_arch/x86.rs @@ -0,0 +1,609 @@ +//! A (partial) Rust-based model of [`core::arch::x86`] and [`core::arch::x86_64`]. +//! +//! This module provides a purely Rust implementation of selected operations from +//! `core::arch::x86` and `core::arch::x86_64`. + +use crate::abstractions::{bit::*, bitvec::*, funarr::*}; + +pub(crate) mod upstream { + #[cfg(target_arch = "x86")] + pub use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] + pub use core::arch::x86_64::*; +} + +/// Conversions impls between `BitVec` and `__mNi` types. +#[hax_lib::exclude] +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod conversions { + use super::upstream::{ + __m128i, __m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm_loadu_si128, + _mm_storeu_si128, + }; + use super::BitVec; + + impl From> for __m256i { + fn from(bv: BitVec<256>) -> __m256i { + let bv: &[u8] = &bv.to_vec()[..]; + unsafe { _mm256_loadu_si256(bv.as_ptr() as *const _) } + } + } + + impl From> for __m128i { + fn from(bv: BitVec<128>) -> __m128i { + let slice: &[u8] = &bv.to_vec()[..]; + unsafe { _mm_loadu_si128(slice.as_ptr() as *const __m128i) } + } + } + + impl From<__m256i> for BitVec<256> { + fn from(vec: __m256i) -> BitVec<256> { + let mut v = [0u8; 32]; + unsafe { + _mm256_storeu_si256(v.as_mut_ptr() as *mut _, vec); + } + BitVec::from_slice(&v[..], 8) + } + } + + impl From<__m128i> for BitVec<128> { + fn from(vec: __m128i) -> BitVec<128> { + let mut v = [0u8; 16]; + unsafe { + _mm_storeu_si128(v.as_mut_ptr() as *mut _, vec); + } + BitVec::from_slice(&v[..], 8) + } + } +} + +#[hax_lib::fstar::replace( + r#" + unfold type t_e_ee_m256i = $:{__m256i} + unfold type t_e_ee_m128i = $:{__m128i} +"# +)] +const _: () = {}; + +#[allow(non_camel_case_types)] +struct __m256(()); + +/// 256-bit wide integer vector type. +/// Models `core::arch::x86::__m256i` or `core::arch::x86_64::__m256i` (the __m256i type defined by Intel, representing a 256-bit SIMD register). +#[allow(non_camel_case_types)] +pub type __m256i = BitVec<256>; + +/// 128-bit wide integer vector type. +/// Models `core::arch::x86::__m128i` or `core::arch::x86_64::__m128i` (the __m128i type defined by Intel, representing a 128-bit SIMD register). +#[allow(non_camel_case_types)] +pub type __m128i = BitVec<128>; + +pub use ssse3::*; +pub mod ssse3 { + use super::*; + #[hax_lib::opaque] + pub fn _mm_shuffle_epi8(vector: __m128i, indexes: __m128i) -> __m128i { + let indexes = indexes.to_vec().try_into().unwrap(); + extra::mm_shuffle_epi8_u8_array(vector, indexes) + } +} +pub use sse2::*; +pub mod sse2 { + use super::*; + #[hax_lib::opaque] + pub fn _mm_set_epi8( + _e15: i8, + _e14: i8, + _e13: i8, + _e12: i8, + _e11: i8, + _e10: i8, + _e9: i8, + _e8: i8, + _e7: i8, + _e6: i8, + _e5: i8, + _e4: i8, + _e3: i8, + _e2: i8, + _e1: i8, + _e0: i8, + ) -> __m128i { + todo!() + } +} + +pub use avx::*; +pub mod avx { + pub use super::*; + pub fn _mm256_castsi256_si128(vector: __m256i) -> __m128i { + BitVec::from_fn(|i| vector[i]) + } + + #[hax_lib::opaque] + pub fn _mm256_set_epi32( + _e0: i32, + _e1: i32, + _e2: i32, + _e3: i32, + _e4: i32, + _e5: i32, + _e6: i32, + _e7: i32, + ) -> __m256i { + todo!() + } +} +pub use avx2::*; +pub mod avx2 { + use super::*; + #[hax_lib::exclude] + pub fn _mm_storeu_si128(output: *mut __m128i, a: __m128i) { + // This is equivalent to `*output = a` + let mut out = [0u8; 128]; + extra::mm_storeu_bytes_si128(&mut out, a); + unsafe { + *(output.as_mut().unwrap()) = BitVec::from_slice(&out, 8); + } + } + + #[hax_lib::requires(SHIFT_BY >= 0 && SHIFT_BY < 16)] + pub fn _mm256_slli_epi16(vector: __m256i) -> __m256i { + vector.chunked_shift::<16, 16>(FunArray::from_fn(|_| SHIFT_BY as i128)) + } + + #[hax_lib::requires(SHIFT_BY >= 0 && SHIFT_BY < 64)] + pub fn _mm256_srli_epi64(vector: __m256i) -> __m256i { + vector.chunked_shift::<64, 4>(FunArray::from_fn(|_| -(SHIFT_BY as i128))) + } + + #[hax_lib::opaque] + pub fn _mm256_sllv_epi32(vector: __m256i, counts: __m256i) -> __m256i { + extra::mm256_sllv_epi32_u32_array(vector, counts.to_vec().try_into().unwrap()) + } + + #[hax_lib::opaque] + pub fn _mm256_srlv_epi32(vector: __m256i, counts: __m256i) -> __m256i { + extra::mm256_srlv_epi32_u32_array(vector, counts.to_vec().try_into().unwrap()) + } + + #[hax_lib::opaque] + pub fn _mm256_permutevar8x32_epi32(a: __m256i, b: __m256i) -> __m256i { + extra::mm256_permutevar8x32_epi32_u32_array(a, b.to_vec().try_into().unwrap()) + } + + pub fn _mm256_extracti128_si256(vector: __m256i) -> __m128i { + BitVec::from_fn(|i| vector[i + if IMM8 == 0 { 0 } else { 128 }]) + } +} + +/// Rewrite lemmas +const _: () = { + #[hax_lib::fstar::before("[@@ $REWRITE_RULE ]")] + #[hax_lib::lemma] + #[hax_lib::opaque] + fn _rw_mm256_sllv_epi32( + vector: __m256i, + b7: i32, + b6: i32, + b5: i32, + b4: i32, + b3: i32, + b2: i32, + b1: i32, + b0: i32, + ) -> Proof< + { + hax_lib::prop::eq( + _mm256_sllv_epi32(vector, _mm256_set_epi32(b7, b6, b5, b4, b3, b2, b1, b0)), + extra::mm256_sllv_epi32_u32( + vector, b7 as u32, b6 as u32, b5 as u32, b4 as u32, b3 as u32, b2 as u32, + b1 as u32, b0 as u32, + ), + ) + }, + > { + } + + #[hax_lib::fstar::before("[@@ $REWRITE_RULE ]")] + #[hax_lib::lemma] + #[hax_lib::opaque] + fn _rw_mm256_permutevar8x32_epi32( + vector: __m256i, + b7: i32, + b6: i32, + b5: i32, + b4: i32, + b3: i32, + b2: i32, + b1: i32, + b0: i32, + ) -> Proof< + { + hax_lib::prop::eq( + _mm256_permutevar8x32_epi32( + vector, + _mm256_set_epi32(b7, b6, b5, b4, b3, b2, b1, b0), + ), + extra::mm256_permutevar8x32_epi32_u32( + vector, b7 as u32, b6 as u32, b5 as u32, b4 as u32, b3 as u32, b2 as u32, + b1 as u32, b0 as u32, + ), + ) + }, + > { + } + + #[hax_lib::fstar::before("[@@ $REWRITE_RULE ]")] + #[hax_lib::lemma] + #[hax_lib::opaque] + fn _rw_mm_shuffle_epi8( + vector: __m128i, + e15: i8, + e14: i8, + e13: i8, + e12: i8, + e11: i8, + e10: i8, + e9: i8, + e8: i8, + e7: i8, + e6: i8, + e5: i8, + e4: i8, + e3: i8, + e2: i8, + e1: i8, + e0: i8, + ) -> Proof< + { + hax_lib::prop::eq( + _mm_shuffle_epi8( + vector, + _mm_set_epi8( + e15, e14, e13, e12, e11, e10, e9, e8, e7, e6, e5, e4, e3, e2, e1, e0, + ), + ), + extra::mm_shuffle_epi8_u8( + vector, e15 as u8, e14 as u8, e13 as u8, e12 as u8, e11 as u8, e10 as u8, + e9 as u8, e8 as u8, e7 as u8, e6 as u8, e5 as u8, e4 as u8, e3 as u8, e2 as u8, + e1 as u8, e0 as u8, + ), + ) + }, + > { + } +}; + +pub mod extra { + use super::*; + + pub fn mm256_sllv_epi32_u32_array( + vector: BitVec<256>, + counts: FunArray<8, u32>, + ) -> BitVec<256> { + vector.chunked_shift::<32, 8>(FunArray::from_fn(|i| counts[i] as i128)) + } + + pub fn mm256_sllv_epi32_u32( + vector: BitVec<256>, + b7: u32, + b6: u32, + b5: u32, + b4: u32, + b3: u32, + b2: u32, + b1: u32, + b0: u32, + ) -> BitVec<256> { + mm256_sllv_epi32_u32_array( + vector, + FunArray::from_fn(|i| match i { + 7 => b7, + 6 => b6, + 5 => b5, + 4 => b4, + 3 => b3, + 2 => b2, + 1 => b1, + 0 => b0, + _ => unreachable!(), + }), + ) + } + + pub fn mm256_srlv_epi32_u32_array( + vector: BitVec<256>, + counts: FunArray<8, u32>, + ) -> BitVec<256> { + vector.chunked_shift::<32, 8>(FunArray::from_fn(|i| -(counts[i] as i128))) + } + + pub fn mm256_srlv_epi32_u32( + vector: BitVec<256>, + b7: u32, + b6: u32, + b5: u32, + b4: u32, + b3: u32, + b2: u32, + b1: u32, + b0: u32, + ) -> BitVec<256> { + mm256_srlv_epi32_u32_array( + vector, + FunArray::from_fn(|i| match i { + 7 => b7, + 6 => b6, + 5 => b5, + 4 => b4, + 3 => b3, + 2 => b2, + 1 => b1, + 0 => b0, + _ => unreachable!(), + }), + ) + } + + pub fn mm256_permutevar8x32_epi32_u32_array( + a: BitVec<256>, + b: FunArray<8, u32>, + ) -> BitVec<256> { + BitVec::from_fn(|i| { + let j = i / 32; + let index = ((b[j] % 8) as u64) * 32; + a[index + i % 32] + }) + } + + pub fn mm256_permutevar8x32_epi32_u32( + vector: BitVec<256>, + b7: u32, + b6: u32, + b5: u32, + b4: u32, + b3: u32, + b2: u32, + b1: u32, + b0: u32, + ) -> BitVec<256> { + mm256_permutevar8x32_epi32_u32_array( + vector, + FunArray::from_fn(|i| match i { + 7 => b7, + 6 => b6, + 5 => b5, + 4 => b4, + 3 => b3, + 2 => b2, + 1 => b1, + 0 => b0, + _ => unreachable!(), + }), + ) + } + + pub fn mm_shuffle_epi8_u8_array(vector: BitVec<128>, indexes: FunArray<16, u8>) -> BitVec<128> { + BitVec::from_fn(|i| { + let nth = i / 8; + let index = indexes[nth]; + if index > 127 { + Bit::Zero + } else { + let index = (index % 16) as u64; + vector[index * 8 + i % 8] + } + }) + } + + pub fn mm_shuffle_epi8_u8( + vector: BitVec<128>, + b15: u8, + b14: u8, + b13: u8, + b12: u8, + b11: u8, + b10: u8, + b9: u8, + b8: u8, + b7: u8, + b6: u8, + b5: u8, + b4: u8, + b3: u8, + b2: u8, + b1: u8, + b0: u8, + ) -> BitVec<128> { + let indexes = FunArray::from_fn(|i| match i { + 15 => b15, + 14 => b14, + 13 => b13, + 12 => b12, + 11 => b11, + 10 => b10, + 9 => b9, + 8 => b8, + 7 => b7, + 6 => b6, + 5 => b5, + 4 => b4, + 3 => b3, + 2 => b2, + 1 => b1, + 0 => b0, + _ => unreachable!(), + }); + mm_shuffle_epi8_u8_array(vector, indexes) + } + + pub fn mm256_mullo_epi16_shifts( + vector: __m256i, + s15: u8, + s14: u8, + s13: u8, + s12: u8, + s11: u8, + s10: u8, + s9: u8, + s8: u8, + s7: u8, + s6: u8, + s5: u8, + s4: u8, + s3: u8, + s2: u8, + s1: u8, + s0: u8, + ) -> __m256i { + let shifts = FunArray::<16, _>::from_fn(|i| match i { + 15 => s15, + 14 => s14, + 13 => s13, + 12 => s12, + 11 => s11, + 10 => s10, + 9 => s9, + 8 => s8, + 7 => s7, + 6 => s6, + 5 => s5, + 4 => s4, + 3 => s3, + 2 => s2, + 1 => s1, + 0 => s0, + _ => unreachable!(), + }); + mm256_mullo_epi16_shifts_array(vector, shifts) + } + pub fn mm256_mullo_epi16_shifts_array(vector: __m256i, shifts: FunArray<16, u8>) -> __m256i { + BitVec::from_fn(|i| { + let nth_bit = i % 16; + let nth_i16 = i / 16; + + let shift = shifts[nth_i16] as u64; + + if nth_bit >= shift { + vector[nth_i16 * 16 + nth_bit - shift] + } else { + Bit::Zero + } + }) + } + + #[hax_lib::exclude] + pub fn mm_storeu_bytes_si128(output: &mut [u8], vector: BitVec<128>) { + output.copy_from_slice(&vector.to_vec()[..]); + } +} + +/// Tests of equivalence between `safe::*` and `upstream::*`. +#[cfg(all(test, any(target_arch = "x86", target_arch = "x86_64")))] +mod tests { + use super::*; + + /// Number of tests to run for each function + const N: usize = 1000; + + #[test] + fn mm256_slli_epi16() { + macro_rules! mk { + ($($shift: literal)*) => { + $(for _ in 0..N { + let input = BitVec::<256>::rand(); + assert_eq!( + super::_mm256_slli_epi16::<$shift>(input), + unsafe {upstream::_mm256_slli_epi16::<$shift>(input.into()).into()} + ); + })* + }; + } + mk!(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15); + } + + #[test] + fn mm256_srli_epi64() { + macro_rules! mk { + ($($shift: literal)*) => { + $(for _ in 0..N { + let input = BitVec::<256>::rand(); + assert_eq!( + super::_mm256_srli_epi64::<$shift>(input), + unsafe{upstream::_mm256_srli_epi64::<$shift>(input.into()).into()} + ); + })* + }; + } + mk!(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63); + } + + #[test] + fn mm256_sllv_epi32() { + for _ in 0..N { + let vector: BitVec<256> = BitVec::rand(); + let counts: BitVec<256> = BitVec::rand(); + assert_eq!(super::_mm256_sllv_epi32(vector, counts), unsafe { + upstream::_mm256_sllv_epi32(vector.into(), counts.into()).into() + }); + } + } + + #[test] + fn mm256_srlv_epi32() { + for _ in 0..N { + let vector: BitVec<256> = BitVec::rand(); + let counts: BitVec<256> = BitVec::rand(); + assert_eq!(super::_mm256_srlv_epi32(vector, counts), unsafe { + upstream::_mm256_srlv_epi32(vector.into(), counts.into()).into() + }); + } + } + + #[test] + fn mm256_permutevar8x32_epi32() { + for _ in 0..N { + let vector: BitVec<256> = BitVec::rand(); + let counts: BitVec<256> = BitVec::rand(); + assert_eq!(super::_mm256_permutevar8x32_epi32(vector, counts), unsafe { + upstream::_mm256_permutevar8x32_epi32(vector.into(), counts.into()).into() + }); + } + } + + #[test] + fn mm256_castsi256_si128() { + for _ in 0..N { + let vector: BitVec<256> = BitVec::rand(); + assert_eq!(super::_mm256_castsi256_si128(vector), unsafe { + upstream::_mm256_castsi256_si128(vector.into()).into() + }); + } + } + + #[test] + fn mm256_extracti128_si256() { + for _ in 0..N { + let vector: BitVec<256> = BitVec::rand(); + assert_eq!(super::_mm256_extracti128_si256::<0>(vector), unsafe { + upstream::_mm256_extracti128_si256::<0>(vector.into()).into() + }); + assert_eq!(super::_mm256_extracti128_si256::<1>(vector), unsafe { + upstream::_mm256_extracti128_si256::<1>(vector.into()).into() + }); + } + } + + #[test] + fn mm_shuffle_epi8() { + for _ in 0..N { + let a: BitVec<128> = BitVec::rand(); + let b: BitVec<128> = BitVec::rand(); + + assert_eq!(super::_mm_shuffle_epi8(a, b), unsafe { + upstream::_mm_shuffle_epi8(a.into(), b.into()).into() + }); + } + } +} diff --git a/fstar-helpers/minicore/src/lib.rs b/fstar-helpers/minicore/src/lib.rs index 084d19ea2..68a6978cc 100644 --- a/fstar-helpers/minicore/src/lib.rs +++ b/fstar-helpers/minicore/src/lib.rs @@ -26,4 +26,6 @@ //! proof assistants and other verification tools. pub mod abstractions; -pub mod arch; +pub mod core_arch; + +pub use core_arch as arch; diff --git a/libcrux-intrinsics/Cargo.toml b/libcrux-intrinsics/Cargo.toml index 5cacc5bee..2b196d12c 100644 --- a/libcrux-intrinsics/Cargo.toml +++ b/libcrux-intrinsics/Cargo.toml @@ -20,4 +20,4 @@ simd256 = [] [dev-dependencies] [lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ['cfg(hax)', 'cfg(eurydice)'] } +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(hax)', 'cfg(pre_minicore)', 'cfg(eurydice)'] } diff --git a/libcrux-intrinsics/src/avx2.rs b/libcrux-intrinsics/src/avx2.rs index 9c419e557..a4ebf725e 100644 --- a/libcrux-intrinsics/src/avx2.rs +++ b/libcrux-intrinsics/src/avx2.rs @@ -7,6 +7,7 @@ pub type Vec256 = __m256i; pub type Vec128 = __m128i; pub type Vec256Float = __m256; +#[hax_lib::opaque] #[inline(always)] pub fn mm256_storeu_si256_u8(output: &mut [u8], vector: Vec256) { debug_assert_eq!(output.len(), 32); @@ -14,6 +15,8 @@ pub fn mm256_storeu_si256_u8(output: &mut [u8], vector: Vec256) { _mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector); } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_storeu_si256_i16(output: &mut [i16], vector: Vec256) { debug_assert_eq!(output.len(), 16); @@ -21,6 +24,8 @@ pub fn mm256_storeu_si256_i16(output: &mut [i16], vector: Vec256) { _mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector); } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_storeu_si256_i32(output: &mut [i32], vector: Vec256) { debug_assert_eq!(output.len(), 8); @@ -29,6 +34,7 @@ pub fn mm256_storeu_si256_i32(output: &mut [i32], vector: Vec256) { } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_storeu_si128(output: &mut [i16], vector: Vec128) { debug_assert!(output.len() >= 8); @@ -36,6 +42,8 @@ pub fn mm_storeu_si128(output: &mut [i16], vector: Vec128) { _mm_storeu_si128(output.as_mut_ptr() as *mut Vec128, vector); } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm_storeu_si128_i32(output: &mut [i32], vector: Vec128) { debug_assert_eq!(output.len(), 4); @@ -44,6 +52,8 @@ pub fn mm_storeu_si128_i32(output: &mut [i32], vector: Vec128) { } } +#[hax_lib::opaque] +#[hax_lib::ensures(|_r| future(output).len() == output.len())] #[inline(always)] pub fn mm_storeu_bytes_si128(output: &mut [u8], vector: Vec128) { debug_assert_eq!(output.len(), 16); @@ -52,32 +62,41 @@ pub fn mm_storeu_bytes_si128(output: &mut [u8], vector: Vec128) { } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_loadu_si128(input: &[u8]) -> Vec128 { debug_assert_eq!(input.len(), 16); unsafe { _mm_loadu_si128(input.as_ptr() as *const Vec128) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_loadu_si256_u8(input: &[u8]) -> Vec256 { debug_assert_eq!(input.len(), 32); unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_loadu_si256_i16(input: &[i16]) -> Vec256 { debug_assert_eq!(input.len(), 16); unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_loadu_si256_i32(input: &[i32]) -> Vec256 { debug_assert_eq!(input.len(), 8); unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_setzero_si256() -> Vec256 { unsafe { _mm256_setzero_si256() } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_set_m128i(hi: Vec128, lo: Vec128) -> Vec256 { unsafe { _mm256_set_m128i(hi, lo) } @@ -124,6 +143,7 @@ pub fn mm_set_epi8( } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_set_epi8( byte31: i8, @@ -168,10 +188,13 @@ pub fn mm256_set_epi8( } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_set1_epi16(constant: i16) -> Vec256 { unsafe { _mm256_set1_epi16(constant) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_set_epi16( input15: i16, @@ -199,20 +222,24 @@ pub fn mm256_set_epi16( } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_set1_epi16(constant: i16) -> Vec128 { unsafe { _mm_set1_epi16(constant) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_set1_epi32(constant: i32) -> Vec256 { unsafe { _mm256_set1_epi32(constant) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_set_epi32(input3: i32, input2: i32, input1: i32, input0: i32) -> Vec128 { unsafe { _mm_set_epi32(input3, input2, input1, input0) } } + #[inline(always)] pub fn mm256_set_epi32( input7: i32, @@ -231,130 +258,207 @@ pub fn mm256_set_epi32( } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_add_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 { unsafe { _mm_add_epi16(lhs, rhs) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_add_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_add_epi16(lhs, rhs) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_madd_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_madd_epi16(lhs, rhs) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_add_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_add_epi32(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_add_epi64(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_add_epi64(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_abs_epi32(a: Vec256) -> Vec256 { unsafe { _mm256_abs_epi32(a) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_sub_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_sub_epi16(lhs, rhs) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_sub_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_sub_epi32(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_sub_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 { unsafe { _mm_sub_epi16(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_mullo_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_mullo_epi16(lhs, rhs) } } +#[hax_lib::opaque] +#[inline(always)] +pub fn mm256_mullo_epi16_shifts( + vector: Vec256, + s15: u8, + s14: u8, + s13: u8, + s12: u8, + s11: u8, + s10: u8, + s9: u8, + s8: u8, + s7: u8, + s6: u8, + s5: u8, + s4: u8, + s3: u8, + s2: u8, + s1: u8, + s0: u8, +) -> Vec256 { + mm256_mullo_epi16( + vector, + mm256_set_epi16( + 1i16 << s15, + 1i16 << s14, + 1i16 << s13, + 1i16 << s12, + 1i16 << s11, + 1i16 << s10, + 1i16 << s9, + 1i16 << s8, + 1i16 << s7, + 1i16 << s6, + 1i16 << s5, + 1i16 << s4, + 1i16 << s3, + 1i16 << s2, + 1i16 << s1, + 1i16 << s0, + ), + ) +} + +#[hax_lib::opaque] #[inline(always)] pub fn mm_mullo_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 { unsafe { _mm_mullo_epi16(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_cmpgt_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_cmpgt_epi16(lhs, rhs) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_cmpgt_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_cmpgt_epi32(lhs, rhs) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_cmpeq_epi32(a: Vec256, b: Vec256) -> Vec256 { unsafe { _mm256_cmpeq_epi32(a, b) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_sign_epi32(a: Vec256, b: Vec256) -> Vec256 { unsafe { _mm256_sign_epi32(a, b) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_castsi256_ps(a: Vec256) -> Vec256Float { unsafe { _mm256_castsi256_ps(a) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_castps_si256(a: Vec256Float) -> Vec256 { unsafe { _mm256_castps_si256(a) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_movemask_ps(a: Vec256Float) -> i32 { unsafe { _mm256_movemask_ps(a) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_mulhi_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 { unsafe { _mm_mulhi_epi16(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_mullo_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_mullo_epi32(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_mulhi_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_mulhi_epi16(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_mul_epu32(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_mul_epu32(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_mul_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_mul_epi32(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_and_si256(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_and_si256(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_or_si256(a: Vec256, b: Vec256) -> Vec256 { unsafe { _mm256_or_si256(a, b) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_testz_si256(lhs: Vec256, rhs: Vec256) -> i32 { unsafe { _mm256_testz_si256(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_xor_si256(lhs: Vec256, rhs: Vec256) -> Vec256 { // This floating point xor may or may not be faster than regular xor. @@ -373,45 +477,55 @@ pub fn mm256_xor_si256(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_xor_si256(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_srai_epi16(vector: Vec256) -> Vec256 { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); unsafe { _mm256_srai_epi16(vector, SHIFT_BY) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_srai_epi32(vector: Vec256) -> Vec256 { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32); unsafe { _mm256_srai_epi32(vector, SHIFT_BY) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_srli_epi16(vector: Vec256) -> Vec256 { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); unsafe { _mm256_srli_epi16(vector, SHIFT_BY) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_srli_epi32(vector: Vec256) -> Vec256 { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32); unsafe { _mm256_srli_epi32(vector, SHIFT_BY) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_srli_epi64(vector: Vec128) -> Vec128 { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64); unsafe { _mm_srli_epi64(vector, SHIFT_BY) } } + #[inline(always)] pub fn mm256_srli_epi64(vector: Vec256) -> Vec256 { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64); unsafe { _mm256_srli_epi64(vector, SHIFT_BY) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_slli_epi16(vector: Vec256) -> Vec256 { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); unsafe { _mm256_slli_epi16(vector, SHIFT_BY) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_slli_epi32(vector: Vec256) -> Vec256 { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32); @@ -422,32 +536,40 @@ pub fn mm256_slli_epi32(vector: Vec256) -> Vec256 { pub fn mm_shuffle_epi8(vector: Vec128, control: Vec128) -> Vec128 { unsafe { _mm_shuffle_epi8(vector, control) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_shuffle_epi8(vector: Vec256, control: Vec256) -> Vec256 { unsafe { _mm256_shuffle_epi8(vector, control) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_shuffle_epi32(vector: Vec256) -> Vec256 { debug_assert!(CONTROL >= 0 && CONTROL < 256); unsafe { _mm256_shuffle_epi32(vector, CONTROL) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_permute4x64_epi64(vector: Vec256) -> Vec256 { debug_assert!(CONTROL >= 0 && CONTROL < 256); unsafe { _mm256_permute4x64_epi64(vector, CONTROL) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_unpackhi_epi64(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_unpackhi_epi64(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_unpacklo_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_unpacklo_epi32(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_unpackhi_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_unpackhi_epi32(lhs, rhs) } @@ -457,20 +579,26 @@ pub fn mm256_unpackhi_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { pub fn mm256_castsi256_si128(vector: Vec256) -> Vec128 { unsafe { _mm256_castsi256_si128(vector) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_castsi128_si256(vector: Vec128) -> Vec256 { unsafe { _mm256_castsi128_si256(vector) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_cvtepi16_epi32(vector: Vec128) -> Vec256 { unsafe { _mm256_cvtepi16_epi32(vector) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_packs_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 { unsafe { _mm_packs_epi16(lhs, rhs) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_packs_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_packs_epi32(lhs, rhs) } @@ -482,18 +610,21 @@ pub fn mm256_extracti128_si256(vector: Vec256) -> Vec128 { unsafe { _mm256_extracti128_si256(vector, CONTROL) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_inserti128_si256(vector: Vec256, vector_i128: Vec128) -> Vec256 { debug_assert!(CONTROL == 0 || CONTROL == 1); unsafe { _mm256_inserti128_si256(vector, vector_i128, CONTROL) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_blend_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { debug_assert!(CONTROL >= 0 && CONTROL < 256); unsafe { _mm256_blend_epi16(lhs, rhs, CONTROL) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_blend_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { debug_assert!(CONTROL >= 0 && CONTROL < 256); @@ -502,6 +633,7 @@ pub fn mm256_blend_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 // This is essentially _mm256_blendv_ps adapted for use with the Vec256 type. // It is not offered by the AVX2 instruction set. +#[hax_lib::opaque] #[inline(always)] pub fn vec256_blendv_epi32(a: Vec256, b: Vec256, mask: Vec256) -> Vec256 { unsafe { @@ -513,6 +645,7 @@ pub fn vec256_blendv_epi32(a: Vec256, b: Vec256, mask: Vec256) -> Vec256 { } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_movemask_epi8(vector: Vec128) -> i32 { unsafe { _mm_movemask_epi8(vector) } @@ -527,50 +660,62 @@ pub fn mm256_permutevar8x32_epi32(vector: Vec256, control: Vec256) -> Vec256 { pub fn mm256_srlv_epi32(vector: Vec256, counts: Vec256) -> Vec256 { unsafe { _mm256_srlv_epi32(vector, counts) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_srlv_epi64(vector: Vec256, counts: Vec256) -> Vec256 { unsafe { _mm256_srlv_epi64(vector, counts) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm_sllv_epi32(vector: Vec128, counts: Vec128) -> Vec128 { unsafe { _mm_sllv_epi32(vector, counts) } } + #[inline(always)] pub fn mm256_sllv_epi32(vector: Vec256, counts: Vec256) -> Vec256 { unsafe { _mm256_sllv_epi32(vector, counts) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_slli_epi64(x: Vec256) -> Vec256 { unsafe { _mm256_slli_epi64::(x) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_bsrli_epi128(x: Vec256) -> Vec256 { debug_assert!(SHIFT_BY > 0 && SHIFT_BY < 16); unsafe { _mm256_bsrli_epi128::(x) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_andnot_si256(a: Vec256, b: Vec256) -> Vec256 { unsafe { _mm256_andnot_si256(a, b) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_set1_epi64x(a: i64) -> Vec256 { unsafe { _mm256_set1_epi64x(a) } } + +#[hax_lib::opaque] #[inline(always)] pub fn mm256_set_epi64x(input3: i64, input2: i64, input1: i64, input0: i64) -> Vec256 { unsafe { _mm256_set_epi64x(input3, input2, input1, input0) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_unpacklo_epi64(lhs: Vec256, rhs: Vec256) -> Vec256 { unsafe { _mm256_unpacklo_epi64(lhs, rhs) } } +#[hax_lib::opaque] #[inline(always)] pub fn mm256_permute2x128_si256(a: Vec256, b: Vec256) -> Vec256 { unsafe { _mm256_permute2x128_si256::(a, b) } diff --git a/libcrux-intrinsics/src/lib.rs b/libcrux-intrinsics/src/lib.rs index 98d72f2a2..c8a950a96 100644 --- a/libcrux-intrinsics/src/lib.rs +++ b/libcrux-intrinsics/src/lib.rs @@ -9,7 +9,11 @@ pub mod avx2; pub mod arm64_extract; #[cfg(all(feature = "simd128", hax))] pub use arm64_extract as arm64; -#[cfg(all(feature = "simd256", hax))] + +#[cfg(all(feature = "simd256", hax, pre_minicore))] pub mod avx2_extract; -#[cfg(all(feature = "simd256", hax))] +#[cfg(all(feature = "simd256", hax, pre_minicore))] pub use avx2_extract as avx2; + +#[cfg(all(feature = "simd256", hax, not(pre_minicore)))] +pub mod avx2; diff --git a/libcrux-ml-dsa/Cargo.toml b/libcrux-ml-dsa/Cargo.toml index 6eadbaf74..1b3f3ff02 100644 --- a/libcrux-ml-dsa/Cargo.toml +++ b/libcrux-ml-dsa/Cargo.toml @@ -33,6 +33,9 @@ criterion = "0.5" [target.'cfg(not(all(target_os = "macos", target_arch = "x86_64")))'.dev-dependencies] pqcrypto-mldsa = { version = "0.1.0" } #, default-features = false +[target.'cfg(hax)'.dependencies] +minicore = { path = "../fstar-helpers/minicore" } + [features] default = ["std", "mldsa44", "mldsa65", "mldsa87"] simd128 = ["libcrux-sha3/simd128", "libcrux-intrinsics/simd128"] diff --git a/libcrux-ml-dsa/cg/code_gen.txt b/libcrux-ml-dsa/cg/code_gen.txt index d7cf61fde..962a0312c 100644 --- a/libcrux-ml-dsa/cg/code_gen.txt +++ b/libcrux-ml-dsa/cg/code_gen.txt @@ -3,4 +3,4 @@ Charon: 763350c6948d5594d3017ecb93273bc41c1a4e1d Eurydice: 36a5ed7dd6b61b5cd3d69a010859005912d21537 Karamel: bf9b89d76dd24e2ceaaca32de3535353e7b6bc01 F*: 4b3fc11774003a6ff7c09500ecb5f0145ca6d862 -Libcrux: f3f15359a852405a81628f982f2debc4d50fff30 +Libcrux: fe232ff426a7be69ae8434cdae0917405ac9d7a2 diff --git a/libcrux-ml-dsa/cg/header.txt b/libcrux-ml-dsa/cg/header.txt index 1f870d0c5..255ae0e51 100644 --- a/libcrux-ml-dsa/cg/header.txt +++ b/libcrux-ml-dsa/cg/header.txt @@ -8,5 +8,5 @@ * Eurydice: 36a5ed7dd6b61b5cd3d69a010859005912d21537 * Karamel: bf9b89d76dd24e2ceaaca32de3535353e7b6bc01 * F*: 4b3fc11774003a6ff7c09500ecb5f0145ca6d862 - * Libcrux: f3f15359a852405a81628f982f2debc4d50fff30 + * Libcrux: fe232ff426a7be69ae8434cdae0917405ac9d7a2 */ diff --git a/libcrux-ml-dsa/cg/libcrux_core.h b/libcrux-ml-dsa/cg/libcrux_core.h index 458ce9572..82085cc4f 100644 --- a/libcrux-ml-dsa/cg/libcrux_core.h +++ b/libcrux-ml-dsa/cg/libcrux_core.h @@ -8,7 +8,7 @@ * Eurydice: 36a5ed7dd6b61b5cd3d69a010859005912d21537 * Karamel: bf9b89d76dd24e2ceaaca32de3535353e7b6bc01 * F*: 4b3fc11774003a6ff7c09500ecb5f0145ca6d862 - * Libcrux: f3f15359a852405a81628f982f2debc4d50fff30 + * Libcrux: fe232ff426a7be69ae8434cdae0917405ac9d7a2 */ #ifndef __libcrux_core_H diff --git a/libcrux-ml-dsa/cg/libcrux_mldsa65_avx2.h b/libcrux-ml-dsa/cg/libcrux_mldsa65_avx2.h index 9c3acee02..0b950f3f8 100644 --- a/libcrux-ml-dsa/cg/libcrux_mldsa65_avx2.h +++ b/libcrux-ml-dsa/cg/libcrux_mldsa65_avx2.h @@ -8,7 +8,7 @@ * Eurydice: 36a5ed7dd6b61b5cd3d69a010859005912d21537 * Karamel: bf9b89d76dd24e2ceaaca32de3535353e7b6bc01 * F*: 4b3fc11774003a6ff7c09500ecb5f0145ca6d862 - * Libcrux: f3f15359a852405a81628f982f2debc4d50fff30 + * Libcrux: fe232ff426a7be69ae8434cdae0917405ac9d7a2 */ #ifndef __libcrux_mldsa65_avx2_H @@ -1301,6 +1301,37 @@ static KRML_MUSTINLINE void libcrux_ml_dsa_simd_avx2_gamma1_deserialize_22( gamma1_exponent); } +KRML_ATTRIBUTE_TARGET("avx2") +static KRML_MUSTINLINE __m128i +libcrux_ml_dsa_simd_avx2_encoding_commitment_serialize_4_normalized_serialize_4( + __m256i *simd_unit) { + __m256i adjacent_2_combined = libcrux_intrinsics_avx2_mm256_sllv_epi32( + simd_unit[0U], libcrux_intrinsics_avx2_mm256_set_epi32( + (int32_t)0, (int32_t)28, (int32_t)0, (int32_t)28, + (int32_t)0, (int32_t)28, (int32_t)0, (int32_t)28)); + __m256i adjacent_2_combined0 = libcrux_intrinsics_avx2_mm256_srli_epi64( + (int32_t)28, adjacent_2_combined, __m256i); + __m256i adjacent_4_combined = + libcrux_intrinsics_avx2_mm256_permutevar8x32_epi32( + adjacent_2_combined0, + libcrux_intrinsics_avx2_mm256_set_epi32( + (int32_t)0, (int32_t)0, (int32_t)0, (int32_t)0, (int32_t)6, + (int32_t)2, (int32_t)4, (int32_t)0)); + __m128i adjacent_4_combined0 = + libcrux_intrinsics_avx2_mm256_castsi256_si128(adjacent_4_combined); + return libcrux_intrinsics_avx2_mm_shuffle_epi8( + adjacent_4_combined0, libcrux_intrinsics_avx2_mm_set_epi8( + 240U, 240U, 240U, 240U, 240U, 240U, 240U, 240U, + 240U, 240U, 240U, 240U, 12U, 4U, 8U, 0U)); +} + +KRML_ATTRIBUTE_TARGET("avx2") +static KRML_MUSTINLINE __m128i +libcrux_ml_dsa_simd_avx2_encoding_commitment_serialize_4(__m256i *simd_unit) { + return libcrux_ml_dsa_simd_avx2_encoding_commitment_serialize_4_normalized_serialize_4( + simd_unit); +} + KRML_ATTRIBUTE_TARGET("avx2") static KRML_MUSTINLINE void libcrux_ml_dsa_simd_avx2_encoding_commitment_serialize(__m256i *simd_unit, @@ -1308,31 +1339,13 @@ libcrux_ml_dsa_simd_avx2_encoding_commitment_serialize(__m256i *simd_unit, uint8_t serialized[19U] = {0U}; switch ((uint8_t)Eurydice_slice_len(out, uint8_t)) { case 4U: { - __m256i adjacent_2_combined = libcrux_intrinsics_avx2_mm256_sllv_epi32( - simd_unit[0U], libcrux_intrinsics_avx2_mm256_set_epi32( - (int32_t)0, (int32_t)28, (int32_t)0, (int32_t)28, - (int32_t)0, (int32_t)28, (int32_t)0, (int32_t)28)); - __m256i adjacent_2_combined0 = libcrux_intrinsics_avx2_mm256_srli_epi64( - (int32_t)28, adjacent_2_combined, __m256i); - __m256i adjacent_4_combined = - libcrux_intrinsics_avx2_mm256_permutevar8x32_epi32( - adjacent_2_combined0, - libcrux_intrinsics_avx2_mm256_set_epi32( - (int32_t)0, (int32_t)0, (int32_t)0, (int32_t)0, (int32_t)6, - (int32_t)2, (int32_t)4, (int32_t)0)); - __m128i adjacent_4_combined0 = - libcrux_intrinsics_avx2_mm256_castsi256_si128(adjacent_4_combined); - __m128i adjacent_4_combined1 = libcrux_intrinsics_avx2_mm_shuffle_epi8( - adjacent_4_combined0, - libcrux_intrinsics_avx2_mm_set_epi8(240U, 240U, 240U, 240U, 240U, - 240U, 240U, 240U, 240U, 240U, - 240U, 240U, 12U, 4U, 8U, 0U)); + Eurydice_slice uu____0 = Eurydice_array_to_subslice2( + serialized, (size_t)0U, (size_t)16U, uint8_t); libcrux_intrinsics_avx2_mm_storeu_bytes_si128( - Eurydice_array_to_subslice2(serialized, (size_t)0U, (size_t)16U, - uint8_t), - adjacent_4_combined1); - Eurydice_slice uu____0 = out; - Eurydice_slice_copy(uu____0, + uu____0, + libcrux_ml_dsa_simd_avx2_encoding_commitment_serialize_4(simd_unit)); + Eurydice_slice uu____1 = out; + Eurydice_slice_copy(uu____1, Eurydice_array_to_subslice2(serialized, (size_t)0U, (size_t)4U, uint8_t), uint8_t); @@ -1379,8 +1392,8 @@ libcrux_ml_dsa_simd_avx2_encoding_commitment_serialize(__m256i *simd_unit, Eurydice_array_to_subslice2(serialized, (size_t)3U, (size_t)19U, uint8_t), upper_3); - Eurydice_slice uu____1 = out; - Eurydice_slice_copy(uu____1, + Eurydice_slice uu____2 = out; + Eurydice_slice_copy(uu____2, Eurydice_array_to_subslice2(serialized, (size_t)0U, (size_t)6U, uint8_t), uint8_t); diff --git a/libcrux-ml-dsa/cg/libcrux_mldsa65_portable.h b/libcrux-ml-dsa/cg/libcrux_mldsa65_portable.h index e3144029b..0022080bf 100644 --- a/libcrux-ml-dsa/cg/libcrux_mldsa65_portable.h +++ b/libcrux-ml-dsa/cg/libcrux_mldsa65_portable.h @@ -8,7 +8,7 @@ * Eurydice: 36a5ed7dd6b61b5cd3d69a010859005912d21537 * Karamel: bf9b89d76dd24e2ceaaca32de3535353e7b6bc01 * F*: 4b3fc11774003a6ff7c09500ecb5f0145ca6d862 - * Libcrux: f3f15359a852405a81628f982f2debc4d50fff30 + * Libcrux: fe232ff426a7be69ae8434cdae0917405ac9d7a2 */ #ifndef __libcrux_mldsa65_portable_H diff --git a/libcrux-ml-dsa/cg/libcrux_sha3_avx2.h b/libcrux-ml-dsa/cg/libcrux_sha3_avx2.h index 7a9eb8d8a..7e528e569 100644 --- a/libcrux-ml-dsa/cg/libcrux_sha3_avx2.h +++ b/libcrux-ml-dsa/cg/libcrux_sha3_avx2.h @@ -8,7 +8,7 @@ * Eurydice: 36a5ed7dd6b61b5cd3d69a010859005912d21537 * Karamel: bf9b89d76dd24e2ceaaca32de3535353e7b6bc01 * F*: 4b3fc11774003a6ff7c09500ecb5f0145ca6d862 - * Libcrux: f3f15359a852405a81628f982f2debc4d50fff30 + * Libcrux: fe232ff426a7be69ae8434cdae0917405ac9d7a2 */ #ifndef __libcrux_sha3_avx2_H diff --git a/libcrux-ml-dsa/cg/libcrux_sha3_portable.h b/libcrux-ml-dsa/cg/libcrux_sha3_portable.h index ccd69bef6..e8576ce72 100644 --- a/libcrux-ml-dsa/cg/libcrux_sha3_portable.h +++ b/libcrux-ml-dsa/cg/libcrux_sha3_portable.h @@ -8,7 +8,7 @@ * Eurydice: 36a5ed7dd6b61b5cd3d69a010859005912d21537 * Karamel: bf9b89d76dd24e2ceaaca32de3535353e7b6bc01 * F*: 4b3fc11774003a6ff7c09500ecb5f0145ca6d862 - * Libcrux: f3f15359a852405a81628f982f2debc4d50fff30 + * Libcrux: fe232ff426a7be69ae8434cdae0917405ac9d7a2 */ #ifndef __libcrux_sha3_portable_H diff --git a/libcrux-ml-dsa/hax.py b/libcrux-ml-dsa/hax.py deleted file mode 100755 index 0f1fc8ec4..000000000 --- a/libcrux-ml-dsa/hax.py +++ /dev/null @@ -1,173 +0,0 @@ -#! /usr/bin/env python3 - -import os -import argparse -import subprocess -import sys - - -def shell(command, expect=0, cwd=None, env={}): - subprocess_stdout = subprocess.DEVNULL - - print("Env:", env) - print("Command: ", end="") - for i, word in enumerate(command): - if i == 4: - print("'{}' ".format(word), end="") - else: - print("{} ".format(word), end="") - - print("\nDirectory: {}".format(cwd)) - - os_env = os.environ - os_env.update(env) - - ret = subprocess.run(command, cwd=cwd, env=os_env) - if ret.returncode != expect: - raise Exception("Error {}. Expected {}.".format(ret, expect)) - - -class extractAction(argparse.Action): - - def __call__(self, parser, args, values, option_string=None) -> None: - # Extract platform interfaces - include_str = "+:** -**::x86::init::cpuid -**::x86::init::cpuid_count" - interface_include = "+**" - cargo_hax_into = [ - "cargo", - "hax", - "into", - "-i", - include_str, - "fstar", - "--z3rlimit", - "80", - "--interfaces", - interface_include, - ] - hax_env = {} - shell( - cargo_hax_into, - cwd="../sys/platform", - env=hax_env, - ) - - # Extract intrinsics interfaces - include_str = "+:**" - interface_include = "+**" - cargo_hax_into = [ - "cargo", - "hax", - "-C", - "--features", - "simd128,simd256", - ";", - "into", - "-i", - include_str, - "fstar", - "--z3rlimit", - "80", - "--interfaces", - interface_include, - ] - hax_env = {} - shell( - cargo_hax_into, - cwd="../libcrux-intrinsics", - env=hax_env, - ) - - # Extract ml-dsa - includes = [ - "+**", - "-libcrux_ml_dsa::hash_functions::portable::*", - "-libcrux_ml_dsa::hash_functions::simd256::*", - "-libcrux_ml_dsa::hash_functions::neon::*", - "+:libcrux_ml_dsa::hash_functions::*::*", - "-**::types::non_hax_impls::**", - ] - include_str = " ".join(includes) - interface_include = "+** -libcrux_ml_dsa::simd::traits" - cargo_hax_into = [ - "cargo", - "hax", - "-C", - "--features", - "simd128,simd256", - ";", - "into", - "-i", - include_str, - "fstar", - "--z3rlimit", - "100", - "--interfaces", - interface_include, - ] - hax_env = {} - shell( - cargo_hax_into, - cwd=".", - env=hax_env, - ) - return None - - -class proveAction(argparse.Action): - - def __call__(self, parser, args, values, option_string=None) -> None: - admit_env = {} - if args.admit: - admit_env = {"OTHERFLAGS": "--admit_smt_queries true"} - shell(["make", "-C", "proofs/fstar/extraction/", "-j4"], env=admit_env) - return None - - -def parse_arguments(): - parser = argparse.ArgumentParser( - description="Libcrux prove script. " - + "Make sure to separate sub-command arguments with --." - ) - subparsers = parser.add_subparsers() - - extract_parser = subparsers.add_parser( - "extract", help="Extract the F* code for the proofs." - ) - extract_parser.add_argument("extract", nargs="*", action=extractAction) - - prover_parser = subparsers.add_parser( - "prove", - help=""" - Run F*. - - This typechecks the extracted code. - To lax-typecheck use --admit. - """, - ) - prover_parser.add_argument( - "--admit", - help="Admit all smt queries to lax typecheck.", - action="store_true", - ) - prover_parser.add_argument( - "prove", - nargs="*", - action=proveAction, - ) - - if len(sys.argv) == 1: - parser.print_help(sys.stderr) - sys.exit(1) - - return parser.parse_args() - - -def main(): - # Don't print unnecessary Python stack traces. - sys.tracebacklimit = 0 - parse_arguments() - - -if __name__ == "__main__": - main() diff --git a/libcrux-ml-dsa/hax.sh b/libcrux-ml-dsa/hax.sh new file mode 100755 index 000000000..b373b2698 --- /dev/null +++ b/libcrux-ml-dsa/hax.sh @@ -0,0 +1,157 @@ +#!/usr/bin/env bash +set -e + +function extract_all() { + extract sys/platform \ + into -i "+:** -**::x86::init::cpuid -**::x86::init::cpuid_count" \ + fstar --z3rlimit 80 --interfaces "+**" + + extract fstar-helpers/minicore into fstar + + extract libcrux-intrinsics \ + -C --features simd128,simd256 ";" \ + into --output-dir proofs/fstar/extraction/temp \ + fstar --z3rlimit 80 + + fixup-minicore libcrux-intrinsics + + extract libcrux-ml-dsa \ + -C --features simd128,simd256 ";" \ + into -i "+**" \ + -i "-libcrux_ml_dsa::hash_functions::portable::*" \ + -i "-libcrux_ml_dsa::hash_functions::simd256::*" \ + -i "-libcrux_ml_dsa::hash_functions::neon::*" \ + -i "+:libcrux_ml_dsa::hash_functions::*::*" \ + -i "-**::types::non_hax_impls::**" \ + --output-dir proofs/fstar/extraction/temp \ + fstar --z3rlimit 80 + + fixup-minicore libcrux-ml-dsa +} + +function prove() { + case "$1" in + --admit) + shift 1 + export OTHERFLAGS="--admit_smt_queries true";; + *);; + esac + go_to "libcrux-ml-dsa" + JOBS="${JOBS:-$(nproc --all)}" + JOBS="${JOBS:-4}" + make -C proofs/fstar/extraction -j $JOBS "$@" +} + +# `fixup-minicore CRATE` adjusts the F* extraction output of `CRATE` to use modules from `minicore` instead of `core`. +# This is necessary because our F* models of the Rust `core` library and the extracted code from `minicore` overlap, +# particularly for `core::arch::*`. The `minicore` versions offer more accurate and specialized models. +# +# This function scans all modules defined in `minicore`, and for each one, replaces every occurrence of `Core` +# in the extracted code of `CRATE` with `Minicore`, but only if that `Minicore` module actually exists in `minicore`. +function fixup-minicore() { + go_to fstar-helpers/minicore/proofs/fstar/extraction + # List all modules provided by minicore + minicore_modules=$(find . -type f -name '*Minicore*' -exec basename {} .fst ';') + + go_to "$1"/proofs/fstar/extraction/temp + + for minicore_module in $minicore_modules; do + core_module="${minicore_module//Minicore/Core}" + msg "$BLUE" "fixup-minicore '$core_module' -> '$minicore_module'" + find . -type f -exec perl -pi -e 's/\Q'"$core_module"'\E/'"$minicore_module"'/g' {} + + done + + find . -type f | while IFS= read -r file; do + if [ ! -e "../$file" ] || ! cmp -s "$file" "../$file"; then + rm -f "../$file" + cat "$file" > "../$file" + fi + done + + cd .. + + rm -f temp/* + rmdir temp +} + +function init_vars() { + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + SCRIPT_PATH="${SCRIPT_DIR}/${SCRIPT_NAME}" + + if [ -t 1 ]; then + BLUE='\033[34m' + GREEN='\033[32m' + BOLD='\033[1m' + RESET='\033[0m' + else + BLUE='' + GREEN='' + BOLD='' + RESET='' + fi +} + +function go_to() { + ROOT="$SCRIPT_DIR/.." + cd "$ROOT" + cd "$1" +} + +function msg() { + echo -e "$1[$SCRIPT_NAME]$RESET $2" +} + +function extract() { + TARGET="$1" + shift 1 + + msg "$BLUE" "extract ${BOLD}$TARGET${RESET}" + go_to "$TARGET" + cargo hax "$@" || { + msg "$RED" "extract extraction failed for ${BOLD}$1${RESET}" + exit 1 + } +} + +function help() { + echo "Libcrux script to extract Rust to F* via hax." + echo "" + echo "Usage: $0 [COMMAND]" + echo "" + echo "Comands:" + echo "" + grep '[#]>' "$SCRIPT_PATH" | sed 's/[)] #[>]/\t/g' + echo "" +} + +function cli() { + if [ -z "$1" ]; then + help + exit 1 + fi + # Check if an argument was provided + + case "$1" in + --help) #> Show help message + help;; + extract) #> Extract the F* code for the proofs. + extract_all + msg "$GREEN" "done" + ;; + prove) #> Run F*. This typechecks the extracted code. To lax-typecheck use --admit. + shift 1 + prove "$@";; + extract+prove) #> Equivalent to extracting and proving. + shift 1 + extract_all + prove "$@";; + *) + echo "Invalid option: $1" + help + exit 1;; + esac +} + +init_vars +cli "$@" diff --git a/libcrux-ml-dsa/src/simd/avx2/encoding/commitment.rs b/libcrux-ml-dsa/src/simd/avx2/encoding/commitment.rs index a373300e7..610a94a35 100644 --- a/libcrux-ml-dsa/src/simd/avx2/encoding/commitment.rs +++ b/libcrux-ml-dsa/src/simd/avx2/encoding/commitment.rs @@ -1,30 +1,52 @@ use libcrux_intrinsics::avx2::*; +#[cfg(hax)] +use minicore::abstractions::{bit::Bit, bitvec::BitVec}; + +#[hax_lib::ensures(|r| { + use hax_lib::*; + let r = BitVec::<128>::from(r); + let simd_unit = BitVec::<256>::from(*simd_unit); + forall(|i: u64| + implies(i < 32, r[i] == simd_unit[i / 4 * 32 + i % 4]) + ) +})] #[inline(always)] +fn serialize_4(simd_unit: &Vec256) -> Vec128 { + // The F* annotation normalizes the body of the function. After normalization, this function is a simple permutation of bits. + #[hax_lib::fstar::before( + r#"[@@(FStar.Tactics.postprocess_with ${BitVec::<128>::postprocess_rewrite})]"# + )] + #[inline(always)] + fn normalized_serialize_4(simd_unit: &Vec256) -> Vec128 { + let adjacent_2_combined = + mm256_sllv_epi32(*simd_unit, mm256_set_epi32(0, 28, 0, 28, 0, 28, 0, 28)); + let adjacent_2_combined = mm256_srli_epi64::<28>(adjacent_2_combined); + + let adjacent_4_combined = mm256_permutevar8x32_epi32( + adjacent_2_combined, + mm256_set_epi32(0, 0, 0, 0, 6, 2, 4, 0), + ); + let adjacent_4_combined = mm256_castsi256_si128(adjacent_4_combined); + let adjacent_4_combined = mm_shuffle_epi8( + adjacent_4_combined, + mm_set_epi8( + 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 12, 4, 8, 0, + ), + ); + adjacent_4_combined + } + normalized_serialize_4(simd_unit) +} + +#[inline(always)] +#[hax_lib::requires(out.len() == 4 || out.len() == 6)] pub(in crate::simd::avx2) fn serialize(simd_unit: &Vec256, out: &mut [u8]) { let mut serialized = [0u8; 19]; match out.len() as u8 { 4 => { - let adjacent_2_combined = - mm256_sllv_epi32(*simd_unit, mm256_set_epi32(0, 28, 0, 28, 0, 28, 0, 28)); - let adjacent_2_combined = mm256_srli_epi64::<28>(adjacent_2_combined); - - let adjacent_4_combined = mm256_permutevar8x32_epi32( - adjacent_2_combined, - mm256_set_epi32(0, 0, 0, 0, 6, 2, 4, 0), - ); - let adjacent_4_combined = mm256_castsi256_si128(adjacent_4_combined); - let adjacent_4_combined = mm_shuffle_epi8( - adjacent_4_combined, - mm_set_epi8( - 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 12, 4, - 8, 0, - ), - ); - - mm_storeu_bytes_si128(&mut serialized[0..16], adjacent_4_combined); - + mm_storeu_bytes_si128(&mut serialized[0..16], serialize_4(simd_unit)); out.copy_from_slice(&serialized[0..4]); } diff --git a/libcrux-ml-kem/hax.py b/libcrux-ml-kem/hax.py index 37c050bbe..4b29194d1 100755 --- a/libcrux-ml-kem/hax.py +++ b/libcrux-ml-kem/hax.py @@ -67,7 +67,9 @@ def __call__(self, parser, args, values, option_string=None) -> None: "--interfaces", interface_include, ] - hax_env = {} + hax_env = { + 'RUSTFLAGS': "--cfg pre_minicore" + } shell( cargo_hax_into, cwd="../libcrux-intrinsics",