代码拉取完成,页面将自动刷新
#----------------------------------------------------#
# 获取测试集的detection-result和images-optional
# 具体视频教程可查看
# https://www.bilibili.com/video/BV1zE411u7Vw
#----------------------------------------------------#
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Lambda
from nets.yolo4 import yolo_body,yolo_eval
from utils.utils import letterbox_image
from tqdm import tqdm
from yolo import YOLO
from PIL import Image
import numpy as np
import tensorflow as tf
import colorsys
import os
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
class mAP_YOLO(YOLO):
#---------------------------------------------------#
# 获得所有的分类
#---------------------------------------------------#
def generate(self):
self.score = 0.01
self.iou = 0.5
model_path = os.path.expanduser(self.model_path)
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
# 计算anchor数量
num_anchors = len(self.anchors)
num_classes = len(self.class_names)
# 载入模型,如果原来的模型里已经包括了模型结构则直接载入。
# 否则先构建模型再载入
self.yolo_model = yolo_body(Input(shape=(None,None,3)), num_anchors//3, num_classes)
self.yolo_model.load_weights(self.model_path)
print('{} model, anchors, and classes loaded.'.format(model_path))
# 画框设置不同的颜色
hsv_tuples = [(x / len(self.class_names), 1., 1.)
for x in range(len(self.class_names))]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
self.colors))
# 打乱颜色
np.random.seed(10101)
np.random.shuffle(self.colors)
np.random.seed(None)
if self.eager:
self.input_image_shape = Input([2,],batch_size=1)
inputs = [*self.yolo_model.output, self.input_image_shape]
outputs = Lambda(yolo_eval, output_shape=(1,), name='yolo_eval',
arguments={'anchors': self.anchors, 'num_classes': len(self.class_names), 'image_shape': self.model_image_size,
'score_threshold': self.score, 'eager': True})(inputs)
self.yolo_model = Model([self.yolo_model.input, self.input_image_shape], outputs)
else:
self.input_image_shape = K.placeholder(shape=(2, ))
self.boxes, self.scores, self.classes = yolo_eval(self.yolo_model.output, self.anchors,
num_classes, self.input_image_shape,
score_threshold=self.score, iou_threshold=self.iou)
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image_id, image):
f = open("./input/detection-results/"+image_id+".txt","w")
# 调整图片使其符合输入要求
new_image_size = self.model_image_size
boxed_image = letterbox_image(image, new_image_size)
image_data = np.array(boxed_image, dtype='float32')
image_data /= 255.
image_data = np.expand_dims(image_data, 0) # Add batch dimension.
if self.eager:
# 预测结果
input_image_shape = np.expand_dims(np.array([image.size[1], image.size[0]], dtype='float32'), 0)
out_boxes, out_scores, out_classes = self.yolo_model.predict([image_data, input_image_shape])
else:
# 预测结果
out_boxes, out_scores, out_classes = self.sess.run(
[self.boxes, self.scores, self.classes],
feed_dict={
self.yolo_model.input: image_data,
self.input_image_shape: [image.size[1], image.size[0]],
K.learning_phase(): 0
})
for i, c in enumerate(out_classes):
predicted_class = self.class_names[int(c)]
try:
score = str(out_scores[i].numpy())
except:
score = str(out_scores[i])
top, left, bottom, right = out_boxes[i]
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
f.close()
return
yolo = mAP_YOLO()
image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split()
if not os.path.exists("./input"):
os.makedirs("./input")
if not os.path.exists("./input/detection-results"):
os.makedirs("./input/detection-results")
if not os.path.exists("./input/images-optional"):
os.makedirs("./input/images-optional")
for image_id in tqdm(image_ids):
image_path = "./VOCdevkit/VOC2007/JPEGImages/"+image_id+".jpg"
image = Image.open(image_path)
# 开启后在之后计算mAP可以可视化
# image.save("./input/images-optional/"+image_id+".jpg")
yolo.detect_image(image_id,image)
print("Conversion completed!")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。