1 Star 1 Fork 1

SocietyQiang/Pytorch-CapsuleNet

Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
Clone or Download
data_loader.py 1.74 KB
Copy Edit Raw Blame History
jindongwang authored 2018-04-10 14:34 . add: cifar-10
import torch
from torchvision import datasets, transforms
class Dataset:
def __init__(self, dataset, _batch_size):
super(Dataset, self).__init__()
if dataset == 'mnist':
dataset_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('/data/mnist', train=True, download=True,
transform=dataset_transform)
test_dataset = datasets.MNIST('/data/mnist', train=False, download=True,
transform=dataset_transform)
self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=_batch_size, shuffle=True)
self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=_batch_size, shuffle=False)
elif dataset == 'cifar10':
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(
'/data/cifar', train=True, download=True, transform=data_transform)
test_dataset = datasets.CIFAR10(
'/data/cifar', train=False, download=True, transform=data_transform)
self.train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=_batch_size, shuffle=True)
self.test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=_batch_size, shuffle=False)
elif dataset == 'office-caltech':
pass
elif dataset == 'office31':
pass
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/societyqiang/Pytorch-CapsuleNet.git
git@gitee.com:societyqiang/Pytorch-CapsuleNet.git
societyqiang
Pytorch-CapsuleNet
Pytorch-CapsuleNet
master

Search

D67c1975 1850385 1daf7b77 1850385