1 Star 0 Fork 1

杨文琦/AutoOptTool

forked from 丁力/AutoOptTool 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
test_graph.py 4.69 KB
一键复制 编辑 原始数据 按行查看 历史
丁力 提交于 2022-08-15 00:04 . add node ut
import random
import unittest
import numpy as np
import onnx
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE, NP_TYPE_TO_TENSOR_TYPE
from onnxsim import simplify
from auto_opt_tool.onnx.graph import OnnxGraph
from auto_opt_tool.onnx.node import PlaceHolder, Node
class TestGraph(unittest.TestCase):
def create_model(self):
input_0 = PlaceHolder('input_0', np.dtype('float32'), [1,3,224,224])
output_0 = PlaceHolder('output_0', np.dtype('float32'), [1,3,224,224])
node_0 = Node('Node_0', 'Sub', inputs=['input_0'], outputs=['0_out_0', '0_out_1'], attrs={})
node_1 = Node('Node_1', 'Mul', inputs=['0_out_0'], outputs=['1_out_0'], attrs={})
node_2 = Node('Node_2', 'Add', inputs=['0_out_0', '0_out_1'], outputs=['2_out_0', '2_out_1'], attrs={})
node_3 = Node('Node_3', 'Sub', inputs=['1_out_0'], outputs=['3_out_0'], attrs={})
node_4 = Node('Node_4', 'Add', inputs=['1_out_0', '2_out_0'], outputs=['4_out_0'], attrs={})
node_5 = Node('Node_5', 'Mul', inputs=['3_out_0', '4_out_0', '2_out_1'], outputs=['output_0'], attrs={})
graph = OnnxGraph([node_0,node_1,node_2,node_3,node_4,node_5], [input_0], [output_0])
# graph.save('../new_graph.onnx')
return graph
def toposort_test(self):
random.shuffle(graph._nodes)
graph.save('../new_test.onnx')
graph.infershape()
graph.toposort()
graph.infershape()
# onnx_model = graph.to_model()
# simplified_model, _ = simplify(onnx_model)
# onnx.save(simplified_model, '../simplified_model.onnx')
def delete_test(self):
graph = self.create_model()
graph.remove('Node_4', {0:0})
self.assertEqual(graph.nodes, [graph['Node_0'], graph['Node_1'], graph['Node_2'], graph['Node_3'], graph['Node_5']])
def test_add_input(self):
graph = self.create_model()
test_input = graph.add_placeholder('test_input', 'float32', [1,3,224,224])
self.assertEqual(graph['test_input'], test_input)
self.assertEqual(graph.inputs, [graph['input_0'], test_input])
def test_add_output(self):
graph = self.create_model()
test_output = graph.add_placeholder('test_output', 'float32', [1,3,224,224], 'output')
self.assertEqual(graph['test_output'], test_output)
self.assertEqual(graph.outputs, [graph['output_0'], test_output])
def test_add_initializer(self):
graph = self.create_model()
test_ini = graph.add_initializer('test_ini', 'float32')
self.assertEqual(graph['test_ini'], test_ini)
self.assertEqual(graph.initializers, [test_ini])
def test_add_node(self):
graph = self.create_model()
test_node = graph.add_node('test_node', 'Add')
self.assertEqual(graph['test_node'], test_node)
self.assertEqual(graph.nodes, [graph['Node_0'], graph['Node_1'], graph['Node_2'], graph['Node_3'], graph['Node_4'], graph['Node_5'], test_node])
def test_get_nodes(self):
graph = self.create_model()
self.assertEqual(graph.get_nodes('Mul'), [graph['Node_1'], graph['Node_5']])
self.assertEqual(graph.get_nodes('Sub'), [graph['Node_0'], graph['Node_3']])
self.assertEqual(graph.get_nodes('Add'), [graph['Node_2'], graph['Node_4']])
def test_insert_node_before(self):
graph = self.create_model()
# target_graph =
test_node = graph.add_node('test_node', 'Add')
graph.insert_node('Node_4', test_node, 0, 'before')
self.assertEqual(test_node.inputs, ['1_out_0'])
self.assertEqual(test_node.outputs, ['test_node/Node_4'])
self.assertEqual(graph.get_next_nodes('1_out_0'), [graph['Node_3'], test_node])
self.assertEqual(graph.get_prev_node('1_out_0'), graph['Node_1'])
self.assertEqual(graph.get_next_nodes('test_node/Node_4'), [graph['Node_4']])
self.assertEqual(graph.get_prev_node('test_node/Node_4'), test_node)
def test_insert_node_after(self):
graph = self.create_model()
# target_graph =
test_node = graph.add_node('test_node', 'Add')
graph.insert_node('Node_0', test_node, 0, 'after')
self.assertEqual(test_node.inputs, ['Node_0/test_node'])
self.assertEqual(test_node.outputs, ['0_out_0'])
self.assertEqual(graph.get_next_nodes('Node_0/test_node'), [test_node])
self.assertEqual(graph.get_prev_node('Node_0/test_node'), graph['Node_0'])
self.assertEqual(graph.get_next_nodes('0_out_0'), [graph['Node_1'], graph['Node_2']])
self.assertEqual(graph.get_prev_node('0_out_0'), test_node)
if __name__ == "__main__":
unittest.main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/wenqi_yang/auto-opt-tool.git
git@gitee.com:wenqi_yang/auto-opt-tool.git
wenqi_yang
auto-opt-tool
AutoOptTool
master

搜索帮助