From 038966898c0f7d0ccf0351b26d010d1e1a16cda8 Mon Sep 17 00:00:00 2001 From: VectorSL <864733542@qq.com> Date: Tue, 6 Aug 2024 19:52:32 +0800 Subject: [PATCH] support swith inline cache --- .../common/session/kernel_graph_mgr.cc | 110 +++++++++++++----- mindspore/ccsrc/include/common/utils/utils.h | 2 + 2 files changed, 82 insertions(+), 30 deletions(-) diff --git a/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.cc b/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.cc index 5400d512763..a6681e23acb 100644 --- a/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.cc +++ b/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.cc @@ -343,6 +343,18 @@ nlohmann::json SaveAnfToAnfMap(const HashMap &save_map) return ret; } +nlohmann::json SaveAnfToStringMap(const HashMap &save_map) { + nlohmann::json ret; + for (const auto &i : save_map) { + const auto &first_name = GetAnfUniqueCacheName(i.first, false); + if (first_name.empty()) { + continue; + } + ret[first_name] = i.second; + } + return ret; +} + std::vector SaveAnfToAnfIndexMap(const HashMap &save_map) { std::vector ret_json; for (const auto &i : save_map) { @@ -607,7 +619,11 @@ void SaveNodesKernelInfoAndParamsName(const KernelGraphPtr &kg, const std::vecto param_unique_name_to_name[name] = param->name(); } if (node->kernel_info() == nullptr) { - MS_LOG(WARNING) << "The node " << node->DebugString() << " has not kernel_info."; + MS_LOG(INFO) << "The node " << node->DebugString() << " has not kernel_info."; + continue; + } + if (dynamic_cast(node->kernel_info())->select_kernel_build_info() == nullptr) { + MS_LOG(INFO) << "The node " << node->DebugString() << " has kernel_info but build info is null."; continue; } const auto &kernel_info_json = SaveAnfKernelInfo(node); @@ -734,6 +750,14 @@ nlohmann::json GenKernelGraphJson(const KernelGraphPtr &kg, const std::vectorinline_sub_graph_kernels()); + if (!inline_sub_graph_kernels_json.empty()) { + kg_json[kInlineSubGraphKernelsMap] = inline_sub_graph_kernels_json; + } + const auto &condition_gather_to_switch_json = SaveAnfToAnfMap(kg->condition_gather_to_switch()); + if (!condition_gather_to_switch_json.empty()) { + kg_json[kConditionGatherToSwitchMap] = condition_gather_to_switch_json; + } return kg_json; } @@ -2268,6 +2292,59 @@ void HandleAttrAboutOtherGraph(const mindspore::HashMapAddOutFrontPairs(AnfWithOutIndex(first_node, first_index), AnfWithOutIndex(second_node, second_index)); + } + } + if (graph_json.contains(kFrontNodeToGraphOutputMap)) { + const auto &front_out_map_json = graph_json[kFrontNodeToGraphOutputMap]; + for (const auto &iter : front_out_map_json) { + const auto &first_name = iter.at(0); + const auto &first_index = iter.at(kIndexOne); + const auto &second_name = iter.at(kIndexTwo); + const auto &second_index = iter.at(kIndexThree); + auto first_node = context.FindFrontNodeByFrontName(first_name); + MS_EXCEPTION_IF_NULL(first_node); + auto second_node = context.FindBackNodeByBackName(second_name); + MS_EXCEPTION_IF_NULL(second_node); + graph->AddFrontOutPairs(AnfWithOutIndex(first_node, first_index), AnfWithOutIndex(second_node, second_index)); + } + } + if (graph_json.contains(kConditionGatherToSwitchMap)) { + const auto &condition_gather_switch_json = graph_json[kConditionGatherToSwitchMap]; + for (const auto &[gather_name, switch_name] : condition_gather_switch_json.items()) { + auto gather_node = context.FindBackNodeByBackName(gather_name); + auto switch_node = context.FindBackNodeByBackName(switch_name); + if (!gather_node || !switch_node) { + MS_LOG(EXCEPTION) << "The backend node of " << gather_name << " or " << switch_name << " is nullptr."; + } + graph->AddConditionGatherSwitchPair(gather_node, switch_node); + } + } + if (graph_json.contains(kInlineSubGraphKernelsMap)) { + const auto &inline_subgraph_kernels_json = graph_json[kInlineSubGraphKernelsMap]; + for (const auto &[node_name, graph_name] : inline_subgraph_kernels_json.items()) { + auto node = context.FindBackNodeByBackName(node_name); + if (!node) { + MS_LOG(EXCEPTION) << "The backend node of " << node_name << " is nullptr."; + } + graph->AddInlineSubgraphKernel(node, graph_name); + } + } +} + void HandleGraphComplexAttr(const mindspore::HashMap> &graphs, const nlohmann::json &graph_json, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); @@ -2318,34 +2395,6 @@ void HandleGraphComplexAttr(const mindspore::HashMapAddOutFrontPairs(AnfWithOutIndex(first_node, first_index), AnfWithOutIndex(second_node, second_index)); - } - } - if (graph_json.contains(kFrontNodeToGraphOutputMap)) { - const auto &front_out_map_json = graph_json[kFrontNodeToGraphOutputMap]; - for (const auto &iter : front_out_map_json) { - const auto &first_name = iter.at(0); - const auto &first_index = iter.at(kIndexOne); - const auto &second_name = iter.at(kIndexTwo); - const auto &second_index = iter.at(kIndexThree); - auto first_node = context.FindFrontNodeByFrontName(first_name); - MS_EXCEPTION_IF_NULL(first_node); - auto second_node = context.FindBackNodeByBackName(second_name); - MS_EXCEPTION_IF_NULL(second_node); - graph->AddFrontOutPairs(AnfWithOutIndex(first_node, first_index), AnfWithOutIndex(second_node, second_index)); - } - } if (graph_json.contains(kNodesKernelInfo)) { const auto &kernel_infos_json = graph_json[kNodesKernelInfo]; LoadAnfKernelInfoFromJson(kernel_infos_json); @@ -2375,6 +2424,7 @@ void HandleGraphComplexAttr(const mindspore::HashMapset_summary_nodes(summary_nodes); } #endif + HandleSwitchInlineMaps(graph_json, graph); MS_LOG(INFO) << "Handle graph " << graph->ToString() << " complex attr success."; } @@ -2869,7 +2919,7 @@ void CopyCNodeInfo(const FuncGraphPtr &func_graph, const uint32_t &target_graph_ common::AnfAlgo::SetNodeAttr(kAttrOriFusionName, MakeValue(ori_full_name), new_node); } common::AnfAlgo::SetNodeAttr(kAttrNeedInline, MakeValue(ori_node->fullname_with_scope()), new_node); - common::AnfAlgo::SetNodeAttr(kAttrPreKernelGraph, MakeValue(func_graph), new_node); + common::AnfAlgo::SetNodeAttr(kAttrPreKernelGraph, MakeValue(func_graph->ToString()), new_node); } } diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index bf8e9add711..1b6f79a591c 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -703,6 +703,8 @@ constexpr auto kIsRefGraph = "is_ref_graph"; constexpr auto kFromRefGraph = "from_ref_graph"; constexpr auto kGraphOutputToFrontNodeMap = "graph_output_to_front_node_map"; constexpr auto kFrontNodeToGraphOutputMap = "front_node_to_graph_output_map"; +constexpr auto kInlineSubGraphKernelsMap = "inline_sub_graph_kernels"; +constexpr auto kConditionGatherToSwitchMap = "condition_gather_to_switch"; // recompute and parallel constexpr auto kRecomputeInsert = "recompute_insert"; -- Gitee