代码拉取完成,页面将自动刷新
同步操作将从 丁力/AutoOptTool 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import random
import unittest
import numpy as np
import onnx
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE, NP_TYPE_TO_TENSOR_TYPE
from onnxsim import simplify
from auto_opt_tool.onnx.graph import OnnxGraph
from auto_opt_tool.onnx.node import PlaceHolder, Node
class TestGraph(unittest.TestCase):
def create_model(self):
input_0 = PlaceHolder('input_0', np.dtype('float32'), [1,3,224,224])
output_0 = PlaceHolder('output_0', np.dtype('float32'), [1,3,224,224])
node_0 = Node('Node_0', 'Sub', inputs=['input_0'], outputs=['0_out_0', '0_out_1'], attrs={})
node_1 = Node('Node_1', 'Mul', inputs=['0_out_0'], outputs=['1_out_0'], attrs={})
node_2 = Node('Node_2', 'Add', inputs=['0_out_0', '0_out_1'], outputs=['2_out_0', '2_out_1'], attrs={})
node_3 = Node('Node_3', 'Sub', inputs=['1_out_0'], outputs=['3_out_0'], attrs={})
node_4 = Node('Node_4', 'Add', inputs=['1_out_0', '2_out_0'], outputs=['4_out_0'], attrs={})
node_5 = Node('Node_5', 'Mul', inputs=['3_out_0', '4_out_0', '2_out_1'], outputs=['output_0'], attrs={})
graph = OnnxGraph([node_0,node_1,node_2,node_3,node_4,node_5], [input_0], [output_0])
# graph.save('../new_graph.onnx')
return graph
def toposort_test(self):
random.shuffle(graph._nodes)
graph.save('../new_test.onnx')
graph.infershape()
graph.toposort()
graph.infershape()
# onnx_model = graph.to_model()
# simplified_model, _ = simplify(onnx_model)
# onnx.save(simplified_model, '../simplified_model.onnx')
def delete_test(self):
graph = self.create_model()
graph.remove('Node_4', {0:0})
self.assertEqual(graph.nodes, [graph['Node_0'], graph['Node_1'], graph['Node_2'], graph['Node_3'], graph['Node_5']])
def test_add_input(self):
graph = self.create_model()
test_input = graph.add_placeholder('test_input', 'float32', [1,3,224,224])
self.assertEqual(graph['test_input'], test_input)
self.assertEqual(graph.inputs, [graph['input_0'], test_input])
def test_add_output(self):
graph = self.create_model()
test_output = graph.add_placeholder('test_output', 'float32', [1,3,224,224], 'output')
self.assertEqual(graph['test_output'], test_output)
self.assertEqual(graph.outputs, [graph['output_0'], test_output])
def test_add_initializer(self):
graph = self.create_model()
test_ini = graph.add_initializer('test_ini', 'float32')
self.assertEqual(graph['test_ini'], test_ini)
self.assertEqual(graph.initializers, [test_ini])
def test_add_node(self):
graph = self.create_model()
test_node = graph.add_node('test_node', 'Add')
self.assertEqual(graph['test_node'], test_node)
self.assertEqual(graph.nodes, [graph['Node_0'], graph['Node_1'], graph['Node_2'], graph['Node_3'], graph['Node_4'], graph['Node_5'], test_node])
def test_get_nodes(self):
graph = self.create_model()
self.assertEqual(graph.get_nodes('Mul'), [graph['Node_1'], graph['Node_5']])
self.assertEqual(graph.get_nodes('Sub'), [graph['Node_0'], graph['Node_3']])
self.assertEqual(graph.get_nodes('Add'), [graph['Node_2'], graph['Node_4']])
def test_insert_node_before(self):
graph = self.create_model()
# target_graph =
test_node = graph.add_node('test_node', 'Add')
graph.insert_node('Node_4', test_node, 0, 'before')
self.assertEqual(test_node.inputs, ['1_out_0'])
self.assertEqual(test_node.outputs, ['test_node/Node_4'])
self.assertEqual(graph.get_next_nodes('1_out_0'), [graph['Node_3'], test_node])
self.assertEqual(graph.get_prev_node('1_out_0'), graph['Node_1'])
self.assertEqual(graph.get_next_nodes('test_node/Node_4'), [graph['Node_4']])
self.assertEqual(graph.get_prev_node('test_node/Node_4'), test_node)
def test_insert_node_after(self):
graph = self.create_model()
# target_graph =
test_node = graph.add_node('test_node', 'Add')
graph.insert_node('Node_0', test_node, 0, 'after')
self.assertEqual(test_node.inputs, ['Node_0/test_node'])
self.assertEqual(test_node.outputs, ['0_out_0'])
self.assertEqual(graph.get_next_nodes('Node_0/test_node'), [test_node])
self.assertEqual(graph.get_prev_node('Node_0/test_node'), graph['Node_0'])
self.assertEqual(graph.get_next_nodes('0_out_0'), [graph['Node_1'], graph['Node_2']])
self.assertEqual(graph.get_prev_node('0_out_0'), test_node)
if __name__ == "__main__":
unittest.main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。