1 Star 0 Fork 0

rjyrjy/face-of-art

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
train_heatmaps_network.py 6.62 KB
一键复制 编辑 原始数据 按行查看 历史
import tensorflow as tf
from deep_heatmaps_model_fusion_net import DeepHeatmapsModel
import os
flags = tf.app.flags
# define paths
flags.DEFINE_string('output_dir', 'output', "directory for saving models, logs and samples")
flags.DEFINE_string('save_model_path', 'model', "directory for saving the model")
flags.DEFINE_string('save_sample_path', 'sample',
"directory for saving the sampled images, relevant if sample_to_log is False")
flags.DEFINE_string('save_log_path', 'logs', "directory for saving the log file")
flags.DEFINE_string('img_path', '~/landmark_detection_datasets', "data directory")
flags.DEFINE_string('valid_data', 'full', 'validation set to use: full/common/challenging/test')
flags.DEFINE_string('train_crop_dir', 'crop_gt_margin_0.25', "directory of train images cropped to bb (+margin)")
flags.DEFINE_string('img_dir_ns', 'crop_gt_margin_0.25_ns', "directory of train imgs cropped to bb + style transfer")
flags.DEFINE_string('epoch_data_dir', 'epoch_data', "directory containing pre-augmented data for each epoch")
flags.DEFINE_bool('use_epoch_data', False, "use pre-augmented data")
# logging parameters
flags.DEFINE_integer('print_every', 100, "print losses to screen + log every X steps")
flags.DEFINE_integer('save_every', 20000, "save model every X steps")
flags.DEFINE_integer('sample_every', 5000, "sample heatmaps + landmark predictions every X steps")
flags.DEFINE_integer('sample_grid', 4, 'number of training images in sample')
flags.DEFINE_bool('sample_to_log', True, 'samples will be saved to tensorboard log')
flags.DEFINE_integer('valid_size', 20, 'number of validation images to run')
flags.DEFINE_integer('log_valid_every', 10, 'evaluate on valid set every X epochs')
flags.DEFINE_integer('debug_data_size', 20, 'subset data size to test in debug mode')
flags.DEFINE_bool('debug', False, 'run in debug mode - use subset of the data')
# pretrain parameters (for fine-tuning / resume training)
flags.DEFINE_string('pre_train_path', 'model/deep_heatmaps-40000', 'pretrained model path')
flags.DEFINE_bool('load_pretrain', False, "load pretrained weight?")
flags.DEFINE_bool('load_primary_only', False, 'fine-tuning using only primary network weights')
# input data parameters
flags.DEFINE_integer('image_size', 256, "image size")
flags.DEFINE_integer('c_dim', 3, "color channels")
flags.DEFINE_integer('num_landmarks', 68, "number of face landmarks")
flags.DEFINE_float('sigma', 6, "std for heatmap generation gaussian")
flags.DEFINE_integer('scale', 1, 'scale for image normalization 255/1/0')
flags.DEFINE_float('margin', 0.25, 'margin for face crops - % of bb size')
flags.DEFINE_string('bb_type', 'gt', "bb to use - 'gt':for ground truth / 'init':for face detector output")
flags.DEFINE_float('win_mult', 3.33335, 'gaussian filter size for approx maps: 2 * sigma * win_mult + 1')
# optimization parameters
flags.DEFINE_float('l_weight_primary', 1., 'primary loss weight')
flags.DEFINE_float('l_weight_fusion', 0., 'fusion loss weight')
flags.DEFINE_float('l_weight_upsample', 3., 'upsample loss weight')
flags.DEFINE_integer('train_iter', 60000, 'maximum training iterations')
flags.DEFINE_integer('batch_size', 6, "batch_size")
flags.DEFINE_float('learning_rate', 1e-4, "initial learning rate")
flags.DEFINE_bool('adam_optimizer', True, "use adam optimizer (if False momentum optimizer is used)")
flags.DEFINE_float('momentum', 0.95, "optimizer momentum (if adam_optimizer==False)")
flags.DEFINE_integer('step', 100000, 'step for lr decay')
flags.DEFINE_float('gamma', 0.1, 'exponential base for lr decay')
flags.DEFINE_float('reg', 1e-5, 'scalar multiplier for weight decay (0 to disable)')
flags.DEFINE_string('weight_initializer', 'xavier', 'weight initializer: random_normal / xavier')
flags.DEFINE_float('weight_initializer_std', 0.01, 'std for random_normal weight initializer')
flags.DEFINE_float('bias_initializer', 0.0, 'constant value for bias initializer')
# augmentation parameters
flags.DEFINE_bool('augment_basic', True, "use basic augmentation?")
flags.DEFINE_bool('augment_texture', False, "use artistic texture augmentation?")
flags.DEFINE_float('p_texture', 0., 'probability of artistic texture augmentation')
flags.DEFINE_bool('augment_geom', False, "use artistic geometric augmentation?")
flags.DEFINE_float('p_geom', 0., 'probability of artistic geometric augmentation')
FLAGS = flags.FLAGS
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
def main(_):
save_model_path = os.path.join(FLAGS.output_dir, FLAGS.save_model_path)
save_sample_path = os.path.join(FLAGS.output_dir, FLAGS.save_sample_path)
save_log_path = os.path.join(FLAGS.output_dir, FLAGS.save_log_path)
# create directories if not exist
if not os.path.exists(save_model_path):
os.mkdir(save_model_path)
if not os.path.exists(save_log_path):
os.mkdir(save_log_path)
if not os.path.exists(save_sample_path) and not FLAGS.sample_to_log:
os.mkdir(save_sample_path)
model = DeepHeatmapsModel(
mode='TRAIN', train_iter=FLAGS.train_iter, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate,
l_weight_primary=FLAGS.l_weight_primary, l_weight_fusion=FLAGS.l_weight_fusion,
l_weight_upsample=FLAGS.l_weight_upsample, reg=FLAGS.reg, adam_optimizer=FLAGS.adam_optimizer,
momentum=FLAGS.momentum, step=FLAGS.step, gamma=FLAGS.gamma,
weight_initializer=FLAGS.weight_initializer, weight_initializer_std=FLAGS.weight_initializer_std,
bias_initializer=FLAGS.bias_initializer, image_size=FLAGS.image_size, c_dim=FLAGS.c_dim,
num_landmarks=FLAGS.num_landmarks, sigma=FLAGS.sigma, scale=FLAGS.scale, margin=FLAGS.margin,
bb_type=FLAGS.bb_type, win_mult=FLAGS.win_mult, augment_basic=FLAGS.augment_basic,
augment_texture=FLAGS.augment_texture, p_texture=FLAGS.p_texture, augment_geom=FLAGS.augment_geom,
p_geom=FLAGS.p_geom, output_dir=FLAGS.output_dir, save_model_path=save_model_path,
save_sample_path=save_sample_path, save_log_path=save_log_path, pre_train_path=FLAGS.pre_train_path,
load_pretrain=FLAGS.load_pretrain, load_primary_only=FLAGS.load_primary_only,
img_path=FLAGS.img_path, valid_data=FLAGS.valid_data, valid_size=FLAGS.valid_size,
log_valid_every=FLAGS.log_valid_every, train_crop_dir=FLAGS.train_crop_dir, img_dir_ns=FLAGS.img_dir_ns,
print_every=FLAGS.print_every, save_every=FLAGS.save_every, sample_every=FLAGS.sample_every,
sample_grid=FLAGS.sample_grid, sample_to_log=FLAGS.sample_to_log, debug_data_size=FLAGS.debug_data_size,
debug=FLAGS.debug, use_epoch_data=FLAGS.use_epoch_data, epoch_data_dir=FLAGS.epoch_data_dir)
model.train()
if __name__ == '__main__':
tf.app.run()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/Ruanjiyang/face-of-art.git
git@gitee.com:Ruanjiyang/face-of-art.git
Ruanjiyang
face-of-art
face-of-art
master

搜索帮助