1 Star 0 Fork 2

13606799717/YoloV3物体检测

forked from cangye/YoloV3物体检测 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
yolo2.valid.py 2.20 KB
一键复制 编辑 原始数据 按行查看 历史
YUZIYE 提交于 2022-02-18 10:24 . add mobilenet
import torch
import os
import matplotlib.pyplot as plt
from utils.model2 import YoloModel
import torchvision.transforms as transforms
import cv2
from utils.transforms import Resize, DEFAULT_TRANSFORMS
import numpy as np
from utils.nms import rescale_boxes, non_max_suppression
import matplotlib.pyplot as plt
def main():
device = torch.device("cpu")
model = YoloModel()
model.eval()
model.to(device)
model.load_state_dict(torch.load("ckpt/23.pt", map_location=device))
image = cv2.imread("data/input.jpg")
#image = image.copy()[::4, ::4, :]
H, W, C = image.shape
print(image.shape, image.dtype)
# Configure input
input_img = transforms.Compose([
DEFAULT_TRANSFORMS,
Resize(416)])(
(image, np.zeros((1, 5))))[0].unsqueeze(0)
namedict = {}
with open("ckpt/coco.names", "r", encoding="utf-8") as f:
for i, line in enumerate(f.readlines()):
namedict[i] = line.strip()
with torch.no_grad():
print(input_img.max(), input_img.min(), input_img.shape)
#x = input_img.permute(0, 2, 3, 1)
#x = x.numpy()
#plt.imshow(x[0])
#plt.show()
detections = model(input_img)
detections = torch.cat(detections, 1)
detections = non_max_suppression(detections, 0.1, 0.1)
detections = rescale_boxes(detections[0], 416, image.shape[:2])
def between(a, b):
if a>0 and b>0 and a<W and b<H:
return True
else:
return False
image = image.copy()
for x1, y1, x2, y2, conf, cls_pred in detections.numpy():
if between(x1, y1) and between(x2, y2):
print("OUTPUT", H, W, x1, y1, x2, y2, image.dtype, image.shape)
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0))
predtype = namedict[int(cls_pred)]
cv2.putText(image, f"Conf:{conf:.2f},{predtype}", (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 0))
else:
print(x1, y1, x2, y2)
cv2.imwrite("out2.jpg", image)
cv2.imshow("www", image)
cv2.waitKey(0)
#nohup /home/yuzy/software/anaconda39/bin/python yolo.train.py > ckpt/yolo.log 2>&1 &
if __name__ == "__main__":
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhaoshifu111/yolo-v3-object-detection.git
git@gitee.com:zhaoshifu111/yolo-v3-object-detection.git
zhaoshifu111
yolo-v3-object-detection
YoloV3物体检测
master

搜索帮助