1 Star 0 Fork 2

13606799717/YoloV3物体检测

forked from cangye/YoloV3物体检测 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
yolo2.train.py 1.27 KB
一键复制 编辑 原始数据 按行查看 历史
YUZIYE 提交于 2022-02-18 10:24 . add mobilenet
from utils.data import ListDataset
from torch.utils.data import DataLoader
import torch
import os
import matplotlib.pyplot as plt
from utils.model2 import YoloModel
from utils.loss import compute_loss
def main():
device = torch.device("cuda:0")
dataset = ListDataset("data")
dataloader = DataLoader(dataset, 32, shuffle=True, num_workers=8, collate_fn=dataset.collate_fn)
model = YoloModel()
model.train()
model.to(device)
optim = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-3)
n = 0
model.load_state_dict(torch.load("ckpt/mobilenet.pt"))
print("样本数量", len(dataset))
for e in range(30):
for temp, imgs, targets in dataloader:
imgs = imgs.to(device)
targets = targets.to(device)
outs = model(imgs)
loss = compute_loss(outs, targets)
loss.backward()
optim.step()
optim.zero_grad()
if n % 50 == 0:
print(e, n, loss)
torch.save(model.state_dict(), f"ckpt/mobilenet.pt")
n += 1
torch.save(model.state_dict(), f"ckpt/{e}.pt")
print(loss)
#nohup /home/yuzy/software/anaconda39/bin/python yolo.train.py > ckpt/yolo.log 2>&1 &
if __name__ == "__main__":
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhaoshifu111/yolo-v3-object-detection.git
git@gitee.com:zhaoshifu111/yolo-v3-object-detection.git
zhaoshifu111
yolo-v3-object-detection
YoloV3物体检测
master

搜索帮助