1 Star 0 Fork 0

colinchern/deeplabv3-plus-keras

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
deeplab.py 7.17 KB
一键复制 编辑 原始数据 按行查看 历史
Bubbliiiing 提交于 2021-06-30 10:15 . Add files via upload
import colorsys
import copy
import time
import numpy as np
from PIL import Image
from nets.deeplab import Deeplabv3
#--------------------------------------------#
# 使用自己训练好的模型预测需要修改3个参数
# model_path、backbone和num_classes都需要修改!
# 如果出现shape不匹配
# 一定要注意训练时的model_path、
# backbone和num_classes数的修改
#--------------------------------------------#
class Deeplab(object):
_defaults = {
"model_path" : 'model_data/deeplabv3_mobilenetv2.h5',
"model_image_size" : (512, 512, 3),
"backbone" : "mobilenet",
"downsample_factor" : 16,
"num_classes" : 21,
#--------------------------------#
# blend参数用于控制是否
# 让识别结果和原图混合
#--------------------------------#
"blend" : False,
}
#---------------------------------------------------#
# 初始化Deeplab
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
self.generate()
#---------------------------------------------------#
# 获得所有的分类
#---------------------------------------------------#
def generate(self):
#-------------------------------#
# 载入模型与权值
#-------------------------------#
self.model = Deeplabv3(self.num_classes,self.model_image_size,backbone=self.backbone,downsample_factor=self.downsample_factor)
self.model.load_weights(self.model_path)
print('{} model loaded.'.format(self.model_path))
if self.num_classes <= 21:
self.colors = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), (128, 64, 12)]
else:
# 画框设置不同的颜色
hsv_tuples = [(x / len(self.class_names), 1., 1.)
for x in range(len(self.class_names))]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
self.colors))
def letterbox_image(self ,image, size):
'''resize image with unchanged aspect ratio using padding'''
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image,nw,nh
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
#---------------------------------------------------------#
image = image.convert('RGB')
#---------------------------------------------------#
# 对输入图像进行一个备份,后面用于绘图
#---------------------------------------------------#
old_img = copy.deepcopy(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
#---------------------------------------------------#
# 进行不失真的resize,添加灰条,进行图像归一化
#---------------------------------------------------#
img, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0]))
img = [np.array(img)/127.5-1]
img = np.asarray(img)
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.model.predict(img)[0]
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]])
pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
#------------------------------------------------#
# 创建一副新图,并根据每个像素点的种类赋予颜色
#------------------------------------------------#
seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
for c in range(self.num_classes):
seg_img[:,:,0] += ((pr[:,: ] == c )*( self.colors[c][0] )).astype('uint8')
seg_img[:,:,1] += ((pr[:,: ] == c )*( self.colors[c][1] )).astype('uint8')
seg_img[:,:,2] += ((pr[:,: ] == c )*( self.colors[c][2] )).astype('uint8')
#------------------------------------------------#
# 将新图片转换成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img)).resize((orininal_w,orininal_h), Image.NEAREST)
#------------------------------------------------#
# 将新图片和原图片混合
#------------------------------------------------#
if self.blend:
image = Image.blend(old_img,image,0.7)
return image
def get_FPS(self, image, test_interval):
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
img, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0]))
img = [np.array(img)/127.5-1]
img = np.asarray(img)
pr = self.model.predict(img)[0]
pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]])
pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
image = Image.fromarray(np.uint8(pr)).resize((orininal_w,orininal_h), Image.NEAREST)
t1 = time.time()
for _ in range(test_interval):
pr = self.model.predict(img)[0]
pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]])
pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
image = Image.fromarray(np.uint8(pr)).resize((orininal_w,orininal_h), Image.NEAREST)
t2 = time.time()
tact_time = (t2 - t1) / test_interval
return tact_time
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/colinchern/deeplabv3-plus-keras.git
git@gitee.com:colinchern/deeplabv3-plus-keras.git
colinchern
deeplabv3-plus-keras
deeplabv3-plus-keras
main

搜索帮助