1 Star 0 Fork 0

LZY/Char-RNN-TensorFlow

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 2.84 KB
一键复制 编辑 原始数据 按行查看 历史
LZY2006 提交于 2021-01-10 12:34 . Signed-off-by: LZY2006
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from read_utils import TextConverter, batch_generator
from model import CharRNN
import os
import codecs
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('name', 'default', 'name of the model')
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')
tf.flags.DEFINE_integer('num_layers', 2, 'number of lstm layers')
tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding')
tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding')
tf.flags.DEFINE_float('learning_rate', 0.001, 'learning_rate')
tf.flags.DEFINE_float('train_keep_prob', 0.5, 'dropout rate during training')
tf.flags.DEFINE_string('input_file', '', 'utf8 encoded text file')
tf.flags.DEFINE_integer('max_steps', 100000, 'max steps to train')
tf.flags.DEFINE_integer('save_every_n', 1000, 'save the model every n steps')
tf.flags.DEFINE_integer('log_every_n', 10, 'log to the screen every n steps')
tf.flags.DEFINE_integer('max_vocab', 3500, 'max char number')
def main(_):
model_path = os.path.join('model', FLAGS.name)
print(model_path)
if os.path.exists(model_path) is False:
os.makedirs(model_path)
path_exist = False
else:
path_exist = True
with codecs.open(FLAGS.input_file, encoding='utf-8') as f:
text = f.read()
converter = TextConverter(text, FLAGS.max_vocab)
converter.save_to_file(os.path.join(model_path, 'converter.pkl'))
arr = converter.text_to_arr(text)
g = batch_generator(arr, FLAGS.num_seqs, FLAGS.num_steps)
print(converter.vocab_size)
model = CharRNN(converter.vocab_size,
num_seqs=FLAGS.num_seqs,
num_steps=FLAGS.num_steps,
lstm_size=FLAGS.lstm_size,
num_layers=FLAGS.num_layers,
learning_rate=FLAGS.learning_rate,
train_keep_prob=FLAGS.train_keep_prob,
use_embedding=FLAGS.use_embedding,
embedding_size=FLAGS.embedding_size
)
model_file_path = tf.train.latest_checkpoint(model_path)
if path_exist:
model.load(model_file_path)
indexes = []
for dirpath, dirnames, filenames in os.walk(model_path):
for name in filenames:
filepath = os.path.join(dirpath, name)
if filepath.endswith(".index"):
indexes.append(int(name[6:-6]))
indexes.sort()
last_index = indexes[-1]
model.step = last_index
model.train(g,
FLAGS.max_steps,
model_path,
FLAGS.save_every_n,
FLAGS.log_every_n,
)
if __name__ == '__main__':
tf.app.run()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/LZY2006/Char-RNN-TensorFlow.git
git@gitee.com:LZY2006/Char-RNN-TensorFlow.git
LZY2006
Char-RNN-TensorFlow
Char-RNN-TensorFlow
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385