diff --git a/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py b/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py index 3901e5e9e139a0e7d209cb25e61b8cbdaa785830..e53d58ea72eaa865b5421b7eafa65dd1486c99e2 100644 --- a/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py +++ b/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py @@ -1 +1,2 @@ from .flatten_recursive_stmt import FlattenRecursiveStmt +from .fold_if_return import FoldIfReturn diff --git a/mindspore/python/mindspore/rewrite/ast_transformers/fold_if_return.py b/mindspore/python/mindspore/rewrite/ast_transformers/fold_if_return.py new file mode 100644 index 0000000000000000000000000000000000000000..03c9482b211c938584912ebabdf9ddc7c0b508d9 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/ast_transformers/fold_if_return.py @@ -0,0 +1,101 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Fold if return.""" + +import ast +import astpretty +import copy +from ast import FunctionDef, If, Return +from typing import Any, Union, Optional +from mindspore import log as logger + + +class FoldIfReturn(ast.NodeTransformer): + @staticmethod + def _last_node_is_return(node: Optional[Union[Return, If]]) -> bool: + """Judge whether this node represents a return.""" + if isinstance(node, Return): # last node is ast.Return + return True + elif isinstance(node, If): # last node is ast.If + if node.body and FoldIfReturn._last_node_is_return(node.body[-1]) \ + and (not node.orelse or FoldIfReturn._last_node_is_return(node.orelse[-1])): + return True + return False + + @staticmethod + def _should_fold_following_nodes(if_node: If): + """ + Judge whether to fold following nodes into if node. + Return True when encounter following example: + + def func(x): + if x == 1: + return True + return False + """ + if FoldIfReturn._last_node_is_return(if_node.body[-1]) \ + and (not if_node.orelse or not FoldIfReturn._last_node_is_return(if_node.orelse[-1])): + return True + else: + return False + + @staticmethod + def _fold_return(father_node, if_node: If, if_index: int, attr: str): + """Fold nodes.""" + if not hasattr(father_node, attr): + astpretty.pprint(father_node) + raise RuntimeError('Father node has not input attr', attr) + father_node_attr = getattr(father_node, attr) + if FoldIfReturn._should_fold_following_nodes(if_node): + for index in range(if_index + 1, len(father_node_attr)): + node = copy.deepcopy(father_node_attr[index]) + if_node.orelse.append(node) + remove_num = len(father_node_attr) - if_index - 1 + for _ in range(remove_num): + father_node_attr.pop() + else: + return + + @staticmethod + def _fold(father_node, attr: str): + """Fold nodes. Iterate into body and orelse of if node.""" + if not hasattr(father_node, attr) or not getattr(father_node, attr): + return + + if isinstance(getattr(father_node, attr)[-1], If): + FoldIfReturn._fold(getattr(father_node, attr)[-1], 'body') # if.body + FoldIfReturn._fold(getattr(father_node, attr)[-1], 'orelse') # if.orelse + + cur_index = len(getattr(father_node, attr)) - 2 # no following nodes to fold when if is the last one + while cur_index >= 0: + child = getattr(father_node, attr)[cur_index] + if isinstance(child, If): + FoldIfReturn._fold(child, 'body') # if.body + FoldIfReturn._fold(child, 'orelse') # if.orelse + FoldIfReturn._fold_return(father_node, child, cur_index, attr) + cur_index -= 1 + + def visit_FunctionDef(self, node: FunctionDef) -> Any: + """Iterate construct node and fold following nodes into if node when condition is met.""" + if node.name != "construct": + return node + FoldIfReturn._fold(node, 'body') + return node + + def transform(self, ast_root): + """Transform.""" + ast_root = self.visit(ast_root) + ast_root = ast.fix_missing_locations(ast_root) + return ast_root diff --git a/mindspore/python/mindspore/rewrite/symbol_tree_builder.py b/mindspore/python/mindspore/rewrite/symbol_tree_builder.py index eb00633e37df8f5e57423837c7b0907750788167..cc5153c600719f8264638cc9ee54fdb57d680dae 100644 --- a/mindspore/python/mindspore/rewrite/symbol_tree_builder.py +++ b/mindspore/python/mindspore/rewrite/symbol_tree_builder.py @@ -21,7 +21,7 @@ from .symbol_tree import SymbolTree from .parser_register import ParserRegister from .parser import Parser from .namespace import Namespace -from .ast_transformers import FlattenRecursiveStmt +from .ast_transformers import FlattenRecursiveStmt, FoldIfReturn class SymbolTreeBuilder: @@ -36,7 +36,7 @@ class SymbolTreeBuilder: @staticmethod def _ast_transform(ast_root: ast.AST) -> ast.AST: - transform_list = [FlattenRecursiveStmt()] + transform_list = [FlattenRecursiveStmt(), FoldIfReturn()] for transformer in transform_list: ast_root = transformer.transform(ast_root) return ast_root diff --git a/mindspore/python/mindspore/rewrite/test/test.sh b/mindspore/python/mindspore/rewrite/test/test.sh index de338e814eea8296ed1a1d808a4cea69817681e8..361331ac33fe9cc7012c1198424dcd62e1a68f03 100644 --- a/mindspore/python/mindspore/rewrite/test/test.sh +++ b/mindspore/python/mindspore/rewrite/test/test.sh @@ -11,6 +11,8 @@ pytest -s ${CUR_PATH}/ut/test_symbol_tree.py check_ret "test_symbol_tree" pytest -s ${CUR_PATH}/ut/test_node.py check_ret "test_create_node" +pytest -s ${CUR_PATH}/ut/test_fold_return.py +check_ret "test_fold_return" pytest -s ${CUR_PATH}/ut/test_flatten_recursive_stmt.py check_ret "test_flatten_recursive_stmt" # st: diff --git a/mindspore/python/mindspore/rewrite/test/ut/test_fold_return.py b/mindspore/python/mindspore/rewrite/test/ut/test_fold_return.py new file mode 100644 index 0000000000000000000000000000000000000000..fa7958365032fa8c476929c79c64f48b650bb8eb --- /dev/null +++ b/mindspore/python/mindspore/rewrite/test/ut/test_fold_return.py @@ -0,0 +1,107 @@ +import inspect +import ast + +import astunparse +from mindspore.rewrite.ast_transformers import FoldIfReturn + + +class TestIf: + """Simple test.""" + def construct(self, x): + """construct""" + if x > 2: + return x - 2 + return x + + +class TestIf2: + """Test multiple if and test if in if.""" + def construct(self, x): + """construct""" + if x > 2: + return x + x += 2 + if x > 2: + if x > 2: + return x + x += 2 + return x + x *= 2 + return x + + +class TestIf3: + """Test orelse""" + def construct(self, x): + """construct""" + x += 2 + if x > 2: + if x > 2: + return x + else: + if x > 2: + return x + return x + x *= 2 + return x + + +def test_simple_if(): + ast_root: ast.Module = ast.parse(inspect.getsource(TestIf)) + folder = FoldIfReturn() + folder.transform(ast_root) + assert astunparse.unparse(ast_root) == """\n\nclass TestIf(): + 'Simple test.'\n + def construct(self, x): + 'construct' + if (x > 2): + return (x - 2) + else: + return x +""" + + +def test_multiple_if(): + ast_root: ast.Module = ast.parse(inspect.getsource(TestIf2)) + folder = FoldIfReturn() + folder.transform(ast_root) + assert astunparse.unparse(ast_root) == """\n\nclass TestIf2(): + 'Test multiple if and test if in if.'\n + def construct(self, x): + 'construct' + if (x > 2): + return x + else: + x += 2 + if (x > 2): + if (x > 2): + return x + else: + x += 2 + return x + else: + x *= 2 + return x +""" + + +def test_orelse(): + ast_root: ast.Module = ast.parse(inspect.getsource(TestIf3)) + folder = FoldIfReturn() + folder.transform(ast_root) + assert astunparse.unparse(ast_root) == """\n\nclass TestIf3(): + 'Test orelse'\n + def construct(self, x): + 'construct' + x += 2 + if (x > 2): + if (x > 2): + return x + elif (x > 2): + return x + else: + return x + else: + x *= 2 + return x +"""