77#include < gsl/gsl>
88#include < memory>
99#include < vector>
10+ #include < fstream>
1011
1112#include " core/common/common.h"
1213#include " core/framework/tensorprotoutils.h"
1314#include " core/framework/tensor_type_and_shape.h"
1415#include " core/framework/onnxruntime_typeinfo.h"
1516#include " core/session/onnxruntime_cxx_api.h"
17+ #include " core/graph/ep_api_types.h"
18+ #include " core/graph/graph_proto_serializer.h"
1619
1720#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL
1821#include " core/providers/utils/ort_graph_to_proto.h"
@@ -31,6 +34,7 @@ namespace test {
3134// forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent
3235// to a graph represented by the internal ORT GraphViewer class.
3336static void CheckGraphCApi (const GraphViewer& graph_viewer, const OrtGraph& api_graph);
37+ static void Check_Graph_GetSubgraph (const OrtGraph& api_graph);
3438
3539//
3640// Tests
@@ -73,6 +77,16 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) {
7377 CheckGraphCApi (test_graph->GetGraphViewer (), test_graph->GetOrtGraph ());
7478}
7579
80+ TEST (EpGraphTest, Check3LayerNestedSubgraphV2) {
81+ // The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test.
82+ // The model consists of a graph with subgraphs nested across three levels.
83+ // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer).
84+ auto test_graph = TestGraph::Load (ORT_TSTR (" testdata/three_layer_nested_subgraph_v2.onnx" ));
85+ ASSERT_NE (test_graph, nullptr ) << " Failed to load test model" ;
86+
87+ CheckGraphCApi (test_graph->GetGraphViewer (), test_graph->GetOrtGraph ());
88+ }
89+
7690static void RunMNISTModel (const ORTCHAR_T* model_path, std::vector<float >& output_data) {
7791 auto memory_info = Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeCPU);
7892 Ort::SessionOptions sess_options;
@@ -474,6 +488,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span<const
474488 }
475489}
476490
491+ // Checks the Graph_GetSubgraph C API
492+ static void Check_Graph_GetSubgraph (const OrtGraph& api_graph) {
493+ const OrtApi& ort_api = Ort::GetApi ();
494+
495+ // Get all the nodes
496+ size_t num_nodes = 0 ;
497+ ASSERT_ORTSTATUS_OK (ort_api.Graph_GetNumNodes (&api_graph, &num_nodes));
498+
499+ std::vector<const OrtNode*> nodes (num_nodes);
500+ ASSERT_ORTSTATUS_OK (ort_api.Graph_GetNodes (&api_graph, nodes.data (), nodes.size ()));
501+
502+ // Select a half of nodes to create a OrtGraph
503+ size_t num_selected_nodes = std::max ((nodes.size () >> 1 ), (size_t )1 );
504+ std::vector<const OrtNode*> selected_nodes (num_selected_nodes);
505+
506+ for (size_t i = 0 ; i < num_selected_nodes; i++) {
507+ selected_nodes[i] = nodes[i];
508+ }
509+
510+ OrtGraph* sub_graph;
511+ ASSERT_ORTSTATUS_OK (ort_api.Graph_GetGraphView (&api_graph, selected_nodes.data (), selected_nodes.size (), &sub_graph));
512+
513+ // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk.
514+ // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw.
515+ const GraphViewer& sub_graph_viewer = EpGraph::ToInternal (sub_graph)->GetGraphViewer ();
516+ std::unique_ptr<Model> model = std::make_unique<Model>(sub_graph_viewer.Name (), true , sub_graph_viewer.GetGraph ().GetLogger ());
517+ auto model_proto = std::make_unique<ONNX_NAMESPACE::ModelProto>(model->ToProto ());
518+ GraphViewerToProto (sub_graph_viewer, *model_proto->mutable_graph (), true , true , static_cast <ExecutionOrder>(1 ));
519+ model_proto->set_ir_version (ONNX_NAMESPACE::Version::IR_VERSION);
520+
521+ const char * graph_name = nullptr ;
522+ ASSERT_ORTSTATUS_OK (ort_api.Graph_GetName (&api_graph, &graph_name));
523+ std::string name = graph_name;
524+ name += " _half.onnx" ;
525+
526+ // Dump the graph for debugging
527+ // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary);
528+ // model_proto->SerializeToOstream(&dump);
529+
530+ ort_api.ReleaseGraph (sub_graph);
531+ }
532+
477533// Checks that the contents of the original GraphViewer matches the contents of the OrtGraph.
478534// Uses the public C APIs to traverse the OrtGraph.
479535static void CheckGraphCApi (const GraphViewer& graph_viewer, const OrtGraph& api_graph) {
@@ -682,6 +738,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
682738 }
683739 }
684740 }
741+
742+ // Check creating an OrtGraph from a subset of nodes in an OrtGraph
743+ Check_Graph_GetSubgraph (api_graph);
685744}
686745
687746} // namespace test
0 commit comments