代码拉取完成,页面将自动刷新
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
import cv2
from PIL import Image
import yaml
from prettytable import PrettyTable
from tqdm import tqdm
import common
from utils.image import letterbox_image, get_n_hls_colors_v2
from utils.decode import decode_outputs, non_max_suppression, get_real_boxes
class Yolo(object):
def __init__(self, engine_path, config, logger_severity, warmup_epoch=10):
"""
YOLO类
:param engine_path: engine文件路径
:param config: config文件路径
:param logger_severity: logger等级
:param warmup_epoch: 预热轮次
"""
self.logger = trt.Logger(logger_severity) # 创建日志对象
# 生成引擎以及开辟内存空间
self.engine = self.get_engine(engine_path)
self.context = self.engine.create_execution_context()
self.inputs, self.outputs, self.binding, self.stream = common.allocate_buffers(self.engine)
# 读取config文件
with open(config, "r") as fp:
self.config = yaml.safe_load(fp)
self.num_classes = len(self.config["classes"])
self.colors = get_n_hls_colors_v2(self.num_classes) # 获取检测框颜色
# 网络预热
self.warmup(warmup_epoch)
def get_engine(self, engine_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}".format(engine_path))
with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
return runtime.deserialize_cuda_engine(f.read())
def print_engine(self):
table = PrettyTable(["binding name", "is input", "binding size", "binding shape", "dtype"])
for binding in self.engine:
dims = self.engine.get_binding_shape(binding)
size = trt.volume(dims)
is_input = self.engine.binding_is_input(binding)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
table.add_row([binding, is_input, size, str(dims), dtype])
print(table)
def warmup(self, epoch):
print("start warm up!")
t = np.random.random(self.config["input_shape"])
np.copyto(self.inputs[0].host, t.reshape(-1))
for i in tqdm(range(epoch)):
common.do_inference_v2(self.context, self.binding, self.inputs, self.outputs, self.stream)
def transform(self, img):
"""
图像预处理
:param img: 输入图像
:return: 预处理后的图像
"""
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = letterbox_image(img, self.config["input_shape"][2:]).astype(np.float32) / 255. # 给图像增加灰条并resize
# 标准化图像
img -= np.array(self.config["mean"])
img /= np.array(self.config["std"])
# shape: 640x640x3 -> 1x3x640x640
img = np.transpose(img, [2, 0, 1])
img = np.expand_dims(img, axis=0)
return img
def predict(self, img):
"""
图像推理
:param img: 输入图像
:return: 推理结果
"""
h, w, _ = img.shape
img = self.transform(img)
np.copyto(self.inputs[0].host, img.reshape(-1)) # 将图像拷贝到分配的内存当中
result = common.do_inference_v2(self.context, self.binding, self.inputs, self.outputs, self.stream) # 网络推理
# 获取输出特征层
for i in range(3):
result[i] = np.reshape(result[i], newshape=[1, 5 + self.num_classes] + self.config["stage"][i])
result = decode_outputs(result, self.config["input_shape"][2:]) # 解码结果
result = non_max_suppression(result, self.num_classes, self.config["conf_thres"], self.config["iou_thres"]) # nms
result = get_real_boxes(result, (w, h)) # 获取真实框
return result
def draw_bboxes(self, img, bboxes, thickness=1, print_info=True):
"""
绘制框
:param img: 原始图像
:param bboxes: 推理结果
:param thickness: 线框粗细
:param print_info: 是否打印推理结果
:return:
"""
for box in bboxes:
color = self.colors[int(box[-1])] # 获取对应颜色
img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness) # 绘制框
conf = box[4] * box[5] # 求解置信度
classes = self.config["classes"][int(box[-1])] # 获取label
img = cv2.putText(img, "{} {:.2f}".format(classes, conf), (int(box[0]), int(box[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, thickness) # 绘制种类
if print_info:
print("{} {:.2f} {} {} {} {}".format(classes, conf, int(box[0]), int(box[1]), int(box[2]), int(box[3])))
return img
def show_result(self, img, bboxes, thickness=1, print_info=True, show_type="PIL", title="result"):
"""
展示检测结果
:param img: 原始图像
:param bboxes: 推理结果
:param thickness: 线框粗细
:param print_info: 是否打印推理结果
:param show_type: 展示方式,可选择 PIL或cv2
:param title: 图像标题
:return: 无
"""
img = self.draw_bboxes(img, bboxes, thickness, print_info) # 绘制结果
# 显示图像
if show_type == "cv2":
cv2.imshow(title, img)
cv2.waitKey(0)
elif show_type == "PIL":
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
img.show(title=title)
else:
raise KeyError("Please use cv2 or PIL.")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。