1 Star 0 Fork 1

CSDN-AI工程师直通车第二期项目/showAndTell

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
show_and_tell_model.py 14.81 KB
一键复制 编辑 原始数据 按行查看 历史
zhangyachen 提交于 2018-06-02 13:57 . no commit message
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
"Show and Tell: A Neural Image Caption Generator"
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from ops import image_embedding
from ops import image_processing
from ops import inputs as input_ops
class ShowAndTellModel(object):
"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
"Show and Tell: A Neural Image Caption Generator"
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
"""
def __init__(self, config, mode, train_inception=False):
"""Basic setup.
Args:
config: Object containing configuration parameters.
mode: "train", "eval" or "inference".
train_inception: Whether the inception submodel variables are trainable.
"""
assert mode in ["train", "eval", "inference"]
self.config = config
self.mode = mode
self.train_inception = train_inception
# 为输入数据构建TFRecord读取器.
self.reader = tf.TFRecordReader()
# To match the "Show and Tell" paper we initialize all variables with a
# random uniform initializer.
# 随机化初始器
self.initializer = tf.random_uniform_initializer(
minval=-self.config.initializer_scale,
maxval=self.config.initializer_scale)
# A float32 Tensor with shape [batch_size, height, width, channels].
self.images = None
# An int32 Tensor with shape [batch_size, padded_length].
self.input_seqs = None
# An int32 Tensor with shape [batch_size, padded_length].
self.target_seqs = None
# An int32 0/1 Tensor with shape [batch_size, padded_length].
self.input_mask = None
# A float32 Tensor with shape [batch_size, embedding_size].
self.image_embeddings = None
# A float32 Tensor with shape [batch_size, padded_length, embedding_size].
self.seq_embeddings = None
# A float32 scalar Tensor; the total loss for the trainer to optimize.
self.total_loss = None
# A float32 Tensor with shape [batch_size * padded_length].
self.target_cross_entropy_losses = None
# A float32 Tensor with shape [batch_size * padded_length].
self.target_cross_entropy_loss_weights = None
# Collection of variables from the inception submodel.
self.inception_variables = []
# Function to restore the inception submodel from checkpoint.
self.init_fn = None
# Global step Tensor.
self.global_step = None
def is_training(self):
"""Returns true if the model is built for training mode.
训练模式判断"""
return self.mode == "train"
def process_image(self, encoded_image, thread_id=0):
"""Decodes and processes an image string.
解码并处理图片
Args:
encoded_image: 图片文件原始数据/A scalar string Tensor; the encoded image.
thread_id: Preprocessing thread id used to select the ordering of color
distortions.
Returns:
A float32 Tensor of shape [height, width, 3]; the processed image.
"""
return image_processing.process_image(encoded_image,
is_training=self.is_training(),
height=self.config.image_height,
width=self.config.image_width,
thread_id=thread_id,
image_format=self.config.image_format)
def build_inputs(self):
"""Input prefetching, preprocessing and batching.
获取、预处理、批处理输入数据
Outputs:
self.images
self.input_seqs
self.target_seqs (training and eval only)
self.input_mask (training and eval only)
"""
if self.mode == "inference":
# In inference mode, images and inputs are fed via placeholders.
# 使用placeholders传入数据
image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
input_feed = tf.placeholder(dtype=tf.int64,
shape=[None], # batch_size
name="input_feed")
# Process image and insert batch dimensions.
# 处理图片,添加batch维度
images = tf.expand_dims(self.process_image(image_feed), 0)
input_seqs = tf.expand_dims(input_feed, 1)
# No target sequences or input mask in inference mode.
# 推断模式下,无target sequences、input mask
target_seqs = None
input_mask = None
else:
# Prefetch serialized SequenceExample protos.
# 获取原始输入数据队列
input_queue = input_ops.prefetch_input_data(
self.reader,
self.config.input_file_pattern,
is_training=self.is_training(),
batch_size=self.config.batch_size,
values_per_shard=self.config.values_per_input_shard,
input_queue_capacity_factor=self.config.input_queue_capacity_factor,
num_reader_threads=self.config.num_input_reader_threads)
# Image processing and random distortion. Split across multiple threads
# with each thread applying a slightly different distortion.
# 使用偶数个线程对图片进行
assert self.config.num_preprocess_threads % 2 == 0
images_and_captions = []
for thread_id in range(self.config.num_preprocess_threads):
# 获取单个image-caption对的serialized_sequence_example数据
serialized_sequence_example = input_queue.dequeue()
# 解析出原始图片数据和caption
encoded_image, caption = input_ops.parse_sequence_example(
serialized_sequence_example,
image_feature=self.config.image_feature_name,
caption_feature=self.config.caption_feature_name)
# 处于图片和随机颜色扭曲
image = self.process_image(encoded_image, thread_id=thread_id)
# 添加到images_and_captions数组
images_and_captions.append([image, caption])
# Batch inputs.
# 把添加到images_and_captions数据转换为batch
queue_capacity = (2 * self.config.num_preprocess_threads *
self.config.batch_size)
images, input_seqs, target_seqs, input_mask = (
input_ops.batch_with_dynamic_pad(images_and_captions,
batch_size=self.config.batch_size,
queue_capacity=queue_capacity))
# 获取最终训练数据
self.images = images
self.input_seqs = input_seqs
self.target_seqs = target_seqs
self.input_mask = input_mask
def build_image_embeddings(self):
"""Builds the image model subgraph and generates image embeddings.
建立图片编码模型子网络InceptionV3,生成图片embedding特征
Inputs:
self.images
Outputs:
self.image_embeddings
"""
# 获取模型输出
inception_output = image_embedding.inception_v3(
self.images,
trainable=self.train_inception,
is_training=self.is_training())
self.inception_variables = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope="InceptionV3")
# Map inception output into embedding space.
# 把inception网络输出映射到embedding空间
with tf.variable_scope("image_embedding") as scope:
image_embeddings = tf.contrib.layers.fully_connected(
inputs=inception_output,
num_outputs=self.config.embedding_size,
activation_fn=None,
weights_initializer=self.initializer,
biases_initializer=None,
scope=scope)
# Save the embedding size in the graph.
tf.constant(self.config.embedding_size, name="embedding_size")
self.image_embeddings = image_embeddings
def build_seq_embeddings(self):
"""Builds the input sequence embeddings.
生caption中的word序列的embedding特征
Inputs:
self.input_seqs
Outputs:
self.seq_embeddings
"""
# cpu上执行word序列的embedding特征(矩阵查询方式)
with tf.variable_scope("seq_embedding"), tf.device("/cpu:0"):
embedding_map = tf.get_variable(
name="map",
shape=[self.config.vocab_size, self.config.embedding_size],
initializer=self.initializer)
seq_embeddings = tf.nn.embedding_lookup(embedding_map, self.input_seqs)
self.seq_embeddings = seq_embeddings
def build_model(self):
"""Builds the model.
建立caption模型
Inputs:
self.image_embeddings
self.seq_embeddings
self.target_seqs (training and eval only)
self.input_mask (training and eval only)
Outputs:
self.total_loss (training and eval only)
self.target_cross_entropy_losses (training and eval only)
self.target_cross_entropy_loss_weights (training and eval only)
"""
# This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
# modified LSTM in the "Show and Tell" paper has no biases and outputs
# new_c * sigmoid(o).
# 构建LSTM模型
lstm_cell = tf.contrib.rnn.BasicLSTMCell(
num_units=self.config.num_lstm_units, state_is_tuple=True)
# 训练阶段,添加dropout
if self.mode == "train":
lstm_cell = tf.contrib.rnn.DropoutWrapper(
lstm_cell,
input_keep_prob=self.config.lstm_dropout_keep_prob,
output_keep_prob=self.config.lstm_dropout_keep_prob)
with tf.variable_scope("lstm", initializer=self.initializer) as lstm_scope:
# Feed the image embeddings to set the initial LSTM state.
# 获取LSTM的全零隐含状态
zero_state = lstm_cell.zero_state(
batch_size=self.image_embeddings.get_shape()[0], dtype=tf.float32)
# 把图片embedding特征作为LSTM的第一输入,获取初始隐含状态
_, initial_state = lstm_cell(self.image_embeddings, zero_state)
# Allow the LSTM variables to be reused.
# 运行LSTM权重参数重用
lstm_scope.reuse_variables()
if self.mode == "inference":
# In inference mode, use concatenated states for convenient feeding and
# fetching.
tf.concat(axis=1, values=initial_state, name="initial_state")
# Placeholder for feeding a batch of concatenated states.
state_feed = tf.placeholder(dtype=tf.float32,
shape=[None, sum(lstm_cell.state_size)],
name="state_feed")
state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1)
# Run a single LSTM step.
lstm_outputs, state_tuple = lstm_cell(
inputs=tf.squeeze(self.seq_embeddings, axis=[1]),
state=state_tuple)
# Concatentate the resulting state.
tf.concat(axis=1, values=state_tuple, name="state")
else:
# Run the batch of sequence embeddings through the LSTM.
# 在batch数据上执行LSTM推断,获取输出
sequence_length = tf.reduce_sum(self.input_mask, 1)
lstm_outputs, _ = tf.nn.dynamic_rnn(cell=lstm_cell,
inputs=self.seq_embeddings,
sequence_length=sequence_length,
initial_state=initial_state,
dtype=tf.float32,
scope=lstm_scope)
# Stack batches vertically.
lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size])
# 基于LSTM输出,进行全连接层推断
with tf.variable_scope("logits") as logits_scope:
logits = tf.contrib.layers.fully_connected(
inputs=lstm_outputs,
num_outputs=self.config.vocab_size,
activation_fn=None,
weights_initializer=self.initializer,
scope=logits_scope)
if self.mode == "inference":
tf.nn.softmax(logits, name="softmax")
else:
# 整理目标序列和真word序列的格式
targets = tf.reshape(self.target_seqs, [-1])
weights = tf.to_float(tf.reshape(self.input_mask, [-1]))
# 计算losses.
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets,
logits=logits)
batch_loss = tf.div(tf.reduce_sum(tf.multiply(losses, weights)),
tf.reduce_sum(weights),
name="batch_loss")
tf.losses.add_loss(batch_loss)
total_loss = tf.losses.get_total_loss()
# 添加总结
tf.summary.scalar("losses/batch_loss", batch_loss)
tf.summary.scalar("losses/total_loss", total_loss)
for var in tf.trainable_variables():
tf.summary.histogram("parameters/" + var.op.name, var)
self.total_loss = total_loss
self.target_cross_entropy_losses = losses # Used in evaluation.
self.target_cross_entropy_loss_weights = weights # Used in evaluation.
def setup_inception_initializer(self):
"""Sets up the function to restore inception variables from checkpoint.
加载inception预训练模型参数"""
if self.mode != "inference":
# Restore inception variables only.
saver = tf.train.Saver(self.inception_variables)
def restore_fn(sess):
tf.logging.info("Restoring Inception variables from checkpoint file %s",
self.config.inception_checkpoint_file)
saver.restore(sess, self.config.inception_checkpoint_file)
self.init_fn = restore_fn
def setup_global_step(self):
"""Sets up the global step Tensor.
建立Global step"""
global_step = tf.Variable(
initial_value=0,
name="global_step",
trainable=False,
collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])
self.global_step = global_step
def build(self):
"""Creates all ops for training and evaluation."""
self.build_inputs()#获取、预处理、批处理输入数据
self.build_image_embeddings()#建立图片编码模型子网络InceptionV3,生成图片embedding特征
self.build_seq_embeddings()#生caption中的word序列的embedding特征
self.build_model()#建立caption模型
self.setup_inception_initializer()#加载inception预训练模型参数
self.setup_global_step()# 建立Global step
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/csdn_ai_project2/showAndTell.git
git@gitee.com:csdn_ai_project2/showAndTell.git
csdn_ai_project2
showAndTell
showAndTell
master

搜索帮助