1 Star 0 Fork 0

Whiskey_Priest/Mask-Predict

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
score.py 3.08 KB
一键复制 编辑 原始数据 按行查看 历史
Yinhan Liu 提交于 2019-09-04 23:12 . Initial commit
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
BLEU scoring of generated translations against reference translations.
"""
import argparse
import os
import sys
from fairseq import bleu
from fairseq.data import dictionary
def get_parser():
parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
# fmt: off
parser.add_argument('-s', '--sys', default='-', help='system output')
parser.add_argument('-r', '--ref', required=True, help='references')
parser.add_argument('-o', '--order', default=4, metavar='N',
type=int, help='consider ngrams up to this order')
parser.add_argument('--ignore-case', action='store_true',
help='case-insensitive scoring')
parser.add_argument('--sacrebleu', action='store_true',
help='score with sacrebleu')
parser.add_argument('--sentence-bleu', action='store_true',
help='report sentence-level BLEUs (i.e., with +1 smoothing)')
# fmt: on
return parser
def main():
parser = get_parser()
args = parser.parse_args()
print(args)
assert args.sys == '-' or os.path.exists(args.sys), \
"System output file {} does not exist".format(args.sys)
assert os.path.exists(args.ref), \
"Reference file {} does not exist".format(args.ref)
dict = dictionary.Dictionary()
def readlines(fd):
for line in fd.readlines():
if args.ignore_case:
yield line.lower()
else:
yield line
if args.sacrebleu:
import sacrebleu
def score(fdsys):
with open(args.ref) as fdref:
print(sacrebleu.corpus_bleu(fdsys, [fdref]))
elif args.sentence_bleu:
def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for i, (sys_tok, ref_tok) in enumerate(zip(readlines(fdsys), readlines(fdref))):
scorer.reset(one_init=True)
sys_tok = dict.encode_line(sys_tok)
ref_tok = dict.encode_line(ref_tok)
scorer.add(ref_tok, sys_tok)
print(i, scorer.result_string(args.order))
else:
def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = dict.encode_line(sys_tok)
ref_tok = dict.encode_line(ref_tok)
scorer.add(ref_tok, sys_tok)
print(scorer.result_string(args.order))
if args.sys == '-':
score(sys.stdin)
else:
with open(args.sys, 'r') as f:
score(f)
if __name__ == '__main__':
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/dasdsadasdas/Mask-Predict.git
git@gitee.com:dasdsadasdas/Mask-Predict.git
dasdsadasdas
Mask-Predict
Mask-Predict
main

搜索帮助