Fetch the repository succeeded.
import torch
from torchvision import datasets, transforms
class Dataset:
def __init__(self, dataset, _batch_size):
super(Dataset, self).__init__()
if dataset == 'mnist':
dataset_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('/data/mnist', train=True, download=True,
transform=dataset_transform)
test_dataset = datasets.MNIST('/data/mnist', train=False, download=True,
transform=dataset_transform)
self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=_batch_size, shuffle=True)
self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=_batch_size, shuffle=False)
elif dataset == 'cifar10':
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(
'/data/cifar', train=True, download=True, transform=data_transform)
test_dataset = datasets.CIFAR10(
'/data/cifar', train=False, download=True, transform=data_transform)
self.train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=_batch_size, shuffle=True)
self.test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=_batch_size, shuffle=False)
elif dataset == 'office-caltech':
pass
elif dataset == 'office31':
pass
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。