代码拉取完成,页面将自动刷新
import jieba
import nltk
import torch
import string
import time
import numpy as np
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from nltk.tokenize.treebank import TreebankWordTokenizer, TreebankWordDetokenizer
from nltk.parse import CoreNLPDependencyParser
from utils.bert import perturb
from utils.youdaotranslator import youdaotranslate
from utils.distance import depDistance
# initializing
chi_parser = CoreNLPDependencyParser(url='http://localhost:9000')
# nltk.download('averaged_perceptron_tagger')
tokenizer = TreebankWordTokenizer()
detokenizer = TreebankWordDetokenizer()
# test filename
testType = 'test'
threshold = 3.0
# for counting
total = 0
valid = 0
derivated = 0
valid_error = 0
print(f'initialize finished')
def main_loop(L):
global total, valid, derivated, valid_error
# translated source input
translated_source_sentencesL = L
# segment result of 'translated source input'
translated_source_sentenceL_seg = []
translated_source_sentenceL = youdaotranslate(untranslated_source_sentencesL, 'en', 'zh-CHS')
# print(f'translated source sentences: {translated_source_sentenceL}')
print(f'source translation finished')
for translated_source_sentence in translated_source_sentenceL:
total += 1
valid += 1
translated_source_sentence_seg = ' '.join(jieba.cut(translated_source_sentence))
translated_source_sentenceL_seg.append(translated_source_sentence_seg)
# print(f'translated source sentences segmented: {translated_source_sentenceL_seg}')
print(f'source segmentation finished')
source_TreeL = [i for (i,) in
chi_parser.raw_parse_sents(translated_source_sentenceL_seg, properties={'ssplit.eolonly': 'true'})]
# print(f'source dependency tree: {source_TreeL}')
print(f'source dependency tree built')
# MetaMorph for all input
for idx, origin_source_sent in enumerate(untranslated_source_sentencesL):
print(f'processing sentence {idx} {origin_source_sent}')
org_sentence = origin_source_sent
translated_sentence = translated_source_sentenceL[idx]
# use bert to get derived sentences
derived_sentencesL = perturb(org_sentence, 5)
# print(f'derived sentences: {derived_sentencesL}')
if len(derived_sentencesL) > 10:
derived_sentencesL = derived_sentencesL[:10]
print(f'bert perturbation {idx} finished got {len(derived_sentencesL)} derived sentences')
derivated += len(derived_sentencesL)
if len(derived_sentencesL) == 0:
print('no derived sentences')
valid -= 1
continue
# translate derived sentences
translated_derived_sentencesL = youdaotranslate(derived_sentencesL, 'en', 'zh-CHS')
translated_derived_sentencesL_seg = []
# segment translated derived sentences
for translated_derived_sentence in translated_derived_sentencesL:
translated_derived_sentence_seg = ' '.join(jieba.cut(translated_derived_sentence))
translated_derived_sentencesL_seg.append(translated_derived_sentence_seg)
# build dependency tree for translated derived sentences
derived_TreeL = [i for (i,) in chi_parser.raw_parse_sents(translated_derived_sentencesL_seg,
properties={'ssplit.eolonly': 'true'})]
# print(f'derived dependency tree: {derived_TreeL}')
for i in range(len(derived_TreeL)):
# compare dependency tree
distance = depDistance(source_TreeL[idx].triples(), derived_TreeL[i].triples())
print(f'distance of source{idx} and derived{i}: {distance}')
if distance > threshold:
valid_error += 1
with open('error.txt', 'a') as f:
f.write(f'{valid_error}------------------------------------------\n')
f.write(f'origin sentence: {org_sentence}\n')
f.write(f'translated sentence: {translated_sentence}\n')
f.write(f'derived sentence: {derived_sentencesL[i]}\n')
f.write(f'translated derived sentence: {translated_derived_sentencesL[i]}\n')
f.write(f'distance: {distance}\n')
f.write('----------------------------------------------\n')
print(f'total: {total}')
print(f'valid: {valid}')
print(f'derivated: {derivated}')
print(f'valid_error: {valid_error}')
return
# Load source input
# untranslated source input
untranslated_source_sentencesL = []
sum = 0
with open('./dataset/'+testType, 'r', encoding='utf-8') as f:
for line in f:
if line.strip() == "" or line.strip() == "<P>" or line.strip() == "<HEADLINE>":
continue
sum = sum + 1
untranslated_source_sentencesL.append(line.strip())
if len(untranslated_source_sentencesL) == 10:
main_loop(untranslated_source_sentencesL)
untranslated_source_sentencesL = []
if sum >= 100:
break
print(f'total_all: {total}')
print(f'valid_all: {valid}')
print(f'derivated_all: {derivated}')
print(f'valid_error_all: {valid_error}')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。