diff --git a/graph/utils/graph_utils.cc b/graph/utils/graph_utils.cc index 879db5386c3ced6041c1a2288c49f7e53ce5a5cc..d6979c17376af2c1dccbb6b0c14798583a4d0ac3 100644 --- a/graph/utils/graph_utils.cc +++ b/graph/utils/graph_utils.cc @@ -1775,6 +1775,67 @@ graphStatus ReplaceControlAnchors(const NodePtr &new_node, const NodePtr &old_no return GRAPH_SUCCESS; } + +// check refdata in subgraph is ref from inner data +graphStatus CheckIsRefFromInnerData(const OutDataAnchorPtr &out_data_anchor, NodePtr &inner_data, + bool &is_ref_from_innerdata) { + is_ref_from_innerdata = false; + const auto owner_node = out_data_anchor->GetOwnerNode(); + if (owner_node->GetType() != REFDATA) { + return GRAPH_SUCCESS; + } + GE_ASSERT_NOTNULL(owner_node->GetOwnerComputeGraph()); + if (owner_node->GetOwnerComputeGraph()->GetParentNode() == nullptr) { + return GRAPH_SUCCESS; + } + const auto &peer_out_control_anchor = owner_node->GetInControlAnchor()->GetPeerOutControlAnchors(); + if (peer_out_control_anchor.empty() || peer_out_control_anchor.size() > 1u) { + GELOGE(GRAPH_FAILED, "Invalid graph. Refdata[%s] in subgraph[%s] should has one control edge from inner data.", + owner_node->GetNamePtr(), owner_node->GetOwnerComputeGraph()->GetName().c_str()); + return GRAPH_FAILED; + } + const auto peer_in_ctrl_node = peer_out_control_anchor.at(0U)->GetOwnerNode(); + GE_ASSERT_NOTNULL(peer_in_ctrl_node); + + if (!OpTypeUtils::IsSubgraphInnerData(peer_in_ctrl_node->GetOpDesc())) { + GELOGE(GRAPH_FAILED, + "Invalid graph. Refdata[%s] in subgraph[%s] should has one control edge from inner data, but current is " + "[%s][%s].", + owner_node->GetNamePtr(), owner_node->GetOwnerComputeGraph()->GetName().c_str(), + peer_in_ctrl_node->GetNamePtr(), peer_in_ctrl_node->GetTypePtr()); + return GRAPH_FAILED; + } + inner_data = peer_in_ctrl_node; + is_ref_from_innerdata = true; + return GRAPH_SUCCESS; +} + +graphStatus CheckIsRefFromRefData(const OutDataAnchorPtr &out_data_anchor, NodePtr &ref_data, + bool &is_ref_from_refdata) { + is_ref_from_refdata = false; + const auto owner_node = out_data_anchor->GetOwnerNode(); + const auto out_desc = owner_node->GetOpDesc()->GetOutputDescPtr(static_cast(out_data_anchor->GetIdx())); + GE_ASSERT_NOTNULL(out_desc); + std::string ref_var_src_var_name; + bool has_ref_attr = ge::AttrUtils::GetStr(out_desc, REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); + if (!has_ref_attr) { + return GRAPH_SUCCESS; + } + // find src ref_data_node + const auto &ower_graph = owner_node->GetOwnerComputeGraph(); + GE_ASSERT_NOTNULL(ower_graph); + const auto ref_data_node = ower_graph->FindNode(ref_var_src_var_name); + if (ref_data_node == nullptr) { + GELOGW("Can not find refdata named %s. Please check ref relation on graph.", ref_var_src_var_name.c_str()); + return GRAPH_SUCCESS; + } + if (ref_data_node->GetType() != REFDATA) { + return GRAPH_SUCCESS; + } + ref_data = ref_data_node; + is_ref_from_refdata = true; + return GRAPH_SUCCESS;; +} } // namespace GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNode(const NodePtr &node, @@ -2685,10 +2746,11 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, } NodePtr ref_node; - const bool is_ref_from_refdata = IsRefFromRefData(out_data_anchor, ref_node); + bool is_ref_from_refdata = false; + GE_ASSERT_GRAPH_SUCCESS(CheckIsRefFromOther(out_data_anchor, ref_node, is_ref_from_refdata)); NodeIndexIO exist_ref_data_info(ref_node, 0U, kOut); if (is_ref_from_refdata && (anchor_to_symbol.find(exist_ref_data_info.ToString()) != anchor_to_symbol.end())) { - GELOGD("Node %s output:%d is ref form refdata: %s.", node->GetName().c_str(), out_data_anchor->GetIdx(), + GELOGD("Node %s output:%d is ref form node: %s.", node->GetName().c_str(), out_data_anchor->GetIdx(), exist_ref_data_info.ToString().c_str()); GE_ASSERT_GRAPH_SUCCESS( UpdateRefMapping(cur_node_info, exist_ref_data_info, symbol_to_anchors, anchor_to_symbol)); @@ -3086,30 +3148,16 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t return IsNoPaddingRefFromInput(out_data_anchor, reuse_in_index); } -bool GraphUtils::IsRefFromRefData(const OutDataAnchorPtr &out_data_anchor, NodePtr &ref_data) { +graphStatus GraphUtils::CheckIsRefFromOther(const OutDataAnchorPtr &out_data_anchor, NodePtr &ref_data, bool &is_ref_from_other) { GE_ASSERT_NOTNULL(out_data_anchor); const auto owner_node = out_data_anchor->GetOwnerNode(); GE_ASSERT_NOTNULL(owner_node); - const auto out_desc = owner_node->GetOpDesc()->GetOutputDescPtr(static_cast(out_data_anchor->GetIdx())); - GE_ASSERT_NOTNULL(out_desc); - std::string ref_var_src_var_name; - bool has_ref_attr = ge::AttrUtils::GetStr(out_desc, REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); - if (!has_ref_attr) { - return false; - } - // find src ref_data_node - const auto &ower_graph = owner_node->GetOwnerComputeGraph(); - GE_ASSERT_NOTNULL(ower_graph); - const auto ref_data_node = ower_graph->FindNode(ref_var_src_var_name); - if (ref_data_node == nullptr) { - GELOGW("Can not find refdata named %s. Please check ref relation on graph.", ref_var_src_var_name.c_str()); - return false; - } - if (ref_data_node->GetType() != REFDATA) { - return false; - } - ref_data = ref_data_node; - return true; + bool is_ref_from_refdata = false; + bool is_ref_from_innerdata = false; + GE_ASSERT_GRAPH_SUCCESS(CheckIsRefFromRefData(out_data_anchor, ref_data, is_ref_from_refdata)); + GE_ASSERT_GRAPH_SUCCESS(CheckIsRefFromInnerData(out_data_anchor, ref_data, is_ref_from_innerdata)); + is_ref_from_other = (is_ref_from_refdata || is_ref_from_innerdata); + return GRAPH_SUCCESS; } bool GraphUtils::IsNoPaddingRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { diff --git a/graph/utils/op_type_utils.cc b/graph/utils/op_type_utils.cc index b3e4e259563a683b7788c7f70a5b6ad8a0222b23..6a718075c477754e5737297a4f09d0cefb122195 100644 --- a/graph/utils/op_type_utils.cc +++ b/graph/utils/op_type_utils.cc @@ -92,4 +92,8 @@ graphStatus OpTypeUtils::GetOriginalType(const ge::OpDescPtr &op_desc, std::stri GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); return GRAPH_SUCCESS; } + +bool OpTypeUtils::IsSubgraphInnerData(const ge::OpDescPtr &op_desc) { + return ((op_desc->GetType() == DATA) && op_desc->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)); +} } // namespace ge diff --git a/inc/graph/utils/graph_utils.h b/inc/graph/utils/graph_utils.h index fd90789c611e0a9a581ba2edd8dd17bd860fdf39..1fdecfd8e27c8ba706874487afce4b5f36ff250e 100644 --- a/inc/graph/utils/graph_utils.h +++ b/inc/graph/utils/graph_utils.h @@ -637,9 +637,11 @@ class GraphUtils { * 判断当前`out_data_anchor`是否引用RefData的输出 * @param out_data_anchor * @param ref_data 复用的RefData节点 - * @return 如果存在复用关系,返回true, 否则返回false + * @param 如果存在复用关系,返回true, 否则返回false + * @return status */ - static bool IsRefFromRefData(const OutDataAnchorPtr &out_data_anchor, NodePtr &ref_data); + static graphStatus CheckIsRefFromOther(const OutDataAnchorPtr &out_data_anchor, NodePtr &ref_data, + bool &is_ref_from_other); /** * 针对含有`ATTR_NAME_NOPADDING_CONTINUOUS_INPUT`和`ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT`类型的节点 * 单独封装的复用接口 diff --git a/inc/graph/utils/op_type_utils.h b/inc/graph/utils/op_type_utils.h index fb2c03c289dc4b2fae8d9524ae6958d20f39fa18..8344d8ff514d5f2812ae03178d0627d14f4b2fd3 100644 --- a/inc/graph/utils/op_type_utils.h +++ b/inc/graph/utils/op_type_utils.h @@ -23,6 +23,7 @@ class OpTypeUtils { static bool IsIdentityLikeNode(const std::string &type); static bool IsConstPlaceHolderNode(const std::string &type); static graphStatus GetOriginalType(const ge::OpDescPtr &op_desc, std::string &type); + static bool IsSubgraphInnerData(const ge::OpDescPtr &op_desc); }; } // namespace ge #endif // __INC_METADEF_OP_TYPE_UTILS_H diff --git a/tests/ut/graph/testcase/graph_unittest.cc b/tests/ut/graph/testcase/graph_unittest.cc index 0052124069a878b6b5cf15dfcd506773cba15afd..cf6ca68f8db446c02292f12ed6a990b8e3af43f7 100644 --- a/tests/ut/graph/testcase/graph_unittest.cc +++ b/tests/ut/graph/testcase/graph_unittest.cc @@ -33,6 +33,8 @@ #include "graph/ge_attr_value.h" #include "ge_ir.pb.h" #include "inc/common/ge_common/ge_inner_error_codes.h" + +#include #undef private #undef protected @@ -277,7 +279,9 @@ TEST_F(UtestGraph, IsRefFromRefData_HasNoAttr_ReturnFalse) { ASSERT_NE(out_data_anchor, nullptr); NodePtr node = nullptr; - EXPECT_FALSE(GraphUtils::IsRefFromRefData(out_data_anchor, node)); + bool is_ref_from_other = true; + EXPECT_EQ(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); + EXPECT_FALSE(is_ref_from_other); } TEST_F(UtestGraph, IsRefFromRefData_VarNameNotExist_ReturnFalse) { @@ -293,7 +297,9 @@ TEST_F(UtestGraph, IsRefFromRefData_VarNameNotExist_ReturnFalse) { ASSERT_NE(out_data_anchor, nullptr); NodePtr node = nullptr; - EXPECT_FALSE(GraphUtils::IsRefFromRefData(out_data_anchor, node)); + bool is_ref_from_other = true; + EXPECT_EQ(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); + EXPECT_FALSE(is_ref_from_other); } TEST_F(UtestGraph, IsRefFromRefData_VarNameNodeIsNotRefData_ReturnFalse) { @@ -309,7 +315,9 @@ TEST_F(UtestGraph, IsRefFromRefData_VarNameNodeIsNotRefData_ReturnFalse) { ASSERT_NE(out_data_anchor, nullptr); NodePtr node = nullptr; - EXPECT_FALSE(GraphUtils::IsRefFromRefData(out_data_anchor, node)); + bool is_ref_from_other = true; + EXPECT_EQ(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); + EXPECT_FALSE(is_ref_from_other); } TEST_F(UtestGraph, IsRefFromRefData_ReturnTrue) { @@ -325,7 +333,101 @@ TEST_F(UtestGraph, IsRefFromRefData_ReturnTrue) { ASSERT_NE(out_data_anchor, nullptr); NodePtr node = nullptr; - EXPECT_TRUE(GraphUtils::IsRefFromRefData(out_data_anchor, node)); + bool is_ref_from_other = false; + EXPECT_EQ(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); + EXPECT_TRUE(is_ref_from_other); +} + +TEST_F(UtestGraph, RefDataInSubgraph_IsRefFromInnerData_ReturnTrue) { + ut::GraphBuilder builder = ut::GraphBuilder("graph"); + auto ref_data = builder.AddNode("ref_data", "RefData", 0, 1); + auto partitioned_call = builder.AddNode("partitionedcall", "PartitionedCall", 1, 1); + auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); + builder.AddDataEdge(ref_data, 0, partitioned_call, 0); + builder.AddDataEdge(partitioned_call, 0, netoutput, 0); + auto graph = builder.GetGraph(); + + ut::GraphBuilder sub_builder = ut::GraphBuilder("subgraph"); + auto sub_data = sub_builder.AddNode("sub_Data", "Data", 0, 1); + AttrUtils::SetInt(sub_data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); + auto sub_refdata = sub_builder.AddNode("sub_RefData", "RefData", 0, 1); + auto sub_netoutput = sub_builder.AddNode("sub_Netoutput", "NetOutput", 1, 0); + builder.AddControlEdge(sub_data, sub_refdata); + builder.AddDataEdge(sub_refdata, 0, sub_netoutput, 0); + auto sub_graph = sub_builder.GetGraph(); + + sub_graph->SetParentGraph(graph); + sub_graph->SetParentNode(partitioned_call); + graph->AddSubgraph("subgraph", sub_graph); + + auto out_data_anchor = sub_refdata->GetOutDataAnchor(0); + ASSERT_NE(out_data_anchor, nullptr); + + NodePtr node = nullptr; + bool is_ref_from_other = false; + EXPECT_EQ(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); + EXPECT_TRUE(is_ref_from_other); +} + +TEST_F(UtestGraph, RefDataInSubgraph_IsRefFromInnerData_PeerInCtrolNotData_InvalidGraph_ReturnFalse) { + ut::GraphBuilder builder = ut::GraphBuilder("graph"); + auto ref_data = builder.AddNode("ref_data", "RefData", 0, 1); + auto partitioned_call = builder.AddNode("partitionedcall", "PartitionedCall", 1, 1); + auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); + builder.AddDataEdge(ref_data, 0, partitioned_call, 0); + builder.AddDataEdge(partitioned_call, 0, netoutput, 0); + auto graph = builder.GetGraph(); + + ut::GraphBuilder sub_builder = ut::GraphBuilder("subgraph"); + auto sub_cast = sub_builder.AddNode("sub_Data", "Cast", 0, 1); + auto sub_refdata = sub_builder.AddNode("sub_RefData", "RefData", 0, 1); + auto sub_netoutput = sub_builder.AddNode("sub_Netoutput", "NetOutput", 1, 0); + builder.AddControlEdge(sub_cast, sub_refdata); + builder.AddDataEdge(sub_refdata, 0, sub_netoutput, 0); + auto sub_graph = sub_builder.GetGraph(); + + sub_graph->SetParentGraph(graph); + sub_graph->SetParentNode(partitioned_call); + graph->AddSubgraph("subgraph", sub_graph); + + auto out_data_anchor = sub_refdata->GetOutDataAnchor(0); + ASSERT_NE(out_data_anchor, nullptr); + + NodePtr node = nullptr; + bool is_ref_from_other = false; + EXPECT_NE(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); +} + +TEST_F(UtestGraph, RefDataInSubgraph_IsRefFromInnerData_MultiPeerInCtrl_InvalidGraph_ReturnFalse) { + ut::GraphBuilder builder = ut::GraphBuilder("graph"); + auto ref_data = builder.AddNode("ref_data", "RefData", 0, 1); + auto partitioned_call = builder.AddNode("partitionedcall", "PartitionedCall", 1, 1); + auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); + builder.AddDataEdge(ref_data, 0, partitioned_call, 0); + builder.AddDataEdge(partitioned_call, 0, netoutput, 0); + auto graph = builder.GetGraph(); + + ut::GraphBuilder sub_builder = ut::GraphBuilder("subgraph"); + auto sub_data = sub_builder.AddNode("sub_Data", "Data", 0, 1); + AttrUtils::SetInt(sub_data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); + auto sub_cast = sub_builder.AddNode("sub_cast", "Cast", 0, 1); + auto sub_refdata = sub_builder.AddNode("sub_RefData", "RefData", 0, 1); + auto sub_netoutput = sub_builder.AddNode("sub_Netoutput", "NetOutput", 1, 0); + builder.AddControlEdge(sub_cast, sub_refdata); + builder.AddControlEdge(sub_data, sub_refdata); + builder.AddDataEdge(sub_refdata, 0, sub_netoutput, 0); + auto sub_graph = sub_builder.GetGraph(); + + sub_graph->SetParentGraph(graph); + sub_graph->SetParentNode(partitioned_call); + graph->AddSubgraph("subgraph", sub_graph); + + auto out_data_anchor = sub_refdata->GetOutDataAnchor(0); + ASSERT_NE(out_data_anchor, nullptr); + + NodePtr node = nullptr; + bool is_ref_from_other = false; + EXPECT_NE(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); } REG_OP(Shape)