1 Star 0 Fork 1

jinxuan9712/ChineseNumberIdentify

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
ChineseNumberIdentify.py 4.97 KB
一键复制 编辑 原始数据 按行查看 历史
提交于 2018-10-14 18:32 . 遵循google规范
#-*- coding:utf-8 -*-
from pybrain.tools.shortcuts import buildNetwork
from pybrain.datasets import SupervisedDataSet
from pybrain.supervised.trainers import BackpropTrainer
from pybrain.structure import TanhLayer
import numpy as np
#import os,image as Image
import os
from PIL import Image
import cPickle
#import copy_reg, copy, pickle
import datetime
def get_train_samples(input_num,output_num):
'''
从new_samples文件夹中读图,根据输入数和输出数制作样本,每一原始样本加入随机噪音生成100个样本
'''
print('getsample start.')
sam_path='./new_samples'
samples = SupervisedDataSet(input_num,output_num)
nlist = os.listdir(sam_path)
t=int(np.sqrt(input_num))
for n in nlist:
file = os.path.join(sam_path,n)
im = Image.open(file)
im = im.convert('L')
im = im.resize((t,t),Image.BILINEAR)
buf = np.array(im).reshape(input_num,1)
buf = buf<200
buf = tuple(buf)
buf1=int(n.split('.')[0])
buf2=range(output_num)
for i in range(len(buf2)):
buf2[i] = 0
buf2[buf1]=1
buf2 = tuple(buf2)
samples.addSample(buf,buf2)
for i in range(100):
buf3 = list(buf)
for j in range(len(buf)/20):
buf3[np.random.randint(len(buf))] = bool(np.random.randint(2))
samples.addSample(tuple(buf3),buf2)
return samples
def get_test_samples(input_num):
'''
从new_test文件夹读取测试数据
'''
print('Get test samples start.')
test_path='./new_test'
samples = SupervisedDataSet(input_num,1)
nlist = os.listdir(test_path)
t=int(np.sqrt(input_num))
for n in nlist:
file = os.path.join(test_path,n)
im = Image.open(file)
im = im.convert('L')
im = im.resize((t,t),Image.BILINEAR)
buf = np.array(im).reshape(input_num,1)
buf = buf<200
buf = tuple(buf)
samples.addSample(buf,1)
return samples
class net(object):
'''
网络的定义
'''
def __init__(self,input_num,hide_node_num,output_num):
'''
根据参数初始化网络
'''
self.input_num = input_num
self.hide_node_num = hide_node_num
self.output_num = output_num
self.network = buildNetwork(input_num,hide_node_num,output_num,bias=True) #响应函数和层数可在此调
def train(self,samples,epsilon):
'''
训练函数
'''
print('Train start.')
trainer = BackpropTrainer(self.network,samples) #学习率可在此调
e = 100
n=0
while e>epsilon:
e=trainer.train()
n+=1
print(n,' done,e=',e)
if not n%10:
self.save()
if n>=100:break
self.save()
print('Train end.')
return e
def run(self,samples):
'''
测试
'''
print('Test start.')
result = []
for sample in samples['input']:
buf = self.network.activate(sample)
buf= list(buf)
result.append(buf.index(max(buf)))
print('Result ',result)
result_path = './results/'
filename = str(self.input_num)+'-'+str(self.hide_node_num)+'-'+str(self.output_num)+'new.txt'
with open(result_path+filename,'w') as f:
result = str(result)
f.write(result)
def save(self):
'''
保存训练好的网络
'''
print('saving')
save_path = './save/'
filename = str(self.input_num)+'-'+str(self.hide_node_num)+'-'+str(self.output_num)+'new.cPickle'
with open(save_path+filename,'wb') as f:
cPickle.dump(self.network,f)
print('done')
def load(self):
'''
从存档中加载训练好的网络
'''
print('loading')
save_path ='./save/'
filename = str(self.input_num)+'-'+str(self.hide_node_num)+'-'+str(self.output_num)+'new.cPickle'
if filename in os.listdir('./save/'):
with open(save_path+filename,'r') as f:
self.network = cPickle.load(f)
print('done')
def main():
'''
主函数,定义程序运行过程
'''
start=datetime.datetime.now()
output_num = 10
epsilon = 0.01
input_num=20*20
hide_node_num = 200
filename = str(input_num)+'-'+str(hide_node_num)+'-'+str(output_num)+'new.cPickle'
net1 = net(input_num,hide_node_num,output_num)
if filename in os.listdir('./save/'):
net1.load()
else:
samples = get_train_samples(input_num,output_num)
net1.train(samples,epsilon)
net1.save()
net1.run(get_test_samples(input_num))
end =datetime.datetime.now()
print('Time ',end-start)
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/jinxuan9712/ChineseNumberIdentify.git
git@gitee.com:jinxuan9712/ChineseNumberIdentify.git
jinxuan9712
ChineseNumberIdentify
ChineseNumberIdentify
master

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385