1 Star 0 Fork 0

xkeys/TL_Dataset_Classification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
dataset.py 1.60 KB
一键复制 编辑 原始数据 按行查看 历史
FangYang970206 提交于 2018-12-01 18:23 . 'update'
import torch
import cv2
import torch.utils.data as data
class_light = {
'Red Circle': 0,
'Green Circle': 1,
'Red Left': 2,
'Green Left': 3,
'Red Up': 4,
'Green Up': 5,
'Red Right': 6,
'Green Right': 7,
'Red Negative': 8,
'Green Negative': 8
}
class Traffic_Light(data.Dataset):
def __init__(self, dataset_names, img_resize_shape):
super(Traffic_Light, self).__init__()
self.dataset_names = dataset_names
self.img_resize_shape = img_resize_shape
def __getitem__(self, ind):
img = cv2.imread(self.dataset_names[ind])
img = cv2.resize(img, self.img_resize_shape)
img = img.transpose(2, 0, 1)-127.5/127.5
for key in class_light.keys():
if key in self.dataset_names[ind]:
label = class_light[key]
# pylint: disable=E1101,E1102
return torch.from_numpy(img), torch.tensor(label)
# pylint: disable=E1101,E1102
def __len__(self):
return len(self.dataset_names)
if __name__ == '__main__':
from torch.utils.data import DataLoader
from glob import glob
import os
path = 'TL_Dataset/Green Up/'
names = glob(os.path.join(path, '*.png'))
dataset = Traffic_Light(names, (64, 64))
dataload = DataLoader(dataset, batch_size=1)
for ind, (inp, label) in enumerate(dataload):
print("{}-inp_size:{}-label_size:{}".format(ind, inp.numpy().shape,
label.numpy().shape))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/xkeys1997/TL_Dataset_Classification.git
git@gitee.com:xkeys1997/TL_Dataset_Classification.git
xkeys1997
TL_Dataset_Classification
TL_Dataset_Classification
master

搜索帮助