diff --git a/engine/lib/phases/phase_functionalize_loops.ml b/engine/lib/phases/phase_functionalize_loops.ml index a583cbc6e..2c57940cf 100644 --- a/engine/lib/phases/phase_functionalize_loops.ml +++ b/engine/lib/phases/phase_functionalize_loops.ml @@ -39,40 +39,52 @@ struct include Features.SUBTYPE.Id end - type body_and_invariant = { + type loop_annotation_kind = + | LoopInvariant of { index_pat : B.pat option; invariant : B.expr } + | LoopVariant of B.expr + + type loop_annotation = { body : B.expr; - invariant : (B.pat * B.expr) option; + annotation : loop_annotation_kind option; } - let extract_loop_invariant (body : B.expr) : body_and_invariant = - match body.e with - | Let + let extract_loop_annotation (body : B.expr) : loop_annotation = + let rhs_body = + let* (e_let : UB.D.expr_Let) = UB.D.expr_Let body in + let*? _ = Option.is_none e_let.monadic in + let* _ = UB.D.pat_PWild e_let.lhs in + let* app = UB.D.expr_App e_let.rhs in + let* f = UB.D.expr_GlobalVar app.f in + Some (f, app.args, e_let.body) + in + match rhs_body with + | Some + ( f, + [ { e = Closure { params = [ pat ]; body = invariant; _ }; _ } ], + body ) + when Global_ident.eq_name Hax_lib___internal_loop_invariant f -> { - monadic = None; - lhs = { p = PWild; _ }; - rhs = - { - e = - App - { - f = { e = GlobalVar f; _ }; - args = - [ - { - e = - Closure { params = [ pat ]; body = invariant; _ }; - _; - }; - ]; - _; - }; - _; - }; body; + annotation = + Some (LoopInvariant { index_pat = Some pat; invariant }); } - when Global_ident.eq_name Hax_lib___internal_loop_invariant f -> - { body; invariant = Some (pat, invariant) } - | _ -> { body; invariant = None } + | Some (f, [ invariant ], body) + when Global_ident.eq_name Hax_lib___internal_while_loop_invariant f -> + { + body; + annotation = Some (LoopInvariant { index_pat = None; invariant }); + } + | Some (f, [ invariant ], body) + when Global_ident.eq_name Hax_lib___internal_loop_decreases f -> + { body; annotation = Some (LoopVariant invariant) } + | _ -> { body; annotation = None } + + let expect_invariant_variant (annotation1 : loop_annotation_kind option) + (annotation2 : loop_annotation_kind option) : + loop_annotation_kind option * loop_annotation_kind option = + match annotation1 with + | Some (LoopVariant _) -> (annotation2, annotation1) + | _ -> (annotation1, annotation2) type iterator = | Range of { start : B.expr; end_ : B.expr } @@ -144,6 +156,16 @@ struct | None -> Rust_primitives__hax__folds__fold_enumerated_chunked_slice in Some (fold_op, [ size; slice ], usize) + | ChunksExact { size; slice } -> + let fold_op = + match cf with + | Some BreakOrReturn -> + Rust_primitives__hax__folds__fold_chunked_slice_return + | Some BreakOnly -> + Rust_primitives__hax__folds__fold_chunked_slice_cf + | None -> Rust_primitives__hax__folds__fold_chunked_slice + in + Some (fold_op, [ size; slice ], usize) | Enumerate (Slice slice) -> let fold_op = match cf with @@ -206,7 +228,7 @@ struct (M.pat_PWild ~span ~typ:unit.typ, unit) in let body = dexpr body in - let { body; invariant } = extract_loop_invariant body in + let { body; annotation } = extract_loop_annotation body in let it = dexpr it in let pat = dpat pat in let fn : B.expr = UB.make_closure [ bpat; pat ] body body.span in @@ -220,7 +242,13 @@ struct let pat = MS.pat_PWild ~typ in (pat, MS.expr_Literal ~typ:TBool (Bool true)) in - let pat, invariant = Option.value ~default invariant in + let pat, invariant = + match annotation with + | Some (LoopInvariant { index_pat = Some pat; invariant }) + -> + (pat, invariant) + | _ -> default + in UB.make_closure [ bpat; pat ] invariant invariant.span in (f, args @ [ invariant; init; fn ]) @@ -259,6 +287,39 @@ struct (M.pat_PWild ~span ~typ:unit.typ, unit) in let body = dexpr body in + let { body; annotation = annotation1 } = + extract_loop_annotation body + in + let { body; annotation = annotation2 } = + extract_loop_annotation body + in + let invariant, variant = + expect_invariant_variant annotation1 annotation2 + in + let invariant = + match invariant with + | Some (LoopInvariant { index_pat = None; invariant }) -> invariant + | _ -> MS.expr_Literal ~typ:TBool (Bool true) + in + let variant = + match variant with + | Some (LoopVariant variant) -> variant + | _ -> + let kind = { size = S32; signedness = Unsigned } in + let e = + UB.M.expr_Literal ~typ:(TInt kind) ~span:body.span + (Int { value = "0"; negative = false; kind }) + in + UB.call Rust_primitives__hax__int__from_machine [ e ] e.span + (TApp + { + ident = + `Concrete + (Concrete_ident.of_name ~value:false + Hax_lib__int__Int); + args = []; + }) + in let condition = dexpr condition in let condition : B.expr = M.expr_Closure ~params:[ bpat ] ~body:condition ~captures:[] @@ -276,8 +337,13 @@ struct | Some (BreakOnly, _) -> Rust_primitives__hax__while_loop_cf | None -> Rust_primitives__hax__while_loop in - UB.call fold_operator [ condition; init; body ] span - (dty span expr.typ) + let invariant : B.expr = + UB.make_closure [ bpat ] invariant invariant.span + in + let variant = UB.make_closure [ bpat ] variant variant.span in + UB.call fold_operator + [ condition; invariant; variant; init; body ] + span (dty span expr.typ) | Loop { state = None; _ } -> Error.unimplemented ~issue_id:405 ~details:"Loop without mutation" span diff --git a/engine/names/src/lib.rs b/engine/names/src/lib.rs index d4ef63838..d7e7f05d2 100644 --- a/engine/names/src/lib.rs +++ b/engine/names/src/lib.rs @@ -30,6 +30,8 @@ fn dummy_hax_concrete_ident_wrapper>(x: I, mu assert_eq!(1, 1); hax_lib::assert!(true); hax_lib::_internal_loop_invariant(|_: usize| true); + hax_lib::_internal_while_loop_invariant(hax_lib::Prop::from(true)); + hax_lib::_internal_loop_decreases(hax_lib::Int::_unsafe_from_str("0")); fn props() { use hax_lib::prop::*; @@ -222,6 +224,9 @@ mod hax { fn fold_enumerated_chunked_slice() {} fn fold_enumerated_chunked_slice_cf() {} fn fold_enumerated_chunked_slice_return() {} + fn fold_chunked_slice() {} + fn fold_chunked_slice_cf() {} + fn fold_chunked_slice_return() {} fn fold_cf() {} fn fold_return() {} } diff --git a/hax-lib/macros/src/dummy.rs b/hax-lib/macros/src/dummy.rs index 470d9e9b4..dab22dc70 100644 --- a/hax-lib/macros/src/dummy.rs +++ b/hax-lib/macros/src/dummy.rs @@ -188,3 +188,8 @@ pub fn trait_fn_decoration(_attr: TokenStream, _item: TokenStream) -> TokenStrea pub fn loop_invariant(_predicate: TokenStream) -> TokenStream { quote! {}.into() } + +#[proc_macro] +pub fn loop_decreases(_predicate: TokenStream) -> TokenStream { + quote! {}.into() +} diff --git a/hax-lib/macros/src/implementation.rs b/hax-lib/macros/src/implementation.rs index 7eec0184b..0b157e31d 100644 --- a/hax-lib/macros/src/implementation.rs +++ b/hax-lib/macros/src/implementation.rs @@ -63,11 +63,20 @@ pub fn fstar_options(attr: pm::TokenStream, item: pm::TokenStream) -> pm::TokenS /// `coq`...) are in scope. #[proc_macro] pub fn loop_invariant(predicate: pm::TokenStream) -> pm::TokenStream { - let predicate: TokenStream = predicate.into(); + let predicate2: TokenStream = predicate.clone().into(); + let predicate_expr: syn::Expr = parse_macro_input!(predicate); + + let (invariant_f, predicate) = match predicate_expr { + syn::Expr::Closure(_) => (quote!(hax_lib::_internal_loop_invariant), predicate2), + _ => ( + quote!(hax_lib::_internal_while_loop_invariant), + quote!(::hax_lib::Prop::from(#predicate2)), + ), + }; let ts: pm::TokenStream = quote! { #[cfg(#HaxCfgOptionName)] { - hax_lib::_internal_loop_invariant({ + #invariant_f({ #HaxQuantifiers #predicate }) @@ -77,6 +86,28 @@ pub fn loop_invariant(predicate: pm::TokenStream) -> pm::TokenStream { ts } +/// Must be used to prove termination of while loops. This takes an +/// expression that should be a usize that decreases at every iteration +/// +/// This function must be called just after `loop_invariant`, or at the first +/// line of the loop if there is no invariant. +#[proc_macro] +pub fn loop_decreases(predicate: pm::TokenStream) -> pm::TokenStream { + let predicate: TokenStream = predicate.into(); + let ts: pm::TokenStream = quote! { + #[cfg(#HaxCfgOptionName)] + { + hax_lib::_internal_loop_decreases({ + #HaxQuantifiers + use ::hax_lib::int::ToInt; + (#predicate).to_int() + }) + } + } + .into(); + ts +} + /// When extracting to F*, inform about what is the current /// verification status for an item. It can either be `lax` or /// `panic_free`. diff --git a/hax-lib/src/dummy.rs b/hax-lib/src/dummy.rs index fec739180..27f58241d 100644 --- a/hax-lib/src/dummy.rs +++ b/hax-lib/src/dummy.rs @@ -44,7 +44,13 @@ pub fn inline_unsafe(_: &str) -> T { } #[doc(hidden)] -pub fn _internal_loop_invariant, P: FnOnce(T) -> R>(_: P) {} +pub const fn _internal_loop_invariant, P: FnOnce(T) -> R>(_: &P) {} + +#[doc(hidden)] +pub const fn _internal_while_loop_invariant(_: Prop) {} + +#[doc(hidden)] +pub const fn _internal_loop_decreases(_: int::Int) {} pub trait Refinement { type InnerType; diff --git a/hax-lib/src/implementation.rs b/hax-lib/src/implementation.rs index 17394252f..2ad484eca 100644 --- a/hax-lib/src/implementation.rs +++ b/hax-lib/src/implementation.rs @@ -143,6 +143,14 @@ pub fn any_to_unit(_: T) -> () { #[doc(hidden)] pub fn _internal_loop_invariant, P: FnOnce(T) -> R>(_: P) {} +/// A dummy function that holds a while loop invariant. +#[doc(hidden)] +pub const fn _internal_while_loop_invariant(_: Prop) {} + +/// A dummy function that holds a loop variant. +#[doc(hidden)] +pub fn _internal_loop_decreases(_: Int) {} + /// A type that implements `Refinement` should be a newtype for a /// type `T`. The field holding the value of type `T` should be /// private, and `Refinement` should be the only interface to the diff --git a/hax-lib/src/proc_macros.rs b/hax-lib/src/proc_macros.rs index 2d4cb0328..51d140602 100644 --- a/hax-lib/src/proc_macros.rs +++ b/hax-lib/src/proc_macros.rs @@ -2,8 +2,9 @@ //! proc-macro crate cannot export anything but procedural macros. pub use hax_lib_macros::{ - attributes, decreases, ensures, exclude, impl_fn_decoration, include, lemma, loop_invariant, - opaque, opaque_type, refinement_type, requires, trait_fn_decoration, transparent, + attributes, decreases, ensures, exclude, impl_fn_decoration, include, lemma, loop_decreases, + loop_invariant, opaque, opaque_type, refinement_type, requires, trait_fn_decoration, + transparent, }; pub use hax_lib_macros::{ diff --git a/hax-lib/src/prop.rs b/hax-lib/src/prop.rs index 8d3ee3d93..599b640b9 100644 --- a/hax-lib/src/prop.rs +++ b/hax-lib/src/prop.rs @@ -9,7 +9,7 @@ pub struct Prop(bool); /// Hax rewrite more elaborated versions (see `forall` or `AndBit` below) to those monomorphic constructors. pub mod constructors { use super::Prop; - pub fn from_bool(b: bool) -> Prop { + pub const fn from_bool(b: bool) -> Prop { Prop(b) } pub fn and(lhs: Prop, other: Prop) -> Prop { @@ -46,7 +46,7 @@ pub mod constructors { impl Prop { /// Lifts a boolean to a logical proposition. - pub fn from_bool(b: bool) -> Self { + pub const fn from_bool(b: bool) -> Self { constructors::from_bool(b) } /// Conjuction of two propositions. diff --git a/proof-libs/fstar/core/Core.Ops.Arith.fsti b/proof-libs/fstar/core/Core.Ops.Arith.fsti index 9d4071fa0..1feb5123e 100644 --- a/proof-libs/fstar/core/Core.Ops.Arith.fsti +++ b/proof-libs/fstar/core/Core.Ops.Arith.fsti @@ -1,60 +1,60 @@ module Core.Ops.Arith open Rust_primitives - +open Hax_lib.Prop class t_Add self rhs = { [@@@ Tactics.Typeclasses.no_method] f_Output: Type; - f_add_pre: self -> rhs -> bool; - f_add_post: self -> rhs -> f_Output -> bool; + f_add_pre: self -> rhs -> t_Prop; + f_add_post: self -> rhs -> f_Output -> t_Prop; f_add: x:self -> y:rhs -> Pure f_Output (f_add_pre x y) (fun r -> f_add_post x y r); } class t_Sub self rhs = { [@@@ Tactics.Typeclasses.no_method] f_Output: Type; - f_sub_pre: self -> rhs -> bool; - f_sub_post: self -> rhs -> f_Output -> bool; + f_sub_pre: self -> rhs -> t_Prop; + f_sub_post: self -> rhs -> f_Output -> t_Prop; f_sub: x:self -> y:rhs -> Pure f_Output (f_sub_pre x y) (fun r -> f_sub_post x y r); } class t_Mul self rhs = { [@@@ Tactics.Typeclasses.no_method] f_Output: Type; - f_mul_pre: self -> rhs -> bool; - f_mul_post: self -> rhs -> f_Output -> bool; + f_mul_pre: self -> rhs -> t_Prop; + f_mul_post: self -> rhs -> f_Output -> t_Prop; f_mul: x:self -> y:rhs -> Pure f_Output (f_mul_pre x y) (fun r -> f_mul_post x y r); } class t_Div self rhs = { [@@@ Tactics.Typeclasses.no_method] f_Output: Type; - f_div_pre: self -> rhs -> bool; - f_div_post: self -> rhs -> f_Output -> bool; + f_div_pre: self -> rhs -> t_Prop; + f_div_post: self -> rhs -> f_Output -> t_Prop; f_div: x:self -> y:rhs -> Pure f_Output (f_div_pre x y) (fun r -> f_div_post x y r); } class t_AddAssign self rhs = { - f_add_assign_pre: self -> rhs -> bool; - f_add_assign_post: self -> rhs -> self -> bool; + f_add_assign_pre: self -> rhs -> t_Prop; + f_add_assign_post: self -> rhs -> self -> t_Prop; f_add_assign: x:self -> y:rhs -> Pure self (f_add_assign_pre x y) (fun r -> f_add_assign_post x y r); } class t_SubAssign self rhs = { - f_sub_assign_pre: self -> rhs -> bool; - f_sub_assign_post: self -> rhs -> self -> bool; + f_sub_assign_pre: self -> rhs -> t_Prop; + f_sub_assign_post: self -> rhs -> self -> t_Prop; f_sub_assign: x:self -> y:rhs -> Pure self (f_sub_assign_pre x y) (fun r -> f_sub_assign_post x y r); } class t_MulAssign self rhs = { - f_mul_assign_pre: self -> rhs -> bool; - f_mul_assign_post: self -> rhs -> self -> bool; + f_mul_assign_pre: self -> rhs -> t_Prop; + f_mul_assign_post: self -> rhs -> self -> t_Prop; f_mul_assign: x:self -> y:rhs -> Pure self (f_mul_assign_pre x y) (fun r -> f_mul_assign_post x y r); } class t_DivAssign self rhs = { - f_div_assign_pre: self -> rhs -> bool; - f_div_assign_post: self -> rhs -> self -> bool; + f_div_assign_pre: self -> rhs -> t_Prop; + f_div_assign_post: self -> rhs -> self -> t_Prop; f_div_assign: x:self -> y:rhs -> Pure self (f_div_assign_pre x y) (fun r -> f_div_assign_post x y r); } diff --git a/proof-libs/fstar/rust_primitives/Rust_primitives.Hax.Folds.fsti b/proof-libs/fstar/rust_primitives/Rust_primitives.Hax.Folds.fsti index 1b2602965..e5faa46ad 100644 --- a/proof-libs/fstar/rust_primitives/Rust_primitives.Hax.Folds.fsti +++ b/proof-libs/fstar/rust_primitives/Rust_primitives.Hax.Folds.fsti @@ -37,6 +37,26 @@ val fold_enumerated_chunked_slice ) : result: acc_t {inv result (mk_int (Seq.length s / v chunk_size))} +/// Fold function that is generated for `for` loops iterating on +/// `s.chunks_exact(chunk_size)`-like iterators +val fold_chunked_slice + (#t: Type0) (#acc_t: Type0) + (chunk_size: usize {v chunk_size > 0}) + (s: t_Slice t) + (inv: acc_t -> (i:usize) -> Type0) + (init: acc_t {inv init (sz 0)}) + (f: ( acc:acc_t + -> item:(t_Slice t) { + length item == chunk_size /\ + inv acc (sz 0) + } + -> acc':acc_t { + inv acc' (sz 0) + } + ) + ) + : result: acc_t {inv result (mk_int 0)} + (**** `s.enumerate()` *) /// Fold function that is generated for `for` loops iterating on /// `s.enumerate()`-like iterators diff --git a/proof-libs/fstar/rust_primitives/Rust_primitives.Hax.fst b/proof-libs/fstar/rust_primitives/Rust_primitives.Hax.fst index cec5e6303..bdaa678d6 100644 --- a/proof-libs/fstar/rust_primitives/Rust_primitives.Hax.fst +++ b/proof-libs/fstar/rust_primitives/Rust_primitives.Hax.fst @@ -63,7 +63,22 @@ class iterator_return (self: Type u#0): Type u#1 = { parent_iterator: Core.Iter.Traits.Iterator.t_Iterator self; f_fold_return: #b:Type0 -> s:self -> b -> (b -> i:parent_iterator.f_Item{parent_iterator.f_contains s i} -> Core.Ops.Control_flow.t_ControlFlow b b) -> Core.Ops.Control_flow.t_ControlFlow b b; } -let rec while_loop #s (condition: s -> bool) (init: s) (f: (i:s -> o:s{o << i})): s - = if condition init - then while_loop #s condition (f init) f - else init +let while_loop #acc_t + (condition: acc_t -> bool) + (inv: acc_t -> Type0) + (fuel: (a:acc_t{inv a} -> nat)) + (init: acc_t {inv init}) + (f: (i:acc_t{condition i /\ inv i} -> o:acc_t{inv o /\ fuel o < fuel i})): + (res: acc_t {inv res /\ not (condition res)}) + = + let rec while_loop_internal + (current: acc_t {inv current}): + Tot (res: acc_t {inv res /\ not (condition res)}) (decreases (fuel current)) + = if condition current + then + let next = f current in + assert (fuel next < fuel current); + while_loop_internal next + else current in + while_loop_internal init + diff --git a/test-harness/src/snapshots/toolchain__loops into-fstar.snap b/test-harness/src/snapshots/toolchain__loops into-fstar.snap index 4847191ef..99ff3df1e 100644 --- a/test-harness/src/snapshots/toolchain__loops into-fstar.snap +++ b/test-harness/src/snapshots/toolchain__loops into-fstar.snap @@ -208,6 +208,12 @@ let bigger_power_2_ (x: i32) : i32 = Rust_primitives.Hax.while_loop_cf (fun pow -> let pow:i32 = pow in pow <. mk_i32 1000000 <: bool) + (fun pow -> + let pow:i32 = pow in + true) + (fun pow -> + let pow:i32 = pow in + Rust_primitives.Hax.Int.from_machine (mk_u32 0) <: Hax_lib.Int.t_Int) pow (fun pow -> let pow:i32 = pow in @@ -842,6 +848,52 @@ let f (_: Prims.unit) : u8 = Rust_primitives.Hax.while_loop (fun x -> let x:u8 = x in x <. mk_u8 10 <: bool) + (fun x -> + let x:u8 = x in + true) + (fun x -> + let x:u8 = x in + Rust_primitives.Hax.Int.from_machine (mk_u32 0) <: Hax_lib.Int.t_Int) + x + (fun x -> + let x:u8 = x in + let x:u8 = x +! mk_u8 3 in + x) + in + x +! mk_u8 12 + +let while_invariant_decr (_: Prims.unit) : u8 = + let x:u8 = mk_u8 0 in + let x:u8 = + Rust_primitives.Hax.while_loop (fun x -> + let x:u8 = x in + x <. mk_u8 10 <: bool) + (fun x -> + let x:u8 = x in + b2t (x <=. mk_u8 10 <: bool)) + (fun x -> + let x:u8 = x in + Rust_primitives.Hax.Int.from_machine (mk_u8 10 -! x <: u8) <: Hax_lib.Int.t_Int) + x + (fun x -> + let x:u8 = x in + let x:u8 = x +! mk_u8 3 in + x) + in + x +! mk_u8 12 + +let while_invariant_decr_rev (_: Prims.unit) : u8 = + let x:u8 = mk_u8 0 in + let x:u8 = + Rust_primitives.Hax.while_loop (fun x -> + let x:u8 = x in + x <. mk_u8 10 <: bool) + (fun x -> + let x:u8 = x in + b2t (x <=. mk_u8 10 <: bool)) + (fun x -> + let x:u8 = x in + Rust_primitives.Hax.Int.from_machine (mk_u8 10 -! x <: u8) <: Hax_lib.Int.t_Int) x (fun x -> let x:u8 = x in diff --git a/tests/loops/src/lib.rs b/tests/loops/src/lib.rs index 143e7b86a..53f77ce4e 100644 --- a/tests/loops/src/lib.rs +++ b/tests/loops/src/lib.rs @@ -138,6 +138,24 @@ mod while_loops { } x + 12 } + fn while_invariant_decr() -> u8 { + let mut x = 0; + while x < 10 { + hax_lib::loop_invariant!(x <= 10); + hax_lib::loop_decreases!(10 - x); + x = x + 3; + } + x + 12 + } + fn while_invariant_decr_rev() -> u8 { + let mut x = 0; + while x < 10 { + hax_lib::loop_decreases!(10 - x); + hax_lib::loop_invariant!(x <= 10); + x = x + 3; + } + x + 12 + } } mod control_flow {