2 Star 0 Fork 0

alibaba/catex

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 6.33 KB
一键复制 编辑 原始数据 按行查看 历史
Zhihang-Fu 提交于 2024-11-13 17:34 . Add files via upload
import argparse
import torch
from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.engine import build_trainer
# custom
import datasets.oxford_pets
import datasets.oxford_flowers
import datasets.fgvc_aircraft
import datasets.dtd
import datasets.eurosat
import datasets.stanford_cars
import datasets.food101
import datasets.sun397
import datasets.caltech101
import datasets.ucf101
import datasets.imagenet
import datasets.cifar_
import datasets.imagenet_sketch
import datasets.imagenetv2
import datasets.imagenet_a
# import datasets.imagenet_r
import trainers.catex
def print_args(args, cfg):
print("***************")
print("** Arguments **")
print("***************")
optkeys = list(args.__dict__.keys())
optkeys.sort()
for key in optkeys:
print("{}: {}".format(key, args.__dict__[key]))
print("************")
print("** Config **")
print("************")
print(cfg)
def reset_cfg(cfg, args):
if args.root:
cfg.DATASET.ROOT = args.root
if args.ood_test:
cfg.TRAINER.OOD_TEST = args.ood_test
if args.ood_train:
cfg.TRAINER.OOD_TRAIN = args.ood_train
if args.output_dir:
cfg.OUTPUT_DIR = args.output_dir
if args.resume:
cfg.RESUME = args.resume
if args.seed:
cfg.SEED = args.seed
if args.source_domains:
cfg.DATASET.SOURCE_DOMAINS = args.source_domains
if args.target_domains:
cfg.DATASET.TARGET_DOMAINS = args.target_domains
if args.transforms:
cfg.INPUT.TRANSFORMS = args.transforms
if args.trainer:
cfg.TRAINER.NAME = args.trainer
if args.backbone:
cfg.MODEL.BACKBONE.NAME = args.backbone
if args.head:
cfg.MODEL.HEAD.NAME = args.head
def extend_cfg(cfg):
"""
Add new config variables.
E.g.
from yacs.config import CfgNode as CN
cfg.TRAINER.MY_MODEL = CN()
cfg.TRAINER.MY_MODEL.PARAM_A = 1.
cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
cfg.TRAINER.MY_MODEL.PARAM_C = False
"""
from yacs.config import CfgNode as CN
cfg.TRAINER.CATEX = CN()
cfg.TRAINER.CATEX.N_CTX = 16 # number of context vectors
cfg.TRAINER.CATEX.CSC = False # class-specific context
cfg.TRAINER.CATEX.CTX_INIT = "" # initialization words a photo of a / ensemble / ensemble with learned
cfg.TRAINER.CATEX.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.CATEX.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
def setup_cfg(args):
cfg = get_cfg_default()
extend_cfg(cfg)
# 1. From the dataset config file
if args.dataset_config_file:
cfg.merge_from_file(args.dataset_config_file)
# 2. From the method config file
if args.config_file:
cfg.merge_from_file(args.config_file)
# 3. From input arguments
reset_cfg(cfg, args)
# 4. From optional input arguments
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def main(args):
cfg = setup_cfg(args)
if cfg.SEED >= 0:
print("Setting fixed seed: {}".format(cfg.SEED))
set_random_seed(cfg.SEED)
setup_logger(cfg.OUTPUT_DIR)
if torch.cuda.is_available() and cfg.USE_CUDA:
torch.backends.cudnn.benchmark = True
# print_args(args, cfg)
# print("Collecting env info ...")
# print("** System info **\n{}\n".format(collect_env_info()))
trainer = build_trainer(cfg)
if args.ood_test:
if args.model_dir != '':
assert cfg.TRAINER.CATEX.CTX_INIT in ['', None, 'ensemble', 'ensemble_learned']
trainer.load_model(args.model_dir, epoch=args.load_epoch)
trainer.test_ood(model_directory=args.model_dir)
return
if args.eval_only:
if cfg.TRAINER.CATEX.CTX_INIT == '':
trainer.load_model(args.model_dir, epoch=args.load_epoch)
trainer.test()
return
if not args.no_train:
if args.model_dir != '':
trainer.load_model(args.model_dir, epoch=args.load_epoch)
if cfg.TRAINER.OOD_TRAIN:
trainer.forward_backward = trainer.forward_backward_ood
trainer.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="", help="path to dataset")
parser.add_argument("--output-dir", type=str, default="", help="output directory")
parser.add_argument(
"--resume",
type=str,
default="",
help="checkpoint directory (from which the training resumes)",
)
parser.add_argument(
"--seed", type=int, default=-1, help="only positive value enables a fixed seed"
)
parser.add_argument(
"--source-domains", type=str, nargs="+", help="source domains for DA/DG"
)
parser.add_argument(
"--target-domains", type=str, nargs="+", help="target domains for DA/DG"
)
parser.add_argument(
"--transforms", type=str, nargs="+", help="data augmentation methods"
)
parser.add_argument(
"--config-file", type=str, default="", help="path to config file"
)
parser.add_argument(
"--dataset-config-file",
type=str,
default="",
help="path to config file for dataset setup",
)
parser.add_argument("--ood-test", action="store_true", help="flag for ood test")
parser.add_argument("--ood-train", action="store_true", help="flag for ood train")
parser.add_argument("--trainer", type=str, default="", help="name of trainer")
parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone")
parser.add_argument("--head", type=str, default="", help="name of head")
parser.add_argument("--eval-only", action="store_true", help="evaluation only")
parser.add_argument(
"--model-dir",
type=str,
default="",
help="load model from this directory for eval-only mode",
)
parser.add_argument(
"--load-epoch", type=int, help="load model weights at this epoch for evaluation"
)
parser.add_argument(
"--no-train", action="store_true", help="do not call trainer.train()"
)
parser.add_argument(
"opts",
default=None,
nargs=argparse.REMAINDER,
help="modify config options using the command-line",
)
args = parser.parse_args()
main(args)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mirrors_alibaba/catex.git
git@gitee.com:mirrors_alibaba/catex.git
mirrors_alibaba
catex
catex
main

搜索帮助