1 Star 4 Fork 0

Dominic23331/yolox_tensorrt

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
yolo.py 5.63 KB
一键复制 编辑 原始数据 按行查看 历史
Dominic23331 提交于 2022-09-30 10:09 . 增加了注释
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
import cv2
from PIL import Image
import yaml
from prettytable import PrettyTable
from tqdm import tqdm
import common
from utils.image import letterbox_image, get_n_hls_colors_v2
from utils.decode import decode_outputs, non_max_suppression, get_real_boxes
class Yolo(object):
def __init__(self, engine_path, config, logger_severity, warmup_epoch=10):
"""
YOLO类
:param engine_path: engine文件路径
:param config: config文件路径
:param logger_severity: logger等级
:param warmup_epoch: 预热轮次
"""
self.logger = trt.Logger(logger_severity) # 创建日志对象
# 生成引擎以及开辟内存空间
self.engine = self.get_engine(engine_path)
self.context = self.engine.create_execution_context()
self.inputs, self.outputs, self.binding, self.stream = common.allocate_buffers(self.engine)
# 读取config文件
with open(config, "r") as fp:
self.config = yaml.safe_load(fp)
self.num_classes = len(self.config["classes"])
self.colors = get_n_hls_colors_v2(self.num_classes) # 获取检测框颜色
# 网络预热
self.warmup(warmup_epoch)
def get_engine(self, engine_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}".format(engine_path))
with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
return runtime.deserialize_cuda_engine(f.read())
def print_engine(self):
table = PrettyTable(["binding name", "is input", "binding size", "binding shape", "dtype"])
for binding in self.engine:
dims = self.engine.get_binding_shape(binding)
size = trt.volume(dims)
is_input = self.engine.binding_is_input(binding)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
table.add_row([binding, is_input, size, str(dims), dtype])
print(table)
def warmup(self, epoch):
print("start warm up!")
t = np.random.random(self.config["input_shape"])
np.copyto(self.inputs[0].host, t.reshape(-1))
for i in tqdm(range(epoch)):
common.do_inference_v2(self.context, self.binding, self.inputs, self.outputs, self.stream)
def transform(self, img):
"""
图像预处理
:param img: 输入图像
:return: 预处理后的图像
"""
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = letterbox_image(img, self.config["input_shape"][2:]).astype(np.float32) / 255. # 给图像增加灰条并resize
# 标准化图像
img -= np.array(self.config["mean"])
img /= np.array(self.config["std"])
# shape: 640x640x3 -> 1x3x640x640
img = np.transpose(img, [2, 0, 1])
img = np.expand_dims(img, axis=0)
return img
def predict(self, img):
"""
图像推理
:param img: 输入图像
:return: 推理结果
"""
h, w, _ = img.shape
img = self.transform(img)
np.copyto(self.inputs[0].host, img.reshape(-1)) # 将图像拷贝到分配的内存当中
result = common.do_inference_v2(self.context, self.binding, self.inputs, self.outputs, self.stream) # 网络推理
# 获取输出特征层
for i in range(3):
result[i] = np.reshape(result[i], newshape=[1, 5 + self.num_classes] + self.config["stage"][i])
result = decode_outputs(result, self.config["input_shape"][2:]) # 解码结果
result = non_max_suppression(result, self.num_classes, self.config["conf_thres"], self.config["iou_thres"]) # nms
result = get_real_boxes(result, (w, h)) # 获取真实框
return result
def draw_bboxes(self, img, bboxes, thickness=1, print_info=True):
"""
绘制框
:param img: 原始图像
:param bboxes: 推理结果
:param thickness: 线框粗细
:param print_info: 是否打印推理结果
:return:
"""
for box in bboxes:
color = self.colors[int(box[-1])] # 获取对应颜色
img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness) # 绘制框
conf = box[4] * box[5] # 求解置信度
classes = self.config["classes"][int(box[-1])] # 获取label
img = cv2.putText(img, "{} {:.2f}".format(classes, conf), (int(box[0]), int(box[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, thickness) # 绘制种类
if print_info:
print("{} {:.2f} {} {} {} {}".format(classes, conf, int(box[0]), int(box[1]), int(box[2]), int(box[3])))
return img
def show_result(self, img, bboxes, thickness=1, print_info=True, show_type="PIL", title="result"):
"""
展示检测结果
:param img: 原始图像
:param bboxes: 推理结果
:param thickness: 线框粗细
:param print_info: 是否打印推理结果
:param show_type: 展示方式,可选择 PIL或cv2
:param title: 图像标题
:return: 无
"""
img = self.draw_bboxes(img, bboxes, thickness, print_info) # 绘制结果
# 显示图像
if show_type == "cv2":
cv2.imshow(title, img)
cv2.waitKey(0)
elif show_type == "PIL":
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
img.show(title=title)
else:
raise KeyError("Please use cv2 or PIL.")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/dominic23331/yolox_tensorrt.git
git@gitee.com:dominic23331/yolox_tensorrt.git
dominic23331
yolox_tensorrt
yolox_tensorrt
master

搜索帮助