1 Star 0 Fork 2

liq159159/Easy_Lstm_Cnn

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
predict.py 1.49 KB
一键复制 编辑 原始数据 按行查看 历史
NLPxiaoxu 提交于 2019-01-27 14:37 . Add files via upload
import numpy as np
from Lstm_Cnn import Lstm_CNN
import tensorflow as tf
from data_processing import read_category, get_wordid, get_word2vec, process, batch_iter, seq_length
from Parameters import Parameters as pm
def val():
pre_label = []
label = []
session = tf.Session()
session.run(tf.global_variables_initializer())
save_path = tf.train.latest_checkpoint('./checkpoints/Lstm_CNN')
saver = tf.train.Saver()
saver.restore(sess=session, save_path=save_path)
val_x, val_y = process(pm.val_filename, wordid, cat_to_id, max_length=pm.seq_length)
batch_val = batch_iter(val_x, val_y, batch_size=64)
for x_batch, y_batch in batch_val:
real_seq_len = seq_length(x_batch)
feed_dict = model.feed_data(x_batch, y_batch, real_seq_len, 1.0)
pre_lab = session.run(model.predict, feed_dict=feed_dict)
pre_label.extend(pre_lab)
label.extend(y_batch)
return pre_label, label
if __name__ == '__main__':
pm = pm
sentences = []
label2 = []
categories, cat_to_id = read_category()
wordid = get_wordid(pm.vocab_filename)
pm.vocab_size = len(wordid)
pm.pre_trianing = get_word2vec(pm.vector_word_npz)
model = Lstm_CNN()
pre_label, label = val()
correct = np.equal(pre_label, np.argmax(label, 1))
accuracy = np.mean(np.cast['float32'](correct))
print('accuracy:', accuracy)
print("预测前10项:", ' '.join(str(pre_label[:10])))
print("正确前10项:", ' '.join(str(np.argmax(label[:10], 1))))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/liq159159/Easy_Lstm_Cnn.git
git@gitee.com:liq159159/Easy_Lstm_Cnn.git
liq159159
Easy_Lstm_Cnn
Easy_Lstm_Cnn
master

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385