From 39927f38bc81431bf3321d3ebb9ae8acecd87da0 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Fri, 17 Dec 2021 09:47:29 +0800 Subject: [PATCH 01/34] add mindspore rewrite --- .../python/mindspore/rewrite/__init__.py | 5 + .../python/mindspore/rewrite/ast_base.py | 9 + .../python/mindspore/rewrite/ast_rewriter.py | 68 ++ .../python/mindspore/rewrite/ast_unparser.py | 137 +++++ .../mindspore/rewrite/ast_unparser_test.py | 57 ++ .../python/mindspore/rewrite/globals.log | 1 + mindspore/python/mindspore/rewrite/graph.py | 416 +++++++++++++ .../mindspore/rewrite/legacy/ast_utils.py | 52 ++ .../mindspore/rewrite/legacy/code_analysis.py | 133 ++++ .../python/mindspore/rewrite/legacy/graph.py | 73 +++ .../mindspore/rewrite/legacy/python_node.py | 57 ++ mindspore/python/mindspore/rewrite/lenet.py | 61 ++ .../python/mindspore/rewrite/namespace.py | 126 ++++ mindspore/python/mindspore/rewrite/node.py | 173 ++++++ mindspore/python/mindspore/rewrite/parser.py | 580 ++++++++++++++++++ .../mindspore/rewrite/pattern_engine.py | 348 +++++++++++ .../rewrite/pattern_engine_match_test.py | 191 ++++++ .../mindspore/rewrite/pattern_engine_test.py | 176 ++++++ .../python/mindspore/rewrite/rewriter.py | 70 +++ mindspore/python/mindspore/rewrite/test.py | 23 + .../python/mindspore/rewrite/test_network.py | 121 ++++ mindspore/python/mindspore/rewrite/ut.sh | 34 + requirements.txt | 2 + 23 files changed, 2913 insertions(+) create mode 100644 mindspore/python/mindspore/rewrite/__init__.py create mode 100644 mindspore/python/mindspore/rewrite/ast_base.py create mode 100644 mindspore/python/mindspore/rewrite/ast_rewriter.py create mode 100644 mindspore/python/mindspore/rewrite/ast_unparser.py create mode 100644 mindspore/python/mindspore/rewrite/ast_unparser_test.py create mode 100644 mindspore/python/mindspore/rewrite/globals.log create mode 100644 mindspore/python/mindspore/rewrite/graph.py create mode 100644 mindspore/python/mindspore/rewrite/legacy/ast_utils.py create mode 100644 mindspore/python/mindspore/rewrite/legacy/code_analysis.py create mode 100644 mindspore/python/mindspore/rewrite/legacy/graph.py create mode 100644 mindspore/python/mindspore/rewrite/legacy/python_node.py create mode 100644 mindspore/python/mindspore/rewrite/lenet.py create mode 100644 mindspore/python/mindspore/rewrite/namespace.py create mode 100644 mindspore/python/mindspore/rewrite/node.py create mode 100644 mindspore/python/mindspore/rewrite/parser.py create mode 100644 mindspore/python/mindspore/rewrite/pattern_engine.py create mode 100644 mindspore/python/mindspore/rewrite/pattern_engine_match_test.py create mode 100644 mindspore/python/mindspore/rewrite/pattern_engine_test.py create mode 100644 mindspore/python/mindspore/rewrite/rewriter.py create mode 100644 mindspore/python/mindspore/rewrite/test.py create mode 100644 mindspore/python/mindspore/rewrite/test_network.py create mode 100644 mindspore/python/mindspore/rewrite/ut.sh diff --git a/mindspore/python/mindspore/rewrite/__init__.py b/mindspore/python/mindspore/rewrite/__init__.py new file mode 100644 index 00000000000..2a3665186ba --- /dev/null +++ b/mindspore/python/mindspore/rewrite/__init__.py @@ -0,0 +1,5 @@ +from .graph import Graph +from .node import Node, NodeType +from .pattern_engine import PatternEngine, PatternNode, PlaceHolderNode + +__all__ = ["Graph", "Node", "NodeType", "PatternEngine", "PatternNode", "PlaceHolderNode"] diff --git a/mindspore/python/mindspore/rewrite/ast_base.py b/mindspore/python/mindspore/rewrite/ast_base.py new file mode 100644 index 00000000000..231ff32eb8d --- /dev/null +++ b/mindspore/python/mindspore/rewrite/ast_base.py @@ -0,0 +1,9 @@ +import ast + + +class BaseNodeTransformer(ast.NodeTransformer): + pass + + +class BaseNodeVisitor(ast.NodeVisitor): + pass diff --git a/mindspore/python/mindspore/rewrite/ast_rewriter.py b/mindspore/python/mindspore/rewrite/ast_rewriter.py new file mode 100644 index 00000000000..6fca6cc5978 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/ast_rewriter.py @@ -0,0 +1,68 @@ +import ast +from .ast_base import BaseNodeTransformer, BaseNodeVisitor + + +class NodeRemover(BaseNodeTransformer): + def __init__(self): + self.line_no = [] + self.node_name = [] + + def remove_node_by_lineno(self): + pass + + def remove_node_by_name(self): + pass + + +class NodeReplacer(BaseNodeTransformer): + def replace_if(self): + pass + + def replace_assign(self): + pass + + +class NodeParser(BaseNodeVisitor): + def visit_all_children(self, node: ast.AST): + pass + + def visit_if(self, node: ast.AST): + pass + + def visit_assign(self, node: ast.AST): + line_no = node.lineno + target = node.targets[0] + + pass + + def visit_call(self, node: ast.AST): + pass + + def visit_attribute(self, node: ast.AST): + pass + + +class CodeAnalyzer(BaseNodeVisitor): + def visit_if(self, node: ast.AST): + pass + + def visit_assign(self, node: ast.AST): + pass + + def visit_call(self, node: ast.AST): + pass + + def visit_attribute(self, node: ast.AST): + pass + + +class UpdateLineCol(BaseNodeVisitor): + pass + + +class NodeModifier(BaseNodeVisitor): + def modify_attribute(self, node: ast.AST): + pass + + def modify_assign(self, node: ast.AST): + pass diff --git a/mindspore/python/mindspore/rewrite/ast_unparser.py b/mindspore/python/mindspore/rewrite/ast_unparser.py new file mode 100644 index 00000000000..96b28d2db15 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/ast_unparser.py @@ -0,0 +1,137 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +unpaser ast string to nn.Cell +""" +import inspect +import logging +import copy + +import mindspore.nn as nn + + +class ASTUnparser: + """ + unparser ast string to nn.Cell + + This does four things: + + - replace original class name to _transformed_cell_name in ast unparsed source code + - get original implementation source code of the class (for dependecies) + - insert source code to original code -> _modified_source_code + - exec the _modified_source_code and get class object + + Example: + + >>> from mindspore.rewrite.ast_unparser import ASTUnparser + >>> unparser = ASTUnparser(LeNet5, source_code) + >>> res_class = unparser.get_res_cell() # then you can use res_class same as LeNet5 + + """ + + def __init__(self, network: nn.Cell, source_code: str): + self._network = network + self._source_code = source_code # ast uparse result + self._network_original_file_path = inspect.getfile(self._network) + self._original_source_code = None + self._modified_source_code = None # output res source code + self._transformed_cell_name = 'TransfromedCellThisNameCanNotBeSameAsOrigin' + + def _get_original_source_code(self): + """ + get original file source code, only rewrite target part + for getting dependencies modules and funcs + """ + with open(self._network_original_file_path, 'r') as original_f: + self._original_source_code = original_f.read() + + def print_original_source_code(self): + logging.info(self._original_source_code) + + def print_res_source_code(self): + logging.info(self._modified_source_code) + + def _replace_class_name(self): + """ + replace class name in self._source_code to self._transformed_cell_name + method: + first find target line, which is 'class Net(***):', + then get name 'Net' by target_line.split(' ')[1].split('(')[0] + + """ + class_name = None + + source_code_lines = self._source_code.split('\n') + target_line = None + for line in source_code_lines: + if 'class' in line: + target_line = line + break + + if target_line: + class_name = target_line.split(' ')[1].split('(')[0] + + self._source_code = self._source_code.replace(class_name, self._transformed_cell_name) + + def _insert_source_code(self): + """ + Insert source code to original source code and delete old part (for dependencies) + get final res to be written into a python script + """ + or_code_lines = self._original_source_code.splitlines() + + _modified_source_code_lines = [] + meet_target_class_flag = False + # delete old part in original code + # when first meet 'class targetNet():' set meet_target_class_flag True + # then when meet other func or class or else, set meet_target_class_flag False + for line in or_code_lines: + # TODO mind corner cases + if meet_target_class_flag and not line.startswith(' ') and line != '': + meet_target_class_flag = False + + # delete code "if __name__ in '__main__'" + if 'main' in line and 'if' in line and 'name' in line: + meet_target_class_flag = True + + # delete target class definition if exists + if self._transformed_cell_name in line and 'class' in line: + meet_target_class_flag = True + + if not meet_target_class_flag: + _modified_source_code_lines.append(line) + + # insert source code + _modified_source_code_lines.extend(self._source_code.splitlines()) + self._modified_source_code = '\n'.join(_modified_source_code_lines) + + def get_res_cell(self) -> nn.Cell: + # replace original class name to self._transformed_cell_name + self._replace_class_name() + + # get original source code + self._get_original_source_code() + + # insert source code + self._insert_source_code() + + # get res + # TODO note that _modified_source_code_lines must be executed one time, + # if _modified_source_code_lines includes code like main, it may go wrong + global_dict = copy.copy(self.__init__.__globals__) + code_obj = compile(self._modified_source_code, "", "exec") + exec(code_obj, global_dict) + + return global_dict[self._transformed_cell_name] diff --git a/mindspore/python/mindspore/rewrite/ast_unparser_test.py b/mindspore/python/mindspore/rewrite/ast_unparser_test.py new file mode 100644 index 00000000000..f238184c14c --- /dev/null +++ b/mindspore/python/mindspore/rewrite/ast_unparser_test.py @@ -0,0 +1,57 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +ast unparser ut +""" +import unittest +import ast +import astunparse +import inspect + +from mindspore.rewrite.ast_unparser import ASTUnparser +from mindspore.nn import Cell, Conv2d, Dense, ReLU, MaxPool2d, Flatten + + +class LeNet5(Cell): + + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = Conv2d(1, 6, 5) + self.fc1 = Dense(((16 * 5) * 5), 120) + self.relu = ReLU() + self.max_pool2d = MaxPool2d(kernel_size=2, stride=2) + self.flatten = Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + return x + + +class ASTUnparserTestCase(unittest.TestCase): + def test_unparse(self): + source_code = astunparse.unparse(ast.parse(inspect.getsource(LeNet5))) + unparser = ASTUnparser(LeNet5, source_code) + res = unparser.get_res_cell() + self.assertEqual(type(LeNet5), type(res)) + + +if __name__ == "__main__": + unittest.main() diff --git a/mindspore/python/mindspore/rewrite/globals.log b/mindspore/python/mindspore/rewrite/globals.log new file mode 100644 index 00000000000..000728d2a72 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/globals.log @@ -0,0 +1 @@ +input globals: {'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <_frozen_importlib_external.SourceFileLoader object at 0x7f6a9a4f04d0>, '__spec__': None, '__annotations__': {}, '__builtins__': , '__file__': 'test_app.py', '__cached__': None, 'Graph': , 'parse': , 'ControlSimpleIf': } diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py new file mode 100644 index 00000000000..d08e8965aa4 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -0,0 +1,416 @@ +from collections import OrderedDict +import inspect +from types import FunctionType +import types +from typing import DefaultDict, Dict, List, Union +import ast +import astunparse +import astpretty + +import mindspore.nn as nn +from mindspore.ops.primitive import Primitive + +from .node import AttributeNode, ConstantNode, Node, NodeType, PlaceholderNode +from .parser import Parser + +class _node_list: + def __init__(self, graph) -> None: + self._graph = graph + self._nodes: List = [] + + def __iter__(self): + self._count = 0 + self._node_num = len(self._graph._nodes) + self._visited = [] + for node in self._graph._placeholders: + self._nodes.append(node) + self._visited.append(id(node)) + return self + + def __next__(self): + while self._nodes: + flag = True + for n in self._nodes[0].inputs: + if id(n) not in self._visited: + flag = False + break + if flag: + for n in self._nodes[0].outputs: + if id(n) not in self._visited: + self._nodes.append(n) + self._visited.append(id(n)) + n = self._nodes.pop(0) + return n + else: + n = self._nodes.pop(0) + self._nodes.append(n) + continue + raise StopIteration + +class Graph(): + def __init__(self, network: nn.Cell): + self._name = network.__name__ + self._network = network + self._ast_root: ast.AST = None + self._root = None + self._nodes: List = [] + self._base_scope = network.__name__ + self._parser = Parser(network) + self._placeholders: List[Node] = [] + self._contant_nodes: List[ConstantNode] = [] + self._param_default_value: Dict = {} + self._ast_function_root: Dict = {} + self._node_attributes: Dict = {} + self._subgraphs: Dict = {} + + @property + def nodes(self) -> list: + """ + 返回graph的节点,可以迭代访问,这些节点中应该还要包含init中的子图,在pattern匹配的时候会出现该问题 + """ + return self._nodes + + def set_root(self, root: Node): + self._root = root + self._nodes.clear() + queue = [root] + while len(queue) > 0: + cur_node = queue.pop(0) + for input in cur_node.inputs: + queue.append(input) + self._nodes.append(cur_node) + self._nodes.reverse() + + def root(self): + return self._root + + def create_ast(self): + network_str = inspect.getsource(self._network) + self._ast_root = ast.parse(network_str) + + def print_ast(self): + astpretty.pprint(self._ast_root) + + def bfs(self) -> _node_list: + return _node_list(self) + + def find(self, full_name_with_scope: str) -> Node: + """ + 以什么条件查找, 返回查找到的所有节点,找不到则返回空列表 + """ + for node in self.nodes: + if node.name == full_name_with_scope: + return node + + return None + + def find_by_instance(self, Type) -> list(): + """ + 按照类别查找节点 + """ + pass + + def visit_start_with_node(self, node: Node) -> _node_list: + """ + 从某个节点开始遍历图 + """ + return _node_list(node) + + def get_function_root(self): + """ + 1.使用walk来遍历所有节点,找到FunctionDef的节点这种方式不能够区分出class,我们只传入一个cell应该不存在多个class的问题 + 2.直接遍历class的body,该方式效率会高一些 + """ + for node in self._ast_root.body[0].body: + if isinstance(node, ast.FunctionDef): + self._ast_function_root[node.name] = node + + def create_placeholder(self, ast_function: ast.FunctionDef): + ast_node = ast_function.args + args = self._parser.parse_arguments(ast_node) + for name, value in args.items(): + new_node = PlaceholderNode(name, name, ast_node, default_value=value) + print ("placeholder node:", new_node) + self._nodes.append(new_node) + + def parse_init(self): + """ + 解析init函数,获取相关算子的属性信息 + """ + if "__init__" not in self._network.__dict__.keys(): + return + print("================= parse init function ========================") + self.create_placeholder(self._ast_function_root["__init__"]) + self._parser.updete_closure_namespace(self._network.__init__) + for ast_node in self._ast_function_root["__init__"].body: + if isinstance(ast_node, ast.Expr): + continue + + if isinstance(ast_node, ast.Assign): + new_node = self._parser.parse_init_assign(ast_node) + print("new node in init function", new_node) + self._node_attributes[new_node.name] = new_node + + return + + def parse_construct(self): + print("================= parse construct function start ========================") + self._parser.updete_closure_namespace(self._network.construct) + self.create_placeholder(self._ast_function_root["construct"]) + name_counts = {} #save the number of the variable, if the number is over 1,then modify the name - add a number as the name suffix + index = 0 + for ast_node in self._ast_function_root["construct"].body: + print(ast_node) + if isinstance(ast_node, ast.Expr): + continue + method = 'parse_' + ast_node.__class__.__name__ + visitor = getattr(self._parser, method, None) + + nodes, attribute_names = visitor(ast_node) + + for i in range(len(nodes)): + if attribute_names and attribute_names[i] in self._node_attributes.keys(): + print("defined in init function: ", attribute_names[i]) + nodes[i]._attribute = self._node_attributes[attribute_names[i]] + elif nodes[i].name.split(".")[-1] in dir(self._network): + nodes[i]._attribute._type = NodeType.call_method + nodes[i]._attribute._is_custom_define = True + print("self defined func: ", nodes[i].name) + elif self._parser.get_func_namesapce(nodes[i].name.split(".")[-1]): + class_, name_space_, is_custom_define_ = self._parser.get_func_namesapce(nodes[i].name.split(".")[-1]) + print("defined in other namespace") #must resolve the undefined symble + print ("class: ", class_, "name space: ", name_space_, "is custom define: ", is_custom_define_) + nodes[i]._attribute._is_custom_define = is_custom_define_ + nodes[i]._attribute._class = class_ + nodes[i]._attribute._type = NodeType.call_function + if is_custom_define_: + subgraph = self.parse_function(class_) + print("self defined subgraph: ", subgraph) + else: + print("undefined symbole ....") + + name = self._base_scope + "." + nodes[i].name.split(".")[-1] + if name in name_counts.keys(): + name_counts[name] += 1 + name = name + "_" + str(name_counts[name]) + else: + name_counts[name] = 0 + + nodes[i].name = name + nodes[i]._index = index + self._find_input_node(nodes[i]) + index += 1 + self._nodes.append(nodes[i]) + print("======================= construct nodes =========================") + for node in self._nodes: + print(node) + print("================= parse construct function end ========================") + + def parse_function(self, func: Union[ast.FunctionDef, FunctionType]): + if isinstance(func, FunctionType): + print("================= parse " + func.__name__ + " function ========================") + function_str = inspect.getsource(func) + ast_root = ast.parse(function_str) + astpretty.pprint (ast_root) + node = ast_root.body[0] + subgraph = FunctionGraph(func) #要区分类内还是类外方法 + else: + print("================= parse " + func.name + " function ========================") + node = func + subgraph = FunctionGraph(self._network.__dict__[node.name]) #要区分类内还是类外方法 + subgraph._name = node.name + subgraph._ast_root = node + subgraph.create_placeholder(node) + + index = 0 + for ast_node in subgraph._ast_root.body: + print(ast_node) + if isinstance(ast_node, ast.Expr): + continue + method = 'parse_' + ast_node.__class__.__name__ + visitor = getattr(self._parser, method, None) + + nodes, attribute_names = visitor(ast_node) + for i in range(len(nodes)): + nodes[i]._index = index + nodes[i].name = self._base_scope + "." + node.name + "." + nodes[i].name + if attribute_names and attribute_names[i] in self._node_attributes.keys(): + nodes[i]._attribute = self._node_attributes[attribute_names[i]] + else: + print("can not find attribute") + + subgraph._find_input_node(nodes[i]) + subgraph._nodes.append(nodes[i]) + index += 1 + print(subgraph._name + " nodes: ") + for node in subgraph._nodes: + print (node) + print("====================== parse function end =====================") + return subgraph + + def parse_functions(self): + for name, ast_root in self._ast_function_root.items(): + if name == "__init__" or name == "construct": + continue + + subgraph = self.parse_function(ast_root) + print("name = ", name, "; subgraph = ", subgraph) + self._subgraphs["self." + name] = subgraph + + def _find_input_node(self, node: Node): + for arg in node._args: + flag = 0 + if isinstance(arg, int): + for n in self._contant_nodes: + if arg in n._targets: + node.inputs.append(n) + n.outputs.append(node) + flag = 1 + break + if flag == 0: + new_node = ConstantNode(arg) + new_node.outputs.append(node) + node.inputs.append(new_node) + self._contant_nodes.append(new_node) + continue + + for i in range(len(self._nodes) - 1, -1, -1): #需要反向查找节点 + print ("arg: ", arg, "; n.targets: ", self._nodes[i]._targets) + if arg in self._nodes[i]._targets: + node.inputs.append(self._nodes[i]) + self._nodes[i].outputs.append(node) + flag = 1 + break + + if flag == 0: + for n in self._placeholders: + if arg in n._target: + node.inputs.append(n) + n._outputs.append(node) + flag = 1 + break + + if arg in self._node_attributes.keys(): + node.inputs.append(self._node_attributes[arg]) + continue + + return + + def parse_attr_subgraph(self): + for name, node in self._node_attributes.items(): + print ("name: ", name, "; node: ", node) + if node._is_custom_define and isinstance(node._class, FunctionType): + print("is function") + subgraph = self.parse_function(node._class) + print("name = ", name, "; subgraph = ", subgraph) + self._subgraphs["self." + name] = subgraph + elif node._is_custom_define and issubclass(node._class, (nn.Cell)): + print ("this node is Cell class") + graph = CellGraph(node._class) + graph.create_ast() + graph.print_ast() + graph.get_function_root() + graph.parse_init() + graph.parse_construct() + for node in graph._nodes: + node.name = self._base_scope + "." + name.split(".")[-1] + "." + node.name.split(".")[-1] + print (name + " subgraph node: ", repr(node)) + self._subgraphs[name] = graph + elif node._is_custom_define and issubclass(node._class, Primitive): + print ("is primitive") + else: + print ("other types") + + def remove_node(self, node: Node): + index = node._index + self._ast_function_root["construct"].body.pop(index) + ast.fix_missing_locations(self._ast_function_root["construct"]) + node.inputs[0].outputs = node.outputs + node.outputs[0].inputs = node.inputs + self._nodes.remove(node) + + def replace_node(self, src_nodes, dst_node): + if isinstance(src_nodes, list): + # redirect edges + appends = [] + for src_node in src_nodes: + for output in src_node.outputs: + if output not in src_nodes: + appends.append(output) + dst_node.outputs = appends + for append in appends: + new_inputs = [] + for input in append.inputs: + if input in src_nodes: + new_inputs.append(dst_node) + else: + new_inputs.append(input) + append.inputs = new_inputs + + appends = [] + for src_node in src_nodes: + for input in src_node.inputs: + if input not in src_nodes: + appends.append(input) + dst_node.inputs = appends + for append in appends: + new_inputs = [] + for output in append.outputs: + if output in src_nodes: + new_inputs.append(dst_node) + else: + new_inputs.append(output) + append.outputs = new_inputs + + # update nodes + for i in range(len(self._nodes) - 1, -1, -1): + cur_node = self._nodes[i] + for node in src_nodes: + if cur_node is node: + self._nodes.pop(i) + break + self._nodes.append(dst_node) + else: + pass + + def insert_node(self, node: Node): + """ + 需要知道插入的位置,在body中的下标, attr_name是插入的属性名称或者是已存在的属性名称 + """ + new_init_node = ast.Assign() + self._ast_function_root["__init__"].body.append(new_init_node) + + new_node = ast.Assign() + new_node.lineno = 0 + new_node.targets = None + new_node.value = None + index = node.inputs[0]._index + node.outputs[0].inputs[0] = node + node.inputs[0].outputs[0] = node + self.nodes.insert(index, node) + + self._ast_function_root["construct"].body.insert(index, new_node) + ast.fix_missing_locations(self._ast_function_root["construct"]) + + def _insert_attribute(self): + """ + 在init中插入一个属性,供插入节点时使用 + """ + pass + + def check(self): + pass + + @property + def python_code(self): + return astunparse.unparse(self._ast_root) + + @property + def convert_to_cell(self) -> nn.Cell: + pass + + def print_graph(self): + pass + + def deep_copy(self): + pass diff --git a/mindspore/python/mindspore/rewrite/legacy/ast_utils.py b/mindspore/python/mindspore/rewrite/legacy/ast_utils.py new file mode 100644 index 00000000000..6c3ef73919f --- /dev/null +++ b/mindspore/python/mindspore/rewrite/legacy/ast_utils.py @@ -0,0 +1,52 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import ast + + +def get_func_node(root_node, func_name): + for node in ast.walk(root_node): + if isinstanceof(node, ast.FunctionDef) and node.name == func_name: + return node + return None + + +class Assign: + def __init__(self, func_name, targets, args): + self.__func_name = func_name + self.__targets = targets + self.__args = args + + +def resolve_func_assign(assign: ast.Assign): + args = [] + if not isinstance(assign.value, ast.Call): + return None + call = assign.value + for arg in call.args: + if isinstance(arg, ast.Name): + args.append(arg) + func_name = call.func.attr + outputs = [] + for output in assign.targets: + if isinstance(output, ast.Name): + outputs.append(output.id) + return Assign(func_name, outputs, args) + + +def resolve_all_func_assgin_nodes(func_node: ast.FunctionDef): + for node in func_node: + if isinstance(node, ast.Assign): + yield resolve_func_assign(node) diff --git a/mindspore/python/mindspore/rewrite/legacy/code_analysis.py b/mindspore/python/mindspore/rewrite/legacy/code_analysis.py new file mode 100644 index 00000000000..c51420a069f --- /dev/null +++ b/mindspore/python/mindspore/rewrite/legacy/code_analysis.py @@ -0,0 +1,133 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import ast +from enum import Enum +from typing import Any +from _ast import AST, Assign, FunctionDef, ClassDef, Attribute, Name + + +class ASTType(Enum): + TypeUnknown = 0 + TypeClassDef = 1 + TypeFunctionDef = 2 + TypeAttribute = 3 + TypeName = 4 + TypeCall = 5 + TypeAssign = 6 + + def __new__(cls, value: str): + member = object.__new__(cls) + if value == "ClassDef": + member._value_ = 1 + elif value == "FunctionDef": + member._value_ = 2 + elif value == "Attribute": + member._value_ = 3 + elif value == "Name": + member._value_ = 4 + elif value == "Call": + member._value_ = 5 + elif value == "Assign": + member._value_ = 6 + else: + member._value_ = 0 + return member + + +class Scope: + def __init__(self, scope_type: ASTType, lineno: int): + self._type = scope_type + self._lineno = lineno + self._name = str(scope_type) + "-" + str(lineno) + + def __str__(self): + return self._name + + +class FullScope: + def __init__(self): + self._scopes: [Scope] = [] + + def enter_scope(self, scope: Scope): + print("--------------------------------------Enter scope ", scope) + self._scopes.append(scope) + + def enter_ast_class_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeClassDef, scope.lineno)) + + def enter_ast_function_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeFunctionDef, scope.lineno)) + + def enter_ast_attribute_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeAttribute, scope.lineno)) + + def enter_ast_name_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeName, scope.lineno)) + + def enter_ast_call_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeCall, scope.lineno)) + + def enter_ast_assign_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeAssign, scope.lineno)) + + def enter_ast_scope(self, scope: AST): + pass + self.enter_scope(Scope(ASTType(scope.__class__.__name__), scope.lineno)) + + def exit_scope(self): + print("--------------------------------------Exit scope ", self._scopes.pop()) + + def cur_scope_type(self) -> str: + return self._scopes[len(self._scopes) - 1] + + +class CodeAnalyzer(ast.NodeVisitor): + def __init__(self): + super(CodeAnalyzer, self).__init__() + self._scope: FullScope = FullScope() + + # def visit(self, node: AST) -> Any: + # return super(CodeAnalyzer, self).visit(node) + + def visit_Assign(self, node: Assign) -> Any: + self._scope.enter_ast_scope(node) + print("visit assign : ", node.lineno) + super(CodeAnalyzer, self).generic_visit(node) + self._scope.exit_scope() + + def visit_Attribute(self, node: Attribute) -> Any: + self._scope.enter_ast_scope(node) + print("visit attribute : ", self._scope.cur_scope_type(), node.attr) + super(CodeAnalyzer, self).generic_visit(node) + self._scope.exit_scope() + + def visit_Name(self, node: Name) -> Any: + self._scope.enter_ast_scope(node) + print("visit name : ", node.__class__) + super(CodeAnalyzer, self).generic_visit(node) + self._scope.exit_scope() + + def visit_FunctionDef(self, node: FunctionDef) -> Any: + self._scope.enter_ast_scope(node) + print("visit function : ", node.name) + super(CodeAnalyzer, self).generic_visit(node) + self._scope.exit_scope() + + def visit_ClassDef(self, node: ClassDef) -> Any: + self._scope.enter_ast_scope(node) + print("visit class : ", node.name) + super(CodeAnalyzer, self).generic_visit(node) + self._scope.exit_scope() diff --git a/mindspore/python/mindspore/rewrite/legacy/graph.py b/mindspore/python/mindspore/rewrite/legacy/graph.py new file mode 100644 index 00000000000..b2673637602 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/legacy/graph.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Graph.""" + +import dill +import ast +import astpretty +from python_node import PythonNode +from mindspore.nn.cell import Cell +from code_analysis import CodeAnalyzer + + +class Graph: + def __init__(self): + self._nodes: [Cell] = [] + self._source = "" + self._ast: ast = None + + # basic api + + def find_node(self, full_name_with_scope: str) -> Cell: ... + + # return inputs of node whose full_name_with_scope is full_name_with_scope + def node_inputs(self, full_name_with_scope: str) -> [Cell]: ... + + # return outputs of node whose full_name_with_scope is full_name_with_scope + def node_outputs(self, full_name_with_scope: str) -> [Cell]: ... + + # todo + # insert a cell into graph. only support b into a and c: a--c ==> a--b--c + # return node been inserted, return None if failed + def insert_node(self, new_node: Cell, input_nodes: [Cell], output_nodes: [Cell]) -> Cell: ... + + # remove a node from graph. only support b between a and c: a--b--c ==> a--c + # return node been removed, return None if failed + def remove_node(self, full_name_with_scope: str) -> Cell: ... + + # replace node whose name is full_name_with_scope with a new cell. + # new_node should has same inputs and outputs with node. + # return node been replaced, return None if failed + def replace_node(self, full_name_with_scope: str, new_node: Cell) -> Cell: ... + + # other api + def convert_to_cell(self) -> Cell: ... + + def deep_copy_graph(self): ... + + def print(self) -> str: ... + + def nodes(self) -> [PythonNode]: + return self._nodes + + @staticmethod + def build_from_cell(cell, args=None, **kwargs): + graph = Graph() + graph._source = dill.source.getsource(cell) + graph._ast = ast.parse(graph._source) + # code_analyzer = CodeAnalyzer() + # code_analyzer.visit(graph._ast) + astpretty.pprint(graph._ast) + return graph diff --git a/mindspore/python/mindspore/rewrite/legacy/python_node.py b/mindspore/python/mindspore/rewrite/legacy/python_node.py new file mode 100644 index 00000000000..0a2c0cbbc0b --- /dev/null +++ b/mindspore/python/mindspore/rewrite/legacy/python_node.py @@ -0,0 +1,57 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PythonNode.""" + +from enum import Enum +from mindspore.nn.cell import Cell + + +class PythonNodeType(Enum): + unknown = 0 + placeholder = 1 # input + parameter = 2 # weight + value = 3 # const + call_cell = 4 # cell object + call_method = 5 # member of cell + call_function = 6 # primitive object + graph = 7 # sub-graph + output = 8 + + +class PythonNode: + def __init__(self, cell: Cell = None, node_type=PythonNodeType.unknown, name="", scope="", ast_node=None, + inputs: [] = None, + outputs: [] = None): + if inputs is None: + inputs = [] + if outputs is None: + outputs = [] + self._type = node_type + self._name = name + self._scope = scope + self._ast_node = ast_node + self._inputs = inputs + self._outputs = outputs + self._kwargs = {} + self._cell: Cell = cell + + def cell(self) -> Cell: + return self._cell + + def type(self) -> PythonNodeType: + return self._type + + def name(self) -> str: + return self._name diff --git a/mindspore/python/mindspore/rewrite/lenet.py b/mindspore/python/mindspore/rewrite/lenet.py new file mode 100644 index 00000000000..feb63521aee --- /dev/null +++ b/mindspore/python/mindspore/rewrite/lenet.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""LeNet.""" +import mindspore.nn as nn +from mindspore.common.initializer import Normal + + +class LeNet5(nn.Cell): + """ + Lenet network + + Args: + num_class (int): Number of classes. Default: 10. + num_channel (int): Number of channels. Default: 1. + + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + + """ + + def __init__(self, num_class=10, num_channel=1, include_top=True): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.include_top = include_top + if self.include_top: + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + if not self.include_top: + return x + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/mindspore/python/mindspore/rewrite/namespace.py b/mindspore/python/mindspore/rewrite/namespace.py new file mode 100644 index 00000000000..e7283431432 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/namespace.py @@ -0,0 +1,126 @@ +# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). +# +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the namespace of parse.""" + +import builtins + +from mindspore import log as logger + + +class Namespace: + """ + Base class of namespace for resolve variables. + + Args: + name (str): The namespace's name. + dicts (dict): A list of dict containing the namespace's variable. + """ + + def __init__(self, name, *dicts): + self.name = name + self.dicts = dicts + + def __contains__(self, name): + for d in self.dicts: + if name in d: + return True + return False + + def __getitem__(self, name): + for d in self.dicts: + if name in d: + return d[name] + raise NameError(name) + + def __repr__(self): + return f'Namespace:{self.name}' + + +class CellNamespace(Namespace): + """ + Namespace for Cell object. + + Args: + name (str): Valid module name, it can be imported. + """ + + def __init__(self, name): + mod_dict = vars(__import__(name, fromlist=['_'])) + builtins_dict = vars(builtins) + super().__init__(name, mod_dict, builtins_dict) + + def __getstate__(self): + return (self.name,) + + def __setstate__(self, state): + name, = state + mod_dict = vars(__import__(name, fromlist=['_'])) + builtins_dict = vars(builtins) + super().__init__(name, mod_dict, builtins_dict) + + +class ClosureNamespace(Namespace): + """ + Namespace for function closure. + + Args: + fn (Function): A python function. + """ + + def __init__(self, fn): + name = f'{fn.__module__}..<{fn.__name__}>' + names = fn.__code__.co_freevars + cells = fn.__closure__ + ns = dict(zip(names, cells or ())) + super().__init__(name, ns) + + def __getitem__(self, name): + d, = self.dicts + try: + return d[name].cell_contents + except ValueError: + raise UnboundLocalError(name) + + +class ClassMemberNamespace(Namespace): + """ + Namespace of a class's closure. + + Args: + obj (Object): A python class object. + """ + + def __init__(self, obj): + self.__class_member_namespace__ = True + label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' + super().__init__(label, obj) + + def __getitem__(self, name): + d, = self.dicts + if name == "self": + return d + if name == "namespace": + return self + try: + if hasattr(d, name): + return getattr(d, name) + return d.__dict__[name] + except ValueError: + raise UnboundLocalError(name) + except KeyError: + logger.info(f"'{d.__class__.__name__}' object has no attribute or method: '{name}', so will return None.") + raise AttributeError(name) diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py new file mode 100644 index 00000000000..57f9be0edf0 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/node.py @@ -0,0 +1,173 @@ +import ast +from typing import Dict, List, Union + +import mindspore.nn as nn +from mindspore.ops.primitive import Primitive + +class NodeType(): + placeholder = 1 # 输入 + parameter = 2 # 权重 + constant = 3 # 常量 + call_cell = 4 # 预计cell的对象 + call_method = 5 # cell内部成员 + call_function = 6 # 基于primitive对象 + output = 7 # 输出 + +class AttributeNode: + def __init__(self, name="", type=NodeType.call_cell, class_=None, is_custom_define=False, attibute=None, constant_value=None) -> None: + self._name = name + self._type: NodeType = type + self._class = class_ + self._is_custom_define = is_custom_define + self._attribute: Dict = attibute + self._constant_value = constant_value + + @property + def name(self): + return self._name + + @name.setter + def name(self, name: str): + self._name = name + + @property + def type(self): + return self._type + + @type.setter + def type(self, type: NodeType): + self._type = type + + def __repr__(self): + return f"name: {self.name}; type: {self.type}; class: {self._class}; attribute: {self._attribute}; is_custom_define: {self._is_custom_define}; constant value: {self._constant_value}" + +class BaseNode: + def __init__(self, name="", targets=None, args=None, inputs: List = None): + """ + 创建一个节点时对应的属性怎么传进来,cell应该不涉及,primitive会有这种情况 + """ + self._name: str = name + self._attribute: AttributeNode = AttributeNode() + self._outputs: List[Node] = list() + if inputs is None: + self._inputs: List[Node] = list() + else: + self._inputs = inputs + self._targets: List[str] = targets #用来保存算子输出结果的名称,用来匹配算子输入名称 + self._args: List = args + + @property + def name(self) ->str: + return self._name + + @name.setter + def name(self, name: str): + self._name = name + + @property + def inputs(self) -> list(): + return self._inputs + + @inputs.setter + def inputs(self, nodes: list): + self._inputs = nodes + + @property + def outputs(self) -> List: + return self._outputs + + @outputs.setter + def outputs(self, nodes: list): + self._outputs = nodes + + def node_type(self) -> NodeType: + return self._attribute.type + +class Node(BaseNode): + def __init__(self, name="", targets=None, args=None, ast_node=None, instance: Union[nn.Cell, Primitive] = None, inputs: List = None): + """ + 创建一个节点时对应的属性怎么传进来,cell应该不涉及,primitive会有这种情况 + """ + super().__init__(name, targets, args, inputs) + #self._name: str = "" + self._kwargs: Dict = {} + self._scope: str = "" + self._ast_node: ast.AST = ast_node + self._index = 0 + self._attribute._class = type(instance) + + @property + def attribute(self) -> AttributeNode: + return self._attribute + + @attribute.setter + def attribute(self, attribute: AttributeNode): + self._attribute = attribute + + def add_attribute(self, attribute: Dict): + for key, value in attribute.items(): + self._attribute._attribute[key] = value + + @property + def type(self): + return self._attribute._class + + @type.setter + def type(self, cell_type): + self._attribute._class = cell_type + + def set_cell(self, cell: nn.Cell): + #self._attribute = None + self._attribute.instance = NodeType.call_cell + + def __repr__(self): + input_names= "" + input_nodes = "" + output_names = "" + output_nodes = "" + + for n in self.inputs: + input_names += n.name + " " + # input_nodes += str(n) + + for n in self.outputs: + output_names += n.name + " " + #output_nodes += str(n) + + #attr_info = "attr name: " + self._attribute.name + "; class" + str(self.attribute._class) + "; is coustum defined: " + str(self.attribute._is_custom_define) + return f"name: {self._name}; ast_node: {self._ast_node}; scope: {self._scope}; index: {self._index}; inputs: {len(self.inputs)}; input names: {input_names}; outputs: {len(self.outputs)}; output names: {output_names}; attr info: {self._attribute}" + +class ConstantNode(BaseNode): + def __init__(self, value): + super().__init__(str(value)) + #self._name = str(value) + self._value = value + self._args.append(value) + self._attribute.type = NodeType.constant + + def __repr__(self) -> str: + output_names = "" + for n in self.outputs: + output_names += n.name + " " + return f"name: {self._name}; value: {self._value}; ast_node: {self._ast_node}; index: {self._index}; outputs: {len(self.outputs)}; output names: {output_names}" + +class PlaceholderNode(BaseNode): + def __init__(self, name, targets=None, ast_node=None, default_value=None): + super().__init__(name, targets) + #self._name = "" + #self._target = target + self._ast_node = ast_node + self._default_value = default_value + self._attribute.type = NodeType.placeholder + #self._outputs: List = [] + + def __repr__(self) -> str: + output_names = "" + for n in self.outputs: + output_names += n.name + " " + return f"name: {self._name}; targets: {self._targets}, outputs: {len(self.outputs)}; output names: {output_names}; attribute: {self._attribute}" + +class ControlFlowNode(Node): + def __init__(self): + self.subgraph = None + \ No newline at end of file diff --git a/mindspore/python/mindspore/rewrite/parser.py b/mindspore/python/mindspore/rewrite/parser.py new file mode 100644 index 00000000000..135212cb1bd --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parser.py @@ -0,0 +1,580 @@ +import ast +from collections import OrderedDict +import inspect +from types import FunctionType +from typing import Dict, List, Union + +from .namespace import CellNamespace, ClosureNamespace, Namespace +from .node import AttributeNode, Node, NodeType + +from mindspore.ops.primitive import Primitive +import mindspore.nn as nn + +namespace_nodetype_map = { + "mindspore.common": NodeType.call_cell, + "mindspore.nn": NodeType.call_cell, + "mindspore.ops": NodeType.call_method, + "mindspore.ops.composite": NodeType.call_method, + "mindspore.ops.composite.multitype_ops": NodeType.call_method, + "mindspore.ops.operations": NodeType.call_method +} + +class Parser(): + def __init__(self, network: Union[nn.Cell, Primitive, FunctionType]): + # Used to resolve mindspore builtin ops namespace. + self.ms_common_ns = CellNamespace('mindspore.common') + self.ms_nn_ns = CellNamespace('mindspore.nn') + self.ms_ops_ns = CellNamespace('mindspore.ops') + self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite') + self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops') + self.ms_ops_p_ns = CellNamespace('mindspore.ops.operations') + # Used to resolve the function's globals namespace. + self.global_namespace: CellNamespace = CellNamespace(network.__module__) + # Used to resolve the function's nonlocals. + self.closure_namespace: ClosureNamespace = None + self._default_values = None + + def updete_closure_namespace(self, fn: FunctionType): + self.closure_namespace = ClosureNamespace(inspect.unwrap(fn)) + + def parse_function(self, ast_root: ast.AST, function_name: str): + ''' + parse function by name + ''' + for ast_node in ast_root.body: + #new_node: Node = node_parser_mapper[type(ast_node).__name__](self, ast_node) + pass + + def _get_node_visitor(self, node: ast.AST): + method = 'parse_' + node.__class__.__name__ + visitor = getattr(self, method, None) + return visitor + + def get_func_namesapce(self, func_name: str): + if func_name in self.ms_common_ns: + return self.ms_common_ns[func_name], repr(self.ms_common_ns), False + elif func_name in self.ms_nn_ns: + return self.ms_nn_ns[func_name], repr(self.ms_nn_ns), False + elif func_name in self.ms_ops_ns: + return self.ms_ops_ns[func_name], repr(self.ms_ops_ns), False + elif func_name in self.ms_ops_c_ns: + return self.ms_ops_c_ns[func_name], repr(self.ms_ops_c_ns), False + elif func_name in self.ms_ops_c_multitype_ns: + return self.ms_ops_c_multitype_ns[func_name], repr(self.ms_ops_c_multitype_ns), False + elif func_name in self.ms_ops_p_ns: + return self.ms_ops_p_ns[func_name], repr(self.ms_ops_p_ns), False + elif func_name in self.global_namespace: + return self.global_namespace[func_name], repr(self.global_namespace), True + elif func_name in self.closure_namespace: + return self.closure_namespace[func_name], repr(self.closure_namespace), True + else: + return None, None, False + + def _parse_targets(self, node: ast.AST): + visitor = self._get_node_visitor(node) + res = visitor(node) + return res + + def _parse_args(self, node_list: ast.List): + args = [] + nodes = [] + called_obj_names = [] + for node in node_list: + if isinstance(node, ast.Call): #according to the configuration of the node, create a new node and insert into nodes before it, set args and targets information + new_node = Node(targets=["tmp"], ast_node=node) + nodes_, called_obj_names_ = self.parse_Call(node, new_node) + nodes.extend(nodes_) + args.append("tmp") + called_obj_names.extend(called_obj_names_) + assert(len(nodes) == len(called_obj_names)) + print("node in args: ", new_node) + elif isinstance(node, ast.Name): + args.append(node.id) + + return args, nodes, called_obj_names + + def _calc_Add(self, left, right): + return int(left) + int(right) + + def _calc_Sub(self, left, right): + return int(left) - int(right) + + def _calc_Mult(self, left, right): + return int(left) * int(right) + + def _calc_Div(self, left, right): + return left / right + + def parse_init_assign(self, node: ast.Assign): + lineno = node.lineno + visitor = self._get_node_visitor(node.targets) + targets = visitor(node.targets) + + print ("targets:", targets) + value = node.value + new_node = AttributeNode(name=targets[0]) + if isinstance(value, ast.Call): + self.parse_init_Call(value, new_node) + elif isinstance(value, ast.Name): + new_node._class = ast.Constant + new_node._type = NodeType.constant + new_node._is_custom_define = True + if value.id in self._default_values.keys(): + new_node._constant_value = self._default_values[value.id] + else: + new_node._constant_value = value.id + elif isinstance(value, ast.BinOp): + print("value is BinOp") + pass + else: + print("vaule type: ", type(value), " is not supported") + return new_node + + def _parse_init_args(self, ast_nodes: ast.List): + args = [] + for node in ast_nodes: + if isinstance(node, ast.BinOp): + value = self.parse_init_BinOp(node) + elif isinstance(node, ast.Call): + value = AttributeNode() + self.parse_init_Call(node, value) + else: + visitor = self._get_node_visitor(node) + value = visitor(node) + if value in self._default_values.keys(): + value = self._default_values[value] + + args.append(value) + print("init args: ", args) + + return args + + def _parse_init_keywords(self, ast_nodes: ast.List): + keywords = {} + + for node in ast_nodes: + key = node.arg + value = node.value + + if isinstance(node.value, ast.BinOp): + value = self.parse_init_BinOp(node.value) + elif isinstance(node.value, ast.Call): + value = AttributeNode() + self.parse_init_Call(node.value, value) + else: + visitor = self._get_node_visitor(node.value) + value = visitor(node.value) + if value in self._default_values.keys(): + value = self._default_values[value] + + keywords[key] = value + + return keywords + + def parse_init_Call(self, ast_node: ast.Call, attr_node: AttributeNode): + def _update_args_value(args: List, keywords: Dict, parameters: inspect.signature): + print("defaults values:", self._default_values) + + new_dict = OrderedDict() + for name, para_ in parameters.items(): + new_dict[name] = para_.default + + keys = list(new_dict.keys()) + + if "args" in keys: + new_dict["args"] = args + else: + for i in range(len(args)): + if args[i] in self._default_values.keys(): + new_dict[keys[i + 1]] = self._default_values[args[i]] + else: + new_dict[keys[i + 1]] = args[i] + + if "kwargs" in keys: + new_dict["kwargs"] = keywords + else: + for key, value in keywords.items(): + if value in self._default_values.keys(): + new_dict[key] = self._default_values[value] + else: + new_dict[key] = value + + print ("new dict: ", new_dict) + return new_dict + + new_dict = OrderedDict() + + visitor = self._get_node_visitor(ast_node.func) + value = visitor(ast_node.func) + print ("node name: ", value.split(".")[-1]) + class_name = value.split(".")[-1] + class_, name_space, is_custom_define = self.get_func_namesapce(class_name) + + print ("class: ", class_) + parameters = inspect.signature(class_.__init__).parameters + print ("parameters: ", parameters) + if name_space in namespace_nodetype_map: + node_type = namespace_nodetype_map[name_space] + else: + node_type = NodeType.call_cell + + args = self._parse_init_args(ast_node.args) + keywords = self._parse_init_keywords(ast_node.keywords) + new_dict = _update_args_value(args, keywords, parameters) + + attr_node._class = class_ + attr_node._type = node_type + attr_node._is_custom_define = is_custom_define + attr_node._attribute = new_dict + return + + def parse_init_BinOp(self, ast_node: ast.BinOp): + def _get_value(node: ast.AST): + if isinstance(node, ast.Call): + value = AttributeNode() + self.parse_init_Call(node, value) + elif isinstance(node, ast.BinOp): + value = self.parse_init_BinOp(node) + else: + visitor = self._get_node_visitor(node) + value = visitor(node) + if value in self._default_values.keys(): + value = self._default_values[value] + + return value + + op = ast_node.op + left = ast_node.left + right = ast_node.right + + left_value = _get_value(left) + right_value = _get_value(right) + print("left value: ", left_value) + print("right value: ", right_value) + if (isinstance(left_value, int) and isinstance(right_value, int)) or (str.isdigit(str(left_value)) and str.isdigit(str(right_value))): + method = '_calc_' + op.__class__.__name__ + calc = getattr(self, method, None) + if calc: + result = calc(left_value, right_value) + else: + print("undefined op", method) + else: + result = str(left_value) + " op.__class__.__name__ " + repr(right_value) + + return result + + def parse_Assign(self, node: ast.Assign): + lineno = node.lineno + nodes = [] + called_obj_names = [] + visitor = self._get_node_visitor(node.targets) + targets = visitor(node.targets) + print("targets: ", targets) + + value = node.value + new_node = Node(targets=targets, ast_node=node) + visitor = self._get_node_visitor(value) + nodes_, called_obj_names_ = visitor(value, new_node) + + nodes.extend(nodes_) + called_obj_names.extend(called_obj_names_) + print("new node in assign: ", new_node) + assert(len(nodes) == len(called_obj_names)) + + return nodes, called_obj_names + + def parse_Call(self, ast_node: ast.Call, node: Node): + nodes = [] + called_obj_names = [] + visitor = self._get_node_visitor(ast_node.func) + called_obj_name = visitor(ast_node.func) + node.name = called_obj_name + + args_, nodes_, called_obj_names_ = self._parse_args(ast_node.args) + + visitor = self._get_node_visitor(ast_node.keywords) + kwargs_: Dict = visitor(ast_node.keywords) + + node._args = args_ + node._kwargs = kwargs_ + nodes.extend(nodes_) + nodes.append(node) + called_obj_names.extend(called_obj_names_) + called_obj_names.append(called_obj_name) + print ("nodes in call:", nodes) + print("called obj name in call: ", called_obj_names) + return nodes, called_obj_names + + def parse_Attribute(self, node: ast.Attribute): + visitor = self._get_node_visitor(node.value) + attribute_value = visitor(node.value) + "." + node.attr + + return attribute_value + + def parse_list(self, node: ast.List) -> list: + res = [] + for n in node: + visitor = self._get_node_visitor(n) + value = visitor(n) + if isinstance(value, list): + res += value + else: + res.append(value) + + return res + + def parse_List(self, node: ast.List) -> list: + res = [] + for n in node.elts: + visitor = self._get_node_visitor(n) + value = visitor(n) + if isinstance(value, list): + res += value + else: + res.append(value) + + return res + + def parse_Tuple(self, node: ast.List) -> list: + res = [] + for n in node.elts: + visitor = self._get_node_visitor(n) + value = visitor(n) + res.append(value) + + return res + + def parse_Expr(self, node: ast.expr): + pass + + def parse_BinOp(self, ast_node: ast.BinOp, node: Node): #如果left和right都是Call则需要分别创建节点,同时分析call的args,根据args也创建对应节点 + nodes = [] + called_obj_names = [] + #ops_info = parse_object_map. + node.name = ast_node.op.__class__.__name__ + node._attribute._class = ast_node.op #ast node type must convert to mindspore op type + node._attribute.node_type = NodeType.call_function + args = [] + left = ast_node.left + if isinstance(left, ast.Call): + new_node = Node() + nodes_, called_obj_names_ = self.parse_Call(left, new_node) + new_node._targets.append("tmp") + nodes.extend(nodes_) + called_obj_names.extend(called_obj_names_) + args.append("tmp") + print("left node in BinOp: ", new_node) + assert(len(nodes) == len(called_obj_names)) + else: + visitor = self._get_node_visitor(left) + args.append(visitor(left)) + + right = ast_node.right + if isinstance(right, ast.Call): + new_node = Node() + nodes_, called_obj_names_ = self.parse_Call(right, new_node) + new_node._targets.append("tmp") + nodes.extend(nodes_) + called_obj_names.extend(called_obj_names_) + args.append("tmp") + print("right node in BinOp: ", new_node) + assert(len(nodes) == len(called_obj_names)) + else: + visitor = self._get_node_visitor(right) + args.append(visitor(right)) + + node._args = args + nodes.append(node) + called_obj_names.append(node.name) + return nodes, called_obj_names + + def parse_BoolOp(self, node: ast.BoolOp): + pass + + def parse_UnaryOp(self, node: ast.UnaryOp): + pass + + def parse_Lambda(self, node: ast.Lambda): + pass + + def parse_IfExp(self, node: ast.IfExp): + pass + + def parse_Dict(self, node: ast.Dict): + pass + + def parse_Set(self, node: ast.Set): + pass + + def parse_Slice(self, node: ast.Slice): + pass + + def parse_Name(self, node: ast.Name) -> str: + return node.id + + def parse_Num(self, node: ast.Num) -> str: + return node.n + + def parse_Constant(self, node: ast.Constant) -> str: + return node.value + + def parse_keyword(self, node: ast.keyword): + key = node.arg + visitor = self._get_node_visitor(node.value) + value = visitor(node.value) + return {key: value} + + def parse_NameConstant(self, node: ast.NameConstant): + return node.value + + def parse_Str(self, node: ast.Str): + return node.s + + def parse_AugAssign(self, node: ast.AugAssign): + nodes = [] + called_obj_names = [] + new_node = Node() + + new_node._targets.append(node.target.id) + + if isinstance(node.value, ast.Call): + nodes_, called_obj_names_ = self.parse_Call(node.value, new_node) + new_node._args.insert(0, node.target.id) + nodes.extend(nodes_) + called_obj_names.extend(called_obj_names_) + elif isinstance(node.value, ast.Name): + new_node.name = node.value.id + new_node._args.append(node.value.id) + new_node._args.insert(0, node.target.id) + nodes.append(new_node) + elif isinstance(node.value, ast.Attribute): + new_node.name = self.parse_Attribute(node.value) + new_node._args.append(new_node.name) + new_node._args.insert(0, node.target.id) + nodes.append(new_node) + elif isinstance(node.value, ast.Num): + new_node.name = node.op.__class__.__name__ + new_node._args.insert(0, node.target.id) + new_node._attribute.name = node.value.__class__.__name__ + new_node._attribute.type = NodeType.constant + visitor = self._get_node_visitor(node.value) + value = visitor(node.value) + print ("augassign value:", value) + new_node._attribute._attribute["value"] = value + nodes.append(new_node) + else: + new_node.name = node.target.id + new_node._args.insert(0, node.target.id) + new_node._attribute.name = node.value.__class__.__name__ + new_node._attribute.type = NodeType.constant + visitor = self._get_node_visitor(node.value) + value = visitor(node.value) + print ("augassign value:", value) + new_node._attribute._attribute["value"] = value + nodes.append(new_node) + + return nodes, called_obj_names + + def parse_arguments(self, node: ast.arguments): + class Arg: + def __init__(self, lineno, col_offset, name) -> None: + self._lineno = lineno + self._clo_offset = col_offset + self._name = name + + class Default: + def __init__(self, lineno, col_offset, value) -> None: + self._lineno = lineno + self._col_offset = col_offset + self._value = value + + def _find_corresponding_name(defaults: List[Default], names: List[Arg]): + #print("defaults: ", defaults, "; names: ", names) + for d in defaults: + # print("dddddd = ", d) + i = 0 + while i < len(names) and names[i]._lineno == d._lineno and names[i]._clo_offset < d._col_offset: + # print("names[i]._lineno = ", names[i]._lineno, "; d._lineno: ", d._lineno, "; names[i]._clo_offset: ", names[i]._clo_offset, "; d._col_offset: ", d._col_offset) + i += 1 + #print("i = ", i) + if i <= len(names): + # print("names[i]._name: ", names[i-1]._name, "; d.value: ", d._value) + arg_with_default_value[names[i-1]._name] = d._value + + args_ = [] + arg_with_default_value = {} + for arg in node.args: + if arg.arg == "self": + continue + a = Arg(arg.lineno, arg.col_offset, arg.arg) + #args_.append(arg.arg) + args_.append(a) + arg_with_default_value[a._name] = None + + for arg in node.kwonlyargs: + #args_.append(arg.arg) + a = Arg(arg.lineno, arg.col_offset, arg.arg) + args_.append(a) + arg_with_default_value[a._name] = None + + if node.vararg != None: + #args_.append(node.vararg.arg) + a = Arg(node.vararg.arg.lineno, arg.col_offset, arg.arg) + args_.append(a) + arg_with_default_value[a._name] = None + + if node.kwarg != None: + #args_.append(node.kwarg.arg) + a = Arg(node.vararg.arg.lineno, arg.col_offset, arg.arg) + args_.append(a) + arg_with_default_value[a._name] = None + + defaults_ = [] + for default in node.defaults: + visitor = self._get_node_visitor(default) + value = visitor(default) + d = Default(default.lineno, default.col_offset, value) + defaults_.append(d) + + _find_corresponding_name(defaults_, args_) + self._default_values = arg_with_default_value + return arg_with_default_value + + def parse_Return(self, node: ast.Return): + nodes = [] + called_obj_names = [] + value = node.value + + visitor = self._get_node_visitor(value) + value = visitor(value) + new_node = Node(name="return", args=[], ast_node=node) + print ("return args: ", value) + #new_node.name = "return" + if isinstance(value, list): + new_node._args += value + else: + new_node._args.append(value) + #attribute = AttributeNode() + #new_node._ast_node = node + new_node._attribute.type = NodeType.output + nodes.append(new_node) + print("return node: ", new_node) + return nodes, called_obj_names + + def parse_If(self, node: ast.If): + nodes = [] + called_obj_names = [] + + new_node = Node("if", [], [], node) + # new_node.name = "if" + + nodes.append(new_node) + + return nodes, called_obj_names + + def parse_While(self, node: ast.While): + pass + + def parse_For(self, node: ast.For): + pass + \ No newline at end of file diff --git a/mindspore/python/mindspore/rewrite/pattern_engine.py b/mindspore/python/mindspore/rewrite/pattern_engine.py new file mode 100644 index 00000000000..70dbcaaf21f --- /dev/null +++ b/mindspore/python/mindspore/rewrite/pattern_engine.py @@ -0,0 +1,348 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PatternEngine.""" + +from typing import Tuple, Union, List, Type +from collections import OrderedDict +from .graph import Graph +from .node import Node, NodeType +from mindspore.nn.cell import Cell + + +class PatternNode: + """ + PatternNode is define as a node while defining pattern. + + Args: + node_name (str): Name of current node. + node_type (str): Cell type of current node. + inputs (List[PatternNode]): Input nodes of current node. + """ + + def __init__(self, node_name: str, node_type: Type = Type[None], inputs: ['PatternNode'] = None): + self._name = node_name + self._type = node_type + if inputs is None: + self._inputs = [] + else: + self._inputs = inputs + + @staticmethod + def from_node(node: Node) -> 'PatternNode': + """ + Create a PatternNode from a rewrite node. + + Args: + node (Node): input rewrite node. + + Returns: + PatternNode created from rewrite node. + """ + + pattern_node = PatternNode(node.name) + if node.node_type() is NodeType.call_cell: + pattern_node._type = node.type + return pattern_node + + @staticmethod + def create_pattern_from_node(node: Node) -> 'PatternNode': + """ + Create a PatternNode from a rewrite node with its inputs. + + Args: + node (Node): input rewrite node. + + Returns: + PatternNode created from rewrite node. + """ + + pattern_node = PatternNode.from_node(node) + inputs = [] + for node_input in node.inputs(): + inputs.append(PatternNode.create_pattern_from_node(node_input)) + pattern_node._inputs = inputs + return pattern_node + + @staticmethod + def create_pattern_from_list(type_list: []) -> 'PatternNode': + """ + Create a PatternNode from a cell type list. + + Args: + type_list (List): input cell type list. + + Returns: + PatternNode created from cell type list. + """ + + last_node = None + for i in range(0, len(type_list)): + cell_type = type_list[i] + cur_node = PatternNode(str(i) + "-" + str(cell_type), cell_type, []) + if last_node is not None: + cur_node._inputs = [last_node] + else: + cur_node._inputs = [] + last_node = cur_node + return last_node + + def add_input(self, node_type): + """ + Add a input for current PatternNode. + + Args: + node_type : cell type as an input. + """ + + self._inputs.append(node_type) + + def set_inputs(self, inputs): + """ + Set inputs for current PatternNode. + + Args: + inputs (List) : inputs to be set as inputs of current PatternNode. + """ + + self._inputs = inputs + + def match(self, node: Node) -> bool: + """ + Check if current PatternNode can match with a rewrite node + + Args: + node (Node) : a rewrite node to be match. + """ + + return self._type == node.type + + def inputs(self): + """ + Getter of inputs. + """ + + return self._inputs + + def name(self) -> str: + """ + Getter of name. + """ + return self._name + + def type(self): + """ + Getter of type. + """ + return self._type + + +class PlaceHolderNode(PatternNode): + """ + PlaceHoldNode is a subclass of PatternNode whose match is always True. + """ + + def __init__(self): + super(PlaceHolderNode, self).__init__("placehold", Cell, []) + + def match(self, node: Node) -> bool: + return node is not None + + +class PatternEngine: + """ + PatternEngine is define how to transform a graph by PattenNode. + + Args: + pattern (Union[PatternNode, List]): a instance of PatternNode or a cell-type-list to construct PatternNode. + replacement (callable): a callable define how to generate new_node. + """ + + def __init__(self, pattern: Union[PatternNode, List], replacement: callable = None): + if isinstance(pattern, PatternNode): + self._is_chain = False + self._replacement = None + self._pattern = pattern + elif isinstance(pattern, list): + self._is_chain = True + self._replacement = replacement + self._pattern = PatternNode.create_pattern_from_list(pattern) + else: + print("Unsupported pattern type: ", type(pattern)) + self._is_chain = False + self._replacement = None + self._pattern = PlaceHolderNode() + + def pattern(self) -> PatternNode: + """ + Getter of pattern. + """ + + return self._pattern + + def apply(self, graph: Graph) -> bool: + """ + Apply current pattern to a graph. + + Args: + graph (Graph): graph to be transformed. + + Returns: + If graph been changed. + """ + + root: Node = graph.root() + changed = False + # IR match + queue: [Node] = [root] + while len(queue) > 0: + cur_node: Node = queue.pop(0) + node_inputs = cur_node.inputs + matched, matched_dict = self._match(self._pattern, cur_node) + if not matched or not PatternEngine._check_match(self._pattern, matched_dict): + for node_input in node_inputs: + queue.append(node_input) + continue + matched_list = list(matched_dict.values()) + if self._is_chain: + new_node = self._process_chain(matched_list) + else: + new_node = self._process_tree(matched_dict) + if new_node is None: # return None to remove + changed = True + for key in matched_dict: + graph.remove_node(matched_dict[key]) + elif new_node == cur_node: # return origin Node for do nothing + pass + else: # return Node to insert or replace (new Node no need to set inputs and outputs) + changed = True + graph.replace_node(matched_list, new_node) + node_inputs = new_node.inputs + for node_input in node_inputs: + queue.append(node_input) + return changed + + def _process_chain(self, matched_nodes: [Node]) -> Node: + """ + Define how to generate a new_node with fuse_fn when pattern is a chain-pattern. + + Args: + matched_nodes ([Node]): a list of Node as matched result. + + Returns: + New node created from matched result. + """ + + if self._replacement is None: + return matched_nodes[len(matched_nodes) - 1] + instance = self._replacement(*matched_nodes) + if instance is None: + return None + if len(matched_nodes) == 0: + new_node = Node(instance=instance) + else: + new_node = Node(instance=instance, inputs=matched_nodes[0].inputs) + node_name = "" + for matched_node in matched_nodes: + node_name += matched_node.name + "_" + node_name += "fused" + new_node.name = node_name + return new_node + + # matched_cells: name_of_cell_in_pattern map to matched cell in network + def _process_tree(self, matched_nodes: OrderedDict) -> Node: + """ + Define how to generate a new_node when pattern is a tree-pattern. + This method must be overridden by all subclasses whose pattern is a tree-pattern. + + Args: + matched_nodes (OrderedDict): a OrderedDict of Node as matched result. + + Returns: + New node created from matched result. + """ + + pass + + @staticmethod + def _merge_ordered_dict(dict1: OrderedDict, dict2: OrderedDict) -> OrderedDict: + """ + A static util method to merge two OrderedDict + + Args: + dict1 (OrderedDict): first dict to be merge. + dict2 (OrderedDict): second dict to be merge. + + Returns: + Merged OrderedDict. + """ + + merged = dict1.copy() + merged.update(dict2) + return merged + + def _match(self, pattern: PatternNode, node: Node) -> Tuple[bool, OrderedDict]: + """ + Match `pattern` with a rewrite node with all inputs of the `pattern` + + Args: + pattern (PatternNode): pattern to be match. + node (Node): node to be match. + + Returns: + A bool value to indicate if matched. + A instance of OrderedDict as match result. + """ + + # todo: Recurse into subgraph node. Depend on subgraph node definition + if node.node_type() != NodeType.call_cell: + print("Pattern match failed: node(", node.name, ") is not a cell") + return False, OrderedDict() + if not pattern.match(node): + print("Pattern match failed: node(", node.name, ")'s type is ", node.type, " while pattern type is ", + pattern.type()) + return False, OrderedDict() + if isinstance(pattern, PlaceHolderNode): + return True, OrderedDict() + pattern_inputs = pattern.inputs() + cur_inputs = node.inputs + input_num = len(pattern_inputs) + if input_num == 0: + return True, OrderedDict({pattern.name(): node}) + if input_num != len(cur_inputs): + print("Pattern match failed: node(", node.name, ")'s has ", len(node.inputs), " inputs while pattern has ", + input_num, " inputs") + return False, OrderedDict() + result = OrderedDict() + for i in range(0, input_num): + is_matched, tmp_result = self._match(pattern_inputs[i], cur_inputs[i]) + if not is_matched: + return False, OrderedDict() + else: + result = PatternEngine._merge_ordered_dict(result, tmp_result) + result[pattern.name()] = node + return True, result + + @staticmethod + def _check_match(pattern: PatternNode, match_dict: OrderedDict) -> bool: + matched_nodes = match_dict.values() + for key in match_dict: + if key == pattern.name(): + continue + node = match_dict[key] + for output in node.outputs: + if output not in matched_nodes: + print("Check match failed, pattern leaked") + return False + return True diff --git a/mindspore/python/mindspore/rewrite/pattern_engine_match_test.py b/mindspore/python/mindspore/rewrite/pattern_engine_match_test.py new file mode 100644 index 00000000000..6c9b436c560 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/pattern_engine_match_test.py @@ -0,0 +1,191 @@ +import unittest + +from collections import OrderedDict +from mindspore.rewrite import PatternEngine, Node, PlaceHolderNode, PatternNode +from mindspore.nn.layer import Pad, Conv2d, BatchNorm2d, ReLU, Softmax, MatMul + + +class PatternEngineMatchTestCase(unittest.TestCase): + def test_merge_ordered_dict(self): + dict1: OrderedDict = OrderedDict({'a': 1, 'b': 2, 'c': 3}) + dict2: OrderedDict = OrderedDict({'d': 4, 'e': 5, 'f': 6}) + ret = PatternEngine._merge_ordered_dict(dict1, dict2) + index = 1 + for key in ret: + self.assertEqual(ret.get(key), index) + index += 1 + print(ret) + + @staticmethod + def chain_network(): + pad = Node() + pad.type = Pad + pad.name = "pad" + conv = Node() + conv.type = Conv2d + conv.name = "conv" + bn = Node() + bn.type = BatchNorm2d + bn.name = "bn" + relu = Node() + relu.type = ReLU + relu.name = "relu" + softmax = Node() + softmax.type = Softmax + softmax.name = "softmax" + + pad.outputs = [conv] + conv.inputs = [pad] + conv.outputs = [bn] + bn.inputs = [conv] + bn.outputs = [relu] + relu.inputs = [bn] + relu.outputs = [softmax] + softmax.inputs = [relu] + return softmax, relu + + @staticmethod + def tree_network(): + pad = Node() + pad.type = Pad + pad.name = "pad" + conv1 = Node() + conv1.type = Conv2d + conv1.name = "conv1" + bn1 = Node() + bn1.type = BatchNorm2d + bn1.name = "bn1" + conv2 = Node() + conv2.type = Conv2d + conv2.name = "conv2" + bn2 = Node() + bn2.type = BatchNorm2d + bn2.name = "bn2" + matmul = Node() + matmul.type = MatMul + matmul.name = "matmul" + softmax = Node() + softmax.type = Softmax + softmax.name = "softmax" + + pad.outputs = [conv1, conv2] + conv1.inputs = [pad] + conv1.outputs = [bn1] + bn1.inputs = [conv1] + bn1.outputs = [matmul] + conv2.inputs = [pad] + conv2.outputs = [bn2] + bn2.inputs = [conv2] + bn2.outputs = [matmul] + matmul.inputs = [bn1, bn2] + matmul.outputs = [softmax] + softmax.inputs = [matmul] + return softmax, matmul + + @staticmethod + def chain_network_for_leak_pattern(): + pad = Node() + pad.type = Pad + pad.name = "pad" + conv = Node() + conv.type = Conv2d + conv.name = "conv" + bn = Node() + bn.type = BatchNorm2d + bn.name = "bn" + relu = Node() + relu.type = ReLU + relu.name = "relu" + relu2 = Node() + relu2.type = ReLU + relu2.name = "relu2" + matmul = Node() + matmul.type = MatMul + matmul.name = "matmul" + + pad.outputs = [conv] + conv.inputs = [pad] + conv.outputs = [bn, relu2] + bn.inputs = [conv] + relu2.inputs = [conv] + return bn + + def test_chain_pattern(self): + bad_root, good_root = self.chain_network() + # define pattern + p_conv = PatternNode("p_conv", Conv2d) + p_bn = PatternNode("p_bn", BatchNorm2d, [p_conv]) + p_relu = PatternNode("p_relu", ReLU, [p_bn]) + # define pattern engine + pattern_engine = PatternEngine(p_relu) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) + print("*****Chain match softmax result: ", match) + self.assertEqual(match, False) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) + print("*****Chain match matmul result: ", match) + self.assertEqual(match, True) + self.assertEqual(len(match_dict), 3) + + def test_chain_pattern_from_list(self): + bad_root, good_root = self.chain_network() + # define pattern engine + pattern_engine = PatternEngine([Conv2d, BatchNorm2d, ReLU]) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) + print("*****List chain match softmax result: ", match) + self.assertEqual(match, False) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) + print("*****List chain match matmul result: ", match) + self.assertEqual(match, True) + self.assertEqual(len(match_dict), 3) + + def test_tree_pattern(self): + bad_root, good_root = self.tree_network() + # define pattern + p_placeholder = PlaceHolderNode() + p_conv1 = PatternNode("p_conv1", Conv2d, [p_placeholder]) + p_bn1 = PatternNode("p_bn1", BatchNorm2d, [p_conv1]) + p_conv2 = PatternNode("p_conv2", Conv2d, [p_placeholder]) + p_bn2 = PatternNode("p_bn2", BatchNorm2d, [p_conv2]) + p_matmul = PatternNode("p_matmul", MatMul, [p_bn1, p_bn2]) + # define pattern engine + pattern_engine = PatternEngine(p_matmul) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) + print("*****Tree match softmax result: ", match) + self.assertEqual(match, False) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) + print("*****Tree match matmul result: ", match) + self.assertEqual(match, True) + self.assertEqual(len(match_dict), 5) + + def test_placeholder_pattern(self): + _, good_root = self.tree_network() + # define pattern + p_placeholder = PlaceHolderNode() + p_conv2 = PatternNode("p_conv2", Conv2d, [p_placeholder]) + p_bn2 = PatternNode("p_bn2", BatchNorm2d, [p_conv2]) + p_matmul = PatternNode("p_matmul", MatMul, [p_bn2]) + # define pattern engine + pattern_engine = PatternEngine(p_matmul) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) + print("*****Tree no-placehold match matmul result: ", match) + self.assertEqual(match, False) + + p_matmul.set_inputs([p_placeholder, p_bn2]) + pattern_engine = PatternEngine(p_matmul) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) + print("*****Tree placehold match matmul result: ", match) + self.assertEqual(match, True) + self.assertEqual(len(match_dict), 3) + + def test_chain_leak_pattern(self): + root = self.chain_network_for_leak_pattern() + # define pattern + pattern_engine = PatternEngine([Conv2d, BatchNorm2d]) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), root) + self.assertEqual(match, True) + match = pattern_engine._check_match(pattern_engine.pattern(), match_dict) + self.assertEqual(match, False) + + +if __name__ == '__main__': + unittest.main() diff --git a/mindspore/python/mindspore/rewrite/pattern_engine_test.py b/mindspore/python/mindspore/rewrite/pattern_engine_test.py new file mode 100644 index 00000000000..2f0823e0cfb --- /dev/null +++ b/mindspore/python/mindspore/rewrite/pattern_engine_test.py @@ -0,0 +1,176 @@ +import unittest + +from mindspore.rewrite import PatternEngine, Node, Graph, NodeType +from lenet import LeNet5 +from mindspore.nn import Cell +from mindspore.nn.layer import Conv2d, BatchNorm2d, Dense, MaxPool2d, Flatten, ReLU + + +class PatternEngineTestCase(unittest.TestCase): + def lenet(self): + conv1 = Node() + conv1.type = Conv2d + conv1.name = "conv1" + + bn1 = Node() + bn1.type = BatchNorm2d + bn1.name = "bn1" + conv1.outputs = [bn1] + bn1.inputs = [conv1] + + pool1 = Node() + pool1.type = MaxPool2d + pool1.name = "pool1" + bn1.outputs = [pool1] + pool1.inputs = [bn1] + + conv2 = Node() + conv2.type = Conv2d + conv2.name = "conv2" + pool1.outputs = [conv2] + conv2.inputs = [pool1] + + bn2 = Node() + bn2.type = BatchNorm2d + bn2.name = "bn2" + conv2.outputs = [bn2] + bn2.inputs = [conv2] + + pool2 = Node() + pool2.type = MaxPool2d + pool2.name = "pool2" + bn2.outputs = [pool2] + pool2.inputs = [bn2] + + flatten = Node() + flatten.type = Flatten + flatten.name = "flatten" + pool2.outputs = [flatten] + flatten.inputs = [pool2] + + fc1 = Node() + fc1.type = Dense + fc1.name = "dense1" + flatten.outputs = [fc1] + fc1.inputs = [flatten] + + fc2 = Node() + fc2.type = Dense + fc2.name = "dense2" + fc1.outputs = [fc2] + fc2.inputs = [fc1] + + fc3 = Node() + fc3.type = Dense + fc3.name = "dense3" + fc2.outputs = [fc3] + fc3.inputs = [fc2] + graph = Graph(Cell) + graph.set_root(fc3) + return graph + + @staticmethod + def get_nodes_count(root: Node, to_print: bool = False): + count = 1 + if to_print: + print("Visit ", root.name) + for input in root.inputs: + count += PatternEngineTestCase.get_nodes_count(input, to_print) + return count + + def test_pattern(self): + class ConvBn(Cell): + def __init__(self, conv, bn): + super(ConvBn, self).__init__() + self._conv = conv + self._bn = bn + + def construct(self, x): + x = self._conv(x) + return self._bn(x) + + class ConvBnPatternEngine(PatternEngine): + def __init__(self): + super().__init__([Conv2d, BatchNorm2d], ConvBn) + + lenet = self.lenet() + self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet.root()), 10) + pattern_engine = ConvBnPatternEngine() + pattern_engine.apply(lenet) + self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet.root(), True), 8) + + def test_lenet(self): + lenet = LeNet5(num_class=10) + lenet_graph = Graph(LeNet5) + lenet_graph.create_ast() + lenet_graph.get_function_root() + lenet_graph.parse_init() + lenet_graph.parse_construct() + lenet_graph.set_root(lenet_graph.nodes[-1]) + origin_lenet_nn = 14 + self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet_graph.root()), origin_lenet_nn) + # test insert + pre_node = None + post_node = None + for node in lenet_graph.nodes: + if len(node.inputs) == 0: + continue + if node.type is Conv2d and len(node.outputs) == 1 and node.outputs[0].type is ReLU: + pre_node = node + post_node = node.outputs[0] + self.assertNotEqual(pre_node, None) + self.assertNotEqual(post_node, None) + conv_cell = Conv2d(16, 16, 3) + conv_node = Node(name="conv2d_3", instance=conv_cell) + conv_node.inputs = [pre_node] + conv_node.outputs = [post_node] + lenet_graph.insert_node(conv_node) + self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet_graph.root(), True), origin_lenet_nn + 1) + self.assertEqual(len(pre_node.outputs), 1) + self.assertEqual(pre_node.outputs[0], conv_node) + self.assertEqual(len(post_node.inputs), 1) + self.assertEqual(post_node.inputs[0], conv_node) + self.assertEqual(len(conv_node.inputs), 1) + self.assertEqual(conv_node.inputs[0], pre_node) + self.assertEqual(len(conv_node.outputs), 1) + self.assertEqual(conv_node.outputs[0], post_node) + + # test replace + + class ConvReLU(Cell): + def __init__(self, conv, relu): + super(ConvReLU, self).__init__() + self._conv = conv + self._relu = relu + + def construct(self, x): + x = self._conv(x) + return self._relu(x) + + class ConvReLUPatternEngine(PatternEngine): + def __init__(self): + super().__init__([Conv2d, ReLU], ConvReLU) + + pattern_engine = ConvReLUPatternEngine() + pattern_engine.apply(lenet_graph) + self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet_graph.root(), True), origin_lenet_nn + 1 - 2) + # test remove + for node in lenet_graph.nodes: + if node.node_type() == NodeType.call_cell and node.type is ReLU: + lenet_graph.remove_node(node) + self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet_graph.root(), True), origin_lenet_nn + 1 - 2 - 2) + + def remove_cell(*args, **kwargs): + return None + + class RemovePatternEngine(PatternEngine): + def __init__(self): + super().__init__([MaxPool2d], remove_cell) + + pattern_engine = RemovePatternEngine() + pattern_engine.apply(lenet_graph) + self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet_graph.root(), True), origin_lenet_nn + 1 - 2 - 2 - 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/mindspore/python/mindspore/rewrite/rewriter.py b/mindspore/python/mindspore/rewrite/rewriter.py new file mode 100644 index 00000000000..97231fed102 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/rewriter.py @@ -0,0 +1,70 @@ +import ast +from typing import Dict, Union +from .graph import Graph +import mindspore.nn as nn +from mindspore.ops.primitive import Primitive + + +def parse(network: Union[nn.Cell, Primitive]) -> Graph: + graph = Graph(network) + + graph.create_ast() + + graph.get_function_root() + print(graph.python_code) + # graph.print_ast() + graph.parse_init() + graph.parse_construct() + # graph.parse_function("__init__") + + # graph.parse_function("construct") + return graph + + +def insert_node(graph, cell, input_node, output_node): + pass + + +def insert_node(graph, prim, input_node, output_node): + pass + + +def remove_node(graph, node): + pass + + +def replace_node(graph, cell, node): + pass + + +def replace_node(graph, prim, node): + pass + + +# def remove_node_by_pattern(graph, pattern): +# ''' +# pattern 怎么定义,能不能用字符串来定义,例如: +# > remove_node_by_pattern(g, "Conv + Relu") +# ''' +# pass + +# def replace_node_by_pattern(graph, src_pattern, dst_pattern): +# pass + +def check_graph(graph): + return graph.check() + + +# def convert_graph_to_python_code(graph): +# python_code = graph.convert_to_python_code() +# return python_code + +def print_graph(graph): + graph.print_graph() + + +def copy_graph(graph): + graph.deep_copy() + +# def get_graph_node_list(graph): +# node_list = graph.node_list() diff --git a/mindspore/python/mindspore/rewrite/test.py b/mindspore/python/mindspore/rewrite/test.py new file mode 100644 index 00000000000..d728e55d767 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/test.py @@ -0,0 +1,23 @@ +# from common import Graph + +from .rewriter import parse +from .test_network import ControlSimpleIf + +# def test_parse_init(): +# graph. + + +if __name__ == "__main__": + print("aaaaa") + print(globals()) + # print (globals()['P']) + graph = parse(ControlSimpleIf, globals()) + # graph.print_ast() + # graph = Graph(network) + + # graph.create_ast() + + # graph.get_function_root() + # print(graph.python_code()) + # graph.parse_init() + # test_parse_init() diff --git a/mindspore/python/mindspore/rewrite/test_network.py b/mindspore/python/mindspore/rewrite/test_network.py new file mode 100644 index 00000000000..eba244da0ff --- /dev/null +++ b/mindspore/python/mindspore/rewrite/test_network.py @@ -0,0 +1,121 @@ +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common import Parameter + +''' +class test_NetWork(nn.Cell): + def __init__(self): + self. + for +''' + + +class ControlSimpleIf(nn.Cell): + sss = "sss" + + def __init__(self, input_shape): + super().__init__() + self.addn = P.AddN() + self.assign = P.Assign() + self.conv2d_1 = nn.Conv2d(3, 128, kernel_size=1, has_bias=True) + self.flatten = P.Flatten() + self.fc = nn.Dense(128, 10) + # self.input_data = Parameter(initializer(1, input_shape, mstype.float32), name="var") + self.input_data = Parameter(initializer(1, input_shape, mstype.float32), name="var") + self.block0 = ControlIfinIf() + + def construct(self, x, y, input_data): + x += 1 + x -= 1 + x += x + x += self.input_data + x -= self.addn(x) + x = self.conv2d_1(x) + xx0 = self.block0(x, y) + xx0 = self.block0(xx0, y) + af, be, cd = self.conv2d_1(xx0) + xx0 = self.block0(xx0, y) + # if x0 > y: + out = self.addn([input_data, input_data, input_data]) + # else: + out = self.assign(self.input_data, input_data) + out = self.flatten(out) + out = self.fc(out) + return out, af, be, cd + + +class ControlSimpleIfWithAssign(nn.Cell): + def __init__(self, input_shape): + super().__init__() + self.addn = P.AddN() + self.assign = P.Assign() + self.input_data = Parameter(initializer(1, input_shape, mstype.float32), name="var") + + def construct(self, x, y, input_data): + if x > y: + out = self.addn([input_data, input_data, input_data]) + else: + out = self.assign(self.input_data, input_data) + return out + + +class ControlIfinIf(nn.Cell): + """pass""" + + def construct(self, x, y): + if x > y: + x = x + 1 + x = x + 2 + if y < 0: + y = y + 2 + y = y + 1 + else: + y = y + 2 + else: + x = x + 2 + x = x + y + return x + + +class ControlMixedWhileIf(nn.Cell): + def __init__(self): + super().__init__() + self.assign = op.Assign() + self.var = Parameter(initializer(1, (1), mstype.float32), name="var") + + def construct(self, x, y, z, c2, c4): + out = c4 + self.assign(self.var, c4) + while x < c2: + y = c4 + self.assign(self.var, c4) + while y < c2 and x < c2: + if 2 * y < c2: + y = y + 2 + else: + y = y + 1 + out = out + y + z = c4 + self.assign(self.var, c4) + while z < c2: + z = z + 1 + out = out + z + x = x + 1 + out = out + x + while x < 2 * c2: + y = c4 + self.assign(self.var, c4) + x = x + 1 + while y < c2: + z = c4 + self.assign(self.var, c4) + while z < c2: + z = z + 1 + if x < c2: + y = y - 1 + else: + y = y + 1 + out = out + z + out = out + y + out = out + x + return out diff --git a/mindspore/python/mindspore/rewrite/ut.sh b/mindspore/python/mindspore/rewrite/ut.sh new file mode 100644 index 00000000000..637df70cf80 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/ut.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +success=0 +failure=0 +python pattern_engine_match_test.py +if [[ $? -ne 0 ]]; then + echo "---------------- pattern_engine_match_test failed" + ((failure=failure+1)) +else + echo "---------------- pattern_engine_match_test succeed" + ((success=success+1)) +fi + +python pattern_engine_test.py +if [[ $? -ne 0 ]]; then + echo "---------------- pattern_engine_test failed" + ((failure=failure+1)) +else + echo "---------------- pattern_engine_test succeed" + ((success=success+1)) +fi + +cd ../golden_stick/quantization/ +python transformer_test.py +if [[ $? -ne 0 ]]; then + echo "---------------- transformer_test failed" + ((failure=failure+1)) +else + echo "---------------- transformer_test succeed" + ((success=success+1)) +fi + + +echo "=========== rewrite testcases finished, ${success} succeed, ${failure} failed" diff --git a/requirements.txt b/requirements.txt index 3a781b5821f..28e7297e0aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,5 @@ pycocotools >= 2.0.2 # for st test tables >= 3.6.1 # for st test easydict >= 1.9 # for st test psutil >= 5.7.0 +astunparse >= 0.0 +astpretty >= 0.0 -- Gitee From d9ad205b65dec6e89484d75faa2ddfbae0291768 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Fri, 17 Dec 2021 09:49:02 +0800 Subject: [PATCH 02/34] remove mindspore compression --- .../python/mindspore/compression/__init__.py | 17 - .../mindspore/compression/common/__init__.py | 21 - .../mindspore/compression/common/constant.py | 119 ---- .../mindspore/compression/export/__init__.py | 17 - .../compression/export/quant_export.py | 501 -------------- .../mindspore/compression/quant/__init__.py | 25 - .../python/mindspore/compression/quant/qat.py | 618 ------------------ .../compression/quant/quant_utils.py | 440 ------------- .../mindspore/compression/quant/quantizer.py | 64 -- .../python/mindspore/nn/layer/__init__.py | 4 +- .../python/mindspore/train/serialization.py | 43 -- 11 files changed, 1 insertion(+), 1868 deletions(-) delete mode 100644 mindspore/python/mindspore/compression/__init__.py delete mode 100644 mindspore/python/mindspore/compression/common/__init__.py delete mode 100644 mindspore/python/mindspore/compression/common/constant.py delete mode 100644 mindspore/python/mindspore/compression/export/__init__.py delete mode 100644 mindspore/python/mindspore/compression/export/quant_export.py delete mode 100644 mindspore/python/mindspore/compression/quant/__init__.py delete mode 100644 mindspore/python/mindspore/compression/quant/qat.py delete mode 100644 mindspore/python/mindspore/compression/quant/quant_utils.py delete mode 100644 mindspore/python/mindspore/compression/quant/quantizer.py diff --git a/mindspore/python/mindspore/compression/__init__.py b/mindspore/python/mindspore/compression/__init__.py deleted file mode 100644 index 45e57ef38bf..00000000000 --- a/mindspore/python/mindspore/compression/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -MindSpore compression module. -""" diff --git a/mindspore/python/mindspore/compression/common/__init__.py b/mindspore/python/mindspore/compression/common/__init__.py deleted file mode 100644 index c382f47e87b..00000000000 --- a/mindspore/python/mindspore/compression/common/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Common module for various compression algorithms, now only including datatype definition for quantization. -""" - -from .constant import QuantDtype - -__all__ = ["QuantDtype"] diff --git a/mindspore/python/mindspore/compression/common/constant.py b/mindspore/python/mindspore/compression/common/constant.py deleted file mode 100644 index 460940b46f9..00000000000 --- a/mindspore/python/mindspore/compression/common/constant.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Constant module for compression""" -import enum -import re -from types import DynamicClassAttribute - - -__all__ = ["QuantDtype"] - - -@enum.unique -class QuantDtype(enum.Enum): - """ - An enum for quant datatype, contains `INT2` ~ `INT8`, `UINT2` ~ `UINT8`. - """ - INT2 = "INT2" - INT3 = "INT3" - INT4 = "INT4" - INT5 = "INT5" - INT6 = "INT6" - INT7 = "INT7" - INT8 = "INT8" - - UINT2 = "UINT2" - UINT3 = "UINT3" - UINT4 = "UINT4" - UINT5 = "UINT5" - UINT6 = "UINT6" - UINT7 = "UINT7" - UINT8 = "UINT8" - - def __str__(self): - return f"{self.name}" - - @staticmethod - def is_signed(dtype): - """ - Get whether the quant datatype is signed. - - Args: - dtype (QuantDtype): quant datatype. - - Returns: - bool, whether the input quant datatype is signed. - - Examples: - >>> quant_dtype = QuantDtype.INT8 - >>> is_signed = QuantDtype.is_signed(quant_dtype) - """ - return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5, - QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8] - - @staticmethod - def switch_signed(dtype): - """ - Switch the signed state of the input quant datatype. - - Args: - dtype (QuantDtype): quant datatype. - - Returns: - QuantDtype, quant datatype with opposite signed state as the input. - - Examples: - >>> quant_dtype = QuantDtype.INT8 - >>> quant_dtype = QuantDtype.switch_signed(quant_dtype) - """ - type_map = { - QuantDtype.INT2: QuantDtype.UINT2, - QuantDtype.INT3: QuantDtype.UINT3, - QuantDtype.INT4: QuantDtype.UINT4, - QuantDtype.INT5: QuantDtype.UINT5, - QuantDtype.INT6: QuantDtype.UINT6, - QuantDtype.INT7: QuantDtype.UINT7, - QuantDtype.INT8: QuantDtype.UINT8, - QuantDtype.UINT2: QuantDtype.INT2, - QuantDtype.UINT3: QuantDtype.INT3, - QuantDtype.UINT4: QuantDtype.INT4, - QuantDtype.UINT5: QuantDtype.INT5, - QuantDtype.UINT6: QuantDtype.INT6, - QuantDtype.UINT7: QuantDtype.INT7, - QuantDtype.UINT8: QuantDtype.INT8 - } - return type_map[dtype] - - @DynamicClassAttribute - def _value(self): - """The value of the Enum member.""" - return int(re.search(r"(\d+)", self._value_).group(1)) - - @DynamicClassAttribute - def num_bits(self): - """ - Get the num bits of the QuantDtype member. - - Returns: - int, the num bits of the QuantDtype member. - - Examples: - >>> from mindspore.compression.common import QuantDtype - >>> quant_dtype = QuantDtype.INT8 - >>> num_bits = quant_dtype.num_bits - >>> print(num_bits) - 8 - """ - return self._value diff --git a/mindspore/python/mindspore/compression/export/__init__.py b/mindspore/python/mindspore/compression/export/__init__.py deleted file mode 100644 index 48e59baa71a..00000000000 --- a/mindspore/python/mindspore/compression/export/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Compression export module. -""" diff --git a/mindspore/python/mindspore/compression/export/quant_export.py b/mindspore/python/mindspore/compression/export/quant_export.py deleted file mode 100644 index 4badbc619ce..00000000000 --- a/mindspore/python/mindspore/compression/export/quant_export.py +++ /dev/null @@ -1,501 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Export for quantization.""" - -import copy - -import numpy as np - -from ... import nn, ops -from ..._checkparam import Validator -from ...common import Tensor -from ...common import dtype as mstype -from ...common.api import _cell_graph_executor as _executor -from ...common.parameter import Parameter -from ...nn import Cell -from ...nn.layer import quant -from ...ops import operations as P -from ...ops import functional as F -from ...ops.operations import _inner_ops as inner -from ..quant import quant_utils -from ..quant.qat import _AddFakeQuantInput, _AddFakeQuantAfterSubCell - - -__all__ = ["ExportToQuantInferNetwork"] - - -class QuantBlock(Cell): - r""" - A quant block of Conv/Dense, activation layer for Ascend deploy. - - Calculate Conv or Dense in Int8, with Quant and DeQuant. - - Notes: - This block is only for deploy, and not trainable. - - Args: - in_channels (int): The number of channels in the input space. - out_channels (int): The number of channels in the output space. - weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype - is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is - same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. - batchnorm (bool): Specifies to used batchnorm or not. Default: None. - activation (string): Specifies activation type. The optional values are as following: - 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', - 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. - - Inputs: - - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. - - Outputs: - Tensor of shape :math:`(N, out\_channels)`. - """ - - def __init__(self, - core_op, - weight, - quant_op, - dequant_op, - dequant_scale, - bias=None, - activation=None): - super(QuantBlock, self).__init__() - self.core_op = core_op - self.weight = weight - self.quant = quant_op - self.dequant = dequant_op - self.dequant_scale = dequant_scale - self.bias = bias - self.has_bias = bias is not None - self.activation = activation - self.has_act = activation is not None - self.bias_add = P.BiasAdd() - self.sub = P.Sub() - self.weight_offset = Parameter(np.zeros(1, dtype=np.int8), name='weight_offset') - - def construct(self, x): - x = self.quant(x) - if self.has_bias: - weight = self.sub(self.weight, self.weight_offset) - x = self.core_op(x, weight) - x = self.bias_add(x, self.bias) - else: - x = self.core_op(x, self.weight) - x = self.dequant(x, self.dequant_scale) - x = F.cast(x, mstype.float32) - if self.has_act: - x = self.activation(x) - return x - - def extend_repr(self): - s = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]' - if self.has_bias: - s += f', bias=shape[{self.bias.shape}]' - if self.has_act: - s += f', activation={self.activation}' - s += f', dequant={self.dequant}' - return s - - -class QuantMindirBlock(Cell): - """A quant binary block of Conv/Dense, activation layer for export MINDIR model. - - Args: - core_op (Cell): The operation cell. - weight (Tensor): The weight of the cell. - bias (Tensor): The bias of the cell. Default: None. - activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. - param_dict (dict): The information of the cell. - """ - - def __init__(self, - core_op, - weight, - bias=None, - activation=None, - param_dict=None): - - super(QuantMindirBlock, self).__init__() - self.core_op = core_op - if activation is not None: - self.core_op.add_prim_attr("activation_name", activation.__class__.__name__) - self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"])) - self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"])) - if param_dict["output_maxq"] is not None: - self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"])) - self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"])) - self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"])) - if hasattr(core_op, 'pad_mode'): - self.core_op.add_prim_attr("pad_mode", core_op.pad_mode) - self.core_op.add_prim_attr("act_num_bits", Tensor(8)) - self.core_op.add_prim_attr("weight_num_bits", Tensor(param_dict["weight_num_bits"])) - self.core_op.add_prim_attr("weight_narrow_range", Tensor(param_dict["weight_narrow_range"])) - if param_dict["input_narrow_range"] is not None: - self.core_op.add_prim_attr("input_narrow_range", Tensor(param_dict["input_narrow_range"])) - if param_dict["output_narrow_range"] is not None: - self.core_op.add_prim_attr("output_narrow_range", Tensor(param_dict["output_narrow_range"])) - if param_dict["input_maxq"] == 'None': - self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"])) - self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"])) - elif param_dict["input_maxq"] is not None: - self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"])) - self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"])) - - self.weight = weight - self.bias = bias - self.has_bias = bias is not None - self.activation = activation - self.has_act = activation is not None - self.bias_add = P.BiasAdd() - - def construct(self, x): - if self.has_bias: - x = self.core_op(x, self.weight) - x = self.bias_add(x, self.bias) - else: - x = self.core_op(x, self.weight) - if self.has_act: - x = self.activation(x) - return x - - def extend_repr(self): - s = f'core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]' - if self.has_bias: - s += f', bias=shape[{self.bias.shape}]' - if self.has_act: - s += f', activation={self.activation}' - return s - - -class ExportToQuantInferNetwork: - """ - Convert quantization aware network to infer network. - - Args: - network (Cell): MindSpore quantization aware training network. - inputs (Tensor): Input tensors of the `quantization aware training network`. - mean (int, float): The mean of input data after preprocessing, used for quantizing the first layer of network. - Default: 127.5. - std_dev (int, float): The variance of input data after preprocessing, used for quantizing the first layer - of network. Default: 127.5. - is_mindir (bool): Whether export MINDIR format. Default: False. - - Returns: - Cell, Infer network. - """ - - def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): - network = Validator.check_isinstance('network', network, (nn.Cell,)) - self.data_type = mstype.int8 - self.network = copy.deepcopy(network) - self.network_bk = copy.deepcopy(network) - self.get_inputs_table(inputs) - self.mean = mean - self.std_dev = std_dev - self.is_mindir = is_mindir - self.upcell = None - - def get_inputs_table(self, inputs): - """Get the input quantization parameters of quantization cell for quant export.""" - phase_name = 'export_quant' - graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False) - self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id) - - def run(self): - """Start to convert.""" - self.network.update_cell_prefix() - network = self.network - if isinstance(network, _AddFakeQuantInput): - network = network.network - network = self._convert_quant2deploy(network) - return network - - def _get_quant_block(self, cell_core, activation, fake_quant_a_out): - """convert network's quant subcell to deploy subcell""" - scale_a_in, zp_a_in, scale_w, zp_w, param_dict = self.__get_quant_param(cell_core, fake_quant_a_out) - - # Build the `Quant` `Dequant` op. - # Quant only support perlayer version. Need check here. - quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in)) - scale_deq = self.__get_dequant_scale(scale_a_in, scale_w) - dequant_op = inner.Dequant() - - if isinstance(activation, _AddFakeQuantAfterSubCell): - activation = activation.subcell - elif hasattr(activation, "get_origin"): - activation = activation.get_origin() - - # get op - if isinstance(cell_core, quant.DenseQuant): - op_core = P.MatMul() - else: - op_core = cell_core.conv - - # get the `weight` and `bias` - weight, bias, weight_b, bias_b = self.__get_weight_bias(cell_core, scale_a_in, scale_w, zp_w) - - if self.is_mindir: - block = QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict) - else: - block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) - return block - - def _get_input_quant_param(self, minq_name, np_type, param_dict): - """get input quant parameter for quant block""" - fake_quant_a_in_prefix = minq_name[:-5] - cells = self.network_bk.cells_and_names() - for cell in cells: - if cell[0].endswith(fake_quant_a_in_prefix): - fake_quant_a_in = cell[1] - break - scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ - quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type) - param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range - return scale_a_in, zp_a_in - - def __get_quant_param(self, cell_core, fake_quant_a_out): - """get parameter for quant block""" - w_minq_name = cell_core.fake_quant_weight.minq.name - w_maxq_name = cell_core.fake_quant_weight.maxq.name - np_type = mstype.dtype_to_nptype(self.data_type) - param_dict = dict() - param_dict["filter_maxq"] = None - param_dict["filter_minq"] = None - param_dict["output_maxq"] = None - param_dict["output_minq"] = None - param_dict["input_maxq"] = None - param_dict["input_minq"] = None - param_dict["input_narrow_range"] = None - param_dict["output_narrow_range"] = None - param_dict["weight_narrow_range"] = cell_core.fake_quant_weight.narrow_range - param_dict["mean"] = self.mean - param_dict["std_dev"] = self.std_dev - param_dict["symmetric"] = cell_core.fake_quant_weight.symmetric - param_dict["weight_num_bits"] = cell_core.fake_quant_weight.num_bits - - scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \ - quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) - if fake_quant_a_out is not None: - _, _, param_dict["output_maxq"], param_dict["output_minq"] = \ - quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) - param_dict["output_narrow_range"] = fake_quant_a_out.narrow_range - - info = self.quant_info_table.get(w_minq_name, None) - if not info: - info = self.quant_info_table.get(w_maxq_name, None) - if info: - _, minq_name = info - if minq_name == 'input': - scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ - (1 / self.std_dev), round(self.mean), 'None', 'None' - else: - scale_a_in, zp_a_in = self._get_input_quant_param(minq_name, np_type, param_dict) - else: - # skip quant layer - scale_a_in, zp_a_in = 1.0, 0.0 - return scale_a_in, zp_a_in, scale_w, zp_w, param_dict - - @staticmethod - def __get_dequant_scale(scale_a_in, scale_w): - """Get dequant scale""" - scale_deq = scale_a_in * scale_w - - # fuse parameter - # |--------|47:40|--------|39:32|--------|31:0| - # offset_w [8] shift_N [8] deq_scale [32] - float32_deq_scale = scale_deq.astype(np.float32) - uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32) - scale_length = scale_deq.size # channel - dequant_param = np.zeros(scale_length, dtype=np.uint64) - for index in range(scale_length): - dequant_param[index] += uint32_deq_scale[index] - scale_deq = Tensor(dequant_param, mstype.uint64) - return scale_deq - - def __get_weight_bias(self, cell_core, scale_a_in, scale_w, zp_w): - """Get weight and bias for quantizaiton""" - np_type = mstype.dtype_to_nptype(self.data_type) - weight = cell_core.weight.data.asnumpy() - bias = None - if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)): - if cell_core.has_bias: - bias = cell_core.bias.data.asnumpy() - elif isinstance(cell_core, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv)): - weight, bias = quant_utils.fold_batchnorm(weight, cell_core) - elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant): - weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core) - weight_b = weight - bias_b = bias - # apply the quant - quant_min, quant_max = quant_utils.get_quant_min_max(np_type, - cell_core.fake_quant_weight.num_bits, - cell_core.fake_quant_weight.narrow_range) - weight = quant_utils.weight2int(weight, scale_w, zp_w, quant_min, quant_max) - if bias is not None: - bias = Tensor(bias / scale_a_in / scale_w, mstype.int32) - - if isinstance(cell_core, quant.DenseQuant): - weight = np.transpose(weight) - weight_b = np.transpose(weight_b) - - weight = Tensor(weight, self.data_type) - weight_b = Tensor(weight_b) - if bias_b is not None: - bias_b = Tensor(bias_b, mstype.float32) - return weight, bias, weight_b, bias_b - - def _add_output_min_max_for_op(self, origin_op, fake_quant_cell): - """add output quant info for quant op for export mindir.""" - if self.is_mindir: - if isinstance(origin_op, ops.Primitive) and not hasattr(origin_op, 'output_minq'): - np_type = mstype.dtype_to_nptype(self.data_type) - _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type) - origin_op.add_prim_attr('output_maxq', Tensor(maxq)) - origin_op.add_prim_attr('output_minq', Tensor(minq)) - - def _convert_subcell(self, network, change, name, subcell): - """Convert subcell to ant subcell.""" - if subcell is not None and hasattr(subcell, "fake_quant_weight"): - new_subcell = self._get_quant_block(subcell, None, None) - prefix = subcell.param_prefix - new_subcell.update_parameters_name(prefix + '.') - self.upcell = new_subcell - network.insert_child_to_cell(name, new_subcell) - change = True - return network, change - - def _convert_conv(self, network, change, name, subcell): - """Convert subcell to ant subcell for conv.""" - cell_core = subcell.conv - activation = subcell.activation - fake_quant_act = None - if hasattr(activation, 'fake_quant_act_before'): - fake_quant_act = activation.fake_quant_act_before - elif hasattr(activation, 'fake_quant_act'): - fake_quant_act = activation.fake_quant_act - if cell_core is not None and hasattr(cell_core, "fake_quant_weight"): - new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) - self.upcell = None - prefix = subcell.param_prefix - new_subcell.update_parameters_name(prefix + '.') - network.insert_child_to_cell(name, new_subcell) - change = True - return network, change - - def _convert_dense(self, network, change, name, subcell): - """Convert subcell to ant subcell for dense.""" - cell_core = subcell.dense - activation = subcell.activation - fake_quant_act = None - if hasattr(activation, 'fake_quant_act_before'): - fake_quant_act = activation.fake_quant_act_before - elif hasattr(activation, 'fake_quant_act'): - fake_quant_act = activation.fake_quant_act - if cell_core is not None and hasattr(cell_core, "fake_quant_weight"): - new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) - prefix = subcell.param_prefix - new_subcell.update_parameters_name(prefix + '.') - network.insert_child_to_cell(name, new_subcell) - self.upcell = None - change = True - return network, change - - def _convert_act(self, subcell): - """Convert subcell to ant subcell for activation.""" - activation = subcell.get_origin() - if isinstance(activation, nn.ReLU): - self._add_output_min_max_for_op(activation.relu, subcell.fake_quant_act) - elif isinstance(activation, nn.ReLU6): - self._add_output_min_max_for_op(activation.relu6, subcell.fake_quant_act) - if self.upcell: - self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act) - return activation - - def _convert_add(self, subcell): - """Convert subcell to ant subcell for add.""" - if isinstance(subcell.add, _AddFakeQuantAfterSubCell): - add_op = subcell.add.subcell - subcell.__delattr__("add") - subcell.__setattr__("add", add_op) - add_op = subcell.add - self._add_output_min_max_for_op(add_op, subcell.fake_quant_act) - subcell.__delattr__("fake_quant_act") - subcell.__setattr__("fake_quant_act", P.identity()) - - def _convert_observer(self, network, name, subcell): - """Convert subcell to ant subcell for FakeQuantWithMinMaxObserver.""" - if self.upcell: - self._add_output_min_max_for_op(self.upcell.core_op, subcell) - network.__delattr__(name) - network.__setattr__(name, P.identity()) - - def _convert_fake_quant_after_cell(self, network, name, subcell): - """Convert subcell to ant subcell for _AddFakeQuantAfterSubCell.""" - op = subcell.subcell - self._add_output_min_max_for_op(op, subcell.fake_quant_act) - network.__delattr__(name) - network.__setattr__(name, op) - - def _convert_core_quant_subcell(self, network, change, name, subcell): - """Convert subcell to ant subcell for conv and dense.""" - is_core_subcell = True - if isinstance(subcell, nn.Conv2dBnAct): - network, change = self._convert_conv(network, change, name, subcell) - elif isinstance(subcell, nn.DenseBnAct): - network, change = self._convert_dense(network, change, name, subcell) - elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv, - quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)): - network, change = self._convert_subcell(network, change, name, subcell) - else: - is_core_subcell = False - return is_core_subcell, network, change - - def _convert_other_quant_subcell(self, network, change, name, subcell): - """Convert subcell to ant subcell for cell except conv and dense.""" - is_other_subcell = True - if isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"): - activation = self._convert_act(subcell) - network.insert_child_to_cell(name, activation) - change = True - elif isinstance(subcell, nn.TensorAddQuant): - self._convert_add(subcell) - elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver): - self._convert_observer(network, name, subcell) - elif isinstance(subcell, _AddFakeQuantAfterSubCell): - self._convert_fake_quant_after_cell(network, name, subcell) - change = True - else: - is_other_subcell = False - return is_other_subcell, network, change - - def _convert_quant2deploy(self, network): - """Convert network's all quant subcell to deploy subcell.""" - cells = network.name_cells() - change = False - for name in cells: - subcell = cells[name] - if subcell == network: - continue - is_core_quant_subcell, network, change = self._convert_core_quant_subcell(network, change, name, subcell) - is_other_quant_subcell, network, change = self._convert_other_quant_subcell(network, change, name, subcell) - if not is_core_quant_subcell and not is_other_quant_subcell: - self.upcell = None - self._convert_quant2deploy(subcell) - if isinstance(network, nn.SequentialCell) and change: - network.cell_list = list(network.cells()) - return network diff --git a/mindspore/python/mindspore/compression/quant/__init__.py b/mindspore/python/mindspore/compression/quant/__init__.py deleted file mode 100644 index e2b8cf0f83d..00000000000 --- a/mindspore/python/mindspore/compression/quant/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Quantization module, including base class of the quantizer, the quantization aware training algorithm, -and quantization utils. -""" - -from .quantizer import OptimizeOption -from .qat import QuantizationAwareTraining, create_quant_config -from .quant_utils import load_nonquant_param_into_quant_net, query_quant_layers - -__all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers", "QuantizationAwareTraining", - "create_quant_config", "OptimizeOption"] diff --git a/mindspore/python/mindspore/compression/quant/qat.py b/mindspore/python/mindspore/compression/quant/qat.py deleted file mode 100644 index 3c8ccbcae56..00000000000 --- a/mindspore/python/mindspore/compression/quant/qat.py +++ /dev/null @@ -1,618 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Quantization aware training - -User can use quantization aware to train a model. MindSpore supports quantization aware training, -which models quantization errors in both the forward and backward passes using fake-quantization -operations. Note that the entire computation is carried out in floating point. At the end of quantization -aware training, MindSpore provides conversion functions to convert the trained model into lower precision. -""" - -import re -import mindspore.context as context -import numpy as np -from ... import nn, ops -from ..._checkparam import Validator, Rel -from ...nn.layer import quant -from ...ops import functional as F -from ..common import QuantDtype -from .quantizer import Quantizer, OptimizeOption -from .quant_utils import compute_kl_threshold - - -__all__ = ["QuantizationAwareTraining", "create_quant_config"] - - -def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver), - quant_delay=(0, 0), - quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), - per_channel=(False, False), - symmetric=(False, False), - narrow_range=(False, False), - mode="DEFAULT"): - r""" - Config the observer type of weights and data flow with quant parameters. - - Args: - quant_observer (Union[Observer, list, tuple]): The types of observer for quantization. The first element - applies to weights and the second applies to data flow. Currently, only - :class:`FakeQuantWithMinMaxObserver` supported. - Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver). - quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized - during train and eval. The first element represents weights and the second element represents data flow. - Default: (0, 0). - quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first - element represents weights and the second element represents data flow. - Default: (QuantDtype.INT8, QuantDtype.INT8). - per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True` - then base on per channel, otherwise base on per layer. The first element represents weights - and the second element represents data flow, and the second element must be `False` now. - Default: (False, False). - symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then - base on symmetric, otherwise base on asymmetric. The first element represents weights and the second - element represents data flow. Default: (False, False). - narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. - The first element represents weights and the second element represents data flow. - Default: (False, False). - mode (str): Optional quantization mode, currently only `DEFAULT`(QAT) and `LEARNED_SCALE` are supported. - Default: ("DEFAULT"). - - Returns: - QuantConfig, contains the observer type of weight and activation. - - Raises: - ValueError: If the second element of `per_channel` is not `False`. - """ - if per_channel[-1]: - raise ValueError("Arg 'per_channel' second element must be 'False'.") - weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], - per_channel=per_channel[0], symmetric=symmetric[0], - narrow_range=narrow_range[0], mode=mode) - act_observer = quant_observer[-1].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1], - per_channel=per_channel[-1], symmetric=symmetric[-1], - narrow_range=narrow_range[-1], mode=mode) - return quant.QuantConfig(weight=weight_observer, activation=act_observer) - - -class _AddFakeQuantInput(nn.Cell): - """ - Add FakeQuant OP at input of the network. Only support one input case. - """ - - def __init__(self, network, quant_delay=0): - super(_AddFakeQuantInput, self).__init__(auto_prefix=False) - self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6, - quant_delay=quant_delay, ema=True) - self.fake_quant_input.update_parameters_name('fake_quant_input.') - self.network = network - - def construct(self, data): - data = self.fake_quant_input(data) - output = self.network(data) - return output - - -class _AddFakeQuantAfterSubCell(nn.Cell): - """ - Add FakeQuant OP after of the sub Cell. - """ - - def __init__(self, subcell, **kwargs): - super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) - self.subcell = subcell - self.mode = "DEFAULT" - self.max_init = 6 - self.min_init = -6 - - if OptimizeOption.LEARNED_SCALE in kwargs["optimize_option"]: - self.mode = "LEARNED_SCALE" - self.max_init = 16 - self.min_init = -16 - - self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=self.min_init, - max_init=self.max_init, - ema=True, - quant_dtype=kwargs["quant_dtype"], - quant_delay=kwargs["quant_delay"], - per_channel=kwargs["per_channel"], - symmetric=kwargs["symmetric"], - narrow_range=kwargs["narrow_range"], - mode=self.mode) - - def construct(self, *data): - output = self.subcell(*data) - output = self.fake_quant_act(output) - return output - - -class QuantizationAwareTraining(Quantizer): - r""" - Quantizer for quantization aware training. - - Args: - bn_fold (bool): Whether to use bn fold ops for simulation inference operation. Default: True. - freeze_bn (int): Number of steps after which BatchNorm OP parameters fixed to global mean and variance. - Default: 1e7. - quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized - during train and eval. The first element represents weights and the second element represents data flow. - Default: (0, 0). - quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first - element represents weights and the second element represents data flow. It is necessary to consider the - precision support of hardware devices in the practical quantization infer scenario. - Default: (QuantDtype.INT8, QuantDtype.INT8). - per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True` - then base on per channel, otherwise base on per layer. The first element represents weights and the - second element represents data flow, and the second element must be `False` now. Default: (False, False). - symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then - base on symmetric, otherwise base on asymmetric. The first element represents weights and the second - element represents data flow. Default: (False, False). - narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. - The first element represents weights and the second element represents data flow. - Default: (False, False). - optimize_option (Union[OptimizeOption, list, tuple]): Specifies the quant algorithm and options, currently - only support `QAT` and `LEARNED_SCALE` (Note that, if both `QAT` and `LEARNED_SCALE` are configured, - `LEARNED_SCALE` has a higher priority. `LEARNED_SCALE` currently only work under some constraints, which - includes: freeze_bn=0, quant_delay=0, symmetric=True, narrow_range=True, More specifically, for operators - such as Relu and Relu6, which only have positive values, we add a negative truncation to optimize this - scenario, and narrow_range will automatically match to False). Default: OptimizeOption.QAT. - one_conv_fold (bool): Whether to use one conv bn fold ops for simulation inference operation. Default: True. - - Raises: - TypeError: If the element of `quant_delay` or `freeze_bn` is not int. - TypeError: If `bn_fold`, `one_conv_fold` or the element of `per_channel`, `symmetric`, `narrow_range` - is not bool. - TypeError: If the element of `quant_dtype` is not `QuantDtype`. - ValueError: If the length of `quant_delay`, `quant_dtype`, `per_channel`, `symmetric` or `narrow_range` is - not less than 2. - ValueError: If the `optimize_option` is `LEARNED_SCALE` and `freeze_bn` is not equal to 0. - ValueError: If the `optimize_option` is `LEARNED_SCALE` and `symmetric` is not (True, True). - ValueError: If the `optimize_option` is `LEARNED_SCALE` and `narrow_range` is not (True, True). - ValueError: If the `optimize_option` is `LEARNED_SCALE` and `quant_delay` is not (0, 0). - - Examples: - >>> from mindspore.compression.quant import QuantizationAwareTraining - >>> class LeNet5(nn.Cell): - ... def __init__(self, num_class=10, channel=1): - ... super(LeNet5, self).__init__() - ... self.type = "fusion" - ... self.num_class = num_class - ... - ... # change `nn.Conv2d` to `nn.Conv2dBnAct` - ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') - ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') - ... # change `nn.Dense` to `nn.DenseBnAct` - ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') - ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu') - ... self.fc3 = nn.DenseBnAct(84, self.num_class) - ... - ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - ... self.flatten = nn.Flatten() - ... - ... def construct(self, x): - ... x = self.conv1(x) - ... x = self.max_pool2d(x) - ... x = self.conv2(x) - ... x = self.max_pool2d(x) - ... x = self.flatten(x) - ... x = self.fc1(x) - ... x = self.fc2(x) - ... x = self.fc3(x) - ... return x - ... - >>> net = LeNet5() - >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) - >>> net_qat = quantizer.quantize(net) - """ - __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv", "ReduceMean"] - - def __init__(self, - bn_fold=True, - freeze_bn=10000000, - quant_delay=(0, 0), - quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), - per_channel=(False, False), - symmetric=(False, False), - narrow_range=(False, False), - optimize_option=OptimizeOption.QAT, - one_conv_fold=True): - """Init for QuantizationAwareTraining quantizer""" - super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option) - - def convert2list(name, value): - if not isinstance(value, list) and not isinstance(value, tuple): - value = [value] - elif len(value) > 2: - raise ValueError("input `{}` len should less then 2".format(name)) - return value - - quant_delay = convert2list("quant delay", quant_delay) - quant_dtype = convert2list("quant dtype", quant_dtype) - per_channel = convert2list("per channel", per_channel) - symmetric = convert2list("symmetric", symmetric) - narrow_range = convert2list("narrow range", narrow_range) - - self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay") - self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay") - self.bn_fold = Validator.check_bool(bn_fold, "bn fold") - self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn") - self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype) - self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype) - self.weight_channel = Validator.check_bool(per_channel[0], "per channel") - self.act_channel = Validator.check_bool(per_channel[-1], "per channel") - self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric") - self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric") - self.weight_range = Validator.check_bool(narrow_range[0], "narrow range") - self.act_range = Validator.check_bool(narrow_range[-1], "narrow range") - self.one_conv_fold = Validator.check_bool(one_conv_fold, "one conv fold") - self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv, - nn.DenseBnAct: self._convert_dense} - self.mode = "DEFAULT" - if OptimizeOption.LEARNED_SCALE in self.optimize_option: - self.mode = "LEARNED_SCALE" - if not self.weight_symmetric or not self.act_symmetric: - raise ValueError("OptimizeOption.LEARNED_SCALE currently only support " - "symmetric=(True, True) for quant") - if not self.weight_range or not self.act_range: - raise ValueError("OptimizeOption.LEARNED_SCALE currently only support narrow_range=(True, True) " - "for quant") - if self.freeze_bn != 0: - raise ValueError("OptimizeOption.LEARNED_SCALE currently only support freeze_bn equal to 0, " - "but get freeze_bn={}".format(self.freeze_bn)) - if self.weight_qdelay != 0 or self.act_qdelay != 0: - raise ValueError("OptimizeOption.LEARNED_SCALE currently only support quant_delay=(0, 0)") - self.quant_config = create_quant_config(quant_delay=quant_delay, - quant_dtype=quant_dtype, - per_channel=per_channel, - symmetric=symmetric, - narrow_range=narrow_range, - mode=self.mode) - self.eps = 1e-5 - - @staticmethod - def _convert_op_name(name): - pattern = re.compile(r'([A-Z]{1})') - name_new = re.sub(pattern, r'_\1', name).lower() - if name_new[0] == '_': - name_new = name_new[1:] - return name_new - - def quantize(self, network): - """ - Quant API to convert input network to a quantization aware training network. - - Note: - Please refer to the Examples of class: `mindspore.compression.quant.QuantizationAwareTraining`. - - Args: - network (Cell): network to be quantized. - - Returns: - Cell, a quantization aware training network. - - Raises: - KeyError: If the `device_target` set in context is not in `support_device`. - """ - support_device = ["Ascend", "GPU"] - if context.get_context('device_target') not in support_device: - raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) - - if OptimizeOption.QAT in self.optimize_option or OptimizeOption.LEARNED_SCALE in self.optimize_option: - network.update_cell_prefix() - network = self._convert_subcells2quant(network) - network.update_cell_type("quant") - return network - - def _convert_subcells2quant(self, network): - """ - convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell - """ - cells = network.name_cells() - change = False - for name in cells: - subcell = cells[name] - if subcell == network: - continue - elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)): - prefix = subcell.param_prefix - new_subcell = self._convert_method_map[type(subcell)](subcell) - new_subcell.update_parameters_name(prefix + '.') - network.insert_child_to_cell(name, new_subcell) - change = True - else: - self._convert_subcells2quant(subcell) - if isinstance(network, nn.SequentialCell) and change: - network.cell_list = list(network.cells()) - - # add FakeQuant OP after OP in white list, but not including those wrapped in the below quantization cell. - if isinstance(network, (nn.FakeQuantWithMinMaxObserver, - nn.Conv2dBnFoldQuantOneConv, - nn.Conv2dBnFoldQuant, - nn.Conv2dBnWithoutFoldQuant, - nn.Conv2dQuant, - nn.DenseQuant, - nn.ActQuant, - nn.TensorAddQuant, - nn.MulQuant)): - return network - - add_list = [] - for name in network.__dict__: - if name[0] == '_': - continue - attr = network.__dict__[name] - if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: - add_list.append((name, attr)) - for name, prim_op in add_list: - prefix = name - add_quant = _AddFakeQuantAfterSubCell(prim_op, - quant_dtype=self.act_dtype, - quant_delay=self.act_qdelay, - per_channel=self.act_channel, - symmetric=self.act_symmetric, - narrow_range=self.act_range, - optimize_option=self.optimize_option) - if network.param_prefix: - prefix = '.'.join([network.param_prefix, prefix]) - add_quant.update_parameters_name(prefix + '.') - del network.__dict__[name] - network.insert_child_to_cell(name, add_quant) - return network - - def _convert_conv(self, subcell): - """ - convert Conv2d cell to quant cell - """ - min_init = -6 - max_init = 6 - if OptimizeOption.LEARNED_SCALE in self.optimize_option: - subcell_weight_para = subcell.conv.weight.data.asnumpy() - if subcell.has_bn: - scale_factor = (subcell.batchnorm.gamma.data.asnumpy() / - np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps)) - subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) - min_init, max_init = self._kl_init(subcell_weight_para, self.weight_dtype) - self.quant_config = self.quant_config._replace( - weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init)) - - conv_inner = subcell.conv - if subcell.has_bn: - bn_inner = subcell.batchnorm - if self.bn_fold: - if self.one_conv_fold: - conv_inner = quant.Conv2dBnFoldQuantOneConv(conv_inner.in_channels, - conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, - stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, - padding=conv_inner.padding, - dilation=conv_inner.dilation, - group=conv_inner.group, - eps=bn_inner.eps, - momentum=1 - bn_inner.momentum, - has_bias=conv_inner.has_bias, - bias_init=conv_inner.bias_init, - quant_config=self.quant_config, - quant_dtype=self.weight_dtype, - fake=True) - else: - conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, - conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, - stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, - padding=conv_inner.padding, - dilation=conv_inner.dilation, - group=conv_inner.group, - eps=bn_inner.eps, - momentum=1 - bn_inner.momentum, - has_bias=conv_inner.has_bias, - bias_init=conv_inner.bias_init, - freeze_bn=self.freeze_bn, - quant_config=self.quant_config, - quant_dtype=self.weight_dtype, - fake=True) - # change original network Batch Normalization OP parameters to quant network - conv_inner.gamma = subcell.batchnorm.gamma - conv_inner.beta = subcell.batchnorm.beta - conv_inner.moving_mean = subcell.batchnorm.moving_mean - conv_inner.moving_variance = subcell.batchnorm.moving_variance - else: - conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, - conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, - stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, - padding=conv_inner.padding, - dilation=conv_inner.dilation, - group=conv_inner.group, - eps=bn_inner.eps, - momentum=1 - bn_inner.momentum, - has_bias=conv_inner.has_bias, - bias_init=conv_inner.bias_init, - quant_config=self.quant_config, - quant_dtype=self.weight_dtype) - # change original network Batch Normalization OP parameters to quant network - conv_inner.batchnorm.gamma = subcell.batchnorm.gamma - conv_inner.batchnorm.beta = subcell.batchnorm.beta - conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean - conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance - del subcell.batchnorm - subcell.batchnorm = None - subcell.has_bn = False - else: - conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, padding=conv_inner.padding, - dilation=conv_inner.dilation, group=conv_inner.group, - has_bias=conv_inner.has_bias, quant_config=self.quant_config, - quant_dtype=self.weight_dtype) - # change original network Conv2D OP parameters to quant network - conv_inner.weight = subcell.conv.weight - if subcell.conv.has_bias: - conv_inner.bias = subcell.conv.bias - subcell.conv = conv_inner - if subcell.has_act and subcell.activation is not None: - subcell.activation = self._convert_activation(subcell.activation) - elif subcell.after_fake: - subcell.has_act = True - subcell.activation = _AddFakeQuantAfterSubCell(F.identity, quant_dtype=self.act_dtype, - quant_delay=self.act_qdelay, per_channel=self.act_channel, - symmetric=self.act_symmetric, narrow_range=self.act_range, - optimize_option=self.optimize_option) - return subcell - - def _convert_dense(self, subcell): - """ - convert dense cell to quant cell - """ - min_init = -6 - max_init = 6 - if OptimizeOption.LEARNED_SCALE in self.optimize_option: - subcell_weight_para = subcell.dense.weight.data.asnumpy() - if subcell.has_bn: - scale_factor = (subcell.batchnorm.gamma.data.asnumpy() / - np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps)) - subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) - min_init, max_init = self._kl_init(subcell_weight_para, self.weight_dtype) - self.quant_config = self.quant_config._replace( - weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init)) - - dense_inner = subcell.dense - dense_inner = quant.DenseQuant(dense_inner.in_channels, - dense_inner.out_channels, - has_bias=dense_inner.has_bias, - quant_config=self.quant_config, - quant_dtype=self.weight_dtype) - # change original network Dense OP parameters to quant network - dense_inner.weight = subcell.dense.weight - if subcell.dense.has_bias: - dense_inner.bias = subcell.dense.bias - subcell.dense = dense_inner - if subcell.has_act and subcell.activation is not None: - subcell.activation = self._convert_activation(subcell.activation) - elif subcell.after_fake: - subcell.has_act = True - subcell.activation = _AddFakeQuantAfterSubCell(F.identity, - quant_dtype=self.act_dtype, - quant_delay=self.act_qdelay, - per_channel=self.act_channel, - symmetric=self.act_symmetric, - narrow_range=self.act_range, - optimize_option=self.optimize_option) - return subcell - - def _convert_activation(self, activation): - """ - convert activation cell to quant cell - """ - act_class = activation.__class__ - act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid] - act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish] - - if act_class in act_list: - return quant.ActQuant(activation=activation, - quant_config=self.quant_config, - quant_dtype=self.act_dtype) - if act_class in act_list_with_fake_before: - return quant.ActQuant(activation=activation, - ema=True, - fake_before=True, - quant_config=self.quant_config, - quant_dtype=self.act_dtype) - raise ValueError("Unsupported activation in auto quant: ", act_class) - - def _kl_init(self, subcell_weight_para, weight_dtype): - """ - Calculate the value of max_init and min_init with compute_kl_threshold. - """ - if self.weight_channel: - max_init = [compute_kl_threshold(weight_para_each, weight_dtype) - for weight_para_each in subcell_weight_para] - min_init = [-x for x in max_init] - else: - max_init = [compute_kl_threshold(subcell_weight_para, weight_dtype)] - min_init = [-x for x in max_init] - return min_init, max_init - - def _set_mixed_bits(self, network, strategy): - r""" - Set network's quantization strategy, this function is currently only valid for `LEARNED_SCALE` - optimize_option. - - Args: - network (Cell): Input network. - strategy (list): The quantization strategy for layers that need to be quantified (eg. [[8], [8], - ..., [6], [4], [8]]), currently only the quant_dtype for weights of the dense layer and the - convolution layer is supported. - - Returns: - Cell, a network with mixed bit strategy configured. - - Raises: - ValueError: If `OptimizeOption.LEARNED_SCALE` is not in `self.optimize_option`. - """ - if OptimizeOption.LEARNED_SCALE not in self.optimize_option: - raise ValueError("The `_set_mixed_bits` function is currently only valid for `LEARNED_SCALE` " - "optimize_option.") - - quantizable_idx = [] - pass_cell = None - for i, cell_and_name in enumerate(network.cells_and_names()): - cell = cell_and_name[1] - if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)) and cell is not pass_cell: - quantizable_idx.append(i) - - if len(quantizable_idx) != len(strategy): - raise ValueError("The dimension of quantifiable layers is not consistent with that of strategy.") - - quantizable_layer_bit_dict = {idx: bit for idx, bit in zip(quantizable_idx, strategy)} - type_map = { - QuantDtype.INT2.num_bits: QuantDtype.INT2, - QuantDtype.INT3.num_bits: QuantDtype.INT3, - QuantDtype.INT4.num_bits: QuantDtype.INT4, - QuantDtype.INT5.num_bits: QuantDtype.INT5, - QuantDtype.INT6.num_bits: QuantDtype.INT6, - QuantDtype.INT7.num_bits: QuantDtype.INT7, - QuantDtype.INT8.num_bits: QuantDtype.INT8 - } - for i, cell_and_name in enumerate(network.cells_and_names()): - cell = cell_and_name[1] - if i not in quantizable_idx: - continue - else: - if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)): - cell.weight_dtype = type_map[quantizable_layer_bit_dict[i][0]] - if isinstance(cell, nn.Conv2dBnAct): - subcell_weight_para = cell.conv.weight.data.asnumpy() - if hasattr(cell.conv, 'gamma'): - scale_factor = (cell.conv.gamma.data.asnumpy() / - np.sqrt(cell.conv.moving_variance.data.asnumpy() + self.eps)) - subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) - min_init, max_init = self._kl_init(subcell_weight_para, cell.weight_dtype) - cell.conv.fake_quant_weight.reset(quant_dtype=cell.weight_dtype, - min_init=min_init, - max_init=max_init) - elif isinstance(cell, nn.DenseBnAct): - subcell_weight_para = cell.dense.weight.data.asnumpy() - if hasattr(cell.dense, 'gamma'): - scale_factor = (cell.dense.gamma.data.asnumpy() / - np.sqrt(cell.dense.moving_variance.data.asnumpy() + self.eps)) - subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) - min_init, max_init = self._kl_init(subcell_weight_para, cell.weight_dtype) - cell.dense.fake_quant_weight.reset(quant_dtype=cell.weight_dtype, - min_init=min_init, - max_init=max_init) - return network diff --git a/mindspore/python/mindspore/compression/quant/quant_utils.py b/mindspore/python/mindspore/compression/quant/quant_utils.py deleted file mode 100644 index 9406f3f164b..00000000000 --- a/mindspore/python/mindspore/compression/quant/quant_utils.py +++ /dev/null @@ -1,440 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Quantization utils.""" - -import numpy as np -from mindspore._checkparam import Validator -from ... import nn - -__all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers"] - - -def cal_quantization_params(input_min, - input_max, - quant_min, - quant_max, - data_type, - symmetric=False): - r""" - Calculate quantization params for scale and zero point. - - Args: - input_min (numpy.ndarray): The dimension of channel or 1. - input_max (numpy.ndarray): The dimension of channel or 1. - quant_min (int): The minimum quantization integer. - quant_max (int): The maximum quantization integer. - data_type (numpy type) : Can be numpy int8, numpy uint8. - symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. - - Returns: - scale (numpy.ndarray): quantization param. - zero point (numpy.ndarray): quantization param. - """ - input_max = np.maximum(0.0, input_max) - input_min = np.minimum(0.0, input_min) - - if input_min.shape != input_max.shape: - raise ValueError("input min shape should be equal to input max.") - if len(input_min.shape) > 1: - raise ValueError("input min and max shape should be one dim.") - if (input_min > input_max).all(): - raise ValueError("input_min min should be less than input max.") - if (input_max == input_min).all(): - return np.ones(input_min.shape), np.zeros(input_min.shape) - - # calculate scale - if symmetric: - input_max = np.maximum(-input_min, input_max) - input_min = -input_max - scale = (input_max - input_min) / (quant_max - quant_min) - - # calculate zero point - if data_type == np.int8 and symmetric: - zp = np.zeros(input_min.shape) - else: - zp_double = quant_min - input_min / scale - zp = np.floor(zp_double + 0.5) - - return scale, zp - - -def get_quant_min_max(data_type, num_bits=8, narrow_range=False): - """Calculate quantization params for minimum/maximum quantization integer""" - if data_type == np.int8: - quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - 1 - elif data_type == np.uint8: - quant_min = 0 - quant_max = 2 ** num_bits - 1 - else: - raise ValueError("Unsupported datatype({})".format(data_type)) - if narrow_range: - quant_min = quant_min + 1 - return quant_min, quant_max - - -def weight2int(data, scale, zero_point, quant_min, quant_max): - r""" - Calculate int8/uint8 weight from fp32. the formula is defined as: - - .. math:: - int8/uint8 = round(float/scale) + offset - - Args: - data (numpy.ndarray): The dimension of channel or 1. Should be NCHW. - scale (numpy.ndarray): The dimension of channel or 1. - zero_point (numpy.ndarray): The dimension of channel or 1. - quant_min (int): The minimum quantization integer. - quant_max (int): The maximum quantization integer. - - Returns: - weight (numpy.ndarray): The dimension of channel or 1. - """ - if scale.shape != zero_point.shape: - raise ValueError("`scale` and `zero_point` should have the same shape.") - if scale.shape[0] < 0: - raise ValueError("`scale` and `zero_point` shape should be greater than zero.") - if len(scale.shape) >= 1 and scale.shape[0] > 1: - # for perchannel - if scale.shape[0] == data.shape[0]: - # `Conv2d` or `Dense` op weight - shape_list = [-1] + [1] * len(data.shape[1:]) - scale = scale.reshape(shape_list) - zero_point = zero_point.reshape(shape_list) - elif scale.shape[0] == data.shape[1]: - # `DepthwiseConv2d` op weight - shape_list = [1, -1] + [1] * len(data.shape[2:]) - scale = scale.reshape(shape_list) - zero_point = zero_point.reshape(shape_list) - else: - raise ValueError("Unsupported weight shape({})".format(data.shape)) - - weight_int = np.round((data / scale) + zero_point) - weight_int[weight_int > quant_max] = quant_max - weight_int[weight_int < quant_min] = quant_min - return weight_int - - -def scale_zp_max_min_from_fake_quant_cell(cell, data_type): - """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMaxObserver`.""" - minq = cell.minq.data.asnumpy() - maxq = cell.maxq.data.asnumpy() - # make sure maxq > 0 and minq <= 0 - if cell.mode == 'LEARNED_SCALE': - maxq = np.abs(maxq) - minq = -np.abs(minq) - quant_min, quant_max = get_quant_min_max(data_type, num_bits=cell.num_bits, narrow_range=cell.narrow_range) - symmetric = cell.symmetric and not cell.neg_trunc - scale, zp = cal_quantization_params( - minq, maxq, - quant_min, quant_max, data_type, - symmetric=symmetric) - return scale, zp, maxq, minq - - -def fold_batchnorm(weight, cell_quant): - r""" - Fold the batchnorm in `Conv2dBnFoldQuant` to weight. - - Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. - - Args: - weight (numpy.ndarray): Weight of `cell_quant`. - cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`. - - Returns: - weight (numpy.ndarray): Folded weight. - bias (numpy.ndarray): Folded bias. - """ - variance = cell_quant.moving_variance.data.asnumpy() - mean = cell_quant.moving_mean.data.asnumpy() - gamma = cell_quant.gamma.data.asnumpy() - beta = cell_quant.beta.data.asnumpy() - epsilon = cell_quant.eps - sigma = np.sqrt(variance + epsilon) - - if gamma.shape[0] == weight.shape[0]: - # `Conv2d` or `Dense` op weight - shape_list = [-1] + [1] * len(weight.shape[1:]) - _gamma = gamma.reshape(shape_list) - _sigma = sigma.reshape(shape_list) - elif gamma.shape[0] == weight.shape[1]: - # `DepthwiseConv2d` op weight - shape_list = [1, -1] + [1] * len(weight.shape[2:]) - _gamma = gamma.reshape(shape_list) - _sigma = sigma.reshape(shape_list) - else: - raise ValueError("Unsupported weight shape({})".format(weight.shape)) - - weight = weight * _gamma / _sigma - bias = beta - gamma * mean / sigma - return weight, bias - - -def without_fold_batchnorm(weight, cell_quant): - r""" - Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight. - - Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. - - Args: - weight (numpy.ndarray): Weight of `cell_quant`. - cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`. - - Returns: - weight (numpy.ndarray): whihout folded weight. - bias (numpy.ndarray): without folded bias. - """ - variance = cell_quant.batchnorm.moving_variance.data.asnumpy() - mean = cell_quant.batchnorm.moving_mean.data.asnumpy() - gamma = cell_quant.batchnorm.gamma.data.asnumpy() - beta = cell_quant.batchnorm.beta.data.asnumpy() - epsilon = cell_quant.batchnorm.eps - sigma = np.sqrt(variance + epsilon) - - if gamma.shape[0] == weight.shape[0]: - # `Conv2d` or `Dense` op weight - shape_list = [-1] + [1] * len(weight.shape[1:]) - _gamma = gamma.reshape(shape_list) - _sigma = sigma.reshape(shape_list) - elif gamma.shape[0] == weight.shape[1]: - # `DepthwiseConv2d` op weight - shape_list = [1, -1] + [1] * len(weight.shape[2:]) - _gamma = gamma.reshape(shape_list) - _sigma = sigma.reshape(shape_list) - else: - raise ValueError("Unsupported weight shape({})".format(weight.shape)) - - weight = weight * _gamma / _sigma - bias = beta - gamma * mean / sigma - return weight, bias - - -def compute_kl_threshold(data, bitwidth): - r""" - Using KL-J Distance to calculate the clip threshold. - - Args: - - **data** (NumpyArray) - Data observed to calculate the threshold for quantization, - - **bitwidth** (QuantDtype) - The datatype of quantization. - Outputs: - Tensor with Shape 1. Threshold to calculate the data. - """ - data_max = np.abs(data).max() - if data_max < 1e-5: - return 1e-5 - hist, bin_edges = np.histogram(np.abs(data), bins='sqrt', range=(0, data_max), density=True) - # For the sake of high efficiency, we limit the maximum number of bins to 1024 in `sqrt` mode, If it exceeds the - # largest size, turn to use the default bins config. - largest_bin_size = 1024 - if hist.shape[0] > largest_bin_size: - hist, bin_edges = np.histogram(np.abs(data), range=(0, data_max), density=True) - hist = hist / np.sum(hist) - cumsum = np.cumsum(hist) - bit_pow_range = pow(2, int(bitwidth.num_bits) - 1) - threshold = [] - scaling_factor = [] - kl = [] - if bit_pow_range + 1 > len(bin_edges) - 1: - th_layer_out = bin_edges[-1] - return float(th_layer_out) - for i in range(bit_pow_range + 1, len(bin_edges), 1): - threshold_tmp = (i + 0.5) * (bin_edges[1] - bin_edges[0]) - threshold = np.concatenate((threshold, [threshold_tmp])) - scaling_factor_tmp = threshold_tmp / (bit_pow_range - 1) - scaling_factor = np.concatenate((scaling_factor, [scaling_factor_tmp])) - # forward interpolation - cumsum_tmp = np.copy(cumsum) - cumsum_tmp[(i - 1):] = 1 - fwd_x = np.linspace(0.0, 1.0, bit_pow_range) - fwd_xp = np.linspace(0.0, 1.0, i) - fwd_fp = cumsum_tmp[:i] - forward_interp = np.interp(fwd_x, fwd_xp, fwd_fp) - # backward interpolation - bwd_x = np.linspace(0.0, 1.0, i) - bwd_xp = np.linspace(0.0, 1.0, bit_pow_range) - bwd_fp = forward_interp - backward_interp = np.interp(bwd_x, bwd_xp, bwd_fp) - cumsum_tmp[:i] = backward_interp - kl_tmp = np.sum((cumsum - cumsum_tmp) * np.log2(cumsum / cumsum_tmp)) # Kullback-Leibler-J - kl = np.concatenate((kl, [kl_tmp])) - th_layer_out = threshold[np.argmin(kl)] - threshold = float(th_layer_out) - if threshold < 1e-5: - threshold = 1e-5 - return threshold - - -def query_quant_layers(network): - r""" - Query the network's quantization strategy of each quantized layer and print it to the screen, note that all the - quantization layers are queried before graph compile optimization in the graph mode, thus, some redundant quantized - layers, which not exist in practical execution, may appear. - - Args: - network (Cell): input network - - Examples: - >>> from mindspore.compression.quant import QuantizationAwareTraining - >>> from mindspore.compression.quant.quant_utils import query_quant_layers - >>> class LeNet5(nn.Cell): - ... def __init__(self, num_class=10, channel=1): - ... super(LeNet5, self).__init__() - ... self.type = "fusion" - ... self.num_class = num_class - ... - ... # change `nn.Conv2d` to `nn.Conv2dBnAct` - ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') - ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') - ... # change `nn.Dense` to `nn.DenseBnAct` - ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') - ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu') - ... self.fc3 = nn.DenseBnAct(84, self.num_class) - ... - ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - ... self.flatten = nn.Flatten() - ... - ... def construct(self, x): - ... x = self.conv1(x) - ... x = self.max_pool2d(x) - ... x = self.conv2(x) - ... x = self.max_pool2d(x) - ... x = self.flatten(x) - ... x = self.fc1(x) - ... x = self.fc2(x) - ... x = self.fc3(x) - ... return x - ... - >>> net = LeNet5() - >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) - >>> net_qat = quantizer.quantize(net) - >>> query_quant_layers(net_qat) - conv1.conv.fake_quant_weight INT8 - conv1.activation.fake_quant_act INT8 - conv2.conv.fake_quant_weight INT8 - conv2.activation.fake_quant_act INT8 - fc1.dense.fake_quant_weight INT8 - fc1.activation.fake_quant_act INT8 - fc2.dense.fake_quant_weight INT8 - fc2.activation.fake_quant_act INT8 - fc3.dense.fake_quant_weight INT8 - fc3.activation.fake_quant_act INT8 - """ - network = Validator.check_isinstance("network", network, nn.Cell) - tplt = "{0:60}\t{1:10}" - for cell_and_name in network.cells_and_names(): - cell_name = cell_and_name[0] - cell = cell_and_name[1] - if isinstance(cell, nn.FakeQuantWithMinMaxObserver): - print(tplt.format(cell_name, cell.quant_dtype)) - - -def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None): - r""" - Load fp32 model parameters into quantization model. - - Args: - quant_model(Cell): Quantization model. - params_dict(dict): Parameter dict that stores fp32 parameters. - quant_new_params(list): Parameters that exist in quantization network but not in non-quantization - network. Default: None. - - Raises: - TypeError: If `quant_new_params` is not None and is not list. - ValueError: If there are parameters in the `quant_model` that are neither in `params_dict` - nor in `quant_new_params`. - - Examples: - >>> from mindspore import load_checkpoint - >>> from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net - >>> class LeNet5(nn.Cell): - ... def __init__(self, num_class=10, channel=1): - ... super(LeNet5, self).__init__() - ... self.type = "fusion" - ... self.num_class = num_class - ... - ... # change `nn.Conv2d` to `nn.Conv2dBnAct` - ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') - ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') - ... # change `nn.Dense` to `nn.DenseBnAct` - ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') - ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu') - ... self.fc3 = nn.DenseBnAct(84, self.num_class) - ... - ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - ... self.flatten = nn.Flatten() - ... - ... def construct(self, x): - ... x = self.conv1(x) - ... x = self.max_pool2d(x) - ... x = self.conv2(x) - ... x = self.max_pool2d(x) - ... x = self.flatten(x) - ... x = self.fc1(x) - ... x = self.fc2(x) - ... x = self.fc3(x) - ... return x - ... - >>> net = LeNet5() - >>> ckpt_file_name = "./checkpoint/LeNet5_noquant-1_32.ckpt" - >>> param_dict = load_checkpoint(ckpt_file_name) - >>> load_nonquant_param_into_quant_net(net, param_dict) - """ - if quant_new_params is not None and not isinstance(quant_new_params, list): - raise TypeError("quant_new_params must be list or None.") - iterable_dict = { - 'minq': iter(list(filter(lambda item: item[0].endswith('minq'), params_dict.items()))), - 'maxq': iter(list(filter(lambda item: item[0].endswith('maxq'), params_dict.items()))), - 'quant_max': iter(list(filter(lambda item: item[0].endswith('quant_max'), params_dict.items()))) - } - for param in params_dict.items(): - key_name = param[0].split(".")[-1] - if key_name not in iterable_dict: - iterable_dict[key_name] = iter(list(filter(lambda item, value=key_name: item[0].endswith(value), - params_dict.items()))) - - for name, param in quant_model.parameters_and_names(): - key_name = name.split(".")[-1] - if key_name not in iterable_dict.keys(): - if key_name not in quant_new_params: - raise ValueError(f"Can't find match parameter in ckpt, param name = {name}") - continue - value_param = next(iterable_dict[key_name], None) - if value_param: - param.set_data(value_param[1].data) - print(f'init model param {name} with checkpoint param {value_param[0]}') - - - # Perform KL_init when learned scale quantization is executed. - for cell_and_name in quant_model.cells_and_names(): - cell = cell_and_name[1] - if isinstance(cell, (nn.Conv2dBnFoldQuantOneConv, nn.Conv2dBnFoldQuant, nn.Conv2dBnWithoutFoldQuant, - nn.Conv2dQuant, nn.DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE": - subcell_weight_para = cell.weight.data.asnumpy() - if hasattr(cell, 'gamma'): - scale_factor = (cell.gamma.data.asnumpy() / - np.sqrt(cell.moving_variance.data.asnumpy() + 1e-5)) - subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) - - if cell.fake_quant_weight.per_channel: - max_init = [compute_kl_threshold(weight_para_each, cell.fake_quant_weight.quant_dtype) - for weight_para_each in subcell_weight_para] - min_init = [-x for x in max_init] - else: - max_init = [compute_kl_threshold(subcell_weight_para, cell.fake_quant_weight.quant_dtype)] - min_init = [-x for x in max_init] - - cell.fake_quant_weight.reset(quant_dtype=cell.fake_quant_weight.quant_dtype, - min_init=min_init, max_init=max_init) diff --git a/mindspore/python/mindspore/compression/quant/quantizer.py b/mindspore/python/mindspore/compression/quant/quantizer.py deleted file mode 100644 index 7d0bc1096c7..00000000000 --- a/mindspore/python/mindspore/compression/quant/quantizer.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Base Class of Quantizer.""" - -from abc import ABC, abstractmethod -from enum import Enum - -from ..._checkparam import Validator - -__all__ = ["OptimizeOption"] - - -class OptimizeOption(Enum): - r""" - An enum for the model quantization optimize option, currently only support `QAT` and `LEARNED_SCALE`. - """ - # using quantization aware training - QAT = "QAT" - - # using the learned scale quantization - LEARNED_SCALE = "LEARNED_SCALE" - - def __str__(self): - return self.value - - -class Quantizer(ABC): - """ - Base class of Quantizer. You can implement different kind of quantizer to get different quantization result. - - Notes: - This class is an abstract class. - - Args: - optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options. Default: - OptimizeOption.QAT. - """ - def __init__(self, - optimize_option=OptimizeOption.QAT): - if not isinstance(optimize_option, list) and not isinstance(optimize_option, tuple): - optimize_option = [optimize_option] - for option in optimize_option: - option = Validator.check_isinstance("optimize_option", option, OptimizeOption) - self.optimize_option = optimize_option - - @abstractmethod - def quantize(self, network): - """ - Quant API to convert input network to a quantization aware training network - Args: - network (Cell): network to be quantized. - """ diff --git a/mindspore/python/mindspore/nn/layer/__init__.py b/mindspore/python/mindspore/nn/layer/__init__.py index 5a0836fe035..39363759c4d 100644 --- a/mindspore/python/mindspore/nn/layer/__init__.py +++ b/mindspore/python/mindspore/nn/layer/__init__.py @@ -17,7 +17,7 @@ Layer. The high-level components(Cells) used to construct the neural network. """ -from . import activation, normalization, container, conv, basic, embedding, pooling, image, quant, math, \ +from . import activation, normalization, container, conv, basic, embedding, pooling, image, math, \ combined, timedistributed, thor_layer, rnns, rnn_cells from .activation import * from .normalization import * @@ -29,7 +29,6 @@ from .basic import * from .embedding import * from .pooling import * from .image import * -from .quant import * from .math import * from .combined import * from .timedistributed import * @@ -46,7 +45,6 @@ __all__.extend(basic.__all__) __all__.extend(embedding.__all__) __all__.extend(pooling.__all__) __all__.extend(image.__all__) -__all__.extend(quant.__all__) __all__.extend(math.__all__) __all__.extend(combined.__all__) __all__.extend(timedistributed.__all__) diff --git a/mindspore/python/mindspore/train/serialization.py b/mindspore/python/mindspore/train/serialization.py index 38df0fc766e..02c2e3e9f87 100644 --- a/mindspore/python/mindspore/train/serialization.py +++ b/mindspore/python/mindspore/train/serialization.py @@ -45,7 +45,6 @@ from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor from mindspore.communication.management import get_rank, get_group_size -from mindspore.compression.export import quant_export from mindspore.parallel._cell_wrapper import get_allgather_cell from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index from mindspore.parallel._tensor import _reshape_param_data @@ -788,7 +787,6 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs): check_input_data(*inputs, data_class=Tensor) Validator.check_file_name_by_regular(file_name) file_name = os.path.realpath(file_name) - net = _quant_export(net, *inputs, file_format=file_format, **kwargs) if 'enc_key' in kwargs.keys(): if file_format != 'MINDIR': raise ValueError(f"For 'export', 'enc_key' can be passed in only when 'file_format' == 'MINDIR'," @@ -1077,47 +1075,6 @@ def quant_mode_manage(func): return warpper - -@quant_mode_manage -def _quant_export(network, *inputs, file_format, **kwargs): - """ - Exports MindSpore quantization predict model to deploy with AIR and MINDIR. - """ - supported_device = ["Ascend", "GPU"] - supported_formats = ['AIR', 'MINDIR'] - quant_mode_formats = ['QUANT', 'NONQUANT'] - - quant_mode = kwargs['quant_mode'] - if quant_mode not in quant_mode_formats: - raise KeyError(f"For 'export', the argument 'quant_mode' must be one of {quant_mode_formats}, " - f"but got {quant_mode}.") - if quant_mode == 'NONQUANT': - return network - quant_net = copy.deepcopy(network) - quant_net._create_time = int(time.time() * 1e9) - - mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean'] - std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev'] - mean = Validator.check_value_type("mean", mean, (int, float)) - std_dev = Validator.check_value_type("std_dev", std_dev, (int, float)) - - if context.get_context('device_target') not in supported_device: - raise KeyError(f"For 'export', quant export only support {supported_device} device target now, " - f"but got {context.get_context('device_target')}") - - if file_format not in supported_formats: - raise ValueError(f"For 'export', quant export only support 'file_format' {supported_formats}, " - f"but got {file_format}.") - - quant_net.set_train(False) - if file_format == "MINDIR": - exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True) - else: - exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs) - deploy_net = exporter.run() - return deploy_net - - def parse_print(print_file_name): """ Parse saved data generated by mindspore.ops.Print. Print is used to print data to screen in graph mode. -- Gitee From 9ed8a91237d636e9cdf5c3b6f4940be481ffc077 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Fri, 17 Dec 2021 09:49:27 +0800 Subject: [PATCH 03/34] add golden_stick & build --- cmake/package.cmake | 3 +- mindspore/python/mindspore/__init__.py | 4 + .../python/mindspore/golden_stick/__init__.py | 26 + .../golden_stick/example/__init__.py | 0 .../example/default_qat_example.py | 96 +++ .../golden_stick/example/pruner_example.py | 63 ++ .../example/simple_qat_example.py | 50 ++ .../mindspore/golden_stick/golden_stick.py | 66 ++ .../mindspore/golden_stick/legacy/__init__.py | 0 .../mindspore/golden_stick/legacy/constant.py | 118 ++++ .../golden_stick/legacy/export/__init__.py | 17 + .../legacy/export/quant_export.py | 500 ++++++++++++++ .../golden_stick/legacy/quant/__init__.py | 25 + .../golden_stick/legacy/quant/qat.py | 617 ++++++++++++++++++ .../golden_stick/legacy/quant/quant_utils.py | 439 +++++++++++++ .../golden_stick/legacy/quant/quantizer.py | 65 ++ .../mindspore/golden_stick/net_transform.py | 158 +++++ .../mindspore/golden_stick/pruner/__init__.py | 0 .../golden_stick/pruner/simple_pruner.py | 97 +++ .../golden_stick/quantization/__init__.py | 28 + .../quantization/default_qat/__init__.py | 25 + .../default_qat/default_layer_policy.py | 80 +++ .../default_qat/default_net_policy.py | 43 ++ .../default_qat/default_quantize.py | 30 + .../default_qat/default_quantizer.py | 68 ++ .../quantization/hello_qat/__init__.py | 0 .../quantization/hello_qat/simple_qat.py | 86 +++ .../golden_stick/quantization/layer_policy.py | 101 +++ .../golden_stick/quantization/net_policy.py | 48 ++ .../golden_stick/quantization/quantize.py | 126 ++++ .../quantization/quantize_wrapper_act.py | 45 ++ .../quantization/quantize_wrapper_cell.py | 105 +++ .../golden_stick/quantization/quantizer.py | 49 ++ .../golden_stick/quantization/test_common.py | 51 ++ .../golden_stick/quantization/transformer.py | 100 +++ .../quantization/transformer_test.py | 219 +++++++ .../python/mindspore/rewrite/__init__.py | 4 +- mindspore/python/mindspore/rewrite/node.py | 35 +- .../mindspore/rewrite/pattern_engine.py | 8 +- 39 files changed, 3575 insertions(+), 20 deletions(-) create mode 100644 mindspore/python/mindspore/golden_stick/__init__.py create mode 100644 mindspore/python/mindspore/golden_stick/example/__init__.py create mode 100644 mindspore/python/mindspore/golden_stick/example/default_qat_example.py create mode 100644 mindspore/python/mindspore/golden_stick/example/pruner_example.py create mode 100644 mindspore/python/mindspore/golden_stick/example/simple_qat_example.py create mode 100644 mindspore/python/mindspore/golden_stick/golden_stick.py create mode 100644 mindspore/python/mindspore/golden_stick/legacy/__init__.py create mode 100644 mindspore/python/mindspore/golden_stick/legacy/constant.py create mode 100644 mindspore/python/mindspore/golden_stick/legacy/export/__init__.py create mode 100644 mindspore/python/mindspore/golden_stick/legacy/export/quant_export.py create mode 100644 mindspore/python/mindspore/golden_stick/legacy/quant/__init__.py create mode 100644 mindspore/python/mindspore/golden_stick/legacy/quant/qat.py create mode 100644 mindspore/python/mindspore/golden_stick/legacy/quant/quant_utils.py create mode 100644 mindspore/python/mindspore/golden_stick/legacy/quant/quantizer.py create mode 100644 mindspore/python/mindspore/golden_stick/net_transform.py create mode 100644 mindspore/python/mindspore/golden_stick/pruner/__init__.py create mode 100644 mindspore/python/mindspore/golden_stick/pruner/simple_pruner.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/__init__.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantize.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantizer.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/hello_qat/__init__.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/hello_qat/simple_qat.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/layer_policy.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/net_policy.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/quantize.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/quantizer.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/test_common.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/transformer.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/transformer_test.py diff --git a/cmake/package.cmake b/cmake/package.cmake index 4ce0862a6ff..c4c4d1aa320 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -319,7 +319,8 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/ops ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/communication ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/profiler - ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression + ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite + ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/golden_stick ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check DESTINATION ${INSTALL_PY_DIR} COMPONENT mindspore diff --git a/mindspore/python/mindspore/__init__.py b/mindspore/python/mindspore/__init__.py index 535bdfec886..ad9abce3011 100755 --- a/mindspore/python/mindspore/__init__.py +++ b/mindspore/python/mindspore/__init__.py @@ -21,9 +21,11 @@ from .common import * from .mindrecord import * from .ops import _op_impl from .train import * +from .rewrite import * from .log import * from .context import * from .version import __version__ +from .golden_stick import * __all__ = ["run_check"] @@ -32,3 +34,5 @@ __all__.extend(common.__all__) __all__.extend(train.__all__) __all__.extend(log.__all__) __all__.extend(context.__all__) +__all__.extend(rewrite.__all__) +__all__.extend(golden_stick.__all__) diff --git a/mindspore/python/mindspore/golden_stick/__init__.py b/mindspore/python/mindspore/golden_stick/__init__.py new file mode 100644 index 00000000000..d568e043a05 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +MindSpore golden stick module. +""" + +from .golden_stick import GoldenStick +from .net_transform import NetTransformer +from .quantization import LayerPolicy, NetPolicy, QuantAwareTraining, Quantizer, Transformer, AllValueQuantizer, \ + LastValueQuantizer, LSQ, DefaultLayerPolicy, DefaultNetworkPolicy, DefaultQuantAwareTraining + +__all__ = ["GoldenStick", "NetTransformer", "LayerPolicy", "NetPolicy", "QuantAwareTraining", "Quantizer" + , "Transformer", "AllValueQuantizer", "LastValueQuantizer", "LSQ", "DefaultLayerPolicy", "DefaultNetworkPolicy", \ + "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/golden_stick/example/__init__.py b/mindspore/python/mindspore/golden_stick/example/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindspore/python/mindspore/golden_stick/example/default_qat_example.py b/mindspore/python/mindspore/golden_stick/example/default_qat_example.py new file mode 100644 index 00000000000..d233b7e9ad1 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/example/default_qat_example.py @@ -0,0 +1,96 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from typing import List, Tuple + +from ..quantization.default_qat.default_quantize import DefaultQuantAwareTraining +from mindspore.nn import Cell, Conv2d, BatchNorm2d, Dense, ReLU, MaxPool2d, Flatten, SoftmaxCrossEntropyWithLogits, \ + Momentum +from mindspore.train.model import Model +from ..quantization.transformer import Transformer +from ..quantization.layer_policy import LayerPolicy +from ..quantization.quantizer import Quantizer + + +class LeNet5(Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = Conv2d(1, 6, 5) + self.fc1 = Dense(16 * 5 * 5, 120) + self.relu = ReLU() + self.max_pool2d = MaxPool2d(kernel_size=2, stride=2) + self.flatten = Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + return x + + +# custom quantizer +class AllBitQuantizer(Quantizer): + """ + Derived class of QuantizeOp. Use min and max value of data to compute scale and zero-point. + """ + + def __init__(self, bit_num=8): + super().__init__() + self._bit_num = bit_num + + def compute(self, data: [float]) -> Tuple[List[float], float, int]: + min = data[0] + max = data[1] + scale = (1 << self._bit_num) / (max - min) + zp = max * scale + return data, scale, zp + + +# custom layer policy +class ConvBNQPolicy(LayerPolicy): + def __init__(self): + super().__init__() + self._quantizer = AllBitQuantizer() + + def get_weight_name_and_quantizers(self) -> [(str, Quantizer)]: + # todo how to define weight inside of a subgraph + return [("_old_conv.weight", self._quantizer), ("_old_bn.gamma", self._quantizer), + ("_old_bn.beta", self._quantizer), ("_old_bn.moving_mean", self._quantizer), + ("_old_bn.moving_variance", self._quantizer)] + + def get_act_name_and_quantizers(self) -> [(str, (Quantizer, Quantizer))]: + return [] + + def get_output_quantizers(self) -> [Quantizer]: + return [self._quantizer] + + +net = LeNet5() + +custom_pattern_engine = Transformer([Conv2d, BatchNorm2d]) +custom_conv_bn_policy = ConvBNQPolicy() +# custom net policy +algo = DefaultQuantAwareTraining({"custom_transforms": [custom_pattern_engine], + "custom_policies": [custom_conv_bn_policy]}) +net_opt = algo.apply(net) + +loss = algo.loss(SoftmaxCrossEntropyWithLogits()) +optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) +model = Model(net_opt, loss_fn=loss, optimizer=optimizer, metrics=None) +dataset = {} +model.train(2, dataset, algo.callback()) diff --git a/mindspore/python/mindspore/golden_stick/example/pruner_example.py b/mindspore/python/mindspore/golden_stick/example/pruner_example.py new file mode 100644 index 00000000000..54e1cbe53f8 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/example/pruner_example.py @@ -0,0 +1,63 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from ..pruner.simple_pruner import PrunerCompressAlgo +from mindspore.nn import Cell, Conv2d, Dense, ReLU, MaxPool2d, Flatten, SoftmaxCrossEntropyWithLogits, Momentum +from mindspore.train.model import Model +from mindspore.train.callback import Callback + + +class LeNet5(Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = Conv2d(1, 6, 5) + self.fc1 = Dense(16 * 5 * 5, 120) + self.relu = ReLU() + self.max_pool2d = MaxPool2d(kernel_size=2, stride=2) + self.flatten = Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + return x + + +class ProfileCallBacks(Callback): + def __init__(self): + pass + + def step_end(self, run_context): + origin_args = run_context.original_args() + cur_step_num = origin_args.cur_step_num + print("cur step: ", cur_step_num) + + +net = LeNet5() + +algo = PrunerCompressAlgo({"begin_step": 1, "end_step": 100, "frequency": 1, "target_sparsity": 0.8}) +net_opt = algo.apply(net) + +loss = algo.loss(SoftmaxCrossEntropyWithLogits()) +optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) +model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None) +dataset = {} +profiling_cbs = ProfileCallBacks() +callbacks = [algo.callbacks(), profiling_cbs] +model.train(2, dataset, callbacks) diff --git a/mindspore/python/mindspore/golden_stick/example/simple_qat_example.py b/mindspore/python/mindspore/golden_stick/example/simple_qat_example.py new file mode 100644 index 00000000000..a99a6410a5d --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/example/simple_qat_example.py @@ -0,0 +1,50 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from ..quantization.hello_qat.simple_qat import QATCompressAlgo +from mindspore.nn import Cell, Conv2d, Dense, ReLU, MaxPool2d, Flatten, SoftmaxCrossEntropyWithLogits, Momentum +from mindspore.train.model import Model + + +class LeNet5(Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = Conv2d(1, 6, 5) + self.fc1 = Dense(16 * 5 * 5, 120) + self.relu = ReLU() + self.max_pool2d = MaxPool2d(kernel_size=2, stride=2) + self.flatten = Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + return x + + +net = LeNet5() + +algo = QATCompressAlgo({"bit_num": 8, "per_channel": True}) +net_opt = algo.apply(net) + +loss = SoftmaxCrossEntropyWithLogits() +optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) +model = Model(net_opt, loss_fn=loss, optimizer=optimizer, metrics=None) +dataset = {} +model.train(2, dataset, algo.callback()) diff --git a/mindspore/python/mindspore/golden_stick/golden_stick.py b/mindspore/python/mindspore/golden_stick/golden_stick.py new file mode 100644 index 00000000000..3aa5e4f8e17 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/golden_stick.py @@ -0,0 +1,66 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""GoldenStick.""" + +from mindspore.nn.cell import Cell +from mindspore.train.callback import Callback + + +class GoldenStick: + """ + Base class of algorithms in GoldenStick. + + Args: + config (Dict): User config for network compression. Config specification is default by derived class. + """ + + def __init__(self, config): + self._config = config + + def apply(self, network: Cell) -> Cell: + """ + Define how to compress input `network`. This method must be overridden by all subclasses. + + Args: + network (Cell): Network to be compressed. + + Returns: + Compressed Network. + """ + + return network + + def callback(self) -> Callback: + """ + Define what task need to be done when training for QAT. + + Returns: + Instance of Callback + """ + + return Callback() + + def loss(self, loss_fn: callable) -> callable: + """ + Define how to adjust loss-function for algorithm. Subclass is not need to overridden this method if current algorithm not care loss-function. + + Args: + loss_fn (callable): Original loss function. + + Returns: + Adjusted loss function. + """ + + return loss_fn diff --git a/mindspore/python/mindspore/golden_stick/legacy/__init__.py b/mindspore/python/mindspore/golden_stick/legacy/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindspore/python/mindspore/golden_stick/legacy/constant.py b/mindspore/python/mindspore/golden_stick/legacy/constant.py new file mode 100644 index 00000000000..6ef251226db --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/legacy/constant.py @@ -0,0 +1,118 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Constant module for compression""" +import enum +import re +from types import DynamicClassAttribute + +__all__ = ["QuantDtype"] + + +@enum.unique +class QuantDtype(enum.Enum): + """ + An enum for quant datatype, contains `INT2` ~ `INT8`, `UINT2` ~ `UINT8`. + """ + INT2 = "INT2" + INT3 = "INT3" + INT4 = "INT4" + INT5 = "INT5" + INT6 = "INT6" + INT7 = "INT7" + INT8 = "INT8" + + UINT2 = "UINT2" + UINT3 = "UINT3" + UINT4 = "UINT4" + UINT5 = "UINT5" + UINT6 = "UINT6" + UINT7 = "UINT7" + UINT8 = "UINT8" + + def __str__(self): + return f"{self.name}" + + @staticmethod + def is_signed(dtype): + """ + Get whether the quant datatype is signed. + + Args: + dtype (QuantDtype): quant datatype. + + Returns: + bool, whether the input quant datatype is signed. + + Examples: + >>> quant_dtype = QuantDtype.INT8 + >>> is_signed = QuantDtype.is_signed(quant_dtype) + """ + return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5, + QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8] + + @staticmethod + def switch_signed(dtype): + """ + Switch the signed state of the input quant datatype. + + Args: + dtype (QuantDtype): quant datatype. + + Returns: + QuantDtype, quant datatype with opposite signed state as the input. + + Examples: + >>> quant_dtype = QuantDtype.INT8 + >>> quant_dtype = QuantDtype.switch_signed(quant_dtype) + """ + type_map = { + QuantDtype.INT2: QuantDtype.UINT2, + QuantDtype.INT3: QuantDtype.UINT3, + QuantDtype.INT4: QuantDtype.UINT4, + QuantDtype.INT5: QuantDtype.UINT5, + QuantDtype.INT6: QuantDtype.UINT6, + QuantDtype.INT7: QuantDtype.UINT7, + QuantDtype.INT8: QuantDtype.UINT8, + QuantDtype.UINT2: QuantDtype.INT2, + QuantDtype.UINT3: QuantDtype.INT3, + QuantDtype.UINT4: QuantDtype.INT4, + QuantDtype.UINT5: QuantDtype.INT5, + QuantDtype.UINT6: QuantDtype.INT6, + QuantDtype.UINT7: QuantDtype.INT7, + QuantDtype.UINT8: QuantDtype.INT8 + } + return type_map[dtype] + + @DynamicClassAttribute + def _value(self): + """The value of the Enum member.""" + return int(re.search(r"(\d+)", self._value_).group(1)) + + @DynamicClassAttribute + def num_bits(self): + """ + Get the num bits of the QuantDtype member. + + Returns: + int, the num bits of the QuantDtype member. + + Examples: + >>> from mindspore.compression.common import QuantDtype + >>> quant_dtype = QuantDtype.INT8 + >>> num_bits = quant_dtype.num_bits + >>> print(num_bits) + 8 + """ + return self._value diff --git a/mindspore/python/mindspore/golden_stick/legacy/export/__init__.py b/mindspore/python/mindspore/golden_stick/legacy/export/__init__.py new file mode 100644 index 00000000000..48e59baa71a --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/legacy/export/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Compression export module. +""" diff --git a/mindspore/python/mindspore/golden_stick/legacy/export/quant_export.py b/mindspore/python/mindspore/golden_stick/legacy/export/quant_export.py new file mode 100644 index 00000000000..54105365c2a --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/legacy/export/quant_export.py @@ -0,0 +1,500 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Export for quantization.""" + +import copy + +import numpy as np + +from ... import nn, ops +from ..._checkparam import Validator +from ...common import Tensor +from ...common import dtype as mstype +from ...common.api import _cell_graph_executor as _executor +from ...common.parameter import Parameter +from ...nn import Cell +from ...nn.layer import quant +from ...ops import operations as P +from ...ops import functional as F +from ...ops.operations import _inner_ops as inner +from ..quant import quant_utils +from ..quant.qat import _AddFakeQuantInput, _AddFakeQuantAfterSubCell + +__all__ = ["ExportToQuantInferNetwork"] + + +class QuantBlock(Cell): + r""" + A quant block of Conv/Dense, activation layer for Ascend deploy. + + Calculate Conv or Dense in Int8, with Quant and DeQuant. + + Notes: + This block is only for deploy, and not trainable. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. + batchnorm (bool): Specifies to used batchnorm or not. Default: None. + activation (string): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. + + Outputs: + Tensor of shape :math:`(N, out\_channels)`. + """ + + def __init__(self, + core_op, + weight, + quant_op, + dequant_op, + dequant_scale, + bias=None, + activation=None): + super(QuantBlock, self).__init__() + self.core_op = core_op + self.weight = weight + self.quant = quant_op + self.dequant = dequant_op + self.dequant_scale = dequant_scale + self.bias = bias + self.has_bias = bias is not None + self.activation = activation + self.has_act = activation is not None + self.bias_add = P.BiasAdd() + self.sub = P.Sub() + self.weight_offset = Parameter(np.zeros(1, dtype=np.int8), name='weight_offset') + + def construct(self, x): + x = self.quant(x) + if self.has_bias: + weight = self.sub(self.weight, self.weight_offset) + x = self.core_op(x, weight) + x = self.bias_add(x, self.bias) + else: + x = self.core_op(x, self.weight) + x = self.dequant(x, self.dequant_scale) + x = F.cast(x, mstype.float32) + if self.has_act: + x = self.activation(x) + return x + + def extend_repr(self): + s = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]' + if self.has_bias: + s += f', bias=shape[{self.bias.shape}]' + if self.has_act: + s += f', activation={self.activation}' + s += f', dequant={self.dequant}' + return s + + +class QuantMindirBlock(Cell): + """A quant binary block of Conv/Dense, activation layer for export MINDIR model. + + Args: + core_op (Cell): The operation cell. + weight (Tensor): The weight of the cell. + bias (Tensor): The bias of the cell. Default: None. + activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. + param_dict (dict): The information of the cell. + """ + + def __init__(self, + core_op, + weight, + bias=None, + activation=None, + param_dict=None): + + super(QuantMindirBlock, self).__init__() + self.core_op = core_op + if activation is not None: + self.core_op.add_prim_attr("activation_name", activation.__class__.__name__) + self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"])) + self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"])) + if param_dict["output_maxq"] is not None: + self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"])) + self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"])) + self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"])) + if hasattr(core_op, 'pad_mode'): + self.core_op.add_prim_attr("pad_mode", core_op.pad_mode) + self.core_op.add_prim_attr("act_num_bits", Tensor(8)) + self.core_op.add_prim_attr("weight_num_bits", Tensor(param_dict["weight_num_bits"])) + self.core_op.add_prim_attr("weight_narrow_range", Tensor(param_dict["weight_narrow_range"])) + if param_dict["input_narrow_range"] is not None: + self.core_op.add_prim_attr("input_narrow_range", Tensor(param_dict["input_narrow_range"])) + if param_dict["output_narrow_range"] is not None: + self.core_op.add_prim_attr("output_narrow_range", Tensor(param_dict["output_narrow_range"])) + if param_dict["input_maxq"] == 'None': + self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"])) + self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"])) + elif param_dict["input_maxq"] is not None: + self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"])) + self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"])) + + self.weight = weight + self.bias = bias + self.has_bias = bias is not None + self.activation = activation + self.has_act = activation is not None + self.bias_add = P.BiasAdd() + + def construct(self, x): + if self.has_bias: + x = self.core_op(x, self.weight) + x = self.bias_add(x, self.bias) + else: + x = self.core_op(x, self.weight) + if self.has_act: + x = self.activation(x) + return x + + def extend_repr(self): + s = f'core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]' + if self.has_bias: + s += f', bias=shape[{self.bias.shape}]' + if self.has_act: + s += f', activation={self.activation}' + return s + + +class ExportToQuantInferNetwork: + """ + Convert quantization aware network to infer network. + + Args: + network (Cell): MindSpore quantization aware training network. + inputs (Tensor): Input tensors of the `quantization aware training network`. + mean (int, float): The mean of input data after preprocessing, used for quantizing the first layer of network. + Default: 127.5. + std_dev (int, float): The variance of input data after preprocessing, used for quantizing the first layer + of network. Default: 127.5. + is_mindir (bool): Whether export MINDIR format. Default: False. + + Returns: + Cell, Infer network. + """ + + def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): + network = Validator.check_isinstance('network', network, (nn.Cell,)) + self.data_type = mstype.int8 + self.network = copy.deepcopy(network) + self.network_bk = copy.deepcopy(network) + self.get_inputs_table(inputs) + self.mean = mean + self.std_dev = std_dev + self.is_mindir = is_mindir + self.upcell = None + + def get_inputs_table(self, inputs): + """Get the input quantization parameters of quantization cell for quant export.""" + phase_name = 'export_quant' + graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False) + self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id) + + def run(self): + """Start to convert.""" + self.network.update_cell_prefix() + network = self.network + if isinstance(network, _AddFakeQuantInput): + network = network.network + network = self._convert_quant2deploy(network) + return network + + def _get_quant_block(self, cell_core, activation, fake_quant_a_out): + """convert network's quant subcell to deploy subcell""" + scale_a_in, zp_a_in, scale_w, zp_w, param_dict = self.__get_quant_param(cell_core, fake_quant_a_out) + + # Build the `Quant` `Dequant` op. + # Quant only support perlayer version. Need check here. + quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in)) + scale_deq = self.__get_dequant_scale(scale_a_in, scale_w) + dequant_op = inner.Dequant() + + if isinstance(activation, _AddFakeQuantAfterSubCell): + activation = activation.subcell + elif hasattr(activation, "get_origin"): + activation = activation.get_origin() + + # get op + if isinstance(cell_core, quant.DenseQuant): + op_core = P.MatMul() + else: + op_core = cell_core.conv + + # get the `weight` and `bias` + weight, bias, weight_b, bias_b = self.__get_weight_bias(cell_core, scale_a_in, scale_w, zp_w) + + if self.is_mindir: + block = QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict) + else: + block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) + return block + + def _get_input_quant_param(self, minq_name, np_type, param_dict): + """get input quant parameter for quant block""" + fake_quant_a_in_prefix = minq_name[:-5] + cells = self.network_bk.cells_and_names() + for cell in cells: + if cell[0].endswith(fake_quant_a_in_prefix): + fake_quant_a_in = cell[1] + break + scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ + quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type) + param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range + return scale_a_in, zp_a_in + + def __get_quant_param(self, cell_core, fake_quant_a_out): + """get parameter for quant block""" + w_minq_name = cell_core.fake_quant_weight.minq.name + w_maxq_name = cell_core.fake_quant_weight.maxq.name + np_type = mstype.dtype_to_nptype(self.data_type) + param_dict = dict() + param_dict["filter_maxq"] = None + param_dict["filter_minq"] = None + param_dict["output_maxq"] = None + param_dict["output_minq"] = None + param_dict["input_maxq"] = None + param_dict["input_minq"] = None + param_dict["input_narrow_range"] = None + param_dict["output_narrow_range"] = None + param_dict["weight_narrow_range"] = cell_core.fake_quant_weight.narrow_range + param_dict["mean"] = self.mean + param_dict["std_dev"] = self.std_dev + param_dict["symmetric"] = cell_core.fake_quant_weight.symmetric + param_dict["weight_num_bits"] = cell_core.fake_quant_weight.num_bits + + scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \ + quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) + if fake_quant_a_out is not None: + _, _, param_dict["output_maxq"], param_dict["output_minq"] = \ + quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) + param_dict["output_narrow_range"] = fake_quant_a_out.narrow_range + + info = self.quant_info_table.get(w_minq_name, None) + if not info: + info = self.quant_info_table.get(w_maxq_name, None) + if info: + _, minq_name = info + if minq_name == 'input': + scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ + (1 / self.std_dev), round(self.mean), 'None', 'None' + else: + scale_a_in, zp_a_in = self._get_input_quant_param(minq_name, np_type, param_dict) + else: + # skip quant layer + scale_a_in, zp_a_in = 1.0, 0.0 + return scale_a_in, zp_a_in, scale_w, zp_w, param_dict + + @staticmethod + def __get_dequant_scale(scale_a_in, scale_w): + """Get dequant scale""" + scale_deq = scale_a_in * scale_w + + # fuse parameter + # |--------|47:40|--------|39:32|--------|31:0| + # offset_w [8] shift_N [8] deq_scale [32] + float32_deq_scale = scale_deq.astype(np.float32) + uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32) + scale_length = scale_deq.size # channel + dequant_param = np.zeros(scale_length, dtype=np.uint64) + for index in range(scale_length): + dequant_param[index] += uint32_deq_scale[index] + scale_deq = Tensor(dequant_param, mstype.uint64) + return scale_deq + + def __get_weight_bias(self, cell_core, scale_a_in, scale_w, zp_w): + """Get weight and bias for quantizaiton""" + np_type = mstype.dtype_to_nptype(self.data_type) + weight = cell_core.weight.data.asnumpy() + bias = None + if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)): + if cell_core.has_bias: + bias = cell_core.bias.data.asnumpy() + elif isinstance(cell_core, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv)): + weight, bias = quant_utils.fold_batchnorm(weight, cell_core) + elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant): + weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core) + weight_b = weight + bias_b = bias + # apply the quant + quant_min, quant_max = quant_utils.get_quant_min_max(np_type, + cell_core.fake_quant_weight.num_bits, + cell_core.fake_quant_weight.narrow_range) + weight = quant_utils.weight2int(weight, scale_w, zp_w, quant_min, quant_max) + if bias is not None: + bias = Tensor(bias / scale_a_in / scale_w, mstype.int32) + + if isinstance(cell_core, quant.DenseQuant): + weight = np.transpose(weight) + weight_b = np.transpose(weight_b) + + weight = Tensor(weight, self.data_type) + weight_b = Tensor(weight_b) + if bias_b is not None: + bias_b = Tensor(bias_b, mstype.float32) + return weight, bias, weight_b, bias_b + + def _add_output_min_max_for_op(self, origin_op, fake_quant_cell): + """add output quant info for quant op for export mindir.""" + if self.is_mindir: + if isinstance(origin_op, ops.Primitive) and not hasattr(origin_op, 'output_minq'): + np_type = mstype.dtype_to_nptype(self.data_type) + _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type) + origin_op.add_prim_attr('output_maxq', Tensor(maxq)) + origin_op.add_prim_attr('output_minq', Tensor(minq)) + + def _convert_subcell(self, network, change, name, subcell): + """Convert subcell to ant subcell.""" + if subcell is not None and hasattr(subcell, "fake_quant_weight"): + new_subcell = self._get_quant_block(subcell, None, None) + prefix = subcell.param_prefix + new_subcell.update_parameters_name(prefix + '.') + self.upcell = new_subcell + network.insert_child_to_cell(name, new_subcell) + change = True + return network, change + + def _convert_conv(self, network, change, name, subcell): + """Convert subcell to ant subcell for conv.""" + cell_core = subcell.conv + activation = subcell.activation + fake_quant_act = None + if hasattr(activation, 'fake_quant_act_before'): + fake_quant_act = activation.fake_quant_act_before + elif hasattr(activation, 'fake_quant_act'): + fake_quant_act = activation.fake_quant_act + if cell_core is not None and hasattr(cell_core, "fake_quant_weight"): + new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) + self.upcell = None + prefix = subcell.param_prefix + new_subcell.update_parameters_name(prefix + '.') + network.insert_child_to_cell(name, new_subcell) + change = True + return network, change + + def _convert_dense(self, network, change, name, subcell): + """Convert subcell to ant subcell for dense.""" + cell_core = subcell.dense + activation = subcell.activation + fake_quant_act = None + if hasattr(activation, 'fake_quant_act_before'): + fake_quant_act = activation.fake_quant_act_before + elif hasattr(activation, 'fake_quant_act'): + fake_quant_act = activation.fake_quant_act + if cell_core is not None and hasattr(cell_core, "fake_quant_weight"): + new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) + prefix = subcell.param_prefix + new_subcell.update_parameters_name(prefix + '.') + network.insert_child_to_cell(name, new_subcell) + self.upcell = None + change = True + return network, change + + def _convert_act(self, subcell): + """Convert subcell to ant subcell for activation.""" + activation = subcell.get_origin() + if isinstance(activation, nn.ReLU): + self._add_output_min_max_for_op(activation.relu, subcell.fake_quant_act) + elif isinstance(activation, nn.ReLU6): + self._add_output_min_max_for_op(activation.relu6, subcell.fake_quant_act) + if self.upcell: + self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act) + return activation + + def _convert_add(self, subcell): + """Convert subcell to ant subcell for add.""" + if isinstance(subcell.add, _AddFakeQuantAfterSubCell): + add_op = subcell.add.subcell + subcell.__delattr__("add") + subcell.__setattr__("add", add_op) + add_op = subcell.add + self._add_output_min_max_for_op(add_op, subcell.fake_quant_act) + subcell.__delattr__("fake_quant_act") + subcell.__setattr__("fake_quant_act", P.identity()) + + def _convert_observer(self, network, name, subcell): + """Convert subcell to ant subcell for FakeQuantWithMinMaxObserver.""" + if self.upcell: + self._add_output_min_max_for_op(self.upcell.core_op, subcell) + network.__delattr__(name) + network.__setattr__(name, P.identity()) + + def _convert_fake_quant_after_cell(self, network, name, subcell): + """Convert subcell to ant subcell for _AddFakeQuantAfterSubCell.""" + op = subcell.subcell + self._add_output_min_max_for_op(op, subcell.fake_quant_act) + network.__delattr__(name) + network.__setattr__(name, op) + + def _convert_core_quant_subcell(self, network, change, name, subcell): + """Convert subcell to ant subcell for conv and dense.""" + is_core_subcell = True + if isinstance(subcell, nn.Conv2dBnAct): + network, change = self._convert_conv(network, change, name, subcell) + elif isinstance(subcell, nn.DenseBnAct): + network, change = self._convert_dense(network, change, name, subcell) + elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv, + quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)): + network, change = self._convert_subcell(network, change, name, subcell) + else: + is_core_subcell = False + return is_core_subcell, network, change + + def _convert_other_quant_subcell(self, network, change, name, subcell): + """Convert subcell to ant subcell for cell except conv and dense.""" + is_other_subcell = True + if isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"): + activation = self._convert_act(subcell) + network.insert_child_to_cell(name, activation) + change = True + elif isinstance(subcell, nn.TensorAddQuant): + self._convert_add(subcell) + elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver): + self._convert_observer(network, name, subcell) + elif isinstance(subcell, _AddFakeQuantAfterSubCell): + self._convert_fake_quant_after_cell(network, name, subcell) + change = True + else: + is_other_subcell = False + return is_other_subcell, network, change + + def _convert_quant2deploy(self, network): + """Convert network's all quant subcell to deploy subcell.""" + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + is_core_quant_subcell, network, change = self._convert_core_quant_subcell(network, change, name, subcell) + is_other_quant_subcell, network, change = self._convert_other_quant_subcell(network, change, name, subcell) + if not is_core_quant_subcell and not is_other_quant_subcell: + self.upcell = None + self._convert_quant2deploy(subcell) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) + return network diff --git a/mindspore/python/mindspore/golden_stick/legacy/quant/__init__.py b/mindspore/python/mindspore/golden_stick/legacy/quant/__init__.py new file mode 100644 index 00000000000..e2b8cf0f83d --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/legacy/quant/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Quantization module, including base class of the quantizer, the quantization aware training algorithm, +and quantization utils. +""" + +from .quantizer import OptimizeOption +from .qat import QuantizationAwareTraining, create_quant_config +from .quant_utils import load_nonquant_param_into_quant_net, query_quant_layers + +__all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers", "QuantizationAwareTraining", + "create_quant_config", "OptimizeOption"] diff --git a/mindspore/python/mindspore/golden_stick/legacy/quant/qat.py b/mindspore/python/mindspore/golden_stick/legacy/quant/qat.py new file mode 100644 index 00000000000..5c8e2ecfec9 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/legacy/quant/qat.py @@ -0,0 +1,617 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Quantization aware training + +User can use quantization aware to train a model. MindSpore supports quantization aware training, +which models quantization errors in both the forward and backward passes using fake-quantization +operations. Note that the entire computation is carried out in floating point. At the end of quantization +aware training, MindSpore provides conversion functions to convert the trained model into lower precision. +""" + +import re +import mindspore.context as context +import numpy as np +from ... import nn, ops +from ..._checkparam import Validator, Rel +from ...nn.layer import quant +from ...ops import functional as F +from ..common import QuantDtype +from .quantizer import Quantizer, OptimizeOption +from .quant_utils import compute_kl_threshold + +__all__ = ["QuantizationAwareTraining", "create_quant_config"] + + +def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver), + quant_delay=(0, 0), + quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), + per_channel=(False, False), + symmetric=(False, False), + narrow_range=(False, False), + mode="DEFAULT"): + r""" + Config the observer type of weights and data flow with quant parameters. + + Args: + quant_observer (Union[Observer, list, tuple]): The types of observer for quantization. The first element + applies to weights and the second applies to data flow. Currently, only + :class:`FakeQuantWithMinMaxObserver` supported. + Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver). + quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized + during train and eval. The first element represents weights and the second element represents data flow. + Default: (0, 0). + quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first + element represents weights and the second element represents data flow. + Default: (QuantDtype.INT8, QuantDtype.INT8). + per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True` + then base on per channel, otherwise base on per layer. The first element represents weights + and the second element represents data flow, and the second element must be `False` now. + Default: (False, False). + symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then + base on symmetric, otherwise base on asymmetric. The first element represents weights and the second + element represents data flow. Default: (False, False). + narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. + The first element represents weights and the second element represents data flow. + Default: (False, False). + mode (str): Optional quantization mode, currently only `DEFAULT`(QAT) and `LEARNED_SCALE` are supported. + Default: ("DEFAULT"). + + Returns: + QuantConfig, contains the observer type of weight and activation. + + Raises: + ValueError: If the second element of `per_channel` is not `False`. + """ + if per_channel[-1]: + raise ValueError("Arg 'per_channel' second element must be 'False'.") + weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], + per_channel=per_channel[0], symmetric=symmetric[0], + narrow_range=narrow_range[0], mode=mode) + act_observer = quant_observer[-1].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1], + per_channel=per_channel[-1], symmetric=symmetric[-1], + narrow_range=narrow_range[-1], mode=mode) + return quant.QuantConfig(weight=weight_observer, activation=act_observer) + + +class _AddFakeQuantInput(nn.Cell): + """ + Add FakeQuant OP at input of the network. Only support one input case. + """ + + def __init__(self, network, quant_delay=0): + super(_AddFakeQuantInput, self).__init__(auto_prefix=False) + self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6, + quant_delay=quant_delay, ema=True) + self.fake_quant_input.update_parameters_name('fake_quant_input.') + self.network = network + + def construct(self, data): + data = self.fake_quant_input(data) + output = self.network(data) + return output + + +class _AddFakeQuantAfterSubCell(nn.Cell): + """ + Add FakeQuant OP after of the sub Cell. + """ + + def __init__(self, subcell, **kwargs): + super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) + self.subcell = subcell + self.mode = "DEFAULT" + self.max_init = 6 + self.min_init = -6 + + if OptimizeOption.LEARNED_SCALE in kwargs["optimize_option"]: + self.mode = "LEARNED_SCALE" + self.max_init = 16 + self.min_init = -16 + + self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=self.min_init, + max_init=self.max_init, + ema=True, + quant_dtype=kwargs["quant_dtype"], + quant_delay=kwargs["quant_delay"], + per_channel=kwargs["per_channel"], + symmetric=kwargs["symmetric"], + narrow_range=kwargs["narrow_range"], + mode=self.mode) + + def construct(self, *data): + output = self.subcell(*data) + output = self.fake_quant_act(output) + return output + + +class QuantizationAwareTraining(Quantizer): + r""" + Quantizer for quantization aware training. + + Args: + bn_fold (bool): Whether to use bn fold ops for simulation inference operation. Default: True. + freeze_bn (int): Number of steps after which BatchNorm OP parameters fixed to global mean and variance. + Default: 1e7. + quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized + during train and eval. The first element represents weights and the second element represents data flow. + Default: (0, 0). + quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first + element represents weights and the second element represents data flow. It is necessary to consider the + precision support of hardware devices in the practical quantization infer scenario. + Default: (QuantDtype.INT8, QuantDtype.INT8). + per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True` + then base on per channel, otherwise base on per layer. The first element represents weights and the + second element represents data flow, and the second element must be `False` now. Default: (False, False). + symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then + base on symmetric, otherwise base on asymmetric. The first element represents weights and the second + element represents data flow. Default: (False, False). + narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. + The first element represents weights and the second element represents data flow. + Default: (False, False). + optimize_option (Union[OptimizeOption, list, tuple]): Specifies the quant algorithm and options, currently + only support `QAT` and `LEARNED_SCALE` (Note that, if both `QAT` and `LEARNED_SCALE` are configured, + `LEARNED_SCALE` has a higher priority. `LEARNED_SCALE` currently only work under some constraints, which + includes: freeze_bn=0, quant_delay=0, symmetric=True, narrow_range=True, More specifically, for operators + such as Relu and Relu6, which only have positive values, we add a negative truncation to optimize this + scenario, and narrow_range will automatically match to False). Default: OptimizeOption.QAT. + one_conv_fold (bool): Whether to use one conv bn fold ops for simulation inference operation. Default: True. + + Raises: + TypeError: If the element of `quant_delay` or `freeze_bn` is not int. + TypeError: If `bn_fold`, `one_conv_fold` or the element of `per_channel`, `symmetric`, `narrow_range` + is not bool. + TypeError: If the element of `quant_dtype` is not `QuantDtype`. + ValueError: If the length of `quant_delay`, `quant_dtype`, `per_channel`, `symmetric` or `narrow_range` is + not less than 2. + ValueError: If the `optimize_option` is `LEARNED_SCALE` and `freeze_bn` is not equal to 0. + ValueError: If the `optimize_option` is `LEARNED_SCALE` and `symmetric` is not (True, True). + ValueError: If the `optimize_option` is `LEARNED_SCALE` and `narrow_range` is not (True, True). + ValueError: If the `optimize_option` is `LEARNED_SCALE` and `quant_delay` is not (0, 0). + + Examples: + >>> from mindspore.compression.quant import QuantizationAwareTraining + >>> class LeNet5(nn.Cell): + ... def __init__(self, num_class=10, channel=1): + ... super(LeNet5, self).__init__() + ... self.type = "fusion" + ... self.num_class = num_class + ... + ... # change `nn.Conv2d` to `nn.Conv2dBnAct` + ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') + ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') + ... # change `nn.Dense` to `nn.DenseBnAct` + ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') + ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu') + ... self.fc3 = nn.DenseBnAct(84, self.num_class) + ... + ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + ... self.flatten = nn.Flatten() + ... + ... def construct(self, x): + ... x = self.conv1(x) + ... x = self.max_pool2d(x) + ... x = self.conv2(x) + ... x = self.max_pool2d(x) + ... x = self.flatten(x) + ... x = self.fc1(x) + ... x = self.fc2(x) + ... x = self.fc3(x) + ... return x + ... + >>> net = LeNet5() + >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) + >>> net_qat = quantizer.quantize(net) + """ + __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv", "ReduceMean"] + + def __init__(self, + bn_fold=True, + freeze_bn=10000000, + quant_delay=(0, 0), + quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), + per_channel=(False, False), + symmetric=(False, False), + narrow_range=(False, False), + optimize_option=OptimizeOption.QAT, + one_conv_fold=True): + """Init for QuantizationAwareTraining quantizer""" + super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option) + + def convert2list(name, value): + if not isinstance(value, list) and not isinstance(value, tuple): + value = [value] + elif len(value) > 2: + raise ValueError("input `{}` len should less then 2".format(name)) + return value + + quant_delay = convert2list("quant delay", quant_delay) + quant_dtype = convert2list("quant dtype", quant_dtype) + per_channel = convert2list("per channel", per_channel) + symmetric = convert2list("symmetric", symmetric) + narrow_range = convert2list("narrow range", narrow_range) + + self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay") + self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay") + self.bn_fold = Validator.check_bool(bn_fold, "bn fold") + self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn") + self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype) + self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype) + self.weight_channel = Validator.check_bool(per_channel[0], "per channel") + self.act_channel = Validator.check_bool(per_channel[-1], "per channel") + self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric") + self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric") + self.weight_range = Validator.check_bool(narrow_range[0], "narrow range") + self.act_range = Validator.check_bool(narrow_range[-1], "narrow range") + self.one_conv_fold = Validator.check_bool(one_conv_fold, "one conv fold") + self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv, + nn.DenseBnAct: self._convert_dense} + self.mode = "DEFAULT" + if OptimizeOption.LEARNED_SCALE in self.optimize_option: + self.mode = "LEARNED_SCALE" + if not self.weight_symmetric or not self.act_symmetric: + raise ValueError("OptimizeOption.LEARNED_SCALE currently only support " + "symmetric=(True, True) for quant") + if not self.weight_range or not self.act_range: + raise ValueError("OptimizeOption.LEARNED_SCALE currently only support narrow_range=(True, True) " + "for quant") + if self.freeze_bn != 0: + raise ValueError("OptimizeOption.LEARNED_SCALE currently only support freeze_bn equal to 0, " + "but get freeze_bn={}".format(self.freeze_bn)) + if self.weight_qdelay != 0 or self.act_qdelay != 0: + raise ValueError("OptimizeOption.LEARNED_SCALE currently only support quant_delay=(0, 0)") + self.quant_config = create_quant_config(quant_delay=quant_delay, + quant_dtype=quant_dtype, + per_channel=per_channel, + symmetric=symmetric, + narrow_range=narrow_range, + mode=self.mode) + self.eps = 1e-5 + + @staticmethod + def _convert_op_name(name): + pattern = re.compile(r'([A-Z]{1})') + name_new = re.sub(pattern, r'_\1', name).lower() + if name_new[0] == '_': + name_new = name_new[1:] + return name_new + + def quantize(self, network): + """ + Quant API to convert input network to a quantization aware training network. + + Note: + Please refer to the Examples of class: `mindspore.compression.quant.QuantizationAwareTraining`. + + Args: + network (Cell): network to be quantized. + + Returns: + Cell, a quantization aware training network. + + Raises: + KeyError: If the `device_target` set in context is not in `support_device`. + """ + support_device = ["Ascend", "GPU"] + if context.get_context('device_target') not in support_device: + raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) + + if OptimizeOption.QAT in self.optimize_option or OptimizeOption.LEARNED_SCALE in self.optimize_option: + network.update_cell_prefix() + network = self._convert_subcells2quant(network) + network.update_cell_type("quant") + return network + + def _convert_subcells2quant(self, network): + """ + convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell + """ + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)): + prefix = subcell.param_prefix + new_subcell = self._convert_method_map[type(subcell)](subcell) + new_subcell.update_parameters_name(prefix + '.') + network.insert_child_to_cell(name, new_subcell) + change = True + else: + self._convert_subcells2quant(subcell) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) + + # add FakeQuant OP after OP in white list, but not including those wrapped in the below quantization cell. + if isinstance(network, (nn.FakeQuantWithMinMaxObserver, + nn.Conv2dBnFoldQuantOneConv, + nn.Conv2dBnFoldQuant, + nn.Conv2dBnWithoutFoldQuant, + nn.Conv2dQuant, + nn.DenseQuant, + nn.ActQuant, + nn.TensorAddQuant, + nn.MulQuant)): + return network + + add_list = [] + for name in network.__dict__: + if name[0] == '_': + continue + attr = network.__dict__[name] + if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: + add_list.append((name, attr)) + for name, prim_op in add_list: + prefix = name + add_quant = _AddFakeQuantAfterSubCell(prim_op, + quant_dtype=self.act_dtype, + quant_delay=self.act_qdelay, + per_channel=self.act_channel, + symmetric=self.act_symmetric, + narrow_range=self.act_range, + optimize_option=self.optimize_option) + if network.param_prefix: + prefix = '.'.join([network.param_prefix, prefix]) + add_quant.update_parameters_name(prefix + '.') + del network.__dict__[name] + network.insert_child_to_cell(name, add_quant) + return network + + def _convert_conv(self, subcell): + """ + convert Conv2d cell to quant cell + """ + min_init = -6 + max_init = 6 + if OptimizeOption.LEARNED_SCALE in self.optimize_option: + subcell_weight_para = subcell.conv.weight.data.asnumpy() + if subcell.has_bn: + scale_factor = (subcell.batchnorm.gamma.data.asnumpy() / + np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + min_init, max_init = self._kl_init(subcell_weight_para, self.weight_dtype) + self.quant_config = self.quant_config._replace( + weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init)) + + conv_inner = subcell.conv + if subcell.has_bn: + bn_inner = subcell.batchnorm + if self.bn_fold: + if self.one_conv_fold: + conv_inner = quant.Conv2dBnFoldQuantOneConv(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + momentum=1 - bn_inner.momentum, + has_bias=conv_inner.has_bias, + bias_init=conv_inner.bias_init, + quant_config=self.quant_config, + quant_dtype=self.weight_dtype, + fake=True) + else: + conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + momentum=1 - bn_inner.momentum, + has_bias=conv_inner.has_bias, + bias_init=conv_inner.bias_init, + freeze_bn=self.freeze_bn, + quant_config=self.quant_config, + quant_dtype=self.weight_dtype, + fake=True) + # change original network Batch Normalization OP parameters to quant network + conv_inner.gamma = subcell.batchnorm.gamma + conv_inner.beta = subcell.batchnorm.beta + conv_inner.moving_mean = subcell.batchnorm.moving_mean + conv_inner.moving_variance = subcell.batchnorm.moving_variance + else: + conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + momentum=1 - bn_inner.momentum, + has_bias=conv_inner.has_bias, + bias_init=conv_inner.bias_init, + quant_config=self.quant_config, + quant_dtype=self.weight_dtype) + # change original network Batch Normalization OP parameters to quant network + conv_inner.batchnorm.gamma = subcell.batchnorm.gamma + conv_inner.batchnorm.beta = subcell.batchnorm.beta + conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean + conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance + del subcell.batchnorm + subcell.batchnorm = None + subcell.has_bn = False + else: + conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, padding=conv_inner.padding, + dilation=conv_inner.dilation, group=conv_inner.group, + has_bias=conv_inner.has_bias, quant_config=self.quant_config, + quant_dtype=self.weight_dtype) + # change original network Conv2D OP parameters to quant network + conv_inner.weight = subcell.conv.weight + if subcell.conv.has_bias: + conv_inner.bias = subcell.conv.bias + subcell.conv = conv_inner + if subcell.has_act and subcell.activation is not None: + subcell.activation = self._convert_activation(subcell.activation) + elif subcell.after_fake: + subcell.has_act = True + subcell.activation = _AddFakeQuantAfterSubCell(F.identity, quant_dtype=self.act_dtype, + quant_delay=self.act_qdelay, per_channel=self.act_channel, + symmetric=self.act_symmetric, narrow_range=self.act_range, + optimize_option=self.optimize_option) + return subcell + + def _convert_dense(self, subcell): + """ + convert dense cell to quant cell + """ + min_init = -6 + max_init = 6 + if OptimizeOption.LEARNED_SCALE in self.optimize_option: + subcell_weight_para = subcell.dense.weight.data.asnumpy() + if subcell.has_bn: + scale_factor = (subcell.batchnorm.gamma.data.asnumpy() / + np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + min_init, max_init = self._kl_init(subcell_weight_para, self.weight_dtype) + self.quant_config = self.quant_config._replace( + weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init)) + + dense_inner = subcell.dense + dense_inner = quant.DenseQuant(dense_inner.in_channels, + dense_inner.out_channels, + has_bias=dense_inner.has_bias, + quant_config=self.quant_config, + quant_dtype=self.weight_dtype) + # change original network Dense OP parameters to quant network + dense_inner.weight = subcell.dense.weight + if subcell.dense.has_bias: + dense_inner.bias = subcell.dense.bias + subcell.dense = dense_inner + if subcell.has_act and subcell.activation is not None: + subcell.activation = self._convert_activation(subcell.activation) + elif subcell.after_fake: + subcell.has_act = True + subcell.activation = _AddFakeQuantAfterSubCell(F.identity, + quant_dtype=self.act_dtype, + quant_delay=self.act_qdelay, + per_channel=self.act_channel, + symmetric=self.act_symmetric, + narrow_range=self.act_range, + optimize_option=self.optimize_option) + return subcell + + def _convert_activation(self, activation): + """ + convert activation cell to quant cell + """ + act_class = activation.__class__ + act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid] + act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish] + + if act_class in act_list: + return quant.ActQuant(activation=activation, + quant_config=self.quant_config, + quant_dtype=self.act_dtype) + if act_class in act_list_with_fake_before: + return quant.ActQuant(activation=activation, + ema=True, + fake_before=True, + quant_config=self.quant_config, + quant_dtype=self.act_dtype) + raise ValueError("Unsupported activation in auto quant: ", act_class) + + def _kl_init(self, subcell_weight_para, weight_dtype): + """ + Calculate the value of max_init and min_init with compute_kl_threshold. + """ + if self.weight_channel: + max_init = [compute_kl_threshold(weight_para_each, weight_dtype) + for weight_para_each in subcell_weight_para] + min_init = [-x for x in max_init] + else: + max_init = [compute_kl_threshold(subcell_weight_para, weight_dtype)] + min_init = [-x for x in max_init] + return min_init, max_init + + def _set_mixed_bits(self, network, strategy): + r""" + Set network's quantization strategy, this function is currently only valid for `LEARNED_SCALE` + optimize_option. + + Args: + network (Cell): Input network. + strategy (list): The quantization strategy for layers that need to be quantified (eg. [[8], [8], + ..., [6], [4], [8]]), currently only the quant_dtype for weights of the dense layer and the + convolution layer is supported. + + Returns: + Cell, a network with mixed bit strategy configured. + + Raises: + ValueError: If `OptimizeOption.LEARNED_SCALE` is not in `self.optimize_option`. + """ + if OptimizeOption.LEARNED_SCALE not in self.optimize_option: + raise ValueError("The `_set_mixed_bits` function is currently only valid for `LEARNED_SCALE` " + "optimize_option.") + + quantizable_idx = [] + pass_cell = None + for i, cell_and_name in enumerate(network.cells_and_names()): + cell = cell_and_name[1] + if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)) and cell is not pass_cell: + quantizable_idx.append(i) + + if len(quantizable_idx) != len(strategy): + raise ValueError("The dimension of quantifiable layers is not consistent with that of strategy.") + + quantizable_layer_bit_dict = {idx: bit for idx, bit in zip(quantizable_idx, strategy)} + type_map = { + QuantDtype.INT2.num_bits: QuantDtype.INT2, + QuantDtype.INT3.num_bits: QuantDtype.INT3, + QuantDtype.INT4.num_bits: QuantDtype.INT4, + QuantDtype.INT5.num_bits: QuantDtype.INT5, + QuantDtype.INT6.num_bits: QuantDtype.INT6, + QuantDtype.INT7.num_bits: QuantDtype.INT7, + QuantDtype.INT8.num_bits: QuantDtype.INT8 + } + for i, cell_and_name in enumerate(network.cells_and_names()): + cell = cell_and_name[1] + if i not in quantizable_idx: + continue + else: + if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)): + cell.weight_dtype = type_map[quantizable_layer_bit_dict[i][0]] + if isinstance(cell, nn.Conv2dBnAct): + subcell_weight_para = cell.conv.weight.data.asnumpy() + if hasattr(cell.conv, 'gamma'): + scale_factor = (cell.conv.gamma.data.asnumpy() / + np.sqrt(cell.conv.moving_variance.data.asnumpy() + self.eps)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + min_init, max_init = self._kl_init(subcell_weight_para, cell.weight_dtype) + cell.conv.fake_quant_weight.reset(quant_dtype=cell.weight_dtype, + min_init=min_init, + max_init=max_init) + elif isinstance(cell, nn.DenseBnAct): + subcell_weight_para = cell.dense.weight.data.asnumpy() + if hasattr(cell.dense, 'gamma'): + scale_factor = (cell.dense.gamma.data.asnumpy() / + np.sqrt(cell.dense.moving_variance.data.asnumpy() + self.eps)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + min_init, max_init = self._kl_init(subcell_weight_para, cell.weight_dtype) + cell.dense.fake_quant_weight.reset(quant_dtype=cell.weight_dtype, + min_init=min_init, + max_init=max_init) + return network diff --git a/mindspore/python/mindspore/golden_stick/legacy/quant/quant_utils.py b/mindspore/python/mindspore/golden_stick/legacy/quant/quant_utils.py new file mode 100644 index 00000000000..860cf0c9128 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/legacy/quant/quant_utils.py @@ -0,0 +1,439 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Quantization utils.""" + +import numpy as np +from mindspore._checkparam import Validator +from ... import nn + +__all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers"] + + +def cal_quantization_params(input_min, + input_max, + quant_min, + quant_max, + data_type, + symmetric=False): + r""" + Calculate quantization params for scale and zero point. + + Args: + input_min (numpy.ndarray): The dimension of channel or 1. + input_max (numpy.ndarray): The dimension of channel or 1. + quant_min (int): The minimum quantization integer. + quant_max (int): The maximum quantization integer. + data_type (numpy type) : Can be numpy int8, numpy uint8. + symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. + + Returns: + scale (numpy.ndarray): quantization param. + zero point (numpy.ndarray): quantization param. + """ + input_max = np.maximum(0.0, input_max) + input_min = np.minimum(0.0, input_min) + + if input_min.shape != input_max.shape: + raise ValueError("input min shape should be equal to input max.") + if len(input_min.shape) > 1: + raise ValueError("input min and max shape should be one dim.") + if (input_min > input_max).all(): + raise ValueError("input_min min should be less than input max.") + if (input_max == input_min).all(): + return np.ones(input_min.shape), np.zeros(input_min.shape) + + # calculate scale + if symmetric: + input_max = np.maximum(-input_min, input_max) + input_min = -input_max + scale = (input_max - input_min) / (quant_max - quant_min) + + # calculate zero point + if data_type == np.int8 and symmetric: + zp = np.zeros(input_min.shape) + else: + zp_double = quant_min - input_min / scale + zp = np.floor(zp_double + 0.5) + + return scale, zp + + +def get_quant_min_max(data_type, num_bits=8, narrow_range=False): + """Calculate quantization params for minimum/maximum quantization integer""" + if data_type == np.int8: + quant_min = 0 - 2 ** (num_bits - 1) + quant_max = 2 ** (num_bits - 1) - 1 + elif data_type == np.uint8: + quant_min = 0 + quant_max = 2 ** num_bits - 1 + else: + raise ValueError("Unsupported datatype({})".format(data_type)) + if narrow_range: + quant_min = quant_min + 1 + return quant_min, quant_max + + +def weight2int(data, scale, zero_point, quant_min, quant_max): + r""" + Calculate int8/uint8 weight from fp32. the formula is defined as: + + .. math:: + int8/uint8 = round(float/scale) + offset + + Args: + data (numpy.ndarray): The dimension of channel or 1. Should be NCHW. + scale (numpy.ndarray): The dimension of channel or 1. + zero_point (numpy.ndarray): The dimension of channel or 1. + quant_min (int): The minimum quantization integer. + quant_max (int): The maximum quantization integer. + + Returns: + weight (numpy.ndarray): The dimension of channel or 1. + """ + if scale.shape != zero_point.shape: + raise ValueError("`scale` and `zero_point` should have the same shape.") + if scale.shape[0] < 0: + raise ValueError("`scale` and `zero_point` shape should be greater than zero.") + if len(scale.shape) >= 1 and scale.shape[0] > 1: + # for perchannel + if scale.shape[0] == data.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(data.shape[1:]) + scale = scale.reshape(shape_list) + zero_point = zero_point.reshape(shape_list) + elif scale.shape[0] == data.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(data.shape[2:]) + scale = scale.reshape(shape_list) + zero_point = zero_point.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(data.shape)) + + weight_int = np.round((data / scale) + zero_point) + weight_int[weight_int > quant_max] = quant_max + weight_int[weight_int < quant_min] = quant_min + return weight_int + + +def scale_zp_max_min_from_fake_quant_cell(cell, data_type): + """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMaxObserver`.""" + minq = cell.minq.data.asnumpy() + maxq = cell.maxq.data.asnumpy() + # make sure maxq > 0 and minq <= 0 + if cell.mode == 'LEARNED_SCALE': + maxq = np.abs(maxq) + minq = -np.abs(minq) + quant_min, quant_max = get_quant_min_max(data_type, num_bits=cell.num_bits, narrow_range=cell.narrow_range) + symmetric = cell.symmetric and not cell.neg_trunc + scale, zp = cal_quantization_params( + minq, maxq, + quant_min, quant_max, data_type, + symmetric=symmetric) + return scale, zp, maxq, minq + + +def fold_batchnorm(weight, cell_quant): + r""" + Fold the batchnorm in `Conv2dBnFoldQuant` to weight. + + Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. + + Args: + weight (numpy.ndarray): Weight of `cell_quant`. + cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`. + + Returns: + weight (numpy.ndarray): Folded weight. + bias (numpy.ndarray): Folded bias. + """ + variance = cell_quant.moving_variance.data.asnumpy() + mean = cell_quant.moving_mean.data.asnumpy() + gamma = cell_quant.gamma.data.asnumpy() + beta = cell_quant.beta.data.asnumpy() + epsilon = cell_quant.eps + sigma = np.sqrt(variance + epsilon) + + if gamma.shape[0] == weight.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(weight.shape[1:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + elif gamma.shape[0] == weight.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(weight.shape[2:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(weight.shape)) + + weight = weight * _gamma / _sigma + bias = beta - gamma * mean / sigma + return weight, bias + + +def without_fold_batchnorm(weight, cell_quant): + r""" + Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight. + + Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. + + Args: + weight (numpy.ndarray): Weight of `cell_quant`. + cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`. + + Returns: + weight (numpy.ndarray): whihout folded weight. + bias (numpy.ndarray): without folded bias. + """ + variance = cell_quant.batchnorm.moving_variance.data.asnumpy() + mean = cell_quant.batchnorm.moving_mean.data.asnumpy() + gamma = cell_quant.batchnorm.gamma.data.asnumpy() + beta = cell_quant.batchnorm.beta.data.asnumpy() + epsilon = cell_quant.batchnorm.eps + sigma = np.sqrt(variance + epsilon) + + if gamma.shape[0] == weight.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(weight.shape[1:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + elif gamma.shape[0] == weight.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(weight.shape[2:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(weight.shape)) + + weight = weight * _gamma / _sigma + bias = beta - gamma * mean / sigma + return weight, bias + + +def compute_kl_threshold(data, bitwidth): + r""" + Using KL-J Distance to calculate the clip threshold. + + Args: + - **data** (NumpyArray) - Data observed to calculate the threshold for quantization, + - **bitwidth** (QuantDtype) - The datatype of quantization. + Outputs: + Tensor with Shape 1. Threshold to calculate the data. + """ + data_max = np.abs(data).max() + if data_max < 1e-5: + return 1e-5 + hist, bin_edges = np.histogram(np.abs(data), bins='sqrt', range=(0, data_max), density=True) + # For the sake of high efficiency, we limit the maximum number of bins to 1024 in `sqrt` mode, If it exceeds the + # largest size, turn to use the default bins config. + largest_bin_size = 1024 + if hist.shape[0] > largest_bin_size: + hist, bin_edges = np.histogram(np.abs(data), range=(0, data_max), density=True) + hist = hist / np.sum(hist) + cumsum = np.cumsum(hist) + bit_pow_range = pow(2, int(bitwidth.num_bits) - 1) + threshold = [] + scaling_factor = [] + kl = [] + if bit_pow_range + 1 > len(bin_edges) - 1: + th_layer_out = bin_edges[-1] + return float(th_layer_out) + for i in range(bit_pow_range + 1, len(bin_edges), 1): + threshold_tmp = (i + 0.5) * (bin_edges[1] - bin_edges[0]) + threshold = np.concatenate((threshold, [threshold_tmp])) + scaling_factor_tmp = threshold_tmp / (bit_pow_range - 1) + scaling_factor = np.concatenate((scaling_factor, [scaling_factor_tmp])) + # forward interpolation + cumsum_tmp = np.copy(cumsum) + cumsum_tmp[(i - 1):] = 1 + fwd_x = np.linspace(0.0, 1.0, bit_pow_range) + fwd_xp = np.linspace(0.0, 1.0, i) + fwd_fp = cumsum_tmp[:i] + forward_interp = np.interp(fwd_x, fwd_xp, fwd_fp) + # backward interpolation + bwd_x = np.linspace(0.0, 1.0, i) + bwd_xp = np.linspace(0.0, 1.0, bit_pow_range) + bwd_fp = forward_interp + backward_interp = np.interp(bwd_x, bwd_xp, bwd_fp) + cumsum_tmp[:i] = backward_interp + kl_tmp = np.sum((cumsum - cumsum_tmp) * np.log2(cumsum / cumsum_tmp)) # Kullback-Leibler-J + kl = np.concatenate((kl, [kl_tmp])) + th_layer_out = threshold[np.argmin(kl)] + threshold = float(th_layer_out) + if threshold < 1e-5: + threshold = 1e-5 + return threshold + + +def query_quant_layers(network): + r""" + Query the network's quantization strategy of each quantized layer and print it to the screen, note that all the + quantization layers are queried before graph compile optimization in the graph mode, thus, some redundant quantized + layers, which not exist in practical execution, may appear. + + Args: + network (Cell): input network + + Examples: + >>> from mindspore.compression.quant import QuantizationAwareTraining + >>> from mindspore.compression.quant.quant_utils import query_quant_layers + >>> class LeNet5(nn.Cell): + ... def __init__(self, num_class=10, channel=1): + ... super(LeNet5, self).__init__() + ... self.type = "fusion" + ... self.num_class = num_class + ... + ... # change `nn.Conv2d` to `nn.Conv2dBnAct` + ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') + ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') + ... # change `nn.Dense` to `nn.DenseBnAct` + ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') + ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu') + ... self.fc3 = nn.DenseBnAct(84, self.num_class) + ... + ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + ... self.flatten = nn.Flatten() + ... + ... def construct(self, x): + ... x = self.conv1(x) + ... x = self.max_pool2d(x) + ... x = self.conv2(x) + ... x = self.max_pool2d(x) + ... x = self.flatten(x) + ... x = self.fc1(x) + ... x = self.fc2(x) + ... x = self.fc3(x) + ... return x + ... + >>> net = LeNet5() + >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) + >>> net_qat = quantizer.quantize(net) + >>> query_quant_layers(net_qat) + conv1.conv.fake_quant_weight INT8 + conv1.activation.fake_quant_act INT8 + conv2.conv.fake_quant_weight INT8 + conv2.activation.fake_quant_act INT8 + fc1.dense.fake_quant_weight INT8 + fc1.activation.fake_quant_act INT8 + fc2.dense.fake_quant_weight INT8 + fc2.activation.fake_quant_act INT8 + fc3.dense.fake_quant_weight INT8 + fc3.activation.fake_quant_act INT8 + """ + network = Validator.check_isinstance("network", network, nn.Cell) + tplt = "{0:60}\t{1:10}" + for cell_and_name in network.cells_and_names(): + cell_name = cell_and_name[0] + cell = cell_and_name[1] + if isinstance(cell, nn.FakeQuantWithMinMaxObserver): + print(tplt.format(cell_name, cell.quant_dtype)) + + +def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None): + r""" + Load fp32 model parameters into quantization model. + + Args: + quant_model(Cell): Quantization model. + params_dict(dict): Parameter dict that stores fp32 parameters. + quant_new_params(list): Parameters that exist in quantization network but not in non-quantization + network. Default: None. + + Raises: + TypeError: If `quant_new_params` is not None and is not list. + ValueError: If there are parameters in the `quant_model` that are neither in `params_dict` + nor in `quant_new_params`. + + Examples: + >>> from mindspore import load_checkpoint + >>> from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net + >>> class LeNet5(nn.Cell): + ... def __init__(self, num_class=10, channel=1): + ... super(LeNet5, self).__init__() + ... self.type = "fusion" + ... self.num_class = num_class + ... + ... # change `nn.Conv2d` to `nn.Conv2dBnAct` + ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') + ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') + ... # change `nn.Dense` to `nn.DenseBnAct` + ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') + ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu') + ... self.fc3 = nn.DenseBnAct(84, self.num_class) + ... + ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + ... self.flatten = nn.Flatten() + ... + ... def construct(self, x): + ... x = self.conv1(x) + ... x = self.max_pool2d(x) + ... x = self.conv2(x) + ... x = self.max_pool2d(x) + ... x = self.flatten(x) + ... x = self.fc1(x) + ... x = self.fc2(x) + ... x = self.fc3(x) + ... return x + ... + >>> net = LeNet5() + >>> ckpt_file_name = "./checkpoint/LeNet5_noquant-1_32.ckpt" + >>> param_dict = load_checkpoint(ckpt_file_name) + >>> load_nonquant_param_into_quant_net(net, param_dict) + """ + if quant_new_params is not None and not isinstance(quant_new_params, list): + raise TypeError("quant_new_params must be list or None.") + iterable_dict = { + 'minq': iter(list(filter(lambda item: item[0].endswith('minq'), params_dict.items()))), + 'maxq': iter(list(filter(lambda item: item[0].endswith('maxq'), params_dict.items()))), + 'quant_max': iter(list(filter(lambda item: item[0].endswith('quant_max'), params_dict.items()))) + } + for param in params_dict.items(): + key_name = param[0].split(".")[-1] + if key_name not in iterable_dict: + iterable_dict[key_name] = iter(list(filter(lambda item, value=key_name: item[0].endswith(value), + params_dict.items()))) + + for name, param in quant_model.parameters_and_names(): + key_name = name.split(".")[-1] + if key_name not in iterable_dict.keys(): + if key_name not in quant_new_params: + raise ValueError(f"Can't find match parameter in ckpt, param name = {name}") + continue + value_param = next(iterable_dict[key_name], None) + if value_param: + param.set_data(value_param[1].data) + print(f'init model param {name} with checkpoint param {value_param[0]}') + + # Perform KL_init when learned scale quantization is executed. + for cell_and_name in quant_model.cells_and_names(): + cell = cell_and_name[1] + if isinstance(cell, (nn.Conv2dBnFoldQuantOneConv, nn.Conv2dBnFoldQuant, nn.Conv2dBnWithoutFoldQuant, + nn.Conv2dQuant, nn.DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE": + subcell_weight_para = cell.weight.data.asnumpy() + if hasattr(cell, 'gamma'): + scale_factor = (cell.gamma.data.asnumpy() / + np.sqrt(cell.moving_variance.data.asnumpy() + 1e-5)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + + if cell.fake_quant_weight.per_channel: + max_init = [compute_kl_threshold(weight_para_each, cell.fake_quant_weight.quant_dtype) + for weight_para_each in subcell_weight_para] + min_init = [-x for x in max_init] + else: + max_init = [compute_kl_threshold(subcell_weight_para, cell.fake_quant_weight.quant_dtype)] + min_init = [-x for x in max_init] + + cell.fake_quant_weight.reset(quant_dtype=cell.fake_quant_weight.quant_dtype, + min_init=min_init, max_init=max_init) diff --git a/mindspore/python/mindspore/golden_stick/legacy/quant/quantizer.py b/mindspore/python/mindspore/golden_stick/legacy/quant/quantizer.py new file mode 100644 index 00000000000..1a04a07b165 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/legacy/quant/quantizer.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base Class of Quantizer.""" + +from abc import ABC, abstractmethod +from enum import Enum + +from ..._checkparam import Validator + +__all__ = ["OptimizeOption"] + + +class OptimizeOption(Enum): + r""" + An enum for the model quantization optimize option, currently only support `QAT` and `LEARNED_SCALE`. + """ + # using quantization aware training + QAT = "QAT" + + # using the learned scale quantization + LEARNED_SCALE = "LEARNED_SCALE" + + def __str__(self): + return self.value + + +class Quantizer(ABC): + """ + Base class of Quantizer. You can implement different kind of quantizer to get different quantization result. + + Notes: + This class is an abstract class. + + Args: + optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options. Default: + OptimizeOption.QAT. + """ + + def __init__(self, + optimize_option=OptimizeOption.QAT): + if not isinstance(optimize_option, list) and not isinstance(optimize_option, tuple): + optimize_option = [optimize_option] + for option in optimize_option: + option = Validator.check_isinstance("optimize_option", option, OptimizeOption) + self.optimize_option = optimize_option + + @abstractmethod + def quantize(self, network): + """ + Quant API to convert input network to a quantization aware training network + Args: + network (Cell): network to be quantized. + """ diff --git a/mindspore/python/mindspore/golden_stick/net_transform.py b/mindspore/python/mindspore/golden_stick/net_transform.py new file mode 100644 index 00000000000..f5fb25f87c1 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/net_transform.py @@ -0,0 +1,158 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""NetTransform.""" +from typing import Union, Optional + +from mindspore.nn.cell import Cell +from mindspore.rewrite import Graph, Node, PatternEngine + + +class NetTransformer: + """ + NetTransformer is define for transform network in MindSpore. + + Args: + net (Cell): Network to be transformed. + """ + + def __init__(self, net: Cell): + self._net = net + self._graph = Graph(net) + + def get_transformed(self) -> Cell: + """ + Returns: + Transformed network. + """ + + return self._graph.python_object() + + def nodes(self) -> [Node]: + """ + Returns: + a list of Node corresponding to all layers in original network. + """ + + return self._graph.nodes + + @staticmethod + def set_node_attr(node: Node, key: str, value): + node.attribute.set_attribute(key, value) + + @staticmethod + def get_node_attr(node: Node, key: str): + return node.attribute.attribute[key] + + def find_node(self, full_name_with_scope: str) -> Node: + """ + Args: + full_name_with_scope (str): Name of node to be find. + + Returns: + Node whose name is `full_name_with_scope`. + """ + + return self._graph.find(full_name_with_scope) + + # return inputs of node whose full_name_with_scope is full_name_with_scope + def node_inputs(self, full_name_with_scope: str) -> [Node]: + """ + Args: + full_name_with_scope (str): Name of node to be find. + + Returns: + Input nodes of node whose name is `full_name_with_scope` + """ + + node = self._graph.find(full_name_with_scope) + if node is None: + return [] + return node.inputs + + # return outputs of node whose full_name_with_scope is full_name_with_scope + def node_outputs(self, full_name_with_scope: str) -> [Node]: + """ + Args: + full_name_with_scope (str): Name of node to be find. + + Returns: + Output nodes of node whose name is `full_name_with_scope` + """ + + node = self._graph.find(full_name_with_scope) + if node is None: + return [] + return node.outputs + + def insert_node(self, new_node: Node) -> Node: + """ + Args: + new_node (Node): New node to be inserted into original network. + New_node should contain its inputs and outputs. + + Returns: + Node has been inserted, return None if failed + """ + + return self._graph.insert_node(new_node) + + def remove_node(self, node: Union[str, Node]) -> Optional[Node]: + """ + Args: + node (Node): node to be removed from original network. + + Returns: + Node has been removed, return None if failed + """ + + if isinstance(node, str): + node = self._graph.find(node) + if node is None: + return None + return self._graph.remove_node(node) + + def replace_node(self, target: Union[str, Node], value: Union[Cell, Node]) -> Optional[Node]: + """ + Args: + target (Union[str, Node]): Name of node to be replaced. + value (Union[Cell, Node]): Node to be replaced into original network. + + Note: + new_node should has same inputs and outputs with old_node. + + Returns: + Node has been replaced, return None if failed + """ + + if isinstance(target, str): + target = self._graph.find(target) + if target is None: + return None + if isinstance(value, Cell): + value = Node(value) + return self._graph.replace_node(target, value) + + # replace src_pattern with target_nodes. + # target_nodes should has same inputs and outputs with src_pattern. + def pattern_transform(self, pattern_engine: PatternEngine) -> bool: + """ + Args: + pattern_engine (PatternEngine): Instance of PatternEngine. Apply `pattern_engine` on current network + + Returns: + a bool value indicating if transform occurred + """ + + return pattern_engine.apply(self._graph) diff --git a/mindspore/python/mindspore/golden_stick/pruner/__init__.py b/mindspore/python/mindspore/golden_stick/pruner/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindspore/python/mindspore/golden_stick/pruner/simple_pruner.py b/mindspore/python/mindspore/golden_stick/pruner/simple_pruner.py new file mode 100644 index 00000000000..e133524ac70 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/pruner/simple_pruner.py @@ -0,0 +1,97 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SimplePruner.""" + +from ..golden_stick import GoldenStick +from mindspore.train.callback import Callback +from mindspore.nn import * +from mindspore.rewrite.pattern_engine import PatternEngine + + +class PrunerConv(Cell): + def __init__(self, conv, relu): + super().__init__() + self._old_conv: Conv2d = conv + self._old_relu: ReLU = relu + + def construct(self, x): + x = self._old_conv(x) + x = self._old_relu(x) + return self._old_relu(x) + + +class PrunerLoss(Cell): + def __init__(self, ori_loss: Cell): + super().__init__() + self._old_loss = ori_loss + self._extra_loss = L1Loss + + def construct(self, x): + x = self._old_loss(x) + return self._extra_loss(x) + + +class PrunerPatternEngine(PatternEngine): + def __init__(self): + super().__init__([Conv2d, ReLU], PrunerConv) + + +class PrunerCallback(Callback): + def __init__(self, begin, end, frequency, target_sparsity): + super(PrunerCallback, self).__init__() + self._begin_step = begin + self._end_step = end + self._frequency = frequency + self._target_sparsity = target_sparsity + self._masks = {} + + # init _masks + def begin(self, run_context): + origin_args = run_context.original_args() + net: Cell = origin_args.network + cells = net.cells() + for cell in cells: + if cell.cell_type is Conv2d: + self._masks[cell.name] = [1, 1, 1, 1] + + def step_end(self, run_context): + origin_args = run_context.original_args() + cur_step_num = origin_args.cur_step_num + if cur_step_num < self._begin_step or cur_step_num >= self._end_step: + return + cur_step_num_index = cur_step_num - self._begin_step + if cur_step_num_index % self._frequency != 0: + return + # use _mask to update weight + net: Cell = origin_args.network + # compute new _mask + + def end(self, run_context): + # weight = weight * mask + pass + + +# define pruner algo +class PrunerCompressAlgo(GoldenStick): + def __init__(self, config: {}): + super(PrunerCompressAlgo, self).__init__(config) + self._callback = PrunerCallback(config["begin_step"], config["end_step"], config["frequency"], + config["target_sparsity"]) + + def callbacks(self): + return self._callback + + def loss(self, loss_fn): + return PrunerLoss(loss_fn) diff --git a/mindspore/python/mindspore/golden_stick/quantization/__init__.py b/mindspore/python/mindspore/golden_stick/quantization/__init__.py new file mode 100644 index 00000000000..fae58383ffe --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +MindSpore golden stick module. +""" + +from .layer_policy import LayerPolicy +from .net_policy import NetPolicy +from .quantize import QuantAwareTraining +from .quantizer import Quantizer +from .transformer import Transformer +from .default_qat import AllValueQuantizer, LastValueQuantizer, LSQ, DefaultLayerPolicy, DefaultNetworkPolicy, \ + DefaultQuantAwareTraining + +__all__ = ["LayerPolicy", "NetPolicy", "QuantAwareTraining", "Quantizer", "Transformer", "AllValueQuantizer", + "LastValueQuantizer", "LSQ", "DefaultLayerPolicy", "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py new file mode 100644 index 00000000000..b2542fa2330 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +MindSpore golden stick default-qat-quantization. +""" + +from .default_quantizer import AllValueQuantizer, LastValueQuantizer, LSQ +from .default_layer_policy import DefaultLayerPolicy +from .default_net_policy import DefaultNetworkPolicy +from .default_quantize import DefaultQuantAwareTraining + +__all__ = ["AllValueQuantizer", "LastValueQuantizer", "LSQ", "DefaultLayerPolicy", "DefaultNetworkPolicy", + "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py new file mode 100644 index 00000000000..9d45a6ba72d --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py @@ -0,0 +1,80 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DefaultLayerPolicy.""" +from typing import Optional + +from ..layer_policy import LayerPolicy +from ..quantize_wrapper_cell import QuantizeWrapperCell +from ..quantizer import Quantizer +from .default_quantizer import LastValueQuantizer +from mindspore.nn import Cell + + +class DefaultLayerPolicy(LayerPolicy): + """ + Derived class of LayerQConfig. Default layer-quant-config. + + Supported Config: + ``quant_delay`` ``quant_dtype`` ``per_channel`` ``symmetric`` ``narrow_range`` ``one_conv_fold``. + """ + + def __init__(self, weight_names: [], act_names: [], config=None): + if config is None: + config = {} + self._weight_quantizer = LastValueQuantizer() + self._act_quantizer = LastValueQuantizer() + self._input_quantizer: Optional[Quantizer] = LastValueQuantizer() + self._output_quantizer: Optional[Quantizer] = LastValueQuantizer() + self._weight_names = weight_names + self._act_names = act_names + self._input_num = 0 + self._inputs_insert_fq = [] + + def get_weight_name_and_quantizers(self): + return [(name, self._weight_quantizer) for name in self._weight_names] + + def get_act_name_and_quantizers(self): + return [(name, self._act_quantizer) for name in self._act_names] + + def get_input_quantizer(self) -> Optional[Quantizer]: + return self._input_quantizer + + def get_output_quantizer(self) -> Optional[Quantizer]: + return self._output_quantizer + + def set_input_number(self, input_num: int): + self._input_num = input_num + for i in range(0, self._input_num): + self._inputs_insert_fq.append(True) + + def set_input_not_insert_fq(self, index: Optional[int] = None): + if index is None: + for i in range(0, self._input_num): + self._inputs_insert_fq[i] = False + else: + if index >= self._input_num: + raise RuntimeError("Index out of range of input number") + self._inputs_insert_fq[index] = False + + def get_input_need_insert_fq(self, index: int): + if index >= self._input_num: + raise RuntimeError("Index out of range of input number") + return self._inputs_insert_fq[index] + + def set_output_not_insert_fq(self, index: Optional[int] = None): + self._output_quantizer = None + + def wrap_cell(self, handler: Cell) -> Cell: + return QuantizeWrapperCell(handler, self) diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py new file mode 100644 index 00000000000..191c2b0bed4 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py @@ -0,0 +1,43 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DefaultNetworkPolicy.""" + +from ..net_policy import NetPolicy +from .default_layer_policy import DefaultLayerPolicy +from ..transformer import Transformer +from mindspore.nn.layer import Conv2d, Dense, MatMul, BatchNorm2d, ReLU + + +class DefaultNetworkPolicy(NetPolicy): + """ + Derived class of NetworkQConfig. Default network-quant-config. + + Supported Config: + ``quant_delay`` ``quant_dtype`` ``per_channel`` ``symmetric`` ``narrow_range`` ``one_conv_fold``. + """ + + def __init__(self, config=None): + super().__init__(config) + if config is None: + config = {} + self._pattern_engines: [Transformer] = [ + Transformer([Conv2d, BatchNorm2d]), + Transformer([Conv2d, ReLU]) + ] + self._support_layer_map: dict = { + Conv2d: DefaultLayerPolicy(["weight"], [], config), + Dense: DefaultLayerPolicy(["weight"], [], config), + MatMul: DefaultLayerPolicy(["weight"], [], config) + } diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantize.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantize.py new file mode 100644 index 00000000000..9a6d645646b --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantize.py @@ -0,0 +1,30 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DefaultQuantAwareTraining.""" + +from ...golden_stick import GoldenStick +from .default_net_policy import DefaultNetworkPolicy + + +class DefaultQuantAwareTraining(GoldenStick): + """ + Derived class of GoldenStick. Default QAT-algorithm. + """ + + def __init__(self, config=None): + super(DefaultQuantAwareTraining, self).__init__(config) + self._qat_policy = DefaultNetworkPolicy(config) + self._custom_transforms = config["custom_transforms"] + self._custom_layer_policy_map = config["custom_policies"] diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantizer.py new file mode 100644 index 00000000000..7aff644a890 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantizer.py @@ -0,0 +1,68 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DefaultQuantizeOp.""" + +from ..quantizer import Quantizer + + +class FixQuantizer(Quantizer): + ... + + +class AllValueQuantizer(Quantizer): + ... + + +class MovingAvgQuantizer(Quantizer): + ... + + +class LastValueQuantizer(Quantizer): + """ + Derived class of QuantizeOp. Use min and max value of data to compute scale and zero-point. + """ + + def __init__(self): + super().__init__() + self._bit_num = 8 + + def compute_quant_param(self, float_data: [float]) -> {}: + data_min = float_data[0] + data_max = float_data[1] + if data_max == data_min: + return 1, 0 + scale = (1 << self._bit_num) / (data_max - data_min) + zp = data_max * scale + return scale, zp + + def fake_quant(self, float_data: [float], quant_params: dict, **kwargs) -> [float]: + scale = quant_params.get("scale") + zp = quant_params.get("zp") + return float_data * scale + zp + + +class LSQ(Quantizer): + """ + Derived class of QuantizeOp. Use learning-rate from each epoch to compute scale and zero-point. + """ + + def __init__(self): + super(LSQ, self).__init__() + + def compute_quant_param(self, float_data: [float]) -> {}: + pass + + def fake_quant(self, float_data: [float], quant_params: dict, **kwargs) -> [float]: + pass diff --git a/mindspore/python/mindspore/golden_stick/quantization/hello_qat/__init__.py b/mindspore/python/mindspore/golden_stick/quantization/hello_qat/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindspore/python/mindspore/golden_stick/quantization/hello_qat/simple_qat.py b/mindspore/python/mindspore/golden_stick/quantization/hello_qat/simple_qat.py new file mode 100644 index 00000000000..5f4a51104f1 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/hello_qat/simple_qat.py @@ -0,0 +1,86 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SimpleQAT.""" + +from collections import OrderedDict +from ...golden_stick import GoldenStick +from ...net_transform import NetTransformer +from mindspore.nn import Conv2d, Cell, BatchNorm2d, FakeQuantWithMinMaxObserver +from mindspore.rewrite.pattern_engine import PatternEngine, PatternNode +from mindspore.rewrite import Node +from mindspore.train.callback import Callback + + +# define how to create a new Cell +class QuantCell(Cell): + def __init__(self, conv, bn, **kwargs): + super().__init__() + self._old_conv: Conv2d = conv + self._old_bn: BatchNorm2d = bn + self._fake_quant = FakeQuantWithMinMaxObserver(per_channel=kwargs["per_channel"]) + + def construct(self, x): + x = self._old_conv(x) + x = self._old_bn(x) + return self._fake_quant(x) + + +# define PatternEngine: +# pattern +# new cell class +# when and how to replace pattern with new cell +class ConvBnPatternEngine(PatternEngine): + def __init__(self, config): + self._conv_name = "_conv" + self._bn_name = "_bn" + # construct a tree-pattern + p_conv = PatternNode(self._conv_name, Conv2d) + p_bn = PatternNode(self._bn_name, BatchNorm2d, [p_conv]) + self._pattern: PatternNode = p_bn + super().__init__(self._pattern) + self._config = config + self._kernel_size_threshold = 500 + + def _process_tree(self, matched_cells: OrderedDict): + old_conv: Node = matched_cells[self._conv_name] + bn: Node = matched_cells[self._bn_name] + if int(old_conv.attribute("kernel_size")) < self._kernel_size_threshold: + return bn + return Node(QuantCell(old_conv, bn, **self._config), old_conv.inputs()) + + +class QATCallback(Callback): + def __init__(self): + super(QATCallback, self).__init__() + + def end(self, run_context): + # strip QuantCell to conv + bn + pass + + +# define qat algo +class QATCompressAlgo(GoldenStick): + def __init__(self, config: {}): + super(QATCompressAlgo, self).__init__(config) + self._pattern_helper = ConvBnPatternEngine(config) + self._cb = QATCallback() + + def apply(self, network: Cell) -> Cell: + transformer = NetTransformer(network) + transformer.pattern_transform(self._pattern_helper) + return transformer.get_transformed() + + def callback(self): + return self._cb diff --git a/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py b/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py new file mode 100644 index 00000000000..a6e839c8747 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py @@ -0,0 +1,101 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""LayerQConfig.""" +import abc +from typing import Optional +from .quantizer import Quantizer +from mindspore.nn import Cell + +layer_policy_key = "layer_quant_policy" + + +class LayerPolicy(abc.ABC): + """ + Base class for layer quantize configure. + Configuration including: + Which weights of layer to be fake-quantize and how they should be fake-quantized + If input and output of activation of layer need to be fake-quantized and how they should be fake-quantized + If output and layer need to be fake-quantized and how it should be fake-quantized + + Args: + config (int): User config for QAT. Config specification is default by derived class. + + Supported Config: + ``quant_delay`` ``quant_dtype`` ``per_channel`` ``symmetric`` ``narrow_range`` ``one_conv_fold``. + + Note: + Derived class must override `get_weight_name_and_quantizers`, `get_act_name_and_quantizers`, + `get_output_quantizers` and `wrapper_cell`. + """ + + def get_weight_name_and_quantizers(self) -> [(str, Quantizer)]: + """ + Define how to fake-quantize weight data. This method must be overridden by all subclasses. + + Returns: + Return a list of 2-tuple of weight_name and weight_quantizer. + Return empty list if no need to fake-quant weight. + """ + + return [] + + def get_act_name_and_quantizers(self) -> [(str, (Optional[Quantizer], Optional[Quantizer]))]: + return [] + + def get_input_quantizer(self) -> Optional[Quantizer]: + """ + Define how to fake-quantize input data. This method must be overridden by all subclasses. + + Returns: + Return a instance of quantizer as quantizer for inputs. + Return None if all inputs don't need to fake-quant. + """ + return None + + def get_output_quantizer(self) -> Optional[Quantizer]: + """ + Define how to fake-quantize output data. This method must be overridden by all subclasses. + + Returns: + Return a instance of quantizer as quantizer for outputs. + Return None if all outputs don't need to fake-quant. + """ + return None + + @abc.abstractmethod + def wrap_cell(self, handler: Cell) -> Cell: + """ + Define how to wrapper `handler`. This method must be overridden by all subclasses. + + Args: + handler (Cell): cell to be wrapped. + + Returns: + Wrapped cell. + """ + raise NotImplementedError + + def set_input_number(self, input_num: int): + pass + + def set_input_not_insert_fq(self, index: Optional[int] = None): + pass + + def get_input_need_insert_fq(self, index: int) -> bool: + return False + + # only support one-output-quantizer pre layer because we can not get how many outputs a cell would has + def set_output_not_insert_fq(self): + pass diff --git a/mindspore/python/mindspore/golden_stick/quantization/net_policy.py b/mindspore/python/mindspore/golden_stick/quantization/net_policy.py new file mode 100644 index 00000000000..fb0bc318ef2 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/net_policy.py @@ -0,0 +1,48 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""NetworkQConfig.""" +from typing import Optional + +from .layer_policy import LayerPolicy +from .transformer import Transformer + + +class NetPolicy: + """ + Base class for network quantize configure. + + Args: + config (Dict): User config for QAT. Config specification is default by derived class. + + Note: + Derived class must define `_pattern_engines` and `_support_layer_map` in constructor. + """ + + def __init__(self, config=None): + self._pattern_engines: [Transformer] = [] + self._layer_policy_map: dict = {} + self._net_layer_policy: Optional[LayerPolicy] = None + + def get_transformers(self) -> [Transformer]: + return self._pattern_engines + + def get_layer_policy_map(self) -> {str, LayerPolicy}: + return self._layer_policy_map + + def get_layer_policy(self, layer_type) -> Optional[LayerPolicy]: + return self._layer_policy_map.get(type(layer_type)) + + def get_net_layer_policy(self) -> Optional[LayerPolicy]: + return self._net_layer_policy diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize.py b/mindspore/python/mindspore/golden_stick/quantization/quantize.py new file mode 100644 index 00000000000..b8a401a1ba1 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize.py @@ -0,0 +1,126 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Quantize.""" +import copy +from typing import Optional + +from .net_policy import NetPolicy +from .layer_policy import LayerPolicy, layer_policy_key +from .transformer import Transformer +from ..golden_stick import GoldenStick +from ..net_transform import NetTransformer +from mindspore.rewrite import Node +from mindspore.nn import Cell + + +class QuantAwareTraining(GoldenStick): + """ + Derived class of GoldenStick. Default QAT-algorithm. + """ + + def __init__(self, config: {}): + super(QuantAwareTraining, self).__init__(config) + self._qat_policy = None + self._custom_transforms = None + self._custom_layer_policy_map = None + + def _propagate_layer_policy(self, nodes: [Node]): + # step1 apply net layer-policy first + net_layer_policy: Optional[LayerPolicy] = self._qat_policy.get_net_layer_policy() + if net_layer_policy: + for node in nodes: + NetTransformer.set_node_attr(node, layer_policy_key, copy.copy(net_layer_policy)) + # todo subgraph + # step2 then apply layer-policy map, override policy if need + layer_policy_map = self._qat_policy.get_layer_policy_map() + for node in nodes: + layer_policy: LayerPolicy = self._custom_layer_policy_map.get(node.type) + if layer_policy is None: + layer_policy = layer_policy_map.get(node.type) + if isinstance(layer_policy, LayerPolicy): + new_layer_policy = copy.copy(layer_policy) + new_layer_policy.set_input_number(len(node.inputs)) + NetTransformer.set_node_attr(node, layer_policy_key, new_layer_policy) + + @staticmethod + def _reduce_redundant_fake_quant(nodes: [Node]): + for node in nodes: + cur_policy: LayerPolicy = NetTransformer.get_node_attr(node, layer_policy_key) + # cur-node has no quant policy, so no fq will insert into its inputs + if cur_policy is None: + continue + input_nodes = node.inputs + for i in range(0, len(input_nodes)): + cur_in_quantizer = cur_policy.get_input_quantizer() + # cur-node's input quantizer is None, so no fq will insert into its inputs + if cur_in_quantizer is None: + continue + input_node: Node = input_nodes[i] + pre_policy: LayerPolicy = NetTransformer.get_node_attr(input_node, layer_policy_key) + # pre-node has no quant policy, so no fq will insert into its outputs + if pre_policy is None: + continue + output_nodes_of_input_node = input_node.outputs + for j in range(0, len(output_nodes_of_input_node)): + output_node_of_input_node = output_nodes_of_input_node[j] + if output_node_of_input_node is not node: + continue + pre_out_quantizer = pre_policy.get_output_quantizer() + # pre-node's output quantizer is None, so no fq will insert into its outputs + # or input fq of cur-node and output fq of pre-node are different + if type(pre_out_quantizer) is not type(cur_in_quantizer): + continue + # input fq of cur-node and output fq of pre-node are same type + # so we mark input fq of cur-node as redundant + cur_policy.set_input_not_insert_fq(i) + + def _apply_fuse_patterns(self, net_transformer: NetTransformer): + transformers = self._qat_policy.get_transformers() + if isinstance(self._custom_transforms, list): + for transform in self._custom_transforms: + if isinstance(transform, Transformer): + transformers.append(transform) + for transformer in transformers: + # Transformer always return False + # todo test overlap between transformers + net_transformer.pattern_transform(transformer) + + @staticmethod + def _apply_layer_policy(nodes: [Node], net_transformer: NetTransformer): + for node in nodes: + layer_policy = NetTransformer.get_node_attr(node, layer_policy_key) + if isinstance(layer_policy, LayerPolicy): + net_transformer.replace_node(node, layer_policy.wrap_cell(node.cell())) + + def apply(self, network: Cell) -> Cell: + """ + Apply QAT-Algorithm on `graph` + + Args: + network (Cell): Network to be quantized. + + Returns: + Quantized network. + """ + + if not isinstance(self._qat_policy, NetPolicy): + raise RuntimeError("Derived class should provide net policy") + net_transformer = NetTransformer(network) + nodes = net_transformer.nodes() + self._propagate_layer_policy(nodes) + QuantAwareTraining._reduce_redundant_fake_quant(nodes) + self._apply_fuse_patterns(net_transformer) + QuantAwareTraining._apply_layer_policy(nodes, net_transformer) + return net_transformer.get_transformed() diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py new file mode 100644 index 00000000000..e2a0e2ae1a7 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py @@ -0,0 +1,45 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""QuantizeWrapperActivation.""" + +from mindspore.nn import Cell +from .quantizer import Quantizer + + +class QuantizeWrapperActivation(Cell): + """ + Derive from Cell for define how to construct a wrap quant-cell from a normal cell with fake-quant algorithm. + + Args: + act (Cell): normal cell to be wrapped. + pre_quantizer (Quantizer): Define how weight data to be fake-quant. + post_quantizer (Quantizer): Define how activation data to be fake-quant. + """ + + def __init__(self, act: Cell, pre_quantizer: Quantizer = None, post_quantizer: Quantizer = None): + super().__init__() + self._handler: callable = act + self._pre_quantizer = pre_quantizer + self._post_quantizer = post_quantizer + + def construct(self, x): + if self._pre_quantizer is not None: + quant_param = self._pre_quantizer.compute_quant_param(x) + x = self._pre_quantizer.fake_quant(x, quant_param) + x = self._handler(x) + if self._post_quantizer is not None: + quant_param = self._post_quantizer.compute_quant_param(x) + x = self._post_quantizer.fake_quant(x, quant_param) + return x diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py new file mode 100644 index 00000000000..5afede41f2a --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py @@ -0,0 +1,105 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""QuantizeWrapperCell.""" + +from mindspore.nn import Cell +from .quantizer import Quantizer +from .layer_policy import LayerPolicy +from .quantize_wrapper_act import QuantizeWrapperActivation + + +class QuantizeWrapperCell(Cell): + """ + Derive from Cell for define how to construct a wrap quant-cell from a normal cell with fake-quant algorithm. + + Args: + handler (Cell): normal cell to be wrapped. + layer_policy (Quantizer): Define how weight data to be fake-quant. + """ + + def __init__(self, handler: Cell, layer_policy: LayerPolicy): + super().__init__() + self._handler: Cell = handler + self._policy = layer_policy + self._w_scale = 1.0 + self._w_zp = 0 + self._o_scale = 1.0 + self._o_zp = 0 + + def construct(self, *inputs, **kwargs): + """ + Defines the computation of QuantizeWrapperCell to be performed. + + Returns: + Tensor, returns the computed result. + """ + + # fake-quant weight + for weight_name, quantizer in self._policy.get_weight_name_and_quantizers(): + assert weight_name is not None + assert quantizer is not None + weight = getattr(self._handler, weight_name) + quant_param = quantizer.compute_quant_param(weight) + fq_data = quantizer.fake_quant(weight, quant_param) + setattr(self._handler, weight_name, fq_data) + + # fake-quant activation + for act_name, quantizers in self._policy.get_act_name_and_quantizers(): + assert act_name is not None + if quantizers is None: + continue + pre_quantizer, post_quantizer = quantizers + activation = getattr(self._handler, act_name) + quant_act = QuantizeWrapperActivation(activation, pre_quantizer, post_quantizer) + setattr(self._handler, act_name, quant_act) + + # fake-quant input + input_quantizer = self._policy.get_input_quantizer() + if input_quantizer is None: + fq_inputs = inputs + else: + fq_inputs = [] + input_len = len(inputs) + for i in range(0, input_len): + ori_input = inputs[i] + if self._policy.get_input_need_insert_fq(i): + quant_param = input_quantizer.compute_quant_param(ori_input) + fq_inputs.append(input_quantizer.fake_quant(ori_input, quant_param)) + else: + fq_inputs.append(ori_input) + + # forward handler + outputs = self._handler(*fq_inputs, **kwargs) + + # fake-quant output + output_quantizer = self._policy.get_output_quantizer() + if output_quantizer is None: + return outputs + if isinstance(outputs, list) or isinstance(outputs, tuple): + raise RuntimeError("Only support single output tensor fake-quant now") + output_len = len(outputs) + if output_len == 0: + return outputs + elif output_len == 1: + quant_param = output_quantizer.compute_quant_param(outputs) + fq_data = output_quantizer.fake_quant(outputs, quant_param) + return fq_data + else: + fq_outputs = [] + for i in range(0, output_len): + ori_output = outputs[i] + quant_param = output_quantizer.compute_quant_param(ori_output) + fq_outputs.append(output_quantizer.fake_quant(ori_output, quant_param)) + return fq_outputs diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/quantizer.py new file mode 100644 index 00000000000..8c9e6e4202e --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/quantizer.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Quantizer.""" + + +class Quantizer: + def __init__(self): + pass + + def compute_quant_param(self, float_data: [float]) -> {}: + """ + Compute quant-params such as min/max/scale/zero-point according to input `data`. + This method must be overridden by all subclasses. + + Args: + float_data (List[float]): input data for quant-params. + + Returns: + a dictionary as quant-params + """ + + pass + + def fake_quant(self, float_data: [float], quant_params: dict, **kwargs) -> [float]: + """ + FakeQuant input `float-data` according to quant_params and other args. + This method must be overridden by all subclasses. + + Args: + float_data (List[float]): input data to be fake-quantize. + quant_params (dict): quant-params of input data. + + Returns: + FakeQuantized data. + """ + + pass diff --git a/mindspore/python/mindspore/golden_stick/quantization/test_common.py b/mindspore/python/mindspore/golden_stick/quantization/test_common.py new file mode 100644 index 00000000000..c6bb330c882 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/test_common.py @@ -0,0 +1,51 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test common method.""" + +from mindspore.golden_stick import LayerPolicy, DefaultLayerPolicy +from mindspore.golden_stick.quantization.layer_policy import layer_policy_key +from mindspore.rewrite import Node, PlaceholderNode +from mindspore.nn import Conv2d, BatchNorm2d, MaxPool2d + + +class TestCommon: + @staticmethod + def create_layer_policy(input_num, weight_names: []) -> LayerPolicy: + layer_policy = DefaultLayerPolicy(weight_names, []) + layer_policy.set_input_number(input_num) + return layer_policy + + @staticmethod + def create_conv_layer(name, inputs): + conv = Node(name=name, inputs=inputs, instance=Conv2d(16, 16, 9)) + conv.set_attribute(layer_policy_key, TestCommon.create_layer_policy(1, ["weight"])) + return conv + + @staticmethod + def create_bn_layer(name, inputs): + bn = Node(name=name, inputs=inputs, instance=BatchNorm2d(16)) + bn.set_attribute(layer_policy_key, TestCommon.create_layer_policy(1, ["gamma"])) + return bn + + @staticmethod + def create_pool_layer(name, inputs): + pool = Node(name=name, inputs=inputs, instance=MaxPool2d()) + pool.set_attribute(layer_policy_key, TestCommon.create_layer_policy(1, [])) + return pool + + @staticmethod + def create_placeholder_layer(): + placeholder = PlaceholderNode("placeholder") + return placeholder diff --git a/mindspore/python/mindspore/golden_stick/quantization/transformer.py b/mindspore/python/mindspore/golden_stick/quantization/transformer.py new file mode 100644 index 00000000000..307adaedab8 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/transformer.py @@ -0,0 +1,100 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformer.""" +from typing import List, Union +from mindspore.rewrite import PatternEngine, PatternNode, Graph, Node +from .layer_policy import LayerPolicy, layer_policy_key + + +# Only support for fusion currently +class Transformer(PatternEngine): + def __init__(self, pattern: Union[PatternNode, List]): + super().__init__(pattern, None) + self._node_visited_key: str = "is_node_visited" + + def _is_node_visited(self, node: Node) -> bool: + assert node is not None + return node.get_attribute(self._node_visited_key) + + def _set_node_visited(self, node: Node): + assert node is not None + node.set_attribute(self._node_visited_key, True) + + def apply(self, graph: Graph) -> bool: + root: Node = graph.root() + # IR match + queue: [Node] = [root] + while len(queue) > 0: + cur_node: Node = queue.pop(0) + cur_node_inputs = cur_node.inputs + matched, matched_dict = self._match(self._pattern, cur_node) + if not matched or not PatternEngine._check_match(self._pattern, matched_dict): + for cur_node_input in cur_node_inputs: + queue.append(cur_node_input) + continue + matched_list = list(matched_dict.values()) + overlapped = False + for matched_node in matched_list: + if self._is_node_visited(matched_node): + overlapped = True + break + if overlapped: + for cur_node_input in cur_node_inputs: + queue.append(cur_node_input) + continue + for matched_node in matched_list: + self._set_node_visited(matched_node) + + # modify layer policy + # find input output + inputs_of_matched: [Node] = [] + output = matched_dict.get(self._pattern.name()) + inputs: [] = [] + for matched_node in matched_list: + node_inputs = matched_node.inputs + is_input_node = False + for node_input in node_inputs: + if node_input in matched_dict.values(): + continue + inputs_of_matched.append(node_input) + is_input_node = True + if is_input_node: + inputs.append(matched_node) + # remove inter-matched-node-policy + for matched_node in matched_list: + node_policy: LayerPolicy = matched_node.get_attribute(layer_policy_key) + if node_policy is None: + continue + is_input = matched_node in inputs + is_output = matched_node == output + if not is_input and not is_output: + node_policy.set_input_not_insert_fq() + node_policy.set_output_not_insert_fq() + continue + if is_input and not is_output: + node_policy.set_output_not_insert_fq() + continue + if is_output and not is_input: + node_policy.set_input_not_insert_fq() + continue + for i in range(0, len(matched_node.inputs)): + node_input = matched_node.inputs[i] + if node_input in inputs_of_matched: + continue + node_policy.set_input_not_insert_fq(i) + + for input_of_matched in inputs_of_matched: + queue.append(input_of_matched) + return False diff --git a/mindspore/python/mindspore/golden_stick/quantization/transformer_test.py b/mindspore/python/mindspore/golden_stick/quantization/transformer_test.py new file mode 100644 index 00000000000..ea14f184798 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/transformer_test.py @@ -0,0 +1,219 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test Transformer.""" + +import unittest +from test_common import TestCommon +from mindspore.golden_stick import Transformer, LayerPolicy +from mindspore.golden_stick.quantization.layer_policy import layer_policy_key +from mindspore.rewrite import Node, Graph, PlaceholderNode +from mindspore.nn import Cell, Conv2d, BatchNorm2d + + +class TransformerTestCase(unittest.TestCase): + @staticmethod + def network(): + placeholder = TestCommon.create_placeholder_layer() + + conv1 = TestCommon.create_conv_layer("conv1", [placeholder]) + placeholder.outputs = [conv1] + + bn1 = TestCommon.create_bn_layer("bn1", [conv1]) + conv1.outputs = [bn1] + + pool1 = TestCommon.create_pool_layer("pool1", [bn1]) + bn1.outputs = [pool1] + + conv2 = TestCommon.create_conv_layer("conv2", [pool1]) + pool1.outputs = [conv2] + + bn2 = TestCommon.create_bn_layer("bn2", [conv2]) + conv2.outputs = [bn2] + + pool2 = TestCommon.create_pool_layer("pool2", [bn2]) + bn2.outputs = [pool2] + + graph = Graph(Cell) + graph.set_root(pool2) + return graph + + @staticmethod + def network_intra_overlapped(): + placeholder = TestCommon.create_placeholder_layer() + + conv1 = TestCommon.create_conv_layer("conv1", [placeholder]) + placeholder.outputs = [conv1] + + conv2 = TestCommon.create_conv_layer("conv2", [conv1]) + conv1.outputs = [conv2] + + conv3 = TestCommon.create_conv_layer("conv3", [conv2]) + conv2.outputs = [conv3] + + bn = TestCommon.create_bn_layer("bn", [conv3]) + conv3.outputs = [bn] + + pool = TestCommon.create_pool_layer("pool", [bn]) + bn.outputs = [pool] + + graph = Graph(Cell) + graph.set_root(pool) + return graph + + @staticmethod + def network_inter_overlapped(): + placeholder = TestCommon.create_placeholder_layer() + + conv1 = TestCommon.create_conv_layer("conv1", [placeholder]) + placeholder.outputs = [conv1] + + conv2 = TestCommon.create_conv_layer("conv2", [conv1]) + conv1.outputs = [conv2] + + bn = TestCommon.create_bn_layer("bn", [conv2]) + conv2.outputs = [bn] + + pool = TestCommon.create_pool_layer("pool", [bn]) + bn.outputs = [pool] + + graph = Graph(Cell) + graph.set_root(pool) + return graph + + @staticmethod + def _get_node_inout_fq(node: Node, is_input: bool = True): + policy: LayerPolicy = node.get_attribute(layer_policy_key) + if policy is None: + return None + if is_input: + fq = policy.get_input_quantizer() + else: + fq = policy.get_output_quantizer() + if not is_input: + return fq + if fq is None or not policy.get_input_need_insert_fq(0): + return None + else: + return fq + + @staticmethod + def _get_node_of_graph_inout_fq(graph: Graph, node_index: int, is_input: bool = True): + if node_index >= len(graph._nodes): + return None + node = graph._nodes[node_index] + if node is None: + return None + return TransformerTestCase._get_node_inout_fq(node, is_input) + + def test_apply(self): + transformer: Transformer = Transformer([Conv2d, BatchNorm2d]) + graph: Graph = TransformerTestCase.network() + conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) + bn1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) + bn1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) + conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 4, True) + conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 4, False) + bn2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 5, True) + bn2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 5, False) + self.assertNotEqual(conv1_in_fq, None) + self.assertNotEqual(conv1_out_fq, None) + self.assertNotEqual(bn1_in_fq, None) + self.assertNotEqual(bn1_out_fq, None) + self.assertNotEqual(conv2_in_fq, None) + self.assertNotEqual(conv2_out_fq, None) + self.assertNotEqual(bn2_in_fq, None) + self.assertNotEqual(bn2_out_fq, None) + transformer.apply(graph) + conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) + bn1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) + bn1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) + conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 4, True) + conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 4, False) + bn2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 5, True) + bn2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 5, False) + self.assertNotEqual(conv1_in_fq, None) + self.assertEqual(conv1_out_fq, None) + self.assertEqual(bn1_in_fq, None) + self.assertNotEqual(bn1_out_fq, None) + self.assertNotEqual(conv2_in_fq, None) + self.assertEqual(conv2_out_fq, None) + self.assertEqual(bn2_in_fq, None) + self.assertNotEqual(bn2_out_fq, None) + + def test_intra_overlap(self): + transformer: Transformer = Transformer([Conv2d, Conv2d]) + graph: Graph = TransformerTestCase.network_intra_overlapped() + conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) + conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) + conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) + conv3_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, True) + conv3_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, False) + self.assertNotEqual(conv1_in_fq, None) + self.assertNotEqual(conv1_out_fq, None) + self.assertNotEqual(conv2_in_fq, None) + self.assertNotEqual(conv2_out_fq, None) + self.assertNotEqual(conv3_in_fq, None) + self.assertNotEqual(conv3_out_fq, None) + transformer.apply(graph) + conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) + conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) + conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) + conv3_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, True) + conv3_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, False) + self.assertNotEqual(conv1_in_fq, None) + self.assertNotEqual(conv1_out_fq, None) + self.assertNotEqual(conv2_in_fq, None) + self.assertEqual(conv2_out_fq, None) + self.assertEqual(conv3_in_fq, None) + self.assertNotEqual(conv3_out_fq, None) + + def test_inter_overlap(self): + transformer1: Transformer = Transformer([Conv2d, Conv2d]) + transformer2: Transformer = Transformer([Conv2d, BatchNorm2d]) + graph: Graph = TransformerTestCase.network_inter_overlapped() + conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) + conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) + conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) + bn_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, True) + bn_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, False) + self.assertNotEqual(conv1_in_fq, None) + self.assertNotEqual(conv1_out_fq, None) + self.assertNotEqual(conv2_in_fq, None) + self.assertNotEqual(conv2_out_fq, None) + self.assertNotEqual(bn_in_fq, None) + self.assertNotEqual(bn_out_fq, None) + transformer1.apply(graph) + transformer2.apply(graph) + conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) + conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) + conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) + bn_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, True) + bn_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, False) + self.assertNotEqual(conv1_in_fq, None) + self.assertEqual(conv1_out_fq, None) + self.assertEqual(conv2_in_fq, None) + self.assertNotEqual(conv2_out_fq, None) + self.assertNotEqual(bn_in_fq, None) + self.assertNotEqual(bn_out_fq, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/mindspore/python/mindspore/rewrite/__init__.py b/mindspore/python/mindspore/rewrite/__init__.py index 2a3665186ba..1ace088ec52 100644 --- a/mindspore/python/mindspore/rewrite/__init__.py +++ b/mindspore/python/mindspore/rewrite/__init__.py @@ -1,5 +1,5 @@ from .graph import Graph -from .node import Node, NodeType +from .node import Node, NodeType, PlaceholderNode from .pattern_engine import PatternEngine, PatternNode, PlaceHolderNode -__all__ = ["Graph", "Node", "NodeType", "PatternEngine", "PatternNode", "PlaceHolderNode"] +__all__ = ["Graph", "Node", "NodeType", "PatternEngine", "PatternNode", "PlaceHolderNode", "PlaceholderNode"] diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py index 57f9be0edf0..6ad4e667dfa 100644 --- a/mindspore/python/mindspore/rewrite/node.py +++ b/mindspore/python/mindspore/rewrite/node.py @@ -19,7 +19,10 @@ class AttributeNode: self._type: NodeType = type self._class = class_ self._is_custom_define = is_custom_define - self._attribute: Dict = attibute + if isinstance(attibute, dict): + self._attribute: dict = attibute + else: + self._attribute: dict = dict() self._constant_value = constant_value @property @@ -83,6 +86,24 @@ class BaseNode: def node_type(self) -> NodeType: return self._attribute.type + @property + def attribute(self) -> AttributeNode: + return self._attribute + + @attribute.setter + def attribute(self, attribute: AttributeNode): + self._attribute = attribute + + def set_attribute(self, attribute: Dict): + for key, value in attribute.items(): + self._attribute._attribute[key] = value + + def set_attribute(self, key: str, value): + self._attribute._attribute[key] = value + + def get_attribute(self, key: str): + return self._attribute._attribute.get(key) + class Node(BaseNode): def __init__(self, name="", targets=None, args=None, ast_node=None, instance: Union[nn.Cell, Primitive] = None, inputs: List = None): """ @@ -96,18 +117,6 @@ class Node(BaseNode): self._index = 0 self._attribute._class = type(instance) - @property - def attribute(self) -> AttributeNode: - return self._attribute - - @attribute.setter - def attribute(self, attribute: AttributeNode): - self._attribute = attribute - - def add_attribute(self, attribute: Dict): - for key, value in attribute.items(): - self._attribute._attribute[key] = value - @property def type(self): return self._attribute._class diff --git a/mindspore/python/mindspore/rewrite/pattern_engine.py b/mindspore/python/mindspore/rewrite/pattern_engine.py index 70dbcaaf21f..f1db00b748a 100644 --- a/mindspore/python/mindspore/rewrite/pattern_engine.py +++ b/mindspore/python/mindspore/rewrite/pattern_engine.py @@ -246,13 +246,13 @@ class PatternEngine: if self._replacement is None: return matched_nodes[len(matched_nodes) - 1] - instance = self._replacement(*matched_nodes) - if instance is None: + replacement = self._replacement(*matched_nodes) + if replacement is None: return None if len(matched_nodes) == 0: - new_node = Node(instance=instance) + new_node = Node(instance=replacement) else: - new_node = Node(instance=instance, inputs=matched_nodes[0].inputs) + new_node = Node(instance=replacement, inputs=matched_nodes[0].inputs) node_name = "" for matched_node in matched_nodes: node_name += matched_node.name + "_" -- Gitee From c784239dc4fbf080094077840782ffac78c8e313 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Mon, 20 Dec 2021 09:23:46 +0800 Subject: [PATCH 04/34] update rewrite --- .../python/mindspore/rewrite/globals.log | 1 - mindspore/python/mindspore/rewrite/graph.py | 22 ++- mindspore/python/mindspore/rewrite/node.py | 65 +++++-- mindspore/python/mindspore/rewrite/parser.py | 170 +++++++++--------- .../python/mindspore/rewrite/rewriter.py | 29 +-- 5 files changed, 163 insertions(+), 124 deletions(-) delete mode 100644 mindspore/python/mindspore/rewrite/globals.log diff --git a/mindspore/python/mindspore/rewrite/globals.log b/mindspore/python/mindspore/rewrite/globals.log deleted file mode 100644 index 000728d2a72..00000000000 --- a/mindspore/python/mindspore/rewrite/globals.log +++ /dev/null @@ -1 +0,0 @@ -input globals: {'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <_frozen_importlib_external.SourceFileLoader object at 0x7f6a9a4f04d0>, '__spec__': None, '__annotations__': {}, '__builtins__': , '__file__': 'test_app.py', '__cached__': None, 'Graph': , 'parse': , 'ControlSimpleIf': } diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index d08e8965aa4..9d90ffeacfb 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -8,6 +8,7 @@ import astunparse import astpretty import mindspore.nn as nn +from mindspore import log as logger from mindspore.ops.primitive import Primitive from .node import AttributeNode, ConstantNode, Node, NodeType, PlaceholderNode @@ -139,7 +140,7 @@ class Graph(): """ if "__init__" not in self._network.__dict__.keys(): return - print("================= parse init function ========================") + print("================= parse " + self._base_scope + " init function start ========================") self.create_placeholder(self._ast_function_root["__init__"]) self._parser.updete_closure_namespace(self._network.__init__) for ast_node in self._ast_function_root["__init__"].body: @@ -148,13 +149,16 @@ class Graph(): if isinstance(ast_node, ast.Assign): new_node = self._parser.parse_init_assign(ast_node) - print("new node in init function", new_node) + #print("new node in init function", new_node) self._node_attributes[new_node.name] = new_node - + print("init function node: ") + for node in self._node_attributes: + print(node) + print("================= parse " + self._base_scope + " init function end ========================") return def parse_construct(self): - print("================= parse construct function start ========================") + print("================= parse " + self._base_scope + " construct function start ========================") self._parser.updete_closure_namespace(self._network.construct) self.create_placeholder(self._ast_function_root["construct"]) name_counts = {} #save the number of the variable, if the number is over 1,then modify the name - add a number as the name suffix @@ -201,10 +205,10 @@ class Graph(): self._find_input_node(nodes[i]) index += 1 self._nodes.append(nodes[i]) - print("======================= construct nodes =========================") + print("construct nodes: ") for node in self._nodes: print(node) - print("================= parse construct function end ========================") + print("================= parse " + self._base_scope + " construct function end ========================") def parse_function(self, func: Union[ast.FunctionDef, FunctionType]): if isinstance(func, FunctionType): @@ -213,11 +217,11 @@ class Graph(): ast_root = ast.parse(function_str) astpretty.pprint (ast_root) node = ast_root.body[0] - subgraph = FunctionGraph(func) #要区分类内还是类外方法 + subgraph = Graph(func) #要区分类内还是类外方法 else: print("================= parse " + func.name + " function ========================") node = func - subgraph = FunctionGraph(self._network.__dict__[node.name]) #要区分类内还是类外方法 + subgraph = Graph(self._network.__dict__[node.name]) #要区分类内还是类外方法 subgraph._name = node.name subgraph._ast_root = node subgraph.create_placeholder(node) @@ -306,7 +310,7 @@ class Graph(): self._subgraphs["self." + name] = subgraph elif node._is_custom_define and issubclass(node._class, (nn.Cell)): print ("this node is Cell class") - graph = CellGraph(node._class) + graph = Graph(node._class) graph.create_ast() graph.print_ast() graph.get_function_root() diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py index 6ad4e667dfa..2402287c211 100644 --- a/mindspore/python/mindspore/rewrite/node.py +++ b/mindspore/python/mindspore/rewrite/node.py @@ -1,3 +1,17 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ import ast from typing import Dict, List, Union @@ -5,15 +19,27 @@ import mindspore.nn as nn from mindspore.ops.primitive import Primitive class NodeType(): - placeholder = 1 # 输入 - parameter = 2 # 权重 - constant = 3 # 常量 - call_cell = 4 # 预计cell的对象 - call_method = 5 # cell内部成员 - call_function = 6 # 基于primitive对象 - output = 7 # 输出 + placeholder = 1 # input + parameter = 2 # weight + constant = 3 + call_cell = 4 # call cell object + call_method = 5 # method in cell + call_function = 6 # subclass of primitive + output = 7 + invalid = 8 class AttributeNode: + """ + Save Node attribute + + Args: + name: the name of node, it equals to the variable name in __init__ function + type: + class_: + is_custom_define: + attribute: it is a Dict that save the attribute name and value + constant_value: if the node is constant, it save the value of the node + """ def __init__(self, name="", type=NodeType.call_cell, class_=None, is_custom_define=False, attibute=None, constant_value=None) -> None: self._name = name self._type: NodeType = type @@ -45,6 +71,15 @@ class AttributeNode: return f"name: {self.name}; type: {self.type}; class: {self._class}; attribute: {self._attribute}; is_custom_define: {self._is_custom_define}; constant value: {self._constant_value}" class BaseNode: + """ + Base class of node. + + Args: + name: the name of node. + targets: the output names of the node. + args: the input names of the node. + inputs: the input nodes of this node. + """ def __init__(self, name="", targets=None, args=None, inputs: List = None): """ 创建一个节点时对应的属性怎么传进来,cell应该不涉及,primitive会有这种情况 @@ -136,19 +171,18 @@ class Node(BaseNode): output_nodes = "" for n in self.inputs: - input_names += n.name + " " + input_names += n.name + ", " # input_nodes += str(n) for n in self.outputs: - output_names += n.name + " " + output_names += n.name + ", " #output_nodes += str(n) - #attr_info = "attr name: " + self._attribute.name + "; class" + str(self.attribute._class) + "; is coustum defined: " + str(self.attribute._is_custom_define) return f"name: {self._name}; ast_node: {self._ast_node}; scope: {self._scope}; index: {self._index}; inputs: {len(self.inputs)}; input names: {input_names}; outputs: {len(self.outputs)}; output names: {output_names}; attr info: {self._attribute}" class ConstantNode(BaseNode): - def __init__(self, value): - super().__init__(str(value)) + def __init__(self, name="constant", value=None): + super().__init__(name=name, args=[], targets=[]) #self._name = str(value) self._value = value self._args.append(value) @@ -157,18 +191,15 @@ class ConstantNode(BaseNode): def __repr__(self) -> str: output_names = "" for n in self.outputs: - output_names += n.name + " " - return f"name: {self._name}; value: {self._value}; ast_node: {self._ast_node}; index: {self._index}; outputs: {len(self.outputs)}; output names: {output_names}" + output_names += n.name + ", " + return f"name: {self._name}; value: {self._value}; outputs: {len(self.outputs)}; output names: {output_names}" class PlaceholderNode(BaseNode): def __init__(self, name, targets=None, ast_node=None, default_value=None): super().__init__(name, targets) - #self._name = "" - #self._target = target self._ast_node = ast_node self._default_value = default_value self._attribute.type = NodeType.placeholder - #self._outputs: List = [] def __repr__(self) -> str: output_names = "" diff --git a/mindspore/python/mindspore/rewrite/parser.py b/mindspore/python/mindspore/rewrite/parser.py index 135212cb1bd..d20d7597345 100644 --- a/mindspore/python/mindspore/rewrite/parser.py +++ b/mindspore/python/mindspore/rewrite/parser.py @@ -4,11 +4,12 @@ import inspect from types import FunctionType from typing import Dict, List, Union -from .namespace import CellNamespace, ClosureNamespace, Namespace -from .node import AttributeNode, Node, NodeType - from mindspore.ops.primitive import Primitive import mindspore.nn as nn +from mindspore import log as logger + +from .namespace import CellNamespace, ClosureNamespace, Namespace +from .node import AttributeNode, ConstantNode, Node, NodeType namespace_nodetype_map = { "mindspore.common": NodeType.call_cell, @@ -72,7 +73,10 @@ class Parser(): def _parse_targets(self, node: ast.AST): visitor = self._get_node_visitor(node) - res = visitor(node) + if visitor: + res = visitor(node) + else: + logger.warning("get node visiter failed, node: %r", node) return res def _parse_args(self, node_list: ast.List): @@ -110,9 +114,10 @@ class Parser(): visitor = self._get_node_visitor(node.targets) targets = visitor(node.targets) - print ("targets:", targets) + logger.debug(f"start parse node in __init__ function: {node}") + #print ("targets:", targets) value = node.value - new_node = AttributeNode(name=targets[0]) + new_node = AttributeNode(name=targets[0]) # TODO: deal with multi outputs if isinstance(value, ast.Call): self.parse_init_Call(value, new_node) elif isinstance(value, ast.Name): @@ -124,10 +129,11 @@ class Parser(): else: new_node._constant_value = value.id elif isinstance(value, ast.BinOp): - print("value is BinOp") + logger.debug("value is BinOp") pass else: - print("vaule type: ", type(value), " is not supported") + logger.warning("vaule type: ", type(value), " is not supported") + logger.debug("parse end, result: ", new_node) return new_node def _parse_init_args(self, ast_nodes: ast.List): @@ -173,7 +179,7 @@ class Parser(): def parse_init_Call(self, ast_node: ast.Call, attr_node: AttributeNode): def _update_args_value(args: List, keywords: Dict, parameters: inspect.signature): - print("defaults values:", self._default_values) + logger.debug("defaults values:", self._default_values) new_dict = OrderedDict() for name, para_ in parameters.items(): @@ -199,20 +205,19 @@ class Parser(): else: new_dict[key] = value - print ("new dict: ", new_dict) return new_dict new_dict = OrderedDict() visitor = self._get_node_visitor(ast_node.func) value = visitor(ast_node.func) - print ("node name: ", value.split(".")[-1]) + #print ("node name: ", value.split(".")[-1]) class_name = value.split(".")[-1] class_, name_space, is_custom_define = self.get_func_namesapce(class_name) - print ("class: ", class_) + #print ("class: ", class_) parameters = inspect.signature(class_.__init__).parameters - print ("parameters: ", parameters) + #print ("parameters: ", parameters) if name_space in namespace_nodetype_map: node_type = namespace_nodetype_map[name_space] else: @@ -220,12 +225,14 @@ class Parser(): args = self._parse_init_args(ast_node.args) keywords = self._parse_init_keywords(ast_node.keywords) + logger.debug("before update parameters: ", parameters) new_dict = _update_args_value(args, keywords, parameters) - + logger.debug ("the node: ", new_dict) attr_node._class = class_ attr_node._type = node_type attr_node._is_custom_define = is_custom_define attr_node._attribute = new_dict + return def parse_init_BinOp(self, ast_node: ast.BinOp): @@ -237,39 +244,44 @@ class Parser(): value = self.parse_init_BinOp(node) else: visitor = self._get_node_visitor(node) + if not visitor: + logger.warning("get node visitor failed in parse_init_BinOp") + return None value = visitor(node) if value in self._default_values.keys(): value = self._default_values[value] return value + logger.debug("start parse binop: ", ast_node) op = ast_node.op left = ast_node.left right = ast_node.right left_value = _get_value(left) right_value = _get_value(right) - print("left value: ", left_value) - print("right value: ", right_value) + logger.debug("left value: ", left_value) + logger.debug("right value: ", right_value) if (isinstance(left_value, int) and isinstance(right_value, int)) or (str.isdigit(str(left_value)) and str.isdigit(str(right_value))): method = '_calc_' + op.__class__.__name__ calc = getattr(self, method, None) if calc: result = calc(left_value, right_value) else: - print("undefined op", method) + logger.warning("undefined op", method) else: result = str(left_value) + " op.__class__.__name__ " + repr(right_value) return result def parse_Assign(self, node: ast.Assign): + logger.debug("start parse assign node: ", node) lineno = node.lineno nodes = [] called_obj_names = [] visitor = self._get_node_visitor(node.targets) targets = visitor(node.targets) - print("targets: ", targets) + #print("targets: ", targets) value = node.value new_node = Node(targets=targets, ast_node=node) @@ -278,9 +290,10 @@ class Parser(): nodes.extend(nodes_) called_obj_names.extend(called_obj_names_) - print("new node in assign: ", new_node) + #print("new node in assign: ", new_node) assert(len(nodes) == len(called_obj_names)) + logger.debug("parse node end") return nodes, called_obj_names def parse_Call(self, ast_node: ast.Call, node: Node): @@ -301,14 +314,16 @@ class Parser(): nodes.append(node) called_obj_names.extend(called_obj_names_) called_obj_names.append(called_obj_name) - print ("nodes in call:", nodes) - print("called obj name in call: ", called_obj_names) + #print ("nodes in call:", nodes) + #print("called obj name in call: ", called_obj_names) return nodes, called_obj_names def parse_Attribute(self, node: ast.Attribute): visitor = self._get_node_visitor(node.value) - attribute_value = visitor(node.value) + "." + node.attr - + if visitor: + attribute_value = visitor(node.value) + "." + node.attr + else: + logger.warning("get node visitor failed in parse_Attribute") return attribute_value def parse_list(self, node: ast.List) -> list: @@ -348,6 +363,30 @@ class Parser(): pass def parse_BinOp(self, ast_node: ast.BinOp, node: Node): #如果left和right都是Call则需要分别创建节点,同时分析call的args,根据args也创建对应节点 + def _get_value(ast_node: ast.AST, args, nodes, called_obj_names, side): + visitor = self._get_node_visitor(ast_node) + if not visitor: + logger.warning("get node visitor failed in parse_BinOp._get_value, node: ", ast_node) + return + if isinstance(ast_node, ast.Call) or isinstance(ast_node, ast.BinOp): + new_node = Node(args=[], targets=[]) + nodes_, called_obj_names_ = visitor(ast_node, new_node) + new_node._targets.append(side) + nodes.extend(nodes_) + called_obj_names.extend(called_obj_names_) + args.append(side) + assert(len(nodes) == len(called_obj_names)) + elif isinstance(ast_node, ast.Num): + value = visitor(ast_node) + if value in self._default_values.keys(): + value = self._default_values[value] + new_node = ConstantNode(name="constant" + str(value), value=value) + new_node._targets.append(new_node.name) + args.append(new_node.name) + elif isinstance(ast_node, ast.Name): + value = visitor(ast_node) + args.append(value) + return nodes = [] called_obj_names = [] #ops_info = parse_object_map. @@ -355,34 +394,7 @@ class Parser(): node._attribute._class = ast_node.op #ast node type must convert to mindspore op type node._attribute.node_type = NodeType.call_function args = [] - left = ast_node.left - if isinstance(left, ast.Call): - new_node = Node() - nodes_, called_obj_names_ = self.parse_Call(left, new_node) - new_node._targets.append("tmp") - nodes.extend(nodes_) - called_obj_names.extend(called_obj_names_) - args.append("tmp") - print("left node in BinOp: ", new_node) - assert(len(nodes) == len(called_obj_names)) - else: - visitor = self._get_node_visitor(left) - args.append(visitor(left)) - right = ast_node.right - if isinstance(right, ast.Call): - new_node = Node() - nodes_, called_obj_names_ = self.parse_Call(right, new_node) - new_node._targets.append("tmp") - nodes.extend(nodes_) - called_obj_names.extend(called_obj_names_) - args.append("tmp") - print("right node in BinOp: ", new_node) - assert(len(nodes) == len(called_obj_names)) - else: - visitor = self._get_node_visitor(right) - args.append(visitor(right)) - node._args = args nodes.append(node) called_obj_names.append(node.name) @@ -430,49 +442,39 @@ class Parser(): def parse_Str(self, node: ast.Str): return node.s - def parse_AugAssign(self, node: ast.AugAssign): + def parse_AugAssign(self, ast_node: ast.AugAssign): nodes = [] called_obj_names = [] - new_node = Node() + augAssign_node = Node(args=[], targets=[]) - new_node._targets.append(node.target.id) + augAssign_node._targets.append(ast_node.target.id) + augAssign_node._args.insert(0, ast_node.target.id) - if isinstance(node.value, ast.Call): - nodes_, called_obj_names_ = self.parse_Call(node.value, new_node) - new_node._args.insert(0, node.target.id) + if isinstance(ast_node.value, ast.Call): + new_node = Node(args=[], targets=[]) + nodes_, called_obj_names_ = self.parse_Call(ast_node.value, new_node) + new_node._targets.append("tmp") nodes.extend(nodes_) called_obj_names.extend(called_obj_names_) - elif isinstance(node.value, ast.Name): - new_node.name = node.value.id - new_node._args.append(node.value.id) - new_node._args.insert(0, node.target.id) - nodes.append(new_node) - elif isinstance(node.value, ast.Attribute): - new_node.name = self.parse_Attribute(node.value) - new_node._args.append(new_node.name) - new_node._args.insert(0, node.target.id) - nodes.append(new_node) - elif isinstance(node.value, ast.Num): - new_node.name = node.op.__class__.__name__ - new_node._args.insert(0, node.target.id) - new_node._attribute.name = node.value.__class__.__name__ - new_node._attribute.type = NodeType.constant - visitor = self._get_node_visitor(node.value) - value = visitor(node.value) - print ("augassign value:", value) - new_node._attribute._attribute["value"] = value + augAssign_node._args.append("tmp") + assert(len(nodes) == len(called_obj_names)) + elif isinstance(ast_node.value, ast.Name): + augAssign_node._args.append(ast_node.value.id) + elif isinstance(ast_node.value, ast.Attribute): + augAssign_node.name = self.parse_Attribute(ast_node.value) + augAssign_node._args.append(augAssign_node.name) + elif isinstance(ast_node.value, ast.Num): + new_node = ConstantNode(name="constant" + str(ast_node.value.n), value=ast_node.value.n) + new_node._targets.append(new_node.name) + augAssign_node._args.append(new_node.name) nodes.append(new_node) + called_obj_names.append("constant value") else: - new_node.name = node.target.id - new_node._args.insert(0, node.target.id) - new_node._attribute.name = node.value.__class__.__name__ - new_node._attribute.type = NodeType.constant - visitor = self._get_node_visitor(node.value) - value = visitor(node.value) - print ("augassign value:", value) - new_node._attribute._attribute["value"] = value - nodes.append(new_node) + logger.warning("unsupported type, node: ", ast_node) + augAssign_node.name = ast_node.op.__class__.__name__ + nodes.append(augAssign_node) + called_obj_names.append(augAssign_node.name) return nodes, called_obj_names def parse_arguments(self, node: ast.arguments): diff --git a/mindspore/python/mindspore/rewrite/rewriter.py b/mindspore/python/mindspore/rewrite/rewriter.py index 97231fed102..bdbc9a2609b 100644 --- a/mindspore/python/mindspore/rewrite/rewriter.py +++ b/mindspore/python/mindspore/rewrite/rewriter.py @@ -1,23 +1,26 @@ import ast +from types import FunctionType from typing import Dict, Union from .graph import Graph import mindspore.nn as nn from mindspore.ops.primitive import Primitive - def parse(network: Union[nn.Cell, Primitive]) -> Graph: - graph = Graph(network) - - graph.create_ast() - - graph.get_function_root() - print(graph.python_code) - # graph.print_ast() - graph.parse_init() - graph.parse_construct() - # graph.parse_function("__init__") - - # graph.parse_function("construct") + if issubclass(network, nn.Cell): + graph = Graph(network) + graph.create_ast() + print(graph.python_code) + graph.print_ast() + graph.get_function_root() + graph.parse_init() + graph.parse_attr_subgraph() + graph.parse_functions() + graph.parse_construct() + elif isinstance(network, FunctionType): + graph = FunctionGraph(network) + graph.create_placeholder(graph._ast_root.body[0]) + elif isinstance(network, Primitive): + graph = PrimitiveGraph(network) return graph -- Gitee From 3ce231a6acc04bf3203f18e5b91735ef166ed439 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Mon, 20 Dec 2021 19:09:10 +0800 Subject: [PATCH 05/34] update rewrite --- mindspore/python/mindspore/rewrite/graph.py | 116 ++++++++---- mindspore/python/mindspore/rewrite/node.py | 14 +- mindspore/python/mindspore/rewrite/parser.py | 165 ++++++++++++------ .../python/mindspore/rewrite/rewriter.py | 7 +- 4 files changed, 203 insertions(+), 99 deletions(-) diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index 9d90ffeacfb..f9159a1814c 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -49,7 +49,7 @@ class _node_list: raise StopIteration class Graph(): - def __init__(self, network: nn.Cell): + def __init__(self, network: Union[nn.Cell, Primitive, FunctionType]): self._name = network.__name__ self._network = network self._ast_root: ast.AST = None @@ -86,6 +86,9 @@ class Graph(): return self._root def create_ast(self): + """ + Create the ast corresponding to the network. + """ network_str = inspect.getsource(self._network) self._ast_root = ast.parse(network_str) @@ -93,6 +96,9 @@ class Graph(): astpretty.pprint(self._ast_root) def bfs(self) -> _node_list: + """ + Return an iterator to access the node in bfs mode. + """ return _node_list(self) def find(self, full_name_with_scope: str) -> Node: @@ -119,6 +125,7 @@ class Graph(): def get_function_root(self): """ + Get the ast root node for all function in the network. 1.使用walk来遍历所有节点,找到FunctionDef的节点这种方式不能够区分出class,我们只传入一个cell应该不存在多个class的问题 2.直接遍历class的body,该方式效率会高一些 """ @@ -127,20 +134,23 @@ class Graph(): self._ast_function_root[node.name] = node def create_placeholder(self, ast_function: ast.FunctionDef): + """ + Create placeholder for function. + """ ast_node = ast_function.args args = self._parser.parse_arguments(ast_node) for name, value in args.items(): new_node = PlaceholderNode(name, name, ast_node, default_value=value) - print ("placeholder node:", new_node) + logger.debug ("placeholder node: %r", new_node) self._nodes.append(new_node) def parse_init(self): """ - 解析init函数,获取相关算子的属性信息 + Analyze the '__init__' function in network and get the 'op' attribute. """ if "__init__" not in self._network.__dict__.keys(): return - print("================= parse " + self._base_scope + " init function start ========================") + logger.info(f"parse {self._base_scope} init function start") self.create_placeholder(self._ast_function_root["__init__"]) self._parser.updete_closure_namespace(self._network.__init__) for ast_node in self._ast_function_root["__init__"].body: @@ -149,49 +159,53 @@ class Graph(): if isinstance(ast_node, ast.Assign): new_node = self._parser.parse_init_assign(ast_node) - #print("new node in init function", new_node) self._node_attributes[new_node.name] = new_node - print("init function node: ") + logger.debug("init function node: ") for node in self._node_attributes: - print(node) - print("================= parse " + self._base_scope + " init function end ========================") + logger.debug(node) + logger.info("parse {self._base_scope} init function end") return def parse_construct(self): - print("================= parse " + self._base_scope + " construct function start ========================") + """ + Analyze the 'construct' function in network. + """ + logger.info(f"parse {self._base_scope} construct function start") self._parser.updete_closure_namespace(self._network.construct) self.create_placeholder(self._ast_function_root["construct"]) name_counts = {} #save the number of the variable, if the number is over 1,then modify the name - add a number as the name suffix index = 0 for ast_node in self._ast_function_root["construct"].body: - print(ast_node) + logger.debug(f"process ast node: {ast_node}") if isinstance(ast_node, ast.Expr): continue method = 'parse_' + ast_node.__class__.__name__ visitor = getattr(self._parser, method, None) - + if not visitor: + logger.warning("Get node visitor failed in parse_construct, node: %r", ast_node) + continue nodes, attribute_names = visitor(ast_node) for i in range(len(nodes)): if attribute_names and attribute_names[i] in self._node_attributes.keys(): - print("defined in init function: ", attribute_names[i]) + #print("defined in init function: ", attribute_names[i]) nodes[i]._attribute = self._node_attributes[attribute_names[i]] elif nodes[i].name.split(".")[-1] in dir(self._network): nodes[i]._attribute._type = NodeType.call_method nodes[i]._attribute._is_custom_define = True - print("self defined func: ", nodes[i].name) + #print("self defined func: ", nodes[i].name) elif self._parser.get_func_namesapce(nodes[i].name.split(".")[-1]): class_, name_space_, is_custom_define_ = self._parser.get_func_namesapce(nodes[i].name.split(".")[-1]) - print("defined in other namespace") #must resolve the undefined symble - print ("class: ", class_, "name space: ", name_space_, "is custom define: ", is_custom_define_) + #print("defined in other namespace") #must resolve the undefined symble + #print ("class: ", class_, "name space: ", name_space_, "is custom define: ", is_custom_define_) nodes[i]._attribute._is_custom_define = is_custom_define_ nodes[i]._attribute._class = class_ nodes[i]._attribute._type = NodeType.call_function if is_custom_define_: subgraph = self.parse_function(class_) - print("self defined subgraph: ", subgraph) + #print("self defined subgraph: ", subgraph) else: - print("undefined symbole ....") + logger.warning("undefined symbole ....") name = self._base_scope + "." + nodes[i].name.split(".")[-1] if name in name_counts.keys(): @@ -205,21 +219,24 @@ class Graph(): self._find_input_node(nodes[i]) index += 1 self._nodes.append(nodes[i]) - print("construct nodes: ") + logger.debug("construct nodes: ") for node in self._nodes: - print(node) - print("================= parse " + self._base_scope + " construct function end ========================") + logger.debug(node) + logger.info(f"parse {self._base_scope } construct function end") def parse_function(self, func: Union[ast.FunctionDef, FunctionType]): + """ + Analyze a function, it can be a method in network or a function. + """ if isinstance(func, FunctionType): - print("================= parse " + func.__name__ + " function ========================") + logger.info(f"parse {func.__name__} function start") function_str = inspect.getsource(func) ast_root = ast.parse(function_str) astpretty.pprint (ast_root) node = ast_root.body[0] subgraph = Graph(func) #要区分类内还是类外方法 else: - print("================= parse " + func.name + " function ========================") + logger.info(f"parse {func.name} function start") node = func subgraph = Graph(self._network.__dict__[node.name]) #要区分类内还是类外方法 subgraph._name = node.name @@ -228,7 +245,7 @@ class Graph(): index = 0 for ast_node in subgraph._ast_root.body: - print(ast_node) + logger.debug(f"process ast node: {ast_node}") if isinstance(ast_node, ast.Expr): continue method = 'parse_' + ast_node.__class__.__name__ @@ -241,30 +258,36 @@ class Graph(): if attribute_names and attribute_names[i] in self._node_attributes.keys(): nodes[i]._attribute = self._node_attributes[attribute_names[i]] else: - print("can not find attribute") + logger.warning("can not find attribute") subgraph._find_input_node(nodes[i]) subgraph._nodes.append(nodes[i]) index += 1 - print(subgraph._name + " nodes: ") + logger.debug(f"{subgraph._name} nodes: ") for node in subgraph._nodes: - print (node) - print("====================== parse function end =====================") + logger.debug (node) + logger.info("parse function end") return subgraph def parse_functions(self): + """ + Analyze methods other than init and construct in network. + """ for name, ast_root in self._ast_function_root.items(): if name == "__init__" or name == "construct": continue subgraph = self.parse_function(ast_root) - print("name = ", name, "; subgraph = ", subgraph) + logger.info("name = %r", name, "; subgraph = %r", subgraph) self._subgraphs["self." + name] = subgraph def _find_input_node(self, node: Node): + """ + Find the input nodes of node. + """ for arg in node._args: flag = 0 - if isinstance(arg, int): + if isinstance(arg, int): # TODO: for n in self._contant_nodes: if arg in n._targets: node.inputs.append(n) @@ -278,8 +301,8 @@ class Graph(): self._contant_nodes.append(new_node) continue - for i in range(len(self._nodes) - 1, -1, -1): #需要反向查找节点 - print ("arg: ", arg, "; n.targets: ", self._nodes[i]._targets) + for i in range(len(self._nodes) - 1, -1, -1): + logger.debug("node arg: %r", arg, "; current node targets: %r", self._nodes[i]._targets) if arg in self._nodes[i]._targets: node.inputs.append(self._nodes[i]) self._nodes[i].outputs.append(node) @@ -301,15 +324,18 @@ class Graph(): return def parse_attr_subgraph(self): + """ + Parse the subgraph defined in '__init__' function. + """ for name, node in self._node_attributes.items(): - print ("name: ", name, "; node: ", node) + #print ("name: ", name, "; node: ", node) if node._is_custom_define and isinstance(node._class, FunctionType): - print("is function") + logger.info("The node is FunctionType, node: %r", node) subgraph = self.parse_function(node._class) - print("name = ", name, "; subgraph = ", subgraph) + logger.debug("name = %r", name, "; subgraph = %r", subgraph) self._subgraphs["self." + name] = subgraph elif node._is_custom_define and issubclass(node._class, (nn.Cell)): - print ("this node is Cell class") + logger.info("The node is subclass of Cell, node: %r", node) graph = Graph(node._class) graph.create_ast() graph.print_ast() @@ -318,14 +344,20 @@ class Graph(): graph.parse_construct() for node in graph._nodes: node.name = self._base_scope + "." + name.split(".")[-1] + "." + node.name.split(".")[-1] - print (name + " subgraph node: ", repr(node)) + logger.debug(f"{name} subgraph node: {node}") self._subgraphs[name] = graph elif node._is_custom_define and issubclass(node._class, Primitive): - print ("is primitive") + logger.info("The node is subclass of Primitive, node: %r", node) else: - print ("other types") + logger.info("The node is other types") def remove_node(self, node: Node): + """ + Remove node in 'nodes', modify ast synchronously. + + Args: + node: Node to be deleted. + """ index = node._index self._ast_function_root["construct"].body.pop(index) ast.fix_missing_locations(self._ast_function_root["construct"]) @@ -334,6 +366,13 @@ class Graph(): self._nodes.remove(node) def replace_node(self, src_nodes, dst_node): + """ + Replace src_nodes in 'nodes' by dst_node, modify ast synchronously. + + Args: + src_nodes: Nodes to be replaced. + dst_node: Node used to replace. + """ if isinstance(src_nodes, list): # redirect edges appends = [] @@ -379,6 +418,7 @@ class Graph(): def insert_node(self, node: Node): """ + Insert node into 'nodes', modify ast synchronously. 需要知道插入的位置,在body中的下标, attr_name是插入的属性名称或者是已存在的属性名称 """ new_init_node = ast.Assign() diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py index 2402287c211..769400121be 100644 --- a/mindspore/python/mindspore/rewrite/node.py +++ b/mindspore/python/mindspore/rewrite/node.py @@ -140,12 +140,14 @@ class BaseNode: return self._attribute._attribute.get(key) class Node(BaseNode): + """ + 'Node' is the main data structure that represents individual operations within a 'Graph'. + """ def __init__(self, name="", targets=None, args=None, ast_node=None, instance: Union[nn.Cell, Primitive] = None, inputs: List = None): """ 创建一个节点时对应的属性怎么传进来,cell应该不涉及,primitive会有这种情况 """ super().__init__(name, targets, args, inputs) - #self._name: str = "" self._kwargs: Dict = {} self._scope: str = "" self._ast_node: ast.AST = ast_node @@ -161,8 +163,7 @@ class Node(BaseNode): self._attribute._class = cell_type def set_cell(self, cell: nn.Cell): - #self._attribute = None - self._attribute.instance = NodeType.call_cell + self._attribute._class = NodeType.call_cell def __repr__(self): input_names= "" @@ -181,9 +182,11 @@ class Node(BaseNode): return f"name: {self._name}; ast_node: {self._ast_node}; scope: {self._scope}; index: {self._index}; inputs: {len(self.inputs)}; input names: {input_names}; outputs: {len(self.outputs)}; output names: {output_names}; attr info: {self._attribute}" class ConstantNode(BaseNode): + """ + 'ConstantNode' is used to save constants. + """ def __init__(self, name="constant", value=None): super().__init__(name=name, args=[], targets=[]) - #self._name = str(value) self._value = value self._args.append(value) self._attribute.type = NodeType.constant @@ -195,6 +198,9 @@ class ConstantNode(BaseNode): return f"name: {self._name}; value: {self._value}; outputs: {len(self.outputs)}; output names: {output_names}" class PlaceholderNode(BaseNode): + """ + 'PlaceholderNode' is used to represent inputs of Cell, method or function. + """ def __init__(self, name, targets=None, ast_node=None, default_value=None): super().__init__(name, targets) self._ast_node = ast_node diff --git a/mindspore/python/mindspore/rewrite/parser.py b/mindspore/python/mindspore/rewrite/parser.py index d20d7597345..2f9c4b823d9 100644 --- a/mindspore/python/mindspore/rewrite/parser.py +++ b/mindspore/python/mindspore/rewrite/parser.py @@ -36,22 +36,20 @@ class Parser(): self._default_values = None def updete_closure_namespace(self, fn: FunctionType): + """ + Update 'closure_namespace' of fn. + """ self.closure_namespace = ClosureNamespace(inspect.unwrap(fn)) - def parse_function(self, ast_root: ast.AST, function_name: str): - ''' - parse function by name - ''' - for ast_node in ast_root.body: - #new_node: Node = node_parser_mapper[type(ast_node).__name__](self, ast_node) - pass - def _get_node_visitor(self, node: ast.AST): method = 'parse_' + node.__class__.__name__ visitor = getattr(self, method, None) return visitor def get_func_namesapce(self, func_name: str): + """ + Get the namespace of func_name. + """ if func_name in self.ms_common_ns: return self.ms_common_ns[func_name], repr(self.ms_common_ns), False elif func_name in self.ms_nn_ns: @@ -69,17 +67,25 @@ class Parser(): elif func_name in self.closure_namespace: return self.closure_namespace[func_name], repr(self.closure_namespace), True else: + logger.warning(f"get namespace failed, func_name: {func_name}") return None, None, False def _parse_targets(self, node: ast.AST): + """ + Parse targets of ast node. + """ visitor = self._get_node_visitor(node) - if visitor: - res = visitor(node) - else: + if not visitor: logger.warning("get node visiter failed, node: %r", node) + return None + res = visitor(node) + return res def _parse_args(self, node_list: ast.List): + """ + Parse args of ast node. + """ args = [] nodes = [] called_obj_names = [] @@ -91,9 +97,10 @@ class Parser(): args.append("tmp") called_obj_names.extend(called_obj_names_) assert(len(nodes) == len(called_obj_names)) - print("node in args: ", new_node) elif isinstance(node, ast.Name): args.append(node.id) + else: + logger.warning("unsupported type in _parse_args, node: %r", node) return args, nodes, called_obj_names @@ -110,12 +117,14 @@ class Parser(): return left / right def parse_init_assign(self, node: ast.Assign): + """ + Parse Assign node in '__init__' function. + """ lineno = node.lineno visitor = self._get_node_visitor(node.targets) targets = visitor(node.targets) logger.debug(f"start parse node in __init__ function: {node}") - #print ("targets:", targets) value = node.value new_node = AttributeNode(name=targets[0]) # TODO: deal with multi outputs if isinstance(value, ast.Call): @@ -132,11 +141,14 @@ class Parser(): logger.debug("value is BinOp") pass else: - logger.warning("vaule type: ", type(value), " is not supported") - logger.debug("parse end, result: ", new_node) + logger.warning(f"vaule type: {type(value)} is not supported") + logger.debug("parse init Assign end, result: %r", new_node) return new_node def _parse_init_args(self, ast_nodes: ast.List): + """ + Parse the args of the node in the '__init__' function. + """ args = [] for node in ast_nodes: if isinstance(node, ast.BinOp): @@ -151,11 +163,14 @@ class Parser(): value = self._default_values[value] args.append(value) - print("init args: ", args) + logger.debug("parse init args end, args: %r", args) return args def _parse_init_keywords(self, ast_nodes: ast.List): + """ + Parse the keywords of the node in the '__init__' function. + """ keywords = {} for node in ast_nodes: @@ -174,10 +189,13 @@ class Parser(): value = self._default_values[value] keywords[key] = value - + logger.debug("parse init keywords end, keywords: %r", keywords) return keywords def parse_init_Call(self, ast_node: ast.Call, attr_node: AttributeNode): + """ + Parse Call node in '__init__' function. + """ def _update_args_value(args: List, keywords: Dict, parameters: inspect.signature): logger.debug("defaults values:", self._default_values) @@ -211,13 +229,10 @@ class Parser(): visitor = self._get_node_visitor(ast_node.func) value = visitor(ast_node.func) - #print ("node name: ", value.split(".")[-1]) class_name = value.split(".")[-1] class_, name_space, is_custom_define = self.get_func_namesapce(class_name) - #print ("class: ", class_) parameters = inspect.signature(class_.__init__).parameters - #print ("parameters: ", parameters) if name_space in namespace_nodetype_map: node_type = namespace_nodetype_map[name_space] else: @@ -225,9 +240,9 @@ class Parser(): args = self._parse_init_args(ast_node.args) keywords = self._parse_init_keywords(ast_node.keywords) - logger.debug("before update parameters: ", parameters) + logger.debug("before update parameters: %r", parameters) new_dict = _update_args_value(args, keywords, parameters) - logger.debug ("the node: ", new_dict) + logger.debug ("updated parameters: %r", new_dict) attr_node._class = class_ attr_node._type = node_type attr_node._is_custom_define = is_custom_define @@ -236,6 +251,9 @@ class Parser(): return def parse_init_BinOp(self, ast_node: ast.BinOp): + """ + Parse BinOp node in '__init__' function. + """ def _get_value(node: ast.AST): if isinstance(node, ast.Call): value = AttributeNode() @@ -253,35 +271,38 @@ class Parser(): return value - logger.debug("start parse binop: ", ast_node) + logger.debug("parse init BinOp start, node: %r", ast_node) op = ast_node.op left = ast_node.left right = ast_node.right left_value = _get_value(left) right_value = _get_value(right) - logger.debug("left value: ", left_value) - logger.debug("right value: ", right_value) + logger.debug("left value: %r", left_value) + logger.debug("right value: %r", right_value) if (isinstance(left_value, int) and isinstance(right_value, int)) or (str.isdigit(str(left_value)) and str.isdigit(str(right_value))): method = '_calc_' + op.__class__.__name__ calc = getattr(self, method, None) if calc: result = calc(left_value, right_value) else: - logger.warning("undefined op", method) + logger.warning("undefined method: %r", method) else: result = str(left_value) + " op.__class__.__name__ " + repr(right_value) + logger.debug("parse init BinOp end, result: %r", result) return result def parse_Assign(self, node: ast.Assign): - logger.debug("start parse assign node: ", node) + """ + Parse Assign node in ast. + """ + logger.debug("parse assign node start: %r", node) lineno = node.lineno nodes = [] called_obj_names = [] visitor = self._get_node_visitor(node.targets) targets = visitor(node.targets) - #print("targets: ", targets) value = node.value new_node = Node(targets=targets, ast_node=node) @@ -290,13 +311,15 @@ class Parser(): nodes.extend(nodes_) called_obj_names.extend(called_obj_names_) - #print("new node in assign: ", new_node) assert(len(nodes) == len(called_obj_names)) - logger.debug("parse node end") + logger.debug(f"parse assign node end, nodes: {nodes}; called object names: {called_obj_names}") return nodes, called_obj_names def parse_Call(self, ast_node: ast.Call, node: Node): + """ + Parse Call node in ast. + """ nodes = [] called_obj_names = [] visitor = self._get_node_visitor(ast_node.func) @@ -314,19 +337,24 @@ class Parser(): nodes.append(node) called_obj_names.extend(called_obj_names_) called_obj_names.append(called_obj_name) - #print ("nodes in call:", nodes) - #print("called obj name in call: ", called_obj_names) + return nodes, called_obj_names def parse_Attribute(self, node: ast.Attribute): + """ + Parse Attribute node in ast. + """ visitor = self._get_node_visitor(node.value) - if visitor: - attribute_value = visitor(node.value) + "." + node.attr - else: + if not visitor: logger.warning("get node visitor failed in parse_Attribute") + + attribute_value = visitor(node.value) + "." + node.attr return attribute_value def parse_list(self, node: ast.List) -> list: + """ + Parse list. + """ res = [] for n in node: visitor = self._get_node_visitor(n) @@ -339,6 +367,9 @@ class Parser(): return res def parse_List(self, node: ast.List) -> list: + """ + Parse list. + """ res = [] for n in node.elts: visitor = self._get_node_visitor(n) @@ -350,7 +381,10 @@ class Parser(): return res - def parse_Tuple(self, node: ast.List) -> list: + def parse_Tuple(self, node: ast.Tuple) -> list: + """ + Parse Tuple. + """ res = [] for n in node.elts: visitor = self._get_node_visitor(n) @@ -363,10 +397,13 @@ class Parser(): pass def parse_BinOp(self, ast_node: ast.BinOp, node: Node): #如果left和right都是Call则需要分别创建节点,同时分析call的args,根据args也创建对应节点 + """ + Parse BinOp node in ast. + """ def _get_value(ast_node: ast.AST, args, nodes, called_obj_names, side): visitor = self._get_node_visitor(ast_node) if not visitor: - logger.warning("get node visitor failed in parse_BinOp._get_value, node: ", ast_node) + logger.warning("get node visitor failed in parse_BinOp._get_value, node: %r", ast_node) return if isinstance(ast_node, ast.Call) or isinstance(ast_node, ast.BinOp): new_node = Node(args=[], targets=[]) @@ -389,12 +426,13 @@ class Parser(): return nodes = [] called_obj_names = [] - #ops_info = parse_object_map. node.name = ast_node.op.__class__.__name__ node._attribute._class = ast_node.op #ast node type must convert to mindspore op type node._attribute.node_type = NodeType.call_function args = [] + _get_value(ast_node.left, args, nodes, called_obj_names, "left") + _get_value(ast_node.right, args, nodes, called_obj_names, "right") node._args = args nodes.append(node) called_obj_names.append(node.name) @@ -422,27 +460,48 @@ class Parser(): pass def parse_Name(self, node: ast.Name) -> str: + """ + Parse Name node in ast + """ return node.id def parse_Num(self, node: ast.Num) -> str: + """ + Parse Num node in ast. + """ return node.n def parse_Constant(self, node: ast.Constant) -> str: + """ + Parse Constant node in ast. + """ return node.value def parse_keyword(self, node: ast.keyword): + """ + Parse keyword node in ast + """ key = node.arg visitor = self._get_node_visitor(node.value) value = visitor(node.value) return {key: value} def parse_NameConstant(self, node: ast.NameConstant): + """ + Parse NameConstant node in ast. + """ return node.value def parse_Str(self, node: ast.Str): + """ + Parse Str node in ast. + """ return node.s def parse_AugAssign(self, ast_node: ast.AugAssign): + """ + Parse AugAssign node in ast. + """ nodes = [] called_obj_names = [] augAssign_node = Node(args=[], targets=[]) @@ -470,7 +529,7 @@ class Parser(): nodes.append(new_node) called_obj_names.append("constant value") else: - logger.warning("unsupported type, node: ", ast_node) + logger.warning("unsupported type, node: %r", ast_node) augAssign_node.name = ast_node.op.__class__.__name__ nodes.append(augAssign_node) @@ -478,6 +537,9 @@ class Parser(): return nodes, called_obj_names def parse_arguments(self, node: ast.arguments): + """ + Parse arguments. + """ class Arg: def __init__(self, lineno, col_offset, name) -> None: self._lineno = lineno @@ -491,16 +553,11 @@ class Parser(): self._value = value def _find_corresponding_name(defaults: List[Default], names: List[Arg]): - #print("defaults: ", defaults, "; names: ", names) for d in defaults: - # print("dddddd = ", d) i = 0 while i < len(names) and names[i]._lineno == d._lineno and names[i]._clo_offset < d._col_offset: - # print("names[i]._lineno = ", names[i]._lineno, "; d._lineno: ", d._lineno, "; names[i]._clo_offset: ", names[i]._clo_offset, "; d._col_offset: ", d._col_offset) i += 1 - #print("i = ", i) if i <= len(names): - # print("names[i]._name: ", names[i-1]._name, "; d.value: ", d._value) arg_with_default_value[names[i-1]._name] = d._value args_ = [] @@ -509,24 +566,20 @@ class Parser(): if arg.arg == "self": continue a = Arg(arg.lineno, arg.col_offset, arg.arg) - #args_.append(arg.arg) args_.append(a) arg_with_default_value[a._name] = None for arg in node.kwonlyargs: - #args_.append(arg.arg) a = Arg(arg.lineno, arg.col_offset, arg.arg) args_.append(a) arg_with_default_value[a._name] = None if node.vararg != None: - #args_.append(node.vararg.arg) a = Arg(node.vararg.arg.lineno, arg.col_offset, arg.arg) args_.append(a) arg_with_default_value[a._name] = None if node.kwarg != None: - #args_.append(node.kwarg.arg) a = Arg(node.vararg.arg.lineno, arg.col_offset, arg.arg) args_.append(a) arg_with_default_value[a._name] = None @@ -543,6 +596,10 @@ class Parser(): return arg_with_default_value def parse_Return(self, node: ast.Return): + """ + Parse Return node in ast. + """ + logger.debug("create return node start, node: %r", node) nodes = [] called_obj_names = [] value = node.value @@ -550,26 +607,24 @@ class Parser(): visitor = self._get_node_visitor(value) value = visitor(value) new_node = Node(name="return", args=[], ast_node=node) - print ("return args: ", value) - #new_node.name = "return" if isinstance(value, list): new_node._args += value else: new_node._args.append(value) - #attribute = AttributeNode() - #new_node._ast_node = node + new_node._attribute.type = NodeType.output nodes.append(new_node) - print("return node: ", new_node) + logger.debug("create return node end: %r", new_node) return nodes, called_obj_names def parse_If(self, node: ast.If): + """ + Parse If node in ast. + """ nodes = [] called_obj_names = [] new_node = Node("if", [], [], node) - # new_node.name = "if" - nodes.append(new_node) return nodes, called_obj_names diff --git a/mindspore/python/mindspore/rewrite/rewriter.py b/mindspore/python/mindspore/rewrite/rewriter.py index bdbc9a2609b..1f4ff920d6e 100644 --- a/mindspore/python/mindspore/rewrite/rewriter.py +++ b/mindspore/python/mindspore/rewrite/rewriter.py @@ -6,6 +6,9 @@ import mindspore.nn as nn from mindspore.ops.primitive import Primitive def parse(network: Union[nn.Cell, Primitive]) -> Graph: + """ + Parse python code and return a Graph object. + """ if issubclass(network, nn.Cell): graph = Graph(network) graph.create_ast() @@ -17,10 +20,10 @@ def parse(network: Union[nn.Cell, Primitive]) -> Graph: graph.parse_functions() graph.parse_construct() elif isinstance(network, FunctionType): - graph = FunctionGraph(network) + graph = Graph(network) graph.create_placeholder(graph._ast_root.body[0]) elif isinstance(network, Primitive): - graph = PrimitiveGraph(network) + graph = Graph(network) return graph -- Gitee From fe74d5a3db22134aa34e9e41d86db406fb694262 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Mon, 20 Dec 2021 19:43:35 +0800 Subject: [PATCH 06/34] fix log print --- mindspore/python/mindspore/rewrite/graph.py | 23 ++++++++++++--------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index f9159a1814c..55f08fa5275 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -187,27 +187,30 @@ class Graph(): nodes, attribute_names = visitor(ast_node) for i in range(len(nodes)): + n = nodes[i].name.split(".")[-1] if attribute_names and attribute_names[i] in self._node_attributes.keys(): - #print("defined in init function: ", attribute_names[i]) + logger.debug(f"defined in init function: {attribute_names[i]}") nodes[i]._attribute = self._node_attributes[attribute_names[i]] - elif nodes[i].name.split(".")[-1] in dir(self._network): + elif n in dir(self._network): + logger.debug(f"self defined func: {nodes[i].name}") nodes[i]._attribute._type = NodeType.call_method nodes[i]._attribute._is_custom_define = True - #print("self defined func: ", nodes[i].name) - elif self._parser.get_func_namesapce(nodes[i].name.split(".")[-1]): - class_, name_space_, is_custom_define_ = self._parser.get_func_namesapce(nodes[i].name.split(".")[-1]) - #print("defined in other namespace") #must resolve the undefined symble - #print ("class: ", class_, "name space: ", name_space_, "is custom define: ", is_custom_define_) + elif n == "return": + nodes[i]._attribute._type = NodeType.output + elif self._parser.get_func_namesapce(n)[0]: + class_, name_space_, is_custom_define_ = self._parser.get_func_namesapce(n) + logger.debug(f"defined in other namespace: {n}") + logger.debug("class: ", class_, "name space: ", name_space_, "is custom define: ", is_custom_define_) nodes[i]._attribute._is_custom_define = is_custom_define_ nodes[i]._attribute._class = class_ nodes[i]._attribute._type = NodeType.call_function if is_custom_define_: subgraph = self.parse_function(class_) - #print("self defined subgraph: ", subgraph) + logger.debug(f"self defined subgraph: {subgraph}") else: - logger.warning("undefined symbole ....") + logger.warning("undefined symbole {n} ... ...") - name = self._base_scope + "." + nodes[i].name.split(".")[-1] + name = self._base_scope + "." + n if name in name_counts.keys(): name_counts[name] += 1 name = name + "_" + str(name_counts[name]) -- Gitee From 6ee83fb6ecdcddd337111373bde2744753eb8047 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Mon, 20 Dec 2021 19:47:54 +0800 Subject: [PATCH 07/34] fix log print --- mindspore/python/mindspore/rewrite/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index 55f08fa5275..ac4f64eae21 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -208,7 +208,7 @@ class Graph(): subgraph = self.parse_function(class_) logger.debug(f"self defined subgraph: {subgraph}") else: - logger.warning("undefined symbole {n} ... ...") + logger.warning(f"undefined symbole {n} ... ...") name = self._base_scope + "." + n if name in name_counts.keys(): -- Gitee From dd74a1f8160582acb54b49029c7f955c9f42a2bc Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Tue, 21 Dec 2021 09:30:02 +0800 Subject: [PATCH 08/34] use log to print debug-msg --- .../python/mindspore/golden_stick/__init__.py | 6 +- .../quantization/quantize_test.py | 129 ++++++++++++++++++ .../mindspore/rewrite/pattern_engine.py | 15 +- .../rewrite/pattern_engine_match_test.py | 18 +-- .../mindspore/rewrite/pattern_engine_test.py | 3 +- 5 files changed, 151 insertions(+), 20 deletions(-) create mode 100644 mindspore/python/mindspore/golden_stick/quantization/quantize_test.py diff --git a/mindspore/python/mindspore/golden_stick/__init__.py b/mindspore/python/mindspore/golden_stick/__init__.py index d568e043a05..608a949f2ae 100644 --- a/mindspore/python/mindspore/golden_stick/__init__.py +++ b/mindspore/python/mindspore/golden_stick/__init__.py @@ -21,6 +21,6 @@ from .net_transform import NetTransformer from .quantization import LayerPolicy, NetPolicy, QuantAwareTraining, Quantizer, Transformer, AllValueQuantizer, \ LastValueQuantizer, LSQ, DefaultLayerPolicy, DefaultNetworkPolicy, DefaultQuantAwareTraining -__all__ = ["GoldenStick", "NetTransformer", "LayerPolicy", "NetPolicy", "QuantAwareTraining", "Quantizer" - , "Transformer", "AllValueQuantizer", "LastValueQuantizer", "LSQ", "DefaultLayerPolicy", "DefaultNetworkPolicy", \ - "DefaultQuantAwareTraining"] +__all__ = ["GoldenStick", "NetTransformer", "LayerPolicy", "NetPolicy", "QuantAwareTraining", "Quantizer", + "Transformer", "AllValueQuantizer", "LastValueQuantizer", "LSQ", "DefaultLayerPolicy", + "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize_test.py b/mindspore/python/mindspore/golden_stick/quantization/quantize_test.py new file mode 100644 index 00000000000..b3742fe7fce --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize_test.py @@ -0,0 +1,129 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test Quantize.""" + +import unittest +from test_common import TestCommon +from mindspore.golden_stick import Transformer, LayerPolicy +from mindspore.golden_stick.quantization import DefaultLayerPolicy +from mindspore.golden_stick.quantization.layer_policy import layer_policy_key +from mindspore.rewrite import Node, Graph +from mindspore.nn import Cell, Conv2d, BatchNorm2d, MaxPool2d + + +class QuantizeTestCase(unittest.TestCase): + @staticmethod + def network(): + placeholder = TestCommon.create_placeholder_layer() + + conv1 = TestCommon.create_conv_layer("conv1", [placeholder]) + placeholder.outputs = [conv1] + + bn1 = TestCommon.create_bn_layer("bn1", [conv1]) + conv1.outputs = [bn1] + + pool1 = TestCommon.create_pool_layer("pool1", [bn1]) + bn1.outputs = [pool1] + + conv2 = TestCommon.create_conv_layer("conv2", [pool1]) + pool1.outputs = [conv2] + + bn2 = TestCommon.create_bn_layer("bn2", [conv2]) + conv2.outputs = [bn2] + + pool2 = TestCommon.create_pool_layer("pool2", [bn2]) + bn2.outputs = [pool2] + + graph = Graph(Cell()) + graph.set_root(pool2) + return graph + + @staticmethod + def network_intra_overlapped(): + placeholder = TestCommon.create_placeholder_layer() + + conv1 = TestCommon.create_conv_layer("conv1", [placeholder]) + placeholder.outputs = [conv1] + + conv2 = TestCommon.create_conv_layer("conv2", [conv1]) + conv1.outputs = [conv2] + + conv3 = TestCommon.create_conv_layer("conv3", [conv2]) + conv2.outputs = [conv3] + + bn = TestCommon.create_bn_layer("bn", [conv3]) + conv3.outputs = [bn] + + pool = TestCommon.create_pool_layer("pool", [bn]) + bn.outputs = [pool] + + graph = Graph(Cell()) + graph.set_root(pool) + return graph + + @staticmethod + def network_inter_overlapped(): + placeholder = TestCommon.create_placeholder_layer() + + conv1 = TestCommon.create_conv_layer("conv1", [placeholder]) + placeholder.outputs = [conv1] + + conv2 = TestCommon.create_conv_layer("conv2", [conv1]) + conv1.outputs = [conv2] + + bn = TestCommon.create_bn_layer("bn", [conv2]) + conv2.outputs = [bn] + + pool = TestCommon.create_pool_layer("pool", [bn]) + bn.outputs = [pool] + + graph = Graph(Cell()) + graph.set_root(pool) + return graph + + def test_inter_overlap(self): + transformer1: Transformer = Transformer([Conv2d, Conv2d]) + transformer2: Transformer = Transformer([Conv2d, BatchNorm2d]) + graph: Graph = TransformerTestCase.network_inter_overlapped() + conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) + conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) + conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) + bn_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, True) + bn_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, False) + self.assertNotEqual(conv1_in_fq, None) + self.assertNotEqual(conv1_out_fq, None) + self.assertNotEqual(conv2_in_fq, None) + self.assertNotEqual(conv2_out_fq, None) + self.assertNotEqual(bn_in_fq, None) + self.assertNotEqual(bn_out_fq, None) + transformer1.apply(graph) + transformer2.apply(graph) + conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) + conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) + conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) + bn_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, True) + bn_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, False) + self.assertNotEqual(conv1_in_fq, None) + self.assertEqual(conv1_out_fq, None) + self.assertEqual(conv2_in_fq, None) + self.assertNotEqual(conv2_out_fq, None) + self.assertNotEqual(bn_in_fq, None) + self.assertNotEqual(bn_out_fq, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/mindspore/python/mindspore/rewrite/pattern_engine.py b/mindspore/python/mindspore/rewrite/pattern_engine.py index f1db00b748a..ba66c5c3cff 100644 --- a/mindspore/python/mindspore/rewrite/pattern_engine.py +++ b/mindspore/python/mindspore/rewrite/pattern_engine.py @@ -19,6 +19,7 @@ from collections import OrderedDict from .graph import Graph from .node import Node, NodeType from mindspore.nn.cell import Cell +from mindspore import log as logger class PatternNode: @@ -179,7 +180,7 @@ class PatternEngine: self._replacement = replacement self._pattern = PatternNode.create_pattern_from_list(pattern) else: - print("Unsupported pattern type: ", type(pattern)) + logger.debug("Unsupported pattern type: %s", type(pattern)) self._is_chain = False self._replacement = None self._pattern = PlaceHolderNode() @@ -307,11 +308,11 @@ class PatternEngine: # todo: Recurse into subgraph node. Depend on subgraph node definition if node.node_type() != NodeType.call_cell: - print("Pattern match failed: node(", node.name, ") is not a cell") + logger.debug("Pattern match failed: node(%s) is not a cell", node.name) return False, OrderedDict() if not pattern.match(node): - print("Pattern match failed: node(", node.name, ")'s type is ", node.type, " while pattern type is ", - pattern.type()) + logger.debug("Pattern match failed: node(%s)'s type is %s while pattern type is %s", node.name, node.type, + pattern.type()) return False, OrderedDict() if isinstance(pattern, PlaceHolderNode): return True, OrderedDict() @@ -321,8 +322,8 @@ class PatternEngine: if input_num == 0: return True, OrderedDict({pattern.name(): node}) if input_num != len(cur_inputs): - print("Pattern match failed: node(", node.name, ")'s has ", len(node.inputs), " inputs while pattern has ", - input_num, " inputs") + logger.debug("Pattern match failed: node(%s)'s has %d inputs while pattern has %d inputs", node.name, + len(node.inputs), input_num) return False, OrderedDict() result = OrderedDict() for i in range(0, input_num): @@ -343,6 +344,6 @@ class PatternEngine: node = match_dict[key] for output in node.outputs: if output not in matched_nodes: - print("Check match failed, pattern leaked") + logger.debug("Check match failed, pattern leaked") return False return True diff --git a/mindspore/python/mindspore/rewrite/pattern_engine_match_test.py b/mindspore/python/mindspore/rewrite/pattern_engine_match_test.py index 6c9b436c560..39fc11892fb 100644 --- a/mindspore/python/mindspore/rewrite/pattern_engine_match_test.py +++ b/mindspore/python/mindspore/rewrite/pattern_engine_match_test.py @@ -3,6 +3,7 @@ import unittest from collections import OrderedDict from mindspore.rewrite import PatternEngine, Node, PlaceHolderNode, PatternNode from mindspore.nn.layer import Pad, Conv2d, BatchNorm2d, ReLU, Softmax, MatMul +from mindspore import log as logger class PatternEngineMatchTestCase(unittest.TestCase): @@ -14,7 +15,6 @@ class PatternEngineMatchTestCase(unittest.TestCase): for key in ret: self.assertEqual(ret.get(key), index) index += 1 - print(ret) @staticmethod def chain_network(): @@ -119,10 +119,10 @@ class PatternEngineMatchTestCase(unittest.TestCase): # define pattern engine pattern_engine = PatternEngine(p_relu) match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) - print("*****Chain match softmax result: ", match) + logger.info("*****Chain match softmax result: %s", match) self.assertEqual(match, False) match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) - print("*****Chain match matmul result: ", match) + logger.info("*****Chain match matmul result: %s", match) self.assertEqual(match, True) self.assertEqual(len(match_dict), 3) @@ -131,10 +131,10 @@ class PatternEngineMatchTestCase(unittest.TestCase): # define pattern engine pattern_engine = PatternEngine([Conv2d, BatchNorm2d, ReLU]) match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) - print("*****List chain match softmax result: ", match) + logger.info("*****List chain match softmax result: %s", match) self.assertEqual(match, False) match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) - print("*****List chain match matmul result: ", match) + logger.info("*****List chain match matmul result: %s", match) self.assertEqual(match, True) self.assertEqual(len(match_dict), 3) @@ -150,10 +150,10 @@ class PatternEngineMatchTestCase(unittest.TestCase): # define pattern engine pattern_engine = PatternEngine(p_matmul) match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) - print("*****Tree match softmax result: ", match) + logger.info("*****Tree match softmax result: %s", match) self.assertEqual(match, False) match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) - print("*****Tree match matmul result: ", match) + logger.info("*****Tree match matmul result: %s", match) self.assertEqual(match, True) self.assertEqual(len(match_dict), 5) @@ -167,13 +167,13 @@ class PatternEngineMatchTestCase(unittest.TestCase): # define pattern engine pattern_engine = PatternEngine(p_matmul) match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) - print("*****Tree no-placehold match matmul result: ", match) + logger.info("*****Tree no-placehold match matmul result: %s", match) self.assertEqual(match, False) p_matmul.set_inputs([p_placeholder, p_bn2]) pattern_engine = PatternEngine(p_matmul) match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) - print("*****Tree placehold match matmul result: ", match) + logger.info("*****Tree placehold match matmul result: %s", match) self.assertEqual(match, True) self.assertEqual(len(match_dict), 3) diff --git a/mindspore/python/mindspore/rewrite/pattern_engine_test.py b/mindspore/python/mindspore/rewrite/pattern_engine_test.py index 2f0823e0cfb..8f7cc55e759 100644 --- a/mindspore/python/mindspore/rewrite/pattern_engine_test.py +++ b/mindspore/python/mindspore/rewrite/pattern_engine_test.py @@ -4,6 +4,7 @@ from mindspore.rewrite import PatternEngine, Node, Graph, NodeType from lenet import LeNet5 from mindspore.nn import Cell from mindspore.nn.layer import Conv2d, BatchNorm2d, Dense, MaxPool2d, Flatten, ReLU +from mindspore import log as logger class PatternEngineTestCase(unittest.TestCase): @@ -73,7 +74,7 @@ class PatternEngineTestCase(unittest.TestCase): def get_nodes_count(root: Node, to_print: bool = False): count = 1 if to_print: - print("Visit ", root.name) + logger.debug("Visit %s", root.name) for input in root.inputs: count += PatternEngineTestCase.get_nodes_count(input, to_print) return count -- Gitee From b5c74bd86ea786c522a3a811dbdd129e2fa4448c Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Tue, 21 Dec 2021 10:45:50 +0800 Subject: [PATCH 09/34] updata rewrite --- mindspore/python/mindspore/rewrite/graph.py | 8 +- mindspore/python/mindspore/rewrite/parser.py | 97 +++++++++---------- .../python/mindspore/rewrite/rewriter.py | 2 +- 3 files changed, 53 insertions(+), 54 deletions(-) diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index ac4f64eae21..5e2098bc5b7 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -179,8 +179,9 @@ class Graph(): logger.debug(f"process ast node: {ast_node}") if isinstance(ast_node, ast.Expr): continue - method = 'parse_' + ast_node.__class__.__name__ - visitor = getattr(self._parser, method, None) + #method = 'parse_' + ast_node.__class__.__name__ + #visitor = getattr(self._parser, method, None) + visitor = self._parser._get_node_visitor(ast_node) if not visitor: logger.warning("Get node visitor failed in parse_construct, node: %r", ast_node) continue @@ -222,6 +223,7 @@ class Graph(): self._find_input_node(nodes[i]) index += 1 self._nodes.append(nodes[i]) + self._root = self._nodes[-1] logger.debug("construct nodes: ") for node in self._nodes: logger.debug(node) @@ -326,7 +328,7 @@ class Graph(): return - def parse_attr_subgraph(self): + def parse_init_subgraph(self): """ Parse the subgraph defined in '__init__' function. """ diff --git a/mindspore/python/mindspore/rewrite/parser.py b/mindspore/python/mindspore/rewrite/parser.py index 2f9c4b823d9..7ae48132b97 100644 --- a/mindspore/python/mindspore/rewrite/parser.py +++ b/mindspore/python/mindspore/rewrite/parser.py @@ -42,7 +42,7 @@ class Parser(): self.closure_namespace = ClosureNamespace(inspect.unwrap(fn)) def _get_node_visitor(self, node: ast.AST): - method = 'parse_' + node.__class__.__name__ + method = 'parse_' + node.__class__.__name__.lower() visitor = getattr(self, method, None) return visitor @@ -92,7 +92,7 @@ class Parser(): for node in node_list: if isinstance(node, ast.Call): #according to the configuration of the node, create a new node and insert into nodes before it, set args and targets information new_node = Node(targets=["tmp"], ast_node=node) - nodes_, called_obj_names_ = self.parse_Call(node, new_node) + nodes_, called_obj_names_ = self.parse_call(node, new_node) nodes.extend(nodes_) args.append("tmp") called_obj_names.extend(called_obj_names_) @@ -104,16 +104,16 @@ class Parser(): return args, nodes, called_obj_names - def _calc_Add(self, left, right): + def _calc_add(self, left, right): return int(left) + int(right) - def _calc_Sub(self, left, right): + def _calc_sub(self, left, right): return int(left) - int(right) - def _calc_Mult(self, left, right): + def _calc_mult(self, left, right): return int(left) * int(right) - def _calc_Div(self, left, right): + def _calc_div(self, left, right): return left / right def parse_init_assign(self, node: ast.Assign): @@ -128,7 +128,7 @@ class Parser(): value = node.value new_node = AttributeNode(name=targets[0]) # TODO: deal with multi outputs if isinstance(value, ast.Call): - self.parse_init_Call(value, new_node) + self._parse_init_call(value, new_node) elif isinstance(value, ast.Name): new_node._class = ast.Constant new_node._type = NodeType.constant @@ -152,10 +152,10 @@ class Parser(): args = [] for node in ast_nodes: if isinstance(node, ast.BinOp): - value = self.parse_init_BinOp(node) + value = self._parse_init_binop(node) elif isinstance(node, ast.Call): value = AttributeNode() - self.parse_init_Call(node, value) + self._parse_init_call(node, value) else: visitor = self._get_node_visitor(node) value = visitor(node) @@ -178,10 +178,10 @@ class Parser(): value = node.value if isinstance(node.value, ast.BinOp): - value = self.parse_init_BinOp(node.value) + value = self._parse_init_binop(node.value) elif isinstance(node.value, ast.Call): value = AttributeNode() - self.parse_init_Call(node.value, value) + self._parse_init_call(node.value, value) else: visitor = self._get_node_visitor(node.value) value = visitor(node.value) @@ -192,7 +192,7 @@ class Parser(): logger.debug("parse init keywords end, keywords: %r", keywords) return keywords - def parse_init_Call(self, ast_node: ast.Call, attr_node: AttributeNode): + def _parse_init_call(self, ast_node: ast.Call, attr_node: AttributeNode): """ Parse Call node in '__init__' function. """ @@ -250,20 +250,20 @@ class Parser(): return - def parse_init_BinOp(self, ast_node: ast.BinOp): + def _parse_init_binop(self, ast_node: ast.BinOp): """ Parse BinOp node in '__init__' function. """ def _get_value(node: ast.AST): if isinstance(node, ast.Call): value = AttributeNode() - self.parse_init_Call(node, value) + self._parse_init_call(node, value) elif isinstance(node, ast.BinOp): - value = self.parse_init_BinOp(node) + value = self._parse_init_binop(node) else: visitor = self._get_node_visitor(node) if not visitor: - logger.warning("get node visitor failed in parse_init_BinOp") + logger.warning("get node visitor failed in parse_init_binop") return None value = visitor(node) if value in self._default_values.keys(): @@ -281,7 +281,7 @@ class Parser(): logger.debug("left value: %r", left_value) logger.debug("right value: %r", right_value) if (isinstance(left_value, int) and isinstance(right_value, int)) or (str.isdigit(str(left_value)) and str.isdigit(str(right_value))): - method = '_calc_' + op.__class__.__name__ + method = '_calc_' + op.__class__.__name__.lower() calc = getattr(self, method, None) if calc: result = calc(left_value, right_value) @@ -293,7 +293,7 @@ class Parser(): logger.debug("parse init BinOp end, result: %r", result) return result - def parse_Assign(self, node: ast.Assign): + def parse_assign(self, node: ast.Assign): """ Parse Assign node in ast. """ @@ -316,7 +316,7 @@ class Parser(): logger.debug(f"parse assign node end, nodes: {nodes}; called object names: {called_obj_names}") return nodes, called_obj_names - def parse_Call(self, ast_node: ast.Call, node: Node): + def parse_call(self, ast_node: ast.Call, node: Node): """ Parse Call node in ast. """ @@ -340,21 +340,18 @@ class Parser(): return nodes, called_obj_names - def parse_Attribute(self, node: ast.Attribute): + def parse_attribute(self, node: ast.Attribute): """ Parse Attribute node in ast. """ visitor = self._get_node_visitor(node.value) if not visitor: - logger.warning("get node visitor failed in parse_Attribute") + logger.warning("get node visitor failed in parse_attribute") attribute_value = visitor(node.value) + "." + node.attr return attribute_value - def parse_list(self, node: ast.List) -> list: - """ - Parse list. - """ + """def parse_list(self, node: ast.List) -> list: res = [] for n in node: visitor = self._get_node_visitor(n) @@ -364,9 +361,9 @@ class Parser(): else: res.append(value) - return res + return res""" - def parse_List(self, node: ast.List) -> list: + def parse_list(self, node: ast.List) -> list: """ Parse list. """ @@ -381,7 +378,7 @@ class Parser(): return res - def parse_Tuple(self, node: ast.Tuple) -> list: + def parse_tuple(self, node: ast.Tuple) -> list: """ Parse Tuple. """ @@ -393,17 +390,17 @@ class Parser(): return res - def parse_Expr(self, node: ast.expr): + def parse_expr(self, node: ast.expr): pass - def parse_BinOp(self, ast_node: ast.BinOp, node: Node): #如果left和right都是Call则需要分别创建节点,同时分析call的args,根据args也创建对应节点 + def parse_binop(self, ast_node: ast.BinOp, node: Node): #如果left和right都是Call则需要分别创建节点,同时分析call的args,根据args也创建对应节点 """ Parse BinOp node in ast. """ def _get_value(ast_node: ast.AST, args, nodes, called_obj_names, side): visitor = self._get_node_visitor(ast_node) if not visitor: - logger.warning("get node visitor failed in parse_BinOp._get_value, node: %r", ast_node) + logger.warning("get node visitor failed in parse_binop._get_value, node: %r", ast_node) return if isinstance(ast_node, ast.Call) or isinstance(ast_node, ast.BinOp): new_node = Node(args=[], targets=[]) @@ -438,40 +435,40 @@ class Parser(): called_obj_names.append(node.name) return nodes, called_obj_names - def parse_BoolOp(self, node: ast.BoolOp): + def parse_boolop(self, node: ast.BoolOp): pass - def parse_UnaryOp(self, node: ast.UnaryOp): + def parse_unaryop(self, node: ast.UnaryOp): pass - def parse_Lambda(self, node: ast.Lambda): + def parse_lambda(self, node: ast.Lambda): pass - def parse_IfExp(self, node: ast.IfExp): + def parse_ifexp(self, node: ast.IfExp): pass - def parse_Dict(self, node: ast.Dict): + def parse_dict(self, node: ast.Dict): pass - def parse_Set(self, node: ast.Set): + def parse_set(self, node: ast.Set): pass - def parse_Slice(self, node: ast.Slice): + def parse_slice(self, node: ast.Slice): pass - def parse_Name(self, node: ast.Name) -> str: + def parse_name(self, node: ast.Name) -> str: """ Parse Name node in ast """ return node.id - def parse_Num(self, node: ast.Num) -> str: + def parse_num(self, node: ast.Num) -> str: """ Parse Num node in ast. """ return node.n - def parse_Constant(self, node: ast.Constant) -> str: + def parse_constant(self, node: ast.Constant) -> str: """ Parse Constant node in ast. """ @@ -486,19 +483,19 @@ class Parser(): value = visitor(node.value) return {key: value} - def parse_NameConstant(self, node: ast.NameConstant): + def parse_nameconstant(self, node: ast.NameConstant): """ Parse NameConstant node in ast. """ return node.value - def parse_Str(self, node: ast.Str): + def parse_str(self, node: ast.Str): """ Parse Str node in ast. """ return node.s - def parse_AugAssign(self, ast_node: ast.AugAssign): + def parse_augassign(self, ast_node: ast.AugAssign): """ Parse AugAssign node in ast. """ @@ -511,7 +508,7 @@ class Parser(): if isinstance(ast_node.value, ast.Call): new_node = Node(args=[], targets=[]) - nodes_, called_obj_names_ = self.parse_Call(ast_node.value, new_node) + nodes_, called_obj_names_ = self.parse_call(ast_node.value, new_node) new_node._targets.append("tmp") nodes.extend(nodes_) called_obj_names.extend(called_obj_names_) @@ -520,7 +517,7 @@ class Parser(): elif isinstance(ast_node.value, ast.Name): augAssign_node._args.append(ast_node.value.id) elif isinstance(ast_node.value, ast.Attribute): - augAssign_node.name = self.parse_Attribute(ast_node.value) + augAssign_node.name = self.parse_attribute(ast_node.value) augAssign_node._args.append(augAssign_node.name) elif isinstance(ast_node.value, ast.Num): new_node = ConstantNode(name="constant" + str(ast_node.value.n), value=ast_node.value.n) @@ -595,7 +592,7 @@ class Parser(): self._default_values = arg_with_default_value return arg_with_default_value - def parse_Return(self, node: ast.Return): + def parse_return(self, node: ast.Return): """ Parse Return node in ast. """ @@ -617,7 +614,7 @@ class Parser(): logger.debug("create return node end: %r", new_node) return nodes, called_obj_names - def parse_If(self, node: ast.If): + def parse_if(self, node: ast.If): """ Parse If node in ast. """ @@ -629,9 +626,9 @@ class Parser(): return nodes, called_obj_names - def parse_While(self, node: ast.While): + def parse_while(self, node: ast.While): pass - def parse_For(self, node: ast.For): + def parse_for(self, node: ast.For): pass \ No newline at end of file diff --git a/mindspore/python/mindspore/rewrite/rewriter.py b/mindspore/python/mindspore/rewrite/rewriter.py index 1f4ff920d6e..9bef6ac94c7 100644 --- a/mindspore/python/mindspore/rewrite/rewriter.py +++ b/mindspore/python/mindspore/rewrite/rewriter.py @@ -16,7 +16,7 @@ def parse(network: Union[nn.Cell, Primitive]) -> Graph: graph.print_ast() graph.get_function_root() graph.parse_init() - graph.parse_attr_subgraph() + graph.parse_init_subgraph() graph.parse_functions() graph.parse_construct() elif isinstance(network, FunctionType): -- Gitee From 7b276e57646365bff9440e11c5a7a37ef21c1fb8 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Tue, 21 Dec 2021 11:12:17 +0800 Subject: [PATCH 10/34] updata rewrite for method names --- mindspore/python/mindspore/rewrite/graph.py | 14 +++-- mindspore/python/mindspore/rewrite/parser.py | 65 ++++++++++---------- 2 files changed, 42 insertions(+), 37 deletions(-) diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index 5e2098bc5b7..5067d724525 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -181,7 +181,7 @@ class Graph(): continue #method = 'parse_' + ast_node.__class__.__name__ #visitor = getattr(self._parser, method, None) - visitor = self._parser._get_node_visitor(ast_node) + visitor = self._parser.get_node_visitor(ast_node) if not visitor: logger.warning("Get node visitor failed in parse_construct, node: %r", ast_node) continue @@ -253,9 +253,11 @@ class Graph(): logger.debug(f"process ast node: {ast_node}") if isinstance(ast_node, ast.Expr): continue - method = 'parse_' + ast_node.__class__.__name__ - visitor = getattr(self._parser, method, None) - + #method = 'parse_' + ast_node.__class__.__name__ + #visitor = getattr(self._parser, method, None) + visitor = self._parser._get_node_visitor(ast_node) + if not visitor: + logger.warning(f"get node visitor failed in parse_function, node: {ast_node}") nodes, attribute_names = visitor(ast_node) for i in range(len(nodes)): nodes[i]._index = index @@ -375,8 +377,8 @@ class Graph(): Replace src_nodes in 'nodes' by dst_node, modify ast synchronously. Args: - src_nodes: Nodes to be replaced. - dst_node: Node used to replace. + src_nodes: Nodes to be replaced. + dst_node: Node used to replace. """ if isinstance(src_nodes, list): # redirect edges diff --git a/mindspore/python/mindspore/rewrite/parser.py b/mindspore/python/mindspore/rewrite/parser.py index 7ae48132b97..f6bfe24494e 100644 --- a/mindspore/python/mindspore/rewrite/parser.py +++ b/mindspore/python/mindspore/rewrite/parser.py @@ -41,7 +41,7 @@ class Parser(): """ self.closure_namespace = ClosureNamespace(inspect.unwrap(fn)) - def _get_node_visitor(self, node: ast.AST): + def get_node_visitor(self, node: ast.AST): method = 'parse_' + node.__class__.__name__.lower() visitor = getattr(self, method, None) return visitor @@ -74,7 +74,7 @@ class Parser(): """ Parse targets of ast node. """ - visitor = self._get_node_visitor(node) + visitor = self.get_node_visitor(node) if not visitor: logger.warning("get node visiter failed, node: %r", node) return None @@ -121,7 +121,7 @@ class Parser(): Parse Assign node in '__init__' function. """ lineno = node.lineno - visitor = self._get_node_visitor(node.targets) + visitor = self.get_node_visitor(node.targets) targets = visitor(node.targets) logger.debug(f"start parse node in __init__ function: {node}") @@ -157,7 +157,7 @@ class Parser(): value = AttributeNode() self._parse_init_call(node, value) else: - visitor = self._get_node_visitor(node) + visitor = self.get_node_visitor(node) value = visitor(node) if value in self._default_values.keys(): value = self._default_values[value] @@ -183,7 +183,7 @@ class Parser(): value = AttributeNode() self._parse_init_call(node.value, value) else: - visitor = self._get_node_visitor(node.value) + visitor = self.get_node_visitor(node.value) value = visitor(node.value) if value in self._default_values.keys(): value = self._default_values[value] @@ -227,7 +227,7 @@ class Parser(): new_dict = OrderedDict() - visitor = self._get_node_visitor(ast_node.func) + visitor = self.get_node_visitor(ast_node.func) value = visitor(ast_node.func) class_name = value.split(".")[-1] class_, name_space, is_custom_define = self.get_func_namesapce(class_name) @@ -261,7 +261,7 @@ class Parser(): elif isinstance(node, ast.BinOp): value = self._parse_init_binop(node) else: - visitor = self._get_node_visitor(node) + visitor = self.get_node_visitor(node) if not visitor: logger.warning("get node visitor failed in parse_init_binop") return None @@ -301,12 +301,12 @@ class Parser(): lineno = node.lineno nodes = [] called_obj_names = [] - visitor = self._get_node_visitor(node.targets) + visitor = self.get_node_visitor(node.targets) targets = visitor(node.targets) value = node.value new_node = Node(targets=targets, ast_node=node) - visitor = self._get_node_visitor(value) + visitor = self.get_node_visitor(value) nodes_, called_obj_names_ = visitor(value, new_node) nodes.extend(nodes_) @@ -322,13 +322,13 @@ class Parser(): """ nodes = [] called_obj_names = [] - visitor = self._get_node_visitor(ast_node.func) + visitor = self.get_node_visitor(ast_node.func) called_obj_name = visitor(ast_node.func) node.name = called_obj_name args_, nodes_, called_obj_names_ = self._parse_args(ast_node.args) - visitor = self._get_node_visitor(ast_node.keywords) + visitor = self.get_node_visitor(ast_node.keywords) kwargs_: Dict = visitor(ast_node.keywords) node._args = args_ @@ -344,32 +344,20 @@ class Parser(): """ Parse Attribute node in ast. """ - visitor = self._get_node_visitor(node.value) + visitor = self.get_node_visitor(node.value) if not visitor: logger.warning("get node visitor failed in parse_attribute") attribute_value = visitor(node.value) + "." + node.attr return attribute_value - """def parse_list(self, node: ast.List) -> list: - res = [] - for n in node: - visitor = self._get_node_visitor(n) - value = visitor(n) - if isinstance(value, list): - res += value - else: - res.append(value) - - return res""" - def parse_list(self, node: ast.List) -> list: """ Parse list. """ res = [] - for n in node.elts: - visitor = self._get_node_visitor(n) + for n in node: + visitor = self.get_node_visitor(n) value = visitor(n) if isinstance(value, list): res += value @@ -378,13 +366,28 @@ class Parser(): return res + #def parse_list(self, node: ast.List) -> list: + # """ + # Parse list. + # """ + # res = [] + # for n in node.elts: + # visitor = self.get_node_visitor(n) + # value = visitor(n) + # if isinstance(value, list): + # res += value + # else: + # res.append(value) + + # return res + def parse_tuple(self, node: ast.Tuple) -> list: """ Parse Tuple. """ res = [] for n in node.elts: - visitor = self._get_node_visitor(n) + visitor = self.get_node_visitor(n) value = visitor(n) res.append(value) @@ -398,7 +401,7 @@ class Parser(): Parse BinOp node in ast. """ def _get_value(ast_node: ast.AST, args, nodes, called_obj_names, side): - visitor = self._get_node_visitor(ast_node) + visitor = self.get_node_visitor(ast_node) if not visitor: logger.warning("get node visitor failed in parse_binop._get_value, node: %r", ast_node) return @@ -479,7 +482,7 @@ class Parser(): Parse keyword node in ast """ key = node.arg - visitor = self._get_node_visitor(node.value) + visitor = self.get_node_visitor(node.value) value = visitor(node.value) return {key: value} @@ -583,7 +586,7 @@ class Parser(): defaults_ = [] for default in node.defaults: - visitor = self._get_node_visitor(default) + visitor = self.get_node_visitor(default) value = visitor(default) d = Default(default.lineno, default.col_offset, value) defaults_.append(d) @@ -601,7 +604,7 @@ class Parser(): called_obj_names = [] value = node.value - visitor = self._get_node_visitor(value) + visitor = self.get_node_visitor(value) value = visitor(value) new_node = Node(name="return", args=[], ast_node=node) if isinstance(value, list): -- Gitee From 17d2021bfb4ff6484eb31d01aa0a792f89eaea7d Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Tue, 21 Dec 2021 11:52:46 +0800 Subject: [PATCH 11/34] add ut for rewrite and golden_stick --- .../quantization/transformer_test.py | 219 ---------- mindspore/python/mindspore/rewrite/lenet.py | 61 --- .../rewrite/pattern_engine_match_test.py | 191 --------- .../mindspore/rewrite/pattern_engine_test.py | 177 -------- mindspore/python/mindspore/rewrite/ut.sh | 37 +- tests/ut/python/golden_stick/__init__.py | 0 .../ut/python/golden_stick}/test_common.py | 56 ++- .../python/golden_stick/test_transformer.py | 215 ++++++++++ tests/ut/python/rewrite/__init__.py | 0 .../ut/python/rewrite/test_pattern_engine.py | 377 ++++++++++++++++++ tests/ut/python/runtest.sh | 12 + 11 files changed, 637 insertions(+), 708 deletions(-) delete mode 100644 mindspore/python/mindspore/golden_stick/quantization/transformer_test.py delete mode 100644 mindspore/python/mindspore/rewrite/lenet.py delete mode 100644 mindspore/python/mindspore/rewrite/pattern_engine_match_test.py delete mode 100644 mindspore/python/mindspore/rewrite/pattern_engine_test.py create mode 100644 tests/ut/python/golden_stick/__init__.py rename {mindspore/python/mindspore/golden_stick/quantization => tests/ut/python/golden_stick}/test_common.py (46%) create mode 100644 tests/ut/python/golden_stick/test_transformer.py create mode 100644 tests/ut/python/rewrite/__init__.py create mode 100644 tests/ut/python/rewrite/test_pattern_engine.py diff --git a/mindspore/python/mindspore/golden_stick/quantization/transformer_test.py b/mindspore/python/mindspore/golden_stick/quantization/transformer_test.py deleted file mode 100644 index ea14f184798..00000000000 --- a/mindspore/python/mindspore/golden_stick/quantization/transformer_test.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""test Transformer.""" - -import unittest -from test_common import TestCommon -from mindspore.golden_stick import Transformer, LayerPolicy -from mindspore.golden_stick.quantization.layer_policy import layer_policy_key -from mindspore.rewrite import Node, Graph, PlaceholderNode -from mindspore.nn import Cell, Conv2d, BatchNorm2d - - -class TransformerTestCase(unittest.TestCase): - @staticmethod - def network(): - placeholder = TestCommon.create_placeholder_layer() - - conv1 = TestCommon.create_conv_layer("conv1", [placeholder]) - placeholder.outputs = [conv1] - - bn1 = TestCommon.create_bn_layer("bn1", [conv1]) - conv1.outputs = [bn1] - - pool1 = TestCommon.create_pool_layer("pool1", [bn1]) - bn1.outputs = [pool1] - - conv2 = TestCommon.create_conv_layer("conv2", [pool1]) - pool1.outputs = [conv2] - - bn2 = TestCommon.create_bn_layer("bn2", [conv2]) - conv2.outputs = [bn2] - - pool2 = TestCommon.create_pool_layer("pool2", [bn2]) - bn2.outputs = [pool2] - - graph = Graph(Cell) - graph.set_root(pool2) - return graph - - @staticmethod - def network_intra_overlapped(): - placeholder = TestCommon.create_placeholder_layer() - - conv1 = TestCommon.create_conv_layer("conv1", [placeholder]) - placeholder.outputs = [conv1] - - conv2 = TestCommon.create_conv_layer("conv2", [conv1]) - conv1.outputs = [conv2] - - conv3 = TestCommon.create_conv_layer("conv3", [conv2]) - conv2.outputs = [conv3] - - bn = TestCommon.create_bn_layer("bn", [conv3]) - conv3.outputs = [bn] - - pool = TestCommon.create_pool_layer("pool", [bn]) - bn.outputs = [pool] - - graph = Graph(Cell) - graph.set_root(pool) - return graph - - @staticmethod - def network_inter_overlapped(): - placeholder = TestCommon.create_placeholder_layer() - - conv1 = TestCommon.create_conv_layer("conv1", [placeholder]) - placeholder.outputs = [conv1] - - conv2 = TestCommon.create_conv_layer("conv2", [conv1]) - conv1.outputs = [conv2] - - bn = TestCommon.create_bn_layer("bn", [conv2]) - conv2.outputs = [bn] - - pool = TestCommon.create_pool_layer("pool", [bn]) - bn.outputs = [pool] - - graph = Graph(Cell) - graph.set_root(pool) - return graph - - @staticmethod - def _get_node_inout_fq(node: Node, is_input: bool = True): - policy: LayerPolicy = node.get_attribute(layer_policy_key) - if policy is None: - return None - if is_input: - fq = policy.get_input_quantizer() - else: - fq = policy.get_output_quantizer() - if not is_input: - return fq - if fq is None or not policy.get_input_need_insert_fq(0): - return None - else: - return fq - - @staticmethod - def _get_node_of_graph_inout_fq(graph: Graph, node_index: int, is_input: bool = True): - if node_index >= len(graph._nodes): - return None - node = graph._nodes[node_index] - if node is None: - return None - return TransformerTestCase._get_node_inout_fq(node, is_input) - - def test_apply(self): - transformer: Transformer = Transformer([Conv2d, BatchNorm2d]) - graph: Graph = TransformerTestCase.network() - conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) - conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) - bn1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) - bn1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) - conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 4, True) - conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 4, False) - bn2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 5, True) - bn2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 5, False) - self.assertNotEqual(conv1_in_fq, None) - self.assertNotEqual(conv1_out_fq, None) - self.assertNotEqual(bn1_in_fq, None) - self.assertNotEqual(bn1_out_fq, None) - self.assertNotEqual(conv2_in_fq, None) - self.assertNotEqual(conv2_out_fq, None) - self.assertNotEqual(bn2_in_fq, None) - self.assertNotEqual(bn2_out_fq, None) - transformer.apply(graph) - conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) - conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) - bn1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) - bn1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) - conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 4, True) - conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 4, False) - bn2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 5, True) - bn2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 5, False) - self.assertNotEqual(conv1_in_fq, None) - self.assertEqual(conv1_out_fq, None) - self.assertEqual(bn1_in_fq, None) - self.assertNotEqual(bn1_out_fq, None) - self.assertNotEqual(conv2_in_fq, None) - self.assertEqual(conv2_out_fq, None) - self.assertEqual(bn2_in_fq, None) - self.assertNotEqual(bn2_out_fq, None) - - def test_intra_overlap(self): - transformer: Transformer = Transformer([Conv2d, Conv2d]) - graph: Graph = TransformerTestCase.network_intra_overlapped() - conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) - conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) - conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) - conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) - conv3_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, True) - conv3_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, False) - self.assertNotEqual(conv1_in_fq, None) - self.assertNotEqual(conv1_out_fq, None) - self.assertNotEqual(conv2_in_fq, None) - self.assertNotEqual(conv2_out_fq, None) - self.assertNotEqual(conv3_in_fq, None) - self.assertNotEqual(conv3_out_fq, None) - transformer.apply(graph) - conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) - conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) - conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) - conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) - conv3_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, True) - conv3_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, False) - self.assertNotEqual(conv1_in_fq, None) - self.assertNotEqual(conv1_out_fq, None) - self.assertNotEqual(conv2_in_fq, None) - self.assertEqual(conv2_out_fq, None) - self.assertEqual(conv3_in_fq, None) - self.assertNotEqual(conv3_out_fq, None) - - def test_inter_overlap(self): - transformer1: Transformer = Transformer([Conv2d, Conv2d]) - transformer2: Transformer = Transformer([Conv2d, BatchNorm2d]) - graph: Graph = TransformerTestCase.network_inter_overlapped() - conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) - conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) - conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) - conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) - bn_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, True) - bn_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, False) - self.assertNotEqual(conv1_in_fq, None) - self.assertNotEqual(conv1_out_fq, None) - self.assertNotEqual(conv2_in_fq, None) - self.assertNotEqual(conv2_out_fq, None) - self.assertNotEqual(bn_in_fq, None) - self.assertNotEqual(bn_out_fq, None) - transformer1.apply(graph) - transformer2.apply(graph) - conv1_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, True) - conv1_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 1, False) - conv2_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, True) - conv2_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 2, False) - bn_in_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, True) - bn_out_fq = TransformerTestCase._get_node_of_graph_inout_fq(graph, 3, False) - self.assertNotEqual(conv1_in_fq, None) - self.assertEqual(conv1_out_fq, None) - self.assertEqual(conv2_in_fq, None) - self.assertNotEqual(conv2_out_fq, None) - self.assertNotEqual(bn_in_fq, None) - self.assertNotEqual(bn_out_fq, None) - - -if __name__ == '__main__': - unittest.main() diff --git a/mindspore/python/mindspore/rewrite/lenet.py b/mindspore/python/mindspore/rewrite/lenet.py deleted file mode 100644 index feb63521aee..00000000000 --- a/mindspore/python/mindspore/rewrite/lenet.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""LeNet.""" -import mindspore.nn as nn -from mindspore.common.initializer import Normal - - -class LeNet5(nn.Cell): - """ - Lenet network - - Args: - num_class (int): Number of classes. Default: 10. - num_channel (int): Number of channels. Default: 1. - - Returns: - Tensor, output tensor - Examples: - >>> LeNet(num_class=10) - - """ - - def __init__(self, num_class=10, num_channel=1, include_top=True): - super(LeNet5, self).__init__() - self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') - self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.include_top = include_top - if self.include_top: - self.flatten = nn.Flatten() - self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) - self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) - self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) - - def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) - if not self.include_top: - return x - x = self.flatten(x) - x = self.relu(self.fc1(x)) - x = self.relu(self.fc2(x)) - x = self.fc3(x) - return x diff --git a/mindspore/python/mindspore/rewrite/pattern_engine_match_test.py b/mindspore/python/mindspore/rewrite/pattern_engine_match_test.py deleted file mode 100644 index 39fc11892fb..00000000000 --- a/mindspore/python/mindspore/rewrite/pattern_engine_match_test.py +++ /dev/null @@ -1,191 +0,0 @@ -import unittest - -from collections import OrderedDict -from mindspore.rewrite import PatternEngine, Node, PlaceHolderNode, PatternNode -from mindspore.nn.layer import Pad, Conv2d, BatchNorm2d, ReLU, Softmax, MatMul -from mindspore import log as logger - - -class PatternEngineMatchTestCase(unittest.TestCase): - def test_merge_ordered_dict(self): - dict1: OrderedDict = OrderedDict({'a': 1, 'b': 2, 'c': 3}) - dict2: OrderedDict = OrderedDict({'d': 4, 'e': 5, 'f': 6}) - ret = PatternEngine._merge_ordered_dict(dict1, dict2) - index = 1 - for key in ret: - self.assertEqual(ret.get(key), index) - index += 1 - - @staticmethod - def chain_network(): - pad = Node() - pad.type = Pad - pad.name = "pad" - conv = Node() - conv.type = Conv2d - conv.name = "conv" - bn = Node() - bn.type = BatchNorm2d - bn.name = "bn" - relu = Node() - relu.type = ReLU - relu.name = "relu" - softmax = Node() - softmax.type = Softmax - softmax.name = "softmax" - - pad.outputs = [conv] - conv.inputs = [pad] - conv.outputs = [bn] - bn.inputs = [conv] - bn.outputs = [relu] - relu.inputs = [bn] - relu.outputs = [softmax] - softmax.inputs = [relu] - return softmax, relu - - @staticmethod - def tree_network(): - pad = Node() - pad.type = Pad - pad.name = "pad" - conv1 = Node() - conv1.type = Conv2d - conv1.name = "conv1" - bn1 = Node() - bn1.type = BatchNorm2d - bn1.name = "bn1" - conv2 = Node() - conv2.type = Conv2d - conv2.name = "conv2" - bn2 = Node() - bn2.type = BatchNorm2d - bn2.name = "bn2" - matmul = Node() - matmul.type = MatMul - matmul.name = "matmul" - softmax = Node() - softmax.type = Softmax - softmax.name = "softmax" - - pad.outputs = [conv1, conv2] - conv1.inputs = [pad] - conv1.outputs = [bn1] - bn1.inputs = [conv1] - bn1.outputs = [matmul] - conv2.inputs = [pad] - conv2.outputs = [bn2] - bn2.inputs = [conv2] - bn2.outputs = [matmul] - matmul.inputs = [bn1, bn2] - matmul.outputs = [softmax] - softmax.inputs = [matmul] - return softmax, matmul - - @staticmethod - def chain_network_for_leak_pattern(): - pad = Node() - pad.type = Pad - pad.name = "pad" - conv = Node() - conv.type = Conv2d - conv.name = "conv" - bn = Node() - bn.type = BatchNorm2d - bn.name = "bn" - relu = Node() - relu.type = ReLU - relu.name = "relu" - relu2 = Node() - relu2.type = ReLU - relu2.name = "relu2" - matmul = Node() - matmul.type = MatMul - matmul.name = "matmul" - - pad.outputs = [conv] - conv.inputs = [pad] - conv.outputs = [bn, relu2] - bn.inputs = [conv] - relu2.inputs = [conv] - return bn - - def test_chain_pattern(self): - bad_root, good_root = self.chain_network() - # define pattern - p_conv = PatternNode("p_conv", Conv2d) - p_bn = PatternNode("p_bn", BatchNorm2d, [p_conv]) - p_relu = PatternNode("p_relu", ReLU, [p_bn]) - # define pattern engine - pattern_engine = PatternEngine(p_relu) - match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) - logger.info("*****Chain match softmax result: %s", match) - self.assertEqual(match, False) - match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) - logger.info("*****Chain match matmul result: %s", match) - self.assertEqual(match, True) - self.assertEqual(len(match_dict), 3) - - def test_chain_pattern_from_list(self): - bad_root, good_root = self.chain_network() - # define pattern engine - pattern_engine = PatternEngine([Conv2d, BatchNorm2d, ReLU]) - match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) - logger.info("*****List chain match softmax result: %s", match) - self.assertEqual(match, False) - match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) - logger.info("*****List chain match matmul result: %s", match) - self.assertEqual(match, True) - self.assertEqual(len(match_dict), 3) - - def test_tree_pattern(self): - bad_root, good_root = self.tree_network() - # define pattern - p_placeholder = PlaceHolderNode() - p_conv1 = PatternNode("p_conv1", Conv2d, [p_placeholder]) - p_bn1 = PatternNode("p_bn1", BatchNorm2d, [p_conv1]) - p_conv2 = PatternNode("p_conv2", Conv2d, [p_placeholder]) - p_bn2 = PatternNode("p_bn2", BatchNorm2d, [p_conv2]) - p_matmul = PatternNode("p_matmul", MatMul, [p_bn1, p_bn2]) - # define pattern engine - pattern_engine = PatternEngine(p_matmul) - match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) - logger.info("*****Tree match softmax result: %s", match) - self.assertEqual(match, False) - match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) - logger.info("*****Tree match matmul result: %s", match) - self.assertEqual(match, True) - self.assertEqual(len(match_dict), 5) - - def test_placeholder_pattern(self): - _, good_root = self.tree_network() - # define pattern - p_placeholder = PlaceHolderNode() - p_conv2 = PatternNode("p_conv2", Conv2d, [p_placeholder]) - p_bn2 = PatternNode("p_bn2", BatchNorm2d, [p_conv2]) - p_matmul = PatternNode("p_matmul", MatMul, [p_bn2]) - # define pattern engine - pattern_engine = PatternEngine(p_matmul) - match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) - logger.info("*****Tree no-placehold match matmul result: %s", match) - self.assertEqual(match, False) - - p_matmul.set_inputs([p_placeholder, p_bn2]) - pattern_engine = PatternEngine(p_matmul) - match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) - logger.info("*****Tree placehold match matmul result: %s", match) - self.assertEqual(match, True) - self.assertEqual(len(match_dict), 3) - - def test_chain_leak_pattern(self): - root = self.chain_network_for_leak_pattern() - # define pattern - pattern_engine = PatternEngine([Conv2d, BatchNorm2d]) - match, match_dict = pattern_engine._match(pattern_engine.pattern(), root) - self.assertEqual(match, True) - match = pattern_engine._check_match(pattern_engine.pattern(), match_dict) - self.assertEqual(match, False) - - -if __name__ == '__main__': - unittest.main() diff --git a/mindspore/python/mindspore/rewrite/pattern_engine_test.py b/mindspore/python/mindspore/rewrite/pattern_engine_test.py deleted file mode 100644 index 8f7cc55e759..00000000000 --- a/mindspore/python/mindspore/rewrite/pattern_engine_test.py +++ /dev/null @@ -1,177 +0,0 @@ -import unittest - -from mindspore.rewrite import PatternEngine, Node, Graph, NodeType -from lenet import LeNet5 -from mindspore.nn import Cell -from mindspore.nn.layer import Conv2d, BatchNorm2d, Dense, MaxPool2d, Flatten, ReLU -from mindspore import log as logger - - -class PatternEngineTestCase(unittest.TestCase): - def lenet(self): - conv1 = Node() - conv1.type = Conv2d - conv1.name = "conv1" - - bn1 = Node() - bn1.type = BatchNorm2d - bn1.name = "bn1" - conv1.outputs = [bn1] - bn1.inputs = [conv1] - - pool1 = Node() - pool1.type = MaxPool2d - pool1.name = "pool1" - bn1.outputs = [pool1] - pool1.inputs = [bn1] - - conv2 = Node() - conv2.type = Conv2d - conv2.name = "conv2" - pool1.outputs = [conv2] - conv2.inputs = [pool1] - - bn2 = Node() - bn2.type = BatchNorm2d - bn2.name = "bn2" - conv2.outputs = [bn2] - bn2.inputs = [conv2] - - pool2 = Node() - pool2.type = MaxPool2d - pool2.name = "pool2" - bn2.outputs = [pool2] - pool2.inputs = [bn2] - - flatten = Node() - flatten.type = Flatten - flatten.name = "flatten" - pool2.outputs = [flatten] - flatten.inputs = [pool2] - - fc1 = Node() - fc1.type = Dense - fc1.name = "dense1" - flatten.outputs = [fc1] - fc1.inputs = [flatten] - - fc2 = Node() - fc2.type = Dense - fc2.name = "dense2" - fc1.outputs = [fc2] - fc2.inputs = [fc1] - - fc3 = Node() - fc3.type = Dense - fc3.name = "dense3" - fc2.outputs = [fc3] - fc3.inputs = [fc2] - graph = Graph(Cell) - graph.set_root(fc3) - return graph - - @staticmethod - def get_nodes_count(root: Node, to_print: bool = False): - count = 1 - if to_print: - logger.debug("Visit %s", root.name) - for input in root.inputs: - count += PatternEngineTestCase.get_nodes_count(input, to_print) - return count - - def test_pattern(self): - class ConvBn(Cell): - def __init__(self, conv, bn): - super(ConvBn, self).__init__() - self._conv = conv - self._bn = bn - - def construct(self, x): - x = self._conv(x) - return self._bn(x) - - class ConvBnPatternEngine(PatternEngine): - def __init__(self): - super().__init__([Conv2d, BatchNorm2d], ConvBn) - - lenet = self.lenet() - self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet.root()), 10) - pattern_engine = ConvBnPatternEngine() - pattern_engine.apply(lenet) - self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet.root(), True), 8) - - def test_lenet(self): - lenet = LeNet5(num_class=10) - lenet_graph = Graph(LeNet5) - lenet_graph.create_ast() - lenet_graph.get_function_root() - lenet_graph.parse_init() - lenet_graph.parse_construct() - lenet_graph.set_root(lenet_graph.nodes[-1]) - origin_lenet_nn = 14 - self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet_graph.root()), origin_lenet_nn) - # test insert - pre_node = None - post_node = None - for node in lenet_graph.nodes: - if len(node.inputs) == 0: - continue - if node.type is Conv2d and len(node.outputs) == 1 and node.outputs[0].type is ReLU: - pre_node = node - post_node = node.outputs[0] - self.assertNotEqual(pre_node, None) - self.assertNotEqual(post_node, None) - conv_cell = Conv2d(16, 16, 3) - conv_node = Node(name="conv2d_3", instance=conv_cell) - conv_node.inputs = [pre_node] - conv_node.outputs = [post_node] - lenet_graph.insert_node(conv_node) - self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet_graph.root(), True), origin_lenet_nn + 1) - self.assertEqual(len(pre_node.outputs), 1) - self.assertEqual(pre_node.outputs[0], conv_node) - self.assertEqual(len(post_node.inputs), 1) - self.assertEqual(post_node.inputs[0], conv_node) - self.assertEqual(len(conv_node.inputs), 1) - self.assertEqual(conv_node.inputs[0], pre_node) - self.assertEqual(len(conv_node.outputs), 1) - self.assertEqual(conv_node.outputs[0], post_node) - - # test replace - - class ConvReLU(Cell): - def __init__(self, conv, relu): - super(ConvReLU, self).__init__() - self._conv = conv - self._relu = relu - - def construct(self, x): - x = self._conv(x) - return self._relu(x) - - class ConvReLUPatternEngine(PatternEngine): - def __init__(self): - super().__init__([Conv2d, ReLU], ConvReLU) - - pattern_engine = ConvReLUPatternEngine() - pattern_engine.apply(lenet_graph) - self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet_graph.root(), True), origin_lenet_nn + 1 - 2) - # test remove - for node in lenet_graph.nodes: - if node.node_type() == NodeType.call_cell and node.type is ReLU: - lenet_graph.remove_node(node) - self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet_graph.root(), True), origin_lenet_nn + 1 - 2 - 2) - - def remove_cell(*args, **kwargs): - return None - - class RemovePatternEngine(PatternEngine): - def __init__(self): - super().__init__([MaxPool2d], remove_cell) - - pattern_engine = RemovePatternEngine() - pattern_engine.apply(lenet_graph) - self.assertEqual(PatternEngineTestCase.get_nodes_count(lenet_graph.root(), True), origin_lenet_nn + 1 - 2 - 2 - 2) - - -if __name__ == '__main__': - unittest.main() diff --git a/mindspore/python/mindspore/rewrite/ut.sh b/mindspore/python/mindspore/rewrite/ut.sh index 637df70cf80..b5ec4d66d62 100644 --- a/mindspore/python/mindspore/rewrite/ut.sh +++ b/mindspore/python/mindspore/rewrite/ut.sh @@ -1,34 +1,9 @@ #!/bin/bash -success=0 -failure=0 -python pattern_engine_match_test.py -if [[ $? -ne 0 ]]; then - echo "---------------- pattern_engine_match_test failed" - ((failure=failure+1)) -else - echo "---------------- pattern_engine_match_test succeed" - ((success=success+1)) -fi +cd ../../../../tests/ut/python/rewrite/ || exit +pytest +cd - -python pattern_engine_test.py -if [[ $? -ne 0 ]]; then - echo "---------------- pattern_engine_test failed" - ((failure=failure+1)) -else - echo "---------------- pattern_engine_test succeed" - ((success=success+1)) -fi - -cd ../golden_stick/quantization/ -python transformer_test.py -if [[ $? -ne 0 ]]; then - echo "---------------- transformer_test failed" - ((failure=failure+1)) -else - echo "---------------- transformer_test succeed" - ((success=success+1)) -fi - - -echo "=========== rewrite testcases finished, ${success} succeed, ${failure} failed" +cd ../../../../tests/ut/python/golden_stick/ || exit +pytest +cd - diff --git a/tests/ut/python/golden_stick/__init__.py b/tests/ut/python/golden_stick/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindspore/python/mindspore/golden_stick/quantization/test_common.py b/tests/ut/python/golden_stick/test_common.py similarity index 46% rename from mindspore/python/mindspore/golden_stick/quantization/test_common.py rename to tests/ut/python/golden_stick/test_common.py index c6bb330c882..a4e281be2c4 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/test_common.py +++ b/tests/ut/python/golden_stick/test_common.py @@ -20,32 +20,30 @@ from mindspore.rewrite import Node, PlaceholderNode from mindspore.nn import Conv2d, BatchNorm2d, MaxPool2d -class TestCommon: - @staticmethod - def create_layer_policy(input_num, weight_names: []) -> LayerPolicy: - layer_policy = DefaultLayerPolicy(weight_names, []) - layer_policy.set_input_number(input_num) - return layer_policy - - @staticmethod - def create_conv_layer(name, inputs): - conv = Node(name=name, inputs=inputs, instance=Conv2d(16, 16, 9)) - conv.set_attribute(layer_policy_key, TestCommon.create_layer_policy(1, ["weight"])) - return conv - - @staticmethod - def create_bn_layer(name, inputs): - bn = Node(name=name, inputs=inputs, instance=BatchNorm2d(16)) - bn.set_attribute(layer_policy_key, TestCommon.create_layer_policy(1, ["gamma"])) - return bn - - @staticmethod - def create_pool_layer(name, inputs): - pool = Node(name=name, inputs=inputs, instance=MaxPool2d()) - pool.set_attribute(layer_policy_key, TestCommon.create_layer_policy(1, [])) - return pool - - @staticmethod - def create_placeholder_layer(): - placeholder = PlaceholderNode("placeholder") - return placeholder +def create_layer_policy(input_num, weight_names: []) -> LayerPolicy: + layer_policy = DefaultLayerPolicy(weight_names, []) + layer_policy.set_input_number(input_num) + return layer_policy + + +def create_conv_layer(name, inputs): + conv = Node(name=name, inputs=inputs, instance=Conv2d(16, 16, 9)) + conv.set_attribute(layer_policy_key, create_layer_policy(1, ["weight"])) + return conv + + +def create_bn_layer(name, inputs): + bn = Node(name=name, inputs=inputs, instance=BatchNorm2d(16)) + bn.set_attribute(layer_policy_key, create_layer_policy(1, ["gamma"])) + return bn + + +def create_pool_layer(name, inputs): + pool = Node(name=name, inputs=inputs, instance=MaxPool2d()) + pool.set_attribute(layer_policy_key, create_layer_policy(1, [])) + return pool + + +def create_placeholder_layer(): + placeholder = PlaceholderNode("placeholder") + return placeholder diff --git a/tests/ut/python/golden_stick/test_transformer.py b/tests/ut/python/golden_stick/test_transformer.py new file mode 100644 index 00000000000..37e479238d5 --- /dev/null +++ b/tests/ut/python/golden_stick/test_transformer.py @@ -0,0 +1,215 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test Transformer.""" + +from test_common import * +from mindspore.golden_stick import Transformer, LayerPolicy +from mindspore.golden_stick.quantization.layer_policy import layer_policy_key +from mindspore.rewrite import Node, Graph, PlaceholderNode +from mindspore.nn import Cell, Conv2d, BatchNorm2d + + +def create_simple_network(): + placeholder = create_placeholder_layer() + + conv1 = create_conv_layer("conv1", [placeholder]) + placeholder.outputs = [conv1] + + bn1 = create_bn_layer("bn1", [conv1]) + conv1.outputs = [bn1] + + pool1 = create_pool_layer("pool1", [bn1]) + bn1.outputs = [pool1] + + conv2 = create_conv_layer("conv2", [pool1]) + pool1.outputs = [conv2] + + bn2 = create_bn_layer("bn2", [conv2]) + conv2.outputs = [bn2] + + pool2 = create_pool_layer("pool2", [bn2]) + bn2.outputs = [pool2] + + graph = Graph(Cell) + graph.set_root(pool2) + return graph + + +def create_network_intra_overlapped(): + placeholder = create_placeholder_layer() + + conv1 = create_conv_layer("conv1", [placeholder]) + placeholder.outputs = [conv1] + + conv2 = create_conv_layer("conv2", [conv1]) + conv1.outputs = [conv2] + + conv3 = create_conv_layer("conv3", [conv2]) + conv2.outputs = [conv3] + + bn = create_bn_layer("bn", [conv3]) + conv3.outputs = [bn] + + pool = create_pool_layer("pool", [bn]) + bn.outputs = [pool] + + graph = Graph(Cell) + graph.set_root(pool) + return graph + + +def create_network_inter_overlapped(): + placeholder = create_placeholder_layer() + + conv1 = create_conv_layer("conv1", [placeholder]) + placeholder.outputs = [conv1] + + conv2 = create_conv_layer("conv2", [conv1]) + conv1.outputs = [conv2] + + bn = create_bn_layer("bn", [conv2]) + conv2.outputs = [bn] + + pool = create_pool_layer("pool", [bn]) + bn.outputs = [pool] + + graph = Graph(Cell) + graph.set_root(pool) + return graph + + +def get_node_inout_fq(node: Node, is_input: bool = True): + policy: LayerPolicy = node.get_attribute(layer_policy_key) + if policy is None: + return None + if is_input: + fq = policy.get_input_quantizer() + else: + fq = policy.get_output_quantizer() + if not is_input: + return fq + if fq is None or not policy.get_input_need_insert_fq(0): + return None + else: + return fq + + +def get_node_of_graph_inout_fq(graph: Graph, node_index: int, is_input: bool = True): + if node_index >= len(graph.nodes): + return None + node = graph.nodes[node_index] + if node is None: + return None + return get_node_inout_fq(node, is_input) + + +def test_apply(): + transformer: Transformer = Transformer([Conv2d, BatchNorm2d]) + graph: Graph = create_simple_network() + conv1_in_fq = get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = get_node_of_graph_inout_fq(graph, 1, False) + bn1_in_fq = get_node_of_graph_inout_fq(graph, 2, True) + bn1_out_fq = get_node_of_graph_inout_fq(graph, 2, False) + conv2_in_fq = get_node_of_graph_inout_fq(graph, 4, True) + conv2_out_fq = get_node_of_graph_inout_fq(graph, 4, False) + bn2_in_fq = get_node_of_graph_inout_fq(graph, 5, True) + bn2_out_fq = get_node_of_graph_inout_fq(graph, 5, False) + assert conv1_in_fq is not None + assert conv1_out_fq is not None + assert bn1_in_fq is not None + assert bn1_out_fq is not None + assert conv2_in_fq is not None + assert conv2_out_fq is not None + assert bn2_in_fq is not None + assert bn2_out_fq is not None + transformer.apply(graph) + conv1_in_fq = get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = get_node_of_graph_inout_fq(graph, 1, False) + bn1_in_fq = get_node_of_graph_inout_fq(graph, 2, True) + bn1_out_fq = get_node_of_graph_inout_fq(graph, 2, False) + conv2_in_fq = get_node_of_graph_inout_fq(graph, 4, True) + conv2_out_fq = get_node_of_graph_inout_fq(graph, 4, False) + bn2_in_fq = get_node_of_graph_inout_fq(graph, 5, True) + bn2_out_fq = get_node_of_graph_inout_fq(graph, 5, False) + assert conv1_in_fq is not None + assert conv1_out_fq is None + assert bn1_in_fq is None + assert bn1_out_fq is not None + assert conv2_in_fq is not None + assert conv2_out_fq is None + assert bn2_in_fq is None + assert bn2_out_fq is not None + + +def test_intra_overlap(): + transformer: Transformer = Transformer([Conv2d, Conv2d]) + graph: Graph = create_network_intra_overlapped() + conv1_in_fq = get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = get_node_of_graph_inout_fq(graph, 1, False) + conv2_in_fq = get_node_of_graph_inout_fq(graph, 2, True) + conv2_out_fq = get_node_of_graph_inout_fq(graph, 2, False) + conv3_in_fq = get_node_of_graph_inout_fq(graph, 3, True) + conv3_out_fq = get_node_of_graph_inout_fq(graph, 3, False) + assert conv1_in_fq is not None + assert conv1_out_fq is not None + assert conv2_in_fq is not None + assert conv2_out_fq is not None + assert conv3_in_fq is not None + assert conv3_out_fq is not None + transformer.apply(graph) + conv1_in_fq = get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = get_node_of_graph_inout_fq(graph, 1, False) + conv2_in_fq = get_node_of_graph_inout_fq(graph, 2, True) + conv2_out_fq = get_node_of_graph_inout_fq(graph, 2, False) + conv3_in_fq = get_node_of_graph_inout_fq(graph, 3, True) + conv3_out_fq = get_node_of_graph_inout_fq(graph, 3, False) + assert conv1_in_fq is not None + assert conv1_out_fq is not None + assert conv2_in_fq is not None + assert conv2_out_fq is None + assert conv3_in_fq is None + assert conv3_out_fq is not None + + +def test_inter_overlap(): + transformer1: Transformer = Transformer([Conv2d, Conv2d]) + transformer2: Transformer = Transformer([Conv2d, BatchNorm2d]) + graph: Graph = create_network_inter_overlapped() + conv1_in_fq = get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = get_node_of_graph_inout_fq(graph, 1, False) + conv2_in_fq = get_node_of_graph_inout_fq(graph, 2, True) + conv2_out_fq = get_node_of_graph_inout_fq(graph, 2, False) + bn_in_fq = get_node_of_graph_inout_fq(graph, 3, True) + bn_out_fq = get_node_of_graph_inout_fq(graph, 3, False) + assert conv1_in_fq is not None + assert conv1_out_fq is not None + assert conv2_in_fq is not None + assert conv2_out_fq is not None + assert bn_in_fq is not None + assert bn_out_fq is not None + transformer1.apply(graph) + transformer2.apply(graph) + conv1_in_fq = get_node_of_graph_inout_fq(graph, 1, True) + conv1_out_fq = get_node_of_graph_inout_fq(graph, 1, False) + conv2_in_fq = get_node_of_graph_inout_fq(graph, 2, True) + conv2_out_fq = get_node_of_graph_inout_fq(graph, 2, False) + bn_in_fq = get_node_of_graph_inout_fq(graph, 3, True) + bn_out_fq = get_node_of_graph_inout_fq(graph, 3, False) + assert conv1_in_fq is not None + assert conv1_out_fq is None + assert conv2_in_fq is None + assert conv2_out_fq is not None + assert bn_in_fq is not None + assert bn_out_fq is not None diff --git a/tests/ut/python/rewrite/__init__.py b/tests/ut/python/rewrite/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/ut/python/rewrite/test_pattern_engine.py b/tests/ut/python/rewrite/test_pattern_engine.py new file mode 100644 index 00000000000..412e7b1a9b6 --- /dev/null +++ b/tests/ut/python/rewrite/test_pattern_engine.py @@ -0,0 +1,377 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_pattern_engine """ + +from collections import OrderedDict +from mindspore.rewrite import PatternEngine, Node, Graph, NodeType, PlaceHolderNode, PatternNode +from lenet import LeNet5 +from mindspore.nn import Cell +from mindspore.nn.layer import Conv2d, BatchNorm2d, Dense, MaxPool2d, Flatten, ReLU, Pad, Softmax, MatMul +from mindspore import log as logger + + +def create_lenet_graph(): + conv1 = Node() + conv1.type = Conv2d + conv1.name = "conv1" + + bn1 = Node() + bn1.type = BatchNorm2d + bn1.name = "bn1" + conv1.outputs = [bn1] + bn1.inputs = [conv1] + + pool1 = Node() + pool1.type = MaxPool2d + pool1.name = "pool1" + bn1.outputs = [pool1] + pool1.inputs = [bn1] + + conv2 = Node() + conv2.type = Conv2d + conv2.name = "conv2" + pool1.outputs = [conv2] + conv2.inputs = [pool1] + + bn2 = Node() + bn2.type = BatchNorm2d + bn2.name = "bn2" + conv2.outputs = [bn2] + bn2.inputs = [conv2] + + pool2 = Node() + pool2.type = MaxPool2d + pool2.name = "pool2" + bn2.outputs = [pool2] + pool2.inputs = [bn2] + + flatten = Node() + flatten.type = Flatten + flatten.name = "flatten" + pool2.outputs = [flatten] + flatten.inputs = [pool2] + + fc1 = Node() + fc1.type = Dense + fc1.name = "dense1" + flatten.outputs = [fc1] + fc1.inputs = [flatten] + + fc2 = Node() + fc2.type = Dense + fc2.name = "dense2" + fc1.outputs = [fc2] + fc2.inputs = [fc1] + + fc3 = Node() + fc3.type = Dense + fc3.name = "dense3" + fc2.outputs = [fc3] + fc3.inputs = [fc2] + graph = Graph(Cell) + graph.set_root(fc3) + return graph + + +def create_chain_network(): + pad = Node() + pad.type = Pad + pad.name = "pad" + conv = Node() + conv.type = Conv2d + conv.name = "conv" + bn = Node() + bn.type = BatchNorm2d + bn.name = "bn" + relu = Node() + relu.type = ReLU + relu.name = "relu" + softmax = Node() + softmax.type = Softmax + softmax.name = "softmax" + + pad.outputs = [conv] + conv.inputs = [pad] + conv.outputs = [bn] + bn.inputs = [conv] + bn.outputs = [relu] + relu.inputs = [bn] + relu.outputs = [softmax] + softmax.inputs = [relu] + return softmax, relu + + +def create_tree_network(): + pad = Node() + pad.type = Pad + pad.name = "pad" + conv1 = Node() + conv1.type = Conv2d + conv1.name = "conv1" + bn1 = Node() + bn1.type = BatchNorm2d + bn1.name = "bn1" + conv2 = Node() + conv2.type = Conv2d + conv2.name = "conv2" + bn2 = Node() + bn2.type = BatchNorm2d + bn2.name = "bn2" + matmul = Node() + matmul.type = MatMul + matmul.name = "matmul" + softmax = Node() + softmax.type = Softmax + softmax.name = "softmax" + + pad.outputs = [conv1, conv2] + conv1.inputs = [pad] + conv1.outputs = [bn1] + bn1.inputs = [conv1] + bn1.outputs = [matmul] + conv2.inputs = [pad] + conv2.outputs = [bn2] + bn2.inputs = [conv2] + bn2.outputs = [matmul] + matmul.inputs = [bn1, bn2] + matmul.outputs = [softmax] + softmax.inputs = [matmul] + return softmax, matmul + + +def create_chain_network_for_leak_pattern(): + pad = Node() + pad.type = Pad + pad.name = "pad" + conv = Node() + conv.type = Conv2d + conv.name = "conv" + bn = Node() + bn.type = BatchNorm2d + bn.name = "bn" + relu = Node() + relu.type = ReLU + relu.name = "relu" + relu2 = Node() + relu2.type = ReLU + relu2.name = "relu2" + matmul = Node() + matmul.type = MatMul + matmul.name = "matmul" + + pad.outputs = [conv] + conv.inputs = [pad] + conv.outputs = [bn, relu2] + bn.inputs = [conv] + relu2.inputs = [conv] + return bn + + +def get_nodes_count(root: Node, to_print: bool = False): + count = 1 + if to_print: + logger.debug("Visit %s", root.name) + for node_input in root.inputs: + count += get_nodes_count(node_input, to_print) + return count + + +def test_merge_ordered_dict(): + dict1: OrderedDict = OrderedDict({'a': 1, 'b': 2, 'c': 3}) + dict2: OrderedDict = OrderedDict({'d': 4, 'e': 5, 'f': 6}) + ret = PatternEngine._merge_ordered_dict(dict1, dict2) + index = 1 + for key in ret: + assert ret.get(key) == index + index += 1 + + +def test_chain_pattern(): + bad_root, good_root = create_chain_network() + # define pattern + p_conv = PatternNode("p_conv", Conv2d) + p_bn = PatternNode("p_bn", BatchNorm2d, [p_conv]) + p_relu = PatternNode("p_relu", ReLU, [p_bn]) + # define pattern engine + pattern_engine = PatternEngine(p_relu) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) + logger.info("*****Chain match softmax result: %s", match) + assert match is False + match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) + logger.info("*****Chain match matmul result: %s", match) + assert match is True + assert len(match_dict) == 3 + + +def test_chain_pattern_from_list(): + bad_root, good_root = create_chain_network() + # define pattern engine + pattern_engine = PatternEngine([Conv2d, BatchNorm2d, ReLU]) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) + logger.info("*****List chain match softmax result: %s", match) + assert match is False + match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) + logger.info("*****List chain match matmul result: %s", match) + assert match is True + assert len(match_dict) == 3 + + +def test_tree_pattern(): + bad_root, good_root = create_tree_network() + # define pattern + p_placeholder = PlaceHolderNode() + p_conv1 = PatternNode("p_conv1", Conv2d, [p_placeholder]) + p_bn1 = PatternNode("p_bn1", BatchNorm2d, [p_conv1]) + p_conv2 = PatternNode("p_conv2", Conv2d, [p_placeholder]) + p_bn2 = PatternNode("p_bn2", BatchNorm2d, [p_conv2]) + p_matmul = PatternNode("p_matmul", MatMul, [p_bn1, p_bn2]) + # define pattern engine + pattern_engine = PatternEngine(p_matmul) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), bad_root) + logger.info("*****Tree match softmax result: %s", match) + assert match is False + match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) + logger.info("*****Tree match matmul result: %s", match) + assert match is True + assert len(match_dict) == 5 + + +def test_placeholder_pattern(): + _, good_root = create_tree_network() + # define pattern + p_placeholder = PlaceHolderNode() + p_conv2 = PatternNode("p_conv2", Conv2d, [p_placeholder]) + p_bn2 = PatternNode("p_bn2", BatchNorm2d, [p_conv2]) + p_matmul = PatternNode("p_matmul", MatMul, [p_bn2]) + # define pattern engine + pattern_engine = PatternEngine(p_matmul) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) + logger.info("*****Tree no-placeholder match matmul result: %s", match) + assert match is False + + p_matmul.set_inputs([p_placeholder, p_bn2]) + pattern_engine = PatternEngine(p_matmul) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), good_root) + logger.info("*****Tree placeholder match matmul result: %s", match) + assert match is True + assert len(match_dict) == 3 + + +def test_chain_leak_pattern(): + root = create_chain_network_for_leak_pattern() + # define pattern + pattern_engine = PatternEngine([Conv2d, BatchNorm2d]) + match, match_dict = pattern_engine._match(pattern_engine.pattern(), root) + assert match is True + match = pattern_engine._check_match(pattern_engine.pattern(), match_dict) + assert match is False + + +def test_pattern(): + class ConvBn(Cell): + def __init__(self, conv, bn): + super(ConvBn, self).__init__() + self._conv = conv + self._bn = bn + + def construct(self, x): + x = self._conv(x) + return self._bn(x) + + class ConvBnPatternEngine(PatternEngine): + def __init__(self): + super().__init__([Conv2d, BatchNorm2d], ConvBn) + + lenet = create_lenet_graph() + assert get_nodes_count(lenet.root()) == 10 + pattern_engine = ConvBnPatternEngine() + pattern_engine.apply(lenet) + assert get_nodes_count(lenet.root(), True) == 8 + + +def test_lenet(): + lenet_graph = Graph(LeNet5) + lenet_graph.create_ast() + lenet_graph.get_function_root() + lenet_graph.parse_init() + lenet_graph.parse_construct() + lenet_graph.set_root(lenet_graph.nodes[-1]) + origin_lenet_nn = 14 + assert get_nodes_count(lenet_graph.root()) == origin_lenet_nn + # test insert + pre_node = None + post_node = None + for node in lenet_graph.nodes: + if len(node.inputs) == 0: + continue + if node.type is Conv2d and len(node.outputs) == 1 and node.outputs[0].type is ReLU: + pre_node = node + post_node = node.outputs[0] + assert pre_node is not None + assert post_node is not None + conv_cell = Conv2d(16, 16, 3) + conv_node = Node(name="conv2d_3", instance=conv_cell) + conv_node.inputs = [pre_node] + conv_node.outputs = [post_node] + lenet_graph.insert_node(conv_node) + origin_lenet_nn += 1 + assert get_nodes_count(lenet_graph.root(), True) == origin_lenet_nn + assert len(pre_node.outputs) == 1 + assert pre_node.outputs[0] == conv_node + assert len(post_node.inputs) == 1 + assert post_node.inputs[0] == conv_node + assert len(conv_node.inputs) == 1 + assert conv_node.inputs[0] == pre_node + assert len(conv_node.outputs) == 1 + assert conv_node.outputs[0] == post_node + + # test replace + + class ConvReLU(Cell): + def __init__(self, conv, relu): + super(ConvReLU, self).__init__() + self._conv = conv + self._relu = relu + + def construct(self, x): + x = self._conv(x) + return self._relu(x) + + class ConvReLUPatternEngine(PatternEngine): + def __init__(self): + super().__init__([Conv2d, ReLU], ConvReLU) + + pattern_engine = ConvReLUPatternEngine() + pattern_engine.apply(lenet_graph) + origin_lenet_nn -= 2 + assert get_nodes_count(lenet_graph.root(), True) == origin_lenet_nn + # test remove + for node in lenet_graph.nodes: + if node.node_type() == NodeType.call_cell and node.type is ReLU: + lenet_graph.remove_node(node) + origin_lenet_nn -= 2 + assert get_nodes_count(lenet_graph.root(), True) == origin_lenet_nn + + def remove_cell(*args, **kwargs): + return None + + class RemovePatternEngine(PatternEngine): + def __init__(self): + super().__init__([MaxPool2d], remove_cell) + + pattern_engine = RemovePatternEngine() + pattern_engine.apply(lenet_graph) + origin_lenet_nn -= 2 + assert get_nodes_count(lenet_graph.root(), True) == origin_lenet_nn diff --git a/tests/ut/python/runtest.sh b/tests/ut/python/runtest.sh index 5cf3f4481a2..24185416bd8 100755 --- a/tests/ut/python/runtest.sh +++ b/tests/ut/python/runtest.sh @@ -119,6 +119,18 @@ else if [ ${RET} -ne 0 ]; then exit ${RET} fi + + pytest $CURRPATH/rewrite/*.py + RET=$? + if [ ${RET} -ne 0 ]; then + exit ${RET} + fi + + pytest $CURRPATH/golden_stick/*.py + RET=$? + if [ ${RET} -ne 0 ]; then + exit ${RET} + fi fi RET=$? -- Gitee From 13e9bcb4182d32a06ab1909b4763ab5db82115d1 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Wed, 22 Dec 2021 11:11:39 +0800 Subject: [PATCH 12/34] add unparse file --- mindspore/python/mindspore/rewrite/unparse.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 mindspore/python/mindspore/rewrite/unparse.py diff --git a/mindspore/python/mindspore/rewrite/unparse.py b/mindspore/python/mindspore/rewrite/unparse.py new file mode 100644 index 00000000000..da521471536 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/unparse.py @@ -0,0 +1,101 @@ + +import ast +import collections +from typing import Dict + +class unparse: + @staticmethod + def create_name(value) -> ast.Name: + ast_node = ast.Name(lineno=0, col_offset=0, id=value, ctx=ast.Load()) + + return ast_node + + @staticmethod + def create_assign(targets=None, value=None): + ast_node = ast.Assign(lineno=0, col_offset=0, ctx=ast.Load()) + ast_node.targets = unparse._create_targets(targets) + + if isinstance(value, ast.AST): + ast_node.value = value + else: + raise("unsupported type in create_assign, value: ", value) + + return ast_node + + @staticmethod + def create_attribute(value, attr): + ast_node = ast.Attribute(lineno=0, col_offset=0, ctx=ast.Store()) + if isinstance(value, str): + ast_node.value = unparse.create_name(value) + else: + raise("unsupported type") + ast_node.attr = attr + + return ast_node + + @staticmethod + def create_call(func_name, args, kwargs: Dict): + ast_node = ast.Call(lineno=0, col_offset=0, ctx=ast.Load()) + ast_node.args = args + ast_node.keywords = unparse._create_keywords(kwargs) + + if '.' in func_name: + value = func_name.split(".")[-2] + attr = func_name.split(".")[-1] + ast_node.func = unparse.create_attribute(value, attr) + else: + ast_node.func = unparse.create_name(func_name) + + return ast_node + + @staticmethod + def create_num(value): + ast_node = ast.Num(lineno=0, col_offset=0, n=value) + return ast_node + + @staticmethod + def create_keyword(arg, value): + ast_node = ast.keyword() + ast_node.arg = arg + if isinstance(value, int): + ast_node.value = unparse.create_num(value) + elif isinstance(value, str): + ast_node.value = unparse.create_str(value) + else: + ast_node.value = unparse.create_nameconstant(value) + return ast_node + + @staticmethod + def create_nameconstant(value): + ast_node = ast.NameConstant(lineno=0, col_offset=0, value=value) + + return ast_node + + @staticmethod + def _create_targets(targets): + targets_ = [] + for target in targets: + if '.' in target: + value = target.split(".")[-2] + attr = target.split(".")[-1] + t = unparse.create_attribute(value, attr) + else: + t = unparse.create_name(target) + + targets_.append(t) + + return targets_ + + @staticmethod + def _create_keywords(keywords: Dict): + keywords_ = [] + for key, value in keywords: + k = unparse.create_keyword(key, value) + keywords_.append(k) + + return keywords_ + + @staticmethod + def create_str(value): + ast_node = ast.Str(lineno=0, col_offset=0, s=value) + return ast_node \ No newline at end of file -- Gitee From 1f4a4ae3895bcc6c096406ebbac9fd9fce2f1442 Mon Sep 17 00:00:00 2001 From: kevin Date: Fri, 24 Dec 2021 03:14:05 +0000 Subject: [PATCH 13/34] !17 add ast_unparser to graph * add unparse funcs --- mindspore/python/mindspore/rewrite/graph.py | 8 +- mindspore/python/mindspore/rewrite/unparse.py | 142 +++++++++++++++++- .../python/mindspore/rewrite/unparse_test.py | 51 +++++++ 3 files changed, 192 insertions(+), 9 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite/unparse_test.py diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index 5067d724525..15c7f316f30 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -13,6 +13,8 @@ from mindspore.ops.primitive import Primitive from .node import AttributeNode, ConstantNode, Node, NodeType, PlaceholderNode from .parser import Parser +from mindspore.rewrite.ast_unparser import ASTUnparser + class _node_list: def __init__(self, graph) -> None: @@ -457,8 +459,10 @@ class Graph(): return astunparse.unparse(self._ast_root) @property - def convert_to_cell(self) -> nn.Cell: - pass + def convert_to_cell(self): + unparser = ASTUnparser(self._network, self.python_code) + res_class = unparser.get_res_cell() + return res_class def print_graph(self): pass diff --git a/mindspore/python/mindspore/rewrite/unparse.py b/mindspore/python/mindspore/rewrite/unparse.py index da521471536..84f44e5097b 100644 --- a/mindspore/python/mindspore/rewrite/unparse.py +++ b/mindspore/python/mindspore/rewrite/unparse.py @@ -1,8 +1,23 @@ - +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""unparse node to ast node""" import ast import collections from typing import Dict + class unparse: @staticmethod def create_name(value) -> ast.Name: @@ -28,7 +43,7 @@ class unparse: if isinstance(value, str): ast_node.value = unparse.create_name(value) else: - raise("unsupported type") + raise("unsupported type in create_attribute, type: ", value) ast_node.attr = attr return ast_node @@ -36,9 +51,9 @@ class unparse: @staticmethod def create_call(func_name, args, kwargs: Dict): ast_node = ast.Call(lineno=0, col_offset=0, ctx=ast.Load()) - ast_node.args = args + ast_node.args = unparse._create_args(args) ast_node.keywords = unparse._create_keywords(kwargs) - + if '.' in func_name: value = func_name.split(".")[-2] attr = func_name.split(".")[-1] @@ -55,7 +70,7 @@ class unparse: @staticmethod def create_keyword(arg, value): - ast_node = ast.keyword() + ast_node = ast.keyword(lineno=0, col_offset=0) ast_node.arg = arg if isinstance(value, int): ast_node.value = unparse.create_num(value) @@ -89,7 +104,7 @@ class unparse: @staticmethod def _create_keywords(keywords: Dict): keywords_ = [] - for key, value in keywords: + for key, value in keywords.items(): k = unparse.create_keyword(key, value) keywords_.append(k) @@ -98,4 +113,117 @@ class unparse: @staticmethod def create_str(value): ast_node = ast.Str(lineno=0, col_offset=0, s=value) - return ast_node \ No newline at end of file + return ast_node + + @staticmethod + def _create_args(args: list) -> list: + """ + create args in function + + :param args: input function args list, can only contain int or str or ast.AST + :return: args list in function + """ + args_ = [] + for arg in args: + if isinstance(arg, int): + args_.append(unparse.create_constant(arg)) + elif isinstance(arg, str): + if '.' in arg: + value = arg.split(".")[-2] + attr = arg.split(".")[-1] + args_.append(unparse.create_attribute(value, attr)) + else: + args_.append(unparse.create_name(arg)) + elif isinstance(arg, ast.AST): + args_.append(arg) + else: + raise ValueError("unsupported type when creating args") + return args_ + + @staticmethod + def create_binop(op_name: str, left, right ) -> ast.BinOp: + """ + create binop node + + :param op_name: operator name, support Add/Sub/Mult/Div now + :param left: left value of binop node, can only be int/str("x" or "self.fun(x)" etc)/ast.Ast + :param right: right value of binop node, can only be int/str("x" or "self.fun(x)" etc)/ast.Ast + :return: ast binop node + """ + def get_single_side_value(left_or_right): + if isinstance(left_or_right, int): + node = unparse.create_constant(left_or_right) + elif isinstance(left_or_right, str): + if '.' in left_or_right: + func_name = left_or_right.split('(')[0] + all_args_ = left_or_right.split('(')[1].split(')')[0].split(',') + all_args_ = [_.strip() for _ in all_args_] + + args_ = [] + keywords_ = {} + for single_arg in all_args_: + if '=' not in single_arg: + args_.append(single_arg) + else: + key, value = single_arg.split('=') + keywords_[key] = value + + node = unparse.create_call(func_name, args_, keywords_) + else: + node = unparse.create_name(left_or_right) + elif isinstance(left_or_right, ast.AST): + node = left_or_right + else: + raise ValueError("unsupported value in create binop") + + return node + + ast_node = ast.BinOp(lineno=0, col_offset=0) + op_dict = { + "Add": ast.Add(), + "Sub": ast.Sub(), + "Mult": ast.Mult(), + "Div": ast.Div() + } + ast_node.op = op_dict[op_name] + ast_node.left = get_single_side_value(left) + ast_node.right = get_single_side_value(right) + return ast_node + + @staticmethod + def create_constant(value) -> ast.Constant: + """ + create ast constant node + + :param value: node value, can only be int/str + :return: ast Constant node + """ + if isinstance(value, int) or isinstance(value, str): + ast_node = ast.Constant(lineno=0, col_offset=0, s=value) + else: + raise ValueError("unsupported value in create constant") + + return ast_node + + @staticmethod + def create_arguments(arguments_input: list) -> ast.arguments: + """ + create ast arguments node + + :param arguments_input: arguments_input used in define functions, can only contains str + :return: ast arguments node + """ + ast_node = ast.arguments(defaults=[],) + ast_node.args = [] + for argument in arguments_input: + if isinstance(argument, str): + ast_node.args.append( + ast.arg( + lineno=0, + col_offset=0, + arg=argument, + )) + else: + raise ValueError("unsupported value in create_arguments") + + return ast_node diff --git a/mindspore/python/mindspore/rewrite/unparse_test.py b/mindspore/python/mindspore/rewrite/unparse_test.py new file mode 100644 index 00000000000..ce749d9ea39 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/unparse_test.py @@ -0,0 +1,51 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""unparse test""" +import unittest +import ast +import astunparse + +from unparse import unparse + + +class UnparserTestCase(unittest.TestCase): + def test_create_binop_1(self): + node = unparse.create_binop("Add", "self.my(x)", "y") + source_code = astunparse.unparse(node) + self.assertEqual(source_code, "(self.my(x) + y)\n") + + def test_create_binop_2(self): + node = unparse.create_binop("Add", "x", "y") + source_code = astunparse.unparse(node) + self.assertEqual(source_code, "(x + y)\n") + + def test_create_args(self): + node = unparse.create_call("self.myfun", [1, "x", "self.y"], {"flag": 0}) + source_code = astunparse.unparse(node) + self.assertEqual(source_code, "self.myfun(1, x, self.y, flag=0)\n") + + def test_create_arguments(self): + node = ast.FunctionDef( + name='myfun', + args=unparse.create_arguments(["self", "x", "y", "input_data"]), + body=ast.Pass(), + decorator_list=[], + ) + source_code = astunparse.unparse(node) + self.assertEqual(source_code, "\n\ndef myfun(self, x, y, input_data):\n pass\n") + + +if __name__ == "__main__": + unittest.main() -- Gitee From 2f072a72df177c4842b8538491e161647b25500e Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Thu, 23 Dec 2021 19:11:19 +0800 Subject: [PATCH 14/34] add rewrite-experiment --- cmake/package.cmake | 1 + mindspore/python/mindspore/__init__.py | 4 +- .../mindspore/rewrite_experiment/__init__.py | 6 + .../mindspore/rewrite_experiment/compiler.py | 75 ++++++ .../mindspore/rewrite_experiment/graph.py | 55 +++++ .../mindspore/rewrite_experiment/node.py | 220 ++++++++++++++++++ .../mindspore/rewrite_experiment/observer.py | 22 ++ .../rewrite_experiment/pass_register.py | 39 ++++ .../rewrite_experiment/passes/__init__.py | 0 .../passes/class_def_pass.py | 31 +++ .../rewrite_experiment/passes/module_pass.py | 43 ++++ .../mindspore/rewrite_experiment/passs.py | 38 +++ .../mindspore/rewrite_experiment/subject.py | 22 ++ .../mindspore/rewrite_experiment/test.py | 73 ++++++ 14 files changed, 628 insertions(+), 1 deletion(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/__init__.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/compiler.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/graph.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/node.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/observer.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/pass_register.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/passes/__init__.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/passs.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/subject.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/test.py diff --git a/cmake/package.cmake b/cmake/package.cmake index c4c4d1aa320..51382f65840 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -320,6 +320,7 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/communication ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/profiler ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite + ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite_experiment ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/golden_stick ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check DESTINATION ${INSTALL_PY_DIR} diff --git a/mindspore/python/mindspore/__init__.py b/mindspore/python/mindspore/__init__.py index ad9abce3011..8ff8f866d00 100755 --- a/mindspore/python/mindspore/__init__.py +++ b/mindspore/python/mindspore/__init__.py @@ -21,11 +21,12 @@ from .common import * from .mindrecord import * from .ops import _op_impl from .train import * -from .rewrite import * from .log import * from .context import * from .version import __version__ from .golden_stick import * +from .rewrite import * +from .rewrite_experiment import * __all__ = ["run_check"] @@ -35,4 +36,5 @@ __all__.extend(train.__all__) __all__.extend(log.__all__) __all__.extend(context.__all__) __all__.extend(rewrite.__all__) +__all__.extend(rewrite_experiment.__all__) __all__.extend(golden_stick.__all__) diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py new file mode 100644 index 00000000000..b17e28bf416 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -0,0 +1,6 @@ +from .graph import Graph +from .compiler import Compiler +from .passes.module_pass import ModulePass +from .passes.class_def_pass import ClassDefPass + +__all__ = ["Graph", "Compiler"] diff --git a/mindspore/python/mindspore/rewrite_experiment/compiler.py b/mindspore/python/mindspore/rewrite_experiment/compiler.py new file mode 100644 index 00000000000..db43acf6485 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compiler.py @@ -0,0 +1,75 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from typing import Union + +from .graph import Graph +from .node import Node, NodeType +from .passs import Pass +from .pass_register import PassRegister +from mindspore.nn import Cell +import mindspore.nn as nn +from mindspore.ops.primitive import Primitive +from types import FunctionType +from mindspore import log as logger + + +class Compiler: + @staticmethod + def _is_leaf_node(node: Node): + # inseparable python node: + if node.node_type() is NodeType.constant: + return False + # not supported yet: + # if node.node_type() is NodeType.graph: + # return False + # mindspore ops: + ms_ops: tuple = (Cell,) + if node.node_type() is NodeType.call_cell and issubclass(node.class_type(), ms_ops): + return False + return True + + @staticmethod + def compile(network: Union[nn.Cell, Primitive, FunctionType]): + graph = Graph(network) + passes: [Pass] = PassRegister.instance().get_passes() + logger.warning("Load passes: %d", len(passes)) + while True: + changed = False + for node in graph.nodes: # todo how to replace while iterating + if not Compiler._is_leaf_node(node): + continue + for key in passes: + pass_ = passes[key] + result: Node = pass_.process(node) + if result is node: + continue + logger.warning("Changed: new node: %s", result.get_name()) + changed = True + result.set_targets(node.get_targets()) + result.set_outputs(node.get_outputs()) + users = node.get_outputs() + for user in users: + new_inputs = [] + for user_input in user.get_inputs(): + if user_input is node: + new_inputs.append(result) + else: + new_inputs.append(user_input) + user.set_inputs(new_inputs) + if graph.get_return() is node: + graph.set_return(result) + if not changed: + break + return graph diff --git a/mindspore/python/mindspore/rewrite_experiment/graph.py b/mindspore/python/mindspore/rewrite_experiment/graph.py new file mode 100644 index 00000000000..d745d853d6c --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/graph.py @@ -0,0 +1,55 @@ +import inspect +from types import FunctionType +from typing import Union +import ast +import astpretty + +import mindspore.nn as nn +from mindspore import log as logger +from mindspore.ops.primitive import Primitive + +from .node import Node +from .observer import Observer + + +class Graph(Observer): + def update(self): + pass + + def __init__(self, network: Union[nn.Cell, Primitive, FunctionType]): + if not isinstance(network, nn.Cell): + logger.error("Only support network with Cell type now") + return + + # self._placeholders: List[Node] = [] + # self._contant_nodes: List[ConstantNode] = [] + # self._param_default_value: Dict = {} + # self._node_attributes: Dict = {} + # self._symbol_table: dict = {} + self._net_cls = type(network) + self._name = self._net_cls.__name__ + self._base_scope = self._net_cls.__name__ + self._network = network + network_str = inspect.getsource(self._net_cls) + self._ast_root: ast.AST = ast.parse(network_str) + + root_node = Node(self._ast_root, self._name, self._net_cls) + self._nodes: [Node] = [root_node] + self._return = root_node + + + @property + def nodes(self) -> list: + """ + 返回graph的节点,可以迭代访问,这些节点中应该还要包含init中的子图,在pattern匹配的时候会出现该问题 + """ + return self._nodes + + def get_return(self): + return self._return + + def set_return(self, node: Node): + self._return = node + + def print_ast(self): + astpretty.pprint(self._ast_root) \ No newline at end of file diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py new file mode 100644 index 00000000000..8655abbf560 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -0,0 +1,220 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast +from typing import Dict, List, Union, Optional + +import mindspore.nn as nn +from mindspore.ops.primitive import Primitive +from .subject import Subject +from .observer import Observer + + +class NodeType: + placeholder = 1 # input + parameter = 2 # weight + constant = 3 + call_cell = 4 # call cell object + call_method = 5 # method in cell + call_function = 6 # subclass of primitive + output = 7 + graph = 8 + invalid = 100 + + +class Node(Subject): + """ + Base class of node. + + Args: + name: the name of node. + targets: the output names of the node. + args: the input names of the node. + inputs: the input nodes of this node. + """ + + def notify(self): + # self._observer.update() + pass + + def __init__(self, ast_node: ast.AST, name, cls: type): + """ + 创建一个节点时对应的属性怎么传进来,cell应该不涉及,primitive会有这种情况 + """ + self._name: str = name + # self._attribute: AttributeNode = AttributeNode() + # if outputs is None: + # self._outputs: List[CellNode] = list() + # else: + # self._outputs = outputs + # if inputs is None: + # self._inputs: List[CellNode] = list() + # else: + # self._inputs = inputs + # self._targets: List[str] = targets # 用来保存算子输出结果的名称,用来匹配算子输入名称 + # self._args: List = args + self._type: int = NodeType.invalid + self._ast_root: ast.AST = ast_node + self._ast_processing: ast.AST = ast_node + self._atomic: bool = False + self._args: [str] = [] + self._kwargs: {} = {} + self._targets: [str] = [] + self._inputs: [Node] = [] + self._outputs: [Node] = [] + self._cls: type = cls + + def get_name(self) -> str: + return self._name + + def set_name(self, name: str): + self._name = name + + def get_inputs(self) -> list: + return self._inputs + + def set_inputs(self, nodes: list): + self._inputs = nodes + + def get_outputs(self) -> List: + return self._outputs + + def set_outputs(self, nodes: list): + self._outputs = nodes + + def set_targets(self, targets: [str]): + self._targets = targets + + def get_targets(self): + return self._targets + + def class_type(self): + return self._cls + + def node_type(self) -> int: + return self._type + + def get_processing_ast(self): + return self._ast_processing + + def get_ast(self): + return self._ast_root + + def set_ast(self, ast_node: ast.AST): + self._ast_root = ast_node + + # def attribute(self) -> AttributeNode: + # return self._attribute + # + # def attribute(self, attribute: AttributeNode): + # self._attribute = attribute + + # def set_attributes(self, attribute: Dict): + # for key, value in attribute.items(): + # self._attribute.attribute[key] = value + # + # def set_attribute(self, key: str, value): + # self._attribute._attribute[key] = value + # + # def get_attribute(self, key: str): + # return self._attribute._attribute.get(key) + + +# class CellNode(Node): +# """ +# 'Node' is the main data structure that represents individual operations within a 'Graph'. +# """ +# +# def __init__(self, name="", targets=None, args=None, ast_node=None, instance: Union[nn.Cell, Primitive] = None, +# inputs: List = None, outputs: List = None): +# """ +# 创建一个节点时对应的属性怎么传进来,cell应该不涉及,primitive会有这种情况 +# """ +# super().__init__(name, targets, args, inputs, outputs) +# self._kwargs: Dict = {} +# self._scope: str = "" +# self._ast_node: ast.AST = ast_node +# self._index = 0 +# self._attribute._class = type(instance) +# +# @property +# def type(self): +# return self._attribute._class +# +# @type.setter +# def type(self, cell_type): +# self._attribute._class = cell_type +# +# def set_cell(self, cell: nn.Cell): +# self._attribute._class = NodeType.call_cell +# +# def __repr__(self): +# input_names = "" +# input_nodes = "" +# output_names = "" +# output_nodes = "" +# +# for n in self.inputs: +# input_names += n.name + ", " +# # input_nodes += str(n) +# +# for n in self.outputs: +# output_names += n.name + ", " +# # output_nodes += str(n) +# +# return f"name: {self._name}; ast_node: {self._ast_node}; scope: {self._scope}; index: {self._index}; inputs: {len(self.inputs)}; input names: {input_names}; outputs: {len(self.outputs)}; output names: {output_names}; attr info: {self._attribute}" +# +# +# class ConstantNode(Node): +# """ +# 'ConstantNode' is used to save constants. +# """ +# +# def __init__(self, name="constant", value=None): +# super().__init__(name=name, args=[], targets=[]) +# self._value = value +# self._args.append(value) +# self._attribute.type = NodeType.constant +# +# def __repr__(self) -> str: +# output_names = "" +# for n in self.outputs: +# output_names += n.name + ", " +# return f"name: {self._name}; value: {self._value}; outputs: {len(self.outputs)}; output names: {output_names}" +# +# +# class PlaceholderNode(Node): +# """ +# 'PlaceholderNode' is used to represent inputs of Cell, method or function. +# """ +# +# def __init__(self, name, targets=None, ast_node=None, default_value=None): +# super().__init__(name, targets) +# self._ast_node = ast_node +# self._default_value = default_value +# self._attribute.type = NodeType.placeholder +# +# def __repr__(self) -> str: +# output_names = "" +# for n in self.outputs: +# output_names += n.name + " " +# return f"name: {self._name}; targets: {self._targets}, outputs: {len(self.outputs)}; output names: {output_names}; attribute: {self._attribute}" +# +# +# class SubgraphNode(Node): +# ... +# +# class ControlFlowCellNode(Node): +# def __init__(self): +# self.subgraph = None diff --git a/mindspore/python/mindspore/rewrite_experiment/observer.py b/mindspore/python/mindspore/rewrite_experiment/observer.py new file mode 100644 index 00000000000..614e20a3c6f --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/observer.py @@ -0,0 +1,22 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import abc + + +class Observer(abc.ABC): + @abc.abstractmethod + def update(self): + ... diff --git a/mindspore/python/mindspore/rewrite_experiment/pass_register.py b/mindspore/python/mindspore/rewrite_experiment/pass_register.py new file mode 100644 index 00000000000..402551f8ecd --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/pass_register.py @@ -0,0 +1,39 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from .passs import Pass + + +class PassRegister: + def __init__(self): + self._passes: dict = {} + + @classmethod + def instance(cls): + if not hasattr(PassRegister, "_instance"): + PassRegister._instance = PassRegister() + return PassRegister._instance + + @staticmethod + def reg_pass(pass_cls: type): + if issubclass(pass_cls, Pass): + pass_ = pass_cls() + PassRegister.instance()._passes[pass_.name()] = pass_ + + def get_pass(self, name: str): + return self._passes.get(name) + + def get_passes(self): + return self._passes diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/__init__.py b/mindspore/python/mindspore/rewrite_experiment/passes/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py new file mode 100644 index 00000000000..5975c81101c --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py @@ -0,0 +1,31 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast +from ..node import Node +from ..passs import Pass +from ..pass_register import PassRegister + + +@PassRegister.reg_pass +class ClassDefPass(Pass): + def name(self) -> str: + return "_ClassDefPass_" + + def process(self, node: Node) -> Node: + print("In ClassDefPass, Processing node: ", node.get_name(), type(node.get_ast())) + if not isinstance(node.get_ast(), ast.ClassDef): + return node + + return node diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py new file mode 100644 index 00000000000..085267bcda3 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py @@ -0,0 +1,43 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast +from ..node import Node +from ..passs import Pass +from ..pass_register import PassRegister +from mindspore import log as logger + + +@PassRegister.reg_pass +class ModulePass(Pass): + def name(self) -> str: + return "_ModulePass_" + + def process(self, node: Node) -> Node: + print("In ModulePass, Processing node: ", node.get_name(), type(node.get_ast())) + if not isinstance(node.get_ast(), ast.Module): + return node + module: ast.Module = node.get_ast() + bodies: list = module.body + for body in bodies: + if isinstance(body, ast.ClassDef): + node.set_ast(body) + node.set_name(body.name) + else: + if hasattr(body, "name"): + logger.warning("Ignoring node(%s) in Module", body.name) + else: + logger.warning("Ignoring node(%s) in Module", body) + + return node diff --git a/mindspore/python/mindspore/rewrite_experiment/passs.py b/mindspore/python/mindspore/rewrite_experiment/passs.py new file mode 100644 index 00000000000..ebef3b9e15f --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/passs.py @@ -0,0 +1,38 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import abc +from .node import Node + + +class Pass(abc.ABC): + """ + Pass eat a node and return a node. It processes a node by parse one type of ast node further one more + """ + + @abc.abstractmethod + def name(self) -> str: + raise NotImplementedError + + @abc.abstractmethod + def process(self, node: Node) -> Node: + """ + Args: + node (Node): node who is tried to be processed + Returns: + node after processed. Function should keep the inputs of output node. + return node can not be None + """ + raise NotImplementedError diff --git a/mindspore/python/mindspore/rewrite_experiment/subject.py b/mindspore/python/mindspore/rewrite_experiment/subject.py new file mode 100644 index 00000000000..cb9f992d578 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/subject.py @@ -0,0 +1,22 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import abc + + +class Subject(abc.ABC): + @abc.abstractmethod + def notify(self): + ... diff --git a/mindspore/python/mindspore/rewrite_experiment/test.py b/mindspore/python/mindspore/rewrite_experiment/test.py new file mode 100644 index 00000000000..aeac180936d --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/test.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test Transformer.""" + +from mindspore import nn +from mindspore.common.initializer import Normal +from mindspore.rewrite_experiment import Compiler, Graph +from mindspore import log as logger + + +class LeNet5(nn.Cell): + """ + Lenet network + + Args: + num_class (int): Number of classes. Default: 10. + num_channel (int): Number of channels. Default: 1. + + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + + """ + def __init__(self, num_class=10, num_channel=1, include_top=True): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.include_top = include_top + if self.include_top: + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + if not self.include_top: + return x + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +def test_compile(): + lenet = LeNet5(10) + graph: Graph = Compiler.compile(lenet) + logger.warning("After compile: %d", len(graph.nodes)) + + +if __name__ == '__main__': + test_compile() -- Gitee From ec36569a3199bf22bfcafd8ed53a0b0614766af3 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Fri, 24 Dec 2021 16:18:05 +0800 Subject: [PATCH 15/34] add symbol_table symbol --- .../mindspore/rewrite/pattern_engine.py | 1 + .../mindspore/rewrite_experiment/__init__.py | 3 +- .../mindspore/rewrite_experiment/compiler.py | 99 ++++++++++++------- .../mindspore/rewrite_experiment/graph.py | 4 +- .../mutable_dict_iterator.py | 51 ++++++++++ .../mindspore/rewrite_experiment/node.py | 23 +++-- .../passes/class_def_pass.py | 25 +++-- .../rewrite_experiment/passes/module_pass.py | 22 +++-- .../mindspore/rewrite_experiment/passs.py | 7 +- .../mindspore/rewrite_experiment/symbol.py | 58 +++++++++++ .../rewrite_experiment/symbol_table.py | 55 +++++++++++ .../mindspore/rewrite_experiment/test.py | 6 +- 12 files changed, 280 insertions(+), 74 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/mutable_dict_iterator.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/symbol.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/symbol_table.py diff --git a/mindspore/python/mindspore/rewrite/pattern_engine.py b/mindspore/python/mindspore/rewrite/pattern_engine.py index ba66c5c3cff..f31d6644c16 100644 --- a/mindspore/python/mindspore/rewrite/pattern_engine.py +++ b/mindspore/python/mindspore/rewrite/pattern_engine.py @@ -227,6 +227,7 @@ class PatternEngine: elif new_node == cur_node: # return origin Node for do nothing pass else: # return Node to insert or replace (new Node no need to set inputs and outputs) + # todo if we need to support _process_chain or _process_tree return multi-node changed = True graph.replace_node(matched_list, new_node) node_inputs = new_node.inputs diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py index b17e28bf416..b96cd98338b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -1,6 +1,7 @@ from .graph import Graph +from .symbol_table import SymbolTable from .compiler import Compiler from .passes.module_pass import ModulePass from .passes.class_def_pass import ClassDefPass -__all__ = ["Graph", "Compiler"] +__all__ = ["Graph", "Compiler", "SymbolTable"] diff --git a/mindspore/python/mindspore/rewrite_experiment/compiler.py b/mindspore/python/mindspore/rewrite_experiment/compiler.py index db43acf6485..62a267412cd 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compiler.py @@ -14,62 +14,85 @@ # ============================================================================ from typing import Union -from .graph import Graph -from .node import Node, NodeType +from .symbol_table import SymbolTable +from .symbol import Symbol, SymbolType from .passs import Pass from .pass_register import PassRegister -from mindspore.nn import Cell import mindspore.nn as nn from mindspore.ops.primitive import Primitive from types import FunctionType from mindspore import log as logger +from .mutable_dict_iterator import MutableDictIterator class Compiler: @staticmethod - def _is_leaf_node(node: Node): + def _is_leaf_symbol(symbol: Symbol): # inseparable python node: - if node.node_type() is NodeType.constant: + if symbol.symbol_type() is SymbolType.constant: return False # not supported yet: # if node.node_type() is NodeType.graph: # return False # mindspore ops: - ms_ops: tuple = (Cell,) - if node.node_type() is NodeType.call_cell and issubclass(node.class_type(), ms_ops): - return False + # ms_ops: tuple = (Cell,) + # if node.node_type() is NodeType.call_cell and issubclass(node.class_type(), ms_ops): + # return False return True @staticmethod - def compile(network: Union[nn.Cell, Primitive, FunctionType]): - graph = Graph(network) + def compile_symbol_by_pass(iterator: MutableDictIterator, pass_: Pass) -> MutableDictIterator: + symbol: Symbol = iterator.value() + if not Compiler._is_leaf_symbol(symbol): # not compilable, skip + logger.warning("Processing symbol(%s) by pass(%s), leaf symbol", symbol.get_symbol_name(), pass_.name()) + return iterator + results: [Symbol] = pass_.process(symbol) + if len(results) == 1 and results[0] is symbol: # no change in process + logger.warning("Processing symbol(%s) by pass(%s), not changed", symbol.get_symbol_name(), pass_.name()) + return iterator + iterator = iterator.erase() + logger.warning("Processing symbol(%s) by pass(%s), replaced by %d new symbols", symbol.get_symbol_name(), + pass_.name(), len(results)) + for result in results: + iterator = iterator.insert(result.get_symbol_name(), result) + if result.symbol_type() != SymbolType.cell: + continue + raise NotImplementedError + # result.set_targets(node.get_targets()) + # result.set_outputs(node.get_outputs()) + # users = node.get_outputs() + # for user in users: + # new_inputs = [] + # for user_input in user.get_inputs(): + # if user_input is node: + # new_inputs.append(result) + # else: + # new_inputs.append(user_input) + # user.set_inputs(new_inputs) + # if graph.get_return() is node: + # graph.set_return(result) + return iterator + + @staticmethod + def compile_stb_by_pass(stb: SymbolTable, pass_: Pass) -> bool: + changed = False + iterator = MutableDictIterator(stb.get_symbols()) + while not iterator.is_end(): + new_iterator = Compiler.compile_symbol_by_pass(iterator, pass_) + if new_iterator == iterator: + iterator = next(iterator) + else: + changed = True + return changed + + @staticmethod + def compile(network: Union[nn.Cell, Primitive, FunctionType]) -> SymbolTable: + stb = SymbolTable(network) passes: [Pass] = PassRegister.instance().get_passes() - logger.warning("Load passes: %d", len(passes)) - while True: + changed = True + while changed: changed = False - for node in graph.nodes: # todo how to replace while iterating - if not Compiler._is_leaf_node(node): - continue - for key in passes: - pass_ = passes[key] - result: Node = pass_.process(node) - if result is node: - continue - logger.warning("Changed: new node: %s", result.get_name()) - changed = True - result.set_targets(node.get_targets()) - result.set_outputs(node.get_outputs()) - users = node.get_outputs() - for user in users: - new_inputs = [] - for user_input in user.get_inputs(): - if user_input is node: - new_inputs.append(result) - else: - new_inputs.append(user_input) - user.set_inputs(new_inputs) - if graph.get_return() is node: - graph.set_return(result) - if not changed: - break - return graph + for key in passes: + pass_: Pass = passes[key] + changed |= Compiler.compile_stb_by_pass(stb, pass_) + return stb diff --git a/mindspore/python/mindspore/rewrite_experiment/graph.py b/mindspore/python/mindspore/rewrite_experiment/graph.py index d745d853d6c..180936b9d3e 100644 --- a/mindspore/python/mindspore/rewrite_experiment/graph.py +++ b/mindspore/python/mindspore/rewrite_experiment/graph.py @@ -8,7 +8,7 @@ import mindspore.nn as nn from mindspore import log as logger from mindspore.ops.primitive import Primitive -from .node import Node +from .node import Node, NodeType from .observer import Observer @@ -33,7 +33,7 @@ class Graph(Observer): network_str = inspect.getsource(self._net_cls) self._ast_root: ast.AST = ast.parse(network_str) - root_node = Node(self._ast_root, self._name, self._net_cls) + root_node = Node(self._ast_root, self._name, self._net_cls, NodeType.module) self._nodes: [Node] = [root_node] self._return = root_node diff --git a/mindspore/python/mindspore/rewrite_experiment/mutable_dict_iterator.py b/mindspore/python/mindspore/rewrite_experiment/mutable_dict_iterator.py new file mode 100644 index 00000000000..40f2354928c --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/mutable_dict_iterator.py @@ -0,0 +1,51 @@ +from typing import Iterator + + +class MutableDictIterator(Iterator): + def __init__(self, dict_: dict): + self._dict = dict_ + self._keys = list(self._dict.keys()) + self._limit = len(self._keys) + self._curse = 0 + + def __eq__(self, other: 'MutableDictIterator'): + return self._curse == other._curse and self._limit == other._limit and self._dict == other._dict + + def is_end(self): + return self._curse >= self._limit + + def __iter__(self): + return self + + def __next__(self) -> 'MutableDictIterator': + if not self.is_end(): + self._curse += 1 + return self + + def insert(self, key: str, value) -> 'MutableDictIterator': + if self._dict.get(key) is None: + self._keys.insert(self._curse, key) + self._curse += 1 + self._dict[key] = value + self._limit = len(self._keys) + return self + + def erase(self) -> 'MutableDictIterator': + key = self.key() + if key is not None: + self._dict.pop(key) + self._keys.pop(self._curse) + self._limit = len(self._keys) + return self + + def key(self): + if self.is_end(): + return None + else: + return self._keys[self._curse] + + def value(self): + if self.is_end(): + return None + else: + return self._dict.get(self._keys[self._curse]) diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index 8655abbf560..07cfb61af0c 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -22,14 +22,17 @@ from .observer import Observer class NodeType: - placeholder = 1 # input - parameter = 2 # weight - constant = 3 - call_cell = 4 # call cell object - call_method = 5 # method in cell - call_function = 6 # subclass of primitive - output = 7 - graph = 8 + module = 1 + class_def = 2 + function_def = 3 + placeholder = 4 # input + parameter = 5 # weight + constant = 6 + call_cell = 7 # call cell object + call_method = 8 # method in cell + call_function = 9 # subclass of primitive + output = 10 + graph = 11 invalid = 100 @@ -48,7 +51,7 @@ class Node(Subject): # self._observer.update() pass - def __init__(self, ast_node: ast.AST, name, cls: type): + def __init__(self, ast_node: ast.AST, name, cls: type, node_type=NodeType.invalid): """ 创建一个节点时对应的属性怎么传进来,cell应该不涉及,primitive会有这种情况 """ @@ -64,7 +67,7 @@ class Node(Subject): # self._inputs = inputs # self._targets: List[str] = targets # 用来保存算子输出结果的名称,用来匹配算子输入名称 # self._args: List = args - self._type: int = NodeType.invalid + self._type: int = node_type self._ast_root: ast.AST = ast_node self._ast_processing: ast.AST = ast_node self._atomic: bool = False diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py index 5975c81101c..b2492663232 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py @@ -13,9 +13,10 @@ # limitations under the License. # ============================================================================ import ast -from ..node import Node +from ..symbol import Symbol, SymbolType from ..passs import Pass from ..pass_register import PassRegister +from mindspore import log as logger @PassRegister.reg_pass @@ -23,9 +24,19 @@ class ClassDefPass(Pass): def name(self) -> str: return "_ClassDefPass_" - def process(self, node: Node) -> Node: - print("In ClassDefPass, Processing node: ", node.get_name(), type(node.get_ast())) - if not isinstance(node.get_ast(), ast.ClassDef): - return node - - return node + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.ClassDef): + return [symbol] + class_def: ast.ClassDef = symbol.get_ast() + bodies: list = class_def.body + new_symbols: [Symbol] = [] + for body in bodies: + if isinstance(body, ast.FunctionDef): + new_node = Symbol(body, body.name, SymbolType.function_def) + new_symbols.append(new_node) + else: + if hasattr(body, "name"): + logger.warning("Ignoring node(%s) in Module", body.name) + else: + logger.warning("Ignoring node(%s) in Module", body) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py index 085267bcda3..c7015f485fd 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ import ast -from ..node import Node +from ..symbol import Symbol, SymbolType from ..passs import Pass from ..pass_register import PassRegister from mindspore import log as logger @@ -24,20 +24,22 @@ class ModulePass(Pass): def name(self) -> str: return "_ModulePass_" - def process(self, node: Node) -> Node: - print("In ModulePass, Processing node: ", node.get_name(), type(node.get_ast())) - if not isinstance(node.get_ast(), ast.Module): - return node - module: ast.Module = node.get_ast() + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.Module): + return [symbol] + module: ast.Module = symbol.get_ast() bodies: list = module.body + new_symbols: [Symbol] = [] for body in bodies: if isinstance(body, ast.ClassDef): - node.set_ast(body) - node.set_name(body.name) + new_node = Symbol(body, body.name, SymbolType.class_def) + new_symbols.append(new_node) else: if hasattr(body, "name"): logger.warning("Ignoring node(%s) in Module", body.name) else: logger.warning("Ignoring node(%s) in Module", body) - - return node + if len(new_symbols) == 0: + return [symbol] + else: + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/passs.py b/mindspore/python/mindspore/rewrite_experiment/passs.py index ebef3b9e15f..b0d4f522f99 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passs.py +++ b/mindspore/python/mindspore/rewrite_experiment/passs.py @@ -14,7 +14,7 @@ # ============================================================================ import abc -from .node import Node +from .symbol import Symbol class Pass(abc.ABC): @@ -27,11 +27,12 @@ class Pass(abc.ABC): raise NotImplementedError @abc.abstractmethod - def process(self, node: Node) -> Node: + def process(self, symbol: Symbol) -> [Symbol]: """ Args: - node (Node): node who is tried to be processed + symbol (Symbol): node who is tried to be processed Returns: + Symbols after processed. Function should keep the inputs of output node. node after processed. Function should keep the inputs of output node. return node can not be None """ diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py new file mode 100644 index 00000000000..788a2eb9bd1 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -0,0 +1,58 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + + +class SymbolType: + module = 1 + class_def = 2 + function_def = 3 + constant = 4 + cell = 5 + primitive = 6 + invalid = 100 + + +class Symbol: + def __init__(self, ast_node: ast.AST, symbol_name, symbol_type=SymbolType.invalid): + # self._attribute: AttributeNode = AttributeNode() + # if outputs is None: + # self._outputs: List[CellNode] = list() + # else: + # self._outputs = outputs + # if inputs is None: + # self._inputs: List[CellNode] = list() + # else: + # self._inputs = inputs + # self._targets: List[str] = targets # 用来保存算子输出结果的名称,用来匹配算子输入名称 + # self._args: List = args + self._symbol_name: str = symbol_name + self._symbol_type: int = symbol_type + self._ast_root: ast.AST = ast_node + + def get_symbol_name(self) -> str: + return self._symbol_name + + def set_symbol_name(self, symbol_name: str): + self._symbol_name = symbol_name + + def symbol_type(self) -> int: + return self._symbol_type + + def get_ast(self): + return self._ast_root + + def set_ast(self, ast_node: ast.AST): + self._ast_root = ast_node diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_table.py b/mindspore/python/mindspore/rewrite_experiment/symbol_table.py new file mode 100644 index 00000000000..f05c0957a7a --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_table.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from types import FunctionType +from typing import Union + +import inspect +import ast +import astpretty + +import mindspore.nn as nn +from mindspore import log as logger +from mindspore.ops.primitive import Primitive +from .symbol import Symbol, SymbolType + + +class SymbolTable: + def __init__(self, network: Union[nn.Cell, Primitive, FunctionType]): + if not isinstance(network, nn.Cell): + logger.error("Only support network with Cell type now") + return + + self._network = network + self._net_cls = type(network) + network_str = inspect.getsource(self._net_cls) + self._ast_root: ast.AST = ast.parse(network_str) + name = self._net_cls.__name__ + root_symbol = Symbol(self._ast_root, name, SymbolType.module) + self._table = {name: root_symbol} + + def get_symbols(self): + return self._table + + def remove_symbol(self, symbol: Symbol): + print("========== Remove symbol: ", symbol) + symbol_ = self._table.get(symbol.get_symbol_name()) + if symbol_ is not None: + self._table.pop(symbol.get_symbol_name()) + + def add_symbol(self, symbol: Symbol): + print("========== Add symbol: ", symbol) + symbol_ = self._table.get(symbol.get_symbol_name()) + if symbol_ is None: + self._table[symbol.get_symbol_name()] = symbol diff --git a/mindspore/python/mindspore/rewrite_experiment/test.py b/mindspore/python/mindspore/rewrite_experiment/test.py index aeac180936d..3a158f3a28d 100644 --- a/mindspore/python/mindspore/rewrite_experiment/test.py +++ b/mindspore/python/mindspore/rewrite_experiment/test.py @@ -16,7 +16,7 @@ from mindspore import nn from mindspore.common.initializer import Normal -from mindspore.rewrite_experiment import Compiler, Graph +from mindspore.rewrite_experiment import Compiler, SymbolTable from mindspore import log as logger @@ -65,8 +65,8 @@ class LeNet5(nn.Cell): def test_compile(): lenet = LeNet5(10) - graph: Graph = Compiler.compile(lenet) - logger.warning("After compile: %d", len(graph.nodes)) + stb: SymbolTable = Compiler.compile(lenet) + logger.warning("After compile: %d", len(stb.get_symbols().values())) if __name__ == '__main__': -- Gitee From 2140f4130375bdc1b43e654cf9aa0f15ca2620b9 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Mon, 20 Dec 2021 20:38:54 +0800 Subject: [PATCH 16/34] divide fakequantize op by its feature --- .../python/mindspore/golden_stick/__init__.py | 10 +- .../example/default_qat_example.py | 12 +- .../golden_stick/quantization/__init__.py | 8 +- .../quantization/default_qat/__init__.py | 6 +- .../default_qat/default_fake_quantizer.py | 177 ++++++++++++++++++ .../default_qat/default_layer_policy.py | 26 ++- .../default_qat/default_net_policy.py | 20 +- .../default_qat/default_quantizer.py | 68 ------- .../{quantizer.py => fake_quantizer.py} | 36 +--- .../quantization/hello_qat/simple_qat.py | 3 +- .../golden_stick/quantization/layer_policy.py | 12 +- .../quantization/quantize_wrapper_act.py | 4 +- .../quantization/quantize_wrapper_cell.py | 4 +- mindspore/python/mindspore/nn/layer/quant.py | 120 +++++++++++- 14 files changed, 365 insertions(+), 141 deletions(-) create mode 100644 mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py delete mode 100644 mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantizer.py rename mindspore/python/mindspore/golden_stick/quantization/{quantizer.py => fake_quantizer.py} (42%) diff --git a/mindspore/python/mindspore/golden_stick/__init__.py b/mindspore/python/mindspore/golden_stick/__init__.py index 608a949f2ae..0b388adbb9a 100644 --- a/mindspore/python/mindspore/golden_stick/__init__.py +++ b/mindspore/python/mindspore/golden_stick/__init__.py @@ -18,9 +18,9 @@ MindSpore golden stick module. from .golden_stick import GoldenStick from .net_transform import NetTransformer -from .quantization import LayerPolicy, NetPolicy, QuantAwareTraining, Quantizer, Transformer, AllValueQuantizer, \ - LastValueQuantizer, LSQ, DefaultLayerPolicy, DefaultNetworkPolicy, DefaultQuantAwareTraining +from .quantization import LayerPolicy, NetPolicy, QuantAwareTraining, FakeQuantizer, \ + Transformer, AllValueFakeQuantizer, DefaultLayerPolicy, DefaultNetworkPolicy, DefaultQuantAwareTraining -__all__ = ["GoldenStick", "NetTransformer", "LayerPolicy", "NetPolicy", "QuantAwareTraining", "Quantizer", - "Transformer", "AllValueQuantizer", "LastValueQuantizer", "LSQ", "DefaultLayerPolicy", - "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] +__all__ = ["GoldenStick", "NetTransformer", "LayerPolicy", "NetPolicy", "QuantAwareTraining", "FakeQuantizer", + "Transformer", "AllValueFakeQuantizer", "DefaultLayerPolicy", "DefaultNetworkPolicy", + "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/golden_stick/example/default_qat_example.py b/mindspore/python/mindspore/golden_stick/example/default_qat_example.py index d233b7e9ad1..3e8f08bf760 100644 --- a/mindspore/python/mindspore/golden_stick/example/default_qat_example.py +++ b/mindspore/python/mindspore/golden_stick/example/default_qat_example.py @@ -20,7 +20,7 @@ from mindspore.nn import Cell, Conv2d, BatchNorm2d, Dense, ReLU, MaxPool2d, Flat from mindspore.train.model import Model from ..quantization.transformer import Transformer from ..quantization.layer_policy import LayerPolicy -from ..quantization.quantizer import Quantizer +from ..quantization.fake_quantizer import FakeQuantizer class LeNet5(Cell): @@ -44,7 +44,7 @@ class LeNet5(Cell): # custom quantizer -class AllBitQuantizer(Quantizer): +class AllBitFakeQuantizer(FakeQuantizer): """ Derived class of QuantizeOp. Use min and max value of data to compute scale and zero-point. """ @@ -65,18 +65,18 @@ class AllBitQuantizer(Quantizer): class ConvBNQPolicy(LayerPolicy): def __init__(self): super().__init__() - self._quantizer = AllBitQuantizer() + self._quantizer = AllBitFakeQuantizer() - def get_weight_name_and_quantizers(self) -> [(str, Quantizer)]: + def get_weight_name_and_quantizers(self) -> [(str, FakeQuantizer)]: # todo how to define weight inside of a subgraph return [("_old_conv.weight", self._quantizer), ("_old_bn.gamma", self._quantizer), ("_old_bn.beta", self._quantizer), ("_old_bn.moving_mean", self._quantizer), ("_old_bn.moving_variance", self._quantizer)] - def get_act_name_and_quantizers(self) -> [(str, (Quantizer, Quantizer))]: + def get_act_name_and_quantizers(self) -> [(str, (FakeQuantizer, FakeQuantizer))]: return [] - def get_output_quantizers(self) -> [Quantizer]: + def get_output_quantizers(self) -> [FakeQuantizer]: return [self._quantizer] diff --git a/mindspore/python/mindspore/golden_stick/quantization/__init__.py b/mindspore/python/mindspore/golden_stick/quantization/__init__.py index fae58383ffe..1b2c8d47563 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/__init__.py +++ b/mindspore/python/mindspore/golden_stick/quantization/__init__.py @@ -19,10 +19,10 @@ MindSpore golden stick module. from .layer_policy import LayerPolicy from .net_policy import NetPolicy from .quantize import QuantAwareTraining -from .quantizer import Quantizer +from .fake_quantizer import FakeQuantizer from .transformer import Transformer -from .default_qat import AllValueQuantizer, LastValueQuantizer, LSQ, DefaultLayerPolicy, DefaultNetworkPolicy, \ +from .default_qat import AllValueFakeQuantizer, DefaultLayerPolicy, DefaultNetworkPolicy, \ DefaultQuantAwareTraining -__all__ = ["LayerPolicy", "NetPolicy", "QuantAwareTraining", "Quantizer", "Transformer", "AllValueQuantizer", - "LastValueQuantizer", "LSQ", "DefaultLayerPolicy", "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] +__all__ = ["LayerPolicy", "NetPolicy", "QuantAwareTraining", "FakeQuantizer", "Transformer", "AllValueFakeQuantizer", + "DefaultLayerPolicy", "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py index b2542fa2330..c5914b6d3a9 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py @@ -16,10 +16,10 @@ MindSpore golden stick default-qat-quantization. """ -from .default_quantizer import AllValueQuantizer, LastValueQuantizer, LSQ +from .default_fake_quantizer import AllValueFakeQuantizer, LearnedFakeQuantizerPerLayer from .default_layer_policy import DefaultLayerPolicy from .default_net_policy import DefaultNetworkPolicy from .default_quantize import DefaultQuantAwareTraining -__all__ = ["AllValueQuantizer", "LastValueQuantizer", "LSQ", "DefaultLayerPolicy", "DefaultNetworkPolicy", - "DefaultQuantAwareTraining"] +__all__ = ["AllValueFakeQuantizer", "LearnedFakeQuantizerPerLayer", "DefaultLayerPolicy", + "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py new file mode 100644 index 00000000000..727b3fa82ac --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py @@ -0,0 +1,177 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DefaultQuantizeOp.""" + +from functools import partial +from ..fake_quantizer import FakeQuantizer +from mindspore.ops.operations import _quant_ops as Q +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +import mindspore.context as context +import numpy as np + + +class FixFakeQuantizer(FakeQuantizer): + ... + + +class AllValueFakeQuantizer(FakeQuantizer): + ... + + +class MovingAvgFakeQuantizer(FakeQuantizer): + ... + + +def _calculate_quant_max(num_bits, neg_trunc=False): + if neg_trunc: + quant_max = (1 << num_bits) - 1 + else: + quant_max = (1 << (num_bits - 1)) - 1 + return quant_max + + +class DefaultFakeQuantizerPerLayer(FakeQuantizer): + """ + Default implement of MinMaxFakeQuantizer. + 1. statistic the min max value passing through this op + 2. run fake quant execution to simulate the quantize loss + """ + + def __init__(self, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False): + super(DefaultFakeQuantizerPerLayer, self).__init__() + self._ema = ema + self._ema_decay = ema_decay + self._symmetric = symmetric + self._narraw_range = narrow_range + self._min_max_update_func = partial(Q.MinMaxUpdatePerLayer, ema=self._ema, ema_decay=self._ema_decay) + self._is_ascend = context.get_context("device_target") == "Ascend" + quant_func = Q.FakeQuantPerLayer + self._init_fake_quant_func(quant_func) + self._float_min = Parameter(Tensor(float("-inf")), name="float_min") + self._float_max = Parameter(Tensor(float("inf")), name="float_max") + + def _init_fake_quant_func(self, quant_func): + if self.is_ascend: + self._fake_quant_train = quant_func(num_bits=self.quant_dtype.num_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + quant_delay=self.quant_delay) + self._fake_quant_infer = self.fake_quant_train + else: + quant_func = partial(quant_func, + ema=self.ema, + ema_decay=self.ema_decay, + num_bits=self.quant_dtype.num_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + quant_delay=self.quant_delay) + self._fake_quant_train = quant_func(training=True) + self._fake_quant_infer = quant_func(training=False) + + def construct(self, x): + if self.training: + self._float_min, self._float_max = \ + self._min_max_update_func(x, self._float_min, self._float_max) + out = self._fake_quant_train(x, self._float_max, self._float_max) + else: + out = self._fake_quant_infer(x, self._float_max, self._float_max) + return out + + +class DefaultFakeQuantizerPerChannel(DefaultFakeQuantizerPerLayer): + """ + Derived from DefaultFakeQuantizerPerLayer, perchannel version of default fake quantizer + """ + def __init__(self, num_channels=1, channel_axis=1,ema=False, ema_decay=0.999, symmetric=False, narrow_range=False): + super(DefaultFakeQuantizerPerChannel, self).__init__(ema=ema, ema_decay=ema_decay, symmetric=symmetric, + narrow_range=narrow_range) + self._float_min = Parameter(Tensor([float("-inf")] * num_channels), name="float_min") + self._float_max = Parameter(Tensor([float("inf")] * num_channels), name="float_max") + quant_func = partial(Q.FakeQuantPerChannel, channel_axis=channel_axis) + self._init_fake_quant_func(quant_func) + + +class LearnedFakeQuantizerPerLayer(FakeQuantizer): + """ + Derived class of FakeQuantizer. Use learning-rate from each epoch to compute scale and zero-point. + """ + + def __init__(self, num_bits=8, quant_delay=0, min_init=-6, max_init=6, neg_trunc=False): + super(LearnedFakeQuantizerPerLayer, self).__init__() + self.neg_trunc = neg_trunc + self._quant_max = _calculate_quant_max(num_bits, self.neg_trunc) + self.quant_max = Parameter(Tensor(np.array([self._quant_max]).astype(np.float32))) + quant_func = partial(Q.FakeLearnedScaleQuantPerLayer, quant_delay=quant_delay, neg_trunc=self.neg_trunc) + self.fake_quant_train = quant_func(training=True) + self.fake_quant_infer = quant_func(training=False) + self._float_min = Parameter(Tensor(min_init), name="float_min") + self._float_max = Parameter(Tensor(max_init), name="float_max") + + def update_min_max(self, new_float_min, new_float_max): + self._float_min.set_data(Tensor(new_float_min)) + self._float_max.set_data(Tensor(new_float_max)) + + def construct(self, x): + if self.training: + out = self.fake_quant_train(x, self._float_max, self.quant_max) + else: + out = self.fake_quant_infer(x, self._float_max, self.quant_max) + return out + + +class LearnedFakeQuantizePerChannel(FakeQuantizer): + """ + Derived class of FakeQuantizer. perchannel version of LearnedFakeQuantizerPerLayer. + """ + + def __init__(self, num_bits=8, num_channels=1, channel_axis=1, quant_delay=0, + float_min=-6, float_max=6, neg_trunc=False): + super(LearnedFakeQuantizePerChannel, self).__init__() + self._quant_max = _calculate_quant_max(num_bits, neg_trunc) + self.quant_max = Parameter(Tensor(np.array([self._quant_max]).astype(np.float32))) + quant_func = partial(Q.FakeLearnedScaleQuantPerChannel, quant_delay=quant_delay, neg_trunc=neg_trunc, + channel_axis=channel_axis) + self.fake_quant_train = quant_func(training=True) + self.fake_quant_infer = quant_func(training=False) + self._num_channels = num_channels + self._float_min = Parameter(Tensor(self._get_init_array(float_min)), name="float_min") + self._float_max = Parameter(Tensor(self._get_init_array(float_max)), name="float_max") + + def update_min_max(self, new_float_min, new_float_max): + self._float_min.set_data(Tensor(self._get_init_array(new_float_min))) + self._float_max.set_data(Tensor(self._get_init_array(new_float_max))) + + def _get_init_array(self, init_data): + """ + Convert the initial value to array. + """ + if isinstance(init_data, list) and len(init_data) != self.num_channels: + raise ValueError(f"For '{self.cls_name}', the length of 'min_init/max_init' list should be equal to " + f"'num_channels' for perchannel quant scenario, but got 'min_init/max_init': {init_data} " + f"and num_channels: {self._num_channels}.") + + if isinstance(init_data, list): + min_max_array = np.array(init_data).astype(np.float32) + else: + min_max_array = np.array([init_data] * self._num_channels).astype(np.float32) + return min_max_array + + def construct(self, x): + if self.training: + out = self.fake_quant_train(x, self._float_max, self.quant_max) + else: + out = self.fake_quant_infer(x, self._float_max, self.quant_max) + return out diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py index 9d45a6ba72d..c32a6eaa62d 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py @@ -17,8 +17,8 @@ from typing import Optional from ..layer_policy import LayerPolicy from ..quantize_wrapper_cell import QuantizeWrapperCell -from ..quantizer import Quantizer -from .default_quantizer import LastValueQuantizer +from ..fake_quantizer import FakeQuantizer +from .default_fake_quantizer import LearnedFakeQuantizerPerLayer, LearnedFakeQuantizePerChannel from mindspore.nn import Cell @@ -33,10 +33,10 @@ class DefaultLayerPolicy(LayerPolicy): def __init__(self, weight_names: [], act_names: [], config=None): if config is None: config = {} - self._weight_quantizer = LastValueQuantizer() - self._act_quantizer = LastValueQuantizer() - self._input_quantizer: Optional[Quantizer] = LastValueQuantizer() - self._output_quantizer: Optional[Quantizer] = LastValueQuantizer() + self._weight_quantizer = LearnedFakeQuantizePerChannel() + self._act_quantizer = LearnedFakeQuantizerPerLayer() + self._input_quantizer: Optional[FakeQuantizer] = LearnedFakeQuantizerPerLayer() + self._output_quantizer: Optional[FakeQuantizer] = LearnedFakeQuantizerPerLayer() self._weight_names = weight_names self._act_names = act_names self._input_num = 0 @@ -48,10 +48,10 @@ class DefaultLayerPolicy(LayerPolicy): def get_act_name_and_quantizers(self): return [(name, self._act_quantizer) for name in self._act_names] - def get_input_quantizer(self) -> Optional[Quantizer]: + def get_input_quantizer(self) -> Optional[FakeQuantizer]: return self._input_quantizer - def get_output_quantizer(self) -> Optional[Quantizer]: + def get_output_quantizer(self) -> Optional[FakeQuantizer]: return self._output_quantizer def set_input_number(self, input_num: int): @@ -78,3 +78,13 @@ class DefaultLayerPolicy(LayerPolicy): def wrap_cell(self, handler: Cell) -> Cell: return QuantizeWrapperCell(handler, self) + +class ActivationLayerPolicy(DefaultLayerPolicy): + def __init__(self, insert_before_input=False, insert_after_output=True): + super().__init__([], []) + self._input_quantizer: Optional[FakeQuantizer] = LearnedFakeQuantizerPerLayer() + self._output_quantizer: Optional[FakeQuantizer] = LearnedFakeQuantizerPerLayer() + self._insert_before_input = insert_before_input + self._insert_after_output = insert_after_output + + diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py index 191c2b0bed4..a8a99ad91aa 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py @@ -14,10 +14,15 @@ # ============================================================================ """DefaultNetworkPolicy.""" +from functools import partial from ..net_policy import NetPolicy +from ..layer_policy import LayerPolicy from .default_layer_policy import DefaultLayerPolicy from ..transformer import Transformer from mindspore.nn.layer import Conv2d, Dense, MatMul, BatchNorm2d, ReLU +from mindspore.nn.layer.quant import Conv2dBnFoldQuantOneConv +from mindspore.rewrite.pattern_engine import PatternEngine +from mindspore.nn.layer.quant import QuantConfig class DefaultNetworkPolicy(NetPolicy): @@ -32,9 +37,20 @@ class DefaultNetworkPolicy(NetPolicy): super().__init__(config) if config is None: config = {} - self._pattern_engines: [Transformer] = [ + + def fetch_quant_config(layer_policy: LayerPolicy): + weight_fake_quantizer = None if len(layer_policy.get_weight_name_and_quantizers()) == 0 \ + else layer_policy.get_weight_name_and_quantizers()[0][1] + act_fake_quantizer = None if len(layer_policy.get_act_name_and_quantizers()) == 0 \ + else layer_policy.get_act_name_and_quantizers()[0][1] + return QuantConfig(weight_fake_quantizer, act_fake_quantizer) + + self._pattern_engines: [PatternEngine] = [ Transformer([Conv2d, BatchNorm2d]), - Transformer([Conv2d, ReLU]) + Transformer([Conv2d, ReLU]), + PatternEngine([Conv2d, BatchNorm2d], + partial(Conv2dBnFoldQuantOneConv.from_float, + fetch_quant_config(self.get_net_layer_policy()))) ] self._support_layer_map: dict = { Conv2d: DefaultLayerPolicy(["weight"], [], config), diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantizer.py deleted file mode 100644 index 7aff644a890..00000000000 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantizer.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""DefaultQuantizeOp.""" - -from ..quantizer import Quantizer - - -class FixQuantizer(Quantizer): - ... - - -class AllValueQuantizer(Quantizer): - ... - - -class MovingAvgQuantizer(Quantizer): - ... - - -class LastValueQuantizer(Quantizer): - """ - Derived class of QuantizeOp. Use min and max value of data to compute scale and zero-point. - """ - - def __init__(self): - super().__init__() - self._bit_num = 8 - - def compute_quant_param(self, float_data: [float]) -> {}: - data_min = float_data[0] - data_max = float_data[1] - if data_max == data_min: - return 1, 0 - scale = (1 << self._bit_num) / (data_max - data_min) - zp = data_max * scale - return scale, zp - - def fake_quant(self, float_data: [float], quant_params: dict, **kwargs) -> [float]: - scale = quant_params.get("scale") - zp = quant_params.get("zp") - return float_data * scale + zp - - -class LSQ(Quantizer): - """ - Derived class of QuantizeOp. Use learning-rate from each epoch to compute scale and zero-point. - """ - - def __init__(self): - super(LSQ, self).__init__() - - def compute_quant_param(self, float_data: [float]) -> {}: - pass - - def fake_quant(self, float_data: [float], quant_params: dict, **kwargs) -> [float]: - pass diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/fake_quantizer.py similarity index 42% rename from mindspore/python/mindspore/golden_stick/quantization/quantizer.py rename to mindspore/python/mindspore/golden_stick/quantization/fake_quantizer.py index 8c9e6e4202e..e23154aac82 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quantizer.py +++ b/mindspore/python/mindspore/golden_stick/quantization/fake_quantizer.py @@ -12,38 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Quantizer.""" +"""FakeQuantizer.""" +from mindspore.nn.cell import Cell +FakeQuantizer = Cell -class Quantizer: - def __init__(self): - pass - - def compute_quant_param(self, float_data: [float]) -> {}: - """ - Compute quant-params such as min/max/scale/zero-point according to input `data`. - This method must be overridden by all subclasses. - - Args: - float_data (List[float]): input data for quant-params. - - Returns: - a dictionary as quant-params - """ - - pass - - def fake_quant(self, float_data: [float], quant_params: dict, **kwargs) -> [float]: - """ - FakeQuant input `float-data` according to quant_params and other args. - This method must be overridden by all subclasses. - - Args: - float_data (List[float]): input data to be fake-quantize. - quant_params (dict): quant-params of input data. - - Returns: - FakeQuantized data. - """ - - pass diff --git a/mindspore/python/mindspore/golden_stick/quantization/hello_qat/simple_qat.py b/mindspore/python/mindspore/golden_stick/quantization/hello_qat/simple_qat.py index 5f4a51104f1..005b5443bc1 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/hello_qat/simple_qat.py +++ b/mindspore/python/mindspore/golden_stick/quantization/hello_qat/simple_qat.py @@ -17,7 +17,8 @@ from collections import OrderedDict from ...golden_stick import GoldenStick from ...net_transform import NetTransformer -from mindspore.nn import Conv2d, Cell, BatchNorm2d, FakeQuantWithMinMaxObserver +from mindspore.nn import Conv2d, Cell, BatchNorm2d +from mindspore.nn.layer.quant import FakeQuantWithMinMaxObserver from mindspore.rewrite.pattern_engine import PatternEngine, PatternNode from mindspore.rewrite import Node from mindspore.train.callback import Callback diff --git a/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py b/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py index a6e839c8747..b6fde62a67e 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py @@ -15,9 +15,8 @@ """LayerQConfig.""" import abc from typing import Optional -from .quantizer import Quantizer +from .fake_quantizer import FakeQuantizer from mindspore.nn import Cell - layer_policy_key = "layer_quant_policy" @@ -40,7 +39,7 @@ class LayerPolicy(abc.ABC): `get_output_quantizers` and `wrapper_cell`. """ - def get_weight_name_and_quantizers(self) -> [(str, Quantizer)]: + def get_weight_name_and_quantizers(self) -> [(str, FakeQuantizer)]: """ Define how to fake-quantize weight data. This method must be overridden by all subclasses. @@ -51,10 +50,10 @@ class LayerPolicy(abc.ABC): return [] - def get_act_name_and_quantizers(self) -> [(str, (Optional[Quantizer], Optional[Quantizer]))]: + def get_act_name_and_quantizers(self) -> [(str, (Optional[FakeQuantizer], Optional[FakeQuantizer]))]: return [] - def get_input_quantizer(self) -> Optional[Quantizer]: + def get_input_quantizer(self) -> Optional[FakeQuantizer]: """ Define how to fake-quantize input data. This method must be overridden by all subclasses. @@ -64,7 +63,7 @@ class LayerPolicy(abc.ABC): """ return None - def get_output_quantizer(self) -> Optional[Quantizer]: + def get_output_quantizer(self) -> Optional[FakeQuantizer]: """ Define how to fake-quantize output data. This method must be overridden by all subclasses. @@ -99,3 +98,4 @@ class LayerPolicy(abc.ABC): # only support one-output-quantizer pre layer because we can not get how many outputs a cell would has def set_output_not_insert_fq(self): pass + diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py index e2a0e2ae1a7..4e16cd353f8 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py @@ -15,7 +15,7 @@ """QuantizeWrapperActivation.""" from mindspore.nn import Cell -from .quantizer import Quantizer +from .fake_quantizer import FakeQuantizer class QuantizeWrapperActivation(Cell): @@ -28,7 +28,7 @@ class QuantizeWrapperActivation(Cell): post_quantizer (Quantizer): Define how activation data to be fake-quant. """ - def __init__(self, act: Cell, pre_quantizer: Quantizer = None, post_quantizer: Quantizer = None): + def __init__(self, act: Cell, pre_quantizer: FakeQuantizer = None, post_quantizer: FakeQuantizer = None): super().__init__() self._handler: callable = act self._pre_quantizer = pre_quantizer diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py index 5afede41f2a..fbebd4b08e9 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py @@ -15,7 +15,7 @@ """QuantizeWrapperCell.""" from mindspore.nn import Cell -from .quantizer import Quantizer +from .fake_quantizer import FakeQuantizer from .layer_policy import LayerPolicy from .quantize_wrapper_act import QuantizeWrapperActivation @@ -26,7 +26,7 @@ class QuantizeWrapperCell(Cell): Args: handler (Cell): normal cell to be wrapped. - layer_policy (Quantizer): Define how weight data to be fake-quant. + layer_policy (FakeQuantizer): Define how weight data to be fake-quant. """ def __init__(self, handler: Cell, layer_policy: LayerPolicy): diff --git a/mindspore/python/mindspore/nn/layer/quant.py b/mindspore/python/mindspore/nn/layer/quant.py index 1e68732975b..af15b23735c 100644 --- a/mindspore/python/mindspore/nn/layer/quant.py +++ b/mindspore/python/mindspore/nn/layer/quant.py @@ -17,6 +17,8 @@ from functools import partial from collections import namedtuple import numpy as np +import enum +import re import mindspore.common.dtype as mstype from mindspore.ops.primitive import Primitive from mindspore.ops import operations as P @@ -24,9 +26,10 @@ from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor from mindspore._checkparam import Validator, twice -from mindspore.compression.common import QuantDtype +from types import DynamicClassAttribute import mindspore.context as context from .normalization import BatchNorm2d +from .conv import Conv2d from .activation import get_activation from ..cell import Cell from ... import nn @@ -44,6 +47,103 @@ __all__ = [ 'MulQuant', ] +@enum.unique +class QuantDtype(enum.Enum): + """ + An enum for quant datatype, contains `INT2` ~ `INT8`, `UINT2` ~ `UINT8`. + """ + INT2 = "INT2" + INT3 = "INT3" + INT4 = "INT4" + INT5 = "INT5" + INT6 = "INT6" + INT7 = "INT7" + INT8 = "INT8" + + UINT2 = "UINT2" + UINT3 = "UINT3" + UINT4 = "UINT4" + UINT5 = "UINT5" + UINT6 = "UINT6" + UINT7 = "UINT7" + UINT8 = "UINT8" + + def __str__(self): + return f"{self.name}" + + @staticmethod + def is_signed(dtype): + """ + Get whether the quant datatype is signed. + + Args: + dtype (QuantDtype): quant datatype. + + Returns: + bool, whether the input quant datatype is signed. + + Examples: + >>> quant_dtype = QuantDtype.INT8 + >>> is_signed = QuantDtype.is_signed(quant_dtype) + """ + return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5, + QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8] + + @staticmethod + def switch_signed(dtype): + """ + Switch the signed state of the input quant datatype. + + Args: + dtype (QuantDtype): quant datatype. + + Returns: + QuantDtype, quant datatype with opposite signed state as the input. + + Examples: + >>> quant_dtype = QuantDtype.INT8 + >>> quant_dtype = QuantDtype.switch_signed(quant_dtype) + """ + type_map = { + QuantDtype.INT2: QuantDtype.UINT2, + QuantDtype.INT3: QuantDtype.UINT3, + QuantDtype.INT4: QuantDtype.UINT4, + QuantDtype.INT5: QuantDtype.UINT5, + QuantDtype.INT6: QuantDtype.UINT6, + QuantDtype.INT7: QuantDtype.UINT7, + QuantDtype.INT8: QuantDtype.UINT8, + QuantDtype.UINT2: QuantDtype.INT2, + QuantDtype.UINT3: QuantDtype.INT3, + QuantDtype.UINT4: QuantDtype.INT4, + QuantDtype.UINT5: QuantDtype.INT5, + QuantDtype.UINT6: QuantDtype.INT6, + QuantDtype.UINT7: QuantDtype.INT7, + QuantDtype.UINT8: QuantDtype.INT8 + } + return type_map[dtype] + + @DynamicClassAttribute + def _value(self): + """The value of the Enum member.""" + return int(re.search(r"(\d+)", self._value_).group(1)) + + @DynamicClassAttribute + def num_bits(self): + """ + Get the num bits of the QuantDtype member. + + Returns: + int, the num bits of the QuantDtype member. + + Examples: + >>> from mindspore.compression.common import QuantDtype + >>> quant_dtype = QuantDtype.INT8 + >>> num_bits = quant_dtype.num_bits + >>> print(num_bits) + 8 + """ + return self._value + class BatchNormFoldCell(Cell): """ @@ -788,6 +888,24 @@ class Conv2dBnFoldQuantOneConv(Cell): self.fake_quant_weight.quant_delay) return s + @classmethod + def from_float(cls, conv: Conv2d, bn: BatchNorm2d, quant_config: QuantConfig): + convbn_quant = cls(conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + pad_mode=conv.pad_mode, + padding=conv.padding, + dilation=conv.dilation, + group=conv.group, + eps=bn.eps, + momentum=bn.momentum, + has_bias=conv.has_bias, + bias_init=conv.bias_init, + quant_config=quant_config, + fake=True) + return convbn_quant + def construct(self, x): running_std = P.Sqrt()(P.Add()(self.moving_variance, self.eps)) scale_factor = self.gamma / running_std -- Gitee From 275705288695d0cac7eec4a6cc68df9ca2b0cf3b Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 27 Dec 2021 15:05:58 +0800 Subject: [PATCH 17/34] add function_def_pass init_function_def_pass --- .../mindspore/rewrite_experiment/__init__.py | 2 + .../mindspore/rewrite_experiment/compiler.py | 19 +-- .../rewrite_experiment/function_symbol.py | 45 +++++++ .../mindspore/rewrite_experiment/graph.py | 35 +++--- .../mindspore/rewrite_experiment/node.py | 30 +++-- .../passes/class_def_pass.py | 7 +- .../passes/function_def_pass.py | 118 ++++++++++++++++++ .../passes/init_function_def_pass.py | 41 ++++++ .../rewrite_experiment/passes/module_pass.py | 7 +- .../mindspore/rewrite_experiment/passs.py | 3 +- .../mindspore/rewrite_experiment/symbol.py | 18 ++- .../rewrite_experiment/symbol_table.py | 30 +++-- .../mindspore/rewrite_experiment/test.py | 2 +- 13 files changed, 293 insertions(+), 64 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/function_symbol.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/passes/init_function_def_pass.py diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py index b96cd98338b..3fc0bf97a13 100644 --- a/mindspore/python/mindspore/rewrite_experiment/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -3,5 +3,7 @@ from .symbol_table import SymbolTable from .compiler import Compiler from .passes.module_pass import ModulePass from .passes.class_def_pass import ClassDefPass +from .passes.function_def_pass import FunctionDefPass +from .passes.init_function_def_pass import InitFunctionDefPass __all__ = ["Graph", "Compiler", "SymbolTable"] diff --git a/mindspore/python/mindspore/rewrite_experiment/compiler.py b/mindspore/python/mindspore/rewrite_experiment/compiler.py index 62a267412cd..6872a79db83 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compiler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from typing import Union +from typing import Union, Tuple from .symbol_table import SymbolTable from .symbol import Symbol, SymbolType @@ -23,6 +23,7 @@ from mindspore.ops.primitive import Primitive from types import FunctionType from mindspore import log as logger from .mutable_dict_iterator import MutableDictIterator +from .graph import Graph class Compiler: @@ -41,12 +42,12 @@ class Compiler: return True @staticmethod - def compile_symbol_by_pass(iterator: MutableDictIterator, pass_: Pass) -> MutableDictIterator: + def compile_symbol_by_pass(iterator: MutableDictIterator, pass_: Pass, graph: Graph) -> MutableDictIterator: symbol: Symbol = iterator.value() if not Compiler._is_leaf_symbol(symbol): # not compilable, skip logger.warning("Processing symbol(%s) by pass(%s), leaf symbol", symbol.get_symbol_name(), pass_.name()) return iterator - results: [Symbol] = pass_.process(symbol) + results: [Symbol] = pass_.process(symbol, graph) if len(results) == 1 and results[0] is symbol: # no change in process logger.warning("Processing symbol(%s) by pass(%s), not changed", symbol.get_symbol_name(), pass_.name()) return iterator @@ -74,11 +75,11 @@ class Compiler: return iterator @staticmethod - def compile_stb_by_pass(stb: SymbolTable, pass_: Pass) -> bool: + def compile_stb_by_pass(stb: SymbolTable, pass_: Pass, graph: Graph) -> bool: changed = False iterator = MutableDictIterator(stb.get_symbols()) while not iterator.is_end(): - new_iterator = Compiler.compile_symbol_by_pass(iterator, pass_) + new_iterator = Compiler.compile_symbol_by_pass(iterator, pass_, graph) if new_iterator == iterator: iterator = next(iterator) else: @@ -86,13 +87,15 @@ class Compiler: return changed @staticmethod - def compile(network: Union[nn.Cell, Primitive, FunctionType]) -> SymbolTable: + def compile(network: Union[nn.Cell, Primitive, FunctionType]) -> Tuple[Graph, SymbolTable]: stb = SymbolTable(network) + graph = Graph() passes: [Pass] = PassRegister.instance().get_passes() + logger.warning("------------- Load %d passes", len(passes)) changed = True while changed: changed = False for key in passes: pass_: Pass = passes[key] - changed |= Compiler.compile_stb_by_pass(stb, pass_) - return stb + changed |= Compiler.compile_stb_by_pass(stb, pass_, graph) + return graph, stb diff --git a/mindspore/python/mindspore/rewrite_experiment/function_symbol.py b/mindspore/python/mindspore/rewrite_experiment/function_symbol.py new file mode 100644 index 00000000000..26870e0480b --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/function_symbol.py @@ -0,0 +1,45 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast +from .symbol import Symbol, SymbolType + + +class FunctionSymbol(Symbol): + def __init__(self, ast_node: ast.AST, symbol_name, args=None, returns=None, bodies=None): + super().__init__(ast_node, symbol_name, SymbolType.function_def) + self._args = [] + if isinstance(args, list): + for arg in args: + if not isinstance(arg, Symbol): + raise RuntimeError("Input arg is not Symbol type") + self._args.append(arg) + + if isinstance(returns, list): + self._returns = returns + else: + self._returns = [] + + self._bodies = [] + if isinstance(bodies, list): + for body in bodies: + if not isinstance(body, Symbol): + raise RuntimeError("Input bodies is not Symbol type") + self._bodies.append(body) + + def add_body(self, body: Symbol): + self._bodies.append(body) + + def get_args(self): + return self._args diff --git a/mindspore/python/mindspore/rewrite_experiment/graph.py b/mindspore/python/mindspore/rewrite_experiment/graph.py index 180936b9d3e..a718f8abb85 100644 --- a/mindspore/python/mindspore/rewrite_experiment/graph.py +++ b/mindspore/python/mindspore/rewrite_experiment/graph.py @@ -8,7 +8,7 @@ import mindspore.nn as nn from mindspore import log as logger from mindspore.ops.primitive import Primitive -from .node import Node, NodeType +from .node import Node, NodeType, PlaceholderNode from .observer import Observer @@ -16,26 +16,26 @@ class Graph(Observer): def update(self): pass - def __init__(self, network: Union[nn.Cell, Primitive, FunctionType]): - if not isinstance(network, nn.Cell): - logger.error("Only support network with Cell type now") - return + def __init__(self): + # if not isinstance(network, nn.Cell): + # logger.error("Only support network with Cell type now") + # return # self._placeholders: List[Node] = [] # self._contant_nodes: List[ConstantNode] = [] # self._param_default_value: Dict = {} # self._node_attributes: Dict = {} # self._symbol_table: dict = {} - self._net_cls = type(network) - self._name = self._net_cls.__name__ - self._base_scope = self._net_cls.__name__ - self._network = network - network_str = inspect.getsource(self._net_cls) - self._ast_root: ast.AST = ast.parse(network_str) - - root_node = Node(self._ast_root, self._name, self._net_cls, NodeType.module) - self._nodes: [Node] = [root_node] - self._return = root_node + # self._net_cls = type(network) + # self._name = self._net_cls.__name__ + # self._base_scope = self._net_cls.__name__ + # self._network = network + # network_str = inspect.getsource(self._net_cls) + # self._ast_root: ast.AST = ast.parse(network_str) + # + # root_node = Node(self._ast_root, self._name, self._net_cls, NodeType.module) + self._nodes: [Node] = [] + # self._return = root_node @property @@ -52,4 +52,7 @@ class Graph(Observer): self._return = node def print_ast(self): - astpretty.pprint(self._ast_root) \ No newline at end of file + astpretty.pprint(self._ast_root) + + def add_placeholder(self, name, targets=None, ast_node=None, default_value=None): + self._nodes.append(PlaceholderNode(name, targets, ast_node, default_value)) diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index 07cfb61af0c..b2757eb559c 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -197,22 +197,20 @@ class Node(Subject): # return f"name: {self._name}; value: {self._value}; outputs: {len(self.outputs)}; output names: {output_names}" # # -# class PlaceholderNode(Node): -# """ -# 'PlaceholderNode' is used to represent inputs of Cell, method or function. -# """ -# -# def __init__(self, name, targets=None, ast_node=None, default_value=None): -# super().__init__(name, targets) -# self._ast_node = ast_node -# self._default_value = default_value -# self._attribute.type = NodeType.placeholder -# -# def __repr__(self) -> str: -# output_names = "" -# for n in self.outputs: -# output_names += n.name + " " -# return f"name: {self._name}; targets: {self._targets}, outputs: {len(self.outputs)}; output names: {output_names}; attribute: {self._attribute}" +class PlaceholderNode(Node): + """ + 'PlaceholderNode' is used to represent inputs of Cell, method or function. + """ + def __init__(self, name, targets=None, ast_node=None, default_value=None): + super().__init__(ast_node, name, Primitive, NodeType.placeholder) + self._ast_node = ast_node + self._default_value = default_value + + def __repr__(self) -> str: + output_names = "" + for n in self.outputs: + output_names += n.name + " " + return f"name: {self._name}; targets: {self._targets}, outputs: {len(self.outputs)}; output names: {output_names}; attribute: {self._attribute}" # # # class SubgraphNode(Node): diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py index b2492663232..1ac81b98801 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py @@ -17,6 +17,7 @@ from ..symbol import Symbol, SymbolType from ..passs import Pass from ..pass_register import PassRegister from mindspore import log as logger +from ..graph import Graph @PassRegister.reg_pass @@ -24,7 +25,7 @@ class ClassDefPass(Pass): def name(self) -> str: return "_ClassDefPass_" - def process(self, symbol: Symbol) -> [Symbol]: + def process(self, symbol: Symbol, graph: Graph) -> [Symbol]: if not isinstance(symbol.get_ast(), ast.ClassDef): return [symbol] class_def: ast.ClassDef = symbol.get_ast() @@ -36,7 +37,7 @@ class ClassDefPass(Pass): new_symbols.append(new_node) else: if hasattr(body, "name"): - logger.warning("Ignoring node(%s) in Module", body.name) + logger.warning("Ignoring symbol(%s) in ClassDef", body.name) else: - logger.warning("Ignoring node(%s) in Module", body) + logger.warning("Ignoring symbol(%s) in ClassDef", body) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py new file mode 100644 index 00000000000..c8057e326bf --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py @@ -0,0 +1,118 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast +from typing import List + +from ..symbol import Symbol, SymbolType +from ..function_symbol import FunctionSymbol +from ..passs import Pass +from ..pass_register import PassRegister +from mindspore import log as logger +from ..graph import Graph + + +@PassRegister.reg_pass +class FunctionDefPass(Pass): + def name(self) -> str: + return "_FunctionDefPass_" + + def _parse_arguments(self, arguments_node: ast.arguments): + class Arg: + def __init__(self, lineno, col_offset, name) -> None: + self._lineno = lineno + self._clo_offset = col_offset + self._name = name + + class Default: + def __init__(self, lineno, col_offset, value) -> None: + self._lineno = lineno + self._col_offset = col_offset + self._value = value + + def _find_corresponding_name(defaults: List[Default], names: List[Arg]): + for d in defaults: + i = 0 + while i < len(names) and names[i]._lineno == d._lineno and names[i]._clo_offset < d._col_offset: + i += 1 + if i <= len(names): + arg_with_default_value[names[i-1]._name] = d._value + + args_ = [] + arg_with_default_value = {} + for arg in arguments_node.args: + if arg.arg == "self": + continue + a = Arg(arg.lineno, arg.col_offset, arg.arg) + args_.append(a) + arg_with_default_value[a._name] = None + + for arg in arguments_node.kwonlyargs: + a = Arg(arg.lineno, arg.col_offset, arg.arg) + args_.append(a) + arg_with_default_value[a._name] = None + + if arguments_node.vararg != None: + a = Arg(arguments_node.vararg.arg.lineno, arg.col_offset, arg.arg) + args_.append(a) + arg_with_default_value[a._name] = None + + if arguments_node.kwarg != None: + a = Arg(arguments_node.vararg.arg.lineno, arg.col_offset, arg.arg) + args_.append(a) + arg_with_default_value[a._name] = None + + defaults_ = [] + # todo default + # for default in arguments_node.defaults: + # visitor = self.get_node_visitor(default) + # value = visitor(default) + # d = Default(default.lineno, default.col_offset, value) + # defaults_.append(d) + + _find_corresponding_name(defaults_, args_) + self._default_values = arg_with_default_value + return arg_with_default_value + + def _parse_returns(self, returns_node): + return [] + + def process(self, symbol: Symbol, graph: Graph) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.FunctionDef): + return [symbol] + if isinstance(symbol, FunctionSymbol): + return [symbol] + function_def: ast.FunctionDef = symbol.get_ast() + # parse args + args: ast.arguments = function_def.args + args_with_value = self._parse_arguments(args) + # self._parser.updete_closure_namespace(self._network.__init__) + + returns = self._parse_returns(function_def.returns) + function_symbol = FunctionSymbol(function_def, function_def.name, args_with_value, returns) + new_symbols = [function_symbol] + bodies: list = function_def.body + index = 0 + for body in bodies: + if isinstance(body, ast.Assign): + body_symbol = Symbol(body, function_def.name + "-assign-" + str(index), SymbolType.assign) + index += 1 + function_symbol.add_body(body_symbol) + new_symbols.append(body_symbol) + else: + if hasattr(body, "name"): + logger.warning("Ignoring symbol(%s) in FunctionDef", body.name) + else: + logger.warning("Ignoring symbol(%s) in FunctionDef", body) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/init_function_def_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/init_function_def_pass.py new file mode 100644 index 00000000000..2cf8a20726f --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/passes/init_function_def_pass.py @@ -0,0 +1,41 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from ..symbol import Symbol +from ..function_symbol import FunctionSymbol +from ..passs import Pass +from ..pass_register import PassRegister +from ..graph import Graph + + +@PassRegister.reg_pass +class InitFunctionDefPass(Pass): + def __init__(self): + self._function_def_init_visited_attr_key = "function_def_init_visited_key" + + def name(self) -> str: + return "_InitFunctionDefPass_" + + def process(self, symbol: Symbol, graph: Graph) -> [Symbol]: + if not isinstance(symbol, FunctionSymbol): + return [symbol] + if symbol.get_attr(self._function_def_init_visited_attr_key): + return [symbol] + args = symbol.get_ast().args + # todo + # args_with_value = symbol.get_args() + # for name, value in args_with_value.items(): + # graph.add_placeholder(name, name, args, value) + symbol.set_attr(self._function_def_init_visited_attr_key, True) + return [symbol] diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py index c7015f485fd..b8cbe5926ba 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py @@ -17,6 +17,7 @@ from ..symbol import Symbol, SymbolType from ..passs import Pass from ..pass_register import PassRegister from mindspore import log as logger +from ..graph import Graph @PassRegister.reg_pass @@ -24,7 +25,7 @@ class ModulePass(Pass): def name(self) -> str: return "_ModulePass_" - def process(self, symbol: Symbol) -> [Symbol]: + def process(self, symbol: Symbol, graph: Graph) -> [Symbol]: if not isinstance(symbol.get_ast(), ast.Module): return [symbol] module: ast.Module = symbol.get_ast() @@ -36,9 +37,9 @@ class ModulePass(Pass): new_symbols.append(new_node) else: if hasattr(body, "name"): - logger.warning("Ignoring node(%s) in Module", body.name) + logger.warning("Ignoring symbol(%s) in Module", body.name) else: - logger.warning("Ignoring node(%s) in Module", body) + logger.warning("Ignoring symbol(%s) in Module", body) if len(new_symbols) == 0: return [symbol] else: diff --git a/mindspore/python/mindspore/rewrite_experiment/passs.py b/mindspore/python/mindspore/rewrite_experiment/passs.py index b0d4f522f99..c9104ec41de 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passs.py +++ b/mindspore/python/mindspore/rewrite_experiment/passs.py @@ -15,6 +15,7 @@ import abc from .symbol import Symbol +from .graph import Graph class Pass(abc.ABC): @@ -27,7 +28,7 @@ class Pass(abc.ABC): raise NotImplementedError @abc.abstractmethod - def process(self, symbol: Symbol) -> [Symbol]: + def process(self, symbol: Symbol, graph: Graph) -> [Symbol]: """ Args: symbol (Symbol): node who is tried to be processed diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index 788a2eb9bd1..accb30d6653 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -19,9 +19,10 @@ class SymbolType: module = 1 class_def = 2 function_def = 3 - constant = 4 - cell = 5 - primitive = 6 + assign = 4 + constant = 5 + cell = 6 + primitive = 7 invalid = 100 @@ -41,6 +42,8 @@ class Symbol: self._symbol_name: str = symbol_name self._symbol_type: int = symbol_type self._ast_root: ast.AST = ast_node + self._attribute: dict = {} + self._sub_symbols: list = [] def get_symbol_name(self) -> str: return self._symbol_name @@ -56,3 +59,12 @@ class Symbol: def set_ast(self, ast_node: ast.AST): self._ast_root = ast_node + + def set_attr(self, key: str, value): + self._attribute[key] = value + + def get_attr(self, key: str): + return self._attribute.get(key) + + def get_sub_symbols(self): + return self._sub_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_table.py b/mindspore/python/mindspore/rewrite_experiment/symbol_table.py index f05c0957a7a..aa7352e5588 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_table.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_table.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ from types import FunctionType -from typing import Union +from typing import Union, Optional import inspect import ast @@ -26,30 +26,34 @@ from .symbol import Symbol, SymbolType class SymbolTable: - def __init__(self, network: Union[nn.Cell, Primitive, FunctionType]): + def __init__(self, network: Optional[Union[nn.Cell, Primitive, FunctionType]] = None): + if network is None: + self._table = {} + return if not isinstance(network, nn.Cell): logger.error("Only support network with Cell type now") return - self._network = network - self._net_cls = type(network) - network_str = inspect.getsource(self._net_cls) - self._ast_root: ast.AST = ast.parse(network_str) - name = self._net_cls.__name__ - root_symbol = Symbol(self._ast_root, name, SymbolType.module) + network = network + net_cls = type(network) + network_str = inspect.getsource(net_cls) + ast_root: ast.AST = ast.parse(network_str) + name = net_cls.__name__ + root_symbol = Symbol(ast_root, name, SymbolType.module) self._table = {name: root_symbol} - def get_symbols(self): - return self._table - def remove_symbol(self, symbol: Symbol): - print("========== Remove symbol: ", symbol) symbol_ = self._table.get(symbol.get_symbol_name()) if symbol_ is not None: self._table.pop(symbol.get_symbol_name()) def add_symbol(self, symbol: Symbol): - print("========== Add symbol: ", symbol) symbol_ = self._table.get(symbol.get_symbol_name()) if symbol_ is None: self._table[symbol.get_symbol_name()] = symbol + + def get_symbols(self): + return self._table + + def get_symbol(self, key: str): + return self._table.get(key) diff --git a/mindspore/python/mindspore/rewrite_experiment/test.py b/mindspore/python/mindspore/rewrite_experiment/test.py index 3a158f3a28d..eda018b725b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/test.py +++ b/mindspore/python/mindspore/rewrite_experiment/test.py @@ -65,7 +65,7 @@ class LeNet5(nn.Cell): def test_compile(): lenet = LeNet5(10) - stb: SymbolTable = Compiler.compile(lenet) + _, stb = Compiler.compile(lenet) logger.warning("After compile: %d", len(stb.get_symbols().values())) -- Gitee From a61cdc36c462fab0659b1f8d43afb29b13060dff Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 27 Dec 2021 16:15:04 +0800 Subject: [PATCH 18/34] debug compile --- .../mindspore/rewrite_experiment/compiler.py | 16 ++++++++-------- .../rewrite_experiment/function_symbol.py | 4 ++-- .../rewrite_experiment/passes/class_def_pass.py | 2 +- .../passes/function_def_pass.py | 6 ++++-- .../rewrite_experiment/passes/module_pass.py | 2 +- .../mindspore/rewrite_experiment/symbol.py | 8 ++++++-- .../mindspore/rewrite_experiment/symbol_table.py | 2 +- 7 files changed, 23 insertions(+), 17 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/compiler.py b/mindspore/python/mindspore/rewrite_experiment/compiler.py index 6872a79db83..9891a389e88 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compiler.py @@ -42,15 +42,15 @@ class Compiler: return True @staticmethod - def compile_symbol_by_pass(iterator: MutableDictIterator, pass_: Pass, graph: Graph) -> MutableDictIterator: + def compile_symbol_by_pass(iterator: MutableDictIterator, pass_: Pass, graph: Graph) -> bool: symbol: Symbol = iterator.value() if not Compiler._is_leaf_symbol(symbol): # not compilable, skip logger.warning("Processing symbol(%s) by pass(%s), leaf symbol", symbol.get_symbol_name(), pass_.name()) - return iterator + return False results: [Symbol] = pass_.process(symbol, graph) if len(results) == 1 and results[0] is symbol: # no change in process logger.warning("Processing symbol(%s) by pass(%s), not changed", symbol.get_symbol_name(), pass_.name()) - return iterator + return False iterator = iterator.erase() logger.warning("Processing symbol(%s) by pass(%s), replaced by %d new symbols", symbol.get_symbol_name(), pass_.name(), len(results)) @@ -72,18 +72,18 @@ class Compiler: # user.set_inputs(new_inputs) # if graph.get_return() is node: # graph.set_return(result) - return iterator + return True @staticmethod def compile_stb_by_pass(stb: SymbolTable, pass_: Pass, graph: Graph) -> bool: changed = False iterator = MutableDictIterator(stb.get_symbols()) while not iterator.is_end(): - new_iterator = Compiler.compile_symbol_by_pass(iterator, pass_, graph) - if new_iterator == iterator: - iterator = next(iterator) - else: + cur_changed = Compiler.compile_symbol_by_pass(iterator, pass_, graph) + if cur_changed: changed = True + else: + iterator = next(iterator) return changed @staticmethod diff --git a/mindspore/python/mindspore/rewrite_experiment/function_symbol.py b/mindspore/python/mindspore/rewrite_experiment/function_symbol.py index 26870e0480b..8a8d672b3c5 100644 --- a/mindspore/python/mindspore/rewrite_experiment/function_symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/function_symbol.py @@ -17,8 +17,8 @@ from .symbol import Symbol, SymbolType class FunctionSymbol(Symbol): - def __init__(self, ast_node: ast.AST, symbol_name, args=None, returns=None, bodies=None): - super().__init__(ast_node, symbol_name, SymbolType.function_def) + def __init__(self, ast_node: ast.AST, scope, symbol_name, args=None, returns=None, bodies=None): + super().__init__(ast_node, scope, symbol_name, SymbolType.function_def) self._args = [] if isinstance(args, list): for arg in args: diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py index 1ac81b98801..6724c3e0e44 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py @@ -33,7 +33,7 @@ class ClassDefPass(Pass): new_symbols: [Symbol] = [] for body in bodies: if isinstance(body, ast.FunctionDef): - new_node = Symbol(body, body.name, SymbolType.function_def) + new_node = Symbol(body, symbol.get_symbol_name(), "Function(" + body.name + ")", SymbolType.function_def) new_symbols.append(new_node) else: if hasattr(body, "name"): diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py index c8057e326bf..7bebcfd34ef 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py @@ -100,13 +100,15 @@ class FunctionDefPass(Pass): # self._parser.updete_closure_namespace(self._network.__init__) returns = self._parse_returns(function_def.returns) - function_symbol = FunctionSymbol(function_def, function_def.name, args_with_value, returns) + function_symbol = FunctionSymbol(function_def, symbol.get_scope(), "Function(" + function_def.name + ")", + args_with_value, returns) new_symbols = [function_symbol] bodies: list = function_def.body index = 0 for body in bodies: if isinstance(body, ast.Assign): - body_symbol = Symbol(body, function_def.name + "-assign-" + str(index), SymbolType.assign) + body_symbol = Symbol(body, symbol.get_symbol_name(), "Assign(assign-" + str(index) + ")", + SymbolType.assign) index += 1 function_symbol.add_body(body_symbol) new_symbols.append(body_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py b/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py index b8cbe5926ba..89ae5738f4c 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py @@ -33,7 +33,7 @@ class ModulePass(Pass): new_symbols: [Symbol] = [] for body in bodies: if isinstance(body, ast.ClassDef): - new_node = Symbol(body, body.name, SymbolType.class_def) + new_node = Symbol(body, symbol.get_symbol_name(), "Class(" + body.name + ")", SymbolType.class_def) new_symbols.append(new_node) else: if hasattr(body, "name"): diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index accb30d6653..c54a3c8abd0 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -27,7 +27,7 @@ class SymbolType: class Symbol: - def __init__(self, ast_node: ast.AST, symbol_name, symbol_type=SymbolType.invalid): + def __init__(self, ast_node: ast.AST, scope, symbol_name, symbol_type=SymbolType.invalid): # self._attribute: AttributeNode = AttributeNode() # if outputs is None: # self._outputs: List[CellNode] = list() @@ -39,11 +39,15 @@ class Symbol: # self._inputs = inputs # self._targets: List[str] = targets # 用来保存算子输出结果的名称,用来匹配算子输入名称 # self._args: List = args - self._symbol_name: str = symbol_name self._symbol_type: int = symbol_type self._ast_root: ast.AST = ast_node self._attribute: dict = {} self._sub_symbols: list = [] + self._scope = scope + self._symbol_name: str = scope + "." + symbol_name + + def get_scope(self) -> str: + return self._scope def get_symbol_name(self) -> str: return self._symbol_name diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_table.py b/mindspore/python/mindspore/rewrite_experiment/symbol_table.py index aa7352e5588..ae6f8d7733b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_table.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_table.py @@ -39,7 +39,7 @@ class SymbolTable: network_str = inspect.getsource(net_cls) ast_root: ast.AST = ast.parse(network_str) name = net_cls.__name__ - root_symbol = Symbol(ast_root, name, SymbolType.module) + root_symbol = Symbol(ast_root, "", "Module(" + name + ")", SymbolType.module) self._table = {name: root_symbol} def remove_symbol(self, symbol: Symbol): -- Gitee From fb54217adb8e88607afa54e6b6be8c9bef09f7ae Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 27 Dec 2021 19:01:09 +0800 Subject: [PATCH 19/34] rename compile to rewrite rename pass to compiler add linker add assign_compiler --- .../mindspore/rewrite_experiment/__init__.py | 12 +- .../mindspore/rewrite_experiment/compiler.py | 102 ++++------------ .../rewrite_experiment/compilers/__init__.py | 1 + .../compilers/assign_compiler.py | 54 +++++++++ .../class_def_compiler.py} | 16 +-- .../function_def_compiler.py} | 21 ++-- .../module_compiler.py} | 16 +-- .../{passs.py => linker.py} | 5 +- .../construct_function_def_linker.py} | 36 +++--- .../init_function_def_linker.py} | 16 +-- .../rewrite_experiment/passes/__init__.py | 0 .../{function_symbol.py => registers.py} | 80 ++++++++----- .../mindspore/rewrite_experiment/rewrite.py | 110 ++++++++++++++++++ .../mindspore/rewrite_experiment/symbol.py | 72 +++++++++++- .../rewrite_experiment/symbol_table.py | 14 ++- .../mindspore/rewrite_experiment/test.py | 7 +- 16 files changed, 374 insertions(+), 188 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/__init__.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py rename mindspore/python/mindspore/rewrite_experiment/{passes/class_def_pass.py => compilers/class_def_compiler.py} (78%) rename mindspore/python/mindspore/rewrite_experiment/{passes/function_def_pass.py => compilers/function_def_compiler.py} (88%) rename mindspore/python/mindspore/rewrite_experiment/{passes/module_pass.py => compilers/module_compiler.py} (79%) rename mindspore/python/mindspore/rewrite_experiment/{passs.py => linker.py} (92%) rename mindspore/python/mindspore/rewrite_experiment/{pass_register.py => linkers/construct_function_def_linker.py} (47%) rename mindspore/python/mindspore/rewrite_experiment/{passes/init_function_def_pass.py => linkers/init_function_def_linker.py} (81%) delete mode 100644 mindspore/python/mindspore/rewrite_experiment/passes/__init__.py rename mindspore/python/mindspore/rewrite_experiment/{function_symbol.py => registers.py} (33%) create mode 100644 mindspore/python/mindspore/rewrite_experiment/rewrite.py diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py index 3fc0bf97a13..175902ff997 100644 --- a/mindspore/python/mindspore/rewrite_experiment/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -1,9 +1,9 @@ from .graph import Graph from .symbol_table import SymbolTable -from .compiler import Compiler -from .passes.module_pass import ModulePass -from .passes.class_def_pass import ClassDefPass -from .passes.function_def_pass import FunctionDefPass -from .passes.init_function_def_pass import InitFunctionDefPass +from .rewrite import Rewrite +from .compilers.module_compiler import ModuleCompiler +from .compilers.class_def_compiler import ClassDefCompiler +from .compilers.function_def_compiler import FunctionDefCompiler +from .compilers.assign_compiler import AssignCompiler -__all__ = ["Graph", "Compiler", "SymbolTable"] +__all__ = ["Graph", "Rewrite", "SymbolTable"] diff --git a/mindspore/python/mindspore/rewrite_experiment/compiler.py b/mindspore/python/mindspore/rewrite_experiment/compiler.py index 9891a389e88..2054ca20422 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compiler.py @@ -12,90 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from typing import Union, Tuple -from .symbol_table import SymbolTable -from .symbol import Symbol, SymbolType -from .passs import Pass -from .pass_register import PassRegister -import mindspore.nn as nn -from mindspore.ops.primitive import Primitive -from types import FunctionType -from mindspore import log as logger -from .mutable_dict_iterator import MutableDictIterator -from .graph import Graph +import abc +from .symbol import Symbol -class Compiler: - @staticmethod - def _is_leaf_symbol(symbol: Symbol): - # inseparable python node: - if symbol.symbol_type() is SymbolType.constant: - return False - # not supported yet: - # if node.node_type() is NodeType.graph: - # return False - # mindspore ops: - # ms_ops: tuple = (Cell,) - # if node.node_type() is NodeType.call_cell and issubclass(node.class_type(), ms_ops): - # return False - return True +class Compiler(abc.ABC): + """ + Pass eat a node and return a node. It processes a node by parse one type of ast node further one more + """ - @staticmethod - def compile_symbol_by_pass(iterator: MutableDictIterator, pass_: Pass, graph: Graph) -> bool: - symbol: Symbol = iterator.value() - if not Compiler._is_leaf_symbol(symbol): # not compilable, skip - logger.warning("Processing symbol(%s) by pass(%s), leaf symbol", symbol.get_symbol_name(), pass_.name()) - return False - results: [Symbol] = pass_.process(symbol, graph) - if len(results) == 1 and results[0] is symbol: # no change in process - logger.warning("Processing symbol(%s) by pass(%s), not changed", symbol.get_symbol_name(), pass_.name()) - return False - iterator = iterator.erase() - logger.warning("Processing symbol(%s) by pass(%s), replaced by %d new symbols", symbol.get_symbol_name(), - pass_.name(), len(results)) - for result in results: - iterator = iterator.insert(result.get_symbol_name(), result) - if result.symbol_type() != SymbolType.cell: - continue - raise NotImplementedError - # result.set_targets(node.get_targets()) - # result.set_outputs(node.get_outputs()) - # users = node.get_outputs() - # for user in users: - # new_inputs = [] - # for user_input in user.get_inputs(): - # if user_input is node: - # new_inputs.append(result) - # else: - # new_inputs.append(user_input) - # user.set_inputs(new_inputs) - # if graph.get_return() is node: - # graph.set_return(result) - return True + @abc.abstractmethod + def name(self) -> str: + raise NotImplementedError - @staticmethod - def compile_stb_by_pass(stb: SymbolTable, pass_: Pass, graph: Graph) -> bool: - changed = False - iterator = MutableDictIterator(stb.get_symbols()) - while not iterator.is_end(): - cur_changed = Compiler.compile_symbol_by_pass(iterator, pass_, graph) - if cur_changed: - changed = True - else: - iterator = next(iterator) - return changed - - @staticmethod - def compile(network: Union[nn.Cell, Primitive, FunctionType]) -> Tuple[Graph, SymbolTable]: - stb = SymbolTable(network) - graph = Graph() - passes: [Pass] = PassRegister.instance().get_passes() - logger.warning("------------- Load %d passes", len(passes)) - changed = True - while changed: - changed = False - for key in passes: - pass_: Pass = passes[key] - changed |= Compiler.compile_stb_by_pass(stb, pass_, graph) - return graph, stb + @abc.abstractmethod + def process(self, symbol: Symbol) -> [Symbol]: + """ + Args: + symbol (Symbol): node who is tried to be processed + Returns: + Symbols after processed. Function should keep the inputs of output node. + node after processed. Function should keep the inputs of output node. + return node can not be None + """ + raise NotImplementedError diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/__init__.py b/mindspore/python/mindspore/rewrite_experiment/compilers/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/__init__.py @@ -0,0 +1 @@ + diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py new file mode 100644 index 00000000000..6d47ef9b172 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py @@ -0,0 +1,54 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..symbol import Symbol, SymbolType, AssignSymbol +from ..compiler import Compiler +from ..registers import CompilerRegister + + +@CompilerRegister.reg_compiler +class AssignCompiler(Compiler): + def name(self) -> str: + return "_AssignCompiler_" + + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.Assign): + return [symbol] + if isinstance(symbol, AssignSymbol): + return [symbol] + new_symbols = [] + ast_assign: ast.Assign = symbol.get_ast() + targets = ast_assign.targets + target_symbols = [] + target_index = 0 + for target in targets: + if not isinstance(target, ast.expr): + raise RuntimeError("Target of assign should be ast.expr") + target_symbol = Symbol(target, symbol.get_full_name_with_scope(), + "Target(target-" + str(target_index) + ")", SymbolType.expression) + target_index += 1 + target_symbols.append(target_symbol) + new_symbols.append(target_symbol) + + if not isinstance(ast_assign.value, ast.expr): + raise RuntimeError("Value of assign should be ast.expr") + value_symbol = Symbol(ast_assign.value, symbol.get_full_name_with_scope(), + "Value(value-" + str(target_index) + ")", SymbolType.expression) + new_symbols.append(value_symbol) + assign_symbol = AssignSymbol(ast_assign, symbol.get_scope(), symbol.get_symbol_name(), target_symbols, + value_symbol) + new_symbols.append(assign_symbol) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py b/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py similarity index 78% rename from mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py rename to mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py index 6724c3e0e44..107c6e1dd3f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/class_def_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py @@ -14,18 +14,17 @@ # ============================================================================ import ast from ..symbol import Symbol, SymbolType -from ..passs import Pass -from ..pass_register import PassRegister +from ..compiler import Compiler from mindspore import log as logger -from ..graph import Graph +from ..registers import CompilerRegister -@PassRegister.reg_pass -class ClassDefPass(Pass): +@CompilerRegister.reg_compiler +class ClassDefCompiler(Compiler): def name(self) -> str: - return "_ClassDefPass_" + return "_ClassDefCompiler_" - def process(self, symbol: Symbol, graph: Graph) -> [Symbol]: + def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol.get_ast(), ast.ClassDef): return [symbol] class_def: ast.ClassDef = symbol.get_ast() @@ -33,7 +32,8 @@ class ClassDefPass(Pass): new_symbols: [Symbol] = [] for body in bodies: if isinstance(body, ast.FunctionDef): - new_node = Symbol(body, symbol.get_symbol_name(), "Function(" + body.name + ")", SymbolType.function_def) + new_node = Symbol(body, symbol.get_full_name_with_scope(), "Function(" + body.name + ")", + SymbolType.function_def) new_symbols.append(new_node) else: if hasattr(body, "name"): diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py similarity index 88% rename from mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py rename to mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py index 7bebcfd34ef..ef08ec2a379 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/function_def_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py @@ -15,18 +15,16 @@ import ast from typing import List -from ..symbol import Symbol, SymbolType -from ..function_symbol import FunctionSymbol -from ..passs import Pass -from ..pass_register import PassRegister +from ..symbol import Symbol, SymbolType, FunctionSymbol +from ..compiler import Compiler from mindspore import log as logger -from ..graph import Graph +from ..registers import CompilerRegister -@PassRegister.reg_pass -class FunctionDefPass(Pass): +@CompilerRegister.reg_compiler +class FunctionDefCompiler(Compiler): def name(self) -> str: - return "_FunctionDefPass_" + return "_FunctionDefCompiler_" def _parse_arguments(self, arguments_node: ast.arguments): class Arg: @@ -88,7 +86,7 @@ class FunctionDefPass(Pass): def _parse_returns(self, returns_node): return [] - def process(self, symbol: Symbol, graph: Graph) -> [Symbol]: + def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol.get_ast(), ast.FunctionDef): return [symbol] if isinstance(symbol, FunctionSymbol): @@ -97,17 +95,16 @@ class FunctionDefPass(Pass): # parse args args: ast.arguments = function_def.args args_with_value = self._parse_arguments(args) - # self._parser.updete_closure_namespace(self._network.__init__) returns = self._parse_returns(function_def.returns) - function_symbol = FunctionSymbol(function_def, symbol.get_scope(), "Function(" + function_def.name + ")", + function_symbol = FunctionSymbol(function_def, symbol.get_scope(), symbol.get_symbol_name(), args_with_value, returns) new_symbols = [function_symbol] bodies: list = function_def.body index = 0 for body in bodies: if isinstance(body, ast.Assign): - body_symbol = Symbol(body, symbol.get_symbol_name(), "Assign(assign-" + str(index) + ")", + body_symbol = Symbol(body, symbol.get_full_name_with_scope(), "Assign(assign-" + str(index) + ")", SymbolType.assign) index += 1 function_symbol.add_body(body_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py b/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py similarity index 79% rename from mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py rename to mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py index 89ae5738f4c..ee241f43156 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/module_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py @@ -14,18 +14,17 @@ # ============================================================================ import ast from ..symbol import Symbol, SymbolType -from ..passs import Pass -from ..pass_register import PassRegister +from ..compiler import Compiler from mindspore import log as logger -from ..graph import Graph +from ..registers import CompilerRegister -@PassRegister.reg_pass -class ModulePass(Pass): +@CompilerRegister.reg_compiler +class ModuleCompiler(Compiler): def name(self) -> str: - return "_ModulePass_" + return "_ModuleCompiler_" - def process(self, symbol: Symbol, graph: Graph) -> [Symbol]: + def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol.get_ast(), ast.Module): return [symbol] module: ast.Module = symbol.get_ast() @@ -33,7 +32,8 @@ class ModulePass(Pass): new_symbols: [Symbol] = [] for body in bodies: if isinstance(body, ast.ClassDef): - new_node = Symbol(body, symbol.get_symbol_name(), "Class(" + body.name + ")", SymbolType.class_def) + new_node = Symbol(body, symbol.get_full_name_with_scope(), "Class(" + body.name + ")", + SymbolType.class_def) new_symbols.append(new_node) else: if hasattr(body, "name"): diff --git a/mindspore/python/mindspore/rewrite_experiment/passs.py b/mindspore/python/mindspore/rewrite_experiment/linker.py similarity index 92% rename from mindspore/python/mindspore/rewrite_experiment/passs.py rename to mindspore/python/mindspore/rewrite_experiment/linker.py index c9104ec41de..13f77b03454 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passs.py +++ b/mindspore/python/mindspore/rewrite_experiment/linker.py @@ -15,10 +15,9 @@ import abc from .symbol import Symbol -from .graph import Graph -class Pass(abc.ABC): +class Linker(abc.ABC): """ Pass eat a node and return a node. It processes a node by parse one type of ast node further one more """ @@ -28,7 +27,7 @@ class Pass(abc.ABC): raise NotImplementedError @abc.abstractmethod - def process(self, symbol: Symbol, graph: Graph) -> [Symbol]: + def process(self, symbol: Symbol) -> [Symbol]: """ Args: symbol (Symbol): node who is tried to be processed diff --git a/mindspore/python/mindspore/rewrite_experiment/pass_register.py b/mindspore/python/mindspore/rewrite_experiment/linkers/construct_function_def_linker.py similarity index 47% rename from mindspore/python/mindspore/rewrite_experiment/pass_register.py rename to mindspore/python/mindspore/rewrite_experiment/linkers/construct_function_def_linker.py index 402551f8ecd..38d76474f4f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/pass_register.py +++ b/mindspore/python/mindspore/rewrite_experiment/linkers/construct_function_def_linker.py @@ -12,28 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +from ..symbol import Symbol, FunctionSymbol +from ..linker import Linker +from ..registers import LinkerRegister -from .passs import Pass - -class PassRegister: +@LinkerRegister.reg_linker +class ConstructFunctionDefLinker(Linker): def __init__(self): - self._passes: dict = {} - - @classmethod - def instance(cls): - if not hasattr(PassRegister, "_instance"): - PassRegister._instance = PassRegister() - return PassRegister._instance + self._function_def_construct_visited_attr_key = "function_def_construct_visited_key" - @staticmethod - def reg_pass(pass_cls: type): - if issubclass(pass_cls, Pass): - pass_ = pass_cls() - PassRegister.instance()._passes[pass_.name()] = pass_ + def name(self) -> str: + return "_ConstructFunctionDefPass_" - def get_pass(self, name: str): - return self._passes.get(name) + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol, FunctionSymbol): + return [symbol] + if symbol.get_ast().name is not "construct": + return [symbol] + if symbol.get_attr(self._function_def_construct_visited_attr_key): + return [symbol] - def get_passes(self): - return self._passes + symbol.set_attr(self._function_def_construct_visited_attr_key, True) + return [symbol] diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/init_function_def_pass.py b/mindspore/python/mindspore/rewrite_experiment/linkers/init_function_def_linker.py similarity index 81% rename from mindspore/python/mindspore/rewrite_experiment/passes/init_function_def_pass.py rename to mindspore/python/mindspore/rewrite_experiment/linkers/init_function_def_linker.py index 2cf8a20726f..2cfbb810d29 100644 --- a/mindspore/python/mindspore/rewrite_experiment/passes/init_function_def_pass.py +++ b/mindspore/python/mindspore/rewrite_experiment/linkers/init_function_def_linker.py @@ -12,24 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from ..symbol import Symbol -from ..function_symbol import FunctionSymbol -from ..passs import Pass -from ..pass_register import PassRegister -from ..graph import Graph +from ..symbol import Symbol, FunctionSymbol +from ..linker import Linker +from ..registers import LinkerRegister -@PassRegister.reg_pass -class InitFunctionDefPass(Pass): +@LinkerRegister.reg_linker +class InitFunctionDefLinker(Linker): def __init__(self): self._function_def_init_visited_attr_key = "function_def_init_visited_key" def name(self) -> str: return "_InitFunctionDefPass_" - def process(self, symbol: Symbol, graph: Graph) -> [Symbol]: + def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol, FunctionSymbol): return [symbol] + if symbol.get_ast().name is not "__init__": + return [symbol] if symbol.get_attr(self._function_def_init_visited_attr_key): return [symbol] args = symbol.get_ast().args diff --git a/mindspore/python/mindspore/rewrite_experiment/passes/__init__.py b/mindspore/python/mindspore/rewrite_experiment/passes/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/mindspore/python/mindspore/rewrite_experiment/function_symbol.py b/mindspore/python/mindspore/rewrite_experiment/registers.py similarity index 33% rename from mindspore/python/mindspore/rewrite_experiment/function_symbol.py rename to mindspore/python/mindspore/rewrite_experiment/registers.py index 8a8d672b3c5..e8eca35507a 100644 --- a/mindspore/python/mindspore/rewrite_experiment/function_symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/registers.py @@ -12,34 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import ast -from .symbol import Symbol, SymbolType - - -class FunctionSymbol(Symbol): - def __init__(self, ast_node: ast.AST, scope, symbol_name, args=None, returns=None, bodies=None): - super().__init__(ast_node, scope, symbol_name, SymbolType.function_def) - self._args = [] - if isinstance(args, list): - for arg in args: - if not isinstance(arg, Symbol): - raise RuntimeError("Input arg is not Symbol type") - self._args.append(arg) - - if isinstance(returns, list): - self._returns = returns - else: - self._returns = [] - - self._bodies = [] - if isinstance(bodies, list): - for body in bodies: - if not isinstance(body, Symbol): - raise RuntimeError("Input bodies is not Symbol type") - self._bodies.append(body) - - def add_body(self, body: Symbol): - self._bodies.append(body) - - def get_args(self): - return self._args + +from .compiler import Compiler +from .linker import Linker + + +class CompilerRegister: + def __init__(self): + self._compilers: dict = {} + + @classmethod + def instance(cls): + if not hasattr(CompilerRegister, "_instance"): + CompilerRegister._instance = CompilerRegister() + return CompilerRegister._instance + + @staticmethod + def reg_compiler(compiler_cls: type): + if issubclass(compiler_cls, Compiler): + compiler = compiler_cls() + CompilerRegister.instance()._compilers[compiler.name()] = compiler + + def get_compiler(self, name: str): + return self._compilers.get(name) + + def get_compilers(self): + return self._compilers + + +class LinkerRegister: + def __init__(self): + self._linkers: dict = {} + + @classmethod + def instance(cls): + if not hasattr(LinkerRegister, "_instance"): + LinkerRegister._instance = LinkerRegister() + return LinkerRegister._instance + + @staticmethod + def reg_linker(linker_cls: type): + if issubclass(linker_cls, Linker): + linker = linker_cls() + LinkerRegister.instance()._linkers[linker.name()] = linker + + def get_linker(self, name: str): + return self._linkers.get(name) + + def get_linkers(self): + return self._linkers diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py new file mode 100644 index 00000000000..4034304338a --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -0,0 +1,110 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from typing import Union, Tuple + +from .symbol_table import SymbolTable +from .symbol import Symbol, SymbolType +from .compiler import Compiler +from .registers import CompilerRegister +import mindspore.nn as nn +from mindspore.ops.primitive import Primitive +from types import FunctionType +from mindspore import log as logger +from .mutable_dict_iterator import MutableDictIterator +from .graph import Graph + + +class Rewrite: + @staticmethod + def _is_leaf_symbol(symbol: Symbol): + # inseparable python node: + if symbol.symbol_type() is SymbolType.constant: + return False + # not supported yet: + # if node.node_type() is NodeType.graph: + # return False + # mindspore ops: + # ms_ops: tuple = (Cell,) + # if node.node_type() is NodeType.call_cell and issubclass(node.class_type(), ms_ops): + # return False + return True + + @staticmethod + def compile_symbol_by_compiler(iterator: MutableDictIterator, compiler: Compiler) -> bool: + symbol: Symbol = iterator.value() + if not symbol.is_compilable(): # not compilable, skip + logger.warning("Processing symbol(%s) by compiler(%s), leaf symbol", symbol.get_symbol_name(), + compiler.name()) + return False + results: [Symbol] = compiler.process(symbol) + if len(results) == 1 and results[0] is symbol: # no change in process + logger.warning("Processing symbol(%s) by compiler(%s), not changed", symbol.get_symbol_name(), + compiler.name()) + return False + iterator = iterator.erase() + logger.warning("Processing symbol(%s) by compiler(%s), replaced by %d new symbols", symbol.get_symbol_name(), + compiler.name(), len(results)) + for result in results: + iterator = iterator.insert(result.get_full_name_with_scope(), result) + if result.symbol_type() != SymbolType.cell: + continue + raise NotImplementedError + # result.set_targets(node.get_targets()) + # result.set_outputs(node.get_outputs()) + # users = node.get_outputs() + # for user in users: + # new_inputs = [] + # for user_input in user.get_inputs(): + # if user_input is node: + # new_inputs.append(result) + # else: + # new_inputs.append(user_input) + # user.set_inputs(new_inputs) + # if graph.get_return() is node: + # graph.set_return(result) + return True + + @staticmethod + def compile_stb_by_compiler(stb: SymbolTable, compiler: Compiler) -> bool: + changed = False + iterator = MutableDictIterator(stb.get_symbols()) + while not iterator.is_end(): + cur_changed = Rewrite.compile_symbol_by_compiler(iterator, compiler) + if cur_changed: + changed = True + else: + iterator = next(iterator) + return changed + + @staticmethod + def compile(network: Union[nn.Cell, Primitive, FunctionType]) -> SymbolTable: + stb = SymbolTable(network) + compilers: [Compiler] = CompilerRegister.instance().get_compilers() + logger.warning("------------- Load %d compilers", len(compilers)) + changed = True + while changed: + changed = False + for key in compilers: + compiler: Compiler = compilers[key] + changed |= Rewrite.compile_stb_by_compiler(stb, compiler) + return stb + + @staticmethod + def rewrite(network: Union[nn.Cell, Primitive, FunctionType]) -> Tuple[Graph, SymbolTable]: + # compile ast to symbols until all symbols are not compilable + stb = Rewrite.compile(network) + # link symbols, convert construct-function to graph + graph = Graph() + return graph, stb diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index c54a3c8abd0..fec3773a581 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -20,9 +20,10 @@ class SymbolType: class_def = 2 function_def = 3 assign = 4 - constant = 5 - cell = 6 - primitive = 7 + expression = 5 + constant = 6 + cell = 7 + primitive = 8 invalid = 100 @@ -44,7 +45,10 @@ class Symbol: self._attribute: dict = {} self._sub_symbols: list = [] self._scope = scope - self._symbol_name: str = scope + "." + symbol_name + self._symbol_name: str = symbol_name + + def is_compilable(self): + return True def get_scope(self) -> str: return self._scope @@ -52,6 +56,9 @@ class Symbol: def get_symbol_name(self) -> str: return self._symbol_name + def get_full_name_with_scope(self) -> str: + return self._scope + "." + self._symbol_name + def set_symbol_name(self, symbol_name: str): self._symbol_name = symbol_name @@ -72,3 +79,60 @@ class Symbol: def get_sub_symbols(self): return self._sub_symbols + + +class FunctionSymbol(Symbol): + def __init__(self, ast_node: ast.AST, scope, symbol_name, args=None, returns=None, bodies=None): + super().__init__(ast_node, scope, symbol_name, SymbolType.function_def) + self._args = [] + if isinstance(args, list): + for arg in args: + if not isinstance(arg, Symbol): + raise RuntimeError("Input arg is not Symbol type") + self._args.append(arg) + + if isinstance(returns, list): + self._returns = returns + else: + self._returns = [] + + self._bodies = [] + if isinstance(bodies, list): + for body in bodies: + if not isinstance(body, Symbol): + raise RuntimeError("Input bodies is not Symbol type") + self._bodies.append(body) + + def add_body(self, body: Symbol): + self._bodies.append(body) + + def get_args(self): + return self._args + + def is_compilable(self): + return False + + +class AssignSymbol(Symbol): + def __init__(self, ast_node: ast.AST, scope, symbol_name, targets=None, value=None): + super().__init__(ast_node, scope, symbol_name, SymbolType.function_def) + self._targets = [] + if isinstance(targets, list): + for target in targets: + if not isinstance(target, Symbol): + raise RuntimeError("Input targets is not Symbol type") + self._targets.append(target) + self._value = value + + def is_compilable(self): + return False + + +class ConstantSymbol(Symbol): + def __init__(self, ast_node: ast.AST, scope, symbol_name, value, const_type): # const_type: number, string + super().__init__(ast_node, scope, symbol_name, SymbolType.constant) + self._value = value + self._const_type = const_type + + def is_compilable(self): + return False diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_table.py b/mindspore/python/mindspore/rewrite_experiment/symbol_table.py index ae6f8d7733b..ccc2b6d8db4 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_table.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_table.py @@ -43,17 +43,23 @@ class SymbolTable: self._table = {name: root_symbol} def remove_symbol(self, symbol: Symbol): - symbol_ = self._table.get(symbol.get_symbol_name()) + symbol_ = self._table.get(symbol.get_full_name_with_scope()) if symbol_ is not None: - self._table.pop(symbol.get_symbol_name()) + self._table.pop(symbol.get_full_name_with_scope()) def add_symbol(self, symbol: Symbol): - symbol_ = self._table.get(symbol.get_symbol_name()) + symbol_ = self._table.get(symbol.get_full_name_with_scope()) if symbol_ is None: - self._table[symbol.get_symbol_name()] = symbol + self._table[symbol.get_full_name_with_scope()] = symbol def get_symbols(self): return self._table def get_symbol(self, key: str): return self._table.get(key) + + def print(self): + print("=================================================================================================") + for k, v in self._table.items(): + print(v.get_full_name_with_scope()) + print("=================================================================================================") diff --git a/mindspore/python/mindspore/rewrite_experiment/test.py b/mindspore/python/mindspore/rewrite_experiment/test.py index eda018b725b..255cbddd578 100644 --- a/mindspore/python/mindspore/rewrite_experiment/test.py +++ b/mindspore/python/mindspore/rewrite_experiment/test.py @@ -16,7 +16,7 @@ from mindspore import nn from mindspore.common.initializer import Normal -from mindspore.rewrite_experiment import Compiler, SymbolTable +from mindspore.rewrite_experiment import Rewrite, SymbolTable from mindspore import log as logger @@ -65,8 +65,9 @@ class LeNet5(nn.Cell): def test_compile(): lenet = LeNet5(10) - _, stb = Compiler.compile(lenet) - logger.warning("After compile: %d", len(stb.get_symbols().values())) + _, stb = Rewrite.rewrite(lenet) + logger.warning("After rewrite: %d", len(stb.get_symbols().values())) + stb.print() if __name__ == '__main__': -- Gitee From 65c275fefbbcc6e79a9ef6c7a1bc0318dbd5ac04 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Tue, 28 Dec 2021 16:16:14 +0800 Subject: [PATCH 20/34] add owner for symbol --- .../compilers/assign_compiler.py | 24 +++++++++++++++++++ .../mindspore/rewrite_experiment/rewrite.py | 1 + .../mindspore/rewrite_experiment/symbol.py | 12 ++++++++++ 3 files changed, 37 insertions(+) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py index 6d47ef9b172..20017ff140e 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py @@ -21,6 +21,30 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class AssignCompiler(Compiler): + # def parse_call(self, ast_node: ast.Call, node: Node): + # """ + # Parse Call node in ast. + # """ + # nodes = [] + # called_obj_names = [] + # visitor = self.get_node_visitor(ast_node.func) + # called_obj_name = visitor(ast_node.func) + # node.name = called_obj_name + # + # args_, nodes_, called_obj_names_ = self._parse_args(ast_node.args) + # + # visitor = self.get_node_visitor(ast_node.keywords) + # kwargs_: Dict = visitor(ast_node.keywords) + # + # node._args = args_ + # node._kwargs = kwargs_ + # nodes.extend(nodes_) + # nodes.append(node) + # called_obj_names.extend(called_obj_names_) + # called_obj_names.append(called_obj_name) + # + # return nodes, called_obj_names + def name(self) -> str: return "_AssignCompiler_" diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index 4034304338a..8d431d0fa27 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -106,5 +106,6 @@ class Rewrite: # compile ast to symbols until all symbols are not compilable stb = Rewrite.compile(network) # link symbols, convert construct-function to graph + graph = Graph() return graph, stb diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index fec3773a581..ce217139d67 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -46,6 +46,7 @@ class Symbol: self._sub_symbols: list = [] self._scope = scope self._symbol_name: str = symbol_name + self._owner = None def is_compilable(self): return True @@ -128,6 +129,17 @@ class AssignSymbol(Symbol): return False +class CallSymbol(Symbol): + def __init__(self, ast_node: ast.AST, scope, symbol_name, value, const_type): # const_type: number, string + super().__init__(ast_node, scope, symbol_name, SymbolType.constant) + self._func_name = "" + self._args = [] + self._kwargs = {} + + def is_compilable(self): + return False + + class ConstantSymbol(Symbol): def __init__(self, ast_node: ast.AST, scope, symbol_name, value, const_type): # const_type: number, string super().__init__(ast_node, scope, symbol_name, SymbolType.constant) -- Gitee From e59b1fbb3b2b3ff0f2728b21519b76e8d03e1062 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 28 Dec 2021 19:15:20 +0800 Subject: [PATCH 21/34] add arguments compiler --- .../mindspore/rewrite_experiment/__init__.py | 1 + .../compilers/arguments_compiler.py | 68 +++++++++++++++++++ .../compilers/function_def_compiler.py | 65 ++---------------- .../mindspore/rewrite_experiment/symbol.py | 40 ++++++++--- 4 files changed, 105 insertions(+), 69 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py index 175902ff997..f6520d688bc 100644 --- a/mindspore/python/mindspore/rewrite_experiment/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -5,5 +5,6 @@ from .compilers.module_compiler import ModuleCompiler from .compilers.class_def_compiler import ClassDefCompiler from .compilers.function_def_compiler import FunctionDefCompiler from .compilers.assign_compiler import AssignCompiler +from .compilers.arguments_compiler import ArgumentsCompiler __all__ = ["Graph", "Rewrite", "SymbolTable"] diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py new file mode 100644 index 00000000000..80916bd3bdd --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py @@ -0,0 +1,68 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..symbol import Symbol, SymbolType, ArgumentsSymbol +from ..compiler import Compiler +from ..registers import CompilerRegister + + +@CompilerRegister.reg_compiler +class ArgumentsCompiler(Compiler): + def name(self) -> str: + return "_ArgumentsCompiler_" + + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.arguments): + return [symbol] + if isinstance(symbol, ArgumentsSymbol): + return [symbol] + new_symbols = [] + ast_arguments: ast.arguments = symbol.get_ast() + + args_ = ast_arguments.args + defaults_ = ast_arguments.defaults + + arg_symbols = [] + arg_index = 0 + for arg in args_: + if arg.arg == "self": + continue + + if not isinstance(arg, ast.arg): + raise RuntimeError("arg of arguments should be ast.arg") + arg_symbol = Symbol(arg, symbol.get_full_name_with_scope(), + "Arg(arg-" + str(arg_index) + ")", SymbolType.expression) + arg_index += 1 + arg_symbols.append(arg_symbol) + new_symbols.append(arg_symbol) + + default_symbols = [] + default_index = 0 + for single_default in defaults_: + if not isinstance(single_default, ast.Constant): + raise RuntimeError("only support constant default value in arguments now") + + default_symbol = Symbol(single_default, symbol.get_full_name_with_scope(), + "Default(default-" + str(default_index) + ")", SymbolType.constant) + default_index += 1 + default_symbols.append(default_symbol) + new_symbols.append(default_symbol) + + arguments_symbol = ArgumentsSymbol( + ast_arguments, symbol.get_scope(), symbol.get_symbol_name(), arg_symbols, default_symbols + ) + new_symbols.append(arguments_symbol) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py index ef08ec2a379..2645f960aa0 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py @@ -26,63 +26,6 @@ class FunctionDefCompiler(Compiler): def name(self) -> str: return "_FunctionDefCompiler_" - def _parse_arguments(self, arguments_node: ast.arguments): - class Arg: - def __init__(self, lineno, col_offset, name) -> None: - self._lineno = lineno - self._clo_offset = col_offset - self._name = name - - class Default: - def __init__(self, lineno, col_offset, value) -> None: - self._lineno = lineno - self._col_offset = col_offset - self._value = value - - def _find_corresponding_name(defaults: List[Default], names: List[Arg]): - for d in defaults: - i = 0 - while i < len(names) and names[i]._lineno == d._lineno and names[i]._clo_offset < d._col_offset: - i += 1 - if i <= len(names): - arg_with_default_value[names[i-1]._name] = d._value - - args_ = [] - arg_with_default_value = {} - for arg in arguments_node.args: - if arg.arg == "self": - continue - a = Arg(arg.lineno, arg.col_offset, arg.arg) - args_.append(a) - arg_with_default_value[a._name] = None - - for arg in arguments_node.kwonlyargs: - a = Arg(arg.lineno, arg.col_offset, arg.arg) - args_.append(a) - arg_with_default_value[a._name] = None - - if arguments_node.vararg != None: - a = Arg(arguments_node.vararg.arg.lineno, arg.col_offset, arg.arg) - args_.append(a) - arg_with_default_value[a._name] = None - - if arguments_node.kwarg != None: - a = Arg(arguments_node.vararg.arg.lineno, arg.col_offset, arg.arg) - args_.append(a) - arg_with_default_value[a._name] = None - - defaults_ = [] - # todo default - # for default in arguments_node.defaults: - # visitor = self.get_node_visitor(default) - # value = visitor(default) - # d = Default(default.lineno, default.col_offset, value) - # defaults_.append(d) - - _find_corresponding_name(defaults_, args_) - self._default_values = arg_with_default_value - return arg_with_default_value - def _parse_returns(self, returns_node): return [] @@ -93,13 +36,15 @@ class FunctionDefCompiler(Compiler): return [symbol] function_def: ast.FunctionDef = symbol.get_ast() # parse args - args: ast.arguments = function_def.args - args_with_value = self._parse_arguments(args) + arguments: ast.arguments = function_def.args + arguments_symbol = Symbol(arguments, symbol.get_full_name_with_scope(), "Arguments", SymbolType.arguments) returns = self._parse_returns(function_def.returns) function_symbol = FunctionSymbol(function_def, symbol.get_scope(), symbol.get_symbol_name(), - args_with_value, returns) + arguments_symbol, returns) new_symbols = [function_symbol] + new_symbols.append(arguments_symbol) + bodies: list = function_def.body index = 0 for body in bodies: diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index ce217139d67..23699aca3a4 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -24,6 +24,7 @@ class SymbolType: constant = 6 cell = 7 primitive = 8 + arguments = 9 invalid = 100 @@ -83,14 +84,13 @@ class Symbol: class FunctionSymbol(Symbol): - def __init__(self, ast_node: ast.AST, scope, symbol_name, args=None, returns=None, bodies=None): + def __init__(self, ast_node: ast.AST, scope, symbol_name, arguments=None, returns=None, bodies=None): super().__init__(ast_node, scope, symbol_name, SymbolType.function_def) - self._args = [] - if isinstance(args, list): - for arg in args: - if not isinstance(arg, Symbol): - raise RuntimeError("Input arg is not Symbol type") - self._args.append(arg) + self._arguments = None + if not isinstance(arguments, Symbol): + raise RuntimeError("Input arguments is not Symbol type") + else: + self._arguments = arguments if isinstance(returns, list): self._returns = returns @@ -107,8 +107,8 @@ class FunctionSymbol(Symbol): def add_body(self, body: Symbol): self._bodies.append(body) - def get_args(self): - return self._args + def get_arguments(self): + return self._arguments def is_compilable(self): return False @@ -129,6 +129,28 @@ class AssignSymbol(Symbol): return False +class ArgumentsSymbol(Symbol): + def __init__(self, ast_node: ast.AST, scope, symbol_name, args=None, defaults=None): + super().__init__(ast_node, scope, symbol_name, SymbolType.function_def) + self._args = [] + self._defaults = [] + + if isinstance(args, list): + for arg in args: + if not isinstance(arg, Symbol): + raise RuntimeError("Input arg is not Symbol type") + self._args.append(arg) + + if isinstance(defaults, list): + for default in defaults: + if not isinstance(default, Symbol): + raise RuntimeError("Input default is not Symbol type") + self._defaults.append(default) + + def is_compilable(self): + return False + + class CallSymbol(Symbol): def __init__(self, ast_node: ast.AST, scope, symbol_name, value, const_type): # const_type: number, string super().__init__(ast_node, scope, symbol_name, SymbolType.constant) -- Gitee From 861ca69109836330505ffbc82dc2174d6eca7391 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 29 Dec 2021 11:57:12 +0800 Subject: [PATCH 22/34] add call symbol add call symbol --- .../mindspore/rewrite_experiment/__init__.py | 1 + .../mindspore/rewrite_experiment/compiler.py | 3 +- .../compilers/arguments_compiler.py | 3 - .../compilers/assign_compiler.py | 27 -------- .../compilers/call_compiler.py | 69 +++++++++++++++++++ .../compilers/class_def_compiler.py | 3 - .../compilers/function_def_compiler.py | 3 - .../compilers/module_compiler.py | 3 - .../mindspore/rewrite_experiment/symbol.py | 60 +++++++++------- .../rewrite_experiment/symbol_table.py | 7 +- 10 files changed, 111 insertions(+), 68 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py index f6520d688bc..2dfc4e7c7a6 100644 --- a/mindspore/python/mindspore/rewrite_experiment/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -6,5 +6,6 @@ from .compilers.class_def_compiler import ClassDefCompiler from .compilers.function_def_compiler import FunctionDefCompiler from .compilers.assign_compiler import AssignCompiler from .compilers.arguments_compiler import ArgumentsCompiler +from .compilers.call_compiler import CallCompiler __all__ = ["Graph", "Rewrite", "SymbolTable"] diff --git a/mindspore/python/mindspore/rewrite_experiment/compiler.py b/mindspore/python/mindspore/rewrite_experiment/compiler.py index 2054ca20422..e5943e35036 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compiler.py @@ -22,9 +22,8 @@ class Compiler(abc.ABC): Pass eat a node and return a node. It processes a node by parse one type of ast node further one more """ - @abc.abstractmethod def name(self) -> str: - raise NotImplementedError + return "_{}_".format(self.__class__.__name__) @abc.abstractmethod def process(self, symbol: Symbol) -> [Symbol]: diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py index 80916bd3bdd..a21d0603bf0 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py @@ -21,9 +21,6 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class ArgumentsCompiler(Compiler): - def name(self) -> str: - return "_ArgumentsCompiler_" - def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol.get_ast(), ast.arguments): return [symbol] diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py index 20017ff140e..3312907f602 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py @@ -21,33 +21,6 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class AssignCompiler(Compiler): - # def parse_call(self, ast_node: ast.Call, node: Node): - # """ - # Parse Call node in ast. - # """ - # nodes = [] - # called_obj_names = [] - # visitor = self.get_node_visitor(ast_node.func) - # called_obj_name = visitor(ast_node.func) - # node.name = called_obj_name - # - # args_, nodes_, called_obj_names_ = self._parse_args(ast_node.args) - # - # visitor = self.get_node_visitor(ast_node.keywords) - # kwargs_: Dict = visitor(ast_node.keywords) - # - # node._args = args_ - # node._kwargs = kwargs_ - # nodes.extend(nodes_) - # nodes.append(node) - # called_obj_names.extend(called_obj_names_) - # called_obj_names.append(called_obj_name) - # - # return nodes, called_obj_names - - def name(self) -> str: - return "_AssignCompiler_" - def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol.get_ast(), ast.Assign): return [symbol] diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py new file mode 100644 index 00000000000..1c7f72c67b5 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py @@ -0,0 +1,69 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..symbol import Symbol, SymbolType, CallSymbol +from ..compiler import Compiler +from ..registers import CompilerRegister + + +@CompilerRegister.reg_compiler +class CallCompiler(Compiler): + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.Call): + return [symbol] + if isinstance(symbol, CallSymbol): + return [symbol] + new_symbols = [] + ast_call: ast.Call = symbol.get_ast() + + func_ = ast_call.func + args_ = ast_call.args + keywords_ = ast_call.keywords + + if not isinstance(func_, ast.Attribute): + raise RuntimeError("func of call should be ast.Attribute") + func_symbol = Symbol(func_, symbol.get_full_name_with_scope(), + "Func", SymbolType.attribute) + new_symbols.append(func_symbol) + + arg_symbols = [] + arg_index = 0 + for arg in args_: + if not (isinstance(arg, ast.Name) or isinstance(arg, ast.Constant) or isinstance(arg, ast.Call)): + raise RuntimeError("arg of Call should be ast.Name or ast.Constant or ast.Call") + arg_symbol = Symbol(arg, symbol.get_full_name_with_scope(), + "Arg(arg-" + str(arg_index) + ")", SymbolType.expression) + arg_index += 1 + arg_symbols.append(arg_symbol) + new_symbols.append(arg_symbol) + + keyword_symbols = [] + keyword_index = 0 + for keyword in keywords_: + if not isinstance(keyword, ast.keyword): + raise RuntimeError("keyword of Call should be ast.keyword") + + keyword_symbol = Symbol(keyword, symbol.get_full_name_with_scope(), + "Keyword(keyword-" + str(keyword_index) + ")", SymbolType.keyword) + keyword_index += 1 + keyword_symbols.append(keyword_symbol) + new_symbols.append(keyword_symbol) + + call_symbol = CallSymbol( + ast_call, symbol.get_scope(), symbol.get_symbol_name(), func_symbol, arg_symbols, keyword_symbols + ) + new_symbols.append(call_symbol) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py index 107c6e1dd3f..afceaef8cd6 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py @@ -21,9 +21,6 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class ClassDefCompiler(Compiler): - def name(self) -> str: - return "_ClassDefCompiler_" - def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol.get_ast(), ast.ClassDef): return [symbol] diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py index 2645f960aa0..fdaf730dc20 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py @@ -23,9 +23,6 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class FunctionDefCompiler(Compiler): - def name(self) -> str: - return "_FunctionDefCompiler_" - def _parse_returns(self, returns_node): return [] diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py index ee241f43156..a71d37fe46c 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py @@ -21,9 +21,6 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class ModuleCompiler(Compiler): - def name(self) -> str: - return "_ModuleCompiler_" - def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol.get_ast(), ast.Module): return [symbol] diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index 23699aca3a4..aba2561177f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -25,6 +25,9 @@ class SymbolType: cell = 7 primitive = 8 arguments = 9 + attribute = 10 + keyword = 11 + call = 12 invalid = 100 @@ -83,7 +86,15 @@ class Symbol: return self._sub_symbols -class FunctionSymbol(Symbol): +class UncompilableSymbol(Symbol): + def __init__(self, ast_node: ast.AST, scope, symbol_name, symbol_type=SymbolType.invalid): + super(UncompilableSymbol, self).__init__(ast_node, scope, symbol_name, symbol_type) + + def is_compilable(self): + return False + + +class FunctionSymbol(UncompilableSymbol): def __init__(self, ast_node: ast.AST, scope, symbol_name, arguments=None, returns=None, bodies=None): super().__init__(ast_node, scope, symbol_name, SymbolType.function_def) self._arguments = None @@ -110,13 +121,10 @@ class FunctionSymbol(Symbol): def get_arguments(self): return self._arguments - def is_compilable(self): - return False - -class AssignSymbol(Symbol): +class AssignSymbol(UncompilableSymbol): def __init__(self, ast_node: ast.AST, scope, symbol_name, targets=None, value=None): - super().__init__(ast_node, scope, symbol_name, SymbolType.function_def) + super().__init__(ast_node, scope, symbol_name, SymbolType.assign) self._targets = [] if isinstance(targets, list): for target in targets: @@ -125,13 +133,10 @@ class AssignSymbol(Symbol): self._targets.append(target) self._value = value - def is_compilable(self): - return False - -class ArgumentsSymbol(Symbol): +class ArgumentsSymbol(UncompilableSymbol): def __init__(self, ast_node: ast.AST, scope, symbol_name, args=None, defaults=None): - super().__init__(ast_node, scope, symbol_name, SymbolType.function_def) + super().__init__(ast_node, scope, symbol_name, SymbolType.arguments) self._args = [] self._defaults = [] @@ -147,26 +152,31 @@ class ArgumentsSymbol(Symbol): raise RuntimeError("Input default is not Symbol type") self._defaults.append(default) - def is_compilable(self): - return False - -class CallSymbol(Symbol): - def __init__(self, ast_node: ast.AST, scope, symbol_name, value, const_type): # const_type: number, string - super().__init__(ast_node, scope, symbol_name, SymbolType.constant) - self._func_name = "" +class CallSymbol(UncompilableSymbol): + def __init__(self, ast_node: ast.AST, scope, symbol_name, func=None, args=None, keywords=None): + super().__init__(ast_node, scope, symbol_name, SymbolType.call) + self._func = None self._args = [] - self._kwargs = {} + self._keywords = [] - def is_compilable(self): - return False + if not isinstance(func, Symbol): + raise RuntimeError("Input func is not Symbol type") + self._func = func + + for arg in args: + if not isinstance(arg, Symbol): + raise RuntimeError("Input arg is not Symbol type") + self._args.append(arg) + + for keyword in keywords: + if not isinstance(keyword, Symbol): + raise RuntimeError("Input keyword is not Symbol type") + self._keywords.append(keyword) -class ConstantSymbol(Symbol): +class ConstantSymbol(UncompilableSymbol): def __init__(self, ast_node: ast.AST, scope, symbol_name, value, const_type): # const_type: number, string super().__init__(ast_node, scope, symbol_name, SymbolType.constant) self._value = value self._const_type = const_type - - def is_compilable(self): - return False diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_table.py b/mindspore/python/mindspore/rewrite_experiment/symbol_table.py index ccc2b6d8db4..a5a90ea7f8c 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_table.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_table.py @@ -37,9 +37,9 @@ class SymbolTable: network = network net_cls = type(network) network_str = inspect.getsource(net_cls) - ast_root: ast.AST = ast.parse(network_str) + self._ast_root: ast.AST = ast.parse(network_str) name = net_cls.__name__ - root_symbol = Symbol(ast_root, "", "Module(" + name + ")", SymbolType.module) + root_symbol = Symbol(self._ast_root, "", "Module(" + name + ")", SymbolType.module) self._table = {name: root_symbol} def remove_symbol(self, symbol: Symbol): @@ -63,3 +63,6 @@ class SymbolTable: for k, v in self._table.items(): print(v.get_full_name_with_scope()) print("=================================================================================================") + + def print_ast(self): + astpretty.pprint(self._ast_root) -- Gitee From 3d6b0426313b1a4b8c5f58b03a6f86fa0801aeea Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 29 Dec 2021 14:45:19 +0800 Subject: [PATCH 23/34] add attribute compiler --- .../mindspore/rewrite_experiment/__init__.py | 1 + .../compilers/attribute_compiler.py | 45 +++++++++++++++++++ .../mindspore/rewrite_experiment/symbol.py | 10 +++++ 3 files changed, 56 insertions(+) create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py index 2dfc4e7c7a6..6786d03bf0f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -7,5 +7,6 @@ from .compilers.function_def_compiler import FunctionDefCompiler from .compilers.assign_compiler import AssignCompiler from .compilers.arguments_compiler import ArgumentsCompiler from .compilers.call_compiler import CallCompiler +from .compilers.attribute_compiler import AttributeCompiler __all__ = ["Graph", "Rewrite", "SymbolTable"] diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py new file mode 100644 index 00000000000..40e079d55ae --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py @@ -0,0 +1,45 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..symbol import Symbol, SymbolType, AttributeSymbol +from ..compiler import Compiler +from ..registers import CompilerRegister + + +@CompilerRegister.reg_compiler +class AttributeCompiler(Compiler): + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.Attribute): + return [symbol] + if isinstance(symbol, AttributeSymbol): + return [symbol] + new_symbols = [] + ast_attribute: ast.Attribute = symbol.get_ast() + + attr_ = ast_attribute.attr + value_ = ast_attribute.value + + if not (isinstance(value_, ast.Name) or isinstance(value_, ast.Call)): + raise RuntimeError("value of attribute should be ast.Name or ast.Call") + value_symbol = Symbol(value_, symbol.get_full_name_with_scope(), + "Value", SymbolType.expression) + new_symbols.append(value_symbol) + + attribute_symbol = AttributeSymbol( + ast_attribute, symbol.get_scope(), symbol.get_symbol_name(), attr=attr_, value=value_symbol + ) + new_symbols.append(attribute_symbol) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index aba2561177f..601f362db2a 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -175,6 +175,16 @@ class CallSymbol(UncompilableSymbol): self._keywords.append(keyword) +class AttributeSymbol(UncompilableSymbol): + def __init__(self, ast_node: ast.AST, scope, symbol_name, attr=None, value=None): + super().__init__(ast_node, scope, symbol_name, SymbolType.attribute) + self._attr = attr + + if not isinstance(value, Symbol): + raise RuntimeError("Input value is not Symbol type") + self._value = value + + class ConstantSymbol(UncompilableSymbol): def __init__(self, ast_node: ast.AST, scope, symbol_name, value, const_type): # const_type: number, string super().__init__(ast_node, scope, symbol_name, SymbolType.constant) -- Gitee From 9f8864b002bec548839628c58d746825d5843421 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 29 Dec 2021 16:19:47 +0800 Subject: [PATCH 24/34] add if compiler add if compiler add if compiler --- .../mindspore/rewrite_experiment/__init__.py | 1 + .../compilers/arguments_compiler.py | 3 +- .../compilers/assign_compiler.py | 7 +- .../compilers/attribute_compiler.py | 5 +- .../compilers/call_compiler.py | 24 +++-- .../compilers/function_def_compiler.py | 26 +++-- .../compilers/if_compiler.py | 65 +++++++++++++ .../mindspore/rewrite_experiment/symbol.py | 96 ++++++++++++++++--- 8 files changed, 191 insertions(+), 36 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py index 6786d03bf0f..a8bb865704c 100644 --- a/mindspore/python/mindspore/rewrite_experiment/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -5,6 +5,7 @@ from .compilers.module_compiler import ModuleCompiler from .compilers.class_def_compiler import ClassDefCompiler from .compilers.function_def_compiler import FunctionDefCompiler from .compilers.assign_compiler import AssignCompiler +from .compilers.if_compiler import IfCompiler from .compilers.arguments_compiler import ArgumentsCompiler from .compilers.call_compiler import CallCompiler from .compilers.attribute_compiler import AttributeCompiler diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py index a21d0603bf0..462b77fb201 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py @@ -59,7 +59,8 @@ class ArgumentsCompiler(Compiler): new_symbols.append(default_symbol) arguments_symbol = ArgumentsSymbol( - ast_arguments, symbol.get_scope(), symbol.get_symbol_name(), arg_symbols, default_symbols + ast_arguments, symbol.get_scope(), symbol.get_symbol_name(), + arg_symbols, default_symbols ) new_symbols.append(arguments_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py index 3312907f602..5eefb209666 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py @@ -45,7 +45,10 @@ class AssignCompiler(Compiler): value_symbol = Symbol(ast_assign.value, symbol.get_full_name_with_scope(), "Value(value-" + str(target_index) + ")", SymbolType.expression) new_symbols.append(value_symbol) - assign_symbol = AssignSymbol(ast_assign, symbol.get_scope(), symbol.get_symbol_name(), target_symbols, - value_symbol) + assign_symbol = AssignSymbol( + ast_assign, symbol.get_scope(), symbol.get_symbol_name(), + target_symbols, value_symbol + ) + new_symbols.append(assign_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py index 40e079d55ae..f1879e720c0 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py @@ -35,11 +35,12 @@ class AttributeCompiler(Compiler): if not (isinstance(value_, ast.Name) or isinstance(value_, ast.Call)): raise RuntimeError("value of attribute should be ast.Name or ast.Call") value_symbol = Symbol(value_, symbol.get_full_name_with_scope(), - "Value", SymbolType.expression) + "Attribute(value)", SymbolType.expression) new_symbols.append(value_symbol) attribute_symbol = AttributeSymbol( - ast_attribute, symbol.get_scope(), symbol.get_symbol_name(), attr=attr_, value=value_symbol + ast_attribute, symbol.get_scope(), symbol.get_symbol_name(), + attr=attr_, value=value_symbol ) new_symbols.append(attribute_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py index 1c7f72c67b5..4e2d4797898 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py @@ -17,6 +17,7 @@ import ast from ..symbol import Symbol, SymbolType, CallSymbol from ..compiler import Compiler from ..registers import CompilerRegister +from mindspore import log as logger @CompilerRegister.reg_compiler @@ -36,19 +37,21 @@ class CallCompiler(Compiler): if not isinstance(func_, ast.Attribute): raise RuntimeError("func of call should be ast.Attribute") func_symbol = Symbol(func_, symbol.get_full_name_with_scope(), - "Func", SymbolType.attribute) + "Call(func)", SymbolType.attribute) new_symbols.append(func_symbol) arg_symbols = [] arg_index = 0 for arg in args_: - if not (isinstance(arg, ast.Name) or isinstance(arg, ast.Constant) or isinstance(arg, ast.Call)): - raise RuntimeError("arg of Call should be ast.Name or ast.Constant or ast.Call") - arg_symbol = Symbol(arg, symbol.get_full_name_with_scope(), - "Arg(arg-" + str(arg_index) + ")", SymbolType.expression) - arg_index += 1 - arg_symbols.append(arg_symbol) - new_symbols.append(arg_symbol) + + if isinstance(arg, ast.Name) or isinstance(arg, ast.Constant) or isinstance(arg, ast.Call): + arg_symbol = Symbol(arg, symbol.get_full_name_with_scope(), + "Call(arg-" + str(arg_index) + ")", SymbolType.expression) + arg_index += 1 + arg_symbols.append(arg_symbol) + new_symbols.append(arg_symbol) + else: + logger.warning("Ignoring arg (%s) in call_compiler", type(arg).__name__) keyword_symbols = [] keyword_index = 0 @@ -57,13 +60,14 @@ class CallCompiler(Compiler): raise RuntimeError("keyword of Call should be ast.keyword") keyword_symbol = Symbol(keyword, symbol.get_full_name_with_scope(), - "Keyword(keyword-" + str(keyword_index) + ")", SymbolType.keyword) + "Call(keyword-" + str(keyword_index) + ")", SymbolType.keyword) keyword_index += 1 keyword_symbols.append(keyword_symbol) new_symbols.append(keyword_symbol) call_symbol = CallSymbol( - ast_call, symbol.get_scope(), symbol.get_symbol_name(), func_symbol, arg_symbols, keyword_symbols + ast_call, symbol.get_scope(), symbol.get_symbol_name(), + func_symbol, arg_symbols, keyword_symbols ) new_symbols.append(call_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py index fdaf730dc20..6bcca6364c9 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py @@ -37,18 +37,32 @@ class FunctionDefCompiler(Compiler): arguments_symbol = Symbol(arguments, symbol.get_full_name_with_scope(), "Arguments", SymbolType.arguments) returns = self._parse_returns(function_def.returns) - function_symbol = FunctionSymbol(function_def, symbol.get_scope(), symbol.get_symbol_name(), - arguments_symbol, returns) + function_symbol = FunctionSymbol( + function_def, symbol.get_scope(), symbol.get_symbol_name(), + arguments_symbol, returns + ) new_symbols = [function_symbol] new_symbols.append(arguments_symbol) bodies: list = function_def.body - index = 0 + assign_index, if_index = 0, 0 for body in bodies: if isinstance(body, ast.Assign): - body_symbol = Symbol(body, symbol.get_full_name_with_scope(), "Assign(assign-" + str(index) + ")", - SymbolType.assign) - index += 1 + body_symbol = Symbol( + body, symbol.get_full_name_with_scope(), + "Assign(assign-" + str(assign_index) + ")", SymbolType.assign + ) + assign_index += 1 + + function_symbol.add_body(body_symbol) + new_symbols.append(body_symbol) + elif isinstance(body, ast.If): + body_symbol = Symbol( + body, symbol.get_full_name_with_scope(), + "If(if-" + str(if_index) + ")", SymbolType.If + ) + if_index += 1 + function_symbol.add_body(body_symbol) new_symbols.append(body_symbol) else: diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py new file mode 100644 index 00000000000..c854f6c0a50 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py @@ -0,0 +1,65 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..symbol import Symbol, SymbolType, IfSymbol +from ..compiler import Compiler +from ..registers import CompilerRegister +from mindspore import log as logger + + +@CompilerRegister.reg_compiler +class IfCompiler(Compiler): + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.If): + return [symbol] + if isinstance(symbol, IfSymbol): + return [symbol] + new_symbols = [] + ast_if: ast.If = symbol.get_ast() + _test = ast_if.test + _body = ast_if.body + + if not isinstance(_test, ast.expr): + raise RuntimeError("test of if should be ast.expr") + test_symbol = Symbol( + _test, symbol.get_full_name_with_scope(), + "Test", SymbolType.expression + ) + new_symbols.append(test_symbol) + + index = 0 + body_symbols = [] + for body in _body: + if isinstance(body, ast.Assign): + new_node = Symbol( + body, symbol.get_full_name_with_scope(), + "Assign(assign-" + str(index) + ")", SymbolType.assign + ) + index += 1 + body_symbols.append(new_node) + new_symbols.append(new_node) + else: + if hasattr(body, "name"): + logger.warning("Ignoring symbol(%s) in ClassDef", body.name) + else: + logger.warning("Ignoring symbol(%s) in ClassDef", body) + + if_symbol = IfSymbol( + ast_if, symbol.get_scope(), symbol.get_symbol_name(), + test_symbol, body_symbols + ) + new_symbols.append(if_symbol) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index 601f362db2a..c8657a17780 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -28,11 +28,14 @@ class SymbolType: attribute = 10 keyword = 11 call = 12 + If = 13 invalid = 100 class Symbol: - def __init__(self, ast_node: ast.AST, scope, symbol_name, symbol_type=SymbolType.invalid): + def __init__( + self, ast_node: ast.AST, scope, symbol_name, symbol_type=SymbolType.invalid, + ): # self._attribute: AttributeNode = AttributeNode() # if outputs is None: # self._outputs: List[CellNode] = list() @@ -51,6 +54,14 @@ class Symbol: self._scope = scope self._symbol_name: str = symbol_name self._owner = None + if hasattr(self._ast_root, "lineno"): + self._lineno = self._ast_root.lineno + if hasattr(self._ast_root, "col_offset"): + self.col_offset = self._ast_root.col_offset + if hasattr(self._ast_root, "end_lineno"): + self.end_lineno = self._ast_root.end_lineno + if hasattr(self._ast_root, "end_col_offset"): + self.end_col_offset = self._ast_root.end_col_offset def is_compilable(self): return True @@ -87,16 +98,25 @@ class Symbol: class UncompilableSymbol(Symbol): - def __init__(self, ast_node: ast.AST, scope, symbol_name, symbol_type=SymbolType.invalid): - super(UncompilableSymbol, self).__init__(ast_node, scope, symbol_name, symbol_type) + def __init__( + self, ast_node: ast.AST, scope, symbol_name, symbol_type=SymbolType.invalid, + ): + super(UncompilableSymbol, self).__init__( + ast_node, scope, symbol_name, symbol_type, + ) def is_compilable(self): return False class FunctionSymbol(UncompilableSymbol): - def __init__(self, ast_node: ast.AST, scope, symbol_name, arguments=None, returns=None, bodies=None): - super().__init__(ast_node, scope, symbol_name, SymbolType.function_def) + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + arguments=None, returns=None, bodies=None + ): + super().__init__( + ast_node, scope, symbol_name, SymbolType.function_def, + ) self._arguments = None if not isinstance(arguments, Symbol): raise RuntimeError("Input arguments is not Symbol type") @@ -123,8 +143,13 @@ class FunctionSymbol(UncompilableSymbol): class AssignSymbol(UncompilableSymbol): - def __init__(self, ast_node: ast.AST, scope, symbol_name, targets=None, value=None): - super().__init__(ast_node, scope, symbol_name, SymbolType.assign) + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + targets=None, value=None + ): + super().__init__( + ast_node, scope, symbol_name, SymbolType.assign, + ) self._targets = [] if isinstance(targets, list): for target in targets: @@ -134,9 +159,35 @@ class AssignSymbol(UncompilableSymbol): self._value = value +class IfSymbol(UncompilableSymbol): + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + test=None, bodies=None + ): + super().__init__( + ast_node, scope, symbol_name, SymbolType.assign, + ) + self._test = None + self._bodies = [] + + if not isinstance(test, Symbol): + raise RuntimeError('test of If is not a symbol type') + + if isinstance(bodies, list): + for body in bodies: + if not isinstance(body, Symbol): + raise RuntimeError("body of If is not a symbol type") + self._bodies.append(body) + + class ArgumentsSymbol(UncompilableSymbol): - def __init__(self, ast_node: ast.AST, scope, symbol_name, args=None, defaults=None): - super().__init__(ast_node, scope, symbol_name, SymbolType.arguments) + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + args=None, defaults=None + ): + super().__init__( + ast_node, scope, symbol_name, SymbolType.arguments, + ) self._args = [] self._defaults = [] @@ -154,8 +205,13 @@ class ArgumentsSymbol(UncompilableSymbol): class CallSymbol(UncompilableSymbol): - def __init__(self, ast_node: ast.AST, scope, symbol_name, func=None, args=None, keywords=None): - super().__init__(ast_node, scope, symbol_name, SymbolType.call) + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + func=None, args=None, keywords=None + ): + super().__init__( + ast_node, scope, symbol_name, SymbolType.call, + ) self._func = None self._args = [] self._keywords = [] @@ -176,8 +232,13 @@ class CallSymbol(UncompilableSymbol): class AttributeSymbol(UncompilableSymbol): - def __init__(self, ast_node: ast.AST, scope, symbol_name, attr=None, value=None): - super().__init__(ast_node, scope, symbol_name, SymbolType.attribute) + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + attr=None, value=None + ): + super().__init__( + ast_node, scope, symbol_name, SymbolType.attribute, + ) self._attr = attr if not isinstance(value, Symbol): @@ -186,7 +247,12 @@ class AttributeSymbol(UncompilableSymbol): class ConstantSymbol(UncompilableSymbol): - def __init__(self, ast_node: ast.AST, scope, symbol_name, value, const_type): # const_type: number, string - super().__init__(ast_node, scope, symbol_name, SymbolType.constant) + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + value=None, const_type=None + ): # const_type: number, string + super().__init__( + ast_node, scope, symbol_name, SymbolType.constant, + ) self._value = value self._const_type = const_type -- Gitee From 2ec2f2931e1040f8dfea80c20212c3d41ef4c82b Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 30 Dec 2021 10:10:03 +0800 Subject: [PATCH 25/34] add binop return leaf compilers add binop return leaf compilers --- .../mindspore/rewrite_experiment/__init__.py | 6 ++ .../compilers/arg_compiler.py | 39 ++++++++++ .../compilers/binop_compiler.py | 60 +++++++++++++++ .../compilers/constant_compiler.py | 39 ++++++++++ .../compilers/name_compiler.py | 39 ++++++++++ .../compilers/return_compiler.py | 48 ++++++++++++ .../mindspore/rewrite_experiment/rewrite.py | 7 +- .../mindspore/rewrite_experiment/symbol.py | 76 ++++++++++++++++--- 8 files changed, 304 insertions(+), 10 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py index a8bb865704c..cbe853d90b1 100644 --- a/mindspore/python/mindspore/rewrite_experiment/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -9,5 +9,11 @@ from .compilers.if_compiler import IfCompiler from .compilers.arguments_compiler import ArgumentsCompiler from .compilers.call_compiler import CallCompiler from .compilers.attribute_compiler import AttributeCompiler +from .compilers.binop_compiler import BinopCompiler +from .compilers.return_compiler import ReturnCompiler +from .compilers.constant_compiler import ConstantCompiler +from .compilers.name_compiler import NameCompiler +from .compilers.arg_compiler import ArgCompiler + __all__ = ["Graph", "Rewrite", "SymbolTable"] diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py new file mode 100644 index 00000000000..2f3db930214 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py @@ -0,0 +1,39 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..symbol import Symbol, ArgSymbol +from ..compiler import Compiler +from ..registers import CompilerRegister +from mindspore import log as logger + + +@CompilerRegister.reg_compiler +class ArgCompiler(Compiler): + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.arg): + return [symbol] + if isinstance(symbol, ArgSymbol): + return [symbol] + new_symbols = [] + ast_arg: ast.arg = symbol.get_ast() + _arg = ast_arg.arg + + arg_symbol = ArgSymbol( + ast_arg, symbol.get_scope(), symbol.get_symbol_name(), + arg_input=_arg, arg_type=type(_arg).__name__ + ) + new_symbols.append(arg_symbol) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py new file mode 100644 index 00000000000..ace6f191182 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..symbol import Symbol, SymbolType, BinopSymbol +from ..compiler import Compiler +from ..registers import CompilerRegister +from mindspore import log as logger + + +@CompilerRegister.reg_compiler +class BinopCompiler(Compiler): + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.BinOp): + return [symbol] + if isinstance(symbol, BinopSymbol): + return [symbol] + new_symbols = [] + ast_binop: ast.BinOp = symbol.get_ast() + + _left = ast_binop.left + _op = ast_binop.op + _right = ast_binop.right + + left_symbol = None + if isinstance(_left, ast.BinOp) or isinstance(_left, ast.Constant): + left_symbol = Symbol( + _left, symbol.get_full_name_with_scope(), + "Binop(left)", SymbolType.expression) + new_symbols.append(left_symbol) + else: + logger.warning("Ignoring left (%s) in binop compiler", type(_left).__name__) + + right_symbol = None + if isinstance(_right, ast.BinOp) or isinstance(_right, ast.Constant): + right_symbol = Symbol( + _right, symbol.get_full_name_with_scope(), + "Binop(right)", SymbolType.expression) + new_symbols.append(right_symbol) + else: + logger.warning("Ignoring right (%s) in binop compiler", type(_right).__name__) + + binop_symbol = BinopSymbol( + ast_binop, symbol.get_scope(), symbol.get_symbol_name(), + left_symbol, _op, right_symbol + ) + new_symbols.append(binop_symbol) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py new file mode 100644 index 00000000000..d1d29b10f12 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py @@ -0,0 +1,39 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..symbol import Symbol, ConstantSymbol +from ..compiler import Compiler +from ..registers import CompilerRegister +from mindspore import log as logger + + +@CompilerRegister.reg_compiler +class ConstantCompiler(Compiler): + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.Constant): + return [symbol] + if isinstance(symbol, ConstantSymbol): + return [symbol] + new_symbols = [] + ast_constant: ast.Constant = symbol.get_ast() + _value = ast_constant.value + + constant_symbol = ConstantSymbol( + ast_constant, symbol.get_scope(), symbol.get_symbol_name(), + value=_value, const_type=type(_value).__name__ + ) + new_symbols.append(constant_symbol) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py new file mode 100644 index 00000000000..668bd31e880 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py @@ -0,0 +1,39 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..symbol import Symbol, SymbolType, NameSymbol +from ..compiler import Compiler +from ..registers import CompilerRegister +from mindspore import log as logger + + +@CompilerRegister.reg_compiler +class NameCompiler(Compiler): + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.Name): + return [symbol] + if isinstance(symbol, NameSymbol): + return [symbol] + new_symbols = [] + ast_name: ast.Name = symbol.get_ast() + _id = ast_name.id + + name_symbol = NameSymbol( + ast_name, symbol.get_scope(), symbol.get_symbol_name(), + id_input=_id, name_type=type(_id).__name__ + ) + new_symbols.append(name_symbol) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py new file mode 100644 index 00000000000..6c057086b73 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py @@ -0,0 +1,48 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..symbol import Symbol, SymbolType, ReturnSymbol +from ..compiler import Compiler +from ..registers import CompilerRegister +from mindspore import log as logger + + +@CompilerRegister.reg_compiler +class ReturnCompiler(Compiler): + def process(self, symbol: Symbol) -> [Symbol]: + if not isinstance(symbol.get_ast(), ast.Return): + return [symbol] + if isinstance(symbol, ReturnSymbol): + return [symbol] + new_symbols = [] + ast_return: ast.Return = symbol.get_ast() + + _value = ast_return.value + value_symbol = None + if isinstance(_value, ast.Name): + value_symbol = Symbol( + _value, symbol.get_full_name_with_scope(), + "Value", SymbolType.name) + new_symbols.append(value_symbol) + else: + logger.warning("Ignoring value (%s) in return compiler", type(_value).__name__) + + return_symbol = ReturnSymbol( + ast_return, symbol.get_scope(), symbol.get_symbol_name(), + value_symbol + ) + new_symbols.append(return_symbol) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index 8d431d0fa27..d5e5ca3bc2d 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -29,8 +29,13 @@ from .graph import Graph class Rewrite: @staticmethod def _is_leaf_symbol(symbol: Symbol): + leaf_symbol_type_list = [ + SymbolType.constant, + SymbolType.arg, + SymbolType.name + ] # inseparable python node: - if symbol.symbol_type() is SymbolType.constant: + if symbol.symbol_type() in leaf_symbol_type_list: return False # not supported yet: # if node.node_type() is NodeType.graph: diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index c8657a17780..0c234d103fb 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -21,14 +21,18 @@ class SymbolType: function_def = 3 assign = 4 expression = 5 - constant = 6 - cell = 7 - primitive = 8 - arguments = 9 - attribute = 10 - keyword = 11 - call = 12 - If = 13 + cell = 6 + primitive = 7 + arguments = 8 + attribute = 9 + keyword = 10 + call = 11 + If = 12 + binop = 13 + return_type = 14 + name = 15 + arg = 16 + constant = 17 invalid = 100 @@ -165,7 +169,7 @@ class IfSymbol(UncompilableSymbol): test=None, bodies=None ): super().__init__( - ast_node, scope, symbol_name, SymbolType.assign, + ast_node, scope, symbol_name, SymbolType.If, ) self._test = None self._bodies = [] @@ -246,6 +250,36 @@ class AttributeSymbol(UncompilableSymbol): self._value = value +class BinopSymbol(UncompilableSymbol): + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + left=None, op=None, right=None + ): + super().__init__( + ast_node, scope, symbol_name, SymbolType.binop, + ) + if not isinstance(left, Symbol): + raise RuntimeError("Input left is not Symbol type") + self._left = left + self._op = op + if not isinstance(right, Symbol): + raise RuntimeError("Input right is not Symbol type") + self._right = right + + +class ReturnSymbol(UncompilableSymbol): + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + value=None + ): + super().__init__( + ast_node, scope, symbol_name, SymbolType.return_type, + ) + if not isinstance(value, Symbol): + raise RuntimeError("Input value is not Symbol type") + self._value = value + + class ConstantSymbol(UncompilableSymbol): def __init__( self, ast_node: ast.AST, scope, symbol_name, @@ -256,3 +290,27 @@ class ConstantSymbol(UncompilableSymbol): ) self._value = value self._const_type = const_type + + +class NameSymbol(UncompilableSymbol): + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + id_input=None, name_type=None + ): # name_type: string + super().__init__( + ast_node, scope, symbol_name, SymbolType.name, + ) + self._id = id_input + self._name_type = name_type + + +class ArgSymbol(UncompilableSymbol): + def __init__( + self, ast_node: ast.AST, scope, symbol_name, + arg_input=None, arg_type=None + ): # arg_type: string + super().__init__( + ast_node, scope, symbol_name, SymbolType.arg, + ) + self._arg = arg_input + self._arg_type = arg_type -- Gitee From 985ed182cd61c3789499c55402b99a599d25f38b Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 30 Dec 2021 16:02:38 +0800 Subject: [PATCH 26/34] track return symbols track return symbols --- .../compilers/arg_compiler.py | 2 +- .../compilers/call_compiler.py | 3 +- .../compilers/constant_compiler.py | 2 +- .../compilers/function_def_compiler.py | 30 +++++++++++-------- .../compilers/if_compiler.py | 7 +++++ .../compilers/name_compiler.py | 2 +- .../mindspore/rewrite_experiment/symbol.py | 7 +---- 7 files changed, 31 insertions(+), 22 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py index 2f3db930214..c9e806d6b14 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py @@ -32,7 +32,7 @@ class ArgCompiler(Compiler): _arg = ast_arg.arg arg_symbol = ArgSymbol( - ast_arg, symbol.get_scope(), symbol.get_symbol_name(), + ast_arg, symbol.get_scope(), "{}[arg]".format(symbol.get_symbol_name()), arg_input=_arg, arg_type=type(_arg).__name__ ) new_symbols.append(arg_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py index 4e2d4797898..714f6978083 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py @@ -44,7 +44,8 @@ class CallCompiler(Compiler): arg_index = 0 for arg in args_: - if isinstance(arg, ast.Name) or isinstance(arg, ast.Constant) or isinstance(arg, ast.Call): + if isinstance(arg, ast.Name) or isinstance(arg, ast.Constant) \ + or isinstance(arg, ast.Call) or isinstance(arg, ast.BinOp): arg_symbol = Symbol(arg, symbol.get_full_name_with_scope(), "Call(arg-" + str(arg_index) + ")", SymbolType.expression) arg_index += 1 diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py index d1d29b10f12..de391f0e409 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py @@ -32,7 +32,7 @@ class ConstantCompiler(Compiler): _value = ast_constant.value constant_symbol = ConstantSymbol( - ast_constant, symbol.get_scope(), symbol.get_symbol_name(), + ast_constant, symbol.get_scope(), "{}[constant]".format(symbol.get_symbol_name()), value=_value, const_type=type(_value).__name__ ) new_symbols.append(constant_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py index 6bcca6364c9..f5366964ffd 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py @@ -23,28 +23,21 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class FunctionDefCompiler(Compiler): - def _parse_returns(self, returns_node): - return [] - def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol.get_ast(), ast.FunctionDef): return [symbol] if isinstance(symbol, FunctionSymbol): return [symbol] function_def: ast.FunctionDef = symbol.get_ast() + + new_symbols = [] # parse args arguments: ast.arguments = function_def.args arguments_symbol = Symbol(arguments, symbol.get_full_name_with_scope(), "Arguments", SymbolType.arguments) - - returns = self._parse_returns(function_def.returns) - function_symbol = FunctionSymbol( - function_def, symbol.get_scope(), symbol.get_symbol_name(), - arguments_symbol, returns - ) - new_symbols = [function_symbol] new_symbols.append(arguments_symbol) bodies: list = function_def.body + body_symbols = [] assign_index, if_index = 0, 0 for body in bodies: if isinstance(body, ast.Assign): @@ -54,7 +47,7 @@ class FunctionDefCompiler(Compiler): ) assign_index += 1 - function_symbol.add_body(body_symbol) + body_symbols.append(body_symbol) new_symbols.append(body_symbol) elif isinstance(body, ast.If): body_symbol = Symbol( @@ -63,11 +56,24 @@ class FunctionDefCompiler(Compiler): ) if_index += 1 - function_symbol.add_body(body_symbol) + body_symbols.append(body_symbol) + new_symbols.append(body_symbol) + elif isinstance(body, ast.Return): + body_symbol = Symbol( + body, symbol.get_full_name_with_scope(), + "Return", SymbolType.If + ) + body_symbols.append(body_symbol) new_symbols.append(body_symbol) else: if hasattr(body, "name"): logger.warning("Ignoring symbol(%s) in FunctionDef", body.name) else: logger.warning("Ignoring symbol(%s) in FunctionDef", body) + + function_symbol = FunctionSymbol( + function_def, symbol.get_scope(), symbol.get_symbol_name(), + arguments_symbol, body_symbols + ) + new_symbols.append(function_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py index c854f6c0a50..ac1dcea7a7e 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py @@ -51,6 +51,13 @@ class IfCompiler(Compiler): index += 1 body_symbols.append(new_node) new_symbols.append(new_node) + elif isinstance(body, ast.Return): + new_node = Symbol( + body, symbol.get_full_name_with_scope(), + "Return", SymbolType.return_type + ) + body_symbols.append(new_node) + new_symbols.append(new_node) else: if hasattr(body, "name"): logger.warning("Ignoring symbol(%s) in ClassDef", body.name) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py index 668bd31e880..f82f018da68 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py @@ -32,7 +32,7 @@ class NameCompiler(Compiler): _id = ast_name.id name_symbol = NameSymbol( - ast_name, symbol.get_scope(), symbol.get_symbol_name(), + ast_name, symbol.get_scope(), "{}[name]".format(symbol.get_symbol_name()), id_input=_id, name_type=type(_id).__name__ ) new_symbols.append(name_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index 0c234d103fb..ca84b5b7425 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -116,7 +116,7 @@ class UncompilableSymbol(Symbol): class FunctionSymbol(UncompilableSymbol): def __init__( self, ast_node: ast.AST, scope, symbol_name, - arguments=None, returns=None, bodies=None + arguments=None, bodies=None ): super().__init__( ast_node, scope, symbol_name, SymbolType.function_def, @@ -127,11 +127,6 @@ class FunctionSymbol(UncompilableSymbol): else: self._arguments = arguments - if isinstance(returns, list): - self._returns = returns - else: - self._returns = [] - self._bodies = [] if isinstance(bodies, list): for body in bodies: -- Gitee From 94c091303482b82b732bebc9f69fa045cdc18d5d Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 31 Dec 2021 16:21:21 +0800 Subject: [PATCH 27/34] rebuild symbol name rebuild symbol name rebuild symbol name --- .../compilers/arg_compiler.py | 2 +- .../compilers/arguments_compiler.py | 14 ++-- .../compilers/assign_compiler.py | 14 ++-- .../compilers/attribute_compiler.py | 7 +- .../compilers/binop_compiler.py | 12 ++-- .../compilers/call_compiler.py | 20 ++++-- .../compilers/class_def_compiler.py | 6 +- .../compilers/constant_compiler.py | 2 +- .../compilers/function_def_compiler.py | 20 ++++-- .../compilers/if_compiler.py | 15 ++-- .../compilers/module_compiler.py | 6 +- .../compilers/name_compiler.py | 2 +- .../compilers/return_compiler.py | 5 +- .../mindspore/rewrite_experiment/namespace.py | 69 +++++++++++++++++++ .../mindspore/rewrite_experiment/symbol.py | 19 ++++- 15 files changed, 168 insertions(+), 45 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/namespace.py diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py index c9e806d6b14..2f3db930214 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py @@ -32,7 +32,7 @@ class ArgCompiler(Compiler): _arg = ast_arg.arg arg_symbol = ArgSymbol( - ast_arg, symbol.get_scope(), "{}[arg]".format(symbol.get_symbol_name()), + ast_arg, symbol.get_scope(), symbol.get_symbol_name(), arg_input=_arg, arg_type=type(_arg).__name__ ) new_symbols.append(arg_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py index 462b77fb201..3969870a94e 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py @@ -40,8 +40,11 @@ class ArgumentsCompiler(Compiler): if not isinstance(arg, ast.arg): raise RuntimeError("arg of arguments should be ast.arg") - arg_symbol = Symbol(arg, symbol.get_full_name_with_scope(), - "Arg(arg-" + str(arg_index) + ")", SymbolType.expression) + arg_symbol = Symbol( + ast_node=arg, + scope="{}.args".format(symbol.get_full_name_with_scope()), + symbol_type=SymbolType.arg + ) arg_index += 1 arg_symbols.append(arg_symbol) new_symbols.append(arg_symbol) @@ -52,8 +55,11 @@ class ArgumentsCompiler(Compiler): if not isinstance(single_default, ast.Constant): raise RuntimeError("only support constant default value in arguments now") - default_symbol = Symbol(single_default, symbol.get_full_name_with_scope(), - "Default(default-" + str(default_index) + ")", SymbolType.constant) + default_symbol = Symbol( + ast_node=single_default, + scope="{}.Defaults".format(symbol.get_full_name_with_scope()), + symbol_type=SymbolType.constant + ) default_index += 1 default_symbols.append(default_symbol) new_symbols.append(default_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py index 5eefb209666..86f30149dd1 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py @@ -34,16 +34,22 @@ class AssignCompiler(Compiler): for target in targets: if not isinstance(target, ast.expr): raise RuntimeError("Target of assign should be ast.expr") - target_symbol = Symbol(target, symbol.get_full_name_with_scope(), - "Target(target-" + str(target_index) + ")", SymbolType.expression) + target_symbol = Symbol( + ast_node=target, + scope="{}.Targets".format(symbol.get_full_name_with_scope()), + symbol_type=SymbolType.expression + ) target_index += 1 target_symbols.append(target_symbol) new_symbols.append(target_symbol) if not isinstance(ast_assign.value, ast.expr): raise RuntimeError("Value of assign should be ast.expr") - value_symbol = Symbol(ast_assign.value, symbol.get_full_name_with_scope(), - "Value(value-" + str(target_index) + ")", SymbolType.expression) + value_symbol = Symbol( + ast_node=ast_assign.value, + scope="{}.value".format(symbol.get_full_name_with_scope()), + symbol_type=SymbolType.expression + ) new_symbols.append(value_symbol) assign_symbol = AssignSymbol( ast_assign, symbol.get_scope(), symbol.get_symbol_name(), diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py index f1879e720c0..96a27fcae19 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py @@ -34,8 +34,11 @@ class AttributeCompiler(Compiler): if not (isinstance(value_, ast.Name) or isinstance(value_, ast.Call)): raise RuntimeError("value of attribute should be ast.Name or ast.Call") - value_symbol = Symbol(value_, symbol.get_full_name_with_scope(), - "Attribute(value)", SymbolType.expression) + value_symbol = Symbol( + ast_node=value_, + scope=symbol.get_full_name_with_scope(), + symbol_type=SymbolType.expression + ) new_symbols.append(value_symbol) attribute_symbol = AttributeSymbol( diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py index ace6f191182..3badf122339 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py @@ -37,8 +37,10 @@ class BinopCompiler(Compiler): left_symbol = None if isinstance(_left, ast.BinOp) or isinstance(_left, ast.Constant): left_symbol = Symbol( - _left, symbol.get_full_name_with_scope(), - "Binop(left)", SymbolType.expression) + ast_node=_left, + scope="{}.left".format(symbol.get_full_name_with_scope()), + symbol_type=SymbolType.expression + ) new_symbols.append(left_symbol) else: logger.warning("Ignoring left (%s) in binop compiler", type(_left).__name__) @@ -46,8 +48,10 @@ class BinopCompiler(Compiler): right_symbol = None if isinstance(_right, ast.BinOp) or isinstance(_right, ast.Constant): right_symbol = Symbol( - _right, symbol.get_full_name_with_scope(), - "Binop(right)", SymbolType.expression) + ast_node=_right, + scope="{}.right".format(symbol.get_full_name_with_scope()), + symbol_type=SymbolType.expression + ) new_symbols.append(right_symbol) else: logger.warning("Ignoring right (%s) in binop compiler", type(_right).__name__) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py index 714f6978083..0be244c187a 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py @@ -36,8 +36,10 @@ class CallCompiler(Compiler): if not isinstance(func_, ast.Attribute): raise RuntimeError("func of call should be ast.Attribute") - func_symbol = Symbol(func_, symbol.get_full_name_with_scope(), - "Call(func)", SymbolType.attribute) + func_symbol = Symbol( + ast_node=func_, + scope="{}.func".format(symbol.get_full_name_with_scope()), + symbol_type=SymbolType.attribute) new_symbols.append(func_symbol) arg_symbols = [] @@ -46,8 +48,11 @@ class CallCompiler(Compiler): if isinstance(arg, ast.Name) or isinstance(arg, ast.Constant) \ or isinstance(arg, ast.Call) or isinstance(arg, ast.BinOp): - arg_symbol = Symbol(arg, symbol.get_full_name_with_scope(), - "Call(arg-" + str(arg_index) + ")", SymbolType.expression) + arg_symbol = Symbol( + ast_node=arg, + scope="{}.args".format(symbol.get_full_name_with_scope()), + symbol_type=SymbolType.expression + ) arg_index += 1 arg_symbols.append(arg_symbol) new_symbols.append(arg_symbol) @@ -60,8 +65,11 @@ class CallCompiler(Compiler): if not isinstance(keyword, ast.keyword): raise RuntimeError("keyword of Call should be ast.keyword") - keyword_symbol = Symbol(keyword, symbol.get_full_name_with_scope(), - "Call(keyword-" + str(keyword_index) + ")", SymbolType.keyword) + keyword_symbol = Symbol( + ast_node=keyword, + scope="{}.keywords".format(symbol.get_full_name_with_scope()), + symbol_type=SymbolType.keyword + ) keyword_index += 1 keyword_symbols.append(keyword_symbol) new_symbols.append(keyword_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py index afceaef8cd6..657722b1265 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py @@ -29,8 +29,10 @@ class ClassDefCompiler(Compiler): new_symbols: [Symbol] = [] for body in bodies: if isinstance(body, ast.FunctionDef): - new_node = Symbol(body, symbol.get_full_name_with_scope(), "Function(" + body.name + ")", - SymbolType.function_def) + new_node = Symbol( + ast_node=body, + scope=symbol.get_full_name_with_scope(), + symbol_type=SymbolType.function_def) new_symbols.append(new_node) else: if hasattr(body, "name"): diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py index de391f0e409..1e21e66fb60 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py @@ -32,7 +32,7 @@ class ConstantCompiler(Compiler): _value = ast_constant.value constant_symbol = ConstantSymbol( - ast_constant, symbol.get_scope(), "{}[constant]".format(symbol.get_symbol_name()), + ast_node=ast_constant, scope=symbol.get_scope(), symbol_name=symbol.get_symbol_name(), value=_value, const_type=type(_value).__name__ ) new_symbols.append(constant_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py index f5366964ffd..dc3fb923b14 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py @@ -33,7 +33,10 @@ class FunctionDefCompiler(Compiler): new_symbols = [] # parse args arguments: ast.arguments = function_def.args - arguments_symbol = Symbol(arguments, symbol.get_full_name_with_scope(), "Arguments", SymbolType.arguments) + arguments_symbol = Symbol( + ast_node=arguments, + scope=symbol.get_full_name_with_scope(), + symbol_type=SymbolType.arguments) new_symbols.append(arguments_symbol) bodies: list = function_def.body @@ -42,8 +45,9 @@ class FunctionDefCompiler(Compiler): for body in bodies: if isinstance(body, ast.Assign): body_symbol = Symbol( - body, symbol.get_full_name_with_scope(), - "Assign(assign-" + str(assign_index) + ")", SymbolType.assign + ast_node=body, + scope=symbol.get_full_name_with_scope(), + symbol_type=SymbolType.assign ) assign_index += 1 @@ -51,8 +55,9 @@ class FunctionDefCompiler(Compiler): new_symbols.append(body_symbol) elif isinstance(body, ast.If): body_symbol = Symbol( - body, symbol.get_full_name_with_scope(), - "If(if-" + str(if_index) + ")", SymbolType.If + ast_node=body, + scope=symbol.get_full_name_with_scope(), + symbol_type=SymbolType.If ) if_index += 1 @@ -60,8 +65,9 @@ class FunctionDefCompiler(Compiler): new_symbols.append(body_symbol) elif isinstance(body, ast.Return): body_symbol = Symbol( - body, symbol.get_full_name_with_scope(), - "Return", SymbolType.If + ast_node=body, + scope=symbol.get_full_name_with_scope(), + symbol_type=SymbolType.return_type ) body_symbols.append(body_symbol) new_symbols.append(body_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py index ac1dcea7a7e..b604af27835 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py @@ -35,8 +35,9 @@ class IfCompiler(Compiler): if not isinstance(_test, ast.expr): raise RuntimeError("test of if should be ast.expr") test_symbol = Symbol( - _test, symbol.get_full_name_with_scope(), - "Test", SymbolType.expression + ast_node=_test, + scope="{}.test".format(symbol.get_full_name_with_scope()), + symbol_type=SymbolType.expression ) new_symbols.append(test_symbol) @@ -45,16 +46,18 @@ class IfCompiler(Compiler): for body in _body: if isinstance(body, ast.Assign): new_node = Symbol( - body, symbol.get_full_name_with_scope(), - "Assign(assign-" + str(index) + ")", SymbolType.assign + ast_node=body, + scope=symbol.get_full_name_with_scope(), + symbol_type=SymbolType.assign ) index += 1 body_symbols.append(new_node) new_symbols.append(new_node) elif isinstance(body, ast.Return): new_node = Symbol( - body, symbol.get_full_name_with_scope(), - "Return", SymbolType.return_type + ast_node=body, + scope=symbol.get_full_name_with_scope(), + symbol_type=SymbolType.return_type ) body_symbols.append(new_node) new_symbols.append(new_node) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py index a71d37fe46c..207040b048f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py @@ -29,8 +29,10 @@ class ModuleCompiler(Compiler): new_symbols: [Symbol] = [] for body in bodies: if isinstance(body, ast.ClassDef): - new_node = Symbol(body, symbol.get_full_name_with_scope(), "Class(" + body.name + ")", - SymbolType.class_def) + new_node = Symbol( + ast_node=body, + scope=symbol.get_full_name_with_scope(), + symbol_type=SymbolType.class_def) new_symbols.append(new_node) else: if hasattr(body, "name"): diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py index f82f018da68..668bd31e880 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py @@ -32,7 +32,7 @@ class NameCompiler(Compiler): _id = ast_name.id name_symbol = NameSymbol( - ast_name, symbol.get_scope(), "{}[name]".format(symbol.get_symbol_name()), + ast_name, symbol.get_scope(), symbol.get_symbol_name(), id_input=_id, name_type=type(_id).__name__ ) new_symbols.append(name_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py index 6c057086b73..ea0b71cf6d9 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py @@ -34,8 +34,9 @@ class ReturnCompiler(Compiler): value_symbol = None if isinstance(_value, ast.Name): value_symbol = Symbol( - _value, symbol.get_full_name_with_scope(), - "Value", SymbolType.name) + ast_node=_value, + scope=symbol.get_full_name_with_scope(), + symbol_type=SymbolType.name) new_symbols.append(value_symbol) else: logger.warning("Ignoring value (%s) in return compiler", type(_value).__name__) diff --git a/mindspore/python/mindspore/rewrite_experiment/namespace.py b/mindspore/python/mindspore/rewrite_experiment/namespace.py new file mode 100644 index 00000000000..0156ca84508 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/namespace.py @@ -0,0 +1,69 @@ +# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). +# +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the name_depot of symbol.""" + +import builtins + +from mindspore import log as logger + + +class NameDepot: + """ + Base name depot for symbol. + + Args: + name (str): The namedepot's name. + dicts (dict): A list of dict containing the namespace's variable. + """ + + def __init__(self, name): + self.name = name + self._name_dict = {} + + def __contains__(self, name): + if name in self._name_dict: + return True + else: + return False + + def __getitem__(self, name): + if name in self._name_dict: + return self._name_dict[name] + raise NameError(name) + + def __setitem__(self, key, value): + self._name_dict[key] = value + + def __repr__(self): + return f'Namespace:{self.name}' + + +class SymbolNameDepot: + """ + NameDepot for Symbol object + """ + def __init__(self): + self._names = NameDepot("Symbol") + + def set_symbol_name(self, scope: str, symbol_type: str): + name_key = "{}.{}".format(scope, symbol_type) + if name_key in self._names: + self._names[name_key] += 1 + return "{}_{}".format(symbol_type, str(self._names[name_key])) + else: + self._names[name_key] = 0 + return symbol_type diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index ca84b5b7425..c6a10e67be7 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -14,6 +14,11 @@ # ============================================================================ import ast +from .namespace import SymbolNameDepot + + +sn_depot = SymbolNameDepot() + class SymbolType: module = 1 @@ -38,7 +43,7 @@ class SymbolType: class Symbol: def __init__( - self, ast_node: ast.AST, scope, symbol_name, symbol_type=SymbolType.invalid, + self, ast_node: ast.AST, scope, symbol_name=None, symbol_type=SymbolType.invalid, ): # self._attribute: AttributeNode = AttributeNode() # if outputs is None: @@ -56,7 +61,7 @@ class Symbol: self._attribute: dict = {} self._sub_symbols: list = [] self._scope = scope - self._symbol_name: str = symbol_name + self._symbol_name: str = self.init_symbol_name() if not symbol_name else symbol_name self._owner = None if hasattr(self._ast_root, "lineno"): self._lineno = self._ast_root.lineno @@ -67,6 +72,14 @@ class Symbol: if hasattr(self._ast_root, "end_col_offset"): self.end_col_offset = self._ast_root.end_col_offset + def init_symbol_name(self): + global sn_depot + symbol_ast_type: str = type(self._ast_root).__name__ + if hasattr(self._ast_root, "name"): + symbol_ast_type = "{}({})".format(symbol_ast_type, self._ast_root.name) + symbol_name = sn_depot.set_symbol_name(self._scope, symbol_ast_type) + return symbol_name + def is_compilable(self): return True @@ -106,7 +119,7 @@ class UncompilableSymbol(Symbol): self, ast_node: ast.AST, scope, symbol_name, symbol_type=SymbolType.invalid, ): super(UncompilableSymbol, self).__init__( - ast_node, scope, symbol_name, symbol_type, + ast_node, scope, symbol_name=symbol_name, symbol_type=symbol_type, ) def is_compilable(self): -- Gitee From eaaeb37ecf78b83c72a1639ff6d895a2449b094b Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 5 Jan 2022 16:51:51 +0800 Subject: [PATCH 28/34] rebuild child symbol --- .../compilers/arguments_compiler.py | 4 +- .../compilers/assign_compiler.py | 4 +- .../compilers/attribute_compiler.py | 2 +- .../compilers/binop_compiler.py | 2 +- .../compilers/call_compiler.py | 6 +- .../compilers/function_def_compiler.py | 8 +- .../compilers/if_compiler.py | 6 +- .../compilers/return_compiler.py | 2 +- .../mindspore/rewrite_experiment/symbol.py | 109 ++++++------------ 9 files changed, 53 insertions(+), 90 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py index 3969870a94e..9dd9940ec51 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py @@ -46,7 +46,7 @@ class ArgumentsCompiler(Compiler): symbol_type=SymbolType.arg ) arg_index += 1 - arg_symbols.append(arg_symbol) + arg_symbols.append(arg_symbol.symbol_name) new_symbols.append(arg_symbol) default_symbols = [] @@ -61,7 +61,7 @@ class ArgumentsCompiler(Compiler): symbol_type=SymbolType.constant ) default_index += 1 - default_symbols.append(default_symbol) + default_symbols.append(default_symbol.symbol_name) new_symbols.append(default_symbol) arguments_symbol = ArgumentsSymbol( diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py index 86f30149dd1..d2616c81e5c 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py @@ -40,7 +40,7 @@ class AssignCompiler(Compiler): symbol_type=SymbolType.expression ) target_index += 1 - target_symbols.append(target_symbol) + target_symbols.append(target_symbol.symbol_name) new_symbols.append(target_symbol) if not isinstance(ast_assign.value, ast.expr): @@ -53,7 +53,7 @@ class AssignCompiler(Compiler): new_symbols.append(value_symbol) assign_symbol = AssignSymbol( ast_assign, symbol.get_scope(), symbol.get_symbol_name(), - target_symbols, value_symbol + target_symbols, value_symbol.symbol_name ) new_symbols.append(assign_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py index 96a27fcae19..c83322312a7 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py @@ -43,7 +43,7 @@ class AttributeCompiler(Compiler): attribute_symbol = AttributeSymbol( ast_attribute, symbol.get_scope(), symbol.get_symbol_name(), - attr=attr_, value=value_symbol + attr=attr_, value=value_symbol.symbol_name ) new_symbols.append(attribute_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py index 3badf122339..b441afe7ad2 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py @@ -58,7 +58,7 @@ class BinopCompiler(Compiler): binop_symbol = BinopSymbol( ast_binop, symbol.get_scope(), symbol.get_symbol_name(), - left_symbol, _op, right_symbol + left_symbol.symbol_name, _op, right_symbol.symbol_name ) new_symbols.append(binop_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py index 0be244c187a..cada24b6169 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py @@ -54,7 +54,7 @@ class CallCompiler(Compiler): symbol_type=SymbolType.expression ) arg_index += 1 - arg_symbols.append(arg_symbol) + arg_symbols.append(arg_symbol.symbol_name) new_symbols.append(arg_symbol) else: logger.warning("Ignoring arg (%s) in call_compiler", type(arg).__name__) @@ -71,12 +71,12 @@ class CallCompiler(Compiler): symbol_type=SymbolType.keyword ) keyword_index += 1 - keyword_symbols.append(keyword_symbol) + keyword_symbols.append(keyword_symbol.symbol_name) new_symbols.append(keyword_symbol) call_symbol = CallSymbol( ast_call, symbol.get_scope(), symbol.get_symbol_name(), - func_symbol, arg_symbols, keyword_symbols + func_symbol.symbol_name, arg_symbols, keyword_symbols ) new_symbols.append(call_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py index dc3fb923b14..5feff8425a1 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py @@ -51,7 +51,7 @@ class FunctionDefCompiler(Compiler): ) assign_index += 1 - body_symbols.append(body_symbol) + body_symbols.append(body_symbol.symbol_name) new_symbols.append(body_symbol) elif isinstance(body, ast.If): body_symbol = Symbol( @@ -61,7 +61,7 @@ class FunctionDefCompiler(Compiler): ) if_index += 1 - body_symbols.append(body_symbol) + body_symbols.append(body_symbol.symbol_name) new_symbols.append(body_symbol) elif isinstance(body, ast.Return): body_symbol = Symbol( @@ -69,7 +69,7 @@ class FunctionDefCompiler(Compiler): scope=symbol.get_full_name_with_scope(), symbol_type=SymbolType.return_type ) - body_symbols.append(body_symbol) + body_symbols.append(body_symbol.symbol_name) new_symbols.append(body_symbol) else: if hasattr(body, "name"): @@ -79,7 +79,7 @@ class FunctionDefCompiler(Compiler): function_symbol = FunctionSymbol( function_def, symbol.get_scope(), symbol.get_symbol_name(), - arguments_symbol, body_symbols + arguments_symbol.symbol_name, body_symbols ) new_symbols.append(function_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py index b604af27835..27ab0b2023b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py @@ -51,7 +51,7 @@ class IfCompiler(Compiler): symbol_type=SymbolType.assign ) index += 1 - body_symbols.append(new_node) + body_symbols.append(new_node.symbol_name) new_symbols.append(new_node) elif isinstance(body, ast.Return): new_node = Symbol( @@ -59,7 +59,7 @@ class IfCompiler(Compiler): scope=symbol.get_full_name_with_scope(), symbol_type=SymbolType.return_type ) - body_symbols.append(new_node) + body_symbols.append(new_node.symbol_name) new_symbols.append(new_node) else: if hasattr(body, "name"): @@ -69,7 +69,7 @@ class IfCompiler(Compiler): if_symbol = IfSymbol( ast_if, symbol.get_scope(), symbol.get_symbol_name(), - test_symbol, body_symbols + test_symbol.symbol_name, body_symbols ) new_symbols.append(if_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py index ea0b71cf6d9..2361410b691 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py @@ -43,7 +43,7 @@ class ReturnCompiler(Compiler): return_symbol = ReturnSymbol( ast_return, symbol.get_scope(), symbol.get_symbol_name(), - value_symbol + value_symbol.symbol_name ) new_symbols.append(return_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index c6a10e67be7..d324aa238f3 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -83,6 +83,10 @@ class Symbol: def is_compilable(self): return True + @property + def symbol_name(self): + return self._symbol_name + def get_scope(self) -> str: return self._scope @@ -113,6 +117,16 @@ class Symbol: def get_sub_symbols(self): return self._sub_symbols + @staticmethod + def check_input_is_str(input_value): + if not isinstance(input_value, str): + raise RuntimeError("input is not str") + + def copy_list(self, target_list: list, source_list: list): + for value in source_list: + self.check_input_is_str(value) + target_list.append(value) + class UncompilableSymbol(Symbol): def __init__( @@ -129,162 +143,111 @@ class UncompilableSymbol(Symbol): class FunctionSymbol(UncompilableSymbol): def __init__( self, ast_node: ast.AST, scope, symbol_name, - arguments=None, bodies=None + arguments: str = None, bodies: list = None ): super().__init__( ast_node, scope, symbol_name, SymbolType.function_def, ) - self._arguments = None - if not isinstance(arguments, Symbol): - raise RuntimeError("Input arguments is not Symbol type") - else: - self._arguments = arguments + self._arguments = arguments - self._bodies = [] - if isinstance(bodies, list): - for body in bodies: - if not isinstance(body, Symbol): - raise RuntimeError("Input bodies is not Symbol type") - self._bodies.append(body) + self._bodies_names = [] + self.copy_list(self._bodies_names, bodies) - def add_body(self, body: Symbol): - self._bodies.append(body) + def add_body(self, body: str): + self._bodies_names.append(body) - def get_arguments(self): + def get_arguments(self) -> str: return self._arguments class AssignSymbol(UncompilableSymbol): def __init__( self, ast_node: ast.AST, scope, symbol_name, - targets=None, value=None + targets: list = None, value: str = None ): super().__init__( ast_node, scope, symbol_name, SymbolType.assign, ) self._targets = [] - if isinstance(targets, list): - for target in targets: - if not isinstance(target, Symbol): - raise RuntimeError("Input targets is not Symbol type") - self._targets.append(target) + self.copy_list(self._targets, targets) self._value = value class IfSymbol(UncompilableSymbol): def __init__( self, ast_node: ast.AST, scope, symbol_name, - test=None, bodies=None + test: str = None, bodies: list = None ): super().__init__( ast_node, scope, symbol_name, SymbolType.If, ) - self._test = None + self._test = test self._bodies = [] - - if not isinstance(test, Symbol): - raise RuntimeError('test of If is not a symbol type') - - if isinstance(bodies, list): - for body in bodies: - if not isinstance(body, Symbol): - raise RuntimeError("body of If is not a symbol type") - self._bodies.append(body) + self.copy_list(self._bodies, bodies) class ArgumentsSymbol(UncompilableSymbol): def __init__( self, ast_node: ast.AST, scope, symbol_name, - args=None, defaults=None + args: list = None, defaults: list = None ): super().__init__( ast_node, scope, symbol_name, SymbolType.arguments, ) self._args = [] self._defaults = [] - - if isinstance(args, list): - for arg in args: - if not isinstance(arg, Symbol): - raise RuntimeError("Input arg is not Symbol type") - self._args.append(arg) - - if isinstance(defaults, list): - for default in defaults: - if not isinstance(default, Symbol): - raise RuntimeError("Input default is not Symbol type") - self._defaults.append(default) + self.copy_list(self._args, args) + self.copy_list(self._defaults, defaults) class CallSymbol(UncompilableSymbol): def __init__( self, ast_node: ast.AST, scope, symbol_name, - func=None, args=None, keywords=None + func: str = None, args: list = None, keywords: list = None ): super().__init__( ast_node, scope, symbol_name, SymbolType.call, ) - self._func = None + self._func = func self._args = [] self._keywords = [] - - if not isinstance(func, Symbol): - raise RuntimeError("Input func is not Symbol type") - self._func = func - - for arg in args: - if not isinstance(arg, Symbol): - raise RuntimeError("Input arg is not Symbol type") - self._args.append(arg) - - for keyword in keywords: - if not isinstance(keyword, Symbol): - raise RuntimeError("Input keyword is not Symbol type") - self._keywords.append(keyword) + self.copy_list(self._args, args) + self.copy_list(self._keywords, keywords) class AttributeSymbol(UncompilableSymbol): def __init__( self, ast_node: ast.AST, scope, symbol_name, - attr=None, value=None + attr: str = None, value: str = None ): super().__init__( ast_node, scope, symbol_name, SymbolType.attribute, ) self._attr = attr - - if not isinstance(value, Symbol): - raise RuntimeError("Input value is not Symbol type") self._value = value class BinopSymbol(UncompilableSymbol): def __init__( self, ast_node: ast.AST, scope, symbol_name, - left=None, op=None, right=None + left: str = None, op=None, right: str = None ): super().__init__( ast_node, scope, symbol_name, SymbolType.binop, ) - if not isinstance(left, Symbol): - raise RuntimeError("Input left is not Symbol type") self._left = left self._op = op - if not isinstance(right, Symbol): - raise RuntimeError("Input right is not Symbol type") self._right = right class ReturnSymbol(UncompilableSymbol): def __init__( self, ast_node: ast.AST, scope, symbol_name, - value=None + value: str = None ): super().__init__( ast_node, scope, symbol_name, SymbolType.return_type, ) - if not isinstance(value, Symbol): - raise RuntimeError("Input value is not Symbol type") self._value = value -- Gitee From 7e1ecced4c18bb4826cab2e70c5f72d22251350b Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Wed, 5 Jan 2022 09:45:59 +0800 Subject: [PATCH 29/34] fix document and rewrite bug --- .../mindspore/golden_stick/net_transform.py | 40 +++++++++---------- .../default_qat/default_fake_quantizer.py | 12 ------ .../default_qat/default_layer_policy.py | 1 + .../quantization/fake_quantizer.py | 3 ++ .../golden_stick/quantization/quantize.py | 38 +++++++++++++++++- .../quantization/quantize_wrapper_act.py | 2 +- .../quantization/quantize_wrapper_cell.py | 2 +- 7 files changed, 62 insertions(+), 36 deletions(-) diff --git a/mindspore/python/mindspore/golden_stick/net_transform.py b/mindspore/python/mindspore/golden_stick/net_transform.py index f5fb25f87c1..8c337cf5221 100644 --- a/mindspore/python/mindspore/golden_stick/net_transform.py +++ b/mindspore/python/mindspore/golden_stick/net_transform.py @@ -16,7 +16,7 @@ from typing import Union, Optional from mindspore.nn.cell import Cell -from mindspore.rewrite import Graph, Node, PatternEngine +from mindspore.rewrite import Graph, BaseNode, PatternEngine class NetTransformer: @@ -39,35 +39,35 @@ class NetTransformer: return self._graph.python_object() - def nodes(self) -> [Node]: + def nodes(self) -> [BaseNode]: """ Returns: - a list of Node corresponding to all layers in original network. + a list of BaseNode corresponding to all layers in original network. """ return self._graph.nodes @staticmethod - def set_node_attr(node: Node, key: str, value): + def set_node_attr(node: BaseNode, key: str, value): node.attribute.set_attribute(key, value) @staticmethod - def get_node_attr(node: Node, key: str): + def get_node_attr(node: BaseNode, key: str): return node.attribute.attribute[key] - def find_node(self, full_name_with_scope: str) -> Node: + def find_node(self, full_name_with_scope: str) -> BaseNode: """ Args: full_name_with_scope (str): Name of node to be find. Returns: - Node whose name is `full_name_with_scope`. + BaseNode whose name is `full_name_with_scope`. """ return self._graph.find(full_name_with_scope) # return inputs of node whose full_name_with_scope is full_name_with_scope - def node_inputs(self, full_name_with_scope: str) -> [Node]: + def node_inputs(self, full_name_with_scope: str) -> [BaseNode]: """ Args: full_name_with_scope (str): Name of node to be find. @@ -82,7 +82,7 @@ class NetTransformer: return node.inputs # return outputs of node whose full_name_with_scope is full_name_with_scope - def node_outputs(self, full_name_with_scope: str) -> [Node]: + def node_outputs(self, full_name_with_scope: str) -> [BaseNode]: """ Args: full_name_with_scope (str): Name of node to be find. @@ -96,25 +96,25 @@ class NetTransformer: return [] return node.outputs - def insert_node(self, new_node: Node) -> Node: + def insert_node(self, new_node: BaseNode) -> BaseNode: """ Args: - new_node (Node): New node to be inserted into original network. + new_node (BaseNode): New node to be inserted into original network. New_node should contain its inputs and outputs. Returns: - Node has been inserted, return None if failed + BaseNode has been inserted, return None if failed """ return self._graph.insert_node(new_node) - def remove_node(self, node: Union[str, Node]) -> Optional[Node]: + def remove_node(self, node: Union[str, BaseNode]) -> Optional[BaseNode]: """ Args: - node (Node): node to be removed from original network. + node (BaseNode): node to be removed from original network. Returns: - Node has been removed, return None if failed + BaseNode has been removed, return None if failed """ if isinstance(node, str): @@ -123,17 +123,17 @@ class NetTransformer: return None return self._graph.remove_node(node) - def replace_node(self, target: Union[str, Node], value: Union[Cell, Node]) -> Optional[Node]: + def replace_node(self, target: Union[str, BaseNode], value: Union[Cell, BaseNode]) -> Optional[BaseNode]: """ Args: - target (Union[str, Node]): Name of node to be replaced. - value (Union[Cell, Node]): Node to be replaced into original network. + target (Union[str, BaseNode]): Name of node to be replaced. + value (Union[Cell, BaseNode]): BaseNode to be replaced into original network. Note: new_node should has same inputs and outputs with old_node. Returns: - Node has been replaced, return None if failed + BaseNode has been replaced, return None if failed """ if isinstance(target, str): @@ -141,7 +141,7 @@ class NetTransformer: if target is None: return None if isinstance(value, Cell): - value = Node(value) + value = BaseNode(value) return self._graph.replace_node(target, value) # replace src_pattern with target_nodes. diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py index 727b3fa82ac..a57c407a7a5 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py @@ -23,18 +23,6 @@ import mindspore.context as context import numpy as np -class FixFakeQuantizer(FakeQuantizer): - ... - - -class AllValueFakeQuantizer(FakeQuantizer): - ... - - -class MovingAvgFakeQuantizer(FakeQuantizer): - ... - - def _calculate_quant_max(num_bits, neg_trunc=False): if neg_trunc: quant_max = (1 << num_bits) - 1 diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py index c32a6eaa62d..2cb90f65915 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py @@ -79,6 +79,7 @@ class DefaultLayerPolicy(LayerPolicy): def wrap_cell(self, handler: Cell) -> Cell: return QuantizeWrapperCell(handler, self) + class ActivationLayerPolicy(DefaultLayerPolicy): def __init__(self, insert_before_input=False, insert_after_output=True): super().__init__([], []) diff --git a/mindspore/python/mindspore/golden_stick/quantization/fake_quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/fake_quantizer.py index e23154aac82..7dfdb87eb1d 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/fake_quantizer.py +++ b/mindspore/python/mindspore/golden_stick/quantization/fake_quantizer.py @@ -15,5 +15,8 @@ """FakeQuantizer.""" from mindspore.nn.cell import Cell +""" + FakeQuantizer should be a Cell for automatic-differentiation +""" FakeQuantizer = Cell diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize.py b/mindspore/python/mindspore/golden_stick/quantization/quantize.py index b8a401a1ba1..1937f267aa2 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quantize.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize.py @@ -27,7 +27,7 @@ from mindspore.nn import Cell class QuantAwareTraining(GoldenStick): """ - Derived class of GoldenStick. Default QAT-algorithm. + Derived class of GoldenStick. Base class of QAT-algorithm. """ def __init__(self, config: {}): @@ -37,6 +37,15 @@ class QuantAwareTraining(GoldenStick): self._custom_layer_policy_map = None def _propagate_layer_policy(self, nodes: [Node]): + """ + Set layer_policy for every layer according to custom_layer_policy_map, layer_policy_map and net_layer_policy in + QuantAwareTraining. custom_layer_policy_map is in first priority, layer_policy_map is in second priority and + net_layer_policy is in last priority. + + Args: + nodes (List[Node]): nodes to be checked between which may find redundant fake-quantizer + """ + # step1 apply net layer-policy first net_layer_policy: Optional[LayerPolicy] = self._qat_policy.get_net_layer_policy() if net_layer_policy: @@ -56,6 +65,14 @@ class QuantAwareTraining(GoldenStick): @staticmethod def _reduce_redundant_fake_quant(nodes: [Node]): + """ + Reduce redundant fake-quantizer node between nodes. It usually occurs when pre-node inserted output + fake-quantizer and post-node inserted input fake-quantizer at the same time. + + Args: + nodes (List[Node]): nodes to be checked between which may find redundant fake-quantizer + """ + for node in nodes: cur_policy: LayerPolicy = NetTransformer.get_node_attr(node, layer_policy_key) # cur-node has no quant policy, so no fq will insert into its inputs @@ -87,6 +104,15 @@ class QuantAwareTraining(GoldenStick): cur_policy.set_input_not_insert_fq(i) def _apply_fuse_patterns(self, net_transformer: NetTransformer): + """ + Apply transforms to corresponding layer. + Replace layer with return value of wrap_cell of layer-policy by default. + + Args: + net_transformer (NetTransformer): net_transformer is used to apply transforms to graph. + # todo we should decouple graph with net_transformer + """ + transformers = self._qat_policy.get_transformers() if isinstance(self._custom_transforms, list): for transform in self._custom_transforms: @@ -94,11 +120,19 @@ class QuantAwareTraining(GoldenStick): transformers.append(transform) for transformer in transformers: # Transformer always return False - # todo test overlap between transformers net_transformer.pattern_transform(transformer) @staticmethod def _apply_layer_policy(nodes: [Node], net_transformer: NetTransformer): + """ + Apply layer-policy to corresponding layer. + Replace layer with return value of wrap_cell of layer-policy by default. + + Args: + nodes (List(Node)): Network to be quantized. + net_transformer (NetTransformer): net_transformer is used to transform node according to layer policy. + """ + for node in nodes: layer_policy = NetTransformer.get_node_attr(node, layer_policy_key) if isinstance(layer_policy, LayerPolicy): diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py index 4e16cd353f8..754e92b7ff2 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_act.py @@ -20,7 +20,7 @@ from .fake_quantizer import FakeQuantizer class QuantizeWrapperActivation(Cell): """ - Derive from Cell for define how to construct a wrap quant-cell from a normal cell with fake-quant algorithm. + Decorator of Activation Cell class for decorate a cell to a quant-cell with fake-quant algorithm. Args: act (Cell): normal cell to be wrapped. diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py index fbebd4b08e9..3dfa03fbcb6 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py @@ -22,7 +22,7 @@ from .quantize_wrapper_act import QuantizeWrapperActivation class QuantizeWrapperCell(Cell): """ - Derive from Cell for define how to construct a wrap quant-cell from a normal cell with fake-quant algorithm. + Decorator of Activation Cell class for decorate a cell to a quant-cell with fake-quant algorithm. Args: handler (Cell): normal cell to be wrapped. -- Gitee From 3c2f79e8d292fa02bd73277417253457c0f2dfb7 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Wed, 5 Jan 2022 14:31:32 +0800 Subject: [PATCH 30/34] add IR manlpulate interface in graph --- .../default_qat/default_quantize.py | 4 +- .../mindspore/rewrite_experiment/graph.py | 140 ++++--- .../mindspore/rewrite_experiment/node.py | 182 ++++----- .../rewrite_experiment/pattern_engine.py | 350 ++++++++++++++++++ 4 files changed, 545 insertions(+), 131 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/pattern_engine.py diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantize.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantize.py index 9a6d645646b..8162bd3ad6e 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantize.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_quantize.py @@ -14,11 +14,11 @@ # ============================================================================ """DefaultQuantAwareTraining.""" -from ...golden_stick import GoldenStick +from ..quantize import QuantAwareTraining from .default_net_policy import DefaultNetworkPolicy -class DefaultQuantAwareTraining(GoldenStick): +class DefaultQuantAwareTraining(QuantAwareTraining): """ Derived class of GoldenStick. Default QAT-algorithm. """ diff --git a/mindspore/python/mindspore/rewrite_experiment/graph.py b/mindspore/python/mindspore/rewrite_experiment/graph.py index a718f8abb85..ae28150b871 100644 --- a/mindspore/python/mindspore/rewrite_experiment/graph.py +++ b/mindspore/python/mindspore/rewrite_experiment/graph.py @@ -1,58 +1,108 @@ -import inspect -from types import FunctionType -from typing import Union -import ast -import astpretty +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ -import mindspore.nn as nn -from mindspore import log as logger -from mindspore.ops.primitive import Primitive +from typing import Tuple, Optional +from .node import Node, NodeType -from .node import Node, NodeType, PlaceholderNode -from .observer import Observer +class InsertPoint: + def __init__(self, node: Optional[Node]=None, before: bool=True, index: int=0): + self._node = node + self._before = before + self._index = index -class Graph(Observer): - def update(self): - pass + def update(self, node: Optional[Node]=None, before: bool=True, index: int=0): + self._node = node + self._before = before + self._index = index - def __init__(self): - # if not isinstance(network, nn.Cell): - # logger.error("Only support network with Cell type now") - # return - - # self._placeholders: List[Node] = [] - # self._contant_nodes: List[ConstantNode] = [] - # self._param_default_value: Dict = {} - # self._node_attributes: Dict = {} - # self._symbol_table: dict = {} - # self._net_cls = type(network) - # self._name = self._net_cls.__name__ - # self._base_scope = self._net_cls.__name__ - # self._network = network - # network_str = inspect.getsource(self._net_cls) - # self._ast_root: ast.AST = ast.parse(network_str) - # - # root_node = Node(self._ast_root, self._name, self._net_cls, NodeType.module) - self._nodes: [Node] = [] - # self._return = root_node + def get_node(self): + return self._node + + def get_is_before(self): + return self._before + + def get_index(self): + return self._index +class Graph: + def __init__(self): + self._root: Node = Node(NodeType.invalid, "", "") # todo + self._insert: callable = self._root.insert_before + self._node_size = 0 + self._nodes: {str: Node} = {} + @property - def nodes(self) -> list: - """ - 返回graph的节点,可以迭代访问,这些节点中应该还要包含init中的子图,在pattern匹配的时候会出现该问题 - """ - return self._nodes + def nodes(self) -> 'todo-iter': + pass + # + # def get_return(self): + # return self._return + # + # def set_return(self, node: Node): + # self._return = node - def get_return(self): - return self._return + def insert_before(self, node_or_name: Tuple[Node, str]): + if isinstance(node_or_name, Node): + node = node_or_name + elif isinstance(node_or_name, str): + node = self._nodes.get(node_or_name) + else: + raise RuntimeError("Unsupported node_or_name: ", node_or_name) + self._insert = node.insert_before - def set_return(self, node: Node): - self._return = node + def insert_after(self, node_or_name: Tuple[Node, str]): + if isinstance(node_or_name, Node): + node = node_or_name + elif isinstance(node_or_name, str): + node = self._nodes.get(node_or_name) + else: + raise RuntimeError("Unsupported node_or_name: ", node_or_name) + self._insert = node.insert_after - def print_ast(self): - astpretty.pprint(self._ast_root) + def create_node(self, node_type: int, cell_name: str, construct_args=None, construct_kwargs=None, + targets=None, field: Optional[str] = None) -> Node: + """ + add global_vars: + globals = {k: v} : store cell_name or object used in construct_args, construct_kwargs + add field to class: + self.var_name = cell_name(construct_args, construct_kwargs) + add call to construct method of class: + targets: target_type = self.var_name(self._normalized_args) + """ + if field is None: + field = "todo" + node = Node(node_type, cell_name, field, construct_args, construct_kwargs, targets) + # todo link node-list + self._node_size += 1 + self._nodes[node.get_field()] = node + return node + + def erase_node(self, node_or_name: Tuple[Node, str]) -> Node: + if isinstance(node_or_name, Node): + node = node_or_name + elif isinstance(node_or_name, str): + node = self._nodes.get(node_or_name) + else: + raise RuntimeError("Unsupported node_or_name: ", node_or_name) + # todo relink node-list + self._node_size -= 1 + self._nodes.pop(node.get_field()) + return node def add_placeholder(self, name, targets=None, ast_node=None, default_value=None): - self._nodes.append(PlaceholderNode(name, targets, ast_node, default_value)) + pass + diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index b2757eb559c..255f0bfc59b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -12,22 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import ast +from collections import OrderedDict from typing import Dict, List, Union, Optional -import mindspore.nn as nn -from mindspore.ops.primitive import Primitive -from .subject import Subject -from .observer import Observer - class NodeType: module = 1 - class_def = 2 - function_def = 3 + class_def = 2 # useless + function_def = 3 # useless placeholder = 4 # input - parameter = 5 # weight - constant = 6 + parameter = 5 # useless + constant = 6 # useless call_cell = 7 # call cell object call_method = 8 # method in cell call_function = 9 # subclass of primitive @@ -36,7 +31,7 @@ class NodeType: invalid = 100 -class Node(Subject): +class Node: """ Base class of node. @@ -47,42 +42,85 @@ class Node(Subject): inputs: the input nodes of this node. """ - def notify(self): - # self._observer.update() - pass - - def __init__(self, ast_node: ast.AST, name, cls: type, node_type=NodeType.invalid): - """ - 创建一个节点时对应的属性怎么传进来,cell应该不涉及,primitive会有这种情况 - """ - self._name: str = name - # self._attribute: AttributeNode = AttributeNode() - # if outputs is None: - # self._outputs: List[CellNode] = list() - # else: - # self._outputs = outputs - # if inputs is None: - # self._inputs: List[CellNode] = list() - # else: - # self._inputs = inputs - # self._targets: List[str] = targets # 用来保存算子输出结果的名称,用来匹配算子输入名称 - # self._args: List = args - self._type: int = node_type - self._ast_root: ast.AST = ast_node - self._ast_processing: ast.AST = ast_node - self._atomic: bool = False + def __init__(self, node_type: int, cell_name: str, field: str, construct_args=None, construct_kwargs=None, + targets=None): + if node_type != NodeType.call_cell: + raise RuntimeError("Only support call_cell now") + self._node_type: int = node_type + # name of subclass of cell + self._symbol_name: str = cell_name + # args of constructor of subclass os cell + if construct_args is None: + self._construct_args = [] + else: + self._construct_args = construct_args + # kwargs of constructor of subclass os cell + if construct_kwargs is None: + self._construct_kwargs = {} + else: + self._construct_kwargs = construct_kwargs + # targets of call expression in construct function + if targets is None: + self._targets: [str] = [] + # class field name + if field is None: + self._field: str = "todo" + else: + self._field: str = field + # edge of tensor self._args: [str] = [] self._kwargs: {} = {} - self._targets: [str] = [] + self._normalized_args: OrderedDict = OrderedDict() + # edge of node self._inputs: [Node] = [] self._outputs: [Node] = [] - self._cls: type = cls + # position in graph nodes list + # it will affect code-order of python code + self._prev: Node = self + self._next: Node = self + + def isolate(self): + origin_prev = self._prev + origin_next = self._next + origin_prev._next = origin_next + origin_next._prev = origin_prev + self._prev = self + self._next = self + + def insert_before(self, node: 'Node'): + node.isolate() + origin_prev = self._prev + origin_prev._next = node + node._prev = origin_prev + node._next = self + self._prev = node + + def insert_after(self, node: 'Node'): + node.isolate() + origin_next = self._next + self._next = node + node._prev = self + node._next = origin_next + origin_next._prev = node + + def update_arg(self, index:int, arg:str): + self._args[index] = arg + + def update_args(self, args:[]): + self._args = args + + def update_kwargs(self, key:str, value): + self._kwargs[key] = value + + def normalize_args(self): + # todo merge args kwargs default_args into normalize_args + pass - def get_name(self) -> str: - return self._name + def get_field(self) -> str: + return self._field - def set_name(self, name: str): - self._name = name + def set_field(self, field: str): + self._field = field def get_inputs(self) -> list: return self._inputs @@ -102,36 +140,11 @@ class Node(Subject): def get_targets(self): return self._targets - def class_type(self): - return self._cls - - def node_type(self) -> int: - return self._type - - def get_processing_ast(self): - return self._ast_processing - - def get_ast(self): - return self._ast_root - - def set_ast(self, ast_node: ast.AST): - self._ast_root = ast_node + def get_symbol_name(self): + return self._symbol_name - # def attribute(self) -> AttributeNode: - # return self._attribute - # - # def attribute(self, attribute: AttributeNode): - # self._attribute = attribute - - # def set_attributes(self, attribute: Dict): - # for key, value in attribute.items(): - # self._attribute.attribute[key] = value - # - # def set_attribute(self, key: str, value): - # self._attribute._attribute[key] = value - # - # def get_attribute(self, key: str): - # return self._attribute._attribute.get(key) + def get_node_type(self) -> int: + return self._node_type # class CellNode(Node): @@ -197,20 +210,21 @@ class Node(Subject): # return f"name: {self._name}; value: {self._value}; outputs: {len(self.outputs)}; output names: {output_names}" # # -class PlaceholderNode(Node): - """ - 'PlaceholderNode' is used to represent inputs of Cell, method or function. - """ - def __init__(self, name, targets=None, ast_node=None, default_value=None): - super().__init__(ast_node, name, Primitive, NodeType.placeholder) - self._ast_node = ast_node - self._default_value = default_value - - def __repr__(self) -> str: - output_names = "" - for n in self.outputs: - output_names += n.name + " " - return f"name: {self._name}; targets: {self._targets}, outputs: {len(self.outputs)}; output names: {output_names}; attribute: {self._attribute}" +# class PlaceholderNode(Node): +# """ +# 'PlaceholderNode' is used to represent inputs of Cell, method or function. +# """ +# +# def __init__(self, name, targets=None, ast_node=None, default_value=None): +# super().__init__(ast_node, name, Primitive, NodeType.placeholder) +# self._ast_node = ast_node +# self._default_value = default_value +# +# def __repr__(self) -> str: +# output_names = "" +# for n in self.outputs: +# output_names += n.name + " " +# return f"name: {self._name}; targets: {self._targets}, outputs: {len(self.outputs)}; output names: {output_names}; attribute: {self._attribute}" # # # class SubgraphNode(Node): diff --git a/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py b/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py new file mode 100644 index 00000000000..887761a91e5 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py @@ -0,0 +1,350 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PatternEngine.""" + +from typing import Tuple, Union, List, Type +from collections import OrderedDict +from .graph import Graph +from .node import CellNode, NodeType +from mindspore.nn.cell import Cell +from mindspore import log as logger + + +class PatternNode: + """ + PatternNode is define as a node while defining pattern. + + Args: + node_name (str): Name of current node. + node_type (str): Cell type of current node. + inputs (List[PatternNode]): Input nodes of current node. + """ + + def __init__(self, node_name: str, node_type: Type = Type[None], inputs: ['PatternNode'] = None): + self._name = node_name + self._type = node_type + if inputs is None: + self._inputs = [] + else: + self._inputs = inputs + + @staticmethod + def from_node(node: CellNode) -> 'PatternNode': + """ + Create a PatternNode from a rewrite node. + + Args: + node (CellNode): input rewrite node. + + Returns: + PatternNode created from rewrite node. + """ + + pattern_node = PatternNode(node.name) + if node.node_type() is NodeType.call_cell: + pattern_node._type = node.type + return pattern_node + + @staticmethod + def create_pattern_from_node(node: CellNode) -> 'PatternNode': + """ + Create a PatternNode from a rewrite node with its inputs. + + Args: + node (CellNode): input rewrite node. + + Returns: + PatternNode created from rewrite node. + """ + + pattern_node = PatternNode.from_node(node) + inputs = [] + for node_input in node.inputs(): + inputs.append(PatternNode.create_pattern_from_node(node_input)) + pattern_node._inputs = inputs + return pattern_node + + @staticmethod + def create_pattern_from_list(type_list: []) -> 'PatternNode': + """ + Create a PatternNode from a cell type list. + + Args: + type_list (List): input cell type list. + + Returns: + PatternNode created from cell type list. + """ + + last_node = None + for i in range(0, len(type_list)): + cell_type = type_list[i] + cur_node = PatternNode(str(i) + "-" + str(cell_type), cell_type, []) + if last_node is not None: + cur_node._inputs = [last_node] + else: + cur_node._inputs = [] + last_node = cur_node + return last_node + + def add_input(self, node_type): + """ + Add a input for current PatternNode. + + Args: + node_type : cell type as an input. + """ + + self._inputs.append(node_type) + + def set_inputs(self, inputs): + """ + Set inputs for current PatternNode. + + Args: + inputs (List) : inputs to be set as inputs of current PatternNode. + """ + + self._inputs = inputs + + def match(self, node: CellNode) -> bool: + """ + Check if current PatternNode can match with a rewrite node + + Args: + node (CellNode) : a rewrite node to be match. + """ + + return self._type == node.type + + def inputs(self): + """ + Getter of inputs. + """ + + return self._inputs + + def name(self) -> str: + """ + Getter of name. + """ + return self._name + + def type(self): + """ + Getter of type. + """ + return self._type + + +class VarNode(PatternNode): + """ + VarNode is a subclass of PatternNode whose match is always True. + """ + + def __init__(self): + super(VarNode, self).__init__("placehold", Cell, []) + + def match(self, node: CellNode) -> bool: + return node is not None + + +class PatternEngine: + """ + PatternEngine is define how to transform a graph by PattenNode. + + Args: + pattern (Union[PatternNode, List]): a instance of PatternNode or a cell-type-list to construct PatternNode. + replacement (callable): a callable define how to generate new_node. + """ + + def __init__(self, pattern: Union[PatternNode, List], replacement: callable = None): + if isinstance(pattern, PatternNode): + self._is_chain = False + self._replacement = replacement + self._pattern = pattern + elif isinstance(pattern, list): + self._is_chain = True + self._replacement = replacement + self._pattern = PatternNode.create_pattern_from_list(pattern) + else: + logger.debug("Unsupported pattern type: %s", type(pattern)) + self._is_chain = False + self._replacement = None + self._pattern = VarNode() + + def pattern(self) -> PatternNode: + """ + Getter of pattern. + """ + + return self._pattern + + def apply(self, graph: Graph) -> bool: + """ + Apply current pattern to a graph. + + Args: + graph (Graph): graph to be transformed. + + Returns: + If graph been changed. + """ + + root: CellNode = graph.root() + changed = False + # IR match + queue: [CellNode] = [root] + while len(queue) > 0: + cur_node: CellNode = queue.pop(0) + node_inputs = cur_node.inputs + matched, matched_dict = self._match(self._pattern, cur_node) + if not matched or not PatternEngine._check_match(self._pattern, matched_dict): + for node_input in node_inputs: + queue.append(node_input) + continue + matched_list = list(matched_dict.values()) + if self._is_chain: + new_node = self._process_chain(matched_list) + else: + new_node = self._process_tree(matched_dict) + if new_node is None: # return None to remove + changed = True + for key in matched_dict: + graph.remove_node(matched_dict[key]) + elif new_node == cur_node: # return origin Node for do nothing + pass + else: # return Node to insert or replace (new Node no need to set inputs and outputs) + # todo if we need to support _process_chain or _process_tree return multi-node + changed = True + graph.replace_node(matched_list, new_node) + node_inputs = new_node.inputs + for node_input in node_inputs: + queue.append(node_input) + return changed + + def _process_chain(self, matched_nodes: [CellNode]) -> CellNode: + """ + Define how to generate a new_node with fuse_fn when pattern is a chain-pattern. + + Args: + matched_nodes ([Node]): a list of Node as matched result. + + Returns: + New node created from matched result. + """ + + if self._replacement is None: + return matched_nodes[len(matched_nodes) - 1] + replacement = self._replacement(*matched_nodes) + if replacement is None: + return None + if len(matched_nodes) == 0: + new_node = CellNode(instance=replacement) + else: + new_node = CellNode(instance=replacement, inputs=matched_nodes[0].inputs) + node_name = "" + for matched_node in matched_nodes: + node_name += matched_node.name + "_" + node_name += "fused" + new_node.name = node_name + return new_node + + # matched_cells: name_of_cell_in_pattern map to matched cell in network + def _process_tree(self, matched_nodes: OrderedDict) -> CellNode: + """ + Define how to generate a new_node when pattern is a tree-pattern. + This method must be overridden by all subclasses whose pattern is a tree-pattern. + + Args: + matched_nodes (OrderedDict): a OrderedDict of Node as matched result. + + Returns: + New node created from matched result. + """ + + pass + + @staticmethod + def _merge_ordered_dict(dict1: OrderedDict, dict2: OrderedDict) -> OrderedDict: + """ + A static util method to merge two OrderedDict + + Args: + dict1 (OrderedDict): first dict to be merge. + dict2 (OrderedDict): second dict to be merge. + + Returns: + Merged OrderedDict. + """ + + merged = dict1.copy() + merged.update(dict2) + return merged + + def _match(self, pattern: PatternNode, node: CellNode) -> Tuple[bool, OrderedDict]: + """ + Match `pattern` with a rewrite node with all inputs of the `pattern` + + Args: + pattern (PatternNode): pattern to be match. + node (CellNode): node to be match. + + Returns: + A bool value to indicate if matched. + A instance of OrderedDict as match result. + """ + + # todo: Recurse into subgraph node. Depend on subgraph node definition + if node.node_type() != NodeType.call_cell: + logger.debug("Pattern match failed: node(%s) is not a cell", node.name) + return False, OrderedDict() + if not pattern.match(node): + logger.debug("Pattern match failed: node(%s)'s type is %s while pattern type is %s", node.name, node.type, + pattern.type()) + return False, OrderedDict() + if isinstance(pattern, VarNode): + return True, OrderedDict() + pattern_inputs = pattern.inputs() + cur_inputs = node.inputs + input_num = len(pattern_inputs) + if input_num == 0: + return True, OrderedDict({pattern.name(): node}) + if input_num != len(cur_inputs): + logger.debug("Pattern match failed: node(%s)'s has %d inputs while pattern has %d inputs", node.name, + len(node.inputs), input_num) + return False, OrderedDict() + result = OrderedDict() + for i in range(0, input_num): + is_matched, tmp_result = self._match(pattern_inputs[i], cur_inputs[i]) + if not is_matched: + return False, OrderedDict() + else: + result = PatternEngine._merge_ordered_dict(result, tmp_result) + result[pattern.name()] = node + return True, result + + @staticmethod + def _check_match(pattern: PatternNode, match_dict: OrderedDict) -> bool: + matched_nodes = match_dict.values() + for key in match_dict: + if key == pattern.name(): + continue + node = match_dict[key] + for output in node.outputs: + if output not in matched_nodes: + logger.debug("Check match failed, pattern leaked") + return False + return True -- Gitee From 9240d149cefa1904256496f27cf680038fcc2db8 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 6 Jan 2022 11:14:44 +0800 Subject: [PATCH 31/34] replace symbol name and symbol ast replace symbol name and symbol ast --- .../compilers/arg_compiler.py | 6 +++--- .../compilers/arguments_compiler.py | 6 +++--- .../compilers/assign_compiler.py | 6 +++--- .../compilers/attribute_compiler.py | 6 +++--- .../compilers/binop_compiler.py | 6 +++--- .../compilers/call_compiler.py | 6 +++--- .../compilers/class_def_compiler.py | 4 ++-- .../compilers/constant_compiler.py | 6 +++--- .../compilers/function_def_compiler.py | 6 +++--- .../compilers/if_compiler.py | 6 +++--- .../compilers/module_compiler.py | 4 ++-- .../compilers/name_compiler.py | 6 +++--- .../compilers/return_compiler.py | 6 +++--- .../mindspore/rewrite_experiment/graph.py | 2 +- .../linkers/construct_function_def_linker.py | 2 +- .../linkers/init_function_def_linker.py | 4 ++-- .../mindspore/rewrite_experiment/rewrite.py | 6 +++--- .../mindspore/rewrite_experiment/symbol.py | 20 +++++++++++-------- 18 files changed, 56 insertions(+), 52 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py index 2f3db930214..cee322408a2 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/arg_compiler.py @@ -23,16 +23,16 @@ from mindspore import log as logger @CompilerRegister.reg_compiler class ArgCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.arg): + if not isinstance(symbol.symbol_ast, ast.arg): return [symbol] if isinstance(symbol, ArgSymbol): return [symbol] new_symbols = [] - ast_arg: ast.arg = symbol.get_ast() + ast_arg: ast.arg = symbol.symbol_ast _arg = ast_arg.arg arg_symbol = ArgSymbol( - ast_arg, symbol.get_scope(), symbol.get_symbol_name(), + ast_arg, symbol.get_scope(), symbol.symbol_name, arg_input=_arg, arg_type=type(_arg).__name__ ) new_symbols.append(arg_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py index 9dd9940ec51..326a249cdc3 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/arguments_compiler.py @@ -22,12 +22,12 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class ArgumentsCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.arguments): + if not isinstance(symbol.symbol_ast, ast.arguments): return [symbol] if isinstance(symbol, ArgumentsSymbol): return [symbol] new_symbols = [] - ast_arguments: ast.arguments = symbol.get_ast() + ast_arguments: ast.arguments = symbol.symbol_ast args_ = ast_arguments.args defaults_ = ast_arguments.defaults @@ -65,7 +65,7 @@ class ArgumentsCompiler(Compiler): new_symbols.append(default_symbol) arguments_symbol = ArgumentsSymbol( - ast_arguments, symbol.get_scope(), symbol.get_symbol_name(), + ast_arguments, symbol.get_scope(), symbol.symbol_name, arg_symbols, default_symbols ) new_symbols.append(arguments_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py index d2616c81e5c..04b3d573df9 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/assign_compiler.py @@ -22,12 +22,12 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class AssignCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.Assign): + if not isinstance(symbol.symbol_ast, ast.Assign): return [symbol] if isinstance(symbol, AssignSymbol): return [symbol] new_symbols = [] - ast_assign: ast.Assign = symbol.get_ast() + ast_assign: ast.Assign = symbol.symbol_ast targets = ast_assign.targets target_symbols = [] target_index = 0 @@ -52,7 +52,7 @@ class AssignCompiler(Compiler): ) new_symbols.append(value_symbol) assign_symbol = AssignSymbol( - ast_assign, symbol.get_scope(), symbol.get_symbol_name(), + ast_assign, symbol.get_scope(), symbol.symbol_name, target_symbols, value_symbol.symbol_name ) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py index c83322312a7..afb9b1c9261 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/attribute_compiler.py @@ -22,12 +22,12 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class AttributeCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.Attribute): + if not isinstance(symbol.symbol_ast, ast.Attribute): return [symbol] if isinstance(symbol, AttributeSymbol): return [symbol] new_symbols = [] - ast_attribute: ast.Attribute = symbol.get_ast() + ast_attribute: ast.Attribute = symbol.symbol_ast attr_ = ast_attribute.attr value_ = ast_attribute.value @@ -42,7 +42,7 @@ class AttributeCompiler(Compiler): new_symbols.append(value_symbol) attribute_symbol = AttributeSymbol( - ast_attribute, symbol.get_scope(), symbol.get_symbol_name(), + ast_attribute, symbol.get_scope(), symbol.symbol_name, attr=attr_, value=value_symbol.symbol_name ) new_symbols.append(attribute_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py index b441afe7ad2..3f08bd8ec70 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/binop_compiler.py @@ -23,12 +23,12 @@ from mindspore import log as logger @CompilerRegister.reg_compiler class BinopCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.BinOp): + if not isinstance(symbol.symbol_ast, ast.BinOp): return [symbol] if isinstance(symbol, BinopSymbol): return [symbol] new_symbols = [] - ast_binop: ast.BinOp = symbol.get_ast() + ast_binop: ast.BinOp = symbol.symbol_ast _left = ast_binop.left _op = ast_binop.op @@ -57,7 +57,7 @@ class BinopCompiler(Compiler): logger.warning("Ignoring right (%s) in binop compiler", type(_right).__name__) binop_symbol = BinopSymbol( - ast_binop, symbol.get_scope(), symbol.get_symbol_name(), + ast_binop, symbol.get_scope(), symbol.symbol_name, left_symbol.symbol_name, _op, right_symbol.symbol_name ) new_symbols.append(binop_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py index cada24b6169..2fabb7a0d61 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/call_compiler.py @@ -23,12 +23,12 @@ from mindspore import log as logger @CompilerRegister.reg_compiler class CallCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.Call): + if not isinstance(symbol.symbol_ast, ast.Call): return [symbol] if isinstance(symbol, CallSymbol): return [symbol] new_symbols = [] - ast_call: ast.Call = symbol.get_ast() + ast_call: ast.Call = symbol.symbol_ast func_ = ast_call.func args_ = ast_call.args @@ -75,7 +75,7 @@ class CallCompiler(Compiler): new_symbols.append(keyword_symbol) call_symbol = CallSymbol( - ast_call, symbol.get_scope(), symbol.get_symbol_name(), + ast_call, symbol.get_scope(), symbol.symbol_name, func_symbol.symbol_name, arg_symbols, keyword_symbols ) new_symbols.append(call_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py index 657722b1265..63201afeeb3 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/class_def_compiler.py @@ -22,9 +22,9 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class ClassDefCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.ClassDef): + if not isinstance(symbol.symbol_ast, ast.ClassDef): return [symbol] - class_def: ast.ClassDef = symbol.get_ast() + class_def: ast.ClassDef = symbol.symbol_ast bodies: list = class_def.body new_symbols: [Symbol] = [] for body in bodies: diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py index 1e21e66fb60..283032480be 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/constant_compiler.py @@ -23,16 +23,16 @@ from mindspore import log as logger @CompilerRegister.reg_compiler class ConstantCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.Constant): + if not isinstance(symbol.symbol_ast, ast.Constant): return [symbol] if isinstance(symbol, ConstantSymbol): return [symbol] new_symbols = [] - ast_constant: ast.Constant = symbol.get_ast() + ast_constant: ast.Constant = symbol.symbol_ast _value = ast_constant.value constant_symbol = ConstantSymbol( - ast_node=ast_constant, scope=symbol.get_scope(), symbol_name=symbol.get_symbol_name(), + ast_node=ast_constant, scope=symbol.get_scope(), symbol_name=symbol.symbol_name, value=_value, const_type=type(_value).__name__ ) new_symbols.append(constant_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py index 5feff8425a1..e7e4a4fb36f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/function_def_compiler.py @@ -24,11 +24,11 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class FunctionDefCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.FunctionDef): + if not isinstance(symbol.symbol_ast, ast.FunctionDef): return [symbol] if isinstance(symbol, FunctionSymbol): return [symbol] - function_def: ast.FunctionDef = symbol.get_ast() + function_def: ast.FunctionDef = symbol.symbol_ast new_symbols = [] # parse args @@ -78,7 +78,7 @@ class FunctionDefCompiler(Compiler): logger.warning("Ignoring symbol(%s) in FunctionDef", body) function_symbol = FunctionSymbol( - function_def, symbol.get_scope(), symbol.get_symbol_name(), + function_def, symbol.get_scope(), symbol.symbol_name, arguments_symbol.symbol_name, body_symbols ) new_symbols.append(function_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py index 27ab0b2023b..33bbd305c87 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/if_compiler.py @@ -23,12 +23,12 @@ from mindspore import log as logger @CompilerRegister.reg_compiler class IfCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.If): + if not isinstance(symbol.symbol_ast, ast.If): return [symbol] if isinstance(symbol, IfSymbol): return [symbol] new_symbols = [] - ast_if: ast.If = symbol.get_ast() + ast_if: ast.If = symbol.symbol_ast _test = ast_if.test _body = ast_if.body @@ -68,7 +68,7 @@ class IfCompiler(Compiler): logger.warning("Ignoring symbol(%s) in ClassDef", body) if_symbol = IfSymbol( - ast_if, symbol.get_scope(), symbol.get_symbol_name(), + ast_if, symbol.get_scope(), symbol.symbol_name, test_symbol.symbol_name, body_symbols ) new_symbols.append(if_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py index 207040b048f..e881ba39e19 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/module_compiler.py @@ -22,9 +22,9 @@ from ..registers import CompilerRegister @CompilerRegister.reg_compiler class ModuleCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.Module): + if not isinstance(symbol.symbol_ast, ast.Module): return [symbol] - module: ast.Module = symbol.get_ast() + module: ast.Module = symbol.symbol_ast bodies: list = module.body new_symbols: [Symbol] = [] for body in bodies: diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py index 668bd31e880..a48fc616bee 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/name_compiler.py @@ -23,16 +23,16 @@ from mindspore import log as logger @CompilerRegister.reg_compiler class NameCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.Name): + if not isinstance(symbol.symbol_ast, ast.Name): return [symbol] if isinstance(symbol, NameSymbol): return [symbol] new_symbols = [] - ast_name: ast.Name = symbol.get_ast() + ast_name: ast.Name = symbol.symbol_ast _id = ast_name.id name_symbol = NameSymbol( - ast_name, symbol.get_scope(), symbol.get_symbol_name(), + ast_name, symbol.get_scope(), symbol.symbol_name, id_input=_id, name_type=type(_id).__name__ ) new_symbols.append(name_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py b/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py index 2361410b691..66b1fea3e1d 100644 --- a/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py +++ b/mindspore/python/mindspore/rewrite_experiment/compilers/return_compiler.py @@ -23,12 +23,12 @@ from mindspore import log as logger @CompilerRegister.reg_compiler class ReturnCompiler(Compiler): def process(self, symbol: Symbol) -> [Symbol]: - if not isinstance(symbol.get_ast(), ast.Return): + if not isinstance(symbol.symbol_ast, ast.Return): return [symbol] if isinstance(symbol, ReturnSymbol): return [symbol] new_symbols = [] - ast_return: ast.Return = symbol.get_ast() + ast_return: ast.Return = symbol.symbol_ast _value = ast_return.value value_symbol = None @@ -42,7 +42,7 @@ class ReturnCompiler(Compiler): logger.warning("Ignoring value (%s) in return compiler", type(_value).__name__) return_symbol = ReturnSymbol( - ast_return, symbol.get_scope(), symbol.get_symbol_name(), + ast_return, symbol.get_scope(), symbol.symbol_name, value_symbol.symbol_name ) new_symbols.append(return_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/graph.py b/mindspore/python/mindspore/rewrite_experiment/graph.py index ae28150b871..656b62016ec 100644 --- a/mindspore/python/mindspore/rewrite_experiment/graph.py +++ b/mindspore/python/mindspore/rewrite_experiment/graph.py @@ -40,7 +40,7 @@ class InsertPoint: class Graph: def __init__(self): - self._root: Node = Node(NodeType.invalid, "", "") # todo + self._root: Node = Node(NodeType.call_cell, "", "") # todo self._insert: callable = self._root.insert_before self._node_size = 0 self._nodes: {str: Node} = {} diff --git a/mindspore/python/mindspore/rewrite_experiment/linkers/construct_function_def_linker.py b/mindspore/python/mindspore/rewrite_experiment/linkers/construct_function_def_linker.py index 38d76474f4f..a4d638ab1de 100644 --- a/mindspore/python/mindspore/rewrite_experiment/linkers/construct_function_def_linker.py +++ b/mindspore/python/mindspore/rewrite_experiment/linkers/construct_function_def_linker.py @@ -28,7 +28,7 @@ class ConstructFunctionDefLinker(Linker): def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol, FunctionSymbol): return [symbol] - if symbol.get_ast().name is not "construct": + if symbol.symbol_ast.name is not "construct": return [symbol] if symbol.get_attr(self._function_def_construct_visited_attr_key): return [symbol] diff --git a/mindspore/python/mindspore/rewrite_experiment/linkers/init_function_def_linker.py b/mindspore/python/mindspore/rewrite_experiment/linkers/init_function_def_linker.py index 2cfbb810d29..81d1dba8a6a 100644 --- a/mindspore/python/mindspore/rewrite_experiment/linkers/init_function_def_linker.py +++ b/mindspore/python/mindspore/rewrite_experiment/linkers/init_function_def_linker.py @@ -28,11 +28,11 @@ class InitFunctionDefLinker(Linker): def process(self, symbol: Symbol) -> [Symbol]: if not isinstance(symbol, FunctionSymbol): return [symbol] - if symbol.get_ast().name is not "__init__": + if symbol.symbol_ast.name is not "__init__": return [symbol] if symbol.get_attr(self._function_def_init_visited_attr_key): return [symbol] - args = symbol.get_ast().args + args = symbol.symbol_ast.args # todo # args_with_value = symbol.get_args() # for name, value in args_with_value.items(): diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index d5e5ca3bc2d..0f6bc65be0b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -50,16 +50,16 @@ class Rewrite: def compile_symbol_by_compiler(iterator: MutableDictIterator, compiler: Compiler) -> bool: symbol: Symbol = iterator.value() if not symbol.is_compilable(): # not compilable, skip - logger.warning("Processing symbol(%s) by compiler(%s), leaf symbol", symbol.get_symbol_name(), + logger.warning("Processing symbol(%s) by compiler(%s), leaf symbol", symbol.symbol_name, compiler.name()) return False results: [Symbol] = compiler.process(symbol) if len(results) == 1 and results[0] is symbol: # no change in process - logger.warning("Processing symbol(%s) by compiler(%s), not changed", symbol.get_symbol_name(), + logger.warning("Processing symbol(%s) by compiler(%s), not changed", symbol.symbol_name, compiler.name()) return False iterator = iterator.erase() - logger.warning("Processing symbol(%s) by compiler(%s), replaced by %d new symbols", symbol.get_symbol_name(), + logger.warning("Processing symbol(%s) by compiler(%s), replaced by %d new symbols", symbol.symbol_name, compiler.name(), len(results)) for result in results: iterator = iterator.insert(result.get_full_name_with_scope(), result) diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index d324aa238f3..ad4e4656004 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -87,25 +87,29 @@ class Symbol: def symbol_name(self): return self._symbol_name + @symbol_name.setter + def symbol_name(self, name: str): + if not isinstance(name, str): + raise RuntimeError("input symbol name is not str") + self._symbol_name = name + def get_scope(self) -> str: return self._scope - def get_symbol_name(self) -> str: - return self._symbol_name - def get_full_name_with_scope(self) -> str: return self._scope + "." + self._symbol_name - def set_symbol_name(self, symbol_name: str): - self._symbol_name = symbol_name - def symbol_type(self) -> int: return self._symbol_type - def get_ast(self): + @property + def symbol_ast(self): return self._ast_root - def set_ast(self, ast_node: ast.AST): + @symbol_ast.setter + def symbol_ast(self, ast_node: ast.AST): + if not isinstance(ast_node, ast.AST): + raise RuntimeError("input ast_node is not ast.AST") self._ast_root = ast_node def set_attr(self, key: str, value): -- Gitee From b430b727c2a9e06ccd1b2f84c7c55fe7e239be4e Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Thu, 6 Jan 2022 14:42:24 +0800 Subject: [PATCH 32/34] fix build bug --- .../python/mindspore/golden_stick/__init__.py | 5 +- .../golden_stick/quantization/__init__.py | 7 +- .../quantization/default_qat/__init__.py | 5 +- .../python/mindspore/rewrite/__init__.py | 5 +- .../mindspore/rewrite_experiment/graph.py | 6 +- .../mindspore/rewrite_experiment/node.py | 17 +++-- .../rewrite_experiment/pattern_engine.py | 69 +++++++++---------- 7 files changed, 60 insertions(+), 54 deletions(-) diff --git a/mindspore/python/mindspore/golden_stick/__init__.py b/mindspore/python/mindspore/golden_stick/__init__.py index 0b388adbb9a..00da5f97d0c 100644 --- a/mindspore/python/mindspore/golden_stick/__init__.py +++ b/mindspore/python/mindspore/golden_stick/__init__.py @@ -19,8 +19,7 @@ MindSpore golden stick module. from .golden_stick import GoldenStick from .net_transform import NetTransformer from .quantization import LayerPolicy, NetPolicy, QuantAwareTraining, FakeQuantizer, \ - Transformer, AllValueFakeQuantizer, DefaultLayerPolicy, DefaultNetworkPolicy, DefaultQuantAwareTraining + Transformer, DefaultLayerPolicy, DefaultNetworkPolicy, DefaultQuantAwareTraining __all__ = ["GoldenStick", "NetTransformer", "LayerPolicy", "NetPolicy", "QuantAwareTraining", "FakeQuantizer", - "Transformer", "AllValueFakeQuantizer", "DefaultLayerPolicy", "DefaultNetworkPolicy", - "DefaultQuantAwareTraining"] + "Transformer", "DefaultLayerPolicy", "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/golden_stick/quantization/__init__.py b/mindspore/python/mindspore/golden_stick/quantization/__init__.py index 1b2c8d47563..8b1ca4bb9a8 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/__init__.py +++ b/mindspore/python/mindspore/golden_stick/quantization/__init__.py @@ -21,8 +21,7 @@ from .net_policy import NetPolicy from .quantize import QuantAwareTraining from .fake_quantizer import FakeQuantizer from .transformer import Transformer -from .default_qat import AllValueFakeQuantizer, DefaultLayerPolicy, DefaultNetworkPolicy, \ - DefaultQuantAwareTraining +from .default_qat import DefaultLayerPolicy, DefaultNetworkPolicy, DefaultQuantAwareTraining -__all__ = ["LayerPolicy", "NetPolicy", "QuantAwareTraining", "FakeQuantizer", "Transformer", "AllValueFakeQuantizer", - "DefaultLayerPolicy", "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] +__all__ = ["LayerPolicy", "NetPolicy", "QuantAwareTraining", "FakeQuantizer", "Transformer", "DefaultLayerPolicy", + "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py index c5914b6d3a9..7bb49e41635 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/__init__.py @@ -16,10 +16,9 @@ MindSpore golden stick default-qat-quantization. """ -from .default_fake_quantizer import AllValueFakeQuantizer, LearnedFakeQuantizerPerLayer +from .default_fake_quantizer import LearnedFakeQuantizerPerLayer from .default_layer_policy import DefaultLayerPolicy from .default_net_policy import DefaultNetworkPolicy from .default_quantize import DefaultQuantAwareTraining -__all__ = ["AllValueFakeQuantizer", "LearnedFakeQuantizerPerLayer", "DefaultLayerPolicy", - "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] +__all__ = ["LearnedFakeQuantizerPerLayer", "DefaultLayerPolicy", "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/rewrite/__init__.py b/mindspore/python/mindspore/rewrite/__init__.py index 1ace088ec52..73a640b7d11 100644 --- a/mindspore/python/mindspore/rewrite/__init__.py +++ b/mindspore/python/mindspore/rewrite/__init__.py @@ -1,5 +1,6 @@ from .graph import Graph -from .node import Node, NodeType, PlaceholderNode +from .node import BaseNode, Node, NodeType, PlaceholderNode from .pattern_engine import PatternEngine, PatternNode, PlaceHolderNode -__all__ = ["Graph", "Node", "NodeType", "PatternEngine", "PatternNode", "PlaceHolderNode", "PlaceholderNode"] +__all__ = ["Graph", "Node", "BaseNode", "NodeType", "PatternEngine", "PatternNode", "PlaceHolderNode", + "PlaceholderNode"] diff --git a/mindspore/python/mindspore/rewrite_experiment/graph.py b/mindspore/python/mindspore/rewrite_experiment/graph.py index 656b62016ec..3a00ab497a2 100644 --- a/mindspore/python/mindspore/rewrite_experiment/graph.py +++ b/mindspore/python/mindspore/rewrite_experiment/graph.py @@ -15,6 +15,7 @@ from typing import Tuple, Optional from .node import Node, NodeType +from mindspore.nn import Cell class InsertPoint: @@ -40,7 +41,7 @@ class InsertPoint: class Graph: def __init__(self): - self._root: Node = Node(NodeType.call_cell, "", "") # todo + self._root: Node = Node(NodeType.output, Cell, "") # todo self._insert: callable = self._root.insert_before self._node_size = 0 self._nodes: {str: Node} = {} @@ -55,6 +56,9 @@ class Graph: # def set_return(self, node: Node): # self._return = node + def get_root(self): + return self._root + def insert_before(self, node_or_name: Tuple[Node, str]): if isinstance(node_or_name, Node): node = node_or_name diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index 255f0bfc59b..4743ffb87e8 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -14,6 +14,7 @@ # ============================================================================ from collections import OrderedDict from typing import Dict, List, Union, Optional +from mindspore.nn import Cell class NodeType: @@ -42,13 +43,16 @@ class Node: inputs: the input nodes of this node. """ - def __init__(self, node_type: int, cell_name: str, field: str, construct_args=None, construct_kwargs=None, + def __init__(self, node_type: int, cell_cls: type, field: str, construct_args=None, construct_kwargs=None, targets=None): - if node_type != NodeType.call_cell: - raise RuntimeError("Only support call_cell now") + if node_type != NodeType.call_cell and node_type != NodeType.output: + raise RuntimeError("Only support call_cell and output now") self._node_type: int = node_type # name of subclass of cell - self._symbol_name: str = cell_name + if not issubclass(cell_cls, Cell): + raise RuntimeError("Cell_cls should be a subclass of Cell") + self._cell_cls: type = cell_cls + self._cell_name: str = cell_cls.__name__ # args of constructor of subclass os cell if construct_args is None: self._construct_args = [] @@ -141,11 +145,14 @@ class Node: return self._targets def get_symbol_name(self): - return self._symbol_name + return self._cell_name def get_node_type(self) -> int: return self._node_type + def get_cell_type(self): + return self._cell_cls + # class CellNode(Node): # """ diff --git a/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py b/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py index 887761a91e5..62323a430b2 100644 --- a/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py +++ b/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py @@ -14,10 +14,10 @@ # ============================================================================ """PatternEngine.""" -from typing import Tuple, Union, List, Type +from typing import Tuple, Union, List, Type, Optional from collections import OrderedDict from .graph import Graph -from .node import CellNode, NodeType +from .node import Node, NodeType from mindspore.nn.cell import Cell from mindspore import log as logger @@ -41,7 +41,7 @@ class PatternNode: self._inputs = inputs @staticmethod - def from_node(node: CellNode) -> 'PatternNode': + def from_node(node: Node) -> 'PatternNode': """ Create a PatternNode from a rewrite node. @@ -52,13 +52,13 @@ class PatternNode: PatternNode created from rewrite node. """ - pattern_node = PatternNode(node.name) - if node.node_type() is NodeType.call_cell: - pattern_node._type = node.type + pattern_node = PatternNode(node.get_targets()[0]) + if node.get_node_type() is NodeType.call_cell: + pattern_node._type = node.get_cell_type() return pattern_node @staticmethod - def create_pattern_from_node(node: CellNode) -> 'PatternNode': + def create_pattern_from_node(node: Node) -> 'PatternNode': """ Create a PatternNode from a rewrite node with its inputs. @@ -71,7 +71,7 @@ class PatternNode: pattern_node = PatternNode.from_node(node) inputs = [] - for node_input in node.inputs(): + for node_input in node.get_inputs(): inputs.append(PatternNode.create_pattern_from_node(node_input)) pattern_node._inputs = inputs return pattern_node @@ -119,15 +119,15 @@ class PatternNode: self._inputs = inputs - def match(self, node: CellNode) -> bool: + def match(self, node: Node) -> bool: """ Check if current PatternNode can match with a rewrite node Args: - node (CellNode) : a rewrite node to be match. + node (Node) : a rewrite node to be match. """ - return self._type == node.type + return self._type == node.get_cell_type() def inputs(self): """ @@ -155,9 +155,9 @@ class VarNode(PatternNode): """ def __init__(self): - super(VarNode, self).__init__("placehold", Cell, []) + super(VarNode, self).__init__("placeholder", Cell, []) - def match(self, node: CellNode) -> bool: + def match(self, node: Node) -> bool: return node is not None @@ -180,10 +180,7 @@ class PatternEngine: self._replacement = replacement self._pattern = PatternNode.create_pattern_from_list(pattern) else: - logger.debug("Unsupported pattern type: %s", type(pattern)) - self._is_chain = False - self._replacement = None - self._pattern = VarNode() + raise RuntimeError("Unsupported pattern define") def pattern(self) -> PatternNode: """ @@ -203,13 +200,13 @@ class PatternEngine: If graph been changed. """ - root: CellNode = graph.root() + root: Node = graph.get_root() changed = False # IR match - queue: [CellNode] = [root] + queue: [Node] = [root] while len(queue) > 0: - cur_node: CellNode = queue.pop(0) - node_inputs = cur_node.inputs + cur_node: Node = queue.pop(0) + node_inputs = cur_node.get_inputs() matched, matched_dict = self._match(self._pattern, cur_node) if not matched or not PatternEngine._check_match(self._pattern, matched_dict): for node_input in node_inputs: @@ -223,19 +220,19 @@ class PatternEngine: if new_node is None: # return None to remove changed = True for key in matched_dict: - graph.remove_node(matched_dict[key]) + graph.erase_node(matched_dict[key]) elif new_node == cur_node: # return origin Node for do nothing pass else: # return Node to insert or replace (new Node no need to set inputs and outputs) # todo if we need to support _process_chain or _process_tree return multi-node changed = True graph.replace_node(matched_list, new_node) - node_inputs = new_node.inputs + node_inputs = new_node.get_inputs() for node_input in node_inputs: queue.append(node_input) return changed - def _process_chain(self, matched_nodes: [CellNode]) -> CellNode: + def _process_chain(self, matched_nodes: [Node]) -> Optional[Node]: """ Define how to generate a new_node with fuse_fn when pattern is a chain-pattern. @@ -252,9 +249,9 @@ class PatternEngine: if replacement is None: return None if len(matched_nodes) == 0: - new_node = CellNode(instance=replacement) + new_node = Node(instance=replacement) else: - new_node = CellNode(instance=replacement, inputs=matched_nodes[0].inputs) + new_node = Node(instance=replacement, inputs=matched_nodes[0].inputs) node_name = "" for matched_node in matched_nodes: node_name += matched_node.name + "_" @@ -263,7 +260,7 @@ class PatternEngine: return new_node # matched_cells: name_of_cell_in_pattern map to matched cell in network - def _process_tree(self, matched_nodes: OrderedDict) -> CellNode: + def _process_tree(self, matched_nodes: OrderedDict) -> Node: """ Define how to generate a new_node when pattern is a tree-pattern. This method must be overridden by all subclasses whose pattern is a tree-pattern. @@ -294,13 +291,13 @@ class PatternEngine: merged.update(dict2) return merged - def _match(self, pattern: PatternNode, node: CellNode) -> Tuple[bool, OrderedDict]: + def _match(self, pattern: PatternNode, node: Node) -> Tuple[bool, OrderedDict]: """ Match `pattern` with a rewrite node with all inputs of the `pattern` Args: pattern (PatternNode): pattern to be match. - node (CellNode): node to be match. + node (Node): node to be match. Returns: A bool value to indicate if matched. @@ -308,23 +305,23 @@ class PatternEngine: """ # todo: Recurse into subgraph node. Depend on subgraph node definition - if node.node_type() != NodeType.call_cell: - logger.debug("Pattern match failed: node(%s) is not a cell", node.name) + if node.get_node_type() != NodeType.call_cell: + logger.debug("Pattern match failed: node(%s) is not a cell", node.get_field()) return False, OrderedDict() if not pattern.match(node): - logger.debug("Pattern match failed: node(%s)'s type is %s while pattern type is %s", node.name, node.type, - pattern.type()) + logger.debug("Pattern match failed: node(%s)'s type is %s while pattern type is %s", node.get_field(), + node.get_cell_type(), pattern.type()) return False, OrderedDict() if isinstance(pattern, VarNode): return True, OrderedDict() pattern_inputs = pattern.inputs() - cur_inputs = node.inputs + cur_inputs = node.get_inputs() input_num = len(pattern_inputs) if input_num == 0: return True, OrderedDict({pattern.name(): node}) if input_num != len(cur_inputs): - logger.debug("Pattern match failed: node(%s)'s has %d inputs while pattern has %d inputs", node.name, - len(node.inputs), input_num) + logger.debug("Pattern match failed: node(%s)'s has %d inputs while pattern has %d inputs", node.get_field(), + len(node.get_inputs()), input_num) return False, OrderedDict() result = OrderedDict() for i in range(0, input_num): -- Gitee From de396d1b0751abd503cb642b86b7066088bd946e Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Wed, 29 Dec 2021 16:20:53 +0800 Subject: [PATCH 33/34] support origin qat features --- .../mindspore/golden_stick/net_transform.py | 3 +- .../golden_stick/quantization/__init__.py | 5 +- .../default_qat/default_fake_quantizer.py | 27 +- .../default_qat/default_layer_policy.py | 2 - .../default_qat/default_net_policy.py | 31 +- .../golden_stick/quantization/quant_utils.py | 439 ++++++++++++++++++ .../{quantize.py => quante_aware_training.py} | 0 .../quantization/quantize_wrapper_cell.py | 29 +- .../python/mindspore/nn/layer/combined.py | 1 - mindspore/python/mindspore/rewrite/graph.py | 77 +-- .../mindspore/rewrite/pattern_engine.py | 45 +- .../golden_stick/test_quant_aware_training.py | 8 + 12 files changed, 577 insertions(+), 90 deletions(-) create mode 100644 mindspore/python/mindspore/golden_stick/quantization/quant_utils.py rename mindspore/python/mindspore/golden_stick/quantization/{quantize.py => quante_aware_training.py} (100%) create mode 100644 tests/ut/python/golden_stick/test_quant_aware_training.py diff --git a/mindspore/python/mindspore/golden_stick/net_transform.py b/mindspore/python/mindspore/golden_stick/net_transform.py index 8c337cf5221..897ebe14b2a 100644 --- a/mindspore/python/mindspore/golden_stick/net_transform.py +++ b/mindspore/python/mindspore/golden_stick/net_transform.py @@ -16,7 +16,8 @@ from typing import Union, Optional from mindspore.nn.cell import Cell -from mindspore.rewrite import Graph, BaseNode, PatternEngine +from mindspore.rewrite import Graph, PatternEngine +from mindspore.rewrite.node import BaseNode class NetTransformer: diff --git a/mindspore/python/mindspore/golden_stick/quantization/__init__.py b/mindspore/python/mindspore/golden_stick/quantization/__init__.py index 8b1ca4bb9a8..36caa99561d 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/__init__.py +++ b/mindspore/python/mindspore/golden_stick/quantization/__init__.py @@ -18,10 +18,11 @@ MindSpore golden stick module. from .layer_policy import LayerPolicy from .net_policy import NetPolicy -from .quantize import QuantAwareTraining +from .quante_aware_training import QuantAwareTraining from .fake_quantizer import FakeQuantizer from .transformer import Transformer from .default_qat import DefaultLayerPolicy, DefaultNetworkPolicy, DefaultQuantAwareTraining -__all__ = ["LayerPolicy", "NetPolicy", "QuantAwareTraining", "FakeQuantizer", "Transformer", "DefaultLayerPolicy", +__all__ = ["LayerPolicy", "NetPolicy", "QuantAwareTraining", "FakeQuantizer", "Transformer", + "DefaultLayerPolicy", "DefaultNetworkPolicy", "DefaultQuantAwareTraining"] diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py index a57c407a7a5..0e744786b3c 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py @@ -16,6 +16,7 @@ from functools import partial from ..fake_quantizer import FakeQuantizer +from ..quant_utils import compute_kl_threshold from mindspore.ops.operations import _quant_ops as Q from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor @@ -83,7 +84,8 @@ class DefaultFakeQuantizerPerChannel(DefaultFakeQuantizerPerLayer): """ Derived from DefaultFakeQuantizerPerLayer, perchannel version of default fake quantizer """ - def __init__(self, num_channels=1, channel_axis=1,ema=False, ema_decay=0.999, symmetric=False, narrow_range=False): + + def __init__(self, num_channels=1, channel_axis=1, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False): super(DefaultFakeQuantizerPerChannel, self).__init__(ema=ema, ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) self._float_min = Parameter(Tensor([float("-inf")] * num_channels), name="float_min") @@ -99,8 +101,9 @@ class LearnedFakeQuantizerPerLayer(FakeQuantizer): def __init__(self, num_bits=8, quant_delay=0, min_init=-6, max_init=6, neg_trunc=False): super(LearnedFakeQuantizerPerLayer, self).__init__() + self._num_bits = num_bits self.neg_trunc = neg_trunc - self._quant_max = _calculate_quant_max(num_bits, self.neg_trunc) + self._quant_max = _calculate_quant_max(self._num_bits, self.neg_trunc) self.quant_max = Parameter(Tensor(np.array([self._quant_max]).astype(np.float32))) quant_func = partial(Q.FakeLearnedScaleQuantPerLayer, quant_delay=quant_delay, neg_trunc=self.neg_trunc) self.fake_quant_train = quant_func(training=True) @@ -108,9 +111,11 @@ class LearnedFakeQuantizerPerLayer(FakeQuantizer): self._float_min = Parameter(Tensor(min_init), name="float_min") self._float_max = Parameter(Tensor(max_init), name="float_max") - def update_min_max(self, new_float_min, new_float_max): - self._float_min.set_data(Tensor(new_float_min)) - self._float_max.set_data(Tensor(new_float_max)) + def compute_quant_param(self, weight_param): + max_init = [compute_kl_threshold(weight_param, self._num_bits)] + min_init = [-x for x in max_init] + self._float_min.set_data(Tensor(self._get_init_array(max_init))) + self._float_max.set_data(Tensor(self._get_init_array(min_init))) def construct(self, x): if self.training: @@ -128,7 +133,8 @@ class LearnedFakeQuantizePerChannel(FakeQuantizer): def __init__(self, num_bits=8, num_channels=1, channel_axis=1, quant_delay=0, float_min=-6, float_max=6, neg_trunc=False): super(LearnedFakeQuantizePerChannel, self).__init__() - self._quant_max = _calculate_quant_max(num_bits, neg_trunc) + self._num_bits = num_bits + self._quant_max = _calculate_quant_max(self._num_bits, neg_trunc) self.quant_max = Parameter(Tensor(np.array([self._quant_max]).astype(np.float32))) quant_func = partial(Q.FakeLearnedScaleQuantPerChannel, quant_delay=quant_delay, neg_trunc=neg_trunc, channel_axis=channel_axis) @@ -138,9 +144,12 @@ class LearnedFakeQuantizePerChannel(FakeQuantizer): self._float_min = Parameter(Tensor(self._get_init_array(float_min)), name="float_min") self._float_max = Parameter(Tensor(self._get_init_array(float_max)), name="float_max") - def update_min_max(self, new_float_min, new_float_max): - self._float_min.set_data(Tensor(self._get_init_array(new_float_min))) - self._float_max.set_data(Tensor(self._get_init_array(new_float_max))) + def compute_quant_param(self, weight_param): + max_init = [compute_kl_threshold(weight_para_each, self._num_bits) + for weight_para_each in weight_param] + min_init = [-x for x in max_init] + self._float_min.set_data(Tensor(self._get_init_array(max_init))) + self._float_max.set_data(Tensor(self._get_init_array(min_init))) def _get_init_array(self, init_data): """ diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py index 2cb90f65915..cfb6cb2e03d 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py @@ -87,5 +87,3 @@ class ActivationLayerPolicy(DefaultLayerPolicy): self._output_quantizer: Optional[FakeQuantizer] = LearnedFakeQuantizerPerLayer() self._insert_before_input = insert_before_input self._insert_after_output = insert_after_output - - diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py index a8a99ad91aa..032a766ad03 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py @@ -19,12 +19,29 @@ from ..net_policy import NetPolicy from ..layer_policy import LayerPolicy from .default_layer_policy import DefaultLayerPolicy from ..transformer import Transformer -from mindspore.nn.layer import Conv2d, Dense, MatMul, BatchNorm2d, ReLU +from mindspore.nn.layer import Conv2d, Dense, MatMul, BatchNorm2d, ReLU, Conv2dBnAct from mindspore.nn.layer.quant import Conv2dBnFoldQuantOneConv from mindspore.rewrite.pattern_engine import PatternEngine from mindspore.nn.layer.quant import QuantConfig +def _fetch_quant_config(layer_policy: LayerPolicy): + weight_fake_quantizer = None if len(layer_policy.get_weight_name_and_quantizers()) == 0 \ + else layer_policy.get_weight_name_and_quantizers()[0][1] + act_fake_quantizer = None if len(layer_policy.get_act_name_and_quantizers()) == 0 \ + else layer_policy.get_act_name_and_quantizers()[0][1] + return QuantConfig(weight_fake_quantizer, act_fake_quantizer) + + +def _split_conv2d_bn_act(conv2d_bn_act: Conv2dBnAct): + result = [conv2d_bn_act.conv] + if conv2d_bn_act.has_bn: + result.insert(0, conv2d_bn_act.batchnorm) + if conv2d_bn_act.has_act: + result.insert(0, conv2d_bn_act.activation) + return result + + class DefaultNetworkPolicy(NetPolicy): """ Derived class of NetworkQConfig. Default network-quant-config. @@ -37,20 +54,14 @@ class DefaultNetworkPolicy(NetPolicy): super().__init__(config) if config is None: config = {} - - def fetch_quant_config(layer_policy: LayerPolicy): - weight_fake_quantizer = None if len(layer_policy.get_weight_name_and_quantizers()) == 0 \ - else layer_policy.get_weight_name_and_quantizers()[0][1] - act_fake_quantizer = None if len(layer_policy.get_act_name_and_quantizers()) == 0 \ - else layer_policy.get_act_name_and_quantizers()[0][1] - return QuantConfig(weight_fake_quantizer, act_fake_quantizer) - + self._net_layer_policy = DefaultLayerPolicy(["weight"], [], config) self._pattern_engines: [PatternEngine] = [ Transformer([Conv2d, BatchNorm2d]), Transformer([Conv2d, ReLU]), + PatternEngine([Conv2dBnAct], _split_conv2d_bn_act), PatternEngine([Conv2d, BatchNorm2d], partial(Conv2dBnFoldQuantOneConv.from_float, - fetch_quant_config(self.get_net_layer_policy()))) + _fetch_quant_config(self.get_net_layer_policy()))) ] self._support_layer_map: dict = { Conv2d: DefaultLayerPolicy(["weight"], [], config), diff --git a/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py b/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py new file mode 100644 index 00000000000..cf6cc57ae6c --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py @@ -0,0 +1,439 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Quantization utils.""" + +import numpy as np +from mindspore._checkparam import Validator +from ... import nn + +__all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers", "compute_kl_threshold"] + + +def cal_quantization_params(input_min, + input_max, + quant_min, + quant_max, + data_type, + symmetric=False): + r""" + Calculate quantization params for scale and zero point. + + Args: + input_min (numpy.ndarray): The dimension of channel or 1. + input_max (numpy.ndarray): The dimension of channel or 1. + quant_min (int): The minimum quantization integer. + quant_max (int): The maximum quantization integer. + data_type (numpy type) : Can be numpy int8, numpy uint8. + symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. + + Returns: + scale (numpy.ndarray): quantization param. + zero point (numpy.ndarray): quantization param. + """ + input_max = np.maximum(0.0, input_max) + input_min = np.minimum(0.0, input_min) + + if input_min.shape != input_max.shape: + raise ValueError("input min shape should be equal to input max.") + if len(input_min.shape) > 1: + raise ValueError("input min and max shape should be one dim.") + if (input_min > input_max).all(): + raise ValueError("input_min min should be less than input max.") + if (input_max == input_min).all(): + return np.ones(input_min.shape), np.zeros(input_min.shape) + + # calculate scale + if symmetric: + input_max = np.maximum(-input_min, input_max) + input_min = -input_max + scale = (input_max - input_min) / (quant_max - quant_min) + + # calculate zero point + if data_type == np.int8 and symmetric: + zp = np.zeros(input_min.shape) + else: + zp_double = quant_min - input_min / scale + zp = np.floor(zp_double + 0.5) + + return scale, zp + + +def get_quant_min_max(data_type, num_bits=8, narrow_range=False): + """Calculate quantization params for minimum/maximum quantization integer""" + if data_type == np.int8: + quant_min = 0 - 2 ** (num_bits - 1) + quant_max = 2 ** (num_bits - 1) - 1 + elif data_type == np.uint8: + quant_min = 0 + quant_max = 2 ** num_bits - 1 + else: + raise ValueError("Unsupported datatype({})".format(data_type)) + if narrow_range: + quant_min = quant_min + 1 + return quant_min, quant_max + + +def weight2int(data, scale, zero_point, quant_min, quant_max): + r""" + Calculate int8/uint8 weight from fp32. the formula is defined as: + + .. math:: + int8/uint8 = round(float/scale) + offset + + Args: + data (numpy.ndarray): The dimension of channel or 1. Should be NCHW. + scale (numpy.ndarray): The dimension of channel or 1. + zero_point (numpy.ndarray): The dimension of channel or 1. + quant_min (int): The minimum quantization integer. + quant_max (int): The maximum quantization integer. + + Returns: + weight (numpy.ndarray): The dimension of channel or 1. + """ + if scale.shape != zero_point.shape: + raise ValueError("`scale` and `zero_point` should have the same shape.") + if scale.shape[0] < 0: + raise ValueError("`scale` and `zero_point` shape should be greater than zero.") + if len(scale.shape) >= 1 and scale.shape[0] > 1: + # for perchannel + if scale.shape[0] == data.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(data.shape[1:]) + scale = scale.reshape(shape_list) + zero_point = zero_point.reshape(shape_list) + elif scale.shape[0] == data.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(data.shape[2:]) + scale = scale.reshape(shape_list) + zero_point = zero_point.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(data.shape)) + + weight_int = np.round((data / scale) + zero_point) + weight_int[weight_int > quant_max] = quant_max + weight_int[weight_int < quant_min] = quant_min + return weight_int + + +def scale_zp_max_min_from_fake_quant_cell(cell, data_type): + """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMaxObserver`.""" + minq = cell.minq.data.asnumpy() + maxq = cell.maxq.data.asnumpy() + # make sure maxq > 0 and minq <= 0 + if cell.mode == 'LEARNED_SCALE': + maxq = np.abs(maxq) + minq = -np.abs(minq) + quant_min, quant_max = get_quant_min_max(data_type, num_bits=cell.num_bits, narrow_range=cell.narrow_range) + symmetric = cell.symmetric and not cell.neg_trunc + scale, zp = cal_quantization_params( + minq, maxq, + quant_min, quant_max, data_type, + symmetric=symmetric) + return scale, zp, maxq, minq + + +def fold_batchnorm(weight, cell_quant): + r""" + Fold the batchnorm in `Conv2dBnFoldQuant` to weight. + + Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. + + Args: + weight (numpy.ndarray): Weight of `cell_quant`. + cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`. + + Returns: + weight (numpy.ndarray): Folded weight. + bias (numpy.ndarray): Folded bias. + """ + variance = cell_quant.moving_variance.data.asnumpy() + mean = cell_quant.moving_mean.data.asnumpy() + gamma = cell_quant.gamma.data.asnumpy() + beta = cell_quant.beta.data.asnumpy() + epsilon = cell_quant.eps + sigma = np.sqrt(variance + epsilon) + + if gamma.shape[0] == weight.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(weight.shape[1:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + elif gamma.shape[0] == weight.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(weight.shape[2:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(weight.shape)) + + weight = weight * _gamma / _sigma + bias = beta - gamma * mean / sigma + return weight, bias + + +def without_fold_batchnorm(weight, cell_quant): + r""" + Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight. + + Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. + + Args: + weight (numpy.ndarray): Weight of `cell_quant`. + cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`. + + Returns: + weight (numpy.ndarray): whihout folded weight. + bias (numpy.ndarray): without folded bias. + """ + variance = cell_quant.batchnorm.moving_variance.data.asnumpy() + mean = cell_quant.batchnorm.moving_mean.data.asnumpy() + gamma = cell_quant.batchnorm.gamma.data.asnumpy() + beta = cell_quant.batchnorm.beta.data.asnumpy() + epsilon = cell_quant.batchnorm.eps + sigma = np.sqrt(variance + epsilon) + + if gamma.shape[0] == weight.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(weight.shape[1:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + elif gamma.shape[0] == weight.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(weight.shape[2:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(weight.shape)) + + weight = weight * _gamma / _sigma + bias = beta - gamma * mean / sigma + return weight, bias + + +def compute_kl_threshold(data, bitwidth): + r""" + Using KL-J Distance to calculate the clip threshold. + + Args: + - **data** (NumpyArray) - Data observed to calculate the threshold for quantization, + - **bitwidth** (QuantDtype) - The datatype of quantization. + Outputs: + Tensor with Shape 1. Threshold to calculate the data. + """ + data_max = np.abs(data).max() + if data_max < 1e-5: + return 1e-5 + hist, bin_edges = np.histogram(np.abs(data), bins='sqrt', range=(0, data_max), density=True) + # For the sake of high efficiency, we limit the maximum number of bins to 1024 in `sqrt` mode, If it exceeds the + # largest size, turn to use the default bins config. + largest_bin_size = 1024 + if hist.shape[0] > largest_bin_size: + hist, bin_edges = np.histogram(np.abs(data), range=(0, data_max), density=True) + hist = hist / np.sum(hist) + cumsum = np.cumsum(hist) + bit_pow_range = pow(2, int(bitwidth.num_bits) - 1) + threshold = [] + scaling_factor = [] + kl = [] + if bit_pow_range + 1 > len(bin_edges) - 1: + th_layer_out = bin_edges[-1] + return float(th_layer_out) + for i in range(bit_pow_range + 1, len(bin_edges), 1): + threshold_tmp = (i + 0.5) * (bin_edges[1] - bin_edges[0]) + threshold = np.concatenate((threshold, [threshold_tmp])) + scaling_factor_tmp = threshold_tmp / (bit_pow_range - 1) + scaling_factor = np.concatenate((scaling_factor, [scaling_factor_tmp])) + # forward interpolation + cumsum_tmp = np.copy(cumsum) + cumsum_tmp[(i - 1):] = 1 + fwd_x = np.linspace(0.0, 1.0, bit_pow_range) + fwd_xp = np.linspace(0.0, 1.0, i) + fwd_fp = cumsum_tmp[:i] + forward_interp = np.interp(fwd_x, fwd_xp, fwd_fp) + # backward interpolation + bwd_x = np.linspace(0.0, 1.0, i) + bwd_xp = np.linspace(0.0, 1.0, bit_pow_range) + bwd_fp = forward_interp + backward_interp = np.interp(bwd_x, bwd_xp, bwd_fp) + cumsum_tmp[:i] = backward_interp + kl_tmp = np.sum((cumsum - cumsum_tmp) * np.log2(cumsum / cumsum_tmp)) # Kullback-Leibler-J + kl = np.concatenate((kl, [kl_tmp])) + th_layer_out = threshold[np.argmin(kl)] + threshold = float(th_layer_out) + if threshold < 1e-5: + threshold = 1e-5 + return threshold + + +def query_quant_layers(network): + r""" + Query the network's quantization strategy of each quantized layer and print it to the screen, note that all the + quantization layers are queried before graph compile optimization in the graph mode, thus, some redundant quantized + layers, which not exist in practical execution, may appear. + + Args: + network (Cell): input network + + Examples: + >>> from mindspore.compression.quant import QuantizationAwareTraining + >>> from mindspore.compression.quant.quant_utils import query_quant_layers + >>> class LeNet5(nn.Cell): + ... def __init__(self, num_class=10, channel=1): + ... super(LeNet5, self).__init__() + ... self.type = "fusion" + ... self.num_class = num_class + ... + ... # change `nn.Conv2d` to `nn.Conv2dBnAct` + ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') + ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') + ... # change `nn.Dense` to `nn.DenseBnAct` + ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') + ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu') + ... self.fc3 = nn.DenseBnAct(84, self.num_class) + ... + ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + ... self.flatten = nn.Flatten() + ... + ... def construct(self, x): + ... x = self.conv1(x) + ... x = self.max_pool2d(x) + ... x = self.conv2(x) + ... x = self.max_pool2d(x) + ... x = self.flatten(x) + ... x = self.fc1(x) + ... x = self.fc2(x) + ... x = self.fc3(x) + ... return x + ... + >>> net = LeNet5() + >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) + >>> net_qat = quantizer.quantize(net) + >>> query_quant_layers(net_qat) + conv1.conv.fake_quant_weight INT8 + conv1.activation.fake_quant_act INT8 + conv2.conv.fake_quant_weight INT8 + conv2.activation.fake_quant_act INT8 + fc1.dense.fake_quant_weight INT8 + fc1.activation.fake_quant_act INT8 + fc2.dense.fake_quant_weight INT8 + fc2.activation.fake_quant_act INT8 + fc3.dense.fake_quant_weight INT8 + fc3.activation.fake_quant_act INT8 + """ + network = Validator.check_isinstance("network", network, nn.Cell) + tplt = "{0:60}\t{1:10}" + for cell_and_name in network.cells_and_names(): + cell_name = cell_and_name[0] + cell = cell_and_name[1] + if isinstance(cell, nn.FakeQuantWithMinMaxObserver): + print(tplt.format(cell_name, cell.quant_dtype)) + + +def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None): + r""" + Load fp32 model parameters into quantization model. + + Args: + quant_model(Cell): Quantization model. + params_dict(dict): Parameter dict that stores fp32 parameters. + quant_new_params(list): Parameters that exist in quantization network but not in non-quantization + network. Default: None. + + Raises: + TypeError: If `quant_new_params` is not None and is not list. + ValueError: If there are parameters in the `quant_model` that are neither in `params_dict` + nor in `quant_new_params`. + + Examples: + >>> from mindspore import load_checkpoint + >>> from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net + >>> class LeNet5(nn.Cell): + ... def __init__(self, num_class=10, channel=1): + ... super(LeNet5, self).__init__() + ... self.type = "fusion" + ... self.num_class = num_class + ... + ... # change `nn.Conv2d` to `nn.Conv2dBnAct` + ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') + ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') + ... # change `nn.Dense` to `nn.DenseBnAct` + ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') + ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu') + ... self.fc3 = nn.DenseBnAct(84, self.num_class) + ... + ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + ... self.flatten = nn.Flatten() + ... + ... def construct(self, x): + ... x = self.conv1(x) + ... x = self.max_pool2d(x) + ... x = self.conv2(x) + ... x = self.max_pool2d(x) + ... x = self.flatten(x) + ... x = self.fc1(x) + ... x = self.fc2(x) + ... x = self.fc3(x) + ... return x + ... + >>> net = LeNet5() + >>> ckpt_file_name = "./checkpoint/LeNet5_noquant-1_32.ckpt" + >>> param_dict = load_checkpoint(ckpt_file_name) + >>> load_nonquant_param_into_quant_net(net, param_dict) + """ + if quant_new_params is not None and not isinstance(quant_new_params, list): + raise TypeError("quant_new_params must be list or None.") + iterable_dict = { + 'minq': iter(list(filter(lambda item: item[0].endswith('minq'), params_dict.items()))), + 'maxq': iter(list(filter(lambda item: item[0].endswith('maxq'), params_dict.items()))), + 'quant_max': iter(list(filter(lambda item: item[0].endswith('quant_max'), params_dict.items()))) + } + for param in params_dict.items(): + key_name = param[0].split(".")[-1] + if key_name not in iterable_dict: + iterable_dict[key_name] = iter(list(filter(lambda item, value=key_name: item[0].endswith(value), + params_dict.items()))) + + for name, param in quant_model.parameters_and_names(): + key_name = name.split(".")[-1] + if key_name not in iterable_dict.keys(): + if key_name not in quant_new_params: + raise ValueError(f"Can't find match parameter in ckpt, param name = {name}") + continue + value_param = next(iterable_dict[key_name], None) + if value_param: + param.set_data(value_param[1].data) + print(f'init model param {name} with checkpoint param {value_param[0]}') + + # Perform KL_init when learned scale quantization is executed. + for cell_and_name in quant_model.cells_and_names(): + cell = cell_and_name[1] + if isinstance(cell, (nn.Conv2dBnFoldQuantOneConv, nn.Conv2dBnFoldQuant, nn.Conv2dBnWithoutFoldQuant, + nn.Conv2dQuant, nn.DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE": + subcell_weight_para = cell.weight.data.asnumpy() + if hasattr(cell, 'gamma'): + scale_factor = (cell.gamma.data.asnumpy() / + np.sqrt(cell.moving_variance.data.asnumpy() + 1e-5)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + + if cell.fake_quant_weight.per_channel: + max_init = [compute_kl_threshold(weight_para_each, cell.fake_quant_weight.quant_dtype) + for weight_para_each in subcell_weight_para] + min_init = [-x for x in max_init] + else: + max_init = [compute_kl_threshold(subcell_weight_para, cell.fake_quant_weight.quant_dtype)] + min_init = [-x for x in max_init] + + cell.fake_quant_weight.reset(quant_dtype=cell.fake_quant_weight.quant_dtype, + min_init=min_init, max_init=max_init) diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize.py b/mindspore/python/mindspore/golden_stick/quantization/quante_aware_training.py similarity index 100% rename from mindspore/python/mindspore/golden_stick/quantization/quantize.py rename to mindspore/python/mindspore/golden_stick/quantization/quante_aware_training.py diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py index 3dfa03fbcb6..0c61e92f2f2 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py @@ -37,6 +37,12 @@ class QuantizeWrapperCell(Cell): self._w_zp = 0 self._o_scale = 1.0 self._o_zp = 0 + # init quant params + for weight_name, quantizer in self._policy.get_weight_name_and_quantizers(): + assert weight_name is not None + assert quantizer is not None + weight = getattr(self._handler, weight_name) + quantizer.compute_quant_param(weight) def construct(self, *inputs, **kwargs): """ @@ -51,20 +57,10 @@ class QuantizeWrapperCell(Cell): assert weight_name is not None assert quantizer is not None weight = getattr(self._handler, weight_name) - quant_param = quantizer.compute_quant_param(weight) - fq_data = quantizer.fake_quant(weight, quant_param) + quantizer.compute_quant_param(weight) + fq_data = quantizer(weight) setattr(self._handler, weight_name, fq_data) - # fake-quant activation - for act_name, quantizers in self._policy.get_act_name_and_quantizers(): - assert act_name is not None - if quantizers is None: - continue - pre_quantizer, post_quantizer = quantizers - activation = getattr(self._handler, act_name) - quant_act = QuantizeWrapperActivation(activation, pre_quantizer, post_quantizer) - setattr(self._handler, act_name, quant_act) - # fake-quant input input_quantizer = self._policy.get_input_quantizer() if input_quantizer is None: @@ -75,8 +71,7 @@ class QuantizeWrapperCell(Cell): for i in range(0, input_len): ori_input = inputs[i] if self._policy.get_input_need_insert_fq(i): - quant_param = input_quantizer.compute_quant_param(ori_input) - fq_inputs.append(input_quantizer.fake_quant(ori_input, quant_param)) + fq_inputs.append(input_quantizer(ori_input)) else: fq_inputs.append(ori_input) @@ -93,13 +88,11 @@ class QuantizeWrapperCell(Cell): if output_len == 0: return outputs elif output_len == 1: - quant_param = output_quantizer.compute_quant_param(outputs) - fq_data = output_quantizer.fake_quant(outputs, quant_param) + fq_data = output_quantizer(outputs) return fq_data else: fq_outputs = [] for i in range(0, output_len): ori_output = outputs[i] - quant_param = output_quantizer.compute_quant_param(ori_output) - fq_outputs.append(output_quantizer.fake_quant(ori_output, quant_param)) + fq_outputs.append(output_quantizer(ori_output)) return fq_outputs diff --git a/mindspore/python/mindspore/nn/layer/combined.py b/mindspore/python/mindspore/nn/layer/combined.py index a094900e2fd..c8e10e86f88 100644 --- a/mindspore/python/mindspore/nn/layer/combined.py +++ b/mindspore/python/mindspore/nn/layer/combined.py @@ -21,7 +21,6 @@ from .normalization import BatchNorm2d, BatchNorm1d from .activation import get_activation, LeakyReLU from ..cell import Cell - __all__ = [ 'Conv2dBnAct', 'DenseBnAct' diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index 15c7f316f30..746109ad96d 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -16,6 +16,7 @@ from .parser import Parser from mindspore.rewrite.ast_unparser import ASTUnparser + class _node_list: def __init__(self, graph) -> None: self._graph = graph @@ -50,7 +51,8 @@ class _node_list: continue raise StopIteration -class Graph(): + +class Graph: def __init__(self, network: Union[nn.Cell, Primitive, FunctionType]): self._name = network.__name__ self._network = network @@ -143,7 +145,7 @@ class Graph(): args = self._parser.parse_arguments(ast_node) for name, value in args.items(): new_node = PlaceholderNode(name, name, ast_node, default_value=value) - logger.debug ("placeholder node: %r", new_node) + logger.debug("placeholder node: %r", new_node) self._nodes.append(new_node) def parse_init(self): @@ -175,21 +177,22 @@ class Graph(): logger.info(f"parse {self._base_scope} construct function start") self._parser.updete_closure_namespace(self._network.construct) self.create_placeholder(self._ast_function_root["construct"]) - name_counts = {} #save the number of the variable, if the number is over 1,then modify the name - add a number as the name suffix + name_counts = {} # save the number of the variable, if the number is over 1,then modify the name - add a + # number as the name suffix index = 0 for ast_node in self._ast_function_root["construct"].body: logger.debug(f"process ast node: {ast_node}") if isinstance(ast_node, ast.Expr): continue - #method = 'parse_' + ast_node.__class__.__name__ - #visitor = getattr(self._parser, method, None) + # method = 'parse_' + ast_node.__class__.__name__ + # visitor = getattr(self._parser, method, None) visitor = self._parser.get_node_visitor(ast_node) if not visitor: logger.warning("Get node visitor failed in parse_construct, node: %r", ast_node) continue nodes, attribute_names = visitor(ast_node) - for i in range(len(nodes)): + for i in range(len(nodes)): n = nodes[i].name.split(".")[-1] if attribute_names and attribute_names[i] in self._node_attributes.keys(): logger.debug(f"defined in init function: {attribute_names[i]}") @@ -203,7 +206,8 @@ class Graph(): elif self._parser.get_func_namesapce(n)[0]: class_, name_space_, is_custom_define_ = self._parser.get_func_namesapce(n) logger.debug(f"defined in other namespace: {n}") - logger.debug("class: ", class_, "name space: ", name_space_, "is custom define: ", is_custom_define_) + logger.debug("class: ", class_, "name space: ", name_space_, "is custom define: ", + is_custom_define_) nodes[i]._attribute._is_custom_define = is_custom_define_ nodes[i]._attribute._class = class_ nodes[i]._attribute._type = NodeType.call_function @@ -229,7 +233,7 @@ class Graph(): logger.debug("construct nodes: ") for node in self._nodes: logger.debug(node) - logger.info(f"parse {self._base_scope } construct function end") + logger.info(f"parse {self._base_scope} construct function end") def parse_function(self, func: Union[ast.FunctionDef, FunctionType]): """ @@ -239,13 +243,13 @@ class Graph(): logger.info(f"parse {func.__name__} function start") function_str = inspect.getsource(func) ast_root = ast.parse(function_str) - astpretty.pprint (ast_root) + astpretty.pprint(ast_root) node = ast_root.body[0] - subgraph = Graph(func) #要区分类内还是类外方法 + subgraph = Graph(func) # 要区分类内还是类外方法 else: logger.info(f"parse {func.name} function start") node = func - subgraph = Graph(self._network.__dict__[node.name]) #要区分类内还是类外方法 + subgraph = Graph(self._network.__dict__[node.name]) # 要区分类内还是类外方法 subgraph._name = node.name subgraph._ast_root = node subgraph.create_placeholder(node) @@ -255,15 +259,15 @@ class Graph(): logger.debug(f"process ast node: {ast_node}") if isinstance(ast_node, ast.Expr): continue - #method = 'parse_' + ast_node.__class__.__name__ - #visitor = getattr(self._parser, method, None) + # method = 'parse_' + ast_node.__class__.__name__ + # visitor = getattr(self._parser, method, None) visitor = self._parser._get_node_visitor(ast_node) if not visitor: logger.warning(f"get node visitor failed in parse_function, node: {ast_node}") nodes, attribute_names = visitor(ast_node) for i in range(len(nodes)): nodes[i]._index = index - nodes[i].name = self._base_scope + "." + node.name + "." + nodes[i].name + nodes[i].name = self._base_scope + "." + node.name + "." + nodes[i].name if attribute_names and attribute_names[i] in self._node_attributes.keys(): nodes[i]._attribute = self._node_attributes[attribute_names[i]] else: @@ -274,7 +278,7 @@ class Graph(): index += 1 logger.debug(f"{subgraph._name} nodes: ") for node in subgraph._nodes: - logger.debug (node) + logger.debug(node) logger.info("parse function end") return subgraph @@ -327,8 +331,8 @@ class Graph(): break if arg in self._node_attributes.keys(): - node.inputs.append(self._node_attributes[arg]) - continue + node.inputs.append(self._node_attributes[arg]) + continue return @@ -337,9 +341,9 @@ class Graph(): Parse the subgraph defined in '__init__' function. """ for name, node in self._node_attributes.items(): - #print ("name: ", name, "; node: ", node) + # print ("name: ", name, "; node: ", node) if node._is_custom_define and isinstance(node._class, FunctionType): - logger.info("The node is FunctionType, node: %r", node) + logger.info("The node is FunctionType, node: %r", node) subgraph = self.parse_function(node._class) logger.debug("name = %r", name, "; subgraph = %r", subgraph) self._subgraphs["self." + name] = subgraph @@ -352,7 +356,7 @@ class Graph(): graph.parse_init() graph.parse_construct() for node in graph._nodes: - node.name = self._base_scope + "." + name.split(".")[-1] + "." + node.name.split(".")[-1] + node.name = self._base_scope + "." + name.split(".")[-1] + "." + node.name.split(".")[-1] logger.debug(f"{name} subgraph node: {node}") self._subgraphs[name] = graph elif node._is_custom_define and issubclass(node._class, Primitive): @@ -374,14 +378,18 @@ class Graph(): node.outputs[0].inputs = node.inputs self._nodes.remove(node) - def replace_node(self, src_nodes, dst_node): + def replace_node(self, src_nodes, dst_nodes): """ Replace src_nodes in 'nodes' by dst_node, modify ast synchronously. Args: src_nodes: Nodes to be replaced. - dst_node: Node used to replace. + dst_nodes: Nodes used to replace, dst_nodes can be node or chain-like node list (first element of this list + is output). """ + if isinstance(dst_nodes, list) and len(dst_nodes) == 0: + logger.warning("replace_node doesn't support empty dst_nodes, please use remove_node") + return if isinstance(src_nodes, list): # redirect edges appends = [] @@ -389,12 +397,18 @@ class Graph(): for output in src_node.outputs: if output not in src_nodes: appends.append(output) - dst_node.outputs = appends + if isinstance(dst_nodes, list): + dst_nodes[0].outputs = appends + else: + dst_nodes.outputs = appends for append in appends: new_inputs = [] for input in append.inputs: if input in src_nodes: - new_inputs.append(dst_node) + if isinstance(dst_nodes, list): + new_inputs.append(dst_nodes[0]) + else: + new_inputs.append(dst_nodes) else: new_inputs.append(input) append.inputs = new_inputs @@ -404,12 +418,16 @@ class Graph(): for input in src_node.inputs: if input not in src_nodes: appends.append(input) - dst_node.inputs = appends + + if isinstance(dst_nodes, list): + dst_nodes[-1].inputs = appends + else: + dst_nodes.inputs = appends for append in appends: new_inputs = [] for output in append.outputs: if output in src_nodes: - new_inputs.append(dst_node) + new_inputs.append(dst_nodes) else: new_inputs.append(output) append.outputs = new_inputs @@ -421,7 +439,10 @@ class Graph(): if cur_node is node: self._nodes.pop(i) break - self._nodes.append(dst_node) + if isinstance(dst_nodes, list): + self._nodes.extend(dst_nodes) + else: + self._nodes.append(dst_nodes) else: pass @@ -453,7 +474,7 @@ class Graph(): def check(self): pass - + @property def python_code(self): return astunparse.unparse(self._ast_root) diff --git a/mindspore/python/mindspore/rewrite/pattern_engine.py b/mindspore/python/mindspore/rewrite/pattern_engine.py index f31d6644c16..b0db675a889 100644 --- a/mindspore/python/mindspore/rewrite/pattern_engine.py +++ b/mindspore/python/mindspore/rewrite/pattern_engine.py @@ -217,25 +217,25 @@ class PatternEngine: continue matched_list = list(matched_dict.values()) if self._is_chain: - new_node = self._process_chain(matched_list) + new_nodes = self._process_chain(matched_list) else: - new_node = self._process_tree(matched_dict) - if new_node is None: # return None to remove + new_nodes = self._process_tree(matched_dict) + if len(new_nodes) == 0: # return None to remove changed = True for key in matched_dict: graph.remove_node(matched_dict[key]) - elif new_node == cur_node: # return origin Node for do nothing + elif new_nodes == cur_node: # return origin Node for do nothing pass else: # return Node to insert or replace (new Node no need to set inputs and outputs) # todo if we need to support _process_chain or _process_tree return multi-node changed = True - graph.replace_node(matched_list, new_node) - node_inputs = new_node.inputs + graph.replace_node(matched_list, new_nodes) + node_inputs = new_nodes[-1].inputs for node_input in node_inputs: queue.append(node_input) return changed - def _process_chain(self, matched_nodes: [Node]) -> Node: + def _process_chain(self, matched_nodes: [Node]) -> [Node]: """ Define how to generate a new_node with fuse_fn when pattern is a chain-pattern. @@ -248,19 +248,26 @@ class PatternEngine: if self._replacement is None: return matched_nodes[len(matched_nodes) - 1] - replacement = self._replacement(*matched_nodes) - if replacement is None: - return None - if len(matched_nodes) == 0: - new_node = Node(instance=replacement) + replacements = self._replacement(*matched_nodes) + if replacements is None or len(matched_nodes) == 0: + return [] + if isinstance(replacements, list): + new_nodes = [Node(instance=item) for item in replacements] else: - new_node = Node(instance=replacement, inputs=matched_nodes[0].inputs) - node_name = "" - for matched_node in matched_nodes: - node_name += matched_node.name + "_" - node_name += "fused" - new_node.name = node_name - return new_node + new_nodes = [Node(instance=replacements)] + if len(new_nodes) == 1: + node_name = "" + for matched_node in matched_nodes: + node_name += matched_node.name + "_" + node_name += "fused" + new_nodes[0].name = node_name + + if len(new_nodes) > 0: + new_nodes[0].outputs = matched_nodes[0].outputs + for i in range(1, len(new_nodes)): + new_nodes[i].outputs = [new_nodes[i-1]] + new_nodes[i-1].inputs = [new_nodes[i]] + return new_nodes # matched_cells: name_of_cell_in_pattern map to matched cell in network def _process_tree(self, matched_nodes: OrderedDict) -> Node: diff --git a/tests/ut/python/golden_stick/test_quant_aware_training.py b/tests/ut/python/golden_stick/test_quant_aware_training.py new file mode 100644 index 00000000000..0a512c50f44 --- /dev/null +++ b/tests/ut/python/golden_stick/test_quant_aware_training.py @@ -0,0 +1,8 @@ +from mindspore.golden_stick.quantization.default_qat.default_quantize import DefaultQuantAwareTraining +from lenet import LeNet5 + + +# def test_default_quant_aware_train(): +# qat = DefaultQuantAwareTraining() +# net = LeNet5() +# transformed_net = qat.apply(net) -- Gitee From a6bc59f474ebdde0e5248210841a5948beb9ea17 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Sat, 8 Jan 2022 15:02:30 +0800 Subject: [PATCH 34/34] support origin qat features --- .../golden_stick/quantization/__init__.py | 2 +- .../default_qat/default_net_policy.py | 8 +- .../quantization/net_transform_sequential.py | 37 +++++++ .../quantization/pattern_engine_sequential.py | 60 +++++++++++ ...re_training.py => quant_aware_training.py} | 2 +- .../quant_aware_training_sequential.py | 102 ++++++++++++++++++ mindspore/python/mindspore/nn/layer/quant.py | 25 +++++ 7 files changed, 231 insertions(+), 5 deletions(-) create mode 100644 mindspore/python/mindspore/golden_stick/quantization/net_transform_sequential.py create mode 100644 mindspore/python/mindspore/golden_stick/quantization/pattern_engine_sequential.py rename mindspore/python/mindspore/golden_stick/quantization/{quante_aware_training.py => quant_aware_training.py} (99%) create mode 100644 mindspore/python/mindspore/golden_stick/quantization/quant_aware_training_sequential.py diff --git a/mindspore/python/mindspore/golden_stick/quantization/__init__.py b/mindspore/python/mindspore/golden_stick/quantization/__init__.py index 36caa99561d..5da50c654e3 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/__init__.py +++ b/mindspore/python/mindspore/golden_stick/quantization/__init__.py @@ -18,7 +18,7 @@ MindSpore golden stick module. from .layer_policy import LayerPolicy from .net_policy import NetPolicy -from .quante_aware_training import QuantAwareTraining +from .quant_aware_training import QuantAwareTraining from .fake_quantizer import FakeQuantizer from .transformer import Transformer from .default_qat import DefaultLayerPolicy, DefaultNetworkPolicy, DefaultQuantAwareTraining diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py index 032a766ad03..a7bcb341211 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py @@ -19,6 +19,7 @@ from ..net_policy import NetPolicy from ..layer_policy import LayerPolicy from .default_layer_policy import DefaultLayerPolicy from ..transformer import Transformer +from ..pattern_engine_sequential import PatternEngineSequential from mindspore.nn.layer import Conv2d, Dense, MatMul, BatchNorm2d, ReLU, Conv2dBnAct from mindspore.nn.layer.quant import Conv2dBnFoldQuantOneConv from mindspore.rewrite.pattern_engine import PatternEngine @@ -59,9 +60,10 @@ class DefaultNetworkPolicy(NetPolicy): Transformer([Conv2d, BatchNorm2d]), Transformer([Conv2d, ReLU]), PatternEngine([Conv2dBnAct], _split_conv2d_bn_act), - PatternEngine([Conv2d, BatchNorm2d], - partial(Conv2dBnFoldQuantOneConv.from_float, - _fetch_quant_config(self.get_net_layer_policy()))) + PatternEngineSequential([Conv2d, BatchNorm2d], + partial(Conv2dBnFoldQuantOneConv.from_float, + _fetch_quant_config(self.get_net_layer_policy())) + ) ] self._support_layer_map: dict = { Conv2d: DefaultLayerPolicy(["weight"], [], config), diff --git a/mindspore/python/mindspore/golden_stick/quantization/net_transform_sequential.py b/mindspore/python/mindspore/golden_stick/quantization/net_transform_sequential.py new file mode 100644 index 00000000000..588818ae374 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/net_transform_sequential.py @@ -0,0 +1,37 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""net transform sequential.""" +from ..net_transform import NetTransformer +from mindspore.nn import CellList, Cell +from layer_policy import LayerPolicy +from .pattern_engine_sequential import PatternEngineSequential + + +class NetTransformerSequential(NetTransformer): + def __init__(self, net: CellList): + super(NetTransformerSequential, self).__init__(net) + self.layer_policies: [LayerPolicy] = [] + self.net: CellList = net + + def replace_cell(self, index, new_cell: Cell): + if index >= len(self.net): + return False + self.net.__setitem__(index, new_cell) + return True + + def apply_pattern_engine_sequential(self, pattern_engine: PatternEngineSequential): + pattern_engine.apply_sequence(self.net) + + diff --git a/mindspore/python/mindspore/golden_stick/quantization/pattern_engine_sequential.py b/mindspore/python/mindspore/golden_stick/quantization/pattern_engine_sequential.py new file mode 100644 index 00000000000..31fc4fd831e --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/pattern_engine_sequential.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pattern engine without ast.""" +from mindspore.rewrite.pattern_engine import PatternEngine, PatternNode +from mindspore.nn import CellList, Cell +from typing import Union, List, Tuple +from collections import OrderedDict + + +class PatternEngineSequential(PatternEngine): + def __init__(self, pattern: Union[PatternNode, List], replacement: callable = None): + super(PatternEngineSequential, self).__init__(pattern, replacement) + + def _match_sequence(self, pattern: PatternNode, model: CellList, index=0) -> Tuple[bool, OrderedDict]: + result = OrderedDict() + if pattern.type() != model.__getitem__(index): + return False, result + result[pattern.name()] = index + if len(pattern.inputs()) == 0: + return True, result + if index >= len(model) - 1: + return False, OrderedDict() + match, next_result = self._match_sequence(pattern.inputs()[0], model, index + 1) + if match: + result.update(next_result) + return True, result + else: + return False, OrderedDict() + + def _replace_sequence(self, model: CellList, matched_dict, replacements): + first_idx_removed = matched_dict.values[0] + for idx in matched_dict.values: + model.__delitem__(idx) + if isinstance(replacements, List): + for idx in range(len(replacements)): + model.insert(first_idx_removed + idx, replacements[idx]) + elif isinstance(replacements, Cell): + model.insert(first_idx_removed, replacements) + else: + raise Exception("replace sequence only support list or cell yet.") + + def apply_sequence(self, model: CellList): + matched, matched_dict = self._match_sequence(self.pattern(), model) + if not matched: + return False + matched_nodes = [model.__getitem__(i) for i in matched_dict.values()] + replacements = self._replacement(*matched_nodes) + self._replace_sequence(model, matched_dict, replacements) diff --git a/mindspore/python/mindspore/golden_stick/quantization/quante_aware_training.py b/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py similarity index 99% rename from mindspore/python/mindspore/golden_stick/quantization/quante_aware_training.py rename to mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py index 1937f267aa2..57ff1f9ea04 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quante_aware_training.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Quantize.""" +"""quantize aware training.""" import copy from typing import Optional diff --git a/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training_sequential.py b/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training_sequential.py new file mode 100644 index 00000000000..eb2b493ff91 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training_sequential.py @@ -0,0 +1,102 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""quantize aware training sequential.""" +import copy +from typing import Optional +from .quant_aware_training import QuantAwareTraining +from .net_transform_sequential import NetTransformerSequential +from .pattern_engine_sequential import PatternEngineSequential +from mindspore.nn import CellList +from .layer_policy import LayerPolicy, layer_policy_key + + +class QuantAwareTrainingSequential(QuantAwareTraining): + def __init__(self, config: {}): + super(QuantAwareTrainingSequential, self).__init__(config) + + def _propagate_layer_policy_sequential(self, network: CellList, net_transformer: NetTransformerSequential): + """ + Set layer_policy for every layer according to custom_layer_policy_map, layer_policy_map and net_layer_policy in + QuantAwareTraining. custom_layer_policy_map is in first priority, layer_policy_map is in second priority and + net_layer_policy is in last priority. + + Args: + nodes (List[Node]): nodes to be checked between which may find redundant fake-quantizer + """ + + # step1 apply net layer-policy first + net_layer_policy: Optional[LayerPolicy] = self._qat_policy.get_net_layer_policy() + if net_layer_policy: + net_transformer.layer_policies = [copy.copy(net_layer_policy)] * len(CellList) + + # step2 then apply layer-policy map, override policy if need + layer_policy_map = self._qat_policy.get_layer_policy_map() + for i, cell in enumerate(network.cells()): + layer_policy: LayerPolicy = self._custom_layer_policy_map.get(cell.cell_type) + if layer_policy is None: + layer_policy = layer_policy_map.get(cell.cell_type) + if isinstance(layer_policy, LayerPolicy): + new_layer_policy = copy.copy(layer_policy) + new_layer_policy.set_input_number(1) + net_transformer.layer_policies[i] = new_layer_policy + + @staticmethod + def _reduce_redundant_fake_quant_sequential(net_transformer: NetTransformerSequential): + """ + Reduce redundant fake-quantizer node between nodes. It usually occurs when pre-node inserted output + fake-quantizer and post-node inserted input fake-quantizer at the same time. + + Args: + nodes (List[Node]): nodes to be checked between which may find redundant fake-quantizer + """ + for idx in range(len(net_transformer.layer_policies) - 1): + cur_policy: LayerPolicy = net_transformer.layer_policies[idx] + next_policy: LayerPolicy = net_transformer.layer_policies[idx + 1] + if isinstance(type(cur_policy.get_output_quantizer()[0]), type(next_policy.get_input_quantizer()[0])): + continue + cur_policy.set_output_not_insert_fq() + + def _apply_fuse_patterns(self, net_transformer: NetTransformerSequential): + transformers = self._qat_policy.get_transformers() + if isinstance(self._custom_transforms, list): + for transform in self._custom_transforms: + if isinstance(transform, PatternEngineSequential): + transformers.append(transform) + for transformer in transformers: + # Transformer always return False + net_transformer.apply_pattern_engine_sequential(transformer) + + @staticmethod + def _apply_layer_policy(net: CellList, net_transformer: NetTransformerSequential): + """ + Apply layer-policy to corresponding layer. + Replace layer with return value of wrap_cell of layer-policy by default. + + Args: + net (CellList): Network to be quantized. + net_transformer (NetTransformer): net_transformer is used to transform node according to layer policy. + """ + + for i in range(len(net)): + layer_policy = net_transformer.layer_policies[i] + net_transformer.replace_cell(i, layer_policy.wrap_cell(net[i])) + + def apply(self, net: CellList) -> CellList: + net_transformer = NetTransformerSequential(net) + self._propagate_layer_policy_sequential(net, net_transformer) + QuantAwareTrainingSequential._reduce_redundant_fake_quant_sequential(net_transformer) + self._apply_fuse_patterns(net_transformer) + QuantAwareTrainingSequential._apply_layer_policy(net, net_transformer) + return net diff --git a/mindspore/python/mindspore/nn/layer/quant.py b/mindspore/python/mindspore/nn/layer/quant.py index af15b23735c..942245dcb19 100644 --- a/mindspore/python/mindspore/nn/layer/quant.py +++ b/mindspore/python/mindspore/nn/layer/quant.py @@ -30,6 +30,7 @@ from types import DynamicClassAttribute import mindspore.context as context from .normalization import BatchNorm2d from .conv import Conv2d +from .basic import Dense from .activation import get_activation from ..cell import Cell from ... import nn @@ -47,6 +48,7 @@ __all__ = [ 'MulQuant', ] + @enum.unique class QuantDtype(enum.Enum): """ @@ -1455,6 +1457,22 @@ class Conv2dQuant(Cell): num_channels=out_channels, quant_dtype=quant_dtype) + @classmethod + def from_float(cls, conv: Conv2d, quant_config: QuantConfig): + conv_quant = cls(conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + pad_mode=conv.pad_mode, + padding=conv.padding, + dilation=conv.dilation, + group=conv.group, + has_bias=conv.has_bias, + bias_init=conv.bias_init, + quant_config=quant_config + ) + return conv_quant + def construct(self, x): weight = self.fake_quant_weight(self.weight) out = self.conv(x, weight) @@ -1579,6 +1597,13 @@ class DenseQuant(Cell): num_channels=out_channels, quant_dtype=quant_dtype) + @classmethod + def from_float(cls, dense: Dense, quant_config: QuantConfig): + dense_quant = cls(dense.in_channels, dense.out_channels, + has_bias=dense.has_bias, activation=dense.activation, quant_config=quant_config) + dense_quant.weight = dense.weight + return dense_quant + def construct(self, x): """Use operators to construct the Dense layer. -- Gitee