1 Star 0 Fork 0

东方佑/AMGAN

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 2.24 KB
一键复制 编辑 原始数据 按行查看 历史
GuangyuanHao 提交于 2018-01-08 22:22 . update
import argparse
import os
import scipy.misc
import numpy as np
import tensorflow as tf
from model import amgan
parser = argparse.ArgumentParser(description='')
parser.add_argument('--z_dim', dest='z_dim',type= int, default=100,help='z_dimention')
parser.add_argument('--epoch', dest='epoch',type= int, default=1000,help='epoch')
parser.add_argument('--batch_size',dest='batch_size',type= int, default=64, help = 'bactch_size')
parser.add_argument('--train_size',dest='train_size',type= int, default=1e8, help = 'train_size')
parser.add_argument('--fine_size',dest='fine_size',type =int, default= 32, help ='fine_size')
parser.add_argument('--ngf', dest='ngf', type=int, default=64, help = 'first num_channel of generator layer')
parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer')
parser.add_argument('--inputA_nc', dest='inputA_nc', type=int, default=3, help='# of input image A channels')
parser.add_argument('--inputB_nc', dest='inputB_nc', type=int, default=1, help='# of input image B channels')
parser.add_argument('--lr', dest='lr', type=float, default=0.00002, help='initial learning rate for adam')
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--dataset_name', dest='dataset_name', default='svhn_mnist',help='name of dataset')
parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint1', help='models are saved here')
parser.add_argument('--sample_dir', dest='sample_dir', default='./sample1', help='samples are saved here')
parser.add_argument('--log_dir', dest='log_dir', default='./log1', help='logs are saved here')
parser.add_argument('--phase', dest='phase', default='train', help='train, test')
args=parser.parse_args()
def main(_):
if not os.path.exists(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
if not os.path.exists(args.sample_dir):
os.makedirs(args.sample_dir)
with tf.Session() as sess:
model = amgan(sess,args)
model.train(args) if args.phase == 'train' \
else model.test(args)
if __name__ == '__main__':
tf.app.run()
# CUDA_VISIBLE_DEVICES=1 python main.py --phase=test L zeus 3 L1 grus 1
# tensorboard --port=6010 --logdir=./log
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/chenyang918/AMGAN.git
git@gitee.com:chenyang918/AMGAN.git
chenyang918
AMGAN
AMGAN
master

搜索帮助

D67c1975 1850385 1daf7b77 1850385