1 Star 0 Fork 0

qiqiqi777/mnist-m

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
mnist_m.py 3.95 KB
一键复制 编辑 原始数据 按行查看 历史
liyxi 提交于 2020-10-08 18:57 . Update mnist_m.py
import os
import warnings
import torch
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive
class MNISTM(VisionDataset):
"""MNIST-M Dataset.
"""
resources = [
('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_train.pt.tar.gz',
'191ed53db9933bd85cc9700558847391'),
('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_test.pt.tar.gz',
'e11cb4d7fff76d7ec588b1134907db59')
]
training_file = "mnist_m_train.pt"
test_file = "mnist_m_test.pt"
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
@property
def train_labels(self):
warnings.warn("train_labels has been renamed targets")
return self.targets
@property
def test_labels(self):
warnings.warn("test_labels has been renamed targets")
return self.targets
@property
def train_data(self):
warnings.warn("train_data has been renamed data")
return self.data
@property
def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
"""Init MNIST-M dataset."""
super(MNISTM, self).__init__(root, transform=transform, target_transform=target_transform)
self.train = train
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found." +
" You can use download=True to download it")
if self.train:
data_file = self.training_file
else:
data_file = self.test_file
print(os.path.join(self.processed_folder, data_file))
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, index):
"""Get images and target for data loader.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.squeeze().numpy(), mode="RGB")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
"""Return size of dataset."""
return len(self.data)
@property
def raw_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'raw')
@property
def processed_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'processed')
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self):
return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and
os.path.exists(os.path.join(self.processed_folder, self.test_file)))
def download(self):
"""Download the MNIST-M data."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
# download files
for url, md5 in self.resources:
filename = url.rpartition('/')[2]
download_and_extract_archive(url, download_root=self.raw_folder,
extract_root=self.processed_folder,
filename=filename, md5=md5)
print('Done!')
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/qiqiqi777/mnist-m.git
git@gitee.com:qiqiqi777/mnist-m.git
qiqiqi777
mnist-m
mnist-m
main

搜索帮助