1 Star 2 Fork 0

chensong/meta-learning-lstm-pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
dataloader.py 4.01 KB
一键复制 编辑 原始数据 按行查看 历史
Mark Dong 提交于 2019-02-24 03:31 . almost
from __future__ import division, print_function, absolute_import
import os
import re
import pdb
import glob
import pickle
import torch
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import PIL.Image as PILI
import numpy as np
from tqdm import tqdm
class EpisodeDataset(data.Dataset):
def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None):
"""Args:
root (str): path to data
phase (str): train, val or test
n_shot (int): how many examples per class for training (k/n_support)
n_eval (int): how many examples per class for evaluation
- n_shot + n_eval = batch_size for data.DataLoader of ClassDataset
transform (torchvision.transforms): data augmentation
"""
root = os.path.join(root, phase)
self.labels = sorted(os.listdir(root))
images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels]
self.episode_loader = [data.DataLoader(
ClassDataset(images=images[idx], label=idx, transform=transform),
batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)]
def __getitem__(self, idx):
return next(iter(self.episode_loader[idx]))
def __len__(self):
return len(self.labels)
class ClassDataset(data.Dataset):
def __init__(self, images, label, transform=None):
"""Args:
images (list of str): each item is a path to an image of the same label
label (int): the label of all the images
"""
self.images = images
self.label = label
self.transform = transform
def __getitem__(self, idx):
image = PILI.open(self.images[idx]).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, self.label
def __len__(self):
return len(self.images)
class EpisodicSampler(data.Sampler):
def __init__(self, total_classes, n_class, n_episode):
self.total_classes = total_classes
self.n_class = n_class
self.n_episode = n_episode
def __iter__(self):
for i in range(self.n_episode):
yield torch.randperm(self.total_classes)[:self.n_class]
def __len__(self):
return self.n_episode
def prepare_data(args):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_set = EpisodeDataset(args.data_root, 'train', args.n_shot, args.n_eval,
transform=transforms.Compose([
transforms.RandomResizedCrop(args.image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize]))
val_set = EpisodeDataset(args.data_root, 'val', args.n_shot, args.n_eval,
transform=transforms.Compose([
transforms.Resize(args.image_size * 8 // 7),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
normalize]))
test_set = EpisodeDataset(args.data_root, 'test', args.n_shot, args.n_eval,
transform=transforms.Compose([
transforms.Resize(args.image_size * 8 // 7),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
normalize]))
train_loader = data.DataLoader(train_set, num_workers=args.n_workers, pin_memory=args.pin_mem,
batch_sampler=EpisodicSampler(len(train_set), args.n_class, args.episode))
val_loader = data.DataLoader(val_set, num_workers=2, pin_memory=False,
batch_sampler=EpisodicSampler(len(val_set), args.n_class, args.episode_val))
test_loader = data.DataLoader(test_set, num_workers=2, pin_memory=False,
batch_sampler=EpisodicSampler(len(test_set), args.n_class, args.episode_val))
return train_loader, val_loader, test_loader
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/chensong121/meta-learning-lstm-pytorch.git
git@gitee.com:chensong121/meta-learning-lstm-pytorch.git
chensong121
meta-learning-lstm-pytorch
meta-learning-lstm-pytorch
master

搜索帮助