1 Star 0 Fork 1

ink/pytorch_ResNet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 1.44 KB
一键复制 编辑 原始数据 按行查看 历史
ink 提交于 2020-08-10 21:08 . Initial commit
import math
import numpy as np
import h5py
import matplotlib.pyplot as plt
import scipy
from PIL import Image
from scipy import ndimage
import torch
import torch.nn as nn
from cnn_utils import *
from torch import nn,optim
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
np.random.seed(1)
torch.manual_seed(1)
batch_size = 24
learning_rate = 9e-3
num_epocher = 100
X_train_orig, Y_train_orig, X_test_orig, Y_test_orig, classes = load_dataset()
X_train = X_train_orig/255.
X_test = X_test_orig/255.
class MyData(Dataset): # 继承Dataset
def __init__(self, data, y, transform=None): # __init__是初始化该类的一些基础参数
self.transform = transform # 变换
self.data = data
self.y = y
def __len__(self): # 返回整个数据集的大小
return len(self.data)
def __getitem__(self, index): # 根据索引index返回dataset[index]
sample = self.data[index]
if self.transform:
sample = self.transform(sample) # 对样本进行变换
return sample, self.y[index] # 返回该样本
train_dataset = MyData(X_train, Y_train_orig[0],
transform=transforms.ToTensor())
test_dataset = MyData(X_test, Y_test_orig[0],
transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/inkCode/pytorch_ResNet.git
git@gitee.com:inkCode/pytorch_ResNet.git
inkCode
pytorch_ResNet
pytorch_ResNet
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385