1 Star 1 Fork 1

vegetable0511/ReID2024

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train_cam.py 8.41 KB
一键复制 编辑 原始数据 按行查看 历史
vegetable0511 提交于 2024-02-24 14:16 . 4
import os
import logging
import time
from torch.backends import cudnn
from utils.logger import setup_logger
from datasets import make_dataloader
from model import make_model
from solver import make_optimizer, WarmupMultiStepLR
from loss import make_loss
import random
import torch
import numpy as np
import argparse
from timm.scheduler import create_scheduler
from config import cfg
from timm.data import Mixup
from torch.nn.parallel import DistributedDataParallel
from torch.cuda import amp
import torch.distributed as dist
from utils.meter import AverageMeter
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def do_train(cfg,
model,
center_criterion,
train_loader,
val_loader,
optimizer,
optimizer_center,
scheduler,
loss_fn,
num_query,
local_rank,
):
log_period = cfg.SOLVER.LOG_PERIOD
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
device = "cuda"
epochs = cfg.SOLVER.MAX_EPOCHS
logger = logging.getLogger("reid_baseline.train")
logger.info('start training')
_LOCAL_PROCESS_GROUP = None
if device:
model.to(local_rank)
if torch.cuda.device_count() > 1 and cfg.MODEL.DIST_TRAIN:
print('Using {} GPUs for training'.format(torch.cuda.device_count()))
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
scaler = amp.GradScaler()
loss_meter = AverageMeter()
acc_meter = AverageMeter()
# train
for epoch in range(1, epochs + 1):
start_time = time.time()
loss_meter.reset()
acc_meter.reset()
scheduler.step(epoch)
model.train()
for n_iter, (img, vid, target_cam) in enumerate(train_loader):
optimizer.zero_grad()
optimizer_center.zero_grad()
img = img.to(device)
target_cam = target_cam.to(device)
if cfg.SOLVER.FP16_ENABLED:
#### FP16 training
with amp.autocast(enabled=True):
score, feat = model(img, target_cam , cam_label=None)
loss = loss_fn(score, feat, target_cam)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
score, feat = model(img, target_cam , cam_label=None)
loss = loss_fn(score, feat, target_cam, target_cam)
loss.backward()
optimizer.step()
if isinstance(score, list):
acc = (score[0].max(1)[1] == target_cam).float().mean()
else:
acc = (score.max(1)[1] == target_cam).float().mean()
loss_meter.update(loss.item(), img.shape[0])
acc_meter.update(acc, 1)
torch.cuda.synchronize()
if (n_iter + 1) % log_period == 0:
base_lr = scheduler._get_lr(epoch)[0] if cfg.SOLVER.WARMUP_METHOD == 'cosine' else scheduler.get_lr()[0]
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
.format(epoch, (n_iter + 1), len(train_loader), loss_meter.avg, acc_meter.avg, base_lr))
end_time = time.time()
time_per_batch = (end_time - start_time) / (n_iter + 1)
if cfg.MODEL.DIST_TRAIN:
pass
else:
logger.info("Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]"
.format(epoch, time_per_batch, train_loader.batch_size / time_per_batch))
if epoch % checkpoint_period == 0:
if cfg.MODEL.DIST_TRAIN:
if dist.get_rank() == 0:
torch.save(model.module.state_dict(),
os.path.join(cfg.OUTPUT_DIR, cfg.MODEL.NAME + '_{}.pth'.format(epoch)))
else:
torch.save(model.state_dict(),
os.path.join(cfg.OUTPUT_DIR, cfg.MODEL.NAME + '_{}.pth'.format(epoch)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="ReID Baseline Training")
parser.add_argument(
"--config_file", default="", help="path to config file", type=str
)
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
# parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
# help='learning rate (default: 5e-4)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--epochs', default=120, type=int)
parser.add_argument("--local_rank", default=0, type=int)
args = parser.parse_args()
if args.config_file != "":
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
set_seed(cfg.SOLVER.SEED)
if cfg.MODEL.DIST_TRAIN:
torch.cuda.set_device(args.local_rank)
else:
pass
output_dir = cfg.OUTPUT_DIR
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
logger = setup_logger("reid_baseline", output_dir, if_train=True)
logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR))
# logger.info(args)
if args.config_file != "":
logger.info("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, 'r') as cf:
config_str = "\n" + cf.read()
# logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
if cfg.MODEL.DIST_TRAIN:
torch.distributed.init_process_group(backend='nccl', init_method='env://')
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID
train_loader, val_loader, num_query, num_classes = make_dataloader(cfg)
model = make_model(cfg, num_class=num_classes)
loss_func, center_criterion = make_loss(cfg, num_classes=num_classes)
optimizer, optimizer_center = make_optimizer(cfg, model, center_criterion)
args.sched = cfg.SOLVER.WARMUP_METHOD
args.epochs = cfg.SOLVER.MAX_EPOCHS
args.warmup_epochs = cfg.SOLVER.WARMUP_EPOCHS
if args.sched == 'cosine':
print('===========using cosine learning rate=======')
scheduler, _ = create_scheduler(args, optimizer)
else:
print('===========using normal learning rate=======')
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA,
cfg.SOLVER.WARMUP_FACTOR,
cfg.SOLVER.WARMUP_EPOCHS, cfg.SOLVER.WARMUP_METHOD)
do_train(
cfg,
model,
center_criterion,
train_loader,
val_loader,
optimizer,
optimizer_center,
scheduler, # modify for using self trained model
loss_func,
num_query, args.local_rank
)
print(cfg.OUTPUT_DIR)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/vegetable0511/re-id2024.git
git@gitee.com:vegetable0511/re-id2024.git
vegetable0511
re-id2024
ReID2024
master

搜索帮助