代码拉取完成,页面将自动刷新
import imp
from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
process_evaluation_epoch)
from jasper import config
from jasper.model import Jasper
import torch
import torch.distributed as dist
def print_once(msg):
if not dist.is_initialized() or dist.get_rank() == 0:
print(msg)
def add_ctc_blank(symbols):
return symbols + ['<BLANK>']
def convert_v1_state_dict(state_dict):
rules = [
('^jasper_encoder.encoder.', 'encoder.layers.'),
('^jasper_decoder.decoder_layers.', 'decoder.layers.'),
]
ret = {}
for k, v in state_dict.items():
if k.startswith('acoustic_model.'):
continue
if k.startswith('audio_preprocessor.'):
continue
for pattern, to in rules:
k = re.sub(pattern, to, k)
ret[k] = v
return ret
def load(self, fpath, model, ema_model, optimizer, scaler, meta):
print_once(f'Loading model from {fpath}')
checkpoint = torch.load(fpath, map_location="cpu")
unwrap_ddp = lambda model: getattr(model, 'module', model)
state_dict = convert_v1_state_dict(checkpoint['state_dict'])
unwrap_ddp(model).load_state_dict(state_dict, strict=True)
if ema_model is not None:
if checkpoint.get('ema_state_dict') is not None:
key = 'ema_state_dict'
else:
key = 'state_dict'
print_once('WARNING: EMA weights not found in the checkpoint.')
print_once('WARNING: Initializing EMA model with regular params.')
state_dict = convert_v1_state_dict(checkpoint[key])
unwrap_ddp(ema_model).load_state_dict(state_dict, strict=True)
optimizer.load_state_dict(checkpoint['optimizer'])
scaler.load_state_dict(checkpoint['scaler'])
meta['start_epoch'] = checkpoint.get('epoch')
meta['best_wer'] = checkpoint.get('best_wer', meta['best_wer'])
cfg = config.load('configs\jasper10x5dr_speedp-online_speca.yaml')
symbols = add_ctc_blank(cfg['labels'])
model = Jasper(encoder_kw=config.encoder(cfg),
decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
# model.cuda()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。