1 Star 0 Fork 1

Lengien/pytorch_jasper

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
pt2mindspore.py 2.12 KB
一键复制 编辑 原始数据 按行查看 历史
Lengien 提交于 2022-03-31 10:19 . first commit
import imp
from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
process_evaluation_epoch)
from jasper import config
from jasper.model import Jasper
import torch
import torch.distributed as dist
def print_once(msg):
if not dist.is_initialized() or dist.get_rank() == 0:
print(msg)
def add_ctc_blank(symbols):
return symbols + ['<BLANK>']
def convert_v1_state_dict(state_dict):
rules = [
('^jasper_encoder.encoder.', 'encoder.layers.'),
('^jasper_decoder.decoder_layers.', 'decoder.layers.'),
]
ret = {}
for k, v in state_dict.items():
if k.startswith('acoustic_model.'):
continue
if k.startswith('audio_preprocessor.'):
continue
for pattern, to in rules:
k = re.sub(pattern, to, k)
ret[k] = v
return ret
def load(self, fpath, model, ema_model, optimizer, scaler, meta):
print_once(f'Loading model from {fpath}')
checkpoint = torch.load(fpath, map_location="cpu")
unwrap_ddp = lambda model: getattr(model, 'module', model)
state_dict = convert_v1_state_dict(checkpoint['state_dict'])
unwrap_ddp(model).load_state_dict(state_dict, strict=True)
if ema_model is not None:
if checkpoint.get('ema_state_dict') is not None:
key = 'ema_state_dict'
else:
key = 'state_dict'
print_once('WARNING: EMA weights not found in the checkpoint.')
print_once('WARNING: Initializing EMA model with regular params.')
state_dict = convert_v1_state_dict(checkpoint[key])
unwrap_ddp(ema_model).load_state_dict(state_dict, strict=True)
optimizer.load_state_dict(checkpoint['optimizer'])
scaler.load_state_dict(checkpoint['scaler'])
meta['start_epoch'] = checkpoint.get('epoch')
meta['best_wer'] = checkpoint.get('best_wer', meta['best_wer'])
cfg = config.load('configs\jasper10x5dr_speedp-online_speca.yaml')
symbols = add_ctc_blank(cfg['labels'])
model = Jasper(encoder_kw=config.encoder(cfg),
decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
# model.cuda()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/legendlengien/pytorch_jasper.git
git@gitee.com:legendlengien/pytorch_jasper.git
legendlengien
pytorch_jasper
pytorch_jasper
master

搜索帮助