代码拉取完成,页面将自动刷新
同步操作将从 ahqzy/onnx_convert 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import onnx
import sys
import values
import numpy as np
import log
logger = log.getLogger(__name__, log.INFO)
def merge_hard_sigmiod(model):
dict_add = {}
dict_clip = {}
dict_mul = {}
dict_div = {}
got_swish = False
search = True
while search == True:
search = False
for node_id, node in enumerate(model.graph.node):
#print(node_id, ", name:", node.name, ", input:", node.input, ", output:", node.output, \
# ", op:", node.op_type, ', len(input):', len(node.input))
found_add = False
if node.op_type == 'Add':
addB = values.get_init_value(model, node.input[1])
logger.debug('addB: {}'.format(addB))
if isinstance(addB, list) and addB == []:
logger.debug('addB is not in initilizer')
#continue
addB = values.get_constant_value(model, node.input[1])
if addB == []:
logger.debug('addB is not in constant node list')
continue
else:
logger.debug('addB is {} {}'.format(addB, type(addB)))
if addB[0] != 3:
logger.debug('this is not the add-node which we wanted(value B is not 3)...')
continue
if isinstance(addB, np.ndarray) == True:
if addB.shape != (1, ):
logger.debug('this is not the add-node which we wanted(shape is wrong)...')
continue
else:
if len(addB) != 1:
logger.debug('this is not the add-node which we wanted(list len is wrong)...')
continue
dict_add['input'] = node.input
dict_add['output'] = node.output
dict_add['id'] = node_id
logger.debug('got match add node: {}'.format(node.name))
if node.op_type == 'Clip':
if dict_add and node.input[0] == dict_add['output'][0] and len(node.input) >= 3:
clip_min = values.get_init_value(model, node.input[1])
if isinstance(clip_min, list) and clip_min == []:
logger.debug('clip_min is not in initilizer')
clip_min = values.get_constant_value(model, node.input[1])
if clip_min == []:
dict_add = {}
logger.debug('clip_min is not in constant node list~')
continue
logger.debug('clip_min: {}'.format(clip_min))
clip_max = values.get_init_value(model, node.input[2])
if isinstance(clip_max, list) and clip_max == []:
logger.debug('clip_max is not in initilizer')
clip_max = values.get_constant_value(model, node.input[2])
if clip_max == []:
dict_add = {}
logger.debug('clip_max is not in constant node list~')
continue
logger.debug('clip_max: {}'.format(clip_max))
if clip_min[0] != 0:
logger.debug('this is not the clip-node which we wanted(min is not 0)...')
dict_add = {}
continue
if isinstance(clip_min, np.ndarray) == True:
if clip_min.shape != (1, ):
logger.debug('this is not the clip-node which we wanted(shape is wrong)...')
dict_add = {}
continue
else:
if len(clip_min) != 1:
logger.debug('this is not the clip-node which we wanted(list len is wrong)...')
dict_add = {}
continue
if clip_max[0] != 6:
logger.debug('this is not the clip-node which we wanted(max is not 6)...')
continue
if isinstance(clip_max, np.ndarray) == True:
if clip_max.shape != (1, ):
logger.debug('this is not the clip-node which we wanted(shape is wrong)...')
dict_add = {}
continue
else:
if len(clip_max) != 1:
logger.debug('this is not the clip-node which we wanted(list len is wrong)...')
dict_add = {}
continue
dict_clip['input'] = node.input
dict_clip['output'] = node.output
dict_clip['id'] = node_id
logger.debug('got first pair: {} {}'.format(dict_clip['input'], dict_clip['output']))
else:
logger.debug('clear dict_add: {}'.format(dict_add))
dict_add = {}
if node.op_type == 'Div':
if dict_clip and node.input[0] == dict_clip['output'][0]:
dict_div['input'] = node.input
dict_div['output'] = node.output
dict_div['id'] = node_id
divB = values.get_init_value(model, node.input[1])
if isinstance(divB, list) and divB == []:
logger.debug('divB is not in initilizer')
divB = values.get_constant_value(model, node.input[1])
if divB == []:
dict_add = {}
dict_clip = {}
dict_div = {}
logger.debug('divB is not in constant node list~')
continue
logger.debug('divB: {}'.format(divB))
if divB[0] != 6:
logger.debug('this is not the div-node which we wanted(value B is not 6)...')
dict_add = {}
dict_clip = {}
dict_div = {}
continue
if isinstance(divB, np.ndarray) == True:
if divB.shape != (1, ):
logger.debug('this is not the div-node which we wanted(shape is wrong)...')
dict_add = {}
dict_clip = {}
dict_div = {}
continue
else:
if len(divB) != 1:
logger.debug('this is not the div-node which we wanted(list len is wrong)...')
dict_add = {}
dict_clip = {}
dict_div = {}
continue
###################################
old_node = model.graph.node[dict_add['id']]
model.graph.node.remove(old_node)
swish_node = onnx.helper.make_node(
name = '',
op_type='HardSigmoid',
inputs=[dict_add['input'][0]],
outputs=dict_div['output'],
alpha=1.0/6,
beta=0.5
)
model.graph.node.insert(dict_add['id'], swish_node)
old_node = model.graph.node[dict_div['id']]
model.graph.node.remove(old_node)
old_node = model.graph.node[dict_clip['id']]
model.graph.node.remove(old_node)
dict_add = {}
dict_clip = {}
###############################
got_swish = True
search = True
break
else:
logger.debug('clear dict_add and dict_clip, ')
logger.debug('dict_add: {}'.format(dict_add))
logger.debug('dict_clip: {}'.format(dict_clip))
dict_add = {}
dict_clip = {}
return model
'''
if __name__ == "__main__":
model = onnx.load('/home/zqiu/models/ghostnet.onnx')
merge_hard_sigmiod(model)
onnx.save(model, './hs.onnx')
'''
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。