diff --git a/CHANGELOG.md b/CHANGELOG.md index 81c3e2b34..b677f0720 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,7 @@ Changes to hax-lib: - Lean lib: add new setup for `bv_decide` (#1828) - Lean lib: base specs on mathematical integers (#1829) - Lean lib: represent `usize` via a copy of `UInt64` (#1829) - - Lean lib: Add support for while loops (#1857) + - Lean lib: Add support for while loops (#1857, #1863) Changes to the Lean backend: - Support for constants with arbitrary computation (#1738) diff --git a/hax-lib/proof-libs/lean/Hax/Lib.lean b/hax-lib/proof-libs/lean/Hax/Lib.lean index 99ff6533c..ec28b95fe 100644 --- a/hax-lib/proof-libs/lean/Hax/Lib.lean +++ b/hax-lib/proof-libs/lean/Hax/Lib.lean @@ -333,19 +333,19 @@ macro "declare_comparison_specs" s:(&"signed" <|> &"unsigned") typeName:ident wi mvcgen [ne]; rw [← @Bool.coe_iff_coe]; simp [x.toInt_inj] @[spec] - def lt_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ lt x y ⦃ ⇓ r => ⌜ r = (x.toInt < y.toInt) ⌝ ⦄ := by + def lt_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ lt x y ⦃ ⇓ r => ⌜ r = decide (x.toInt < y.toInt) ⌝ ⦄ := by mvcgen [lt]; simp [x.lt_iff_toInt_lt] @[spec] - def le_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ le x y ⦃ ⇓ r => ⌜ r = (x.toInt ≤ y.toInt) ⌝ ⦄ := by + def le_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ le x y ⦃ ⇓ r => ⌜ r = decide (x.toInt ≤ y.toInt) ⌝ ⦄ := by mvcgen [le]; simp [x.le_iff_toInt_le] @[spec] - def gt_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ gt x y ⦃ ⇓ r => ⌜ r = (x.toInt > y.toInt ) ⌝ ⦄ := by + def gt_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ gt x y ⦃ ⇓ r => ⌜ r = decide (x.toInt > y.toInt ) ⌝ ⦄ := by mvcgen [gt]; simp [y.lt_iff_toInt_lt] @[spec] - def ge_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ ge x y ⦃ ⇓ r => ⌜ r = (x.toInt ≥ y.toInt) ⌝ ⦄ := by + def ge_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ ge x y ⦃ ⇓ r => ⌜ r = decide (x.toInt ≥ y.toInt) ⌝ ⦄ := by mvcgen [ge]; simp [y.le_iff_toInt_le] end $typeName @@ -362,20 +362,20 @@ macro "declare_comparison_specs" s:(&"signed" <|> &"unsigned") typeName:ident wi mvcgen [ne]; rw [← @Bool.coe_iff_coe]; simp [x.toNat_inj] @[spec] - def lt_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ lt x y ⦃ ⇓ r => ⌜ r = (x.toNat < y.toNat) ⌝ ⦄ := by - mvcgen [lt]; simp [x.lt_iff_toNat_lt] + def lt_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ lt x y ⦃ ⇓ r => ⌜ r = decide (x.toNat < y.toNat) ⌝ ⦄ := by + mvcgen [lt] @[spec] - def le_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ le x y ⦃ ⇓ r => ⌜ r = (x.toNat ≤ y.toNat) ⌝ ⦄ := by - mvcgen [le]; simp [x.le_iff_toNat_le] + def le_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ le x y ⦃ ⇓ r => ⌜ r = decide (x.toNat ≤ y.toNat) ⌝ ⦄ := by + mvcgen [le] @[spec] - def gt_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ gt x y ⦃ ⇓ r => ⌜ r = (x.toNat > y.toNat ) ⌝ ⦄ := by - mvcgen [gt]; simp [y.lt_iff_toNat_lt] + def gt_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ gt x y ⦃ ⇓ r => ⌜ r = decide (x.toNat > y.toNat ) ⌝ ⦄ := by + mvcgen [gt] @[spec] - def ge_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ ge x y ⦃ ⇓ r => ⌜ r = (x.toNat ≥ y.toNat) ⌝ ⦄ := by - mvcgen [ge]; simp [y.le_iff_toNat_le] + def ge_spec (x y : $typeName) : ⦃ ⌜ True ⌝ ⦄ ge x y ⦃ ⇓ r => ⌜ r = decide (x.toNat ≥ y.toNat) ⌝ ⦄ := by + mvcgen [ge] end $typeName ) @@ -747,12 +747,15 @@ def Rust_primitives.Hax.while_loop {β : Type} (init : β) (body : β -> RustM β) (pureInv: - {i : β -> Prop // ∀ b, ⦃⌜ True ⌝⦄ inv b ⦃⇓ r => ⌜ (i b) = r ⌝⦄} := by + {i : β -> Prop // ∀ b, ⦃⌜ True ⌝⦄ inv b ⦃⇓ r => ⌜ r = (i b) ⌝⦄} := by constructor; intro; mvcgen) (pureTermination : - {t : β -> Nat // ∀ b, ⦃⌜ True ⌝⦄ termination b ⦃⇓ r => ⌜ Int.ofNat (t b) = r ⌝⦄} := by + {t : β -> Nat // ∀ b, ⦃⌜ True ⌝⦄ termination b ⦃⇓ r => ⌜ r = Int.ofNat (t b) ⌝⦄} := by + constructor; intro; mvcgen) + (pureCond : + {c : β -> Bool // ∀ b, ⦃⌜ pureInv.val b ⌝⦄ cond b ⦃⇓ r => ⌜ r = c b ⌝⦄} := by constructor; intro; mvcgen) : RustM β := - Loop.MonoLoopCombinator.while_loop Loop.mk cond init body + Loop.MonoLoopCombinator.while_loop Loop.mk pureCond.val init body @[spec] theorem Rust_primitives.Hax.while_loop.spec {β : Type} @@ -761,25 +764,19 @@ theorem Rust_primitives.Hax.while_loop.spec {β : Type} (termination: β → RustM Hax_lib.Int.Int) (init : β) (body : β -> RustM β) - (pureInv: {i : β -> Prop // ∀ b, ⦃⌜ True ⌝⦄ inv b ⦃⇓ r => ⌜ (i b) = r ⌝⦄}) + (pureInv: {i : β -> Prop // ∀ b, ⦃⌜ True ⌝⦄ inv b ⦃⇓ r => ⌜ r = (i b) ⌝⦄}) (pureTermination : - {t : β -> Nat // ∀ b, ⦃⌜ True ⌝⦄ termination b ⦃⇓ r => ⌜ Int.ofNat (t b) = r ⌝⦄}) - (step : ∀ (b : β), - ⦃⌜pureInv.val b⌝⦄ - do - if ← cond b - then ForInStep.yield (← body b) - else ForInStep.done b - ⦃⇓ r => - match r with - | ForInStep.yield b' => - spred(⌜ pureTermination.val b' < pureTermination.val b ⌝ ∧ ⌜ pureInv.val b' ⌝) - | ForInStep.done b' => - ⌜ pureInv.val b' ⌝⦄) : + {t : β -> Nat // ∀ b, ⦃⌜ True ⌝⦄ termination b ⦃⇓ r => ⌜ r = Int.ofNat (t b) ⌝⦄}) + (pureCond : {c : β -> Bool // ∀ b, ⦃⌜ pureInv.val b ⌝⦄ cond b ⦃⇓ r => ⌜ r = c b ⌝⦄}) + (step : + ∀ (b : β), pureCond.val b → + ⦃⌜ pureInv.val b ⌝⦄ + body b + ⦃⇓ b' => spred(⌜ pureTermination.val b' < pureTermination.val b ⌝ ∧ ⌜ pureInv.val b' ⌝)⦄ ) : ⦃⌜ pureInv.val init ⌝⦄ - while_loop inv cond termination init body pureInv pureTermination - ⦃⇓ r => ⌜ pureInv.val r ⌝⦄ := - Spec.MonoLoopCombinator.while_loop init Loop.mk cond body pureInv pureTermination step + while_loop inv cond termination init body pureInv pureTermination pureCond + ⦃⇓ r => ⌜ pureInv.val r ∧ ¬ pureCond.val r ⌝⦄ := + Spec.MonoLoopCombinator.while_loop init Loop.mk pureCond.val body pureInv pureTermination step end Loop /- diff --git a/hax-lib/proof-libs/lean/Hax/MissingLean/Init/While.lean b/hax-lib/proof-libs/lean/Hax/MissingLean/Init/While.lean index e13f951b8..dadd508eb 100644 --- a/hax-lib/proof-libs/lean/Hax/MissingLean/Init/While.lean +++ b/hax-lib/proof-libs/lean/Hax/MissingLean/Init/While.lean @@ -36,12 +36,12 @@ def Loop.MonoLoopCombinator.forIn {β : Type u} {m : Type u → Type v} [Monad m def Loop.MonoLoopCombinator.while_loop {m} {ps : PostShape} {β: Type} [Monad m] [∀ α, Order.CCPO (m α)] [WPMonad m ps] (loop : Loop) - (cond: β → m Bool) + (cond: β → Bool) (init : β) (body : β -> m β) [∀ f : Unit → β → m (ForInStep β), Loop.MonoLoopCombinator f] : m β := Loop.MonoLoopCombinator.forIn loop init fun () s => do - if ← cond s then + if cond s then let s ← body s pure (.yield s) else diff --git a/hax-lib/proof-libs/lean/Hax/MissingLean/Std/Do/Triple/Basic.lean b/hax-lib/proof-libs/lean/Hax/MissingLean/Std/Do/Triple/Basic.lean new file mode 100644 index 000000000..11d6e887b --- /dev/null +++ b/hax-lib/proof-libs/lean/Hax/MissingLean/Std/Do/Triple/Basic.lean @@ -0,0 +1,17 @@ +import Std.Do.Triple.Basic + +namespace Std.Do + +theorem Triple.of_entails_left {m} {ps : PostShape} {β: Type} [Monad m] [WPMonad m ps] + (P Q : Assertion ps) (R : PostCond β ps) (x : m β) (hPR : ⦃P⦄ x ⦃R⦄) (hPQ : Q ⊢ₛ P) : ⦃Q⦄ x ⦃R⦄ := + SPred.entails.trans hPQ hPR + +theorem Triple.of_entails_right {m} {ps : PostShape} {β: Type} [Monad m] [WPMonad m ps] + (P : Assertion ps) (Q R : PostCond β ps) (x : m β) (hPR : ⦃P⦄ x ⦃Q⦄) (hPQ : Q ⊢ₚ R) : ⦃P⦄ x ⦃R⦄ := + SPred.entails.trans hPR (PredTrans.mono _ _ _ hPQ) + +theorem Triple.map {m} {ps : PostShape} {α β} [Monad m] [WPMonad m ps] (f : α → β) + (x : m α) (P : Assertion ps) (Q : PostCond β ps) : + ⦃P⦄ (f <$> x) ⦃Q⦄ ↔ ⦃P⦄ x ⦃(fun a => Q.fst (f a), Q.snd)⦄ := by rw [Triple, WP.map]; rfl + +end Std.Do diff --git a/hax-lib/proof-libs/lean/Hax/MissingLean/Std/Do/Triple/SpecLemmas.lean b/hax-lib/proof-libs/lean/Hax/MissingLean/Std/Do/Triple/SpecLemmas.lean index 80874ab8c..80696f57c 100644 --- a/hax-lib/proof-libs/lean/Hax/MissingLean/Std/Do/Triple/SpecLemmas.lean +++ b/hax-lib/proof-libs/lean/Hax/MissingLean/Std/Do/Triple/SpecLemmas.lean @@ -1,4 +1,5 @@ import Std.Do.Triple.Basic +import Hax.MissingLean.Std.Do.Triple.Basic import Hax.MissingLean.Init.While import Hax.MissingLean.Std.Do.PostCond @@ -11,26 +12,27 @@ theorem Spec.forIn_monoLoopCombinator {m} {ps : PostShape} {β: Type} (loop : Loop) (init : β) (f : Unit → β → m (ForInStep β)) [Loop.MonoLoopCombinator f] - (inv : PostCond β ps) + (inv : β → Prop) (termination : β -> Nat) + (post : β → Prop) (step : ∀ b, - ⦃inv.1 b⦄ + ⦃⌜ inv b ⌝⦄ f () b - ⦃(fun r => match r with - | .yield b' => spred(⌜ termination b' < termination b ⌝ ∧ inv.1 b') - | .done b' => inv.1 b', inv.2)⦄) : - ⦃inv.1 init⦄ Loop.MonoLoopCombinator.forIn loop init f ⦃(fun b => inv.1 b, inv.2)⦄ := by + ⦃⇓ r => match r with + | .yield b' => spred(⌜ termination b' < termination b ⌝ ∧ ⌜ inv b' ⌝) + | .done b' => ⌜ post b' ⌝⦄) : + ⦃⌜ inv init ⌝⦄ Loop.MonoLoopCombinator.forIn loop init f ⦃⇓ b => ⌜ post b ⌝⦄ := by unfold Loop.MonoLoopCombinator.forIn Loop.MonoLoopCombinator.forIn.loop Loop.loopCombinator apply Triple.bind · apply step · rintro (b | b) · refine Triple.pure b ?_ - exact SPred.entails.refl (inv.fst b) + exact SPred.entails.refl _ · apply SPred.imp_elim apply SPred.pure_elim' intro h rw [SPred.entails_true_intro] - apply Spec.forIn_monoLoopCombinator loop _ f inv termination step + apply Spec.forIn_monoLoopCombinator loop _ f inv termination post step termination_by termination init decreasing_by exact h @@ -40,25 +42,20 @@ theorem Spec.MonoLoopCombinator.while_loop {m} {ps : PostShape} {β: Type} [∀ f : Unit → β → m (ForInStep β), Loop.MonoLoopCombinator f] (init : β) (loop : Loop) - (cond: β → m Bool) + (cond: β → Bool) (body : β → m β) (inv: β → Prop) (termination : β → Nat) (step : - ∀ (b : β), + ∀ (b : β), cond b → ⦃⌜ inv b ⌝⦄ - do - if ← cond b - then return ForInStep.yield (← body b) - else return ForInStep.done b - ⦃(fun r => - match r with - | ForInStep.yield b' => spred(⌜ termination b' < termination b ⌝ ∧ ⌜ inv b' ⌝) - | ForInStep.done b' => ⌜ inv b' ⌝, - ExceptConds.false)⦄ ) : + body b + ⦃⇓ b' => spred(⌜ termination b' < termination b ⌝ ∧ ⌜ inv b' ⌝)⦄ ) : ⦃⌜ inv init ⌝⦄ Loop.MonoLoopCombinator.while_loop loop cond init body - ⦃⇓ b => ⌜ inv b ⌝⦄ := by + ⦃⇓ b => ⌜ inv b ∧ ¬ cond b ⌝⦄ := by apply Spec.forIn_monoLoopCombinator - (inv := (fun b => ⌜ inv b ⌝ , ExceptConds.false)) - (step := step) + intro b + by_cases hb : cond b + · simpa [hb, Triple.map] using step b hb + · simp [hb, Triple.pure]