1 Star 0 Fork 1

unknow1216/CNN-text-Classifier

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
predict.py 1.48 KB
一键复制 编辑 原始数据 按行查看 历史
unknow1216 提交于 2021-01-05 17:07 . 精简代码删除无用空行
import os
import re
import jieba
from numpy import *
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.models import load_model
#预定义变量
MAX_SEQUENCE_LENGTH = 100 #最大序列长度
def readFlie(path): #读取一个样本的记录,默认一个文件一条样本
with open(path,'r',errors='ignore') as file:
content = file.read()
file.close()
return content
def getStopWord(inputFile): #获取停用词表
stopWordList = readFlie(inputFile).splitlines()
return stopWordList
def remove_punctuation(line):
line = str(line)
if line.strip()=='':
return ''
rule = re.compile(u"[^a-zA-Z0-9\u4E00-\u9FA5]")
line = rule.sub('',line)
return line
def predict(text,model,stopWordList):
stopwords = getStopWord(stopWordList)
tokenizer = Tokenizer()
txt = remove_punctuation(text)
txt = [" ".join([w for w in list(jieba.cut(txt)) if w not in stopwords])]
tokenizer.fit_on_texts(txt)
print(txt)
seq = tokenizer.texts_to_sequences(txt)
padded = pad_sequences(seq, maxlen=MAX_SEQUENCE_LENGTH)
print(seq)
pred = model.predict(padded)
print(pred)
cat_id = pred.argmax(axis=1)
return cat_id
if __name__ == '__main__':
model = load_model('cnn.h5')
stopWord_path = "./stop/stopword.txt" # 停用词路径
print(predict("人民银行长沙中心支行快速响应 做好疫情期间金融服务",model,stopWord_path))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/unknow1216/cnn-text-classifier.git
git@gitee.com:unknow1216/cnn-text-classifier.git
unknow1216
cnn-text-classifier
CNN-text-Classifier
master

搜索帮助