diff --git a/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py b/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py index 1329c59ae185b49a1d2cfaf5ba3a5179a13cb36d..919c451faea4a3373e805ad0a15d2beb442f4ef8 100644 --- a/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +++ b/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py @@ -23,6 +23,12 @@ class FlattenRecursiveStmt(ast.NodeTransformer): """Ast optimizer for flatten recursive call.""" def __init__(self): + """ + Constructor of FlattenRecursiveStmt. + + Returns: + An instance of ast optimizer for flatten recursive call. + """ self._flatten_table: dict = { ast.Return: ["value"], ast.Call: ["args"], @@ -33,6 +39,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer): @staticmethod def _generate_target_name(node: ast.AST, target_names): + """Generate unique target name.""" if isinstance(node, ast.Call): func = node.func if isinstance(func, ast.Name): @@ -59,6 +66,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer): @staticmethod def _fill_in_original_target_names(target_names, node): + """Fill in original target names before getting unique names.""" for function_index in range(len(node.body)): child = node.body[function_index] if not isinstance(child, ast.Assign): @@ -73,12 +81,14 @@ class FlattenRecursiveStmt(ast.NodeTransformer): @staticmethod def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]: + """Create new assign node to be inserted into ast.FunctionDef.""" if isinstance(node, (ast.Name, ast.Constant, ast.Num, ast.Str, ast.NameConstant, ast.Bytes, ast.Ellipsis)): return "", node new_target_name = FlattenRecursiveStmt._generate_target_name(node, target_names) return new_target_name, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=node) def _flatten_statement(self, node: ast.AST, target_names) -> [ast.AST]: + """Flatten recursive statement according to different node type.""" flatten_config = self._flatten_table.get(type(node)) if flatten_config is None: return [] @@ -113,6 +123,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer): return results def visit_FunctionDef(self, node: FunctionDef) -> Any: + """Traverse construct node and flatten recursive nodes.""" if node.name != "construct": return node @@ -137,6 +148,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer): return node def transform(self, ast_root): + """Interface of FlattenRecursiveStmt.""" ast_root = self.visit(ast_root) ast_root = ast.fix_missing_locations(ast_root) return ast_root diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py index 5d3b350f5b813d3cacce7e3cf2cb03dab91a500e..fa91c83eac0f7e191f85033e4864505acee157ed 100644 --- a/mindspore/python/mindspore/rewrite/node.py +++ b/mindspore/python/mindspore/rewrite/node.py @@ -216,7 +216,18 @@ class Node: @staticmethod def _get_construct_arg_names(parameters): """ - todo xiongkun + Static method of Node. Get parameters' names of the construct function. + + Args: + parameters (mappingProxy): An ordered mapping of parameters' names to the corresponding Parameter objects. + + Raises: + RuntimeError: Invalid parameter kind. + + Returns: + arg_names: Parameters' names, contain parameters of types in [POSITIONAL_ONLY, POSITIONAL_OR_KEYWORD]. + var_positional_name: Name of VAR_POSITIONAL parameters. + var_keyword_name: Name of VAR_KEYWORD parameters. """ position_only_names: [str] = [] positional_or_keyword_names: [str] = [] @@ -244,7 +255,19 @@ class Node: def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict: """ - todo xiongkun + Merge args and kwargs to normalized args. + The keys of args are obtained from the construct function of type(self._instance). + + Args: + args ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + kwargs ({str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. + + Raises: + RuntimeError: Input args are invalid. + RuntimeError: Arg name already exist in kwargs. + + Returns: + The normalized Args. """ if not args: args = [] @@ -793,10 +816,9 @@ class Node: arg (ScopedValue): An instance of ScopedValue represents argument. Raises: - RuntimeError: If 'key' is a valid key of argument. + AssertError: If 'key' is not in original kwargs' keys. """ - # todo xiongkun check range of kwargs - assert key in self._normalized_args_keys and key in self._normalized_args.keys() + assert key in self._normalized_args_keys[self._args_num:] and key in self._normalized_args.keys() self._normalized_args[key] = arg self._sync_arg() diff --git a/mindspore/python/mindspore/rewrite/parsers/arguments_parser.py b/mindspore/python/mindspore/rewrite/parsers/arguments_parser.py index d691bf1b2de7c9fea0ac0c9507e334cbae9c06f5..67347a1f1eb071f8c45c6caa7a1054d60a880e87 100644 --- a/mindspore/python/mindspore/rewrite/parsers/arguments_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/arguments_parser.py @@ -24,9 +24,20 @@ class ArgumentsParser(Parser): """Parse ast.arguments to input-node of SymbolTree.""" def target(self): + """Parse target type""" return ast.arguments def process(self, stree: SymbolTree, node: ast.arguments): + """ + Parse ast.argumnets and create input-node to stree. + + Args: + stree (SymbolTree): symbol tree under parsing. + node (ast.arguments): argument node in construct. + + Raises: + RuntimeError: Types of node.args elements are not ast.arg. + """ if hasattr(node, "posonlyargs"): stree.try_append_python_node(node, node.posonlyargs) diff --git a/mindspore/python/mindspore/rewrite/parsers/assign_parser.py b/mindspore/python/mindspore/rewrite/parsers/assign_parser.py index c7c5b0191bc1f4efc57eac5f4ff593dc8003b793..42d0cc449b580837b7b18c614d5339c0618c9b2b 100644 --- a/mindspore/python/mindspore/rewrite/parsers/assign_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/assign_parser.py @@ -29,26 +29,52 @@ class AssignParser(Parser): """Parse ast.Assign in construct function to node of SymbolTree.""" def target(self): + """Parse target type.""" return ast.Assign @staticmethod def _create_scopedvalue_from_tuple_ast(node: ast.Tuple) -> ScopedValue: + """ + Create ScopedValue from a tuple ast node. + + Args: + node (ast.Tuple): A tuple node. + + Returns: + An instance of ScopedValue. + + Raises: + RuntimeError: Only support ast.Constant as elts of ast.Tuple. + """ tuple_elts = node.elts tuple_values = [] for tuple_elt in tuple_elts: if not isinstance(tuple_elt, ast.Constant): - raise RuntimeError("Only support ast.Constant as elts of ast.Tuple") + raise RuntimeError("Only support ast.Constant as elts of ast.Tuple.") tuple_values.append(tuple_elt.value) return ScopedValue.create_variable_value(tuple(tuple_values)) @staticmethod def _create_scopedvalue(node: ast.expr) -> ScopedValue: + """ + Create ScopedValue from an ast node. + + Args: + node (ast.expr): An ast node. + + Returns: + An instance of ScopedValue. + + Raises: + RuntimeError: Value of target of ast.Assign should be an ast.Name when target is an ast.Attribute. + RuntimeError: Type of input node is unsupported. + """ if isinstance(node, ast.Name): return ScopedValue.create_naming_value(node.id) if isinstance(node, ast.Attribute): scope = node.value if not isinstance(scope, ast.Name): - raise RuntimeError("value of target of ast.Assign should be a ast.Name when target is a ast.Attribute") + raise RuntimeError("value of target of ast.Assign should be a ast.Name when target is a ast.Attribute.") return ScopedValue.create_naming_value(node.attr, scope.id) if isinstance(node, ast.Tuple): return AssignParser._create_scopedvalue_from_tuple_ast(node) @@ -58,6 +84,18 @@ class AssignParser(Parser): @staticmethod def _get_func_name(ast_node: ast.Call) -> str: + """ + Get the func name from ast.Call. + + Args: + ast_node (ast.Call): Input ast.Call node. + + Returns: + Func name. + + Raises: + RuntimeError: Func of input ast node is not ast.Name or ast.Attribute. + """ func = ast_node.func if isinstance(func, ast.Name): return func.id @@ -68,6 +106,19 @@ class AssignParser(Parser): @staticmethod def _get_func_scope(ast_node: ast.Call) -> str: + """ + Get the func scope from ast.Call. + + Args: + ast_node (ast.Call): Input ast.Call node. + + Returns: + Func scope. + + Raises: + RuntimeError: FuncValue is not an ast.Name when func is an ast.Attribute. + RuntimeError: Func of input ast node is not ast.Name or ast.Attribute. + """ func = ast_node.func if isinstance(func, ast.Name): return "" @@ -80,6 +131,16 @@ class AssignParser(Parser): @staticmethod def _get_symbol_object(symbol_name, origin_net): + """ + Get the func scope from ast.Call. + + Args: + symbol_name (str): Func name. + origin_net ([nn.Cell]): Network instance. + + Returns: + Symbol Object. + """ var_dict = origin_net.__dict__ for key, value in var_dict["_cells"].items(): if key == symbol_name: @@ -92,6 +153,18 @@ class AssignParser(Parser): @staticmethod def _create_kwargs(keywords: [ast.keyword]) -> {str, ScopedValue}: + """ + Transfer ast.Call keywords to a dict of ScopedValue when creating a symbol tree node. + + Args: + keywords ([ast.keyword]): Keywords of ast.Call node. + + Returns: + A dict of ScopedValue. + + Raises: + AssertError: Type of keyword is not ast.keyword. + """ results = {} for keyword in keywords: assert isinstance(keyword, ast.keyword) @@ -100,6 +173,20 @@ class AssignParser(Parser): @staticmethod def _convert_ast_call_to_node(ast_node: ast.Call, father_ast_node: ast.Assign, stree: SymbolTree) -> Node: + """ + Convert ast.Call to a symbol tree node. + + Args: + ast_node ([ast.Call]): An ast.Call of assign node in construct. + father_ast_node ([ast.Assign]): Assign node in construct. + stree ([SymbolTree]): Symbol Tree under parsing. + + Returns: + An instance of Node in Symbol Tree. + + Raises: + RuntimeError: kwargs in construct function assign is unsupported. + """ target = AssignParser._create_scopedvalue(father_ast_node.targets[0]) func_name = AssignParser._get_func_name(ast_node) func_scope = AssignParser._get_func_scope(ast_node) @@ -108,7 +195,7 @@ class AssignParser(Parser): call_args = [AssignParser._create_scopedvalue(arg) for arg in ast_node.args] call_kwargs = AssignParser._create_kwargs(ast_node.keywords) if len(ast_node.keywords) > 0: - raise RuntimeError("kwargs in construct function assign is unsupported") + raise RuntimeError("kwargs in construct function assign is unsupported.") obj = AssignParser._get_symbol_object(func_name, stree.get_origin_network()) # need check if node is a callmethod, like: x = len(x) @@ -122,6 +209,21 @@ class AssignParser(Parser): return Node.create_call_cell(obj, father_ast_node, [target], func, call_args, call_kwargs, func_name) def process(self, stree: SymbolTree, node: ast.Assign): + """ + Parse ast.Assign and create a node in symbol tree. + Will create node when value of ast.Assign is in [ast.Call, ast.Name, ast.Constant, ast.Attribute]. + Will create python node when value of ast.Assign is in + [ast.BinOp, ast.BoolOp, ast.Subscript, ast.List, ast.Tuple, ast.Dict]. + Other value types are not supported. + + Args: + stree ([SymbolTree]): Symbol Tree under parsing. + node ([ast.Assign]): An ast.Assign node. + + Raises: + RuntimeError: Only support one target in assign now. + RuntimeError: Unsupported node type in construct function. + """ targets = node.targets if len(targets) != 1: raise RuntimeError("Only support one target in assign now") diff --git a/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py b/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py index 219d9d5aeab8676e1f8ceca67083370cff37fa12..b025f9d504d15657b90be0a3e46374f62d4e46e8 100644 --- a/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py @@ -27,10 +27,12 @@ class ClassDefParser(Parser): """Parse ast.ClassDef which is subclass of Cell to SymbolTree.""" def target(self): + """Parse target type""" return ast.ClassDef @staticmethod def _process_init_func_ast(init_ast: ast.FunctionDef, ori_cls_name: str, opt_cls_name: str): + """Process init func""" super_index = ClassDefParser._modify_super_expr_of_init_func(init_ast, ori_cls_name, opt_cls_name) ClassDefParser._modify_arguments_of_init_func(init_ast) ClassDefParser._replace_ori_field_of_init_func(init_ast.body, super_index) @@ -38,6 +40,7 @@ class ClassDefParser(Parser): @staticmethod def _modify_super_expr_of_init_func(ast_init_fn: ast.FunctionDef, ori_cls_name: str, opt_cls_name: str) -> int: + """Modify network name in super(XXnet).__init__()""" if not ast_init_fn.body: return -1 super_index = -1 @@ -71,6 +74,7 @@ class ClassDefParser(Parser): @staticmethod def _modify_arguments_of_init_func(ast_init_fn: ast.FunctionDef): + """Replace init function input parameters to self and global_vars.""" arg_self = ast.arg(arg="self") arg_global_vars = ast.arg(arg="global_vars") ast_init_fn.args = ast.arguments(args=[arg_self, arg_global_vars], posonlyargs=[], kwonlyargs=[], @@ -78,6 +82,20 @@ class ClassDefParser(Parser): @staticmethod def _replace_ori_field_of_init_func(bodies: [], super_index: int): + """ + Replace original field in init func to self.XX = getattr(self._handler, "XX"). + Only keep following two kinds of ast nodes in bodies right now: + 1. Ast.If and test is self.XX. + 2. Ast.Assign and target is self.XX. + + Args: + bodies ([]): bodied of init ast.FunctionDef. + super_index (int): index of super().__init__() in bodies. + + Raises: + RuntimeError: Not support multi-targets in assign. + RuntimeError: Only support target.value in [ast.Name] in assign node. + """ body_index_to_be_deleted = [] for body_index, body in enumerate(bodies): if body_index == super_index: @@ -111,6 +129,7 @@ class ClassDefParser(Parser): @staticmethod def _insert_handler_to_init_func(ast_init_fn: ast.FunctionDef, super_index): + """Insert 'self._handler = global_vars.get('handler')' to init ast.FunctionDef.body""" if super_index == -1: super_index = 0 AstModifier.insert_assign_to_function(ast_init_fn, [ScopedValue.create_naming_value("_handler", "self")], @@ -119,6 +138,13 @@ class ClassDefParser(Parser): ast_init_fn.body[super_index], False) def process(self, stree: SymbolTree, node: ast.ClassDef): + """ + Parse init and construct in ast.ClassDef. + + Args: + stree ([SymbolTree]): Symbol Tree under parsing. + node ([ast.ClassDef]): An ast.ClassDef node. + """ # change class name assert node.name == stree.get_ori_cls_name() node.name = stree.get_opt_cls_name() diff --git a/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py b/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py index ef607e794bfbebb2815da4b2fe9d3d2f4d506616..1d8c83a232c9abe79e16312e1e0f0d3f77347a24 100644 --- a/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py @@ -24,9 +24,11 @@ class FunctionDefParser(Parser): """Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree.""" def target(self): + """Parse target type""" return ast.FunctionDef def process(self, stree: SymbolTree, node: ast.FunctionDef): + """Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree.""" stree.set_ast_root(node) # parse args as inputs of stree arguments: ast.arguments = node.args diff --git a/mindspore/python/mindspore/rewrite/parsers/module_parser.py b/mindspore/python/mindspore/rewrite/parsers/module_parser.py index f4e31e50cdc127eb2afc3b9607917531b74442dd..8574db9f62c22a3acfac992bcbc3575d86fb33bd 100644 --- a/mindspore/python/mindspore/rewrite/parsers/module_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/module_parser.py @@ -30,12 +30,15 @@ class ClassFinder(ast.NodeVisitor): """Find all ast.ClassDef in input ast node.""" def __init__(self): + """Keep all found ast.ClassDef in self._classes""" self._classes: [ast.ClassDef] = [] def visit_ClassDef(self, node: ast.ClassDef) -> Any: + """Iterate over all nodes and save ast.ClassDef nodes.""" self._classes.append(node) def find_all_classes(self, node: ast.AST) -> [ast.ClassDef]: + """Interface of ClassFinder.""" self.visit(node) return self._classes @@ -44,10 +47,12 @@ class ModuleParser(Parser): """Parse ast.Module to SymbolTrees.""" def target(self): + """Parse target type""" return ast.Module @staticmethod def _find_class(ast_node: ast.Module) -> ast.ClassDef: + """Find all ast.ClassDef in ast.Module, only support one ast.ClassDef in ast.Module now.""" visitor = ClassFinder() classes = visitor.find_all_classes(ast_node) if not classes: @@ -58,18 +63,23 @@ class ModuleParser(Parser): @staticmethod def get_import_node(ast_root): + """Iterate over ast_root and return all ast.Import nodes or ast.ImportFrom nodes in ast_root.""" import_nodes = [] class GetImportNode(ast.NodeVisitor): + """Find all import nodes from input ast node.""" def visit_Import(self, node: ast.Import) -> Any: + """Iterate over all nodes and save ast.Import nodes.""" import_nodes.append(copy.deepcopy(node)) return node def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: + """Iterate over all nodes and save ast.ImportFrom nodes.""" import_nodes.append(copy.deepcopy(node)) return node def get_node(self, input_ast): + """Interface of GetImportNode.""" self.generic_visit(input_ast) return True @@ -79,6 +89,7 @@ class ModuleParser(Parser): @staticmethod def _add_import_to_module(module: ast.Module, origin_net): + """Insert two groups of import nodes to ast.Module, common ones and those from class definition file.""" assert module is not None module.body.insert(0, ast.Import([ast.alias('mindspore')])) module.body.insert(1, ast.ImportFrom(module='mindspore', names=[ast.alias('nn')], level=0)) @@ -97,6 +108,7 @@ class ModuleParser(Parser): ast.fix_missing_locations(module) def process(self, stree: SymbolTree, node: ast.Module): + """Process ast.ClassDef nodes in ast.Module.""" ModuleParser._add_import_to_module(node, stree.get_origin_network()) class_ast = ModuleParser._find_class(node) stree.set_class_ast(class_ast) diff --git a/mindspore/python/mindspore/rewrite/parsers/return_parser.py b/mindspore/python/mindspore/rewrite/parsers/return_parser.py index 4464506a5318136b47c99f90c41ee3706e986d0d..d70c707f32106c479d74ff03cf7029e29a2e2588 100644 --- a/mindspore/python/mindspore/rewrite/parsers/return_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/return_parser.py @@ -25,9 +25,11 @@ class ReturnParser(Parser): """Parse ast.Return output-node of SymbolTree.""" def target(self): + """Parse target type""" return ast.Return def process(self, stree: SymbolTree, node: ast.Return): + """Parse ast.Return to output-node of SymbolTree.""" return_value = node.value if not isinstance(return_value, ast.Name): raise RuntimeError("Only ast.Name as return value")