代码拉取完成,页面将自动刷新
import os
import warnings
import torch
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive
class MNISTM(VisionDataset):
"""MNIST-M Dataset.
"""
resources = [
('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_train.pt.tar.gz',
'191ed53db9933bd85cc9700558847391'),
('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_test.pt.tar.gz',
'e11cb4d7fff76d7ec588b1134907db59')
]
training_file = "mnist_m_train.pt"
test_file = "mnist_m_test.pt"
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
@property
def train_labels(self):
warnings.warn("train_labels has been renamed targets")
return self.targets
@property
def test_labels(self):
warnings.warn("test_labels has been renamed targets")
return self.targets
@property
def train_data(self):
warnings.warn("train_data has been renamed data")
return self.data
@property
def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
"""Init MNIST-M dataset."""
super(MNISTM, self).__init__(root, transform=transform, target_transform=target_transform)
self.train = train
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found." +
" You can use download=True to download it")
if self.train:
data_file = self.training_file
else:
data_file = self.test_file
print(os.path.join(self.processed_folder, data_file))
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, index):
"""Get images and target for data loader.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.squeeze().numpy(), mode="RGB")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
"""Return size of dataset."""
return len(self.data)
@property
def raw_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'raw')
@property
def processed_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'processed')
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self):
return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and
os.path.exists(os.path.join(self.processed_folder, self.test_file)))
def download(self):
"""Download the MNIST-M data."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
# download files
for url, md5 in self.resources:
filename = url.rpartition('/')[2]
download_and_extract_archive(url, download_root=self.raw_folder,
extract_root=self.processed_folder,
filename=filename, md5=md5)
print('Done!')
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。