1 Star 0 Fork 6

littlezheng/onnx_convert

forked from ahqzy/onnx_convert 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
mha_optimization.py 261.08 KB
一键复制 编辑 原始数据 按行查看 历史
ahqzy 提交于 2023-05-05 09:19 . bugfix(multi input, one of is dynamic)
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397
import onnx
import sys
import values, operation
import numpy as np
import log
logger = log.getLogger(__name__, log.DEBUG)
transpose_node_map = {}
reshape_node_map = {}
def get_prev_node_by_input(model, input_):
n = model.graph.node[0]
for node in model.graph.node:
if input_ in node.output:
return node, 0
return n, -1
def get_next_node_by_output(model, output):
n = model.graph.node[0]
for node in model.graph.node:
if output in node.input:
return node, 0
return n, -1
def get_all_next_node_by_output(model, output):
node_list = []
ok = -1
for node in model.graph.node:
if output in node.input:
node_list.append(node)
ok = 0
return node_list, ok
def insert_node(model, insert_node, follow_up_node):
# 根据插入Node的输出修改后续node的输入
#follow_up_node.input[0] = insert_node.output[0]
# 找到后续Node的索引位置,并将插入节点插入到graph中
for follow_up_node_index, _follow_up_node in enumerate(model.graph.node):
if _follow_up_node == follow_up_node:
logger.debug("follow_up_node_index: {}".format(follow_up_node_index))
model.graph.node.insert(follow_up_node_index, insert_node)
break
#Matmul-->Add->Reshape-->Transpose
def get_matmul_input_path_pattern_one(model, input_name):
res = -1
#node_list = []
node_dict = {}
logger.debug('get_matmul_input_path_pattern_one, input_name: {}'.format(input_name))
input_pre, ok = get_prev_node_by_input(model, input_name)
if ok == 0 and input_pre.op_type == 'Transpose':
logger.debug('got match Transpose node: {}'.format(input_pre.name))
attributes = input_pre.attribute
for attr in attributes:
if attr.name == 'perm':
v = values.get_tensor_value(attr.t)
logger.debug('got transpose shape{} for{}'.format(v, input_pre.name))
break
input_p_pre, ok = get_prev_node_by_input(model, input_pre.input[0])
if ok == 0 and input_p_pre.op_type == 'Reshape':
#####################
data, shape = values.get_init_value_and_shape(model, input_p_pre.input[1])
if isinstance(data, list) and data == []:
logger.debug('reshape_data is not in initilizer')
data = values.get_constant_value(model, input_p_pre.input[1])
if len(data) == 4 or len(data) == 3:
logger.debug('got match Reshape node: {}'.format(input_p_pre.name))
##################
input_pp_pre, ok = get_prev_node_by_input(model, input_p_pre.input[0])
if ok == 0 and input_pp_pre.op_type == 'Add':
################
addA_name = input_pp_pre.input[0]
addA, shapeA = values.get_init_value_and_shape(model, input_pp_pre.input[0])
'''
if isinstance(addA, list) and addA == []:
print('addA is not in initilizer')
addA = values.get_constant_value(model, input_pp_pre.input[1])
'''
add_tensor_two = True
if len(shapeA) == 0:
addA_name = input_pp_pre.input[1]
add_tensor_two = False
addA, shapeA = values.get_init_value_and_shape(model, input_pp_pre.input[1])
if len(shapeA) == 1:
logger.debug('got match Add node: {}'.format(input_pp_pre.name))
###########
add_input = input_pp_pre.input[1]
if add_tensor_two == False:
add_input = input_pp_pre.input[0]
input_ppp_pre, ok = get_prev_node_by_input(model, add_input)
logger.debug('----got matmul node: {}'.format(input_ppp_pre.name))
if ok == 0 and input_ppp_pre.op_type == 'MatMul':
############################
shapeA = values.get_tensor_shape_by_name(model, input_ppp_pre.input[0])
inputB, shapeB = values.get_init_value_and_shape(model, input_ppp_pre.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('inputB is not in initilizer')
inputB = values.get_constant_value(model, input_ppp_pre.input[1])
if (len(shapeA) == 3 or len(shapeA) == 2) and len(shapeB) == 2:
logger.debug('got match MatMul node: {}'.format(input_ppp_pre.name))
res = 0
node_list = [input_ppp_pre, input_pp_pre, input_p_pre, input_pre]
node_dict['node_list'] = node_list
node_dict['addA'] = addA_name
node_dict['matmul_AShape'] = shapeA
node_dict['inputB'] = inputB
node_dict['matmul_BShape'] = shapeB
input_pppp_pre, ok = get_prev_node_by_input(model, input_ppp_pre.input[0])
if ok == 0:
node_dict['prev'] = input_pppp_pre.output[0]
logger.debug('--- map key: {}'.format(input_pppp_pre.output[0]))
else:
node_dict['prev'] = input_ppp_pre.input[0]
logger.debug('pre node maybe input: {}'.format(input_ppp_pre.input[0]))
elif ok == 0:
res = 1
node_list = [input_pre]
node_dict['node_list'] = node_list
return node_dict, res
#Transpose-->Reshape-->Matmul-->Add-->Add
def get_matmul_input_path_pattern_two(model, input_name):
res = -1
node_dict = {}
logger.debug('get_matmul_input_path_pattern_two, input_name: {}'.format(input_name))
next_node, ok = get_next_node_by_output(model, input_name)
if ok == 0 and next_node.op_type == 'Reshape':
data, shape = values.get_init_value_and_shape(model, next_node.input[1])
if isinstance(data, list) and data == []:
logger.debug('---reshape_data is not in initilizer')
data = values.get_constant_value(model, next_node.input[1])
if len(data) == 3 or len(data) == 2:
logger.debug('----got match Reshape node: {}'.format(next_node.name))
n_next_node, ok = get_next_node_by_output(model, next_node.output[0])
if ok == 0 and n_next_node.op_type == 'MatMul':
#####################
shapeA = values.get_tensor_shape_by_name(model, n_next_node.input[0])
inputB, shapeB = values.get_init_value_and_shape(model, n_next_node.input[1])
logger.debug('++++++++++++++++++shapeA, shapeB: {} {}'.format(shapeA, shapeB))
if isinstance(inputB, list) and inputB == []:
logger.debug('inputB is not in initilizer')
inputB = values.get_constant_value(model, n_next_node.input[1])
if (len(shapeA) == 3 or len(shapeA) == 2) and len(shapeB) == 2:
node_dict['matmul_AShape'] = shapeA
node_dict['inputB'] = inputB
node_dict['matmul_BShape'] = shapeB
logger.debug('----got match Matmul node: {}'.format(n_next_node.name))
####################
nn_next_node, ok = get_next_node_by_output(model, n_next_node.output[0])
if ok == 0 and nn_next_node.op_type == 'Add':
################
addA_name = nn_next_node.input[0]
node_dict['addA'] = addA_name
node_dict['addFirst'] = True
addA, shapeA = values.get_init_value_and_shape(model, nn_next_node.input[0])
'''
if isinstance(addA, list) and addA == []:
print('addA is not in initilizer')
addA = values.get_constant_value(model, nn_next_node.input[1])
'''
if len(shapeA) == 0:
addA_name = nn_next_node.input[1]
node_dict['addA'] = addA_name
node_dict['addFirst'] = False
addA, shapeA = values.get_init_value_and_shape(model, nn_next_node.input[1])
if len(shapeA) == 1:
logger.debug('---got match Add node: {}'.format(nn_next_node.name))
###########
nnn_next_node, ok = get_next_node_by_output(model, nn_next_node.output[0])
if ok == 0 and nnn_next_node.op_type == 'Add':
logger.debug('----got match Add node2: {}'.format(n_next_node.name))
res = 0
node_list = [next_node, n_next_node, nn_next_node]
node_dict['node_list'] = node_list
return node_dict, res
def get_add_combination_pattern_one(model):
rm_list = []
sub_list = []
add_list = []
ars_list = []
for node in model.graph.node:
if node.op_type == 'ReduceMean':
rm_list.append(node)
if node.op_type == 'Sub':
sub_list.append(node)
if node.op_type == 'Add':
add_list.append(node)
#print('rm_input_list:', rm_input_list)
#print('sub_input_list:', sub_input_list)
#print('add_input_list:', add_input_list)
for node in model.graph.node:
if node.op_type == 'Add':
match_rm = False
match_sub = False
match_add = False
output = node.output[0]
for rm_node in rm_list:
if output in rm_node.input:
match_rm = True
match_rm_node = rm_node
break
for sub_node in sub_list:
if output in sub_node.input:
match_sub = True
match_sub_node = sub_node
break
for add_node in add_list:
if output in add_node.input:
match_add = True
match_add_node = add_node
break
if match_rm == True and match_sub == True and match_add == True:
logger.debug('found match add node: {}'.format(node.name))
ars = {}
ars['nextAdd'] = match_add_node
ars['currentAdd'] = node
ars['ReduceMean'] = match_rm_node
ars['Sub'] = match_sub_node
ars_list.append(ars)
return ars_list
def get_add_combination_pattern_two(model):
matmul_list = []
add_list = []
rm_list = []
sub_list = []
am_list = []
asr_list = []
for node in model.graph.node:
if node.op_type == 'Add':
add_list.append(node)
if node.op_type == 'MatMul':
matmul_list.append(node)
if node.op_type == 'ReduceMean':
rm_list.append(node)
if node.op_type == 'Sub':
sub_list.append(node)
for node in model.graph.node:
if node.op_type == 'Add':
match_add_node = None
match_matmul_node_list = []
nextAddInput1 = False
output = node.output[0]
for add_node in add_list:
if output in add_node.input:
match_add_node = add_node
if match_add_node.input[1] == output:
nextAddInput1 = True
break
for mm_node in matmul_list:
if output in mm_node.input:
match_matmul_node_list.append(mm_node)
if match_add_node != None and len(match_matmul_node_list) == 3:
logger.debug('found match add node: {}'.format(node.name))
am = {}
am['nextAdd'] = match_add_node
am['nextAddInput1'] = nextAddInput1
am['currentAdd'] = node
am['MatMulList'] = match_matmul_node_list
#am_list.append(am)
match_rm_node = None
match_sub_node = None
next_add_output = match_add_node.output[0]
for rm_node in rm_list:
if next_add_output in rm_node.input:
match_rm_node = rm_node
break
for sub_node in sub_list:
if next_add_output in sub_node.input:
match_sub_node = sub_node
break
if match_rm_node != None and match_sub_node != None:
am['Sub'] = match_sub_node
am['ReduceMean'] = match_rm_node
logger.debug('got sub and reducemean: {} {}'.format(match_sub_node.name, match_rm_node.name))
am_list.append(am)
return am_list
def get_add_combination_pattern_four(model):
matmul_list = []
add_list = []
rm_list = []
sub_list = []
am_list = []
asr_list = []
for node in model.graph.node:
if node.op_type == 'Add':
add_list.append(node)
if node.op_type == 'MatMul':
matmul_list.append(node)
if node.op_type == 'ReduceMean':
rm_list.append(node)
if node.op_type == 'Sub':
sub_list.append(node)
for node in model.graph.node:
if node.op_type == 'Reshape':
match_add_node = None
match_matmul_node_list = []
output = node.output[0]
for add_node in add_list:
if output in add_node.input:
match_add_node = add_node
break
for mm_node in matmul_list:
if output in mm_node.input:
match_matmul_node_list.append(mm_node)
if match_add_node != None and len(match_matmul_node_list) == 3:
logger.debug('found match reshape node: {}'.format(node.name))
am = {}
am['nextAdd'] = match_add_node
am['currentAdd'] = node
am['MatMulList'] = match_matmul_node_list
#am_list.append(am)
match_rm_node = None
match_sub_node = None
next_add_output = match_add_node.output[0]
for rm_node in rm_list:
if next_add_output in rm_node.input:
match_rm_node = rm_node
break
for sub_node in sub_list:
if next_add_output in sub_node.input:
match_sub_node = sub_node
break
if match_rm_node != None and match_sub_node != None:
am['Sub'] = match_sub_node
am['ReduceMean'] = match_rm_node
logger.debug('got sub and reducemean: {} {}'.format(match_sub_node.name, match_rm_node.name))
am_list.append(am)
return am_list
def get_add_combination_pattern_five(model):
pass
def handle_add_combination_pattern_two_three(model):
am_list = get_add_combination_pattern_two(model)
logger.debug('handle_add_combination_pattern_two_three------------')
#if len(am_list):
for am in am_list:
#am = am_list[0]
add_node = am['currentAdd']
next_add_node = am['nextAdd']
matmul_node_list = am['MatMulList']
nextAddInput1 = am['nextAddInput1']
logger.debug('handle_add_combination_pattern_two_three, add_node: {}, next_add_node: {}'.format(add_node.name, next_add_node.name))
###add transpose
ts_name = add_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, add_node.output[0])
ts_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[add_node.output[0]],
outputs=[ts_output_name],
perm=[0,2,1])
model.graph.value_info.append(transpose_output)
###add reshape-1
rs_name = add_node.name + '_reshape_1_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [ts_output_shape[1], ts_output_shape[2]] #TBD
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape_name = add_node.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs_node = onnx.helper.make_node(
'Reshape',
name=rs_name,
inputs=[ts_output_name, const_shape_name],
outputs=[rs_output_name])
model.graph.value_info.append(rs_output)
###add reshape-2
rs2_name = add_node.name + '_reshape_2_'
rs2_output_name = rs2_name + '_output_'
rs2_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
rs2_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape2_name = add_node.name + '_reshape2_data_'
const_shape2_tensor = onnx.helper.make_tensor(name=const_shape2_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape2_tensor)
######################################################
######################################################
rs2_node = onnx.helper.make_node(
'Reshape',
name=rs2_name,
inputs=[rs_output_name, const_shape2_name],
outputs=[rs2_output_name])
model.graph.value_info.append(rs2_output)
insert_node(model, rs2_node, next_add_node)
if nextAddInput1 == True:
next_add_node.input[1] = rs2_output_name
else:
next_add_node.input[0] = rs2_output_name
matmul_node_list[0].input[0] = rs2_output_name
matmul_node_list[1].input[0] = rs2_output_name
matmul_node_list[2].input[0] = rs2_output_name
insert_node(model, rs_node, rs2_node)
insert_node(model, ts_node, rs_node)
###################################################################
###########insert Transpose before ReduceMean and Sub
sub_node = am['Sub']
rm_node = am['ReduceMean']
ts_name = sub_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, next_add_node.output[0])
ts_output_shape = [add_output_shape[0], add_output_shape[1], add_output_shape[2]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[next_add_node.output[0]],
outputs=[ts_output_name],
perm=[0,2,1])
model.graph.value_info.append(transpose_output)
insert_node(model, ts_node, sub_node)
sub_node.input[0] = ts_output_name
rm_node.input[0] = ts_output_name
def handle_add_combination_pattern_four(model):
am_list = get_add_combination_pattern_two(model)
am_list2 = get_add_combination_pattern_four(model)
f = False
if len(am_list2) > 0:
am_list.append(am_list2[0])
f = True
logger.debug('handle_add_combination_pattern_four------------')
length = len(am_list)
#shape_dim = 3
#if len(am_list):
for idx, am in enumerate(am_list):
#am = am_list[0]
if idx != length - 1:
add_node = am['currentAdd']
next_add_node = am['nextAdd']
matmul_node_list = am['MatMulList']
###add transpose
ts_name = add_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, add_node.output[0])
shape_dim = len(add_output_shape)
perm_ = [0,2,1]
if shape_dim == 3:
ts_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
else:
perm_ = [1,0]
ts_output_shape = [add_output_shape[1], add_output_shape[0]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[add_node.output[0]],
outputs=[ts_output_name],
perm=perm_)
model.graph.value_info.append(transpose_output)
if shape_dim == 3:
###add reshape-1
rs_name = add_node.name + '_reshape_1_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [ts_output_shape[1], ts_output_shape[2]] #TBD
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape_name = add_node.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs_node = onnx.helper.make_node(
'Reshape',
name=rs_name,
inputs=[ts_output_name, const_shape_name],
outputs=[rs_output_name])
model.graph.value_info.append(rs_output)
###add reshape-2
rs2_name = add_node.name + '_reshape_2_'
rs2_output_name = rs2_name + '_output_'
const_shape2_name = add_node.name + '_reshape2_data_'
if shape_dim == 3:
inputs_ = [rs_output_name, const_shape2_name]
rs2_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
else:
inputs_ = [ts_output_name, const_shape2_name]
rs2_output_shape = [1, add_output_shape[1], add_output_shape[0]]
rs2_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape2_tensor = onnx.helper.make_tensor(name=const_shape2_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape2_tensor)
######################################################
######################################################
rs2_node = onnx.helper.make_node(
'Reshape',
name=rs2_name,
inputs=inputs_,
outputs=[rs2_output_name])
model.graph.value_info.append(rs2_output)
insert_node(model, rs2_node, next_add_node)
if shape_dim == 3:
next_add_node.input[0] = rs2_output_name
else:
next_add_node.input[1] = rs2_output_name
matmul_node_list[0].input[0] = rs2_output_name
matmul_node_list[1].input[0] = rs2_output_name
matmul_node_list[2].input[0] = rs2_output_name
if shape_dim == 3:
insert_node(model, rs_node, rs2_node)
insert_node(model, ts_node, rs_node)
else:
insert_node(model, ts_node, rs2_node)
update_tensor_shape(model, next_add_node.output[0], rs2_output_shape)
else:
next_add_node = am['nextAdd']
rs_node = am['currentAdd']
rs_shape = values.get_tensor_shape_by_name(model, rs_node.input[0])
rs_node.op_type = 'Transpose'
attr = onnx.helper.make_attribute('perm', [0,2,1])
rs_node.attribute.append(attr)
del rs_node.input[1:]
update_tensor_shape(model, rs_node.output[0], [rs_shape[0], rs_shape[2], rs_shape[1]])
update_tensor_shape(model, next_add_node.output[0], [rs_shape[0], rs_shape[2], rs_shape[1]])
###################################################################
###add reshape
logger.debug('------shape_dim: {}'.format(shape_dim))
if shape_dim == 2:
rs_name_ = next_add_node.name + '_reshape_'
rs_output_name_ = rs_name_ + '_output_'
const_shape_name_ = next_add_node.name + '_reshape_data_'
inputs_ = [next_add_node.output[0], const_shape_name_]
s = values.get_tensor_shape_by_name(model, next_add_node.input[1])
rs_output_shape_ = [s[1], s[2]]
rs_output_ = onnx.helper.make_tensor_value_info(rs_output_name_, onnx.TensorProto.FLOAT, rs_output_shape_)
const_shape_tensor_ = onnx.helper.make_tensor(name=const_shape_name_,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape_)],
vals=rs_output_shape_)
model.graph.initializer.append(const_shape_tensor_)
rs_node_ = onnx.helper.make_node(
'Reshape',
name=rs_name_,
inputs=inputs_,
outputs=[rs_output_name_])
model.graph.value_info.append(rs_output_)
###########insert Transpose before ReduceMean/Sub/Mul
mul_node = None
msr_node_list, ok = get_all_next_node_by_output(model, next_add_node.output[0])
if ok == 0:
for n in msr_node_list:
if n.op_type == 'Mul':
mul_node = n
break
sub_node = am['Sub']
rm_node = am['ReduceMean']
ts_name = sub_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, next_add_node.output[0])
dims = shape_dim #len(add_output_shape)
if dims == 3:
perm_ = [0,2,1]
inputs_ = [next_add_node.output[0]]
ts_output_shape = [add_output_shape[0], add_output_shape[1], add_output_shape[2]]
else:
perm_=[1,0]
inputs_ = [rs_output_name_]
ts_output_shape = [add_output_shape[2], add_output_shape[1]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=inputs_,
outputs=[ts_output_name],
perm=perm_)
model.graph.value_info.append(transpose_output)
insert_node(model, ts_node, sub_node)
if shape_dim == 2:
logger.debug('insert_node rs_node----')
insert_node(model, rs_node_, ts_node)
sub_node.input[0] = ts_output_name
rm_node.input[0] = ts_output_name
if mul_node != None:
mul_node.input[0] = ts_output_name
def get_matmul_block_one(model, matmul_node):
logger.debug('into get_matmul_block_one')
res = -1
node_dict = {}
#input_next, ok = get_next_node_by_output(model, input_)
input_next = matmul_node
if input_next.op_type == 'MatMul':
shapeA = values.get_tensor_shape_by_name(model, input_next.input[0])
inputB, shapeB = values.get_init_value_and_shape(model, input_next.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('inputB is not in initilizer')
inputB = values.get_constant_value(model, input_next.input[1])
if len(shapeA) == 3 and len(shapeB) == 2:
logger.debug('--- got MatMul node: {}'.format(input_next.name))
#node_list = [input_next, input_pp_pre, input_p_pre, input_pre]
#node_dict['node_list'] = node_list
node_dict['MatMul1'] = input_next
node_dict['matmulA1_Shape'] = shapeA
node_dict['inputB1'] = inputB
node_dict['inputB1_name'] = input_next.input[1]
node_dict['matmulB1_Shape'] = shapeB
input_nnext, ok = get_next_node_by_output(model, input_next.output[0])
if ok == 0 and input_nnext.op_type == 'Add':
addA_name = input_nnext.input[0]
addA, shapeA = values.get_init_value_and_shape(model, input_nnext.input[0])
node_dict['addFirst'] = True
if len(shapeA) == 0:
addA_name = input_nnext.input[1]
addA, shapeA = values.get_init_value_and_shape(model, input_nnext.input[1])
node_dict['addFirst'] = False
if len(shapeA) == 1:
node_dict['Add1'] = input_nnext
logger.debug('--- got Add1 node: {}'.format(input_nnext.name))
input_nnnext, ok = get_all_next_node_by_output(model, input_nnext.output[0])
if len(input_nnnext) == 2:
if (input_nnnext[0].op_type == 'Div' and input_nnnext[1].op_type == 'Mul') or \
(input_nnnext[0].op_type == 'Mul' and input_nnnext[0].op_type == 'Div'):
mul_node = input_nnnext[0]
div_node = input_nnnext[1]
if input_nnnext[1].op_type == 'Mul':
mul_node = input_nnnext[1]
div_node = input_nnnext[0]
node_dict['Div'] = div_node
node_dict['Mul'] = mul_node
logger.debug('--- got Div node: {}'.format(div_node.name))
input_nnnnext, ok = get_next_node_by_output(model, mul_node.output[0])
if ok == 0 and input_nnnnext.op_type == 'Mul':
mulB, shapeB = values.get_init_value_and_shape(model, input_nnnnext.input[1])
if len(mulB) > 0:
#######################
##############################
logger.debug('--- got mul2 node: {}'.format(input_nnnnext.name))
input_nnnnnext, ok = get_next_node_by_output(model, input_nnnnext.output[0])
if ok == 0 and input_nnnnnext.op_type == 'MatMul':
shapeA = values.get_tensor_shape_by_name(model, input_nnnnnext.input[0])
inputB, shapeB = values.get_init_value_and_shape(model, input_nnnnnext.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('inputB is not in initilizer')
inputB = values.get_constant_value(model, input_nnnnnext.input[1])
if len(shapeA) == 3 and len(shapeB) == 2:
logger.debug('--- got MatMul2 node: {}'.format(input_nnnnnext.name))
#node_list = [input_nnnnnext, input_pp_pre, input_p_pre, input_pre]
#node_dict['node_list'] = node_list
node_dict['MatMul2'] = input_nnnnnext
node_dict['matmulA2_Shape'] = shapeA
node_dict['inputB2'] = inputB
node_dict['inputB2_name'] = input_nnnnnext.input[1]
node_dict['matmulB2_Shape'] = shapeB
input_nnnnnnext, ok = get_next_node_by_output(model, input_nnnnnext.output[0])
if ok == 0 and input_nnnnnnext.op_type == 'Add':
logger.debug('--- got Add2 node: {}'.format(input_nnnnnnext.name))
##########
addA_name = input_nnnnnnext.input[0]
addA, shapeA = values.get_init_value_and_shape(model, input_nnnnnnext.input[0])
node_dict['addFirst2'] = True
if len(shapeA) == 0:
addA_name = input_nnnnnnext.input[1]
addA, shapeA = values.get_init_value_and_shape(model, input_nnnnnnext.input[1])
node_dict['addFirst2'] = False
if len(shapeA) == 1:
node_dict['Add2'] = input_nnnnnnext
next_node, ok = get_next_node_by_output(model, input_nnnnnnext.output[0])
if ok == 0 and next_node.op_type == 'Add':
logger.debug('--- got last Add node: {}'.format(next_node.name))
res = 0
node_dict['NextAdd'] = next_node
node_dict['NextAddInput1'] = False
if next_node.input[0] == input_nnnnnnext.output[0]:
node_dict['NextAddInput1'] = True
return node_dict, res
def get_mul_add_block(model):
logger.debug('into get_mul_add_block')
node_list = []
for node in model.graph.node:
if node.op_type == 'Mul':
#print('got mul:', node.name)
is_init = False
for init in model.graph.initializer:
if init.name == node.input[0] or init.name == node.input[1]:
is_init = True
break
if is_init == False:
dataA = values.get_constant_value(model, node.input[0])
if len(dataA) == 0:
dataA = values.get_constant_value(model, node.input[1])
if dataA != []:
is_init = True
if is_init == True:
#print('----got mul:', node.name)
next_node, ok = get_next_node_by_output(model, node.output[0])
if ok == 0 and next_node.op_type == 'Add':
##############
#print('----got add:', next_node.name)
is_init = False
for init in model.graph.initializer:
if init.name == next_node.input[1]:
is_init = True
break
if is_init == False:
dataA = values.get_constant_value(model, next_node.input[1])
if dataA != []:
is_init = True
if is_init == True:
#print('get_all_next_node_by_output---', next_node.output, node.name)
next_node_list, ok = get_all_next_node_by_output(model, next_node.output[0])
if ok == 0:
#print('next_node_list:', len(next_node_list))
if len(next_node_list) == 2:
#print('got next_node_list:', next_node_list[0].op_type, next_node_list[1].op_type)
if (next_node_list[0].op_type == 'Add' and next_node_list[1].op_type == 'MatMul') or \
(next_node_list[0].op_type == 'MatMul' and next_node_list[1].op_type == 'Add'):
logger.debug('got it~')
matmul_node = next_node_list[0]
if next_node_list[1].op_type == 'MatMul':
matmul_node = next_node_list[1]
node_dict, ret = get_matmul_block_one(model, matmul_node)
if ret == 0:
#print('got node dict:', node_dict)
node_dict['currentAdd'] = next_node
node_list.append(node_dict)
elif len(next_node_list) == 1: #for telecom transform model
if next_node_list[0].op_type == 'MatMul':
logger.debug('got Add~~')
matmul_node = next_node_list[0]
node_dict, ret = get_matmul_block_one(model, matmul_node)
if ret == 0:
#print('got node dict:', node_dict)
node_dict['currentAdd'] = next_node
node_list.append(node_dict)
return node_list
def handle_mul_add_block(model, pattern):
node_list = get_mul_add_block(model)
#if len(node_list) > 0:
for node_dict in node_list:
logger.debug('++++++++++++++++++++++')
logger.debug('Add1: {}'.format(node_dict['Add1'].name))
logger.debug('Add2: {}'.format(node_dict['Add2'].name))
logger.debug('++++++++++++++++++++++')
matmul1 = node_dict['MatMul1']
add1 = node_dict['Add1']
matmul2 = node_dict['MatMul2']
add2 = node_dict['Add2']
currentAdd = node_dict['currentAdd']
nextAdd = node_dict['NextAdd']
nextAddInput1 = node_dict['NextAddInput1']
div_node = node_dict['Div']
###add transpose
ts_name = currentAdd.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, currentAdd.output[0])
ts_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[currentAdd.output[0]],
outputs=[ts_output_name],
perm=[0,2,1])
model.graph.value_info.append(transpose_output)
###add reshape-1
rs_name = currentAdd.name + '_reshape_1_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [ts_output_shape[0], ts_output_shape[1], 1, ts_output_shape[2]]
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape_name = currentAdd.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs_node = onnx.helper.make_node(
'Reshape',
name=rs_name,
inputs=[ts_output_name, const_shape_name],
outputs=[rs_output_name])
model.graph.value_info.append(rs_output)
#########################
insert_node(model, rs_node, matmul1)
matmul1.input[0] = rs_output_name
insert_node(model, ts_node, rs_node)
if pattern != 5:
if nextAddInput1 == True:
nextAdd.input[1] = ts_output_name
else:
nextAdd.input[0] = ts_output_name
#MatMul1--->Conv
matmul1.op_type = 'Conv'
logger.debug('-----reuse MatMul to Conv')
const_x_name = matmul1.name + '_to_conv_x_'
v = node_dict['inputB1']
old_dims = [node_dict['matmulB1_Shape'][0], node_dict['matmulB1_Shape'][1]]
dims_ = [node_dict['matmulB1_Shape'][1], node_dict['matmulB1_Shape'][0],1,1]
operation.remove_initializer_if_necessary_by_name(model, node_dict['inputB1_name'], matmul1)
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
model.graph.initializer.append(const_x_tensor)
matmul1.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
matmul1.attribute.append(attr)
if node_dict['addFirst'] == True:
matmul1.input.append(add1.input[0])
else:
matmul1.input.append(add1.input[1])
output_shape = values.get_tensor_shape_by_name(model, matmul1.output[0])
conv_output_shape = [output_shape[0], output_shape[2], 1, output_shape[1]]
if pattern == 5:
conv_output_shape = [output_shape[1], output_shape[2], 1, output_shape[0]]
update_tensor_shape(model, matmul1.output[0], conv_output_shape)
#Add1--->Reshape
add1.op_type = 'Reshape'
del add1.attribute[:]
rs_name = add1.name + '_reshape_1_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [conv_output_shape[0], conv_output_shape[1], conv_output_shape[3]]
logger.debug('-----rs_output_shape: {}'.format(rs_output_shape))
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape_name = add1.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
if node_dict['addFirst'] == True:
add1.input[0] = add1.input[1]
add1.input[1] = const_shape_name
update_tensor_shape(model, add1.output[0], rs_output_shape)
#################################
#################################
###add reshape-1
rs2_name = matmul2.name + '_reshape_1_'
rs2_output_name = rs2_name + '_output_'
rs2_output_shape = [rs_output_shape[0], rs_output_shape[1], 1, rs_output_shape[2]]
rs_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape_name = matmul2.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs2_node = onnx.helper.make_node(
'Reshape',
name=rs2_name,
inputs=[matmul2.input[0], const_shape_name],
outputs=[rs2_output_name])
model.graph.value_info.append(rs_output)
insert_node(model, rs2_node, matmul2)
matmul2.input[0] = rs2_output_name
#MatMul2--->Conv
matmul2.op_type = 'Conv'
logger.debug('++++++reuse MatMul to Conv')
const_x_name = matmul2.name + '_to_conv_x_'
v = node_dict['inputB2']
old_dims = [node_dict['matmulB2_Shape'][0], node_dict['matmulB2_Shape'][1]]
dims_ = [node_dict['matmulB2_Shape'][1], node_dict['matmulB2_Shape'][0],1,1]
operation.remove_initializer_if_necessary_by_name(model, node_dict['inputB2_name'], matmul2)
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
model.graph.initializer.append(const_x_tensor)
matmul2.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
matmul2.attribute.append(attr)
if node_dict['addFirst2'] == True:
B = add2.input[0]
else:
B = add2.input[1]
matmul2.input.append(B)
output_shape = values.get_tensor_shape_by_name(model, matmul2.output[0])
conv_output_shape = [output_shape[0], output_shape[2], 1, output_shape[1]]
if pattern == 5:
conv_output_shape = [output_shape[1], output_shape[2], 1, output_shape[0]]
update_tensor_shape(model, matmul2.output[0], conv_output_shape)
#Add2--->Reshape
add2.op_type = 'Reshape'
del add2.attribute[:]
rs2_name = add2.name + '_reshape_1_'
rs2_output_name = rs2_name + '_output_'
rs2_output_shape = [conv_output_shape[0], conv_output_shape[1], conv_output_shape[3]]
rs_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape_name = add2.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
if node_dict['addFirst2'] == True:
add2.input[0] = add2.input[1]
add2.input[1] = const_shape_name
update_tensor_shape(model, add2.output[0], rs2_output_shape)
######update tensor shape
div_output_shape = values.get_tensor_shape_by_name(model, div_node.output[0])
new_shape = [div_output_shape[0], div_output_shape[2], div_output_shape[1]]
if pattern == 5:
new_shape = [div_output_shape[1], div_output_shape[2], div_output_shape[0]]
update_tensor_shape(model, div_node.output[0], new_shape)
erf_node, ok = get_next_node_by_output(model, div_node.output[0])
if ok == 0 and erf_node.op_type == 'Erf':
erf_output_shape = values.get_tensor_shape_by_name(model, erf_node.output[0])
new_shape = [erf_output_shape[0], erf_output_shape[2], erf_output_shape[1]]
if pattern == 5:
new_shape = [erf_output_shape[1], erf_output_shape[2], erf_output_shape[0]]
update_tensor_shape(model, erf_node.output[0], new_shape)
add_node_internal, ok = get_next_node_by_output(model, erf_node.output[0])
if ok == 0 and add_node_internal.op_type == 'Add':
addi_output_shape = values.get_tensor_shape_by_name(model, add_node_internal.output[0])
new_shape = [addi_output_shape[0], addi_output_shape[2], addi_output_shape[1]]
if pattern == 5:
new_shape = [addi_output_shape[1], addi_output_shape[2], addi_output_shape[0]]
update_tensor_shape(model, add_node_internal.output[0], new_shape)
mul_node1, ok = get_next_node_by_output(model, add_node_internal.output[0])
if ok == 0 and mul_node1.op_type == 'Mul':
mul1_output_shape = values.get_tensor_shape_by_name(model, mul_node1.output[0])
new_shape = [mul1_output_shape[0], mul1_output_shape[2], mul1_output_shape[1]]
if pattern == 5:
new_shape = [mul1_output_shape[1], mul1_output_shape[2], mul1_output_shape[0]]
update_tensor_shape(model, mul_node1.output[0], new_shape)
mul_node2, ok = get_next_node_by_output(model, mul_node1.output[0])
if ok == 0 and mul_node2.op_type == 'Mul':
mul2_output_shape = values.get_tensor_shape_by_name(model, mul_node2.output[0])
new_shape = [mul2_output_shape[0], mul2_output_shape[2], mul2_output_shape[1]]
if pattern == 5:
new_shape = [mul2_output_shape[1], mul2_output_shape[2], mul2_output_shape[0]]
update_tensor_shape(model, mul_node2.output[0], new_shape)
######insert Transpose before ReduceMean and Sub
if pattern == 5:
###add transpose
ts3_name = nextAdd.name + '_transpose_'
ts3_output_name = ts3_name + '_output_'
ts3_output_shape = [rs2_output_shape[0], rs2_output_shape[2], rs2_output_shape[1]]
ts3_output = onnx.helper.make_tensor_value_info(ts3_output_name, onnx.TensorProto.FLOAT, ts3_output_shape)
ts3_node = onnx.helper.make_node(
'Transpose',
name=ts3_name,
inputs=[add2.output[0]],
outputs=[ts3_output_name],
perm=[0,2,1])
model.graph.value_info.append(ts3_output)
insert_node(model, ts3_node, add2)
nextAdd.input[1] = ts3_output_name
else:
update_tensor_shape(model, nextAdd.output[0], rs2_output_shape)
rm_sub, ok = get_all_next_node_by_output(model, nextAdd.output[0])
if ok == 0 and len(rm_sub) == 2:
logger.debug('got reducemean and sub node---')
sub_node = rm_sub[0]
rm_node = rm_sub[1]
if rm_sub[0].op_type == 'ReduceMean':
sub_node = rm_sub[1]
rm_node = rm_sub[0]
###add transpose
ts3_name = nextAdd.name + '_transpose_'
ts3_output_name = ts3_name + '_output_'
add3_output_shape = values.get_tensor_shape_by_name(model, nextAdd.output[0])
ts3_output_shape = [add3_output_shape[0], add3_output_shape[2], add3_output_shape[1]]
ts3_output = onnx.helper.make_tensor_value_info(ts3_output_name, onnx.TensorProto.FLOAT, ts3_output_shape)
ts3_node = onnx.helper.make_node(
'Transpose',
name=ts3_name,
inputs=[nextAdd.output[0]],
outputs=[ts3_output_name],
perm=[0,2,1])
model.graph.value_info.append(ts3_output)
insert_node(model, ts3_node, sub_node)
sub_node.input[0] = ts3_output_name
rm_node.input[0] = ts3_output_name
def get_matmul_block_two(model, matmul_node):
logger.debug('into get_matmul_block_two')
res = -1
node_dict = {}
#input_next, ok = get_next_node_by_output(model, input_)
input_next = matmul_node
if input_next.op_type == 'MatMul':
shapeA = values.get_tensor_shape_by_name(model, input_next.input[0])
inputB, shapeB = values.get_init_value_and_shape(model, input_next.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('inputB is not in initilizer')
inputB = values.get_constant_value(model, input_next.input[1])
if len(shapeA) == 3 and len(shapeB) == 2:
logger.debug('++++ got MatMul node: {}'.format(input_next.name))
#node_list = [input_next, input_pp_pre, input_p_pre, input_pre]
#node_dict['node_list'] = node_list
node_dict['MatMul1'] = input_next
node_dict['matmulA1_Shape'] = shapeA
node_dict['inputB1'] = inputB
node_dict['matmulB1_Shape'] = shapeB
input_nnext, ok = get_next_node_by_output(model, input_next.output[0])
if ok == 0 and input_nnext.op_type == 'Add':
addA_name = input_nnext.input[0]
addA, shapeA = values.get_init_value_and_shape(model, input_nnext.input[0])
if len(shapeA) == 1:
node_dict['Add1'] = input_nnext
logger.debug('++++ got Add1 node: {}'.format(input_nnext.name))
input_nnnext, ok = get_next_node_by_output(model, input_nnext.output[0])
if ok == 0 and input_nnnext.op_type == 'Relu':
node_dict['Relu'] = input_nnnext
logger.debug('++++ got Relu node: {}'.format(input_nnnext.name))
input_nnnnext, ok = get_next_node_by_output(model, input_nnnext.output[0])
if ok == 0 and input_nnnnext.op_type == 'MatMul':
shapeA = values.get_tensor_shape_by_name(model, input_nnnnext.input[0])
inputB, shapeB = values.get_init_value_and_shape(model, input_nnnnext.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('inputB is not in initilizer')
inputB = values.get_constant_value(model, input_nnnnext.input[1])
if len(shapeA) == 3 and len(shapeB) == 2:
logger.debug('++++ got MatMul2 node: {}'.format(input_nnnnext.name))
#node_list = [input_nnnnext, input_pp_pre, input_p_pre, input_pre]
#node_dict['node_list'] = node_list
node_dict['MatMul2'] = input_nnnnext
node_dict['matmulA2_Shape'] = shapeA
node_dict['inputB2'] = inputB
node_dict['matmulB2_Shape'] = shapeB
input_nnnnnext, ok = get_next_node_by_output(model, input_nnnnext.output[0])
if ok == 0 and input_nnnnnext.op_type == 'Add':
logger.debug('++++ got Add2 node: {}'.format(input_nnnnnext.name))
#addA_name = input_nnnnnext.input[0]
#if len(shapeA) == 1:
node_dict['Add2'] = input_nnnnnext
next_node, ok = get_next_node_by_output(model, input_nnnnnext.output[0])
if ok == 0 and next_node.op_type == 'Add':
logger.debug('++++ got last Add node: {}'.format(next_node.name))
res = 0
node_dict['NextAdd'] = next_node
return node_dict, res
#Mul->Add->MatMul->Add->Relu->MatMul->Add
def get_mul_add_block_two(model):
node_list = []
for node in model.graph.node:
if node.op_type == 'Mul':
#print('got mul:', node.name)
is_init = False
for init in model.graph.initializer:
if init.name == node.input[1]:
is_init = True
break
if is_init == False:
dataA = values.get_constant_value(model, node.input[1])
if dataA != []:
is_init = True
if is_init == True:
#print('----got mul:', node.name)
next_node, ok = get_next_node_by_output(model, node.output[0])
if ok == 0 and next_node.op_type == 'Add':
##############
#print('----got add:', next_node.name)
is_init = False
for init in model.graph.initializer:
if init.name == next_node.input[1]:
is_init = True
break
if is_init == False:
dataA = values.get_constant_value(model, next_node.input[1])
if dataA != []:
is_init = True
if is_init == True:
#print('get_all_next_node_by_output---', next_node.output, node.name)
matmul_node, ok = get_next_node_by_output(model, next_node.output[0])
if ok == 0 and matmul_node.op_type == 'MatMul':
logger.debug('got match MatMul~')
node_dict, ret = get_matmul_block_two(model, matmul_node)
if ret == 0:
#print('got node dict:', node_dict)
node_dict['currentAdd'] = next_node
node_list.append(node_dict)
return node_list
#Mul->Add->MatMul->Add->Relu->MatMul->Add
def handle_mul_add_block_two(model):
node_list = get_mul_add_block_two(model)
#if len(node_list) > 0:
for node_dict in node_list:
logger.debug('##############################')
logger.debug('Add1: {}'.format(node_dict['Add1'].name))
logger.debug('Add2: {}'.format(node_dict['Add2'].name))
logger.debug('###############################')
matmul1 = node_dict['MatMul1']
add1 = node_dict['Add1']
matmul2 = node_dict['MatMul2']
add2 = node_dict['Add2']
currentAdd = node_dict['currentAdd']
nextAdd = node_dict['NextAdd']
relu_node = node_dict['Relu']
###add transpose
ts_name = currentAdd.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, currentAdd.output[0])
ts_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[currentAdd.output[0]],
outputs=[ts_output_name],
perm=[0,2,1])
model.graph.value_info.append(transpose_output)
###add reshape-1
rs_name = currentAdd.name + '_reshape_1_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [ts_output_shape[0], ts_output_shape[1], 1, ts_output_shape[2]]
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape_name = currentAdd.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs_node = onnx.helper.make_node(
'Reshape',
name=rs_name,
inputs=[ts_output_name, const_shape_name],
outputs=[rs_output_name])
model.graph.value_info.append(rs_output)
#########################
insert_node(model, rs_node, matmul1)
matmul1.input[0] = rs_output_name
insert_node(model, ts_node, rs_node)
#nextAdd.input[0] = ts_output_name
#MatMul1--->Conv
matmul1.op_type = 'Conv'
logger.debug('+++++ reuse MatMul to Conv')
const_x_name = matmul1.name + '_to_conv_x_'
v = node_dict['inputB1']
old_dims = [node_dict['matmulB1_Shape'][0], node_dict['matmulB1_Shape'][1]]
dims_ = [node_dict['matmulB1_Shape'][1], node_dict['matmulB1_Shape'][0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape:', A.shape)
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
model.graph.initializer.append(const_x_tensor)
matmul1.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
matmul1.attribute.append(attr)
matmul1.input.append(add1.input[0])
output_shape = values.get_tensor_shape_by_name(model, matmul1.output[0])
conv_output_shape = [output_shape[0], output_shape[2], 1, output_shape[1]]
update_tensor_shape(model, matmul1.output[0], conv_output_shape)
#Add1--->Reshape
add1.op_type = 'Reshape'
del add1.attribute[:]
rs_name = add1.name + '_reshape_1_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [conv_output_shape[0], conv_output_shape[1], conv_output_shape[3]]
logger.debug('-----rs_output_shape: {}'.format(rs_output_shape))
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape_name = add1.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
add1.input[0] = add1.input[1]
add1.input[1] = const_shape_name
update_tensor_shape(model, add1.output[0], rs_output_shape)
update_tensor_shape(model, relu_node.output[0], rs_output_shape)
#################################
#################################
###add reshape-1
rs2_name = matmul2.name + '_reshape_1_'
rs2_output_name = rs2_name + '_output_'
rs2_output_shape = [rs_output_shape[0], rs_output_shape[1], 1, rs_output_shape[2]]
rs_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape_name = matmul2.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs2_node = onnx.helper.make_node(
'Reshape',
name=rs2_name,
inputs=[matmul2.input[0], const_shape_name],
outputs=[rs2_output_name])
model.graph.value_info.append(rs_output)
insert_node(model, rs2_node, matmul2)
matmul2.input[0] = rs2_output_name
#MatMul2--->Conv
matmul2.op_type = 'Conv'
logger.debug('++++++reuse MatMul2 to Conv')
const_x_name = matmul2.name + '_to_conv_x_'
v = node_dict['inputB2']
old_dims = [node_dict['matmulB2_Shape'][0], node_dict['matmulB2_Shape'][1]]
dims_ = [node_dict['matmulB2_Shape'][1], node_dict['matmulB2_Shape'][0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
model.graph.initializer.append(const_x_tensor)
matmul2.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
matmul2.attribute.append(attr)
B = add2.input[0]
matmul2.input.append(B)
output_shape = values.get_tensor_shape_by_name(model, matmul2.output[0])
conv_output_shape = [output_shape[0], output_shape[2], 1, output_shape[1]]
update_tensor_shape(model, matmul2.output[0], conv_output_shape)
#Add2--->Reshape
add2.op_type = 'Reshape'
del add2.attribute[:]
rs2_name = add2.name + '_reshape_1_'
rs2_output_name = rs2_name + '_output_'
rs2_output_shape = [conv_output_shape[0], conv_output_shape[1], conv_output_shape[3]]
rs_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape_name = add2.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
add2.input[0] = add2.input[1]
add2.input[1] = const_shape_name
update_tensor_shape(model, add2.output[0], rs2_output_shape)
update_tensor_shape(model, nextAdd.output[0], rs2_output_shape)
nextAdd_next_list, ok = get_all_next_node_by_output(model, nextAdd.output[0])
if ok == 0:
where_node = None
tp_node = None
add_node = None
for node in nextAdd_next_list:
logger.debug('----nextAdd_next_list, node: {}'.format(node.name))
if node.op_type == 'Where':
where_node = node
if node.op_type == 'Transpose':
tp_node = node
if node.op_type == 'Add':
add_node = node
if where_node != None and tp_node != None and add_node != None:
where_node.input[1] = tp_node.output[0]
elif where_node != None and tp_node == None:
###add transpose
tp_name = nextAdd.name + '_transpose_'
tp_output_name = tp_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, nextAdd.output[0])
tp_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
tp_output = onnx.helper.make_tensor_value_info(tp_output_name, onnx.TensorProto.FLOAT, tp_output_shape)
tp_node = onnx.helper.make_node(
'Transpose',
name=tp_name,
inputs=[nextAdd.output[0]],
outputs=[tp_output_name],
perm=[0,2,1])
model.graph.value_info.append(tp_output)
insert_node(model, tp_node, where_node)
where_node.input[1] = tp_output_name
def get_last_group(model):
graph_output = []
node_dict = {}
res = -1
for o in model.graph.output:
graph_output.append(o.name)
for node in model.graph.node:
if node.output[0] in graph_output:
#print('got mul:', node.name)
if node.op_type == 'LogSoftmax' or node.op_type == 'Softmax':
logger.debug('got LogSoftmax node: {}'.format(node.name))
node_dict['LogSoftmax'] = node
add_node, ok = get_prev_node_by_input(model, node.input[0])
if ok == 0 and add_node.op_type == 'Add':
addA_name = add_node.input[0]
addA, shapeA = values.get_init_value_and_shape(model, add_node.input[0])
if len(shapeA) == 1:
node_dict['Add'] = add_node
logger.debug('!!!!! got Add node: {}'.format(add_node.name))
matmul_node, ok = get_prev_node_by_input(model, add_node.input[1])
if ok == 0 and matmul_node.op_type == 'MatMul':
logger.debug('got MatMul node: {}'.format(matmul_node.name))
shapeA = values.get_tensor_shape_by_name(model, matmul_node.input[0])
inputB, shapeB = values.get_init_value_and_shape(model, matmul_node.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('inputB is not in initilizer')
inputB = values.get_constant_value(model, matmul_node.input[1])
if len(shapeA) == 3 and len(shapeB) == 2:
logger.debug('++++ got MatMul2 node: {}'.format(matmul_node.name))
node_dict['MatMul'] = matmul_node
node_dict['matmulA_Shape'] = shapeA
node_dict['inputB'] = inputB
node_dict['matmulB_Shape'] = shapeB
add_node2, ok = get_prev_node_by_input(model, matmul_node.input[0])
if ok == 0 and add_node2.op_type == 'Add':
logger.debug('++++ got Add2 node: {}'.format(add_node2.name))
node_dict['Add2'] = add_node2
res = 0
break
return node_dict, res
#Mul->Add->MatMul->Add->LogSoftmax
def handle_last_group(model):
node_dict, ok = get_last_group(model)
if ok == 0:
logger.debug('start handle_last_group')
matmul_node = node_dict['MatMul']
add_node = node_dict['Add']
add2_node = node_dict['Add2']
ls_node = node_dict['LogSoftmax']
###add transpose
ts_name = add2_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, add2_node.output[0])
ts_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[add2_node.output[0]],
outputs=[ts_output_name],
perm=[0,2,1])
model.graph.value_info.append(transpose_output)
###add reshape
rs_name = add2_node.name + '_reshape_2_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [ts_output_shape[0], ts_output_shape[1], 1, ts_output_shape[2]]
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape2_name = add2_node.name + '_reshape2_data_'
const_shape2_tensor = onnx.helper.make_tensor(name=const_shape2_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape2_tensor)
rs_node = onnx.helper.make_node(
'Reshape',
name=rs_name,
inputs=[ts_output_name, const_shape2_name],
outputs=[rs_output_name])
model.graph.value_info.append(rs_output)
insert_node(model, rs_node, matmul_node)
matmul_node.input[0] = rs_output_name
insert_node(model, ts_node, rs_node)
#MatMul-->Conv
matmul_node.op_type = 'Conv'
const_x_name = matmul_node.name + '_to_conv_x_'
v = node_dict['inputB']
old_dims = [node_dict['matmulB_Shape'][0], node_dict['matmulB_Shape'][1]]
dims_ = [node_dict['matmulB_Shape'][1], node_dict['matmulB_Shape'][0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
model.graph.initializer.append(const_x_tensor)
matmul_node.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
matmul_node.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
matmul_node.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
matmul_node.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
matmul_node.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
matmul_node.attribute.append(attr)
matmul_node.input.append(add_node.input[0])
mm_output_shape = values.get_tensor_shape_by_name(model, matmul_node.output[0])
conv_output_shape = [mm_output_shape[0], mm_output_shape[2], 1, mm_output_shape[1]]
update_tensor_shape(model, matmul_node.output[0], conv_output_shape)
###########
add_node.op_type = 'Reshape'
reshape_output = add_node.output[0]
const_shape_name = add_node.name + '_to_reshape_'
add_output_shape = values.get_tensor_shape_by_name(model, add_node.output[0])
rs2_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
add_node.input[0] = add_node.input[1]
add_node.input[1] = const_shape_name
update_tensor_shape(model, add_node.output[0], rs2_output_shape)
###add transpose
ts2_name = add_node.name + '_transpose_'
ts2_output_name = ts2_name + '_output_'
ts2_output_shape = add_output_shape
transpose_output = onnx.helper.make_tensor_value_info(ts2_output_name, onnx.TensorProto.FLOAT, ts2_output_shape)
ts2_node = onnx.helper.make_node(
'Transpose',
name=ts2_name,
inputs=[add_node.output[0]],
outputs=[ts2_output_name],
perm=[0,2,1])
model.graph.value_info.append(transpose_output)
insert_node(model, ts2_node, ls_node)
ls_node.input[0] = ts2_output_name
def handle_add_combination_pattern_one(model):
ars_list = get_add_combination_pattern_one(model)
#print('handle_add_combination_pattern_one,ars_list:', ars_list)
if len(ars_list):
ars = ars_list[0]
add_node = ars['currentAdd']
next_add_node = ars['nextAdd']
sub_node = ars['Sub']
rm_node = ars['ReduceMean']
###add transpose
ts_name = add_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, add_node.output[0])
ts_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[add_node.output[0]],
outputs=[ts_output_name],
perm=[0,2,1])
model.graph.value_info.append(transpose_output)
###add reshape-1
rs_name = add_node.name + '_reshape_1_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [ts_output_shape[0], ts_output_shape[1]] #TBD
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape_name = add_node.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs_node = onnx.helper.make_node(
'Reshape',
name=rs_name,
inputs=[ts_output_name, const_shape_name],
outputs=[rs_output_name])
model.graph.value_info.append(rs_output)
###add reshape-2
rs2_name = add_node.name + '_reshape_2_'
rs2_output_name = rs2_name + '_output_'
rs2_output_shape = [add_output_shape[0], add_output_shape[2], add_output_shape[1]]
rs2_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape2_name = add_node.name + '_reshape2_data_'
const_shape2_tensor = onnx.helper.make_tensor(name=const_shape2_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape2_tensor)
rs2_node = onnx.helper.make_node(
'Reshape',
name=rs2_name,
inputs=[rs_output_name, const_shape2_name],
outputs=[rs2_output_name])
model.graph.value_info.append(rs2_output)
###add transpose2
ts2_name = add_node.name + '_transpose2_'
ts2_output_name = ts2_name + '_output_'
ts2_output = onnx.helper.make_tensor_value_info(ts2_output_name, onnx.TensorProto.FLOAT, [add_output_shape[0], add_output_shape[1], add_output_shape[2]])
model.graph.value_info.append(ts2_output)
ts2_node = onnx.helper.make_node(
'Transpose',
name=ts2_name,
inputs=[rs2_output_name],
outputs=[ts2_output_name],
perm=[0,2,1])
insert_node(model, ts2_node, rm_node)
rm_node.input[0] = ts2_output_name
sub_node.input[0] = ts2_output_name
insert_node(model, rs2_node, ts2_node)
insert_node(model, rs_node, rs2_node)
insert_node(model, ts_node, rs_node)
next_add_node.input[0] = rs2_output_name
ars_list2 = ars_list[1:]
for ars in ars_list2:
add_node = ars['currentAdd']
next_add_node = ars['nextAdd']
sub_node = ars['Sub']
rm_node = ars['ReduceMean']
###add transpose
add_output_shape = values.get_tensor_shape_by_name(model, add_node.output[0])
ts2_name = add_node.name + '_transpose2_'
ts2_output_name = ts2_name + '_output_'
ts2_output = onnx.helper.make_tensor_value_info(ts2_output_name, onnx.TensorProto.FLOAT, add_output_shape)
model.graph.value_info.append(ts2_output)
ts2_node = onnx.helper.make_node(
'Transpose',
name=ts2_name,
inputs=[add_node.output[0]],
outputs=[ts2_output_name],
perm=[0,2,1])
insert_node(model, ts2_node, rm_node)
rm_node.input[0] = ts2_output_name
sub_node.input[0] = ts2_output_name
def print_matmul_input_path(node_list, desp):
node_print = ''
first = True
for n in node_list:
if first == True:
node_print = n.name
first = False
else:
node_print = node_print + '-->'
node_print = node_print + n.name
logger.debug('{}:{}'.format(desp, node_print))
def update_tensor_shape(model, tensor_name, target_shape_list):
for vi in model.graph.value_info:
if vi.name == tensor_name:
dim = vi.type.tensor_type.shape.dim[0]
del vi.type.tensor_type.shape.dim[:]#[0]
# dim_proto_input.dim_param = 'bs'
for ss in target_shape_list:
dim.dim_value = ss
vi.type.tensor_type.shape.dim.append(dim)
break
def do_convert_pattern_one(model, matmul_dict, isInputA):
orig_reshape_name = ''
orig_matmul_name = ''
reshape_output = ''
bert_mode = 1
current_inputA_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].input[0])
current_inputB_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].input[1])
logger.debug('current: {}, inputA_shape: {}, inputB_shape: {}'.format(matmul_dict['current'].name, current_inputA_shape, current_inputB_shape))
map_key = ''
if isInputA == True:
inputA_shape = matmul_dict['A_matmul_AShape']
inputB_shape = matmul_dict['A_matmul_BShape']
path_node = matmul_dict['pathA']
if 'A_prev' in matmul_dict.keys():
map_key = matmul_dict['A_prev']
else:
inputA_shape = matmul_dict['B_matmul_AShape']
inputB_shape = matmul_dict['B_matmul_BShape']
path_node = matmul_dict['pathB']
if 'B_prev' in matmul_dict.keys():
map_key = matmul_dict['B_prev']
logger.debug('A inputA shape:{}, inputB shape:{}'.format(inputA_shape, inputB_shape))
logger.debug('B inputA shape:{}, inputB shape:{}'.format(matmul_dict['B_matmul_AShape'], matmul_dict['B_matmul_BShape']))
remove_matmul = False
remove_add = False
matmul_input0 = ''
matmul_output0 = ''
matmul_input0_shape = []
add_input1 = ''
add_input1_shape = []
reuse_transpose = False
reuse_reshape = False
if len(current_inputA_shape) == 3 and len(current_inputB_shape) == 3:
for node in path_node:
if node.op_type == 'MatMul':
matmul_input0 = node.input[0]
matmul_output0 = node.output[0]
matmul_input0_shape = values.get_tensor_shape_by_name(model, matmul_input0)
if inputA_shape[1] != inputB_shape[0]:
if map_key in transpose_node_map.keys():
#transpose_node_map[map_key]
logger.debug('------ found transpose_node_map, key: {}'.format(map_key))
#model.graph.node.remove(node)
operation.remove_onnx_node(model, node)
reuse_transpose = True
bert_mode = 0
else:
orig_matmul_name = node.name
logger.debug('---matmul+add-->conv, need same channel: {} {}, node.name: {}'.format(inputA_shape[1], inputB_shape[0], node.name))
node.op_type = 'Transpose'
attr = onnx.helper.make_attribute('perm', [0,2,1])
node.attribute.append(attr)
del node.input[1:]
update_tensor_shape(model, node.output[0], [inputA_shape[0], inputA_shape[2], inputA_shape[1]])
transpose_node_map[map_key] = node
logger.debug('---map_key is {}'.format(map_key))
else:
logger.debug('----Delete MatMul node: {}'.format(node.name))
#matmul_input0 = node.input[0]
#matmul_input0_shape = values.get_tensor_shape_by_name(model, matmul_input0)
#model.graph.node.remove(node)
operation.remove_onnx_node(model, node)
remove_matmul = True
bert_mode = 0
'''
###########
if node.op_type == 'Add':
logger.debug('reuse Add to Reshape')
orig_reshape_name = node.name
node.op_type = 'Reshape'
const_shape_name = node.name + '_to_reshape_'
rs_output_shape = [inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
if remove_matmul == True:
rs_output_shape = [matmul_input0_shape[1], matmul_input0_shape[2], 1, matmul_input0_shape[0]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
if reuse_transpose == False:
node.input[0] = matmul_output0
else:
node.input[0] = transpose_node_map[map_key].output[0]
node.input[1] = const_shape_name
update_tensor_shape(model, node.output[0], rs_output_shape)
##################
'''
if node.op_type == 'Add':
rs_output_shape = [inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
if remove_matmul == True:
rs_output_shape = [matmul_input0_shape[1], matmul_input0_shape[2], 1, matmul_input0_shape[0]]
if map_key in reshape_node_map.keys():
logger.debug('------ found reshape_node_map, key: {}'.format(map_key))
#model.graph.node.remove(node)
model.graph.node.remove(node)
reuse_reshape = True
else:
logger.debug('reuse Add to Reshape')
orig_reshape_name = node.name
node.op_type = 'Reshape'
const_shape_name = node.name + '_to_reshape_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
if reuse_transpose == False:
node.input[0] = matmul_output0
else:
node.input[0] = transpose_node_map[map_key].output[0]
node.input[1] = const_shape_name
update_tensor_shape(model, node.output[0], rs_output_shape)
reshape_node_map[map_key] = node
##################
if node.op_type == 'Reshape' and node.name != orig_reshape_name:
rs1_output_shape = values.get_tensor_shape_by_name(model, node.output[0])
logger.debug('-----reuse Reshape to Conv')
node.op_type = 'Conv'
const_x_name = node.name + '_to_conv_x_'
v = matmul_dict['A_inputB']
old_dims = [matmul_dict['A_matmul_BShape'][0], matmul_dict['A_matmul_BShape'][1]]
dims_ = [matmul_dict['A_matmul_BShape'][1], matmul_dict['A_matmul_BShape'][0],1,1]
if reuse_reshape == True:
node.input[0] = reshape_node_map[map_key].output[0]
operation.remove_initializer_if_necessary_by_name(model, node.input[1], node)
if isInputA == False:
v = matmul_dict['B_inputB']
old_dims = [matmul_dict['B_matmul_BShape'][0], matmul_dict['B_matmul_BShape'][1]]
dims_ = [matmul_dict['B_matmul_BShape'][1], matmul_dict['B_matmul_BShape'][0], 1, 1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,#[matmul_dict['A_matmul_BShape'][1], matmul_dict['A_matmul_BShape'][0],1,1],
vals=A)
model.graph.initializer.append(const_x_tensor)
node.input[1] = const_x_name
del node.attribute[:]
attr = onnx.helper.make_attribute('dilations', [1, 1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
node.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1, 1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
node.attribute.append(attr)
if isInputA == True:
node.input.append(matmul_dict['A_addA'])
else:
node.input.append(matmul_dict['B_addA'])
output_shape = rs_output_shape #[inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
update_tensor_shape(model, node.output[0], output_shape)
if node.op_type == 'Transpose' and node.name != orig_matmul_name:
logger.debug('-----reuse Transpose to Reshape, node.name: {}'.format(node.name))
node.op_type = 'Reshape'
del node.attribute[:]
reshape_output = node.output[0]
tp_output_shape = values.get_tensor_shape_by_name(model, reshape_output)
const_shape_name = node.name + '_to_reshape_'
if isInputA == True:
output_shape = [rs1_output_shape[1], rs1_output_shape[2], rs1_output_shape[0]]
else:
output_shape = [rs1_output_shape[1], rs1_output_shape[2], rs1_output_shape[0]]#[current_inputB_shape[0], current_inputB_shape[1], current_inputB_shape[2]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(output_shape)],
vals=output_shape)
model.graph.initializer.append(const_shape_tensor)
node.input.append(const_shape_name)
update_tensor_shape(model, node.output[0], output_shape)
follow_up_node = node
next_div_node, ok = get_next_node_by_output(model, reshape_output)
if ok == 0 and next_div_node.op_type == 'Div':
shape = values.get_tensor_shape_by_name(model, next_div_node.output[0])
update_tensor_shape(model, next_div_node.output[0], [shape[0], shape[2], shape[1]])
if isInputA == False:
ts_name = const_shape_name + '_transpose_'
ts_output_name = ts_name + '_output_'
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, [tp_output_shape[0], tp_output_shape[2],tp_output_shape[1]])
transpose_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[reshape_output],
outputs=[ts_output_name],
perm=[0,2,1])
insert_node(model, transpose_node, follow_up_node)
model.graph.value_info.append(transpose_output)
matmul_dict['current'].input[1] = matmul_dict['current'].input[0]
matmul_dict['current'].input[0] = ts_output_name
output_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].output[0])
update_tensor_shape(model, matmul_dict['current'].output[0], [output_shape[0], output_shape[2], output_shape[1]])
'''
else:
ts_name = const_shape_name + '_transpose_'
ts_output_name = ts_name + '_output_'
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, tp_output_shape)
transpose_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[reshape_output],
outputs=[ts_output_name],
perm=[1,0,2])
insert_node(model, transpose_node, follow_up_node)
next_div_node.input[0] = ts_output_name
model.graph.value_info.append(transpose_output)
#matmul_dict['current'].input[1] = matmul_dict['current'].input[0]
#matmul_dict['current'].input[0] = ts_output_name
#output_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].output[0])
#update_tensor_shape(model, matmul_dict['current'].output[0], [output_shape[0], output_shape[2], output_shape[1]])
'''
else:
for node in path_node:
if node.op_type == 'MatMul':
if inputA_shape[1] != inputB_shape[0]:
orig_matmul_name = node.name
logger.debug('matmul+add-->conv, need same channel: {} {}'.format(inputA_shape[1], inputB_shape[0]))
node.op_type = 'Transpose'
attr = onnx.helper.make_attribute('perm', [0,2,1])
node.attribute.append(attr)
del node.input[1:]
update_tensor_shape(model, node.output[0], [inputA_shape[0], inputA_shape[2], inputA_shape[1]])
transpose_node_map[map_key] = node
logger.debug('map_key is {}'.format(map_key))
else:
logger.debug('Delete MatMul node: {}'.format(node.name))
matmul_input0 = node.input[0]
matmul_input0_shape = values.get_tensor_shape_by_name(model, matmul_input0)
#model.graph.node.remove(node)
operation.remove_onnx_node(model, node)
remove_matmul = True
bert_mode = 0
if node.op_type == 'Add':
logger.debug('reuse Add to Reshape')
orig_reshape_name = node.name
node.op_type = 'Reshape'
const_shape_name = node.name + '_to_reshape_'
rs_output_shape = [inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
if remove_matmul == True:
rs_output_shape = [matmul_input0_shape[0], matmul_input0_shape[1], 1, matmul_input0_shape[2]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
node.input[0] = node.input[1]
if remove_matmul == True:
node.input[0] = matmul_input0
node.input[1] = const_shape_name
update_tensor_shape(model, node.output[0], rs_output_shape)
if node.op_type == 'Reshape' and node.name != orig_reshape_name:
logger.debug('reuse Reshape to Conv')
node.op_type = 'Conv'
const_x_name = node.name + '_to_conv_x_'
v = matmul_dict['A_inputB']
old_dims = [matmul_dict['A_matmul_BShape'][0], matmul_dict['A_matmul_BShape'][1]]
dims_ = [matmul_dict['A_matmul_BShape'][1], matmul_dict['A_matmul_BShape'][0],1,1]
if isInputA == False:
v = matmul_dict['B_inputB']
old_dims = [matmul_dict['B_matmul_BShape'][0], matmul_dict['B_matmul_BShape'][1]]
dims_ = [matmul_dict['B_matmul_BShape'][1], matmul_dict['B_matmul_BShape'][0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,#[matmul_dict['A_matmul_BShape'][1], matmul_dict['A_matmul_BShape'][0],1,1],
vals=A)
model.graph.initializer.append(const_x_tensor)
node.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
node.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
node.attribute.append(attr)
if isInputA == True:
node.input.append(matmul_dict['A_addA'])
else:
node.input.append(matmul_dict['B_addA'])
output_shape = rs_output_shape #[inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
update_tensor_shape(model, node.output[0], output_shape)
if node.op_type == 'Transpose' and node.name != orig_matmul_name:
logger.debug('reuse Transpose to Reshape, node.name: {}'.format(node.name))
node.op_type = 'Reshape'
del node.attribute[:]
reshape_output = node.output[0]
const_shape_name = node.name + '_to_reshape_'
if isInputA == True:
output_shape = [current_inputA_shape[0], current_inputA_shape[1], current_inputA_shape[3], current_inputA_shape[2]]
else:
#output_shape = [inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
output_shape = [current_inputB_shape[0], current_inputB_shape[1], current_inputB_shape[2], current_inputB_shape[3]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(output_shape)],
vals=output_shape)
model.graph.initializer.append(const_shape_tensor)
node.input.append(const_shape_name)
update_tensor_shape(model, node.output[0], output_shape)
follow_up_node = node
if isInputA == False:
ts_name = const_shape_name + '_transpose_'
ts_output_name = ts_name + '_output_'
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, [current_inputB_shape[0], current_inputB_shape[1], current_inputB_shape[3], current_inputB_shape[2]])
transpose_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[reshape_output],
outputs=[ts_output_name],
perm=[0,1,3,2])
insert_node(model, transpose_node, follow_up_node)
model.graph.value_info.append(transpose_output)
matmul_dict['current'].input[1] = matmul_dict['current'].input[0]
matmul_dict['current'].input[0] = ts_output_name
output_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].output[0])
update_tensor_shape(model, matmul_dict['current'].output[0], [output_shape[0], output_shape[1], output_shape[3], output_shape[2]])
return bert_mode
def do_convert_pattern_four(model, matmul_dict, isInputA):
orig_reshape_name = ''
orig_matmul_name = ''
reshape_output = ''
bert_mode = 1
current_inputA_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].input[0])
current_inputB_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].input[1])
logger.debug('---current_inputA_shape: {}'.format(current_inputA_shape))
logger.debug('---current_inputB_shape: '.format(current_inputB_shape))
map_key = ''
if isInputA == True:
inputA_shape = matmul_dict['A_matmul_AShape']
inputB_shape = matmul_dict['A_matmul_BShape']
path_node = matmul_dict['pathA']
if 'A_prev' in matmul_dict.keys():
map_key = matmul_dict['A_prev']
else:
inputA_shape = matmul_dict['B_matmul_AShape']
inputB_shape = matmul_dict['B_matmul_BShape']
path_node = matmul_dict['pathB']
if 'B_prev' in matmul_dict.keys():
map_key = matmul_dict['B_prev']
logger.debug('A inputA shape:{}, inputB shape:{}'.format(inputA_shape, inputB_shape))
logger.debug('B inputA shape:{}, inputB shape:{}'.format(matmul_dict['B_matmul_AShape'], matmul_dict['B_matmul_BShape']))
remove_matmul = False
matmul_input0 = ''
matmul_input0_shape = []
for node in path_node:
if node.op_type == 'MatMul':
if inputA_shape[1] != inputB_shape[0]:
orig_matmul_name = node.name
logger.debug('matmul+add-->conv, need same channel'.format(inputA_shape[1], inputB_shape[0]))
node.op_type = 'Transpose'
attr = onnx.helper.make_attribute('perm', [0,2,1])
node.attribute.append(attr)
del node.input[1:]
update_tensor_shape(model, node.output[0], [inputA_shape[0], inputA_shape[2], inputA_shape[1]])
transpose_node_map[map_key] = node
logger.debug('map_key is {}'.format(map_key))
else:
logger.debug('Delete MatMul node: {}'.format(node.name))
matmul_input0 = node.input[0]
matmul_input0_shape = values.get_tensor_shape_by_name(model, matmul_input0)
#model.graph.node.remove(node)
operation.remove_onnx_node(model, node)
remove_matmul = True
bert_mode = 0
if node.op_type == 'Add':
logger.debug('reuse Add to Reshape')
orig_reshape_name = node.name
node.op_type = 'Reshape'
const_shape_name = node.name + '_to_reshape_'
rs_output_shape = [inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
if remove_matmul == True:
rs_output_shape = [matmul_input0_shape[0], matmul_input0_shape[1], 1, matmul_input0_shape[2]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
node.input[0] = node.input[1]
if remove_matmul == True:
node.input[0] = matmul_input0
node.input[1] = const_shape_name
update_tensor_shape(model, node.output[0], rs_output_shape)
if node.op_type == 'Reshape' and node.name != orig_reshape_name:
logger.debug('reuse Reshape to Conv')
node.op_type = 'Conv'
const_x_name = node.name + '_to_conv_x_'
v = matmul_dict['A_inputB']
old_dims = [matmul_dict['A_matmul_BShape'][0], matmul_dict['A_matmul_BShape'][1]]
dims_ = [matmul_dict['A_matmul_BShape'][1], matmul_dict['A_matmul_BShape'][0],1,1]
if isInputA == False:
v = matmul_dict['B_inputB']
old_dims = [matmul_dict['B_matmul_BShape'][0], matmul_dict['B_matmul_BShape'][1]]
dims_ = [matmul_dict['B_matmul_BShape'][1], matmul_dict['B_matmul_BShape'][0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,#[matmul_dict['A_matmul_BShape'][1], matmul_dict['A_matmul_BShape'][0],1,1],
vals=A)
model.graph.initializer.append(const_x_tensor)
node.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
node.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
node.attribute.append(attr)
if isInputA == True:
node.input.append(matmul_dict['A_addA'])
else:
node.input.append(matmul_dict['B_addA'])
output_shape = rs_output_shape #[inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
update_tensor_shape(model, node.output[0], output_shape)
if node.op_type == 'Transpose' and node.name != orig_matmul_name:
logger.debug('reuse Transpose to Reshape')
node.op_type = 'Reshape'
del node.attribute[:]
reshape_output = node.output[0]
const_shape_name = node.name + '_to_reshape_'
if isInputA == True:
output_shape = [current_inputA_shape[0], current_inputA_shape[1], current_inputA_shape[3], current_inputA_shape[2]]
else:
#output_shape = [inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
output_shape = [current_inputB_shape[0], current_inputB_shape[1], current_inputB_shape[2], current_inputB_shape[3]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(output_shape)],
vals=output_shape)
model.graph.initializer.append(const_shape_tensor)
node.input.append(const_shape_name)
update_tensor_shape(model, node.output[0], output_shape)
follow_up_node = node
if isInputA == False:
ts_name = const_shape_name + '_transpose_'
ts_output_name = ts_name + '_output_'
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, [current_inputB_shape[0], current_inputB_shape[1], current_inputB_shape[3], current_inputB_shape[2]])
transpose_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[reshape_output],
outputs=[ts_output_name],
perm=[0,1,3,2])
insert_node(model, transpose_node, follow_up_node)
model.graph.value_info.append(transpose_output)
matmul_dict['current'].input[1] = matmul_dict['current'].input[0]
matmul_dict['current'].input[0] = ts_output_name
output_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].output[0])
update_tensor_shape(model, matmul_dict['current'].output[0], [output_shape[0], output_shape[1], output_shape[3], output_shape[2]])
return bert_mode
def do_convert_pattern_two(model, matmul_dict):
orig_reshape_name = ''
orig_matmul_name = ''
reshape_output = ''
current_inputA_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].input[0])
current_inputB_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].input[1])
logger.debug('do_convert_pattern_two, current_inputA_shape: {}'.format(current_inputA_shape))
logger.debug('do_convert_pattern_two, current_inputB_shape: {}'.format(current_inputB_shape))
inputA_shape = matmul_dict['B_matmul_AShape']
inputB_shape = matmul_dict['B_matmul_BShape']
path_node = matmul_dict['pathB']
logger.debug('do_convert_pattern_two, A inputA shape:{}, inputB shape:{}'.format(inputA_shape, inputB_shape))
logger.debug('do_convert_pattern_two, B inputA shape:{}, inputB shape:{}'.format(matmul_dict['B_matmul_AShape'], matmul_dict['B_matmul_BShape']))
reuse_transpose = False
remove_matmul = False
matmul_input0 = ''
matmul_input0_shape = []
add_input1 = ''
if len(current_inputA_shape) == 3 and len(current_inputB_shape) == 3:
map_key = ''
reuse_reshape = False
for node in path_node:
if node.op_type == 'MatMul':
map_key = node.input[0]
matmul_input0 = node.input[0]
matmul_input0_shape = values.get_tensor_shape_by_name(model, matmul_input0)
logger.debug('--handle MatMul: {}'.format(node.name))
if inputA_shape[1] != inputB_shape[0]:
#map_key = node.input[0]
if map_key in transpose_node_map.keys():
#transpose_node_map[map_key]
logger.debug('found transpose_node_map, key: {}'.format(map_key))
#model.graph.node.remove(node)
operation.remove_onnx_node(model, node)
reuse_transpose = True
else:
orig_matmul_name = node.name
logger.debug('#### matmul+add-->conv, need same channel: {} {}'.format(inputA_shape[1], inputB_shape[0]))
node.op_type = 'Transpose'
attr = onnx.helper.make_attribute('perm', [0,2,1])
node.attribute.append(attr)
del node.input[1:]
update_tensor_shape(model, node.output[0], [inputA_shape[0], inputA_shape[2], inputA_shape[1]])
else:
logger.debug('###### Delete MatMul node: {}'.format(node.name))
#model.graph.node.remove(node)
operation.remove_onnx_node(model, node)
remove_matmul = True
if node.op_type == 'Add':
if map_key in reshape_node_map.keys():
logger.debug('++++++ found reshape_node_map, key: {}'.format(map_key))
#model.graph.node.remove(node)
model.graph.node.remove(node)
reuse_reshape = True
else:
logger.debug('----delete add node: {}'.format(node.name))
#add_input1 = node.input[1]
#model.graph.node.remove(node)
#'''
orig_reshape_name = node.name
node.op_type = 'Reshape'
const_shape_name = node.name + '_to_reshape_'
rs_output_shape = [inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
if remove_matmul == True:
rs_output_shape = [matmul_input0_shape[1], matmul_input0_shape[2], 1, matmul_input0_shape[0]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
if remove_matmul == True:
node.input[0] = matmul_input0
else:
if reuse_transpose == False:
node.input[0] = node.input[1]
else:
node.input[0] = transpose_node_map[map_key].output[0]
node.input[1] = const_shape_name
update_tensor_shape(model, node.output[0], rs_output_shape)
#'''
if node.op_type == 'Reshape' and node.name != orig_reshape_name:
logger.debug('reuse Reshape to Conv')
node.op_type = 'Conv'
const_x_name = node.name + '_to_conv_x_'
v = matmul_dict['B_inputB']
rs1_output_shape = values.get_tensor_shape_by_name(model, node.output[0])
if reuse_reshape == True:
node.input[0] = reshape_node_map[map_key].output[0]
if isinstance(v, np.ndarray) == True:
A = v.reshape(matmul_dict['B_matmul_BShape'][0], matmul_dict['B_matmul_BShape'][1])
A = A.transpose()
A = A.reshape(matmul_dict['B_matmul_BShape'][1], matmul_dict['B_matmul_BShape'][0], 1, 1)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(matmul_dict['B_matmul_BShape'][0], matmul_dict['B_matmul_BShape'][1])
A = A.transpose()
A = A.reshape(matmul_dict['B_matmul_BShape'][1], matmul_dict['B_matmul_BShape'][0], 1, 1)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=[matmul_dict['B_matmul_BShape'][1], matmul_dict['B_matmul_BShape'][0],1, 1],
vals=A)
model.graph.initializer.append(const_x_tensor)
node.input[1] = const_x_name
del node.attribute[:]
attr = onnx.helper.make_attribute('dilations', [1, 1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
node.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1, 1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
node.attribute.append(attr)
node.input.append(matmul_dict['B_addA'])
output_shape = values.get_tensor_shape_by_name(model, node.input[0]) #[inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
update_tensor_shape(model, node.output[0], output_shape)
if node.op_type == 'Transpose' and node.name != orig_matmul_name:
logger.debug('reuse Transpose to Reshape')
node.op_type = 'Reshape'
del node.attribute[:]
reshape_output = node.output[0]
const_shape_name = node.name + '_to_reshape_'
shape = values.get_tensor_shape_by_name(model, node.output[0])
output_shape = [shape[0], shape[2], shape[1]]
if remove_matmul == True:
output_shape = output_shape #[current_inputB_shape[0], current_inputB_shape[1], current_inputB_shape[3], current_inputB_shape[2]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(output_shape)],
vals=output_shape)
model.graph.initializer.append(const_shape_tensor)
node.input.append(const_shape_name)
update_tensor_shape(model, node.output[0], output_shape)
follow_up_node = node
output_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].output[0])
#update_tensor_shape(model, matmul_dict['current'].output[0], [output_shape[0], output_shape[2], output_shape[1]])
if remove_matmul == True:
tmp = matmul_dict['current'].input[1]
matmul_dict['current'].input[1] = matmul_dict['current'].input[0]
matmul_dict['current'].input[0] = tmp
if remove_matmul == False:#isInputA == False:
ts_name = const_shape_name + '_transpose_'
ts_output_name = ts_name + '_output_'
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, [current_inputB_shape[0], current_inputB_shape[1], current_inputB_shape[2]])
transpose_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[reshape_output],
outputs=[ts_output_name],
perm=[0,2,1])
insert_node(model, transpose_node, follow_up_node)
model.graph.value_info.append(transpose_output)
#matmul_dict['current'].input[1] = matmul_dict['current'].input[0]
matmul_dict['current'].input[1] = ts_output_name
else:
for node in path_node:
if node.op_type == 'MatMul':
if inputA_shape[1] != inputB_shape[0]:
map_key = node.input[0]
if map_key in transpose_node_map.keys():
#transpose_node_map[map_key]
#model.graph.node.remove(node)
operation.remove_onnx_node(model, node)
reuse_transpose = True
else:
orig_matmul_name = node.name
logger.debug('matmul+add-->conv, need same channel: {} {}'.format(inputA_shape[1], inputB_shape[0]))
node.op_type = 'Transpose'
attr = onnx.helper.make_attribute('perm', [0,2,1])
node.attribute.append(attr)
del node.input[1:]
update_tensor_shape(model, node.output[0], [inputA_shape[0], inputA_shape[2], inputA_shape[1]])
else:
logger.debug('------Delete MatMul node: {}'.format(node.name))
matmul_input0 = node.input[0]
matmul_input0_shape = values.get_tensor_shape_by_name(model, matmul_input0)
#model.graph.node.remove(node)
operation.remove_onnx_node(model, node)
remove_matmul = True
if node.op_type == 'Add':
logger.debug('----reuse Add to Reshape: {}'.format(node.name))
orig_reshape_name = node.name
node.op_type = 'Reshape'
const_shape_name = node.name + '_to_reshape_'
rs_output_shape = [inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
if remove_matmul == True:
rs_output_shape = [matmul_input0_shape[0], matmul_input0_shape[1], 1, matmul_input0_shape[2]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
if remove_matmul == True:
node.input[0] = matmul_input0
else:
if reuse_transpose == False:
node.input[0] = node.input[1]
else:
node.input[0] = transpose_node_map[map_key].output[0]
node.input[1] = const_shape_name
update_tensor_shape(model, node.output[0], rs_output_shape)
if node.op_type == 'Reshape' and node.name != orig_reshape_name:
logger.debug('reuse Reshape to Conv')
node.op_type = 'Conv'
const_x_name = node.name + '_to_conv_x_'
v = matmul_dict['B_inputB']
if isinstance(v, np.ndarray) == True:
A = v.reshape(matmul_dict['B_matmul_BShape'][0], matmul_dict['B_matmul_BShape'][1])
A = A.transpose()
A = A.reshape(matmul_dict['B_matmul_BShape'][1], matmul_dict['B_matmul_BShape'][0], 1, 1)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(matmul_dict['B_matmul_BShape'][0], matmul_dict['B_matmul_BShape'][1])
A = A.transpose()
A = A.reshape(matmul_dict['B_matmul_BShape'][1], matmul_dict['B_matmul_BShape'][0], 1, 1)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=[matmul_dict['B_matmul_BShape'][1], matmul_dict['B_matmul_BShape'][0],1,1],
vals=A)
model.graph.initializer.append(const_x_tensor)
node.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
node.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
node.attribute.append(attr)
node.input.append(matmul_dict['B_addA'])
output_shape = rs_output_shape #[inputA_shape[0], inputA_shape[2], 1, inputA_shape[1]]
update_tensor_shape(model, node.output[0], output_shape)
if node.op_type == 'Transpose' and node.name != orig_matmul_name:
logger.debug('reuse Transpose to Reshape')
node.op_type = 'Reshape'
del node.attribute[:]
reshape_output = node.output[0]
const_shape_name = node.name + '_to_reshape_'
output_shape = [current_inputB_shape[0], current_inputB_shape[1], current_inputB_shape[2], current_inputB_shape[3]]
if remove_matmul == True:
output_shape = [current_inputB_shape[0], current_inputB_shape[1], current_inputB_shape[3], current_inputB_shape[2]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(output_shape)],
vals=output_shape)
model.graph.initializer.append(const_shape_tensor)
node.input.append(const_shape_name)
update_tensor_shape(model, node.output[0], output_shape)
follow_up_node = node
output_shape = values.get_tensor_shape_by_name(model, matmul_dict['current'].output[0])
update_tensor_shape(model, matmul_dict['current'].output[0], [output_shape[0], output_shape[1], output_shape[3], output_shape[2]])
if remove_matmul == True:
tmp = matmul_dict['current'].input[1]
matmul_dict['current'].input[1] = matmul_dict['current'].input[0]
matmul_dict['current'].input[0] = tmp
if remove_matmul == False:#isInputA == False:
ts_name = const_shape_name + '_transpose_'
ts_output_name = ts_name + '_output_'
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, [current_inputB_shape[0], current_inputB_shape[1], current_inputB_shape[3], current_inputB_shape[2]])
transpose_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[reshape_output],
outputs=[ts_output_name],
perm=[0,1,3,2])
insert_node(model, transpose_node, follow_up_node)
model.graph.value_info.append(transpose_output)
matmul_dict['current'].input[1] = matmul_dict['current'].input[0]
matmul_dict['current'].input[0] = ts_output_name
def do_convert_pattern_three(model, matmul_dict, ts_node):
orig_reshape_name = ''
orig_matmul_name = ''
reshape_output = ''
inputA_shape = matmul_dict['matmul_AShape']
inputB_shape = matmul_dict['matmul_BShape']
path_node = matmul_dict['node_list']
logger.debug('----- inputA shape:{}, inputB shape:{}'.format(inputA_shape, inputB_shape))
orig_reshape_name = ts_node.name
ts_node.op_type = 'Reshape'
del ts_node.attribute[:]
const_shape_name = ts_node.name + '_to_reshape_'
ts_input_shape = values.get_tensor_shape_by_name(model, ts_node.input[0])
rs_output_shape = [ts_input_shape[0], ts_input_shape[1]*ts_input_shape[2], ts_input_shape[3]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
ts_node.input.append(const_shape_name)
del ts_node.attribute[:]
update_tensor_shape(model, ts_node.output[0], rs_output_shape)
for node in path_node:
if node.op_type == 'Reshape':
logger.debug('----reuse Reshape')
rs2_output_shape = [rs_output_shape[0], rs_output_shape[1], 1, rs_output_shape[2]]
const_shape_name = node.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
node.input[1] = const_shape_name
update_tensor_shape(model, node.output[0], rs2_output_shape)
if node.op_type == 'MatMul':
logger.debug('---- reuse Matmul to Conv')
node.op_type = 'Conv'
const_x_name = node.name + '_to_conv_x_'
v = matmul_dict['inputB']
old_dims = [matmul_dict['matmul_BShape'][0], matmul_dict['matmul_BShape'][1]]
dims_ = [matmul_dict['matmul_BShape'][1], matmul_dict['matmul_BShape'][0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++ A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('--- A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
model.graph.initializer.append(const_x_tensor)
node.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
node.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
node.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
node.attribute.append(attr)
node.input.append(matmul_dict['addA'])
update_tensor_shape(model, node.output[0], rs2_output_shape)
if node.op_type == 'Add':
logger.debug('----- reuse Add to Reshape')
node.op_type = 'Reshape'
add_first = matmul_dict['addFirst']
const_shape_name = node.name + '_to_reshape_'
output_shape = [rs2_output_shape[0], rs2_output_shape[1], rs2_output_shape[3]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(output_shape)],
vals=output_shape)
model.graph.initializer.append(const_shape_tensor)
if add_first == True:
node.input[0] = node.input[1]
node.input[1] = const_shape_name
update_tensor_shape(model, node.output[0], output_shape)
next_node, ok = get_next_node_by_output(model, node.output[0])
update_tensor_shape(model, next_node.output[0], output_shape)
def cvt_matmul_add_to_conv(model, matmul_dict, pattern):
if matmul_dict['next'][0].op_type == 'Div' or matmul_dict['next'][0].op_type == 'Add' or matmul_dict['next'][0].op_type == 'Mul':
bert_mode = -1
logger.debug('cvt_matmul_add_to_conv, next: {}, current: {}'.format(matmul_dict['next'][0].name, matmul_dict['current'].name))
if pattern == 4:
if matmul_dict['A_MatMul_Add'] == True:
bert_mode = do_convert_pattern_four(model, matmul_dict, True)
if matmul_dict['B_MatMul_Add'] == True:
bert_mode = do_convert_pattern_four(model, matmul_dict, False)
else:
if matmul_dict['A_MatMul_Add'] == True:
bert_mode = do_convert_pattern_one(model, matmul_dict, True)
if matmul_dict['B_MatMul_Add'] == True:
bert_mode = do_convert_pattern_one(model, matmul_dict, False)
########### next node is Add
if matmul_dict['next'][0].op_type == 'Add':
current_node = matmul_dict['current']
next_node = matmul_dict['next'][0]
shape = values.get_tensor_shape_by_name(model, next_node.output[0])
update_tensor_shape(model, next_node.output[0], [shape[0], shape[1], shape[3], shape[2]])
###add transpose
ts_name = next_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, next_node.output[0])
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[current_node.output[0]],
outputs=[ts_output_name],
perm=[0,1,3,2])
model.graph.value_info.append(transpose_output)
insert_node(model, ts_node, next_node)
next_node.input[0] = ts_output_name
else:
next_node = matmul_dict['next'][0]
shape = values.get_tensor_shape_by_name(model, next_node.output[0])
logger.debug('next_node.name: {}, shape: {}'.format(next_node.name, shape))
if len(shape) == 3:
update_tensor_shape(model, next_node.output[0], [shape[0], shape[2], shape[1]])
###add transpose
ts_name = next_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, next_node.output[0])
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[next_node.output[0]],
outputs=[ts_output_name],
perm=[0,2,1])
model.graph.value_info.append(transpose_output)
current_node = matmul_dict['current']
logger.debug('insert ts_name: {}, current_node: {}'.format(ts_name, current_node.name))
insert_node(model, ts_node, matmul_dict['nnext'][0])
if matmul_dict['nnext'][0].op_type == 'Add':
if matmul_dict['nnext'][0].input[0] == next_node.output[0]:
matmul_dict['nnext'][0].input[0] = ts_output_name
else:
matmul_dict['nnext'][0].input[1] = ts_output_name
else:
if bert_mode == 0:
matmul_dict['nnext'][0].input[0] = ts_output_name
else:
matmul_dict['nnext'][0].input[1] = ts_output_name
else:
update_tensor_shape(model, next_node.output[0], [shape[0], shape[1], shape[3], shape[2]])
###add transpose
ts_name = next_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, next_node.output[0])
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[next_node.output[0]],
outputs=[ts_output_name],
perm=[0,1,3,2])
model.graph.value_info.append(transpose_output)
insert_node(model, ts_node, matmul_dict['nnext'][0])
if bert_mode == 0:
matmul_dict['nnext'][0].input[0] = ts_output_name
else:
matmul_dict['nnext'][0].input[1] = ts_output_name
elif matmul_dict['next'][0].op_type == 'Transpose':
if matmul_dict['B_MatMul_Add'] == True:
logger.debug('cvt_matmul_add_to_conv, BBBBBBBBBBBBBBBBBBBB')
do_convert_pattern_two(model, matmul_dict)
if matmul_dict['A_MatMul_Add'] == False:
logger.debug('cvt_matmul_add_to_conv, CCCCCCCCCCCCCCCCCCCCC')
path_node = matmul_dict['pathA']
node= path_node[0]
if node.op_type == 'Where' or node.op_type == 'Softmax':
shape = values.get_tensor_shape_by_name(model, node.output[0])
if len (shape) == 3:
return
###add transpose
logger.debug('insert Transpose before: {}'.format(node.name))
ts_name = node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
perm_ = [0,1,3,2]
if len (shape) == 3:
perm_ = [0, 2, 1]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, [shape[0], shape[2],shape[1]])
else:
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, [shape[0], shape[1],shape[3],shape[2]])
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[node.output[0]],
outputs=[ts_output_name],
perm=perm_)
model.graph.value_info.append(transpose_output)
insert_node(model, ts_node, matmul_dict['current'])
matmul_dict['current'].input[1] = ts_output_name
next_node = matmul_dict['next'][0]
nnext_node = matmul_dict['nnext'][0]
logger.debug('pattern two, Matmul name: {}, next_node: {}'.format(matmul_dict['current'].name, next_node.name))
if len (shape) == 3:
del next_node.attribute[:]
attr = onnx.helper.make_attribute('perm', [2,0,1])
next_node.attribute.append(attr)
op_dict, ok = get_matmul_input_path_pattern_two(model, next_node.output[0])
if op_dict and ok == 0:
for node in op_dict['node_list']:
logger.debug('got matmul+add path(pattern 2): {}'.format(node.name))
do_convert_pattern_three(model, op_dict, next_node)
def get_mul_add_transpose_matmul_block(model):
matm_list = []
for node in model.graph.node:
if node.op_type == 'Mul':
add_node, ok = get_next_node_by_output(model, node.output[0])
if ok == 0 and add_node.op_type == 'Add':
tp_node, ok = get_next_node_by_output(model, add_node.output[0])
if ok == 0 and tp_node.op_type == 'Transpose':
mm_node_list, ok = get_all_next_node_by_output(model, tp_node.output[0])
if ok == 0 and len(mm_node_list) == 3:
matm = {}
matm['Add'] = add_node
matm['Tp'] = tp_node
matm['mm1'] = mm_node_list[0]
matm['mm2'] = mm_node_list[1]
matm['mm3'] = mm_node_list[2]
matm_list.append(matm)
return matm_list
def gen_mul_add_block_by_rm_transpose(model):
logger.debug('into gen_mul_add_block_by_rm_transpose')
node_list = []
for node in model.graph.node:
node_dict = {}
if node.op_type == 'Mul':
#print('got mul:', node.name)
is_init = False
for init in model.graph.initializer:
if init.name == node.input[0] or init.name == node.input[1]:
is_init = True
break
if is_init == False:
dataA = values.get_constant_value(model, node.input[0])
if len(dataA) == 0:
dataA = values.get_constant_value(model, node.input[1])
if dataA != []:
is_init = True
if is_init == True:
#print('----got mul:', node.name)
next_node, ok = get_next_node_by_output(model, node.output[0])
if ok == 0 and next_node.op_type == 'Add':
##############
#print('----got add:', next_node.name)
is_init = False
for init in model.graph.initializer:
if init.name == next_node.input[1]:
is_init = True
break
if is_init == False:
dataA = values.get_constant_value(model, next_node.input[1])
if dataA != []:
is_init = True
if is_init == True:
#print('get_all_next_node_by_output---', next_node.output, node.name)
tp_node, ok = get_next_node_by_output(model, next_node.output[0])
if ok == 0 and tp_node.op_type== 'Transpose':
mm_node1, ok = get_next_node_by_output(model, tp_node.output[0])
if ok == 0 and mm_node1.op_type == 'MatMul':
add_node1, ok = get_next_node_by_output(model, mm_node1.output[0])
if ok == 0 and add_node1.op_type == 'Add':
next_node_list, ok = get_all_next_node_by_output(model, add_node1.output[0])
#print('next_node_list:', len(next_node_list))
if len(next_node_list) == 2:
#print('got next_node_list:', next_node_list[0].op_type, next_node_list[1].op_type)
if (next_node_list[0].op_type == 'Div' and next_node_list[1].op_type == 'Mul') or \
(next_node_list[0].op_type == 'Mul' and next_node_list[1].op_type == 'Div'):
logger.debug('---got it~')
mul_node1 = next_node_list[0]
if next_node_list[1].op_type == 'Mul':
mul_node1 = next_node_list[1]
mul_node2, ok = get_next_node_by_output(model, mul_node1.output[0])
if ok == 0 and mul_node2.op_type == 'Mul':
mm_node2, ok = get_next_node_by_output(model, mul_node2.output[0])
if ok == 0 and mm_node2.op_type == 'MatMul':
add_node2, ok = get_next_node_by_output(model, mm_node2.output[0])
if ok == 0 and add_node2.op_type == 'Add':
tp_node2, ok = get_next_node_by_output(model, add_node2.output[0])
if ok == 0 and tp_node2.op_type == 'Transpose':
add_node3, ok = get_next_node_by_output(model, tp_node2.output[0])
if ok == 0 and add_node3.op_type == 'Add':
logger.debug('got match transpose block')
node_dict['Add'] = add_node3
node_dict['Transpose2'] = tp_node2
node_dict['MatMul'] = mm_node1
node_dict['Transpose'] = tp_node
node_list.append(node_dict)
for nd in node_list:
logger.debug('gen_mul_add_block_by_rm_transpose working...')
tp_node = nd['Transpose']
tp_node2 = nd['Transpose2']
add_node = nd['Add']
mm_node = nd['MatMul']
mm_node.input[0] = tp_node.input[0]
#model.graph.node.remove(tp_node)
logger.debug('gen_mul_add_block_by_rm_transpose, remove transpose node: {}'.format(tp_node.name))
operation.remove_onnx_node(model, tp_node)
add_node.input[1] = tp_node2.input[0]
logger.debug('gen_mul_add_block_by_rm_transpose, remove transpose node2: {}'.format(tp_node2.name))
#model.graph.node.remove(tp_node2)
operation.remove_onnx_node(model, tp_node2)
def gen_mul_add_block_by_rm_transpose2(model):
tramt_list = []
for node in model.graph.node:
if node.op_type == 'Transpose':
rs_node, ok = get_prev_node_by_input(model, node.input[0])
if ok == 0 and rs_node.op_type == 'Reshape':
add_node, ok = get_prev_node_by_input(model, rs_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
mm_node, ok = get_prev_node_by_input(model, add_node.input[1])
if ok == 0 and mm_node.op_type == 'MatMul':
tp_node2, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and tp_node2.op_type == 'Transpose':
logger.debug('got tramt {}'.format(mm_node.name))
tramt = {}
tramt['Transpose'] = tp_node2
'''
tramt['Reshape'] = rs_node
tramt['Add'] = add_node
tramt['MatMul'] = mm_node
tramt['Transpose2'] = tp_node2
'''
if tramt not in tramt_list:
tramt_list.append(tramt)
for tramt in tramt_list:
tp_node = tramt['Transpose']
tp_next_node_list, _ = get_all_next_node_by_output(model, tp_node.output[0])
for node in tp_next_node_list:
node.input[0] = tp_node.input[0]
logger.debug('gen_mul_add_block_by_rm_transpose2, remove transpose node: {}'.format(tp_node.name))
model.graph.node.remove(tp_node)
def correct_reshape_expand_reshape_pattern(model):
logger.debug('into correct_reshape_expand_reshape_pattern')
node_list = []
for node in model.graph.node:
node_dict = {}
if node.op_type == 'Reshape':
expend_node, ok = get_next_node_by_output(model, node.output[0])
if ok == 0 and expend_node.op_type == 'Expand':
reshape_node, ok = get_next_node_by_output(model, expend_node.output[0])
if ok == 0 and reshape_node.op_type == 'Reshape':
node_dict['Reshape1'] = node
node_dict['Expand'] = expend_node
node_dict['Reshape2'] = reshape_node
node_list.append(node_dict)
for index, nd in enumerate(node_list):
rs_node1 = nd['Reshape1']
rs_node2 = nd['Reshape2']
expand_node = nd['Expand']
len_reshape1_input = len(values.get_tensor_shape_by_name(model, rs_node1.input[0]))
len_reshape2_output = len(values.get_tensor_shape_by_name(model, rs_node2.output[0]))
if len_reshape1_input < 4:
diff = 4 - len_reshape1_input
logger.info('Corrent Reshape+Expand+Reshape {}'.format(len_reshape1_input))
input_shape = values.get_tensor_shape_by_name(model, rs_node1.input[0])
new_shape = [1,1,1,1]
for idx, v in enumerate(input_shape):
new_shape[idx+diff] = v
logger.debug('got new_shape: {}'.format(new_shape))
if True:
shape_tensor_name = rs_node1.name + '_reshape_data_' + str(index)
const_shape = onnx.helper.make_tensor(shape_tensor_name, onnx.TensorProto.INT64, [4], new_shape)
model.graph.initializer.append(const_shape)
output_tensor_name = rs_node1.name + '_reshape_' + str(index)
output_tensor = onnx.helper.make_tensor_value_info(output_tensor_name, onnx.TensorProto.FLOAT, new_shape)
model.graph.value_info.append(output_tensor)
prev_node, _ = get_prev_node_by_input(model, rs_node1.input[0])
rs_node = onnx.helper.make_node(
name=rs_node1.name+'__Reshape__'+ str(index),
op_type='Reshape',
inputs=[rs_node1.input[0], shape_tensor_name],
outputs=[output_tensor_name]
)
insert_node(model, rs_node, prev_node)
rs_node1.input[0] = output_tensor_name
###############################################
shape_tensor_name = rs_node2.name + '_reshape_data_' + str(index)
new_shape = values.get_tensor_shape_by_name(model, expand_node.output[0])
const_shape = onnx.helper.make_tensor(shape_tensor_name, onnx.TensorProto.INT64, [4], new_shape)
model.graph.initializer.append(const_shape)
output_tensor_name = rs_node2.name + '_reshape_' + str(index)
output_tensor = onnx.helper.make_tensor_value_info(output_tensor_name, onnx.TensorProto.FLOAT, new_shape)
model.graph.value_info.append(output_tensor)
rs_node = onnx.helper.make_node(
name=rs_node2.name+'__Reshape__'+ str(index),
op_type='Reshape',
inputs=[expand_node.output[0], shape_tensor_name],
outputs=[output_tensor_name]
)
insert_node(model, rs_node, expand_node)
rs_node2.input[0] = output_tensor_name
#Transpose-->Reshape-->MatMul-->Add-->Reshape-->Transpose
def handle_matmul_add_child_block(model):
logger.debug('into handle_matmul_add_child_block')
node_list = []
for node in model.graph.node:
node_dict = {}
if node.op_type == 'Transpose':
logger.debug('into handle_matmul_add_child_block, step 1')
rs_node, ok = get_next_node_by_output(model, node.output[0])
if ok == 0 and rs_node.op_type == 'Reshape':
logger.debug('into handle_matmul_add_child_block, step 2')
mm_node, ok = get_next_node_by_output(model, rs_node.output[0])
if ok == 0 and mm_node.op_type == 'MatMul':
logger.debug('into handle_matmul_add_child_block, step 3')
add_node, ok = get_next_node_by_output(model, mm_node.output[0])
if ok == 0 and add_node.op_type == 'Add':
logger.debug('into handle_matmul_add_child_block, step 4')
rs_node2, ok = get_next_node_by_output(model, add_node.output[0])
if ok == 0 and rs_node2.op_type == 'Reshape':
tp_node2, ok = get_next_node_by_output(model, rs_node2.output[0])
if ok == 0 and tp_node2.op_type == 'Transpose':
logger.debug('got match matmul+add child block~~')
node_dict['Transpose'] = node
node_dict['Transpose2'] = tp_node2
node_dict['MatMul'] = mm_node
node_dict['Add'] = add_node
node_dict['Reshape'] = rs_node
node_dict['Reshape2'] = rs_node2
node_list.append(node_dict)
for nd in node_list:
tp_node = nd['Transpose']
tp_node2 = nd['Transpose2']
add_node = nd['Add']
mm_node = nd['MatMul']
rs_node = nd['Reshape']
rs_node2 = nd['Reshape2']
###add transpose
ts_name = tp_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
shape = values.get_tensor_shape_by_name(model, rs_node.output[0])
ts_output_shape = [shape[1], shape[0]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[rs_node.output[0]],
outputs=[ts_output_name],
perm=[1,0])
model.graph.value_info.append(transpose_output)
###add reshape-1
rs2_name = rs_node.name + '_reshape_1_'
rs2_output_name = rs2_name + '_output_'
shape = values.get_tensor_shape_by_name(model, rs_node.output[0])
rs2_output_shape = [1, shape[1], 1, shape[0]]
rs_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape_name = rs_node.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs_node_insert1 = onnx.helper.make_node(
'Reshape',
name=rs2_name,
inputs=[ts_output_name, const_shape_name],
outputs=[rs2_output_name])
model.graph.value_info.append(rs_output)
insert_node(model, ts_node, rs_node)
insert_node(model, rs_node_insert1, ts_node)
mm_node.input[0] = rs2_output_name
###MatMul-->Conv
inputB, shapeB = values.get_init_value_and_shape(model, mm_node.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('-- inputB is not in initilizer')
inputB = values.get_constant_value(model, input_next.input[1])
mm_node.op_type = 'Conv'
logger.debug('=== reuse MatMul to Conv')
const_x_name = mm_node.name + '_to_conv_x_'
v = inputB
old_dims = [shapeB[0], shapeB[1]]
dims_ = [shapeB[1], shapeB[0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++--- A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---+++ A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
operation.remove_initializer_if_necessary_by_name(model, mm_node.input[1], mm_node)
model.graph.initializer.append(const_x_tensor)
mm_node.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
mm_node.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
mm_node.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
mm_node.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
mm_node.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
mm_node.attribute.append(attr)
B = add_node.input[1]
mm_node.input.append(B)
conv_output_shape = rs2_output_shape
update_tensor_shape(model, mm_node.output[0], conv_output_shape)
#Add--->Reshape
add_node.op_type = 'Reshape'
del add_node.attribute[:]
rs_name = add_node.name + '_reshape_1_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [conv_output_shape[1], conv_output_shape[2], conv_output_shape[3]]
logger.debug('-----+++ rs_output_shape: {}'.format(rs_output_shape))
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape_name = add_node.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
add_node.input[1] = const_shape_name
update_tensor_shape(model, add_node.output[0], rs_output_shape)
#Reshape2
operation.remove_onnx_node(model, rs_node2)
#Transpose
tp_node2.input[0] = add_node.output[0]
del tp_node2.attribute[:]
attr = onnx.helper.make_attribute('perm', [1, 2, 0])
tp_node2.attribute.append(attr)
#Transpose-->Reshape-->Gemm-->Reshape-->Transpose
def handle_matmul_add_child_block2(model):
logger.debug('into handle_matmul_add_child_block2')
node_list = []
for node in model.graph.node:
node_dict = {}
if node.op_type == 'Transpose':
logger.debug('into handle_matmul_add_child_block2, step 1')
rs_node, ok = get_next_node_by_output(model, node.output[0])
if ok == 0 and rs_node.op_type == 'Reshape':
logger.debug('into handle_matmul_add_child_block2, step 2')
gemm_node, ok = get_next_node_by_output(model, rs_node.output[0])
if ok == 0 and gemm_node.op_type == 'Gemm':
logger.debug('into handle_matmul_add_child_block2, step 3')
rs_node2, ok = get_next_node_by_output(model, gemm_node.output[0])
if ok == 0 and rs_node2.op_type == 'Reshape':
tp_node2, ok = get_next_node_by_output(model, rs_node2.output[0])
if ok == 0 and tp_node2.op_type == 'Transpose':
logger.debug('got match matmul+add child block~~')
node_dict['Transpose'] = node
node_dict['Transpose2'] = tp_node2
node_dict['Gemm'] = gemm_node
node_dict['Reshape'] = rs_node
node_dict['Reshape2'] = rs_node2
node_list.append(node_dict)
for nd in node_list:
tp_node = nd['Transpose']
tp_node2 = nd['Transpose2']
gemm_node = nd['Gemm']
rs_node = nd['Reshape']
rs_node2 = nd['Reshape2']
###add transpose
ts_name = tp_node.name + '_transpose_'
ts_output_name = ts_name + '_output_'
shape = values.get_tensor_shape_by_name(model, rs_node.output[0])
ts_output_shape = [shape[1], shape[0]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[rs_node.output[0]],
outputs=[ts_output_name],
perm=[1,0])
model.graph.value_info.append(transpose_output)
###add reshape-1
rs2_name = rs_node.name + '_reshape_1_'
rs2_output_name = rs2_name + '_output_'
shape = values.get_tensor_shape_by_name(model, rs_node.output[0])
rs2_output_shape = [1, shape[1], 1, shape[0]]
rs_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape_name = rs_node.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs_node_insert1 = onnx.helper.make_node(
'Reshape',
name=rs2_name,
inputs=[ts_output_name, const_shape_name],
outputs=[rs2_output_name])
model.graph.value_info.append(rs_output)
insert_node(model, ts_node, rs_node)
insert_node(model, rs_node_insert1, ts_node)
gemm_node.input[0] = rs2_output_name
###Gemm-->Conv
inputB, shapeB = values.get_init_value_and_shape(model, gemm_node.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('-- inputB is not in initilizer')
inputB = values.get_constant_value(model, input_next.input[1])
gemm_node.op_type = 'Conv'
logger.debug('=== reuse MatMul to Conv')
const_x_name = gemm_node.name + '_to_conv_x_'
transB = 0
attributes = gemm_node.attribute
for attr in attributes:
#TBD
'''
if attr.name == 'alpha':
alpha = attr.f
logger.debug('alpha: {}'.format(alpha))
if attr.name == 'beta':
beta = attr.f
logger.debug('beta: {}'.format(beta))
if attr.name == 'transA':
transA = attr.i
logger.debug('transA: {}'.format(transA))
'''
if attr.name == 'transB':
transB = attr.i
logger.debug('got transB: {}'.format(transB))
del gemm_node.attribute[:]
v = inputB
old_dims = [shapeB[0], shapeB[1]]
dims_ = [shapeB[1], shapeB[0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
if transB == 0:
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++--- A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
if transB == 0:
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---+++ A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
operation.remove_initializer_if_necessary_by_name(model, gemm_node.input[1], gemm_node)
model.graph.initializer.append(const_x_tensor)
gemm_node.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
gemm_node.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
gemm_node.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
gemm_node.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
gemm_node.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
gemm_node.attribute.append(attr)
#B = gemm_node.input[2]
#gemm_node.input.append(B)
conv_output_shape = rs2_output_shape
update_tensor_shape(model, gemm_node.output[0], conv_output_shape)
#Reshape2
rs_output_shape = [conv_output_shape[1], conv_output_shape[2], conv_output_shape[3]]
logger.debug('-----+++ rs_output_shape: {}'.format(rs_output_shape))
const_shape_name = rs_node2.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
operation.remove_initializer_if_necessary_by_name(model, rs_node2.input[1], rs_node2)
rs_node2.input[1] = const_shape_name
update_tensor_shape(model, rs_node2.output[0], rs_output_shape)
#Transpose
del tp_node2.attribute[:]
attr = onnx.helper.make_attribute('perm', [1, 2, 0])
tp_node2.attribute.append(attr)
def mha_optimizer(model):
pattern = -1
ret1 = match_mha_block_pattern_one(model) #for decoder_model_bs10.onnx
ret2 = match_mha_block_pattern_two(model) #for bert_cls_sim1.onnx/bert_sst2_wm.onnx
ret3 = match_mha_block_pattern_three(model) #for bert_sst2_sim.onnx
ret4 = match_mha_block_pattern_four(model) #for bert_squad_v1_sim1.onnx
ret5 = match_mha_block_pattern_five(model) #for
if ret1 == 0:
pattern = 1
elif ret2 == 0:
pattern = 2
elif ret3 == 0:
pattern = 3
elif ret4 == 0:
pattern = 4
elif ret5 == 0:
pattern = 5
if pattern == -1:
logger.debug('This is not a mha model---')
return model
if pattern == 5:
matm_list = get_mul_add_transpose_matmul_block(model)
for matm in matm_list:
print('got matm~')
matm['mm1'].input[0] = matm['Add'].output[0]
matm['mm2'].input[0] = matm['Add'].output[0]
matm['mm3'].input[0] = matm['Add'].output[0]
#model.graph.node.remove(matm['Tp'])
operation.remove_onnx_node(model, matm['Tp'])
del model.graph.value_info[:]
model = onnx.shape_inference.infer_shapes(model)
model = onnx.shape_inference.infer_shapes(model)
gen_mul_add_block_by_rm_transpose(model)
gen_mul_add_block_by_rm_transpose2(model)
handle_matmul_add_child_block(model)
handle_matmul_add_child_block2(model)
#onnx.save(model, './ss.onnx')
#sys.exit()
matmul_list = []
logger.debug('mha_optimizer, pattern = {}'.format(pattern))
if pattern == 1:
handle_add_combination_pattern_one(model)
if pattern == 2 or pattern == 3:
handle_add_combination_pattern_two_three(model)
if pattern != 4:
handle_mul_add_block(model, pattern)
else:
logger.debug('handle pattern 4')
handle_add_combination_pattern_four(model)
handle_mul_add_block_pattern_four(model)
for node in model.graph.node:
#print(node_id, ", name:", node.name, ", input:", node.input, ", output:", node.output, \
# ", op:", node.op_type)
if node.op_type == 'MatMul':
inputA = node.input[0]
inputB = node.input[1]
is_init = False
for init in model.graph.initializer:
if init.name == inputA or init.name == inputB:
is_init = True
break
if is_init == False:
dataA = values.get_constant_value(model, inputA)
dataB = values.get_constant_value(model, inputB)
if dataA != [] or dataB != []:
is_init = True
if is_init == True:
logger.debug('skip MatMul: {}'.format(node.name))
continue
matmul_dict = {}
mul_node, ok = get_prev_node_by_input(model, inputA)
if ok == 0 and mul_node.op_type == 'Mul':
mulB = values.get_init_value(model, mul_node.input[1])
logger.debug('matmul input is Mul: {}'.format(mul_node.name, mulB[0]))
if isinstance(mulB, list) and mulB == []:
logger.debug('mulB is not in initilizer')
mulB = values.get_constant_value(model, mul_node.input[1])
if len(mulB) > 0 and abs(mulB[0] - 0.125) < 0.00001:
logger.debug('this is the mul-node which we wanted(value B is 0.125)...')
matmul_dict['AMul'] = mul_node
inputA = mul_node.input[0]
div_node, ok = get_prev_node_by_input(model, inputA)
if ok == 0 and div_node.op_type == 'Div':
divB = values.get_init_value(model, div_node.input[1])
logger.debug('matmul input is Div: {}'.format(div_node.name, divB[0]))
if isinstance(divB, list) and mulB == []:
logger.debug('divB is not in initilizer')
divB = values.get_constant_value(model, div_node.input[1])
if len(divB) > 0 and abs(divB[0] - 8.0) < 0.00001:
logger.debug('this is the div-node which we wanted(value B is 8)...')
matmul_dict['AMul'] = div_node
inputA = div_node.input[0]
node_dictA, res1 = get_matmul_input_path_pattern_one(model, inputA)
node_dictB, res2 = get_matmul_input_path_pattern_one(model, inputB)
if res1 > -1:
print_matmul_input_path(node_dictA['node_list'], 'node_listA')
if res2 > -1:
print_matmul_input_path(node_dictB['node_list'], 'node_listB')
if res1 > -1 or res2 > -1:
next_node, _ = get_next_node_by_output(model, node.output[0])
nnext_node, _ = get_next_node_by_output(model, next_node.output[0])
#matmul_dict = {}
matmul_dict['name'] = node.name
matmul_dict['current'] = node
matmul_dict['pathA'] = node_dictA['node_list']
matmul_dict['A_MatMul_Add'] = False
if res1 == 0:
matmul_dict['A_MatMul_Add'] = True
matmul_dict['A_addA'] = node_dictA['addA']
matmul_dict['A_matmul_AShape'] = node_dictA['matmul_AShape']
matmul_dict['A_inputB'] = node_dictA['inputB']
matmul_dict['A_matmul_BShape'] = node_dictA['matmul_BShape']
if 'prev' in node_dictA.keys():
matmul_dict['A_prev'] = node_dictA['prev']
matmul_dict['pathB'] = node_dictB['node_list']
matmul_dict['B_MatMul_Add'] = False
if res2 == 0:
matmul_dict['B_MatMul_Add'] = True
matmul_dict['B_addA'] = node_dictB['addA']
matmul_dict['B_matmul_AShape'] = node_dictB['matmul_AShape']
matmul_dict['B_inputB'] = node_dictB['inputB']
matmul_dict['B_matmul_BShape'] = node_dictB['matmul_BShape']
if 'prev' in node_dictB.keys():
matmul_dict['B_prev'] = node_dictB['prev']
matmul_dict['next'] = [next_node]
matmul_dict['nnext'] = [nnext_node]
matmul_list.append(matmul_dict)
for ll in matmul_list:
logger.debug('stat MatMul: {}, next: {}, op_type: {}'.format(ll['name'], ll['next'][0].name,ll['next'][0].op_type))
logger.debug('------pathA:')
for node in ll['pathA']:
logger.debug(' {}'.format(node.name))
logger.debug('------pathB:')
for node in ll['pathB']:
logger.debug(' {}'.format(node.name))
cvt_matmul_add_to_conv(model, ll, pattern)
if pattern == 1:
handle_mul_add_block_two(model)
if pattern == 4:
pass
#handle_last_group_pattern_four(model)
else:
handle_last_group(model)
return model
def match_mha_block_common(model):
common_dict = {}
logger.debug('into match_mha_block_common')
for node in model.graph.node:
if node.op_type == 'Add':
mul_node, ok = get_prev_node_by_input(model, node.input[0])
if ok == 0 and mul_node.op_type == 'Mul':
div_node = None
div_node_case1, ok1 = get_prev_node_by_input(model, mul_node.input[0])
div_node_case2, ok2 = get_prev_node_by_input(model, mul_node.input[1])
if ok1 == 0 and div_node_case1.op_type == 'Div':
div_node = div_node_case1
elif ok2 == 0 and div_node_case2.op_type == 'Div':
div_node = div_node_case2
#logger.debug('into match_mha_block_common, step 1')
if div_node != None:
sub_node, ok1 = get_prev_node_by_input(model, div_node.input[0])
sqrt_node, ok2 = get_prev_node_by_input(model, div_node.input[1])
if ok1 == 0 and sub_node.op_type == 'Sub' and ok2 == 0 and sqrt_node.op_type == 'Sqrt':
add_node, ok = get_prev_node_by_input(model, sqrt_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
#logger.debug('into match_mha_block_common, step 2')
rm_node, ok = get_prev_node_by_input(model, add_node.input[0])
if ok == 0 and rm_node.op_type == 'ReduceMean':
pow_node, ok = get_prev_node_by_input(model, rm_node.input[0])
if ok == 0 and pow_node.op_type == 'Pow':
#logger.debug('into match_mha_block_common, step 3')
sub_node, ok = get_prev_node_by_input(model, pow_node.input[0])
if ok == 0 and sub_node.op_type == 'Sub':
add_node, ok1 = get_prev_node_by_input(model, sub_node.input[0])
rm_node, ok2 = get_prev_node_by_input(model, sub_node.input[1])
if ok1 == 0 and add_node.op_type == 'Add' and ok2 == 0 and rm_node.op_type == 'ReduceMean':
#logger.debug('into match_mha_block_common, step 4')
add_node_, ok = get_prev_node_by_input(model, rm_node.input[0])
if ok == 0 and add_node_.op_type == 'Add' and add_node_ == add_node:
add_node_1, ok1 = get_prev_node_by_input(model, add_node_.input[0])
add_node_2, ok2 = get_prev_node_by_input(model, add_node_.input[1])
#logger.debug('into match_mha_block_common, step 5')
if ok1 == 0 and add_node_1.op_type == 'Add' and ok2 == 0 and add_node_2.op_type == 'Add':
mm_node = None
mm_node_case1, ok1 = get_prev_node_by_input(model, add_node_2.input[0])
mm_node_case2, ok2 = get_prev_node_by_input(model, add_node_2.input[1])
if ok1 == 0 and mm_node_case1.op_type == 'MatMul':
mm_node = mm_node_case1
elif ok2 == 0 and mm_node_case2.op_type == 'MatMul':
mm_node = mm_node_case2
if mm_node == None:
mm_node_case1, ok1 = get_prev_node_by_input(model, add_node_1.input[0])
mm_node_case2, ok2 = get_prev_node_by_input(model, add_node_1.input[1])
if ok1 == 0 and mm_node_case1.op_type == 'MatMul':
mm_node = mm_node_case1
elif ok2 == 0 and mm_node_case2.op_type == 'MatMul':
mm_node = mm_node_case2
logger.debug('into match_mha_block_common, step 6')
if mm_node != None:
#logger.debug('into match_mha_block_common, step 6.1')
mul_node, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and mul_node.op_type == 'Mul':
#logger.debug('into match_mha_block_common, step 6.2')
mul_node_2, ok = get_prev_node_by_input(model, mul_node.input[0])
if ok == 0 and mul_node_2.op_type == 'Mul':
add_node_1, ok1 = get_prev_node_by_input(model, mul_node_2.input[0])
add_node_2, ok2 = get_prev_node_by_input(model, mul_node_2.input[1])
#logger.debug('into match_mha_block_common, step 7')
if ok1 == 0 and add_node_1.op_type == 'Add' and ok2 == 0 and add_node_2.op_type == 'Add':
erf_node, ok = get_prev_node_by_input(model, add_node_2.input[0])
if ok == 0 and erf_node.op_type == 'Erf':
#logger.debug('into match_mha_block_common, step 8')
div_node, ok = get_prev_node_by_input(model, erf_node.input[0])
if ok == 0 and div_node.op_type == 'Div':
#logger.debug('into match_mha_block_common, step 9')
add_node, ok = get_prev_node_by_input(model, div_node.input[0])
if ok == 0 and add_node.op_type == 'Add' and add_node == add_node_1:
#logger.debug('into match_mha_block_common, step 10')
mm_node = None
mm_node_case1, ok1 = get_prev_node_by_input(model, add_node.input[0])
mm_node_case2, ok2 = get_prev_node_by_input(model, add_node.input[1])
if ok1 == 0 and mm_node_case1.op_type == 'MatMul':
mm_node = mm_node_case1
elif ok2 == 0 and mm_node_case2.op_type == 'MatMul':
mm_node = mm_node_case2
if mm_node != None:
#logger.debug('into match_mha_block_common, step 11')
add_node, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
mul_node, ok = get_prev_node_by_input(model, add_node.input[0])
if ok == 0 and mul_node.op_type == 'Mul':
div_node = None
div_node_case1, ok1 = get_prev_node_by_input(model, mul_node.input[0])
div_node_case2, ok2 = get_prev_node_by_input(model, mul_node.input[1])
if ok1 == 0 and div_node_case1.op_type == 'Div':
div_node = div_node_case1
elif ok2 == 0 and div_node_case2.op_type == 'Div':
div_node = div_node_case2
if div_node != None:
#logger.debug('into match_mha_block_common, step 12')
sub_node, ok1 = get_prev_node_by_input(model, div_node.input[0])
sqrt_node, ok2 = get_prev_node_by_input(model, div_node.input[1])
if ok1 == 0 and sub_node.op_type == 'Sub' and ok2 == 0 and sqrt_node.op_type == 'Sqrt':
add_node, ok = get_prev_node_by_input(model, sqrt_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
#logger.debug('into match_mha_block_common, step 13')
rm_node, ok = get_prev_node_by_input(model, add_node.input[0])
if ok == 0 and rm_node.op_type == 'ReduceMean':
#logger.debug('into match_mha_block_common, step 14')
pow_node, ok = get_prev_node_by_input(model, rm_node.input[0])
if ok == 0 and pow_node.op_type == 'Pow':
#logger.debug('into match_mha_block_common, step 15')
sub_node_, ok = get_prev_node_by_input(model, pow_node.input[0])
if ok == 0 and sub_node_.op_type == 'Sub' and sub_node_ == sub_node:
add_node, ok1 = get_prev_node_by_input(model, sub_node_.input[0])
rm_node, ok2 = get_prev_node_by_input(model, sub_node_.input[1])
#logger.debug('into match_mha_block_common, step 16')
if ok1 == 0 and add_node.op_type == 'Add' and ok2 == 0 and rm_node.op_type == 'ReduceMean':
add_node_, ok = get_prev_node_by_input(model, rm_node.input[0])
if ok == 0 and add_node_.op_type == 'Add' and add_node_ == add_node:
logger.debug('into match_mha_block_common, step 17')
add_node_1, ok1 = get_prev_node_by_input(model, add_node_.input[0])
add_node_2, ok2 = get_prev_node_by_input(model, add_node_.input[1])
if ok1 == 0 and add_node_1.op_type == 'Add' and ok2 == 0 and add_node_2.op_type == 'Add':
mm_node = None
mm_node_case1, ok1 = get_prev_node_by_input(model, add_node_2.input[0])
mm_node_case2, ok2 = get_prev_node_by_input(model, add_node_2.input[1])
if ok1 == 0 and mm_node_case1.op_type == 'MatMul':
mm_node = mm_node_case1
elif ok2 == 0 and mm_node_case2.op_type == 'MatMul':
mm_node = mm_node_case2
if mm_node == None:
mm_node_case1, ok1 = get_prev_node_by_input(model, add_node_1.input[0])
mm_node_case2, ok2 = get_prev_node_by_input(model, add_node_1.input[1])
if ok1 == 0 and mm_node_case1.op_type == 'MatMul':
mm_node = mm_node_case1
elif ok2 == 0 and mm_node_case2.op_type == 'MatMul':
mm_node = mm_node_case2
if mm_node != None:
logger.debug('into match_mha_block_common, step 18')
reshape_node, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and reshape_node.op_type == 'Reshape':
tp_node, ok = get_prev_node_by_input(model, reshape_node.input[0])
if ok == 0 and tp_node.op_type == 'Transpose':
mm_node, ok = get_prev_node_by_input(model, tp_node.input[0])
if ok == 0 and mm_node.op_type == 'MatMul':
softmax_node, ok1 = get_prev_node_by_input(model, mm_node.input[0])
tp_node, ok2 = get_prev_node_by_input(model, mm_node.input[1])
if ok1 == 0 and softmax_node.op_type == 'Softmax' and ok2 == 0 and tp_node.op_type == 'Transpose':
reshape_node, ok = get_prev_node_by_input(model, tp_node.input[0])
if ok == 0 and reshape_node.op_type == 'Reshape':
add_node_branch1, ok = get_prev_node_by_input(model, reshape_node.input[0])
if ok == 0 and add_node_branch1.op_type == 'Add':
mm_node = None
mm_node_case1, ok1 = get_prev_node_by_input(model, add_node_branch1.input[0])
mm_node_case2, ok2 = get_prev_node_by_input(model, add_node_branch1.input[1])
if ok1 == 0 and mm_node_case1.op_type == 'MatMul':
mm_node = mm_node_case1
elif ok2 == 0 and mm_node_case2.op_type == 'MatMul':
mm_node = mm_node_case2
if mm_node != None:
add_node_last, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and add_node_last.op_type == 'Add':
add_node_common, ok = get_prev_node_by_input(model, softmax_node.input[0])
if ok == 0 and add_node_common.op_type == 'Add':
common_dict['add_node_last'] = add_node_last
common_dict['add_node_common'] = add_node_common
logger.debug('got common mha block')
break
return common_dict
def match_mha_block_pattern_five(model):
ret = -1
logger.debug('into match_mha_block_pattern_five')
for node in model.graph.node:
if node.op_type == 'Add':
mul_node, ok = get_prev_node_by_input(model, node.input[0])
if ok == 0 and mul_node.op_type == 'Mul':
div_node = None
div_node_case1, ok1 = get_prev_node_by_input(model, mul_node.input[0])
div_node_case2, ok2 = get_prev_node_by_input(model, mul_node.input[1])
if ok1 == 0 and div_node_case1.op_type == 'Div':
div_node = div_node_case1
elif ok2 == 0 and div_node_case2.op_type == 'Div':
div_node = div_node_case2
#logger.debug('into match_mha_block_pattern_five, step 1')
if div_node != None:
sub_node, ok1 = get_prev_node_by_input(model, div_node.input[0])
sqrt_node, ok2 = get_prev_node_by_input(model, div_node.input[1])
if ok1 == 0 and sub_node.op_type == 'Sub' and ok2 == 0 and sqrt_node.op_type == 'Sqrt':
add_node, ok = get_prev_node_by_input(model, sqrt_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
#logger.debug('into match_mha_block_pattern_five, step 2')
rm_node, ok = get_prev_node_by_input(model, add_node.input[0])
if ok == 0 and rm_node.op_type == 'ReduceMean':
pow_node, ok = get_prev_node_by_input(model, rm_node.input[0])
if ok == 0 and pow_node.op_type == 'Pow':
#logger.debug('into match_mha_block_pattern_five, step 3')
sub_node, ok = get_prev_node_by_input(model, pow_node.input[0])
if ok == 0 and sub_node.op_type == 'Sub':
add_node, ok1 = get_prev_node_by_input(model, sub_node.input[0])
rm_node, ok2 = get_prev_node_by_input(model, sub_node.input[1])
if ok1 == 0 and add_node.op_type == 'Add' and ok2 == 0 and rm_node.op_type == 'ReduceMean':
#logger.debug('into match_mha_block_pattern_five, step 4')
add_node_, ok = get_prev_node_by_input(model, rm_node.input[0])
if ok == 0 and add_node_.op_type == 'Add' and add_node_ == add_node:
add_node_1, ok1 = get_prev_node_by_input(model, add_node_.input[0])
tp_node, ok2 = get_prev_node_by_input(model, add_node_.input[1])
logger.debug('into match_mha_block_pattern_five, step 5')
if ok1 == 0 and add_node_1.op_type == 'Add' and ok2 == 0 and tp_node.op_type == 'Transpose':
add_node_2, ok = get_prev_node_by_input(model, tp_node.input[0])
if ok == 0 and add_node_2.op_type == 'Add':
mm_node = None
mm_node_case1, ok1 = get_prev_node_by_input(model, add_node_2.input[0])
mm_node_case2, ok2 = get_prev_node_by_input(model, add_node_2.input[1])
if ok1 == 0 and mm_node_case1.op_type == 'MatMul':
mm_node = mm_node_case1
elif ok2 == 0 and mm_node_case2.op_type == 'MatMul':
mm_node = mm_node_case2
logger.debug('into match_mha_block_pattern_five, step 6')
if mm_node != None:
#logger.debug('into match_mha_block_pattern_five, step 6.1')
mul_node, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and mul_node.op_type == 'Mul':
#logger.debug('into match_mha_block_pattern_five, step 6.2')
mul_node_2, ok = get_prev_node_by_input(model, mul_node.input[0])
if ok == 0 and mul_node_2.op_type == 'Mul':
add_node_1, ok1 = get_prev_node_by_input(model, mul_node_2.input[0])
add_node_2, ok2 = get_prev_node_by_input(model, mul_node_2.input[1])
#logger.debug('into match_mha_block_pattern_five, step 7')
if ok1 == 0 and add_node_1.op_type == 'Add' and ok2 == 0 and add_node_2.op_type == 'Add':
erf_node, ok = get_prev_node_by_input(model, add_node_2.input[0])
if ok == 0 and erf_node.op_type == 'Erf':
#logger.debug('into match_mha_block_pattern_five, step 8')
div_node, ok = get_prev_node_by_input(model, erf_node.input[0])
if ok == 0 and div_node.op_type == 'Div':
#logger.debug('into match_mha_block_pattern_five, step 9')
add_node, ok = get_prev_node_by_input(model, div_node.input[0])
if ok == 0 and add_node.op_type == 'Add' and add_node == add_node_1:
logger.debug('into match_mha_block_pattern_five, step 10')
mm_node = None
mm_node_case1, ok1 = get_prev_node_by_input(model, add_node.input[0])
mm_node_case2, ok2 = get_prev_node_by_input(model, add_node.input[1])
if ok1 == 0 and mm_node_case1.op_type == 'MatMul':
mm_node = mm_node_case1
elif ok2 == 0 and mm_node_case2.op_type == 'MatMul':
mm_node = mm_node_case2
if mm_node != None:
logger.debug('into match_mha_block_pattern_five, step 11')
tp_node, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and tp_node.op_type == 'Transpose':
#################
add_node, ok = get_prev_node_by_input(model, tp_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
mul_node, ok = get_prev_node_by_input(model, add_node.input[0])
if ok == 0 and mul_node.op_type == 'Mul':
div_node = None
div_node_case1, ok1 = get_prev_node_by_input(model, mul_node.input[0])
div_node_case2, ok2 = get_prev_node_by_input(model, mul_node.input[1])
if ok1 == 0 and div_node_case1.op_type == 'Div':
div_node = div_node_case1
elif ok2 == 0 and div_node_case2.op_type == 'Div':
div_node = div_node_case2
if div_node != None:
logger.debug('into match_mha_block_pattern_five, step 12')
sub_node, ok1 = get_prev_node_by_input(model, div_node.input[0])
sqrt_node, ok2 = get_prev_node_by_input(model, div_node.input[1])
if ok1 == 0 and sub_node.op_type == 'Sub' and ok2 == 0 and sqrt_node.op_type == 'Sqrt':
add_node, ok = get_prev_node_by_input(model, sqrt_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
#logger.debug('into match_mha_block_pattern_five, step 13')
rm_node, ok = get_prev_node_by_input(model, add_node.input[0])
if ok == 0 and rm_node.op_type == 'ReduceMean':
#logger.debug('into match_mha_block_pattern_five, step 14')
pow_node, ok = get_prev_node_by_input(model, rm_node.input[0])
if ok == 0 and pow_node.op_type == 'Pow':
#logger.debug('into match_mha_block_pattern_five, step 15')
sub_node_, ok = get_prev_node_by_input(model, pow_node.input[0])
if ok == 0 and sub_node_.op_type == 'Sub' and sub_node_ == sub_node:
add_node, ok1 = get_prev_node_by_input(model, sub_node_.input[0])
rm_node, ok2 = get_prev_node_by_input(model, sub_node_.input[1])
#logger.debug('into match_mha_block_pattern_five, step 16')
if ok1 == 0 and add_node.op_type == 'Add' and ok2 == 0 and rm_node.op_type == 'ReduceMean':
add_node_, ok = get_prev_node_by_input(model, rm_node.input[0])
if ok == 0 and add_node_.op_type == 'Add' and add_node_ == add_node:
logger.debug('got common mha block')
ret = 0
return ret
def get_node_group(model, input_name, num, index):
node_list = []
name = input_name
for i in range(num):
node, ok = get_prev_node_by_input(model, name)
if ok == 0 and len(node.input) > index[i]:
name = node.input[index[i]]
node_list.append(node)
else:
break
return node_list
def match_mha_block_pattern_two(model):
common_dict = match_mha_block_common(model)
if len(common_dict):
add_node_last = common_dict['add_node_last']
add_node_common = common_dict['add_node_common']
div_node, ok = get_prev_node_by_input(model, add_node_common.input[0])
if ok == 0 and div_node.op_type == 'Div':
mm_node, ok = get_prev_node_by_input(model, div_node.input[0])
if ok == 0 and mm_node.op_type == 'MatMul':
tp_node_1, ok1 = get_prev_node_by_input(model, mm_node.input[0])
tp_node_2, ok2 = get_prev_node_by_input(model, mm_node.input[1])
if ok1 == 0 and tp_node_1.op_type == 'Transpose' and ok2 == 0 and tp_node_2.op_type == 'Transpose':
reshape_node, ok = get_prev_node_by_input(model, tp_node_1.input[0])
if ok == 0 and reshape_node.op_type == 'Reshape':
add_node, ok = get_prev_node_by_input(model, reshape_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
mm_node, ok = get_prev_node_by_input(model, add_node.input[1])
if ok == 0 and mm_node.op_type == 'MatMul':
add_node_branchA, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and add_node_branchA.op_type == 'Add' and add_node_branchA == add_node_last:
########################
reshape_node, ok = get_prev_node_by_input(model, tp_node_2.input[0])
if ok == 0 and reshape_node.op_type == 'Reshape':
add_node, ok = get_prev_node_by_input(model, reshape_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
mm_node, ok = get_prev_node_by_input(model, add_node.input[1])
if ok == 0 and mm_node.op_type == 'MatMul':
add_node_branchB, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and add_node_branchB.op_type == 'Add' and add_node_branchB == add_node_last:
logger.debug('match mha block pattern two success')
return 0
return -1
def match_mha_block_pattern_three(model):
common_dict = match_mha_block_common(model)
if len(common_dict):
add_node_last = common_dict['add_node_last']
add_node_common = common_dict['add_node_common']
mm_node, ok = get_prev_node_by_input(model, add_node_common.input[0])
if ok == 0 and mm_node.op_type == 'MatMul':
mul_node, ok1 = get_prev_node_by_input(model, mm_node.input[0])
tp_node, ok2 = get_prev_node_by_input(model, mm_node.input[1])
if ok1 == 0 and mul_node.op_type == 'Mul' and ok2 == 0 and tp_node.op_type == 'Transpose':
tp_node2, ok = get_prev_node_by_input(model, mul_node.input[0])
if ok == 0 and tp_node2.op_type == 'Transpose':
reshape_node, ok = get_prev_node_by_input(model, tp_node2.input[0])
if ok == 0 and reshape_node.op_type == 'Reshape':
add_node, ok = get_prev_node_by_input(model, reshape_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
mm_node, ok = get_prev_node_by_input(model, add_node.input[0])
if ok == 0 and mm_node.op_type == 'MatMul':
add_node_branchA, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and add_node_branchA.op_type == 'Add' and add_node_branchA == add_node_last:
########################
reshape_node, ok = get_prev_node_by_input(model, tp_node.input[0])
if ok == 0 and reshape_node.op_type == 'Reshape':
add_node, ok = get_prev_node_by_input(model, reshape_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
mm_node, ok = get_prev_node_by_input(model, add_node.input[0])
if ok == 0 and mm_node.op_type == 'MatMul':
add_node_branchB, ok = get_prev_node_by_input(model, mm_node.input[0])
if ok == 0 and add_node_branchB.op_type == 'Add' and add_node_branchB == add_node_last:
logger.debug('match mha block pattern three success')
return 0
return -1
def match_mha_block_pattern_one(model):
res = -1
for node in model.graph.node:
if node.op_type == 'Add':
node_list = get_node_group(model, node.input[1], 6, [0,0,1,0,0,0])
if len(node_list) == 6:
expected_pattern = ['MatMul', 'Relu', 'Add', 'MatMul', 'Add', 'Mul']
for idx1, n in enumerate(node_list):
#print('node:', idx1, n.op_type, expected_pattern[idx1])
if n.op_type != expected_pattern[idx1]:
break
if idx1 == 5:
node_list2 = get_node_group(model, node_list[5].input[0], 6, [1,0,0,0,0,0])
if len(node_list2) == 6:
expected_pattern = ['Div', 'Sqrt', 'Add', 'ReduceMean', 'Pow', 'Sub']
for idx2, n in enumerate(node_list2):
if n.op_type != expected_pattern[idx2]:
break
if idx2 == 5:
node_list3 = get_node_group(model, node_list2[5].input[1], 7, [0,1,1,0,0,0,0])
if len(node_list3) == 7:
expected_pattern = ['ReduceMean', 'Add', 'Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul']
for idx3, n in enumerate(node_list3):
if n.op_type != expected_pattern[idx3]:
break
if idx3 == 6:
node_list4 = get_node_group(model, node_list3[6].input[1], 4, [0,0,1,0])
if len(node_list4) == 4:
expected_pattern = ['Transpose', 'Reshape', 'Add', 'MatMul']
for idx4, n in enumerate(node_list4):
if n.op_type != expected_pattern[idx4]:
break
if idx4 == 3:
node_list5 = get_node_group(model, node_list3[6].input[0], 12, [1,0,1,0,0,0,0,1,0,0,0,0])
if len(node_list5) == 12:
expected_pattern = ['Where', 'Softmax', 'Where', 'Div', 'MatMul', 'Transpose', 'Reshape','Add', 'MatMul', 'Add', 'Mul', 'Div']
for idx5, n in enumerate(node_list5):
if n.op_type != expected_pattern[idx5]:
break
if idx5 == 11:
node_list6 = get_node_group(model, node_list3[6].input[0], 9, [1,0,1,0,1,0,0,1,0])
if len(node_list6) == 9:
expected_pattern = ['Where', 'Softmax', 'Where', 'Div', 'MatMul', 'Transpose', 'Reshape','Add', 'MatMul']
for idx6, n in enumerate(node_list6):
if n.op_type != expected_pattern[idx6]:
break
if idx6 == 8:
res = 0
logger.debug('match_mha_block_pattern_one, success')
break
return res
def match_mha_block_pattern_four(model):
res = -1
for node in model.graph.node:
if node.op_type == 'Add':
node_list = get_node_group(model, node.input[1], 11, [1,1,0,0,0,0,0,0,1,0,0])
if len(node_list) == 11:
expected_pattern = ['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean', 'Add']
for idx1, n in enumerate(node_list):
#print('node:', idx1, n.op_type, expected_pattern[idx1])
if n.op_type != expected_pattern[idx1]:
break
if idx1 == 10:
logger.debug('match_mha_block_pattern_four, success 1')
node_list1 = get_node_group(model, node_list[10].input[0], 13, [0,0,1,1,1,0,1,1,1,0,0,0,0])
if len(node_list1) == 13:
expected_pattern = ['Add', 'MatMul', 'Mul', 'Mul', 'Add', 'Tanh', 'Mul', 'Add', 'Mul', 'Pow', 'Add', 'MatMul', 'Add']
for idx2, n in enumerate(node_list1):
#print('node:', idx1, n.op_type, expected_pattern[idx1])
if n.op_type != expected_pattern[idx2]:
break
if idx2 == 12:
logger.debug('match_mha_block_pattern_four, success 2')
node_list2 = get_node_group(model, node_list1[12].input[1], 11, [1,1,0,0,0,0,0,0,1,0,0])
if len(node_list2) == 11:
expected_pattern = ['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean', 'Add']
for idx3, n in enumerate(node_list2):
#print('node:', idx1, n.op_type, expected_pattern[idx1])
if n.op_type != expected_pattern[idx3]:
break
if idx3 == 10:
logger.debug('match_mha_block_pattern_four, success 3')
last_add_node, ok = get_prev_node_by_input(model, node_list2[10].input[1])
if ok == 0 and last_add_node.op_type == 'Add':
node_list3 = get_node_group(model, node_list2[10].input[0], 5, [0,0,0,0,0])
if len(node_list3) == 5:
expected_pattern = ['Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul']
for idx4, n in enumerate(node_list3):
#print('node:', idx1, n.op_type, expected_pattern[idx1])
if n.op_type != expected_pattern[idx4]:
break
if idx4 == 4:
logger.debug('match_mha_block_pattern_four, success 4')
node_list4 = get_node_group(model, node_list3[4].input[1], 5, [0,0,0,0,0])
if len(node_list4) == 5:
expected_pattern = ['Transpose', 'Reshape', 'Add', 'MatMul', 'Add']
for idx5, n in enumerate(node_list4):
#print('node:', idx1, n.op_type, expected_pattern[idx1])
if n.op_type != expected_pattern[idx5]:
break
if idx5 == 4 and node_list4[4] == last_add_node:
logger.debug('match_mha_block_pattern_four, success 5')
node_list5 = get_node_group(model, node_list3[4].input[0], 4, [0,0,0,0])
if len(node_list5) == 4:
expected_pattern = ['Softmax', 'Add', 'Mul', 'MatMul']
for idx5, n in enumerate(node_list5):
#print('node:', idx1, n.op_type, expected_pattern[idx1])
if n.op_type != expected_pattern[idx5]:
break
if idx5 == 3:
logger.debug('match_mha_block_pattern_four, success 6')
match_times = 0
for i in range(2):
node_list_ = get_node_group(model, node_list5[3].input[i], 5, [0,0,0,0,0])
if len(node_list_) == 5:
expected_pattern = ['Transpose', 'Reshape', 'Add', 'MatMul', 'Add']
for idx_, n in enumerate(node_list_):
#print('node:', idx1, n.op_type, expected_pattern[idx1])
if n.op_type != expected_pattern[idx_]:
break
if idx_ == 4 and node_list_[4] == last_add_node:
match_times = match_times + 1
logger.debug('match_mha_block_pattern_four, success 7, match_times: {}'.format(match_times))
if match_times == 2:
logger.debug('match_mha_block_pattern_four, success!!!!')
res = 0
break
return res
def get_matmul_block_pattern_four(model, matmul_node):
logger.debug('into get_matmul_block_pattern_four')
res = -1
node_dict = {}
#input_next, ok = get_next_node_by_output(model, input_)
input_next = matmul_node
if input_next.op_type == 'MatMul':
shapeA = values.get_tensor_shape_by_name(model, input_next.input[0])
inputB, shapeB = values.get_init_value_and_shape(model, input_next.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('inputB is not in initilizer')
inputB = values.get_constant_value(model, input_next.input[1])
if len(shapeA) == 2 and len(shapeB) == 2:
logger.debug('--- got MatMul node {}'.format(input_next.name))
#node_list = [input_next, input_pp_pre, input_p_pre, input_pre]
#node_dict['node_list'] = node_list
node_dict['MatMul1'] = input_next
node_dict['matmulA1_Shape'] = shapeA
node_dict['inputB1'] = inputB
node_dict['matmulB1_Shape'] = shapeB
input_nnext, ok = get_next_node_by_output(model, input_next.output[0])
if ok == 0 and input_nnext.op_type == 'Add':
addA_name = input_nnext.input[0]
addA, shapeA = values.get_init_value_and_shape(model, input_nnext.input[0])
node_dict['addFirst'] = True
if len(shapeA) == 0:
addA_name = input_nnext.input[1]
addA, shapeA = values.get_init_value_and_shape(model, input_nnext.input[1])
node_dict['addFirst'] = False
if len(shapeA) == 1:
node_dict['Add1'] = input_nnext
logger.debug('--- got Add1 node {}'.format(input_nnext.name))
input_nnnext, ok = get_all_next_node_by_output(model, input_nnext.output[0])
if len(input_nnnext) == 3:
got_match_op = 0
for n in input_nnnext:
if n.op_type == 'Add':
node_dict['AddT'] = n
got_match_op = got_match_op + 1
if n.op_type == 'Pow':
node_dict['Pow'] = n
got_match_op = got_match_op + 1
if n.op_type == 'Mul':
node_dict['Mul'] = n
got_match_op = got_match_op + 1
if got_match_op == 3:
input_nnnnnext, ok = get_next_node_by_output(model, node_dict['Mul'].output[0])
if ok == 0 and input_nnnnnext.op_type == 'MatMul':
shapeA = values.get_tensor_shape_by_name(model, input_nnnnnext.input[0])
inputB, shapeB = values.get_init_value_and_shape(model, input_nnnnnext.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('inputB is not in initilizer')
inputB = values.get_constant_value(model, input_nnnnnext.input[1])
if len(shapeA) == 2 and len(shapeB) == 2:
logger.debug('--- got MatMul2 node: {}'.format(input_nnnnnext.name))
#node_list = [input_nnnnnext, input_pp_pre, input_p_pre, input_pre]
#node_dict['node_list'] = node_list
node_dict['MatMul2'] = input_nnnnnext
node_dict['matmulA2_Shape'] = shapeA
node_dict['inputB2'] = inputB
node_dict['matmulB2_Shape'] = shapeB
input_nnnnnnext, ok = get_next_node_by_output(model, input_nnnnnext.output[0])
if ok == 0 and input_nnnnnnext.op_type == 'Add':
logger.debug('--- got Add2 node: {}'.format(input_nnnnnnext.name))
##########
addA_name = input_nnnnnnext.input[0]
addA, shapeA = values.get_init_value_and_shape(model, input_nnnnnnext.input[0])
node_dict['addFirst2'] = True
if len(shapeA) == 0:
addA_name = input_nnnnnnext.input[1]
addA, shapeA = values.get_init_value_and_shape(model, input_nnnnnnext.input[1])
node_dict['addFirst2'] = False
if len(shapeA) == 1:
node_dict['Add2'] = input_nnnnnnext
next_node, ok = get_next_node_by_output(model, input_nnnnnnext.output[0])
if ok == 0 and next_node.op_type == 'Add':
logger.debug('--- got last Add node: {}'.format(next_node.name))
res = 0
node_dict['NextAdd'] = next_node
return node_dict, res
def get_mul_add_block_pattern_four(model):
logger.debug('into get_mul_add_block_pattern_four')
node_list = []
for node in model.graph.node:
if node.op_type == 'Mul':
#print('----got mul:', node.name)
next_node, ok = get_next_node_by_output(model, node.output[0])
if ok == 0 and next_node.op_type == 'Add':
#print('get_all_next_node_by_output---', next_node.output, node.name)
next_node_list, ok = get_all_next_node_by_output(model, next_node.output[0])
if ok == 0:
#print('next_node_list:', len(next_node_list))
if len(next_node_list) == 2:
#print('got next_node_list:', next_node_list[0].op_type, next_node_list[1].op_type)
if (next_node_list[0].op_type == 'Add' and next_node_list[1].op_type == 'MatMul') or \
(next_node_list[0].op_type == 'MatMul' and next_node_list[1].op_type == 'Add'):
logger.debug('got it~')
matmul_node = next_node_list[0]
if next_node_list[1].op_type == 'MatMul':
matmul_node = next_node_list[1]
node_dict, ret = get_matmul_block_pattern_four(model, matmul_node)
if ret == 0:
#print('got node dict:', node_dict)
node_dict['currentAdd'] = next_node
node_list.append(node_dict)
return node_list
def handle_mul_add_block_pattern_four(model):
node_list = get_mul_add_block_pattern_four(model)
#if len(node_list) > 0:
for node_dict in node_list:
logger.debug('++++++++++++++++++++++')
logger.debug('Add1: {}'.format(node_dict['Add1'].name))
logger.debug('Add2: {}'.format(node_dict['Add2'].name))
logger.debug('++++++++++++++++++++++')
matmul1 = node_dict['MatMul1']
add1 = node_dict['Add1']
matmul2 = node_dict['MatMul2']
add2 = node_dict['Add2']
currentAdd = node_dict['currentAdd']
nextAdd = node_dict['NextAdd']
pow_node = node_dict['Pow']
###add transpose
ts_name = currentAdd.name + '_transpose_'
ts_output_name = ts_name + '_output_'
add_output_shape = values.get_tensor_shape_by_name(model, currentAdd.output[0])
ts_output_shape = [add_output_shape[1], add_output_shape[0]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[currentAdd.output[0]],
outputs=[ts_output_name],
perm=[1,0])
model.graph.value_info.append(transpose_output)
###add reshape-1
rs_name = currentAdd.name + '_reshape_1_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [1, ts_output_shape[0], 1, ts_output_shape[1]]
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape_name = currentAdd.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs_node = onnx.helper.make_node(
'Reshape',
name=rs_name,
inputs=[ts_output_name, const_shape_name],
outputs=[rs_output_name])
model.graph.value_info.append(rs_output)
#########################
insert_node(model, rs_node, matmul1)
matmul1.input[0] = rs_output_name
insert_node(model, ts_node, rs_node)
nextAdd.input[1] = ts_output_name
#MatMul1--->Conv
matmul1.op_type = 'Conv'
logger.debug('-----reuse MatMul to Conv: {}'.format(matmul1.name))
const_x_name = matmul1.name + '_to_conv_x_'
v = node_dict['inputB1']
old_dims = [node_dict['matmulB1_Shape'][0], node_dict['matmulB1_Shape'][1]]
dims_ = [node_dict['matmulB1_Shape'][1], node_dict['matmulB1_Shape'][0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
model.graph.initializer.append(const_x_tensor)
matmul1.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
matmul1.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
matmul1.attribute.append(attr)
if node_dict['addFirst'] == True:
matmul1.input.append(add1.input[0])
else:
matmul1.input.append(add1.input[1])
output_shape = values.get_tensor_shape_by_name(model, matmul1.output[0])
conv_output_shape = [rs_output_shape[0], node_dict['matmulB1_Shape'][1], rs_output_shape[2], rs_output_shape[3]]#[1, output_shape[1], 1, output_shape[0]]
update_tensor_shape(model, matmul1.output[0], conv_output_shape)
#Add1--->Reshape
add1.op_type = 'Reshape'
del add1.attribute[:]
rs_name = add1.name + '_reshape_1_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [conv_output_shape[1], conv_output_shape[3]]
logger.debug('-----rs_output_shape: {}'.format(rs_output_shape))
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape_name = add1.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape_tensor)
if node_dict['addFirst'] == True:
add1.input[0] = add1.input[1]
add1.input[1] = const_shape_name
update_tensor_shape(model, add1.output[0], rs_output_shape)
mul_node = node_dict['Mul']
update_tensor_shape(model, mul_node.output[0], rs_output_shape)
#################################
#################################
###add reshape-1
rs2_name = matmul2.name + '_reshape_1_'
rs2_output_name = rs2_name + '_output_'
rs2_output_shape = [1, rs_output_shape[0], 1, rs_output_shape[1]]
rs_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape_name = matmul2.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
rs2_node = onnx.helper.make_node(
'Reshape',
name=rs2_name,
inputs=[matmul2.input[0], const_shape_name],
outputs=[rs2_output_name])
model.graph.value_info.append(rs_output)
insert_node(model, rs2_node, matmul2)
matmul2.input[0] = rs2_output_name
#MatMul2--->Conv
matmul2.op_type = 'Conv'
logger.debug('++++++reuse MatMul to Conv')
const_x_name = matmul2.name + '_to_conv_x_'
v = node_dict['inputB2']
old_dims = [node_dict['matmulB2_Shape'][0], node_dict['matmulB2_Shape'][1]]
dims_ = [node_dict['matmulB2_Shape'][1], node_dict['matmulB2_Shape'][0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
model.graph.initializer.append(const_x_tensor)
matmul2.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
matmul2.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
matmul2.attribute.append(attr)
if node_dict['addFirst2'] == True:
B = add2.input[0]
else:
B = add2.input[1]
matmul2.input.append(B)
output_shape = values.get_tensor_shape_by_name(model, matmul2.output[0])
conv_output_shape = [rs2_output_shape[0], node_dict['matmulB2_Shape'][1], rs2_output_shape[2], rs2_output_shape[3]]#[1, output_shape[1], 1, output_shape[0]]
update_tensor_shape(model, matmul2.output[0], conv_output_shape)
#Add2--->Reshape
add2.op_type = 'Reshape'
del add2.attribute[:]
rs2_name = add2.name + '_reshape_1_'
rs2_output_name = rs2_name + '_output_'
rs2_output_shape = [conv_output_shape[1], conv_output_shape[3]]
rs_output = onnx.helper.make_tensor_value_info(rs2_output_name, onnx.TensorProto.FLOAT, rs2_output_shape)
const_shape_name = add2.name + '_reshape_data_'
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
if node_dict['addFirst2'] == True:
add2.input[0] = add2.input[1]
add2.input[1] = const_shape_name
update_tensor_shape(model, add2.output[0], rs2_output_shape)
######update tensor shape
pow_output_shape = values.get_tensor_shape_by_name(model, pow_node.output[0])
new_shape = [pow_output_shape[1], pow_output_shape[0]]
update_tensor_shape(model, pow_node.output[0], new_shape)
mul_node, ok = get_next_node_by_output(model, pow_node.output[0])
if ok == 0 and mul_node.op_type == 'Mul':
mul_output_shape = values.get_tensor_shape_by_name(model, mul_node.output[0])
new_shape = [mul_output_shape[1], mul_output_shape[0]]
update_tensor_shape(model, mul_node.output[0], new_shape)
add_node_internal, ok = get_next_node_by_output(model, mul_node.output[0])
if ok == 0 and add_node_internal.op_type == 'Add':
addi_output_shape = values.get_tensor_shape_by_name(model, add_node_internal.output[0])
new_shape = [addi_output_shape[1], addi_output_shape[0]]
update_tensor_shape(model, add_node_internal.output[0], new_shape)
mul_node1, ok = get_next_node_by_output(model, add_node_internal.output[0])
if ok == 0 and mul_node1.op_type == 'Mul':
mul1_output_shape = values.get_tensor_shape_by_name(model, mul_node1.output[0])
new_shape = [mul1_output_shape[1], mul1_output_shape[0]]
update_tensor_shape(model, mul_node1.output[0], new_shape)
tanh_node, ok = get_next_node_by_output(model, mul_node1.output[0])
if ok == 0 and tanh_node.op_type == 'Tanh':
tanh_output_shape = values.get_tensor_shape_by_name(model, tanh_node.output[0])
new_shape = [tanh_output_shape[1], tanh_output_shape[0]]
update_tensor_shape(model, tanh_node.output[0], new_shape)
add_node2, ok = get_next_node_by_output(model, tanh_node.output[0])
if ok == 0 and add_node2.op_type == 'Add':
add_output_shape = values.get_tensor_shape_by_name(model, add_node2.output[0])
new_shape = [add_output_shape[1], add_output_shape[0]]
update_tensor_shape(model, add_node2.output[0], new_shape)
mul_node2, ok = get_next_node_by_output(model, add_node2.output[0])
if ok == 0 and mul_node2.op_type == 'Mul':
mul2_output_shape = values.get_tensor_shape_by_name(model, mul_node2.output[0])
new_shape = [mul2_output_shape[1], mul2_output_shape[0]]
update_tensor_shape(model, mul_node2.output[0], new_shape)
######insert Transpose before ReduceMean and Sub
update_tensor_shape(model, nextAdd.output[0], rs2_output_shape)
rm_sub, ok = get_all_next_node_by_output(model, nextAdd.output[0])
if ok == 0 and len(rm_sub) == 3:
logger.debug('got reducemean and sub node---')
sub_node = None
rm_node = None
mul_node = None
for n in rm_sub:
if n.op_type == 'Sub':
sub_node = n
if n.op_type == 'ReduceMean':
rm_node = n
if n.op_type == 'Mul':
mul_node = n
if sub_node != None and rm_node != None and mul_node != None:
###add transpose
ts3_name = nextAdd.name + '_transpose_'
ts3_output_name = ts3_name + '_output_'
add3_output_shape = values.get_tensor_shape_by_name(model, nextAdd.output[0])
ts3_output_shape = [add3_output_shape[1], add3_output_shape[0]]
ts3_output = onnx.helper.make_tensor_value_info(ts3_output_name, onnx.TensorProto.FLOAT, ts3_output_shape)
ts3_node = onnx.helper.make_node(
'Transpose',
name=ts3_name,
inputs=[nextAdd.output[0]],
outputs=[ts3_output_name],
perm=[1,0])
model.graph.value_info.append(ts3_output)
insert_node(model, ts3_node, sub_node)
sub_node.input[0] = ts3_output_name
rm_node.input[0] = ts3_output_name
mul_node.input[0] = ts3_output_name
def get_last_group_pattern_four(model):
graph_output = []
node_dict = {}
res = -1
for o in model.graph.output:
graph_output.append(o.name)
for node in model.graph.node:
if node.output[0] in graph_output:
#print('got mul:', node.name)
if node.op_type == 'Squeeze':
split_node, ok = get_prev_node_by_input(model, node.input[0])
if ok == 0 and split_node.op_type == 'Split':
logger.debug('got Split node: {}'.format(split_node.name))
node_dict['Split'] = split_node
tp_node, ok = get_prev_node_by_input(model, split_node.input[0])
if ok == 0 and tp_node.op_type == 'Transpose':
rs_node, ok = get_prev_node_by_input(model, tp_node.input[0])
if ok == 0 and rs_node.op_type == 'Reshape':
node_dict['Reshape'] = rs_node
add_node, ok = get_prev_node_by_input(model, rs_node.input[0])
if ok == 0 and add_node.op_type == 'Add':
logger.debug('get_last_group_pattern_four, got Add node: {}'.format(add_node.name))
node_dict['Add'] = add_node
matmul_node, ok = get_prev_node_by_input(model, add_node.input[0])
if ok == 0 and matmul_node.op_type == 'MatMul':
logger.debug('get_last_group_pattern_four, got MatMul node: {}'.format(matmul_node.name))
shapeA = values.get_tensor_shape_by_name(model, matmul_node.input[0])
inputB, shapeB = values.get_init_value_and_shape(model, matmul_node.input[1])
if isinstance(inputB, list) and inputB == []:
logger.debug('inputB is not in initilizer')
inputB = values.get_constant_value(model, matmul_node.input[1])
if len(shapeA) == 2 and len(shapeB) == 2:
logger.debug('get_last_group_pattern_four, got MatMul node: {}'.format(matmul_node.name))
node_dict['MatMul'] = matmul_node
node_dict['matmulA_Shape'] = shapeA
node_dict['inputB'] = inputB
node_dict['matmulB_Shape'] = shapeB
rs_node2, ok = get_prev_node_by_input(model, matmul_node.input[0])
if ok == 0 and rs_node2.op_type == 'Reshape':
logger.debug('get_last_group_pattern_four, got Reshape node2: {}'.format(rs_node2.name))
node_dict['Reshape2'] = rs_node2
res = 0
break
return node_dict, res
#Reshape->MatMul->Add->Reshape->Transpose->Split
def handle_last_group_pattern_four(model):
node_dict, ok = get_last_group_pattern_four(model)
if ok == 0:
logger.debug('start handle_last_group')
matmul_node = node_dict['MatMul']
rs_node = node_dict['Reshape']
add_node = node_dict['Add']
rs_node2 = node_dict['Reshape2']
###add transpose
ts_name = rs_node2.name + '_transpose_'
ts_output_name = ts_name + '_output_'
rs2_output_shape = values.get_tensor_shape_by_name(model, rs_node2.output[0])
ts_output_shape = [rs2_output_shape[1], rs2_output_shape[0]]
transpose_output = onnx.helper.make_tensor_value_info(ts_output_name, onnx.TensorProto.FLOAT, ts_output_shape)
ts_node = onnx.helper.make_node(
'Transpose',
name=ts_name,
inputs=[rs_node2.output[0]],
outputs=[ts_output_name],
perm=[1,0])
model.graph.value_info.append(transpose_output)
###add reshape
rs_name = rs_node2.name + '_reshape_2_'
rs_output_name = rs_name + '_output_'
rs_output_shape = [1, ts_output_shape[0], 1, ts_output_shape[1]]
rs_output = onnx.helper.make_tensor_value_info(rs_output_name, onnx.TensorProto.FLOAT, rs_output_shape)
const_shape2_name = rs_node2.name + '_reshape2_data_'
const_shape2_tensor = onnx.helper.make_tensor(name=const_shape2_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs_output_shape)],
vals=rs_output_shape)
model.graph.initializer.append(const_shape2_tensor)
rs_node_ = onnx.helper.make_node(
'Reshape',
name=rs_name,
inputs=[ts_output_name, const_shape2_name],
outputs=[rs_output_name])
model.graph.value_info.append(rs_output)
insert_node(model, rs_node_, matmul_node)
matmul_node.input[0] = rs_output_name
insert_node(model, ts_node, rs_node_)
#MatMul-->Conv
matmul_node.op_type = 'Conv'
const_x_name = matmul_node.name + '_to_conv_x_'
v = node_dict['inputB']
old_dims = [node_dict['matmulB_Shape'][0], node_dict['matmulB_Shape'][1]]
dims_ = [node_dict['matmulB_Shape'][1], node_dict['matmulB_Shape'][0],1,1]
if isinstance(v, np.ndarray) == True:
A = v.reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('+++A.shape: {}'.format(A.shape))
A = A.flatten()
else:
A = np.array(v).reshape(*old_dims)
A = A.transpose()
A = A.reshape(*dims_)
logger.debug('---A.shape: {}'.format(A.shape))
A = A.flatten()
A = A.tolist()
const_x_tensor = onnx.helper.make_tensor(name=const_x_name,
data_type=onnx.TensorProto.FLOAT,
dims=dims_,
vals=A)
model.graph.initializer.append(const_x_tensor)
matmul_node.input[1] = const_x_name
attr = onnx.helper.make_attribute('dilations', [1, 1])
matmul_node.attribute.append(attr)
attr = onnx.helper.make_attribute('group', 1)
matmul_node.attribute.append(attr)
attr = onnx.helper.make_attribute('kernel_shape', [1,1])
matmul_node.attribute.append(attr)
attr = onnx.helper.make_attribute('pads', [0,0,0,0])
matmul_node.attribute.append(attr)
attr = onnx.helper.make_attribute('strides', [1,1])
matmul_node.attribute.append(attr)
matmul_node.input.append(add_node.input[1])
#mm_output_shape = values.get_tensor_shape_by_name(model, matmul_node.output[0])
conv_output_shape = rs_output_shape#[mm_output_shape[0], mm_output_shape[2], 1, mm_output_shape[1]]
conv_output_shape[1] = node_dict['matmulB_Shape'][1]
update_tensor_shape(model, matmul_node.output[0], conv_output_shape)
###########
add_node.op_type = 'Reshape'
reshape_output = add_node.output[0]
const_shape_name = add_node.name + '_to_reshape_'
add_output_shape = values.get_tensor_shape_by_name(model, add_node.output[0])
rs2_output_shape = [add_output_shape[1], add_output_shape[0]]
const_shape_tensor = onnx.helper.make_tensor(name=const_shape_name,
data_type=onnx.TensorProto.INT64,
dims=[len(rs2_output_shape)],
vals=rs2_output_shape)
model.graph.initializer.append(const_shape_tensor)
add_node.input[1] = const_shape_name
update_tensor_shape(model, add_node.output[0], rs2_output_shape)
###add transpose
ts2_name = add_node.name + '_transpose_'
ts2_output_name = ts2_name + '_output_'
ts2_output_shape = [rs2_output_shape[1], rs2_output_shape[0]]
transpose_output = onnx.helper.make_tensor_value_info(ts2_output_name, onnx.TensorProto.FLOAT, ts2_output_shape)
ts2_node = onnx.helper.make_node(
'Transpose',
name=ts2_name,
inputs=[add_node.output[0]],
outputs=[ts2_output_name],
perm=[1,0])
model.graph.value_info.append(transpose_output)
insert_node(model, ts2_node, rs_node)
rs_node.input[0] = ts2_output_name
'''
if __name__ == "__main__":
#model = onnx.load('/home/zqiu/models/bert_sst2_sim.onnx')
#model = onnx.load('./bert_sst2_sub1.onnx')
#model = onnx.load('./decoder_model_bs10_sim.onnx')
model = onnx.load('./bert_squad_v1_sim1.onnx')
#model = onnx.load('./bert_sub2.onnx')
#model = onnx.load('/home/zqiu/models/bert_cls_sim1.onnx')
mha_optimizer(model)
onnx.save(model, './hs3.onnx')
'''
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/littezheng/onnx_convert.git
git@gitee.com:littezheng/onnx_convert.git
littezheng
onnx_convert
onnx_convert
master

搜索帮助