diff --git a/regression/cbmc-java/virtual10/A.class b/regression/cbmc-java/virtual10/A.class new file mode 100644 index 00000000000..a5a66b250dc Binary files /dev/null and b/regression/cbmc-java/virtual10/A.class differ diff --git a/regression/cbmc-java/virtual10/B.class b/regression/cbmc-java/virtual10/B.class new file mode 100644 index 00000000000..7102b9334fd Binary files /dev/null and b/regression/cbmc-java/virtual10/B.class differ diff --git a/regression/cbmc-java/virtual10/C.class b/regression/cbmc-java/virtual10/C.class new file mode 100644 index 00000000000..6ee4ed02d0b Binary files /dev/null and b/regression/cbmc-java/virtual10/C.class differ diff --git a/regression/cbmc-java/virtual10/D.class b/regression/cbmc-java/virtual10/D.class new file mode 100644 index 00000000000..8d91a3e0dba Binary files /dev/null and b/regression/cbmc-java/virtual10/D.class differ diff --git a/regression/cbmc-java/virtual10/E.class b/regression/cbmc-java/virtual10/E.class new file mode 100644 index 00000000000..3f6d0cf934e Binary files /dev/null and b/regression/cbmc-java/virtual10/E.class differ diff --git a/regression/cbmc-java/virtual10/O.class b/regression/cbmc-java/virtual10/O.class new file mode 100644 index 00000000000..3cdd96b6a51 Binary files /dev/null and b/regression/cbmc-java/virtual10/O.class differ diff --git a/regression/cbmc-java/virtual10/test.desc b/regression/cbmc-java/virtual10/test.desc new file mode 100644 index 00000000000..df1d033f0a4 --- /dev/null +++ b/regression/cbmc-java/virtual10/test.desc @@ -0,0 +1,6 @@ +CORE +E.class +--show-goto-functions +IF.*"java::D" +IF.*"java::O" +IF.*"java::C" diff --git a/regression/cbmc-java/virtual10/test.java b/regression/cbmc-java/virtual10/test.java new file mode 100644 index 00000000000..cfca1db92ea --- /dev/null +++ b/regression/cbmc-java/virtual10/test.java @@ -0,0 +1,35 @@ +interface A { + public int f(); +} +interface B { + public int g(); +} + +class O { + public String toString() { + return "O"; + } +} + +class D extends O implements A, B { + public int f() { + return 0; + } + public int g() { + return 1; + } +} + +class C extends D { + public String toString() { + return "C"; + } +} + +class E { + C c; + D d; + String f(Object o) { + return o.toString(); + } +} diff --git a/regression/cbmc-java/virtual6/test.desc b/regression/cbmc-java/virtual6/test.desc index 5a5a7c7bf3b..9ee1b819301 100644 --- a/regression/cbmc-java/virtual6/test.desc +++ b/regression/cbmc-java/virtual6/test.desc @@ -3,5 +3,4 @@ A.class --function A.main --show-goto-functions ^EXIT=0$ ^SIGNAL=0$ -IF "java::C".*THEN GOTO -IF "java::B".*THEN GOTO +IF "java::A".*THEN GOTO diff --git a/regression/cbmc-java/virtual7/test.desc b/regression/cbmc-java/virtual7/test.desc index b35404ac657..d873bbcc8f9 100644 --- a/regression/cbmc-java/virtual7/test.desc +++ b/regression/cbmc-java/virtual7/test.desc @@ -3,9 +3,6 @@ test.class --show-goto-functions ^EXIT=0$ ^SIGNAL=0$ -IF "java::E".*THEN GOTO [12] -IF "java::B".*THEN GOTO [12] -IF "java::D".*THEN GOTO [12] -IF "java::C".*THEN GOTO [12] --- -IF "java::A".*THEN GOTO +IF.*"java::C".*THEN GOTO [12] +IF.*"java::D".*THEN GOTO [12] +IF.*"java::A".*THEN GOTO [12] diff --git a/src/goto-programs/remove_virtual_functions.cpp b/src/goto-programs/remove_virtual_functions.cpp index ba97916ff78..00a9c459b0d 100644 --- a/src/goto-programs/remove_virtual_functions.cpp +++ b/src/goto-programs/remove_virtual_functions.cpp @@ -57,7 +57,7 @@ class remove_virtual_functionst const symbol_exprt &, const irep_idt &, dispatch_table_entriest &, - std::set &visited, + dispatch_table_entries_mapt &, const function_call_resolvert &) const; exprt get_method(const irep_idt &class_id, const irep_idt &component_name) const; @@ -163,11 +163,18 @@ void remove_virtual_functionst::remove_virtual_function( newinst->source_location=vcall_source_loc; } + // get initial identifier for grouping + INVARIANT(!functions.empty(), "Function dispatch table cannot be empty."); + auto last_id = functions.back().symbol_expr.get_identifier(); + // record class_ids for disjunction + std::set class_ids; + std::map calls; // Note backwards iteration, to get the fallback candidate first. for(auto it=functions.crbegin(), itend=functions.crend(); it!=itend; ++it) { const auto &fun=*it; + class_ids.insert(fun.class_id); auto insertit=calls.insert( {fun.symbol_expr.get_identifier(), goto_programt::targett()}); @@ -209,15 +216,50 @@ void remove_virtual_functionst::remove_virtual_function( t3->make_goto(t_final, true_exprt()); } + // Emit target if end of dispatch table is reached or if the next element is + // dispatched to another function call. Assumes entries in the functions + // variable to be sorted for the identifier of the function to be called. + auto l_it = std::next(it); + bool next_emit_target = + (l_it == functions.crend()) || + l_it->symbol_expr.get_identifier() != fun.symbol_expr.get_identifier(); + + // The root function call is done via fall-through, so nothing to emit + // explicitly for this. + if(next_emit_target && fun.symbol_expr == last_function_symbol) + { + class_ids.clear(); + } + // If this calls the fallback function we just fall through. // Otherwise branch to the right call: if(fallback_action!=virtual_dispatch_fallback_actiont::CALL_LAST_FUNCTION || fun.symbol_expr!=last_function_symbol) { - exprt c_id1=constant_exprt(fun.class_id, string_typet()); - goto_programt::targett t4=new_code_gotos.add_instruction(); - t4->source_location=vcall_source_loc; - t4->make_goto(insertit.first->second, equal_exprt(c_id1, c_id2)); + // create a disjunction of class_ids to test + if(next_emit_target && fun.symbol_expr != last_function_symbol) + { + exprt::operandst or_ops; + for(const auto &id : class_ids) + { + const constant_exprt c_id1(id, string_typet()); + const equal_exprt class_id_test(c_id1, c_id2); + or_ops.push_back(class_id_test); + } + + goto_programt::targett t4 = new_code_gotos.add_instruction(); + t4->source_location = vcall_source_loc; + t4->make_goto(insertit.first->second, disjunction(or_ops)); + + last_id = fun.symbol_expr.get_identifier(); + class_ids.clear(); + } + // record class_id + else if(next_emit_target) + { + last_id = fun.symbol_expr.get_identifier(); + class_ids.clear(); + } } } @@ -252,11 +294,12 @@ void remove_virtual_functionst::remove_virtual_function( /// Used by get_functions to track the most-derived parent that provides an /// override of a given function. -/// \par parameters: `this_id`: class name -/// `last_method_defn`: the most-derived parent of `this_id` to define the -/// requested function -/// `component_name`: name of the function searched for -/// `resolve_function_call`: function to resolve abstract method call +/// \param parameters: `this_id`: class name +/// \param `last_method_defn`: the most-derived parent of `this_id` to define +/// the requested function +/// \param `component_name`: name of the function searched for +/// \param `entry_map`: map of class identifiers to dispatch table entries +/// \param `resolve_function_call`: function to resolve abstract method call /// \return `functions` is assigned a list of {class name, function symbol} /// pairs indicating that if `this` is of the given class, then the call will /// target the given function. Thus if A <: B <: C and A and C provide @@ -267,7 +310,7 @@ void remove_virtual_functionst::get_child_functions_rec( const symbol_exprt &last_method_defn, const irep_idt &component_name, dispatch_table_entriest &functions, - std::set &visited, + dispatch_table_entries_mapt &entry_map, const function_call_resolvert &resolve_function_call) const { auto findit=class_hierarchy.class_map.find(this_id); @@ -276,9 +319,18 @@ void remove_virtual_functionst::get_child_functions_rec( for(const auto &child : findit->second.children) { - if(!visited.insert(child).second) + // Skip if we have already visited this and we found a function call that + // did not resolve to non java.lang.Object. + auto it = entry_map.find(child); + if( + it != entry_map.end() && + !has_prefix( + id2string(it->second.symbol_expr.get_identifier()), + "java::java.lang.Object")) + { continue; - exprt method=get_method(child, component_name); + } + exprt method = get_method(child, component_name); dispatch_table_entryt function(child); if(method.is_not_nil()) { @@ -305,37 +357,43 @@ void remove_virtual_functionst::get_child_functions_rec( } } functions.push_back(function); + entry_map.insert({child, function}); get_child_functions_rec( child, function.symbol_expr, component_name, functions, - visited, + entry_map, resolve_function_call); } } +/// Used to get dispatch entries to call for the given function +/// \param function: function that should be called +/// \param[out] functions: is assigned a list of dispatch entries, i.e., pairs +/// of class names and function symbol to call when encountering the class. void remove_virtual_functionst::get_functions( const exprt &function, dispatch_table_entriest &functions) { + // class part of function to call const irep_idt class_id=function.get(ID_C_class); const std::string class_id_string(id2string(class_id)); - const irep_idt component_name=function.get(ID_component_name); - const std::string component_name_string(id2string(component_name)); + const irep_idt function_name = function.get(ID_component_name); + const std::string function_name_string(id2string(function_name)); INVARIANT(!class_id.empty(), "All virtual functions must have a class"); resolve_concrete_function_callt get_virtual_call_target( symbol_table, class_hierarchy); const function_call_resolvert resolve_function_call = [&get_virtual_call_target]( - const irep_idt &class_id, const irep_idt &component_name) { - return get_virtual_call_target(class_id, component_name); + const irep_idt &class_id, const irep_idt &function_name) { + return get_virtual_call_target(class_id, function_name); }; const resolve_concrete_function_callt::concrete_function_callt - &resolved_call = get_virtual_call_target(class_id, component_name); + &resolved_call = get_virtual_call_target(class_id, function_name); dispatch_table_entryt root_function; @@ -357,17 +415,37 @@ void remove_virtual_functionst::get_functions( } // iterate over all children, transitively - std::set visited; + dispatch_table_entries_mapt entry_map; get_child_functions_rec( class_id, root_function.symbol_expr, - component_name, + function_name, functions, - visited, + entry_map, resolve_function_call); if(root_function.symbol_expr!=symbol_exprt()) functions.push_back(root_function); + + // Sort for the identifier of the function call symbol expression, grouping + // together calls to the same function. Keep java.lang.Object entries at the + // end for fall through. The reasoning is that this is the case with most + // entries in realistic cases. + std::sort( + functions.begin(), + functions.end(), + [&root_function](const dispatch_table_entryt &a, dispatch_table_entryt &b) { + if( + has_prefix( + id2string(a.symbol_expr.get_identifier()), "java::java.lang.Object")) + return false; + else if( + has_prefix( + id2string(b.symbol_expr.get_identifier()), "java::java.lang.Object")) + return true; + else + return a.symbol_expr.get_identifier() < b.symbol_expr.get_identifier(); + }); } exprt remove_virtual_functionst::get_method( diff --git a/src/goto-programs/remove_virtual_functions.h b/src/goto-programs/remove_virtual_functions.h index bd5b396341a..fc06c3e8c78 100644 --- a/src/goto-programs/remove_virtual_functions.h +++ b/src/goto-programs/remove_virtual_functions.h @@ -56,6 +56,7 @@ class dispatch_table_entryt }; typedef std::vector dispatch_table_entriest; +typedef std::map dispatch_table_entries_mapt; void remove_virtual_function( goto_modelt &goto_model, diff --git a/unit/Makefile b/unit/Makefile index 0be65a71b80..91dc62345cf 100644 --- a/unit/Makefile +++ b/unit/Makefile @@ -43,6 +43,7 @@ SRC += unit_tests.cpp \ util/simplify_expr.cpp \ util/symbol_table.cpp \ catch_example.cpp \ + java_bytecode/java_virtual_functions/virtual_functions.cpp \ # Empty last line INCLUDES= -I ../src/ -I. diff --git a/unit/java_bytecode/java_virtual_functions/A.class b/unit/java_bytecode/java_virtual_functions/A.class new file mode 100644 index 00000000000..a5a66b250dc Binary files /dev/null and b/unit/java_bytecode/java_virtual_functions/A.class differ diff --git a/unit/java_bytecode/java_virtual_functions/B.class b/unit/java_bytecode/java_virtual_functions/B.class new file mode 100644 index 00000000000..7102b9334fd Binary files /dev/null and b/unit/java_bytecode/java_virtual_functions/B.class differ diff --git a/unit/java_bytecode/java_virtual_functions/C.class b/unit/java_bytecode/java_virtual_functions/C.class new file mode 100644 index 00000000000..6ee4ed02d0b Binary files /dev/null and b/unit/java_bytecode/java_virtual_functions/C.class differ diff --git a/unit/java_bytecode/java_virtual_functions/D.class b/unit/java_bytecode/java_virtual_functions/D.class new file mode 100644 index 00000000000..8d91a3e0dba Binary files /dev/null and b/unit/java_bytecode/java_virtual_functions/D.class differ diff --git a/unit/java_bytecode/java_virtual_functions/E.class b/unit/java_bytecode/java_virtual_functions/E.class new file mode 100644 index 00000000000..3f6d0cf934e Binary files /dev/null and b/unit/java_bytecode/java_virtual_functions/E.class differ diff --git a/unit/java_bytecode/java_virtual_functions/O.class b/unit/java_bytecode/java_virtual_functions/O.class new file mode 100644 index 00000000000..3cdd96b6a51 Binary files /dev/null and b/unit/java_bytecode/java_virtual_functions/O.class differ diff --git a/unit/java_bytecode/java_virtual_functions/test.java b/unit/java_bytecode/java_virtual_functions/test.java new file mode 100644 index 00000000000..cfca1db92ea --- /dev/null +++ b/unit/java_bytecode/java_virtual_functions/test.java @@ -0,0 +1,35 @@ +interface A { + public int f(); +} +interface B { + public int g(); +} + +class O { + public String toString() { + return "O"; + } +} + +class D extends O implements A, B { + public int f() { + return 0; + } + public int g() { + return 1; + } +} + +class C extends D { + public String toString() { + return "C"; + } +} + +class E { + C c; + D d; + String f(Object o) { + return o.toString(); + } +} diff --git a/unit/java_bytecode/java_virtual_functions/virtual_functions.cpp b/unit/java_bytecode/java_virtual_functions/virtual_functions.cpp new file mode 100644 index 00000000000..128c19b2e5b --- /dev/null +++ b/unit/java_bytecode/java_virtual_functions/virtual_functions.cpp @@ -0,0 +1,120 @@ +/*******************************************************************\ + + Module: Unit tests for java_types + + Author: DiffBlue Limited. All rights reserved. + +\*******************************************************************/ + +#include +#include +#include + +#include +#include +#include + +#include + +void check_function_call( + const equal_exprt &eq_expr, + const irep_idt &class_name, + const irep_idt &function_name, + const goto_programt::targetst &targets) +{ + REQUIRE(eq_expr.op0().id() == ID_constant); + REQUIRE(eq_expr.op0().type().id() == ID_string); + REQUIRE(to_constant_expr(eq_expr.op0()).get_value() == class_name); + + REQUIRE(targets.size() == 1); + + for(const auto &target : targets) + { + REQUIRE(target->type == goto_program_instruction_typet::FUNCTION_CALL); + const code_function_callt call = to_code_function_call(target->code); + REQUIRE(call.function().id() == ID_symbol); + REQUIRE(to_symbol_expr(call.function()).get_identifier() == function_name); + } +} + +SCENARIO( + "load class with virtual method call, resolve to all valid calls", + "[core][java_bytecode][virtual_functions]") +{ + config.set_arch("none"); + GIVEN("A class with a call to java.lang.Object.toString()") + { + const symbol_tablet &symbol_table = + load_java_class("E", "./java_bytecode/java_virtual_functions", "E.f"); + + const std::string function_name = + "java::E.f:(Ljava/lang/Object;)Ljava/lang/String;"; + + WHEN("The entry point function is generated") + { + symbol_tablet new_table(symbol_table); + null_message_handlert null_output; + goto_functionst new_goto_functions; + goto_convert(new_table, new_goto_functions, null_output); + remove_virtual_functions(new_table, new_goto_functions); + + bool found_function = false; + for(const auto &fun : new_goto_functions.function_map) + { + if(fun.first == function_name) + { + const goto_programt &goto_program = fun.second.body; + found_function = true; + for(const auto &instruction : goto_program.instructions) + { + // There should be two guarded GOTOs with non-constant guards. One + // branching for class C and one for class D or O. + if(instruction.type == goto_program_instruction_typet::GOTO) + { + if(instruction.guard.id() == ID_equal) + { + THEN("Class C should call its specific method") + { + const equal_exprt &eq_expr = to_equal_expr(instruction.guard); + check_function_call( + eq_expr, + "java::C", + "java::C.toString:()Ljava/lang/String;", + instruction.targets); + } + } + + else if(instruction.guard.id() == ID_or) + { + THEN("Classes D and O should both call O.toString()") + { + const or_exprt &disjunction = to_or_expr(instruction.guard); + REQUIRE( + (disjunction.op0().id() == ID_equal && + disjunction.op1().id() == ID_equal)); + const equal_exprt &eq_expr0 = + to_equal_expr(disjunction.op0()); + const equal_exprt &eq_expr1 = + to_equal_expr(disjunction.op1()); + + check_function_call( + eq_expr0, + "java::D", + "java::O.toString:()Ljava/lang/String;", + instruction.targets); + check_function_call( + eq_expr1, + "java::O", + "java::O.toString:()Ljava/lang/String;", + instruction.targets); + } + } + } + } + } + } + + REQUIRE(found_function); + } + } +} diff --git a/unit/pointer-analysis/custom_value_set_analysis.cpp b/unit/pointer-analysis/custom_value_set_analysis.cpp index db657133a51..8d632ff6a60 100644 --- a/unit/pointer-analysis/custom_value_set_analysis.cpp +++ b/unit/pointer-analysis/custom_value_set_analysis.cpp @@ -168,12 +168,14 @@ SCENARIO("test_value_set_analysis", { GIVEN("Normal and custom value-set analysis of CustomVSATest::test") { + config.set_arch("none"); + config.main = ""; null_message_handlert null_output; cmdlinet command_line; // This classpath is the default, but the config object // is global and previous unit tests may have altered it - command_line.set("java-cp-include-files", "."); + command_line.set("java-cp-include-files", "CustomVSATest.class"); config.java.classpath={"."}; command_line.args.push_back("pointer-analysis/CustomVSATest.jar");