代码拉取完成,页面将自动刷新
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
class VOCDataset(Dataset):
def __init__(self, root_dir, image_size=(224, 224)):
self.root_dir = root_dir
self.image_size = image_size
root_dir = root_dir + "Annotations/"
self.xml_files = [os.path.join(root_dir, file) for file in os.listdir(root_dir) if file.endswith('.xml')]
self.class_dict = {'normal_driving': 0, 'phone': 1, 'yawn': 2, 'look_around': 3, 'close_eyes': 4}
def __len__(self):
return len(self.xml_files)
def __getitem__(self, index):
xml_file = self.xml_files[index]
tree = ET.parse(xml_file)
root = tree.getroot()
image_dir = self.root_dir + "Images/"
image_name = root.find('filename').text
image_path = os.path.join(image_dir, image_name)
image = Image.open(image_path).convert('RGB')
image = np.array(image)
width = int(image.shape[1] * 0.33)
image = Image.fromarray(image[:, width:, :])
# resize image
if self.image_size:
image = image.resize(self.image_size)
# convert image to tensor
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255
# get label
label = root.find('object').find('name').text
label = self.class_dict[label]
return image, label
if __name__ == '__main__':
dataset = VOCDataset('./dataset/data_5k/')
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
for images, labels in dataloader:
print(images.shape)
print(labels)
break
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。