代码拉取完成,页面将自动刷新
同步操作将从 cangye/YoloV3物体检测 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from utils.data import ListDataset
from torch.utils.data import DataLoader
import torch
import os
import matplotlib.pyplot as plt
from utils.model2 import YoloModel
from utils.loss import compute_loss
def main():
device = torch.device("cuda:0")
dataset = ListDataset("data")
dataloader = DataLoader(dataset, 32, shuffle=True, num_workers=8, collate_fn=dataset.collate_fn)
model = YoloModel()
model.train()
model.to(device)
optim = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-3)
n = 0
model.load_state_dict(torch.load("ckpt/mobilenet.pt"))
print("样本数量", len(dataset))
for e in range(30):
for temp, imgs, targets in dataloader:
imgs = imgs.to(device)
targets = targets.to(device)
outs = model(imgs)
loss = compute_loss(outs, targets)
loss.backward()
optim.step()
optim.zero_grad()
if n % 50 == 0:
print(e, n, loss)
torch.save(model.state_dict(), f"ckpt/mobilenet.pt")
n += 1
torch.save(model.state_dict(), f"ckpt/{e}.pt")
print(loss)
#nohup /home/yuzy/software/anaconda39/bin/python yolo.train.py > ckpt/yolo.log 2>&1 &
if __name__ == "__main__":
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。