4 Star 0 Fork 0

xinanXu/FA-CC

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
data.py 1.72 KB
一键复制 编辑 原始数据 按行查看 历史
xinanXu 提交于 2023-05-23 14:32 . template commit
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
class VOCDataset(Dataset):
def __init__(self, root_dir, image_size=(224, 224)):
self.root_dir = root_dir
self.image_size = image_size
root_dir = root_dir + "Annotations/"
self.xml_files = [os.path.join(root_dir, file) for file in os.listdir(root_dir) if file.endswith('.xml')]
self.class_dict = {'normal_driving': 0, 'phone': 1, 'yawn': 2, 'look_around': 3, 'close_eyes': 4}
def __len__(self):
return len(self.xml_files)
def __getitem__(self, index):
xml_file = self.xml_files[index]
tree = ET.parse(xml_file)
root = tree.getroot()
image_dir = self.root_dir + "Images/"
image_name = root.find('filename').text
image_path = os.path.join(image_dir, image_name)
image = Image.open(image_path).convert('RGB')
image = np.array(image)
width = int(image.shape[1] * 0.33)
image = Image.fromarray(image[:, width:, :])
# resize image
if self.image_size:
image = image.resize(self.image_size)
# convert image to tensor
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255
# get label
label = root.find('object').find('name').text
label = self.class_dict[label]
return image, label
if __name__ == '__main__':
dataset = VOCDataset('./dataset/data_5k/')
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
for images, labels in dataloader:
print(images.shape)
print(labels)
break
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/DearAtri/fa-cc.git
git@gitee.com:DearAtri/fa-cc.git
DearAtri
fa-cc
FA-CC
master

搜索帮助