代码拉取完成,页面将自动刷新
from __future__ import division
from models import *
from utils.logger import *
from utils.utils import *
from utils.datasets import *
from utils.augmentations import *
from utils.transforms import *
from utils.parse_config import *
from test import evaluate
from terminaltables import AsciiTable
import os
import sys
import time
import datetime
import argparse
import tqdm
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=8, help="size of each image batch")
parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
parser.add_argument("--data_config", type=str, default="config/coco.data", help="path to data config file")
parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
parser.add_argument("--evaluation_interval", type=int, default=1, help="interval evaluations on validation set")
parser.add_argument("--compute_map", default=False, help="if True computes mAP every tenth batch")
parser.add_argument("--multiscale_training", default=True, help="allow for multi-scale training")
parser.add_argument("--verbose", "-v", default=False, action='store_true', help="Makes the training more verbose")
parser.add_argument("--logdir", type=str, default="logs", help="Defines the directory where the training log files are stored")
opt = parser.parse_args()
print(opt)
logger = Logger(opt.logdir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("output", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
# Get data configuration
data_config = parse_data_config(opt.data_config)
train_path = data_config["train"]
valid_path = data_config["valid"]
class_names = load_classes(data_config["names"])
# Initiate model
model = Darknet(opt.model_def).to(device)
model.apply(weights_init_normal)
# If specified we start from checkpoint
if opt.pretrained_weights:
if opt.pretrained_weights.endswith(".pth"):
model.load_state_dict(torch.load(opt.pretrained_weights))
else:
model.load_darknet_weights(opt.pretrained_weights)
# Get dataloader
dataset = ListDataset(train_path, multiscale=opt.multiscale_training, img_size=opt.img_size, transform=AUGMENTATION_TRANSFORMS)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
pin_memory=True,
collate_fn=dataset.collate_fn,
)
optimizer = torch.optim.Adam(model.parameters())
metrics = [
"grid_size",
"loss",
"x",
"y",
"w",
"h",
"conf",
"cls",
"cls_acc",
"recall50",
"recall75",
"precision",
"conf_obj",
"conf_noobj",
]
for epoch in range(opt.epochs):
model.train()
start_time = time.time()
for batch_i, (_, imgs, targets) in enumerate(tqdm.tqdm(dataloader, desc=f"Training Epoch {epoch}")):
batches_done = len(dataloader) * epoch + batch_i
imgs = Variable(imgs.to(device))
targets = Variable(targets.to(device), requires_grad=False)
loss, outputs = model(imgs, targets)
loss.backward()
if batches_done % opt.gradient_accumulations == 0:
# Accumulates gradient before each step
optimizer.step()
optimizer.zero_grad()
# ----------------
# Log progress
# ----------------
log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (epoch, opt.epochs, batch_i, len(dataloader))
metric_table = [["Metrics", *[f"YOLO Layer {i}" for i in range(len(model.yolo_layers))]]]
# Log metrics at each YOLO layer
for i, metric in enumerate(metrics):
formats = {m: "%.6f" for m in metrics}
formats["grid_size"] = "%2d"
formats["cls_acc"] = "%.2f%%"
row_metrics = [formats[metric] % yolo.metrics.get(metric, 0) for yolo in model.yolo_layers]
metric_table += [[metric, *row_metrics]]
log_str += AsciiTable(metric_table).table
log_str += f"\nTotal loss {to_cpu(loss).item()}"
# Tensorboard logging
tensorboard_log = []
for j, yolo in enumerate(model.yolo_layers):
for name, metric in yolo.metrics.items():
if name != "grid_size":
tensorboard_log += [(f"train/{name}_{j+1}", metric)]
tensorboard_log += [("train/loss", to_cpu(loss).item())]
logger.list_of_scalars_summary(tensorboard_log, batches_done)
# Determine approximate time left for epoch
epoch_batches_left = len(dataloader) - (batch_i + 1)
time_left = datetime.timedelta(seconds=epoch_batches_left * (time.time() - start_time) / (batch_i + 1))
log_str += f"\n---- ETA {time_left}"
if opt.verbose: print(log_str)
model.seen += imgs.size(0)
if epoch % opt.evaluation_interval == 0:
print("\n---- Evaluating Model ----")
# Evaluate the model on the validation set
metrics_output = evaluate(
model,
path=valid_path,
iou_thres=0.5,
conf_thres=0.5,
nms_thres=0.5,
img_size=opt.img_size,
batch_size=8,
)
if metrics_output is not None:
precision, recall, AP, f1, ap_class = metrics_output
evaluation_metrics = [
("validation/precision", precision.mean()),
("validation/recall", recall.mean()),
("validation/mAP", AP.mean()),
("validation/f1", f1.mean()),
]
logger.list_of_scalars_summary(evaluation_metrics, epoch)
# Print class APs and mAP
ap_table = [["Index", "Class name", "AP"]]
for i, c in enumerate(ap_class):
ap_table += [[c, class_names[c], "%.5f" % AP[i]]]
print(AsciiTable(ap_table).table)
print(f"---- mAP {AP.mean()}")
else:
print( "---- mAP not measured (no detections found by model)")
if epoch % opt.checkpoint_interval == 0:
torch.save(model.state_dict(), f"checkpoints/yolov3_ckpt_%d.pth" % epoch)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。