1 Star 0 Fork 0

sesepp/DeepID

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 3.78 KB
一键复制 编辑 原始数据 按行查看 历史
shen1994 提交于 2018-06-05 21:27 . Initial commit
# -*- coding: utf-8 -*-
"""
Created on Sun Jun 3 21:31:56 2018
@author: shen1994
"""
import os
import tensorflow as tf
import numpy as np
from image_vector import load_vector
from image_vector import load_pair_vector
from image_vector import load_vector_from_index
from image_vector import safe_file_close
from image_split import load_train_test_number
from model import deepid_1
def run():
epocs = 100001
batch_size = 256
log_dir = "log"
if tf.gfile.Exists(log_dir):
tf.gfile.DeleteRecursively(log_dir)
tf.gfile.MakeDirs(log_dir)
model_dir = "model"
if tf.gfile.Exists(model_dir):
tf.gfile.DeleteRecursively(model_dir)
tf.gfile.MakeDirs(model_dir)
train_samples_number, _ = load_train_test_number("image/train_test_number.pkl")
class_num = train_samples_number + 1
with tf.name_scope('input'):
x = tf.placeholder(tf.float32, [None, 55, 47, 3], name='x')
y = tf.placeholder(tf.float32, [None, class_num], name='y')
merged, loss, accuracy, optimizer = deepid_1(x, y, class_num=class_num)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
train_writer = tf.summary.FileWriter(log_dir + "/train", sess.graph)
valid_writer = tf.summary.FileWriter(log_dir + "/valid", sess.graph)
t_big_number, t_size, t_index, t_pkl_file = 0, 0, 0, None
v_big_number, v_size, v_index, v_pkl_file = 0, 0, 0, None
step = 0
while step < epocs:
train_x, train_y, t_index, t_big_number, t_size, t_pkl_file = load_vector_from_index(
"image/train_vector_dataset.pkl",
batch_size,
t_index,
t_big_number,
t_size,
t_pkl_file)
train_onehot_y = (np.arange(class_num) == train_y[:, None]).astype(np.float32)
_ = sess.run(optimizer, {x: train_x, y: train_onehot_y})
if step % 1000 == 0:
# train
summary = sess.run(merged, {x: train_x, y: train_onehot_y})
train_writer.add_summary(summary, step)
# valid
valid_x, valid_y, v_index, v_big_number, v_size, v_pkl_file = load_vector_from_index(
"image/valid_vector_dataset.pkl",
batch_size,
v_index,
v_big_number,
v_size,
v_pkl_file)
valid_onehot_y = (np.arange(class_num) == valid_y[:, None]).astype(np.float32)
summary = sess.run(merged, {x: valid_x, y: valid_onehot_y})
valid_writer.add_summary(summary, step)
t_cost, t_acc = sess.run([loss, accuracy], {x: train_x, y: train_onehot_y})
v_cost, v_acc = sess.run([loss, accuracy], {x: valid_x, y: valid_onehot_y})
print(str(step) + ": train --->" + "cost:" + str(t_cost) + ", accuracy:" + str(t_acc))
print(str(step) + ": valid --->" + "cost:" + str(v_cost) + ", accuracy:" + str(v_acc))
print("----------------------------------------")
if step % 10000 == 0 and step != 0:
saver.save(sess, 'model/deepid%d.ckpt' % step)
step += 1
safe_file_close(t_pkl_file)
safe_file_close(v_pkl_file)
if __name__ == "__main__":
run()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/sesepp/DeepID.git
git@gitee.com:sesepp/DeepID.git
sesepp
DeepID
DeepID
master

搜索帮助