From 055075b423620dcc55e0be2aec3472c726ca0ac9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E4=BA=91=E9=BE=99?= Date: Sat, 25 Dec 2021 11:09:11 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B4=E6=94=B9=E9=97=AE=E9=A2=98=EF=BC=8C?= =?UTF-8?q?=E6=96=B0=E5=A2=9EDynamicMul=20=E7=9A=84mapping=20=E6=96=B9?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- inc/external/register/register.h | 3 + register/register.cpp | 107 +++++++++++++++++++++++++++++-- 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/inc/external/register/register.h b/inc/external/register/register.h index d361e3948c..52b1084d3f 100644 --- a/inc/external/register/register.h +++ b/inc/external/register/register.h @@ -86,6 +86,9 @@ ATTRIBUTED_DEPRECATED(Status AutoMappingByOpFnDynamic(const ge::Operator &, ge:: Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, std::map> dynamic_name_attr_value, int32_t in_pos = -1, int32_t out_pos = -1); +Status AutoMappingFnDynamicMul(const google::protobuf::Message *op_src, ge::Operator &op, + std::map> dynamic_name_attr_value, + int32_t in_pos = -1, int32_t out_pos = -1); Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function &input, const std::function &output); diff --git a/register/register.cpp b/register/register.cpp index 11ea15eb38..1b585dc424 100644 --- a/register/register.cpp +++ b/register/register.cpp @@ -16,12 +16,12 @@ #include "external/register/register.h" #include -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_op_types.h" #include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_log.h" +#include "graph/debug/ge_op_types.h" #include "graph/debug/ge_util.h" -#include "graph/debug/ge_attr_define.h" +#include "graph/graph.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" @@ -30,7 +30,6 @@ #include "register/auto_mapping_util.h" #include "register/op_registry.h" #include "register/register_utils.h" -#include "graph/graph.h" namespace domi { using namespace domi::tensorflow; @@ -359,6 +358,106 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingFnDynamic( return SUCCESS; } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingFnDynamicMul( + const google::protobuf::Message *op_src, ge::Operator &op, + std::map> dynamic_name_attr_value, + int32_t in_pos, int32_t out_pos) { + // 1. automapping for parser + const std::shared_ptr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); + GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(op_src); + const Status ret = AutoMappingFn(op_src, op); + if (ret != SUCCESS) { + GE_LOGE("Op: %s call auto mapping function failed.", op_desc->GetName().c_str()); + return FAILED; + } + + GELOGI("op[%s] call auto mapping function success.", op_desc->GetName().c_str()); + + // add dynamic input and output + const NodeDef *const node = reinterpret_cast(op_src); + int32_t in_mul = 1; + int32_t out_mul = 1; + + for (const auto &it : dynamic_name_attr_value) { + const std::string flag = it.first; + const std::pair name_value = it.second; + const std::string attr_name = name_value.second; + + tensorflow::AttrValue attr_num; + int32_t dynamic_tensor_num = 0; + if (!(ge::AutoMappingUtil::FindAttrValue(node, attr_name, attr_num))) { + GELOGW("[AutoMappingFn][GetAttr] Dynamic attr %s in node %s not exist.", attr_name.c_str(), node->name().c_str()); + } + + if (attr_num.has_list()) { + dynamic_tensor_num = attr_num.list().type_size(); + } else { + dynamic_tensor_num = static_cast(attr_num.i()); + } + + if (dynamic_tensor_num <= 0) { + GELOGW("[AutoMappingFn][Check] Dynamic num %d in node %s is less than 0.", dynamic_tensor_num, + node->name().c_str()); + continue; + } + + GELOGI("In NodeDef %s dynamic attr [%s] is exist: %d.", node->name().c_str(), attr_name.c_str(), + dynamic_tensor_num); + + if (flag == "in_mul") { + in_mul = dynamic_tensor_num; + GELOGI("In NodeDef %s add dynamic input multiple[%d]", node->name().c_str(), dynamic_tensor_num); + } else if (flag == "out_mul") { + out_mul = dynamic_tensor_num; + GELOGI("In NodeDef %s add dynamic output multiple[%d]", node->name().c_str(), dynamic_tensor_num); + } + } + + for (const auto &it : dynamic_name_attr_value) { + const std::string flag = it.first; + const std::pair name_value = it.second; + const std::string dynamic_name = name_value.first; + const std::string attr_name = name_value.second; + + tensorflow::AttrValue attr_num; + int32_t dynamic_tensor_num = 0; + if (!(ge::AutoMappingUtil::FindAttrValue(node, attr_name, attr_num))) { + GELOGW("[AutoMappingFn][GetAttr] Dynamic attr %s in node %s not exist.", attr_name.c_str(), node->name().c_str()); + } + + if (attr_num.has_list()) { + dynamic_tensor_num = attr_num.list().type_size(); + } else { + dynamic_tensor_num = static_cast(attr_num.i()); + } + + if (dynamic_tensor_num <= 0) { + GELOGW("[AutoMappingFn][Check] Dynamic num %d in node %s is less than 0.", dynamic_tensor_num, + node->name().c_str()); + continue; + } + + GELOGI("In NodeDef %s dynamic attr [%s] is exist: %d.", node->name().c_str(), attr_name.c_str(), + dynamic_tensor_num); + + if (flag == "in") { + const bool is_pushback = (in_pos == -1); + dynamic_tensor_num *= in_mul; + (void)op_desc->AddDynamicInputDesc(dynamic_name, static_cast(dynamic_tensor_num), is_pushback); + (void)ge::AttrUtils::SetInt(op_desc, DYNAMIC_INPUT_TD_NUM(dynamic_name), dynamic_tensor_num); + GELOGI("In NodeDef %s add dynamic input[%d]", node->name().c_str(), dynamic_tensor_num); + } else if (flag == "out") { + const bool is_pushback = (out_pos == -1); + dynamic_tensor_num *= out_mul; + (void)op_desc->AddDynamicOutputDesc(dynamic_name, static_cast(dynamic_tensor_num), is_pushback); + (void)ge::AttrUtils::SetInt(op_desc, DYNAMIC_OUTPUT_TD_NUM(dynamic_name), dynamic_tensor_num); + GELOGI("In NodeDef %s add dynamic output[%d]", node->name().c_str(), dynamic_tensor_num); + } + } + return SUCCESS; +} + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op, const vector &dynamic_name_attr_value) { // 1. auto mapping for parser -- Gitee