diff --git a/rust-engine/src/ast/resugared.rs b/rust-engine/src/ast/resugared.rs index 673b74bb5..86a0e0f7c 100644 --- a/rust-engine/src/ast/resugared.rs +++ b/rust-engine/src/ast/resugared.rs @@ -13,13 +13,32 @@ use hax_rust_engine_macros::*; +use super::*; + /// Resugared variants for items. This represent extra printing-only items, see [`super::ItemKind::Resugared`]. #[derive_group_for_ast] pub enum ResugaredItemKind {} /// Resugared variants for expressions. This represent extra printing-only expressions, see [`super::ExprKind::Resugared`]. #[derive_group_for_ast] -pub enum ResugaredExprKind {} +pub enum ResugaredExprKind { + /// Binary operations (identified by resugaring) of the form `f(e1, e2)` + BinOp { + /// The identifier of the operation (`f`) + op: GlobalId, + /// The left-hand side of the operation (`e1`) + lhs: Expr, + /// The right-hand side of the operation (`e2`) + rhs: Expr, + /// The generic arguments applied to the function. + generic_args: Vec, + /// If the function requires generic bounds to be called, `bounds_impls` + /// is a vector of impl. expressions for those bounds. + bounds_impls: Vec, + /// If we apply an associated function, contains the impl. expr used. + trait_: Option<(ImplExpr, Vec)>, + }, +} /// Resugared variants for patterns. This represent extra printing-only patterns, see [`super::PatKind::Resugared`]. #[derive_group_for_ast] @@ -60,10 +79,10 @@ macro_rules! derive_from { } derive_from!( - ResugaredItemKind => super::ItemKind, - ResugaredExprKind => super::ExprKind, - ResugaredPatKind => super::PatKind, - ResugaredTyKind => super::TyKind, - ResugaredImplItemKind => super::ImplItemKind, - ResugaredTraitItemKind => super::TraitItemKind + ResugaredItemKind => ItemKind, + ResugaredExprKind => ExprKind, + ResugaredPatKind => PatKind, + ResugaredTyKind => TyKind, + ResugaredImplItemKind => ImplItemKind, + ResugaredTraitItemKind => TraitItemKind ); diff --git a/rust-engine/src/backends/lean.rs b/rust-engine/src/backends/lean.rs index 3785c3cd7..31a6d2d61 100644 --- a/rust-engine/src/backends/lean.rs +++ b/rust-engine/src/backends/lean.rs @@ -5,6 +5,12 @@ //! source maps). use super::prelude::*; +use crate::resugarings::BinOp; + +mod binops { + pub use crate::names::rust_primitives::hax::machine_int::{add, div, mul, rem, shr, sub}; + pub use crate::names::rust_primitives::hax::{logical_op_and, logical_op_or}; +} /// The Lean printer #[derive(Default)] @@ -13,7 +19,16 @@ impl_doc_allocator_for!(LeanPrinter); impl Printer for LeanPrinter { fn resugaring_phases() -> Vec> { - vec![] + vec![Box::new(BinOp::new(&[ + binops::add(), + binops::sub(), + binops::mul(), + binops::rem(), + binops::div(), + binops::shr(), + binops::logical_op_and(), + binops::logical_op_or(), + ]))] } const NAME: &str = "Lean"; @@ -230,6 +245,41 @@ set_option linter.unusedVariables false .parens() .group() .nest(INDENT), + ExprKind::Resugared(resugared_expr_kind) => match resugared_expr_kind { + ResugaredExprKind::BinOp { + op, + lhs, + rhs, + generic_args: _, + bounds_impls: _, + trait_: _, + } => { + let symbol = if op == &binops::add() { + "+?" + } else if op == &binops::sub() { + "-?" + } else if op == &binops::mul() { + "*?" + } else if op == &binops::div() { + "/?" + } else if op == &binops::rem() { + "%?" + } else if op == &binops::shr() { + ">>>?" + } else if op == &binops::logical_op_and() { + "&&" + } else if op == &binops::logical_op_or() { + "||" + } else { + unreachable!() + }; + // This monad lifting should be handled by a phase/resugaring, see + // https://github.com/cryspen/hax/issues/1620 + docs!["← ", lhs, softline!(), symbol, softline!(), rhs] + .group() + .parens() + } + }, _ => todo!(), } } diff --git a/rust-engine/src/lib.rs b/rust-engine/src/lib.rs index b4c927544..66f3af47b 100644 --- a/rust-engine/src/lib.rs +++ b/rust-engine/src/lib.rs @@ -15,4 +15,5 @@ pub mod names; pub mod ocaml_engine; pub mod phase; pub mod printer; +pub mod resugarings; pub mod symbol; diff --git a/rust-engine/src/resugarings.rs b/rust-engine/src/resugarings.rs new file mode 100644 index 000000000..fc77b903c --- /dev/null +++ b/rust-engine/src/resugarings.rs @@ -0,0 +1,67 @@ +//! The "resugaring" phases used by printers. + +//! This module defines resugarings instances (see +//! [`hax_rust_engine::ast::Resugaring`] for the definition of a +//! resugaring). Each backend defines its own set of resugaring phases. + +use crate::ast::identifiers::global_id::DefId; +use crate::ast::resugared::*; +use crate::ast::visitors::*; +use crate::ast::*; +use crate::printer::*; +use std::collections::HashSet; + +/// Binop resugaring. Used to identify expressions of the form `(f e1 e2)` where +/// `f` is a known identifier. +pub struct BinOp { + /// Stores a set of identifiers that should be resugared as binary + /// operations. Usually, those identifiers come from the hax encoding. Each + /// backend can select its own set of identifiers Typically, if the backend + /// has a special support for addition, `known_ops` will contain + /// `hax::machine::int::add` + pub known_ops: HashSet, +} + +impl BinOp { + /// Adds a new binary operation from a list of (hax-introduced) names + pub fn new(known_ops: &[DefId]) -> Self { + Self { + known_ops: HashSet::from_iter(known_ops.iter().cloned()), + } + } +} + +impl AstVisitorMut for BinOp { + fn enter_expr_kind(&mut self, x: &mut ExprKind) { + let ExprKind::App { + head, + args, + generic_args, + bounds_impls, + trait_, + }: &mut ExprKind = x + else { + return; + }; + let ExprKind::GlobalId(id) = &*head.kind else { + return; + }; + let [lhs, rhs] = &args[..] else { return }; + if self.known_ops.iter().any(|defid| id == defid) { + *x = ExprKind::Resugared(ResugaredExprKind::BinOp { + op: id.clone(), + lhs: lhs.clone(), + rhs: rhs.clone(), + generic_args: generic_args.clone(), + bounds_impls: bounds_impls.clone(), + trait_: trait_.clone(), + }); + } + } +} + +impl Resugaring for BinOp { + fn name(&self) -> String { + "binop".to_string() + } +} diff --git a/test-harness/src/snapshots/toolchain__lean-tests into-lean.snap b/test-harness/src/snapshots/toolchain__lean-tests into-lean.snap index 3d43536c2..165def3f8 100644 --- a/test-harness/src/snapshots/toolchain__lean-tests into-lean.snap +++ b/test-harness/src/snapshots/toolchain__lean-tests into-lean.snap @@ -47,28 +47,23 @@ def FORTYTWO : USize := 42 def MINUS_FORTYTWO : ISize := -42 def returns42 (_ : hax_Tuple0) : Result USize := do FORTYTWO -def add_two_numbers (x : USize) (y : USize) : Result USize := do - (← hax_machine_int_add x y) +def add_two_numbers (x : USize) (y : USize) : Result USize := do (← x +? y) def letBinding (x : USize) (y : USize) : Result USize := do let (useless : hax_Tuple0) ← pure (constr_hax_Tuple0); - let (result1 : USize) ← pure (← hax_machine_int_add x y); - let (result2 : USize) ← pure (← hax_machine_int_add result1 2); - (← hax_machine_int_add result2 1) + let (result1 : USize) ← pure (← x +? y); + let (result2 : USize) ← pure (← result1 +? 2); + (← result2 +? 1) def closure (_ : hax_Tuple0) : Result Int32 := do let (x : Int32) ← pure 41; - let (f1 : Int32 -> Result Int32) ← pure - (fun (y : Int32) => do (← hax_machine_int_add y x)); + let (f1 : Int32 -> Result Int32) ← pure (fun (y : Int32) => do (← y +? x)); let (f2 : Int32 -> Int32 -> Result Int32) ← pure - (fun (y : Int32) (z : Int32) => do (← hax_machine_int_add - (← hax_machine_int_add y x) - z)); - (← hax_machine_int_add - (← ops_function_Fn_call f1 (constr_hax_Tuple1 (hax_Tuple1_Tuple0 := 1))) - (← ops_function_Fn_call - f2 - (constr_hax_Tuple2 (hax_Tuple2_Tuple0 := 2) (hax_Tuple2_Tuple1 := 3)))) + (fun (y : Int32) (z : Int32) => do (← (← y +? x) +? z)); + (← (← ops_function_Fn_call f1 (constr_hax_Tuple1 (hax_Tuple1_Tuple0 := 1))) +? + (← ops_function_Fn_call + f2 + (constr_hax_Tuple2 (hax_Tuple2_Tuple0 := 2) (hax_Tuple2_Tuple1 := 3)))) @[spec] @@ -79,4 +74,7 @@ def test_before_verbatime_single_line (x : UInt8) : Result UInt8 := do 42 def multiline : Unit := () def test_before_verbatim_multi_line (x : UInt8) : Result UInt8 := do 32 + +def binop_resugarings (x : UInt32) : Result UInt32 := do + (← (← (← x +? 1) -? (← (← (← 2 *? 3) %? 4) /? 5)) >>>? 1) ''' diff --git a/tests/lean-tests/src/lib.rs b/tests/lean-tests/src/lib.rs index 60b994067..52380234e 100644 --- a/tests/lean-tests/src/lib.rs +++ b/tests/lean-tests/src/lib.rs @@ -39,3 +39,10 @@ def multiline : Unit := () fn test_before_verbatim_multi_line(x: u8) -> u8 { 32 } + + +// BinOp Resugarings + +fn binop_resugarings(x:u32) -> u32 { + x + 1 - 2 * 3 % 4 / 5 >> 1 +}