1 Star 0 Fork 6

littlezheng/onnx_convert

forked from ahqzy/onnx_convert 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
ln_convert.py 46.85 KB
一键复制 编辑 原始数据 按行查看 历史
ahqzy 提交于 2023-04-03 16:33 . add log
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925
import onnx
import sys
import argparse
import values, sys
import numpy as np
from onnx import onnx_pb as onnx_proto
import log
logger = log.getLogger(__name__, log.INFO)
def make_fp32_tensor_from_fp16(fp16_tensor, fp32_tensor_name):
float_list = []
if fp16_tensor.int32_data:
logger.debug('make_fp32_tensor_from_fp16, int32_data')
num = np.array(fp16_tensor.int32_data)
float_list = num.astype(np.float_).tolist()
if fp16_tensor.raw_data:
logger.debug('make_fp32_tensor_from_fp16, raw_data')
float_list = np.fromstring(fp16_tensor.raw_data, dtype='float16')
num = np.array(float_list)
float_list = num.astype(np.float_).tolist()
t = onnx.helper.make_tensor(name=fp32_tensor_name,
data_type=onnx_proto.TensorProto.FLOAT,
dims=fp16_tensor.dims,
vals=float_list)
return t
def is_unused_init(model, init):
for node in model.graph.node:
if init.name in node.input:
return False
return True
def remove_unused_initializer(model, unused_init_list):
for init in unused_init_list:
if is_unused_init(model, init):
logger.debug('remove unused init: {}'.format(init.name))
model.graph.initializer.remove(init)
def remove_invalid_sub_node(model):
invalid_sub_node_list = []
for node in model.graph.node:
if node.op_type == 'Sub':
used = False
sub_output = node.output[0]
for n in model.graph.node:
if sub_output in n.input:
used = True
break
if used == False:
invalid_sub_node_list.append(node)
for node in invalid_sub_node_list:
model.graph.node.remove(node)
# pattern 1:
# --- --- --- --- --- --- --- --
# | |
# ***(0) --- ReduceMean(1) --- Sub(2) --- Pow(3) --- ReduceMean(4) --- Add(5) --- Sqrt(6) --- Div(7) --- Mul(8) --- (Add)(9)
# | |
# --- --- --- ---
class MergeLNPattern1():
def __init__(self, model):
logger.debug('MergeLNPattern1 Init')
self.model = model
self.got_ln = False
self.search = True
self.unused_init_list = []
self.loop = 0
self.dict_rm = {}
self.dict_sub = {}
self.dict_pow = {}
self.dict_rm2 = {}
self.dict_add = {}
self.dict_sqrt = {}
self.dict_div = {}
self.dict_mul = {}
self.dict_add2 = {}
self.rm1_axes = [-1]
self.rm2_axes = [-1]
self.ready = False
self.ready2 = False
def clear(self):
self.dict_rm = {}
self.dict_sub = {}
self.dict_pow = {}
self.dict_rm2 = {}
self.dict_add = {}
self.dict_sqrt = {}
self.dict_div = {}
self.dict_mul = {}
self.dict_add2 = {}
self.rm1_axes = [-1]
self.rm2_axes = [-1]
self.ready = False
self.ready2 = False
self.search = False
def merge(self):
while self.search == True:
self.clear()
self.loop = self.loop + 1
for node_id, node in enumerate(self.model.graph.node):
self.loop = self.loop + 1
#print(node_id, ", name:", node.name, ", input:", node.input, ", output:", node.output, \
# ", op:", node.op_type, ', len(input):', len(node.input))
if node.op_type == 'ReduceMean':
if self.ready == True:
if node.input == self.dict_pow['output']:
self.dict_rm2['input'] = node.input
self.dict_rm2['output'] = node.output
self.dict_rm2['id'] = node_id
attributes = node.attribute
for attr in attributes:
if attr.name == 'axes':
self.rm2_axes = attr.ints
logger.debug('self.rm2_axes: {}'.format(self.rm2_axes))
if len(self.rm2_axes) != 1 or self.rm2_axes != self.rm1_axes:
logger.debug('--This ReduceMean IsNot we are looking for...')
self.clear()
else:
logger.debug('self.clear ReduceMean Sub Pow')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
logger.debug('self.dict_pow: {}'.format(self.dict_pow))
self.clear()
else:
self.dict_rm['input'] = node.input
self.dict_rm['output'] = node.output
self.dict_rm['id'] = node_id
attributes = node.attribute
for attr in attributes:
if attr.name == 'axes':
self.rm1_axes = attr.ints
logger.debug('self.rm1_axes: {}'.format(self.rm1_axes))
if len(self.rm1_axes) != 1:
logger.debug('This ReduceMean IsNot we are looking for...')
self.clear()
break
if node.op_type == 'Sub':
if self.dict_rm and node.input[0] == self.dict_rm['input'][0] and node.input[1] == self.dict_rm['output'][0]:
self.dict_sub['input'] = node.input
self.dict_sub['output'] = node.output
self.dict_sub['id'] = node_id
logger.debug('got first pair: {} {}'.format(self.dict_sub['input'], self.dict_sub['output']))
else:
logger.debug('self.clear ReduceMean, self.dict_rm: {}'.format(self.dict_rm))
self.clear()
if node.op_type == 'Pow':
if self.dict_rm and self.dict_sub and node.input[0] == self.dict_sub['output'][0]:
v = values.get_init_value(self.model, node.input[1])
logger.debug('pow exp: {} {}'.format(v, type(v)))
if v == 2:
self.ready = True
self.dict_pow['input'] = node.input
self.dict_pow['output'] = node.output
self.dict_pow['id'] = node_id
logger.debug('got second pair: {} {}'.format(self.dict_pow['input'], self.dict_pow['output']))
else:
logger.debug('--self.clear ReduceMean and Sub')
logger.debug('--self.dict_rm: {}'.format(self.dict_rm))
logger.debug('--self.dict_sub: {}'.format(self.dict_sub))
self.clear()
else:
logger.debug('self.clear ReduceMean and Sub')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
self.clear()
if node.op_type == 'Add':
if self.ready == True and self.ready2 == True:
if node.input[0] == self.dict_mul['output'][0]:
self.dict_add2['input'] = node.input
self.dict_add2['output'] = node.output
self.dict_add2['id'] = node_id
self.search = True
self.got_ln = True
logger.debug('Got a LayerNorm op')
###
rm_node = self.model.graph.node[self.dict_rm['id']]
sub_node = self.model.graph.node[self.dict_sub['id']]
pow_node = self.model.graph.node[self.dict_pow['id']]
rm2_node = self.model.graph.node[self.dict_rm2['id']]
add_node = self.model.graph.node[self.dict_add['id']]
sqrt_node = self.model.graph.node[self.dict_sqrt['id']]
div_node = self.model.graph.node[self.dict_div['id']]
mul_node = self.model.graph.node[self.dict_mul['id']]
add2_node = self.model.graph.node[self.dict_add2['id']]
self.model.graph.node.remove(rm_node)
scale_name = self.dict_mul['input'][1]
beta_name = self.dict_add2['input'][1]
'''
for init in self.model.graph.initializer:
if init.name == self.dict_mul['input'][1]:
if init.data_type == onnx_proto.TensorProto.FLOAT16:
scale_name = scale_name + '_to_fp32__'
t = make_fp32_tensor_from_fp16(init, scale_name)
self.model.graph.initializer.append(t)
if init not in self.unused_init_list:
self.unused_init_list.append(init)
elif init.name == self.dict_add2['input'][1]:
if init.data_type == onnx_proto.TensorProto.FLOAT16:
beta_name = beta_name + '_to_fp32__'
t = make_fp32_tensor_from_fp16(init, beta_name)
self.model.graph.initializer.append(t)
if init not in self.unused_init_list:
self.unused_init_list.append(init)
'''
ln_node = onnx.helper.make_node(
name = node.name + '_to_layernorm_' + str(self.loop),
op_type='LayerNorm',
inputs=[self.dict_rm['input'][0], scale_name, beta_name],
outputs=self.dict_add2['output'],
axis=self.rm1_axes[0],
epsilon=1e-5,
stash_type=0,
domain='com.metax-tech'
)
self.model.graph.node.insert(self.dict_rm['id'], ln_node)
self.model.graph.node.remove(sub_node)
self.model.graph.node.remove(pow_node)
self.model.graph.node.remove(rm2_node)
self.model.graph.node.remove(add_node)
self.model.graph.node.remove(sqrt_node)
self.model.graph.node.remove(div_node)
self.model.graph.node.remove(mul_node)
self.model.graph.node.remove(add2_node)
break
else:
logger.debug('--self.clear ReduceMean Sub Pow ReduceMean2')
logger.debug('--self.dict_rm: {}'.format(self.dict_rm))
logger.debug('--self.dict_sub: {}'.format(self.dict_sub))
logger.debug('--self.dict_pow: {}'.format(self.dict_pow))
logger.debug('--self.dict_rm2: {}'.format(self.dict_rm2))
self.clear()
else:
if self.dict_rm and self.dict_sub and self.dict_pow and self.dict_rm2 and node.input[0] == self.dict_rm2['output'][0]:
self.dict_add['input'] = node.input
self.dict_add['output'] = node.output
self.dict_add['id'] = node_id
logger.debug('got third pair: {} {}'.format(self.dict_add['input'], self.dict_add['output']))
else:
logger.debug('self.clear ReduceMean Sub Pow ReduceMean2')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
logger.debug('self.dict_pow: {}'.format(self.dict_pow))
logger.debug('self.dict_rm2: {}'.format(self.dict_rm2))
self.clear()
if node.op_type == 'Sqrt':
if self.dict_rm and self.dict_sub and self.dict_pow and self.dict_rm2 and \
self.dict_add and node.input == self.dict_add['output']:
self.dict_sqrt['input'] = node.input
self.dict_sqrt['output'] = node.output
self.dict_sqrt['id'] = node_id
logger.debug('got forth pair: {} {}'.format(self.dict_sqrt['input'], self.dict_sqrt['output']))
else:
logger.debug('self.clear ReduceMean Sub Pow ReduceMean2 Add')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
logger.debug('self.dict_pow: {}'.format(self.dict_pow))
logger.debug('self.dict_rm2: {}'.format(self.dict_rm2))
logger.debug('self.dict_add: {}'.format(self.dict_add))
self.clear()
if node.op_type == 'Div':
if self.dict_rm and self.dict_sub and self.dict_pow and self.dict_rm2 and \
self.dict_add and self.dict_sqrt and node.input[0] == self.dict_sub['output'][0] \
and node.input[1] == self.dict_sqrt['output'][0]:
self.dict_div['input'] = node.input
self.dict_div['output'] = node.output
self.dict_div['id'] = node_id
logger.debug('got fifth pair: {} {}'.format(self.dict_div['input'], self.dict_div['output']))
else:
logger.debug('-self.clear ReduceMean Sub Pow ReduceMean2 Add Sqrt')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
logger.debug('self.dict_pow: {}'.format(self.dict_pow))
logger.debug('self.dict_rm2: {}'.format(self.dict_rm2))
logger.debug('self.dict_add: {}'.format(self.dict_add))
logger.debug('self.dict_sqrt: {}'.format(self.dict_sqrt))
self.clear()
if node.op_type == 'Mul':
if self.dict_rm and self.dict_sub and self.dict_pow and self.dict_rm2 and \
self.dict_add and self.dict_sqrt and self.dict_div and (node.input[0] == self.dict_div['output'][0] or node.input[1] == self.dict_div['output'][0]):
self.dict_mul['input'] = node.input
self.dict_mul['output'] = node.output
self.dict_mul['id'] = node_id
self.ready2 = True
logger.debug('got sixth pair: {} {}'.format(self.dict_mul['input'], self.dict_mul['output']))
#print('got scale:', scale)
else:
logger.debug('--self.clear ReduceMean Sub Pow ReduceMean2 Add Sqrt Div')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
logger.debug('self.dict_pow: {}'.format(self.dict_pow))
logger.debug('self.dict_rm2: {}'.format(self.dict_rm2))
logger.debug('self.dict_add: {}'.format(self.dict_add))
logger.debug('self.dict_sqrt: {}'.format(self.dict_sqrt))
logger.debug('self.dict_mul: {}'.format(self.dict_mul))
self.clear()
if self.got_ln == True:
op_set = self.model.opset_import.add()
op_set.domain = 'com.metax-tech'
op_set.version = 1
#onnx.save(model, export_onnx)
remove_unused_initializer(self.model, self.unused_init_list)
remove_invalid_sub_node(self.model)
return self.model
# pattern 2:
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- -- Mul(11) --- --- Add(12) ---
# ^ | ^ ^
# | v | |
# ***(0) --- ReduceMean(1) --- Sub(2) --- Mul(3) --- ReduceMean(4) --- Add(5) --- Sqrt(6) --- Reciprocal(7) --- Mul(8) --- Mul(9) --- Sub(10)
# | |
# v ^
# -- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --
class MergeLNPattern2():
def __init__(self, model):
logger.debug('MergeLNPattern2 Init--------------------------')
self.model = model
self.got_ln = False
self.search = True
self.loop = 0
self.unused_init_list = []
self.dict_rm = {}
self.dict_sub = {}
self.dict_sub2 = {}
self.dict_mul = {}
self.dict_rm2 = {}
self.dict_add = {}
self.dict_sqrt = {}
self.dict_div = {}
self.dict_mul2 = {}
self.dict_mul3 = {}
self.dict_mul4 = {}
self.dict_add2 = {}
self.rm1_axes = [-1]
self.rm2_axes = [-1]
self.ready_for_reducemean = False
self.ready_for_mul_first = False
self.ready_for_mul_second = False
self.ready_for_mul_third = False
self.ready_for_sub = False
self.got_rm_axes = False
self.ready2 = False
def clear(self):
logger.debug('MergeLNPattern2 clear--------------------------')
self.unused_init_list = []
self.dict_rm = {}
self.dict_sub = {}
self.dict_sub2 = {}
self.dict_mul = {}
self.dict_rm2 = {}
self.dict_add = {}
self.dict_sqrt = {}
self.dict_div = {}
self.dict_mul2 = {}
self.dict_mul3 = {}
self.dict_mul4 = {}
self.dict_add2 = {}
self.rm1_axes = [-1]
self.rm2_axes = [-1]
self.ready_for_reducemean = False
self.ready_for_mul_first = False
self.ready_for_mul_second = False
self.ready_for_mul_third = False
self.ready_for_sub = False
self.got_rm_axes = False
self.ready2 = False
self.search = False
def merge(self):
while self.search == True:
self.loop = self.loop + 1
self.clear()
for node_id, node in enumerate(self.model.graph.node):
self.loop = self.loop + 1
#print(node_id, ", name:", node.name, ", input:", node.input, ", output:", node.output, \
# ", op:", node.op_type, ', len(input):', len(node.input))
if node.op_type == 'ReduceMean' or node.op_type == 'GlobalAveragePool':
if self.ready_for_reducemean == True:
if node.input == self.dict_mul['output']:
self.dict_rm2['input'] = node.input
self.dict_rm2['output'] = node.output
self.dict_rm2['id'] = node_id
if node.op_type == 'ReduceMean':
attributes = node.attribute
for attr in attributes:
if attr.name == 'axes':
self.rm2_axes = attr.ints
logger.debug('self.rm2_axes: {}'.format(self.rm2_axes))
if len(self.rm2_axes) != 1 or self.rm2_axes != self.rm1_axes:
logger.debug('--This ReduceMean IsNot we are looking for...')
self.clear()
else:
self.got_rm_axes = True
else:
logger.debug('self.clear ReduceMean Sub Pow')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
logger.debug('self.dict_mul: {}'.format(self.dict_mul))
self.clear()
else:
self.dict_rm['input'] = node.input
self.dict_rm['output'] = node.output
self.dict_rm['id'] = node_id
if node.op_type == 'ReduceMean':
attributes = node.attribute
for attr in attributes:
if attr.name == 'axes':
self.rm1_axes = attr.ints
logger.debug('self.rm1_axes: {}'.format(self.rm1_axes))
if len(self.rm1_axes) != 1:
logger.debug('This ReduceMean IsNot we are looking for...')
self.clear()
break
if node.op_type == 'Sub':
if self.ready_for_sub == True:
if node.input[1] == self.dict_mul3['output'][0]:
self.dict_sub2['input'] = node.input
self.dict_sub2['output'] = node.output
self.dict_sub2['id'] = node_id
self.ready2 = True
logger.debug('got eighth pair: {} {}'.format(self.dict_sub2['input'], self.dict_sub2['output']))
else:
logger.debug('---self.clear ReduceMean Sub Pow ReduceMean2 Add Sqrt')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
logger.debug('self.dict_mul: {}'.format(self.dict_mul))
logger.debug('self.dict_rm2: {}'.format(self.dict_rm2))
logger.debug('self.dict_add: {}'.format(self.dict_add))
logger.debug('self.dict_sqrt: {}'.format(self.dict_sqrt))
self.clear()
else:
if self.dict_rm and node.input[0] == self.dict_rm['input'][0] and node.input[1] == self.dict_rm['output'][0]:
self.dict_sub['input'] = node.input
self.dict_sub['output'] = node.output
self.dict_sub['id'] = node_id
logger.debug('got first pair: {} {}'.format(self.dict_sub['input'], self.dict_sub['output']))
else:
logger.debug('---self.clear ReduceMean, self.dict_rm: {} {}'.format(self.dict_rm, node.name))
self.clear()
if node.op_type == 'Mul':
if self.ready_for_mul_first == True and self.ready_for_mul_second == True:
if node.input[0] == self.dict_rm['output'][0] and node.input[1] == self.dict_mul2['output'][0]:
self.dict_mul3['input'] = node.input
self.dict_mul3['output'] = node.output
self.dict_mul3['id'] = node_id
self.ready_for_sub = True
logger.debug('got seventh pair: {} {}'.format(self.dict_mul3['input'], self.dict_mul3['output']))
elif node.input[0] == self.dict_rm['input'][0] and node.input[1] == self.dict_mul2['output'][0]:
self.dict_mul4['input'] = node.input
self.dict_mul4['output'] = node.output
self.dict_mul4['id'] = node_id
self.ready_for_mul_third = True
logger.debug('got nineth pair: {} {}'.format(self.dict_mul4['input'], self.dict_mul4['output']))
else:
logger.debug('---self.clear ReduceMean and Sub')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
self.clear()
elif self.ready_for_mul_first == True and self.ready_for_mul_second == False and self.ready_for_mul_third == False:
if node.input[0] == self.dict_div['output'][0]:
self.dict_mul2['input'] = node.input
self.dict_mul2['output'] = node.output
self.dict_mul2['id'] = node_id
self.ready_for_mul_second = True
logger.debug('got sixth pair: {} {}'.format(self.dict_mul2['input'], self.dict_mul2['output']))
else:
logger.debug('---self.clear ReduceMean and Sub')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
self.clear()
else:
if self.dict_rm and self.dict_sub and node.input[0] == self.dict_sub['output'][0] and node.input[1] == self.dict_sub['output'][0]:
self.ready_for_reducemean = True
self.dict_mul['input'] = node.input
self.dict_mul['output'] = node.output
self.dict_mul['id'] = node_id
logger.debug('got second pair: {} {}'.format(self.dict_mul['input'], self.dict_mul['output']))
else:
logger.debug('self.clear ReduceMean and Sub')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
self.clear()
if node.op_type == 'Add':
if self.ready2 == True:
if node.input[0] == self.dict_mul4['output'][0] and node.input[1] == self.dict_sub2['output'][0]:
self.dict_add2['input'] = node.input
self.dict_add2['output'] = node.output
self.dict_add2['id'] = node_id
self.search = True
self.got_ln = True
logger.debug('Got a LayerNorm op')
###
rm_node = self.model.graph.node[self.dict_rm['id']]
sub_node = self.model.graph.node[self.dict_sub['id']]
sub2_node = self.model.graph.node[self.dict_sub2['id']]
mul_node = self.model.graph.node[self.dict_mul['id']]
mul2_node = self.model.graph.node[self.dict_mul2['id']]
mul3_node = self.model.graph.node[self.dict_mul3['id']]
mul4_node = self.model.graph.node[self.dict_mul4['id']]
rm2_node = self.model.graph.node[self.dict_rm2['id']]
add_node = self.model.graph.node[self.dict_add['id']]
sqrt_node = self.model.graph.node[self.dict_sqrt['id']]
div_node = self.model.graph.node[self.dict_div['id']]
add2_node = self.model.graph.node[self.dict_add2['id']]
self.model.graph.node.remove(rm_node)
if self.got_rm_axes == True:
axis_ = self.rm1_axes[0]
else:
axis_ = -1
logger.debug('get axis_ = {}'.format(axis_))
scale_name = self.dict_mul2['input'][1]
beta_name = self.dict_sub2['input'][0]
'''
for init in self.model.graph.initializer:
if init.name == self.dict_mul2['input'][1]:
if init.data_type == onnx_proto.TensorProto.FLOAT16:
scale_name = scale_name + '_to_fp32__'
t = make_fp32_tensor_from_fp16(init, scale_name)
self.model.graph.initializer.append(t)
if init not in self.unused_init_list:
self.unused_init_list.append(init)
elif init.name == self.dict_sub2['input'][0]:
if init.data_type == onnx_proto.TensorProto.FLOAT16:
beta_name = beta_name + '_to_fp32__'
t = make_fp32_tensor_from_fp16(init, beta_name)
self.model.graph.initializer.append(t)
if init not in self.unused_init_list:
self.unused_init_list.append(init)
'''
ln_node = onnx.helper.make_node(
name = node.name + '_to_layernorm_' + str(self.loop),
op_type='LayerNorm',
inputs=[self.dict_rm['input'][0], scale_name, beta_name],
outputs=self.dict_add2['output'],
axis=axis_,
epsilon=1e-5,
stash_type=0,
domain='com.metax-tech'
)
self.model.graph.node.insert(self.dict_rm['id'], ln_node)
self.model.graph.node.remove(sub_node)
self.model.graph.node.remove(sub2_node)
self.model.graph.node.remove(rm2_node)
self.model.graph.node.remove(add_node)
self.model.graph.node.remove(sqrt_node)
self.model.graph.node.remove(div_node)
self.model.graph.node.remove(mul_node)
self.model.graph.node.remove(mul2_node)
self.model.graph.node.remove(mul3_node)
self.model.graph.node.remove(mul4_node)
self.model.graph.node.remove(add2_node)
break
else:
logger.debug('--self.clear ReduceMean Sub Pow ReduceMean2')
logger.debug('--self.dict_rm: {}'.format(self.dict_rm))
logger.debug('--self.dict_sub: {}'.format(self.dict_sub))
logger.debug('--self.dict_pow: {}'.format(self.dict_pow))
logger.debug('--self.dict_rm2: {}'.format(self.dict_rm2))
self.clear()
else:
if self.dict_rm and self.dict_sub and self.dict_mul and self.dict_rm2 and node.input[0] == self.dict_rm2['output'][0]:
self.dict_add['input'] = node.input
self.dict_add['output'] = node.output
self.dict_add['id'] = node_id
logger.debug('got third pair: {} {}'.format(self.dict_add['input'], self.dict_add['output']))
else:
logger.debug('self.clear ReduceMean Sub Mul ReduceMean2')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
logger.debug('self.dict_mul: {}'.format(self.dict_mul))
logger.debug('self.dict_rm2: {}'.format(self.dict_rm2))
self.clear()
if node.op_type == 'Sqrt':
if self.dict_rm and self.dict_sub and self.dict_mul and self.dict_rm2 and \
self.dict_add and node.input == self.dict_add['output']:
self.dict_sqrt['input'] = node.input
self.dict_sqrt['output'] = node.output
self.dict_sqrt['id'] = node_id
logger.debug('got forth pair: {} {}'.format(self.dict_sqrt['input'], self.dict_sqrt['output']))
else:
logger.debug('self.clear ReduceMean Sub Pow ReduceMean2 Add')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
logger.debug('self.dict_mul: {}'.format(self.dict_mul))
logger.debug('self.dict_rm2: {}'.format(self.dict_rm2))
logger.debug('self.dict_add: {}'.format(self.dict_add))
self.clear()
if node.op_type == 'Reciprocal':
if self.dict_rm and self.dict_sub and self.dict_mul and self.dict_rm2 and \
self.dict_add and self.dict_sqrt and node.input[0] == self.dict_sqrt['output'][0]:
self.dict_div['input'] = node.input
self.dict_div['output'] = node.output
self.dict_div['id'] = node_id
self.ready_for_mul_first = True
logger.debug('got fifth pair: {} {}'.format(self.dict_div['input'], self.dict_div['output']))
else:
logger.debug('----self.clear ReduceMean Sub Pow ReduceMean2 Add Sqrt')
logger.debug('self.dict_rm: {}'.format(self.dict_rm))
logger.debug('self.dict_sub: {}'.format(self.dict_sub))
logger.debug('self.dict_mul: {}'.format(self.dict_mul))
logger.debug('self.dict_rm2: {}'.format(self.dict_rm2))
logger.debug('self.dict_add: {}'.format(self.dict_add))
logger.debug('self.dict_sqrt: {}'.format(self.dict_sqrt))
self.clear()
if self.got_ln == True:
op_set = self.model.opset_import.add()
op_set.domain = 'com.metax-tech'
op_set.version = 1
#onnx.save(model, export_onnx)
remove_unused_initializer(self.model, self.unused_init_list)
return self.model
# rms layernorm
class MergeRMSLn():
def __init__(self, model):
logger.debug('MergeRMSLn Init--------------------------')
self.model = model
self.got_ln = False
self.search = True
self.unused_init_list = []
self.loop = 0
self.ready = False
self.dict_rm = {}
self.dict_pow = {}
self.dict_add = {}
self.dict_sqrt = {}
self.dict_div = {}
self.dict_mul = {}
self.dict_mul2 = {}
self.input_type = onnx.TensorProto.FLOAT
self.input_shape = [1]
self.scale = [1]
self.rm1_axes = [-1]
def clear(self):
self.dict_rm = {}
self.dict_pow = {}
self.dict_add = {}
self.dict_sqrt = {}
self.dict_div = {}
self.dict_mul = {}
self.dict_mul2 = {}
self.rm1_axes = [-1]
self.search = False
self.ready = False
def merge(self):
while self.search == True:
self.clear()
self.loop = self.loop + 1
for node_id, node in enumerate(self.model.graph.node):
self.loop = self.loop + 1
#print(node_id, ", name:", node.name, ", input:", node.input, ", output:", node.output, \
# ", op:", node.op_type, ', len(input):', len(node.input))
if node.op_type == 'Pow':
self.dict_pow['input'] = node.input
self.dict_pow['output'] = node.output
self.dict_pow['id'] = node_id
for vi in self.model.graph.value_info:
if vi.name == node.input[0]:
self.input_type = vi.type.tensor_type.elem_type
self.input_shape = [d.dim_value for d in vi.type.tensor_type.shape.dim]
print('got input shape: ', self.input_shape)
print('got input type: ', self.input_type)
if node.op_type == 'ReduceMean':
if self.dict_pow and node.input == self.dict_pow['output']:
self.dict_rm['input'] = node.input
self.dict_rm['output'] = node.output
self.dict_rm['id'] = node_id
print('got first pair:', self.dict_rm['input'], self.dict_rm['output'])
attributes = node.attribute
for attr in attributes:
if attr.name == 'axes':
self.rm1_axes = attr.ints
print('self.rm1_axes: ', self.rm1_axes)
if len(self.rm1_axes) != 1:
print('This ReduceMean IsNot we are looking for...')
self.clear()
break
else:
self.clear()
if node.op_type == 'Add':
if self.dict_pow and self.dict_rm and node.input[0] == self.dict_rm['output'][0]:
self.dict_add['input'] = node.input
self.dict_add['output'] = node.output
self.dict_add['id'] = node_id
print('got second pair:', self.dict_rm['input'], self.dict_rm['output'])
else:
print('self.clear ReduceMean Sub Pow ReduceMean2')
print('self.dict_rm:', self.dict_rm)
print('self.dict_pow:', self.dict_pow)
self.clear()
if node.op_type == 'Sqrt':
if self.dict_pow and self.dict_rm and \
self.dict_add and node.input == self.dict_add['output']:
self.dict_sqrt['input'] = node.input
self.dict_sqrt['output'] = node.output
self.dict_sqrt['id'] = node_id
print('got third pair:', self.dict_sqrt['input'], self.dict_sqrt['output'])
else:
print('self.clear ReduceMean Sub Pow ReduceMean2 Add')
print('self.dict_rm:', self.dict_rm)
print('self.dict_pow:', self.dict_pow)
print('self.dict_add:', self.dict_add)
self.clear()
if node.op_type == 'Div':
self.scale = values.get_init_value(self.model, node.input[0])
if self.scale and self.dict_pow and self.dict_rm and \
self.dict_add and self.dict_sqrt and node.input[1] == self.dict_sqrt['output'][0] :
self.dict_div['input'] = node.input
self.dict_div['output'] = node.output
self.dict_div['id'] = node_id
if isinstance(self.scale, np.ndarray):
self.scale = self.scale.tolist()
print('got fourth pair, scale', self.scale, type(self.scale))
else:
print('-self.clear ReduceMean Sub Pow ReduceMean2 Add Sqrt')
print('self.dict_rm:', self.dict_rm)
print('self.dict_pow:', self.dict_pow)
print('self.dict_add:', self.dict_add)
print('self.dict_sqrt:', self.dict_sqrt)
self.clear()
if node.op_type == 'Mul':
if self.ready == True and node.input[0] == self.dict_mul['output'][0]:
self.dict_mul2['input'] = node.input
self.dict_mul2['output'] = node.output
self.dict_mul2['id'] = node_id
self.search = True
self.got_ln = True
print('Got a LayerNorm op')
###
pow_node = self.model.graph.node[self.dict_pow['id']]
rm_node = self.model.graph.node[self.dict_rm['id']]
add_node = self.model.graph.node[self.dict_add['id']]
sqrt_node = self.model.graph.node[self.dict_sqrt['id']]
div_node = self.model.graph.node[self.dict_div['id']]
mul_node = self.model.graph.node[self.dict_mul['id']]
mul2_node = self.model.graph.node[self.dict_mul2['id']]
self.model.graph.node.remove(pow_node)
scale_name = self.dict_mul2['input'][1] #node.name + '_scale_' + str(self.loop)
beta_name = node.name + '_beta_' + str(self.loop)
'''
scale_tensor = onnx.helper.make_tensor(name=scale_name,
data_type=self.input_type,
dims=[self.input_shape[-1]],
vals=self.scale*self.input_shape[-1])
self.model.graph.initializer.append(scale_tensor)
'''
beta_tensor = onnx.helper.make_tensor(name=beta_name,
data_type=self.input_type,
dims=[self.input_shape[-1]],
vals=[0]*self.input_shape[-1])
self.model.graph.initializer.append(beta_tensor)
ln_node = onnx.helper.make_node(
name = node.name + '_to_layernorm_' + str(self.loop),
op_type='LayerNorm',
inputs=[self.dict_pow['input'][0], scale_name, beta_name],
outputs=self.dict_mul2['output'],
axis=self.rm1_axes[0],
epsilon=1e-5,
stash_type=0,
domain='com.metax-tech'
)
self.model.graph.node.insert(self.dict_pow['id'], ln_node)
self.model.graph.node.remove(rm_node)
self.model.graph.node.remove(add_node)
self.model.graph.node.remove(sqrt_node)
self.model.graph.node.remove(div_node)
self.model.graph.node.remove(mul_node)
self.model.graph.node.remove(mul2_node)
break
elif self.dict_pow and self.dict_rm and \
self.dict_add and self.dict_sqrt and self.dict_div and node.input[1] == self.dict_div['output'][0] \
and node.input[0] == self.dict_pow['input'][0] :
self.dict_mul['input'] = node.input
self.dict_mul['output'] = node.output
self.dict_mul['id'] = node_id
self.ready = True
print('got fifth pair:', self.dict_mul['input'], self.dict_mul['output'])
else:
print('--self.clear ReduceMean Sub Pow ReduceMean2 Add Sqrt Div')
print('self.dict_rm:', self.dict_rm)
print('self.dict_pow:', self.dict_pow)
print('self.dict_add:', self.dict_add)
print('self.dict_sqrt:', self.dict_sqrt)
print('self.dict_mul:', self.dict_mul)
self.clear()
if self.got_ln == True:
op_set = self.model.opset_import.add()
op_set.domain = 'com.metax-tech'
op_set.version = 1
#onnx.save(model, export_onnx)
remove_unused_initializer(self.model, self.unused_init_list)
return self.model
def merge_layernorm(model):
mlp1 = MergeLNPattern1(model)
model = mlp1.merge()
mlp2 = MergeLNPattern2(model)
model = mlp2.merge()
#mlp3 = MergeRMSLn(model)
#model = mlp3.merge()
return model
'''
model = onnx.load('./bert_fp16_test.onnx')
#model = onnx.load('./bert_sub.onnx')
m = merge_layernorm(model)
onnx.save(m, 'ln.onnx')
'''
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/littezheng/onnx_convert.git
git@gitee.com:littezheng/onnx_convert.git
littezheng
onnx_convert
onnx_convert
master

搜索帮助