代码拉取完成,页面将自动刷新
from torch import nn
class MyCNN(nn.Module):
def __init__(self):
super(MyCNN, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(3, 16, (3, 3), stride=(2, 2)),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, (3, 3), stride=(2, 2)),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.fc1 = nn.Linear(32 * 3 * 7, 18)
self.fc2 = nn.Linear(32 * 3 * 7, 18)
self.fc3 = nn.Linear(32 * 3 * 7, 18)
self.fc4 = nn.Linear(32 * 3 * 7, 18)
self.fc5 = nn.Linear(32 * 3 * 7, 18)
self.fc6 = nn.Linear(32 * 3 * 7, 18)
def forward(self, img):
feat = self.cnn(img)
print(feat.shape) # 在展平之前
feat = feat.view(feat.shape[0], -1)
print(feat.shape)
c1 = self.fc1(feat)
c2 = self.fc2(feat)
c3 = self.fc3(feat)
c4 = self.fc4(feat)
c5 = self.fc5(feat)
c6 = self.fc6(feat)
return c1, c2, c3, c4, c5, c6
# root_dir = '.\\hagrid_data\\images'
#
#
# transform = transforms.Compose([
# transforms.Resize((64, 128)),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])
#
# train_dataset = HagridDataset(root_dir,transform=transform)
# train_loader = torch.utils.data.DataLoader(
# train_dataset,
# batch_size=10,
# shuffle=True,
# num_workers=4,
# )
#
# val_dataset = HagridDataset(root_dir,'val',transform=transform)
# val_loader = torch.utils.data.DataLoader(
# val_dataset,
# batch_size=32,
# shuffle=True,
# num_workers=4,
# )
#
# model = MyCNN()
#
# criterion = nn.CrossEntropyLoss(reduction='mean')
# optimizer = torch.optim.Adam(model.parameters(), 0.001)
# best_loss = 1000.0
#
#
# def my_train(train_loader, model, criterion, optimizer):
# print(train_loader)
# model.train()
# for data in train_loader:
# imgs, targets = data
# print(data.shape)
# print(targets)
# c0, c1, c2, c3, c4, c5 = model(data[0])
# loss = (criterion(c0, data[1][:, 0]) +
# criterion(c1, data[1][:, 1]) +
# criterion(c2, data[1][:, 2]) +
# criterion(c3, data[1][:, 3]) +
# criterion(c4, data[1][:, 4]) +
# criterion(c5, data[1][:, 5]))
# loss /= 6
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
#
#
# def my_validate(val_loader, model, criterion):
# model.eval()
# val_loss = []
# # 不记录模型梯度信息
# with torch.no_grad():
# for i, (data, target) in enumerate(val_loader):
# c0, c1, c2, c3, c4, c5 = model(data[0])
# loss = criterion(c0, data[1][:, 0]) + \
# criterion(c1, data[1][:, 1]) + \
# criterion(c2, data[1][:, 2]) + \
# criterion(c3, data[1][:, 3]) + \
# criterion(c4, data[1][:, 4]) + \
# criterion(c5, data[1][:, 5])
# loss /= 6
# val_loss.append(loss.item())
# return np.mean(val_loss)
#
#
# if __name__ == '__main__':
# for epoch in range(10):
# print('Epoch: ', epoch)
#
# my_train(train_loader, model, criterion, optimizer)
# val_loss = my_validate(val_loader, model, criterion)
#
# # 记录下验证集精度
# if val_loss < best_loss:
# best_loss = val_loss
# torch.save(model.state_dict(), './model.pt')
#
# torch.save(model.state_dict(), './model.pt')
# model.load_state_dict(torch.load('./model.pt'))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。