From aecf8050bc2804693674a20161f78a4df129da6b Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 17 Jan 2022 15:01:26 +0800 Subject: [PATCH 1/2] move arg normalize into cls move arg normalize into cls move arg normalize into cls move arg normalize into cls move arg normalize into cls move arg normalize into cls --- .../parsers/arguments_parser.py | 52 +------- .../mindspore/rewrite_experiment/symbol.py | 126 ++++++++++-------- .../mindspore/rewrite_experiment/test.py | 24 ---- .../{test_location.py => test_parse.py} | 35 +++++ 4 files changed, 113 insertions(+), 124 deletions(-) rename mindspore/python/mindspore/rewrite_experiment/{test_location.py => test_parse.py} (62%) diff --git a/mindspore/python/mindspore/rewrite_experiment/parsers/arguments_parser.py b/mindspore/python/mindspore/rewrite_experiment/parsers/arguments_parser.py index b332d45548b..f5f41e31784 100644 --- a/mindspore/python/mindspore/rewrite_experiment/parsers/arguments_parser.py +++ b/mindspore/python/mindspore/rewrite_experiment/parsers/arguments_parser.py @@ -24,44 +24,13 @@ from ..registers import ParserRegister @ParserRegister.reg_parser class ArgumentsParser(Parser): def process(self, symbol: Symbol) -> [Symbol]: - class ArgOrDefault: - def __init__(self, lineno, col_offset, symbol_name) -> None: - self._lineno = lineno - self._clo_offset = col_offset - self._symbol_name: Value = Value.create_symbol_value(symbol_name) - - @property - def symbol_name(self): - return self._symbol_name - - @property - def lineno(self): - return self._lineno - - @property - def col_offset(self): - return self._clo_offset - - def _create_single_arg_symbol(scope: str, arg_ast: ast.arg, a_list: list, a_with_d_value: OrderedDict, - n_symbols: list): + def _create_single_arg_symbol(scope: str, arg_ast: ast.arg, n_symbols: list): if not isinstance(arg_ast, ast.arg): raise RuntimeError("{} of arguments should be ast.arg".format(scope)) a_symbol = Symbol(ast_node=arg_ast, symbol_type=SymbolType.arg) arguments_symbol.add_arg(a_symbol) n_symbols.append(a_symbol) - single_a = ArgOrDefault(arg_ast.lineno, arg_ast.col_offset, a_symbol.get_full_name_with_scope()) - a_list.append(single_a) - a_with_d_value[single_a.symbol_name] = Value.create_empty_value() - - def _find_corresponding_name(defaults: list[ArgOrDefault], names: list[ArgOrDefault], a_with_d_value): - for d in defaults: - i = 0 - while i < len(names) and names[i].lineno == d.lineno and names[i].col_offset < d.col_offset: - i += 1 - if i <= len(names): - a_with_d_value[names[i - 1].symbol_name] = d.symbol_name - if not isinstance(symbol.symbol_ast, ast.arguments): return [symbol] new_symbols = [] @@ -70,23 +39,20 @@ class ArgumentsParser(Parser): arguments_symbol = ArgumentsSymbol(ast_arguments, symbol.scope, symbol.symbol_name) arguments_symbol.set_symbol_location(symbol.get_symbol_location()) - arg_with_default_value = OrderedDict() - args = [] for arg in ast_arguments.args: if arg.arg == "self": continue - _create_single_arg_symbol('args', arg, args, arg_with_default_value, new_symbols) + _create_single_arg_symbol('args', arg, new_symbols) for arg in ast_arguments.kwonlyargs: - _create_single_arg_symbol('kwonlyargs', arg, args, arg_with_default_value, new_symbols) + _create_single_arg_symbol('kwonlyargs', arg, new_symbols) if ast_arguments.vararg is not None: - _create_single_arg_symbol('vararg', ast_arguments.vararg, args, arg_with_default_value, new_symbols) + _create_single_arg_symbol('vararg', ast_arguments.vararg, new_symbols) if ast_arguments.kwarg is not None: - _create_single_arg_symbol('kwarg', ast_arguments.vararg, args, arg_with_default_value, new_symbols) + _create_single_arg_symbol('kwarg', ast_arguments.kwarg, new_symbols) - defaults = [] for single_default in ast_arguments.defaults: if not isinstance(single_default, ast.Constant): raise RuntimeError("only support constant default value in arguments now") @@ -95,12 +61,6 @@ class ArgumentsParser(Parser): arguments_symbol.add_defaults(default_symbol) new_symbols.append(default_symbol) - single_d = ArgOrDefault(single_default.lineno, single_default.col_offset, - default_symbol.get_full_name_with_scope()) - defaults.append(single_d) - - _find_corresponding_name(defaults, args, arg_with_default_value) - arguments_symbol.set_arg_with_default_value(arg_with_default_value) - + arguments_symbol.normalize_args() new_symbols.append(arguments_symbol) return new_symbols diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index 6728005c3fa..e2e14d8df83 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -165,18 +165,18 @@ class Symbol: # add scope / add suffix symbol.set_scope(self.get_full_name_with_scope()) symbol.update_symbol_name() - self._child_syms.append(Value.create_symbol_value(symbol.get_full_name_with_scope())) def update_symbol_name(self): global sn_depot self._symbol_name = sn_depot.set_symbol_name(self._scope, self._symbol_name) + def _update_child_syms(self): + raise RuntimeError("_update_child_syms not implemented") + def get_child_syms(self) -> [Value]: + self._update_child_syms() return self._child_syms - def set_child_syms(self, syms): - self._child_syms = syms - class SubgraphSymbol(Symbol): def __init__(self, ast_node: Optional[ast.AST], scope, symbol_name=None, symbol_type=SymbolType.invalid): @@ -208,12 +208,11 @@ class ModuleSymbol(SubgraphSymbol): assert isinstance(ast_node, ast.Module) self.finish_compile() - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [] for body in self._bodies: syms.append(body) - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"ModuleSymbol({self.get_full_name_with_scope()})" @@ -229,12 +228,11 @@ class ClassSymbol(SubgraphSymbol): def get_class_name(self) -> str: return self._class_name - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [] for body in self._bodies: syms.append(body) - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"ClassSymbol({self.get_full_name_with_scope()})" @@ -263,12 +261,11 @@ class FunctionSymbol(SubgraphSymbol): def arguments(self, arguments): self._arguments = arguments - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._arguments] for body in self._bodies: syms.append(body) - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"FunctionSymbol({self.get_full_name_with_scope()})" @@ -309,12 +306,11 @@ class AssignSymbol(Symbol): def value(self, value: Value): self._value = value - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._value] for target in self._targets: syms.append(target) - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"AssignSymbol({self.get_full_name_with_scope()}): {str(self._targets)} = {self._value}" @@ -339,39 +335,69 @@ class IfSymbol(SubgraphSymbol): def cond(self, cond: Value): self._cond = cond - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._cond] for body in self._bodies: syms.append(body) - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"IfSymbol({self.get_full_name_with_scope()})" class ArgumentsSymbol(Symbol): - def __init__(self, ast_node: Optional[ast.AST], scope, symbol_name, arg_with_default_value: OrderedDict = None): + def __init__(self, ast_node: Optional[ast.AST], scope, symbol_name): super().__init__(ast_node, scope, symbol_name, SymbolType.arguments) - self.args = [] - self.defaults = [] - self._arg_with_default_value = arg_with_default_value + self._arg_symbols: [Symbol] = [] # only use for normalize args + self._default_symbols: [Symbol] = [] # only use for normalize args + self._a_w_d = {} # only use for normalize args + self.arg_with_default_value: OrderedDict = OrderedDict() self.finish_compile() def add_arg(self, symbol): self.add_child_symbol(symbol) + self._arg_symbols.append(symbol) def add_defaults(self, symbol): self.add_child_symbol(symbol) - - def set_arg_with_default_value(self, arg_with_default_value: OrderedDict): - self._arg_with_default_value = arg_with_default_value + self._default_symbols.append(symbol) + + def _find_corresponding_name(self): + for d in self._default_symbols: + i = 0 + while i < len(self._arg_symbols) \ + and self._arg_symbols[i].symbol_location.lineno == d.symbol_location.lineno \ + and self._arg_symbols[i].symbol_location.col_offset < d.symbol_location.col_offset: + i += 1 + if i <= len(self._arg_symbols): + self._a_w_d[self._arg_symbols[i - 1].get_full_name_with_scope()] = d.get_full_name_with_scope() + + def normalize_args(self): + self._find_corresponding_name() + for arg in self._arg_symbols: + arg_full_name = arg.get_full_name_with_scope() + if arg_full_name in self._a_w_d.keys(): + self.arg_with_default_value[Value.create_symbol_value(arg_full_name)] = \ + Value.create_symbol_value(self._a_w_d[arg_full_name]) + else: + self.arg_with_default_value[Value.create_symbol_value(arg_full_name)] = Value.create_empty_value() + + self._arg_symbols = [] + self._default_symbols = [] + self._a_w_d = {} def get_normalized_args(self): - return self._arg_with_default_value + return self.arg_with_default_value def set_normalized_args(self, arg_with_default_value: OrderedDict): - self._arg_with_default_value = arg_with_default_value + self.arg_with_default_value = arg_with_default_value + + def _update_child_syms(self) -> [Value]: + syms = [] + for k, v in self.arg_with_default_value.items(): + syms.append(k) + syms.append(v) + self._child_syms = syms def __str__(self): return f"ArgumentsSymbol({self.get_full_name_with_scope()})" @@ -422,14 +448,13 @@ class CallSymbol(Symbol): func_name = func_name.replace(".", "-") return func_name - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._func] for arg in self._args: syms.append(arg) for keyword in self._keywords: syms.append(keyword) - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"CallSymbol({self.get_full_name_with_scope()}): {self._func}({str(self._args)})" @@ -458,10 +483,9 @@ class AttributeSymbol(Symbol): def set_attribute_attr(self, attr: Value): self._attr = attr - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._value, self._attr] - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"AttributeSymbol({self.get_full_name_with_scope()}): {self._value}.{self._attr}" @@ -486,9 +510,9 @@ class BinopSymbol(Symbol): self.add_child_symbol(symbol) self._right = Value.create_symbol_value(symbol.get_full_name_with_scope()) - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._left, self._right] - return syms + self._child_syms = syms def get_left(self) -> Value: return self._left @@ -525,10 +549,9 @@ class ReturnSymbol(Symbol): def set_return_value(self, value: Value): self._value = value - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._value] - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"ReturnSymbol({self.get_full_name_with_scope()}): {self._value.value}" @@ -548,10 +571,9 @@ class UnaryOpSymbol(Symbol): self.add_child_symbol(symbol) self._operand = Value.create_symbol_value(symbol.get_full_name_with_scope()) - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._operand] - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"UnaryOpSymbol({self.get_full_name_with_scope()}): {type(self._op).__name__}({self._operand.value})" @@ -568,10 +590,9 @@ class KeywordSymbol(Symbol): self.add_child_symbol(symbol) self._value = Value.create_symbol_value(symbol.get_full_name_with_scope()) - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._arg, self._value] - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"KeywordSymbol({self.get_full_name_with_scope()}): {self._arg}({self._value.value})" @@ -590,10 +611,9 @@ class ConstantSymbol(Symbol): def get_value_type(self) -> ValueType: return self._value.type - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._value] - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"ConstantSymbol({self.get_full_name_with_scope()}): {self._value}" @@ -611,10 +631,9 @@ class NameSymbol(Symbol): def set_name(self, name: Value): self._name = name - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._name] - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"NameSymbol({self.get_full_name_with_scope()}): {self._name}" @@ -632,10 +651,9 @@ class ArgSymbol(Symbol): def set_arg(self, arg: Value): self._arg = arg - def get_child_syms(self) -> [Value]: + def _update_child_syms(self) -> [Value]: syms = [self._arg] - self.set_child_syms(syms) - return syms + self._child_syms = syms def __str__(self): return f"ArgSymbol({self.get_full_name_with_scope()}): {self._arg}" diff --git a/mindspore/python/mindspore/rewrite_experiment/test.py b/mindspore/python/mindspore/rewrite_experiment/test.py index 45aca612225..63451095f4f 100644 --- a/mindspore/python/mindspore/rewrite_experiment/test.py +++ b/mindspore/python/mindspore/rewrite_experiment/test.py @@ -111,29 +111,5 @@ def test_compile(): return stb -def test_binop(stb): - stb_table = stb._table - for key, value in stb_table.items(): - if "BinOp" in key.split('.')[-1]: - test_symbol = stb_table[key] - logger.warning("BinOp of symbol {} is {}".format(key, test_symbol._op)) - - -def test_arguments(stb): - symbol_names = [ - '.Module(LeNet5).ClassDef(LeNet5).FunctionDef(__init__).arguments', - '.Module(LeNet5).ClassDef(LeNet5).FunctionDef(construct).arguments'] - for symbol_name in symbol_names: - if symbol_name in stb._table.keys(): - test_symbol = stb._table[symbol_name] - arguments = test_symbol._arg_with_default_value - for k in arguments.keys(): - logger.warning("{}: {}".format(k, arguments[k])) - else: - logger.warning("symbol {} not exsit in stb".format(symbol_name)) - - if __name__ == '__main__': stb = test_compile() - test_binop(stb) - test_arguments(stb) diff --git a/mindspore/python/mindspore/rewrite_experiment/test_location.py b/mindspore/python/mindspore/rewrite_experiment/test_parse.py similarity index 62% rename from mindspore/python/mindspore/rewrite_experiment/test_location.py rename to mindspore/python/mindspore/rewrite_experiment/test_parse.py index 16b76d72c72..3d46fa12372 100644 --- a/mindspore/python/mindspore/rewrite_experiment/test_location.py +++ b/mindspore/python/mindspore/rewrite_experiment/test_parse.py @@ -19,6 +19,7 @@ from mindspore.common.initializer import Normal from mindspore.rewrite_experiment import Rewrite, SymbolTable from mindspore import log as logger from mindspore.rewrite_experiment.test import LeNet5 +from mindspore.rewrite_experiment.symbol import Symbol class Rewrite_test: @@ -49,6 +50,40 @@ def test_location(stb): assert stb._table[k].symbol_location.owner_cls == "LeNet5" +def test_arguments(stb): + symbol_name_1 = '.Module(LeNet5).ClassDef(LeNet5).FunctionDef(__init__).arguments' + symbol_name_2 = '.Module(LeNet5).ClassDef(LeNet5).FunctionDef(construct).arguments' + + test_symbol = stb._table[symbol_name_1] + arguments = test_symbol.arg_with_default_value + assert len(arguments.keys()) == 5 + + test_symbol = stb._table[symbol_name_2] + arguments = test_symbol.arg_with_default_value + assert len(arguments.keys()) == 1 + + +def test_binop(stb): + op_type_list = [ + "Not", "Invert", "UAdd", "USub", + "Add", "Sub", "Mult", "Div", "MatMult", "Mod", "Pow" + ] + stb_table = stb._table + for key, value in stb_table.items(): + if "BinOp" in key.split('.')[-1]: + test_symbol = stb_table[key] + assert type(test_symbol._op).__name__ in op_type_list + + +def test_child_syms(stb): + symbol_name_1 = '.Module(LeNet5).ClassDef(LeNet5).FunctionDef(__init__).arguments' + symbol: Symbol = stb._table[symbol_name_1] + assert len(symbol.get_child_syms()) != 0 + + if __name__ == '__main__': stb = parse() test_location(stb) + test_arguments(stb) + test_binop(stb) + test_child_syms(stb) -- Gitee From 26fcf564c28c3143f9e6da76a8288a70b3084af7 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 17 Jan 2022 16:48:24 +0800 Subject: [PATCH 2/2] move arg normalize into cls --- .../parsers/attribute_parser.py | 2 +- .../parsers/binop_parser.py | 4 +- .../rewrite_experiment/parsers/call_parser.py | 2 +- .../parsers/function_def_parser.py | 2 +- .../rewrite_experiment/parsers/if_parser.py | 2 +- .../parsers/keyword_parser.py | 2 +- .../parsers/return_parser.py | 2 +- .../parsers/unaryop_parser.py | 2 +- .../mindspore/rewrite_experiment/symbol.py | 148 ++++++++++-------- .../const_symbol_propagate.py | 6 +- .../flatten_call_symbol.py | 6 +- ...rename_construct_function_assign_target.py | 6 +- .../mindspore/rewrite_experiment/test.py | 4 +- 13 files changed, 99 insertions(+), 89 deletions(-) diff --git a/mindspore/python/mindspore/rewrite_experiment/parsers/attribute_parser.py b/mindspore/python/mindspore/rewrite_experiment/parsers/attribute_parser.py index 4d0e8f542ba..3f534fa3525 100644 --- a/mindspore/python/mindspore/rewrite_experiment/parsers/attribute_parser.py +++ b/mindspore/python/mindspore/rewrite_experiment/parsers/attribute_parser.py @@ -37,7 +37,7 @@ class AttributeParser(Parser): if not (isinstance(value, ast.Name) or isinstance(value, ast.Call)): raise RuntimeError("value of attribute should be ast.Name or ast.Call") value_symbol = Symbol(ast_node=value, symbol_type=SymbolType.expression) - attribute_symbol.add_value(value_symbol) + attribute_symbol.set_attribute_value(value_symbol) new_symbols.append(value_symbol) new_symbols.append(attribute_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/parsers/binop_parser.py b/mindspore/python/mindspore/rewrite_experiment/parsers/binop_parser.py index c083b4c90ff..f80a4922c19 100644 --- a/mindspore/python/mindspore/rewrite_experiment/parsers/binop_parser.py +++ b/mindspore/python/mindspore/rewrite_experiment/parsers/binop_parser.py @@ -37,14 +37,14 @@ class BinOpParser(Parser): if isinstance(left, ast.BinOp) or isinstance(left, ast.Constant): left_symbol = Symbol(ast_node=left, symbol_type=SymbolType.expression) - binop_symbol.add_left_symbol(left_symbol) + binop_symbol.set_left(left_symbol) new_symbols.append(left_symbol) else: logger.warning("Ignoring left (%s) in binop compiler", type(left).__name__) if isinstance(right, ast.BinOp) or isinstance(right, ast.Constant): right_symbol = Symbol(ast_node=right, symbol_type=SymbolType.expression) - binop_symbol.add_right_symbol(right_symbol) + binop_symbol.set_right(right_symbol) new_symbols.append(right_symbol) else: logger.warning("Ignoring right (%s) in binop compiler", type(right).__name__) diff --git a/mindspore/python/mindspore/rewrite_experiment/parsers/call_parser.py b/mindspore/python/mindspore/rewrite_experiment/parsers/call_parser.py index b01d51767b1..b2cbda5a32a 100644 --- a/mindspore/python/mindspore/rewrite_experiment/parsers/call_parser.py +++ b/mindspore/python/mindspore/rewrite_experiment/parsers/call_parser.py @@ -37,7 +37,7 @@ class CallParser(Parser): keywords_ = ast_call.keywords func_symbol = Symbol(ast_node=func_, symbol_type=SymbolType.expression) - call_symbol.add_func_symbol(func_symbol) + call_symbol.set_func(func_symbol) new_symbols.append(func_symbol) for arg in args_: diff --git a/mindspore/python/mindspore/rewrite_experiment/parsers/function_def_parser.py b/mindspore/python/mindspore/rewrite_experiment/parsers/function_def_parser.py index 979d8c58da3..74d315192db 100644 --- a/mindspore/python/mindspore/rewrite_experiment/parsers/function_def_parser.py +++ b/mindspore/python/mindspore/rewrite_experiment/parsers/function_def_parser.py @@ -35,7 +35,7 @@ class FunctionDefParser(Parser): # parse args arguments: ast.arguments = function_def.args arguments_symbol = Symbol(ast_node=arguments, symbol_type=SymbolType.arguments) - function_symbol.add_arguments(arguments_symbol) + function_symbol.set_arguments(arguments_symbol) new_symbols.append(arguments_symbol) bodies: list = function_def.body diff --git a/mindspore/python/mindspore/rewrite_experiment/parsers/if_parser.py b/mindspore/python/mindspore/rewrite_experiment/parsers/if_parser.py index db579810ad1..e16f0c831e7 100644 --- a/mindspore/python/mindspore/rewrite_experiment/parsers/if_parser.py +++ b/mindspore/python/mindspore/rewrite_experiment/parsers/if_parser.py @@ -38,7 +38,7 @@ class IfParser(Parser): if not isinstance(cond, ast.expr): raise RuntimeError("test of if should be ast.expr") test_symbol = Symbol(ast_node=cond, symbol_type=SymbolType.expression) - if_symbol.add_cond(test_symbol) + if_symbol.set_cond(test_symbol) new_symbols.append(test_symbol) for body in bodies: diff --git a/mindspore/python/mindspore/rewrite_experiment/parsers/keyword_parser.py b/mindspore/python/mindspore/rewrite_experiment/parsers/keyword_parser.py index ec7669db1aa..16d5fbb3b0d 100644 --- a/mindspore/python/mindspore/rewrite_experiment/parsers/keyword_parser.py +++ b/mindspore/python/mindspore/rewrite_experiment/parsers/keyword_parser.py @@ -35,7 +35,7 @@ class KeywordParser(Parser): keyword_symbol.set_symbol_location(symbol.get_symbol_location()) value_symbol = Symbol(ast_node=ast_keyword.value, symbol_type=SymbolType.expression) - keyword_symbol.add_value(value_symbol) + keyword_symbol.set_value(value_symbol) new_symbols.append(value_symbol) new_symbols.append(keyword_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/parsers/return_parser.py b/mindspore/python/mindspore/rewrite_experiment/parsers/return_parser.py index 864510fba6b..1de66997548 100644 --- a/mindspore/python/mindspore/rewrite_experiment/parsers/return_parser.py +++ b/mindspore/python/mindspore/rewrite_experiment/parsers/return_parser.py @@ -35,7 +35,7 @@ class ReturnParser(Parser): value = ast_return.value if isinstance(value, ast.Name): value_symbol = Symbol(ast_node=value, symbol_type=SymbolType.name) - return_symbol.add_value(value_symbol) + return_symbol.set_value(value_symbol) new_symbols.append(value_symbol) else: logger.warning("Ignoring value (%s) in return compiler", type(value).__name__) diff --git a/mindspore/python/mindspore/rewrite_experiment/parsers/unaryop_parser.py b/mindspore/python/mindspore/rewrite_experiment/parsers/unaryop_parser.py index 2df75b1c4a4..132831bc0ea 100644 --- a/mindspore/python/mindspore/rewrite_experiment/parsers/unaryop_parser.py +++ b/mindspore/python/mindspore/rewrite_experiment/parsers/unaryop_parser.py @@ -35,7 +35,7 @@ class UnaryOpParser(Parser): operand = ast_unaryop.operand operand_symbol = Symbol(ast_node=operand, symbol_type=SymbolType.expression) - unaryop_symbol.add_operand(operand_symbol) + unaryop_symbol.set_operand(operand_symbol) new_symbols.append(operand_symbol) new_symbols.append(unaryop_symbol) diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol.py index e2e14d8df83..bb25d4f923b 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol.py @@ -190,12 +190,10 @@ class SubgraphSymbol(Symbol): else: self._bodies.insert(index, Value.create_symbol_value(body.get_full_name_with_scope())) - @property - def bodies(self) -> [Value]: + def get_bodies(self) -> [Value]: return self._bodies - @bodies.setter - def bodies(self, bodies: [Value]): + def set_bodies(self, bodies: [Value]): self._bodies = bodies def set_body(self, index: int, body: Value): @@ -246,21 +244,19 @@ class FunctionSymbol(SubgraphSymbol): self._func_name = ast_node.name self.finish_compile() - def add_arguments(self, symbol: Symbol): - self.add_child_symbol(symbol) - self._arguments = Value.create_symbol_value(symbol.get_full_name_with_scope()) + def set_arguments(self, symbol: Optional[Symbol, Value]): + if isinstance(symbol, Symbol): + self.add_child_symbol(symbol) + self._arguments = Value.create_symbol_value(symbol.get_full_name_with_scope()) + else: + self._arguments = symbol def get_func_name(self) -> str: return self._func_name - @property - def arguments(self) -> Value: + def get_arguments(self) -> Value: return self._arguments - @arguments.setter - def arguments(self, arguments): - self._arguments = arguments - def _update_child_syms(self) -> [Value]: syms = [self._arguments] for body in self._bodies: @@ -286,12 +282,10 @@ class AssignSymbol(Symbol): self.add_child_symbol(value) self._value = Value.create_symbol_value(value.get_full_name_with_scope()) - @property - def targets(self): + def get_targets(self): return self._targets - @targets.setter - def targets(self, targets: [Value]): + def set_targets(self, targets: [Value]): Symbol.check_str_value_list(targets) self._targets = targets @@ -323,18 +317,16 @@ class IfSymbol(SubgraphSymbol): self._cond: Value = Value.create_empty_value() self.finish_compile() - def add_cond(self, symbol: Symbol): - self.add_child_symbol(symbol) - self._cond = Value.create_symbol_value(symbol.get_full_name_with_scope()) + def set_cond(self, symbol: Optional[Symbol, Value]): + if isinstance(symbol, Symbol): + self.add_child_symbol(symbol) + self._cond = Value.create_symbol_value(symbol.get_full_name_with_scope()) + else: + self._cond = symbol - @property - def cond(self): + def get_cond(self): return self._cond - @cond.setter - def cond(self, cond: Value): - self._cond = cond - def _update_child_syms(self) -> [Value]: syms = [self._cond] for body in self._bodies: @@ -411,9 +403,12 @@ class CallSymbol(Symbol): self._keywords: [Value] = [] self.finish_compile() - def add_func_symbol(self, symbol: Symbol): - self.add_child_symbol(symbol) - self._func = Value.create_symbol_value(symbol.get_full_name_with_scope()) + def set_func(self, symbol: Optional[Symbol, Value]): + if isinstance(symbol, Symbol): + self.add_child_symbol(symbol) + self._func = Value.create_symbol_value(symbol.get_full_name_with_scope()) + else: + self._func = symbol def add_args_symbol(self, symbol: Symbol): self.add_child_symbol(symbol) @@ -438,9 +433,6 @@ class CallSymbol(Symbol): def get_func(self) -> Value: return self._func - def set_func(self, func: Value): - self._func = func - def generate_candidate_func_name(self) -> Optional[str]: if self._func.type is not ValueType.String: return None @@ -467,22 +459,22 @@ class AttributeSymbol(Symbol): self._value: Value = Value.create_empty_value() self.finish_compile() - def add_value(self, symbol: Symbol): - self.add_child_symbol(symbol) - self._value = Value.create_symbol_value(symbol.get_full_name_with_scope()) + def set_attribute_value(self, symbol: Optional[Symbol, Value]): + if isinstance(symbol, Symbol): + self.add_child_symbol(symbol) + self._value = Value.create_symbol_value(symbol.get_full_name_with_scope()) + else: + self._value = symbol + + def set_attribute_attr(self, attr: Value): + self._attr = attr def get_attribute_value(self) -> Value: return self._value - def set_attribute_value(self, value: Value): - self._value = value - def get_attribute_attr(self) -> Value: return self._attr - def set_attribute_attr(self, attr: Value): - self._attr = attr - def _update_child_syms(self) -> [Value]: syms = [self._value, self._attr] self._child_syms = syms @@ -502,33 +494,33 @@ class BinopSymbol(Symbol): self._right: Value = Value.create_empty_value() self.finish_compile() - def add_left_symbol(self, symbol: Symbol): - self.add_child_symbol(symbol) - self._left = Value.create_symbol_value(symbol.get_full_name_with_scope()) - - def add_right_symbol(self, symbol: Symbol): - self.add_child_symbol(symbol) - self._right = Value.create_symbol_value(symbol.get_full_name_with_scope()) + def set_left(self, symbol: Optional[Symbol, Value]): + if isinstance(symbol, Symbol): + self.add_child_symbol(symbol) + self._left = Value.create_symbol_value(symbol.get_full_name_with_scope()) + else: + self._left = symbol - def _update_child_syms(self) -> [Value]: - syms = [self._left, self._right] - self._child_syms = syms + def set_right(self, symbol: Optional[Symbol, Value]): + if isinstance(symbol, Symbol): + self.add_child_symbol(symbol) + self._right = Value.create_symbol_value(symbol.get_full_name_with_scope()) + else: + self._right = symbol def get_left(self) -> Value: return self._left - def set_left(self, left: Value): - self._left = left - def get_right(self) -> Value: return self._right - def set_right(self, right: Value): - self._right = right - def get_op(self): return self._op + def _update_child_syms(self) -> [Value]: + syms = [self._left, self._right] + self._child_syms = syms + def __str__(self): return f"BinopSymbol({self.get_full_name_with_scope()}): {self._left} {self._op} {self._right}" @@ -539,16 +531,16 @@ class ReturnSymbol(Symbol): self._value = Value.create_empty_value() self.finish_compile() - def add_value(self, symbol: Symbol): - self.add_child_symbol(symbol) - self._value = Value.create_symbol_value(symbol.get_full_name_with_scope()) + def set_value(self, symbol: Optional[Symbol, Value]): + if isinstance(symbol, Symbol): + self.add_child_symbol(symbol) + self._value = Value.create_symbol_value(symbol.get_full_name_with_scope()) + else: + self._value = symbol def return_value(self) -> Value: return self._value - def set_return_value(self, value: Value): - self._value = value - def _update_child_syms(self) -> [Value]: syms = [self._value] self._child_syms = syms @@ -567,9 +559,18 @@ class UnaryOpSymbol(Symbol): self._operand: Value = Value.create_empty_value() self.finish_compile() - def add_operand(self, symbol: Symbol): - self.add_child_symbol(symbol) - self._operand = Value.create_symbol_value(symbol.get_full_name_with_scope()) + def set_operand(self, symbol: Optional[Symbol, Value]): + if isinstance(symbol, Symbol): + self.add_child_symbol(symbol) + self._operand = Value.create_symbol_value(symbol.get_full_name_with_scope()) + else: + self._operand = symbol + + def get_operand(self): + return self._operand + + def get_op(self): + return self._op def _update_child_syms(self) -> [Value]: syms = [self._operand] @@ -586,9 +587,18 @@ class KeywordSymbol(Symbol): self._value = Value.create_empty_value() self.finish_compile() - def add_value(self, symbol: Symbol): - self.add_child_symbol(symbol) - self._value = Value.create_symbol_value(symbol.get_full_name_with_scope()) + def set_value(self, symbol: Optional[Symbol, Value]): + if isinstance(symbol, Symbol): + self.add_child_symbol(symbol) + self._value = Value.create_symbol_value(symbol.get_full_name_with_scope()) + else: + self._value = symbol + + def get_arg(self): + return self._arg + + def get_value(self): + return self._value def _update_child_syms(self) -> [Value]: syms = [self._arg, self._value] diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/const_symbol_propagate.py b/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/const_symbol_propagate.py index e00dc9d5293..141125bbfd0 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/const_symbol_propagate.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/const_symbol_propagate.py @@ -86,7 +86,7 @@ class ConstSymbolPropagate(Parser): def assign_symbol_propagate(self, assign_symbol: AssignSymbol): logger.debug("======== const-propagate assign symbol: %s", assign_symbol.get_full_name_with_scope()) - targets = assign_symbol.targets + targets = assign_symbol.get_targets() new_targets = [] for target_value in targets: target = self.get_symbol_by_value(target_value) @@ -94,7 +94,7 @@ class ConstSymbolPropagate(Parser): new_targets.append(target.get_value()) else: new_targets.append(target_value) - assign_symbol.targets = new_targets + assign_symbol.set_targets(new_targets) value = self.get_symbol_by_value(assign_symbol.value) if isinstance(value, ConstantSymbol): @@ -104,7 +104,7 @@ class ConstSymbolPropagate(Parser): logger.debug("======== const-propagate return symbol: %s", return_symbol.get_full_name_with_scope()) return_value = self.get_symbol_by_value(return_symbol.return_value()) if isinstance(return_value, ConstantSymbol): - return_symbol.set_return_value(return_value.get_value()) + return_symbol.set_value(return_value.get_value()) def call_symbol_propagate(self, call_symbol: CallSymbol): logger.debug("======== const-propagate call symbol: %s", call_symbol.get_full_name_with_scope()) diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/flatten_call_symbol.py b/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/flatten_call_symbol.py index 7f0f8a9cd81..14a5eaff4f9 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/flatten_call_symbol.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/flatten_call_symbol.py @@ -58,7 +58,7 @@ class FlattenCallSymbol(Parser): self._names[func_name] = name_number + 1 func_name = func_name + "_" + str(name_number) assign_symbol = AssignSymbol(None, call_symbol.scope, call_symbol.symbol_name + "." + func_name) - assign_symbol.targets = [Value.create_string_value(func_name)] + assign_symbol.set_targets([Value.create_string_value(func_name)]) assign_symbol.value = arg_value new_args.append(Value.create_string_value(func_name)) results.append(assign_symbol) @@ -72,8 +72,8 @@ class FlattenCallSymbol(Parser): new_symbols = [symbol] if not isinstance(symbol, SubgraphSymbol): return new_symbols - for i in range(len(symbol.bodies) - 1, -1, -1): - child: Symbol = self.get_symbol_by_value(symbol.bodies[i]) + for i in range(len(symbol.get_bodies()) - 1, -1, -1): + child: Symbol = self.get_symbol_by_value(symbol.get_bodies()[i]) if not isinstance(child, AssignSymbol): continue call_value = self.get_symbol_by_value(child.value) diff --git a/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/rename_construct_function_assign_target.py b/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/rename_construct_function_assign_target.py index b867ed512cb..af3e3d4ba3c 100644 --- a/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/rename_construct_function_assign_target.py +++ b/mindspore/python/mindspore/rewrite_experiment/symbol_transformers/rename_construct_function_assign_target.py @@ -46,7 +46,7 @@ class RenameConstructFuncAssignTarget(Parser): def init_last_target_symbol_with_arguments(self): if self._construct_func is None: return - arguments_value = self._construct_func.arguments + arguments_value = self._construct_func.get_arguments() if arguments_value is None or arguments_value.type is not ValueType.Symbol: return arguments: ArgumentsSymbol = self._stb.get_symbol(arguments_value.value) @@ -76,7 +76,7 @@ class RenameConstructFuncAssignTarget(Parser): arg.value = new_arg_str target_name = self.generate_final_target(call.generate_candidate_func_name()) - targets = assign_symbol.targets + targets = assign_symbol.get_targets() if len(targets) == 1: self._last_target_symbol[targets[0].value] = target_name targets[0] = Value.create_string_value(target_name) @@ -105,7 +105,7 @@ class RenameConstructFuncAssignTarget(Parser): self._construct_func = func_symbol self.init_last_target_symbol_with_arguments() - bodies = func_symbol.bodies + bodies = func_symbol.get_bodies() for body_value in bodies: body = self.get_symbol_by_value(body_value) if isinstance(body, AssignSymbol): diff --git a/mindspore/python/mindspore/rewrite_experiment/test.py b/mindspore/python/mindspore/rewrite_experiment/test.py index 63451095f4f..1abd6c35b32 100644 --- a/mindspore/python/mindspore/rewrite_experiment/test.py +++ b/mindspore/python/mindspore/rewrite_experiment/test.py @@ -69,13 +69,13 @@ def print_construct(symbols: dict): for k, v in symbols.items(): if v.symbol_type() == SymbolType.function_def and v.get_func_name() == "construct": print("Construct function: ") - bodies = v.bodies + bodies = v.get_bodies() for body_value in bodies: body = symbols.get(body_value.value) if body is None: continue if body.symbol_type() == SymbolType.assign: - msg: str = body.targets[0].value + " = " + msg: str = body.get_targets()[0].value + " = " call_value = body.value call = symbols.get(call_value.value) if call is None: -- Gitee