代码拉取完成,页面将自动刷新
from rouge import Rouge
import matplotlib.pyplot as plt
from Tokenize import Tokenize
import jieba
import torch
import os
import numpy as np
import argparse
import logging
from transformers.modeling_gpt2 import GPT2LMHeadModel
import torch.nn.functional as F
PAD = '[PAD]'
pad_id = 0
sep_id = 3
cls_id = 2
def set_interact_args():
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='0,1', type=str, required=False, help='生成设备')
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成的temperature')
parser.add_argument('--topk', default=0, type=int, required=False, help='最高k选1')
parser.add_argument('--topp', default=0.9, type=float, required=False, help='最高积累概率')
parser.add_argument('--model_config', default='summary_model/config.json', type=str, required=False,
help='模型参数')
parser.add_argument('--log_path', default='log/evaluate.log', type=str, required=False, help='日志存放位置')
parser.add_argument('--vocab_path', default='chinese_wobert_L-12_H-768_A-12/vocab.txt', type=str, required=False, help='选择词库')
parser.add_argument('--model_path', default='summary_model/', type=str, required=False, help='模型路径')
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
help="重复惩罚参数,若生成的重复性较高,可适当提高该参数")
parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
parser.add_argument('--max_len', type=int, default=50, help='summary的最大长度,超过指定长度则进行截断')
parser.add_argument('--no_cuda', default=False, help='不使用GPU进行预测')
parser.add_argument('--article', default='data/data_for_evaluate/article.txt', help='新闻文本')
parser.add_argument('--summary', default='data/data_for_evaluate/summary.txt', help='对应标题')
parser.add_argument('--result', default='data/data_for_evaluate/result.txt', help='测试生成的结果')
parser.add_argument('--count', default=1000, help='用于评价模型的样本数量')
return parser.parse_args()
def create_logger(args):
"""
将日志输出到日志文件和控制台
"""
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s')
# 创建一个handler,用于写入日志文件
file_handler = logging.FileHandler(
filename=args.log_path)
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)
# 创建一个handler,用于将日志输出到控制台
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
console.setFormatter(formatter)
logger.addHandler(console)
return logger
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# torch.topk()返回最后一维最大的top_k个元素,返回值为二维(values,indices)
# ...表示其他维度由计算机自行推断
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) # 对logits进行递减排序
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
def main():
args = set_interact_args()
logger = create_logger(args)
# 当用户使用GPU,并且GPU可用时
args.cuda = torch.cuda.is_available() and not args.no_cuda
# args.cuda = False
device = 'cuda' if args.cuda else 'cpu'
logger.info('using device:{}'.format(device))
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
tokenizer = Tokenize(
token_dict=args.vocab_path,
pre_tokenize=lambda s: jieba.cut(s)
)
model = GPT2LMHeadModel.from_pretrained(args.model_path)
model.to(device)
model.eval()
f_article = open(args.article, 'r', encoding='utf-8')
f_summary = open(args.summary, 'r', encoding='utf-8')
f_w = open(args.result, 'w', encoding='utf-8')
count = 0
rouge = Rouge()
x = np.arange(1, args.count+1)
scores_f = []
scores_p = []
scores_r = []
print('***********************evaluate start************************')
oom_time = 0
for line_article, line_summary in zip(f_article, f_summary):
if count >= args.count:
break
try:
text = str(line_article)
if len(text):
text = text[:1000]
input_ids = [cls_id] # 每个input以[CLS]为开头
input_ids.extend(tokenizer.tokens_to_ids(text))
input_ids.append(sep_id)
curr_input_tensor = torch.tensor(input_ids).long().to(device)
generated = []
# 最多生成max_len个token
for _ in range(args.max_len):
outputs = model(input_ids=curr_input_tensor)
next_token_logits = outputs[0][-1, :]
# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
for id in set(generated):
next_token_logits[id] /= args.repetition_penalty
next_token_logits = next_token_logits / args.temperature
# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
next_token_logits[tokenizer.token_to_id('[UNK]')] = -float('Inf')
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if next_token == sep_id:
break
generated.append(next_token.item())
curr_input_tensor = torch.cat((curr_input_tensor, next_token), dim=0)
text = tokenizer.ids_to_tokens(generated)
print('summary:'+''.join(text))
text_seg = jieba.lcut(str(''.join(text)))
line_summary_seg = jieba.lcut(str(line_summary))
while len(text_seg) < len(line_summary_seg):
text_seg.append(' ')
while len(line_summary_seg) < len(text_seg):
line_summary_seg.append(' ')
print(text_seg)
print(line_summary_seg)
rouge_score = rouge.get_scores(' '.join(text_seg), ' '.join(line_summary_seg))
logger.info(str(count)+': {}'.format(rouge_score[0]["rouge-l"]['f']))
f_w.write('summary:'+''.join(text)+str(count)+': {}'.format(rouge_score[0]["rouge-l"]['f'])+'\n')
scores_f.append(rouge_score[0]["rouge-1"]['f'])
scores_p.append(rouge_score[0]["rouge-1"]['p'])
scores_r.append(rouge_score[0]["rouge-1"]['r'])
count = count+1
# print("summary:" + "".join(text))
except RuntimeError as exception:
if "out of memory" in str(exception):
oom_time += 1
logger.info("WARNING: ran out of memory,times: {}".format(oom_time))
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
logger.info(str(exception))
raise exception
except KeyboardInterrupt:
break
plt.title('f1')
plt.scatter(x, scores_f, marker='.')
plt.show()
plt.title('p')
plt.scatter(x, scores_p, marker='.')
plt.show()
plt.hist(scores_p, bins=10, rwidth=0.8)
plt.title('precisions')
plt.show()
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。