1 Star 0 Fork 2

13606799717/YoloV3物体检测

forked from cangye/YoloV3物体检测 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
yolo2.onnx.py 3.38 KB
一键复制 编辑 原始数据 按行查看 历史
cangye 提交于 2022-03-03 02:42 . update yolo2.onnx.py.
import cv2
import onnxruntime as rt
import numpy as np
def box_iou(boxes1, boxes2):
b1_x1, b1_y1, b1_x2, b1_y2 = \
boxes1[:, 0], boxes1[:, 1], boxes1[:, 2], boxes1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = boxes2[0], boxes2[1], boxes2[2], boxes2[3]
x1 = np.stack([b1_x1, b1_x1]).T
x2 = np.stack([b1_x2, b1_x2]).T
y1 = np.stack([b1_y1, b1_y1]).T
y2 = np.stack([b1_y2, b1_y2]).T
x1[:, 1] = b2_x1
x2[:, 1] = b2_x2
y1[:, 1] = b2_y1
y2[:, 1] = b2_y2
# Intersection area
inter_area = np.clip(np.min(x2, axis=1) - np.max(x1, axis=1) + 1, 0, np.inf) * \
np.clip(np.min(y2, axis=1) - np.max(y1, axis=1) + 1, 0, np.inf)
# Union Area
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
return iou
def non_max_supression(detec, c=0.3, i=0.1):
conf = detec[:, 4]
detec = detec[conf>c]
c1, c2, w, h, conf, clas = detec[:, 0], detec[:, 1], detec[:, 2], detec[:, 3], detec[:, 4], detec[:, 5:]
w2, h2 = w/2, h/2
x1, x2 = c1 - w2, c1 + w2
y1, y2 = c2 - h2, c2 + h2
out = np.stack([x1, y1, x2, y2, conf, np.argmax(clas, axis=1)]).T
sidx = np.argsort(conf)
out = out[sidx]
outputs = []
while True:
if len(out) == 0:
break
a = out[0]
outputs.append(a)
iou = box_iou(out[:], a)
out = out[iou<i]
out = np.stack(outputs)
return out
def main():
image = cv2.imread("data/input.jpg")
H, W, C = image.shape
S = np.max([H, W])
img32 = image.copy().astype(np.float32)/255
img = np.zeros([S, S, 3], dtype=np.float32)
padH, padW = -1, -1
if S==H:
w2 = (S-W)//2
img[:, w2:w2+W] = img32
padW = w2
else:
h2 = (S-H)//2
img[h2:h2+H, :] = img32
padH = h2
img = cv2.resize(img, (416, 416))
rate = S / 416
img = np.transpose(img[np.newaxis, ...], axes=[0, 3, 1, 2])
namedict = {}
with open("ckpt/coco.names", "r", encoding="utf-8") as f:
for i, line in enumerate(f.readlines()):
namedict[i] = line.strip()
sess = rt.InferenceSession("ckpt/mobilenet.yolo.onnx")
detections = sess.run(["output"], {"image":img})[0]
detections = detections[0]
print(detections.shape)
def between(a, b):
if a>0 and b>0 and a<W and b<H:
return True
else:
return False
#image = img[0].transpose([1, 2, 0]).copy()
detections = non_max_supression(detections, c=0.1)
for temp in detections:
x1, y1, x2, y2, conf, cls_pred = temp
if padH > 0:
x1 = x1 * rate
x2 = x2 * rate
y1 = y1 * rate - padH
y2 = y2 * rate - padH
else:
y1 = y1 * rate
y2 = y2 * rate
x1 = x1 * rate - padW
x2 = x2 * rate - padW
if between(x1, y1) and between(x2, y2):
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0))
predtype = namedict[int(cls_pred)]
cv2.putText(image, f"Conf:{conf:.2f},{predtype}", (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 0, 0))
else:
print(x1, y1, x2, y2)
cv2.imwrite("out2.jpg", image)
cv2.imshow("www", image)
cv2.waitKey(0)
if __name__ == "__main__":
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhaoshifu111/yolo-v3-object-detection.git
git@gitee.com:zhaoshifu111/yolo-v3-object-detection.git
zhaoshifu111
yolo-v3-object-detection
YoloV3物体检测
master

搜索帮助