1 Star 0 Fork 0

许满坤/deep learning

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
mycnn.py 3.56 KB
一键复制 编辑 原始数据 按行查看 历史
from torch import nn
class MyCNN(nn.Module):
def __init__(self):
super(MyCNN, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(3, 16, (3, 3), stride=(2, 2)),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, (3, 3), stride=(2, 2)),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.fc1 = nn.Linear(32 * 3 * 7, 18)
self.fc2 = nn.Linear(32 * 3 * 7, 18)
self.fc3 = nn.Linear(32 * 3 * 7, 18)
self.fc4 = nn.Linear(32 * 3 * 7, 18)
self.fc5 = nn.Linear(32 * 3 * 7, 18)
self.fc6 = nn.Linear(32 * 3 * 7, 18)
def forward(self, img):
feat = self.cnn(img)
print(feat.shape) # 在展平之前
feat = feat.view(feat.shape[0], -1)
print(feat.shape)
c1 = self.fc1(feat)
c2 = self.fc2(feat)
c3 = self.fc3(feat)
c4 = self.fc4(feat)
c5 = self.fc5(feat)
c6 = self.fc6(feat)
return c1, c2, c3, c4, c5, c6
# root_dir = '.\\hagrid_data\\images'
#
#
# transform = transforms.Compose([
# transforms.Resize((64, 128)),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])
#
# train_dataset = HagridDataset(root_dir,transform=transform)
# train_loader = torch.utils.data.DataLoader(
# train_dataset,
# batch_size=10,
# shuffle=True,
# num_workers=4,
# )
#
# val_dataset = HagridDataset(root_dir,'val',transform=transform)
# val_loader = torch.utils.data.DataLoader(
# val_dataset,
# batch_size=32,
# shuffle=True,
# num_workers=4,
# )
#
# model = MyCNN()
#
# criterion = nn.CrossEntropyLoss(reduction='mean')
# optimizer = torch.optim.Adam(model.parameters(), 0.001)
# best_loss = 1000.0
#
#
# def my_train(train_loader, model, criterion, optimizer):
# print(train_loader)
# model.train()
# for data in train_loader:
# imgs, targets = data
# print(data.shape)
# print(targets)
# c0, c1, c2, c3, c4, c5 = model(data[0])
# loss = (criterion(c0, data[1][:, 0]) +
# criterion(c1, data[1][:, 1]) +
# criterion(c2, data[1][:, 2]) +
# criterion(c3, data[1][:, 3]) +
# criterion(c4, data[1][:, 4]) +
# criterion(c5, data[1][:, 5]))
# loss /= 6
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
#
#
# def my_validate(val_loader, model, criterion):
# model.eval()
# val_loss = []
# # 不记录模型梯度信息
# with torch.no_grad():
# for i, (data, target) in enumerate(val_loader):
# c0, c1, c2, c3, c4, c5 = model(data[0])
# loss = criterion(c0, data[1][:, 0]) + \
# criterion(c1, data[1][:, 1]) + \
# criterion(c2, data[1][:, 2]) + \
# criterion(c3, data[1][:, 3]) + \
# criterion(c4, data[1][:, 4]) + \
# criterion(c5, data[1][:, 5])
# loss /= 6
# val_loss.append(loss.item())
# return np.mean(val_loss)
#
#
# if __name__ == '__main__':
# for epoch in range(10):
# print('Epoch: ', epoch)
#
# my_train(train_loader, model, criterion, optimizer)
# val_loss = my_validate(val_loader, model, criterion)
#
# # 记录下验证集精度
# if val_loss < best_loss:
# best_loss = val_loss
# torch.save(model.state_dict(), './model.pt')
#
# torch.save(model.state_dict(), './model.pt')
# model.load_state_dict(torch.load('./model.pt'))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xu-mankun/deep-learning.git
git@gitee.com:xu-mankun/deep-learning.git
xu-mankun
deep-learning
deep learning
master

搜索帮助