diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 26df6185d5c9d37cf0e7666e2f613002c245ea3f..2541c065f1e1ee8b4475f9666ffef81748081eeb 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -279,17 +279,14 @@ void Executor::ClearDoneTasks() { void Executor::RunTask(const std::shared_ptr &task, bool sync, bool long_run) { { - std::lock_guard lock(task_mutex_); - ready_tasks_.push(task); - } - sync_run_task_finished_ = false; - task_cond_var_.notify_all(); - if (sync && !sync_run_task_finished_) { std::unique_lock lock(task_mutex_); - if (long_run) { + ready_tasks_.push(task); + sync_run_task_finished_ = false; + task_cond_var_.notify_all(); + if (sync && long_run) { mindspore::ScopedLongRunning long_running; sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; }); - } else { + } else if (sync) { sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; }); } } diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index b68e7c267d2700241dbc140d6832ba8d2eb67c3e..b5975efc0b1945f5ec0374fec0fd264c0d8c3d36 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -729,7 +729,7 @@ AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) { auto value_node = node->cast(); MS_EXCEPTION_IF_NULL(value_node); auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value()); - if (RemoveValueNodeFromGraph(value_node)) { + if (!RemoveValueNodeFromGraph(value_node)) { MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); } return make_tuple; diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index ca825519f97c83d4a782699cf6bba3fab42c72d0..b0d20af99c3b1fa5958a911b21c244ce64cf71e0 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -358,7 +358,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor auto func_graph = GetValueNode(input_fg); MS_EXCEPTION_IF_NULL(func_graph); auto manager = Manage({fg, func_graph}, false); - auto need_replace_forward = pynative::PynativeExecutor::GetInstance()->need_replace_forward(); + auto need_replace_forward = true; auto forward_value = GenNewTensor(manager, equivdout, forward, need_replace_forward); if (!need_replace_forward) { cnode_morph->clear_inputs_value(); @@ -818,7 +818,8 @@ static std::pair FindPrimalJPair(const FuncGraphManagerPtr & auto &node_user_map = manager->node_users(); // Search primal graph user cnodes. for (auto &entry : primal_graph->func_graph_cnodes_index()) { - auto cnode = entry.first->first->cast(); + auto anfnode = entry.first->first.lock(); + auto cnode = anfnode->cast(); auto index = entry.first->second; if (index == 0) { // To find real calling. diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index 47497b7bea9aef8dd21d2bbeb2078b91e96f22e2..51e838db5b598f5a0c6d2ac3abb0f47bb3e81265 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -138,6 +138,8 @@ class KPrim { FuncGraphPtr KPrimitive(const CNodePtr &primal_user, const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); + FuncGraphPtr KPrimitiveForPrimBpOpt(const CNodePtr &primal_user, const ValueNodePtr &value_node, + const pipeline::ResourceBasePtr &resources); MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); // bprop_fg and primal_fg in bprop_fg's transforms are FuncGraph just after convert. // current_primal_fg is the specialized and AutoMonaded primal_fg. @@ -147,6 +149,7 @@ class KPrim { bprop_registry_meta_.clear(); bprop_registry_.clear(); } + FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim); private: FuncGraphPtr GetBprop(const PrimitivePtr &prim); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 47713ff6dece0d44a276b673b707dd36d3236d9d..0c35874b3777041d31e0d6dc51ef6e1addc05b90 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -37,7 +37,6 @@ namespace mindspore { namespace ad { -using PatternListType = std::initializer_list; KPrim g_k_prims; FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { @@ -75,6 +74,23 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { return func_graph; } +FuncGraphPtr KPrim::GetPossibleBprop(const PrimitivePtr &prim) { + FuncGraphPtr bprop_fg = nullptr; + auto iter = bprop_registry_.find(prim); + if (iter != bprop_registry_.end()) { + bprop_fg = iter->second; + } + + if (bprop_fg == nullptr) { + bprop_fg = GetBprop(prim); + if (bprop_fg != nullptr) { + // Set bprop_g graph cache + bprop_registry_[prim] = bprop_fg; + } + } + return bprop_fg; +} + FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) { static const std::string ad_module = "mindspore.ops._grad.grad_implementations"; std::string func_name = "_fprop_" + prim->name(); @@ -189,6 +205,56 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_ return expanded_fg; } +FuncGraphPtr KPrim::KPrimitiveForPrimBpOpt(const CNodePtr &cnode, const ValueNodePtr &value_node, + const pipeline::ResourceBasePtr &resources) { + if (!IsValueNode(value_node)) { + MS_LOG(EXCEPTION) << "Primitive node is not valid."; + } + + auto prim = GetValueNode(value_node); + if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) { + auto fprop = GetFprop(prim); + fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer)); + return fprop; + } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { + return nullptr; + } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { + return nullptr; + } + + FuncGraphPtr bprop_fg = nullptr; + if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { + if (MsContext::GetInstance()->get_param(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) { + MS_LOG(EXCEPTION) << "HookBackward is not supported in graph mode."; + } + bprop_fg = BpropCut(value_node, resources); + } else { + auto iter = bprop_registry_.find(prim); + if (iter != bprop_registry_.end()) { + bprop_fg = iter->second; + } + + if (bprop_fg == nullptr) { + bprop_fg = GetBprop(prim); + if (bprop_fg != nullptr) { + // Set bprop_g graph cache + bprop_registry_[prim] = bprop_fg; + } else { + bprop_fg = FakeBprop(value_node, resources); + } + } + } + AdjustForAutoMonad(prim, bprop_fg); + auto cloned_bprop_fg = BasicClone(bprop_fg); + MS_EXCEPTION_IF_NULL(cloned_bprop_fg); + auto debug_info = std::make_shared(); + debug_info->set_name(prim->ToString()); + cloned_bprop_fg->debug_info()->set_name(prim->ToString()); + cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared(debug_info)); + + return cloned_bprop_fg; +} + AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) { // current_primal_fg may have extra parameters like u_monad, io_monad std::vector extra_args; diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc b/mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc new file mode 100644 index 0000000000000000000000000000000000000000..f2bc8d6acab7653577a1f3b3ebf6f0076a17d60a --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc @@ -0,0 +1,854 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "pipeline/jit/prim_bprop_optimizer.h" +#include "frontend/optimizer/ad/adjoint.h" +#include "frontend/optimizer/ad/dfunctor.h" +#include "frontend/optimizer/ad/kpynative.h" +#include "frontend/operator/ops.h" +#include "utils/info.h" +#include "debug/anf_ir_dump.h" +#include "debug/trace.h" + +namespace mindspore { +namespace ad { +extern KPrim g_k_prims; + +namespace { +FuncGraphPtr ZerosLikePrimOptPass(const pipeline::ResourcePtr &res) { + static const opt::irpass::OptimizeIRPassLib irpass; + opt::OptPassConfig eliminate_zeros_like_prim_pass = opt::OptPassConfig({ + irpass.zero_like_fill_zero_, + }); + + opt::OptPassGroupMap map({{"eliminate_zeros_like_prim_", eliminate_zeros_like_prim_pass}}); + + auto eliminate_zeros_like_prim = opt::Optimizer::MakeOptimizer("eliminate_zeros_like_prim", res, map); + FuncGraphPtr func_graph = res->func_graph(); + WITH(MsProfile::GetProfile()->Step("eliminate_zeros_like_prim"))[&eliminate_zeros_like_prim, &func_graph]() { + func_graph = eliminate_zeros_like_prim->step(func_graph, true); + }; + return func_graph; +} + +FuncGraphPtr GetZerosLike(const abstract::AbstractBasePtrList &args_spec) { + static ValuePtr zeros_like_ops = prim::GetPythonOps("zeros_like"); + static std::unordered_map + zeros_like_funcgraph_cache; + auto iter = zeros_like_funcgraph_cache.find(args_spec); + if (iter != zeros_like_funcgraph_cache.end()) { + MS_LOG(DEBUG) << "Cache hit for zeros_like: " << mindspore::ToString(args_spec); + return BasicClone(iter->second); + } + if (!zeros_like_ops->isa()) { + MS_LOG(EXCEPTION) << "zeros_like is not a MetaFuncGraph"; + } + auto zeros_like = zeros_like_ops->cast(); + auto zeros_like_fg = zeros_like->GenerateFuncGraph(args_spec); + MS_EXCEPTION_IF_NULL(zeros_like_fg); + pipeline::ResourcePtr resource = std::make_shared(); + auto specialized_zeros_like_fg = pipeline::Renormalize(resource, zeros_like_fg, args_spec); + MS_EXCEPTION_IF_NULL(specialized_zeros_like_fg); + auto opted_zeros_like_fg = ZerosLikePrimOptPass(resource); + MS_EXCEPTION_IF_NULL(opted_zeros_like_fg); + zeros_like_funcgraph_cache[args_spec] = opted_zeros_like_fg; + return BasicClone(opted_zeros_like_fg); +} + +FuncGraphPtr GetHyperAdd(const abstract::AbstractBasePtrList &args_spec) { + static ValuePtr add_ops = prim::GetPythonOps("hyper_add"); + static std::unordered_map + add_backward_funcgraph_cache; + auto iter = add_backward_funcgraph_cache.find(args_spec); + if (iter != add_backward_funcgraph_cache.end()) { + MS_LOG(DEBUG) << "Cache hit for hyper_add: " << mindspore::ToString(args_spec); + return BasicClone(iter->second); + } + if (!add_ops->isa()) { + MS_LOG(EXCEPTION) << "add is not a MetaFuncGraph"; + } + auto add = add_ops->cast(); + auto add_fg = add->GenerateFuncGraph(args_spec); + MS_EXCEPTION_IF_NULL(add_fg); + pipeline::ResourcePtr resource = std::make_shared(); + auto specialized_add_fg = pipeline::Renormalize(resource, add_fg, args_spec); + MS_EXCEPTION_IF_NULL(specialized_add_fg); + add_backward_funcgraph_cache[args_spec] = specialized_add_fg; + return BasicClone(specialized_add_fg); +} + +AnfNodePtr BuildZerosLikeNode(const FuncGraphPtr &tape, const AnfNodePtr &node) { + // Build zeros_like(node) as dout + abstract::AbstractBasePtrList args_spec{node->abstract()->Broaden()}; + auto zeros_like_fg = GetZerosLike(args_spec); + auto zeros_like_node = tape->NewCNode({NewValueNode(zeros_like_fg), node}); + zeros_like_node->set_abstract(zeros_like_fg->output()->abstract()); + return zeros_like_node; +} + +AnfNodePtr BuildZerosLikeValue(const FuncGraphPtr &tape, const ValuePtr &out) { + // Build zeros_like(out) as dout + abstract::AbstractBasePtrList args_spec{out->ToAbstract()->Broaden()}; + auto zeros_like_fg = GetZerosLike(args_spec); + auto zeros_like_value = tape->NewCNode({NewValueNode(zeros_like_fg), NewValueNode(out)}); + zeros_like_value->set_abstract(zeros_like_fg->output()->abstract()); + return zeros_like_value; +} + +FuncGraphPtr GetOnesLike(const abstract::AbstractBasePtrList &args_spec) { + static ValuePtr ones_like_ops = prim::GetPythonOps("ones_like"); + static std::unordered_map + ones_like_funcgraph_cache; + auto iter = ones_like_funcgraph_cache.find(args_spec); + if (iter != ones_like_funcgraph_cache.end()) { + MS_LOG(DEBUG) << "Cache hit for ones_like: " << mindspore::ToString(args_spec); + return BasicClone(iter->second); + } + if (!ones_like_ops->isa()) { + MS_LOG(EXCEPTION) << "ones_like is not a MetaFuncGraph"; + } + auto ones_like = ones_like_ops->cast(); + auto ones_like_fg = ones_like->GenerateFuncGraph(args_spec); + MS_EXCEPTION_IF_NULL(ones_like_fg); + pipeline::ResourcePtr resource = std::make_shared(); + auto specialized_ones_like_fg = pipeline::Renormalize(resource, ones_like_fg, args_spec); + MS_EXCEPTION_IF_NULL(specialized_ones_like_fg); + ones_like_funcgraph_cache[args_spec] = specialized_ones_like_fg; + return BasicClone(specialized_ones_like_fg); +} + +AnfNodePtr BuildOnesLikeValue(const FuncGraphPtr &tape, const ValuePtr &out) { + // Build ones_like(out) as dout + abstract::AbstractBasePtrList args_spec{out->ToAbstract()->Broaden()}; + auto ones_like_fg = GetOnesLike(args_spec); + auto ones_like_value = tape->NewCNode({NewValueNode(ones_like_fg), NewValueNode(out)}); + ones_like_value->set_abstract(ones_like_fg->output()->abstract()); + return ones_like_value; +} + +// This Faked BProp func_graph should not be present in the final top bprop func_graph. +// Its output is faked but not a tuple. +FuncGraphPtr BuildFakeBProp(const PrimitivePtr &prim, size_t inputs_num) { + auto func_graph = std::make_shared(); + std::vector outputs; + + auto fake_bprop = std::make_shared("fake_bprop"); + (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined.")); + outputs.push_back(NewValueNode(fake_bprop)); + outputs.push_back(NewValueNode(true)); + + for (size_t i = 0; i < inputs_num; ++i) { + // Mock params for inputs + auto param = func_graph->add_parameter(); + } + // mock params for out and dout + (void)func_graph->add_parameter(); + (void)func_graph->add_parameter(); + func_graph->set_output(func_graph->NewCNode(outputs)); + return func_graph; +} +} // namespace + +class PynativeAdjoint { + public: + PynativeAdjoint(const FuncGraphPtr &tape, const ValuePtrList &op_args, const ValuePtr &out, + const FuncGraphPtr &bprop_fg) + : tape_(tape), op_args_(op_args), out_(out), bprop_fg_(bprop_fg) {} + + AnfNodePtrList &users() { return users_; } + const ValuePtrList &op_args() { return op_args_; } + const ValuePtr &out() { return out_; } + const FuncGraphPtr &bprop_fg() { return bprop_fg_; } + AnfNodePtr RealDout() { + if (dout_ != nullptr) { + return dout_; + } + return BuildZerosLikeValue(tape_, out_); + } + + void AccumulateDout(const AnfNodePtr &dout_factor) { + if (dout_ != nullptr) { + MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); + auto arg = out_->ToAbstract()->Broaden(); + abstract::AbstractBasePtrList args_spec{arg, arg}; + auto add_fg = GetHyperAdd(args_spec); + MS_EXCEPTION_IF_NULL(add_fg); + dout_ = tape_->NewCNode({NewValueNode(add_fg), dout_, dout_factor}); + dout_->set_abstract(add_fg->output()->abstract()); + MS_LOG(DEBUG) << "New dout_ " << dout_->DebugString(); + return; + } + dout_ = dout_factor; + } + + private: + const FuncGraphPtr tape_; + AnfNodePtr dout_{nullptr}; + // Used by whoes + AnfNodePtrList users_; + // cache these arguments from ad caller. + const ValuePtrList op_args_; + // For CNode , it's output of cnode. For Parameter or ValueNode, it's its value. + const ValuePtr out_; + // bprop_fg passed from ad caller, it may be user defined back propagate funcgragh. + const FuncGraphPtr bprop_fg_; +}; +using PynativeAdjointPtr = std::shared_ptr; + +class KPynativeCellImpl : public KPynativeCell { + public: + explicit KPynativeCellImpl(const AnfNodePtrList &cell_inputs) : cell_inputs_(cell_inputs) { + tape_ = std::make_shared(); + tape_->debug_info()->set_name("grad_top"); + for (size_t i = 0; i < cell_inputs.size(); ++i) { + TraceGuard trace_guard(std::make_shared(cell_inputs[i]->debug_info())); + tape_->add_parameter(); + } + } + ~KPynativeCellImpl() override = default; + bool KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, + FuncGraphPtr bprop_fg = nullptr); + bool KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, + const FuncGraphPtr &bprop_fg); + FuncGraphPtr Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights, bool has_sens_arg); + + private: + FuncGraphPtr tape_; + OrderedMap anfnode_to_adjoin_; + AnfNodePtrList cell_inputs_; + // Last cnode of this Cell, may be a primitive op or cell with user defined bprop. + AnfNodePtr last_node_{nullptr}; + bool need_propagate_stop_gradient_{false}; + + // For CNode like TupleGetItem, ListGetItem, MakeTuple, MakeList, it's bypassed by caller so + // no KPynativeOp is called for these CNode. Here we forge Adjoint for these CNode. + PynativeAdjointPtr ForgeCNodeAdjoint(const CNodePtr &cnode); + PynativeAdjointPtr ForgeGetItemAdjoint(const CNodePtr &cnode); + PynativeAdjointPtr ForgeMakeSequenceAdjoint(const CNodePtr &cnode); + bool BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, + const FuncGraphPtr &bprop_fg); + bool BuildHighOrderFuncGraphAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, + const FuncGraphPtr &bprop_fg); + void PropagateStopGradient(); + bool AllReferencesStopped(const CNodePtr &curr_cnode); + // Back propagate for all node; + bool BackPropagate(); + bool BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app); + FuncGraphPtr BuildBpropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode); + // Back propagate for MakeList or MakeTuple is generated from MetaFuncGraph. + FuncGraphPtr BuildMakeSequenceBprop(const PrimitivePtr &prim, const CNodePtr &cnode); + // Set return node according to grad flag + void SetReturnNodeByGradFlag(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights); +}; +using KPynativeCellImplPtr = std::shared_ptr; + +KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs) { + auto abstract_are_set = std::all_of(cell_inputs.cbegin(), cell_inputs.cend(), + [](const AnfNodePtr &node) { return node->abstract() != nullptr; }); + if (!abstract_are_set) { + MS_LOG(EXCEPTION) << "Not all abstract_value in cell_inputs are set"; + } + return std::make_shared(cell_inputs); +} + +FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs, + bool grad_weights, bool has_sens_arg) { + auto k_cell_impl = std::dynamic_pointer_cast(k_cell); + return k_cell_impl->Finish(weights, grad_inputs, grad_weights, has_sens_arg); +} + +FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights, + bool has_sens_arg) { + // propagate stop_gradient flag to cnode before back propagate; + PropagateStopGradient(); + + auto last_node_adjoint_iter = anfnode_to_adjoin_.find(last_node_); + if (last_node_adjoint_iter == anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->ToString(); + } + if (has_sens_arg) { + // sens parameter; + auto sens_param = tape_->add_parameter(); + sens_param->debug_info()->set_name("sens"); + // Set dout of last node to sens; + last_node_adjoint_iter->second->AccumulateDout(sens_param); + } else { + auto sens_node = BuildOnesLikeValue(tape_, last_node_adjoint_iter->second->out()); + last_node_adjoint_iter->second->AccumulateDout(sens_node); + } + // Add weights parameter + for (const auto &weight : weights) { + TraceGuard trace_guard(std::make_shared(weight->debug_info())); + auto p = tape_->add_parameter(); + auto input_w = weight->cast(); + MS_EXCEPTION_IF_NULL(input_w); + p->set_default_param(input_w->default_param()); + } + + // BackPropagate sensitivity; + BackPropagate(); + // Return the gradient; + SetReturnNodeByGradFlag(weights, grad_inputs, grad_weights); + // Replace AnfNode with parameter of tape_; + auto mng = MakeManager({tape_}, false); + auto tr = mng->Transact(); + const auto ¶meters = tape_->parameters(); + auto cell_inputs_size = cell_inputs_.size(); + for (size_t i = 0; i < cell_inputs_size; ++i) { + tr.Replace(cell_inputs_[i], parameters[i]); + } + // (Inputs, sens, weights) or (Inputs, weights) + size_t weight_offset = cell_inputs_size; + if (has_sens_arg) { + weight_offset = weight_offset + 1; + } + for (size_t i = 0; i < weights.size(); ++i) { + tr.Replace(weights[i], parameters[weight_offset + i]); + } + tr.Commit(); + + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { + DumpIR("before_final_opt.ir", tape_); + } + return tape_; +} + +bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args, + const ValuePtr &out, FuncGraphPtr bprop_fg) { + auto k_cell_impl = std::dynamic_pointer_cast(k_cell); + return k_cell_impl->KPynativeOp(cnode, op_args, out, bprop_fg); +} + +bool KPynativeCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, + FuncGraphPtr bprop_fg) { + MS_EXCEPTION_IF_NULL(cnode); + auto prim = GetCNodePrimitive(cnode); + if (prim == nullptr) { + if (IsValueNode(cnode->input(0))) { + MS_EXCEPTION_IF_NULL(bprop_fg); + MS_LOG(DEBUG) << "Do nested adjont"; + BuildHighOrderFuncGraphAdjoint(cnode, op_args, out, bprop_fg); + return true; + } else { + MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString(); + } + } + if (IsPrimitiveEquals(prim, prim::kPrimStopGradient) || IsPrimitiveEquals(prim, prim::kPrimUpdateState)) { + need_propagate_stop_gradient_ = true; + } + + if (IsPrimitiveEquals(prim, prim::kPrimHookBackward)) { + bprop_fg = BuildBpropCutFuncGraph(prim, cnode); + } else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) { + bprop_fg = BuildMakeSequenceBprop(prim, cnode); + } else { + bprop_fg = g_k_prims.GetPossibleBprop(prim); + if (bprop_fg == nullptr) { + MS_LOG(DEBUG) << "Cannot find defined bprop for cnode prim: " << cnode->DebugString(); + bprop_fg = BuildFakeBProp(prim, cnode->size() - 1); + } + } + MS_EXCEPTION_IF_NULL(bprop_fg); + BuildAdjoint(cnode, op_args, out, bprop_fg); + + return true; +} + +bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args, + const ValuePtr &out, const FuncGraphPtr &bprop_fg) { + auto k_cell_impl = std::dynamic_pointer_cast(k_cell); + return k_cell_impl->KPynativeWithBProp(cnode, op_args, out, bprop_fg); +} + +bool KPynativeCellImpl::KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, + const FuncGraphPtr &bprop_fg) { + MS_EXCEPTION_IF_NULL(cnode); + auto primal_fg = GetCNodeFuncGraph(cnode); + if (primal_fg == nullptr) { + MS_LOG(EXCEPTION) << "Should be func graph, but: " << cnode->DebugString(); + } + MS_EXCEPTION_IF_NULL(bprop_fg); + BuildAdjoint(cnode, op_args, out, bprop_fg); + + return true; +} + +namespace { +ValuePtr ShallowCopyValue(const ValuePtr &value) { + if (value->isa()) { + auto tensor_value = value->cast(); + return std::make_shared(*tensor_value); + } else if (value->isa()) { + std::vector values; + auto value_tuple = value->cast(); + std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values), + [](const ValuePtr &elem) { return ShallowCopyValue(elem); }); + return std::make_shared(values); + } else { + return value; + } +} +} // namespace + +PynativeAdjointPtr KPynativeCellImpl::ForgeGetItemAdjoint(const CNodePtr &cnode) { + if (cnode->size() != 3) { + MS_LOG(EXCEPTION) << "TupleGetItem/ListGetItem CNode should have 3 inputs, but CNode: " << cnode->DebugString(); + } + // Input 1 of CNode; + PynativeAdjointPtr inp_1_adjoint = nullptr; + auto inp_1 = cnode->input(1); + auto inp_1_adjoint_iter = anfnode_to_adjoin_.find(inp_1); + if (inp_1_adjoint_iter == anfnode_to_adjoin_.end()) { + if (!inp_1->isa()) { + MS_LOG(EXCEPTION) << "Input 1 of CNode should be a CNode, CNode: " << cnode->DebugString(); + } + inp_1_adjoint = ForgeCNodeAdjoint(inp_1->cast()); + if (inp_1_adjoint == nullptr) { + MS_LOG(EXCEPTION) << "Build adjoint for input 1 of CNode failed, CNode: " << cnode->DebugString(); + } + inp_1_adjoint->users().push_back(cnode); + } else { + inp_1_adjoint = inp_1_adjoint_iter->second; + } + if (!inp_1_adjoint->out()->isa()) { + MS_LOG(EXCEPTION) << "Input of CNode should be evaluated to a ValueSequence. CNode: " << cnode->DebugString() + << ", out of input1: " << inp_1_adjoint->out(); + } + auto inp_1_out = inp_1_adjoint->out()->cast(); + + // Input 2 of CNode; + auto index_value = GetValueNode(cnode->input(2)); + if (index_value == nullptr) { + MS_LOG(EXCEPTION) << "CNode input 2 should be a Int64Imm, CNode: " << cnode->DebugString(); + } + if (index_value->value() < 0) { + MS_LOG(EXCEPTION) << "CNode input 2 should not less than 0, CNode: " << cnode->DebugString(); + } + size_t index_value_imm = index_value->value(); + if (index_value_imm < 0 || index_value_imm >= inp_1_out->size()) { + MS_LOG(EXCEPTION) << "CNode input 2 should be index between [0, " << inp_1_out->size() + << ", but: " << index_value->ToString(); + } + auto cnode_out = (*inp_1_out)[index_value_imm]; + ValuePtrList op_args{inp_1_out, index_value}; + // cnode is TupleGetItem/ListGetItem, op_args is inputs and find by prev cnode inputs + auto built = KPynativeOp(cnode, op_args, cnode_out); + if (!built) { + MS_LOG(EXCEPTION) << "Build Adjoint for GetItem node failed, CNode: " << cnode->DebugString(); + } + auto cnode_adjoint_iter = anfnode_to_adjoin_.find(cnode); + if (cnode_adjoint_iter == anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "Build Adjoint for GetItem node failed, CNode: " << cnode->DebugString(); + } + return cnode_adjoint_iter->second; +} + +PynativeAdjointPtr KPynativeCellImpl::ForgeMakeSequenceAdjoint(const CNodePtr &cnode) { + // () or [] is not supported yet. + if (cnode->size() <= 1) { + MS_LOG(DEBUG) << "MakeTuple/MakeList CNode is empty Tuple/List, CNode: " << cnode->DebugString(); + static auto dummy_adjoint = std::make_shared(nullptr, ValuePtrList{}, nullptr, nullptr); + anfnode_to_adjoin_[cnode] = dummy_adjoint; + cnode->set_stop_gradient(true); + return dummy_adjoint; + } + ValuePtrList op_args; + for (size_t i = 1; i < cnode->size(); ++i) { + const auto &inp = cnode->input(i); + auto inp_adjoint_iter = anfnode_to_adjoin_.find(inp); + if (inp_adjoint_iter == anfnode_to_adjoin_.end()) { + MS_LOG(DEBUG) << "Item in CNode cannot found in cache. Inp is: " << inp->DebugString(); + if (inp->isa()) { + const auto inp_cnode = inp->cast(); + MS_EXCEPTION_IF_NULL(inp_cnode); + auto forged_inp_adjoint = ForgeCNodeAdjoint(inp->cast()); + op_args.push_back(forged_inp_adjoint->out()); + } else if (inp->isa()) { + const auto &inp_value = GetValueNode(inp); + op_args.push_back(inp_value); + } else { + MS_LOG(EXCEPTION) << "Input of MakeTuple/MakeLis is not a CNode or ValueNode, but: " << inp->DebugString(); + } + } else { + op_args.push_back(inp_adjoint_iter->second->out()); + } + } + ValuePtr cnode_out = nullptr; + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + cnode_out = MakeValue(op_args); + } + if (IsPrimitiveCNode(cnode, prim::kPrimMakeList)) { + cnode_out = std::make_shared(op_args); + } + // op_args is real inputs find by prev cnode outputs + auto built = KPynativeOp(cnode, op_args, cnode_out); + if (!built) { + MS_LOG(EXCEPTION) << "Build Adjoint for MakeTuple/MakeList node failed, CNode: " << cnode->DebugString(); + } + auto cnode_adjoint_iter = anfnode_to_adjoin_.find(cnode); + if (cnode_adjoint_iter == anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "Build Adjoint for MakeTuple/MakeList node failed, CNode: " << cnode->DebugString(); + } + return cnode_adjoint_iter->second; +} + +PynativeAdjointPtr KPynativeCellImpl::ForgeCNodeAdjoint(const CNodePtr &cnode) { + if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimListGetItem)) { + MS_LOG(DEBUG) << "Build cnode adjoint for anfnode: " << cnode->DebugString(); + return ForgeGetItemAdjoint(cnode); + } + + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) { + MS_LOG(DEBUG) << "Build cnode adjoint for anfnode: " << cnode->DebugString(); + return ForgeMakeSequenceAdjoint(cnode); + } + MS_LOG(EXCEPTION) << "Unknown cnode: " << cnode->DebugString(); +} + +bool KPynativeCellImpl::BuildHighOrderFuncGraphAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, + const ValuePtr &out, const FuncGraphPtr &bprop_fg) { + auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(cnode); + if (anfnode_adjoint_iter != anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "CNode should be unique, but: " << cnode->DebugString(); + } + last_node_ = cnode; + + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto inp_i = cnode->input(i); + auto input_anfnode_adjoint_iter = anfnode_to_adjoin_.find(inp_i); + if (input_anfnode_adjoint_iter == anfnode_to_adjoin_.end()) { + if (inp_i->isa()) { + auto cnode_inp_i = inp_i->cast(); + auto forged_adjoint = ForgeCNodeAdjoint(cnode_inp_i); + if (forged_adjoint == nullptr) { + MS_LOG(EXCEPTION) << "Cannot forge adjoint for anfnode: " << inp_i->DebugString(); + } + forged_adjoint->users().push_back(cnode); + } else { + auto inp_i_pynative_adjoint = std::make_shared(tape_, ValuePtrList{}, op_args[i - 1], nullptr); + anfnode_to_adjoin_.insert(std::make_pair(inp_i, inp_i_pynative_adjoint)); + inp_i_pynative_adjoint->users().push_back(cnode); + } + } else { + input_anfnode_adjoint_iter->second->users().push_back(cnode); + } + } + + auto cnode_pynative_adjoint = std::make_shared(tape_, op_args, out, bprop_fg); + anfnode_to_adjoin_.insert(std::make_pair(cnode, cnode_pynative_adjoint)); + return true; +} + +bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, + const FuncGraphPtr &bprop_fg) { + // Optimize the bprop_fg based on value. + // Clone op_args and out, so the address of tensor data can be reset to nullptr if the value of tensor + // is not used in bprop_fg; + ValuePtrList cloned_op_args; + std::transform(op_args.begin(), op_args.end(), std::back_inserter(cloned_op_args), + [](const ValuePtr &value) { return ShallowCopyValue(value); }); + ValuePtr cloned_out = ShallowCopyValue(out); + auto optimized_bprop_fg = OptimizeBPropFuncGraph(bprop_fg, cnode, cloned_op_args, cloned_out); + + auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(cnode); + if (anfnode_adjoint_iter != anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "CNode should be unique, but: " << cnode->DebugString(); + } + // Book-keeping last cnode, as dout of this node will be given from outside; + last_node_ = cnode; + + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto inp_i = cnode->input(i); + auto input_anfnode_adjoint_iter = anfnode_to_adjoin_.find(inp_i); + if (input_anfnode_adjoint_iter == anfnode_to_adjoin_.end()) { + if (inp_i->isa()) { + auto cnode_inp_i = inp_i->cast(); + auto forged_adjoint = ForgeCNodeAdjoint(cnode_inp_i); + if (forged_adjoint == nullptr) { + MS_LOG(EXCEPTION) << "Cannot forge adjoint for anfnode: " << inp_i->DebugString(); + } + forged_adjoint->users().push_back(cnode); + } else { + auto inp_i_pynative_adjoint = std::make_shared(tape_, ValuePtrList{}, op_args[i - 1], nullptr); + anfnode_to_adjoin_.insert(std::make_pair(inp_i, inp_i_pynative_adjoint)); + inp_i_pynative_adjoint->users().push_back(cnode); + } + } else { + input_anfnode_adjoint_iter->second->users().push_back(cnode); + } + } + + auto cnode_pynative_adjoint = + std::make_shared(tape_, cloned_op_args, cloned_out, optimized_bprop_fg); + anfnode_to_adjoin_.insert(std::make_pair(cnode, cnode_pynative_adjoint)); + + return true; +} + +FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &cnode, const ValuePtrList &op_args, + const ValuePtr &out) { + auto optimized_bprop_fg = + pipeline::PrimBpropOptimizer::GetPrimBpropOptimizerInst().OptimizeBPropFuncGraph(bprop_fg, cnode, op_args, out); + return optimized_bprop_fg; +} + +bool KPynativeCellImpl::BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app) { + for (size_t i = 1; i < cnode_primal->size(); i++) { + auto input = cnode_primal->input(i); + // Useless to accumulate sens for ValueNode, the sens for ValueNode should be zeros_like; + if (input->isa()) { + continue; + } + auto cnode_input = input->cast(); + if (cnode_input != nullptr && cnode_input->stop_gradient()) { + MS_LOG(DEBUG) << "Bypass accumulate dout to cnode with stop_gradient flag, cnode: " << input->ToString(); + continue; + } + // Backprop sens wrt inputs. + auto input_adjoint_iter = anfnode_to_adjoin_.find(input); + if (input_adjoint_iter == anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString(); + } + auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i - 1))}); + input_adjoint_iter->second->AccumulateDout(din); + } + return true; +} + +bool KPynativeCellImpl::BackPropagate() { + for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) { + if (!iter->first->isa()) { + continue; + } + auto cnode = iter->first->cast(); + if (cnode->stop_gradient()) { + MS_LOG(DEBUG) << "Bypass backpropagate for cnode with stop_gradient flag: " << cnode->ToString(); + continue; + } + MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->ToString(); + auto bprop_fg = iter->second->bprop_fg(); + MS_EXCEPTION_IF_NULL(bprop_fg); + + AnfNodePtrList node_list{NewValueNode(bprop_fg)}; + auto &args = iter->second->op_args(); + for (size_t idx = 0; idx < args.size(); ++idx) { + auto cur_arg = args[idx]; + auto anf_node = cnode->input(idx + 1); + if (!anf_node->isa()) { + anf_node = NewValueNode(cur_arg); + } + anf_node->set_abstract(cur_arg->ToAbstract()->Broaden()); + node_list.push_back(anf_node); + } + + auto out_node = NewValueNode(iter->second->out()); + out_node->set_abstract(iter->second->out()->ToAbstract()->Broaden()); + node_list.push_back(out_node); + node_list.push_back(iter->second->RealDout()); + + // Back propagate process + auto bprop_app = tape_->NewCNode(node_list); + BackPropagate(cnode, bprop_app); + } + return true; +} + +bool KPynativeCellImpl::AllReferencesStopped(const CNodePtr &curr_cnode) { + // If all CNode use curr_cnode has stop_gradient_ flag, then curr_cnode also can set that flag. + auto iter = anfnode_to_adjoin_.find(curr_cnode); + if (iter == anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "Cannot find adjoint for cnode: " << curr_cnode->DebugString(); + } + auto users = iter->second->users(); + if (users.empty()) { + return false; + } + auto all_users_have_stopped = std::all_of(users.cbegin(), users.cend(), [](const AnfNodePtr &user) { + if (!user->isa() || !user->cast()->stop_gradient()) { + return false; + } + return true; + }); + return all_users_have_stopped; +} + +void KPynativeCellImpl::PropagateStopGradient() { + // propagate need_stop_gradient_ to cnode before back propagate; + if (need_propagate_stop_gradient_) { + for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) { + const auto &node = iter->first; + if (node->isa()) { + auto cnode = node->cast(); + if (!cnode->stop_gradient()) { + // Cut off the cnode only when it's not referred any more + if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) || + AllReferencesStopped(cnode)) { + MS_LOG(DEBUG) << "Set stop_gradient flag for " << cnode->ToString(); + cnode->set_stop_gradient(true); + } + } + } + } + } +} + +FuncGraphPtr KPynativeCellImpl::BuildBpropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode) { + auto inputs_num = cnode->size() - 1; + + auto func_graph = std::make_shared(); + std::vector outputs; + + auto bprop_cut = std::make_shared("bprop_cut", py::object()); + bprop_cut->CopyHookFunction(prim); + + auto cell_id = GetValue(prim->GetAttr("cell_id")); + if (cell_id != "") { + (void)bprop_cut->AddAttr("cell_hook", MakeValue(true)); + (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id)); + } + + outputs.push_back(NewValueNode(bprop_cut)); + for (size_t i = 0; i < inputs_num; ++i) { + auto param = func_graph->add_parameter(); + outputs.push_back(param); + } + // out, dout + auto p1 = func_graph->add_parameter(); + auto p2 = func_graph->add_parameter(); + outputs.push_back(p1); + outputs.push_back(p2); + + func_graph->set_output(func_graph->NewCNode(outputs)); + return func_graph; +} + +FuncGraphPtr KPynativeCellImpl::BuildMakeSequenceBprop(const PrimitivePtr &prim, const CNodePtr &cnode) { + using KeyPair = std::pair; + static std::map bprop_func_graph_cache; + auto inputs_num = cnode->size() - 1; + KeyPair key{prim->name(), inputs_num}; + auto bprop_func_graph_iter = bprop_func_graph_cache.find(key); + if (bprop_func_graph_iter != bprop_func_graph_cache.end()) { + return bprop_func_graph_iter->second; + } + + FuncGraphPtr b = std::make_shared(); + + std::ostringstream ss; + ss << "◀" << prim->ToString() << inputs_num; + b->debug_info()->set_name(ss.str()); + for (size_t i = 0; i < inputs_num; ++i) { + auto param = b->add_parameter(); + } + // out, dout + auto p1 = b->add_parameter(); + AnfNodePtr dout = b->add_parameter(); + + std::vector grads; + PrimitivePtr getitem_prim; + + if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) { + getitem_prim = prim::kPrimTupleGetItem; + } else if (IsPrimitiveEquals(prim, prim::kPrimMakeList)) { + getitem_prim = prim::kPrimListGetItem; + } else { + MS_LOG(EXCEPTION) << "Prim should be MakeTuple or MakeList, Invalid prim: " << prim->ToString(); + } + + grads.push_back(NewValueNode(prim)); + for (size_t i = 0; i < inputs_num; ++i) { + grads.push_back(b->NewCNode({NewValueNode(getitem_prim), dout, NewValueNode(SizeToLong(i))})); + } + + b->set_flag(FUNC_GRAPH_FLAG_CORE, true); + b->set_output(b->NewCNode(grads)); + + bprop_func_graph_cache[key] = b; + return b; +} + +void KPynativeCellImpl::SetReturnNodeByGradFlag(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights) { + AnfNodePtrList grad_inputs_list{NewValueNode(prim::kPrimMakeTuple)}; + if (grad_inputs) { + for (const auto &input : cell_inputs_) { + MS_EXCEPTION_IF_NULL(input); + auto input_adjoint_iter = anfnode_to_adjoin_.find(input); + if (input_adjoint_iter == anfnode_to_adjoin_.end()) { + // If input is not used in the network, just return zeros_like() as dout; + MS_LOG(WARNING) << "Input is not used in network, input: " << input->ToString(); + auto dout = BuildZerosLikeNode(tape_, input); + grad_inputs_list.push_back(dout); + } else { + grad_inputs_list.push_back(input_adjoint_iter->second->RealDout()); + } + } + } + + AnfNodePtrList grad_weights_list{NewValueNode(prim::kPrimMakeTuple)}; + if (grad_weights) { + for (const auto &weight : weights) { + MS_EXCEPTION_IF_NULL(weight); + auto input_adjoint_iter = anfnode_to_adjoin_.find(weight); + if (input_adjoint_iter == anfnode_to_adjoin_.end()) { + // If weight is not used in the network, just return zeros_like() as dout; + MS_LOG(WARNING) << "Weight is not used in network, weight: " << weight->ToString(); + auto input_w = weight->cast(); + MS_EXCEPTION_IF_NULL(input_w); + auto default_param = input_w->default_param(); + MS_EXCEPTION_IF_NULL(default_param); + auto dout = BuildZerosLikeValue(tape_, default_param); + grad_weights_list.push_back(dout); + } else { + grad_weights_list.push_back(input_adjoint_iter->second->RealDout()); + } + } + } + + AnfNodePtr tape_output; + if (grad_inputs && grad_weights) { + tape_output = tape_->NewCNode( + {NewValueNode(prim::kPrimMakeTuple), tape_->NewCNode(grad_inputs_list), tape_->NewCNode(grad_weights_list)}); + } else if (grad_inputs) { + tape_output = tape_->NewCNode(grad_inputs_list); + } else if (grad_weights) { + tape_output = tape_->NewCNode(grad_weights_list); + } else if (cell_inputs_.empty()) { + tape_output = tape_->NewCNode(grad_inputs_list); + } else { + auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[0]); + if (input_adjoint_iter == anfnode_to_adjoin_.end()) { + // If input is not used in the network, just return zeros_like() as dout; + MS_LOG(WARNING) << "Input is not used in network, input: " << cell_inputs_[0]->ToString(); + tape_output = BuildZerosLikeNode(tape_, cell_inputs_[0]); + } else { + tape_output = input_adjoint_iter->second->RealDout(); + } + } + tape_->set_output(tape_output); +} +} // namespace ad +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kpynative.h b/mindspore/ccsrc/frontend/optimizer/ad/kpynative.h new file mode 100644 index 0000000000000000000000000000000000000000..e0ec4bb656afedd747cd4b70561d38f574d52229 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/kpynative.h @@ -0,0 +1,80 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_ + +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" + +namespace mindspore { +namespace ad { +class KPynativeCell { + public: + virtual ~KPynativeCell() = default; +}; + +using KPynativeCellPtr = std::shared_ptr; + +// bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after +// just parsed. will have prototype: +// (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout) +// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim. +// op_args: the arguments list of each input parameters. +// out: the op result. +// return: the returned funcgraph should have the same prototype. +FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args, + const ValuePtr &out); + +// Start building back propagate funcgraph for this cell. +// cell_inputs: the input parameter list of this cell except the weights; +KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs); + +// Return the back propagate funcgraph for this cell. +// weights: weights parameters used in this cell. +// grad_inputs: return sensitivity for input parameters; +// grad_weights: return sensitivity for weights; +// has_sens_arg: caller will pass sens args; +// return: the returned funcgraph will have prototype: +// if has_sens_arg is true +// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ..., +// sens_out) +// else: +// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...) +FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs, + bool grad_weights, bool has_sens_arg = false); + +// Grad for each operation. +// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim. +// op_args: the arguments list of each input parameters. +// out: the op result. +bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args, + const ValuePtr &out, FuncGraphPtr bprop_fg = nullptr); + +// Grad for cell which may have user defined back propagate function. +// c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that +// cell. op_args: the arguments list of each input parameters. +// out: the op result. +// bprop_fg: user defined back propagate funcgraph, it should be passed after just parsed. +// Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout) +bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args, + const ValuePtr &out, const FuncGraphPtr &bprop_fg); +} // namespace ad +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc index ce3a5d434edef058af294e495773fa8c5c573f06..59a161a8963ab82748bbaebc7831fe09ead5e933 100644 --- a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc +++ b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc @@ -114,9 +114,10 @@ bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { if (new_fg != nullptr) { // Replace current fg with existing fg - auto users = fg->func_graph_cnodes_index(); - for (auto &iter : users) { - auto cnode = iter.first->first->cast(); + const auto &users = fg->func_graph_cnodes_index(); + for (auto &user : users) { + auto anf_node = user.first->first.lock(); + auto cnode = anf_node->cast(); auto new_input = cnode->inputs(); auto main_graph = cnode->func_graph(); MS_EXCEPTION_IF_NULL(main_graph); @@ -126,7 +127,7 @@ bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { new_input[0] = NewValueNode(new_fg); } auto new_cnode = main_graph->NewCNode(new_input); - manager->Replace(iter.first->first, new_cnode); + manager->Replace(anf_node, new_cnode); changed = true; } } else { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index c866d83f256dfe84f3f105a5b74f8bfd62edb7ae..078a201f23aaa3f8083a8079a46ae676134ec857 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -49,6 +49,7 @@ #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" #include "frontend/optimizer/irpass/switch_layer_defer_inline.h" #include "frontend/optimizer/irpass/call_graph_tuple_transform.h" +#include "frontend/optimizer/irpass/bool_scalar_eliminate.h" namespace mindspore { namespace opt { @@ -223,6 +224,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // switch_layer defer inline switch_layer_defer_inline_ = MakeSubstitution(std::make_shared(), "switch_layer_defer_inline", prim::kPrimSwitchLayer); + + bool_scalar_eliminate = MakeSubstitution(std::make_shared(), "bool_scalar_eliminate", IsCNode); } ResolveIRPassLib::ResolveIRPassLib() { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 9ad4ab9ec54904c49e3d3318972cc9689974abc6..ed184247c4d606fe3783af239d45ce87da581a18 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -144,6 +144,9 @@ class OptimizeIRPassLib { // Pynative Eliminate SubstitutionPtr pynative_eliminate_; + + // Eliminate getattr bool scalar + SubstitutionPtr bool_scalar_eliminate; }; // the collection of irpass for resolve action diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/bool_scalar_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/bool_scalar_eliminate.cc new file mode 100644 index 0000000000000000000000000000000000000000..15ccd270cc838b34924c688f3c9b491e19a091b0 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/bool_scalar_eliminate.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bool_scalar_eliminate.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { + +AnfNodePtr BoolScalarEliminate::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + auto cnode = node->cast(); + if (cnode == nullptr) { + return nullptr; + } + + if (!cnode->IsApply(prim::kPrimGetAttr)) { + return nullptr; + } + + auto vnode = cnode->input(1)->cast(); + if (vnode == nullptr) { + return nullptr; + } + + if (!vnode->value()->isa()) { + return nullptr; + } + + auto res = optimizer->resource(); + auto manager = res->manager(); + auto &node_users = manager->node_users(); + auto iter = node_users.find(node); + if (iter == node_users.end()) { + return nullptr; + } + + AnfNodeIndexSet node_idx_set = iter->second; + for (auto &item : node_idx_set) { + manager->Replace(item.first, vnode); + } + return nullptr; +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/bool_scalar_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/bool_scalar_eliminate.h new file mode 100644 index 0000000000000000000000000000000000000000..6313ec60ac547aab1035651955a55326ee14407f --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/bool_scalar_eliminate.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BOOL_SCALAR_ELIMINATE_H +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BOOL_SCALAR_ELIMINATE_H + +#include "ir/func_graph.h" +#include "frontend/optimizer/optimizer_caller.h" +#include "ir/pattern_matcher.h" + +namespace mindspore { +namespace opt { +namespace irpass { + +class BoolScalarEliminate : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_BOOL_SCALAR_ELIMINATE_H \ No newline at end of file diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h index 1362818c93f71ca3b780ec2b1a987368baab63de..919c927e339204a0a35ebe6efdc5a4645523570f 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -295,9 +295,9 @@ class InlinerBase : public AnfVisitor { }; bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { - auto &cnodes = fg->func_graph_cnodes_index(); + const auto &users = fg->func_graph_cnodes_index(); int64_t n_use = std::accumulate( - cnodes.begin(), cnodes.end(), 0, + users.begin(), users.end(), 0, [](int64_t sum, const std::pair &item) { return sum + item.second; }); return n_use == 1; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h index ad6657218d75b6e2c0456dcb7ad0742d3628b993..565f4b095514663a9529615e06895bcfd1104aba 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h @@ -154,6 +154,7 @@ class GetitemConstEliminator : public AnfVisitor { if (is_match_) { auto out = NewValueNode((*tuple_)[id_]); out->set_has_new_value(has_new_value_); + out->set_abstract((*tuple_)[id_]->ToAbstract()); return out; } return nullptr; diff --git a/mindspore/ccsrc/frontend/optimizer/opt.cc b/mindspore/ccsrc/frontend/optimizer/opt.cc index 5474bb5c1e1df5a35defa43eac530f44a23a44a3..b903a6a01132d24681380ccd3ee201c9d0657041 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.cc +++ b/mindspore/ccsrc/frontend/optimizer/opt.cc @@ -127,13 +127,13 @@ static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNod } static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, - std::deque *todo, bool change, size_t seen) { + std::vector &todo, bool change, size_t seen) { if (IsValueNode(node)) { - (*todo).emplace_back(GetValueNode(node)->output()); + todo.emplace_back(GetValueNode(node)->output()); } if (node->isa()) { auto &inputs = node->cast()->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo)); + (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); } if (!change) { @@ -151,7 +151,7 @@ static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const A if (use_node == nullptr) { continue; } - (*todo).emplace_back(use_node); + todo.emplace_back(use_node); if (use_node->seen_ == seen) { use_node->seen_--; } @@ -164,16 +164,17 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con #endif FuncGraphManagerPtr manager = optimizer->manager(); auto seen = NewSeenGeneration(); - // 1024 is for the initial capacity of deque - std::deque todo(1024); - todo.clear(); + // 1024 is for the initial capacity of vector + std::vector todo; + todo.reserve(1024); todo.emplace_back(func_graph->output()); bool changes = false; auto &all_nodes = manager->all_nodes(); - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); + size_t node_idx = 0; + while (node_idx < todo.size()) { + AnfNodePtr node = todo[node_idx]; + node_idx++; if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { continue; @@ -191,7 +192,7 @@ bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, con break; } } - UpdateTransformingList(optimizer, node, &todo, change, seen); + UpdateTransformingList(optimizer, node, todo, change, seen); } #ifdef ENABLE_PROFILE MsProfile::StatTime("opt.transforms." + optimizer->name(), GetTime() - start); @@ -206,16 +207,17 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons #endif FuncGraphManagerPtr manager = optimizer->manager(); auto seen = NewSeenGeneration(); - // 1024 is for the initial capacity of deque - std::deque todo(1024); - todo.clear(); + // 1024 is for the initial capacity of vector + std::vector todo(0); + todo.reserve(1024); todo.emplace_back(root_node); bool changes = false; auto &all_nodes = manager->all_nodes(); - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); + size_t node_idx = 0; + while (node_idx < todo.size()) { + AnfNodePtr node = todo[node_idx]; + node_idx++; if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { continue; @@ -229,7 +231,7 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons changes = true; node = res; } - UpdateTransformingList(optimizer, node, &todo, change, seen); + UpdateTransformingList(optimizer, node, todo, change, seen); } #ifdef ENABLE_PROFILE diff --git a/mindspore/ccsrc/pipeline/jit/CMakeLists.txt b/mindspore/ccsrc/pipeline/jit/CMakeLists.txt index ad553dd004ffde0f158dd25a2fde0aa4f72b10a7..f587bbc0398ee4f7fbdab1fb58197220ab2dbfd0 100644 --- a/mindspore/ccsrc/pipeline/jit/CMakeLists.txt +++ b/mindspore/ccsrc/pipeline/jit/CMakeLists.txt @@ -8,6 +8,7 @@ file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pipeline_split.cc" "parse/*.cc" "static_analysis/*.cc" + "prim_bprop_optimizer.cc" ) diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index fcad535cc989b6870e4904147045cdcfcca8bcac..4812219d20e01523d78f51dc08c61954a37633a7 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -472,7 +472,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python if (value && value->isa()) { MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; func_graph = value->cast(); - return func_graph; + return BasicClone(func_graph); } } @@ -489,7 +489,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python data_converter::SetObjGraphValue(obj_key, func_graph); } - return func_graph; + return BasicClone(func_graph); } namespace data_converter { static std::unordered_map object_map_; diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index 835b212dcdd1e1bcb0669083edbee9876a052630..4f9e70b4b82de651447e9400ea3640cb8034c051 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -311,12 +311,12 @@ class ParseAst { AstSubType GetOpType(const py::object &node); template - py::object CallParserObjMethod(const std::string &method, const T &... args) { + py::object CallParserObjMethod(const std::string &method, const T &...args) { return python_adapter::CallPyObjMethod(parser_, method, args...); } template - py::object CallParseModFunction(const std::string &function, const T &... args) { + py::object CallParseModFunction(const std::string &function, const T &...args) { return python_adapter::CallPyModFn(module_, function, args...); } diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_dynamic.cc b/mindspore/ccsrc/pipeline/jit/parse/parse_dynamic.cc new file mode 100644 index 0000000000000000000000000000000000000000..8de2c6e1b6f151c212a75c3e7e9903b13ba4ddab --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_dynamic.cc @@ -0,0 +1,249 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/parse_dynamic.h" +#include "mindspore/core/ir/cell.h" + +namespace mindspore::parse { +static std::unordered_set cell_input_args_; +static const std::set ignore_judge_dynamic_cell = { + "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal", + "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"}; +static const std::set unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, + parse::NAMED_PRIMITIVE_NAMECONSTANT, + parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; + +std::string DynamicAnalysis::ParseNodeName(const std::shared_ptr &ast, const py::object &node, + parse::AstMainType type) { + MS_EXCEPTION_IF_NULL(ast); + if (py::isinstance(node)) { + MS_LOG(DEBUG) << "Get none type node!"; + return ""; + } + auto node_type = ast->GetNodeType(node); + MS_EXCEPTION_IF_NULL(node_type); + // Check node type + parse::AstMainType node_main_type = node_type->main_type(); + if (node_main_type != type) { + MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type; + return ""; + } + std::string node_name = node_type->node_name(); + MS_LOG(DEBUG) << "Ast node is " << node_name; + return node_name; +} + +void DynamicAnalysis::ParseInputArgs(const std::shared_ptr &ast, const py::object &fn_node) { + MS_EXCEPTION_IF_NULL(ast); + py::list args = ast->GetArgs(fn_node); + for (size_t i = 1; i < args.size(); i++) { + std::string arg_name = py::cast(args[i].attr("arg")); + MS_LOG(DEBUG) << "Input arg name: " << arg_name; + cell_input_args_.emplace(arg_name); + } +} + +bool DynamicAnalysis::ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node) { + MS_LOG(DEBUG) << "Parse if/while expr"; + py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST); + const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR); + if (node_name == parse::NAMED_PRIMITIVE_COMPARE) { + py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT); + py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS); + if (comparators_node.empty()) { + MS_LOG(DEBUG) << "Get comparators node failed!"; + return false; + } + auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR); + auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR); + // while self.a > self.b and changed self.a or self.b + if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) { + auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE); + std::string left_variable; + if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) { + left_variable = py::cast(left_value.attr("id")) + py::cast(left_node.attr("attr")); + } + auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE); + std::string right_variable; + if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) { + right_variable = + py::cast(right_value.attr("id")) + py::cast(comparators_node[0].attr("attr")); + } + return ParseBodyContext(ast, node, {left_variable, right_variable}); + } + // if a[0] + if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) { + py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE); + left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR); + } + MS_LOG(DEBUG) << "Left is " << left << " Right is " << right; + if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() || + unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) { + return true; + } + } + // if flag: + if (node_name == parse::NAMED_PRIMITIVE_NAME) { + std::string id = py::cast(test_node.attr("id")); + if (cell_input_args_.find(id) != cell_input_args_.end()) { + return true; + } + } + return false; +} + +bool DynamicAnalysis::ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node) { + MS_LOG(DEBUG) << "Parse assign expr"; + py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE); + const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR); + if (node_name == parse::NAMED_PRIMITIVE_CALL) { + py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC); + const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR); + if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) { + py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE); + py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE); + if (py::isinstance(value_in_slice_node)) { + MS_LOG(DEBUG) << "Parse value node is none!"; + return false; + } + const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR); + std::string id; + if (py::hasattr(value_in_slice_node, "id")) { + id = py::cast(value_in_slice_node.attr("id")); + } + if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end() || + (!id.empty() && cell_input_args_.find(id) != cell_input_args_.end())) { + return true; + } + } + } + return false; +} + +bool DynamicAnalysis::ParseAugAssignExprNode(const std::shared_ptr &ast, const py::object &node, + const std::vector &compare_prim) { + MS_LOG(DEBUG) << "Parse augassign expr"; + bool ret = false; + if (compare_prim.empty()) { + return ret; + } + py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET); + if (py::isinstance(target_node)) { + MS_LOG(DEBUG) << "Parse target node is none!"; + return ret; + } + py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE); + if (py::isinstance(value_node)) { + MS_LOG(DEBUG) << "Parse value node is none!"; + return ret; + } + std::string assign_prim; + if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) { + assign_prim = py::cast(value_node.attr("id")) + py::cast(target_node.attr("attr")); + } + auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim); + if (iter != compare_prim.end()) { + ret = true; + } + return ret; +} + +bool DynamicAnalysis::ParseForExprNode(const std::shared_ptr &ast, const py::object &node) { + MS_LOG(DEBUG) << "Parse for expr"; + py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY); + if (py::isinstance(body_node)) { + MS_LOG(DEBUG) << "Parse body of for expression is none!"; + return false; + } + py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN); + size_t count = LongToSize(pcount); + MS_LOG(DEBUG) << "The for nodes count in body is " << count; + for (size_t i = 0; i < count; ++i) { + auto it = py::cast(body_node)[i]; + const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT); + if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) { + return true; + } + } + return false; +} + +bool DynamicAnalysis::ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node, + const std::vector &compare_prim) { + MS_EXCEPTION_IF_NULL(ast); + py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY); + if (py::isinstance(func_obj)) { + MS_LOG(DEBUG) << "Parse body of cell is none!"; + return false; + } + py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN); + size_t count = IntToSize(pcount); + MS_LOG(DEBUG) << "The nodes count in body is " << count; + bool ret = false; + for (size_t i = 0; i < count; ++i) { + auto node = py::cast(func_obj)[i]; + const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT); + if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) { + ret = ParseAssignExprNode(ast, node); + } else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) { + ret = ParseAugAssignExprNode(ast, node, compare_prim); + } else if (node_name == parse::NAMED_PRIMITIVE_FOR) { + ret = ParseForExprNode(ast, node); + } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) { + ret = ParseIfWhileExprNode(ast, node); + } + if (ret) { + MS_LOG(INFO) << "Current cell is dynamic!"; + break; + } + } + return ret; +} + +std::string DynamicAnalysis::GetCellInfo(const py::object &cell) { + if (py::isinstance(cell)) { + auto c_cell = py::cast(cell); + MS_EXCEPTION_IF_NULL(c_cell); + auto cell_info = c_cell->ToString(); + return cell_info; + } + return ""; +} + +bool DynamicAnalysis::IsDynamicCell(const py::object &cell) { + std::string cell_info = GetCellInfo(cell); + if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) { + return false; + } + // Using ast parse to check whether the construct of cell will be changed + auto ast = std::make_shared(cell); + bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD); + if (!success) { + MS_LOG(ERROR) << "Parse code to ast tree failed"; + return false; + } + py::object fn_node = ast->GetAstNode(); + // get the name of input args as the initialize of dynamic_variables + ParseInputArgs(ast, fn_node); + // parse body context + bool ret = false; + ret = ParseBodyContext(ast, fn_node); + cell_input_args_.clear(); + return ret; +} +} // namespace mindspore::parse diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_dynamic.h b/mindspore/ccsrc/pipeline/jit/parse/parse_dynamic.h new file mode 100644 index 0000000000000000000000000000000000000000..4871df88fcb0317908ee9828508f41b27c4eb2f7 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_dynamic.h @@ -0,0 +1,49 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_DYNAMIC_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_DYNAMIC_H_ + +#include "pipeline/jit/parse/parse.h" + +namespace mindspore::parse { + +class DynamicAnalysis { + public: + DynamicAnalysis() = default; + ~DynamicAnalysis() = default; + + // Check cell struct + static bool IsDynamicCell(const py::object &cell); + + private: + static std::string GetCellInfo(const py::object &cell); + static void ParseInputArgs(const std::shared_ptr &ast, const py::object &fn_node); + static bool ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node, + const std::vector &compare_prim = {}); + static bool ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node); + static bool ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node); + static bool ParseAugAssignExprNode(const std::shared_ptr &ast, const py::object &node, + const std::vector &compare_prim = {}); + static bool ParseForExprNode(const std::shared_ptr &ast, const py::object &node); + static std::string ParseNodeName(const std::shared_ptr &ast, const py::object &node, + parse::AstMainType type); +}; +} // namespace mindspore::parse + +#endif diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 3af722277c9acae3e85520e58bf0275841c28598..81aabb394f798fdd4de5aa8e8ddd774c6b833985 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -75,6 +75,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) { return true; } +bool TransformTopGraphPass(const ResourcePtr &res) { + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "Transform top graph error."; + } + FuncGraphPtr func_graph = res->func_graph(); + if (opt::FuncGraphHasTupleInput(func_graph)) { + opt::GraphTupleParamTransform graph_trans; + func_graph = graph_trans(func_graph, res->manager()); + res->set_func_graph(func_graph); + AbstractBasePtrList abs_spec_list; + auto ¶ms = func_graph->parameters(); + std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list), + [](AnfNodePtr node) { return node->abstract(); }); + res->set_args_spec(abs_spec_list); + } + return true; +} + bool CleanAfterOptAPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); @@ -93,6 +111,97 @@ bool CleanAfterOptAPass(const ResourcePtr &res) { return true; } +FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) { + opt::OptPassConfig pynative_eliminate_ = opt::OptPassConfig({ + irpass.pynative_eliminate_, + }); + opt::irpass::ResolveIRPassLib resolve_irpass; + opt::OptPassConfig resolver_prim = opt::OptPassConfig({ + resolve_irpass.resolver_resolve_and_getattr_, + resolve_irpass.resolver_resolve_, + resolve_irpass.resolver_getattr_, + }); + + opt::OptPassConfig switch_simplify_ = opt::OptPassConfig({ + irpass.switch_simplify_, + }); + + opt::OptPassConfig inline_ = opt::OptPassConfig({ + irpass.inline_, + }); + + opt::OptPassConfig bool_scalar_eliminate = opt::OptPassConfig({ + irpass.bool_scalar_eliminate, + }); + + OptPassGroupMap map({{"ad_eliminate_", pynative_eliminate_}, + {"ad_resolver_prim", resolver_prim}, + {"ad_inline_", inline_}, + {"bool_scalar_eliminate", bool_scalar_eliminate}, + {"ad_switch_simplify_", switch_simplify_}}); + + auto prim_bprop_opt_step_1 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_1", res, map); + FuncGraphPtr func_graph = res->func_graph(); + WITH(MsProfile::GetProfile()->Step("prim_bprop_opt_step_1"))[&prim_bprop_opt_step_1, &func_graph]() { + func_graph = prim_bprop_opt_step_1->step(func_graph, true); + }; + return func_graph; +} + +FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) { + opt::OptPassConfig switch_simplify_ = opt::OptPassConfig({ + irpass.switch_simplify_, + }); + + opt::OptPassConfig inline_ = opt::OptPassConfig({ + irpass.inline_, + }); + + opt::OptPassConfig zero_like_fill_zero_ = opt::OptPassConfig({ + irpass.zero_like_fill_zero_, + }); + + auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool { + return ReAutoMonad(root); + }; + OptPassGroupMap map({{"ad_renormalize", opt::OptPassConfig::Renormalize()}, + {"ad_inline_", inline_}, + {"ad_switch_simplify_", switch_simplify_}, + {"ad_zero_like_fill_zero_", zero_like_fill_zero_}, + {"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)}}); + + auto prim_bprop_opt_step_2 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_2", res, map); + FuncGraphPtr func_graph = res->func_graph(); + WITH(MsProfile::GetProfile()->Step("prim_bprop_opt_step_2"))[&prim_bprop_opt_step_2, &func_graph]() { + func_graph = prim_bprop_opt_step_2->step(func_graph, true); + }; + return func_graph; +} + +FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res); + if (!TransformTopGraphPass(res)) { + MS_LOG(EXCEPTION) << "Run TransformTopGraphPass failed"; + } + + opt::irpass::OptimizeIRPassLib irpass; + opt::OptPassConfig bg_final_opt_ = opt::OptPassConfig({ + irpass.inline_, + irpass.item_tuple_or_list_eliminate_, + irpass.depend_value_elim_, + irpass.reshape_eliminate_, + }); + OptPassGroupMap map({{"ad_final_opt_", bg_final_opt_}}); + + auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", res, map); + FuncGraphPtr func_graph = res->func_graph(); + WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() { + func_graph = bprop_graph_final_opt->step(func_graph, true); + }; + + return func_graph; +} + namespace { bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); } @@ -455,24 +564,6 @@ bool CconvPass(const ResourcePtr &res) { return true; } -bool TransformTopGraphPass(const ResourcePtr &res) { - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "Transform top graph error."; - } - FuncGraphPtr func_graph = res->func_graph(); - if (opt::FuncGraphHasTupleInput(func_graph)) { - opt::GraphTupleParamTransform graph_trans; - func_graph = graph_trans(func_graph, res->manager()); - res->set_func_graph(func_graph); - AbstractBasePtrList abs_spec_list; - auto ¶ms = func_graph->parameters(); - std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list), - [](AnfNodePtr node) { return node->abstract(); }); - res->set_args_spec(abs_spec_list); - } - return true; -} - bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); } void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) { diff --git a/mindspore/ccsrc/pipeline/jit/pass.h b/mindspore/ccsrc/pipeline/jit/pass.h index 50abeea7bfba8df9d2b1b5fb3b70d689a8d0eb01..f65ac4e50d84ec20c3c75f74306769e64b910595 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.h +++ b/mindspore/ccsrc/pipeline/jit/pass.h @@ -24,6 +24,12 @@ #include "pipeline/jit/resource.h" namespace mindspore { +namespace opt { +namespace irpass { +class OptimizeIRPassLib; +} // namespace irpass +} // namespace opt + namespace pipeline { using PassItem = std::pair>; @@ -40,6 +46,9 @@ bool AddCacheEmbeddingPass(const ResourcePtr &res); bool InferenceOptPreparePass(const ResourcePtr &res); void ReclaimOptimizer(); bool PynativeOptPass(const ResourcePtr &res); +FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res); +FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res); +FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index a17d01b58c5adfb5928e3f54ed949036ec2583e7..630255537e0feeab950a42fcacc59af9c42b2ae1 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -49,6 +49,7 @@ #include "utils/shape_utils.h" #include "utils/info.h" #include "load_mindir/load_model.h" +#include "pipeline/jit/prim_bprop_optimizer.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/constants.h" #include "ps/util.h" @@ -1199,6 +1200,7 @@ void ClearResAtexit() { } #endif ad::g_k_prims.clear(); + PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear(); abstract::ClearPrimEvaluatorMap(); compile::ClearConvertCache(); diff --git a/mindspore/ccsrc/pipeline/jit/prim_bprop_optimizer.cc b/mindspore/ccsrc/pipeline/jit/prim_bprop_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4449edd626f4e50edd4a0e535a3b4d9c7521897 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/prim_bprop_optimizer.cc @@ -0,0 +1,356 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "prim_bprop_optimizer.h" +#include "pass.h" + +namespace mindspore { +namespace pipeline { + +void PrimBpropOptGraphLevel2Info::TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out) { + // args_value_using_info_ contains out + if (args_value_using_info_.size() != op_args.size() + 1) { + MS_LOG(EXCEPTION) << "param size :" << args_value_using_info_.size() + << " of bp_graph:" << opt_func_graph_->ToString() + << " not match input arguments num:" << op_args.size(); + } + + ValuePtrList new_args(op_args); + new_args.emplace_back(out); + TryFreeOneValue(new_args, args_value_using_info_); +} + +void PrimBpropOptGraphLevel2Info::TryFreeOneValue(const ValuePtrList &op_args, + const std::vector ¶m_info_vec) { + if (param_info_vec.size() != op_args.size()) { + MS_LOG(EXCEPTION) << "param size :" << param_info_vec.size() << " of bp_graph:" << opt_func_graph_->ToString() + << " not match input arguments num:" << op_args.size(); + } + + for (size_t i = 0; i < op_args.size(); ++i) { + if (!param_info_vec[i].using_flg_ && !param_info_vec[i].tuple_flg_ && op_args[i]->isa()) { + auto value = op_args[i]->cast(); + value->set_device_address(nullptr); + } else if (param_info_vec[i].tuple_flg_ && op_args[i]->isa()) { + auto value = op_args[i]->cast(); + MS_EXCEPTION_IF_NULL(value); + TryFreeOneValue(value->value(), param_info_vec[i].sub_using_info_); + } + } +} + +void PrimBpropOptGraphLevel2Info::AnalysisArgUsingInfo(FuncGraphManagerPtr &manager) { + if (analysis_finish_flg_) { + return; + } + MS_EXCEPTION_IF_NULL(opt_func_graph_); + auto ¶ms = opt_func_graph_->parameters(); + auto &node_users = manager->node_users(); + args_value_using_info_.resize(params.size() - 1); + // analysis value using flg except dout + for (size_t i = 0; i < params.size() - 1; ++i) { + auto ¶m = params[i]; + auto &arg_info = args_value_using_info_[i]; + ArgInfoRefresh(param, arg_info); + AnalysisNodeUsingInfo(node_users, param, arg_info); + } + analysis_finish_flg_ = true; +} + +void PrimBpropOptGraphLevel2Info::AnalysisNodeUsingInfo(const NodeUsersMap &node_users, + const std::shared_ptr ¶m, + ParamUsingInfo &arg_info) const { + auto iter = node_users.find(param); + + if (iter == node_users.end()) { + arg_info.using_flg_ = false; + return; + } + + // tensor return directly + if (!arg_info.tuple_flg_) { + arg_info.using_flg_ = true; + return; + } + + // specific process for tuple parameter, may only partial items used + // map(A, (B, i)) + auto &users_info = iter->second; + for (auto &user_info : users_info) { + auto user_node = user_info.first; + arg_info.using_flg_ = true; + MS_LOG(DEBUG) << "param:" << param->ToString() << " used by node:" << user_node->ToString(); + if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) { + for (auto &sub_info : arg_info.sub_using_info_) { + sub_info.using_flg_ = true; + } + } else { + AalysisForTupleGetItem(node_users, param, arg_info, user_node); + } + } +} +void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &node_users, + const std::shared_ptr ¶m, + ParamUsingInfo &arg_info, const AnfNodePtr &user_node) const { + auto cnode = user_node->cast(); + if (cnode->size() != 3) { + MS_LOG(EXCEPTION) << "TupleGetItem Node:" << user_node->ToString() << " of bp_graph:" << opt_func_graph_->ToString() + << "input size is:" << cnode->size(); + } + auto idx_node = cnode->input(2); + if (!idx_node->isa()) { + MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString() + << " unexpected used by node:" << user_node->ToString() + << " TupleGetItem idx node:" << idx_node->ToString(); + } + + auto vnode = idx_node->cast(); + auto value_ptr = vnode->value(); + if (value_ptr == nullptr || !value_ptr->isa()) { + MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString() + << " unexpected used by node:" << user_node->ToString() + << " TupleGetItem idx node:" << idx_node->ToString() << " idx Value :" << value_ptr; + } + + auto idx = value_ptr->cast()->value(); + arg_info.sub_using_info_[idx].using_flg_ = true; + ArgInfoRefresh(cnode, arg_info.sub_using_info_[idx]); + + if (arg_info.tuple_flg_) { + AnalysisNodeUsingInfo(node_users, cnode, arg_info.sub_using_info_[idx]); + } +} + +void PrimBpropOptGraphLevel2Info::ArgInfoRefresh(const std::shared_ptr ¶m, + ParamUsingInfo &arg_info) const { + auto abs = param->abstract(); + if (abs->isa()) { + arg_info.tuple_flg_ = false; + MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTensor"; + } else if (abs->isa()) { + auto abs_tuple = abs->cast(); + MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTuple"; + arg_info.tuple_flg_ = true; + arg_info.tuple_size_ = abs_tuple->size(); + arg_info.sub_using_info_.resize(abs_tuple->size()); + } else { + arg_info.tuple_flg_ = false; + } +} + +PrimBpropOptimizer &PrimBpropOptimizer::GetPrimBpropOptimizerInst() { + static PrimBpropOptimizer g_prim_bprop_opt; + return g_prim_bprop_opt; +} + +PrimBpropOptimizer::PrimBpropOptimizer() {} + +PrimBpropOptimizer::~PrimBpropOptimizer() {} + +void PrimBpropOptimizer::Clear() { prim_bprop_cache_.clear(); } + +// bprop_fg has the signature: +// (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out) +// c_node contains the prim(input 0) and the input parameters of that prim; +// op_args contains the arguments list of each input parameters, it maybe tensor or tuple +// out contains the out of c_node; +FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, + const ValuePtrList &op_args, const ValuePtr &out) { + MS_EXCEPTION_IF_NULL(bprop_fg); + MS_EXCEPTION_IF_NULL(c_node); + MS_EXCEPTION_IF_NULL(out); + auto &inputs = c_node->inputs(); + if (inputs.size() < 1 || inputs.size() - 1 != op_args.size()) { + MS_LOG(EXCEPTION) << "The parameters num " << inputs.size() - 1 << " not match arguments num " << op_args.size() + << ", CNode:" << c_node->ToString() << " grap:" << bprop_fg->ToString(); + } + + if (!IsValueNode(inputs[0])) { + MS_LOG(EXCEPTION) << "CNode:" << c_node->ToString() + << " not a primitive node, input_0 is:" << inputs[0]->ToString(); + } + + PrimitivePtr prim = GetValueNode(inputs[0]); + MS_LOG(DEBUG) << "Hash of prim " << prim->ToString() << " is:" << prim->hash(); + + // kPrimHookBackward + bool hookback_flg = IsPrimitiveEquals(prim, prim::kPrimHookBackward); + if (hookback_flg || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) { + return GenSpecOptBprop(bprop_fg, op_args, out, prim, hookback_flg); + } + + return GetOptBpropFromCache(bprop_fg, op_args, out, prim); +} + +FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, + const ValuePtr &out, PrimitivePtr &prim) { + abstract::AbstractBasePtrList abs_list; + ArgsToAbs(prim, op_args, abs_list); + + PrimBpropOptGraphLevel2InfoPtr level_2_graph_info; + PrimBpropOptGraphInfoPtr level_1_graph_info; + ECacheQrtRes cache_res = GetOptBpfgFromCache(prim, abs_list, level_2_graph_info, level_1_graph_info); + + MS_LOG(DEBUG) << "Cache match result " << cache_res << ", prim: " << prim->ToString(); + if (cache_res == E_LEVEL_2) { + MS_LOG(DEBUG) << "Level 2 cache matched, prim: " << prim->ToString(); + level_2_graph_info->TryFreeArgsValue(op_args, out); + return BasicClone(level_2_graph_info->opt_func_graph()); + } + + // do step1 opt + if (cache_res == E_NOT_FOUND) { + bprop_fg->debug_info()->set_name(prim->ToString()); + level_1_graph_info = PrimBpropOptStep1(bprop_fg); + prim_bprop_cache_[prim] = level_1_graph_info; + } + FuncGraphPtr level_1_graph = BasicClone(level_1_graph_info->opt_func_graph_); + + // do step2 opt + auto new_abs_list = AddOutToAbsList(out, abs_list); + level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list); + level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info; + level_2_graph_info->TryFreeArgsValue(op_args, out); + return BasicClone(level_2_graph_info->opt_func_graph()); +} + +FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, + const ValuePtr &out, PrimitivePtr &prim, bool hook_flg) { + abstract::AbstractBasePtrList abs_list; + ArgsToAbs(prim, op_args, abs_list); + if (!hook_flg) { + auto iter = tuple_list_bprop_cache_.find(std::pair(prim, abs_list)); + if (iter != tuple_list_bprop_cache_.end()) { + return BasicClone(iter->second); + } + } + + // do step1 opt + bprop_fg->debug_info()->set_name(prim->ToString()); + auto level_1_graph_info = PrimBpropOptStep1(bprop_fg); + + // do step2 opt + auto new_abs_list = AddOutToAbsList(out, abs_list); + auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list); + level_2_graph_info->TryFreeArgsValue(op_args, out); + + if (!hook_flg) { + tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph()); + } + return level_2_graph_info->opt_func_graph(); +} + +PrimBpropOptGraphInfoPtr PrimBpropOptimizer::PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) { + auto level_1_graph_info = std::make_shared(); + auto prim_bprop_opt_res = std::make_shared(); + auto prim_bprop_opt_manage = prim_bprop_opt_res->manager(); + prim_bprop_opt_res->set_func_graph(bprop_fg); + prim_bprop_opt_manage->AddFuncGraph(bprop_fg); + auto opt_bprop_fg = PrimBpOptPassStep1(irpass_, prim_bprop_opt_res); + level_1_graph_info->opt_func_graph_ = opt_bprop_fg; + return level_1_graph_info; +} + +void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg, + abstract::AbstractBasePtrList &abs_list_input) { + auto ¶ms = bprop_fg->parameters(); + if (abs_list_input.size() != params.size()) { + MS_LOG(EXCEPTION) << "Param num:" << params.size() << " not match inputs num " << abs_list_input.size(); + } + + for (size_t i = 0; i < abs_list_input.size(); i++) { + params[i]->set_abstract(abs_list_input[i]); + } +} + +PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(const FuncGraphPtr &bprop_fg, + abstract::AbstractBasePtrList &abs_list_input) { + BindAbsToParameters(bprop_fg, abs_list_input); + pipeline::ResourcePtr resource = std::make_shared(); + auto manager = resource->manager(); + resource->set_func_graph(bprop_fg); + manager->AddFuncGraph(bprop_fg); + auto opt_bprop_fg = PrimBpOptPassStep2(irpass_, resource); + auto level_2_graph_info = std::make_shared(opt_bprop_fg); + level_2_graph_info->AnalysisArgUsingInfo(manager); + return level_2_graph_info; +} + +FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res); + auto after_opt_bg = BpropGraphFinalOptPass(res); + return after_opt_bg; +} + +ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim, + const abstract::AbstractBasePtrList &abs_list, + PrimBpropOptGraphLevel2InfoPtr &level_2_graph_info, + PrimBpropOptGraphInfoPtr &level_1_graph_info) { + auto attrs_ = prim->attrs(); + for (auto &item : attrs_) { + MS_LOG(DEBUG) << "prim:" << prim->ToString() << " attr: " << item.first << " value:" << item.second->ToString(); + } + + auto iter = prim_bprop_cache_.find(prim); + if (iter == prim_bprop_cache_.end()) { + return E_NOT_FOUND; + } + + level_1_graph_info = iter->second; + auto second_iter = level_1_graph_info->graph_level_2_cache_.find(abs_list); + if (second_iter == level_1_graph_info->graph_level_2_cache_.end()) { + return E_LEVEL_1; + } + level_2_graph_info = second_iter->second; + return E_LEVEL_2; +} + +void PrimBpropOptimizer::ArgsToAbs(PrimitivePtr &prim, const ValuePtrList &op_args, + abstract::AbstractBasePtrList &abs_list) { + auto const_input_index = prim->get_const_input_indexes(); + bool have_const_input = !const_input_index.empty(); + bool is_const_prim = prim->is_const_prim(); + for (size_t i = 0; i < op_args.size(); ++i) { + bool is_const_input = + have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end(); + auto &arg_value = op_args[i]; + auto arg_abs = arg_value->ToAbstract(); + if (!is_const_prim && !is_const_input) { + auto config = abstract::AbstractBase::kBroadenTensorOnly; + arg_abs = arg_abs->Broaden(config); + MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config; + } + abs_list.emplace_back(arg_abs); + } +} + +abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr &out, + const abstract::AbstractBasePtrList &abs_list) { + if (!out->isa() && !out->isa()) { + MS_LOG(EXCEPTION) << "Out value not Tensor or Tuple, please check the input arguments."; + } + abstract::AbstractBasePtrList new_abs_list(abs_list); + auto out_abs = out->ToAbstract(); + auto config = abstract::AbstractBase::kBroadenTensorOnly; + out_abs = out_abs->Broaden(config); + new_abs_list.emplace_back(out_abs); + new_abs_list.emplace_back(out_abs); + return new_abs_list; +} + +} // namespace pipeline +} // namespace mindspore \ No newline at end of file diff --git a/mindspore/ccsrc/pipeline/jit/prim_bprop_optimizer.h b/mindspore/ccsrc/pipeline/jit/prim_bprop_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..b833f68211c2eff3cf75195198c963e385440ec2 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/prim_bprop_optimizer.h @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H +#define MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "ir/func_graph.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace pipeline { +struct PrimBpropOptGraphInfo; + +class PrimBpropOptGraphLevel2Info; + +struct PrimitiveTotalEqual; + +struct PrimitiveTupleListHasher; + +struct PrimitiveTupleListEqual; + +using PrimBpropOptGraphInfoPtr = std::shared_ptr; + +using PrimBpropOptGraphLevel2InfoPtr = std::shared_ptr; + +using PrimBpropCache = std::unordered_map; + +using TupleListKey = std::pair; + +using PrimBpropLevel2Cache = + std::unordered_map; + +using PrimTupleListCache = + std::unordered_map; + +struct PrimitiveTupleListHasher { + bool operator()(const TupleListKey &key) const { + abstract::AbstractBasePtrListHasher hasher; + return hasher(key.second); + } +}; + +struct PrimitiveTupleListEqual { + bool operator()(TupleListKey const &t1, TupleListKey const &t2) const { + MS_EXCEPTION_IF_NULL(t1.first); + MS_EXCEPTION_IF_NULL(t2.first); + + if (!(*t1.first == *t2.first)) { + return false; + } + abstract::AbstractBasePtrListEqual cmp; + return cmp(t1.second, t2.second); + } +}; + +struct PrimitiveTotalEqual { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { + MS_EXCEPTION_IF_NULL(t1); + MS_EXCEPTION_IF_NULL(t2); + return *t1 == *t2; + } +}; + +enum ECacheQrtRes { E_NOT_FOUND, E_LEVEL_1, E_LEVEL_2 }; + +struct PrimBpropOptGraphInfo { + // the level1 opt func_graph without infer, no shape/type info provide + FuncGraphPtr opt_func_graph_; + // the opt func_graph after infer, func_graph level2 cache + PrimBpropLevel2Cache graph_level_2_cache_; +}; + +struct ParamUsingInfo { + bool using_flg_{false}; + bool tuple_flg_{false}; + size_t tuple_size_; + std::vector sub_using_info_; +}; + +class PrimBpropOptGraphLevel2Info { + public: + explicit PrimBpropOptGraphLevel2Info(const FuncGraphPtr &func_graph) : opt_func_graph_(func_graph) {} + + const FuncGraphPtr &opt_func_graph() const { return opt_func_graph_; } + + void TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out); + + void AnalysisArgUsingInfo(FuncGraphManagerPtr &manager); + + private: + void ArgInfoRefresh(const std::shared_ptr ¶m, ParamUsingInfo &arg_info) const; + + void AnalysisNodeUsingInfo(const NodeUsersMap &node_users, const std::shared_ptr ¶m, + ParamUsingInfo &arg_info) const; + + void TryFreeOneValue(const ValuePtrList &op_args, const std::vector ¶m_info_vec); + + void AalysisForTupleGetItem(const NodeUsersMap &node_users, const std::shared_ptr ¶m, + ParamUsingInfo &arg_info, const AnfNodePtr &user_node) const; + + private: + // the level2 opt func_graph + FuncGraphPtr opt_func_graph_; + // to indicate arguments value using or not, if not using should free device memory + std::vector args_value_using_info_; + bool analysis_finish_flg_{false}; +}; + +class PrimBpropOptimizer { + public: + ~PrimBpropOptimizer(); + + void Clear(); + + static PrimBpropOptimizer &GetPrimBpropOptimizerInst(); + + // bprop_fg has the signature: + // (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out) + // c_node contains the prim(input 0) and the input parameters of that prim; + // op_args contains the arguments list of each input parameters, it maybe tensor or tuple + // out contains the out of c_node; + FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args, + const ValuePtr &out); + + // do inline opt for final bprop graph + FuncGraphPtr BpropGraphFinalOpt(const ResourcePtr &res); + + private: + PrimBpropOptimizer(); + + ECacheQrtRes GetOptBpfgFromCache(const PrimitivePtr &prim, const abstract::AbstractBasePtrList &abs_list, + PrimBpropOptGraphLevel2InfoPtr &level_2_graph_info, + PrimBpropOptGraphInfoPtr &level_1_graph_info); + + // converter tensor args to abs value; + void ArgsToAbs(PrimitivePtr &prim, const ValuePtrList &op_args, abstract::AbstractBasePtrList &abs_list); + + // add out && dout to abs list + abstract::AbstractBasePtrList AddOutToAbsList(const ValuePtr &out, const abstract::AbstractBasePtrList &abs_list); + + // do opt without input info, no infer + PrimBpropOptGraphInfoPtr PrimBpropOptStep1(const FuncGraphPtr &bprop_fg); + + // do opt with input info + PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2(const FuncGraphPtr &bprop_fg, + abstract::AbstractBasePtrList &abs_list_input); + + void BindAbsToParameters(const FuncGraphPtr &bprop_fg, abstract::AbstractBasePtrList &abs_list_input); + + FuncGraphPtr GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out, + PrimitivePtr &prim); + + FuncGraphPtr GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out, + PrimitivePtr &prim, bool hook_flg); + + private: + // cache optimized bprop graph + PrimBpropCache prim_bprop_cache_; + opt::irpass::OptimizeIRPassLib irpass_; + PrimTupleListCache tuple_list_bprop_cache_; +}; + +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H \ No newline at end of file diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index ed111d7bec0c25393e77a5554280d85c6760fae8..29107ab9266a339d01436897c7837d3fb3a4758f 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -336,7 +336,7 @@ void MemoryCleaner::ReleasePrimitivePyObj(PrimitivePy *prim) { return; } all_primitives_[prim] = false; - prim->SetPyObj(py::none()); + // prim->SetPyObj(py::none()); } void MemoryCleaner::ClearPrimitivePyPythonObj() { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc index 95cb26460602091675090cc428f95b56f4afc5d5..d83fb899d10d82ed1491dc2979172786b84b82bf 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc @@ -734,7 +734,7 @@ class SideEffectFinder { const int para_index = GetParameterIndex(func_graph, para); const size_t input_index = static_cast(para_index) + 1; // Search user cnodes of the func graph. - auto &users = func_graph->func_graph_cnodes_index(); + const auto &users = func_graph->func_graph_cnodes_index(); if (users.empty()) { MS_LOG(WARNING) << "Unused graph for parameter " << para->DebugString(); } @@ -745,7 +745,7 @@ class SideEffectFinder { continue; } // Caller cnode. - auto cnode = dyn_cast(user.first->first); + auto cnode = dyn_cast(user.first->first.lock()); if (cnode && input_index < cnode->size()) { handler(cnode->input(input_index)); } diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index ad07595887c17990cc1e27e36cc9c1b204163a7d..f544ca6986c268757575508c9398d875f152922e 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -50,29 +50,23 @@ enum PynativeStatusCode { enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; struct OpExecInfo { + bool is_dynamic_shape = false; + bool is_mixed_precision_cast = false; + size_t next_input_index = 0; std::string op_name; - std::string op_index; + std::string op_info; + std::string next_op_name = ""; PrimitivePyPtr py_primitive; AbstractBasePtr abstract; py::list op_inputs; py::dict op_attrs; std::vector inputs_mask; - bool is_dynamic_shape = false; - std::string next_op_name = ""; - bool is_mixed_precision_cast = false; - size_t next_input_index = 0; }; using OpExecInfoPtr = std::shared_ptr; const std::set ignore_infer_prim = {"mixed_precision_cast"}; const std::set force_infer_prim = {"TopK", "DropoutGenMask"}; -const std::set ignore_judge_dynamic_cell = { - "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal", - "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"}; -const std::set unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, - parse::NAMED_PRIMITIVE_NAMECONSTANT, - parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; const std::set dynamic_shape_const_input_to_attr = {"Cast", "ExpandDims", "Reshape", "EmbeddingLookup", "Transpose"}; } // namespace pynative diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index d60db1afe4a7115b174372be2ec1c6cb6c039b7e..4eb8eb9dcc28645a9e35c80488585e1e122b1347 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -40,6 +40,7 @@ #include "frontend/operator/ops.h" #include "frontend/operator/composite/do_signature.h" #include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/parse_dynamic.h" #include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/static_analysis/prim.h" #include "pipeline/jit/static_analysis/auto_monad.h" @@ -58,6 +59,8 @@ #include "pipeline/jit/pipeline.h" #include "pipeline/jit/pass.h" #include "frontend/parallel/context.h" +#include "pipeline/jit/prim_bprop_optimizer.h" +#include "frontend/optimizer/ad/dfunctor.h" #ifdef ENABLE_GE #include "pipeline/pynative/pynative_execute_ge.h" @@ -81,7 +84,7 @@ std::mutex PynativeExecutor::instance_lock_; constexpr auto implcast = "implcast"; template -void PynativeExecutorTry(std::function method, T *ret, const Args &... args) { +void PynativeExecutorTry(std::function method, T *ret, const Args &...args) { const auto inst = PynativeExecutor::GetInstance(); MS_EXCEPTION_IF_NULL(inst); MS_EXCEPTION_IF_NULL(method); @@ -571,6 +574,60 @@ py::tuple ConvertArgs(const py::tuple &args) { return res; } +void ResetTopCellInfo(const TopCellInfoPtr &top_cell, const py::args &args) { + MS_EXCEPTION_IF_NULL(top_cell); + top_cell->set_op_num(0); + top_cell->set_all_op_info(""); + top_cell->set_forward_already_run(true); + std::string input_args_id; + for (size_t i = 0; i < args.size(); ++i) { + input_args_id = input_args_id + GetId(args[i]) + "_"; + } + top_cell->set_input_args_id(input_args_id); +} + +void SaveOpInfo(const TopCellInfoPtr &top_cell, const std::string &op_info, + const std::vector &op_out_tensors) { + MS_EXCEPTION_IF_NULL(top_cell); + auto &op_info_with_tensor_id = top_cell->op_info_with_tensor_id(); + if (op_info_with_tensor_id.find(op_info) != op_info_with_tensor_id.end()) { + MS_LOG(EXCEPTION) << "Top cell: " << top_cell.get() << " records op info with tensor id, but get op info " + << op_info << " in op_info_with_tensor_id map"; + } + // Record the relationship between the forward op and its output tensor id + std::for_each(op_out_tensors.begin(), op_out_tensors.end(), + [&](const tensor::TensorPtr &tensor) { op_info_with_tensor_id[op_info].emplace_back(tensor->id()); }); +} + +void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector &pre_tensors) { + MS_EXCEPTION_IF_NULL(new_tensor); + auto device_target = MsContext::GetInstance()->get_param(MS_CTX_DEVICE_TARGET); + for (auto &pre_tensor : pre_tensors) { + MS_EXCEPTION_IF_NULL(pre_tensor); + pre_tensor->set_shape(new_tensor->shape()); + pre_tensor->set_data_type(new_tensor->data_type()); + if (device_target != kCPUDevice) { + pre_tensor->set_device_address(new_tensor->device_address()); + } else { + auto old_device_address = std::dynamic_pointer_cast(pre_tensor->device_address()); + auto new_device_address = std::dynamic_pointer_cast(new_tensor->device_address()); + auto old_ptr = old_device_address->GetMutablePtr(); + MS_EXCEPTION_IF_NULL(old_ptr); + auto new_ptr = new_device_address->GetPtr(); + MS_EXCEPTION_IF_NULL(new_ptr); + auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize()); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret; + } + } + MS_LOG(DEBUG) << "Replace Old tensor " << pre_tensor.get() << " id " << pre_tensor->id() + << " device_address: " << pre_tensor->device_address()->GetMutablePtr() << " shape and type " + << pre_tensor->GetShapeAndDataTypeInfo() << " with New tensor " << new_tensor.get() << " id " + << new_tensor->id() << " device_address " << new_tensor->device_address()->GetMutablePtr() + << " shape and dtype " << new_tensor->GetShapeAndDataTypeInfo(); + } +} + void ClearPyNativeSession() { session = nullptr; } void CheckPyNativeContext() { @@ -603,6 +660,40 @@ GradExecutorPtr ForwardExecutor::grad() const { return grad_executor; } +bool TopCellInfo::IsSubCell(const std::string &cell_id) const { + if (sub_cell_list_.empty()) { + MS_LOG(DEBUG) << "The sub cell list is empty, there is no sub cell"; + return false; + } + if (sub_cell_list_.find(cell_id) != sub_cell_list_.end()) { + return true; + } + return false; +} + +void TopCellInfo::clear() { + MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_; + op_num_ = 0; + is_dynamic_ = false; + vm_compiled_ = false; + is_init_kpynative_ = false; + need_compile_graph_ = false; + forward_already_run_ = false; + input_args_id_.clear(); + all_op_info_.clear(); + + if (resource_ != nullptr) { + resource_->Clean(); + resource_ = nullptr; + } + df_builder_ = nullptr; + k_pynative_cell_ptr_ = nullptr; + graph_info_map_.clear(); + sub_cell_list_.clear(); + op_info_with_tensor_id_.clear(); + tensor_id_with_tensor_object_.clear(); +} + void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info) { MS_EXCEPTION_IF_NULL(ret); MS_EXCEPTION_IF_NULL(op_exec_info); @@ -659,13 +750,12 @@ void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_i } std::string obj_id = GetId(out_real); node_abs_map_[obj_id] = op_exec_info->abstract; - // Save info for building grad graph - if (grad()->grad_flag() && grad()->in_grad_process()) { + // Save cnode info and build grad graph + if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) { grad()->SaveOutputNodeMap(obj_id, out_real, cnode); - grad()->SaveAllResult(op_exec_info, cnode, out_real); - // Update the abstract and device address of value node with tensor in grad graph - UpdateAbstractAndDeviceAddress(op_exec_info, out_real); + grad()->DoOpGrad(op_exec_info, cnode, out_real); } + grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real); *ret = out_real; } @@ -677,16 +767,6 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) { auto op_exec_info = std::make_shared(); auto op_name = py::cast(args[PY_NAME]); op_exec_info->op_name = op_name; - // Need const grad graph - if (grad()->grad_flag()) { - // Get forward op index - op_exec_info->op_index = op_name + "_" + std::to_string(grad()->op_index_map()[op_name]); - if (!grad()->cell_op_info_stack().empty()) { - std::string &cell_op_info = grad()->cell_op_info_stack().top(); - cell_op_info += op_exec_info->op_index; - } - grad()->op_index_map()[op_name]++; - } auto prim = py::cast(args[PY_PRIM]); MS_EXCEPTION_IF_NULL(prim); if (!prim->HasPyObj()) { @@ -695,6 +775,14 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) { op_exec_info->py_primitive = prim; op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); op_exec_info->op_inputs = args[PY_INPUTS]; + // Record op info for judge whether the construct of cell has been changed + if (grad()->grad_flag()) { + size_t curr_op_num = grad()->top_cell()->op_num(); + op_exec_info->op_info = op_name + "-" + std::to_string(curr_op_num); + std::string curr_op_info = grad()->top_cell()->all_op_info() + "_" + op_exec_info->op_info; + grad()->top_cell()->set_all_op_info(curr_op_info); + grad()->top_cell()->set_op_num(curr_op_num + 1); + } return op_exec_info; } @@ -736,7 +824,12 @@ void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector if (grad()->need_construct_graph()) { AnfNodePtr input_node = nullptr; if (!grad()->top_cell_list().empty()) { - input_node = grad()->GetInput(obj, op_mask); + bool requires_grad = true; + if (op_mask) { + auto requires_grad_attr = parse::python_adapter::GetPyObjAttr(obj, "requires_grad"); + requires_grad = requires_grad_attr.cast(); + } + input_node = grad()->GetInput(obj, op_mask & requires_grad); } // update abstract if (input_node != nullptr) { @@ -757,8 +850,6 @@ AnfNodePtr ForwardExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::ve MS_EXCEPTION_IF_NULL(op_exec_info); auto prim = op_exec_info->py_primitive; - std::vector inputs; - inputs.emplace_back(NewValueNode(prim)); const auto &signature = prim->signatures(); auto sig_size = signature.size(); @@ -777,7 +868,10 @@ AnfNodePtr ForwardExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::ve if (op_exec_info->op_name != prim::kPrimCast->name()) { RunParameterAutoMixPrecisionCast(op_exec_info); } - MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad()->grad_flag(); + + std::vector inputs; + inputs.emplace_back(NewValueNode(prim)); + MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag " << grad()->grad_flag(); GetArgsSpec(op_exec_info, op_masks, &inputs, args_spec_list); CNodePtr cnode = nullptr; @@ -900,12 +994,11 @@ py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::obj auto grad = this->grad(); MS_EXCEPTION_IF_NULL(grad); if (grad->grad_flag()) { - // Get forward op index - if (!grad->cell_op_info_stack().empty()) { - std::string &cell_op_info = grad->cell_op_info_stack().top(); - cell_op_info += cast_struct->op_index; - } - grad->op_index_map()[cast_struct->op_name]++; + size_t curr_op_num = grad->top_cell()->op_num(); + cast_struct->op_info = cast_struct->op_name + "-" + std::to_string(curr_op_num); + std::string curr_op_info = grad->top_cell()->all_op_info() + "_" + cast_struct->op_info; + grad->top_cell()->set_all_op_info(curr_op_info); + grad->top_cell()->set_op_num(curr_op_num + 1); } py::object ret = py::none(); RunOpInner(&ret, cast_struct); @@ -1056,7 +1149,7 @@ AnfNodePtr GradExecutor::GetInput(const py::object &obj, bool op_mask) { MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; } auto param_name = py::cast(name_attr); - auto df_builder = GetDfbuilder(top_cell_id()); + auto df_builder = GetDfbuilder(top_cell()->cell_id()); MS_EXCEPTION_IF_NULL(df_builder); auto graph_info = top_cell()->graph_info_map().at(df_builder); MS_EXCEPTION_IF_NULL(graph_info); @@ -1111,128 +1204,6 @@ AnfNodePtr GradExecutor::GetInput(const py::object &obj, bool op_mask) { : MS_LOG(DEBUG) << "Get input node " << node->ToString() << " " << obj_id; return node; } - -void ForwardExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real) { - MS_EXCEPTION_IF_NULL(op_exec_info); - auto op_index = op_exec_info->op_index; - auto output_value = PyAttrValue(out_real); - MS_EXCEPTION_IF_NULL(output_value); - std::vector output_tensors; - TensorValueToTensor(output_value, &output_tensors); - if (cell_op_index_with_tensor_id()[grad()->top_cell_id()].find(op_index) == - cell_op_index_with_tensor_id()[grad()->top_cell_id()].end()) { - // first step - std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) { - cell_op_index_with_tensor_id()[grad()->top_cell_id()][op_index].emplace_back(tensor->id()); - }); - return; - } - auto ms_context = MsContext::GetInstance(); - auto target = ms_context->get_param(MS_CTX_DEVICE_TARGET); - const auto &tensor_id_list = cell_op_index_with_tensor_id()[grad()->top_cell_id()][op_index]; - for (size_t i = 0; i < tensor_id_list.size(); ++i) { - auto tensor_id = tensor_id_list[i]; - if (cell_tensor_id_with_tensor()[grad()->top_cell_id()].find(tensor_id) != - cell_tensor_id_with_tensor()[grad()->top_cell_id()].end()) { - auto &new_tensor = output_tensors[i]; - auto &tensors_in_value_node = cell_tensor_id_with_tensor()[grad()->top_cell_id()][tensor_id]; - std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) { - MS_LOG(DEBUG) << "Debug address: Replace forward old tensor obj " << tensor.get() << ", tensor id " - << tensor->id() << ", device address " << tensor->device_address().get() - << " with New tensor obj " << new_tensor.get() << ", tensor id " << new_tensor->id() - << ", device address " << new_tensor->device_address().get(); - tensor->set_shape(new_tensor->shape()); - tensor->set_data_type(new_tensor->data_type()); - if (target != kCPUDevice) { - tensor->set_device_address(new_tensor->device_address()); - } else { - auto old_device_address = std::dynamic_pointer_cast(tensor->device_address()); - auto new_device_address = std::dynamic_pointer_cast(new_tensor->device_address()); - auto old_ptr = old_device_address->GetMutablePtr(); - auto new_ptr = new_device_address->GetPtr(); - MS_EXCEPTION_IF_NULL(old_ptr); - MS_EXCEPTION_IF_NULL(new_ptr); - auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize()); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret; - } - } - }); - } - } -} - -void GradExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { - MS_EXCEPTION_IF_NULL(resource); - std::set forward_op_tensor_id; - auto it = forward()->cell_op_index_with_tensor_id().find(top_cell_id()); - if (it != forward()->cell_op_index_with_tensor_id().end()) { - for (const auto &elem : it->second) { - const auto &tensor_id_list = elem.second; - for (const auto &tensor_id : tensor_id_list) { - forward_op_tensor_id.emplace(tensor_id); - } - } - } - - forward()->cell_tensor_id_with_tensor()[top_cell_id()].clear(); - const auto &func_graph = resource->func_graph(); - const auto &value_node_list = func_graph->value_nodes(); - for (const auto &elem : value_node_list) { - auto value_node = elem.first->cast(); - MS_EXCEPTION_IF_NULL(value_node); - std::vector tensors; - TensorValueToTensor(value_node->value(), &tensors); - for (const auto &tensor : tensors) { - if (tensor->device_address() != nullptr && - forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) { - forward()->cell_tensor_id_with_tensor()[top_cell_id()][tensor->id()].emplace_back(tensor); - MS_LOG(DEBUG) << "Debug address: Save forward tensor obj " << tensor.get() << ", tensor id " << tensor->id() - << ", device address " << tensor->device_address().get(); - } - } - } -} - -void GradExecutor::CleanPreMemoryInValueNode() { - auto ms_context = MsContext::GetInstance(); - std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); - if (device_target == "CPU" || pre_top_cell_ == nullptr) { - return; - } - if (pre_top_cell_->has_dynamic_cell()) { - std::set forward_op_tensor_id; - for (const auto &elem : forward()->cell_op_index_with_tensor_id().at(pre_top_cell_->cell_id())) { - const auto &tensor_id_list = elem.second; - for (const auto &tensor_id : tensor_id_list) { - forward_op_tensor_id.emplace(tensor_id); - } - } - for (auto &tensor : all_value_node_tensors_) { - if (tensor->device_address() != nullptr && - forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) { - tensor->device_address()->ClearDeviceMemory(); - tensor->set_device_address(nullptr); - } - } - all_value_node_tensors_.clear(); - } - auto it = forward()->cell_tensor_id_with_tensor().find(pre_top_cell_->cell_id()); - if (it == forward()->cell_tensor_id_with_tensor().end()) { - pre_top_cell_ = nullptr; - return; - } - const auto &tensor_id_with_tensor = it->second; - for (const auto &elem : tensor_id_with_tensor) { - const auto &tensors_in_value_node = elem.second; - for (const auto &tensor : tensors_in_value_node) { - MS_EXCEPTION_IF_NULL(tensor); - tensor->set_device_address(nullptr); - } - } - pre_top_cell_ = nullptr; -} - AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { auto graph_info = top_cell()->graph_info_map().at(curr_g_); MS_EXCEPTION_IF_NULL(graph_info); @@ -1293,7 +1264,7 @@ AnfNodePtr GradExecutor::MakeValueNode(const py::object &obj, const std::string } void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode) { - if (graph_stack_.empty()) { + if (cell_stack_.empty()) { MS_LOG(DEBUG) << "No need save output"; return; } @@ -1310,43 +1281,116 @@ void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object } } SetNodeMapInGraphInfoMap(curr_g_, obj_id, cnode); - SetPyObjInGraphInfoMap(curr_g_, obj_id); } -void GradExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, - const py::object &out_real) { - if (node == nullptr) { - return; +// Run ad grad for curr op and connect grad graph with previous op +void GradExecutor::DoOpGrad(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &op_out) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "The node should be a cnode to run ad grad, but got node: " << node->ToString(); } + auto c_node = node->cast(); + MS_EXCEPTION_IF_NULL(c_node); + // Get input values MS_EXCEPTION_IF_NULL(op_exec_info); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // save input object - size_t size = op_exec_info->op_inputs.size(); - for (size_t i = 0; i < size; i++) { - auto obj = op_exec_info->op_inputs[i]; - auto obj_id = GetId(obj); - auto it = obj_to_forward_id_.find(obj_id); - if (it != obj_to_forward_id_.end()) { - cnode->add_input_value(PyAttrValue(obj), it->second); - } else { - cnode->add_input_value(nullptr, ""); + ValuePtrList input_args; + for (size_t i = 0; i < op_exec_info->op_inputs.size(); ++i) { + auto arg = parse::data_converter::PyDataToValue(op_exec_info->op_inputs[i]); + MS_EXCEPTION_IF_NULL(arg); + input_args.emplace_back(arg); + } + // get output value + auto out_value = parse::data_converter::PyDataToValue(op_out); + MS_EXCEPTION_IF_NULL(out_value); + + if (!ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), c_node, input_args, out_value)) { + MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_exec_info->op_name; + } +} + +void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const py::object &out_real) { + if (!grad_flag()) { + MS_LOG(DEBUG) << "The grad flag is false, no need to update forward op info in bprop graph"; + return; + } + MS_EXCEPTION_IF_NULL(top_cell_); + MS_EXCEPTION_IF_NULL(op_exec_info); + auto op_info = op_exec_info->op_info; + MS_LOG(DEBUG) << "Current op info: " << op_info; + // Get output tensors + std::vector all_op_tensors; + TensorValueToTensor(parse::data_converter::PyDataToValue(out_real), &all_op_tensors); + // Save all tensors info of current op + if (need_construct_graph()) { + SaveOpInfo(top_cell_, op_info, all_op_tensors); + } + // First run top cell + if (already_run_top_cell_.find(top_cell_->cell_id()) == already_run_top_cell_.end()) { + MS_LOG(DEBUG) << "Top cell " << top_cell_->cell_id() << " run firstly"; + if (!need_construct_graph()) { + MS_LOG(EXCEPTION) << "The cell stack is empty"; } + return; } - // save output object - auto output_value = PyAttrValue(out_real); - MS_EXCEPTION_IF_NULL(output_value); - cnode->set_forward(output_value, op_exec_info->op_index); - auto out_id = GetId(out_real); - if (py::isinstance(out_real)) { - auto tuple_item = py::cast(out_real); - for (size_t i = 0; i < tuple_item.size(); i++) { - auto tuple_item_id = GetId(tuple_item[i]); - obj_to_forward_id_[tuple_item_id] = op_exec_info->op_index; + // Non-first run + const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->cell_id()); + MS_EXCEPTION_IF_NULL(pre_top_cell); + if (pre_top_cell->op_info_with_tensor_id().find(op_info) == pre_top_cell->op_info_with_tensor_id().end()) { + MS_LOG(DEBUG) << "Can not find op info " << op_info << " in op info with tensor id map. Top cell " + << top_cell_->cell_id(); + return; + } + const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_info); + if (pre_op_tensor_id.size() != all_op_tensors.size()) { + MS_LOG(EXCEPTION) << "The size of pre op tensor id: " << pre_op_tensor_id.size() + << " is not equal to the size of all tensors of current op " << all_op_tensors.size(); + } + // Update new output tensor info in bprop graph + const auto &pre_tensor_id_with_tensor_object = pre_top_cell->tensor_id_with_tensor_object(); + for (size_t i = 0; i < pre_op_tensor_id.size(); ++i) { + auto pre_id = pre_op_tensor_id[i]; + if (pre_tensor_id_with_tensor_object.find(pre_id) == pre_tensor_id_with_tensor_object.end()) { + continue; + } + const auto &new_tensor = all_op_tensors[i]; + const auto &pre_tensor_object = pre_tensor_id_with_tensor_object.at(pre_id); + UpdateTensorInfo(new_tensor, pre_tensor_object); + } +} + +void GradExecutor::SaveForwardTensorInfoInBpropGraph(const ResourcePtr &resource) { + MS_EXCEPTION_IF_NULL(resource); + // Get all tensors id belong to forward op + std::unordered_set forward_op_tensor_id; + const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id(); + for (const auto &elem : op_info_with_tensor_id) { + std::for_each(elem.second.begin(), elem.second.end(), + [&](const std::string &tensor_id) { forward_op_tensor_id.emplace(tensor_id); }); + } + auto &tensor_id_with_tensor_object_ = top_cell()->tensor_id_with_tensor_object(); + if (!tensor_id_with_tensor_object_.empty()) { + MS_LOG(EXCEPTION) << "When compile a new graph, the map tensor_id_with_tensor_object should be empty. Top cell " + << top_cell()->cell_id(); + } + const auto &bprop_graph = resource->func_graph(); + const auto &value_node_list = bprop_graph->value_nodes(); + std::vector tensors_in_bprop_graph; + for (const auto &elem : value_node_list) { + auto value_node = elem.first->cast(); + MS_EXCEPTION_IF_NULL(value_node); + TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph); + } + // Save tensor info in bprop graph + for (const auto &tensor : tensors_in_bprop_graph) { + if (tensor->device_address() == nullptr || forward_op_tensor_id.find(tensor->id()) == forward_op_tensor_id.end()) { + continue; } + tensor_id_with_tensor_object_[tensor->id()].emplace_back(tensor); + MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id() + << " device address: " << tensor->device_address()->GetMutablePtr() << " shape and dtype " + << tensor->GetShapeAndDataTypeInfo(); } - obj_to_forward_id_[out_id] = op_exec_info->op_index; } py::tuple ForwardExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) { @@ -1442,8 +1486,7 @@ py::object ForwardExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, Pynativ auto input_obj_id = GetId(input); auto tensor = py::cast(input); MS_EXCEPTION_IF_NULL(tensor); - if (grad()->obj_to_forward_id().find(input_obj_id) == grad()->obj_to_forward_id().end() && - op_exec_info->op_name == "HookBackward") { + if (op_exec_info->op_name == "HookBackward") { // the input object is not a output of forward cnode, eg: parameter result[i] = tensor; } else { @@ -1532,8 +1575,6 @@ void ForwardExecutor::ClearRes() { node_abs_map_.clear(); cast_struct_map_.clear(); op_mask_map_.clear(); - cell_op_index_with_tensor_id_.clear(); - cell_tensor_id_with_tensor_.clear(); } ForwardExecutorPtr GradExecutor::forward() const { @@ -1542,11 +1583,6 @@ ForwardExecutorPtr GradExecutor::forward() const { return forward_executor; } -DynamicAnalysisPtr GradExecutor::dynamic_analysis() const { - MS_EXCEPTION_IF_NULL(dynamic_analysis_); - return dynamic_analysis_; -} - TopCellInfoPtr GradExecutor::top_cell() const { MS_EXCEPTION_IF_NULL(top_cell_); return top_cell_; @@ -1557,28 +1593,25 @@ FuncGraphPtr GradExecutor::curr_g() const { return curr_g_; } -void GradExecutor::PushCurrentGraphToStack() { graph_stack_.push(curr_g_); } - -void GradExecutor::PushCurrentCellOpInfoToStack() { - std::string cell_op_info = "Cell ops: "; - cell_op_info_stack_.push(cell_op_info); -} +void GradExecutor::PushCellStack(const std::string &cell_id) { cell_stack_.push(cell_id); } -void GradExecutor::PopGraphStack() { - if (graph_stack_.empty()) { - MS_LOG(EXCEPTION) << "Stack graph_stack_ is empty"; - } - graph_stack_.pop(); - if (!graph_stack_.empty()) { - curr_g_ = graph_stack_.top(); +void GradExecutor::PopCellStack() { + if (cell_stack_.empty()) { + MS_LOG(EXCEPTION) << "Stack cell_statck_ is empty"; } + cell_stack_.pop(); } -void GradExecutor::PopCurrentCellOpInfoFromStack() { - if (cell_op_info_stack_.empty()) { - MS_LOG(EXCEPTION) << "The cell op info stack is empty"; +void GradExecutor::PushHighOrderGraphStack() { high_order_stack_.push(curr_g_); } + +void GradExecutor::PopHighOrderGraphStack() { + if (high_order_stack_.empty()) { + MS_LOG(EXCEPTION) << "Stack high_order_stack_ is empty"; + } + high_order_stack_.pop(); + if (!high_order_stack_.empty()) { + curr_g_ = high_order_stack_.top(); } - cell_op_info_stack_.pop(); } std::string GradExecutor::GetCellId(const py::object &cell, const py::args &args) { @@ -1601,44 +1634,6 @@ std::string GradExecutor::GetCellId(const py::object &cell, const py::args &args return cell_id; } -void GradExecutor::SetTopCellTensorId(const std::string &cell_id) { - // Get top cell id - if (top_cell()->cell_graph_list().empty()) { - return; - } - auto top_cell_id = top_cell()->cell_graph_list().front()->cell_id(); - if (top_cell_id.find("NoShape") == std::string::npos) { - return; - } - std::string key = top_cell_id.substr(0, PTR_LEN); - auto fn = [](const std::string &str, std::vector &value) { - size_t pos = 0; - size_t pre_pos = 0; - while ((pos = str.find_first_of('_', pre_pos)) != std::string::npos) { - value.emplace_back(str.substr(pre_pos, pos - pre_pos + 1)); - pre_pos = pos + 1; - } - value.emplace_back(str.substr(pre_pos)); - }; - std::vector pre_cell_id; - std::vector cur_cell_id; - fn(cell_id, cur_cell_id); - fn(top_cell_id, pre_cell_id); - auto pre_tensor_size = pre_cell_id.size(); - if (pre_tensor_size == cur_cell_id.size()) { - size_t same_tensor_count = 0; - for (size_t i = 0; i < pre_tensor_size; ++i) { - if (pre_cell_id[i].find("NoShape") != std::string::npos || cur_cell_id[i] == pre_cell_id[i]) { - ++same_tensor_count; - } - } - if (same_tensor_count == pre_tensor_size) { - MS_LOG(DEBUG) << "Changed cell id from " << top_cell_id << " to " << cell_id; - top_cell()->cell_graph_list().front()->set_cell_id(cell_id); - } - } -} - void GradExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) { if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { DumpIR(filename, graph); @@ -1650,62 +1645,28 @@ bool GradExecutor::IsNestedGrad() const { return grad_order_ > 1; } -bool GradExecutor::IsTopGraph(const std::string &cell_id) { - return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { - return value->cell_id().find(cell_id) != std::string::npos; - }); -} - -bool GradExecutor::IsTopestGraph(const std::string &cell_id) { - return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { - return (value->cell_id() == cell_id || cell_id.find(value->cell_id()) != std::string::npos) && value->is_topest(); - }); -} - -bool GradExecutor::TopCellIsDynamic() { - if (top_cell_ == nullptr) { - return false; +bool GradExecutor::IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) { + // get end pos of obj_id + size_t obj_id_end_idx = l_cell_id.find('_'); + if (obj_id_end_idx == std::string::npos) { + obj_id_end_idx = l_cell_id.length(); } - return CheckRealDynamicCell(top_cell_id()); + // just compare obj_id, ignore args id + int cmp_ret = l_cell_id.compare(0, obj_id_end_idx, r_cell_id, 0, obj_id_end_idx); + return cmp_ret == 0; } -TopCellInfoPtr GradExecutor::GetTopCell(const string &cell_id, bool find_nearest) { - auto find_top_cell = [&](const string &cell_id) -> TopCellInfoPtr { - auto iter = std::find_if( - top_cell_list_.rbegin(), top_cell_list_.rend(), - [&cell_id](const TopCellInfoPtr &top_cell) { return cell_id == top_cell->cell_id() && top_cell->is_topest(); }); - if (iter != top_cell_list_.rend()) { - return *iter; - } - return nullptr; - }; - TopCellInfoPtr top_cell = find_top_cell(cell_id); - // find nearest top cell - if (top_cell == nullptr && find_nearest) { - for (auto it = top_cell_list_.begin(); it != top_cell_list_.end(); ++it) { - MS_EXCEPTION_IF_NULL(*it); - for (const auto &cell_info : (*it)->cell_graph_list()) { - MS_EXCEPTION_IF_NULL(cell_info); - if (cell_id == cell_info->cell_id()) { - return *it; - } - } - } - } - return top_cell; +bool GradExecutor::IsTopGraph(const std::string &cell_id) { + auto &top_cell_id = top_cell()->cell_id(); + return IsCellObjIdEq(cell_id, top_cell_id); } -void GradExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled) { - auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id() == cell_id; }); - if (it != top_cell_list_.end()) { - (*it)->set_vm_compiled(vm_compiled); - (*it)->set_forward_already_run(false); - (*it)->set_need_grad(true); - (*it)->set_is_grad(true); - if ((*it)->is_topest()) { - in_grad_process_ = false; - } +void GradExecutor::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled) { + top_cell()->set_vm_compiled(vm_compiled); + top_cell()->set_need_compile_graph(need_compile_graph); + top_cell()->set_forward_already_run(forward_already_run); + if (top_cell()->is_topest()) { + in_grad_process_ = false; } } @@ -1713,18 +1674,7 @@ bool GradExecutor::IsBpropGraph(const std::string &cell_id) { if (top_cell_ == nullptr) { return false; } - return std::any_of( - top_cell_->cell_graph_list().begin(), top_cell_->cell_graph_list().end(), [&cell_id](const CellInfoPtr &value) { - return !value->bprop_cell_id().empty() && cell_id.find(value->bprop_cell_id()) != std::string::npos; - }); -} - -bool GradExecutor::IsFirstGradStep() { return !top_cell()->is_grad(); } - -bool GradExecutor::IsGradBefore(const std::string &cell_id) { - return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { - return value->cell_id() == cell_id && value->is_grad(); - }); + return false; } void GradExecutor::SubNestedGradOrder() { @@ -1733,88 +1683,29 @@ void GradExecutor::SubNestedGradOrder() { } } -bool GradExecutor::CheckCellGraph(const std::string &cell_id) { - if (top_cell_ == nullptr) { - for (const auto &it : top_cell_list_) { - MS_EXCEPTION_IF_NULL(it); - if (it->cell_id() == cell_id) { - set_top_cell(it); - return true; - } - } - return false; - } else { - return std::any_of(top_cell_->cell_graph_list().begin(), top_cell_->cell_graph_list().end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); - } -} - -bool GradExecutor::CheckDynamicCell(const std::string &cell_id) { - if (top_cell_ == nullptr) { - return false; - } - return std::any_of( - top_cell_->cell_graph_list().begin(), top_cell_->cell_graph_list().end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id && value->is_dynamic(); }); -} - -bool GradExecutor::CheckRealDynamicCell(const std::string &cell_id) { - if (top_cell_ == nullptr) { - return false; - } - return top_cell_->is_real_dynamic(); -} - -void GradExecutor::ClearResidualRes(const std::string &cell_id) { - // Abnormal case - if (top_cell_list_.empty() && !graph_stack_.empty()) { - ClearCellRes(); - std::stack().swap(graph_stack_); - } - if (pre_top_cell_ == nullptr || !graph_stack_.empty() || !IsTopGraph(cell_id) || IsBpropGraph(cell_id)) { - return; - } - auto is_real_dynamic = pre_top_cell_->is_real_dynamic(); - if (is_real_dynamic && cell_id == pre_top_cell_->cell_id()) { - // Clear previous step resource - auto resource = GetResource(cell_id); - if (resource != nullptr && resource->results().find(pipeline::kBackend) != resource->results().end()) { - compile::BackendPtr backend = resource->results()[pipeline::kBackend].cast(); - auto ms_backend = std::dynamic_pointer_cast(backend); - ms_backend->ClearSessionGraphs(); - } - } -} - void GradExecutor::ClearCellRes(const std::string &cell_id) { // Grad clean if (cell_id.empty()) { for (const auto &it : top_cell_list_) { it->clear(); } + top_cell_list_.clear(); + already_run_top_cell_.clear(); + MS_LOG(DEBUG) << "Clear all cell resources"; return; } - if (IsTopGraph(cell_id)) { - for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) { - if ((*it)->cell_id().find(cell_id) != std::string::npos) { - (*it)->clear(); - it = top_cell_list_.erase(it); - } else { - it++; - } - } - } else { - // Clear common cell id - for (const auto &it : top_cell_list_) { - MS_EXCEPTION_IF_NULL(it); - for (auto ic = it->cell_graph_list().begin(); ic != it->cell_graph_list().end();) { - if ((*ic)->cell_id().find(cell_id) != std::string::npos) { - ic = it->cell_graph_list().erase(ic); - } else { - ++ic; - } + for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) { + auto top_cell_id = (*it)->cell_id(); + if (IsCellObjIdEq(cell_id, top_cell_id)) { + (*it)->clear(); + it = top_cell_list_.erase(it); + if (already_run_top_cell_.find(top_cell_id) != already_run_top_cell_.end()) { + (void)already_run_top_cell_.erase(top_cell_id); } + MS_LOG(DEBUG) << "Clear top cell resource. Top cell id " << top_cell_id; + continue; } + it++; } } @@ -1833,6 +1724,7 @@ FuncGraphPtr GradExecutor::GetDfbuilder(const std::string &cell_id) { } ResourcePtr GradExecutor::GetResource(const std::string &cell_id) { + // If top graph hold for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) { if (cell_id.find((*it)->cell_id()) != std::string::npos) { return (*it)->resource(); @@ -1845,292 +1737,86 @@ ResourcePtr GradExecutor::GetResource(const std::string &cell_id) { return nullptr; } -std::string DynamicAnalysis::ParseNodeName(const std::shared_ptr &ast, const py::object &node, - parse::AstMainType type) { - MS_EXCEPTION_IF_NULL(ast); - if (py::isinstance(node)) { - MS_LOG(DEBUG) << "Get none type node!"; - return ""; - } - auto node_type = ast->GetNodeType(node); - MS_EXCEPTION_IF_NULL(node_type); - // Check node type - parse::AstMainType node_main_type = node_type->main_type(); - if (node_main_type != type) { - MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type; - return ""; - } - std::string node_name = node_type->node_name(); - MS_LOG(DEBUG) << "Ast node is " << node_name; - return node_name; -} - -void DynamicAnalysis::ParseInputArgs(const std::shared_ptr &ast, const py::object &fn_node) { - MS_EXCEPTION_IF_NULL(ast); - py::list args = ast->GetArgs(fn_node); - for (size_t i = 1; i < args.size(); i++) { - std::string arg_name = py::cast(args[i].attr("arg")); - MS_LOG(DEBUG) << "Input arg name: " << arg_name; - cell_input_args_.emplace(arg_name); - } -} - -bool DynamicAnalysis::ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node) { - MS_LOG(DEBUG) << "Parse if/while expr"; - py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST); - const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR); - if (node_name == parse::NAMED_PRIMITIVE_COMPARE) { - py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT); - py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS); - if (comparators_node.empty()) { - MS_LOG(DEBUG) << "Get comparators node failed!"; - return false; - } - auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR); - auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR); - // while self.a > self.b and changed self.a or self.b - if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) { - auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE); - std::string left_variable; - if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) { - left_variable = py::cast(left_value.attr("id")) + py::cast(left_node.attr("attr")); - } - auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE); - std::string right_variable; - if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) { - right_variable = - py::cast(right_value.attr("id")) + py::cast(comparators_node[0].attr("attr")); - } - return ParseBodyContext(ast, node, {left_variable, right_variable}); - } - // if a[0] - if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) { - py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE); - left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR); - } - MS_LOG(DEBUG) << "Left is " << left << " Right is " << right; - if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() || - unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) { - return true; - } - } - // if flag: - if (node_name == parse::NAMED_PRIMITIVE_NAME) { - std::string id = py::cast(test_node.attr("id")); - if (cell_input_args_.find(id) != cell_input_args_.end()) { - return true; - } - } - return false; -} - -bool DynamicAnalysis::ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node) { - MS_LOG(DEBUG) << "Parse assign expr"; - py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE); - const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR); - if (node_name == parse::NAMED_PRIMITIVE_CALL) { - py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC); - const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR); - if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) { - py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE); - py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE); - if (py::isinstance(value_in_slice_node)) { - MS_LOG(DEBUG) << "Parse value node is none!"; - return false; - } - const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR); - std::string id; - if (py::hasattr(value_in_slice_node, "id")) { - id = py::cast(value_in_slice_node.attr("id")); - } - if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end() || - (!id.empty() && cell_input_args_.find(id) != cell_input_args_.end())) { - return true; - } - } - } - return false; -} - -bool DynamicAnalysis::ParseAugAssignExprNode(const std::shared_ptr &ast, const py::object &node, - const std::vector &compare_prim) { - MS_LOG(DEBUG) << "Parse augassign expr"; - bool ret = false; - if (compare_prim.empty()) { - return ret; - } - py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET); - if (py::isinstance(target_node)) { - MS_LOG(DEBUG) << "Parse target node is none!"; - return ret; - } - py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE); - if (py::isinstance(value_node)) { - MS_LOG(DEBUG) << "Parse value node is none!"; - return ret; - } - std::string assign_prim; - if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) { - assign_prim = py::cast(value_node.attr("id")) + py::cast(target_node.attr("attr")); - } - auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim); - if (iter != compare_prim.end()) { - ret = true; - } - return ret; -} - -bool DynamicAnalysis::ParseForExprNode(const std::shared_ptr &ast, const py::object &node) { - MS_LOG(DEBUG) << "Parse for expr"; - py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY); - if (py::isinstance(body_node)) { - MS_LOG(DEBUG) << "Parse body of for expression is none!"; - return false; - } - py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN); - size_t count = LongToSize(pcount); - MS_LOG(DEBUG) << "The for nodes count in body is " << count; - for (size_t i = 0; i < count; ++i) { - auto it = py::cast(body_node)[i]; - const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT); - if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) { - return true; - } - } - return false; -} - -bool DynamicAnalysis::ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node, - const std::vector &compare_prim) { - MS_EXCEPTION_IF_NULL(ast); - py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY); - if (py::isinstance(func_obj)) { - MS_LOG(DEBUG) << "Parse body of cell is none!"; - return false; - } - py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN); - size_t count = IntToSize(pcount); - MS_LOG(DEBUG) << "The nodes count in body is " << count; - bool ret = false; - for (size_t i = 0; i < count; ++i) { - auto node = py::cast(func_obj)[i]; - const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT); - if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) { - ret = ParseAssignExprNode(ast, node); - } else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) { - ret = ParseAugAssignExprNode(ast, node, compare_prim); - } else if (node_name == parse::NAMED_PRIMITIVE_FOR) { - ret = ParseForExprNode(ast, node); - } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) { - ret = ParseIfWhileExprNode(ast, node); - } - if (ret) { - MS_LOG(INFO) << "Current cell is dynamic!"; - break; +void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args) { + if (cell_stack_.empty()) { + if (IsBpropGraph(cell_id)) { + in_grad_process_ = true; + in_bprop_process_ = true; + } else { + MakeNewTopGraph(cell_id, args, true); } + } else { + // High order + if (IsNestedGrad()) { + MS_LOG(DEBUG) << "Enter nested graph"; + set_cell_nums(cell_stack_.size()); + MakeNewTopGraph(cell_id, args, false); + } + } + PushCellStack(cell_id); + // Init kPynativeCellPtr with input parameters of top cell + if (!top_cell()->is_init_kpynative()) { + auto graph_info_cg = std::make_shared(cell_id); + top_cell()->graph_info_map()[curr_g_] = graph_info_cg; + auto df_builder = GetDfbuilder(cell_id); + auto graph_info_df = std::make_shared(cell_id); + top_cell()->graph_info_map()[df_builder] = graph_info_df; + // Init parameter info for make cnode and curr_g + for (size_t i = 0; i < args.size(); ++i) { + auto param = args[i]; + auto new_param = curr_g_->add_parameter(); + ValuePtr param_value = PyAttrValue(param); + MS_EXCEPTION_IF_NULL(param_value); + new_param->set_abstract(param_value->ToAbstract()->Broaden()); + std::string param_id = GetId(param); + SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true); + SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param); + SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param); + } + top_cell()->set_k_pynative_cell_ptr(ad::GradPynativeCellBegin(curr_g_->parameters())); + top_cell()->set_need_compile_graph(true); + top_cell()->set_init_kpynative(true); + } else { + // Non-top cell + top_cell()->sub_cell_list().emplace(cell_id); } - return ret; -} - -std::string DynamicAnalysis::GetCellInfo(const py::object &cell) { - if (py::isinstance(cell)) { - auto c_cell = py::cast(cell); - MS_EXCEPTION_IF_NULL(c_cell); - auto cell_info = c_cell->ToString(); - return cell_info; - } - return ""; -} - -bool DynamicAnalysis::IsDynamicCell(const py::object &cell) { - std::string cell_info = GetCellInfo(cell); - if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) { - return false; - } - // Using ast parse to check whether the construct of cell will be changed - auto ast = std::make_shared(cell); - bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD); - if (!success) { - MS_LOG(ERROR) << "Parse code to ast tree failed"; - return false; - } - py::object fn_node = ast->GetAstNode(); - // get the name of input args as the initialize of dynamic_variables - ParseInputArgs(ast, fn_node); - // parse body context - bool ret = false; - ret = ParseBodyContext(ast, fn_node); - cell_input_args_.clear(); - return ret; } void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) { auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id; - // check whether cell needed to construct grad graph - if (graph_stack_.empty() && !top_cell_list_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { - // Clear previous step resource - auto init_fn = [&](bool flag) { - CleanPreMemoryInValueNode(); - op_index_map_.clear(); - in_grad_process_ = true; - auto top_cell = GetTopCell(cell_id, flag); - MS_EXCEPTION_IF_NULL(top_cell); - top_cell->set_forward_already_run(true); - set_top_cell(top_cell); - MS_LOG(DEBUG) << "Top cell id " << top_cell->cell_id(); - }; - if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) { - init_fn(false); - } - if (!in_grad_process_ && cell_op_info_stack_.empty()) { - init_fn(true); - } - PushCurrentCellOpInfoToStack(); - MS_LOG(INFO) << "NewGraph already compiled"; - return; - } - // Init resource for constructing forward graph and grad graph - curr_g_ = std::make_shared(); - ClearResidualRes(cell_id); - if (graph_stack_.empty()) { - if (IsBpropGraph(cell_id)) { - in_grad_process_ = true; - in_bprop_process_ = true; - } else { - MakeNewTopGraph(cell_id, args); - } - } - PushCurrentGraphToStack(); - PushCurrentCellOpInfoToStack(); - if (top_cell()->graph_info_map().find(curr_g_) == top_cell()->graph_info_map().end()) { - auto graph_info = std::make_shared(cell_id); - top_cell()->graph_info_map()[curr_g_] = graph_info; - } - for (size_t i = 0; i < args.size(); ++i) { - auto param = args[i]; - auto new_param = curr_g_->add_parameter(); - std::string param_id = GetId(param); - SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true); - SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param); - SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param); - } - // Check whether the construct of cell is dynamic - if (!has_dynamic_cell_) { - has_dynamic_cell_ = dynamic_analysis()->IsDynamicCell(cell); - MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << has_dynamic_cell_; - if (has_dynamic_cell_ && IsBpropGraph(cell_id)) { - auto it = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), - [this](const CellInfoPtr &value) { return value->cell_id() == top_cell_id(); }); - while (it != top_cell()->cell_graph_list().end()) { - (*it)->set_is_dynamic(true); - ++it; + // When the cell has custom bprop, in_custom_bprop_cell is lager than 0 + if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { + custom_bprop_cell_count_ += 1; + } + if (cell_stack_.empty() && top_cell_ != nullptr) { + // non-first step + if (!top_cell()->IsSubCell(cell_id) && already_run_top_cell_.find(cell_id) != already_run_top_cell_.end()) { + // top cell + const auto &pre_top_cell = already_run_top_cell_.at(cell_id); + if (!pre_top_cell->is_dynamic()) { + MS_LOG(DEBUG) << "Top cell " << cell_id << " is not dynamic, no need to run NewGraphInner again"; + ResetTopCellInfo(pre_top_cell, args); + set_top_cell(pre_top_cell); + return; } + } else if (top_cell()->IsSubCell(cell_id) && !top_cell()->is_dynamic()) { + // non-top cell + MS_LOG(DEBUG) << "no need to run NewGraphInner again"; + return; } } + // Init resource for resource and df_builder + InitResourceAndDfBuilder(cell_id, args); + // Check whether cell has dynamic construct + bool is_dynamic = parse::DynamicAnalysis::IsDynamicCell(cell); + if (is_dynamic) { + MS_LOG(DEBUG) << "Current cell " << py::cast(cell)->ToString() << " has dynamic construct"; + top_cell()->set_is_dynamic(is_dynamic); + } } -void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args) { +void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest) { for (const auto &arg : args) { if (py::isinstance(arg)) { auto tensor = arg.cast(); @@ -2140,74 +1826,28 @@ void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args) } } - CleanPreMemoryInValueNode(); std::string input_args_id; for (size_t i = 0; i < args.size(); ++i) { input_args_id = input_args_id + GetId(args[i]) + "_"; } - auto pre_dynamic_top_cell = GetTopCell(cell_id); - bool is_real_dynamic = false; - // Dynamic top cell is not nullptr - if (pre_dynamic_top_cell != nullptr) { - has_dynamic_cell_ = true; - // Clear top cell - if (pre_dynamic_top_cell->is_real_dynamic()) { - ClearCellRes(cell_id); - is_real_dynamic = true; - pre_dynamic_top_cell = nullptr; - } else { - pre_dynamic_top_cell->set_forward_already_run(true); - pre_dynamic_top_cell->set_input_args_id(input_args_id); - } - } else { - has_dynamic_cell_ = false; + + if (top_cell_list_.empty() && grad_order_ == 0) { + AddNestedGradOrder(); } - op_index_map_.clear(); in_grad_process_ = true; - + curr_g_ = std::make_shared(); + PushHighOrderGraphStack(); // Init resource for new top cell auto df_builder = std::make_shared(); - auto graph_info = std::make_shared(cell_id); auto resource = std::make_shared(); - auto new_top_cell = std::make_shared(true, resource, df_builder, cell_id); - new_top_cell->graph_info_map()[df_builder] = graph_info; - new_top_cell->set_forward_already_run(true); - new_top_cell->set_input_args_id(input_args_id); - if (pre_dynamic_top_cell != nullptr) { - MS_LOG(DEBUG) << "Get dynamic top cell"; - if (pre_dynamic_top_cell->is_grad()) { - new_top_cell->set_is_grad(true); - } - new_top_cell->set_cell_graph_list(pre_dynamic_top_cell->cell_graph_list()); - new_top_cell->set_graph_info_map(pre_dynamic_top_cell->graph_info_map()); - } - if (is_real_dynamic) { - MS_LOG(DEBUG) << "Get real dynamic"; - new_top_cell->set_is_real_dynamic(true); - } - set_top_cell(new_top_cell); - top_cell_list_.emplace_back(new_top_cell); + auto top_cell = std::make_shared(is_topest, resource, df_builder, cell_id); + top_cell->set_forward_already_run(true); + top_cell->set_input_args_id(input_args_id); + top_cell_list_.emplace_back(top_cell); + set_top_cell(top_cell); MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get(); } -std::string GradExecutor::GetCellOpInfo() { - if (cell_op_info_stack_.empty()) { - MS_LOG(EXCEPTION) << "The cell op info stack is empty"; - } - return cell_op_info_stack_.top(); -} - -void GradExecutor::ReplaceCellOpInfoByCellId(const std::string &cell_id) { - if (cell_id.empty()) { - MS_LOG(EXCEPTION) << "The cell id is empty"; - } - if (cell_op_info_stack_.empty()) { - MS_LOG(DEBUG) << "The cell op info stack is empty, No need replace"; - return; - } - cell_op_info_stack_.top() = cell_op_info_stack_.top() + cell_id; -} - void GradExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, bool is_param) { if (!py::isinstance(args) && !py::isinstance(args)) { @@ -2248,607 +1888,199 @@ void GradExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const p } } +void GradExecutor::SetMakeTupleAsOutputNode(const TopCellInfoPtr &top_cell, const FuncGraphPtr &curr_g, + const py::object &out) { + MS_EXCEPTION_IF_NULL(top_cell); + MS_EXCEPTION_IF_NULL(curr_g); + if (!(py::isinstance(out) || py::isinstance(out))) { + MS_LOG(EXCEPTION) << "The out of top cell should be tuple or list when set maketuple as output node"; + } + auto tuple = out.cast(); + auto tuple_size = static_cast(tuple.size()); + + // get input node and value + ValuePtrList input_args; + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (int64_t i = 0; i < tuple_size; i++) { + inputs.emplace_back(GetInput(tuple[i], false)); + input_args.emplace_back(parse::data_converter::PyDataToValue(tuple[i])); + } + auto cnode = curr_g_->NewCNode(inputs); + // record node info in graph map + auto out_id = GetId(out); + SetTupleArgsToGraphInfoMap(curr_g_, out, cnode); + SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode); + // run ad for maketuple node + ValuePtr out_value = parse::data_converter::PyDataToValue(out); + ad::GradPynativeOp(top_cell->k_pynative_cell_ptr(), cnode, input_args, out_value); + // record op info + size_t curr_op_num = top_cell->op_num(); + std::string op_info = "MakeTuple-" + std::to_string(curr_op_num); + std::string curr_op_info = top_cell->all_op_info() + "_" + op_info; + top_cell->set_all_op_info(curr_op_info); + top_cell->set_op_num(curr_op_num + 1); + MS_LOG(DEBUG) << "Tuple output node info " << cnode->DebugString(); +} + void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args) { const auto &cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id; - if (graph_stack_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { - PopCurrentCellOpInfoFromStack(); - MS_LOG(INFO) << "Endgraph already compiled"; + if (!need_construct_graph()) { + MS_LOG(DEBUG) << "Current cell " << cell_id << " no need to run EndGraphInner again"; + if (cell_id == top_cell()->cell_id()) { + set_grad_flag(false); + } return; } - auto out_id = GetId(out); // x =op1, y =op2, return (x, y) + auto out_id = GetId(out); auto graph_info = top_cell()->graph_info_map().at(curr_g_); MS_EXCEPTION_IF_NULL(graph_info); if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) { if (py::isinstance(out) || py::isinstance(out)) { - auto tuple = out.cast(); - auto tuple_size = static_cast(tuple.size()); - - std::vector inputs; - inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - for (int64_t i = 0; i < tuple_size; i++) { - inputs.emplace_back(GetInput(tuple[i], false)); - } - auto cnode = curr_g_->NewCNode(inputs); - SetTupleArgsToGraphInfoMap(curr_g_, out, cnode); - SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode); + SetMakeTupleAsOutputNode(top_cell(), curr_g_, out); } else { MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id; MakeValueNode(out, out_id); } } - EndGraphByOutId(cell, cell_id, out, out_id, args); -} -void GradExecutor::EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, - const std::string &out_id, const py::args &args) { - AnfNodePtr output_node = GetObjNode(out, out_id); - curr_g_->set_output(output_node); - MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString(); - if (EndBpropGraph(cell_id)) { - MS_LOG(DEBUG) << "Get bprop function cell"; + if (IsBpropGraph(cell_id)) { + MS_LOG(DEBUG) << "Brop cell no need construct graph"; return; } - auto resource = GetResource(top_cell_id()); - MS_EXCEPTION_IF_NULL(resource); - resource->manager()->AddFuncGraph(curr_g_); - UpdateCellGraph(cell, curr_g_, cell_id, true, false); - FuncGraphPtr newfg = nullptr; - // Cell no Change - if (CheckDynamicCell(cell_id) && !CheckCellChanged(cell_id)) { - MS_LOG(DEBUG) << "Cell is fake dynamic, no need make ad grad"; - top_cell()->set_need_grad(false); - ClearCnodeRes(curr_g_->output()); - } else { - MS_LOG(DEBUG) << "Need make ad grad"; - if (!top_cell()->need_grad()) { - ClearCnodeRes(curr_g_->output()); - } - newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args); - } - - if (graph_stack_.size() > 1) { - std::vector inputs; - inputs.emplace_back(NewValueNode(curr_g_)); - - PopGraphStack(); - PopCurrentCellOpInfoFromStack(); - ReplaceCellOpInfoByCellId(cell_id); - // connect the previous graph to the inside graph - auto graph_prev = graph_stack_.top(); - for (size_t i = 0; i < args.size(); i++) { - auto input = GetInput(args[i], false); - inputs.emplace_back(input); - } - auto out_cnode = graph_prev->NewCNode(inputs); - SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args)); - SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); - SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); - } else { - if (newfg != nullptr) { - DumpGraphIR("before_resolve.ir", newfg); - parse::ResolveFuncGraph(newfg, resource); - DumpGraphIR("after_resolve.ir", newfg); - resource->set_func_graph(newfg); - } - PopGraphStack(); - PopCurrentCellOpInfoFromStack(); - ClearDynamicTopRes(cell_id); + DoGradForCustomBprop(cell, out, args); + PopCellStack(); + // Not first call + bool is_nested_grad = IsNestedGrad(); + if ((cell_stack_.size() > 1 && !is_nested_grad) || (is_nested_grad && cell_stack_.size() != cell_nums())) { + MS_LOG(DEBUG) << "Sub cell no need construct graph"; + return; } -} -bool GradExecutor::EndBpropGraph(const string &cell_id) { - auto is_bprop_graph = IsBpropGraph(cell_id); - if (is_bprop_graph) { - if (!IsNestedGrad()) { - PopGraphStack(); - PopCurrentCellOpInfoFromStack(); - ReplaceCellOpInfoByCellId(cell_id); - } - return true; + // Reset grad flag and checkout whether need to compile graph when top cell has ran finished + if (cell_stack_.empty() && cell_id == top_cell()->cell_id()) { + set_grad_flag(false); + (void)CheckNeedCompileGraph(); } - return false; } -bool GradExecutor::CheckCellChanged(const std::string &cell_id) { - bool res = false; - if (top_cell()->is_real_dynamic()) { - MS_LOG(DEBUG) << "Cur cell " << cell_id << " is dynamic, no need check"; - return true; - } - if (GetCellOpInfo().empty()) { - MS_LOG(DEBUG) << "Cell op info is empty"; - return true; - } - - auto it = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); - if (it == top_cell()->cell_graph_list().end() || IsFirstGradStep()) { - return true; - } - MS_LOG(DEBUG) << "Cell op info " << GetCellOpInfo() << ", old " << (*it)->cell_ops_info().at((*it)->call_times()); - if ((*it)->cell_ops_info().at((*it)->call_times()) != GetCellOpInfo()) { - res = true; - top_cell()->set_is_real_dynamic(true); - MS_LOG(DEBUG) << "Cell self changed"; - } - if ((*it)->call_times() < ((*it)->cell_ops_info().size() - 1)) { - (*it)->set_call_times((*it)->call_times() + 1); - } else { - (*it)->set_call_times(0); - } - return res; -} - -bool GradExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, - bool need_cloned, bool is_grad) { +void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args) { if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { - return false; - } - auto update_in_endgraph = need_cloned && !is_grad; - // Bprop just save backward graph - auto it = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); - if (it != top_cell()->cell_graph_list().end()) { - if (top_cell_id() == cell_id) { - top_cell()->set_is_grad(is_grad); - } - if (g != (*it)->fg()) { - top_cell()->graph_info_map().update((*it)->fg(), g); - (*it)->set_fg(g); - } - if (update_in_endgraph && IsFirstGradStep()) { - (*it)->cell_ops_info().emplace_back(GetCellOpInfo()); - } - MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id; - } else { - py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME); - auto bprop_func_cell_id = GetId(bprop_func); - MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id - << " cell ops info " << GetCellOpInfo(); - auto cell_info = std::make_shared(true, has_dynamic_cell_, g, cell_id, bprop_func_cell_id); - cell_info->cell_ops_info().emplace_back(GetCellOpInfo()); - if (in_bprop_process_) { - top_cell()->cell_graph_list().emplace_back(cell_info); - } else { - top_cell()->cell_graph_list().insert(top_cell()->cell_graph_list().begin(), cell_info); - } - } - return true; -} - -void GradExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, - bool need_cloned, bool is_grad) { - auto update_in_endgraph = need_cloned && !is_grad; - if (UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad)) { return; } - FuncGraphPtr tmp = g; - if (!IsFirstGradStep() && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) { - MS_LOG(DEBUG) << "No need cloned"; - need_cloned = false; - } - auto clone_fn = [&g, &tmp, need_cloned, this]() { - if (!need_cloned) { - return; - } - tmp = BasicClone(g); - top_cell()->graph_info_map().update(g, tmp); - ClearCnodeRes(tmp->output()); - }; - // First call or cell id not exist - if (update_in_endgraph && (IsFirstGradStep() || !CheckCellGraph(cell_id))) { - if (!CheckCellGraph(cell_id)) { - clone_fn(); - MS_LOG(DEBUG) << "Add new cell with cloned graph " << cell_id << " cell ops info " << GetCellOpInfo(); - auto cell_info = std::make_shared(true, has_dynamic_cell_, tmp, cell_id, ""); - cell_info->cell_ops_info().emplace_back(GetCellOpInfo()); - if (in_bprop_process_) { - top_cell()->cell_graph_list().emplace_back(cell_info); - } else { - top_cell()->cell_graph_list().insert(top_cell()->cell_graph_list().begin(), cell_info); - } - } else { - auto it = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); - if (it != top_cell()->cell_graph_list().end()) { - (*it)->cell_ops_info().emplace_back(GetCellOpInfo()); - } - MS_LOG(DEBUG) << "Add another same cell ops info"; - } + custom_bprop_cell_count_ -= 1; + if (custom_bprop_cell_count_ != 0) { return; } + py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME); + auto fake_prim = std::make_shared(prim::kPrimHookBackward->name(), py::object()); + fake_prim->set_hook(bprop_func); + const auto &cell_id = GetCellId(cell, args); + (void)fake_prim->AddAttr("cell_id", MakeValue(cell_id)); + (void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true)); - for (auto &it : top_cell()->cell_graph_list()) { - if (it->cell_id() != cell_id) { - continue; - } - it->set_is_dynamic(has_dynamic_cell_); - if (need_cloned) { - clone_fn(); - if (it->fg() != nullptr) { - top_cell()->graph_info_map().erase(it->fg()); - } - MS_LOG(DEBUG) << "Update cur graph " << it->fg().get() << " with cloned new " << tmp.get(); - it->set_fg(tmp); - } - if (!need_cloned && !is_grad) { - top_cell()->graph_info_map().erase(it->fg()); - MS_LOG(DEBUG) << "Update cur graph " << it->fg().get() << " with new " << tmp.get(); - it->set_fg(tmp); - } - break; + py::object code_obj = py::getattr(bprop_func, "__code__"); + // Three parameters self, out and dout need to be excluded + const size_t inputs_num = py::cast(py::getattr(code_obj, "co_argcount")) - 3; + if (inputs_num > args.size()) { + MS_LOG(EXCEPTION) << "Size of bprop func inputs[" << inputs_num << "] is larger than size of cell inputs[" + << args.size() << "]"; } + + py::list cell_inputs; + for (size_t i = 0; i < inputs_num; i += 1) { + cell_inputs.append(args[i]); + } + OpExecInfoPtr op_exec_info = std::make_shared(); + op_exec_info->op_name = fake_prim->name(); + op_exec_info->py_primitive = fake_prim; + op_exec_info->op_inputs = cell_inputs; + + abstract::AbstractBasePtrList args_spec_list; + std::vector op_masks; + auto cnode = forward()->MakeCNode(op_exec_info, &op_masks, &args_spec_list); + DoOpGrad(op_exec_info, cnode, out); + const std::string out_obj_id = GetId(out); + SaveOutputNodeMap(out_obj_id, out, cnode); } -void GradExecutor::ClearCnodeRes(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - std::unordered_set node_set; - std::function fn; - fn = [&fn, &node_set](const AnfNodePtr &node) { - if (!node->isa() || node_set.find(node) != node_set.end()) { - return; - } - node_set.insert(node); - auto cnode = node->cast(); - cnode->clear_inputs_value(); - cnode->set_forward(nullptr, ""); - for (size_t i = 0; i < cnode->size(); ++i) { - auto n = cnode->input(i); - fn(n); - } - }; - fn(node); - node_set.clear(); -} - -FuncGraphPtr GradExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, - const std::string &cell_id, const py::args &args) { - bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); - if (is_custom_bprop) { - size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size(); - if (par_number > 0) { - MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number - << " parameters that is not supported in the net."; - } - MS_LOG(INFO) << "Use cell custom bprop function."; - FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell); - if (bprop_graph != nullptr) { - (void)g->transforms().emplace(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); - (void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g))); - } - } - auto is_top = IsTopGraph(cell_id); - MS_LOG(DEBUG) << "Grad top cell " << is_top; - DumpGraphIR("fg.ir", g); - // Before make grad graph, we need to run auto-monad on forward graph, - // so that side effects in forward graph can be handled in grad graph. - (void)pipeline::AutoMonad(g); - set_need_replace_forward(!IsNestedGrad()); - // Obtain grad graph - auto newfg = ad::Grad(g, r, is_top); - - if (is_custom_bprop) { - auto params = newfg->parameters(); - auto manager = Manage({newfg}, false); - if (args.size() > params.size()) { - MS_EXCEPTION(TypeError) << "The number of arguments " << args.size() - << " is more than the number of parameters required, which is " << params.size(); - } - for (size_t i = 0; i < args.size(); i++) { - ValuePtr value = PyAttrValue(args[i]); - auto v_node = NewValueNode(value); - manager->Replace(params[i], v_node); - } - UpdateCellGraph(cell, newfg, cell_id, false, false); - } - return newfg; -} - -std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, - py::object *forward_args, py::object *sens) { - auto size = args.size(); - size_t forward_args_size = size; +std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, const py::args &args) { + py::args forward_args = args; + size_t forward_args_size = args.size(); if (has_sens) { - if (size >= 1) { - --forward_args_size; - if (sens != nullptr) { - *sens = args[forward_args_size]; - } - } + forward_args_size--; py::tuple f_args(forward_args_size); for (size_t i = 0; i < forward_args_size; ++i) { f_args[i] = args[i]; } - *forward_args = f_args; + forward_args = f_args; } - const auto &cell_id = GetCellId(cell, *forward_args); + const auto &cell_id = GetCellId(cell, forward_args); return cell_id; } -void GradExecutor::SaveAllValueNodeTensors(const FuncGraphPtr &graph) { - std::unordered_set all_value_node_tensors; - auto trace_function = [&all_value_node_tensors](const AnfNodePtr &anf_node) { - auto value = GetValueNode(anf_node); - if (value) { - if (value->isa()) { - auto tensor = value->cast(); - MS_EXCEPTION_IF_NULL(tensor); - if (tensor->device_address()) { - all_value_node_tensors.emplace(tensor); - } - } else if (value->isa()) { - auto tuple = value->cast(); - MS_EXCEPTION_IF_NULL(tuple); - for (size_t i = 0; i < tuple->size(); i++) { - if ((*tuple)[i]->isa()) { - auto tensor = (*tuple)[i]->cast(); - MS_EXCEPTION_IF_NULL(tensor); - if (tensor->device_address()) { - all_value_node_tensors.emplace(tensor); - } - } - } - } - } - return FOLLOW; - }; - (void)TopoSort(graph->get_return(), SuccDeeperSimple, trace_function); - all_value_node_tensors_ = all_value_node_tensors; -} - void GradExecutor::GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args) { + MS_EXCEPTION_IF_NULL(grad); auto size = args.size(); - py::object sens = py::none(); - py::object forward_args = args; - const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args, &forward_args, &sens); - SetTopCellTensorId(cell_id); + const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args); MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id; - const auto ¶ms_changed = CheckGradParamsChanged(cell_id, weights, sens); - if (!params_changed && IsGradBefore(cell_id) && !CheckRealDynamicCell(cell_id)) { - UpdateTopCellInfo(cell_id, false); - op_index_map_.clear(); - MS_LOG(INFO) << "Gradgraph already compiled"; + if (!top_cell()->need_compile_graph()) { + MS_LOG(DEBUG) << "No need compile graph"; + UpdateTopCellInfo(false, false, true); return; } - // Nested graph - if (CheckCellGraph(cell_id) && !graph_stack_.empty()) { - MS_LOG(DEBUG) << "Set nested top graph"; - SetNestedTopGraph(cell, forward_args, cell_id); - } - auto df_builder = GetDfbuilder(cell_id); MS_EXCEPTION_IF_NULL(df_builder); auto resource = GetResource(cell_id); MS_EXCEPTION_IF_NULL(resource); MS_LOG(DEBUG) << "df_builder ptr " << df_builder.get() << " resource ptr " << resource.get(); - // Set all params(input+weights) - SetGradGraphParams(df_builder, resource, size); // Get params(weights) require derivative - auto w_args = GetWeightsArgs(weights, df_builder); - // Get the parameters items and add the value to args_spec - auto args_spec = GetArgsSpec(args, df_builder); - resource->set_args_spec(args_spec); - // Get real grad graph - DumpGraphIR("before_grad.ir", resource->func_graph()); - SetGradGraph(resource->func_graph(), grad, w_args, size, cell_id); - DumpGraphIR("after_grad.ir", df_builder); - resource->set_func_graph(df_builder); - resource->manager()->KeepRoots({df_builder}); + auto w_args = GetWeightsArgs(grad, weights, df_builder); + // Get bprop graph of top cell + auto bprop_graph = GetBpropGraph(grad, w_args, size, args); + resource->set_func_graph(bprop_graph); + auto manager = resource->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(bprop_graph, true); + DumpGraphIR("launch_bprop_graph.ir", bprop_graph); + // Launch bprop graph to backend + SaveForwardTensorInfoInBpropGraph(resource); resource->results()[pipeline::kBackend] = compile::CreateBackend(); - - MS_LOG(INFO) << "Start opt"; - if (has_dynamic_cell_) { - SaveAllValueNodeTensors(resource->func_graph()); - } - PynativeOptimizeAction(resource); - DumpGraphIR("after_opt.ir", resource->func_graph()); - SaveTensorsInValueNode(resource); + MS_LOG(DEBUG) << "Start task emit action"; TaskEmitAction(resource); + MS_LOG(DEBUG) << "Start execute action"; ExecuteAction(resource); - ClearUselessRes(df_builder, cell, cell_id); - UpdateTopCellInfo(cell_id, true); + MS_LOG(DEBUG) << "Start update top cell info when run finish"; + UpdateTopCellInfo(false, false, true); resource->Clean(); } -void GradExecutor::ClearDynamicTopRes(const std::string &cell_id) { - // Delete unused top cell resource - if (!CheckDynamicCell(cell_id)) { - return; - } - auto count = std::count_if(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id() == cell_id; }); - if (count < 2) { - return; - } - // Keep only one dynamic top cell - bool is_sedond_find = false; - for (auto it = top_cell_list_.begin(); it != top_cell_list_.end(); ++it) { - if ((*it)->cell_id() != cell_id) { - continue; - } - - if (top_cell()->is_real_dynamic()) { - MS_LOG(DEBUG) << "Real dynamic, delete first dynamic top cell"; - (*it)->clear(); - it = top_cell_list_.erase(it); - break; - } else { - if (is_sedond_find) { - MS_LOG(DEBUG) << "Fake dynamic, delete second dynamic top cell"; - (*it)->clear(); - it = top_cell_list_.erase(it); - break; - } - is_sedond_find = true; - } - } -} - -bool GradExecutor::CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, - const py::object &sens) { - bool res = false; - auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id() == cell_id; }); - if (it == top_cell_list_.end()) { - return res; - } - - auto fn = [](const py::object &arg) { - std::string arg_id; - if (py::isinstance(arg)) { - auto tensor_ptr = py::cast(arg); - auto dtype = tensor_ptr->data_type(); - auto shape = tensor_ptr->shape(); - std::stringstream ss; - std::for_each(shape.begin(), shape.end(), [&ss](int i) { ss << i; }); - arg_id = ss.str() + std::to_string(dtype); - } else { - arg_id = std::string(py::str(arg)); - } - return arg_id; - }; - - std::string sens_id = "sens"; - if (!py::isinstance(sens)) { - sens_id = fn(sens); - } - - if (!(*it)->sens_id().empty() && (*it)->sens_id() != sens_id) { - (*it)->set_sens_id(sens_id); - } - std::string weights_id = fn(weights); - if (!(*it)->weights_id().empty() && (*it)->weights_id() != weights_id) { - (*it)->set_weights_id(weights_id); - res = true; - } - return res; -} - -void GradExecutor::SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id) { - if (IsTopGraph(cell_id)) { - ClearCellRes(cell_id); - } - ResourcePtr resource = nullptr; - auto ia = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id() == cell_id; }); - if (ia != top_cell_list_.end()) { - resource = GetResource((*ia)->cell_id()); - MS_EXCEPTION_IF_NULL(resource); - MS_LOG(DEBUG) << "Find old resource " << resource.get(); - } - if (resource == nullptr) { - resource = std::make_shared(); - MS_LOG(DEBUG) << "Make new resource " << resource.get(); - } - MS_EXCEPTION_IF_NULL(resource); - FuncGraphPtr df_builder = std::make_shared(); - auto graph_info = std::make_shared(cell_id); - auto top_cell_info = std::make_shared(false, resource, df_builder, cell_id); - top_cell()->graph_info_map()[df_builder] = graph_info; - top_cell_list_.emplace_back(top_cell_info); - FuncGraphPtr forward_graph = nullptr; - auto ib = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); - if (ib != top_cell()->cell_graph_list().end()) { - forward_graph = (*ib)->fg(); - } - MS_EXCEPTION_IF_NULL(forward_graph); - if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { - DumpGraphIR("nested_bprop.ir", forward_graph); - // Custom bprop get backward graph(before opt), which use like other forward graph - curr_g_ = forward_graph; - resource->set_func_graph(forward_graph); - return; - } - - // Copy weights parameters - ReplaceGraphParams(df_builder, forward_graph, cell_id); - resource->manager()->AddFuncGraph(forward_graph); - DumpGraphIR("nested_fg.ir", forward_graph); - set_need_replace_forward(false); - auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args); - resource->set_func_graph(newfg); -} - -void GradExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph, - const std::string &cell_id) { - std::vector graph_before{}; - bool index_find = false; - for (const auto &it : top_cell()->cell_graph_list()) { - if (IsBpropGraph(it->cell_id()) || it->fg() == nullptr) { - continue; - } - if (index_find) { - graph_before.emplace_back(it->fg()); - continue; - } - if (it->cell_id() == cell_id) { - index_find = true; - graph_before.emplace_back(it->fg()); - } - } - - auto manager = Manage({forward_graph}, false); - for (const auto &f : graph_before) { - auto graph_info = top_cell()->graph_info_map().at(f); - MS_EXCEPTION_IF_NULL(graph_info); - for (const auto &it : graph_info->params) { - if (!it.second->has_default()) { - continue; - } - auto new_param = df_builder->add_parameter(); - new_param->set_abstract(it.second->abstract()); - new_param->set_name(it.second->name()); - new_param->set_default_param(it.second->default_param()); - ScopePtr scope = (it.second->scope() != kDefaultScope) ? it.second->scope() : kDefaultScope; - new_param->set_scope(scope); - manager->Replace(it.second, new_param); - replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param)); - MS_LOG(DEBUG) << "Param name " << new_param->name() << " ptr " << new_param.get(); - - auto graph_info_of_df_builder = top_cell()->graph_info_map().at(df_builder); - MS_EXCEPTION_IF_NULL(graph_info_of_df_builder); - graph_info_of_df_builder->params[it.first] = new_param; - SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param); - SetNodeMapInGraphInfoMap(df_builder, it.first, new_param); - } - } - graph_before.clear(); -} - -void GradExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) { - std::vector new_params; - for (size_t i = 0; i < size; i++) { - ParameterPtr p = std::make_shared(df_builder); - new_params.emplace_back(p); - } - MS_LOG(DEBUG) << "GradNet weight param size " << df_builder->parameters().size(); - // df_builder_->parameters() set in GetInput, which are weights params - new_params.insert(new_params.end(), df_builder->parameters().begin(), df_builder->parameters().end()); - df_builder->set_parameters(new_params); - resource->manager()->SetParameters(df_builder, new_params); -} - -std::vector GradExecutor::GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder) { - std::vector w_args; - if (!py::hasattr(weights, "__parameter_tuple__")) { +std::vector GradExecutor::GetWeightsArgs(const GradOperationPtr &grad, const py::object &weights, + const FuncGraphPtr &df_builder) { + MS_EXCEPTION_IF_NULL(grad); + MS_EXCEPTION_IF_NULL(df_builder); + if (!grad->get_by_list_ && py::isinstance(weights)) { + MS_LOG(DEBUG) << "The input weight is None when run GradNetInner. Return parameters of df_builder directly"; + return df_builder->parameters(); + } else if (!py::hasattr(weights, "__parameter_tuple__")) { MS_LOG(DEBUG) << "No paramter_tuple get"; return {}; } + auto tuple = weights.cast(); MS_LOG(DEBUG) << "Get weights tuple size " << tuple.size(); - w_args.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + std::vector w_args; for (size_t it = 0; it < tuple.size(); ++it) { auto param = tuple[it]; auto param_id = GetId(param); - AnfNodePtr para_node = nullptr; auto graph_info = top_cell()->graph_info_map().at(df_builder); MS_EXCEPTION_IF_NULL(graph_info); + AnfNodePtr para_node = nullptr; if (graph_info->params.find(param_id) != graph_info->params.end() && graph_info->node_map.find(param_id) != graph_info->node_map.end()) { para_node = graph_info->node_map[param_id].first; @@ -2870,114 +2102,70 @@ std::vector GradExecutor::GetWeightsArgs(const py::object &weights, return w_args; } -abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder) { - abstract::AbstractBasePtrList args_spec; +abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph) { + MS_EXCEPTION_IF_NULL(bprop_graph); std::size_t size = args.size(); - auto df_params = df_builder->parameters(); - if (df_params.size() < size) { - MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << size; - } - // input params - for (std::size_t i = 0; i < size; i++) { - ValuePtr converted = nullptr; - bool succ = parse::ConvertData(args[i], &converted); - if (!succ) { - MS_LOG(EXCEPTION) << "Args convert error"; - } - bool broaden = true; - auto abs = abstract::FromValue(converted, broaden); - args_spec.emplace_back(abs); - auto param_node = std::static_pointer_cast(df_params[i]); - param_node->set_abstract(abs); - } - // weights params - for (const auto ¶m : df_params) { + abstract::AbstractBasePtrList args_spec; + auto bprop_params = bprop_graph->parameters(); + if (bprop_params.size() < size) { + MS_LOG(EXCEPTION) << "Df parameters size " << bprop_params.size() << " less than " << size; + } + // Update abstract info for parameters in bprop graph + size_t index = 0; + for (const auto ¶m : bprop_params) { auto param_node = std::static_pointer_cast(param); + MS_EXCEPTION_IF_NULL(param_node); if (param_node->has_default()) { + // update abstract info for weights ValuePtr value = param_node->default_param(); auto ptr = value->ToAbstract(); MS_EXCEPTION_IF_NULL(ptr); args_spec.emplace_back(ptr); - param_node->set_abstract(ptr); + param_node->set_abstract(ptr->Broaden()); + } else { + // update abstract info for input params + ValuePtr input_value = parse::data_converter::PyDataToValue(args[index]); + MS_EXCEPTION_IF_NULL(input_value); + auto abs = abstract::FromValue(input_value, true); + args_spec.emplace_back(abs); + param_node->set_abstract(abs->Broaden()); + index++; } } MS_LOG(DEBUG) << "Args_spec size " << args_spec.size(); return args_spec; } -void GradExecutor::SetGradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, - const std::vector &weights, size_t arg_size, const std::string &cell_id) { - FuncGraphPtr top_g = nullptr; - auto it = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); - if (it != top_cell()->cell_graph_list().end()) { - top_g = (*it)->fg(); - } - MS_EXCEPTION_IF_NULL(top_g); - auto nparam = top_g->parameters().size(); - MS_LOG(DEBUG) << "Top graph input params size " << nparam; +FuncGraphPtr GradExecutor::GetBpropGraph(const GradOperationPtr &grad, const std::vector &weights, + size_t arg_size, const py::args &args) { + MS_EXCEPTION_IF_NULL(grad); + auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr(); + MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr); + auto bprop_graph = + ad::GradPynativeCellEnd(k_pynative_cell_ptr, weights, grad->get_all_, grad->get_by_list_, grad->sens_param_); + MS_EXCEPTION_IF_NULL(bprop_graph); + + MS_LOG(DEBUG) << "Top graph input params size " << arg_size; std::ostringstream ss; - ss << "grad{" << nparam << "}"; - auto df_builder = GetDfbuilder(cell_id); - MS_EXCEPTION_IF_NULL(df_builder); - auto resource = GetResource(cell_id); - MS_EXCEPTION_IF_NULL(resource); - df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true); - df_builder->debug_info()->set_name(ss.str()); - - auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g->parameters(), weights); - std::vector inputs = {NewValueNode(df)}; - auto df_params = df_builder->parameters(); - if (df_params.size() < arg_size) { - MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << arg_size; - } - for (size_t i = 0; i < arg_size; ++i) { - inputs.emplace_back(df_params[i]); - } - auto out = df_builder->NewCNode(inputs); - df_builder->set_output(out); - resource->manager()->AddFuncGraph(df); - resource->manager()->AddFuncGraph(df_builder); -} - -void GradExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id) { - top_cell()->graph_info_map().erase(df_builder); - bool has_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); - bool is_dynamic_top_fist_grad = CheckDynamicCell(cell_id) && IsFirstGradStep(); - bool is_topmost = IsTopestGraph(cell_id); - if (has_custom_bprop || is_dynamic_top_fist_grad || !is_topmost) { - return; - } + ss << "grad{" << arg_size << "}"; + bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + bprop_graph->debug_info()->set_name(ss.str()); + // Get the parameters items and add the value to args_spec + auto args_spec = GetArgsSpec(args, bprop_graph); - MS_LOG(DEBUG) << "Update topmost cell graph list and graph info map"; - // Clear grad()->top_cell()->graph_info_map() - std::vector l{}; - bool index_find = false; - for (auto &it : top_cell()->cell_graph_list()) { - if (it->fg() != nullptr) { - ClearCnodeRes(it->fg()->output()); - it->set_fg(nullptr); - } - if (index_find) { - it->set_fg(nullptr); - l.emplace_back(it->cell_id()); - continue; - } - if (it->cell_id() == cell_id) { - index_find = true; - it->set_fg(nullptr); - l.emplace_back(it->cell_id()); - } - } - for (const auto &it : l) { - for (auto ic = top_cell()->graph_info_map().begin(); ic != top_cell()->graph_info_map().end();) { - if (ic->second->cell_id.find(it) != std::string::npos) { - ic = top_cell()->graph_info_map().erase(ic); - } else { - ++ic; - } - } + // Do opt for final bprop graph + ResourcePtr resource = std::make_shared(); + resource->set_func_graph(bprop_graph); + auto manager = resource->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(bprop_graph); + auto optimized_bg = pipeline::PrimBpropOptimizer::GetPrimBpropOptimizerInst().BpropGraphFinalOpt(resource); + + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { + DumpIR("after_final_opt.ir", optimized_bg); } + optimized_bg->ClearAllManagerInfo(); + return optimized_bg; } py::object GradExecutor::CheckGraph(const py::object &cell, const py::args &args) { @@ -2987,28 +2175,22 @@ py::object GradExecutor::CheckGraph(const py::object &cell, const py::args &args MS_LOG(DEBUG) << "Grad not running yet"; return BaseRefToPyData(ret); } - const auto &cell_id = GetCellId(cell, args); - std::string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); - MS_LOG(DEBUG) << "Key is " << key; - for (auto it = top_cell()->cell_graph_list().begin(); it != top_cell()->cell_graph_list().end(); ++it) { - MS_LOG(DEBUG) << "Cur cell id " << (*it)->cell_id(); - if (key != (*it)->cell_id().substr(0, std::min(PTR_LEN, (*it)->cell_id().size()))) { - continue; - } - MS_LOG(DEBUG) << "Delete cellid from cell graph list"; - top_cell()->graph_info_map().erase((*it)->fg()); - top_cell()->cell_graph_list().erase(it); - if (IsTopestGraph(cell_id)) { - ClearCellRes(cell_id); - } - ret = true; - break; - } return BaseRefToPyData(ret); } +TopCellInfoPtr GradExecutor::GetTopCell(const string &cell_id) { + auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { + return cell_id == value->cell_id() && value->is_topest(); + }); + if (it != top_cell_list_.end()) { + return *it; + } + return nullptr; +} + py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) { bool forward_run = false; + // Get cell id and input args info const auto &cell_id = grad_executor()->GetCellId(cell, args); std::string input_args_id; for (size_t i = 0; i < args.size(); ++i) { @@ -3016,24 +2198,49 @@ py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::a } auto top_cell = grad_executor()->GetTopCell(cell_id); if (top_cell != nullptr) { - if (!top_cell->input_args_id().empty() && top_cell->input_args_id() != input_args_id && - top_cell->forward_already_run() && top_cell->has_dynamic_cell()) { + forward_run = top_cell->forward_already_run(); + bool input_args_changed = !top_cell->input_args_id().empty() && top_cell->input_args_id() != input_args_id; + if (forward_run && input_args_changed && top_cell->is_dynamic()) { MS_LOG(WARNING) << "The construct of running cell is dynamic and the input info of this cell has changed, " "forward process will run again"; - top_cell->set_forward_already_run(false); - top_cell->set_input_args_id(input_args_id); - } else { - forward_run = top_cell->forward_already_run() && !top_cell->is_real_dynamic(); - } - if (forward_run) { - grad_executor()->set_top_cell(top_cell); + forward_run = false; } - MS_LOG(DEBUG) << " Top cell id " << top_cell->cell_id(); } - MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id; + MS_LOG(DEBUG) << "Graph have already run " << forward_run << " top cell id " << cell_id; return BaseRefToPyData(forward_run); } +bool GradExecutor::CheckNeedCompileGraph() { + auto new_top_cell = top_cell(); + bool ret = true; + std::string top_cell_id = new_top_cell->cell_id(); + if (already_run_top_cell_.find(top_cell_id) == already_run_top_cell_.end()) { + MS_LOG(DEBUG) << "Top cell " << top_cell_id << " has never been ran, need compile graph"; + already_run_top_cell_[top_cell_id] = new_top_cell; + } else { + MS_LOG(DEBUG) << "Top cell " << top_cell_id << " has been ran"; + auto pre_top_cell = already_run_top_cell_.at(top_cell_id); + auto pre_all_op_info = pre_top_cell->all_op_info(); + auto new_all_op_info = new_top_cell->all_op_info(); + MS_LOG(DEBUG) << "Pre all op info " << pre_all_op_info; + MS_LOG(DEBUG) << "New all op info " << new_all_op_info; + if (pre_all_op_info != new_all_op_info) { + MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again"; + EraseTopCellFromTopCellList(pre_top_cell); + pre_top_cell->clear(); + already_run_top_cell_[top_cell_id] = new_top_cell; + } else { + EraseTopCellFromTopCellList(new_top_cell); + new_top_cell->clear(); + set_top_cell(already_run_top_cell_.at(top_cell_id)); + MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again"; + ret = pre_top_cell->need_compile_graph(); + } + } + top_cell()->set_need_compile_graph(ret); + return ret; +} + void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args, const py::object &phase) { MS_EXCEPTION_IF_NULL(ret); @@ -3042,8 +2249,7 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p auto has_sens = std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { return cell_id.find(value->cell_id()) != std::string::npos && cell_id != value->cell_id(); }); - py::object forward_args = args; - cell_id = GetGradCellId(has_sens, cell, args, &forward_args); + cell_id = GetGradCellId(has_sens, cell, args); MS_LOG(DEBUG) << "Run has sens " << has_sens << " forward cell id " << cell_id; auto resource = GetResource(cell_id); MS_EXCEPTION_IF_NULL(resource); @@ -3068,10 +2274,7 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p set_grad_runing(false); MS_LOG(DEBUG) << "Eval run end " << value.ToString(); *ret = BaseRefToPyData(value); - auto do_vm_compiled = std::any_of( - top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id() == cell_id && value->vm_compiled(); }); - if (do_vm_compiled) { + if (top_cell()->vm_compiled()) { if (MakeBpropNestedCnode(cell, *ret, cell_id)) { return; } @@ -3080,14 +2283,13 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p } bool GradExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) { - if (graph_stack_.empty() || !py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { + if (cell_stack_.empty() || !py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { MS_LOG(DEBUG) << "No nested bprop grad find"; return false; } auto out_id = GetId(out); std::vector inputs; inputs.emplace_back(NewValueNode(curr_g_)); - PopGraphStack(); auto graph_info = top_cell()->graph_info_map().at(curr_g_); MS_EXCEPTION_IF_NULL(graph_info); for (const auto &ig : graph_info->params) { @@ -3104,75 +2306,99 @@ bool GradExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object void GradExecutor::MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, const py::object &out, bool has_sens) { - if (graph_stack_.empty()) { + if (cell_stack_.empty()) { MS_LOG(DEBUG) << "No nested grad find"; return; } - auto graph_prev = graph_stack_.top(); - MS_EXCEPTION_IF_NULL(graph_prev); - MS_LOG(DEBUG) << "Get pre graph ptr " << graph_prev.get(); - auto newfg = resource->func_graph(); - MS_EXCEPTION_IF_NULL(newfg); auto inputs_size = args.size(); if (has_sens) { inputs_size -= 1; } - std::vector inputs; - inputs.emplace_back(NewValueNode(newfg)); + py::tuple f_args(inputs_size); for (size_t i = 0; i < inputs_size; ++i) { - inputs.emplace_back(GetInput(args[i], false)); + f_args[i] = args[i]; } - if (newfg->parameters().size() > args.size()) { - RecoverGraphParams(newfg, cell_id, &inputs); + PopHighOrderGraphStack(); + auto graph_prev = high_order_stack_.top(); + MS_EXCEPTION_IF_NULL(graph_prev); + MS_LOG(DEBUG) << "Get pre graph ptr " << graph_prev.get(); + top_cell_list_.pop_back(); + set_top_cell(top_cell_list_.back()); + auto first_grad_fg = resource->func_graph(); + MS_EXCEPTION_IF_NULL(first_grad_fg); + + std::vector inputs{NewValueNode(first_grad_fg)}; + // new_fg used for grad + FuncGraphPtr new_fg = std::make_shared(); + for (size_t i = 0; i < inputs_size; ++i) { + auto new_param = new_fg->add_parameter(); + inputs.emplace_back(new_param); + } + auto cnode = new_fg->NewCNode(inputs); + new_fg->set_output(cnode); + DumpIR("new_fg.ir", new_fg); + // Get high order grad graph + ResourcePtr r = std::make_shared(); + r->manager()->AddFuncGraph(new_fg); + auto second_grad_fg = ad::Grad(new_fg, r, false); + r->manager()->AddFuncGraph(second_grad_fg); + parse::ResolveFuncGraph(second_grad_fg, r); + r->set_func_graph(second_grad_fg); + r->set_args_spec(GetArgsSpec(f_args, second_grad_fg)); + PynativeOptimizeAction(r); + DumpIR("nested.ir", r->func_graph()); + r->func_graph()->add_parameter(); + r->func_graph()->add_parameter(); + + // Set next forward + inputs.clear(); + inputs.emplace_back(NewValueNode(first_grad_fg)); + for (size_t i = 0; i < inputs_size; ++i) { + inputs.emplace_back(GetInput(args[i], false)); } + cnode = graph_prev->NewCNode(inputs); auto out_id = GetId(out); - auto cnode = graph_prev->NewCNode(inputs); SetTupleArgsToGraphInfoMap(graph_prev, out, cnode); SetNodeMapInGraphInfoMap(graph_prev, out_id, cnode); MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4); + + // New process + ValuePtrList input_args; + for (size_t i = 0; i < inputs_size; ++i) { + auto arg = parse::data_converter::PyDataToValue(args[i]); + MS_EXCEPTION_IF_NULL(arg); + input_args.emplace_back(arg); + } + // get output value + auto out_value = parse::data_converter::PyDataToValue(out); + MS_EXCEPTION_IF_NULL(out_value); + // Add out and dout + r->func_graph()->add_parameter(); + r->func_graph()->add_parameter(); + if (!ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, out_value, r->func_graph())) { + MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << cnode->ToString(); + } } -void GradExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, - std::vector *inputs) { - FuncGraphPtr forward_graph = nullptr; - auto ic = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); - if (ic != top_cell()->cell_graph_list().end()) { - forward_graph = (*ic)->fg(); - } - MS_EXCEPTION_IF_NULL(forward_graph); - auto param_list = replace_weights_map_.at(forward_graph); - auto params = newfg->parameters(); - auto manage = Manage({newfg}, false); - for (const auto &it : params) { - auto param = it->cast(); - if (!param->has_default()) { - continue; - } - for (auto p = param_list.begin(); p != param_list.end();) { - MS_LOG(DEBUG) << "Param name " << param->name() << " ptr " << param.get(); - if (p->second->name() == param->name()) { - manage->Replace(param, p->first); - inputs->emplace_back(p->first); - param_list.erase(p); - break; - } - } +void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) { + MS_EXCEPTION_IF_NULL(top_cell); + auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), + [&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); }); + if (iter == top_cell_list_.end()) { + MS_LOG(EXCEPTION) << "Can not find top cell " << top_cell.get() << " cell id " << top_cell->cell_id() + << " from top cell list"; } - replace_weights_map_.erase(forward_graph); + (void)top_cell_list_.erase(iter); } void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) { const auto &cell_id = GetCellId(cell, args); - if (IsTopestGraph(cell_id)) { - pre_top_cell_ = top_cell(); - set_top_cell(nullptr); + if (IsTopGraph(cell_id)) { in_grad_process_ = false; } in_bprop_process_ = false; SubNestedGradOrder(); forward()->node_abs_map().clear(); - obj_to_forward_id_.clear(); ad::CleanRes(); pipeline::ReclaimOptimizer(); } @@ -3182,20 +2408,12 @@ void GradExecutor::ClearRes() { grad_order_ = 0; grad_flag_ = false; in_grad_process_ = false; - has_dynamic_cell_ = false; - need_replace_forward_ = true; grad_is_running_ = false; - pre_top_cell_ = nullptr; top_cell_ = nullptr; curr_g_ = nullptr; - op_index_map_.clear(); - replace_weights_map_.clear(); - all_value_node_tensors_.clear(); - obj_to_forward_id_.clear(); ClearCellRes(); - top_cell_list_.clear(); - std::stack().swap(graph_stack_); - std::stack().swap(cell_op_info_stack_); + std::stack().swap(cell_stack_); + std::stack().swap(high_order_stack_); } GradExecutorPtr PynativeExecutor::grad_executor() { @@ -3213,7 +2431,7 @@ bool PynativeExecutor::GetIsDynamicCell() { if (grad_executor_ == nullptr) { return false; } - return grad_executor_->TopCellIsDynamic(); + return true; } py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) { @@ -3255,11 +2473,19 @@ void PynativeExecutor::ClearRes() { } void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { + if (!grad_executor()->grad_flag()) { + MS_LOG(DEBUG) << "Grad flag is false"; + return; + } py::object *ret = nullptr; PynativeExecutorTry(grad_executor()->InitGraph, ret, cell, args); } void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { + if (!grad_executor()->grad_flag()) { + MS_LOG(DEBUG) << "Grad flag is false"; + return; + } MS_LOG(DEBUG) << "Enter end graph process."; py::object *ret = nullptr; auto &mem_cleaner = pipeline::Resource::mem_cleaner(); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 9c1a54d8f0cba735866bdd641419e41197771def..5853aff735a571ae8bd6f3c9b5dc848d25d82717 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -36,11 +36,12 @@ #include "utils/ms_context.h" #include "ir/anf.h" #include "pipeline/jit/resource.h" +#include "frontend/optimizer/ad/kpynative.h" #include "frontend/operator/composite/composite.h" -namespace mindspore { -namespace pynative { +namespace mindspore::pynative { namespace py = pybind11; +using cell_id = std::string; using ResourcePtr = std::shared_ptr; using GradOperationPtr = std::shared_ptr; @@ -52,8 +53,8 @@ struct PrimAbsInfo { using AbstractListMap = std::unordered_map; -using OpIndexWithTensorId = std::unordered_map>; -using TensorIdWithTensor = std::unordered_map>; +using OpInfoWithTensorId = std::unordered_map>; +using TensorIdWithTensorObject = std::unordered_map>; py::object RunOp(const py::args &args); @@ -64,47 +65,11 @@ struct GraphInfo { AnfNodePtr output; OrderedMap params; // hold input parameters and cell weights std::unordered_map>> node_map; - std::vector objects; GraphInfo() = default; explicit GraphInfo(std::string id) : cell_id(std::move((id))) {} }; using GraphInfoPtr = std::shared_ptr; -class CellInfo { - public: - CellInfo() = default; - ~CellInfo() = default; - CellInfo(bool custom_bprop, bool has_dynamic, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id) - : is_custom_bprop_(custom_bprop), - is_dynamic_(has_dynamic), - fg_(std::move(foward_graph)), - cell_id_(std::move(cellid)), - bprop_cell_id_(std::move(bprop_id)) {} - - bool is_custom_bprop() const { return is_custom_bprop_; } - void set_is_custom_bprop(bool is_custom_bprop) { is_custom_bprop_ = is_custom_bprop; } - bool is_dynamic() const { return is_dynamic_; } - void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; } - size_t call_times() const { return call_times_; } - void set_call_times(size_t call_times) { call_times_ = call_times; } - FuncGraphPtr fg() const { return fg_; } - void set_fg(FuncGraphPtr fg) { fg_ = std::move(fg); } - std::string &cell_id() { return cell_id_; } - void set_cell_id(std::string cell_id) { cell_id_ = std::move(cell_id); } - std::string &bprop_cell_id() { return bprop_cell_id_; } - std::vector &cell_ops_info() { return cell_ops_info_; } - - private: - bool is_custom_bprop_{false}; // Custom bprop - bool is_dynamic_{false}; // Set by has_dynamic_cell - size_t call_times_{0}; - FuncGraphPtr fg_{nullptr}; // Forward graph - std::string cell_id_; - std::string bprop_cell_id_; - std::vector cell_ops_info_; // All ops info -}; -using CellInfoPtr = std::shared_ptr; - class TopCellInfo { public: TopCellInfo() = default; @@ -112,60 +77,58 @@ class TopCellInfo { TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid) : is_topest_(topest), resource_(std::move(r)), df_builder_(std::move(df)), cell_id_(std::move(cellid)) {} - bool is_grad() const { return is_grad_; } - void set_is_grad(bool is_grad) { is_grad_ = is_grad; } + bool is_init_kpynative() const { return is_init_kpynative_; } + void set_init_kpynative(bool init) { is_init_kpynative_ = init; } bool is_topest() const { return is_topest_; } + bool is_dynamic() const { return is_dynamic_; } + void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; } bool vm_compiled() const { return vm_compiled_; } void set_vm_compiled(bool vm_compiled) { vm_compiled_ = vm_compiled; } - bool need_grad() const { return need_grad_; } - void set_need_grad(bool need_grad) { need_grad_ = need_grad; } - bool has_dynamic_cell() const { return has_dynamic_cell_; } - bool is_real_dynamic() const { return is_real_dynamic_; } - void set_is_real_dynamic(bool is_real_dynamic) { is_real_dynamic_ = is_real_dynamic; } + bool need_compile_graph() const { return need_compile_graph_; } + void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; } bool forward_already_run() const { return forward_already_run_; } void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; } ResourcePtr resource() { return resource_; } FuncGraphPtr df_builder() { return df_builder_; } + size_t op_num() const { return op_num_; } + void set_op_num(size_t op_num) { op_num_ = op_num; } std::string &cell_id() { return cell_id_; } - std::string &sens_id() { return sens_id_; } - void set_sens_id(std::string sens_id) { sens_id_ = std::move(sens_id); } - std::string &weights_id() { return weights_id_; } - void set_weights_id(std::string weights_id) { weights_id_ = std::move(weights_id); } std::string &input_args_id() { return input_args_id_; } + std::string all_op_info() const { return all_op_info_; } + void set_all_op_info(std::string all_op_info) { all_op_info_ = all_op_info; } void set_input_args_id(std::string input_args_id) { input_args_id_ = std::move(input_args_id); } - std::vector &cell_graph_list() { return cell_graph_list_; } - void set_cell_graph_list(const std::vector &cell_graph_list) { cell_graph_list_ = cell_graph_list; } + std::unordered_set &sub_cell_list() { return sub_cell_list_; } + bool IsSubCell(const std::string &cell_id) const; OrderedMap &graph_info_map() { return graph_info_map_; } - void set_graph_info_map(const OrderedMap &graph_info_map) { - graph_info_map_ = graph_info_map; - } - void clear() { - cell_graph_list_.clear(); - graph_info_map_.clear(); + OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; } + TensorIdWithTensorObject &tensor_id_with_tensor_object() { return tensor_id_with_tensor_object_; } + ad::KPynativeCellPtr k_pynative_cell_ptr() const { return k_pynative_cell_ptr_; } + void set_k_pynative_cell_ptr(const ad::KPynativeCellPtr &k_pynative_cell_ptr) { + k_pynative_cell_ptr_ = k_pynative_cell_ptr; } + void clear(); private: - bool is_grad_{false}; // Derivative is calculated bool is_topest_{false}; + bool is_dynamic_{false}; bool vm_compiled_{false}; - bool need_grad_{true}; - bool has_dynamic_cell_{false}; - bool is_real_dynamic_{false}; + bool is_init_kpynative_{false}; bool forward_already_run_{false}; + bool need_compile_graph_{false}; + size_t op_num_{0}; ResourcePtr resource_{nullptr}; FuncGraphPtr df_builder_{nullptr}; + ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr}; std::string cell_id_; - std::string sens_id_; - std::string weights_id_; std::string input_args_id_; - std::vector cell_graph_list_; + std::string all_op_info_; OrderedMap graph_info_map_; + std::unordered_set sub_cell_list_; + OpInfoWithTensorId op_info_with_tensor_id_; + TensorIdWithTensorObject tensor_id_with_tensor_object_; }; using TopCellInfoPtr = std::shared_ptr; -class DynamicAnalysis; -using DynamicAnalysisPtr = std::shared_ptr; - class ForwardExecutor; using ForwardExecutorPtr = std::shared_ptr; using ForwardExecutorWeakPtr = std::weak_ptr; @@ -174,30 +137,6 @@ class GradExecutor; using GradExecutorPtr = std::shared_ptr; using GradExecutorWeakPtr = std::weak_ptr; -class DynamicAnalysis { - public: - DynamicAnalysis() = default; - ~DynamicAnalysis() = default; - - // Check cell struct - bool IsDynamicCell(const py::object &cell); - - private: - std::string GetCellInfo(const py::object &cell); - void ParseInputArgs(const std::shared_ptr &ast, const py::object &fn_node); - bool ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node, - const std::vector &compare_prim = {}); - bool ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node); - bool ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node); - bool ParseAugAssignExprNode(const std::shared_ptr &ast, const py::object &node, - const std::vector &compare_prim = {}); - bool ParseForExprNode(const std::shared_ptr &ast, const py::object &node); - std::string ParseNodeName(const std::shared_ptr &ast, const py::object &node, - parse::AstMainType type); - - std::unordered_set cell_input_args_; -}; - class GradExecutor { public: GradExecutor() = default; @@ -227,112 +166,73 @@ class GradExecutor { FuncGraphPtr curr_g() const; TopCellInfoPtr top_cell() const; - bool TopCellIsDynamic(); + TopCellInfoPtr top_cell_direct() const { return top_cell_; } + bool CheckNeedCompileGraph(); + TopCellInfoPtr GetTopCell(const string &cell_id); void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); } bool grad_flag() const { return grad_flag_; } void set_grad_flag(bool flag) { grad_flag_ = flag; } bool in_grad_process() const { return in_grad_process_; } - std::string top_cell_id() { return top_cell()->cell_id(); } + bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; } AnfNodePtr GetInput(const py::object &obj, bool op_mask); std::string GetCellId(const py::object &obj, const py::args &args); - TopCellInfoPtr GetTopCell(const string &cell_id, bool find_nearest = false); + std::stack &cell_stack() { return cell_stack_; } + std::vector &top_cell_list() { return top_cell_list_; } + bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; } void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode); - void SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &out_real); + void DoOpGrad(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &op_out); + void UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const py::object &out_real); + void SaveForwardTensorInfoInBpropGraph(const ResourcePtr &resource); py::object CheckGraph(const py::object &cell, const py::args &args); void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args, const py::object &phase); - bool need_construct_graph() const { return !graph_stack_.empty() && grad_flag_; } - void set_dynamic_analysis(DynamicAnalysisPtr dynamic_analysis) { dynamic_analysis_ = std::move(dynamic_analysis); } - std::stack &graph_stack() { return graph_stack_; } - std::vector &top_cell_list() { return top_cell_list_; } - bool need_replace_forward() const { return need_replace_forward_; } - std::stack &cell_op_info_stack() { return cell_op_info_stack_; } - std::unordered_map &op_index_map() { return op_index_map_; } - std::unordered_map &obj_to_forward_id() { return obj_to_forward_id_; } + void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell); void ClearGrad(const py::object &cell, const py::args &args); void ClearRes(); void ClearCellRes(const std::string &cell_id = ""); private: ForwardExecutorPtr forward() const; - DynamicAnalysisPtr dynamic_analysis() const; bool grad_running() const { return grad_is_running_; } void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; } - void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; } // Higher derivative bool IsNestedGrad() const; + size_t cell_nums() const { return cell_nums_; } + void set_cell_nums(size_t cell_nums) { cell_nums_ = cell_nums; } void AddNestedGradOrder() { ++grad_order_; } void SubNestedGradOrder(); - void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph, - const std::string &cell_id); - void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id); void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, const py::object &out, bool has_sens); - void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector *inputs); bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id); - - // Dynamic - bool CheckDynamicCell(const std::string &cell_id); - bool CheckRealDynamicCell(const std::string &cell_id); - void ClearDynamicTopRes(const std::string &cell_id); - - void PushCurrentGraphToStack(); - void PopGraphStack(); - void PushCurrentCellOpInfoToStack(); - void PopCurrentCellOpInfoFromStack(); - std::string GetCellOpInfo(); - void ReplaceCellOpInfoByCellId(const std::string &cell_id); + void PushCellStack(const std::string &cell_id); + void PopCellStack(); + void PushHighOrderGraphStack(); + void PopHighOrderGraphStack(); FuncGraphPtr GetDfbuilder(const std::string &cell_id = ""); ResourcePtr GetResource(const std::string &cell_id = ""); - bool IsFirstGradStep(); + bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id); bool IsTopGraph(const std::string &cell_id); - bool IsTopestGraph(const std::string &cell_id); bool IsBpropGraph(const std::string &cell_id); - bool IsGradBefore(const std::string &cell_id); - bool CheckCellGraph(const std::string &cell_id); - bool UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned, - bool is_grad); - void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, - bool need_cloned = false, bool is_grad = false); - bool CheckCellChanged(const std::string &cell_id); - void UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled); + void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled); void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph); + void InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args); void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args); - void MakeNewTopGraph(const string &cell_id, const py::args &args); + void MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest); void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args); - void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, - const std::string &out_id, const py::args &args); - bool EndBpropGraph(const string &cell_id); - FuncGraphPtr MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, - const std::string &cell_id, const py::args &args); - std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args, - py::object *sens = nullptr); + std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args); void GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); - void SetTopCellTensorId(const std::string &cell_id); - bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens); - void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size); - void SetGradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector &weights, - size_t arg_size, const std::string &cell_id); - std::vector GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder); - abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder); - void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id); + FuncGraphPtr GetBpropGraph(const GradOperationPtr &grad, const std::vector &weights, size_t arg_size, + const py::args &args); + std::vector GetWeightsArgs(const GradOperationPtr &grad, const py::object &weights, + const FuncGraphPtr &df_builder); + abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph); void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node, const std::vector &index_sequence, bool is_param = false); AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id); AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); - // Memory clean between steps - void ClearResidualRes(const std::string &cell_id); - void ClearCnodeRes(const AnfNodePtr &node); - void CleanPreMemoryInValueNode(); - void SaveTensorsInValueNode(const ResourcePtr &resource); - void SaveAllValueNodeTensors(const FuncGraphPtr &graph); - - void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) { - top_cell()->graph_info_map()[g]->objects.push_back(obj); - } void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, bool is_param = false); void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) { @@ -346,33 +246,31 @@ class GradExecutor { const std::vector &index) { top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index); } + void SetMakeTupleAsOutputNode(const TopCellInfoPtr &top_cell, const FuncGraphPtr &curr_g, const py::object &out); + void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args); private: - size_t grad_order_{0}; bool grad_flag_{false}; bool in_bprop_process_{false}; bool in_grad_process_{false}; - bool has_dynamic_cell_{false}; - bool need_replace_forward_{true}; bool grad_is_running_{false}; + int custom_bprop_cell_count_{0}; + size_t grad_order_{0}; + size_t cell_nums_{0}; + FuncGraphPtr curr_g_{nullptr}; // For clear pre top res - TopCellInfoPtr pre_top_cell_{nullptr}; TopCellInfoPtr top_cell_{nullptr}; - std::unordered_map op_index_map_; - std::unordered_map>> replace_weights_map_; - std::unordered_set all_value_node_tensors_; - std::unordered_map obj_to_forward_id_; - - // Records forwrad graph, the bottom is top graph - std::stack graph_stack_; - // Records op info of every cell, the bottom is op info of top cell - std::stack cell_op_info_stack_; - + // Records forwrad cell, the bottom is top cell + std::stack cell_stack_; + // For high grad order + std::stack high_order_stack_; // Use vector for keep order std::vector top_cell_list_; + // Record all top cell which has been ran + std::map already_run_top_cell_; + // Use vector for keep order ForwardExecutorWeakPtr forward_executor_; - DynamicAnalysisPtr dynamic_analysis_; }; class ForwardExecutor { @@ -388,13 +286,9 @@ class ForwardExecutor { OpExecInfoPtr GenerateOpExecInfo(const py::args &args); void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); } std::unordered_map &node_abs_map() { return node_abs_map_; } - std::unordered_map &cell_op_index_with_tensor_id() { - return cell_op_index_with_tensor_id_; - } - std::unordered_map &cell_tensor_id_with_tensor() { - return cell_tensor_id_with_tensor_; - } void ClearRes(); + AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, + abstract::AbstractBasePtrList *args_spec_list); private: GradExecutorPtr grad() const; @@ -404,8 +298,6 @@ class ForwardExecutor { py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); - AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, - abstract::AbstractBasePtrList *args_spec_list); bool FindOpMask(py::object obj, std::vector *op_masks, std::string id); void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, std::vector *inputs, abstract::AbstractBasePtrList *args_spec_list); @@ -413,9 +305,6 @@ class ForwardExecutor { const abstract::AbstractBasePtr &abs, const std::string &id, size_t index); void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, bool *is_find); - // Update the abstract and device address info of value node and tensors in bprop graph - void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); - // Mix precision void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info); py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index); @@ -430,9 +319,6 @@ class ForwardExecutor { GradExecutorWeakPtr grad_executor_; std::unordered_map prim_abs_list_; std::unordered_map node_abs_map_; - // Used for runop and replace forward result of grad graph - std::unordered_map cell_op_index_with_tensor_id_; - std::unordered_map cell_tensor_id_with_tensor_; // Used to cache cast struct std::unordered_map cast_struct_map_; // Used to cache op_mask @@ -447,7 +333,6 @@ class PynativeExecutor : public std::enable_shared_from_this { executor_ = std::shared_ptr(new (std::nothrow) PynativeExecutor()); forward_executor_ = std::make_shared(); grad_executor_ = std::make_shared(forward_executor_); - grad_executor_->set_dynamic_analysis(std::make_shared()); forward_executor_->set_grad_executor(grad_executor_); } return executor_; @@ -471,9 +356,8 @@ class PynativeExecutor : public std::enable_shared_from_this { // Used by graph clean bool GetIsDynamicCell(); - bool need_replace_forward() { return grad_executor()->need_replace_forward(); } // Cell destruct will call - void ClearCell(const std::string &flag = ""); + void ClearCell(const std::string &cell_id); void ClearGrad(const py::object &cell, const py::args &args); // Abnormal existed void ClearRes(); @@ -496,7 +380,6 @@ class PynativeExecutor : public std::enable_shared_from_this { }; using PynativeExecutorPtr = std::shared_ptr; -} // namespace pynative -} // namespace mindspore +} // namespace mindspore::pynative #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 5a5d67f3515a006d0cb5e478346760c7129fd9db..aeb464c9034349f6a1e419aae9e761e3db603c66 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -311,6 +311,9 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { auto primitive_py = primitive->cast(); MS_EXCEPTION_IF_NULL(primitive_py); this->set_hook(primitive_py->hook()); + if (primitive_py->HasAttr(kBpropAttrName)) { + this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName)); + } } BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index d77c0949da5f0b3979dc85104ffe390c626ac1f0..c58fd398f31cd032186144f12e69bc8a5633146f 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -300,6 +300,8 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v } // Clear the output address of graph. ClearOutputAddress(inputs, value_nodes, execution_order); + + graph_output_map_.erase(graph_id); } void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph) { diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index d1d5eacf84890982a2db45eeb1cbdec0ba5e9ae9..e58459cb3616e738f2635dcfe26f532318c9ce8c 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -286,13 +286,13 @@ void TensorValueToTensor(const ValuePtr &value, std::vector * if (element->isa()) { auto tensor = element->cast(); MS_EXCEPTION_IF_NULL(tensor); - tensors->push_back(tensor); + tensors->emplace_back(tensor); } } } else if (value->isa()) { auto tensor = value->cast(); MS_EXCEPTION_IF_NULL(tensor); - tensors->push_back(tensor); + tensors->emplace_back(tensor); } } } // namespace mindspore diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 7a15d3a0674a949217e19b9dca6a7fc6206c1720..b6118fb142d8f9597674efecc8661761c28045ad 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -515,7 +515,13 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri auto input_x = CheckArg(op_name, args_spec_list, 0); MS_EXCEPTION_IF_NULL(input_x); MS_EXCEPTION_IF_NULL(input_x->shape()); - auto input_type = primitive->GetAttr("dst_type")->cast(); + auto attr = primitive->GetAttr("dst_type"); + if (attr == nullptr) { + attr = args_spec_list[1]->BuildValue(); + MS_EXCEPTION_IF_NULL(attr); + primitive->set_attr("dst_type", attr); + } + auto input_type = attr->cast(); auto ret = std::make_shared(input_type, input_x->shape()->shape()); return ret; } diff --git a/mindspore/core/abstract/prim_structures.cc b/mindspore/core/abstract/prim_structures.cc index 859a6d1f9e4462bd0b13a9abe1ae02e20e240d5e..7711b74bbff51707798767196793ec11234b64a3 100644 --- a/mindspore/core/abstract/prim_structures.cc +++ b/mindspore/core/abstract/prim_structures.cc @@ -165,8 +165,7 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra if (dyn_cast(queue->elements()[0]) != nullptr) { return std::make_shared(queue->elements()[0]->BuildType()); } - MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " - << index_value->ToString(); + MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " << index->ToString(); } auto idx_v = GetValue(index_value); std::size_t nelems = queue->elements().size(); diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index aae51466d1f806682201b4a1d011e44f56ed2141..23e795aead5eaf27936cc1b1db7d365e9aec5ec7 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -356,13 +356,14 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) { - auto &others = source->func_graph_cnodes_index(); - for (auto it = others.begin(); it != others.end(); it++) { + const auto &users = source->func_graph_cnodes_index(); + for (auto &user : users) { // Ignore the user graph who may own itself. - auto fg = it->first->first->func_graph(); + auto anfnode = user.first->first.lock(); + auto fg = anfnode->func_graph(); MS_EXCEPTION_IF_NULL(fg); if (fg.get() != this) { - AddFuncGraphCNodeIndex(it->first, it->second); + AddFuncGraphCNodeIndex(user.first, user.second); } } } @@ -384,7 +385,7 @@ void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { } else { func_graph_cnodes_index_[pair]--; if (func_graph_cnodes_index_[pair] < 0) { - MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second + MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first.lock() << "/" << pair->second << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); } } diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 2bf50017301d5f404dcfebbd9b0be049f30f849a..54a0ec5c36cdbb791d49ae939314e0cf4675a9b7 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -42,12 +42,15 @@ namespace mindspore { using BaseRefCounterMap = OrderedMap; using FuncGraphCounterMap = OrderedMap; +using CNodeIndexPair = std::pair; +using CNodeIndexPairPtr = std::shared_ptr; struct CNodeIndexHasher { std::size_t operator()(const CNodeIndexPairPtr pair) const { MS_EXCEPTION_IF_NULL(pair); - MS_EXCEPTION_IF_NULL(pair->first); - return hash_combine(pair->first->hash(), std::hash()(pair->second)); + AnfNodePtr node_ptr = pair->first.lock(); + MS_EXCEPTION_IF_NULL(node_ptr); + return hash_combine(node_ptr->hash(), std::hash()(pair->second)); } }; @@ -59,7 +62,7 @@ struct CNodeIndexEqual { if (lhs == rhs) { return true; } - if (lhs->first != rhs->first) { + if (lhs->first.lock() != rhs->first.lock()) { return false; } if (lhs->second != rhs->second) { diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index c9cc36ec04234f76f94558cefa14127c61178222..c810bd238c26765a11b7967b7c7fef33985d44bc 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -199,10 +199,11 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func target_func_graph->set_return(return_node); } - auto &cnodes = func_graph->func_graph_cnodes_index(); - for (auto &cnode : cnodes) { - auto parent = cnode.first->first->cast(); - auto valuenode = parent->input(cnode.first->second); + const auto &users = func_graph->func_graph_cnodes_index(); + for (auto &user : users) { + auto anf_node = user.first->first.lock(); + auto parent = anf_node->cast(); + auto valuenode = parent->input(user.first->second); CloneValueNode(valuenode, target_func_graph); } } @@ -414,8 +415,9 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph if (lift_params.empty()) { return; } - for (auto &cnode : func_graph_user->func_graph_cnodes_index()) { - LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params); + for (auto &user : func_graph_user->func_graph_cnodes_index()) { + auto anf_node = user.first->first.lock(); + LiftParameters(anf_node->func_graph(), func_graph_user, lift_params); } } @@ -427,8 +429,9 @@ void Cloner::Lift() { auto iter = repl_func_graph_params_.find(func_graph); if (iter != repl_func_graph_params_.end()) { auto ¶ms = iter->second; - for (auto &cnode : func_graph->func_graph_cnodes_index()) { - LiftParameters(cnode.first->first->func_graph(), func_graph, params); + for (auto &user : func_graph->func_graph_cnodes_index()) { + auto anf_node = user.first->first.lock(); + LiftParameters(anf_node->func_graph(), func_graph, params); } } } diff --git a/mindspore/core/ir/graph_utils.cc b/mindspore/core/ir/graph_utils.cc index 67b5050b32ceded781b7fe4494a227f6e28a2519..6c82f3bde7edf4adf116a8da6bf2b73db07662ba 100644 --- a/mindspore/core/ir/graph_utils.cc +++ b/mindspore/core/ir/graph_utils.cc @@ -41,8 +41,8 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c return res; } size_t seen = NewSeenGeneration(); - std::deque todo(1024); - todo.clear(); + std::vector todo; + todo.reserve(1024); todo.push_back(root); while (!todo.empty()) { @@ -95,14 +95,15 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c // search the cnodes inside this graph only std::vector BroadFirstSearchGraphCNodes(const std::vector &starts) { - std::deque todo(1024); - todo.clear(); + std::vector todo; + todo.reserve(1024); todo.insert(todo.end(), starts.begin(), starts.end()); std::vector sorted_nodes; auto seen = NewSeenGeneration(); - while (!todo.empty()) { - CNodePtr top = todo.front(); - todo.pop_front(); + size_t top_idx = 0; + while (top_idx < todo.size()) { + CNodePtr top = todo[top_idx]; + top_idx++; sorted_nodes.push_back(top); auto inputs = top->inputs(); for (auto &item : inputs) { @@ -121,13 +122,14 @@ std::vector BroadFirstSearchGraphCNodes(const std::vector &s // search the cnode match the predicate inside this graph only CNodePtr BroadFirstSearchFirstOf(const std::vector &starts, const MatchFunc &match_predicate) { - std::deque todo(1024); - todo.clear(); + std::vector todo; + todo.reserve(1024); todo.insert(todo.end(), starts.begin(), starts.end()); auto seen = NewSeenGeneration(); - while (!todo.empty()) { - CNodePtr top = todo.front(); - todo.pop_front(); + size_t top_idx = 0; + while (top_idx < todo.size()) { + CNodePtr top = todo[top_idx]; + top_idx++; if (match_predicate(top)) { return top; } @@ -147,13 +149,16 @@ CNodePtr BroadFirstSearchFirstOf(const std::vector &starts, const Matc } std::vector BroadFirstSearchGraphUsed(FuncGraphPtr root) { - std::deque todo; + std::vector todo; + todo.reserve(128); todo.push_back(root); std::vector sorted; + sorted.reserve(128); auto seen = NewSeenGeneration(); - while (!todo.empty()) { - FuncGraphPtr top = todo.front(); - todo.pop_front(); + size_t top_idx = 0; + while (top_idx < todo.size()) { + FuncGraphPtr top = todo[top_idx]; + top_idx++; sorted.push_back(top); auto used = top->func_graphs_used(); for (auto &item : used) { diff --git a/mindspore/core/ir/manager.h b/mindspore/core/ir/manager.h index 394f40a867e4f95f3380dc58dad4cdb775b00b98..1052be99ea165eefc9e636d354e6bd9bba9c93f9 100644 --- a/mindspore/core/ir/manager.h +++ b/mindspore/core/ir/manager.h @@ -93,8 +93,6 @@ struct Signals { enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 }; -using CNodeIndexPair = std::pair; -using CNodeIndexPairPtr = std::shared_ptr; using FuncGraphToFuncGraphSetMap = OrderedMap; // analysis base class, graphs analysis which need dynamic compute by DepCollector in each read diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h index b57e88357968d606b59cdd2695dce29921c082b5..ba7a71175ac61988a750d6808e16fcf6676b4a86 100644 --- a/mindspore/core/ir/primitive.h +++ b/mindspore/core/ir/primitive.h @@ -45,7 +45,7 @@ class Primitive : public Named { Primitive(const std::string &name, const std::unordered_map &attrs); Primitive(const Primitive &prim); MS_DECLARE_PARENT(Primitive, Named); - abstract::AbstractBasePtr ToAbstract(); + abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); std::string ToString() const override { return name(); } void BeginRecordAddAttr() { diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 55559f63a5decbf21fc94c270e774fc5c5303ee8..4fe7c8a38c0a1077ee18c857e0213eae4d12501e 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -345,15 +345,9 @@ class Cell(Cell_): for item in inputs: if isinstance(item, numpy.ndarray): raise TypeError("cell inputs should not be numpy array.") - origin_grad = [] if self.requires_grad is True: _pynative_exec.set_grad_flag(True) - _pynative_exec.new_graph(self, *inputs, **kwargs) - for cell in self.cells(): - origin_grad.append(cell.requires_grad) - cell.set_grad(True) - else: - _pynative_exec.set_grad_flag(False) + _pynative_exec.new_graph(self, *inputs, **kwargs) cast_inputs = list() if hasattr(self, "_mindspore_flags"): if self._mindspore_flags.get('fp16'): @@ -365,10 +359,7 @@ class Cell(Cell_): output = self.run_construct(cast_inputs, kwargs) if isinstance(output, Parameter): output = output.data - if self.requires_grad is True: - _pynative_exec.end_graph(self, output, *inputs, **kwargs) - for i, cell in enumerate(self.cells()): - cell.set_grad(origin_grad[i]) + _pynative_exec.end_graph(self, output, *inputs, **kwargs) return output def _add_attr(self, name, value): diff --git a/tests/ut/cpp/optimizer/ad/kpynative_test.cc b/tests/ut/cpp/optimizer/ad/kpynative_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..97cc0d65201714d7d9b708213ae9ecfa22a0286a --- /dev/null +++ b/tests/ut/cpp/optimizer/ad/kpynative_test.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "frontend/optimizer/ad/kpynative.h" +#include "common/common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "ir/manager.h" +#include "ir/value.h" +#include "ir/func_graph_cloner.h" +#include "utils/log_adapter.h" +#include "ir/graph_utils.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/parse/parse.h" +#include "debug/anf_ir_utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace ad { +class TestKPynative : public UT::Common { + public: + pipeline::ResourcePtr resource = std::make_shared(); + + protected: + AbstractBasePtr BuildArg() { + std::vector shp = {2, 2}; + tensor::TensorPtr tensor = std::make_shared(kFloat32->type_id(), shp); + auto abstract = tensor->ToAbstract(); + return abstract; + } + + FuncGraphPtr BuildPrimalFuncGraph(const std::string& testCase) { + auto g = std::make_shared(); + auto x = g->add_parameter(); + auto y = g->add_parameter(); + x->set_abstract(BuildArg()); + y->set_abstract(BuildArg()); + auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y}); + c_node->set_abstract(BuildArg()); + g->set_output(c_node); + return g; + } + + // a = x * y + // b = stop_gradient(a) + // c = b * y + // return c + FuncGraphPtr BuildStopGradient(const std::string &testCase) { + auto g = std::make_shared(); + auto x = g->add_parameter(); + x->debug_info()->set_name("x"); + auto y = g->add_parameter(); + y->debug_info()->set_name("y"); + x->set_abstract(BuildArg()); + y->set_abstract(BuildArg()); + auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y}); + a_node->set_abstract(BuildArg()); + auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node}); + b_node->set_abstract(BuildArg()); + auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y}); + c_node->set_abstract(BuildArg()); + auto d_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node}); + d_node->set_abstract(BuildArg()); + g->set_output(d_node); + return g; + } + + FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) { + auto k_pynative_cell = GradPynativeCellBegin(primal_fg->parameters()); + auto node_list = TopoSort(primal_fg->output()); + for (auto node : node_list) { + if (node->isa()) { + auto c_node = node->cast(); + auto out = c_node->abstract()->GetValueTrack(); + ValuePtrList args; + for (size_t i = 1; i < c_node->inputs().size(); ++i) { + args.push_back(c_node->input(i)->abstract()->GetValueTrack()); + } + GradPynativeOp(k_pynative_cell, c_node, args, out); + } + } + auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, true, false); + return bprop_fg; + } +}; + +TEST_F(TestKPynative, test_simple_add) { + auto primal_fg = BuildPrimalFuncGraph("test_simple_add"); + resource->manager()->KeepRoots({primal_fg}); + ExportIR(primal_fg->ToString() + ".dat", "", primal_fg); + + auto bprop_fg = BuildBpropFuncGraph(primal_fg); + resource->manager()->KeepRoots({bprop_fg}); + + ExportIR(bprop_fg->ToString() + ".dat", "", bprop_fg); +} + +TEST_F(TestKPynative, test_stop_gradient) { + auto primal_fg = BuildStopGradient("test_stop_gradient"); + resource->manager()->KeepRoots({primal_fg}); + ExportIR(primal_fg->ToString() + ".dat", "", primal_fg); + + auto bprop_fg = BuildBpropFuncGraph(primal_fg); + resource->manager()->KeepRoots({bprop_fg}); + + ExportIR(bprop_fg->ToString() + ".dat", "", bprop_fg); +} +} // namespace ad +} // namespace mindspore