代码拉取完成,页面将自动刷新
#-*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import time
import sys
import glob
import cv2
import six
import gc
import paddle
import paddle.fluid as fluid
from collections import namedtuple
import paddle.dataset as dataset
from data_augmentor import DataAugmentor
import data_augmentor
# 路径相关
RootPath = os.path.abspath("./")
sys.path.append(RootPath)
print ("项目根目录路径为: ",RootPath)
# 标注数据类别
Label = namedtuple( 'Label' , [
'name' ,
'id' ,
'trainId' ,
'category' ,
'categoryId' ,
'hasInstances',
'ignoreInEval',
'color' ,
] )
# 标注定义
labels = [
# name id trainId category catId hasInstances ignoreInEval color
Label( 'void' , 0 , 0, 'void' , 0 , False , False , ( 0, 0, 0) ),
Label( 's_w_d' , 200 , 1 , 'dividing' , 1 , False , False , ( 70, 130, 180) ),
Label( 's_y_d' , 204 , 1 , 'dividing' , 1 , False , False , (220, 20, 60) ),
Label( 'ds_w_dn' , 213 , 1 , 'dividing' , 1 , False , True , (128, 0, 128) ),
Label( 'ds_y_dn' , 209 , 1 , 'dividing' , 1 , False , False , (255, 0, 0) ),
Label( 'sb_w_do' , 206 , 1 , 'dividing' , 1 , False , True , ( 0, 0, 60) ),
Label( 'sb_y_do' , 207 , 1 , 'dividing' , 1 , False , True , ( 0, 60, 100) ),
Label( 'b_w_g' , 201 , 2 , 'guiding' , 2 , False , False , ( 0, 0, 142) ),
Label( 'b_y_g' , 203 , 2 , 'guiding' , 2 , False , False , (119, 11, 32) ),
Label( 'db_w_g' , 211 , 2 , 'guiding' , 2 , False , True , (244, 35, 232) ),
Label( 'db_y_g' , 208 , 2 , 'guiding' , 2 , False , True , ( 0, 0, 160) ),
Label( 'db_w_s' , 216 , 3 , 'stopping' , 3 , False , True , (153, 153, 153) ),
Label( 's_w_s' , 217 , 3 , 'stopping' , 3 , False , False , (220, 220, 0) ),
Label( 'ds_w_s' , 215 , 3 , 'stopping' , 3 , False , True , (250, 170, 30) ),
Label( 's_w_c' , 218 , 4 , 'chevron' , 4 , False , True , (102, 102, 156) ),
Label( 's_y_c' , 219 , 4 , 'chevron' , 4 , False , True , (128, 0, 0) ),
Label( 's_w_p' , 210 , 5 , 'parking' , 5 , False , False , (128, 64, 128) ),
Label( 's_n_p' , 232 , 5 , 'parking' , 5 , False , True , (238, 232, 170) ),
Label( 'c_wy_z' , 214 , 6 , 'zebra' , 6 , False , False , (190, 153, 153) ),
Label( 'a_w_u' , 202 , 7 , 'thru/turn' , 7 , False , True , ( 0, 0, 230) ),
Label( 'a_w_t' , 220 , 7 , 'thru/turn' , 7 , False , False , (128, 128, 0) ),
Label( 'a_w_tl' , 221 , 7 , 'thru/turn' , 7 , False , False , (128, 78, 160) ),
Label( 'a_w_tr' , 222 , 7 , 'thru/turn' , 7 , False , False , (150, 100, 100) ),
Label( 'a_w_tlr' , 231 , 7 , 'thru/turn' , 7 , False , True , (255, 165, 0) ),
Label( 'a_w_l' , 224 , 7 , 'thru/turn' , 7 , False , False , (180, 165, 180) ),
Label( 'a_w_r' , 225 , 7 , 'thru/turn' , 7 , False , False , (107, 142, 35) ),
Label( 'a_w_lr' , 226 , 7 , 'thru/turn' , 7 , False , False , (201, 255, 229) ),
Label( 'a_n_lu' , 230 , 7 , 'thru/turn' , 7 , False , True , (0, 191, 255) ),
Label( 'a_w_tu' , 228 , 7 , 'thru/turn' , 7 , False , True , ( 51, 255, 51) ),
Label( 'a_w_m' , 229 , 7 , 'thru/turn' , 7 , False , True , (250, 128, 114) ),
Label( 'a_y_t' , 233 , 7 , 'thru/turn' , 7 , False , True , (127, 255, 0) ),
Label( 'b_n_sr' , 205 , 8 , 'reduction' , 8 , False , False , (255, 128, 0) ),
Label( 'd_wy_za' , 212 , 8 , 'attention' , 8 , False , True , ( 0, 255, 255) ),
Label( 'r_wy_np' , 227 , 8 , 'no parking' , 8 , False , False , (178, 132, 190) ),
Label( 'vom_wy_n' , 223 , 8 , 'others' , 8 , False , True , (128, 128, 64) ),
Label( 'om_n_n' , 250 , 8 , 'others' , 8 , False , False , (102, 0, 204) ),
Label( 'noise' , 249 , 0 , 'ignored' , 0 , False , True , ( 0, 153, 153) ),
Label( 'ignored' , 255 , 0 , 'ignored' , 0 , False , True , (255, 255, 255) ),
]
# 名字转标注
name2label = { label.name : label for label in labels }
# id转标注
id2label = { label.id : label for label in labels }
# 训练id转标注
trainId2label = { label.trainId : label for label in reversed(labels) }
print ("标准转换检测 200 ---> 1: ",id2label[200].trainId)
# 数据预处理
augmentor = DataAugmentor()
class TrainDataReader:
def __init__(self, dataset_dir, subset='train',rows=2000, cols=1354, shuffle=True, birdeye=True):
label_dirname = dataset_dir + subset
print (label_dirname)
if six.PY2:
import commands
label_files = commands.getoutput(
"find %s -type f | grep _bin.png | sort" %
label_dirname).splitlines()
else:
import subprocess
label_files = subprocess.getstatusoutput(
"find %s -type f | grep _bin.png | sort" %
label_dirname)[-1].splitlines()
print ('---')
print (label_files[0])
self.label_files = label_files
self.label_dirname = label_dirname
self.rows = rows
self.cols = cols
self.index = 0
self.subset = subset
self.dataset_dir = dataset_dir
self.shuffle = shuffle
self.M = 0
self.Minv = 0
self.reset()
self.get_M_Minv()
self.augmentor = 0
self.birdeye = birdeye
print("images total number", len(label_files))
# 标签转分类 255 ignore ?
def label2classes(self, label,row,col):
x = np.zeros([row,col,9]).astype(np.int64)
for i in range(row):
for j in range(col):
try:
trainId = id2label[int(label[i][j])].trainId
x[i, j ,trainId] = 1 # 属于第m类,第三维m处值为1
except Exception as err:
#print('像素级标签值异常',err)
pass
return x
def get_M_Minv(self):
# 左上、右上、左下、右下
src = np.float32([[800, 730], [2583, 730], [0, 1709], [3383, 1709]])
dst = np.float32([[0, 0], [3999,0], [1300, 3999], [2700, 3999]])
self.M = cv2.getPerspectiveTransform(src, dst)
self.Minv = cv2.getPerspectiveTransform(dst,src)
def reset(self, shuffle=False):
self.index = 0
if self.shuffle:
np.random.shuffle(self.label_files)
def next_img(self):
self.index += 1
if self.index >= len(self.label_files):
self.reset()
def prev_img(self):
if self.index >= 1:
self.index -= 1
def get_img(self):
#if self.augmentor != 0 and self.augmentor < 2:
# self.prev_img()
while True:
label_name = self.label_files[self.index]
img_name = label_name.replace('_bin.png', '.jpg')
img_name = img_name.replace('Label', 'ColorImage')
label = cv2.imread(label_name,cv2.IMREAD_GRAYSCALE)
img = cv2.imread(img_name)
if img is None:
print("load img failed:", img_name)
self.next_img()
else:
break
try:
if self.birdeye ==True:
warped_img = cv2.warpPerspective(img, self.M, (4000, 4000),flags=cv2.INTER_CUBIC)
warped_label = cv2.warpPerspective(label, self.M, (4000, 4000),flags=cv2.INTER_NEAREST)
label = cv2.resize(warped_label, (self.cols, self.rows), interpolation=cv2.INTER_NEAREST)
img = cv2.resize(warped_img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
else:
label = cv2.resize(label, (self.cols, self.rows), interpolation=cv2.INTER_NEAREST)
img = cv2.resize(img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
except Exception as err:
print('warped_error: ',err)
img = np.zeros([self.cols,self.rows,3]).astype(np.uint8)
label = np.zeros([self.cols,self.rows]).astype(np.uint8)
# 数据增广
if self.augmentor != 0:
if self.augmentor < 2:
img,label = augmentor.disturb(img, label)
else :
self.augmentor = 0
img = img.transpose((2,0,1))
label = self.label2classes(label,self.rows, self.cols) # 转换为 9 分类
return img, label, label_name
def get_batch(self, batch_size=1):
imgs = []
labels = []
names = []
while len(imgs) < batch_size:
img, label, label_name = self.get_img()
imgs.append(img)
labels.append(label)
names.append(label_name)
self.next_img()
self.augmentor += 1
return np.array(imgs), np.array(labels), names
def get_batch_generator(self, batch_size, total_step):
def do_get_batch():
for i in range(total_step):
gc.collect()
try:
imgs, labels, names = self.get_batch(batch_size)
except Exception as err:
imgs, labels, names = self.get_batch(batch_size)
print('Generator 异常',err)
imgs = imgs.astype(np.float32)
labels = labels.astype(np.float32)
imgs /= 255
yield i, imgs, labels, names
batches = do_get_batch()
try:
from prefetch_generator import BackgroundGenerator
batches = BackgroundGenerator(batches, 10)
except:
print(
"You can install 'prefetch_generator' for acceleration of data reading."
)
return batches
class TestDataReader:
def __init__(self, dataset_dir, subset='test',rows=880, cols=596, shuffle=False, birdeye=True):
image_dirname = os.path.join(dataset_dir,subset)
print (image_dirname)
image_files = sorted(glob.glob(image_dirname+"/image/*."+"jpg"))
print ('---')
print (image_files[0])
self.image_files = image_files
self.image_dirname = image_dirname
self.rows = rows
self.cols = cols
self.index = 0
self.subset = subset
self.dataset_dir = dataset_dir
self.shuffle = shuffle
self.M = 0
self.Minv = 0
self.reset()
self.get_M_Minv()
self.birdeye = birdeye
print("images total number", len(image_files))
def get_M_Minv(self):
# 左上、右上、左下、右下
src = np.float32([[800, 730], [2583, 730], [0, 1709], [3383, 1709]])
dst = np.float32([[0, 0], [3999,0], [1300, 3999], [2700, 3999]])
self.M = cv2.getPerspectiveTransform(src, dst)
self.Minv = cv2.getPerspectiveTransform(dst,src)
def reset(self, shuffle=False):
self.index = 0
if self.shuffle:
np.random.shuffle(self.image_files)
def next_img(self):
self.index += 1
if self.index >= len(self.image_files):
self.reset()
def get_img(self):
while True:
img_name = self.image_files[self.index]
label_name = img_name.replace('.jpg', '.png')
img = cv2.imread(img_name)
if img is None:
print("load img failed:", img_name)
self.next_img()
else:
break
if self.birdeye == True:
warped_img = cv2.warpPerspective(img, self.M, (4000, 4000),flags=cv2.INTER_CUBIC)
img = cv2.resize(warped_img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
else:
img = cv2.resize(img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
img = img.transpose((2,0,1))
return img, label_name
def get_batch(self, batch_size=1):
imgs = []
labels = []
names = []
while len(imgs) < batch_size:
img, label_name = self.get_img()
imgs.append(img)
names.append(label_name)
self.next_img()
return np.array(imgs), names
def get_batch_generator(self, batch_size, total_step):
def do_get_batch():
for i in range(total_step):
imgs = []
names = []
try:
imgs, names = self.get_batch(batch_size)
except Exception as err:
imgs, names = self.get_batch(batch_size)
print('Generator 异常',err)
imgs = imgs.astype(np.float32)
imgs /= 255
yield i, imgs, names
batches = do_get_batch()
try:
from prefetch_generator import BackgroundGenerator
batches = BackgroundGenerator(batches,10)
except:
print(
"You can install 'prefetch_generator' for acceleration of data reading."
)
return batches
class EvalDataReader:
def __init__(self, dataset_dir, subset='val',rows=512, cols=1024, shuffle=True, birdeye=True):
label_dirname = os.path.join(dataset_dir,subset)
print (label_dirname)
label_files = sorted(glob.glob(label_dirname+"/label/*."+"png"))
print ('---')
print (label_files[0])
self.label_files = label_files
self.label_dirname = label_dirname
self.rows = rows
self.cols = cols
self.index = 0
self.subset = subset
self.dataset_dir = dataset_dir
self.shuffle = shuffle
self.reset()
self.augmentor = 0
self.M = 0
self.Minv = 0
self.get_M_Minv()
self.birdeye = birdeye
print("images total number", len(label_files))
# 标签转分类 255 ignore ?
def label2classes(self, label,row,col):
x = np.zeros([row,col,9]).astype(np.int64)
for i in range(row):
for j in range(col):
try:
trainId = id2label[int(label[i][j])].trainId
x[i, j ,trainId] = 1 # 属于第m类,第三维m处值为1
except Exception as err:
print('像素级标签值异常',err)
pass
return x
def get_M_Minv(self):
# 左上、右上、左下、右下
src = np.float32([[800, 730], [2583, 730], [0, 1709], [3383, 1709]])
dst = np.float32([[0, 0], [3999,0], [1300, 3999], [2700, 3999]])
self.M = cv2.getPerspectiveTransform(src, dst)
self.Minv = cv2.getPerspectiveTransform(dst,src)
def reset(self, shuffle=False):
self.index = 0
if self.shuffle:
np.random.shuffle(self.label_files)
def next_img(self):
self.index += 1
if self.index >= len(self.label_files):
self.reset()
def prev_img(self):
if self.index >= 1:
self.index -= 1
def get_img(self):
#if self.augmentor != 0 and self.augmentor < 6:
# self.prev_img()
while True:
label_name = self.label_files[self.index]
img_name = label_name.replace('label', 'image')
img_name = img_name.replace('_bin.png', '.jpg')
label = cv2.imread(label_name,cv2.IMREAD_GRAYSCALE)
img = cv2.imread(img_name)
if img is None:
print("load img failed:", img_name)
self.next_img()
else:
break
warped_img = cv2.warpPerspective(img, self.M, (4000, 4000),flags=cv2.INTER_CUBIC)
warped_label = cv2.warpPerspective(label, self.M, (4000, 4000),flags=cv2.INTER_NEAREST)
if self.birdeye == True:
label = cv2.resize(warped_label, (self.cols, self.rows), interpolation=cv2.INTER_NEAREST)
img = cv2.resize(warped_img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
else:
label = cv2.resize(label, (self.cols, self.rows), interpolation=cv2.INTER_NEAREST)
img = cv2.resize(img, (self.cols, self.rows), interpolation=cv2.INTER_CUBIC)
img = img.transpose((2,0,1))
label = self.label2classes(label,self.rows, self.cols) # 转换为 9 分类
return img, label, label_name
def get_batch(self, batch_size=1):
imgs = []
labels = []
names = []
while len(imgs) < batch_size:
img, label, label_name = self.get_img()
imgs.append(img)
labels.append(label)
names.append(label_name)
self.next_img()
self.augmentor += 1
return np.array(imgs), np.array(labels), names
def get_batch_generator(self, batch_size, total_step):
def do_get_batch():
for i in range(total_step):
gc.collect()
try:
imgs, labels, names = self.get_batch(batch_size)
except Exception as err:
imgs, labels, names = self.get_batch(batch_size)
print('Generator 异常',err)
imgs = imgs.astype(np.float32)
labels = labels.astype(np.float32)
imgs /= 255
yield i, imgs, labels, names
batches = do_get_batch()
try:
from prefetch_generator import BackgroundGenerator
batches = BackgroundGenerator(batches, 10)
except:
print(
"You can install 'prefetch_generator' for acceleration of data reading."
)
return batches
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。