diff --git a/mindspore/__init__.py b/mindspore/__init__.py index 55fc10eff20836c8f8446bff86641805d8bc52fe..1236e488ca7f1618e6a444b15d5cb698ef540c91 100755 --- a/mindspore/__init__.py +++ b/mindspore/__init__.py @@ -20,6 +20,7 @@ from .common import * from .mindrecord import * from .ops import _op_impl from .train import * +from .rewrite import * from .log import * from .version import __version__ @@ -29,3 +30,4 @@ __all__.extend(__version__) __all__.extend(common.__all__) __all__.extend(train.__all__) __all__.extend(log.__all__) +__all__.extend(rewrite.__all__) diff --git a/mindspore/compression/common/ast_utils.py b/mindspore/compression/common/ast_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8c639f9906988f3d85687c2a1b747f84eb6043 --- /dev/null +++ b/mindspore/compression/common/ast_utils.py @@ -0,0 +1,58 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import ast + +def get_func_node(root_node, func_name): + for node in ast.walk(root_node): + if isinstance(node, ast.FunctionDef) and node.name == func_name: + return node + return None + +class Assign: + def __init__(self, func_name, targets, args, line_num): + self.func_name = func_name + self.targets = targets + self.args = args + self.line_num = line_num + +def resolve_func_assign(assign: ast.Assign): + args = [] + if not isinstance(assign.value, ast.Call): + return None + call = assign.value + for arg in call.args: + if isinstance(arg, ast.Name): + args.append(arg) + func_name = call.func.attr + outputs = [] + for output in assign.targets: + if isinstance(output, ast.Name): + outputs.append(output.id) + return Assign(func_name, outputs, args, assign.lineno) + +def resolve_all_func_assgin_nodes(func_node: ast.FunctionDef): + result = [] + for node in ast.walk(func_node): + if isinstance(node, ast.Assign): + result.append(resolve_func_assign(node)) + return result + +def resolve_func_args(func_node: ast.FunctionDef): + args = [] + for arg in func_node.args.args: + if arg.arg != "self": + args.append(arg.arg) + return args \ No newline at end of file diff --git a/mindspore/compression/common/graph.py b/mindspore/compression/common/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..25d35fc177eb20f5f0cfa698be3862b4fce5803b --- /dev/null +++ b/mindspore/compression/common/graph.py @@ -0,0 +1,117 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Graph.""" +import ast +import dill +import astpretty +from ast_utils import * +from python_node import PythonNode, PythonNodeType + +class Graph: + def __init__(self): + self._nodes = [] + self._ast = None + + # find node with filter. support fuzzy search? + # { + # _name: name + # _type: type + # _shape: shape + # } + def find_node(self, **filter): + return None + + # return inputs of node whose full_name_with_scope is full_name_with_scope + def node_inputs(self, full_name_with_scope): + return [] + + # return outputs of node whose full_name_with_scope is full_name_with_scope + def node_outputs(self, full_name_with_scope): + return [] + + # insert a cell into graph. only support b into a and c: a--c ==> a--b--c + def insert_node(self, cell, input_nodes, output_nodes): + return True + + # remove a node from graph. only support b between a and c: a--b--c ==> a--c + def remove_node(self, full_name_with_scope): + return True + + # remove a pattern from graph. only support b--c between a and d: a--b--c--d ==> a--d + def remove_pattern(self, pattern): + return True + + # replace node whose name is full_name_with_scope with a new cell. + # cell should has same inputs and outputs with node. + def replace_node(self, full_name_with_scope, cell): + return True + + # replace pattern with new cell. + # cell should has same inputs and outputs with pattern. + def replace_pattern(self, src_pattern, cell): + return True + + def convert_to_python_code(self): + return "" + + def deep_copy_graph(self): + return None + + def print(self): + return "" + + @staticmethod + def build_from_cell(cell, args, **kwargs): + str_cell = dill.source.getsource(cell) + ast_root = ast.parse(str_cell) + astpretty.pprint(ast_root) + + ### get init and construct nodes + init_node = get_func_node(ast_root, "__init__") + if init_node is None: + return None + + construct_node = get_func_node(ast_root, "construct") + if construct_node is None: + return None + + ### resolve construct node + func_assign_nodes = resolve_all_func_assgin_nodes(construct_node) + tensors = resolve_func_args(construct_node) + + def build_graph(): + graph = Graph() + tensors_map = {} + tensors_count = 0 + for tensor in tensors: + tensors_map[tensor] = tensors_count + tensors_count += 1 + for node in func_assign_nodes: + inputs = [] + outputs = [] + for input in node.args: + if input.id in tensors_map.keys(): + inputs.append(tensors_map[input.id]) + else: + return None + for output in node.targets: + tensors_map[output] = tensors_count + outputs.append(tensors_count) + tensors_count += 1 + + python_node = PythonNode(PythonNodeType.call_cell, str(node.line_num), "", None, inputs, outputs) + graph._nodes.append(python_node) + return graph + return build_graph() diff --git a/mindspore/compression/common/graph_test.py b/mindspore/compression/common/graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c85855dc2751acf357358c5406ba1431b712dfd0 --- /dev/null +++ b/mindspore/compression/common/graph_test.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from graph import Graph +import mindspore.nn as nn + +class LeNet5(nn.Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Dense(16 * 5 * 5, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + +def main(): + graph = Graph.build_from_cell(LeNet5, None) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/mindspore/compression/common/python_node.py b/mindspore/compression/common/python_node.py new file mode 100644 index 0000000000000000000000000000000000000000..a14f432b0fef66f47b38e07f5799c47e5745fad3 --- /dev/null +++ b/mindspore/compression/common/python_node.py @@ -0,0 +1,43 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PythonNode.""" + +from enum import Enum + + +class PythonNodeType(Enum): + unknown = 0 + placeholder = 1 # input + parameter = 2 # weight + value = 3 # const + call_cell = 4 # cell object + call_method = 5 # member of cell + call_function = 6 # Primitive object + output = 7 + + +class PythonNode: + def __init__(self, type=PythonNodeType.unknown, name="", scope="", ast_node=None, inputs=None, outputs=None): + if inputs is None: + inputs = [] + if outputs is None: + outputs = [] + self._type = type + self._name = name + self._scope = scope + self._ast_node = ast_node + self._inputs = inputs + self._outputs = outputs + self._kwargs = {} diff --git a/mindspore/compression/compress_algo.py b/mindspore/compression/compress_algo.py new file mode 100644 index 0000000000000000000000000000000000000000..fe28372a5c4bc3231e4a2edd77e08f9450152dc9 --- /dev/null +++ b/mindspore/compression/compress_algo.py @@ -0,0 +1,25 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.train.model import Model + + +class CompressAlgo: + def __init__(self, config): + self._config = config + + # use ModelTransformer to transform model + # add callback to Model + def apply(self, src_model: Model) -> Model: ... diff --git a/mindspore/compression/model_transform/__init__.py b/mindspore/compression/model_transform/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspore/compression/model_transform/ast_utils.py b/mindspore/compression/model_transform/ast_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af498161f586e2c9c2a702590e56c861c5de877a --- /dev/null +++ b/mindspore/compression/model_transform/ast_utils.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import ast + +def get_func_node(root_node, func_name): + for node in ast.walk(root_node): + if isinstanceof(node, ast.FunctionDef) and node.name == func_name: + return node + return None + +class Assign: + def __init__(self, func_name, targets, args): + self.__func_name = func_name + self.__targets = targets + self.__args = args + + +def resolve_func_assign(assign: ast.Assign): + args = [] + if not isinstance(assign.value, ast.Call): + return None + call = assign.value + for arg in call.args: + if isinstance(arg, ast.Name): + args.append(arg) + func_name = call.func.attr + outputs = [] + for output in assign.targets: + if isinstance(output, ast.Name): + outputs.append(output.id) + return Assign(func_name, outputs, args) + +def resolve_all_func_assgin_nodes(func_node: ast.FunctionDef): + for node in func_node: + if isinstance(node, ast.Assign): + yield resolve_func_assign(node) \ No newline at end of file diff --git a/mindspore/compression/model_transform/cell_transform.py b/mindspore/compression/model_transform/cell_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..3d68f4aad1990d3a581a09a2c2efd97449041ca4 --- /dev/null +++ b/mindspore/compression/model_transform/cell_transform.py @@ -0,0 +1,115 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.nn.cell import Cell +from mindspore.rewrite.graph import Graph +from mindspore.rewrite.graph import PythonNode +from collections import OrderedDict + + +# use pythonnode to support tree and chain +# target_cls: (matched_nodes: OrderedDict(pattern_name:PythonNode)) -> PythonNode +class PatternHelper(callable): + def __init__(self, pattern: PythonNode = None, target_fn: callable = None): + self._pattern: PythonNode = pattern + self._target_fn = target_fn + + def pattern(self) -> PythonNode: + return self._pattern + + # matched_cells: name_of_cell_in_pattern map to matched cell in network + # by default, we just replace + def __call__(self, matched_nodes: OrderedDict) -> PythonNode: + if self._pattern is None: + return None + return self._target_fn(matched_nodes) + + +# todo +# how to link code and node : mind-converter pytorch-fx : use name +# pattern : distiller mind-converter tf-optimization +# 成员变量定义是一个,但是被多次调用,修改时仅修改一处调用: distiller/distiller/quantization/quantizer.py +211 +# sub-graph flatten : mind-converter +# Model add set_network and add_callback interface +class CellTransformer: + def __init__(self, net: Cell): + self._net = net + self._graph = Graph.build_from_cell(net) + + def get_transformed(self) -> Cell: + return self._graph.convert_to_cell() + + # basic api + def find_node(self, full_name_with_scope: str) -> PythonNode: + return self._graph.find_node(full_name_with_scope) + + # return inputs of node whose full_name_with_scope is full_name_with_scope + def node_inputs(self, full_name_with_scope: str) -> [PythonNode]: + return self._graph.node_inputs(full_name_with_scope) + + # return outputs of node whose full_name_with_scope is full_name_with_scope + def node_outputs(self, full_name_with_scope: str) -> [PythonNode]: + return self._graph.node_outputs(full_name_with_scope) + + # insert a cell into graph. only support b into a and c: a--c ==> a--b--c + # PythonNode include cell, inputs, outputs + # return node been inserted, return None if failed + def insert_node(self, new_node: PythonNode) -> PythonNode: + return self._graph.insert_node(new_node) + + # remove a node from graph. only support b between a and c: a--b--c ==> a--c + # return node been removed, return None if failed + def remove_node(self, full_name_with_scope: str) -> PythonNode: + return self._graph.remove_node(full_name_with_scope) + + def remove_node(self, node: PythonNode) -> PythonNode: + return self._graph.remove_node(node.name()) + + # replace node whose name is full_name_with_scope with a new cell. + # new_node should has same inputs and outputs with node. + # return node been replaced, return None if failed + def replace_node(self, full_name_with_scope: str, new_node: PythonNode) -> PythonNode: + return self._graph.replace_node(full_name_with_scope, new_node) + + def replace_node(self, node: PythonNode, new_node: PythonNode) -> PythonNode: + return self._graph.replace_node(node.name(), new_node) + + # pattern api + + # remove a pattern from graph. only support b--c between a and d: a--b--c--d ==> a--d + # todo pattern use cell_type : distiller/tests/test_quantizer.py +141 + def remove_pattern(self, remove_helper: PatternHelper) -> bool: + pattern = remove_helper.pattern() + # IR match + for node in self._graph.nodes(): + if node.type() == pattern.type(): + matched_dict = OrderedDict({pattern.name(): node}) + new_node = remove_helper(matched_dict) + if new_node is None: + self._graph.remove_node([node, node, node]) # multi or single + return True + + # replace src_pattern with target_nodes. + # target_nodes should has same inputs and outputs with src_pattern. + def replace_pattern(self, replace_helper: PatternHelper) -> bool: + pattern = replace_helper.pattern() + # IR match + for node in self._graph.nodes(): + if node.type() == pattern.type(): + matched_dict = OrderedDict({pattern.name(): node}) + new_node = replace_helper(matched_dict) + if new_node is not None: + self._graph.replace_node(node, new_node) # multi-single or single-single + return True diff --git a/mindspore/compression/common/constant.py b/mindspore/compression/model_transform/constant.py similarity index 100% rename from mindspore/compression/common/constant.py rename to mindspore/compression/model_transform/constant.py diff --git a/mindspore/compression/common/__init__.py b/mindspore/compression/model_transform/pattern.py similarity index 73% rename from mindspore/compression/common/__init__.py rename to mindspore/compression/model_transform/pattern.py index c382f47e87b64a1dcf42b8f1b7cc157d01228b07..de5d4d29699b5f6aa7e5bc06225b7269ac9973dd 100644 --- a/mindspore/compression/common/__init__.py +++ b/mindspore/compression/model_transform/pattern.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" -Common module for various compression algorithms, now only including datatype definition for quantization. -""" +"""Pattern.""" + -from .constant import QuantDtype -__all__ = ["QuantDtype"] diff --git a/mindspore/compression/model_transform/pattern_visitor.py b/mindspore/compression/model_transform/pattern_visitor.py new file mode 100644 index 0000000000000000000000000000000000000000..4d950b1086d7155637bcdaef293a4063f187718d --- /dev/null +++ b/mindspore/compression/model_transform/pattern_visitor.py @@ -0,0 +1,41 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PatternVisitor.""" + +from mindspore.nn.cell import Cell +from mindspore.rewrite.graph import Graph +from mindspore.rewrite.python_node import PythonNode, PythonNodeType + + +class PatternVisitor: + def __init__(self, pattern: Cell): + self._pattern = pattern + + # visit all nodes + def visit(self, graph: Graph): + for node in graph.nodes(): + if node.type() != PythonNodeType.call_cell: + return node + match_nodes = self.match(node) + if len(match_nodes) == 0: + return node + new_nodes = self.process(match_nodes) + graph.replace_node(node.name(), new_nodes[0]) + + + def match(self, cell_node: PythonNode) -> dict: + return {} + + def process(self, matched_nodes: {}) -> [PythonNode]: ... diff --git a/mindspore/compression/pruner_example.py b/mindspore/compression/pruner_example.py new file mode 100644 index 0000000000000000000000000000000000000000..a0fdc2d3d46383d6f96f48f381bdb7652484a2c7 --- /dev/null +++ b/mindspore/compression/pruner_example.py @@ -0,0 +1,92 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from compress_algo import CompressAlgo +from mindspore.train.callback import Callback +from mindspore.nn import * +from mindspore.train.model import Model + + +class PrunerCallback(Callback) : + def __init__(self, begin, end, frequency, target_sparsity): + super(PrunerCallback, self).__init__() + self._begin_step = begin + self._end_step = end + self._frequency = frequency + self._target_sparsity = target_sparsity + self._masks = {} + + # init _masks + def begin(self, run_context): + origin_args = run_context.original_args() + net: Cell = origin_args.network + cells = net.cells() + for cell in cells: + if cell.cell_type is Conv2d: + self._masks[cell.name] = [1, 1, 1 ,1] + + def step_end(self, run_context): + origin_args = run_context.original_args() + cur_step_num = origin_args.cur_step_num + if cur_step_num < self._begin_step or cur_step_num >= self._end_step: + return + cur_step_num_index = cur_step_num - self._begin_step + if cur_step_num_index % self._frequency != 0: + return + # use _mask to update weight + net: Cell = origin_args.network + # compute new _mask + + +# define pruner algo +class PrunerCompressAlgo(CompressAlgo): + def __init__(self, config: {}): + super(PrunerCompressAlgo, self).__init__(config) + self._callback = PrunerCallback(1, 100, 2, 0.8) + + def apply(self, model: Model) -> Model: + model.add_callback(self._callback) + return model + + +class LeNet5(Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = Conv2d(1, 6, 5) + self.fc1 = Dense(16 * 5 * 5, 120) + self.relu = ReLU() + self.max_pool2d = MaxPool2d(kernel_size=2, stride=2) + self.flatten = Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + return x + + +net = LeNet5() +print("conv type: ", net.conv1.cell_type) +loss = SoftmaxCrossEntropyWithLogits() +optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) +model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None) +algo = QATCompressAlgo({"bit_num": 8, }) +model_opt = algo.apply(model) +dataset = {} +model_opt.train(2, dataset) diff --git a/mindspore/compression/qat2_example.py b/mindspore/compression/qat2_example.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a4611246abdc052b50034abdea0b0f132d7c30 --- /dev/null +++ b/mindspore/compression/qat2_example.py @@ -0,0 +1,93 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.nn +from collections import OrderedDict +from compress_algo import CompressAlgo +from .model_transform.cell_transform import CellTransformer +from mindspore.nn import * +from mindspore.train.model import Model +from .model_transform.cell_transform import PatternHelper +from mindspore.rewrite.python_node import PythonNode + + +# define how to create a new Cell +class SimpleQuantCell(mindspore.nn.Cell): + def __init__(self, matched_cells: OrderedDict): + super().__init__() + conv1: Conv2d = matched_cells["conv1"] + conv2: Conv2d = matched_cells["conv2"] + bn: BatchNorm2d = matched_cells["bn"] + self._conv = mindspore.nn.Conv2d(conv1.in_channels, bn.num_features, conv2.out_channels, self._per_channel) + self._fake_quant = mindspore.nn.FakeQuantWithMinMaxObserver() + + def construct(self, x): + x = self._conv(x) + return self._fake_quant(x) + +def CreateSimpleQuantNode(matched_cells: OrderedDict): + cell = SimpleQuantCell(matched_cells) + return PythonNode(cell, inputs, outputs) + + +# define PatternHelper, include pattern, when and how to use pattern +class ConvBnPatternHelper(PatternHelper): + def __init__(self, per_channel): + pattern = PythonNode({'conv1':Conv2d, 'conv2':Conv2d, 'bn':BatchNorm2d}) + super().__init__(pattern, SimpleQuantCell) + self._per_channel = per_channel + + +# define qat algo +class QATCompressAlgo(CompressAlgo): + def __init__(self, config: {}): + super(QATCompressAlgo, self).__init__(config) + self._pattern_helper = ConvBnPatternHelper(config["per_channel"]) + + def apply(self, model: Model) -> Model: + transformer = CellTransformer(model.predict_network) + transformer.replace_pattern(self._pattern_helper) + new_net = transformer.get_transformed() + model.update_network(new_net) + return model + + +class LeNet5(Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = Conv2d(1, 6, 5) + self.fc1 = Dense(16 * 5 * 5, 120) + self.relu = ReLU() + self.max_pool2d = MaxPool2d(kernel_size=2, stride=2) + self.flatten = Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + return x + + +net = LeNet5() +loss = SoftmaxCrossEntropyWithLogits() +optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) +model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None) +algo = QATCompressAlgo({"bit_num": 8, "per_channel": True}) +model_opt = algo.apply(model) +dataset = {} +model_opt.train(2, dataset) diff --git a/mindspore/compression/qat_example.py b/mindspore/compression/qat_example.py new file mode 100644 index 0000000000000000000000000000000000000000..e0761919b66145bd49af37fb4ce2d262f275976a --- /dev/null +++ b/mindspore/compression/qat_example.py @@ -0,0 +1,106 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.nn +from collections import OrderedDict +from compress_algo import CompressAlgo +from .model_transform.cell_transform import CellTransformer +from mindspore.nn import * +from mindspore.train.model import Model +from .model_transform.cell_transform import PatternHelper +from mindspore.rewrite.python_node import PythonNode + +# define how to create a new Cell +class QuantCell(mindspore.nn.Cell): + def __init__(self, ic, oc, ks, per_channel_arg): + super().__init__() + self._conv = mindspore.nn.Conv2d(ic, oc, ks) + self._fake_quant = mindspore.nn.FakeQuantWithMinMaxObserver(per_channel=per_channel_arg) + + def construct(self, x): + x = self._conv(x) + return self._fake_quant(x) + + +# define PatternHelper, include pattern, when and how to use pattern +class ConvBnPatternHelper(PatternHelper): + def __init__(self, per_channel): + # todo pythonnode construct + self._pattern: PythonNode = PythonNode({name:'_conv', + type:Conv2d, + inputs: [ + { + name:'_bn', + type:BatchNorm2d, + } + ], + args:{}}) + self._pattern: PythonNode = PythonNode({'_conv': Conv2d, '_bn': BatchNorm2d}) + super().__init__() + self._per_channel = per_channel + + def __call__(self, matched_cells: OrderedDict): + if self._pattern is None: + return None + old_conv: Conv2d = matched_cells["_conv"] + bn: BatchNorm2d = matched_cells["_bn"] + if old_conv.kernel_size < 500: + return None + quant_cell = QuantCell(old_conv.in_channels, bn.num_features, old_conv.out_channels, self._per_channel) + return PythonNode(quant_cell, inputs, outputs) + + +# define qat algo +class QATCompressAlgo(CompressAlgo): + def __init__(self, config: {}): + super(QATCompressAlgo, self).__init__(config) + self._pattern_helper = ConvBnPatternHelper(config["per_channel"]) + + def apply(self, model: Model) -> Model: + transformer = CellTransformer(model.predict_network) + transformer.replace_pattern(self._pattern_helper) + new_net = transformer.get_transformed() + model.update_network(new_net) + return model + + +class LeNet5(Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = Conv2d(1, 6, 5) + self.fc1 = Dense(16 * 5 * 5, 120) + self.relu = ReLU() + self.max_pool2d = MaxPool2d(kernel_size=2, stride=2) + self.flatten = Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + return x + + +net = LeNet5() +loss = SoftmaxCrossEntropyWithLogits() +optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) +model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None) +# todo how to update network to model +algo = QATCompressAlgo({"bit_num": 8, "per_channel": True}) +model_opt = algo.apply(model) +dataset = {} +model_opt.train(2, dataset) diff --git a/mindspore/rewrite/__init__.py b/mindspore/rewrite/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce51dcbeee861ef1467027a5353660f1fd7e1ff --- /dev/null +++ b/mindspore/rewrite/__init__.py @@ -0,0 +1,4 @@ +from mindspore.rewrite.graph import Graph +from mindspore.rewrite.python_node import PythonNode + +__all__ = ["Graph", "PythonNode"] \ No newline at end of file diff --git a/mindspore/rewrite/code_analysis.py b/mindspore/rewrite/code_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..c51420a069f51ca3b6986996c60d9398db31c570 --- /dev/null +++ b/mindspore/rewrite/code_analysis.py @@ -0,0 +1,133 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import ast +from enum import Enum +from typing import Any +from _ast import AST, Assign, FunctionDef, ClassDef, Attribute, Name + + +class ASTType(Enum): + TypeUnknown = 0 + TypeClassDef = 1 + TypeFunctionDef = 2 + TypeAttribute = 3 + TypeName = 4 + TypeCall = 5 + TypeAssign = 6 + + def __new__(cls, value: str): + member = object.__new__(cls) + if value == "ClassDef": + member._value_ = 1 + elif value == "FunctionDef": + member._value_ = 2 + elif value == "Attribute": + member._value_ = 3 + elif value == "Name": + member._value_ = 4 + elif value == "Call": + member._value_ = 5 + elif value == "Assign": + member._value_ = 6 + else: + member._value_ = 0 + return member + + +class Scope: + def __init__(self, scope_type: ASTType, lineno: int): + self._type = scope_type + self._lineno = lineno + self._name = str(scope_type) + "-" + str(lineno) + + def __str__(self): + return self._name + + +class FullScope: + def __init__(self): + self._scopes: [Scope] = [] + + def enter_scope(self, scope: Scope): + print("--------------------------------------Enter scope ", scope) + self._scopes.append(scope) + + def enter_ast_class_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeClassDef, scope.lineno)) + + def enter_ast_function_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeFunctionDef, scope.lineno)) + + def enter_ast_attribute_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeAttribute, scope.lineno)) + + def enter_ast_name_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeName, scope.lineno)) + + def enter_ast_call_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeCall, scope.lineno)) + + def enter_ast_assign_scope(self, scope: AST): + self.enter_scope(Scope(ASTType.TypeAssign, scope.lineno)) + + def enter_ast_scope(self, scope: AST): + pass + self.enter_scope(Scope(ASTType(scope.__class__.__name__), scope.lineno)) + + def exit_scope(self): + print("--------------------------------------Exit scope ", self._scopes.pop()) + + def cur_scope_type(self) -> str: + return self._scopes[len(self._scopes) - 1] + + +class CodeAnalyzer(ast.NodeVisitor): + def __init__(self): + super(CodeAnalyzer, self).__init__() + self._scope: FullScope = FullScope() + + # def visit(self, node: AST) -> Any: + # return super(CodeAnalyzer, self).visit(node) + + def visit_Assign(self, node: Assign) -> Any: + self._scope.enter_ast_scope(node) + print("visit assign : ", node.lineno) + super(CodeAnalyzer, self).generic_visit(node) + self._scope.exit_scope() + + def visit_Attribute(self, node: Attribute) -> Any: + self._scope.enter_ast_scope(node) + print("visit attribute : ", self._scope.cur_scope_type(), node.attr) + super(CodeAnalyzer, self).generic_visit(node) + self._scope.exit_scope() + + def visit_Name(self, node: Name) -> Any: + self._scope.enter_ast_scope(node) + print("visit name : ", node.__class__) + super(CodeAnalyzer, self).generic_visit(node) + self._scope.exit_scope() + + def visit_FunctionDef(self, node: FunctionDef) -> Any: + self._scope.enter_ast_scope(node) + print("visit function : ", node.name) + super(CodeAnalyzer, self).generic_visit(node) + self._scope.exit_scope() + + def visit_ClassDef(self, node: ClassDef) -> Any: + self._scope.enter_ast_scope(node) + print("visit class : ", node.name) + super(CodeAnalyzer, self).generic_visit(node) + self._scope.exit_scope() diff --git a/mindspore/rewrite/graph.py b/mindspore/rewrite/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..b26736376025f6381891bacbed41aa7958536265 --- /dev/null +++ b/mindspore/rewrite/graph.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Graph.""" + +import dill +import ast +import astpretty +from python_node import PythonNode +from mindspore.nn.cell import Cell +from code_analysis import CodeAnalyzer + + +class Graph: + def __init__(self): + self._nodes: [Cell] = [] + self._source = "" + self._ast: ast = None + + # basic api + + def find_node(self, full_name_with_scope: str) -> Cell: ... + + # return inputs of node whose full_name_with_scope is full_name_with_scope + def node_inputs(self, full_name_with_scope: str) -> [Cell]: ... + + # return outputs of node whose full_name_with_scope is full_name_with_scope + def node_outputs(self, full_name_with_scope: str) -> [Cell]: ... + + # todo + # insert a cell into graph. only support b into a and c: a--c ==> a--b--c + # return node been inserted, return None if failed + def insert_node(self, new_node: Cell, input_nodes: [Cell], output_nodes: [Cell]) -> Cell: ... + + # remove a node from graph. only support b between a and c: a--b--c ==> a--c + # return node been removed, return None if failed + def remove_node(self, full_name_with_scope: str) -> Cell: ... + + # replace node whose name is full_name_with_scope with a new cell. + # new_node should has same inputs and outputs with node. + # return node been replaced, return None if failed + def replace_node(self, full_name_with_scope: str, new_node: Cell) -> Cell: ... + + # other api + def convert_to_cell(self) -> Cell: ... + + def deep_copy_graph(self): ... + + def print(self) -> str: ... + + def nodes(self) -> [PythonNode]: + return self._nodes + + @staticmethod + def build_from_cell(cell, args=None, **kwargs): + graph = Graph() + graph._source = dill.source.getsource(cell) + graph._ast = ast.parse(graph._source) + # code_analyzer = CodeAnalyzer() + # code_analyzer.visit(graph._ast) + astpretty.pprint(graph._ast) + return graph diff --git a/mindspore/rewrite/python_node.py b/mindspore/rewrite/python_node.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2c0cbbc0ba74a67fa3827c633d7d22b3c6d8e1 --- /dev/null +++ b/mindspore/rewrite/python_node.py @@ -0,0 +1,57 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PythonNode.""" + +from enum import Enum +from mindspore.nn.cell import Cell + + +class PythonNodeType(Enum): + unknown = 0 + placeholder = 1 # input + parameter = 2 # weight + value = 3 # const + call_cell = 4 # cell object + call_method = 5 # member of cell + call_function = 6 # primitive object + graph = 7 # sub-graph + output = 8 + + +class PythonNode: + def __init__(self, cell: Cell = None, node_type=PythonNodeType.unknown, name="", scope="", ast_node=None, + inputs: [] = None, + outputs: [] = None): + if inputs is None: + inputs = [] + if outputs is None: + outputs = [] + self._type = node_type + self._name = name + self._scope = scope + self._ast_node = ast_node + self._inputs = inputs + self._outputs = outputs + self._kwargs = {} + self._cell: Cell = cell + + def cell(self) -> Cell: + return self._cell + + def type(self) -> PythonNodeType: + return self._type + + def name(self) -> str: + return self._name