From c635940155f931e3339c17e9746b35f324e30c75 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Tue, 21 Feb 2023 16:15:10 +0800 Subject: [PATCH] fix bug --- .../extendrt/backend/lite/compile_result.h | 83 ++++++------ .../backend/lite/compile_result_builder.cc | 126 +++++++++--------- .../backend/lite/compile_result_builder.h | 16 +-- 3 files changed, 112 insertions(+), 113 deletions(-) diff --git a/mindspore/lite/src/extendrt/backend/lite/compile_result.h b/mindspore/lite/src/extendrt/backend/lite/compile_result.h index 08d149d6781..e7d491001ca 100644 --- a/mindspore/lite/src/extendrt/backend/lite/compile_result.h +++ b/mindspore/lite/src/extendrt/backend/lite/compile_result.h @@ -31,7 +31,7 @@ class CompileResult {}; using CompileResultPtr = std::shared_ptr; class LiteNode { -public: + public: void AppendInputTensor(const LiteTensorDelegatePtr &tensor) { if (tensor->Name().empty()) { tensor->SetName(this->name_ + "_in_" + std::to_string(this->inputs_.size())); @@ -55,10 +55,10 @@ public: using LiteNodePtr = std::shared_ptr; class LiteCompileResult : CompileResult { -public: + public: LiteNodePtr GetNode(const std::string &name) { - auto iter = nodes_.find(name); - if (iter == nodes_.end()) { + auto iter = node_map_.find(name); + if (iter == node_map_.end()) { return nullptr; } else { return iter->second; @@ -68,7 +68,7 @@ public: bool AppendNode(LiteNodePtr node) { if (node == nullptr) { MS_LOG(ERROR) << "Input node is nullptr"; - return false; + return false; } const std::string &node_name = node->name_; auto iter = node_map_.find(node_name); @@ -84,7 +84,7 @@ public: bool AppendTensor(LiteTensorDelegatePtr tensor) { if (tensor == nullptr) { MS_LOG(ERROR) << "Input tensor is nullptr"; - return false; + return false; } const auto &tensor_name = tensor->Name(); auto iter = tensor_map_.find(tensor_name); @@ -93,33 +93,33 @@ public: MS_LOG(ERROR) << "Duplicated tensor name : " << tensor_name; return false; } - result_->tensors_.emplace_back(tensor); + tensors_.emplace_back(tensor); tensor_map_[tensor_name] = tensor; return true; } - bool AppendInputTensor(LiteTensorDelegatePtr tensor) { + bool AppendInputTensor(LiteTensorDelegatePtr tensor) { if (tensor == nullptr) { MS_LOG(ERROR) << "Input tensor is nullptr"; - return false; + return false; } if (tensor->Name().empty()) { MS_LOG(ERROR) << "Input tensor has no name"; - return false; + return false; } auto ret = AppendTensor(tensor); if (!ret) { return ret; } - result_->inputs_.emplace_back(tensor); + inputs_.emplace_back(tensor); return true; } - bool AppendOutputTensor(LiteTensorDelegatePtr tensor) { + bool AppendOutputTensor(LiteTensorDelegatePtr tensor) { if (tensor == nullptr) { MS_LOG(ERROR) << "Input tensor is nullptr"; - return false; + return false; } if (tensor->Name().empty()) { tensor->SetName("graph_out_" + std::to_string(this->outputs_.size())); @@ -128,34 +128,30 @@ public: if (!ret) { return ret; } - auto ret = AppendTensor(tensor); - if (!ret) { - return ret; - } - result_->outputs_.emplace_back(tensor); + outputs_.emplace_back(tensor); return true; } bool AppendNodeInputTensor(LiteNodePtr lite_node, LiteTensorDelegatePtr tensor) { + if (lite_node == nullptr) { + MS_LOG(ERROR) << "Input lite_node is nullptr"; + return false; + } return AppendNodeInputTensor(lite_node->name_, tensor); } bool AppendNodeInputTensor(const std::string &node_name, LiteTensorDelegatePtr tensor) { - if (lite_node == nullptr) { - MS_LOG(ERROR) << "Input lite_node is nullptr"; - return false; - } if (tensor == nullptr) { MS_LOG(ERROR) << "Input tensor is nullptr"; - return false; + return false; } - - auto iter = nodes_.find(node_name); - if (iter == nodes_.end()) { - MS_LOG(ERROR) << "LiteNode not belong to this graph, node: " << lite_node->name_; - return false; + + auto iter = node_map_.find(node_name); + if (iter == node_map_.end()) { + MS_LOG(ERROR) << "LiteNode not belong to this graph, node: " << node_name; + return false; } - lite_node->AppendInputTensor(tensor); + iter->second->AppendInputTensor(tensor); auto ret = AppendTensor(tensor); if (!ret) { return ret; @@ -164,31 +160,31 @@ public: } bool AppendNodeOutputTensor(LiteNodePtr lite_node, LiteTensorDelegatePtr tensor) { + if (lite_node == nullptr) { + MS_LOG(ERROR) << "Input lite_node is nullptr"; + return false; + } return AppendNodeOutputTensor(lite_node->name_, tensor); } bool AppendNodeOutputTensor(const std::string &node_name, LiteTensorDelegatePtr tensor) { - if (lite_node == nullptr) { - MS_LOG(ERROR) << "Input lite_node is nullptr"; - return false; - } if (tensor == nullptr) { MS_LOG(ERROR) << "Input tensor is nullptr"; - return false; + return false; } - - auto iter = nodes_.find(node_name); - if (iter == nodes_.end()) { - MS_LOG(ERROR) << "LiteNode not belong to this graph, node: " << lite_node->name_; - return false; + + auto iter = node_map_.find(node_name); + if (iter == node_map_.end()) { + MS_LOG(ERROR) << "LiteNode not belong to this graph, node: " << node_name; + return false; } - lite_node->AppendOutputTensor(tensor); + iter->second->AppendOutputTensor(tensor); auto ret = AppendTensor(tensor); if (!ret) { return ret; } return true; - } + } std::vector nodes_; std::vector tensors_; @@ -198,8 +194,7 @@ public: HashMap tensor_map_{}; }; using LiteCompileResultPtr = std::shared_ptr; -} -} +} // namespace infer +} // namespace mindspore #endif - diff --git a/mindspore/lite/src/extendrt/backend/lite/compile_result_builder.cc b/mindspore/lite/src/extendrt/backend/lite/compile_result_builder.cc index 4cb1e61482c..0a9de868cab 100644 --- a/mindspore/lite/src/extendrt/backend/lite/compile_result_builder.cc +++ b/mindspore/lite/src/extendrt/backend/lite/compile_result_builder.cc @@ -22,15 +22,10 @@ #include "ops/core_ops.h" #include "ops/op_name.h" #include "utils/ms_utils_secure.h" -#include "include/api/status.h" namespace mindspore { namespace infer { namespace { -static std::string GetTensorName(const std::string &node_name, const size_t index = 0) { - return node_name + "_o" + std::to_string(index); -} - template bool IsContain(const std::vector &vec, T element) { for (auto iter = vec.begin(); iter != vec.end(); iter++) { @@ -40,13 +35,13 @@ bool IsContain(const std::vector &vec, T element) { } return false; } -} +} // namespace class TupleNodeSkiper {}; void LiteCompileResultBuilder::RearrangeInputTensor(LiteTensorDelegatePtr dst_idx, LiteTensorDelegatePtr src_idx) { for (auto &lite_node : graph_->nodes_) { if (lite_node == nullptr) { - continue; + continue; } for (size_t i = 0; i < lite_node->inputs_.size(); i++) { if (graph_->inputs_[i] == src_idx) { @@ -59,33 +54,32 @@ void LiteCompileResultBuilder::RearrangeInputTensor(LiteTensorDelegatePtr dst_id bool LiteCompileResultBuilder::IsIsolateNode(size_t index) { if (index >= graph_->nodes_.size()) { MS_LOG(ERROR) << "`index` out of range. `index`:" << index << ", range:" << graph_->nodes_.size(); - return false; + return false; } auto &lite_node = graph_->nodes_[index]; if (lite_node == nullptr) { MS_LOG(ERROR) << "Node at `index` is nullptr"; - return false; + return false; } return false; } bool LiteCompileResultBuilder::RemoveIsolateNode() { - + return false; } bool LiteCompileResultBuilder::RemoveTensor(size_t index) { - + return false; } bool LiteCompileResultBuilder::RemoveIsolateTensor() { for (auto iter = graph_->tensors_.begin(); iter != graph_->tensors_.end();) { - } return false; } bool LiteCompileResultBuilder::HandleMakeSeqNode() { - for (auto iter = graph_->nodes_.begin(); iter != graph_->nodes_.end(); ) { + for (auto iter = graph_->nodes_.begin(); iter != graph_->nodes_.end();) { auto &node = *iter; if (node->type_ != prim::kMakeTuple && node->type_ != prim::kMakeList) { iter++; @@ -94,7 +88,8 @@ bool LiteCompileResultBuilder::HandleMakeSeqNode() { MS_LOG(INFO) << "Handleing make sequence node: " << node->name_; auto tensor_number = node->inputs_.size(); if (tensor_number != node->outputs_.size()) { - MS_LOG(ERROR) << "MakeSequence node should has same number of inputs and outputs, but got " << tensor_number << " inputs and " << node->outputs_.size() << " outputs."; + MS_LOG(ERROR) << "MakeSequence node should has same number of inputs and outputs, but got " << tensor_number + << " inputs and " << node->outputs_.size() << " outputs."; return false; } for (size_t i = 0; i < tensor_number; i++) { @@ -106,7 +101,7 @@ bool LiteCompileResultBuilder::HandleMakeSeqNode() { } bool LiteCompileResultBuilder::HandleDependNode() { - for (auto iter = graph_->nodes_.begin(); iter != graph_->nodes_.end(); ) { + for (auto iter = graph_->nodes_.begin(); iter != graph_->nodes_.end();) { auto &node = *iter; if (node->type_ != prim::kDepend) { iter++; @@ -115,11 +110,11 @@ bool LiteCompileResultBuilder::HandleDependNode() { MS_LOG(INFO) << "Handleing Depend node: " << node->name_; if (node->inputs_.size() != 2) { MS_LOG(ERROR) << "Depend node should has 2 inputs, but got " << node->inputs_.size(); - return false; + return false; } if (node->outputs_.size() != 1) { MS_LOG(ERROR) << "Depend node should has 1 outputs, but got " << node->outputs_.size(); - return false; + return false; } this->RearrangeInputTensor(node->inputs_[0], node->outputs_[0]); iter = graph_->nodes_.erase(iter); @@ -128,11 +123,9 @@ bool LiteCompileResultBuilder::HandleDependNode() { } bool LiteCompileResultBuilder::HandleSeqGetItemNode() { - for (auto iter = graph_->nodes_.begin(); iter != graph_->nodes_.end(); ) { + for (auto iter = graph_->nodes_.begin(); iter != graph_->nodes_.end();) { auto &node = *iter; - if (node->type_ != prim::kTupleGetItem && - node->type_ != prim::kListGetItem && - node->type_ != prim::kArrayGetItem && + if (node->type_ != prim::kTupleGetItem && node->type_ != prim::kListGetItem && node->type_ != prim::kArrayGetItem && node->type_ != prim::kSliceGetItem) { iter++; continue; @@ -140,11 +133,11 @@ bool LiteCompileResultBuilder::HandleSeqGetItemNode() { MS_LOG(INFO) << "Handleing GetItem node: " << node->name_; if (node->inputs_.size() != 2) { MS_LOG(ERROR) << "GetItem node should has 2 inputs, but got " << node->inputs_.size(); - return false; + return false; } if (node->outputs_.size() != 1) { MS_LOG(ERROR) << "GetItem node should has 1 outputs, but got " << node->outputs_.size(); - return false; + return false; } this->RearrangeInputTensor(node->inputs_[0], node->outputs_[0]); iter = graph_->nodes_.erase(iter); @@ -154,6 +147,8 @@ bool LiteCompileResultBuilder::HandleSeqGetItemNode() { LiteCompileResultPtr LiteCompileResultBuilder::Build(const GraphSegmentPtr graph_segment) { graph_ = std::make_shared(); + // build graph inputs + // convert nodes for (auto &node : graph_segment->nodes_) { if (!utils::isa(node)) { continue; @@ -161,10 +156,11 @@ LiteCompileResultPtr LiteCompileResultBuilder::Build(const GraphSegmentPtr graph auto lite_node = CreateNode(utils::cast(node)); if (lite_node == nullptr) { MS_LOG(ERROR) << "Create lite node from cnode failed : " << node; - return nullptr; + return nullptr; } graph_->nodes_.emplace_back(lite_node); } + // build graph outputs return graph_; } @@ -172,7 +168,7 @@ namespace { class DataInfo; using DataInfoPtr = std::shared_ptr; class DataInfo { -public: + public: DataInfo() = default; ~DataInfo() = default; @@ -199,7 +195,7 @@ public: auto tensor_info = std::dynamic_pointer_cast(param_node->default_param()); if (tensor_info == nullptr) { MS_LOG(ERROR) << "Cast default-param to tensor failed."; - return nullptr; + return nullptr; } auto data_info = std::make_shared(); data_info->data_type_ = data_type; @@ -226,7 +222,8 @@ public: } auto abstract_tensor = value_abstract->cast(); if (abstract_tensor == nullptr) { - MS_LOG(ERROR) << "Abstract of tensor type value-node should be abstract tensor, " << value_node->fullname_with_scope(); + MS_LOG(ERROR) << "Abstract of tensor type value-node should be abstract tensor, " + << value_node->fullname_with_scope(); return nullptr; } ShapeVector shape_vector; @@ -234,7 +231,7 @@ public: auto ret = GetDTAndShapeFromAbTensor(abstract_tensor, &data_type, &shape_vector); if (ret != kSuccess) { MS_LOG(ERROR) << "Get data type and shape from value node failed."; - return nullptr; + return nullptr; } auto value = value_node->value(); if (value == nullptr) { @@ -245,7 +242,7 @@ public: if (value == nullptr) { MS_LOG(ERROR) << "Value of tensor-type value-node is not a Tensor, " << value_node->fullname_with_scope(); return nullptr; - } + } auto data_info = std::make_shared(); data_info->is_const_ = true; data_info->data_type_ = data_type; @@ -254,7 +251,8 @@ public: // process weight tensor data_info->data_.resize(data->Size()); if (data->Size() > 0) { - if (EOK != common::huge_memcpy(data_info->data_.data(), data_info->data_.size(), reinterpret_cast(data->data_c()), data->Size())) { + if (EOK != common::huge_memcpy(data_info->data_.data(), data_info->data_.size(), + reinterpret_cast(data->data_c()), data->Size())) { MS_LOG(ERROR) << "memcpy_s failed."; return nullptr; } @@ -344,10 +342,8 @@ public: return nullptr; } TypeId number_type = data->number_type(); - static const std::unordered_map TypeToTypeMap = - {{kNumberTypeInt, kNumberTypeInt32}, - {kNumberTypeUInt, kNumberTypeUInt32}, - {kNumberTypeFloat, kNumberTypeFloat32}}; + static const std::unordered_map TypeToTypeMap = { + {kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}}; if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) { number_type = TypeToTypeMap.at(number_type); } @@ -363,7 +359,8 @@ public: MS_ASSERT(value_node != nullptr); auto value_seq = utils::cast(value_node->value()); if (value_seq == nullptr) { - MS_LOG(ERROR) << "Value of Sequence type value-node is not a ValueSequencePtr, " << value_node->fullname_with_scope(); + MS_LOG(ERROR) << "Value of Sequence type value-node is not a ValueSequencePtr, " + << value_node->fullname_with_scope(); return nullptr; } auto data_info = std::make_shared(); @@ -375,7 +372,7 @@ public: auto data = GetValue>(value_seq); data_info->shape_ = {static_cast(data.size())}; data_info->data_.resize(data.size() * sizeof(int32_t)); - if (memcpy_s(data_info->data_.data(), data.size() * sizeof(int32_t), data.data(), + if (memcpy_s(data_info->data_.data(), data.size() * sizeof(int32_t), data.data(), data.size() * sizeof(int32_t)) != EOK) { MS_LOG(ERROR) << "memcpy_s failed"; return nullptr; @@ -385,11 +382,11 @@ public: auto data = GetValue>(value_seq); data_info->shape_ = {static_cast(data.size())}; data_info->data_.resize(data.size() * sizeof(int64_t)); - if (memcpy_s(data_info->data_.data(), data.size() * sizeof(int64_t), data.data(), + if (memcpy_s(data_info->data_.data(), data.size() * sizeof(int64_t), data.data(), data.size() * sizeof(int64_t)) != EOK) { MS_LOG(ERROR) << "memcpy_s failed"; return nullptr; - } + } } else { MS_LOG(ERROR) << "only support integer value ValueSequence."; return nullptr; @@ -410,7 +407,7 @@ public: } else if (value->isa()) { return CreateFromInt32ImmValue(value_node); } else if (value->isa()) { - return CreateFromInt64ImmValue(value_node); + return CreateFromInt64ImmValue(value_node); } else if (value->isa()) { return CreateFromBoolImmValue(value_node); } else if (value->isa()) { @@ -423,17 +420,18 @@ public: } } -private: - static StatusCode GetDTAndShapeFromAbTensor(const abstract::AbstractTensorPtr &abstract, TypeId *data_type, ShapeVector *shape_vector) { + private: + static StatusCode GetDTAndShapeFromAbTensor(const abstract::AbstractTensorPtr &abstract, TypeId *data_type, + ShapeVector *shape_vector) { MS_ASSERT(abstract != nullptr && data_type != nullptr && shape_vector != nullptr); if (abstract->element() == nullptr) { MS_LOG(ERROR) << "`element` of abstract is nullptr"; - return kLiteError; + return kLiteError; } auto type_ptr = abstract->element()->GetTypeTrack(); if (type_ptr == nullptr) { MS_LOG(ERROR) << "Type of abstract is nullptr"; - return kLiteError; + return kLiteError; } *data_type = type_ptr->type_id(); if (!utils::isa(abstract->BuildShape())) { @@ -444,7 +442,8 @@ private: return kSuccess; } - static StatusCode GetDTAndShapeFromParameter(const ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape_vector) { + static StatusCode GetDTAndShapeFromParameter(const ParameterPtr ¶m_node, TypeId *data_type, + ShapeVector *shape_vector) { MS_ASSERT(param_node != nullptr && data_type != nullptr && shape_vector != nullptr); auto abstract_base = param_node->abstract(); if (abstract_base == nullptr) { @@ -459,7 +458,7 @@ private: return GetDTAndShapeFromAbTensor(abstract_tensor, data_type, shape_vector); } -public: + public: TensorCompressionType compress_type_{kNoCompression}; int format_{0}; TypeId data_type_{kTypeUnknown}; @@ -468,7 +467,7 @@ public: std::vector data_{}; void *data_ptr_{nullptr}; }; -} +} // namespace StatusCode LiteCompileResultBuilder::AppendInputCNodeToInputs(const CNodePtr &cnode, LiteNodePtr lite_node) { if (cnode != nullptr) { @@ -488,13 +487,14 @@ StatusCode LiteCompileResultBuilder::AppendInputCNodeToInputs(const CNodePtr &cn auto ret = graph_->AppendNodeInputTensor(lite_node, input_node_output); if (!ret) { MS_LOG(ERROR) << "Append input tensor for node failed, node: " << lite_node->name_; - return kLiteError; + return kLiteError; } } return kSuccess; } -StatusCode LiteCompileResultBuilder::AppendInputParameterToInputs(const ParameterPtr ¶m_node, LiteNodePtr lite_node) { +StatusCode LiteCompileResultBuilder::AppendInputParameterToInputs(const ParameterPtr ¶m_node, + LiteNodePtr lite_node) { if (param_node != nullptr) { MS_LOG(ERROR) << "Input param_node is nullptr."; return kLiteParamInvalid; @@ -535,12 +535,13 @@ StatusCode LiteCompileResultBuilder::AppendInputParameterToInputs(const Paramete auto ret = graph_->AppendNodeInputTensor(lite_node, tensor); if (!ret) { MS_LOG(ERROR) << "Append input tensor for node failed, node: " << lite_node->name_; - return kLiteError; + return kLiteError; } return kSuccess; } -StatusCode LiteCompileResultBuilder::AppendInputValueNodeToInputs(const ValueNodePtr &value_node, LiteNodePtr lite_node) { +StatusCode LiteCompileResultBuilder::AppendInputValueNodeToInputs(const ValueNodePtr &value_node, + LiteNodePtr lite_node) { if (value_node != nullptr) { MS_LOG(ERROR) << "Input value_node is nullptr."; return kLiteParamInvalid; @@ -566,7 +567,7 @@ StatusCode LiteCompileResultBuilder::AppendInputValueNodeToInputs(const ValueNod auto ret = graph_->AppendNodeInputTensor(lite_node, tensor); if (!ret) { MS_LOG(ERROR) << "Append input tensor for node failed, node: " << lite_node->name_; - return kLiteError; + return kLiteError; } return kSuccess; } @@ -578,9 +579,10 @@ LiteNodePtr LiteCompileResultBuilder::CreateNode(const CNodePtr cnode) { // attrs auto primitive = GetValueNode>(cnode->input(0)); if (primitive == nullptr) { - MS_LOG(ERROR) << "Node has no primitive, first input of cnode(" << cnode->fullname_with_scope() << ") is : " << cnode->input(0); + MS_LOG(ERROR) << "Node has no primitive, first input of cnode(" << cnode->fullname_with_scope() + << ") is : " << cnode->input(0); return nullptr; - } + } static auto baseops_fns = ops::OperatorRegister::GetInstance().GetOperatorMap(); auto baseops_creator_iter = baseops_fns.find(lite_node->type_); if (baseops_creator_iter == baseops_fns.end()) { @@ -591,7 +593,7 @@ LiteNodePtr LiteCompileResultBuilder::CreateNode(const CNodePtr cnode) { lite_node->base_operator_ = baseops_creator(primitive); if (lite_node->base_operator_ == nullptr) { MS_LOG(ERROR) << "Create base-operator failed: " << cnode->fullname_with_scope(); - return nullptr; + return nullptr; } // inputs for (size_t i = 1; i < cnode->size(); i++) { @@ -604,7 +606,8 @@ LiteNodePtr LiteCompileResultBuilder::CreateNode(const CNodePtr cnode) { } else if (utils::isa(input)) { ret = this->AppendInputValueNodeToInputs(utils::cast(input), lite_node); } else { - MS_LOG(ERROR) << "Unsupported input node of cnode: " << input << ", current cnode: " << cnode->fullname_with_scope(); + MS_LOG(ERROR) << "Unsupported input node of cnode: " << input + << ", current cnode: " << cnode->fullname_with_scope(); ret = false; } if (ret != kSuccess) { @@ -616,16 +619,17 @@ LiteNodePtr LiteCompileResultBuilder::CreateNode(const CNodePtr cnode) { auto ret = BuildNodeOutputTensor(cnode, lite_node); if (!ret) { MS_LOG(ERROR) << "Create output tensors of cnode failed, cnode: " << cnode; - return nullptr; + return nullptr; } return lite_node; } -LiteTensorDelegatePtr LiteCompileResultBuilder::CreateTensorFromAbstractTensor(const abstract::AbstractBasePtr &abstract, const std::string &tensor_name) { - auto abs_tensor = utils::cast(abstract); +LiteTensorDelegatePtr LiteCompileResultBuilder::CreateTensorFromAbstractTensor( + const abstract::AbstractBasePtr &abstract, const std::string &tensor_name) { + auto abs_tensor = utils::cast(abstract); if (abs_tensor == nullptr) { MS_LOG(ERROR) << "Input abstract is not a AbstractTensor."; - return nullptr; + return nullptr; } auto typePtr = abs_tensor->element()->GetTypeTrack(); auto type_id = typePtr->type_id(); @@ -674,5 +678,5 @@ bool LiteCompileResultBuilder::BuildNodeOutputTensor(const CNodePtr &cnode, Lite } return true; } -} -} +} // namespace infer +} // namespace mindspore diff --git a/mindspore/lite/src/extendrt/backend/lite/compile_result_builder.h b/mindspore/lite/src/extendrt/backend/lite/compile_result_builder.h index 0fdbc272892..cdb1251a131 100644 --- a/mindspore/lite/src/extendrt/backend/lite/compile_result_builder.h +++ b/mindspore/lite/src/extendrt/backend/lite/compile_result_builder.h @@ -24,21 +24,23 @@ #include "src/extendrt/backend/lite/tensor_delegate.h" #include "abstract/abstract_value.h" #include "ir/anf.h" +#include "include/api/status.h" namespace mindspore { namespace infer { class LiteCompileResultBuilder { -public: + public: LiteCompileResultBuilder() = default; ~LiteCompileResultBuilder() = default; LiteCompileResultPtr Build(const GraphSegmentPtr graph_segment); -private: + private: LiteNodePtr CreateNode(const CNodePtr cnode); StatusCode AppendInputCNodeToInputs(const CNodePtr &cnode, LiteNodePtr lite_node); StatusCode AppendInputParameterToInputs(const ParameterPtr ¶m_node, LiteNodePtr lite_node); StatusCode AppendInputValueNodeToInputs(const ValueNodePtr &value_node, LiteNodePtr lite_node); - LiteTensorDelegatePtr CreateTensorFromAbstractTensor(const abstract::AbstractBasePtr &abstract, const std::string &tensor_name = ""); + LiteTensorDelegatePtr CreateTensorFromAbstractTensor(const abstract::AbstractBasePtr &abstract, + const std::string &tensor_name = ""); bool BuildNodeOutputTensor(const CNodePtr &cnode, LiteNodePtr lite_node); std::vector FindAllInputTensors(); @@ -55,13 +57,11 @@ private: bool HandleMakeSeqNode(); bool HandleDependNode(); -private: + private: LiteCompileResultPtr graph_ = nullptr; std::set input_names_{}; }; -} -} +} // namespace infer +} // namespace mindspore #endif - - -- Gitee