1 Star 0 Fork 0

yangxin/SubCharTokenization

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
get_test_preds.py 1.76 KB
一键复制 编辑 原始数据 按行查看 历史
NoviScl 提交于 2021-12-22 21:11 . push
import json
import pickle
import os
from mrc.preprocess.CHID_preprocess import RawResult, logits_matrix_to_array
seeds = [2, 23, 234]
def get_final_predictions(all_results, tmp_predict_file, g=True):
# if not os.path.exists(tmp_predict_file):
# pickle.dump(all_results, open(tmp_predict_file, 'wb'))
raw_results = {}
for i, elem in enumerate(all_results):
example_id = elem.example_id
if example_id not in raw_results:
raw_results[example_id] = [(elem.tag, elem.logit)]
else:
raw_results[example_id].append((elem.tag, elem.logit))
results = []
for example_id, elem in raw_results.items():
index_2_idiom = {index: tag for index, (tag, logit) in enumerate(elem)}
logits = [logit for _, logit in elem]
if g:
results.extend(logits_matrix_to_array(logits, index_2_idiom))
else:
results.extend(logits_matrix_max_array(logits, index_2_idiom))
return results
def write_predictions(results, output_prediction_file):
# output_prediction_file = result6.csv
# results = pd.DataFrame(results)
# results.to_csv(output_prediction_file, header=None, index=None)
results_dict = {}
for result in results:
results_dict[result[0]] = result[1]
with open(output_prediction_file, 'w') as w:
json.dump(results_dict, w, indent=2)
print("Writing predictions to: {}".format(output_prediction_file))
for seed in seeds:
tmp_predict_file = 'logs/chid/raw_zh/' + str(seed) +'/raw_predictions.pkl'
json_file = 'logs/chid/raw_zh/' + str(seed) + '/test_predictions.json'
result = pickle.load(open(tmp_predict_file, 'rb'))
results = get_final_predictions(result, tmp_predict_file, g=True)
write_predictions(results, json_file)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yx75/SubCharTokenization.git
git@gitee.com:yx75/SubCharTokenization.git
yx75
SubCharTokenization
SubCharTokenization
main

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385