1 Star 0 Fork 0

123/MAML-Pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
omniglot.py 3.54 KB
一键复制 编辑 原始数据 按行查看 历史
import torch.utils.data as data
import os
import os.path
import errno
class Omniglot(data.Dataset):
urls = [
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
'''
The items are (filename,category). The index of all the categories can be found in self.idx_classes
Args:
- root: the directory where the dataset will be stored
- transform: how to transform the input
- target_transform: how to transform the target
- download: need to download the dataset
'''
def __init__(self, root, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
if not self._check_exists():
if download:
self.download()
else:
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
self.idx_classes = index_classes(self.all_items)
def __getitem__(self, index):
filename = self.all_items[index][0]
img = str.join('/', [self.all_items[index][2], filename])
target = self.idx_classes[self.all_items[index][1]]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.all_items)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
def download(self):
from six.moves import urllib
import zipfile
if self._check_exists():
return
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
for url in self.urls:
print('== Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
file_processed = os.path.join(self.root, self.processed_folder)
print("== Unzip from " + file_path + " to " + file_processed)
zip_ref = zipfile.ZipFile(file_path, 'r')
zip_ref.extractall(file_processed)
zip_ref.close()
print("Download finished.")
def find_classes(root_dir):
retour = []
for (root, dirs, files) in os.walk(root_dir):
for f in files:
if (f.endswith("png")):
r = root.split('/')
lr = len(r)
retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
print("== Found %d items " % len(retour))
return retour
def index_classes(items):
idx = {}
for i in items:
if i[1] not in idx:
idx[i[1]] = len(idx)
print("== Found %d classes" % len(idx))
return idx
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/changqixin/MAML-Pytorch.git
git@gitee.com:changqixin/MAML-Pytorch.git
changqixin
MAML-Pytorch
MAML-Pytorch
master

搜索帮助