代码拉取完成,页面将自动刷新
import os
from torch.backends import cudnn
from utils.logger import setup_logger
from datasets import make_dataloader, get_trainloader_uda, get_testloader_uda
from model import make_model
from solver import make_optimizer, WarmupMultiStepLR
from loss import make_loss
from processor import do_uda_train
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import numpy as np
import os
import argparse
from timm.scheduler import create_scheduler
from config import cfg
from timm.data import Mixup
from sklearn.cluster import DBSCAN, KMeans
from utils.faiss_rerank import compute_jaccard_distance
from utils.faiss_rerank import batch_cosine_dist, cosine_dist
# 设置随机数种子,保证训练的可重复性
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
# 从预训练模型中加载权重,用于初始化模型
def copy_state_dict(state_dict, model, strip=None):
tgt_state = model.state_dict()
copied_names = set()
for name, param in state_dict.items():
name = name.replace('module.', '')
if strip is not None and name.startswith(strip):
name = name[len(strip):]
if name not in tgt_state:
continue
if isinstance(param, Parameter):
param = param.data
if param.size() != tgt_state[name].size():
print('mismatch:', name, param.size(), tgt_state[name].size())
continue
tgt_state[name].copy_(param)
copied_names.add(name)
missing = set(tgt_state.keys()) - copied_names
if len(missing) > 0:
print("missing keys in state_dict:", missing)
return model
# 迭代数据加载器的封装类,用于实现循环迭代
class IterLoader:
def __init__(self, loader, length=None):
self.loader = loader
self.length = length
self.iter = None
def __len__(self):
if (self.length is not None):
return self.length
return len(self.loader)
def new_epoch(self):
self.iter = iter(self.loader)
def next(self):
try:
return next(self.iter)
except:
self.iter = iter(self.loader)
return next(self.iter)
# 用于从模型中提取特征
def extract_features(model, data_loader, print_freq):
model.eval()
feats = []
vids = []
camids = []
trkids = []
with torch.no_grad():
for i, (img, vid, camid, trkid, _) in enumerate(data_loader):
img = img.to('cuda')
feat = model(img)
feats.append(feat)
vids.extend(vid)
camids.extend(camid)
trkids.extend(trkid)
feats = torch.cat(feats, dim=0)
vids = torch.tensor(vids).cpu().numpy()
camids = torch.tensor(camids).cpu().numpy()
trkids = torch.tensor(trkids).cpu().numpy()
return feats, vids, camids, trkids
# 计算特征之间的距离矩阵,包括 Jaccard 距离和余弦距离
def calc_distmat(feat):
rerank_distmat = compute_jaccard_distance(feat, k1=30, k2=6, search_option=3)
cosine_distmat = batch_cosine_dist(feat, feat).cpu().numpy()
final_dist = rerank_distmat * 0.9 + cosine_distmat * 0.1
return final_dist
def compute_P2(qf, gf, gc, la=3.0):
"""
计算摄像机对应的逆协方差矩阵和均值向量。
Args:
qf (Tensor): 查询集的特征张量。
gf (Tensor): 图库集的特征张量。
gc (Tensor): 图库集中每个特征对应的摄像机ID。
la (float, optional): 控制逆协方差矩阵的正则化强度。默认值为 3.0。
Returns:
P (dict): 摄像机对应的逆协方差矩阵字典。
neg_vec (dict): 摄像机对应的均值向量字典。
"""
X = gf
neg_vec = {}
u_cams = np.unique(gc)
P = {}
for cam in u_cams:
curX = gf[gc == cam]
neg_vec[cam] = torch.mean(curX, axis=0)
tmp_eye = torch.eye(X.shape[1]).cuda()
P[cam] = torch.inverse(curX.T.matmul(curX)+curX.shape[0]*la*tmp_eye)
return P, neg_vec
# 对特征进行处理,减去相应相机标签的均值
def meanfeat_sub(P, neg_vec, in_feats, in_cams):
"""
对输入特征进行去摄像头平均特征处理。
Args:
P (dict): 摄像机对应的逆协方差矩阵字典。
neg_vec (dict): 摄像机对应的均值向量字典。
in_feats (Tensor): 输入特征张量。
in_cams (Tensor): 输入特征对应的摄像机ID张量。
Returns:
out_feats (Tensor): 处理后的特征张量。
"""
out_feats = []
for i in range(in_feats.shape[0]):
camid = in_cams[i]
feat = in_feats[i] - neg_vec[camid]
feat = P[camid].matmul(feat)
feat = feat/torch.norm(feat, p=2)
out_feats.append(feat)
out_feats = torch.stack(out_feats)
return out_feats
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="ReID Baseline Training")
parser.add_argument(
"--config_file", default="", help="path to config file", type=str
)
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
# parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
# help='learning rate (default: 5e-4)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--epochs', default=120, type=int)
parser.add_argument("--local_rank", default=0, type=int)
args = parser.parse_args()
if args.config_file != "":
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
print('===========sep camera===========')
set_seed(cfg.SOLVER.SEED)
if cfg.MODEL.DIST_TRAIN:
torch.cuda.set_device(args.local_rank)
else:
pass
output_dir = cfg.OUTPUT_DIR
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
logger = setup_logger("reid_baseline", output_dir, if_train=True)
logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR))
# logger.info(args)
if args.config_file != "":
logger.info("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, 'r') as cf:
config_str = "\n" + cf.read()
# logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
if cfg.MODEL.DIST_TRAIN:
torch.distributed.init_process_group(backend='nccl', init_method='env://')
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID
# 初始化验证数据加载器、辅助加载器、模型,并加载预训练权重
val_loader, num_query, testset = get_testloader_uda(cfg)
aug_loader, num_query, _ = get_testloader_uda(cfg, aug=True)
num_classes = 1500
model = make_model(cfg, num_class=num_classes)
initial_weights = torch.load(cfg.MODEL.PRETRAIN_PATH, map_location='cpu')
copy_state_dict(initial_weights, model)
if True:
model.to(args.local_rank)
if torch.cuda.device_count() > 1 and cfg.MODEL.DIST_TRAIN:
print('Using {} GPUs for training'.format(torch.cuda.device_count()))
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
for epoch in range(cfg.SOLVER.MAX_EPOCHS):
if epoch % 3 == 0:
# 从验证集加载器中提取特征、标签、摄像机 ID 和轨迹 ID
target_features, target_labels, target_camids, target_trkids = extract_features(model, val_loader, print_freq=100)
# 对特征进行归一化
target_features = F.normalize(target_features, dim=1)
# 从增强集加载器中提取特征
aug_features, _, _, _ = extract_features(model, aug_loader, print_freq=100)
# 对特征进行归一化
aug_features = F.normalize(aug_features, dim=1)
# 对验证集特征和增强集的特征进行平均
target_features = (aug_features + target_features) / 2.0
# 计算每个摄像机的逆协方差矩阵(P)和均值向量(neg_vec)
P, neg_vec = compute_P2(target_features, target_features, target_camids, la=0.0005)
# 对图像特征进行减摄像头特征的操作
target_features = meanfeat_sub(P, neg_vec, target_features, target_camids)
# 将特征和轨迹特征结合更新target_features
gallery_trkids = target_trkids[num_query:]
unique_trkids = sorted(list(set(gallery_trkids[gallery_trkids != -1])))
gallery_features = target_features[num_query:]
track_features = []
for i, trkid in enumerate(unique_trkids):
track_feature = torch.mean(gallery_features[gallery_trkids == trkid], dim=0, keepdim=True)
tmp_indices = (gallery_trkids == trkid)
gallery_features[tmp_indices] = gallery_features[tmp_indices] * 0.1 + track_feature * 0.9
target_features[num_query:] = gallery_features
new_dataset = []
pids = []
# 给摄像机标签相同的样本添加一个较大的距离。
# 这样做的目的是在生成伪标签时,促使模型更好地区分摄像机标签相同但身份标签不同的样本,从而提升模型的泛化性能。
# 促使 DBSCAN 聚类算法更好地划分簇
final_dist = calc_distmat(target_features)
final_dist[final_dist < 0.0] = 0.0
final_dist[final_dist > 1.0] = 1.0
cam_matches = ((target_camids>=40).astype(np.int32) != (target_camids[:, np.newaxis]>=40).astype(np.int32)).astype(np.int32)
final_dist = final_dist + 10.0*cam_matches
# 使用DBSCAN聚类生成伪标签
cluster = DBSCAN(eps=0.55, min_samples=10, metric='precomputed', n_jobs=-1)
pseudo_labels = cluster.fit_predict(final_dist)
# 使用生成的伪标签更新模型参数
labelset = list(set(pseudo_labels[pseudo_labels >= 0]))
idxs = np.where(np.in1d(pseudo_labels, labelset))
psolabels = pseudo_labels[idxs]
psofeatures = target_features[idxs]
# mean_feature包含了每个聚类簇(伪标签对应的类别)的平均特征
mean_features = []
for label in labelset:
mean_indices = (psolabels == label)
mean_features.append(torch.mean(psofeatures[mean_indices], dim=0))
del target_features
# 创建一个新数据集,其中包含图像、伪标签、摄像机 ID 和轨迹 ID
for i, (item, label) in enumerate(zip(testset, pseudo_labels)):
if label == -1 or label not in labelset:
continue
pids.append(label)
new_dataset.append((item[0], label, int(item[2]), int(item[3])))
mean_features = torch.stack(mean_features).cuda()
num_classes = len(mean_features)
model.num_classes = len(mean_features)
# 模型的最后一层(classifier)被替换成一个新的线性层,该层的权重被设置为伪标签对应类别的平均特征
model.classifier = nn.Linear(model.in_planes, len(mean_features), bias=False)
model.classifier.weight = nn.Parameter(mean_features)
print('new class are {}, length of new dataset is {}'.format(len(set(pids)), len(new_dataset)))
# 重新训练模型
train_loader = IterLoader(get_trainloader_uda(cfg, new_dataset, num_classes))
train_loader.new_epoch()
loss_func, center_criterion = make_loss(cfg, num_classes=num_classes)
optimizer, optimizer_center = make_optimizer(cfg, model, center_criterion)
do_uda_train(
epoch,
cfg,
model,
center_criterion,
train_loader,
val_loader,
optimizer,
optimizer_center,
loss_func,
num_query, args.local_rank
)
print(cfg.OUTPUT_DIR)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。