From b8832dd0c44c8412f55303421f07a1e9e860354a Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 7 Feb 2022 09:57:17 +0800 Subject: [PATCH 01/32] tmp --- .../python/mindspore/rewrite_experiment/rewrite.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index 5752766098a..98f31c15d8b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -30,11 +30,9 @@ class Rewrite: def nodes(self) -> {}: return self._symbol_tree.nodes() - # only support node now def insert_before(self, node_or_name: Union[Node, str]): self._symbol_tree.insert_before(node_or_name) - # only support node now def insert_after(self, node_or_name: Union[Node, str]): self._symbol_tree.insert_after(node_or_name) @@ -57,18 +55,22 @@ class Rewrite: call_kwargs: {str: Argument}) -> Optional[Node]: return self._symbol_tree.add_custom_node(custom_obj, field, targets, call_args, call_kwargs, target_type) - def update_output(self, return_values: [str]) -> Optional[Node]: - return self._symbol_tree.update_output(return_values) + # def update_output(self, return_values: [str]) -> Optional[Node]: + # return self._symbol_tree.update_output(return_values) def add_output(self, return_value: str, index: Optional[int] = None) -> Optional[Node]: return self._symbol_tree.add_output(return_value, index) - def set_output(self, return_value: str, index: int) -> Optional[Node]: + def set_output(self, index: int, return_value: str) -> Optional[Node]: return self._symbol_tree.set_output(return_value, index) def add_input(self, name: str, input_type: Optional[type] = None, default: Optional[str] = None) -> bool: return self._symbol_tree.add_input_and_update_ast(name, input_type, default) + def set_input(self, index: int, name: str, input_type: Optional[type] = None, default: Optional[str] = None) -> \ + bool: + raise NotImplementedError + def update_arg(self, node: Node, index: int, arg: str): node.update_arg(index, arg) -- Gitee From 661579d32a0e385cc1a5ee3568c2930d3379478a Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 7 Feb 2022 15:14:20 +0800 Subject: [PATCH 02/32] update interface of rewrite --- .../rewrite_experiment/ast_parse/resolver.py | 3 +- .../rewrite_experiment/example/test_lenet.py | 24 ++++----- .../mindspore/rewrite_experiment/rewrite.py | 45 ++++++---------- .../rewrite_experiment/symbol_tree.py | 52 ++++++++----------- 4 files changed, 52 insertions(+), 72 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py index 29fa5787f77..c1d5dba4aea 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py @@ -148,7 +148,8 @@ class Resolver: raise RuntimeError("kwargs in construct function assign is unsupported") cell_type = Resolver._get_cell_type(func) # todo call attribute_resolver to resolve attributes - ret = symbol_tree.add_origin_field(func, cell_type, [target], call_args, {}, body.symbol_ast) + ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, cell_type, [target], call_args, + {}, body.symbol_ast) if ret is None: raise RuntimeError("add_origin_field failed") elif body.symbol_type() == SymbolType.return_: diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py index eb5a93bee92..79f07d09ce4 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py @@ -10,7 +10,7 @@ from mindspore.rewrite_experiment import Rewrite, Argument, ArgType class MyCell(nn.Cell): def __init__(self): super().__init__() - self.conv = nn.Dense(32, 16) + self.conv = nn.Dense(5, 16) def construct(self, x): x = self.conv(x) @@ -18,8 +18,6 @@ class MyCell(nn.Cell): def transform(rw: Rewrite): - rw.add_input("nx0") - rewrite.dump("after add_input") for _, node in rw.nodes().items(): targets = node.get_targets() if targets is None: @@ -27,13 +25,13 @@ def transform(rw: Rewrite): assert targets[0].type == ArgType.NamingArg target = str(targets[0]) if target == "self_flatten": - rw.insert_before(node) - custom_in_channel = 1 + position = rw.before(node) + custom_in_channel = 16 construct_args=[Argument.create_custom_arg(custom_in_channel), Argument.create_imm_arg(16), Argument.create_imm_arg(3)] - ret = rw.add_cell(cell_type=nn.Conv2d, field='conv_new', construct_args=construct_args, construct_kwargs={}, - targets=['nx1'], target_type="", call_args=[Argument.create_naming_arg('nx0')], - call_kwargs={}) + ret = rw.add_cell(position=position, cell_type=nn.Conv2d, field='conv_new', construct_args=construct_args, + construct_kwargs={}, targets=['nx1'], target_type="", + call_args=[Argument.create_naming_arg('x')], call_kwargs={}) if ret is None: raise RuntimeError("add_cell failed") break @@ -45,10 +43,10 @@ def transform(rw: Rewrite): assert targets[0].type == ArgType.NamingArg target = str(targets[0]) if target == "self_relu_3": - rw.insert_before(node) + position = rw.before(node) custom_cell = MyCell() - ret = rw.add_custom_node(custom_obj=custom_cell, field='my_cell', targets=['nx2'], target_type="", - call_args=[Argument.create_naming_arg('nx3')], call_kwargs={}) + ret = rw.add_object(position=position, custom_obj=custom_cell, field='my_cell', targets=['nx2'], + target_type="", call_args=[Argument.create_naming_arg('nx3')], call_kwargs={}) if ret is None: raise RuntimeError("add_custom_node failed") break @@ -63,7 +61,7 @@ def transform(rw: Rewrite): rw.update_arg(node, 0, "nx1") break rewrite.dump("after update_arg") - rw.add_output("nx2") + rw.set_output(0, "nx2") rewrite.dump("after add_output") for _, node in rw.nodes().items(): targets = node.get_targets() @@ -86,7 +84,7 @@ if __name__ == '__main__': transform(rewrite) lenet_opt = rewrite.get_network() context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=True, save_graphs_path='./lenet_dump') - double_inputs = True + double_inputs = False input1 = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32) input2 = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32) if double_inputs: diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index 98f31c15d8b..5ae8f701283 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -14,8 +14,6 @@ # ============================================================================ from typing import Union, Optional -from mindspore import log as logger -from .symbol_tree import SymbolTree from .node import Node from .ast_parse import AstParse from mindspore.nn import Cell @@ -30,47 +28,38 @@ class Rewrite: def nodes(self) -> {}: return self._symbol_tree.nodes() - def insert_before(self, node_or_name: Union[Node, str]): - self._symbol_tree.insert_before(node_or_name) + def before(self, node_or_name: Union[Node, str]): + return self._symbol_tree.insert_before(node_or_name) - def insert_after(self, node_or_name: Union[Node, str]): - self._symbol_tree.insert_after(node_or_name) - - # only support node now - def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: - return self._symbol_tree.erase_node(node_or_name) + def after(self, node_or_name: Union[Node, str]): + return self._symbol_tree.insert_after(node_or_name) # self.'field': 'cell_type' = mindspore.nn.'cell_type'(*'construct_args', **'construct_kwargs') # 'targets': 'target_type' = self.'field'(*'call_args', **'call_kwargs') - def add_cell(self, cell_type: type, field: str = None, construct_args: [Argument] = None, + def add_cell(self, position, cell_type: type, field: str = None, construct_args: [Argument] = None, construct_kwargs: {str: Argument} = None, targets: [str] = None, target_type: str = "", call_args: [Argument] = None, call_kwargs: {str: Argument} = None) -> Optional[Node]: # todo call ast_parser and attributes_resolver to resolve attributes - return self._symbol_tree.add_cell(cell_type, field, construct_args, construct_kwargs, targets, target_type, - call_args, call_kwargs) + attribute = {} + return self._symbol_tree.add_cell(position, cell_type, field, construct_args, construct_kwargs, targets, + target_type, call_args, call_kwargs, attribute) + + def add_function(self, position, *args, **kwargs): + raise NotImplementedError # self.'field': 'custom_obj_type' = global_vars.get('field') # 'targets': 'target_type' = self.'field'(*'call_args', **'call_kwargs') - def add_custom_node(self, custom_obj: Cell, field: str, targets: [str], target_type: str, call_args: [Argument], - call_kwargs: {str: Argument}) -> Optional[Node]: - return self._symbol_tree.add_custom_node(custom_obj, field, targets, call_args, call_kwargs, target_type) - - # def update_output(self, return_values: [str]) -> Optional[Node]: - # return self._symbol_tree.update_output(return_values) + def add_object(self, position, custom_obj: Cell, field: str, targets: [str], target_type: str, + call_args: [Argument], call_kwargs: {str: Argument}) -> Optional[Node]: + return self._symbol_tree.add_object(position, custom_obj, field, targets, call_args, call_kwargs, + target_type) - def add_output(self, return_value: str, index: Optional[int] = None) -> Optional[Node]: - return self._symbol_tree.add_output(return_value, index) + def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: + return self._symbol_tree.erase_node(node_or_name) def set_output(self, index: int, return_value: str) -> Optional[Node]: return self._symbol_tree.set_output(return_value, index) - def add_input(self, name: str, input_type: Optional[type] = None, default: Optional[str] = None) -> bool: - return self._symbol_tree.add_input_and_update_ast(name, input_type, default) - - def set_input(self, index: int, name: str, input_type: Optional[type] = None, default: Optional[str] = None) -> \ - bool: - raise NotImplementedError - def update_arg(self, node: Node, index: int, arg: str): node.update_arg(index, arg) diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index 15561038c01..e171344f331 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -55,8 +55,6 @@ class SymbolTree: self._construct_func_ast: ast.FunctionDef = construct_func_ast # need deep copy? # root must be output of graph self._root = self._add_output(["undefined"]) - self._insert_before: bool = True - self._insert_point: Node = self._root # head must be the first statement but must not be inputs of graph self._head = self._add_head_node() self._ori_cls_name = type(origin_network).__name__ @@ -77,8 +75,7 @@ class SymbolTree: logger.warning("%s is not found, ignored. Insert point keep unchanged.") else: raise RuntimeError("Unsupported node_or_name: ", node_or_name) - self._insert_before = True - self._insert_point = node + return True, node def insert_after(self, node_or_name: Union[Node, str]): if isinstance(node_or_name, Node): @@ -89,8 +86,7 @@ class SymbolTree: logger.warning("%s is not found, ignored. Insert point keep unchanged.") else: raise RuntimeError("Unsupported node_or_name: ", node_or_name) - self._insert_before = False - self._insert_point = node + return False, node # def create_node(self, node_type: NodeType, construct_fn: str, targets: [str], field: Optional[str], # construct_args: [], construct_kwargs: {str: object}, call_args: [], call_kwargs: {str: object}, @@ -133,7 +129,7 @@ class SymbolTree: break return node - def add_cell(self, cell_type: type, field: str = None, construct_args: [Argument] = None, + def add_cell(self, position, cell_type: type, field: str = None, construct_args: [Argument] = None, construct_kwargs: {str: Argument} = None, targets: [str] = None, target_type: str = "", call_args: [Argument] = None, call_kwargs: {str: Argument} = None, attribute: {str, object}=None) -> Optional[Node]: @@ -154,16 +150,16 @@ class SymbolTree: logger.error("insert cell into init function ast tree failed.") return None # modify construct function - index = self._find_node_index(self._insert_point) + index = self._find_node_index(position[1]) assert index is not None call_args, new_arg_nodes = self._handle_custom_object_in_args(self._construct_func_ast, call_args) for node in new_arg_nodes: - if not self._insert_node(node): + if not self._insert_node(position, node): logger.error("insert custom object node into symbol_tree failed.") return None call_kwargs, new_kwarg_nodes = self._handle_custom_object_in_kwargs(self._construct_func_ast, call_kwargs) for node in new_kwarg_nodes: - if not self._insert_node(node): + if not self._insert_node(position, node): logger.error("insert custom object node into symbol_tree failed.") return None if targets is None: @@ -171,19 +167,18 @@ class SymbolTree: targets = SymbolTree._convert_targets(targets) ast_node = AstModifier.insert_assign_to_function(self._construct_func_ast, targets, Argument(ArgType.NamingArg, "self", field), - call_args, call_kwargs, self._insert_point.get_ast(), - self._insert_before) + call_args, call_kwargs, position[1].get_ast(), position[0]) if ast_node is None: logger.error("insert cell into construct function ast tree failed.") return None # create and insert node node = Node(NodeType.CallCell, targets, call_args, call_kwargs, ast_node, attribute) - if not self._insert_node(node): + if not self._insert_node(position, node): return None return node - def add_custom_node(self, custom_obj: Cell, field: str, targets: [str], call_args: [], call_kwargs: {str: object}, - target_type: str = "") -> Optional[Node]: + def add_object(self, position, custom_obj: Cell, field: str, targets: [str], call_args: [], + call_kwargs: {str: object}, target_type: str = "") -> Optional[Node]: if field is None: field = self._generate_new_field_name() # modify init function @@ -195,16 +190,16 @@ class SymbolTree: logger.error("insert custom_node into init function ast tree failed.") return None # modify construct function - index = self._find_node_index(self._insert_point) + index = self._find_node_index(position[1]) assert index is not None call_args, new_arg_nodes = self._handle_custom_object_in_args(self._init_func_ast, call_args) for node in new_arg_nodes: - if not self._insert_node(node): + if not self._insert_node(position, node): logger.error("insert custom object node into symbol_tree failed.") return None call_kwargs, new_kwarg_nodes = self._handle_custom_object_in_kwargs(self._init_func_ast, call_kwargs) for node in new_kwarg_nodes: - if not self._insert_node(node): + if not self._insert_node(position, node): logger.error("insert custom object node into symbol_tree failed.") return None if targets is None: @@ -213,15 +208,14 @@ class SymbolTree: ast_node = AstModifier.insert_assign_to_function(self._construct_func_ast, targets=targets, expr=Argument(ArgType.NamingArg, "self", field), args=call_args, kwargs=call_kwargs, - index_ast=self._insert_point.get_ast(), - insert_before=self._insert_before) + index_ast=position[1].get_ast(), insert_before=position[0]) if ast_node is None: logger.error("insert custom_node into construct function ast tree failed.") return None # create and insert node self._global_vars[field] = custom_obj node = Node(NodeType.UserCustom, targets, call_args, call_kwargs, ast_node) - if not self._insert_node(node): + if not self._insert_node(position, node): return None return node @@ -246,7 +240,7 @@ class SymbolTree: self._inputs.append(Input(name, input_type, default)) return True - def add_origin_field(self, handler_name: str, handler_cls: type, targets: [str], args: [Argument], + def add_origin_field(self, position, handler_name: str, handler_cls: type, targets: [str], args: [Argument], kwargs: {str: Argument}, ast_node: ast.AST, target_type: str = "", attribute: {str, object}=None) -> Optional[Node]: if targets is None: @@ -254,7 +248,7 @@ class SymbolTree: targets = SymbolTree._convert_targets(targets) assert len(handler_name) > 0 node = Node(NodeType.UserCustom, targets, args, kwargs, ast_node, attribute) - if not self._insert_node(node, False): + if not self._insert_node(position, node, False): return None return node @@ -350,8 +344,6 @@ class SymbolTree: for i in range(0, len(self._construct_func_ast.body)): body = self._construct_func_ast.body[i] if node.has_same_ast(body): - if self._insert_before: - i += 1 assert i <= len(self._construct_func_ast.body) return i return None @@ -364,13 +356,13 @@ class SymbolTree: self._nodes[node_name] = node return True - def _insert_node(self, node: Node, insert_to_ast: bool = True) -> bool: + def _insert_node(self, position, node: Node, insert_to_ast: bool = True) -> bool: if not self._append_node2nodes(node): return False - if self._insert_before: - self._insert_point.insert_before(node) + if position[0]: + position[1].insert_before(node) else: - self._insert_point.insert_after(node) + position[1].insert_after(node) # if not insert_to_ast: # return True # index = self._find_node_index(self._insert_point) @@ -457,7 +449,7 @@ class SymbolTree: def _add_head_node(self) -> Optional[Node]: node = Node(NodeType.Unknown, None, None, None) - if not self._insert_node(node, False): + if not self._insert_node((True, self._root), node, False): return None return node -- Gitee From b8a51e1039059e84c3290a7c602db338dcc46fc5 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Thu, 10 Feb 2022 11:28:45 +0800 Subject: [PATCH 03/32] merge add_cell and add_object interfaces --- .../mindspore/rewrite_experiment/__init__.py | 6 +- .../rewrite_experiment/ast_parse/resolver.py | 6 +- .../rewrite_experiment/common/__init__.py | 3 +- .../rewrite_experiment/common/node_info.py | 77 +++++++++++++++++ .../rewrite_experiment/example/test_lenet.py | 20 +++-- .../mindspore/rewrite_experiment/node.py | 84 ++++++++----------- .../rewrite_experiment/pattern_engine.py | 35 ++++---- .../mindspore/rewrite_experiment/rewrite.py | 69 ++++++++++----- .../rewrite_experiment/symbol_tree.py | 81 ++++++++++++++++-- 9 files changed, 275 insertions(+), 106 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/common/node_info.py diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py index 4ad2dff4336..50f567508b5 100644 --- a/mindspore/python/mindspore/rewrite_experiment/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -1,6 +1,8 @@ from .symbol_tree import SymbolTree from .node import Node from .rewrite import Rewrite -from .common import Argument, ArgType +from .common import Argument, ArgType, NodeType, NodeInfo +from .pattern_engine import PatternEngine, PatternNode, VarNode -__all__ = ["SymbolTree", "Rewrite", "Node", "ArgType", "Argument"] +__all__ = ["SymbolTree", "Rewrite", "Node", "ArgType", "Argument", "PatternEngine", "PatternNode", "VarNode", + "NodeType", "NodeInfo"] diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py index c1d5dba4aea..cc136b59e80 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py @@ -135,7 +135,7 @@ class Resolver: if body is None: continue if isinstance(body, AssignSymbol): - target = body.get_assign_targets()[0].value + target = Argument.create_naming_arg(body.get_assign_targets()[0].value) call = body.get_assign_value() assert isinstance(call, CallSymbol) func = Resolver._get_real_func(call.get_call_func().value) @@ -148,10 +148,10 @@ class Resolver: raise RuntimeError("kwargs in construct function assign is unsupported") cell_type = Resolver._get_cell_type(func) # todo call attribute_resolver to resolve attributes - ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, cell_type, [target], call_args, + ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, None, [target], call_args, {}, body.symbol_ast) if ret is None: - raise RuntimeError("add_origin_field failed") + raise RuntimeError("add_origin_field failed: ", ) elif body.symbol_type() == SymbolType.return_: symbol_tree.set_output_ast(body.symbol_ast) ret = symbol_tree.update_output([body.get_return_value().value]) diff --git a/mindspore/python/mindspore/rewrite_experiment/common/__init__.py b/mindspore/python/mindspore/rewrite_experiment/common/__init__.py index e451cd0823a..9e44a294e27 100644 --- a/mindspore/python/mindspore/rewrite_experiment/common/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/common/__init__.py @@ -1,4 +1,5 @@ from .ast_modifier import AstModifier from .argument import Argument, ArgType +from .node_info import NodeInfo, NodeType -__all__ = ["AstModifier", "Argument", "ArgType"] +__all__ = ["AstModifier", "Argument", "ArgType", "NodeInfo", "NodeType"] diff --git a/mindspore/python/mindspore/rewrite_experiment/common/node_info.py b/mindspore/python/mindspore/rewrite_experiment/common/node_info.py new file mode 100644 index 00000000000..54e2cf657f8 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/common/node_info.py @@ -0,0 +1,77 @@ +from enum import Enum + +from typing import Optional +from .argument import Argument, ArgType + + +class NodeType(Enum): + Unknown = 0 + # buildin? + CallCell = 1 # call cell object + CallMethod = 2 # method in cell + CallFunction = 3 # subclass of primitive + + UserCustom = 4 + Input = 5 + Output = 6 + Graph = 7 + + +class NodeInfo: + def __init__(self, node_type: NodeType, field: str = None, field_type: Optional[type] = None, func: Argument = None, + construct_args: [Argument] = None, construct_kwargs: {str: Argument} = None, targets: [str] = None, + target_type: Optional[type] = None, call_args: [Argument] = None, call_kwargs: {str: Argument} = None, + extra: {} = None): + self.node_type = node_type + self.field: str = field + self.field_type: Optional[type] = field_type + assert func is not None + self.func = func + if construct_args is None: + self.construct_args: [Argument] = [] + else: + self.construct_args: [Argument] = construct_args + if construct_kwargs is None: + self.construct_kwargs: {str: Argument} = {} + else: + self.construct_kwargs: {str: Argument} = construct_kwargs + if targets is None: + self.targets: [str] = None + else: + self.targets: [str] = targets + self.target_type: Optional[type] = target_type + if call_args is None: + self.call_args: [Argument] = [] + else: + self.call_args: [Argument] = call_args + if call_kwargs is None: + self.call_kwargs: {str: Argument} = {} + else: + self.call_kwargs: {str: Argument} = call_kwargs + if extra is None: + self.extra = {} + else: + self.extra = extra + + @classmethod + def create_cell_info(cls, field: str = None, cell_type: Optional[type] = None, construct_args: [Argument] = None, + construct_kwargs: {str: Argument} = None, targets: [str] = None, + target_type: Optional[type] = None, call_args: [Argument] = None, + call_kwargs: {str: Argument} = None) -> 'NodeInfo': + class_name = cell_type.__name__ + return cls(NodeType.CallCell, field, cell_type, Argument(ArgType.NamingArg, "mindspore.nn", class_name), + construct_args, construct_kwargs, targets, target_type, call_args, call_kwargs) + + @classmethod + def create_function_info(cls, field: str = None, func_type: Optional[type] = None, targets: [str] = None, + target_type: Optional[type] = None, call_args: [Argument] = None, + call_kwargs: {str: Argument} = None) -> 'NodeInfo': + return cls(NodeType.CallFunction, field, func_type, None, [], {}, targets, target_type, call_args, call_kwargs) + + @classmethod + def create_object_info(cls, object, field: str = None, obj_type: Optional[type] = None, targets: [str] = None, + target_type: Optional[type] = None, call_args: [Argument] = None, + call_kwargs: {str: Argument} = None) -> 'NodeInfo': + return cls(NodeType.UserCustom, field, obj_type, Argument(ArgType.NamingArg, "global_vars", "get"), + [Argument(ArgType.StringArg, "", field)], {}, targets, target_type, call_args, call_kwargs, + {"object": object}) diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py index 79f07d09ce4..87ae09dc257 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py @@ -26,12 +26,10 @@ def transform(rw: Rewrite): target = str(targets[0]) if target == "self_flatten": position = rw.before(node) - custom_in_channel = 16 - construct_args=[Argument.create_custom_arg(custom_in_channel), Argument.create_imm_arg(16), - Argument.create_imm_arg(3)] - ret = rw.add_cell(position=position, cell_type=nn.Conv2d, field='conv_new', construct_args=construct_args, - construct_kwargs={}, targets=['nx1'], target_type="", - call_args=[Argument.create_naming_arg('x')], call_kwargs={}) + new_conv = nn.Conv2d(16, 16, 3) + new_conv_node = Rewrite.create_node(new_conv, targets=[Argument.create_naming_arg('nx1')], target_type="", + name='new_conv') + ret = rw.insert(position, new_conv_node, field='conv_new', args=[Argument.create_naming_arg('x')]) if ret is None: raise RuntimeError("add_cell failed") break @@ -45,8 +43,10 @@ def transform(rw: Rewrite): if target == "self_relu_3": position = rw.before(node) custom_cell = MyCell() - ret = rw.add_object(position=position, custom_obj=custom_cell, field='my_cell', targets=['nx2'], - target_type="", call_args=[Argument.create_naming_arg('nx3')], call_kwargs={}) + new_custom_node = Rewrite.create_node(custom_cell, targets=[Argument.create_naming_arg('nx2')], + target_type="", args=[Argument.create_naming_arg('nx3')], + name='my_cell') + ret = rw.insert(position, new_custom_node, field='my_cell') if ret is None: raise RuntimeError("add_custom_node failed") break @@ -58,7 +58,9 @@ def transform(rw: Rewrite): assert targets[0].type == ArgType.NamingArg target = str(targets[0]) if target == "nx2": - rw.update_arg(node, 0, "nx1") + ret = rw.update_arg(node, 0, "nx1") + if not ret: + raise RuntimeError("Update arg failed") break rewrite.dump("after update_arg") rw.set_output(0, "nx2") diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index a5b95c0ab3f..acd2d53b077 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -13,24 +13,11 @@ # limitations under the License. # ============================================================================ import ast -from enum import Enum -from typing import Optional, Union - -from .common import Argument, ArgType, AstModifier - - -class NodeType(Enum): - Unknown = 0 - # buildin? - CallCell = 1 # call cell object - CallMethod = 2 # method in cell - CallFunction = 3 # subclass of primitive - UserCustom = 4 - Input = 5 - Output = 6 - Graph = 7 +from typing import Optional, Union +from .common import Argument, ArgType, AstModifier, NodeType, NodeInfo +from ..nn import Cell global_vars_name = "global_vars" origin_network_field_name = "_handler" @@ -38,19 +25,21 @@ origin_network_key = "handler" class Node: - def __init__(self, node_type: NodeType, targets: [Argument], args: [Argument], kwargs: [Argument], - ast_node: Optional[ast.AST] = None, attribute=None, name: str = ""): + def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], attributes: {str: object}, targets: [Argument], + args: [Argument], kwargs: {str: Argument}, name: str, op): if node_type not in {NodeType.CallCell, NodeType.Output, NodeType.UserCustom, NodeType.CallMethod, NodeType.Unknown}: raise RuntimeError("Only support CallCell, UserCustom, CallMethod and Output now") + assert attributes is not None self._node_type: NodeType = node_type - if name == "": - self._name = Node._generate_node_name(targets, args, kwargs) - else: - self._name = name + self._ast_node: Optional[ast.AST] = ast_node + self._attribute: {str, object} = attributes + self._op = op + self._op_type: type = type(op) + self._name = name self._targets: [Argument] = targets self._args: [Argument] = args - self._kwargs: [Argument] = kwargs + self._kwargs: {str: Argument} = kwargs # edge of node self._inputs: [Node] = [] # position in graph nodes list @@ -58,24 +47,16 @@ class Node: self._prev: Optional[Node] = None self._next: Optional[Node] = None self._update_inputs() - self._ast_node: Optional[ast.AST] = ast_node - if attribute is None: - self._attribute: {str, object} = attribute - else: - self._attribute: {str, object} = {} - - @staticmethod - def _generate_node_name(targets: [Argument], args: [Argument], kwargs: [Argument]) ->str: - if targets is None: - if args is None and kwargs is None: - return "head-node" - else: - return "return-node" - else: - if len(targets) == 0: - return "illegal-node" - else: - return str(targets[0].name) + + @classmethod + def create_by_cell(cls, cell: Cell, ast_node: Optional[ast.AST], attributes: {str: object}, targets: [Argument], + target_type: str = "", args: [Argument] = None, kwargs: {str: Argument} = None, name: str = ""): + if args is None: + args = [] + if kwargs is None: + kwargs = {} + assert attributes is not None + return cls(NodeType.CallCell, ast_node, attributes, targets, args, kwargs, name, cell) def get_prev(self) -> 'Node': return self._prev @@ -197,14 +178,23 @@ class Node: def get_targets(self) -> [Argument]: return self._targets + def set_targets(self, targets: [Argument]): + self._targets = targets + def get_name(self) -> str: return self._name + def set_name(self, name: str): + self._name = name + def get_node_type(self) -> NodeType: return self._node_type - # def get_type(self) -> type: - # return self._stmt.cls + def get_op_type(self) -> type: + return self._op_type + + def get_op(self): + return self._op def get_args(self) -> [Argument]: return self._args @@ -223,14 +213,14 @@ class Node: assert index <= len(self._args) self._args.insert(index, arg) - def get_kwargs(self) -> [Argument]: + def get_kwargs(self) -> {str: Argument}: return self._kwargs + def set_kwargs(self, kwargs: {str: Argument}): + self._kwargs = kwargs + def set_attribute(self, key: str, value): self._attribute[key] = value def get_attribute(self, key: str): return self._attribute.get(key) - - def get_code(self): - return self._stmt.get_code() diff --git a/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py b/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py index f87cb3d5205..65389128b13 100644 --- a/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py +++ b/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py @@ -16,9 +16,10 @@ from typing import Tuple, Union, List, Type, Optional from collections import OrderedDict -from .symbol_tree import SymbolTree -from .node import Node, NodeType -from mindspore.nn.cell import Cell +from .rewrite import Rewrite +from .node import Node +from .common import NodeType, NodeInfo +from mindspore.nn import Cell from mindspore import log as logger @@ -146,7 +147,7 @@ class PatternNode: """ Getter of type. """ - return self._type + return self._typelllllll class VarNode(PatternNode): @@ -161,6 +162,12 @@ class VarNode(PatternNode): return node is not None +# support return multi-node +class Replacement: + def __call__(self, matched: OrderedDict) -> [Node]: + raise NotImplementedError + + class PatternEngine: """ PatternEngine is define how to transform a graph by PattenNode. @@ -189,7 +196,7 @@ class PatternEngine: return self._pattern - def apply(self, graph: SymbolTree) -> bool: + def apply(self, rewrite: Rewrite) -> bool: """ Apply current pattern to a graph. @@ -200,7 +207,7 @@ class PatternEngine: If graph been changed. """ - root: Node = graph.get_root() + root: Node = rewrite.get_root() changed = False # IR match queue: [Node] = [root] @@ -220,13 +227,13 @@ class PatternEngine: if new_node is None: # return None to remove changed = True for key in matched_dict: - graph.erase_node(matched_dict[key]) + rewrite.erase_node(matched_dict[key]) elif new_node == cur_node: # return origin Node for do nothing pass else: # return Node to insert or replace (new Node no need to set inputs and outputs) # todo if we need to support _process_chain or _process_tree return multi-node changed = True - graph.replace_node(matched_list, new_node) + rewrite.replace_node(matched_list, new_node) node_inputs = new_node.get_inputs() for node_input in node_inputs: queue.append(node_input) @@ -272,7 +279,7 @@ class PatternEngine: New node created from matched result. """ - pass + raise NotImplementedError @staticmethod def _merge_ordered_dict(dict1: OrderedDict, dict2: OrderedDict) -> OrderedDict: @@ -305,12 +312,12 @@ class PatternEngine: """ # todo: Recurse into subgraph node. Depend on subgraph node definition - if node.get_node_type() != NodeType.call_cell: - logger.debug("Pattern match failed: node(%s) is not a cell", node.get_field()) + if node.get_node_type() != NodeType.CallCell: + logger.debug("Pattern match failed: node(%s) is not a cell", str(node)) return False, OrderedDict() if not pattern.match(node): - logger.debug("Pattern match failed: node(%s)'s type is %s while pattern type is %s", node.get_field(), - node.get_cell_type(), pattern.type()) + logger.debug("Pattern match failed: node(%s)'s type is %s while pattern type is %s", str(node), + node.get_op_type(), pattern.type()) return False, OrderedDict() if isinstance(pattern, VarNode): return True, OrderedDict() @@ -320,7 +327,7 @@ class PatternEngine: if input_num == 0: return True, OrderedDict({pattern.name(): node}) if input_num != len(cur_inputs): - logger.debug("Pattern match failed: node(%s)'s has %d inputs while pattern has %d inputs", node.get_field(), + logger.debug("Pattern match failed: node(%s)'s has %d inputs while pattern has %d inputs", str(node), len(node.get_inputs()), input_num) return False, OrderedDict() result = OrderedDict() diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index 5ae8f701283..b7f96233324 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import ast + from typing import Union, Optional from .node import Node from .ast_parse import AstParse from mindspore.nn import Cell -from .common.argument import Argument +from mindspore.ops import Primitive +from .common import Argument class Rewrite: @@ -34,25 +37,47 @@ class Rewrite: def after(self, node_or_name: Union[Node, str]): return self._symbol_tree.insert_after(node_or_name) - # self.'field': 'cell_type' = mindspore.nn.'cell_type'(*'construct_args', **'construct_kwargs') - # 'targets': 'target_type' = self.'field'(*'call_args', **'call_kwargs') - def add_cell(self, position, cell_type: type, field: str = None, construct_args: [Argument] = None, - construct_kwargs: {str: Argument} = None, targets: [str] = None, target_type: str = "", - call_args: [Argument] = None, call_kwargs: {str: Argument} = None) -> Optional[Node]: - # todo call ast_parser and attributes_resolver to resolve attributes - attribute = {} - return self._symbol_tree.add_cell(position, cell_type, field, construct_args, construct_kwargs, targets, - target_type, call_args, call_kwargs, attribute) - - def add_function(self, position, *args, **kwargs): - raise NotImplementedError - - # self.'field': 'custom_obj_type' = global_vars.get('field') - # 'targets': 'target_type' = self.'field'(*'call_args', **'call_kwargs') - def add_object(self, position, custom_obj: Cell, field: str, targets: [str], target_type: str, - call_args: [Argument], call_kwargs: {str: Argument}) -> Optional[Node]: - return self._symbol_tree.add_object(position, custom_obj, field, targets, call_args, call_kwargs, - target_type) + @staticmethod + def create_node(op, targets: [Argument], target_type: str = None, args: [Argument] = None, + kwargs: {str: Argument} = None, name: str = None) -> Node: + if isinstance(op, Cell): + # todo create ast from op: ast can be created when insert into symbol_tree + ast_node = None + # todo resolve attributes from op + attributes = {} + return Node.create_by_cell(op, ast_node, attributes, targets, target_type, args, kwargs, name) + elif isinstance(op, Primitive): + raise RuntimeError("Primitive op will be support in near future.") + else: + raise RuntimeError("Only support Cell op or Primitive op!") + + def insert(self, position, node: Node, field: str = None, args: [Argument] = None, kwargs: {str: Argument} = None) \ + -> Optional[Node]: + if args is not None: + node.set_args(args) + if kwargs is not None: + node.set_kwargs(kwargs) + return self._symbol_tree.insert_node(position, node, field) + + # # self.'field': 'cell_type' = mindspore.nn.'cell_type'(*'construct_args', **'construct_kwargs') + # # 'targets': 'target_type' = self.'field'(*'call_args', **'call_kwargs') + # def add_cell(self, position, cell_type: type, field: str = None, construct_args: [Argument] = None, + # construct_kwargs: {str: Argument} = None, targets: [str] = None, target_type: str = "", + # call_args: [Argument] = None, call_kwargs: {str: Argument} = None) -> Optional[Node]: + # # todo call ast_parser and attributes_resolver to resolve attributes + # attribute = {} + # return self._symbol_tree.add_cell(position, cell_type, field, construct_args, construct_kwargs, targets, + # target_type, call_args, call_kwargs, attribute) + # + # def add_function(self, position, *args, **kwargs): + # raise NotImplementedError + # + # # self.'field': 'custom_obj_type' = global_vars.get('field') + # # 'targets': 'target_type' = self.'field'(*'call_args', **'call_kwargs') + # def add_object(self, position, custom_obj: Cell, field: str, targets: [str], target_type: str, + # call_args: [Argument], call_kwargs: {str: Argument}) -> Optional[Node]: + # return self._symbol_tree.add_object(position, custom_obj, field, targets, call_args, call_kwargs, + # target_type) def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: return self._symbol_tree.erase_node(node_or_name) @@ -60,8 +85,8 @@ class Rewrite: def set_output(self, index: int, return_value: str) -> Optional[Node]: return self._symbol_tree.set_output(return_value, index) - def update_arg(self, node: Node, index: int, arg: str): - node.update_arg(index, arg) + def update_arg(self, node: Node, index: int, arg: str) -> bool: + return node.update_arg(index, arg) def update_arg_by_node(self, node_to_update: Node, arg_idx: int, node_to_link: 'Node', out_idx: Optional[int] = None): diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index e171344f331..4413801d430 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -21,10 +21,10 @@ from typing import Optional, Union, Tuple import astpretty import astunparse -from .node import Node, NodeType, global_vars_name +from .node import Node, global_vars_name from mindspore.nn import Cell from mindspore import log as logger -from .common import AstModifier, Argument, ArgType +from .common import AstModifier, Argument, ArgType, NodeType class Input: @@ -110,6 +110,19 @@ class SymbolTree: # return None # return node + @staticmethod + def _generate_node_name(targets: [Argument], args: [Argument], kwargs: {str: Argument}) ->str: + if targets is None: + if args is None and kwargs is None: + return "head-node" + else: + return "return-node" + else: + if len(targets) == 0: + return "illegal-node" + else: + return str(targets[0].name) + # can only erase isolated node def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: if isinstance(node_or_name, Node): @@ -129,6 +142,55 @@ class SymbolTree: break return node + def insert_node(self, position, node: Node, field: str = None) -> Optional[Node]: + if field is None: + field = self._generate_new_field_name() + # modify init function + ast_node = AstModifier.insert_assign_to_function(self._init_func_ast, + targets=[Argument(ArgType.NamingArg, "self", field)], + expr=Argument(ArgType.NamingArg, "global_vars", "get"), + args=[Argument(ArgType.StringArg, "", field)]) + if ast_node is None: + logger.error("insert custom_node into init function ast tree failed.") + return None + # check position + index = self._find_node_index(position[1]) + assert index is not None + # process node name + node.set_name(self._generate_node_name(node.get_targets(), node.get_args(), node.get_kwargs())) + # process args of node + call_args, new_arg_nodes = self._handle_custom_object_in_args(self._init_func_ast, node.get_args()) + for node in new_arg_nodes: + if not self._insert_node(position, node): + logger.error("insert custom object node into symbol_tree failed.") + return None + node.set_args(call_args) + # process kwargs of node + call_kwargs, new_kwarg_nodes = self._handle_custom_object_in_kwargs(self._init_func_ast, node.get_kwargs()) + for node in new_kwarg_nodes: + if not self._insert_node(position, node): + logger.error("insert custom object node into symbol_tree failed.") + return None + node.set_kwargs(call_kwargs) + # process targets of node + targets = node.get_targets() + targets = SymbolTree._convert_targets(targets) + node.set_targets(targets) + # modify construct function + ast_node = AstModifier.insert_assign_to_function(self._construct_func_ast, targets=targets, + expr=Argument(ArgType.NamingArg, "self", field), + args=call_args, kwargs=call_kwargs, + index_ast=position[1].get_ast(), insert_before=position[0]) + if ast_node is None: + logger.error("insert custom_node into construct function ast tree failed.") + return None + node.set_ast(ast_node) + # create and insert node + self._global_vars[field] = node.get_op() + if not self._insert_node(position, node): + return None + return node + def add_cell(self, position, cell_type: type, field: str = None, construct_args: [Argument] = None, construct_kwargs: {str: Argument} = None, targets: [str] = None, target_type: str = "", call_args: [Argument] = None, call_kwargs: {str: Argument} = None, @@ -172,7 +234,7 @@ class SymbolTree: logger.error("insert cell into construct function ast tree failed.") return None # create and insert node - node = Node(NodeType.CallCell, targets, call_args, call_kwargs, ast_node, attribute) + node = Node(NodeType.CallCell, targets, call_args, call_kwargs, ast_node, attribute, "", cell_type) if not self._insert_node(position, node): return None return node @@ -214,7 +276,7 @@ class SymbolTree: return None # create and insert node self._global_vars[field] = custom_obj - node = Node(NodeType.UserCustom, targets, call_args, call_kwargs, ast_node) + node = Node(NodeType.UserCustom, targets, call_args, call_kwargs, ast_node, {}, "", type(custom_obj)) if not self._insert_node(position, node): return None return node @@ -240,14 +302,17 @@ class SymbolTree: self._inputs.append(Input(name, input_type, default)) return True - def add_origin_field(self, position, handler_name: str, handler_cls: type, targets: [str], args: [Argument], + def add_origin_field(self, position, handler_name: str, op, targets: [Argument], args: [Argument], kwargs: {str: Argument}, ast_node: ast.AST, target_type: str = "", attribute: {str, object}=None) -> Optional[Node]: if targets is None: targets = [self._generate_new_target_name()] targets = SymbolTree._convert_targets(targets) assert len(handler_name) > 0 - node = Node(NodeType.UserCustom, targets, args, kwargs, ast_node, attribute) + if attribute is None: + attribute = {} + node_name = SymbolTree._generate_node_name(targets, args, kwargs) + node = Node(NodeType.UserCustom, ast_node, attribute, targets, args, kwargs, node_name, op) if not self._insert_node(position, node, False): return None return node @@ -443,12 +508,12 @@ class SymbolTree: def _add_output(self, return_values: [str]) -> Optional[Node]: real_return_values = self._convert_targets(return_values) - node = Node(NodeType.Output, None, real_return_values, None) + node = Node(NodeType.Output, None, {}, real_return_values, [], {}, "output", None) self._append_node2nodes(node) return node def _add_head_node(self) -> Optional[Node]: - node = Node(NodeType.Unknown, None, None, None) + node = Node(NodeType.Unknown, None, {}, None, [], {}, "head", None) if not self._insert_node((True, self._root), node, False): return None return node -- Gitee From 4c421bd515367f994de471a6b1a55f843605ced3 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Thu, 10 Feb 2022 17:12:14 +0800 Subject: [PATCH 04/32] add AstManager --- .../mindspore/rewrite_experiment/__init__.py | 6 ++--- .../ast_parse/symbol_table.py | 17 ++++-------- .../rewrite_experiment/common/__init__.py | 4 +-- .../rewrite_experiment/common/ast_manager.py | 19 ++++++++++++++ .../rewrite_experiment/common/node_info.py | 13 ---------- .../mindspore/rewrite_experiment/node.py | 16 +++++++++++- .../rewrite_experiment/pattern_engine.py | 3 +-- .../mindspore/rewrite_experiment/rewrite.py | 7 ++++- .../rewrite_experiment/symbol_tree.py | 26 ++----------------- 9 files changed, 53 insertions(+), 58 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/common/ast_manager.py diff --git a/mindspore/python/mindspore/rewrite_experiment/__init__.py b/mindspore/python/mindspore/rewrite_experiment/__init__.py index 50f567508b5..6eb78de2875 100644 --- a/mindspore/python/mindspore/rewrite_experiment/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/__init__.py @@ -1,8 +1,8 @@ from .symbol_tree import SymbolTree -from .node import Node +from .node import Node, NodeType from .rewrite import Rewrite -from .common import Argument, ArgType, NodeType, NodeInfo +from .common import Argument, ArgType from .pattern_engine import PatternEngine, PatternNode, VarNode __all__ = ["SymbolTree", "Rewrite", "Node", "ArgType", "Argument", "PatternEngine", "PatternNode", "VarNode", - "NodeType", "NodeInfo"] + "NodeType"] diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_table.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_table.py index 4263c145824..72066655a7f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_table.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_table.py @@ -15,7 +15,6 @@ from types import FunctionType from typing import Union, Optional -import inspect import ast import astpretty @@ -23,21 +22,15 @@ import mindspore.nn as nn from mindspore.ops.primitive import Primitive from .symbol import Symbol, SymbolType from .value_node import ValueNode +from ..common import AstManager class SymbolTable: def __init__(self, network: Optional[Union[nn.Cell, Primitive, FunctionType]] = None): - if network is None: - self._table: {str: ValueNode} = {} - return - if not isinstance(network, nn.Cell): - raise RuntimeError("Only support network with Cell type now") - - network = network - net_cls = type(network) - network_str = inspect.getsource(net_cls) - self._ast_root: ast.AST = ast.parse(network_str) - name = net_cls.__name__ + assert network is not None + self._ast_root: ast.AST = AstManager.instance().get_ast() + assert self._ast_root is not None + name = type(network).__name__ root_symbol = Symbol(self._ast_root, "", "Module(" + name + ")", SymbolType.module) self._table: {str: ValueNode} = {root_symbol.get_full_name_with_scope(): ValueNode(root_symbol)} diff --git a/mindspore/python/mindspore/rewrite_experiment/common/__init__.py b/mindspore/python/mindspore/rewrite_experiment/common/__init__.py index 9e44a294e27..a67403c8621 100644 --- a/mindspore/python/mindspore/rewrite_experiment/common/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/common/__init__.py @@ -1,5 +1,5 @@ from .ast_modifier import AstModifier from .argument import Argument, ArgType -from .node_info import NodeInfo, NodeType +from .ast_manager import AstManager -__all__ = ["AstModifier", "Argument", "ArgType", "NodeInfo", "NodeType"] +__all__ = ["AstModifier", "Argument", "ArgType", "AstManager"] diff --git a/mindspore/python/mindspore/rewrite_experiment/common/ast_manager.py b/mindspore/python/mindspore/rewrite_experiment/common/ast_manager.py new file mode 100644 index 00000000000..a97ea46920c --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/common/ast_manager.py @@ -0,0 +1,19 @@ +import ast +from typing import Optional + + +class AstManager: + def __init__(self): + self._ast_root: Optional[ast.AST] = None + + @classmethod + def instance(cls): + if not hasattr(AstManager, "_instance"): + AstManager._instance = AstManager() + return AstManager._instance + + def update_ast(self, ast_root: ast.AST): + self._ast_root = ast_root + + def get_ast(self): + return self._ast_root diff --git a/mindspore/python/mindspore/rewrite_experiment/common/node_info.py b/mindspore/python/mindspore/rewrite_experiment/common/node_info.py index 54e2cf657f8..77ba72fc2e1 100644 --- a/mindspore/python/mindspore/rewrite_experiment/common/node_info.py +++ b/mindspore/python/mindspore/rewrite_experiment/common/node_info.py @@ -4,19 +4,6 @@ from typing import Optional from .argument import Argument, ArgType -class NodeType(Enum): - Unknown = 0 - # buildin? - CallCell = 1 # call cell object - CallMethod = 2 # method in cell - CallFunction = 3 # subclass of primitive - - UserCustom = 4 - Input = 5 - Output = 6 - Graph = 7 - - class NodeInfo: def __init__(self, node_type: NodeType, field: str = None, field_type: Optional[type] = None, func: Argument = None, construct_args: [Argument] = None, construct_kwargs: {str: Argument} = None, targets: [str] = None, diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index acd2d53b077..931fbf3440d 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -13,10 +13,11 @@ # limitations under the License. # ============================================================================ import ast +from enum import Enum from typing import Optional, Union -from .common import Argument, ArgType, AstModifier, NodeType, NodeInfo +from .common import Argument, ArgType, AstModifier from ..nn import Cell global_vars_name = "global_vars" @@ -24,6 +25,19 @@ origin_network_field_name = "_handler" origin_network_key = "handler" +class NodeType(Enum): + Unknown = 0 + # buildin? + CallCell = 1 # call cell object + CallMethod = 2 # method in cell + CallFunction = 3 # subclass of primitive + + UserCustom = 4 + Input = 5 + Output = 6 + Graph = 7 + + class Node: def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], attributes: {str: object}, targets: [Argument], args: [Argument], kwargs: {str: Argument}, name: str, op): diff --git a/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py b/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py index 65389128b13..b888226b379 100644 --- a/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py +++ b/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py @@ -17,8 +17,7 @@ from typing import Tuple, Union, List, Type, Optional from collections import OrderedDict from .rewrite import Rewrite -from .node import Node -from .common import NodeType, NodeInfo +from .node import Node, NodeType from mindspore.nn import Cell from mindspore import log as logger diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index b7f96233324..168e7f9fce9 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ import ast +import inspect from typing import Union, Optional @@ -20,12 +21,16 @@ from .node import Node from .ast_parse import AstParse from mindspore.nn import Cell from mindspore.ops import Primitive -from .common import Argument +from .common import Argument, AstManager class Rewrite: def __init__(self, network: Cell): + if not isinstance(network, Cell): + raise RuntimeError("Only support network with Cell type now") self._ori_net = network + network_str = inspect.getsource(type(network)) + AstManager.instance().update_ast(ast.parse(network_str)) self._symbol_tree, self._stb = AstParse.parse(self._ori_net) def nodes(self) -> {}: diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index 4413801d430..def3c0e3b1a 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -21,10 +21,10 @@ from typing import Optional, Union, Tuple import astpretty import astunparse -from .node import Node, global_vars_name +from .node import Node, global_vars_name, NodeType from mindspore.nn import Cell from mindspore import log as logger -from .common import AstModifier, Argument, ArgType, NodeType +from .common import AstModifier, Argument, ArgType class Input: @@ -88,28 +88,6 @@ class SymbolTree: raise RuntimeError("Unsupported node_or_name: ", node_or_name) return False, node - # def create_node(self, node_type: NodeType, construct_fn: str, targets: [str], field: Optional[str], - # construct_args: [], construct_kwargs: {str: object}, call_args: [], call_kwargs: {str: object}, - # target_type: str, cls, attribute=None) -> Optional[Node]: - # """ - # recommend to use add_cell, add_custom_node, add_placeholder and update_output instead - # """ - # - # construct_args = self._convert_args(construct_args) - # construct_kwargs = self._convert_kwargs(construct_kwargs) - # call_args = self._convert_args(call_args) - # call_kwargs = self._convert_kwargs(call_kwargs) - # if targets is None: - # targets = [self._generate_new_target_name()] - # if field is None: - # field = self._generate_new_field_name() - # stmt = Statement(construct_fn, targets, field, construct_args, construct_kwargs, call_args, call_kwargs, - # target_type, cls) - # node = Node(node_type, stmt, attribute) - # if not self._insert_node(node): - # return None - # return node - @staticmethod def _generate_node_name(targets: [Argument], args: [Argument], kwargs: {str: Argument}) ->str: if targets is None: -- Gitee From bf7791037500b985e594cf10a91ccd5ec9c73e92 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Thu, 10 Feb 2022 17:12:46 +0800 Subject: [PATCH 05/32] add ast_transformers --- .../rewrite_experiment/ast_parse/ast_parse.py | 11 ++++++++++- .../ast_parse/ast_transformers/__init__.py | 3 +++ .../ast_transformers/flatten_recursive_call.py | 7 +++++++ 3 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py index a0b823fea4b..3d5cce293f2 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py @@ -13,8 +13,9 @@ # limitations under the License. # ============================================================================ from typing import Union, Tuple - +import ast from .symbol_table import SymbolTable +from ..common import AstManager from .processor import Processor from .processor_manager import ParserManager, ProcessorManager from .parser_register import ParserRegister @@ -30,9 +31,16 @@ from .symbol_transformers.clear_non_symbol_in_stb import ClearNonSymbolInSTB from .symbol_transformers.rename_construct_function_assign_target import RenameConstructFuncAssignTarget from .resolver import Resolver from ..symbol_tree import SymbolTree +from .ast_transformers import FlattenRecursiveCall class AstParse: + @staticmethod + def _ast_transform(ast_root: ast.AST) -> ast.AST: + flatten_recursive_call = FlattenRecursiveCall() + ast_root = flatten_recursive_call.generic_visit(ast_root) + return ast_root + @staticmethod def _parse(network: Union[nn.Cell, Primitive, FunctionType]) -> SymbolTable: stb = SymbolTable(network) @@ -66,6 +74,7 @@ class AstParse: @staticmethod def parse(network: Union[nn.Cell, Primitive, FunctionType]) -> Tuple[SymbolTree, SymbolTable]: + AstManager.instance().update_ast(AstParse._ast_transform(AstManager.instance().get_ast())) # parse ast to symbols until all symbols are not compilable stb = AstParse._parse(network) logger.warning("---------------------- After parse: %d", len(stb.get_symbols().values())) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py new file mode 100644 index 00000000000..85dce8e13eb --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py @@ -0,0 +1,3 @@ +from .flatten_recursive_call import FlattenRecursiveCall + +__all__ = ["FlattenRecursiveCall"] diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py new file mode 100644 index 00000000000..ed2ffe1a30b --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py @@ -0,0 +1,7 @@ +import ast +from ast import FunctionDef +from typing import Any + + +class FlattenRecursiveCall(ast.NodeTransformer): + pass -- Gitee From 6fc29f80577512bad535f0917a9bb29ecb70636f Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 10 Feb 2022 17:29:31 +0800 Subject: [PATCH 06/32] add dump add dump --- .../mindspore/rewrite_experiment/node.py | 10 ++ .../rewrite_experiment/symbol_tree.py | 21 ++- .../rewrite_experiment/symbol_tree_dumper.py | 131 ++++++++++++++++++ 3 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index 931fbf3440d..e7cf5427494 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -51,6 +51,7 @@ class Node: self._op = op self._op_type: type = type(op) self._name = name + self._field = None self._targets: [Argument] = targets self._args: [Argument] = args self._kwargs: {str: Argument} = kwargs @@ -201,6 +202,12 @@ class Node: def set_name(self, name: str): self._name = name + def get_field(self) -> str: + return self._field + + def set_field(self, field: str): + self._field = field + def get_node_type(self) -> NodeType: return self._node_type @@ -236,5 +243,8 @@ class Node: def set_attribute(self, key: str, value): self._attribute[key] = value + def get_attributes(self): + return self._attribute + def get_attribute(self, key: str): return self._attribute.get(key) diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index def3c0e3b1a..cc03c87a542 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -25,6 +25,7 @@ from .node import Node, global_vars_name, NodeType from mindspore.nn import Cell from mindspore import log as logger from .common import AstModifier, Argument, ArgType +from .symbol_tree_dumper import SymbolTreeDumper class Input: @@ -59,6 +60,19 @@ class SymbolTree: self._head = self._add_head_node() self._ori_cls_name = type(origin_network).__name__ self._opt_cls_name = self._ori_cls_name + "Opt" + self._origin_network = origin_network + + def get_inputs(self): + return self._inputs + + def get_head_node(self): + return self._head + + def get_opt_cls_name(self): + return self._opt_cls_name + + def get_origin_network(self): + return self._origin_network def nodes(self) -> {}: return self._nodes @@ -164,6 +178,7 @@ class SymbolTree: return None node.set_ast(ast_node) # create and insert node + node.set_field(field) self._global_vars[field] = node.get_op() if not self._insert_node(position, node): return None @@ -213,6 +228,7 @@ class SymbolTree: return None # create and insert node node = Node(NodeType.CallCell, targets, call_args, call_kwargs, ast_node, attribute, "", cell_type) + node.set_field(field) if not self._insert_node(position, node): return None return node @@ -255,6 +271,7 @@ class SymbolTree: # create and insert node self._global_vars[field] = custom_obj node = Node(NodeType.UserCustom, targets, call_args, call_kwargs, ast_node, {}, "", type(custom_obj)) + node.set_field(field) if not self._insert_node(position, node): return None return node @@ -291,6 +308,7 @@ class SymbolTree: attribute = {} node_name = SymbolTree._generate_node_name(targets, args, kwargs) node = Node(NodeType.UserCustom, ast_node, attribute, targets, args, kwargs, node_name, op) + node.set_field(handler_name) if not self._insert_node(position, node, False): return None return node @@ -339,7 +357,8 @@ class SymbolTree: return self._root def dump(self): - print(self.get_code()) + dump_st = SymbolTreeDumper(self) + dump_st.dump_symbol_tree() def get_code(self) -> str: ast.fix_missing_locations(self._module_ast) diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py new file mode 100644 index 00000000000..b7ca496abd7 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py @@ -0,0 +1,131 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dump symbol tree""" +import inspect + +from mindspore import log as logger +from .node import Node +from .common import Argument + + +class SymbolTreeDumper: + def __init__(self, symbol_tree): + self._symbol_tree = symbol_tree + self._dump_buffer = "" + self._dump_key2index = {} + + def _dump2file(self): + dump_file_name = "{}.ir".format(self._symbol_tree.get_opt_cls_name()) + f = open(dump_file_name, "w+") + f.write(self._dump_buffer) + f.close() + return True + + def _dump_reset(self): + self._dump_buffer = "" + self._dump_key2index = {} + + def _dump_global_info(self): + self._dump_buffer += f"#Graph entry : @cosntruct \n" + + def _dump_inputs(self): + inputs = self._symbol_tree.get_inputs() + self._dump_buffer += f"#Inputs num : {len(inputs)}\n" + for single_input in inputs: + input_name = f"%input_{single_input.name}" + assert single_input.name not in self._dump_key2index.keys() + self._dump_key2index[single_input.name] = input_name + self._dump_buffer += f"{input_name}\n" + self._dump_buffer += f"\n" + + def _dump_nodes(self): + # todo @construct? + self._dump_buffer += f"Symbol Tree @construct {{ \n" + node_no = -1 + + node: Node = self._symbol_tree.get_head_node().get_next() + while node is not None: + if node.get_name() == "output": + self._dump_buffer += f" Return(%{node_no}) \n" + self._dump_buffer += f" : (null) \n" + self._dump_buffer += f" # In file {inspect.getfile(type(self._symbol_tree.get_origin_network()))}" + + node = node.get_next() + continue + + node_no += 1 + self._dump_key2index[node.get_name()] = f"%{node_no}" + + targets = node.get_targets() + if not targets: + targets = [None] + op_type = node.get_op_type() + if hasattr(op_type, "__name__"): + op_type_name = op_type.__name__ + else: + if hasattr(type(op_type), "__name__"): + op_type_name = type(op_type).__name__ + else: + raise RuntimeError("op has no attr __name__") + self._dump_buffer += f" %{node_no}({targets[0]}) = {op_type_name}" + + args = node.get_args() + if args: + arg_str = f"" + for arg in args: + if isinstance(arg, str): + arg_name = arg + elif isinstance(arg, Argument): + arg_name = arg.name + else: + raise RuntimeError(f"arg type {type(arg)} of {arg} is not supported now") + + if arg_name in self._dump_key2index.keys(): + arg_str += f"{self._dump_key2index[arg_name]}, " + else: + logger.warning("arg not appears before") + arg_str += f"{arg_name}, " + self._dump_buffer += f"({arg_str[:-2]})" + + self._dump_buffer += f"{{instance name: {node.get_field()}}}" + + self._dump_buffer += f" attributes {{" + # todo attrs are currently None + attrs = node.get_attributes() + if attrs: + attrs_str = f"" + for attr in attrs: + assert type(attr) == str + attrs_str += f"{attr}, " + self._dump_buffer += attrs_str[:-2] + self._dump_buffer += f"}}\n" + + self._dump_buffer += f" : (null) -> (null)\n" + cls_real_path = inspect.getfile(node.get_op_type()) if node.get_op() else None + self._dump_buffer += f" # In file {cls_real_path}\n" + self._dump_buffer += f" # In file {inspect.getfile(type(self._symbol_tree.get_origin_network()))}\n" + + node = node.get_next() + self._dump_buffer += f"}}\n" + + def dump_symbol_tree(self, output_file: bool = False): + self._dump_reset() + self._dump_global_info() + self._dump_inputs() + self._dump_nodes() + if output_file: + self._dump2file() + else: + print(self._dump_buffer) -- Gitee From e2beb9f9007e6897c6e0146c59fee0157e156cb5 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Thu, 10 Feb 2022 19:14:41 +0800 Subject: [PATCH 07/32] code for node attribute --- .../rewrite_experiment/ast_parse/resolver.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py index cc136b59e80..f63077e7410 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py @@ -23,6 +23,7 @@ class Resolver: self._ori_cls_name = type(net).__name__ self._opt_cls_name = self._ori_cls_name + "Opt" self._module_symbol = self._get_module_symbol() + self._symbol_attribute = {} def resolve(self) -> SymbolTree: Resolver._add_import_to_module(self._module_symbol.symbol_ast) @@ -51,10 +52,11 @@ class Resolver: assert symbol_construct_fn is not None self._process_class_ast(class_symbol.symbol_ast) self._process_init_func_ast(symbol_init_fn.symbol_ast) + self._process_init_func_attribute() origin_network_key = "handler" symbol_tree = SymbolTree(net, self._module_symbol.symbol_ast, class_symbol.symbol_ast, symbol_init_fn.symbol_ast, symbol_construct_fn.symbol_ast, {origin_network_key: net}) - Resolver._process_construct_func_ast(symbol_construct_fn, symbol_tree) + self._process_construct_func_ast(symbol_construct_fn, symbol_tree) return symbol_tree def _get_module_symbol(self) -> Optional[ModuleSymbol]: @@ -115,8 +117,8 @@ class Resolver: assert ast_class_def.name == self._ori_cls_name ast_class_def.name = self._opt_cls_name - @staticmethod - def _process_construct_func_ast(symbol_construct_fn: FunctionSymbol, symbol_tree: SymbolTree): + #@staticmethod + def _process_construct_func_ast(self, symbol_construct_fn: FunctionSymbol, symbol_tree: SymbolTree): # resolve argument arguments = symbol_construct_fn.get_func_arguments() if isinstance(arguments, ArgumentsSymbol): @@ -147,9 +149,11 @@ class Resolver: if len(call.get_call_keywords()) > 0: raise RuntimeError("kwargs in construct function assign is unsupported") cell_type = Resolver._get_cell_type(func) + + attributes = self._symbol_attribute[repr(call.get_call_func())] # todo call attribute_resolver to resolve attributes ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, None, [target], call_args, - {}, body.symbol_ast) + attributes, body.symbol_ast) if ret is None: raise RuntimeError("add_origin_field failed: ", ) elif body.symbol_type() == SymbolType.return_: @@ -158,6 +162,23 @@ class Resolver: if ret is None: raise RuntimeError("update_output failed") + @staticmethod + def get_object_attribute(obj): + attributes = {} + for k, v in obj.__dict__.items(): + if k.startswith("_"): + continue + attributes[k] = v + return attributes + + def _process_init_func_attribute(self): + var_dict = self._origin_net.__dict__ + for key, value in var_dict["_cells"].items(): + attributes = Resolver.get_object_attribute(value) + attributes["cls"] = value.__class__ + logger.debug(f"key: {key}, attributes: {attributes}") + self._symbol_attribute["self." + key] = attributes + def _process_init_func_ast(self, ast_init_fn: ast.FunctionDef): self._modify_super_expr_of_init_func(ast_init_fn) Resolver._modify_arguments_of_init_func(ast_init_fn) -- Gitee From 825edcba1d85b3f29ff39abd72eb8927ed321f8a Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 11 Feb 2022 14:45:57 +0800 Subject: [PATCH 08/32] add ast flatten call add ast flatten call add ast flatten call add ast flatten call add ast flatten call --- .../flatten_recursive_call.py | 54 ++++++++++++++++++- .../example/test_ast_transformers.py | 23 ++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/example/test_ast_transformers.py diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py index ed2ffe1a30b..459a5d95546 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py @@ -1,7 +1,59 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""flatten recursive call""" import ast from ast import FunctionDef from typing import Any class FlattenRecursiveCall(ast.NodeTransformer): - pass + def _flatten_call_value(self, node: FunctionDef, call_value: ast.expr, new_target_names, function_index): + if isinstance(call_value, ast.Call): + args = call_value.args + for call_arg_index in range(0, len(args)): + arg = args[call_arg_index] + if not isinstance(arg, ast.Call): + continue + + if not isinstance(arg.func, ast.Attribute): + raise RuntimeError("func of call can only support ast.attribute now") + assert isinstance(arg.func.attr, str) + target_name = f"self_{arg.func.attr}" + + suffix = 0 + while target_name in new_target_names: + suffix += 1 + target_name = f"self_{arg.func.attr}_{suffix}" + new_target_names.append(target_name) + + new_assign_node = ast.Assign(targets=[ast.Name(id=target_name, ctx=ast.Store())], value=arg) + node.body.insert(function_index, new_assign_node) + args.pop(call_arg_index) + args.insert(call_arg_index, ast.Name(id=target_name, ctx=ast.Load())) + return True + return False + + def visit_FunctionDef(self, node: FunctionDef) -> Any: + changed = True + if node.name == "construct": + new_target_names = [] + while changed: + changed = False + for function_index in range(len(node.body) - 1, -1, -1): + child = node.body[function_index] + if isinstance(child, ast.Assign): + call_value = child.value + changed = changed or self._flatten_call_value(node, call_value, new_target_names, function_index) + return node diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_ast_transformers.py b/mindspore/python/mindspore/rewrite_experiment/example/test_ast_transformers.py new file mode 100644 index 00000000000..bc51edd6af0 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_ast_transformers.py @@ -0,0 +1,23 @@ +import inspect +import ast + +from mindspore.nn import Cell +from lenet import LeNet5 +from mindspore.rewrite_experiment.ast_parse.ast_parse import AstManager +from mindspore.rewrite_experiment.ast_parse import AstParse + + +class Rewrite: + def __init__(self, network: Cell): + if not isinstance(network, Cell): + raise RuntimeError("Only support network with Cell type now") + self._ori_net = network + network_str = inspect.getsource(type(network)) + AstManager.instance().update_ast(ast.parse(network_str)) + AstManager.instance().update_ast(AstParse._ast_transform(AstManager.instance().get_ast())) + print(ast.dump(AstManager.instance().get_ast(), indent=4)) + + +if __name__ == '__main__': + network = LeNet5(10) + rewrite = Rewrite(network) \ No newline at end of file -- Gitee From a64208c9a98e0a2d7a2764a94cd454d16914379a Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 11 Feb 2022 16:20:24 +0800 Subject: [PATCH 09/32] add fold binop add fold binop add fold binop add fold binop --- .../rewrite_experiment/ast_parse/ast_parse.py | 7 +- .../ast_parse/ast_transformers/__init__.py | 3 +- .../ast_parse/ast_transformers/const_fold.py | 51 +++++++++++++ .../flatten_recursive_call.py | 76 +++++++++++-------- 4 files changed, 100 insertions(+), 37 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/const_fold.py diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py index 3d5cce293f2..9cbeb21c4dd 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py @@ -31,14 +31,15 @@ from .symbol_transformers.clear_non_symbol_in_stb import ClearNonSymbolInSTB from .symbol_transformers.rename_construct_function_assign_target import RenameConstructFuncAssignTarget from .resolver import Resolver from ..symbol_tree import SymbolTree -from .ast_transformers import FlattenRecursiveCall +from .ast_transformers import FlattenRecursiveCall, FoldBinop class AstParse: @staticmethod def _ast_transform(ast_root: ast.AST) -> ast.AST: - flatten_recursive_call = FlattenRecursiveCall() - ast_root = flatten_recursive_call.generic_visit(ast_root) + transform_list = [FoldBinop(), FlattenRecursiveCall()] + for transformer in transform_list: + ast_root = transformer.transform(ast_root) return ast_root @staticmethod diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py index 85dce8e13eb..589e57574df 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py @@ -1,3 +1,4 @@ from .flatten_recursive_call import FlattenRecursiveCall +from .const_fold import FoldBinop -__all__ = ["FlattenRecursiveCall"] +__all__ = ["FlattenRecursiveCall", "FoldBinop"] diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/const_fold.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/const_fold.py new file mode 100644 index 00000000000..01b35bfd26a --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/const_fold.py @@ -0,0 +1,51 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""fold binop""" +import ast +from typing import Any +from mindspore import log as logger + + +class FoldBinop(ast.NodeTransformer): + def _replace_binop(self, node: ast.BinOp) -> Any: + if isinstance(node.left, ast.BinOp): + node.left = self._replace_binop(node.left) + + if isinstance(node.right, ast.BinOp): + node.right = self._replace_binop(node.right) + + if not isinstance(node.left, ast.Constant) or not isinstance(node.right, ast.Constant): + logger.warning("fold binop only support constant node now") + return node + + if not isinstance(node.op, (ast.Mult, ast.Add)): + logger.warning('fold binop only support add or mult now') + return node + + if isinstance(node.op, ast.Mult): + return ast.Constant(value=node.left.value * node.right.value) + + if isinstance(node.op, ast.Add): + return ast.Constant(value=node.left.value + node.right.value) + + return node + + def visit_BinOp(self, node: ast.BinOp) -> Any: + return self._replace_binop(node) + + def transform(self, ast_root): + ast_root = self.generic_visit(ast_root) + ast_root = ast.fix_missing_locations(ast_root) + return ast_root diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py index 459a5d95546..f9aad92d1df 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py @@ -19,41 +19,51 @@ from typing import Any class FlattenRecursiveCall(ast.NodeTransformer): - def _flatten_call_value(self, node: FunctionDef, call_value: ast.expr, new_target_names, function_index): - if isinstance(call_value, ast.Call): - args = call_value.args - for call_arg_index in range(0, len(args)): - arg = args[call_arg_index] - if not isinstance(arg, ast.Call): - continue - - if not isinstance(arg.func, ast.Attribute): - raise RuntimeError("func of call can only support ast.attribute now") - assert isinstance(arg.func.attr, str) - target_name = f"self_{arg.func.attr}" - - suffix = 0 - while target_name in new_target_names: - suffix += 1 - target_name = f"self_{arg.func.attr}_{suffix}" - new_target_names.append(target_name) - - new_assign_node = ast.Assign(targets=[ast.Name(id=target_name, ctx=ast.Store())], value=arg) - node.body.insert(function_index, new_assign_node) - args.pop(call_arg_index) - args.insert(call_arg_index, ast.Name(id=target_name, ctx=ast.Load())) - return True + @staticmethod + def _flatten_call_value(node: FunctionDef, call_value: ast.expr, new_target_names, function_index): + if not isinstance(call_value, ast.Call): + return False + + args = call_value.args + for call_arg_index in range(0, len(args)): + arg = args[call_arg_index] + if not isinstance(arg, ast.Call): + continue + + if not isinstance(arg.func, ast.Attribute): + raise RuntimeError("func of call can only support ast.attribute now") + assert isinstance(arg.func.attr, str) + target_name = f"self_{arg.func.attr}" + + suffix = 0 + while target_name in new_target_names: + suffix += 1 + target_name = f"self_{arg.func.attr}_{suffix}" + new_target_names.append(target_name) + + new_assign_node = ast.Assign(targets=[ast.Name(id=target_name, ctx=ast.Store())], value=arg) + node.body.insert(function_index, new_assign_node) + args.pop(call_arg_index) + args.insert(call_arg_index, ast.Name(id=target_name, ctx=ast.Load())) + return True return False def visit_FunctionDef(self, node: FunctionDef) -> Any: + if node.name != "construct": + return node + changed = True - if node.name == "construct": - new_target_names = [] - while changed: - changed = False - for function_index in range(len(node.body) - 1, -1, -1): - child = node.body[function_index] - if isinstance(child, ast.Assign): - call_value = child.value - changed = changed or self._flatten_call_value(node, call_value, new_target_names, function_index) + new_target_names = [] + while changed: + changed = False + for function_index in range(len(node.body) - 1, -1, -1): + child = node.body[function_index] + if isinstance(child, ast.Assign): + call_value = child.value + changed = changed or self._flatten_call_value(node, call_value, new_target_names, function_index) return node + + def transform(self, ast_root): + ast_root = self.generic_visit(ast_root) + ast_root = ast.fix_missing_locations(ast_root) + return ast_root -- Gitee From 095ee67cded1af5aa0fe85ad621c332c2c7b7c6a Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Thu, 10 Feb 2022 19:06:51 +0800 Subject: [PATCH 10/32] use ast_manager in resolve --- .../rewrite_experiment/ast_parse/ast_parse.py | 10 +- .../rewrite_experiment/ast_parse/resolver.py | 152 ++++------ .../rewrite_experiment/ast_parse/symbol.py | 53 +++- .../symbol_transformers/const_symbol_fold.py | 12 +- .../flatten_call_symbol.py | 2 +- ...rename_construct_function_assign_target.py | 2 +- .../rewrite_experiment/common/ast_modifier.py | 6 + .../rewrite_experiment/example/test_lenet.py | 17 +- .../mindspore/rewrite_experiment/node.py | 2 +- .../mindspore/rewrite_experiment/rewrite.py | 26 +- .../rewrite_experiment/symbol_tree.py | 274 ++++++------------ 11 files changed, 233 insertions(+), 323 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py index 3d5cce293f2..cc08d1133d8 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py @@ -68,19 +68,19 @@ class AstParse: unused_symbol_erase_mgr.add_processor(EraseUnusedSymbol(stb)) unused_symbol_erase_mgr.process(stb) - rename_construct_func_assign_target_mgr = ProcessorManager() - rename_construct_func_assign_target_mgr.add_processor(RenameConstructFuncAssignTarget(stb)) - rename_construct_func_assign_target_mgr.process(stb, True) + # rename_construct_func_assign_target_mgr = ProcessorManager() + # rename_construct_func_assign_target_mgr.add_processor(RenameConstructFuncAssignTarget(stb)) + # rename_construct_func_assign_target_mgr.process(stb, True) @staticmethod def parse(network: Union[nn.Cell, Primitive, FunctionType]) -> Tuple[SymbolTree, SymbolTable]: AstManager.instance().update_ast(AstParse._ast_transform(AstManager.instance().get_ast())) # parse ast to symbols until all symbols are not compilable stb = AstParse._parse(network) - logger.warning("---------------------- After parse: %d", len(stb.get_symbols().values())) + logger.info("---------------------- After parse: %d", len(stb.get_symbols().values())) # optimize symbols, convert construct-function to graph AstParse._optimize(stb) - logger.warning("---------------------- After optimize: %d", len(stb.get_symbols().values())) + logger.info("---------------------- After optimize: %d", len(stb.get_symbols().values())) # stb.print() symbol_tree = Resolver(network, stb).resolve() return symbol_tree, stb diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py index f63077e7410..a2ba3499ad7 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py @@ -1,16 +1,12 @@ import ast from collections import OrderedDict from typing import Optional - -import astunparse from mindspore import log as logger -from mindspore import nn -from mindspore.nn import Cell from .symbol_table import SymbolTable -from .symbol import Symbol, FunctionSymbol, SymbolType, ReturnSymbol, ClassSymbol, ArgumentsSymbol, AssignSymbol, \ - CallSymbol, ModuleSymbol -from .value import ValueType, ImmValue +from .symbol import FunctionSymbol, SymbolType, ClassSymbol, ArgumentsSymbol, AssignSymbol, CallSymbol, ModuleSymbol, \ + NameSymbol, AttributeSymbol +from .value import ValueType, ImmValue, Value from ..symbol_tree import SymbolTree from ..common import Argument, AstModifier @@ -22,44 +18,48 @@ class Resolver: self._symbol_trees = [] self._ori_cls_name = type(net).__name__ self._opt_cls_name = self._ori_cls_name + "Opt" - self._module_symbol = self._get_module_symbol() + self._module_symbol = self._find_module_symbol() self._symbol_attribute = {} def resolve(self) -> SymbolTree: + # process module ast Resolver._add_import_to_module(self._module_symbol.symbol_ast) - class_symbols = self._get_class_symbols() + # resolve class ast to symbol_tree + class_symbols = self._find_class_symbols() self._create_symbol_trees(class_symbols, self._origin_net) + if len(self._symbol_trees) != 1: + raise RuntimeError("Only support one class in stb now") return self._symbol_trees[0] @staticmethod def _add_import_to_module(module: ast.Module): assert module is not None module.body.insert(0, ast.Import([ast.alias('mindspore')])) - module.body.insert(1, ast.ImportFrom(module='mindspore.nn', names=[ast.alias('Cell')], level=0)) module.body.insert(1, ast.ImportFrom(module='mindspore', names=[ast.alias('nn')], level=0)) + module.body.insert(2, ast.ImportFrom(module='mindspore.nn', names=[ast.alias('Cell')], level=0)) ast.fix_missing_locations(module) def _create_symbol_trees(self, class_symbols: [ClassSymbol], net): - if len(class_symbols) != 1: - raise RuntimeError("Only support one class in stb now") - class_symbol = class_symbols[0] - self._symbol_trees.append(self._create_symbol_tree(class_symbol, net)) + for class_symbol in class_symbols: + self._symbol_trees.append(self._create_symbol_tree(class_symbol, net)) def _create_symbol_tree(self, class_symbol: ClassSymbol, net) -> SymbolTree: - symbol_init_fn: FunctionSymbol = self._get_init_fn(class_symbol) - assert symbol_init_fn is not None - symbol_construct_fn: FunctionSymbol = self._get_construct_fn(class_symbol) - assert symbol_construct_fn is not None + # process class ast self._process_class_ast(class_symbol.symbol_ast) + # process __init__ function in class ast + symbol_init_fn: FunctionSymbol = Resolver._find_fn_in_class(class_symbol, "__init__") + assert symbol_init_fn is not None self._process_init_func_ast(symbol_init_fn.symbol_ast) self._process_init_func_attribute() - origin_network_key = "handler" + # resolve construct function in class ast to symbol_tree + symbol_construct_fn: FunctionSymbol = Resolver._find_fn_in_class(class_symbol, "construct") + assert symbol_construct_fn is not None symbol_tree = SymbolTree(net, self._module_symbol.symbol_ast, class_symbol.symbol_ast, - symbol_init_fn.symbol_ast, symbol_construct_fn.symbol_ast, {origin_network_key: net}) + symbol_init_fn.symbol_ast, symbol_construct_fn.symbol_ast) self._process_construct_func_ast(symbol_construct_fn, symbol_tree) return symbol_tree - def _get_module_symbol(self) -> Optional[ModuleSymbol]: + def _find_module_symbol(self) -> Optional[ModuleSymbol]: for k, v in self._stb.items(): symbol = v.value() # todo check is subclass of cell @@ -67,7 +67,7 @@ class Resolver: return symbol return None - def _get_class_symbols(self) -> [ClassSymbol]: + def _find_class_symbols(self) -> [ClassSymbol]: results = [] for k, v in self._stb.items(): symbol = v.value() @@ -77,45 +77,43 @@ class Resolver: return results @staticmethod - def _get_construct_fn(class_symbol: ClassSymbol): + def _find_fn_in_class(class_symbol: ClassSymbol, fn_name: str): for body in class_symbol.get_bodies(): - if body.symbol_type() == SymbolType.function_def and body.get_func_name() == "construct": + if body.symbol_type() == SymbolType.function_def and body.get_func_name() == fn_name: return body return None - @staticmethod - def _get_init_fn(class_symbol: ClassSymbol): - for body in class_symbol.get_bodies(): - if body.symbol_type() == SymbolType.function_def and body.get_func_name() == "__init__": - return body - return None + def _process_class_ast(self, ast_class_def: ast.ClassDef): + # change class name + assert ast_class_def.name == self._ori_cls_name + ast_class_def.name = self._opt_cls_name @staticmethod - def _get_cell_type(name: str): - if name.find("conv") != -1: - return nn.Conv2d - elif name.find("pool") != -1: - return nn.MaxPool2d - elif name.find("fc") != -1: - return nn.Dense - elif name.find("relu") != -1: - return nn.ReLU - elif name.find("flatten") != -1: - return nn.Flatten - else: - return type(None) + def create_argument_from_name_symbol(symbol: NameSymbol) -> Argument: + assert isinstance(symbol, NameSymbol) + name_value = symbol.get_name_name() + assert name_value.value_type == ValueType.String + return Argument.create_naming_arg(name_value.value) @staticmethod - def _get_real_func(name: str): - pos = name.find("self.") - if pos == -1: - return name - else: - return name[pos+5:] + def create_argument_from_attribute_symbol(symbol: AttributeSymbol) -> Argument: + assert isinstance(symbol, AttributeSymbol) + attr_value_value = symbol.get_attribute_value() + assert isinstance(attr_value_value, NameSymbol) + scope = attr_value_value.get_name_name().value + attr_attr_value = symbol.get_attribute_attr() + assert attr_attr_value.value_type == ValueType.String + return Argument.create_naming_arg(attr_attr_value.value, scope) - def _process_class_ast(self, ast_class_def: ast.ClassDef): - assert ast_class_def.name == self._ori_cls_name - ast_class_def.name = self._opt_cls_name + @staticmethod + def create_argument(value: Value) -> Optional[Argument]: + if isinstance(value, NameSymbol): + return Resolver.create_argument_from_name_symbol(value) + elif isinstance(value, AttributeSymbol): + return Resolver.create_argument_from_attribute_symbol(value) + elif isinstance(value, ImmValue): + return Argument.create_imm_arg(value.value) + raise RuntimeError("Unsupported value to argument:", value) #@staticmethod def _process_construct_func_ast(self, symbol_construct_fn: FunctionSymbol, symbol_tree: SymbolTree): @@ -137,28 +135,25 @@ class Resolver: if body is None: continue if isinstance(body, AssignSymbol): - target = Argument.create_naming_arg(body.get_assign_targets()[0].value) + target = Resolver.create_argument(body.get_assign_targets()[0]) call = body.get_assign_value() assert isinstance(call, CallSymbol) - func = Resolver._get_real_func(call.get_call_func().value) + func = call.func_name() call_args: [Argument] = [] for arg in call.get_call_args(): - if not isinstance(arg, ImmValue): - raise RuntimeError("Naming symbol in arguments is unsupported") - call_args.append(Argument.create_imm_arg(arg.value)) + call_args.append(Resolver.create_argument(arg)) if len(call.get_call_keywords()) > 0: raise RuntimeError("kwargs in construct function assign is unsupported") - cell_type = Resolver._get_cell_type(func) - attributes = self._symbol_attribute[repr(call.get_call_func())] + attributes = self._symbol_attribute[f"self.{call.func_name()}"] # todo call attribute_resolver to resolve attributes ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, None, [target], call_args, attributes, body.symbol_ast) if ret is None: - raise RuntimeError("add_origin_field failed: ", ) + raise RuntimeError("add_origin_field failed: ") elif body.symbol_type() == SymbolType.return_: symbol_tree.set_output_ast(body.symbol_ast) - ret = symbol_tree.update_output([body.get_return_value().value]) + ret = symbol_tree.update_output([body.get_return_str()]) if ret is None: raise RuntimeError("update_output failed") @@ -237,38 +232,3 @@ class Resolver: assign.value = ast.Call(ast.Name('getattr', ast.Load()), [ast.Attribute(ast.Name('self', ast.Load()), '_handler', ast.Load()), ast.Constant(field_name)], []) - - def get_graph(self): - # add all node to graph - return_symbol: ReturnSymbol = self.find_graph_return() - if return_symbol is None: - logger.error("Construct function has no return expression") - return None - self._symbol_tree.add_output(return_symbol.return_value().value) - # compute edge - - def find_graph_return(self) -> Optional[ReturnSymbol]: - construct_func = None - for key in self._stb.keys(): - symbol: Symbol = self._stb[key] - if symbol.symbol_type() == SymbolType.function_def and isinstance(symbol, FunctionSymbol): - func_symbol: FunctionSymbol = symbol - if func_symbol.get_func_name() == "construct": - if construct_func is not None: - logger.error("Each symbol table should have only one construct function") - return None - construct_func = func_symbol - if construct_func is None: - logger.error("Can not find construct function in symbol table") - return None - bodies = construct_func.get_body_names() - return_symbol = None - for body_name in bodies: - symbol: Symbol = self._stb.get(body_name) - assert symbol is not None - if symbol.symbol_type() == SymbolType.return_ and isinstance(symbol, ReturnSymbol): - if return_symbol is not None: - logger.error("Each function should have only one return statement") - return None - return_symbol = symbol - return return_symbol diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py index ed1f2168508..a98749a80f7 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py @@ -18,7 +18,7 @@ from collections import OrderedDict from typing import Optional from .namespace import SymbolNameDepot -from .value import ValueType, Value +from .value import ValueType, Value, ImmValue from .value_node import ValueNode, make_imm_value_node sn_depot = SymbolNameDepot() @@ -562,12 +562,42 @@ class CallSymbol(Symbol): keywords.append(keyword.value()) return keywords - def generate_candidate_func_name(self) -> Optional[str]: - if self._func.value_type() is not ValueType.String: - return None - func_name: str = self._func.value().value - func_name = func_name.replace(".", "_") - return func_name + def func_scope(self) -> str: + func_symbol = self._func.value() + if isinstance(func_symbol, NameSymbol): + return "" + elif isinstance(func_symbol, AttributeSymbol): + scope_name = func_symbol.get_attribute_value() + assert isinstance(scope_name, NameSymbol) + scope_name_value = scope_name.get_name_name() + assert isinstance(scope_name_value, ImmValue) + assert scope_name_value.value_type == ValueType.String + return scope_name_value.value + else: + raise RuntimeError("FuncValue is should be Name or a Attribute:", func_symbol) + + def func_name(self) -> str: + func_symbol = self._func.value() + if isinstance(func_symbol, NameSymbol): + name = func_symbol.get_name_name() + if name.value_type is not ValueType.String: + raise RuntimeError("name of NameSymbol should be a StringValue:", name) + return name.value + elif isinstance(func_symbol, AttributeSymbol): + attr_value = func_symbol.get_attribute_attr() + assert isinstance(attr_value, ImmValue) + assert attr_value.value_type == ValueType.String + return attr_value.value + else: + raise RuntimeError("FuncValue is should be Name or a Attribute:", func_symbol) + + def get_full_func_name(self) -> str: + func_scope = self.func_scope() + func_name = self.func_name() + if len(func_scope) == 0: + return func_name + else: + return f"{func_scope}_{func_name}" def __str__(self): return f"CallSymbol({self.get_full_name_with_scope()}): {self._func}({str(self._args)})" @@ -658,6 +688,15 @@ class ReturnSymbol(Symbol): def get_return_value(self) -> Value: return self._value.value() + def get_return_str(self) -> str: + value = self._value.value() + if value.value_type == ValueType.String: + return value.value + elif isinstance(value, NameSymbol): + return value.get_name_name().value + else: + raise RuntimeError("Return value should be a StringValue or NameSymbol") + def __str__(self): return f"ReturnSymbol({self.get_full_name_with_scope()}): {self._value.value}" diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_fold.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_fold.py index 3fd5e32a329..7d93698b16b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_fold.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_fold.py @@ -24,10 +24,10 @@ from mindspore import log as logger class ConstSymbolFold(Processor): def __init__(self): - self._fold_fn_map = {SymbolType.name: ConstSymbolFold.fold_name_symbol, + self._fold_fn_map = {SymbolType.bin_op: ConstSymbolFold.fold_binop_symbol, + # SymbolType.name: ConstSymbolFold.fold_name_symbol, SymbolType.attribute: ConstSymbolFold.fold_attribute_symbol, SymbolType.arg: ConstSymbolFold.fold_arg_symbol, - SymbolType.bin_op: ConstSymbolFold.fold_binop_symbol, } @staticmethod @@ -39,10 +39,10 @@ class ConstSymbolFold(Processor): return False return True - @staticmethod - def fold_name_symbol(name_symbol: NameSymbol) -> Optional[ConstantSymbol]: - return ConstantSymbol(ast_node=name_symbol.symbol_ast, scope=name_symbol.scope, - symbol_name=name_symbol.symbol_name, value=name_symbol.get_name_name()) + # @staticmethod + # def fold_name_symbol(name_symbol: NameSymbol) -> Optional[ConstantSymbol]: + # return ConstantSymbol(ast_node=name_symbol.symbol_ast, scope=name_symbol.scope, + # symbol_name=name_symbol.symbol_name, value=name_symbol.get_name_name()) @staticmethod def fold_arg_symbol(arg_symbol: ArgSymbol) -> Optional[ConstantSymbol]: diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/flatten_call_symbol.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/flatten_call_symbol.py index 1fb017c8051..9151fba87ed 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/flatten_call_symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/flatten_call_symbol.py @@ -44,7 +44,7 @@ class FlattenCallSymbol(Processor): arg = args[i] if not isinstance(arg, CallSymbol): continue - func_name = arg.generate_candidate_func_name() + func_name = arg.get_full_func_name() assert func_name is not None name_number = self._names.get(func_name) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/rename_construct_function_assign_target.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/rename_construct_function_assign_target.py index 4b275102711..0fae8933c41 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/rename_construct_function_assign_target.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/rename_construct_function_assign_target.py @@ -67,7 +67,7 @@ class RenameConstructFuncAssignTarget(Processor): call.get_full_name_with_scope()) arg.value = new_arg_str - target_name = self.generate_final_target(call.generate_candidate_func_name()) + target_name = self.generate_final_target(call.get_full_func_name()) targets = assign_symbol.get_assign_targets() if len(targets) == 1: if not isinstance(targets[0], Value) or targets[0].value_type != ValueType.String: diff --git a/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py b/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py index 0574b75bf8c..58b568c8e67 100644 --- a/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py +++ b/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py @@ -86,6 +86,12 @@ class AstModifier(ast.NodeTransformer): return assign return None + @staticmethod + def insert_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [Argument], args: [Argument]) -> \ + Optional[ast.AST]: + return AstModifier.insert_assign_to_function(init_func, targets=targets, args=args, + expr=Argument(ArgType.NamingArg, "global_vars", "get")) + @staticmethod def create_assign(targets: [Argument], expr: Argument, args: [Argument], kwargs: {str, Argument}): if targets is None or len(targets) != 1: diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py index 87ae09dc257..706216fd785 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py @@ -17,6 +17,11 @@ class MyCell(nn.Cell): return x +def print_code(title, rewrite): + print(f"=========={title}===================================================================================") + print(rewrite.get_code()) + + def transform(rw: Rewrite): for _, node in rw.nodes().items(): targets = node.get_targets() @@ -33,7 +38,7 @@ def transform(rw: Rewrite): if ret is None: raise RuntimeError("add_cell failed") break - rewrite.dump("after add_cell") + print_code("after add_cell", rw) for _, node in rw.nodes().items(): targets = node.get_targets() if targets is None: @@ -50,7 +55,7 @@ def transform(rw: Rewrite): if ret is None: raise RuntimeError("add_custom_node failed") break - rewrite.dump("after add_custom_node") + print_code("after add_custom_node", rw) for _, node in rw.nodes().items(): targets = node.get_targets() if targets is None: @@ -62,9 +67,9 @@ def transform(rw: Rewrite): if not ret: raise RuntimeError("Update arg failed") break - rewrite.dump("after update_arg") + print_code("after update_arg", rw) rw.set_output(0, "nx2") - rewrite.dump("after add_output") + print_code("after add_output", rw) for _, node in rw.nodes().items(): targets = node.get_targets() if targets is None: @@ -76,13 +81,13 @@ def transform(rw: Rewrite): if ret is None: raise RuntimeError("erase_node failed") break - rewrite.dump("after erase_node") + print_code("after erase_node", rw) if __name__ == '__main__': lenet = LeNet5(10) rewrite = Rewrite(lenet) - rewrite.dump("after resolve") + print_code("after resolve", rewrite) transform(rewrite) lenet_opt = rewrite.get_network() context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=True, save_graphs_path='./lenet_dump') diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index e7cf5427494..788d2b3082a 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -32,7 +32,7 @@ class NodeType(Enum): CallMethod = 2 # method in cell CallFunction = 3 # subclass of primitive - UserCustom = 4 + UserCustom = 4 # todo remove Input = 5 Output = 6 Graph = 7 diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index 168e7f9fce9..e9a4799e16a 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -45,6 +45,7 @@ class Rewrite: @staticmethod def create_node(op, targets: [Argument], target_type: str = None, args: [Argument] = None, kwargs: {str: Argument} = None, name: str = None) -> Node: + # 'targets': 'target_type' = self.'field'(*'args', **'kwargs') if isinstance(op, Cell): # todo create ast from op: ast can be created when insert into symbol_tree ast_node = None @@ -56,34 +57,16 @@ class Rewrite: else: raise RuntimeError("Only support Cell op or Primitive op!") + # todo define class position: use symbol_tree, node and before/after as position def insert(self, position, node: Node, field: str = None, args: [Argument] = None, kwargs: {str: Argument} = None) \ -> Optional[Node]: + # self.'field': 'custom_obj_type' = global_vars.get('field') if args is not None: node.set_args(args) if kwargs is not None: node.set_kwargs(kwargs) return self._symbol_tree.insert_node(position, node, field) - # # self.'field': 'cell_type' = mindspore.nn.'cell_type'(*'construct_args', **'construct_kwargs') - # # 'targets': 'target_type' = self.'field'(*'call_args', **'call_kwargs') - # def add_cell(self, position, cell_type: type, field: str = None, construct_args: [Argument] = None, - # construct_kwargs: {str: Argument} = None, targets: [str] = None, target_type: str = "", - # call_args: [Argument] = None, call_kwargs: {str: Argument} = None) -> Optional[Node]: - # # todo call ast_parser and attributes_resolver to resolve attributes - # attribute = {} - # return self._symbol_tree.add_cell(position, cell_type, field, construct_args, construct_kwargs, targets, - # target_type, call_args, call_kwargs, attribute) - # - # def add_function(self, position, *args, **kwargs): - # raise NotImplementedError - # - # # self.'field': 'custom_obj_type' = global_vars.get('field') - # # 'targets': 'target_type' = self.'field'(*'call_args', **'call_kwargs') - # def add_object(self, position, custom_obj: Cell, field: str, targets: [str], target_type: str, - # call_args: [Argument], call_kwargs: {str: Argument}) -> Optional[Node]: - # return self._symbol_tree.add_object(position, custom_obj, field, targets, call_args, call_kwargs, - # target_type) - def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: return self._symbol_tree.erase_node(node_or_name) @@ -97,8 +80,7 @@ class Rewrite: out_idx: Optional[int] = None): node_to_update.update_arg_by_node(arg_idx, node_to_link, out_idx) - def dump(self, title: str): - print(f"=========={title}===================================================================================") + def dump(self): self._symbol_tree.dump() def get_code(self) -> str: diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index cc03c87a542..9df4245b305 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -21,13 +21,41 @@ from typing import Optional, Union, Tuple import astpretty import astunparse -from .node import Node, global_vars_name, NodeType +from .node import Node, NodeType from mindspore.nn import Cell from mindspore import log as logger from .common import AstModifier, Argument, ArgType from .symbol_tree_dumper import SymbolTreeDumper +class Namer: + def __init__(self): + self._names: {str: int} = {} + + def get_name(self, origin_name: str) -> str: + number = self._names.get(origin_name) + if number is None: + self._names[origin_name] = 1 + return origin_name + else: + self._names[origin_name] = number + 1 + return f"{origin_name}_{number}" + + +class NodeNamer(Namer): + def get_name(self, node: Node) -> str: + origin_name = node.get_name() + if len(origin_name) == 0: + targets = node.get_targets() + # return node and head node will not call this method + assert targets is not None + if len(targets) == 0: + raise RuntimeError("node should has at lease one target except return-node and head-node.", node) + else: + origin_name = str(targets[0].name) + return super(NodeNamer, self).get_name(origin_name) + + class Input: def __init__(self, name: str, arg_type: Optional[type] = None, default: Optional[str] = None): self.name: str = name @@ -46,21 +74,27 @@ class Input: # one symbol-tree has one init and one construct. (not considering functional) class SymbolTree: def __init__(self, origin_network: Cell, module_ast: ast.Module, ast_root: ast.ClassDef, - init_func_ast: ast.FunctionDef, construct_func_ast: ast.FunctionDef, global_vars: {str, object}): - self._global_vars: {str, object} = global_vars + init_func_ast: ast.FunctionDef, construct_func_ast: ast.FunctionDef): + origin_network_key = "handler" + self._global_vars: {str, object} = {origin_network_key: origin_network} self._nodes: {str, Node} = {} self._inputs: [Input] = [] - self._module_ast: ast.Module = module_ast # need deep copy? - self._ast_root: ast.ClassDef = ast_root # need deep copy? - self._init_func_ast: ast.FunctionDef = init_func_ast # need deep copy? - self._construct_func_ast: ast.FunctionDef = construct_func_ast # need deep copy? - # root must be output of graph - self._root = self._add_output(["undefined"]) - # head must be the first statement but must not be inputs of graph - self._head = self._add_head_node() + self._module_ast: ast.Module = module_ast + self._ast_root: ast.ClassDef = ast_root + self._init_func_ast: ast.FunctionDef = init_func_ast + self._construct_func_ast: ast.FunctionDef = construct_func_ast self._ori_cls_name = type(origin_network).__name__ self._opt_cls_name = self._ori_cls_name + "Opt" self._origin_network = origin_network + # root must be output of graph + self._root = self._add_output_node(["undefined"]) + # head must be the first statement but must not be inputs of graph + self._head = self._add_head_node() + # init unique-namers + # todo init namer by origin source-code + self._field_namer = Namer() + self._target_namer = Namer() + self._node_name_namer = NodeNamer() def get_inputs(self): return self._inputs @@ -135,8 +169,15 @@ class SymbolTree: return node def insert_node(self, position, node: Node, field: str = None) -> Optional[Node]: + # check position + index = self._find_node_index(position[1]) + assert index is not None + # process node name + node.set_name(self._node_name_namer.get_name(node)) + # process field if field is None: - field = self._generate_new_field_name() + field = node.get_name() + field = self._field_namer.get_name(field) # modify init function ast_node = AstModifier.insert_assign_to_function(self._init_func_ast, targets=[Argument(ArgType.NamingArg, "self", field)], @@ -145,11 +186,6 @@ class SymbolTree: if ast_node is None: logger.error("insert custom_node into init function ast tree failed.") return None - # check position - index = self._find_node_index(position[1]) - assert index is not None - # process node name - node.set_name(self._generate_node_name(node.get_targets(), node.get_args(), node.get_kwargs())) # process args of node call_args, new_arg_nodes = self._handle_custom_object_in_args(self._init_func_ast, node.get_args()) for node in new_arg_nodes: @@ -166,7 +202,7 @@ class SymbolTree: node.set_kwargs(call_kwargs) # process targets of node targets = node.get_targets() - targets = SymbolTree._convert_targets(targets) + targets = SymbolTree._convert_strs_to_name_arguments(targets) node.set_targets(targets) # modify construct function ast_node = AstModifier.insert_assign_to_function(self._construct_func_ast, targets=targets, @@ -184,98 +220,6 @@ class SymbolTree: return None return node - def add_cell(self, position, cell_type: type, field: str = None, construct_args: [Argument] = None, - construct_kwargs: {str: Argument} = None, targets: [str] = None, target_type: str = "", - call_args: [Argument] = None, call_kwargs: {str: Argument} = None, - attribute: {str, object}=None) -> Optional[Node]: - class_name = cell_type.__name__ - if not issubclass(cell_type, Cell): - logger.error("not a subclass of cell, got: %s", class_name) - return None - construct_args, _ = self._handle_custom_object_in_args(self._init_func_ast, construct_args) - construct_kwargs, _ = self._handle_custom_object_in_kwargs(self._init_func_ast, construct_kwargs) - if field is None: - field = self._generate_new_field_name() - # modify init function - ast_node = AstModifier.insert_assign_to_function(self._init_func_ast, - [Argument(ArgType.NamingArg, "self", field)], - Argument(ArgType.NamingArg, "mindspore.nn", class_name), - construct_args, construct_kwargs) - if ast_node is None: - logger.error("insert cell into init function ast tree failed.") - return None - # modify construct function - index = self._find_node_index(position[1]) - assert index is not None - call_args, new_arg_nodes = self._handle_custom_object_in_args(self._construct_func_ast, call_args) - for node in new_arg_nodes: - if not self._insert_node(position, node): - logger.error("insert custom object node into symbol_tree failed.") - return None - call_kwargs, new_kwarg_nodes = self._handle_custom_object_in_kwargs(self._construct_func_ast, call_kwargs) - for node in new_kwarg_nodes: - if not self._insert_node(position, node): - logger.error("insert custom object node into symbol_tree failed.") - return None - if targets is None: - targets = [self._generate_new_target_name()] - targets = SymbolTree._convert_targets(targets) - ast_node = AstModifier.insert_assign_to_function(self._construct_func_ast, targets, - Argument(ArgType.NamingArg, "self", field), - call_args, call_kwargs, position[1].get_ast(), position[0]) - if ast_node is None: - logger.error("insert cell into construct function ast tree failed.") - return None - # create and insert node - node = Node(NodeType.CallCell, targets, call_args, call_kwargs, ast_node, attribute, "", cell_type) - node.set_field(field) - if not self._insert_node(position, node): - return None - return node - - def add_object(self, position, custom_obj: Cell, field: str, targets: [str], call_args: [], - call_kwargs: {str: object}, target_type: str = "") -> Optional[Node]: - if field is None: - field = self._generate_new_field_name() - # modify init function - ast_node = AstModifier.insert_assign_to_function(self._init_func_ast, - targets=[Argument(ArgType.NamingArg, "self", field)], - expr=Argument(ArgType.NamingArg, "global_vars", "get"), - args=[Argument(ArgType.StringArg, "", field)]) - if ast_node is None: - logger.error("insert custom_node into init function ast tree failed.") - return None - # modify construct function - index = self._find_node_index(position[1]) - assert index is not None - call_args, new_arg_nodes = self._handle_custom_object_in_args(self._init_func_ast, call_args) - for node in new_arg_nodes: - if not self._insert_node(position, node): - logger.error("insert custom object node into symbol_tree failed.") - return None - call_kwargs, new_kwarg_nodes = self._handle_custom_object_in_kwargs(self._init_func_ast, call_kwargs) - for node in new_kwarg_nodes: - if not self._insert_node(position, node): - logger.error("insert custom object node into symbol_tree failed.") - return None - if targets is None: - targets = [self._generate_new_target_name()] - targets = SymbolTree._convert_targets(targets) - ast_node = AstModifier.insert_assign_to_function(self._construct_func_ast, targets=targets, - expr=Argument(ArgType.NamingArg, "self", field), - args=call_args, kwargs=call_kwargs, - index_ast=position[1].get_ast(), insert_before=position[0]) - if ast_node is None: - logger.error("insert custom_node into construct function ast tree failed.") - return None - # create and insert node - self._global_vars[field] = custom_obj - node = Node(NodeType.UserCustom, targets, call_args, call_kwargs, ast_node, {}, "", type(custom_obj)) - node.set_field(field) - if not self._insert_node(position, node): - return None - return node - # todo update ast def add_input(self, name: str, input_type: Optional[type] = None, default: Optional[str] = None) -> bool: for arg in self._inputs: @@ -302,14 +246,14 @@ class SymbolTree: attribute: {str, object}=None) -> Optional[Node]: if targets is None: targets = [self._generate_new_target_name()] - targets = SymbolTree._convert_targets(targets) + targets = SymbolTree._convert_strs_to_name_arguments(targets) assert len(handler_name) > 0 if attribute is None: attribute = {} node_name = SymbolTree._generate_node_name(targets, args, kwargs) node = Node(NodeType.UserCustom, ast_node, attribute, targets, args, kwargs, node_name, op) node.set_field(handler_name) - if not self._insert_node(position, node, False): + if not self._insert_node(position, node): return None return node @@ -320,7 +264,7 @@ class SymbolTree: if len(return_values) == 0: logger.error("return_values should at least has one element") return None - real_return_values = self._convert_targets(return_values) + real_return_values = self._convert_strs_to_name_arguments(return_values) self._root.set_args(real_return_values) return self._root @@ -364,37 +308,6 @@ class SymbolTree: ast.fix_missing_locations(self._module_ast) return astunparse.unparse(self._module_ast) - def get_code3(self) -> str: - indent = " " - line = "\r\n" - code = f"import mindspore{line}{line}{line}" - code += f"class {self._opt_cls_name}(mindspore.nn.Cell):{line}" - # init code: - init_code = astunparse.unparse(self._init_func_ast) - code += init_code - code += f"{line}{line}" - # code += f"{indent}def __init__(self, global_vars):{line}" - # code += f"{indent}{indent}super({self._opt_cls_name}, self).__init__(){line}" - # node = self._head - # while node is not None: - # tmp = node.get_init_code() - # if len(tmp) > 0: - # code += f"{indent}{indent}{tmp}{line}" - # node = node.get_next() - - arg_str: str = "" - for arg in self._inputs: - arg_str += f", {arg.get_code3()}" - - code += f"{line}{indent}def construct(self{arg_str}):{line}" - node = self._head.get_next() - while node is not None: - tmp = node.get_code() - if len(tmp) > 0: - code += f"{indent}{indent}{tmp}{line}" - node = node.get_next() - return code - def get_network(self): print("------------------------------------ keys of global_vars: ", self._global_vars.keys()) # cls = self._get_cls_directly() @@ -413,40 +326,30 @@ class SymbolTree: def _append_node2nodes(self, node: Node) -> bool: node_name = node.get_name() if self._nodes.get(node_name) is not None: - logger.error("generated duplicated node name: %s, %s", self._nodes.get(node_name), node) - return False + # todo wait for ast_transformer + # logger.error("generated duplicated node name(%s): %s, %s", node_name, self._nodes.get(node_name), node) + # return False + node_name = node_name + "_" + str(len(self._nodes)) + node.set_name(node_name) self._nodes[node_name] = node return True - def _insert_node(self, position, node: Node, insert_to_ast: bool = True) -> bool: + def _insert_node(self, position, node: Node) -> bool: if not self._append_node2nodes(node): return False if position[0]: position[1].insert_before(node) else: position[1].insert_after(node) - # if not insert_to_ast: - # return True - # index = self._find_node_index(self._insert_point) - # assert index is not None - # if node.get_ast() is None: - # logger.error("node has no ast while insert to construct_function_ast") - # return False - # self._construct_func_ast.body.insert(index, node.get_ast()) return True - def _generate_new_target_name(self) -> str: - pass - - def _generate_new_field_name(self) -> str: - pass - - def _generate_new_global_var_key(self, obj) -> str: + @staticmethod + def _generate_new_global_var_key(obj) -> str: key = "var_" + type(obj).__name__ return key.lower() @staticmethod - def _convert_targets(targets: [str]) -> [Argument]: + def _convert_strs_to_name_arguments(targets: [str]) -> [Argument]: result = [] for target in targets: result.append(Argument(ArgType.NamingArg, "", target)) @@ -459,6 +362,18 @@ class SymbolTree: result[arg] = Argument.create_imm_arg(value) return result + def _insert_blackbox_object_into_init(self, object, field: str = "") -> Node: + if len(field) == 0: + field = f"var_{type(object).__name__}" + field = self._field_namer.get_name(field) + self._global_vars[field] = object + targets = [Argument.create_naming_arg(field, "self")] + args = [Argument.create_imm_arg(field)] + ast_node = AstModifier.insert_global_vars_expr_to_init(self._init_func_ast, targets, args) + if ast_node is None: + raise RuntimeError("insert custom obj to init_func failed") + return Node(NodeType.UserCustom, ast_node, {}, targets, args, {}, field, None) + def _handle_custom_object_in_args(self, ast_func: ast.FunctionDef, args: list[Argument]) -> \ Tuple[list[Argument], list[Node]]: result: [Argument] = [] @@ -466,17 +381,18 @@ class SymbolTree: for arg in args: assert isinstance(arg, Argument) if arg.type == ArgType.CustomObjArg: - new_key = self._generate_new_global_var_key(arg) - self._global_vars[new_key] = arg.name - targets = [Argument(ArgType.NamingArg, "", new_key)] - args = [Argument(ArgType.StringArg, "", new_key)] - ast_node = AstModifier.insert_assign_to_function(ast_func, targets=targets, args=args, - expr=Argument(ArgType.NamingArg, "global_vars", "get")) - if ast_node is None: - raise RuntimeError("insert custom obj to ast_tree failed") - node = Node(NodeType.CallMethod, targets, args, None) + # new_key = SymbolTree._generate_new_global_var_key(arg) + # self._global_vars[new_key] = arg.name + # targets = [Argument(ArgType.NamingArg, "", new_key)] + # args = [Argument(ArgType.StringArg, "", new_key)] + # ast_node = AstModifier.insert_assign_to_function(ast_func, targets=targets, args=args, + # expr=Argument(ArgType.NamingArg, "global_vars", "get")) + node = self._insert_blackbox_object_into_init(arg.name) + # if ast_node is None: + # raise RuntimeError("insert custom obj to ast_tree failed") + # node = Node(NodeType.UserCustom, ast_node, {}, targets, args, {}, None) new_nodes.append(node) - result.append(Argument(ArgType.NamingArg, "", new_key)) + result.append(Argument(ArgType.NamingArg, "", node.get_field())) else: result.append(arg) return result, new_nodes @@ -488,7 +404,7 @@ class SymbolTree: for arg, value in kwargs: assert isinstance(value, Argument) if value.type == ArgType.CustomObjArg: - new_key = self._generate_new_global_var_key(value) + new_key = SymbolTree._generate_new_global_var_key(value) self._global_vars[new_key] = value.name targets = [Argument(ArgType.NamingArg, "", new_key)] args = [Argument(ArgType.StringArg, "", new_key)] @@ -503,15 +419,17 @@ class SymbolTree: result[arg] = value return result, new_nodes - def _add_output(self, return_values: [str]) -> Optional[Node]: - real_return_values = self._convert_targets(return_values) - node = Node(NodeType.Output, None, {}, real_return_values, [], {}, "output", None) + def _add_output_node(self, return_values: [str]) -> Optional[Node]: + real_return_values = self._convert_strs_to_name_arguments(return_values) + node = Node(NodeType.Output, None, {}, None, real_return_values, {}, "return", None) self._append_node2nodes(node) return node def _add_head_node(self) -> Optional[Node]: + # todo node_type of head_node + # should input(parameter) be a node of symbol_tree node = Node(NodeType.Unknown, None, {}, None, [], {}, "head", None) - if not self._insert_node((True, self._root), node, False): + if not self._insert_node((True, self._root), node): return None return node -- Gitee From d9677bb9f0715f007576bb341276f4714dc74a87 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Mon, 14 Feb 2022 10:10:18 +0800 Subject: [PATCH 11/32] update for attribute --- .../rewrite_experiment/ast_parse/resolver.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py index a2ba3499ad7..98b6c674b64 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py @@ -19,7 +19,6 @@ class Resolver: self._ori_cls_name = type(net).__name__ self._opt_cls_name = self._ori_cls_name + "Opt" self._module_symbol = self._find_module_symbol() - self._symbol_attribute = {} def resolve(self) -> SymbolTree: # process module ast @@ -50,7 +49,6 @@ class Resolver: symbol_init_fn: FunctionSymbol = Resolver._find_fn_in_class(class_symbol, "__init__") assert symbol_init_fn is not None self._process_init_func_ast(symbol_init_fn.symbol_ast) - self._process_init_func_attribute() # resolve construct function in class ast to symbol_tree symbol_construct_fn: FunctionSymbol = Resolver._find_fn_in_class(class_symbol, "construct") assert symbol_construct_fn is not None @@ -145,10 +143,10 @@ class Resolver: if len(call.get_call_keywords()) > 0: raise RuntimeError("kwargs in construct function assign is unsupported") - attributes = self._symbol_attribute[f"self.{call.func_name()}"] + obj, attributes = self._get_symbol_attribute(call.func_name()) # todo call attribute_resolver to resolve attributes - ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, None, [target], call_args, - attributes, body.symbol_ast) + ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, obj, [target], call_args, + {}, body.symbol_ast, attribute=attributes) if ret is None: raise RuntimeError("add_origin_field failed: ") elif body.symbol_type() == SymbolType.return_: @@ -166,13 +164,17 @@ class Resolver: attributes[k] = v return attributes - def _process_init_func_attribute(self): + def _get_symbol_attribute(self, symbol_name): + logger.debug(f"symbol_name: {symbol_name}") + attributes = {} var_dict = self._origin_net.__dict__ for key, value in var_dict["_cells"].items(): - attributes = Resolver.get_object_attribute(value) - attributes["cls"] = value.__class__ - logger.debug(f"key: {key}, attributes: {attributes}") - self._symbol_attribute["self." + key] = attributes + if key == symbol_name: + attributes = Resolver.get_object_attribute(value) + attributes["cls"] = value.__class__ + logger.debug(f"key: {key}, attributes: {attributes}") + return value, attributes + return None, attributes def _process_init_func_ast(self, ast_init_fn: ast.FunctionDef): self._modify_super_expr_of_init_func(ast_init_fn) -- Gitee From 2ce6efb6c5503b4a2aaff11933c8ad48774ea2bb Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Mon, 14 Feb 2022 10:38:24 +0800 Subject: [PATCH 12/32] add namespace --- .../rewrite_experiment/ast_parse/namespace.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py index 5a37e0a6bb9..32952847417 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py @@ -16,6 +16,66 @@ # ============================================================================ """Define the name_depot of symbol.""" +import inspect +from types import FunctionType + +import mindspore.nn as nn +from mindspore import log as logger +from ..._extends.parse.namespace import CellNamespace,ClosureNamespace, ClassAttrNamespace, ClassMemberNamespace + +class Namespace(): + def __init__(self): + self.ms_common_ns = CellNamespace('mindspore.common') + self.ms_nn_ns = CellNamespace('mindspore.nn') + self.ms_ops_ns = CellNamespace('mindspore.ops') + self.ms_ops_c_ns = CellNamespace('mindspore.ops.composite') + self.ms_ops_c_multitype_ns = CellNamespace('mindspore.ops.composite.multitype_ops') + self.ms_ops_p_ns = CellNamespace('mindspore.ops.operations') + # Used to resolve the function's globals namespace. + self.global_namespace: CellNamespace = None #CellNamespace(network.__module__) if network else None + # Used to resolve the function's nonlocals. + self.closure_namespace: ClosureNamespace = None + + def update_closure_namespace(self, fn: FunctionType): + """ + Update 'closure_namespace' of fn. + """ + self.closure_namespace = ClosureNamespace(inspect.unwrap(fn)) + + def update_global_namespace(self, network): + """ + Update 'global namespace' of network. + """ + if isinstance(network, nn.Cell): + self.global_namespace = CellNamespace(network.__module__) + elif issubclass(network, nn.Cell): + self.global_namespace = CellNamespace(network) + else: + raise ValueError("unsupported type, type: ", type(network)) + + def get_symbole_namesapce(self, symbole_name: str): + """ + Get the namespace of symbole_name. + """ + if symbole_name in self.ms_common_ns: + return self.ms_common_ns[symbole_name], repr(self.ms_common_ns), False + elif symbole_name in self.ms_nn_ns: + return self.ms_nn_ns[symbole_name], repr(self.ms_nn_ns), False + elif symbole_name in self.ms_ops_ns: + return self.ms_ops_ns[symbole_name], repr(self.ms_ops_ns), False + elif symbole_name in self.ms_ops_c_ns: + return self.ms_ops_c_ns[symbole_name], repr(self.ms_ops_c_ns), False + elif symbole_name in self.ms_ops_c_multitype_ns: + return self.ms_ops_c_multitype_ns[symbole_name], repr(self.ms_ops_c_multitype_ns), False + elif symbole_name in self.ms_ops_p_ns: + return self.ms_ops_p_ns[symbole_name], repr(self.ms_ops_p_ns), False + elif symbole_name in self.global_namespace: + return self.global_namespace[symbole_name], repr(self.global_namespace), True + elif symbole_name in self.closure_namespace: + return self.closure_namespace[symbole_name], repr(self.closure_namespace), True + else: + logger.warning(f"get namespace failed, func_name: {symbole_name}") + return None, None, True class NameDepot: """ -- Gitee From 7a197d22c245414de752e5fbdf3105689b37722a Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 14 Feb 2022 10:28:43 +0800 Subject: [PATCH 13/32] add Namer for unique id --- .../rewrite_experiment/ast_parse/resolver.py | 19 +- .../rewrite_experiment/example/test_lenet.py | 1 + .../mindspore/rewrite_experiment/node.py | 9 +- .../rewrite_experiment/symbol_tree.py | 406 ++++++++---------- .../rewrite_experiment/symbol_tree_dumper.py | 4 +- 5 files changed, 191 insertions(+), 248 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py index 98b6c674b64..3bb883e4510 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py @@ -129,6 +129,15 @@ class Resolver: symbol_tree.add_input(arguments.value().value) # resolve body bodies = symbol_construct_fn.get_bodies() + # update return node first + for body in bodies: + if body is None: + continue + if body.symbol_type() == SymbolType.return_: + symbol_tree.set_output_ast(body.symbol_ast) + ret = symbol_tree.update_output([body.get_return_str()]) + if ret is None: + raise RuntimeError("update_output failed") for body in bodies: if body is None: continue @@ -145,15 +154,11 @@ class Resolver: obj, attributes = self._get_symbol_attribute(call.func_name()) # todo call attribute_resolver to resolve attributes - ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, obj, [target], call_args, - {}, body.symbol_ast, attribute=attributes) + # todo func should be field of corresponding init-ast + ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, obj, [target], call_args, {}, + body.symbol_ast, "", attributes) if ret is None: raise RuntimeError("add_origin_field failed: ") - elif body.symbol_type() == SymbolType.return_: - symbol_tree.set_output_ast(body.symbol_ast) - ret = symbol_tree.update_output([body.get_return_str()]) - if ret is None: - raise RuntimeError("update_output failed") @staticmethod def get_object_attribute(obj): diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py index 706216fd785..f3f0de00c27 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py @@ -20,6 +20,7 @@ class MyCell(nn.Cell): def print_code(title, rewrite): print(f"=========={title}===================================================================================") print(rewrite.get_code()) + rewrite.dump() def transform(rw: Rewrite): diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index 788d2b3082a..6517abf3e1f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -50,8 +50,7 @@ class Node: self._attribute: {str, object} = attributes self._op = op self._op_type: type = type(op) - self._name = name - self._field = None + self._name = name # use name as field in class. be unique when insert into graph self._targets: [Argument] = targets self._args: [Argument] = args self._kwargs: {str: Argument} = kwargs @@ -202,12 +201,6 @@ class Node: def set_name(self, name: str): self._name = name - def get_field(self) -> str: - return self._field - - def set_field(self, field: str): - self._field = field - def get_node_type(self) -> NodeType: return self._node_type diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index 9df4245b305..ee61ff3d609 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -29,6 +29,8 @@ from .symbol_tree_dumper import SymbolTreeDumper class Namer: + # for unique identity in a class-scope + # current used for target of construct-function def __init__(self): self._names: {str: int} = {} @@ -41,18 +43,30 @@ class Namer: self._names[origin_name] = number + 1 return f"{origin_name}_{number}" + def exist(self, name: str) -> bool: + return self._names.get(name) is not None + class NodeNamer(Namer): - def get_name(self, node: Node) -> str: - origin_name = node.get_name() - if len(origin_name) == 0: - targets = node.get_targets() - # return node and head node will not call this method - assert targets is not None - if len(targets) == 0: - raise RuntimeError("node should has at lease one target except return-node and head-node.", node) - else: - origin_name = str(targets[0].name) + # current used for node-name which is also used as field of init-function and key of global_vars + def get_name(self, node_or_name: Union[Node, str]) -> str: + if isinstance(node_or_name, Node): + origin_name = node_or_name.get_name() + if len(origin_name) == 0: + targets = node_or_name.get_targets() + # return node and head node will not call this method + assert targets is not None + if len(targets) == 0: + raise RuntimeError("node should has at lease one target except return-node and head-node: ", + node_or_name) + else: + origin_name = str(targets[0].name) + elif isinstance(node_or_name, str): + if len(node_or_name) == 0: + raise RuntimeError("input node_name is empty.") + origin_name = node_or_name + else: + raise RuntimeError("unexpected type of node_or_name: ", type(node_or_name)) return super(NodeNamer, self).get_name(origin_name) @@ -86,15 +100,13 @@ class SymbolTree: self._ori_cls_name = type(origin_network).__name__ self._opt_cls_name = self._ori_cls_name + "Opt" self._origin_network = origin_network + # init unique-namers + self._target_namer = Namer() + self._node_name_namer = NodeNamer() # root must be output of graph self._root = self._add_output_node(["undefined"]) # head must be the first statement but must not be inputs of graph self._head = self._add_head_node() - # init unique-namers - # todo init namer by origin source-code - self._field_namer = Namer() - self._target_namer = Namer() - self._node_name_namer = NodeNamer() def get_inputs(self): return self._inputs @@ -136,165 +148,209 @@ class SymbolTree: raise RuntimeError("Unsupported node_or_name: ", node_or_name) return False, node - @staticmethod - def _generate_node_name(targets: [Argument], args: [Argument], kwargs: {str: Argument}) ->str: - if targets is None: - if args is None and kwargs is None: - return "head-node" - else: - return "return-node" + def _unique_targets(self, node: Node): + new_targets: [Argument] = [] + for target in node.get_targets(): + assert isinstance(target, Argument) + unique_target = self._target_namer.get_name(target.name) + new_targets.append(Argument.create_naming_arg(unique_target, target.scope)) + node.set_targets(new_targets) + + def _append_node2nodes(self, node: Node): + node_name = node.get_name() + if self._nodes.get(node_name) is not None: + raise RuntimeError("generated duplicated node name", node_name, self._nodes.get(node_name), + node) + self._nodes[node_name] = node + + def _insert_node(self, position, node: Node): + self._append_node2nodes(node) + if position[0]: + position[1].insert_before(node) else: - if len(targets) == 0: - return "illegal-node" + position[1].insert_after(node) + + def _insert_blackbox_object_into_init(self, object) -> Node: + field = self._node_name_namer.get_name(f"var_{type(object).__name__}") + self._global_vars[field] = object + init_targets = [Argument.create_naming_arg(field, "self")] + construct_targets = [Argument.create_naming_arg(field)] + args = [Argument.create_imm_arg(field)] + ast_node = AstModifier.insert_global_vars_expr_to_init(self._init_func_ast, init_targets, args) + if ast_node is None: + raise RuntimeError("insert custom obj to init_func failed") + return Node(NodeType.UserCustom, ast_node, {}, construct_targets, args, {}, field, None) + + def _handle_custom_object_in_args(self, position, node: Node): + result: [Argument] = [] + for arg in node.get_args(): + assert isinstance(arg, Argument) + if arg.type == ArgType.CustomObjArg: + node = self._insert_blackbox_object_into_init(arg.name) + self._insert_node(position, node) + result.append(Argument(ArgType.NamingArg, "", node.get_name())) else: - return str(targets[0].name) + result.append(arg) + node.set_args(result) - # can only erase isolated node - def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: - if isinstance(node_or_name, Node): - node = node_or_name - elif isinstance(node_or_name, str): - node = self._nodes.get(node_or_name) - else: - raise RuntimeError("Unsupported node_or_name: ", node_or_name) - ret = AstModifier.erase_ast_from_function(self._construct_func_ast, node.get_ast()) - if not ret: - logger.error("node not in function ast tree.") - return None - for key, value in self._nodes.items(): - if id(value) == id(node): - self._nodes.pop(key) - value.isolate() - break - return node + def _handle_custom_object_in_kwargs(self, position, node: Node): + result: {str, Argument} = {} + for arg, value in node.get_kwargs(): + assert isinstance(value, Argument) + if value.type == ArgType.CustomObjArg: + node = self._insert_blackbox_object_into_init(value.name) + self._insert_node(position, node) + result[arg] = Argument(ArgType.NamingArg, "", node.get_name()) + else: + result[arg] = value + node.set_kwargs(result) + + @staticmethod + def _convert_strs_to_name_arguments(targets: [str]) -> [Argument]: + result = [] + for target in targets: + result.append(Argument(ArgType.NamingArg, "", target)) + return result - def insert_node(self, position, node: Node, field: str = None) -> Optional[Node]: + # todo move into ast_modifier + def _find_node_index(self, node: Node) -> Optional[int]: + for i in range(0, len(self._construct_func_ast.body)): + body = self._construct_func_ast.body[i] + if node.has_same_ast(body): + return i + return None + + def insert_node(self, position, node: Node, field: str = None) -> Node: # check position index = self._find_node_index(position[1]) - assert index is not None - # process node name - node.set_name(self._node_name_namer.get_name(node)) - # process field - if field is None: - field = node.get_name() - field = self._field_namer.get_name(field) + if index is None: + raise RuntimeError("index is not None: ", position[1].get_name()) # modify init function + if field is None: + field = self._node_name_namer.get_name(node) + else: + field = self._node_name_namer.get_name(field) + node.set_name(field) ast_node = AstModifier.insert_assign_to_function(self._init_func_ast, targets=[Argument(ArgType.NamingArg, "self", field)], expr=Argument(ArgType.NamingArg, "global_vars", "get"), args=[Argument(ArgType.StringArg, "", field)]) if ast_node is None: - logger.error("insert custom_node into init function ast tree failed.") - return None - # process args of node - call_args, new_arg_nodes = self._handle_custom_object_in_args(self._init_func_ast, node.get_args()) - for node in new_arg_nodes: - if not self._insert_node(position, node): - logger.error("insert custom object node into symbol_tree failed.") - return None - node.set_args(call_args) - # process kwargs of node - call_kwargs, new_kwarg_nodes = self._handle_custom_object_in_kwargs(self._init_func_ast, node.get_kwargs()) - for node in new_kwarg_nodes: - if not self._insert_node(position, node): - logger.error("insert custom object node into symbol_tree failed.") - return None - node.set_kwargs(call_kwargs) - # process targets of node - targets = node.get_targets() - targets = SymbolTree._convert_strs_to_name_arguments(targets) - node.set_targets(targets) + raise RuntimeError("insert custom_node into init function ast tree failed.") + self._global_vars[field] = node.get_op() # modify construct function - ast_node = AstModifier.insert_assign_to_function(self._construct_func_ast, targets=targets, + self._handle_custom_object_in_args(position, node) + self._handle_custom_object_in_kwargs(position, node) + self._unique_targets(node) + ast_node = AstModifier.insert_assign_to_function(self._construct_func_ast, targets=node.get_targets(), expr=Argument(ArgType.NamingArg, "self", field), - args=call_args, kwargs=call_kwargs, + args=node.get_args(), kwargs=node.get_kwargs(), index_ast=position[1].get_ast(), insert_before=position[0]) if ast_node is None: - logger.error("insert custom_node into construct function ast tree failed.") - return None + raise RuntimeError("insert custom_node into construct function ast tree failed.") node.set_ast(ast_node) - # create and insert node - node.set_field(field) - self._global_vars[field] = node.get_op() - if not self._insert_node(position, node): - return None + # insert node + self._insert_node(position, node) + return node + + def add_origin_field(self, position, origin_field_name: str, op, targets: [Argument], args: [Argument], + kwargs: {str: Argument}, ast_node: ast.AST, target_type: str = "", + attribute: {str, object}=None) -> Node: + if attribute is None: + attribute = {} + assert len(origin_field_name) > 0 + # wait for resolve, use exist + add rather than get_name + node_name = self._node_name_namer.get_name(origin_field_name) + node = Node(NodeType.UserCustom, ast_node, attribute, targets, args, kwargs, node_name, op) + self._unique_targets(node) + self._insert_node(position, node) + return node + + def _add_output_node(self, return_values: [str]) -> Node: + real_return_values = self._convert_strs_to_name_arguments(return_values) + node_name = self._node_name_namer.get_name("return") + node = Node(NodeType.Output, None, {}, None, real_return_values, {}, node_name, None) + self._append_node2nodes(node) + return node + + def _add_head_node(self) -> Node: + # todo node_type of head_node + # should input(parameter) be a node of symbol_tree + node_name = self._node_name_namer.get_name("head") + node = Node(NodeType.Unknown, None, {}, None, [], {}, node_name, None) + self._insert_node((True, self._root), node) + return node + + # can only erase isolated node + def erase_node(self, node_or_name: Union[Node, str]) -> Node: + if isinstance(node_or_name, Node): + node = node_or_name + elif isinstance(node_or_name, str): + node = self._nodes.get(node_or_name) + else: + raise RuntimeError("Unsupported node_or_name: ", node_or_name) + ret = AstModifier.erase_ast_from_function(self._construct_func_ast, node.get_ast()) + if not ret: + raise RuntimeError("node not in function ast tree.") + for key, value in self._nodes.items(): + if id(value) == id(node): + self._nodes.pop(key) + value.isolate() + break return node # todo update ast - def add_input(self, name: str, input_type: Optional[type] = None, default: Optional[str] = None) -> bool: + def add_input(self, name: str, input_type: Optional[type] = None, default: Optional[str] = None): for arg in self._inputs: if arg.name == name: - logger.error("input duplicated: %s", name) - return False + raise RuntimeError("input duplicated: %s", name) self._inputs.append(Input(name, input_type, default)) - return True def add_input_and_update_ast(self, name: str, input_type: Optional[type] = None, - default: Optional[str] = None) -> bool: + default: Optional[str] = None): for arg in self._inputs: if arg.name == name: - logger.error("input duplicated: %s", name) - return False + raise RuntimeError("input duplicated: %s", name) ret = AstModifier.insert_argument_to_function(self._construct_func_ast, name, default) if not ret: - return False + raise RuntimeError("insert_argument_to_function failed") self._inputs.append(Input(name, input_type, default)) - return True - - def add_origin_field(self, position, handler_name: str, op, targets: [Argument], args: [Argument], - kwargs: {str: Argument}, ast_node: ast.AST, target_type: str = "", - attribute: {str, object}=None) -> Optional[Node]: - if targets is None: - targets = [self._generate_new_target_name()] - targets = SymbolTree._convert_strs_to_name_arguments(targets) - assert len(handler_name) > 0 - if attribute is None: - attribute = {} - node_name = SymbolTree._generate_node_name(targets, args, kwargs) - node = Node(NodeType.UserCustom, ast_node, attribute, targets, args, kwargs, node_name, op) - node.set_field(handler_name) - if not self._insert_node(position, node): - return None - return node # todo update ast - def update_output(self, return_values: [str]) -> Optional[Node]: + def update_output(self, return_values: [str]) -> Node: if self._root is None: - return None + raise RuntimeError("SymbolTree not inited") if len(return_values) == 0: - logger.error("return_values should at least has one element") - return None + raise RuntimeError("return_values should at least has one element") real_return_values = self._convert_strs_to_name_arguments(return_values) self._root.set_args(real_return_values) return self._root - def set_output_ast(self, ast_node: ast.AST) -> Optional[Node]: + def set_output_ast(self, ast_node: ast.AST) -> Node: if self._root is None: - return None + raise RuntimeError("SymbolTree not inited") self._root.set_ast(ast_node) return self._root # todo update ast - def add_output(self, return_value: str, index: Optional[int] = None) -> Optional[Node]: + def add_output(self, return_value: str, index: Optional[int] = None) -> Node: if self._root is None: - return None + raise RuntimeError("SymbolTree not inited") new_return_value = Argument(ArgType.NamingArg, "", return_value) if index is None: self._root.add_arg(new_return_value) else: if index > len(self._root.get_args()): - logger.error("index(%d) out of range(%d)", index, len(self._root.get_targets())) - return None + raise RuntimeError("index(%d) out of range(%d)", index, len(self._root.get_targets())) self._root.add_arg(new_return_value, index) return self._root # todo update ast - def set_output(self, return_value: str, index: int) -> Optional[Node]: + def set_output(self, return_value: str, index: int) -> Node: if self._root is None: - return None + raise RuntimeError("SymbolTree not inited") if index >= len(self._root.get_args()): - logger.error("index(%d) out of range(%d)", index, len(self._root.get_targets())) - return None + raise RuntimeError("index(%d) out of range(%d)", index, len(self._root.get_targets())) new_arg: Argument = self._root.get_args()[index] new_arg.name = return_value self._root.set_arg(new_arg, index) @@ -314,47 +370,6 @@ class SymbolTree: cls = self._get_cls_through_file() return cls(self._global_vars) - # todo move into ast_modifier - def _find_node_index(self, node: Node) -> Optional[int]: - for i in range(0, len(self._construct_func_ast.body)): - body = self._construct_func_ast.body[i] - if node.has_same_ast(body): - assert i <= len(self._construct_func_ast.body) - return i - return None - - def _append_node2nodes(self, node: Node) -> bool: - node_name = node.get_name() - if self._nodes.get(node_name) is not None: - # todo wait for ast_transformer - # logger.error("generated duplicated node name(%s): %s, %s", node_name, self._nodes.get(node_name), node) - # return False - node_name = node_name + "_" + str(len(self._nodes)) - node.set_name(node_name) - self._nodes[node_name] = node - return True - - def _insert_node(self, position, node: Node) -> bool: - if not self._append_node2nodes(node): - return False - if position[0]: - position[1].insert_before(node) - else: - position[1].insert_after(node) - return True - - @staticmethod - def _generate_new_global_var_key(obj) -> str: - key = "var_" + type(obj).__name__ - return key.lower() - - @staticmethod - def _convert_strs_to_name_arguments(targets: [str]) -> [Argument]: - result = [] - for target in targets: - result.append(Argument(ArgType.NamingArg, "", target)) - return result - @staticmethod def _convert_kwargs(kwargs: {str, str}) -> {str, Argument}: result = {} @@ -362,77 +377,6 @@ class SymbolTree: result[arg] = Argument.create_imm_arg(value) return result - def _insert_blackbox_object_into_init(self, object, field: str = "") -> Node: - if len(field) == 0: - field = f"var_{type(object).__name__}" - field = self._field_namer.get_name(field) - self._global_vars[field] = object - targets = [Argument.create_naming_arg(field, "self")] - args = [Argument.create_imm_arg(field)] - ast_node = AstModifier.insert_global_vars_expr_to_init(self._init_func_ast, targets, args) - if ast_node is None: - raise RuntimeError("insert custom obj to init_func failed") - return Node(NodeType.UserCustom, ast_node, {}, targets, args, {}, field, None) - - def _handle_custom_object_in_args(self, ast_func: ast.FunctionDef, args: list[Argument]) -> \ - Tuple[list[Argument], list[Node]]: - result: [Argument] = [] - new_nodes: [Node] = [] - for arg in args: - assert isinstance(arg, Argument) - if arg.type == ArgType.CustomObjArg: - # new_key = SymbolTree._generate_new_global_var_key(arg) - # self._global_vars[new_key] = arg.name - # targets = [Argument(ArgType.NamingArg, "", new_key)] - # args = [Argument(ArgType.StringArg, "", new_key)] - # ast_node = AstModifier.insert_assign_to_function(ast_func, targets=targets, args=args, - # expr=Argument(ArgType.NamingArg, "global_vars", "get")) - node = self._insert_blackbox_object_into_init(arg.name) - # if ast_node is None: - # raise RuntimeError("insert custom obj to ast_tree failed") - # node = Node(NodeType.UserCustom, ast_node, {}, targets, args, {}, None) - new_nodes.append(node) - result.append(Argument(ArgType.NamingArg, "", node.get_field())) - else: - result.append(arg) - return result, new_nodes - - def _handle_custom_object_in_kwargs(self, ast_func: ast.FunctionDef, kwargs: dict[str, Argument]) -> \ - Tuple[dict[str, Argument], list[Node]]: - result: {str, Argument} = {} - new_nodes: [Node] = [] - for arg, value in kwargs: - assert isinstance(value, Argument) - if value.type == ArgType.CustomObjArg: - new_key = SymbolTree._generate_new_global_var_key(value) - self._global_vars[new_key] = value.name - targets = [Argument(ArgType.NamingArg, "", new_key)] - args = [Argument(ArgType.StringArg, "", new_key)] - ast_node = AstModifier.insert_assign_to_function(ast_func, targets=targets, args=args, - expr=Argument(ArgType.NamingArg, "global_vars", "get")) - if ast_node is None: - raise RuntimeError("insert custom obj to ast_tree failed") - node = Node(NodeType.CallMethod, targets, args, None) - new_nodes.append(node) - result[arg] = Argument(ArgType.NamingArg, "", new_key) - else: - result[arg] = value - return result, new_nodes - - def _add_output_node(self, return_values: [str]) -> Optional[Node]: - real_return_values = self._convert_strs_to_name_arguments(return_values) - node = Node(NodeType.Output, None, {}, None, real_return_values, {}, "return", None) - self._append_node2nodes(node) - return node - - def _add_head_node(self) -> Optional[Node]: - # todo node_type of head_node - # should input(parameter) be a node of symbol_tree - node = Node(NodeType.Unknown, None, {}, None, [], {}, "head", None) - if not self._insert_node((True, self._root), node): - return None - return node - def _get_cls_directly(self): code_obj = compile(self.get_code(), "", "exec") result_dict = {} diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py index b7ca496abd7..99a44f6caa5 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py @@ -99,7 +99,7 @@ class SymbolTreeDumper: arg_str += f"{arg_name}, " self._dump_buffer += f"({arg_str[:-2]})" - self._dump_buffer += f"{{instance name: {node.get_field()}}}" + self._dump_buffer += f"{{instance name: {node.get_name()}}}" self._dump_buffer += f" attributes {{" # todo attrs are currently None @@ -108,7 +108,7 @@ class SymbolTreeDumper: attrs_str = f"" for attr in attrs: assert type(attr) == str - attrs_str += f"{attr}, " + attrs_str += f"{attr}: {attrs[attr]}, " self._dump_buffer += attrs_str[:-2] self._dump_buffer += f"}}\n" -- Gitee From d835deaea7564ea11dc1813d49e0877112cdd2f2 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 14 Feb 2022 11:04:42 +0800 Subject: [PATCH 14/32] add sync_to_node in node --- .../rewrite_experiment/ast_parse/resolver.py | 6 +- .../rewrite_experiment/common/argument.py | 3 + .../rewrite_experiment/example/test_lenet.py | 25 +++-- .../mindspore/rewrite_experiment/node.py | 45 +++++++-- .../mindspore/rewrite_experiment/rewrite.py | 12 +-- .../rewrite_experiment/symbol_tree.py | 97 ++++++++++++++++--- 6 files changed, 152 insertions(+), 36 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py index 3bb883e4510..df77d35ff5b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py @@ -155,10 +155,10 @@ class Resolver: obj, attributes = self._get_symbol_attribute(call.func_name()) # todo call attribute_resolver to resolve attributes # todo func should be field of corresponding init-ast - ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, obj, [target], call_args, {}, - body.symbol_ast, "", attributes) + ret = symbol_tree.append_origin_field(func, obj, [target], call_args, {}, body.symbol_ast, "", + attributes) if ret is None: - raise RuntimeError("add_origin_field failed: ") + raise RuntimeError("append_origin_field failed: ") @staticmethod def get_object_attribute(obj): diff --git a/mindspore/python/mindspore/rewrite_experiment/common/argument.py b/mindspore/python/mindspore/rewrite_experiment/common/argument.py index b42ba9a79a5..5b4fa417393 100644 --- a/mindspore/python/mindspore/rewrite_experiment/common/argument.py +++ b/mindspore/python/mindspore/rewrite_experiment/common/argument.py @@ -58,3 +58,6 @@ class Argument: return f"CustomObj: {str(self.name)}" else: return f"Illegal ArgType: {str(self.type)}" + + def __repr__(self): + return str(self) diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py index f3f0de00c27..1266d06078b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py @@ -24,16 +24,17 @@ def print_code(title, rewrite): def transform(rw: Rewrite): + new_conv_node = None for _, node in rw.nodes().items(): targets = node.get_targets() if targets is None: continue assert targets[0].type == ArgType.NamingArg target = str(targets[0]) - if target == "self_flatten": + if target == "x_7": position = rw.before(node) new_conv = nn.Conv2d(16, 16, 3) - new_conv_node = Rewrite.create_node(new_conv, targets=[Argument.create_naming_arg('nx1')], target_type="", + new_conv_node = Rewrite.create_node(new_conv, targets=[Argument.create_naming_arg('x_1')], target_type="", name='new_conv') ret = rw.insert(position, new_conv_node, field='conv_new', args=[Argument.create_naming_arg('x')]) if ret is None: @@ -46,7 +47,7 @@ def transform(rw: Rewrite): continue assert targets[0].type == ArgType.NamingArg target = str(targets[0]) - if target == "self_relu_3": + if target == "x_9": position = rw.before(node) custom_cell = MyCell() new_custom_node = Rewrite.create_node(custom_cell, targets=[Argument.create_naming_arg('nx2')], @@ -63,11 +64,23 @@ def transform(rw: Rewrite): continue assert targets[0].type == ArgType.NamingArg target = str(targets[0]) - if target == "nx2": - ret = rw.update_arg(node, 0, "nx1") + if target == "x_11": + ret = rw.update_arg(node, 0, "x_6") if not ret: raise RuntimeError("Update arg failed") break + if new_conv_node is not None: + for _, node in rw.nodes().items(): + targets = node.get_targets() + if targets is None: + continue + assert targets[0].type == ArgType.NamingArg + target = str(targets[0]) + if target == "nx2": + ret = rw.update_arg_by_node(node, 0, new_conv_node) + if not ret: + raise RuntimeError("Update arg failed") + break print_code("after update_arg", rw) rw.set_output(0, "nx2") print_code("after add_output", rw) @@ -77,7 +90,7 @@ def transform(rw: Rewrite): continue assert targets[0].type == ArgType.NamingArg target = str(targets[0]) - if target == "self_fc3": + if target == "x_10": ret = rw.erase_node(node) if ret is None: raise RuntimeError("erase_node failed") diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index 6517abf3e1f..03dfb8ed98a 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ import ast +import astpretty from enum import Enum from typing import Optional, Union @@ -89,6 +90,44 @@ class Node: def get_ast(self) -> Optional[ast.AST]: return self._ast_node + def sync_to_ast(self): + if self._ast_node is None: + return + if self._node_type != NodeType.CallCell: + return + assign_ast = self._ast_node + assert isinstance(assign_ast, ast.Assign) + # update targets + targets_ast = assign_ast.targets + assert len(self._targets) == len(targets_ast) + for i in range(0, len(self._targets)): + target = self._targets[i] + target_ast = targets_ast[i] + assert isinstance(target_ast, ast.Name) + target_ast.id = target.name + # update args + call_ast = assign_ast.value + assert isinstance(call_ast, ast.Call) + args_ast = call_ast.args + assert len(self._args) == len(args_ast) + for i in range(0, len(self._args)): + arg = self._args[i] + assert isinstance(arg, Argument) + arg_ast = args_ast[i] + if isinstance(arg_ast, ast.Name): + assert arg.scope == "" + arg_ast.id = arg.name + elif isinstance(arg_ast, ast.Attribute): + arg_value_ast = arg_ast.value + assert isinstance(arg_value_ast, ast.Name) + arg_value_ast.id = arg.scope + arg_ast.attr = arg.name + else: + raise RuntimeError("Unsupported arg type: ", arg_ast) + # update kwargs + if len(self._kwargs) != 0: + raise RuntimeError("kwargs is not unsupported now") + def set_ast(self, ast_node: ast.AST): assert isinstance(ast_node, ast.AST) self._ast_node = ast_node @@ -183,12 +222,6 @@ class Node: def get_inputs(self) -> ['Node']: return self._inputs - # def set_inputs(self, nodes: list): - # self._inputs = nodes - - # def set_targets(self, targets: [str]): - # self._targets = targets - def get_targets(self) -> [Argument]: return self._targets diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index e9a4799e16a..2b15cb8275d 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -59,7 +59,7 @@ class Rewrite: # todo define class position: use symbol_tree, node and before/after as position def insert(self, position, node: Node, field: str = None, args: [Argument] = None, kwargs: {str: Argument} = None) \ - -> Optional[Node]: + -> Node: # self.'field': 'custom_obj_type' = global_vars.get('field') if args is not None: node.set_args(args) @@ -67,18 +67,18 @@ class Rewrite: node.set_kwargs(kwargs) return self._symbol_tree.insert_node(position, node, field) - def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: + def erase_node(self, node_or_name: Union[Node, str]) -> Node: return self._symbol_tree.erase_node(node_or_name) - def set_output(self, index: int, return_value: str) -> Optional[Node]: + def set_output(self, index: int, return_value: str) -> Node: return self._symbol_tree.set_output(return_value, index) def update_arg(self, node: Node, index: int, arg: str) -> bool: return node.update_arg(index, arg) - def update_arg_by_node(self, node_to_update: Node, arg_idx: int, node_to_link: 'Node', - out_idx: Optional[int] = None): - node_to_update.update_arg_by_node(arg_idx, node_to_link, out_idx) + def update_arg_by_node(self, node_to_update: Node, arg_idx: int, node_to_link: Node, + out_idx: Optional[int] = None) -> bool: + return node_to_update.update_arg_by_node(arg_idx, node_to_link, out_idx) def dump(self): self._symbol_tree.dump() diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index ee61ff3d609..3abe727ea4f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -34,7 +34,23 @@ class Namer: def __init__(self): self._names: {str: int} = {} + @staticmethod + def _real_name(name: str): + pos = name.rfind("_") + if pos == -1: + return name + digit = True + for i in range(pos + 1, len(name)): + if not name[i].isdigit(): + digit = False + break + if digit: + return Namer._real_name(name[:pos]) + else: + return name + def get_name(self, origin_name: str) -> str: + origin_name = Namer._real_name(origin_name) number = self._names.get(origin_name) if number is None: self._names[origin_name] = 1 @@ -46,6 +62,19 @@ class Namer: def exist(self, name: str) -> bool: return self._names.get(name) is not None + def get_real_arg(self, origin_arg: str) -> str: + num = self._names.get(origin_arg) + if num is None or num == 1: + return origin_arg + else: + return f"{origin_arg}_{num - 1}" + + def add_name(self, name: str): + number = self._names.get(name) + if number is not None: + raise RuntimeError("name duplicated: ", name) + self._names[name] = 1 + class NodeNamer(Namer): # current used for node-name which is also used as field of init-function and key of global_vars @@ -150,12 +179,43 @@ class SymbolTree: def _unique_targets(self, node: Node): new_targets: [Argument] = [] + if node.get_targets() is None: + return for target in node.get_targets(): assert isinstance(target, Argument) unique_target = self._target_namer.get_name(target.name) new_targets.append(Argument.create_naming_arg(unique_target, target.scope)) node.set_targets(new_targets) + def _update_args_for_unique(self, node: Node): + result: [Argument] = [] + if node.get_args() is None: + return + for arg in node.get_args(): + assert isinstance(arg, Argument) + assert arg.type != ArgType.CustomObjArg + if arg.type == ArgType.NamingArg: + # unique name + new_arg = Argument(arg.type, arg.scope, self._target_namer.get_real_arg(arg.name)) + result.append(new_arg) + else: + result.append(arg) + node.set_args(result) + + def _update_kwargs_for_unique(self, node: Node): + result: {str, Argument} = {} + if node.get_kwargs() is None: + return + for arg, value in node.get_kwargs(): + assert isinstance(value, Argument) + assert arg.type != ArgType.CustomObjArg + if arg.type == ArgType.NamingArg: + new_arg = Argument(value.type, value.scope, self._target_namer.get_real_arg(value.name)) + result[arg] = new_arg + else: + result[arg] = value + node.set_kwargs(result) + def _append_node2nodes(self, node: Node): node_name = node.get_name() if self._nodes.get(node_name) is not None: @@ -164,6 +224,11 @@ class SymbolTree: self._nodes[node_name] = node def _insert_node(self, position, node: Node): + # unique targets, name while insert node into symbol_tree + self._update_args_for_unique(node) + self._update_kwargs_for_unique(node) + self._unique_targets(node) + node.sync_to_ast() self._append_node2nodes(node) if position[0]: position[1].insert_before(node) @@ -181,26 +246,28 @@ class SymbolTree: raise RuntimeError("insert custom obj to init_func failed") return Node(NodeType.UserCustom, ast_node, {}, construct_targets, args, {}, field, None) - def _handle_custom_object_in_args(self, position, node: Node): + def _handle_custom_obj_in_args(self, position, node: Node): result: [Argument] = [] for arg in node.get_args(): assert isinstance(arg, Argument) if arg.type == ArgType.CustomObjArg: node = self._insert_blackbox_object_into_init(arg.name) self._insert_node(position, node) - result.append(Argument(ArgType.NamingArg, "", node.get_name())) + new_arg = self._target_namer.get_real_arg(node.get_name()) + result.append(Argument(ArgType.NamingArg, "self", new_arg)) else: result.append(arg) node.set_args(result) - def _handle_custom_object_in_kwargs(self, position, node: Node): + def _handle_custom_obj_in_kwargs(self, position, node: Node): result: {str, Argument} = {} for arg, value in node.get_kwargs(): assert isinstance(value, Argument) if value.type == ArgType.CustomObjArg: node = self._insert_blackbox_object_into_init(value.name) self._insert_node(position, node) - result[arg] = Argument(ArgType.NamingArg, "", node.get_name()) + new_arg = self._target_namer.get_real_arg(node.get_name()) + result[arg] = Argument(ArgType.NamingArg, "self", new_arg) else: result[arg] = value node.set_kwargs(result) @@ -239,9 +306,8 @@ class SymbolTree: raise RuntimeError("insert custom_node into init function ast tree failed.") self._global_vars[field] = node.get_op() # modify construct function - self._handle_custom_object_in_args(position, node) - self._handle_custom_object_in_kwargs(position, node) - self._unique_targets(node) + self._handle_custom_obj_in_args(position, node) + self._handle_custom_obj_in_kwargs(position, node) ast_node = AstModifier.insert_assign_to_function(self._construct_func_ast, targets=node.get_targets(), expr=Argument(ArgType.NamingArg, "self", field), args=node.get_args(), kwargs=node.get_kwargs(), @@ -249,21 +315,21 @@ class SymbolTree: if ast_node is None: raise RuntimeError("insert custom_node into construct function ast tree failed.") node.set_ast(ast_node) - # insert node self._insert_node(position, node) return node - def add_origin_field(self, position, origin_field_name: str, op, targets: [Argument], args: [Argument], - kwargs: {str: Argument}, ast_node: ast.AST, target_type: str = "", - attribute: {str, object}=None) -> Node: + def append_origin_field(self, origin_field_name: str, op, targets: [Argument], args: [Argument], + kwargs: {str: Argument}, ast_node: ast.AST, target_type: str = "", + attribute: {str, object}=None) -> Node: + if self._root is None: + raise RuntimeError("SymbolTree not inited") if attribute is None: attribute = {} assert len(origin_field_name) > 0 - # wait for resolve, use exist + add rather than get_name + # todo wait for resolve, use exist + add rather than get_name node_name = self._node_name_namer.get_name(origin_field_name) - node = Node(NodeType.UserCustom, ast_node, attribute, targets, args, kwargs, node_name, op) - self._unique_targets(node) - self._insert_node(position, node) + node = Node(NodeType.CallCell, ast_node, attribute, targets, args, kwargs, node_name, op) + self._insert_node((True, self._root), node) return node def _add_output_node(self, return_values: [str]) -> Node: @@ -305,6 +371,7 @@ class SymbolTree: if arg.name == name: raise RuntimeError("input duplicated: %s", name) self._inputs.append(Input(name, input_type, default)) + self._target_namer.add_name(name) def add_input_and_update_ast(self, name: str, input_type: Optional[type] = None, default: Optional[str] = None): -- Gitee From 7b89ff02fcc63307168460e07e58f22b936661c3 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 14 Feb 2022 11:09:38 +0800 Subject: [PATCH 15/32] add rename construct f assign target add rename construct f assign target add rename construct f assign target add rename construct f assign target --- .../rewrite_experiment/ast_parse/ast_parse.py | 4 +- .../ast_parse/ast_transformers/__init__.py | 3 +- .../rename_construct_f_assign_target.py | 120 ++++++++++++++++++ .../rewrite_experiment/common/ast_modifier.py | 8 ++ .../rewrite_experiment/example/test_lenet.py | 2 +- .../rewrite_experiment/symbol_tree.py | 3 + .../rewrite_experiment/symbol_tree_dumper.py | 4 +- 7 files changed, 138 insertions(+), 6 deletions(-) create mode 100644 mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/rename_construct_f_assign_target.py diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py index 227fc3dfc0e..956063b88dc 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py @@ -31,13 +31,13 @@ from .symbol_transformers.clear_non_symbol_in_stb import ClearNonSymbolInSTB from .symbol_transformers.rename_construct_function_assign_target import RenameConstructFuncAssignTarget from .resolver import Resolver from ..symbol_tree import SymbolTree -from .ast_transformers import FlattenRecursiveCall, FoldBinop +from .ast_transformers import FlattenRecursiveCall, FoldBinop, RenameConstructFAssignTarget class AstParse: @staticmethod def _ast_transform(ast_root: ast.AST) -> ast.AST: - transform_list = [FoldBinop(), FlattenRecursiveCall()] + transform_list = [FoldBinop(), FlattenRecursiveCall(), RenameConstructFAssignTarget()] for transformer in transform_list: ast_root = transformer.transform(ast_root) return ast_root diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py index 589e57574df..1bafdf38c7e 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py @@ -1,4 +1,5 @@ from .flatten_recursive_call import FlattenRecursiveCall from .const_fold import FoldBinop +from .rename_construct_f_assign_target import RenameConstructFAssignTarget -__all__ = ["FlattenRecursiveCall", "FoldBinop"] +__all__ = ["FlattenRecursiveCall", "FoldBinop", "RenameConstructFAssignTarget"] diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/rename_construct_f_assign_target.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/rename_construct_f_assign_target.py new file mode 100644 index 00000000000..da67d16e7e5 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/rename_construct_f_assign_target.py @@ -0,0 +1,120 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""rename construct fun assign target""" +import ast +from typing import Any +from mindspore import log as logger + + +class RenameConstructFAssignTarget(ast.NodeTransformer): + @staticmethod + def _init_last_target_with_arguments(node) -> {str, str}: + last_target: {str, str} = {} + arguments: ast.arguments = node.args + for arg in arguments.args: + if arg.arg == "self": + continue + if not isinstance(arg.arg, str): + raise RuntimeError("arg of arguments is not str") + last_target[arg.arg] = arg.arg + return last_target + + @staticmethod + def _generate_final_target(used_target_name: {str, int}, func_name: str): + used_number = used_target_name.get(func_name) + if used_number is None: + used_target_name[func_name] = 1 + else: + used_target_name[func_name] = used_number + 1 + func_name = func_name + "_" + str(used_number) + return func_name + + def _process_assign(self, body: ast.Assign, last_target: {str, str}, used_target_name: {str, int}): + call = body.value + if not isinstance(call, ast.Call): + return + keywords = call.keywords + if len(keywords) > 0: + raise RuntimeError("keyword in call are not supported") + args = call.args + if not isinstance(call.func, ast.Attribute): + raise RuntimeError("func of call is not ast.Attribute") + for i, arg in enumerate(args): + if not isinstance(arg, ast.Name): + raise RuntimeError("arg of call is not ast.Name") + new_arg_str = last_target.get(arg.id) + if new_arg_str is None: + raise RuntimeError(f"Undefined arg {arg.id} in args of call: {call.func.attr}") + body.value.args[i].id = new_arg_str + + if not isinstance(call.func.value, ast.Name): + raise RuntimeError("value of call.func is not ast.Name") + target_name = self._generate_final_target(used_target_name, "{}_{}".format(call.func.value.id, call.func.attr)) + targets = body.targets + if len(targets) == 1: + single_target = targets[0] + if not isinstance(single_target, ast.Name): + raise RuntimeError("target is not ast.Name") + last_target[body.targets[0].id] = target_name + body.targets[0].id = target_name + else: + for i in range(0, len(targets)): + single_target = targets[i] + if not isinstance(single_target, ast.Name): + raise RuntimeError("target is not ast.Name") + new_target_name = "{}_{}".format(target_name, i) + last_target[body.targets[i].id] = new_target_name + body.targets[i].id = new_target_name + + @staticmethod + def _process_return(body: ast.Return, last_target: {str, str}): + if isinstance(body.value, ast.Name): + if not isinstance(body.value.id, str): + raise RuntimeError("value of return is not str") + return_value = body.value.id + new_target_name = last_target.get(return_value) + if new_target_name is None: + raise RuntimeError(f"Undefined return value {return_value} in return statement") + body.value.id = new_target_name + elif isinstance(body.value, ast.Tuple): + for return_v in body.value.elts: + new_target_name = last_target.get(return_v.id) + if new_target_name is None: + raise RuntimeError(f"Undefined return value {return_v.id} in return statement") + return_v.id = new_target_name + else: + raise RuntimeError(f"return node type {type(body.value)} is not supported yet") + + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + if node.name != "construct": + return node + + last_target: {str, str} = self._init_last_target_with_arguments(node) + used_target_name: {str, int} = {} + bodies = node.body + for body in bodies: + if isinstance(body, ast.Assign): + self._process_assign(body, last_target, used_target_name) + elif isinstance(body, ast.Return): + self._process_return(body, last_target) + else: + logger.warning(f"Ignoring {type(body)} in renameConstructFAssignTarget") + + return node + + def transform(self, ast_root): + ast_root = self.generic_visit(ast_root) + ast_root = ast.fix_missing_locations(ast_root) + return ast_root diff --git a/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py b/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py index 58b568c8e67..60862cc3d2b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py +++ b/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py @@ -153,3 +153,11 @@ class AstModifier(ast.NodeTransformer): result = ast.Call(func=ast_func, args=ast_args, keywords=keywords) ast.fix_missing_locations(result) return result + + @staticmethod + def set_output(construct_ast: ast.FunctionDef, output_value: str): + for body in construct_ast.body: + if isinstance(body, ast.Return) and isinstance(body.value, ast.Name): + body.value.id = output_value + return True + return False diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py index 706216fd785..f1840828028 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py @@ -34,7 +34,7 @@ def transform(rw: Rewrite): new_conv = nn.Conv2d(16, 16, 3) new_conv_node = Rewrite.create_node(new_conv, targets=[Argument.create_naming_arg('nx1')], target_type="", name='new_conv') - ret = rw.insert(position, new_conv_node, field='conv_new', args=[Argument.create_naming_arg('x')]) + ret = rw.insert(position, new_conv_node, field='conv_new', args=[Argument.create_naming_arg('self_max_pool2d_1')]) if ret is None: raise RuntimeError("add_cell failed") break diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index 9df4245b305..91c0c93c802 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -298,6 +298,9 @@ class SymbolTree: new_arg: Argument = self._root.get_args()[index] new_arg.name = return_value self._root.set_arg(new_arg, index) + + if not AstModifier.set_output(self._construct_func_ast, return_value): + raise RuntimeError("ast set output fail") return self._root def dump(self): diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py index b7ca496abd7..a042963306f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree_dumper.py @@ -16,7 +16,7 @@ import inspect from mindspore import log as logger -from .node import Node +from .node import Node, NodeType from .common import Argument @@ -57,7 +57,7 @@ class SymbolTreeDumper: node: Node = self._symbol_tree.get_head_node().get_next() while node is not None: - if node.get_name() == "output": + if node.get_node_type() is NodeType.Output: self._dump_buffer += f" Return(%{node_no}) \n" self._dump_buffer += f" : (null) \n" self._dump_buffer += f" # In file {inspect.getfile(type(self._symbol_tree.get_origin_network()))}" -- Gitee From edfdd71618e8136014b9278d115ea4526ee5b430 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 14 Feb 2022 15:10:30 +0800 Subject: [PATCH 16/32] fix bug --- .../python/mindspore/rewrite_experiment/ast_parse/ast_parse.py | 2 +- .../ast_parse/symbol_transformers/const_symbol_fold.py | 3 ++- .../python/mindspore/rewrite_experiment/example/test_lenet.py | 2 +- mindspore/python/mindspore/rewrite_experiment/symbol_tree.py | 1 + 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py index 956063b88dc..bac806d7819 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py @@ -37,7 +37,7 @@ from .ast_transformers import FlattenRecursiveCall, FoldBinop, RenameConstructFA class AstParse: @staticmethod def _ast_transform(ast_root: ast.AST) -> ast.AST: - transform_list = [FoldBinop(), FlattenRecursiveCall(), RenameConstructFAssignTarget()] + transform_list = [FoldBinop(), FlattenRecursiveCall()] for transformer in transform_list: ast_root = transformer.transform(ast_root) return ast_root diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_fold.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_fold.py index 7d93698b16b..b2badf462a9 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_fold.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_fold.py @@ -24,7 +24,8 @@ from mindspore import log as logger class ConstSymbolFold(Processor): def __init__(self): - self._fold_fn_map = {SymbolType.bin_op: ConstSymbolFold.fold_binop_symbol, + self._fold_fn_map = { + # SymbolType.bin_op: ConstSymbolFold.fold_binop_symbol, # SymbolType.name: ConstSymbolFold.fold_name_symbol, SymbolType.attribute: ConstSymbolFold.fold_attribute_symbol, SymbolType.arg: ConstSymbolFold.fold_arg_symbol, diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py index c745eab5e80..a43eca9da85 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py @@ -20,7 +20,7 @@ class MyCell(nn.Cell): def print_code(title, rewrite): print(f"=========={title}===================================================================================") print(rewrite.get_code()) - rewrite.dump() + # rewrite.dump() def transform(rw: Rewrite): diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index ae18582c3c3..b5f216dfcd3 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -227,6 +227,7 @@ class SymbolTree: # unique targets, name while insert node into symbol_tree self._update_args_for_unique(node) self._update_kwargs_for_unique(node) + # _unique_targets must called after _update_args_for_unique and _update_kwargs_for_unique self._unique_targets(node) node.sync_to_ast() self._append_node2nodes(node) -- Gitee From 72e6055cf425969d65a949e4f9e21ed09aeec42d Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 14 Feb 2022 15:27:08 +0800 Subject: [PATCH 17/32] fix return value after nodes are updated --- .../rewrite_experiment/ast_parse/resolver.py | 23 +++-------- .../rewrite_experiment/ast_parse/symbol.py | 6 +-- .../rewrite_experiment/example/test_lenet.py | 4 +- .../mindspore/rewrite_experiment/node.py | 39 ++++++++++++++++--- .../rewrite_experiment/symbol_tree.py | 39 ++++++++++--------- 5 files changed, 64 insertions(+), 47 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py index df77d35ff5b..a8742548f8b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py @@ -60,7 +60,6 @@ class Resolver: def _find_module_symbol(self) -> Optional[ModuleSymbol]: for k, v in self._stb.items(): symbol = v.value() - # todo check is subclass of cell if isinstance(symbol, ModuleSymbol): return symbol return None @@ -129,15 +128,6 @@ class Resolver: symbol_tree.add_input(arguments.value().value) # resolve body bodies = symbol_construct_fn.get_bodies() - # update return node first - for body in bodies: - if body is None: - continue - if body.symbol_type() == SymbolType.return_: - symbol_tree.set_output_ast(body.symbol_ast) - ret = symbol_tree.update_output([body.get_return_str()]) - if ret is None: - raise RuntimeError("update_output failed") for body in bodies: if body is None: continue @@ -153,12 +143,13 @@ class Resolver: raise RuntimeError("kwargs in construct function assign is unsupported") obj, attributes = self._get_symbol_attribute(call.func_name()) - # todo call attribute_resolver to resolve attributes # todo func should be field of corresponding init-ast - ret = symbol_tree.append_origin_field(func, obj, [target], call_args, {}, body.symbol_ast, "", - attributes) - if ret is None: - raise RuntimeError("append_origin_field failed: ") + symbol_tree.append_origin_field(func, obj, [target], call_args, {}, body.symbol_ast, "", attributes) + for body in bodies: + if body is None: + continue + if body.symbol_type() == SymbolType.return_: + symbol_tree.append_return(body.symbol_ast, body.get_return_strs()) @staticmethod def get_object_attribute(obj): @@ -221,7 +212,6 @@ class Resolver: @staticmethod def _replace_ori_field_of_init_func(origin_field: ast.AST): if not isinstance(origin_field, ast.Assign): - # todo remove this ast node return assign: ast.Assign = origin_field if len(assign.targets) != 1: @@ -233,7 +223,6 @@ class Resolver: raise RuntimeError("Only support target.value in ast.Name now!") target_value: ast.Name = target.value if target_value.id != "self": - # todo remove this ast node return field_name = target.attr assign.value = ast.Call(ast.Name('getattr', ast.Load()), diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py index a98749a80f7..e23894632c4 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py @@ -688,12 +688,12 @@ class ReturnSymbol(Symbol): def get_return_value(self) -> Value: return self._value.value() - def get_return_str(self) -> str: + def get_return_strs(self) -> [str]: value = self._value.value() if value.value_type == ValueType.String: - return value.value + return [value.value] elif isinstance(value, NameSymbol): - return value.get_name_name().value + return [value.get_name_name().value] else: raise RuntimeError("Return value should be a StringValue or NameSymbol") diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py index a43eca9da85..6096901d7a4 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py @@ -36,7 +36,7 @@ def transform(rw: Rewrite): new_conv = nn.Conv2d(16, 16, 3) new_conv_node = Rewrite.create_node(new_conv, targets=[Argument.create_naming_arg('x_1')], target_type="", name='new_conv') - ret = rw.insert(position, new_conv_node, field='conv_new', args=[Argument.create_naming_arg('self_max_pool2d_1')]) + ret = rw.insert(position, new_conv_node, field='conv_new', args=[Argument.create_naming_arg('self_max_po')]) if ret is None: raise RuntimeError("add_cell failed") break @@ -82,7 +82,7 @@ def transform(rw: Rewrite): raise RuntimeError("Update arg failed") break print_code("after update_arg", rw) - rw.set_output(0, "nx2") + rw.set_output(0, "x_9") print_code("after add_output", rw) for _, node in rw.nodes().items(): targets = node.get_targets() diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index 03dfb8ed98a..a66e4720627 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -90,11 +90,7 @@ class Node: def get_ast(self) -> Optional[ast.AST]: return self._ast_node - def sync_to_ast(self): - if self._ast_node is None: - return - if self._node_type != NodeType.CallCell: - return + def _sync_assign_node_to_ast(self): assign_ast = self._ast_node assert isinstance(assign_ast, ast.Assign) # update targets @@ -119,7 +115,8 @@ class Node: arg_ast.id = arg.name elif isinstance(arg_ast, ast.Attribute): arg_value_ast = arg_ast.value - assert isinstance(arg_value_ast, ast.Name) + if not isinstance(arg_value_ast, ast.Name): + raise RuntimeError("Only support ast.Name as argument: ", arg_value_ast) arg_value_ast.id = arg.scope arg_ast.attr = arg.name else: @@ -128,6 +125,36 @@ class Node: if len(self._kwargs) != 0: raise RuntimeError("kwargs is not unsupported now") + def _sync_return_node_to_ast(self): + return_ast = self._ast_node + assert isinstance(return_ast, ast.Return) + # update args + return_value_ast = return_ast.value + if isinstance(return_value_ast, ast.Name): + assert len(self._args) == 1 + return_value_ast.id = self._args[0].name + elif isinstance(return_value_ast, ast.Tuple): + elements = return_value_ast.elts + assert len(self._args) == len(elements) + for i in range(0, len(elements)): + ele = elements[i] + if not isinstance(ele, ast.Name): + raise RuntimeError("Only support ast.Name as return value: ", ele) + arg = self._args[i] + assert isinstance(arg, Argument) + ele.id = arg.name + else: + raise RuntimeError("Unsupported return value type: ", return_value_ast) + + def sync_to_ast(self): + if self._ast_node is None: + return + if self._node_type == NodeType.CallCell: + self._sync_assign_node_to_ast() + elif self._node_type == NodeType.Output: + self._sync_return_node_to_ast() + + def set_ast(self, ast_node: ast.AST): assert isinstance(ast_node, ast.AST) self._ast_node = ast_node diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index b5f216dfcd3..543f5d8eedc 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -59,9 +59,6 @@ class Namer: self._names[origin_name] = number + 1 return f"{origin_name}_{number}" - def exist(self, name: str) -> bool: - return self._names.get(name) is not None - def get_real_arg(self, origin_arg: str) -> str: num = self._names.get(origin_arg) if num is None or num == 1: @@ -333,6 +330,26 @@ class SymbolTree: self._insert_node((True, self._root), node) return node + def append_return(self, return_ast: ast.AST, return_values: [str]) -> Node: + if self._root is None: + raise RuntimeError("SymbolTree not inited") + self._root.set_ast(return_ast) + self.update_output(return_values) + return self._root + + def update_output(self, return_values: [str]) -> Node: + if self._root is None: + raise RuntimeError("SymbolTree not inited") + if len(return_values) == 0: + raise RuntimeError("return_values should at least has one element") + unique_return_values = [] + for return_value in return_values: + unique_return_values.append(self._target_namer.get_real_arg(return_value)) + unique_return_args = self._convert_strs_to_name_arguments(unique_return_values) + self._root.set_args(unique_return_args) + self._root.sync_to_ast() + return self._root + def _add_output_node(self, return_values: [str]) -> Node: real_return_values = self._convert_strs_to_name_arguments(return_values) node_name = self._node_name_namer.get_name("return") @@ -384,22 +401,6 @@ class SymbolTree: raise RuntimeError("insert_argument_to_function failed") self._inputs.append(Input(name, input_type, default)) - # todo update ast - def update_output(self, return_values: [str]) -> Node: - if self._root is None: - raise RuntimeError("SymbolTree not inited") - if len(return_values) == 0: - raise RuntimeError("return_values should at least has one element") - real_return_values = self._convert_strs_to_name_arguments(return_values) - self._root.set_args(real_return_values) - return self._root - - def set_output_ast(self, ast_node: ast.AST) -> Node: - if self._root is None: - raise RuntimeError("SymbolTree not inited") - self._root.set_ast(ast_node) - return self._root - # todo update ast def add_output(self, return_value: str, index: Optional[int] = None) -> Node: if self._root is None: -- Gitee From 7023a6f7e3e69bcc9bb69fc0e3771179dc182a5b Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Tue, 15 Feb 2022 09:15:08 +0800 Subject: [PATCH 18/32] update namespace --- .../mindspore/rewrite_experiment/ast_parse/namespace.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py index 32952847417..6a5cefaa614 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py @@ -77,6 +77,13 @@ class Namespace(): logger.warning(f"get namespace failed, func_name: {symbole_name}") return None, None, True + def is_leaf_node(self, cls_name): + if cls_name in self.ms_common_ns or cls_name in self.ms_nn_ns or cls_name in self.ms_ops_ns or \ + cls_name in self.ms_ops_c_ns or cls_name in self.ms_ops_c_multitype_ns or cls_name in self.ms_ops_p_ns: + return True + else: + return False + class NameDepot: """ Base name depot for symbol. -- Gitee From df88d99a5f354b10524329abade3f64c7e3b725a Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Tue, 15 Feb 2022 09:17:11 +0800 Subject: [PATCH 19/32] update for leaf node --- .../mindspore/rewrite_experiment/ast_parse/resolver.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py index 98b6c674b64..1c8df458e57 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py @@ -9,7 +9,7 @@ from .symbol import FunctionSymbol, SymbolType, ClassSymbol, ArgumentsSymbol, As from .value import ValueType, ImmValue, Value from ..symbol_tree import SymbolTree from ..common import Argument, AstModifier - +from .namespace import Namespace class Resolver: def __init__(self, net, stb: SymbolTable): @@ -19,6 +19,7 @@ class Resolver: self._ori_cls_name = type(net).__name__ self._opt_cls_name = self._ori_cls_name + "Opt" self._module_symbol = self._find_module_symbol() + self._namespace = Namespace() def resolve(self) -> SymbolTree: # process module ast @@ -43,6 +44,7 @@ class Resolver: self._symbol_trees.append(self._create_symbol_tree(class_symbol, net)) def _create_symbol_tree(self, class_symbol: ClassSymbol, net) -> SymbolTree: + self._namespace.update_global_namespace(self._origin_net) # process class ast self._process_class_ast(class_symbol.symbol_ast) # process __init__ function in class ast @@ -144,6 +146,11 @@ class Resolver: raise RuntimeError("kwargs in construct function assign is unsupported") obj, attributes = self._get_symbol_attribute(call.func_name()) + if not obj: + raise RuntimeError("get symbole attribute failed") + print("cls name: ", attributes["cls"].__name__) + if not self._namespace.is_leaf_node(attributes["cls"].__name__): + raise NotImplemented("maybe a subgraph") # todo call attribute_resolver to resolve attributes ret = symbol_tree.add_origin_field((True, symbol_tree.get_root()), func, obj, [target], call_args, {}, body.symbol_ast, attribute=attributes) -- Gitee From b84fd26cf935c43278ed6ccbad171c1238c2b51e Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 14 Feb 2022 17:21:21 +0800 Subject: [PATCH 20/32] sync ast from node --- .../rewrite_experiment/common/argument.py | 12 ++ .../rewrite_experiment/common/ast_modifier.py | 31 +-- .../rewrite_experiment/example/test_lenet.py | 92 ++++----- .../mindspore/rewrite_experiment/node.py | 189 +++++++++++------- .../mindspore/rewrite_experiment/rewrite.py | 22 +- .../rewrite_experiment/symbol_tree.py | 179 ++++++++--------- 6 files changed, 284 insertions(+), 241 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/common/argument.py b/mindspore/python/mindspore/rewrite_experiment/common/argument.py index 5b4fa417393..8dd675c2cc6 100644 --- a/mindspore/python/mindspore/rewrite_experiment/common/argument.py +++ b/mindspore/python/mindspore/rewrite_experiment/common/argument.py @@ -49,6 +49,13 @@ class Argument: def create_custom_arg(cls, obj) -> 'Argument': return cls(ArgType.CustomObjArg, "", obj) + @staticmethod + def create_name_arguments(strs: [str]) -> ['Argument']: + result = [] + for string in strs: + result.append(Argument.create_naming_arg(string)) + return result + def __str__(self): if self.type in (ArgType.IntArg, ArgType.FloatArg, ArgType.StringArg): return str(self.name) @@ -59,5 +66,10 @@ class Argument: else: return f"Illegal ArgType: {str(self.type)}" + def __eq__(self, other): + if id(self) == id(other): + return True + return self.type == other.type and self.scope == other.scope and self.name == other.name + def __repr__(self): return str(self) diff --git a/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py b/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py index 60862cc3d2b..2e785dba2e5 100644 --- a/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py +++ b/mindspore/python/mindspore/rewrite_experiment/common/ast_modifier.py @@ -68,32 +68,35 @@ class AstModifier(ast.NodeTransformer): @staticmethod def insert_assign_to_function(ast_func: ast.FunctionDef, targets: [Argument], expr: Argument, args: [Argument] = None, kwargs: {str, Argument}=None, - index_ast: Optional[ast.AST] = None, insert_before=True) -> Optional[ast.AST]: + index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST: assign = AstModifier.create_assign(targets, expr, args, kwargs) - if assign is None: - logger.warning("create_assign failed") - return None + return AstModifier.insert_assign_ast_to_function(ast_func, assign, index_ast, insert_before) + + @staticmethod + def insert_assign_ast_to_function(ast_func: ast.FunctionDef, ast_assign: ast.Assign, + index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST: if index_ast is None: - ast_func.body.append(assign) - return assign + ast_func.body.append(ast_assign) + ast.fix_missing_locations(ast_func) + return ast_assign else: for index in range(0, len(ast_func.body)): if id(ast_func.body[index]) == id(index_ast): if insert_before: - ast_func.body.insert(index, assign) + ast_func.body.insert(index, ast_assign) else: - ast_func.body.insert(index + 1, assign) - return assign - return None + ast_func.body.insert(index + 1, ast_assign) + ast.fix_missing_locations(ast_func) + return ast_assign + raise RuntimeError("index_ast is not contained in ast_func") @staticmethod - def insert_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [Argument], args: [Argument]) -> \ - Optional[ast.AST]: + def insert_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [Argument], args: [Argument]) -> ast.AST: return AstModifier.insert_assign_to_function(init_func, targets=targets, args=args, expr=Argument(ArgType.NamingArg, "global_vars", "get")) @staticmethod - def create_assign(targets: [Argument], expr: Argument, args: [Argument], kwargs: {str, Argument}): + def create_assign(targets: [Argument], expr: Argument, args: [Argument], kwargs: {str, Argument}) -> ast.Assign: if targets is None or len(targets) != 1: raise RuntimeError("Only support one target in insert_cell_to_init now") if targets[0].type != ArgType.NamingArg: @@ -108,7 +111,7 @@ class AstModifier(ast.NodeTransformer): return result @staticmethod - def create_call(expr: Argument, args: [Argument] = None, kwargs: {str: Argument} = None) -> Optional[ast.Call]: + def create_call(expr: Argument, args: [Argument] = None, kwargs: {str: Argument} = None) -> ast.Call: if expr.type == ArgType.CustomObjArg: raise RuntimeError("Please handle custom-object first") has_arg = args is not None or kwargs is not None diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py index 6096901d7a4..b1389a645b3 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py @@ -12,8 +12,9 @@ class MyCell(nn.Cell): super().__init__() self.conv = nn.Dense(5, 16) - def construct(self, x): + def construct(self, x, y): x = self.conv(x) + x = mindspore.ops.Add()(x, y) return x @@ -23,78 +24,63 @@ def print_code(title, rewrite): # rewrite.dump() -def transform(rw: Rewrite): +def add_conv_before_flatten(rewrite: Rewrite): new_conv_node = None - for _, node in rw.nodes().items(): - targets = node.get_targets() - if targets is None: - continue - assert targets[0].type == ArgType.NamingArg - target = str(targets[0]) - if target == "x_7": - position = rw.before(node) + for _, node in rewrite.nodes().items(): + if node.get_op_type() == mindspore.nn.Flatten: + position = rewrite.before(node) new_conv = nn.Conv2d(16, 16, 3) new_conv_node = Rewrite.create_node(new_conv, targets=[Argument.create_naming_arg('x_1')], target_type="", - name='new_conv') - ret = rw.insert(position, new_conv_node, field='conv_new', args=[Argument.create_naming_arg('self_max_po')]) - if ret is None: - raise RuntimeError("add_cell failed") + name='new_conv', args=[Argument.create_naming_arg('self_max_po')]) + rewrite.insert(position, new_conv_node, field='conv_new') break - print_code("after add_cell", rw) - for _, node in rw.nodes().items(): + if new_conv_node is not None: + for _, node in rewrite.nodes().items(): + if node.get_op_type() == mindspore.nn.Flatten: + inputs = node.get_inputs() + assert len(inputs) == 1 + rewrite.set_arg_by_node(new_conv_node, 0, inputs[0]) + + +def add_my_cell_after_x_11(rewrite: Rewrite): + for _, node in rewrite.nodes().items(): targets = node.get_targets() if targets is None: continue assert targets[0].type == ArgType.NamingArg target = str(targets[0]) - if target == "x_9": - position = rw.before(node) + if target == "x_11": + position = rewrite.after(node) custom_cell = MyCell() + bias = Tensor(1, mindspore.int32) new_custom_node = Rewrite.create_node(custom_cell, targets=[Argument.create_naming_arg('nx2')], - target_type="", args=[Argument.create_naming_arg('nx3')], + target_type="", args=[Argument.create_naming_arg('nx3'), + Argument.create_custom_arg(bias)], name='my_cell') - ret = rw.insert(position, new_custom_node, field='my_cell') - if ret is None: - raise RuntimeError("add_custom_node failed") + rewrite.insert(position, new_custom_node, field='my_cell') + rewrite.set_arg(new_custom_node, 0, "x_11") break - print_code("after add_custom_node", rw) - for _, node in rw.nodes().items(): - targets = node.get_targets() - if targets is None: - continue - assert targets[0].type == ArgType.NamingArg - target = str(targets[0]) - if target == "x_11": - ret = rw.update_arg(node, 0, "x_6") - if not ret: - raise RuntimeError("Update arg failed") - break - if new_conv_node is not None: - for _, node in rw.nodes().items(): - targets = node.get_targets() - if targets is None: - continue - assert targets[0].type == ArgType.NamingArg - target = str(targets[0]) - if target == "nx2": - ret = rw.update_arg_by_node(node, 0, new_conv_node) - if not ret: - raise RuntimeError("Update arg failed") - break - print_code("after update_arg", rw) - rw.set_output(0, "x_9") - print_code("after add_output", rw) - for _, node in rw.nodes().items(): + + +def erase_node_x_10(rewrite: Rewrite): + for _, node in rewrite.nodes().items(): targets = node.get_targets() if targets is None: continue assert targets[0].type == ArgType.NamingArg target = str(targets[0]) if target == "x_10": - ret = rw.erase_node(node) - if ret is None: - raise RuntimeError("erase_node failed") + rewrite.set_output(0, "x_9") + rewrite.erase_node(node) break + + +def transform(rw: Rewrite): + add_conv_before_flatten(rw) + print_code("after add_conv_before_flatten", rw) + add_my_cell_after_x_11(rw) + print_code("after add_my_cell_after_x_11", rw) + erase_node_x_10(rw) print_code("after erase_node", rw) diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index a66e4720627..b605885e1cd 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -19,7 +19,8 @@ from enum import Enum from typing import Optional, Union from .common import Argument, ArgType, AstModifier -from ..nn import Cell +from mindspore.nn import Cell +from mindspore import log as logger global_vars_name = "global_vars" origin_network_field_name = "_handler" @@ -63,6 +64,29 @@ class Node: self._next: Optional[Node] = None self._update_inputs() + @staticmethod + def _handle_custom_obj_in_args(args: [Argument]) -> [Argument]: + result = [] + for arg in args: + assert isinstance(arg, Argument) + if arg.type == ArgType.CustomObjArg: + logger.warning("custom-object exist in args, should be replace before compile") + result.append(Argument.create_naming_arg("custom-object", "self")) + else: + result.append(arg) + return result + + @staticmethod + def _handle_custom_obj_in_kwargs(kwargs: {str: Argument}) -> {str: Argument}: + result: {str, Argument} = {} + for arg, value in kwargs: + assert isinstance(value, Argument) + if value.type == ArgType.CustomObjArg: + result[arg] = Argument.create_naming_arg("custom-object", "self") + else: + result[arg] = value + return result + @classmethod def create_by_cell(cls, cell: Cell, ast_node: Optional[ast.AST], attributes: {str: object}, targets: [Argument], target_type: str = "", args: [Argument] = None, kwargs: {str: Argument} = None, name: str = ""): @@ -70,9 +94,19 @@ class Node: args = [] if kwargs is None: kwargs = {} + non_custom_args = Node._handle_custom_obj_in_args(args) + non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs) assert attributes is not None + if ast_node is None: + ast_node = AstModifier.create_assign(targets, Argument.create_naming_arg(name, "self"), non_custom_args, + non_custom_kwargs) return cls(NodeType.CallCell, ast_node, attributes, targets, args, kwargs, name, cell) + @classmethod + def create_return_node(cls, return_values: [str]): + real_return_values = Argument.create_name_arguments(return_values) + return cls(NodeType.Output, None, {}, None, real_return_values, {}, "return", None) + def get_prev(self) -> 'Node': return self._prev @@ -90,7 +124,27 @@ class Node: def get_ast(self) -> Optional[ast.AST]: return self._ast_node - def _sync_assign_node_to_ast(self): + def set_ast(self, ast_node: ast.AST): + assert isinstance(ast_node, ast.AST) + self._ast_node = ast_node + + def _sync_assign_func_to_ast(self): + if self._ast_node is None: + return + assign_ast = self._ast_node + assert isinstance(assign_ast, ast.Assign) + call_ast = assign_ast.value + assert isinstance(call_ast, ast.Call) + func_ast = call_ast.func + if isinstance(func_ast, ast.Name): + func_ast.id = self._name + elif isinstance(func_ast, ast.Attribute): + func_ast.attr = self._name + ast.fix_missing_locations(assign_ast) + + def _sync_assign_targets_to_ast(self): + if self._ast_node is None: + return assign_ast = self._ast_node assert isinstance(assign_ast, ast.Assign) # update targets @@ -101,10 +155,19 @@ class Node: target_ast = targets_ast[i] assert isinstance(target_ast, ast.Name) target_ast.id = target.name + ast.fix_missing_locations(assign_ast) + + def _sync_assign_args_to_ast(self): + if self._ast_node is None: + return + assign_ast = self._ast_node + assert isinstance(assign_ast, ast.Assign) # update args call_ast = assign_ast.value assert isinstance(call_ast, ast.Call) args_ast = call_ast.args + if len(self._args) != len(args_ast): + astpretty.pprint(call_ast) assert len(self._args) == len(args_ast) for i in range(0, len(self._args)): arg = self._args[i] @@ -121,11 +184,23 @@ class Node: arg_ast.attr = arg.name else: raise RuntimeError("Unsupported arg type: ", arg_ast) - # update kwargs - if len(self._kwargs) != 0: + ast.fix_missing_locations(assign_ast) + + def _sync_assign_kwargs_to_ast(self): + if self._ast_node is None: + return + if len(self._kwargs) > 0: raise RuntimeError("kwargs is not unsupported now") + def _sync_assign_node_to_ast(self): + self._sync_assign_targets_to_ast() + self._sync_assign_func_to_ast() + self._sync_assign_args_to_ast() + self._sync_assign_kwargs_to_ast() + def _sync_return_node_to_ast(self): + if self._ast_node is None: + return return_ast = self._ast_node assert isinstance(return_ast, ast.Return) # update args @@ -145,19 +220,7 @@ class Node: ele.id = arg.name else: raise RuntimeError("Unsupported return value type: ", return_value_ast) - - def sync_to_ast(self): - if self._ast_node is None: - return - if self._node_type == NodeType.CallCell: - self._sync_assign_node_to_ast() - elif self._node_type == NodeType.Output: - self._sync_return_node_to_ast() - - - def set_ast(self, ast_node: ast.AST): - assert isinstance(ast_node, ast.AST) - self._ast_node = ast_node + ast.fix_missing_locations(return_ast) def _update_inputs(self): pass @@ -193,55 +256,6 @@ class Node: if origin_next is not None: origin_next._prev = node - def update_arg(self, index: int, arg: str) -> bool: - if self._ast_node is None or not isinstance(self._ast_node, ast.Assign): - return False - ret = AstModifier.update_argument_for_call_assign(self._ast_node, index, arg, None) - if not ret: - return False - if index >= len(self._args): - return False - self._args[index] = arg - self._update_inputs() - return True - - def update_arg_by_node(self, arg_idx: int, node: 'Node', out_idx: Optional[int] = None) -> bool: - assert isinstance(node, Node) - if self._ast_node is None or not isinstance(self._ast_node, ast.Assign): - return False - if arg_idx >= len(self._args): - return False - if out_idx is None: - assert len(node._targets) == 1 - new_arg = str(node._targets[0]) - ret = AstModifier.update_argument_for_call_assign(self._ast_node, arg_idx, new_arg, None) - if not ret: - return False - self._args[arg_idx] = new_arg - else: - assert out_idx < len(node._targets) - new_arg = str(node._targets[out_idx]) - ret = AstModifier.update_argument_for_call_assign(self._ast_node, arg_idx, new_arg, None) - if not ret: - return False - self._args[arg_idx] = new_arg - self._update_inputs() - return True - - def update_args(self, args: [str]): - self._call_args = args - self._update_inputs() - - def update_kwargs_by_name(self, key: str, value: str): - self._kwargs[key] = value - self._update_inputs() - - def update_kwargs_by_node(self, key: str, node: 'Node', out_idx: int): - assert isinstance(node, Node) - assert out_idx < len(node._targets) - self._kwargs[key] = node._targets[out_idx] - self._update_inputs() - def normalize_args(self): # todo merge args kwargs default_args into normalize_args pass @@ -249,17 +263,25 @@ class Node: def get_inputs(self) -> ['Node']: return self._inputs + def set_inputs(self, inputs: ['Node']): + self._inputs = inputs + def get_targets(self) -> [Argument]: return self._targets + # todo can only be called before node been inserted into symbol-tree def set_targets(self, targets: [Argument]): self._targets = targets + if self._node_type == NodeType.CallCell: + self._sync_assign_targets_to_ast() def get_name(self) -> str: return self._name def set_name(self, name: str): self._name = name + if self._node_type == NodeType.CallCell: + self._sync_assign_func_to_ast() def get_node_type(self) -> NodeType: return self._node_type @@ -273,12 +295,39 @@ class Node: def get_args(self) -> [Argument]: return self._args - def set_arg(self, arg: Argument, index: int): + # todo update inputs + def set_arg_by_node(self, arg_idx: int, node: 'Node', out_idx: Optional[int] = None): + assert isinstance(node, Node) + if arg_idx >= len(self._args): + raise RuntimeError("arg_idx out of range of node args: ", len(self._args)) + if out_idx is None: + if len(node._targets) != 1: + raise RuntimeError("node should has one output when out_idx is not provided") + out_idx = 0 + assert out_idx < len(node._targets) + new_arg = node._targets[out_idx] + self._args[arg_idx] = new_arg + if self._node_type == NodeType.CallCell: + self._sync_assign_args_to_ast() + elif self._node_type == NodeType.Output: + self._sync_return_node_to_ast() + + def set_arg(self, arg: Union[Argument, str], index: int): assert index < len(self._args) + if isinstance(arg, str): + arg = Argument.create_naming_arg(arg) self._args[index] = arg + if self._node_type == NodeType.CallCell: + self._sync_assign_args_to_ast() + elif self._node_type == NodeType.Output: + self._sync_return_node_to_ast() def set_args(self, args: [Argument]): self._args = args + if self._node_type == NodeType.CallCell: + self._sync_assign_args_to_ast() + elif self._node_type == NodeType.Output: + self._sync_return_node_to_ast() def add_arg(self, arg: Argument, index: Optional[int] = None): if index is None: @@ -286,12 +335,18 @@ class Node: else: assert index <= len(self._args) self._args.insert(index, arg) + if self._node_type == NodeType.CallCell: + self._sync_assign_args_to_ast() + elif self._node_type == NodeType.Output: + self._sync_return_node_to_ast() def get_kwargs(self) -> {str: Argument}: return self._kwargs def set_kwargs(self, kwargs: {str: Argument}): self._kwargs = kwargs + if self._node_type == NodeType.CallCell: + self._sync_assign_kwargs_to_ast() def set_attribute(self, key: str, value): self._attribute[key] = value diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index 2b15cb8275d..02c788c3d01 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -47,24 +47,17 @@ class Rewrite: kwargs: {str: Argument} = None, name: str = None) -> Node: # 'targets': 'target_type' = self.'field'(*'args', **'kwargs') if isinstance(op, Cell): - # todo create ast from op: ast can be created when insert into symbol_tree - ast_node = None # todo resolve attributes from op attributes = {} - return Node.create_by_cell(op, ast_node, attributes, targets, target_type, args, kwargs, name) + return Node.create_by_cell(op, None, attributes, targets, target_type, args, kwargs, name) elif isinstance(op, Primitive): raise RuntimeError("Primitive op will be support in near future.") else: raise RuntimeError("Only support Cell op or Primitive op!") # todo define class position: use symbol_tree, node and before/after as position - def insert(self, position, node: Node, field: str = None, args: [Argument] = None, kwargs: {str: Argument} = None) \ - -> Node: + def insert(self, position, node: Node, field: str = None) -> Node: # self.'field': 'custom_obj_type' = global_vars.get('field') - if args is not None: - node.set_args(args) - if kwargs is not None: - node.set_kwargs(kwargs) return self._symbol_tree.insert_node(position, node, field) def erase_node(self, node_or_name: Union[Node, str]) -> Node: @@ -73,12 +66,13 @@ class Rewrite: def set_output(self, index: int, return_value: str) -> Node: return self._symbol_tree.set_output(return_value, index) - def update_arg(self, node: Node, index: int, arg: str) -> bool: - return node.update_arg(index, arg) + @staticmethod + def set_arg(node: Node, index: int, arg: Union[Argument, str]): + node.set_arg(arg, index) - def update_arg_by_node(self, node_to_update: Node, arg_idx: int, node_to_link: Node, - out_idx: Optional[int] = None) -> bool: - return node_to_update.update_arg_by_node(arg_idx, node_to_link, out_idx) + @staticmethod + def set_arg_by_node(node_to_update: Node, arg_idx: int, node_to_link: Node, out_idx: Optional[int] = None): + node_to_update.set_arg_by_node(arg_idx, node_to_link, out_idx) def dump(self): self._symbol_tree.dump() diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index 543f5d8eedc..b36b5db1fd4 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -127,6 +127,7 @@ class SymbolTree: self._opt_cls_name = self._ori_cls_name + "Opt" self._origin_network = origin_network # init unique-namers + self._init_target_namer = Namer() self._target_namer = Namer() self._node_name_namer = NodeNamer() # root must be output of graph @@ -220,63 +221,76 @@ class SymbolTree: node) self._nodes[node_name] = node + # todo optimize + def _find_inputs_of_node(self, node: Node): + result = [] + args = node.get_args() + if args is None: + return result + for name in self._nodes: + exist_node: Node = self._nodes[name] + targets = exist_node.get_targets() + if targets is None: + continue + flag = False + for target in targets: + if flag: + break + for arg in args: + if target.__eq__(arg): + result.append(exist_node) + flag = True + break + return result + def _insert_node(self, position, node: Node): # unique targets, name while insert node into symbol_tree self._update_args_for_unique(node) self._update_kwargs_for_unique(node) # _unique_targets must called after _update_args_for_unique and _update_kwargs_for_unique self._unique_targets(node) - node.sync_to_ast() + # find inputs + inputs = self._find_inputs_of_node(node) + node.set_inputs(inputs) + self._append_node2nodes(node) if position[0]: position[1].insert_before(node) else: position[1].insert_after(node) - def _insert_blackbox_object_into_init(self, object) -> Node: + def _insert_blackbox_object_into_init(self, object) -> str: field = self._node_name_namer.get_name(f"var_{type(object).__name__}") self._global_vars[field] = object init_targets = [Argument.create_naming_arg(field, "self")] - construct_targets = [Argument.create_naming_arg(field)] + # construct_targets = [Argument.create_naming_arg(field)] args = [Argument.create_imm_arg(field)] - ast_node = AstModifier.insert_global_vars_expr_to_init(self._init_func_ast, init_targets, args) - if ast_node is None: - raise RuntimeError("insert custom obj to init_func failed") - return Node(NodeType.UserCustom, ast_node, {}, construct_targets, args, {}, field, None) + AstModifier.insert_global_vars_expr_to_init(self._init_func_ast, init_targets, args) + return field + # return Node(NodeType.UserCustom, ast_node, {}, construct_targets, args, {}, field, None) - def _handle_custom_obj_in_args(self, position, node: Node): + def _handle_custom_obj_in_args(self, node: Node): result: [Argument] = [] for arg in node.get_args(): assert isinstance(arg, Argument) if arg.type == ArgType.CustomObjArg: - node = self._insert_blackbox_object_into_init(arg.name) - self._insert_node(position, node) - new_arg = self._target_namer.get_real_arg(node.get_name()) - result.append(Argument(ArgType.NamingArg, "self", new_arg)) + field = self._insert_blackbox_object_into_init(arg.name) + result.append(Argument.create_naming_arg(field, "self")) else: result.append(arg) node.set_args(result) - def _handle_custom_obj_in_kwargs(self, position, node: Node): + def _handle_custom_obj_in_kwargs(self, node: Node): result: {str, Argument} = {} for arg, value in node.get_kwargs(): assert isinstance(value, Argument) if value.type == ArgType.CustomObjArg: - node = self._insert_blackbox_object_into_init(value.name) - self._insert_node(position, node) - new_arg = self._target_namer.get_real_arg(node.get_name()) - result[arg] = Argument(ArgType.NamingArg, "self", new_arg) + field = self._insert_blackbox_object_into_init(value.name) + result[arg] = Argument.create_naming_arg(field, "self") else: result[arg] = value node.set_kwargs(result) - @staticmethod - def _convert_strs_to_name_arguments(targets: [str]) -> [Argument]: - result = [] - for target in targets: - result.append(Argument(ArgType.NamingArg, "", target)) - return result - # todo move into ast_modifier def _find_node_index(self, node: Node) -> Optional[int]: for i in range(0, len(self._construct_func_ast.body)): @@ -291,34 +305,29 @@ class SymbolTree: if index is None: raise RuntimeError("index is not None: ", position[1].get_name()) # modify init function - if field is None: + if field is None or field == "": field = self._node_name_namer.get_name(node) else: field = self._node_name_namer.get_name(field) node.set_name(field) - ast_node = AstModifier.insert_assign_to_function(self._init_func_ast, - targets=[Argument(ArgType.NamingArg, "self", field)], - expr=Argument(ArgType.NamingArg, "global_vars", "get"), - args=[Argument(ArgType.StringArg, "", field)]) - if ast_node is None: - raise RuntimeError("insert custom_node into init function ast tree failed.") + AstModifier.insert_assign_to_function(self._init_func_ast, targets=[Argument(ArgType.NamingArg, "self", field)], + expr=Argument(ArgType.NamingArg, "global_vars", "get"), + args=[Argument(ArgType.StringArg, "", field)]) self._global_vars[field] = node.get_op() # modify construct function - self._handle_custom_obj_in_args(position, node) - self._handle_custom_obj_in_kwargs(position, node) - ast_node = AstModifier.insert_assign_to_function(self._construct_func_ast, targets=node.get_targets(), - expr=Argument(ArgType.NamingArg, "self", field), - args=node.get_args(), kwargs=node.get_kwargs(), - index_ast=position[1].get_ast(), insert_before=position[0]) - if ast_node is None: - raise RuntimeError("insert custom_node into construct function ast tree failed.") - node.set_ast(ast_node) + self._handle_custom_obj_in_args(node) + self._handle_custom_obj_in_kwargs(node) + node_ast = node.get_ast() + if not isinstance(node_ast, ast.Assign): + raise RuntimeError("Only support insert cell op now") + AstModifier.insert_assign_ast_to_function(self._construct_func_ast, node_ast, position[1].get_ast(), + position[0]) self._insert_node(position, node) return node def append_origin_field(self, origin_field_name: str, op, targets: [Argument], args: [Argument], kwargs: {str: Argument}, ast_node: ast.AST, target_type: str = "", - attribute: {str, object}=None) -> Node: + attribute: {str, object} = None) -> Node: if self._root is None: raise RuntimeError("SymbolTree not inited") if attribute is None: @@ -326,7 +335,7 @@ class SymbolTree: assert len(origin_field_name) > 0 # todo wait for resolve, use exist + add rather than get_name node_name = self._node_name_namer.get_name(origin_field_name) - node = Node(NodeType.CallCell, ast_node, attribute, targets, args, kwargs, node_name, op) + node = Node.create_by_cell(op, ast_node, attribute, targets, target_type, args, kwargs, node_name) self._insert_node((True, self._root), node) return node @@ -334,10 +343,20 @@ class SymbolTree: if self._root is None: raise RuntimeError("SymbolTree not inited") self._root.set_ast(return_ast) - self.update_output(return_values) + self.set_outputs(return_values) + return self._root + + def set_output(self, return_value: str, index: int) -> Node: + if self._root is None: + raise RuntimeError("SymbolTree not inited") + if index >= len(self._root.get_args()): + raise RuntimeError("index(%d) out of range(%d)", index, len(self._root.get_targets())) + new_arg: Argument = self._root.get_args()[index] + new_arg.name = self._target_namer.get_real_arg(return_value) + self._root.set_arg(new_arg, index) return self._root - def update_output(self, return_values: [str]) -> Node: + def set_outputs(self, return_values: [str]) -> Node: if self._root is None: raise RuntimeError("SymbolTree not inited") if len(return_values) == 0: @@ -345,15 +364,13 @@ class SymbolTree: unique_return_values = [] for return_value in return_values: unique_return_values.append(self._target_namer.get_real_arg(return_value)) - unique_return_args = self._convert_strs_to_name_arguments(unique_return_values) + unique_return_args = Argument.create_name_arguments(unique_return_values) self._root.set_args(unique_return_args) - self._root.sync_to_ast() return self._root def _add_output_node(self, return_values: [str]) -> Node: - real_return_values = self._convert_strs_to_name_arguments(return_values) - node_name = self._node_name_namer.get_name("return") - node = Node(NodeType.Output, None, {}, None, real_return_values, {}, node_name, None) + node = Node.create_return_node(return_values) + self._node_name_namer.add_name(node.get_name()) self._append_node2nodes(node) return node @@ -365,6 +382,14 @@ class SymbolTree: self._insert_node((True, self._root), node) return node + # todo update ast + def add_input(self, name: str, input_type: Optional[type] = None, default: Optional[str] = None): + for arg in self._inputs: + if arg.name == name: + raise RuntimeError("input duplicated: %s", name) + self._inputs.append(Input(name, input_type, default)) + self._target_namer.add_name(name) + # can only erase isolated node def erase_node(self, node_or_name: Union[Node, str]) -> Node: if isinstance(node_or_name, Node): @@ -383,50 +408,18 @@ class SymbolTree: break return node - # todo update ast - def add_input(self, name: str, input_type: Optional[type] = None, default: Optional[str] = None): - for arg in self._inputs: - if arg.name == name: - raise RuntimeError("input duplicated: %s", name) - self._inputs.append(Input(name, input_type, default)) - self._target_namer.add_name(name) - - def add_input_and_update_ast(self, name: str, input_type: Optional[type] = None, - default: Optional[str] = None): - for arg in self._inputs: - if arg.name == name: - raise RuntimeError("input duplicated: %s", name) - ret = AstModifier.insert_argument_to_function(self._construct_func_ast, name, default) - if not ret: - raise RuntimeError("insert_argument_to_function failed") - self._inputs.append(Input(name, input_type, default)) - - # todo update ast - def add_output(self, return_value: str, index: Optional[int] = None) -> Node: - if self._root is None: - raise RuntimeError("SymbolTree not inited") - new_return_value = Argument(ArgType.NamingArg, "", return_value) - if index is None: - self._root.add_arg(new_return_value) - else: - if index > len(self._root.get_args()): - raise RuntimeError("index(%d) out of range(%d)", index, len(self._root.get_targets())) - self._root.add_arg(new_return_value, index) - return self._root - - # todo update ast - def set_output(self, return_value: str, index: int) -> Node: - if self._root is None: - raise RuntimeError("SymbolTree not inited") - if index >= len(self._root.get_args()): - raise RuntimeError("index(%d) out of range(%d)", index, len(self._root.get_targets())) - new_arg: Argument = self._root.get_args()[index] - new_arg.name = return_value - self._root.set_arg(new_arg, index) - - if not AstModifier.set_output(self._construct_func_ast, return_value): - raise RuntimeError("ast set output fail") - return self._root + # # todo update ast + # def add_output(self, return_value: str, index: Optional[int] = None) -> Node: + # if self._root is None: + # raise RuntimeError("SymbolTree not inited") + # new_return_value = Argument(ArgType.NamingArg, "", return_value) + # if index is None: + # self._root.add_arg(new_return_value) + # else: + # if index > len(self._root.get_args()): + # raise RuntimeError("index(%d) out of range(%d)", index, len(self._root.get_targets())) + # self._root.add_arg(new_return_value, index) + # return self._root def dump(self): dump_st = SymbolTreeDumper(self) -- Gitee From 1970e5c695936fc8ff154099c61948de6d439ed8 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Tue, 15 Feb 2022 14:09:10 +0800 Subject: [PATCH 21/32] remove NodeType.UserCustom --- .../rewrite_experiment/common/node_info.py | 64 ------------------- .../mindspore/rewrite_experiment/node.py | 12 ++-- .../rewrite_experiment/symbol_tree.py | 15 +---- 3 files changed, 6 insertions(+), 85 deletions(-) delete mode 100644 mindspore/python/mindspore/rewrite_experiment/common/node_info.py diff --git a/mindspore/python/mindspore/rewrite_experiment/common/node_info.py b/mindspore/python/mindspore/rewrite_experiment/common/node_info.py deleted file mode 100644 index 77ba72fc2e1..00000000000 --- a/mindspore/python/mindspore/rewrite_experiment/common/node_info.py +++ /dev/null @@ -1,64 +0,0 @@ -from enum import Enum - -from typing import Optional -from .argument import Argument, ArgType - - -class NodeInfo: - def __init__(self, node_type: NodeType, field: str = None, field_type: Optional[type] = None, func: Argument = None, - construct_args: [Argument] = None, construct_kwargs: {str: Argument} = None, targets: [str] = None, - target_type: Optional[type] = None, call_args: [Argument] = None, call_kwargs: {str: Argument} = None, - extra: {} = None): - self.node_type = node_type - self.field: str = field - self.field_type: Optional[type] = field_type - assert func is not None - self.func = func - if construct_args is None: - self.construct_args: [Argument] = [] - else: - self.construct_args: [Argument] = construct_args - if construct_kwargs is None: - self.construct_kwargs: {str: Argument} = {} - else: - self.construct_kwargs: {str: Argument} = construct_kwargs - if targets is None: - self.targets: [str] = None - else: - self.targets: [str] = targets - self.target_type: Optional[type] = target_type - if call_args is None: - self.call_args: [Argument] = [] - else: - self.call_args: [Argument] = call_args - if call_kwargs is None: - self.call_kwargs: {str: Argument} = {} - else: - self.call_kwargs: {str: Argument} = call_kwargs - if extra is None: - self.extra = {} - else: - self.extra = extra - - @classmethod - def create_cell_info(cls, field: str = None, cell_type: Optional[type] = None, construct_args: [Argument] = None, - construct_kwargs: {str: Argument} = None, targets: [str] = None, - target_type: Optional[type] = None, call_args: [Argument] = None, - call_kwargs: {str: Argument} = None) -> 'NodeInfo': - class_name = cell_type.__name__ - return cls(NodeType.CallCell, field, cell_type, Argument(ArgType.NamingArg, "mindspore.nn", class_name), - construct_args, construct_kwargs, targets, target_type, call_args, call_kwargs) - - @classmethod - def create_function_info(cls, field: str = None, func_type: Optional[type] = None, targets: [str] = None, - target_type: Optional[type] = None, call_args: [Argument] = None, - call_kwargs: {str: Argument} = None) -> 'NodeInfo': - return cls(NodeType.CallFunction, field, func_type, None, [], {}, targets, target_type, call_args, call_kwargs) - - @classmethod - def create_object_info(cls, object, field: str = None, obj_type: Optional[type] = None, targets: [str] = None, - target_type: Optional[type] = None, call_args: [Argument] = None, - call_kwargs: {str: Argument} = None) -> 'NodeInfo': - return cls(NodeType.UserCustom, field, obj_type, Argument(ArgType.NamingArg, "global_vars", "get"), - [Argument(ArgType.StringArg, "", field)], {}, targets, target_type, call_args, call_kwargs, - {"object": object}) diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index b605885e1cd..ea8fc937ce1 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -34,18 +34,16 @@ class NodeType(Enum): CallMethod = 2 # method in cell CallFunction = 3 # subclass of primitive - UserCustom = 4 # todo remove - Input = 5 - Output = 6 - Graph = 7 + Input = 4 + Output = 5 + Graph = 6 class Node: def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], attributes: {str: object}, targets: [Argument], args: [Argument], kwargs: {str: Argument}, name: str, op): - if node_type not in {NodeType.CallCell, NodeType.Output, NodeType.UserCustom, NodeType.CallMethod, - NodeType.Unknown}: - raise RuntimeError("Only support CallCell, UserCustom, CallMethod and Output now") + if node_type not in {NodeType.CallCell, NodeType.Output, NodeType.CallMethod, NodeType.Unknown}: + raise RuntimeError("Only support CallCell, CallMethod and Output now") assert attributes is not None self._node_type: NodeType = node_type self._ast_node: Optional[ast.AST] = ast_node diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index b36b5db1fd4..4157eb0f5c8 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -127,9 +127,9 @@ class SymbolTree: self._opt_cls_name = self._ori_cls_name + "Opt" self._origin_network = origin_network # init unique-namers - self._init_target_namer = Namer() self._target_namer = Namer() self._node_name_namer = NodeNamer() + self._node_name_namer.add_name(origin_network_key) # root must be output of graph self._root = self._add_output_node(["undefined"]) # head must be the first statement but must not be inputs of graph @@ -408,19 +408,6 @@ class SymbolTree: break return node - # # todo update ast - # def add_output(self, return_value: str, index: Optional[int] = None) -> Node: - # if self._root is None: - # raise RuntimeError("SymbolTree not inited") - # new_return_value = Argument(ArgType.NamingArg, "", return_value) - # if index is None: - # self._root.add_arg(new_return_value) - # else: - # if index > len(self._root.get_args()): - # raise RuntimeError("index(%d) out of range(%d)", index, len(self._root.get_targets())) - # self._root.add_arg(new_return_value, index) - # return self._root - def dump(self): dump_st = SymbolTreeDumper(self) dump_st.dump_symbol_tree() -- Gitee From 14f88e81eda3e85ebc18ae47c3a7d09573be8af1 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Tue, 15 Feb 2022 16:40:58 +0800 Subject: [PATCH 22/32] update rewrite interface --- .../rewrite_experiment/example/test_lenet.py | 4 +-- .../mindspore/rewrite_experiment/rewrite.py | 15 ++++---- .../rewrite_experiment/symbol_tree.py | 34 +++++++++++++++---- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py index b1389a645b3..96cb2cd4ee0 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet.py @@ -32,7 +32,7 @@ def add_conv_before_flatten(rewrite: Rewrite): new_conv = nn.Conv2d(16, 16, 3) new_conv_node = Rewrite.create_node(new_conv, targets=[Argument.create_naming_arg('x_1')], target_type="", name='new_conv', args=[Argument.create_naming_arg('self_max_po')]) - rewrite.insert(position, new_conv_node, field='conv_new') + rewrite.insert(position, new_conv_node) break if new_conv_node is not None: for _, node in rewrite.nodes().items(): @@ -57,7 +57,7 @@ def add_my_cell_after_x_11(rewrite: Rewrite): target_type="", args=[Argument.create_naming_arg('nx3'), Argument.create_custom_arg(bias)], name='my_cell') - rewrite.insert(position, new_custom_node, field='my_cell') + rewrite.insert(position, new_custom_node) rewrite.set_arg(new_custom_node, 0, "x_11") break diff --git a/mindspore/python/mindspore/rewrite_experiment/rewrite.py b/mindspore/python/mindspore/rewrite_experiment/rewrite.py index 02c788c3d01..17d2cc596f5 100644 --- a/mindspore/python/mindspore/rewrite_experiment/rewrite.py +++ b/mindspore/python/mindspore/rewrite_experiment/rewrite.py @@ -56,9 +56,9 @@ class Rewrite: raise RuntimeError("Only support Cell op or Primitive op!") # todo define class position: use symbol_tree, node and before/after as position - def insert(self, position, node: Node, field: str = None) -> Node: + def insert(self, position, node: Node) -> Node: # self.'field': 'custom_obj_type' = global_vars.get('field') - return self._symbol_tree.insert_node(position, node, field) + return self._symbol_tree.insert_node(position, node) def erase_node(self, node_or_name: Union[Node, str]) -> Node: return self._symbol_tree.erase_node(node_or_name) @@ -66,13 +66,12 @@ class Rewrite: def set_output(self, index: int, return_value: str) -> Node: return self._symbol_tree.set_output(return_value, index) - @staticmethod - def set_arg(node: Node, index: int, arg: Union[Argument, str]): - node.set_arg(arg, index) + def set_arg(self, node: Union[Node, str], index: int, arg: Union[Argument, str]): + self._symbol_tree.set_node_arg(node, index, arg) - @staticmethod - def set_arg_by_node(node_to_update: Node, arg_idx: int, node_to_link: Node, out_idx: Optional[int] = None): - node_to_update.set_arg_by_node(arg_idx, node_to_link, out_idx) + def set_arg_by_node(self, node_to_update: Union[Node, str], arg_idx: int, node_to_link: Node, + out_idx: Optional[int] = None): + self._symbol_tree.set_node_arg_by_node(node_to_update, arg_idx, node_to_link, out_idx) def dump(self): self._symbol_tree.dump() diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index 4157eb0f5c8..0f6de155db8 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -299,16 +299,13 @@ class SymbolTree: return i return None - def insert_node(self, position, node: Node, field: str = None) -> Node: + def insert_node(self, position, node: Node) -> Node: # check position index = self._find_node_index(position[1]) if index is None: raise RuntimeError("index is not None: ", position[1].get_name()) # modify init function - if field is None or field == "": - field = self._node_name_namer.get_name(node) - else: - field = self._node_name_namer.get_name(field) + field = self._node_name_namer.get_name(node) node.set_name(field) AstModifier.insert_assign_to_function(self._init_func_ast, targets=[Argument(ArgType.NamingArg, "self", field)], expr=Argument(ArgType.NamingArg, "global_vars", "get"), @@ -390,7 +387,7 @@ class SymbolTree: self._inputs.append(Input(name, input_type, default)) self._target_namer.add_name(name) - # can only erase isolated node + # todo can only erase isolated node def erase_node(self, node_or_name: Union[Node, str]) -> Node: if isinstance(node_or_name, Node): node = node_or_name @@ -408,6 +405,31 @@ class SymbolTree: break return node + def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[Argument, str]): + if isinstance(node, str): + node = self._nodes.get(str) + if node is None: + raise RuntimeError("node is None:", node) + assert isinstance(node, Node) + node.set_arg(arg, index) + + def set_node_arg_by_node(self, node_to_update: Union[Node, str], arg_idx: int, node_to_link: Node, + out_idx: Optional[int] = None): + if isinstance(node_to_update, str): + node_to_update = self._nodes.get(str) + if node_to_update is None: + raise RuntimeError("node is None:", node_to_update) + assert isinstance(node_to_update, Node) + node_to_update.set_arg_by_node(arg_idx, node_to_link, out_idx) + + # todo + def get_node_inputs(self, node: Node) -> [Node]: + raise NotImplementedError + + # todo + def get_node_users(self, node: Union[Node, str]) -> [Node]: + raise NotImplementedError + def dump(self): dump_st = SymbolTreeDumper(self) dump_st.dump_symbol_tree() -- Gitee From 93b8f56a09e961794b02897c50d4e88115ec57b4 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 15 Feb 2022 10:23:47 +0800 Subject: [PATCH 23/32] flatten return flatten return flatten return & add return tuple parser flatten return & add return tuple parser flatten return & add return tuple parser flatten return & add return tuple parser flatten return & add return tuple parser --- .../rewrite_experiment/ast_parse/__init__.py | 1 + .../rewrite_experiment/ast_parse/ast_parse.py | 4 +- .../ast_parse/ast_transformers/__init__.py | 4 +- ...sive_call.py => flatten_recursive_stmt.py} | 56 ++++++++++++--- .../rewrite_experiment/ast_parse/namespace.py | 2 +- .../ast_parse/parsers/return_parser.py | 11 +-- .../ast_parse/parsers/tuple_parser.py | 42 +++++++++++ .../rewrite_experiment/ast_parse/symbol.py | 45 ++++++++++++ .../const_symbol_propagate.py | 3 + .../example/test_lenet_multi_output.py | 70 +++++++++++++++++++ 10 files changed, 217 insertions(+), 21 deletions(-) rename mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/{flatten_recursive_call.py => flatten_recursive_stmt.py} (49%) create mode 100644 mindspore/python/mindspore/rewrite_experiment/ast_parse/parsers/tuple_parser.py create mode 100644 mindspore/python/mindspore/rewrite_experiment/example/test_lenet_multi_output.py diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/__init__.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/__init__.py index e36d4d95653..a3021aaae10 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/__init__.py @@ -11,6 +11,7 @@ from .parsers.call_parser import CallParser from .parsers.attribute_parser import AttributeParser from .parsers.binop_parser import BinOpParser from .parsers.return_parser import ReturnParser +from .parsers.tuple_parser import TupleParser from .parsers.keyword_parser import KeywordParser from .parsers.constant_parser import ConstantParser from .parsers.name_parser import NameParser diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py index bac806d7819..9df5cd04b35 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_parse.py @@ -31,13 +31,13 @@ from .symbol_transformers.clear_non_symbol_in_stb import ClearNonSymbolInSTB from .symbol_transformers.rename_construct_function_assign_target import RenameConstructFuncAssignTarget from .resolver import Resolver from ..symbol_tree import SymbolTree -from .ast_transformers import FlattenRecursiveCall, FoldBinop, RenameConstructFAssignTarget +from .ast_transformers import FlattenRecursiveStmt, FoldBinop, RenameConstructFAssignTarget class AstParse: @staticmethod def _ast_transform(ast_root: ast.AST) -> ast.AST: - transform_list = [FoldBinop(), FlattenRecursiveCall()] + transform_list = [FoldBinop(), FlattenRecursiveStmt()] for transformer in transform_list: ast_root = transformer.transform(ast_root) return ast_root diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py index 1bafdf38c7e..881daea3d73 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/__init__.py @@ -1,5 +1,5 @@ -from .flatten_recursive_call import FlattenRecursiveCall +from .flatten_recursive_stmt import FlattenRecursiveStmt from .const_fold import FoldBinop from .rename_construct_f_assign_target import RenameConstructFAssignTarget -__all__ = ["FlattenRecursiveCall", "FoldBinop", "RenameConstructFAssignTarget"] +__all__ = ["FlattenRecursiveStmt", "FoldBinop", "RenameConstructFAssignTarget"] diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_stmt.py similarity index 49% rename from mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py rename to mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_stmt.py index f9aad92d1df..cdc71aee7c7 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_call.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/ast_transformers/flatten_recursive_stmt.py @@ -16,11 +16,21 @@ import ast from ast import FunctionDef from typing import Any +from mindspore import log as logger -class FlattenRecursiveCall(ast.NodeTransformer): +class FlattenRecursiveStmt(ast.NodeTransformer): @staticmethod - def _flatten_call_value(node: FunctionDef, call_value: ast.expr, new_target_names, function_index): + def _generate_target_name(candidate_name, new_target_names): + target_name = f"self_{candidate_name}" + suffix = 0 + while target_name in new_target_names: + suffix += 1 + target_name = f"self_{candidate_name}_{suffix}" + new_target_names.append(target_name) + return target_name + + def _flatten_call_value(self, node: FunctionDef, call_value: ast.expr, new_target_names, function_index): if not isinstance(call_value, ast.Call): return False @@ -32,15 +42,7 @@ class FlattenRecursiveCall(ast.NodeTransformer): if not isinstance(arg.func, ast.Attribute): raise RuntimeError("func of call can only support ast.attribute now") - assert isinstance(arg.func.attr, str) - target_name = f"self_{arg.func.attr}" - - suffix = 0 - while target_name in new_target_names: - suffix += 1 - target_name = f"self_{arg.func.attr}_{suffix}" - new_target_names.append(target_name) - + target_name = self._generate_target_name(arg.func.attr, new_target_names) new_assign_node = ast.Assign(targets=[ast.Name(id=target_name, ctx=ast.Store())], value=arg) node.body.insert(function_index, new_assign_node) args.pop(call_arg_index) @@ -48,12 +50,40 @@ class FlattenRecursiveCall(ast.NodeTransformer): return True return False + def _flatten_return_value(self, node: FunctionDef, return_node: ast.Return, new_target_names, function_index): + return_value = return_node.value + if not isinstance(return_value, ast.Call): + return False + + if not isinstance(return_value.func, ast.Attribute): + raise RuntimeError("func of call can only support ast.attribute now") + target_name = self._generate_target_name(return_value.func.attr, new_target_names) + new_assign_node = ast.Assign(targets=[ast.Name(id=target_name, ctx=ast.Store())], value=return_value) + node.body.insert(function_index, new_assign_node) + return_node.value = ast.Name(id=target_name, ctx=ast.Load()) + return True + + @staticmethod + def _fill_in_original_target_names(target_names, node): + for function_index in range(len(node.body)): + child = node.body[function_index] + if not isinstance(child, ast.Assign): + continue + targets = child.targets + for target in targets: + if not isinstance(target, ast.Name): + raise RuntimeError("currently only support ast.Name targets") + target_name = target.id + if target_name not in target_names: + target_names.append(target_name) + def visit_FunctionDef(self, node: FunctionDef) -> Any: if node.name != "construct": return node changed = True new_target_names = [] + self._fill_in_original_target_names(new_target_names, node) while changed: changed = False for function_index in range(len(node.body) - 1, -1, -1): @@ -61,6 +91,10 @@ class FlattenRecursiveCall(ast.NodeTransformer): if isinstance(child, ast.Assign): call_value = child.value changed = changed or self._flatten_call_value(node, call_value, new_target_names, function_index) + elif isinstance(child, ast.Return): + changed = changed or self._flatten_return_value(node, child, new_target_names, function_index) + else: + logger.warning(f"ignoring body type {type(child)} in flatten_recursive_stmt") return node def transform(self, ast_root): diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py index 6a5cefaa614..80c4d576c43 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/namespace.py @@ -21,7 +21,7 @@ from types import FunctionType import mindspore.nn as nn from mindspore import log as logger -from ..._extends.parse.namespace import CellNamespace,ClosureNamespace, ClassAttrNamespace, ClassMemberNamespace +from ..._extends.parse.namespace import CellNamespace,ClosureNamespace class Namespace(): def __init__(self): diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/parsers/return_parser.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/parsers/return_parser.py index 7ca39926287..2826d8f59bb 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/parsers/return_parser.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/parsers/return_parser.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,9 +31,10 @@ class ReturnParser(Processor): return_symbol = ReturnSymbol(ast_return, symbol.scope, symbol.symbol_name) return_symbol.set_symbol_location(symbol.get_symbol_location()) symbol_node.set_value(return_symbol) - + new_symbols = [symbol_node] value = ast_return.value - if not isinstance(value, ast.Name): - raise RuntimeError("Only support value of ast.Return in (ast.Name,). got: ", type(value).__name__) + if not isinstance(value, (ast.Name, ast.Tuple)): + raise RuntimeError("Only support value of ast.Return in (ast.Name, ast.Tuple). got: ", type(value).__name__) value_node = return_symbol.set_return_value(Symbol(ast_node=value, symbol_type=SymbolType.name)) - return [symbol_node, value_node] + new_symbols.append(value_node) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/parsers/tuple_parser.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/parsers/tuple_parser.py new file mode 100644 index 00000000000..0a9fa9400a6 --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/parsers/tuple_parser.py @@ -0,0 +1,42 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from ..value_node import ValueNode +from ..symbol import Symbol, SymbolType, TupleSymbol +from ..processor import Processor +from ..parser_register import ParserRegister + + +@ParserRegister.reg_parser +class TupleParser(Processor): + def process(self, symbol_node: ValueNode) -> [ValueNode]: + symbol = symbol_node.value() + if not isinstance(symbol.symbol_ast, ast.Tuple): + return [symbol_node] + ast_tuple: ast.Tuple = symbol.symbol_ast + + tuple_symbol = TupleSymbol(ast_tuple, symbol.scope, symbol.symbol_name) + tuple_symbol.set_symbol_location(symbol.get_symbol_location()) + symbol_node.set_value(tuple_symbol) + new_symbols = [symbol_node] + + elts = ast_tuple.elts + for elt in elts: + if not isinstance(elt, ast.Name): + raise RuntimeError("Only support elt of ast.Tuple in (ast.Name). got: ", type(elt).__name__) + elt_node = tuple_symbol.add_tuple_elts(Symbol(ast_node=elt, symbol_type=SymbolType.name)) + new_symbols.append(elt_node) + return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py index e23894632c4..49120c3e1e5 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol.py @@ -694,6 +694,11 @@ class ReturnSymbol(Symbol): return [value.value] elif isinstance(value, NameSymbol): return [value.get_name_name().value] + elif isinstance(value, TupleSymbol): + return_values = [] + for elt in value.get_tuple_elts(): + return_values.append(elt.get_name_name().value) + return return_values else: raise RuntimeError("Return value should be a StringValue or NameSymbol") @@ -701,6 +706,46 @@ class ReturnSymbol(Symbol): return f"ReturnSymbol({self.get_full_name_with_scope()}): {self._value.value}" +class TupleSymbol(Symbol): + def __init__(self, ast_node: Optional[ast.AST], scope, symbol_name): + super().__init__(ast_node, scope, symbol_name, SymbolType.return_) + self._elts: [ValueNode] = [] + self.finish_compile() + + def add_tuple_elts(self, symbol: Value, index: Optional[int] = None) -> ValueNode: + assert isinstance(symbol, Value) + symbol_node = ValueNode(symbol) + self._add_child(symbol_node) + if index is None: + self._elts.append(symbol_node) + else: + assert index < len(self._elts) + self._elts.insert(index, symbol_node) + return symbol_node + + def set_tuple_elt(self, index: int, symbol: Value) -> ValueNode: + assert isinstance(symbol, Value) + assert index < len(self._elts) + self._elts[index].set_value(symbol) + self._update_child_state(symbol) + return self._elts[index] + + def set_tuple_elts(self, elts: [Value]): + assert len(elts) == len(self._elts) + for i in range(0, len(elts)): + self._update_child_state(elts[i]) + self._elts[i].set_value(elts[i]) + + def get_tuple_elts(self) -> [Value]: + elts: [Value] = [] + for elt in self._elts: + elts.append(elt.value()) + return elts + + def __str__(self): + return f"TupleSymbol({self.get_full_name_with_scope()})" + + class UnaryOpSymbol(Symbol): def __init__(self, ast_node: Optional[ast.AST], scope, symbol_name, op): super().__init__(ast_node, scope, symbol_name, SymbolType.unary_op) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_propagate.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_propagate.py index c2084c85286..d4edec49d43 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_propagate.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/symbol_transformers/const_symbol_propagate.py @@ -39,6 +39,9 @@ class ConstSymbolPropagate(Processor): @staticmethod def symbol_propagate(symbol: Symbol, single_children: [str], list_children: [str]): for single_child in single_children: + if not hasattr(symbol, single_child): + logger.debug(f"symbol has no attr {single_child}") + continue value_node: ValueNode = getattr(symbol, single_child) value = value_node.value() if isinstance(value, ConstantSymbol): diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet_multi_output.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet_multi_output.py new file mode 100644 index 00000000000..22db4cabeff --- /dev/null +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet_multi_output.py @@ -0,0 +1,70 @@ +from mindspore.rewrite_experiment import Rewrite, Argument, ArgType +from mindspore import nn +from mindspore.common.initializer import Normal + + +class LeNet5(nn.Cell): + def __init__(self, num_class=10, num_channel=1): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class LeNet2(nn.Cell): + def __init__(self, num_class=10, num_channel=1): + super(LeNet2, self).__init__() + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +def print_code(title, rewrite): + print(f"=========={title}===================================================================================") + print(rewrite.get_code()) + + +if __name__ == '__main__': + lenet = LeNet5(10) + rewrite = Rewrite(lenet) + print_code("after resolve", rewrite) + + lenet = LeNet2(10) + rewrite = Rewrite(lenet) + print_code("after resolve", rewrite) -- Gitee From d1255c1d4533d337f1c5a78a53f26a1fcd52a6ce Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 16 Feb 2022 10:46:05 +0800 Subject: [PATCH 24/32] test_multi_output --- .../rewrite_experiment/example/test_lenet_multi_output.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet_multi_output.py b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet_multi_output.py index 22db4cabeff..430e3ea7b25 100644 --- a/mindspore/python/mindspore/rewrite_experiment/example/test_lenet_multi_output.py +++ b/mindspore/python/mindspore/rewrite_experiment/example/test_lenet_multi_output.py @@ -26,7 +26,7 @@ class LeNet5(nn.Cell): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) - return x + return x, x class LeNet2(nn.Cell): @@ -52,7 +52,7 @@ class LeNet2(nn.Cell): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) - return x + return self.fc3(x) def print_code(title, rewrite): -- Gitee From ba940d889e3d07666bf4f010c8b2992a4d60467d Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Wed, 16 Feb 2022 11:02:44 +0800 Subject: [PATCH 25/32] parse new node attribute --- .../python/mindspore/rewrite_experiment/ast_parse/resolver.py | 2 +- mindspore/python/mindspore/rewrite_experiment/node.py | 3 +++ mindspore/python/mindspore/rewrite_experiment/symbol_tree.py | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py index 6cbd8d03721..ac6e0043d30 100644 --- a/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py +++ b/mindspore/python/mindspore/rewrite_experiment/ast_parse/resolver.py @@ -164,6 +164,7 @@ class Resolver: if k.startswith("_"): continue attributes[k] = v + attributes["cls"] = obj.__class__ return attributes def _get_symbol_attribute(self, symbol_name): @@ -173,7 +174,6 @@ class Resolver: for key, value in var_dict["_cells"].items(): if key == symbol_name: attributes = Resolver.get_object_attribute(value) - attributes["cls"] = value.__class__ logger.debug(f"key: {key}, attributes: {attributes}") return value, attributes return None, attributes diff --git a/mindspore/python/mindspore/rewrite_experiment/node.py b/mindspore/python/mindspore/rewrite_experiment/node.py index b605885e1cd..5154b82adce 100644 --- a/mindspore/python/mindspore/rewrite_experiment/node.py +++ b/mindspore/python/mindspore/rewrite_experiment/node.py @@ -351,6 +351,9 @@ class Node: def set_attribute(self, key: str, value): self._attribute[key] = value + def set_attributes(self, attributes): + self._attribute = attributes + def get_attributes(self): return self._attribute diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py index b36b5db1fd4..e1af0b9b27d 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_tree.py @@ -322,6 +322,9 @@ class SymbolTree: raise RuntimeError("Only support insert cell op now") AstModifier.insert_assign_ast_to_function(self._construct_func_ast, node_ast, position[1].get_ast(), position[0]) + from .ast_parse.resolver import Resolver + attributes = Resolver.get_object_attribute(node.get_op()) + node.set_attributes(attributes) self._insert_node(position, node) return node -- Gitee From 1730b69a0e2c7407814f80ad42c9f22e45bbb0dc Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Tue, 25 Jan 2022 15:24:11 +0800 Subject: [PATCH 26/32] QAT support lenet --- .../default_qat/default_fake_quantizer.py | 37 ++++++++++--------- .../default_qat/default_layer_policy.py | 16 ++++---- .../golden_stick/quantization/layer_policy.py | 4 +- .../quantization/quant_aware_training.py | 5 ++- .../golden_stick/quantization/quant_utils.py | 5 ++- .../quantization/quantize_wrapper_cell.py | 37 +++++++++---------- mindspore/python/mindspore/nn/layer/quant.py | 20 +++++++--- .../rewrite/pattern_engine_for_cell.py | 2 +- 8 files changed, 68 insertions(+), 58 deletions(-) diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py index 2fd0cacfbff..14dd88eb157 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py @@ -15,6 +15,7 @@ """DefaultQuantizeOp.""" from functools import partial +import mindspore from ..fake_quantizer import FakeQuantizer from ..quant_utils import compute_kl_threshold from mindspore.ops.operations import _quant_ops as Q @@ -39,11 +40,13 @@ class DefaultFakeQuantizerPerLayer(FakeQuantizer): 2. run fake quant execution to simulate the quantize loss """ - def __init__(self, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False): + def __init__(self, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, num_bits=8, quant_delay=0): super(DefaultFakeQuantizerPerLayer, self).__init__() self._ema = ema self._ema_decay = ema_decay self._symmetric = symmetric + self._num_bits = num_bits + self._quant_delay = quant_delay self._narraw_range = narrow_range self._min_max_update_func = partial(Q.MinMaxUpdatePerLayer, ema=self._ema, ema_decay=self._ema_decay) self._is_ascend = context.get_context("device_target") == "Ascend" @@ -54,19 +57,19 @@ class DefaultFakeQuantizerPerLayer(FakeQuantizer): def _init_fake_quant_func(self, quant_func): if self.is_ascend: - self._fake_quant_train = quant_func(num_bits=self.quant_dtype.num_bits, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - quant_delay=self.quant_delay) + self._fake_quant_train = quant_func(num_bits=self._num_bits, + symmetric=self._symmetric, + narrow_range=self._narrow_range, + quant_delay=self._quant_delay) self._fake_quant_infer = self.fake_quant_train else: quant_func = partial(quant_func, - ema=self.ema, - ema_decay=self.ema_decay, - num_bits=self.quant_dtype.num_bits, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - quant_delay=self.quant_delay) + ema=self._ema, + ema_decay=self._ema_decay, + num_bits=self._num_bits, + symmetric=self._symmetric, + narrow_range=self._narrow_range, + quant_delay=self._quant_delay) self._fake_quant_train = quant_func(training=True) self._fake_quant_infer = quant_func(training=False) @@ -74,9 +77,9 @@ class DefaultFakeQuantizerPerLayer(FakeQuantizer): if self.training: self._float_min, self._float_max = \ self._min_max_update_func(x, self._float_min, self._float_max) - out = self._fake_quant_train(x, self._float_max, self._float_max) + out = self._fake_quant_train(x, self._float_min, self._float_max) else: - out = self._fake_quant_infer(x, self._float_max, self._float_max) + out = self._fake_quant_infer(x, self._float_min, self._float_max) return out @@ -108,8 +111,8 @@ class LearnedFakeQuantizerPerLayer(FakeQuantizer): quant_func = partial(Q.FakeLearnedScaleQuantPerLayer, quant_delay=quant_delay, neg_trunc=self.neg_trunc) self.fake_quant_train = quant_func(training=True) self.fake_quant_infer = quant_func(training=False) - self._float_min = Parameter(Tensor(min_init), name="float_min") - self._float_max = Parameter(Tensor(max_init), name="float_max") + self._float_min = Parameter(Tensor([min_init], mindspore.float32), name="float_min") + self._float_max = Parameter(Tensor([max_init], mindspore.float32), name="float_max") def compute_quant_param(self, weight_param): max_init = [compute_kl_threshold(weight_param, self._num_bits)] @@ -141,8 +144,8 @@ class LearnedFakeQuantizePerChannel(FakeQuantizer): self.fake_quant_train = quant_func(training=True) self.fake_quant_infer = quant_func(training=False) self._num_channels = num_channels - self._float_min = Parameter(Tensor(self._get_init_array(float_min)), name="float_min") - self._float_max = Parameter(Tensor(self._get_init_array(float_max)), name="float_max") + self._float_min = Parameter(Tensor(self._get_init_array(float_min), mindspore.float32), name="float_min") + self._float_max = Parameter(Tensor(self._get_init_array(float_max), mindspore.float32), name="float_max") def compute_quant_param(self, weight_param): max_init = [compute_kl_threshold(weight_para_each.asnumpy(), self._num_bits) diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py index 89404dc486a..c9d9a728f68 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_layer_policy.py @@ -18,7 +18,7 @@ from typing import Optional from ..layer_policy import LayerPolicy from ..quantize_wrapper_cell import QuantizeWrapperCell from ..fake_quantizer import FakeQuantizer -from .default_fake_quantizer import LearnedFakeQuantizerPerLayer, LearnedFakeQuantizePerChannel +from .default_fake_quantizer import DefaultFakeQuantizerPerChannel, DefaultFakeQuantizerPerLayer from mindspore.nn import Cell from mindspore.nn.layer.quant import QuantConfig, Conv2dQuant, DenseQuant, Conv2dBnFoldQuantOneConv @@ -34,10 +34,10 @@ class DefaultLayerPolicy(LayerPolicy): def __init__(self, weight_names: [], act_names: [], config=None): if config is None: config = {} - self._weight_quantizer = LearnedFakeQuantizePerChannel - self._act_quantizer = LearnedFakeQuantizerPerLayer() - self._input_quantizer: Optional[FakeQuantizer] = LearnedFakeQuantizerPerLayer() - self._output_quantizer: Optional[FakeQuantizer] = LearnedFakeQuantizerPerLayer() + self._weight_quantizer = DefaultFakeQuantizerPerChannel + self._act_quantizer = DefaultFakeQuantizerPerLayer() + self._input_quantizer: Optional[FakeQuantizer] = DefaultFakeQuantizerPerLayer() + self._output_quantizer: Optional[FakeQuantizer] = DefaultFakeQuantizerPerLayer() self._weight_names = weight_names self._act_names = act_names self._input_num = 0 @@ -69,10 +69,8 @@ class DefaultLayerPolicy(LayerPolicy): raise RuntimeError("Index out of range of input number") self._inputs_insert_fq[index] = False - def get_input_need_insert_fq(self, index: int): - if index >= self._input_num: - raise RuntimeError("Index out of range of input number") - return self._inputs_insert_fq[index] + def get_input_need_insert_fq(self): + return self._inputs_insert_fq def set_output_not_insert_fq(self, index: Optional[int] = None): self._output_quantizer = None diff --git a/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py b/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py index b6fde62a67e..2ec8314ac22 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/layer_policy.py @@ -92,8 +92,8 @@ class LayerPolicy(abc.ABC): def set_input_not_insert_fq(self, index: Optional[int] = None): pass - def get_input_need_insert_fq(self, index: int) -> bool: - return False + def get_input_need_insert_fq(self) -> list: + return [] # only support one-output-quantizer pre layer because we can not get how many outputs a cell would has def set_output_not_insert_fq(self): diff --git a/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py b/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py index 90f8ac98d41..7a608da3cd3 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py @@ -142,7 +142,10 @@ class QuantAwareTraining(GoldenStick): for node in net_transformer.nodes(): layer_policy = node.get_attribute(layer_policy_key) if isinstance(layer_policy, LayerPolicy): - wrapped_cell = layer_policy.wrap_cell(net_transformer.get_cell_by_node(node)) + cell = net_transformer.get_cell_by_node(node) + wrapped_cell = layer_policy.wrap_cell(cell) + prefix = node.name.split('.', 1)[1] + wrapped_cell.update_parameters_name(prefix + '.') net_transformer.replace_node(node, wrapped_cell) def apply(self, net: Cell) -> Cell: diff --git a/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py b/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py index cf6cc57ae6c..077cbfc91d4 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py @@ -16,6 +16,7 @@ import numpy as np from mindspore._checkparam import Validator +from mindspore.nn.layer.quant import quant from ... import nn __all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers", "compute_kl_threshold"] @@ -419,8 +420,8 @@ def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_param # Perform KL_init when learned scale quantization is executed. for cell_and_name in quant_model.cells_and_names(): cell = cell_and_name[1] - if isinstance(cell, (nn.Conv2dBnFoldQuantOneConv, nn.Conv2dBnFoldQuant, nn.Conv2dBnWithoutFoldQuant, - nn.Conv2dQuant, nn.DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE": + if isinstance(cell, (quant.Conv2dBnFoldQuantOneConv, quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant, + quant.Conv2dQuant, quant.DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE": subcell_weight_para = cell.weight.data.asnumpy() if hasattr(cell, 'gamma'): scale_factor = (cell.gamma.data.asnumpy() / diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py index d95aa85d757..863b738fdf8 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py @@ -37,6 +37,16 @@ class QuantizeWrapperCell(Cell): self._w_zp = 0 self._o_scale = 1.0 self._o_zp = 0 + self._input_quantizer = self._policy.get_input_quantizer() + self._output_quantizer = self._policy.get_output_quantizer() + self._input_insert_quantizer = self._policy.get_input_need_insert_fq() + # fake-quant weight + for weight_name, quantizer in self._policy.get_weight_name_and_quantizers(): + assert weight_name is not None + assert quantizer is not None + weight = getattr(self._handler, weight_name) + fq_data = quantizer(weight) + setattr(self._handler, weight_name, fq_data) def construct(self, *inputs, **kwargs): """ @@ -46,25 +56,16 @@ class QuantizeWrapperCell(Cell): Tensor, returns the computed result. """ - # fake-quant weight - for weight_name, quantizer in self._policy.get_weight_name_and_quantizers(): - assert weight_name is not None - assert quantizer is not None - weight = getattr(self._handler, weight_name) - fq_data = quantizer(weight) - setattr(self._handler, weight_name, fq_data) - # fake-quant input - input_quantizer = self._policy.get_input_quantizer() - if input_quantizer is None: + if self._input_quantizer is None: fq_inputs = inputs else: fq_inputs = [] input_len = len(inputs) for i in range(0, input_len): ori_input = inputs[i] - if self._policy.get_input_need_insert_fq(i): - fq_inputs.append(input_quantizer(ori_input)) + if self._input_insert_quantizer[i]: + fq_inputs.append(self._input_quantizer(ori_input)) else: fq_inputs.append(ori_input) @@ -72,20 +73,16 @@ class QuantizeWrapperCell(Cell): outputs = self._handler(*fq_inputs, **kwargs) # fake-quant output - output_quantizer = self._policy.get_output_quantizer() - if output_quantizer is None: + if self._output_quantizer is None: return outputs - if isinstance(outputs, list) or isinstance(outputs, tuple): - raise RuntimeError("Only support single output tensor fake-quant now") + if not isinstance(outputs, list): + return self._output_quantizer(outputs) output_len = len(outputs) if output_len == 0: return outputs - elif output_len == 1: - fq_data = output_quantizer(outputs) - return fq_data else: fq_outputs = [] for i in range(0, output_len): ori_output = outputs[i] - fq_outputs.append(output_quantizer(ori_output)) + fq_outputs.append(self._output_quantizer(ori_output)) return fq_outputs diff --git a/mindspore/python/mindspore/nn/layer/quant.py b/mindspore/python/mindspore/nn/layer/quant.py index 20cec6ff28d..5f6fd19a318 100644 --- a/mindspore/python/mindspore/nn/layer/quant.py +++ b/mindspore/python/mindspore/nn/layer/quant.py @@ -1455,8 +1455,13 @@ class Conv2dQuant(Cell): conv_quant = cls(conv2d.in_channels, conv2d.out_channels, conv2d.kernel_size, - has_bias=conv2d.has_bias, - quant_config=quant_config) + conv2d.stride, + conv2d.pad_mode, + conv2d.padding, + conv2d.dilation, + conv2d.group, + conv2d.has_bias, + quant_config) return conv_quant def construct(self, x): @@ -1470,9 +1475,9 @@ class Conv2dQuant(Cell): """Display instance object as string.""" s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \ - 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, - self.pad_mode, self.padding, self.dilation, self.group, - self.has_bias, self.fake_quant_weight.quant_delay) + 'has_bias={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.has_bias) return s @@ -1585,7 +1590,10 @@ class DenseQuant(Cell): def from_float(cls, dense: Dense, quant_config: QuantConfig): dense_quant = cls(dense.in_channels, dense.out_channels, - has_bias=dense.has_bias, + dense.weight, + dense.bias, + dense.has_bias, + dense.activation, quant_config=quant_config) return dense_quant diff --git a/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py b/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py index bcc21b5ba1b..7450b398c39 100644 --- a/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py +++ b/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py @@ -22,7 +22,7 @@ from .rewriter import Rewriter, Graph class Identity(Cell): def construct(self, *inputs, **kwargs): - pass + return inputs[0] def convert_nodes_to_cells(net: Cell, nodes: [Node]): -- Gitee From c67441121647b9e4d4b8a2e5a87da23d94c625d2 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Thu, 10 Feb 2022 02:01:49 +0000 Subject: [PATCH 27/32] !39 use rewriter to support qat * use rewriter to support qat * QAT support lenet * update interface of rewrite * tmp --- mindspore/python/mindspore/__init__.py | 4 +- .../example/custom_qat_example.py | 96 +++++++ .../example/default_qat_example.py | 59 +--- .../mindspore/golden_stick/net_transform.py | 138 ++-------- .../default_qat/default_fake_quantizer.py | 14 +- .../quantization/quant_aware_training.py | 40 ++- .../golden_stick/quantization/quant_utils.py | 6 +- .../quantization/quantize_wrapper_cell.py | 23 +- mindspore/python/mindspore/nn/layer/quant.py | 2 + .../python/mindspore/rewrite/argument.py | 60 +++++ mindspore/python/mindspore/rewrite/graph.py | 252 ++++++++++-------- mindspore/python/mindspore/rewrite/node.py | 2 +- .../rewrite/pattern_engine_for_cell.py | 18 +- .../python/mindspore/rewrite/rewriter.py | 82 ++++-- mindspore/python/mindspore/rewrite/test.py | 2 +- .../python/mindspore/rewrite/test_app.py | 2 +- 16 files changed, 453 insertions(+), 347 deletions(-) create mode 100644 mindspore/python/mindspore/golden_stick/example/custom_qat_example.py create mode 100644 mindspore/python/mindspore/rewrite/argument.py diff --git a/mindspore/python/mindspore/__init__.py b/mindspore/python/mindspore/__init__.py index fe59e61c1f4..a9bef057822 100755 --- a/mindspore/python/mindspore/__init__.py +++ b/mindspore/python/mindspore/__init__.py @@ -28,7 +28,7 @@ from .context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_context, set_au from .version import __version__ from .golden_stick import * from .rewrite import * -from .rewrite_experiment import * +# from .rewrite_experiment import * __all__ = ["run_check"] @@ -38,5 +38,5 @@ __all__.extend(train.__all__) __all__.extend(log.__all__) __all__.extend(context.__all__) __all__.extend(rewrite.__all__) -__all__.extend(rewrite_experiment.__all__) +# __all__.extend(rewrite_experiment.__all__) __all__.extend(golden_stick.__all__) diff --git a/mindspore/python/mindspore/golden_stick/example/custom_qat_example.py b/mindspore/python/mindspore/golden_stick/example/custom_qat_example.py new file mode 100644 index 00000000000..542547b4751 --- /dev/null +++ b/mindspore/python/mindspore/golden_stick/example/custom_qat_example.py @@ -0,0 +1,96 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from typing import List, Tuple + +from ..quantization.default_qat.default_quant_aware_training import DefaultQuantAwareTraining +from mindspore.nn import Cell, Conv2d, BatchNorm2d, Dense, ReLU, MaxPool2d, Flatten, SoftmaxCrossEntropyWithLogits, \ + Momentum +from mindspore.train.model import Model +from ..quantization.transformer import Transformer +from ..quantization.layer_policy import LayerPolicy +from ..quantization.fake_quantizer import FakeQuantizer + + +class LeNet5(Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = Conv2d(1, 6, 5) + self.fc1 = Dense(16 * 5 * 5, 120) + self.relu = ReLU() + self.max_pool2d = MaxPool2d(kernel_size=2, stride=2) + self.flatten = Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + return x + + +# custom quantizer +class AllBitFakeQuantizer(FakeQuantizer): + """ + Derived class of QuantizeOp. Use min and max value of data to compute scale and zero-point. + """ + + def __init__(self, bit_num=8): + super().__init__() + self._bit_num = bit_num + + def compute(self, data: [float]) -> Tuple[List[float], float, int]: + min = data[0] + max = data[1] + scale = (1 << self._bit_num) / (max - min) + zp = max * scale + return data, scale, zp + + +# custom layer policy +class ConvBNQPolicy(LayerPolicy): + def __init__(self): + super().__init__() + self._quantizer = AllBitFakeQuantizer() + + def get_weight_name_and_quantizers(self) -> [(str, FakeQuantizer)]: + # todo how to define weight inside of a subgraph + return [("_old_conv.weight", self._quantizer), ("_old_bn.gamma", self._quantizer), + ("_old_bn.beta", self._quantizer), ("_old_bn.moving_mean", self._quantizer), + ("_old_bn.moving_variance", self._quantizer)] + + def get_act_name_and_quantizers(self) -> [(str, (FakeQuantizer, FakeQuantizer))]: + return [] + + def get_output_quantizers(self) -> [FakeQuantizer]: + return [self._quantizer] + + +net = LeNet5() + +custom_pattern_engine = Transformer([Conv2d, BatchNorm2d]) +custom_conv_bn_policy = ConvBNQPolicy() +# custom net policy +algo = DefaultQuantAwareTraining({"custom_transforms": [custom_pattern_engine], + "custom_policies": [custom_conv_bn_policy]}) +net_opt = algo.apply(net) + +loss = algo.loss(SoftmaxCrossEntropyWithLogits()) +optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) +model = Model(net_opt, loss_fn=loss, optimizer=optimizer, metrics=None) +dataset = {} +model.train(2, dataset, algo.callback()) diff --git a/mindspore/python/mindspore/golden_stick/example/default_qat_example.py b/mindspore/python/mindspore/golden_stick/example/default_qat_example.py index 542547b4751..7109c9c1eb6 100644 --- a/mindspore/python/mindspore/golden_stick/example/default_qat_example.py +++ b/mindspore/python/mindspore/golden_stick/example/default_qat_example.py @@ -15,12 +15,8 @@ from typing import List, Tuple from ..quantization.default_qat.default_quant_aware_training import DefaultQuantAwareTraining -from mindspore.nn import Cell, Conv2d, BatchNorm2d, Dense, ReLU, MaxPool2d, Flatten, SoftmaxCrossEntropyWithLogits, \ - Momentum -from mindspore.train.model import Model -from ..quantization.transformer import Transformer -from ..quantization.layer_policy import LayerPolicy -from ..quantization.fake_quantizer import FakeQuantizer +from mindspore.nn import Cell, Conv2d, Dense, ReLU, MaxPool2d, Flatten, SoftmaxCrossEntropyWithLogits, Momentum +from mindspore.train import Model class LeNet5(Cell): @@ -43,54 +39,11 @@ class LeNet5(Cell): return x -# custom quantizer -class AllBitFakeQuantizer(FakeQuantizer): - """ - Derived class of QuantizeOp. Use min and max value of data to compute scale and zero-point. - """ - - def __init__(self, bit_num=8): - super().__init__() - self._bit_num = bit_num - - def compute(self, data: [float]) -> Tuple[List[float], float, int]: - min = data[0] - max = data[1] - scale = (1 << self._bit_num) / (max - min) - zp = max * scale - return data, scale, zp - - -# custom layer policy -class ConvBNQPolicy(LayerPolicy): - def __init__(self): - super().__init__() - self._quantizer = AllBitFakeQuantizer() - - def get_weight_name_and_quantizers(self) -> [(str, FakeQuantizer)]: - # todo how to define weight inside of a subgraph - return [("_old_conv.weight", self._quantizer), ("_old_bn.gamma", self._quantizer), - ("_old_bn.beta", self._quantizer), ("_old_bn.moving_mean", self._quantizer), - ("_old_bn.moving_variance", self._quantizer)] - - def get_act_name_and_quantizers(self) -> [(str, (FakeQuantizer, FakeQuantizer))]: - return [] - - def get_output_quantizers(self) -> [FakeQuantizer]: - return [self._quantizer] - - net = LeNet5() - -custom_pattern_engine = Transformer([Conv2d, BatchNorm2d]) -custom_conv_bn_policy = ConvBNQPolicy() -# custom net policy -algo = DefaultQuantAwareTraining({"custom_transforms": [custom_pattern_engine], - "custom_policies": [custom_conv_bn_policy]}) -net_opt = algo.apply(net) - -loss = algo.loss(SoftmaxCrossEntropyWithLogits()) +default_qat = DefaultQuantAwareTraining() +net_opt = default_qat.apply(net) +loss = default_qat.loss(SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")) optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(net_opt, loss_fn=loss, optimizer=optimizer, metrics=None) dataset = {} -model.train(2, dataset, algo.callback()) +model.train(2, dataset, default_qat.callback()) diff --git a/mindspore/python/mindspore/golden_stick/net_transform.py b/mindspore/python/mindspore/golden_stick/net_transform.py index e63b2e73bd6..f4772180165 100644 --- a/mindspore/python/mindspore/golden_stick/net_transform.py +++ b/mindspore/python/mindspore/golden_stick/net_transform.py @@ -17,8 +17,8 @@ from typing import Union, Optional from mindspore.nn.cell import Cell from mindspore.rewrite import Graph, PatternEngine -from mindspore.rewrite.rewriter import Rewriter -from mindspore.rewrite.node import BaseNode +from mindspore.rewrite.rewriter import Rewriter, Argument +from mindspore.rewrite.node import BaseNode, Node from mindspore.rewrite.pattern_engine_for_cell import PatternEngineForCell @@ -32,17 +32,9 @@ class NetTransformer: def __init__(self, net: Cell): self._net = net - self._graph = Rewriter.parse(net) + self._rewriter = Rewriter(net) - def get_transformed(self) -> Cell: - """ - Returns: - Transformed network. - """ - - return self._graph.python_object() - - def net(self) -> Cell: + def get_network(self) -> Cell: return self._net def nodes(self) -> [BaseNode]: @@ -51,70 +43,20 @@ class NetTransformer: a list of BaseNode corresponding to all layers in original network. """ - return self._graph.nodes - - @staticmethod - def set_node_attr(node: BaseNode, key: str, value): - node.set_attribute(key, value) - - @staticmethod - def get_node_attr(node: BaseNode, key: str): - return node.get_attribute(key) - - def find_node(self, full_name_with_scope: str) -> BaseNode: - """ - Args: - full_name_with_scope (str): Name of node to be find. - - Returns: - BaseNode whose name is `full_name_with_scope`. - """ - - return self._graph.find(full_name_with_scope) - - # return inputs of node whose full_name_with_scope is full_name_with_scope - def node_inputs(self, full_name_with_scope: str) -> [BaseNode]: - """ - Args: - full_name_with_scope (str): Name of node to be find. - - Returns: - Input nodes of node whose name is `full_name_with_scope` - """ - - node = self._graph.find(full_name_with_scope) - if node is None: - return [] - return node.inputs - - # return outputs of node whose full_name_with_scope is full_name_with_scope - def node_outputs(self, full_name_with_scope: str) -> [BaseNode]: - """ - Args: - full_name_with_scope (str): Name of node to be find. - - Returns: - Output nodes of node whose name is `full_name_with_scope` - """ + return self._rewriter.nodes() - node = self._graph.find(full_name_with_scope) - if node is None: - return [] - return node.outputs + def before(self, node_or_name: Union[Node, str]): + return self._rewriter.before(node_or_name) - def insert_node(self, new_node: BaseNode) -> BaseNode: - """ - Args: - new_node (BaseNode): New node to be inserted into original network. - New_node should contain its inputs and outputs. + def after(self, node_or_name: Union[Node, str]): + return self._rewriter.after(node_or_name) - Returns: - BaseNode has been inserted, return None if failed - """ - - return self._graph.insert_node(new_node) + def add_object(self, position, custom_obj: Cell, field: str, targets: [str], target_type: str, + call_args: [Argument], call_kwargs: {str: Argument}) -> Optional[Node]: + self._net.insert_child_to_cell(field, custom_obj) + return self._rewriter.add_object(position, custom_obj, field, targets, target_type, call_args, call_kwargs) - def remove_node(self, node: Union[str, BaseNode]) -> Optional[BaseNode]: + def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: """ Args: node (BaseNode): node to be removed from original network. @@ -123,52 +65,14 @@ class NetTransformer: BaseNode has been removed, return None if failed """ - if isinstance(node, str): - node = self._graph.find(node) + if isinstance(node_or_name, str): + node = self._rewriter.find_node(node_or_name) + else: + node = node_or_name if node is None: return None - return self._graph.remove_node(node) - - def get_cell_by_node(self, node: Union[str, BaseNode]) -> Optional[Cell]: - """ - Args: - node (BaseNode): node to be used to find cell. + return self._rewriter.erase_node(node) - Returns: - founded cell - """ - cells = self._net.name_cells() - if isinstance(node, str): - cell_name = node - elif isinstance(node, BaseNode): - cell_name = node.name.split('.', 1)[1] - else: - raise Exception("Only support str or base node yet.") - if cell_name in cells.keys(): - return cells[cell_name] - raise Exception("{} not found in origin net.".format(cell_name)) - - def replace_node(self, target: Union[str, BaseNode], value: Union[Cell, BaseNode]) -> Optional[BaseNode]: - """ - Args: - target (Union[str, BaseNode]): Name of node to be replaced. - value (Union[Cell, BaseNode]): BaseNode to be replaced into original network. - - Note: - new_node should has same inputs and outputs with old_node. - - Returns: - BaseNode has been replaced, return None if failed - """ - - if isinstance(target, str): - target = self._graph.find(target) - if target is None: - return None - if isinstance(value, Cell): - self._net.insert_child_to_cell(target.name.split('.', 1)[1], value) - value = BaseNode(value) - return self._graph.replace_node() # replace src_pattern with target_nodes. # target_nodes should has same inputs and outputs with src_pattern. @@ -181,5 +85,5 @@ class NetTransformer: a bool value indicating if transform occurred """ if isinstance(pattern_engine, PatternEngineForCell): - return pattern_engine.apply_cell(self._net, self._graph) - return pattern_engine.apply(self._graph) + return pattern_engine.apply_cell(self._net, self._rewriter) + return pattern_engine.apply(self._rewriter) diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py index 14dd88eb157..0a3d98d9dad 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py @@ -47,16 +47,16 @@ class DefaultFakeQuantizerPerLayer(FakeQuantizer): self._symmetric = symmetric self._num_bits = num_bits self._quant_delay = quant_delay - self._narraw_range = narrow_range - self._min_max_update_func = partial(Q.MinMaxUpdatePerLayer, ema=self._ema, ema_decay=self._ema_decay) + self._narrow_range = narrow_range + self._min_max_update_func = Q.MinMaxUpdatePerLayer(ema=self._ema, ema_decay=self._ema_decay) self._is_ascend = context.get_context("device_target") == "Ascend" quant_func = Q.FakeQuantPerLayer self._init_fake_quant_func(quant_func) - self._float_min = Parameter(Tensor(float("-inf")), name="float_min") - self._float_max = Parameter(Tensor(float("inf")), name="float_max") + self._float_min = Parameter(Tensor(np.array([-6]).astype(np.float32), mindspore.float32), name="float_min") + self._float_max = Parameter(Tensor(np.array([6]).astype(np.float32), mindspore.float32), name="float_max") def _init_fake_quant_func(self, quant_func): - if self.is_ascend: + if self._is_ascend: self._fake_quant_train = quant_func(num_bits=self._num_bits, symmetric=self._symmetric, narrow_range=self._narrow_range, @@ -91,8 +91,8 @@ class DefaultFakeQuantizerPerChannel(DefaultFakeQuantizerPerLayer): def __init__(self, num_channels=1, channel_axis=1, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False): super(DefaultFakeQuantizerPerChannel, self).__init__(ema=ema, ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) - self._float_min = Parameter(Tensor([float("-inf")] * num_channels), name="float_min") - self._float_max = Parameter(Tensor([float("inf")] * num_channels), name="float_max") + self._float_min = Parameter(Tensor(np.array([-6] * num_channels).astype(np.float32), mindspore.float32), name="float_min") + self._float_max = Parameter(Tensor(np.array([6] * num_channels).astype(np.float32), mindspore.float32), name="float_max") quant_func = partial(Q.FakeQuantPerChannel, channel_axis=channel_axis) self._init_fake_quant_func(quant_func) diff --git a/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py b/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py index 7a608da3cd3..3241d03f537 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py @@ -21,7 +21,7 @@ from .layer_policy import LayerPolicy, layer_policy_key from .transformer import Transformer from ..golden_stick import GoldenStick from ..net_transform import NetTransformer -from mindspore.rewrite import Node +from mindspore.rewrite import Node, BaseNode from mindspore.nn import Cell @@ -80,7 +80,7 @@ class QuantAwareTraining(GoldenStick): for node in net_transformer.nodes(): if not isinstance(node, Node): continue - cur_policy: LayerPolicy = NetTransformer.get_node_attr(node, layer_policy_key) + cur_policy: LayerPolicy = node.get_attribute(layer_policy_key) # cur-node has no quant policy, so no fq will insert into its inputs if cur_policy is None: continue @@ -91,7 +91,7 @@ class QuantAwareTraining(GoldenStick): if cur_in_quantizer is None: continue input_node: Node = input_nodes[i] - pre_policy: LayerPolicy = NetTransformer.get_node_attr(input_node, layer_policy_key) + pre_policy: LayerPolicy = input_node.get_attribute(layer_policy_key) # pre-node has no quant policy, so no fq will insert into its outputs if pre_policy is None: continue @@ -128,6 +128,34 @@ class QuantAwareTraining(GoldenStick): # Transformer always return False net_transformer.pattern_transform(transformer) + @staticmethod + def _get_cell_by_node(net_transformer: NetTransformer, node_or_name: Node): + """ + Args: + node_or_name (BaseNode): node to be used to find cell. + + Returns: + founded cell + """ + cells = net_transformer.get_network().name_cells() + if isinstance(node_or_name, str): + cell_name = node_or_name + elif isinstance(node_or_name, BaseNode): + cell_name = node_or_name.name.split('.', 1)[1] + else: + raise Exception("Only support str or base node yet.") + if cell_name in cells.keys(): + return cells[cell_name] + raise Exception("{} not found in origin net.".format(cell_name)) + + @staticmethod + def _replace_node(net_transformer: NetTransformer, target_node: Node, result_cell: Cell): + # todo erase_node have bug yet + # net_transformer.erase_node(target_node) + position = net_transformer.before(target_node.outputs[0]) + net_transformer.add_object(position, result_cell, target_node.name.split('.', 1)[1], target_node.targets, "", + target_node.targets, {}) + @staticmethod def _apply_layer_policy(net_transformer: NetTransformer): """ @@ -142,11 +170,11 @@ class QuantAwareTraining(GoldenStick): for node in net_transformer.nodes(): layer_policy = node.get_attribute(layer_policy_key) if isinstance(layer_policy, LayerPolicy): - cell = net_transformer.get_cell_by_node(node) + cell = QuantAwareTraining._get_cell_by_node(net_transformer, node) wrapped_cell = layer_policy.wrap_cell(cell) prefix = node.name.split('.', 1)[1] wrapped_cell.update_parameters_name(prefix + '.') - net_transformer.replace_node(node, wrapped_cell) + QuantAwareTraining._replace_node(net_transformer, node, wrapped_cell) def apply(self, net: Cell) -> Cell: """ @@ -166,4 +194,4 @@ class QuantAwareTraining(GoldenStick): self._propagate_layer_policy(net_transformer) QuantAwareTraining._reduce_redundant_fake_quant(net_transformer) QuantAwareTraining._apply_layer_policy(net_transformer) - return net_transformer.net() + return net_transformer.get_network() diff --git a/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py b/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py index 077cbfc91d4..e4a6f740c8c 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quant_utils.py @@ -16,7 +16,7 @@ import numpy as np from mindspore._checkparam import Validator -from mindspore.nn.layer.quant import quant +from mindspore.nn.layer.quant import * from ... import nn __all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers", "compute_kl_threshold"] @@ -420,8 +420,8 @@ def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_param # Perform KL_init when learned scale quantization is executed. for cell_and_name in quant_model.cells_and_names(): cell = cell_and_name[1] - if isinstance(cell, (quant.Conv2dBnFoldQuantOneConv, quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant, - quant.Conv2dQuant, quant.DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE": + if isinstance(cell, (Conv2dBnFoldQuantOneConv, Conv2dBnFoldQuant, Conv2dBnWithoutFoldQuant, + Conv2dQuant, DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE": subcell_weight_para = cell.weight.data.asnumpy() if hasattr(cell, 'gamma'): scale_factor = (cell.gamma.data.asnumpy() / diff --git a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py index 863b738fdf8..7d75fdc0361 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quantize_wrapper_cell.py @@ -57,17 +57,18 @@ class QuantizeWrapperCell(Cell): """ # fake-quant input - if self._input_quantizer is None: - fq_inputs = inputs - else: - fq_inputs = [] - input_len = len(inputs) - for i in range(0, input_len): - ori_input = inputs[i] - if self._input_insert_quantizer[i]: - fq_inputs.append(self._input_quantizer(ori_input)) - else: - fq_inputs.append(ori_input) + fq_inputs = inputs + # if self._input_quantizer is None: + # fq_inputs = inputs + # else: + # fq_inputs = [] + # input_len = len(inputs) + # for i in range(0, input_len): + # ori_input = inputs[i] + # if self._input_insert_quantizer[i]: + # fq_inputs.append(self._input_quantizer(ori_input)) + # else: + # fq_inputs.append(ori_input) # forward handler outputs = self._handler(*fq_inputs, **kwargs) diff --git a/mindspore/python/mindspore/nn/layer/quant.py b/mindspore/python/mindspore/nn/layer/quant.py index 5f6fd19a318..1d665f05b2b 100644 --- a/mindspore/python/mindspore/nn/layer/quant.py +++ b/mindspore/python/mindspore/nn/layer/quant.py @@ -1461,6 +1461,8 @@ class Conv2dQuant(Cell): conv2d.dilation, conv2d.group, conv2d.has_bias, + conv2d.weight_init, + conv2d.bias_init, quant_config) return conv_quant diff --git a/mindspore/python/mindspore/rewrite/argument.py b/mindspore/python/mindspore/rewrite/argument.py new file mode 100644 index 00000000000..b42ba9a79a5 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/argument.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from enum import Enum +from typing import Optional + + +class ArgType(Enum): + StringArg = 0, + IntArg = 1, + FloatArg = 2, + NamingArg = 3, + CustomObjArg = 4 + + +class Argument: + def __init__(self, arg_type: ArgType, scope: str = "", name=None): + self.type = arg_type + self.scope = scope + self.name = name + + @classmethod + def create_imm_arg(cls, value) -> Optional['Argument']: + if isinstance(value, int): + return cls(ArgType.IntArg, "", value) + elif isinstance(value, float): + return cls(ArgType.FloatArg, "", value) + elif isinstance(value, str): + return cls(ArgType.StringArg, "", value) + else: + return None + + @classmethod + def create_naming_arg(cls, name: str, scope: str = "") -> 'Argument': + return cls(ArgType.NamingArg, scope, name) + + @classmethod + def create_custom_arg(cls, obj) -> 'Argument': + return cls(ArgType.CustomObjArg, "", obj) + + def __str__(self): + if self.type in (ArgType.IntArg, ArgType.FloatArg, ArgType.StringArg): + return str(self.name) + elif self.type == ArgType.NamingArg: + return f"{self.scope}.{self.name}" if (len(self.scope) > 0) else str(self.name) + elif self.type == ArgType.CustomObjArg: + return f"CustomObj: {str(self.name)}" + else: + return f"Illegal ArgType: {str(self.type)}" diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index 02f43a8eef7..714f54c836c 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -19,23 +19,27 @@ from .parser import Parser from .ast_unparser import ASTUnparser from .node_visitor import _node_list from .called_object_representation import CalledObjectRepresentation -#from _subgraph import SubGraph + + +# from _subgraph import SubGraph class _insert_point: - def __init__(self, graph, new_insert_position) -> None: + def __init__(self, graph, new_insert_position, before) -> None: self._graph = graph - self._orig_insert_point, graph._insert_pos = graph._insert_pos, new_insert_position - + self._orig_insert_point, graph._insert_pos, graph._insert_before = graph._insert_pos, \ + new_insert_position, before + def __enter__(self): pass - + def __exit__(self, type, value, trace): self._graph._insert_pos = self._orig_insert_point -class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继承Graph + +class Graph: # Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继承Graph def __init__(self, network: Union[nn.Cell, Primitive, FunctionType]): - #assert(type(network) in Union[nn.Cell, Primitive, FunctionType]) + # assert(type(network) in Union[nn.Cell, Primitive, FunctionType]) self._name = network.__class__.__name__ self._network_obj = network self._network_class = network.__class__ @@ -45,13 +49,15 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 self._root = None self._nodes: List = [] self._insert_pos: Node = None # insert node before it + self._insert_before = True self._contant_nodes: Dict = {} self._param_value: Dict = {} self._ast_function_root: Dict = {} self._symbols: Dict = {} self._subgraphs: Dict = {} - self._namespace: Dict[str, int] = {} #save the number of the variable, if the number is over 1,then modify the name - add a number as the name suffix - # self._new_obj: Dict = {} # 保存插入节点时对象名即对应的对象,供生成对象时使用,但是这样生成的代码就会与类不一致,代码是不是不需要感知这个东西 + self._namespace: Dict[ + str, int] = {} # save the number of the variable, if the number is over 1,then modify the name - add a number as the name suffix + # self._new_obj: Dict = {} # 保存插入节点时对象名即对应的对象,供生成对象时使用,但是这样生成的代码就会与类不一致,代码是不是不需要感知这个东西 self._new_import: Dict = {} @property @@ -147,25 +153,25 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 if isinstance(node, ast.FunctionDef): self._ast_function_root[node.name] = node - def create_placeholder(self, input:Union[ast.FunctionDef, list]): + def create_placeholder(self, input: Union[ast.FunctionDef, list]): """ Create placeholder for function. """ - #placeholders = [] + # placeholders = [] if isinstance(input, ast.FunctionDef): ast_node = input.args args = self._parser.parse_arguments(ast_node) keys = reversed(list(args.keys())) - #for name, value in args.items(): #反向遍历插入0的位置 + # for name, value in args.items(): #反向遍历插入0的位置 for key in keys: new_node = PlaceholderNode(key, key, ast_node, default_value=args[key]) - logger.debug (f"placeholder node, name: {new_node.name}; target: {new_node._targets}") + logger.debug(f"placeholder node, name: {new_node.name}; target: {new_node._targets}") self._nodes.insert(0, new_node) elif isinstance(input, list): args = input for name in reversed(args): new_node = PlaceholderNode(name, name) - logger.debug (f"placeholder node, name: {new_node.name}; target: {new_node._targets}") + logger.debug(f"placeholder node, name: {new_node.name}; target: {new_node._targets}") self._nodes.insert(0, new_node) else: logger.warning(f"unsupported input type, input: {input}") @@ -177,8 +183,8 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 if "__init__" not in self._network_class.__dict__.keys(): return logger.info(f"parse {self._base_scope} init function start") - #self.create_placeholder(self._ast_function_root["__init__"]) - self._parser.parse_arguments(self._ast_function_root["__init__"].args) #获取参数默认值 + # self.create_placeholder(self._ast_function_root["__init__"]) + self._parser.parse_arguments(self._ast_function_root["__init__"].args) # 获取参数默认值 self._parser.update_closure_namespace(self._network_class.__init__) for ast_node in self._ast_function_root["__init__"].body: if isinstance(ast_node, ast.Expr): @@ -199,27 +205,27 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 if node.name in self._subgraphs.keys(): subgraph = copy.deepcopy(self._subgraphs[node.name]) subgraph._name = n - #subgraph._called_obj._type = NodeType.call_method #TODO: 要根据是init中的cell还是调用的self.function来确定 + # subgraph._called_obj._type = NodeType.call_method #TODO: 要根据是init中的cell还是调用的self.function来确定 subgraph._called_obj._is_custom_define = True subgraph._args = node._args subgraph._targets = node._targets - # subgraph.update_placeholder() + # subgraph.update_placeholder() return subgraph elif attribute_name and attribute_name in self._symbols.keys(): logger.debug(f"defined in init function: {attribute_name}") node._called_obj = self._symbols[attribute_name] - return node + return node elif n == "return": - #node._called_obj._type = NodeType.output + # node._called_obj._type = NodeType.output return node elif self._parser.get_func_namesapce(n)[0]: class_, name_space_, is_custom_define_ = self._parser.get_func_namesapce(n) logger.debug(f"{n} defined in other namespace: {name_space_}") - #logger.debug(f"class: {class_}, name space: {name_space_}, is custom define: {is_custom_define_}") + # logger.debug(f"class: {class_}, name space: {name_space_}, is custom define: {is_custom_define_}") if is_custom_define_: subgraph = self.parse_function(class_) subgraph._name = n - #subgraph._called_obj._type = NodeType.call_method + # subgraph._called_obj._type = NodeType.call_method subgraph._called_obj._is_custom_define = True subgraph._called_obj._class = class_ subgraph._args = node._args @@ -228,9 +234,9 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 else: node._called_obj._is_custom_define = is_custom_define_ node._called_obj._class = class_ - #node._called_obj._type = NodeType.call_function + # node._called_obj._type = NodeType.call_function return node - else: #TODO: 常量节点会走到else,需要做处理 + else: # TODO: 常量节点会走到else,需要做处理 logger.warning(f"undefined symbole {n} ... ...") return node @@ -254,20 +260,20 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 logger.debug(f"process ast node: {ast_node}") if isinstance(ast_node, ast.Expr): continue - # self.print_ast(ast_node) + # self.print_ast(ast_node) visitor = self._parser.get_node_visitor(ast_node) if not visitor: logger.warning(f"Get node visitor failed in parse_construct, node: {ast_node}") continue nodes, attribute_names = visitor(ast_node) - # print(f"nodes in parse construct: {nodes}") + # print(f"nodes in parse construct: {nodes}") logger.debug(f"nodes in parse construct: {nodes}") logger.debug(f"attribute_names in parse construct: {attribute_names}") for i in range(len(nodes)): if isinstance(nodes[i], ConstantNode): node = nodes[i] self._contant_nodes[node._value] = node - #self._nodes.append(nodes[i]) + # self._nodes.append(nodes[i]) else: node = self._process_node(nodes[i], attribute_names[i]) node._name = self._base_scope + "." + node.name.split(".")[-1] @@ -283,7 +289,7 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 logger.debug("construct nodes: ") for node in self._nodes: logger.debug(node) - logger.info(f"parse {self._base_scope } construct function end") + logger.info(f"parse {self._base_scope} construct function end") def parse_function(self, func: Union[ast.FunctionDef, FunctionType]): """ @@ -294,11 +300,11 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 function_str = inspect.getsource(func) ast_root = ast.parse(function_str) node = ast_root.body[0] - subgraph = SubGraph(func) #要区分类内还是类外方法 + subgraph = SubGraph(func) # 要区分类内还是类外方法 elif isinstance(func, ast.FunctionDef): logger.info(f"parse {func.name} function start") node = func - subgraph = SubGraph(self._network_class.__dict__[node.name]) #要区分类内还是类外方法 + subgraph = SubGraph(self._network_class.__dict__[node.name]) # 要区分类内还是类外方法 else: logger.warning(f"unsupported function type, function: {func}") subgraph._name = node.name @@ -310,15 +316,15 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 logger.debug(f"process ast node: {ast_node}") if isinstance(ast_node, ast.Expr): continue - #method = 'parse_' + ast_node.__class__.__name__ - #visitor = getattr(self._parser, method, None) + # method = 'parse_' + ast_node.__class__.__name__ + # visitor = getattr(self._parser, method, None) visitor = self._parser.get_node_visitor(ast_node) if not visitor: logger.warning(f"get node visitor failed in parse_function, node: {ast_node}") nodes, attribute_names = visitor(ast_node) for i in range(len(nodes)): nodes[i]._index = index - nodes[i].name = self._base_scope + "." + node.name + "." + nodes[i].name + nodes[i].name = self._base_scope + "." + node.name + "." + nodes[i].name if attribute_names and attribute_names[i] in self._symbols.keys(): nodes[i]._attribute = self._symbols[attribute_names[i]] else: @@ -329,7 +335,7 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 index += 1 logger.debug(f"{subgraph._name} nodes: ") for node in subgraph._nodes: - logger.debug (node) + logger.debug(node) logger.info("parse function end") return subgraph @@ -368,15 +374,15 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 continue for i in range(len(self._nodes) - 1, -1, -1): - # logger.debug(f"node arg: {arg}, current node targets: {self._nodes[i]._targets}") + # logger.debug(f"node arg: {arg}, current node targets: {self._nodes[i]._targets}") if arg in self._nodes[i]._targets: node.inputs.append(self._nodes[i]) self._nodes[i].outputs.append(node) break if arg in self._symbols.keys(): - node.inputs.append(self._symbols[arg]) - continue + node.inputs.append(self._symbols[arg]) + continue return @@ -387,7 +393,7 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 for name, node in self._symbols.items(): logger.debug(f"name: {name}; node: {node}") if node._is_custom_define and isinstance(node._class, FunctionType): - logger.info("The node is FunctionType, node: %r", node) + logger.info("The node is FunctionType, node: %r", node) subgraph = self.parse_function(node._class) logger.debug("name = %r", name, "; subgraph = %r", subgraph) self._subgraphs["self." + name] = subgraph @@ -399,7 +405,7 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 graph.parse_init() graph.parse_construct() for n in graph._nodes: - n.name = self._base_scope + "." + name.split(".")[-1] + "." + n.name.split(".")[-1] + n.name = self._base_scope + "." + name.split(".")[-1] + "." + n.name.split(".")[-1] graph.name = name self._subgraphs[name] = graph elif node._is_custom_define and issubclass(node._class, Primitive): @@ -413,38 +419,39 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 node._index += varivable def _find_parent_node(self, child): - index = self.nodes.index(child) - for i in range(index, len(self.nodes)): - if child.targets[0] in self.nodes[i]._args: - return self.nodes[i] + index = self.nodes.index(child) + for i in range(index, len(self.nodes)): + if child.targets[0] in self.nodes[i]._args: + return self.nodes[i] def find_node_by_name(self, name: str): for n in self._nodes: if n.name == name: return n - + return None - - def find_node_by_name_and_index(self, name, start: int): #从start开始往前找targets中包含name的节点,找到后返回 + + def find_node_by_name_and_index(self, name, start: int): # 从start开始往前找targets中包含name的节点,找到后返回 for i in range(start, -1, -1): if name in self._nodes[i].targets: return self._nodes[i] - - return None - def insert_before(self, n: Optional[Node]=None): + return None + + def insert_before(self, n: Optional[Node] = None): if n: - return _insert_point(self, n) + return _insert_point(self, n, True) else: - return _insert_point(self, self._nodes[0]) - - def insert_after(self, n: Optional[Node]=None): #如果插入到嵌套节点的中间怎么处理?? + return _insert_point(self, self._nodes[0], True) + + def insert_after(self, n: Optional[Node] = None): # 如果插入到嵌套节点的中间怎么处理?? if n: - return _insert_point(self, n) #TODO: 这里后续节点的确认,按照代码行的话需要找到index为该节点index +1的节点 + return _insert_point(self, n, False) # TODO: 这里后续节点的确认,按照代码行的话需要找到index为该节点index +1的节点 else: - return _insert_point(self, self._nodes[1]) + return _insert_point(self, self._nodes[1], False) - def create_node(self, op, name: str=None, args: Optional[List]=None, kwargs: Optional[Dict]=None, target: Optional[str]=None, for_exec=True): #节点只有一个输出???? + def create_node(self, op, name: str = None, args=None, kwargs=None, + target: Optional[str] = None, for_exec=True): # 节点只有一个输出???? """ using the parameters to create a new node and insert into the graph. Args: @@ -454,10 +461,14 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 kwargs (Optional[Dict]): [description] name (Optional[str]): [description] """ + if kwargs is None: + kwargs = {} + if args is None: + args = [] local_symbols = locals() print("local_symbols: ", local_symbols) self._parser.update_closure_namespace(op.construct) - + if isinstance(op, str): class_, aa, is_custom_define = self._parser.get_func_namesapce(op) elif issubclass(op, nn.Cell) or issubclass(op, Primitive): @@ -466,33 +477,33 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 else: logger.warning(f"the op type not supported, op: {op}") - #if is_custom_define: + # if is_custom_define: # self._parser.update_global_namespace(class_.__name__) print("class is: ", class_, "; is_custom_define: ", is_custom_define, "; aa: ", aa) print("update_closure_namespace: ", self._parser.closure_namespace) print("update_global_namespace: ", self._parser.global_namespace) - #TODO: 判断args和kwargs的参数,如果参数是对象的话则将对象保存并将对应的对象名称保存,然后在最好生成的类中作为属性添加进去,证明不可行,还是会报未定义符号 - #new_arg = [] - #for arg in args: - #print("arg type: ", type(arg)) - #arg_name = arg - #if type(arg) not in [int, str, bool]: - #print("arg name: ", arg.__name__) - # for key in local_symbols: - # if local_symbols[key] == args: - #print("key: ", key) - # arg_name = "self." + arg.__class__.__name__.lower() - # self._new_obj[arg_name] = arg - # op_obj = CalledObjectRepresentation(arg_name, class_=arg.__class__, args=[], kwargs=[], is_custom_define=True) - # self._symbols[name] = op_obj - # ast_node = self._obj_to_astnode(op_obj) - # self._ast_function_root["__init__"].body.append(ast_node) - #arg = arg_name - #new_arg.append(arg_name) - #continue - #new_arg.append(arg_name) - #print("new_arg: ", new_arg) + # TODO: 判断args和kwargs的参数,如果参数是对象的话则将对象保存并将对应的对象名称保存,然后在最好生成的类中作为属性添加进去,证明不可行,还是会报未定义符号 + # new_arg = [] + # for arg in args: + # print("arg type: ", type(arg)) + # arg_name = arg + # if type(arg) not in [int, str, bool]: + # print("arg name: ", arg.__name__) + # for key in local_symbols: + # if local_symbols[key] == args: + # print("key: ", key) + # arg_name = "self." + arg.__class__.__name__.lower() + # self._new_obj[arg_name] = arg + # op_obj = CalledObjectRepresentation(arg_name, class_=arg.__class__, args=[], kwargs=[], is_custom_define=True) + # self._symbols[name] = op_obj + # ast_node = self._obj_to_astnode(op_obj) + # self._ast_function_root["__init__"].body.append(ast_node) + # arg = arg_name + # new_arg.append(arg_name) + # continue + # new_arg.append(arg_name) + # print("new_arg: ", new_arg) for i in range(len(args)): if isinstance(args[i], str): if args[i] not in self._symbols.keys(): @@ -502,11 +513,12 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 args[i] = "self." + args[i] # create a op object in init function name = name if name else class_.__name__ - #name = name if name is not None else target + # name = name if name is not None else target name = self._create_name(name) - op_obj = CalledObjectRepresentation("self." + name, class_=class_, args=args, kwargs=kwargs, is_custom_define=is_custom_define) + op_obj = CalledObjectRepresentation("self." + name, class_=class_, args=args, kwargs=kwargs, + is_custom_define=is_custom_define) self._symbols[name] = op_obj - #create ast node and insert into ast + # create ast node and insert into ast ast_node = self._obj_to_astnode(op_obj) self._ast_function_root["__init__"].body.append(ast_node) @@ -525,10 +537,12 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 n._called_obj = op_obj n.type = NodeType.call_cell - index = self._nodes.index(self._insert_pos) + index = self._nodes.index(self._insert_pos) + if not self._insert_before: + index += 1 self._nodes.insert(index, n) - #create ast node and insert into ast + # create ast node and insert into ast ast_node = self._node_to_astnode(n) index = self._insert_pos._index self._ast_function_root["construct"].body.insert(index, ast_node) @@ -538,23 +552,23 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 def replace_node(self): pass - - def remove_node(self, n: Node) -> None: #要不要把删除的节点返回?? + + def remove_node(self, n: Node) -> None: # 要不要把删除的节点返回?? """ 只删除没有后续节点的节点,在删除节点前要先修改后续节点的输入 """ - #if len(n.outputs) > 0: + # if len(n.outputs) > 0: # logger.warning("the number of outputs is not equal to 0, can not remove.") # return if isinstance(n._ast_node, ast.Assign): if hasattr(n._ast_node.value, "args") and len(n._ast_node.value.args) == 1: - for node in self._nodes: #找到该节点然后修改节点指向的ast节点 + for node in self._nodes: # 找到该节点然后修改节点指向的ast节点 if node._ast_node is n._ast_node.value.args[0]: node._ast_node = n._ast_node n._ast_node.value = n._ast_node.value.args[0] - elif hasattr(n._ast_node.value, "args") and len(n._ast_node.value.args) > 1: + elif hasattr(n._ast_node.value, "args") and len(n._ast_node.value.args) > 1: args_ = [] - for a in n._ast_node.value.args: # 将所有的args放到一个tuple里面 然后赋值给value + for a in n._ast_node.value.args: # 将所有的args放到一个tuple里面 然后赋值给value args_.append(a) n._ast_node.value = ast.Tuple(lineno=0, col_offset=0, elts=args_, ctx=ast.Load()) else: # TODO: 需要刷新节点指向的ast节点 @@ -574,7 +588,7 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 self._namespace[n.name] -= 1 self._nodes.remove(n) self._update_nodes_index(index, -1) - + def _create_name(self, name): if name in self._namespace.keys(): self._namespace[name] += 1 @@ -587,16 +601,17 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 print("called obj: ", called_obj) ast_node = unparse.create_assign(targets=[called_obj.name]) if called_obj._constant_value == None: - #value = unparse.create_call("nn." + called_obj._class.__name__, args=called_obj.args, kwargs=called_obj.kwargs) # TODO: 要根据命名空间来设置 + # value = unparse.create_call("nn." + called_obj._class.__name__, args=called_obj.args, kwargs=called_obj.kwargs) # TODO: 要根据命名空间来设置 print("called_obj._class type ====== ", type(called_obj._class)) - module_name = called_obj._class.__module__ #TODO: 通过import对应的包后,这里可以不用再拼接了 - #if not called_obj._is_custom_define: + module_name = called_obj._class.__module__ # TODO: 通过import对应的包后,这里可以不用再拼接了 + # if not called_obj._is_custom_define: # module_name = "mindspore.nn." - #else: + # else: # module_name = "" print("module_name ====== ", module_name) - value = unparse.create_call(module_name + "." + called_obj._class.__name__, args=called_obj.args, kwargs=called_obj.kwargs) # TODO: 要根据命名空间来设置 + value = unparse.create_call(module_name + "." + called_obj._class.__name__, args=called_obj.args, + kwargs=called_obj.kwargs) # TODO: 要根据命名空间来设置 ast_node.value = value return ast_node @@ -615,30 +630,30 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 def check(self): pass - + @property def python_code(self): return astunparse.unparse(self._ast_root) def convert_to_cell(self) -> nn.Cell: - #construct_func_str = astunparse.unparse(self._ast_function_root["construct"]) - #code_obj = compile(construct_func_str, filename="construct", mode="exec") - - #construct_func_obj = FunctionType(code_obj.co_consts[0], globals=globals(), name="construct") - #self._network_obj.construct = types.MethodType(construct_func_obj, self._network_obj) - #return self._network_obj - - #unparser = ASTUnparser(self._network_class, self.python_code) + # construct_func_str = astunparse.unparse(self._ast_function_root["construct"]) + # code_obj = compile(construct_func_str, filename="construct", mode="exec") + + # construct_func_obj = FunctionType(code_obj.co_consts[0], globals=globals(), name="construct") + # self._network_obj.construct = types.MethodType(construct_func_obj, self._network_obj) + # return self._network_obj + + # unparser = ASTUnparser(self._network_class, self.python_code) unparser = ASTUnparser(self._network_class, self) - new_cell = unparser.get_res_cell()#(self._param_value) - #new_cell.setattr() - # for key, value in self._new_obj.items(): - # print("key: ", key, "; value: ", value) - #setattr(new_cell, key, value) - # new_cell[key] = value - #source_code = inspect.getsource(new_cell) - #print("===================", source_code) - new_obj = new_cell(88) + new_cell = unparser.get_res_cell() # (self._param_value) + # new_cell.setattr() + # for key, value in self._new_obj.items(): + # print("key: ", key, "; value: ", value) + # setattr(new_cell, key, value) + # new_cell[key] = value + # source_code = inspect.getsource(new_cell) + # print("===================", source_code) + new_obj = new_cell(88) return new_obj def print_graph(self): @@ -647,6 +662,7 @@ class Graph: #Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 def deep_copy(self): pass + class SubGraph(Graph, Node): def __init__(self, network: Union[nn.Cell, Primitive, FunctionType]): Graph.__init__(self, network) @@ -662,7 +678,7 @@ class SubGraph(Graph, Node): self.create_placeholder(self._args) def __repr__(self) -> str: - input_names= "" + input_names = "" input_nodes = "" output_names = "" output_nodes = "" @@ -673,7 +689,7 @@ class SubGraph(Graph, Node): for n in self.outputs: output_names += n.name + ", " output = f"subgraph: {self._name}; ast_node: {self._ast_node}; args: {self._args}; targets: {self._targets}; input num: {len(self.inputs)}; input names: {input_names}; output num: {len(self.outputs)}; output name: {output_names}\n" - # for node in self._nodes: - # output += repr(node) + "\n" - + # for node in self._nodes: + # output += repr(node) + "\n" + return output diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py index c01f5dfd646..16389a7a191 100644 --- a/mindspore/python/mindspore/rewrite/node.py +++ b/mindspore/python/mindspore/rewrite/node.py @@ -71,7 +71,7 @@ class BaseNode: self._name = name @property - def inputs(self) -> list(): + def inputs(self) -> list: return self._inputs @inputs.setter diff --git a/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py b/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py index 7450b398c39..387b3735542 100644 --- a/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py +++ b/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py @@ -39,19 +39,19 @@ class PatternEngineForCell(PatternEngine): def __init__(self, pattern: Union[PatternNode, List], replacement: callable = None): super(PatternEngineForCell, self).__init__(pattern, replacement) - def apply_cell(self, net: Cell, graph: Graph): + def apply_cell(self, net: Cell, rewriter: Rewriter): """ Apply current pattern to a cell. output graph cannot be used to compile and executed yet since graph is not allowed to add cell instance only support chain-like and multi cells to one cell pattern match yet Args: net (Cell): net to be transformed. - graph (Graph): net to be transformed. + rewriter (Rewriter): rewriter to be transformed. Returns: If graph been changed. """ - root: Node = graph.root() + root: Node = rewriter.nodes()[-1] changed = False queue: [Node] = [root] while len(queue) > 0: @@ -68,12 +68,14 @@ class PatternEngineForCell(PatternEngine): net.insert_child_to_cell(nodes[0].name.split('.', 1)[1], processed_cell) # relink nodes at origin graph for i in range(len(nodes)): - graph.remove_node(nodes[i]) + rewriter.erase_node(nodes[i]) nodes[-1].outputs[0].inputs.remove(nodes[-1]) - with graph.insert_before(nodes[-1].outputs[0]): - inserted_node: Node = graph.create_node(type(processed_cell), nodes[0].name.split('.', 1)[1], []) - nodes[-1].outputs[0].update_args(0, inserted_node.targets[0]) - new_inputs = nodes[-1].outputs + position = rewriter.before(nodes[-1].outputs[0]) + inserted_node = rewriter.add_object(position, processed_cell, + nodes[0].name.split('.', 1)[1], nodes[-1].targets, "", + nodes[0].inputs[0].targets, {}) + nodes[-1].outputs[0].update_args(0, inserted_node.targets[0]) + new_inputs = nodes[-1].outputs for i in range(1, len(nodes)): net.insert_child_to_cell(nodes[i].name.split('.', 1)[1], Identity()) queue.extend(new_inputs) diff --git a/mindspore/python/mindspore/rewrite/rewriter.py b/mindspore/python/mindspore/rewrite/rewriter.py index 56b43f60fba..837d3ffeced 100644 --- a/mindspore/python/mindspore/rewrite/rewriter.py +++ b/mindspore/python/mindspore/rewrite/rewriter.py @@ -1,31 +1,75 @@ -import ast from types import FunctionType -from typing import Dict, Union +from typing import Optional, Union import mindspore.nn as nn from mindspore.ops.primitive import Primitive +from mindspore.nn import Cell +from .argument import Argument from .graph import Graph +from .node import Node class Rewriter: - def __init__(self) -> None: - pass - - @staticmethod - def parse(network: Union[nn.Cell, Primitive, FunctionType]) -> Graph: + def __init__(self, network: Union[nn.Cell, Primitive, FunctionType]) -> None: + self._graph: Graph = Graph(network) if isinstance(network, nn.Cell): - graph = Graph(network) - graph.create_ast() - graph.get_function_root() - graph.parse_init() - graph.parse_init_subgraph() - graph.parse_functions() - graph.parse_construct() + self._graph.create_ast() + self._graph.get_function_root() + self._graph.parse_init() + self._graph.parse_init_subgraph() + self._graph.parse_functions() + self._graph.parse_construct() elif isinstance(network, FunctionType): - graph = Graph(network) - graph.create_placeholder(graph._ast_root.body[0]) - elif isinstance(network, Primitive): - graph = Graph(network) - return graph + self._graph.create_placeholder(self._graph._ast_root.body[0]) + self._insert_pos = None + + def nodes(self) -> {}: + return self._graph.nodes + + def before(self, node_or_name: Union[Node, str]): + if isinstance(node_or_name, Node): + return True, node_or_name + else: + node = self._graph.find_node_by_name(node_or_name) + return True, node + + def after(self, node_or_name: Union[Node, str]): + if isinstance(node_or_name, Node): + return False, node_or_name + else: + node = self._graph.find_node_by_name(node_or_name) + return False, node + + def add_object(self, position, custom_obj: Cell, field: str, targets: [str], target_type: str, + call_args: [Argument], call_kwargs: {str: Argument}) -> Optional[Node]: + before, node = position + # todo resolve args + if before: + if node == self._graph.nodes[0]: + with self._graph.insert_before(): + inserted_node = self._graph.create_node(type(custom_obj), field) + return inserted_node + with self._graph.insert_before(node): + inserted_node = self._graph.create_node(type(custom_obj), field) + return inserted_node + else: + with self._graph.insert_after(node): + inserted_node = self._graph.create_node(type(custom_obj), field) + return inserted_node + + def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: + if isinstance(node_or_name, Node): + self._graph.remove_node(node_or_name) + return node_or_name + else: + node = self._graph.find_node_by_name(node_or_name) + self._graph.remove_node(node) + return node + + def get_network(self) -> Cell: + return self._graph.convert_to_cell() + + def find_node(self, name: str) -> Node: + return self._graph.find(name) @staticmethod def check_graph(graph): diff --git a/mindspore/python/mindspore/rewrite/test.py b/mindspore/python/mindspore/rewrite/test.py index 7800c6e8ce2..c188d8e5ecc 100644 --- a/mindspore/python/mindspore/rewrite/test.py +++ b/mindspore/python/mindspore/rewrite/test.py @@ -54,7 +54,7 @@ if __name__ == "__main__": #node_remove.args.clear() node_remove.outputs.clear() - graph.remove_node(node_remove) + graph.erase_node(node_remove) # print(graph.python_code) #graph.print_ast() #conv_cell = nn.Conv2d(16, 16, 3) diff --git a/mindspore/python/mindspore/rewrite/test_app.py b/mindspore/python/mindspore/rewrite/test_app.py index 7a76b5e7ec9..093968c7f3b 100644 --- a/mindspore/python/mindspore/rewrite/test_app.py +++ b/mindspore/python/mindspore/rewrite/test_app.py @@ -62,7 +62,7 @@ def test_ControlFlow(): #node_remove.args.clear() node_remove.outputs.clear() - graph.remove_node(node_remove) + graph.erase_node(node_remove) # print(graph.python_code) #graph.print_ast() #conv_cell = nn.Conv2d(16, 16, 3) -- Gitee From 5cf82cf122ecf0bdcb8bf83426d13d2789df0aca Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Tue, 8 Feb 2022 11:53:12 +0800 Subject: [PATCH 28/32] use rewriter to support qat --- .../default_qat/default_fake_quantizer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py index 0a3d98d9dad..c14f2241a74 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_fake_quantizer.py @@ -52,8 +52,10 @@ class DefaultFakeQuantizerPerLayer(FakeQuantizer): self._is_ascend = context.get_context("device_target") == "Ascend" quant_func = Q.FakeQuantPerLayer self._init_fake_quant_func(quant_func) - self._float_min = Parameter(Tensor(np.array([-6]).astype(np.float32), mindspore.float32), name="float_min") - self._float_max = Parameter(Tensor(np.array([6]).astype(np.float32), mindspore.float32), name="float_max") + self._float_min = Parameter(Tensor(np.array([-6]).astype(np.float32), mindspore.float32), + name="float_min", requires_grad=False) + self._float_max = Parameter(Tensor(np.array([6]).astype(np.float32), mindspore.float32), + name="float_max", requires_grad=False) def _init_fake_quant_func(self, quant_func): if self._is_ascend: @@ -91,10 +93,13 @@ class DefaultFakeQuantizerPerChannel(DefaultFakeQuantizerPerLayer): def __init__(self, num_channels=1, channel_axis=1, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False): super(DefaultFakeQuantizerPerChannel, self).__init__(ema=ema, ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) - self._float_min = Parameter(Tensor(np.array([-6] * num_channels).astype(np.float32), mindspore.float32), name="float_min") - self._float_max = Parameter(Tensor(np.array([6] * num_channels).astype(np.float32), mindspore.float32), name="float_max") + self._float_min = Parameter(Tensor(np.array([-6] * num_channels).astype(np.float32), mindspore.float32), + name="float_min", requires_grad=False) + self._float_max = Parameter(Tensor(np.array([6] * num_channels).astype(np.float32), mindspore.float32), + name="float_max", requires_grad=False) quant_func = partial(Q.FakeQuantPerChannel, channel_axis=channel_axis) self._init_fake_quant_func(quant_func) + self._min_max_update_func = Q.MinMaxUpdatePerChannel(channel_axis=channel_axis, ema=ema, ema_decay=ema_decay) class LearnedFakeQuantizerPerLayer(FakeQuantizer): -- Gitee From a503c82943aac8ffef4fe16bd6c268719b8f7960 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Thu, 10 Feb 2022 20:44:17 +0800 Subject: [PATCH 29/32] divide create_node into 2 func --- .../mindspore/golden_stick/net_transform.py | 9 +++++ .../default_qat/default_net_policy.py | 6 +-- .../quantization/quant_aware_training.py | 4 +- mindspore/python/mindspore/rewrite/graph.py | 40 ++++++++++--------- .../rewrite/pattern_engine_for_cell.py | 9 +++-- .../python/mindspore/rewrite/rewriter.py | 31 +++++++------- 6 files changed, 57 insertions(+), 42 deletions(-) diff --git a/mindspore/python/mindspore/golden_stick/net_transform.py b/mindspore/python/mindspore/golden_stick/net_transform.py index f4772180165..f5d19fecf59 100644 --- a/mindspore/python/mindspore/golden_stick/net_transform.py +++ b/mindspore/python/mindspore/golden_stick/net_transform.py @@ -51,6 +51,15 @@ class NetTransformer: def after(self, node_or_name: Union[Node, str]): return self._rewriter.after(node_or_name) + @staticmethod + def create_node(op, targets: [Argument], target_type: str = None, args: [Argument] = None, + kwargs: {str: Argument} = None, name: str = None) -> Node: + return Rewriter.create_node(op, targets, name=name) + + def insert(self, position, node: Node, field: str = None, args: [Argument] = None, + kwargs: {str: Argument} = None) -> Optional[Node]: + return self._rewriter.insert(position, node, field, args, kwargs) + def add_object(self, position, custom_obj: Cell, field: str, targets: [str], target_type: str, call_args: [Argument], call_kwargs: {str: Argument}) -> Optional[Node]: self._net.insert_child_to_cell(field, custom_obj) diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py index 0f66d97bdd0..f6b3f0a6341 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py @@ -50,9 +50,9 @@ class DefaultNetworkPolicy(NetPolicy): # Transformer([Conv2d, ReLU]), # PatternEngine([Conv2dBnAct], _split_conv2d_bn_act), PatternEngineForCell([Conv2d, BatchNorm2d], Conv2dBnAct.fuse), - PatternEngineForCell([Dense, BatchNorm2d], DenseBnAct.fuse), - PatternEngineForCell([Conv2d, BatchNorm2d, ReLU], Conv2dBnAct.fuse), - PatternEngineForCell([Dense, BatchNorm2d, ReLU], DenseBnAct.fuse), + # PatternEngineForCell([Dense, BatchNorm2d], DenseBnAct.fuse), + # PatternEngineForCell([Conv2d, BatchNorm2d, ReLU], Conv2dBnAct.fuse), + # PatternEngineForCell([Dense, BatchNorm2d, ReLU], DenseBnAct.fuse), ] self._layer_policy_map: dict = { Conv2d: ConvLayerPolicy([], [], config), diff --git a/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py b/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py index 3241d03f537..a70f751c6b0 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py +++ b/mindspore/python/mindspore/golden_stick/quantization/quant_aware_training.py @@ -153,8 +153,8 @@ class QuantAwareTraining(GoldenStick): # todo erase_node have bug yet # net_transformer.erase_node(target_node) position = net_transformer.before(target_node.outputs[0]) - net_transformer.add_object(position, result_cell, target_node.name.split('.', 1)[1], target_node.targets, "", - target_node.targets, {}) + node = NetTransformer.create_node(type(result_cell), target_node.targets, name=target_node.name.split(".")[1]) + net_transformer.insert(position, node, node.name) @staticmethod def _apply_layer_policy(net_transformer: NetTransformer): diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index 714f54c836c..292f979442a 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -19,6 +19,7 @@ from .parser import Parser from .ast_unparser import ASTUnparser from .node_visitor import _node_list from .called_object_representation import CalledObjectRepresentation +from .argument import Argument # from _subgraph import SubGraph @@ -450,12 +451,19 @@ class Graph: # Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 else: return _insert_point(self, self._nodes[1], False) - def create_node(self, op, name: str = None, args=None, kwargs=None, + @staticmethod + def create_node(op, targets: [Argument], target_type: str = None, args: [Argument] = None, + kwargs: {str: Argument} = None, name: str = None): + n = Node(name, targets=targets, args=args, kwargs=kwargs) + n.class_ = op + return n + + def insert_node(self, node: Node, name: str = None, args=None, kwargs=None, target: Optional[str] = None, for_exec=True): # 节点只有一个输出???? """ using the parameters to create a new node and insert into the graph. Args: - op ([type]): [description] + node ([Node]): [description] target (str): [description] args (Optional[List]): [description] kwargs (Optional[Dict]): [description] @@ -467,15 +475,10 @@ class Graph: # Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 args = [] local_symbols = locals() print("local_symbols: ", local_symbols) - self._parser.update_closure_namespace(op.construct) + # self._parser.update_closure_namespace(op.construct) - if isinstance(op, str): - class_, aa, is_custom_define = self._parser.get_func_namesapce(op) - elif issubclass(op, nn.Cell) or issubclass(op, Primitive): - class_ = op - _, aa, is_custom_define = self._parser.get_func_namesapce(class_.__name__) - else: - logger.warning(f"the op type not supported, op: {op}") + class_ = node.class_ + _, aa, is_custom_define = self._parser.get_func_namesapce(class_.__name__) # if is_custom_define: # self._parser.update_global_namespace(class_.__name__) @@ -533,22 +536,22 @@ class Graph: # Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 if not for_exec: return target = target if target else name - n = Node(self._base_scope + "." + target, targets=[target], graph=self, args=[], kwargs={}) - n._called_obj = op_obj - n.type = NodeType.call_cell + node._called_obj = op_obj + node.type = NodeType.call_cell + node.name = self._base_scope + "." + node.name index = self._nodes.index(self._insert_pos) if not self._insert_before: index += 1 - self._nodes.insert(index, n) + self._nodes.insert(index, node) # create ast node and insert into ast - ast_node = self._node_to_astnode(n) + ast_node = self._node_to_astnode(node) index = self._insert_pos._index self._ast_function_root["construct"].body.insert(index, ast_node) - n._ast_node = ast_node + node._ast_node = ast_node self._update_nodes_index(index, 1) - return n + return node def replace_node(self): pass @@ -657,7 +660,8 @@ class Graph: # Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 return new_obj def print_graph(self): - pass + for node in self._nodes: + print(repr(node)) def deep_copy(self): pass diff --git a/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py b/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py index 387b3735542..af752975d15 100644 --- a/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py +++ b/mindspore/python/mindspore/rewrite/pattern_engine_for_cell.py @@ -71,11 +71,12 @@ class PatternEngineForCell(PatternEngine): rewriter.erase_node(nodes[i]) nodes[-1].outputs[0].inputs.remove(nodes[-1]) position = rewriter.before(nodes[-1].outputs[0]) - inserted_node = rewriter.add_object(position, processed_cell, - nodes[0].name.split('.', 1)[1], nodes[-1].targets, "", - nodes[0].inputs[0].targets, {}) + node = Rewriter.create_node(type(processed_cell), nodes[-1].targets, target_type=None, + args=nodes[0].inputs[0].targets, kwargs={}, name=nodes[0].name.split('.', 1)[1]) + inserted_node = rewriter.insert(position, node) + inserted_node.outputs = nodes[-1].outputs nodes[-1].outputs[0].update_args(0, inserted_node.targets[0]) - new_inputs = nodes[-1].outputs + new_inputs = nodes[0].inputs for i in range(1, len(nodes)): net.insert_child_to_cell(nodes[i].name.split('.', 1)[1], Identity()) queue.extend(new_inputs) diff --git a/mindspore/python/mindspore/rewrite/rewriter.py b/mindspore/python/mindspore/rewrite/rewriter.py index 837d3ffeced..e90ab9bb323 100644 --- a/mindspore/python/mindspore/rewrite/rewriter.py +++ b/mindspore/python/mindspore/rewrite/rewriter.py @@ -39,21 +39,22 @@ class Rewriter: node = self._graph.find_node_by_name(node_or_name) return False, node - def add_object(self, position, custom_obj: Cell, field: str, targets: [str], target_type: str, - call_args: [Argument], call_kwargs: {str: Argument}) -> Optional[Node]: - before, node = position - # todo resolve args + @staticmethod + def create_node(op, targets: [Argument], target_type: str = None, args: [Argument] = None, + kwargs: {str: Argument} = None, name: str = None) -> Node: + node = Graph.create_node(op, targets, target_type, args, kwargs, name) + return node + + def insert(self, position, node: Node, field: str = None, args: [Argument] = None, kwargs: {str: Argument} = None) \ + -> Optional[Node]: + before, target = position if before: - if node == self._graph.nodes[0]: - with self._graph.insert_before(): - inserted_node = self._graph.create_node(type(custom_obj), field) - return inserted_node - with self._graph.insert_before(node): - inserted_node = self._graph.create_node(type(custom_obj), field) + with self._graph.insert_before(target): + inserted_node = self._graph.insert_node(node, field, args, kwargs) return inserted_node else: - with self._graph.insert_after(node): - inserted_node = self._graph.create_node(type(custom_obj), field) + with self._graph.insert_after(target): + inserted_node = self._graph.insert_node(node, field, args, kwargs) return inserted_node def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: @@ -75,9 +76,9 @@ class Rewriter: def check_graph(graph): return graph.check() - @staticmethod - def print_graph(graph): - graph.print_graph() + def dump(self, title: str): + print(f"=========={title}===================================================================================") + self._graph.print_graph() @staticmethod def copy_graph(graph): -- Gitee From 38ff5e777975249005e840d094fd83d25edb15ab Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Thu, 10 Feb 2022 20:44:17 +0800 Subject: [PATCH 30/32] divide create_node into 2 func --- .../python/mindspore/golden_stick/net_transform.py | 2 ++ mindspore/python/mindspore/rewrite/rewriter.py | 10 +--------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/mindspore/python/mindspore/golden_stick/net_transform.py b/mindspore/python/mindspore/golden_stick/net_transform.py index f5d19fecf59..d40ca9c9790 100644 --- a/mindspore/python/mindspore/golden_stick/net_transform.py +++ b/mindspore/python/mindspore/golden_stick/net_transform.py @@ -82,6 +82,8 @@ class NetTransformer: return None return self._rewriter.erase_node(node) + def dump(self, title: str): + self._rewriter.dump(title) # replace src_pattern with target_nodes. # target_nodes should has same inputs and outputs with src_pattern. diff --git a/mindspore/python/mindspore/rewrite/rewriter.py b/mindspore/python/mindspore/rewrite/rewriter.py index e90ab9bb323..690b014f286 100644 --- a/mindspore/python/mindspore/rewrite/rewriter.py +++ b/mindspore/python/mindspore/rewrite/rewriter.py @@ -72,14 +72,6 @@ class Rewriter: def find_node(self, name: str) -> Node: return self._graph.find(name) - @staticmethod - def check_graph(graph): - return graph.check() - def dump(self, title: str): print(f"=========={title}===================================================================================") - self._graph.print_graph() - - @staticmethod - def copy_graph(graph): - graph.deep_copy() + self._graph.print_graph() \ No newline at end of file -- Gitee From ccc6266f6bd831118cd3120b1a7e5560e1db194f Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Fri, 11 Feb 2022 10:21:52 +0800 Subject: [PATCH 31/32] node fill attribute --- .../quantization/default_qat/default_net_policy.py | 6 +++--- mindspore/python/mindspore/rewrite/graph.py | 2 ++ mindspore/python/mindspore/rewrite/node.py | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py index f6b3f0a6341..0f66d97bdd0 100644 --- a/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py +++ b/mindspore/python/mindspore/golden_stick/quantization/default_qat/default_net_policy.py @@ -50,9 +50,9 @@ class DefaultNetworkPolicy(NetPolicy): # Transformer([Conv2d, ReLU]), # PatternEngine([Conv2dBnAct], _split_conv2d_bn_act), PatternEngineForCell([Conv2d, BatchNorm2d], Conv2dBnAct.fuse), - # PatternEngineForCell([Dense, BatchNorm2d], DenseBnAct.fuse), - # PatternEngineForCell([Conv2d, BatchNorm2d, ReLU], Conv2dBnAct.fuse), - # PatternEngineForCell([Dense, BatchNorm2d, ReLU], DenseBnAct.fuse), + PatternEngineForCell([Dense, BatchNorm2d], DenseBnAct.fuse), + PatternEngineForCell([Conv2d, BatchNorm2d, ReLU], Conv2dBnAct.fuse), + PatternEngineForCell([Dense, BatchNorm2d, ReLU], DenseBnAct.fuse), ] self._layer_policy_map: dict = { Conv2d: ConvLayerPolicy([], [], config), diff --git a/mindspore/python/mindspore/rewrite/graph.py b/mindspore/python/mindspore/rewrite/graph.py index 292f979442a..f3f153e8937 100644 --- a/mindspore/python/mindspore/rewrite/graph.py +++ b/mindspore/python/mindspore/rewrite/graph.py @@ -215,6 +215,7 @@ class Graph: # Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 elif attribute_name and attribute_name in self._symbols.keys(): logger.debug(f"defined in init function: {attribute_name}") node._called_obj = self._symbols[attribute_name] + node.attribute = node._called_obj.kwargs return node elif n == "return": # node._called_obj._type = NodeType.output @@ -235,6 +236,7 @@ class Graph: # Graph可以分为CellGraph, FunctionGraph, PrimitiveGraph都继 else: node._called_obj._is_custom_define = is_custom_define_ node._called_obj._class = class_ + node.attribute = node._called_obj.kwargs # node._called_obj._type = NodeType.call_function return node else: # TODO: 常量节点会走到else,需要做处理 diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py index 16389a7a191..b3e63091569 100644 --- a/mindspore/python/mindspore/rewrite/node.py +++ b/mindspore/python/mindspore/rewrite/node.py @@ -177,13 +177,15 @@ class Node(BaseNode): def _update_inputs(self): for old_input in self.inputs: old_input.outputs.remove(self) + if self._graph._nodes.index(self) == 0: + return self.inputs.clear() for arg in self.args: - pre_node = self._graph.find_node_by_name_and_index(arg, self._graph._nodes.index(self)) + pre_node = self._graph.find_node_by_name_and_index(arg, self._graph._nodes.index(self)-1) pre_node.outputs.append(self) self.inputs.append(pre_node) for _, arg in self.kwargs.items(): - pre_node = self._graph.find_node_by_name_and_index(arg, self._graph._nodes.index(self)) + pre_node = self._graph.find_node_by_name_and_index(arg, self._graph._nodes.index(self)-1) pre_node.outputs.append(self) self.inputs.append(pre_node) -- Gitee From e99ec2b1f601e00d53bcfc1de796f98c9830d5b8 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Tue, 15 Feb 2022 16:20:11 +0800 Subject: [PATCH 32/32] switch to rewrite_exp --- .../mindspore/golden_stick/net_transform.py | 34 ++++++------------- .../rewrite_experiment/pattern_engine.py | 29 +++++++++------- 2 files changed, 26 insertions(+), 37 deletions(-) diff --git a/mindspore/python/mindspore/golden_stick/net_transform.py b/mindspore/python/mindspore/golden_stick/net_transform.py index d40ca9c9790..fe2508359a7 100644 --- a/mindspore/python/mindspore/golden_stick/net_transform.py +++ b/mindspore/python/mindspore/golden_stick/net_transform.py @@ -16,10 +16,10 @@ from typing import Union, Optional from mindspore.nn.cell import Cell -from mindspore.rewrite import Graph, PatternEngine -from mindspore.rewrite.rewriter import Rewriter, Argument +from mindspore.rewrite_experiment.pattern_engine import PatternEngine +from mindspore.rewrite_experiment.rewrite import Rewrite +from mindspore.rewrite_experiment.common.argument import Argument from mindspore.rewrite.node import BaseNode, Node -from mindspore.rewrite.pattern_engine_for_cell import PatternEngineForCell class NetTransformer: @@ -32,7 +32,7 @@ class NetTransformer: def __init__(self, net: Cell): self._net = net - self._rewriter = Rewriter(net) + self._rewriter = Rewrite(net) def get_network(self) -> Cell: return self._net @@ -54,36 +54,24 @@ class NetTransformer: @staticmethod def create_node(op, targets: [Argument], target_type: str = None, args: [Argument] = None, kwargs: {str: Argument} = None, name: str = None) -> Node: - return Rewriter.create_node(op, targets, name=name) + return Rewrite.create_node(op, targets, target_type, args, kwargs, name) def insert(self, position, node: Node, field: str = None, args: [Argument] = None, - kwargs: {str: Argument} = None) -> Optional[Node]: + kwargs: {str: Argument} = None) -> Optional[Node]: return self._rewriter.insert(position, node, field, args, kwargs) - def add_object(self, position, custom_obj: Cell, field: str, targets: [str], target_type: str, - call_args: [Argument], call_kwargs: {str: Argument}) -> Optional[Node]: - self._net.insert_child_to_cell(field, custom_obj) - return self._rewriter.add_object(position, custom_obj, field, targets, target_type, call_args, call_kwargs) - def erase_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: """ Args: - node (BaseNode): node to be removed from original network. + node_or_name (Node/str): node to be removed from original network. Returns: BaseNode has been removed, return None if failed """ + return self._rewriter.erase_node(node_or_name) - if isinstance(node_or_name, str): - node = self._rewriter.find_node(node_or_name) - else: - node = node_or_name - if node is None: - return None - return self._rewriter.erase_node(node) - - def dump(self, title: str): - self._rewriter.dump(title) + def dump(self): + self._rewriter.dump() # replace src_pattern with target_nodes. # target_nodes should has same inputs and outputs with src_pattern. @@ -95,6 +83,4 @@ class NetTransformer: Returns: a bool value indicating if transform occurred """ - if isinstance(pattern_engine, PatternEngineForCell): - return pattern_engine.apply_cell(self._net, self._rewriter) return pattern_engine.apply(self._rewriter) diff --git a/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py b/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py index b888226b379..244a7020e00 100644 --- a/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py +++ b/mindspore/python/mindspore/rewrite_experiment/pattern_engine.py @@ -53,8 +53,8 @@ class PatternNode: """ pattern_node = PatternNode(node.get_targets()[0]) - if node.get_node_type() is NodeType.call_cell: - pattern_node._type = node.get_cell_type() + if node.get_node_type() is NodeType.CallCell: + pattern_node._type = node.get_op_type() return pattern_node @staticmethod @@ -127,7 +127,7 @@ class PatternNode: node (Node) : a rewrite node to be match. """ - return self._type == node.get_cell_type() + return self._type == node.get_op_type() def inputs(self): """ @@ -146,7 +146,7 @@ class PatternNode: """ Getter of type. """ - return self._typelllllll + return self._type class VarNode(PatternNode): @@ -206,7 +206,7 @@ class PatternEngine: If graph been changed. """ - root: Node = rewrite.get_root() + root: Node = rewrite.nodes()[-1] changed = False # IR match queue: [Node] = [root] @@ -232,8 +232,12 @@ class PatternEngine: else: # return Node to insert or replace (new Node no need to set inputs and outputs) # todo if we need to support _process_chain or _process_tree return multi-node changed = True - rewrite.replace_node(matched_list, new_node) - node_inputs = new_node.get_inputs() + origin_next: Node = matched_list[-1].get_next() + for node in matched_list: + rewrite.erase_node(node) + position = rewrite.before(origin_next) + inserted_node = rewrite.insert(position, new_node) + node_inputs = inserted_node.get_inputs() for node_input in node_inputs: queue.append(node_input) return changed @@ -251,18 +255,17 @@ class PatternEngine: if self._replacement is None: return matched_nodes[len(matched_nodes) - 1] - replacement = self._replacement(*matched_nodes) + matched_cells = [node.get_op() for node in matched_nodes] + replacement: Cell = self._replacement(*matched_cells) if replacement is None: return None - if len(matched_nodes) == 0: - new_node = Node(instance=replacement) - else: - new_node = Node(instance=replacement, inputs=matched_nodes[0].inputs) node_name = "" for matched_node in matched_nodes: node_name += matched_node.name + "_" node_name += "fused" - new_node.name = node_name + new_node = Rewrite.create_node(replacement, matched_nodes[-1].get_targets(), + args=matched_nodes[0].get_args(), kwargs=matched_nodes[0].get_kwargs(), + name=node_name) return new_node # matched_cells: name_of_cell_in_pattern map to matched cell in network -- Gitee