1 Star 0 Fork 0

zhoub86/DeepDIG

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
config.py 2.25 KB
一键复制 编辑 原始数据 按行查看 历史
Hamid Karimi 提交于 2020-04-28 15:47 . Update config.py
import argparse
import os
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def path(p):
return os.path.expanduser(p)
PATH = 'CHANGE ME PLEASE'
parser = argparse.ArgumentParser(description='Arguments of DeepDIG project')
parser.add_argument("--project-dir",default=PATH)
parser.add_argument("--dataset",default='MNIST')
parser.add_argument("--pre-trained-model",default='CNN')
parser.add_argument("--dropout", type=float, required=False, default=0.0, help="Ratio of dropout")
parser.add_argument("--lr", type=float, required=False, default=0.01, help="Learning rate")
parser.add_argument("--step-size-scheduler", type=int, required=False, default=1000, help="The learning rate step size scheduler")
parser.add_argument("--gamma-scheduler", type=float, required=False, default=0.95, help="Gamma of scheduler")
parser.add_argument('--cuda', type=str2bool, default=True, help='enables CUDA training')
parser.add_argument('--steps', type=int, default=5000,
help='number of steps to train (default: 2000)')
parser.add_argument('--batch_size', type=int, default=128,
help='input batch size for training (default: 128)')
parser.add_argument("--middle-point-threshold", type=float, required=False, default=0.0001,
help="Parameter beta in Algorithm 1")
parser.add_argument("--alpha", type=float, required=False, default=0.8,
help="Coefficient of target loss")
parser.add_argument("--classes",type=str,default="1;2",help="The investigated classes")
parser.add_argument('--save_samples', type=str2bool, default=True,
help='enables saving generated samples')
parser.add_argument('--num_classes', type=int, default=10, help='number of classes')
parser.add_argument('--pre-trained-model-input-shape', type=str, default="1;28;28",
help='shape of the input data to pre trained model')
parser.add_argument("--num-samples-trajectory", type=int, required=False, default=50,
help="Number of samples generated in the trajectory line between x(t)=t*x0+(1-t)*x1")
args = parser.parse_args()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhoub86/DeepDIG.git
git@gitee.com:zhoub86/DeepDIG.git
zhoub86
DeepDIG
DeepDIG
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385