1 Star 0 Fork 0

刘世杰/dl_final_assignment

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
splitdata.py 1.45 KB
一键复制 编辑 原始数据 按行查看 历史
刘世杰 提交于 2024-06-01 11:16 . 数据处理
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data
from torch.utils.data import random_split
def get_data_loaders(batch_size=64, num_workers=8, val_split=0.1):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
# Calculate validation set size
val_size = int(len(trainset) * val_split)
train_size = len(trainset) - val_size
# Split the dataset
trainset, valset = random_split(trainset, [train_size, val_size])
# Create data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return trainloader, valloader, testloader
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/liushijie618/dl_final_assignment.git
git@gitee.com:liushijie618/dl_final_assignment.git
liushijie618
dl_final_assignment
dl_final_assignment
master

搜索帮助