代码拉取完成,页面将自动刷新
from __future__ import division, print_function, absolute_import
import os
import re
import pdb
import glob
import pickle
import torch
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import PIL.Image as PILI
import numpy as np
from tqdm import tqdm
class EpisodeDataset(data.Dataset):
def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None):
"""Args:
root (str): path to data
phase (str): train, val or test
n_shot (int): how many examples per class for training (k/n_support)
n_eval (int): how many examples per class for evaluation
- n_shot + n_eval = batch_size for data.DataLoader of ClassDataset
transform (torchvision.transforms): data augmentation
"""
root = os.path.join(root, phase)
self.labels = sorted(os.listdir(root))
images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels]
self.episode_loader = [data.DataLoader(
ClassDataset(images=images[idx], label=idx, transform=transform),
batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)]
def __getitem__(self, idx):
return next(iter(self.episode_loader[idx]))
def __len__(self):
return len(self.labels)
class ClassDataset(data.Dataset):
def __init__(self, images, label, transform=None):
"""Args:
images (list of str): each item is a path to an image of the same label
label (int): the label of all the images
"""
self.images = images
self.label = label
self.transform = transform
def __getitem__(self, idx):
image = PILI.open(self.images[idx]).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, self.label
def __len__(self):
return len(self.images)
class EpisodicSampler(data.Sampler):
def __init__(self, total_classes, n_class, n_episode):
self.total_classes = total_classes
self.n_class = n_class
self.n_episode = n_episode
def __iter__(self):
for i in range(self.n_episode):
yield torch.randperm(self.total_classes)[:self.n_class]
def __len__(self):
return self.n_episode
def prepare_data(args):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_set = EpisodeDataset(args.data_root, 'train', args.n_shot, args.n_eval,
transform=transforms.Compose([
transforms.RandomResizedCrop(args.image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize]))
val_set = EpisodeDataset(args.data_root, 'val', args.n_shot, args.n_eval,
transform=transforms.Compose([
transforms.Resize(args.image_size * 8 // 7),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
normalize]))
test_set = EpisodeDataset(args.data_root, 'test', args.n_shot, args.n_eval,
transform=transforms.Compose([
transforms.Resize(args.image_size * 8 // 7),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
normalize]))
train_loader = data.DataLoader(train_set, num_workers=args.n_workers, pin_memory=args.pin_mem,
batch_sampler=EpisodicSampler(len(train_set), args.n_class, args.episode))
val_loader = data.DataLoader(val_set, num_workers=2, pin_memory=False,
batch_sampler=EpisodicSampler(len(val_set), args.n_class, args.episode_val))
test_loader = data.DataLoader(test_set, num_workers=2, pin_memory=False,
batch_sampler=EpisodicSampler(len(test_set), args.n_class, args.episode_val))
return train_loader, val_loader, test_loader
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。