2 Star 2 Fork 2

liuswot/fire-detect-yolov4

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
latest_darknet_API.py 5.99 KB
一键复制 编辑 原始数据 按行查看 历史
gengyanlei 提交于 2020-11-19 11:27 . add latest darknet python API
'''
注释:
author is leilei
由于最新版本的darknet将images和video预测分开了,并不是很符合自己的额需求,
因此参考darknet_images.py修改成属于自己的预测函数
注:
darknet.py 核心函数:load_network、detect_image draw_boxes bbox2points
darknet_images.py 核心函数: image_detection,此函数需要修改成输入图像
darknet官方写的预测图像输出依旧为正方形,而非原图!因此要转换成原图
'''
import os
import cv2
import numpy as np
import darknet
class Detect:
def __init__(self, metaPath, configPath, weightPath, gpu_id=2, batch=1):
'''
:param metaPath: ***.data 存储各种参数
:param configPath: ***.cfg 网络结构文件
:param weightPath: ***.weights yolo的权重
:param batch: ########此类只支持batch=1############
'''
assert batch == 1, "batch必须为1"
# 设置gpu_id
darknet.set_gpu(gpu_id)
# 网络
network, class_names, class_colors = darknet.load_network(
configPath,
metaPath,
weightPath,
batch_size=batch
)
self.network = network
self.class_names = class_names
self.class_colors = class_colors
def bbox2point(self, bbox):
x, y, w, h = bbox
xmin = x - (w / 2)
xmax = x + (w / 2)
ymin = y - (h / 2)
ymax = y + (h / 2)
return (xmin, ymin, xmax, ymax)
def point2bbox(self, point):
x1, y1, x2, y2 = point
x = (x1 + x2) / 2
y = (y1 + y2) / 2
w = (x2 - x1)
h = (y2 - y1)
return (x, y, w, h)
def image_detection(self, image_bgr, network, class_names, class_colors, thresh=0.25):
# 判断输入图像是否为3通道
if len(image_bgr.shape) == 2:
image_bgr = np.stack([image_bgr] * 3, axis=-1)
# 获取原始图像大小
orig_h, orig_w = image_bgr.shape[:2]
width = darknet.network_width(network)
height = darknet.network_height(network)
darknet_image = darknet.make_image(width, height, 3)
# image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
image_resized = cv2.resize(image_rgb, (width, height), interpolation=cv2.INTER_LINEAR)
darknet.copy_image_from_bytes(darknet_image, image_resized.tobytes())
detections = darknet.detect_image(network, class_names, darknet_image, thresh=thresh)
darknet.free_image(darknet_image)
'''注意:这里原始代码依旧是608*608,而不是原图大小,因此我们需要转换'''
new_detections = []
for detection in detections:
pred_label, pred_conf, (x, y, w, h) = detection
new_x = x / width * orig_w
new_y = y / height * orig_h
new_w = w / width * orig_w
new_h = h / height * orig_h
# 可以约束一下
(x1, y1, x2, y2) = self.bbox2point((new_x, new_y, new_w, new_h))
x1 = x1 if x1 > 0 else 0
x2 = x2 if x2 < orig_w else orig_w
y1 = y1 if y1 > 0 else 0
y2 = y2 if y2 < orig_h else orig_h
(new_x, new_y, new_w, new_h) = self.point2bbox((x1, y1, x2, y2))
new_detections.append((pred_label, pred_conf, (new_x, new_y, new_w, new_h)))
image = darknet.draw_boxes(new_detections, image_rgb, class_colors)
return cv2.cvtColor(image, cv2.COLOR_RGB2BGR), new_detections
def predict_image(self, image_bgr, thresh=0.25, is_show=True, save_path=''):
'''
:param image_bgr: 输入图像
:param thresh: 置信度阈值
:param is_show: 是否将画框之后的原始图像返回
:param save_path: 画框后的保存路径, eg='/home/aaa.jpg'
:return:
'''
draw_bbox_image, detections = self.image_detection(image_bgr, self.network, self.class_names, self.class_colors,
thresh)
if is_show:
if save_path:
cv2.imwrite(save_path, draw_bbox_image)
return draw_bbox_image
return detections
if __name__ == '__main__':
# gpu 通过环境变量设置
detect = Detect(metaPath=r'/home/cfg/sg.data',
configPath=r'/home/cfg/yolov4-sg.cfg',
weightPath=r'/home/yolov4-sg_best.weights',
gpu_id=1)
# 读取单张图像
# image_path = r'/home/aa.jpg'
# image = cv2.imread(image_path, -1)
# draw_bbox_image = detect.predict_image(image, save_path='./pred.jpg')
# 读取文件夹
image_root = r'/home/Datasets/image/'
save_root = r'./output'
if not os.path.exists(save_root):
os.makedirs(save_root)
for name in os.listdir(image_root):
print(name)
image = cv2.imread(os.path.join(image_root, name), -1)
draw_bbox_image = detect.predict_image(image, save_path=os.path.join(save_root, name))
# 读取视频
# video_path = r'/home/Datasets/SHIJI_Fire/20200915_2.mp4'
# video_save_path = r'/home/20200915_3_pred.mp4'
# cap = cv2.VideoCapture(video_path)
# # 获取视频的fps, width height
# fps = int(cap.get(cv2.CAP_PROP_FPS))
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# print(count)
# # 创建视频
# fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# # fourcc = cv2.VideoWriter_fourcc('I', '4', '2', '0')
# video_writer = cv2.VideoWriter(video_save_path, fourcc=fourcc, fps=fps, frameSize=(width,height))
# ret, frame = cap.read() # ret表示下一帧还有没有 有为True
# while ret:
# # 预测每一帧
# pred = detect.predict_image(frame)
# video_writer.write(pred)
# cv2.waitKey(fps)
# # 读取下一帧
# ret, frame = cap.read()
# print(ret)
# cap.release()
# cv2.destroyAllWindows()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/liuswot/fire-detect-yolov4.git
git@gitee.com:liuswot/fire-detect-yolov4.git
liuswot
fire-detect-yolov4
fire-detect-yolov4
master

搜索帮助