1 Star 0 Fork 0

孙志强/img_classification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 6.84 KB
一键复制 编辑 原始数据 按行查看 历史
孙志强 提交于 2021-11-25 22:14 . first commit
# -*- coding: utf-8 -*-
"""
-------------------------------------------------
# @Project :Streamlit_demo
# @File :main
# @Date :2021/9/8 15:41
# @Author :Sun
# @Software :PyCharm
-------------------------------------------------
"""
from torchvision import models, transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
from tools.utils import validate, show_confMat
from tools.datapreparation import Laster_Welding
import time
import os
import torch
from model.all_nets import AlexNet, VGG16, SqueezeNet, ResNet, DenseNet, SE_ResNet
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weightslog = 'weights/VGG16/'
dataset = 'dataset/'
loss_function = nn.CrossEntropyLoss()
def new_file(name):
if not os.path.exists(os.path.join(dataset,name)):
os.makedirs(os.path.join(dataset,name))
return
def predict(image):
resnet = models.resnet101(pretrained=True)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)])
batch_t = torch.unsqueeze(transform(image), 0)
resnet.eval()
out = resnet(batch_t)
prob = torch.nn.functional.softmax(out, dim=1)[0] * 100
_, indices = torch.sort(out, descending=True)
idx = indices[0][1]
score = prob[idx].item()
classs = classes[int(idx)]
return classs,score
def test(image, model):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)])
model_1 = model
model_1 = model_1.to(device)
model_1.load_state_dict(torch.load(os.path.join(weightslog, 'latest.pt')))
model_1.eval()
batch_t = torch.unsqueeze(transform(image), 0)
inputs = batch_t.to(device)
outputs = model_1(inputs)
prob = torch.nn.functional.softmax(outputs, dim=1)[0] * 100
_, preds = torch.max(outputs, 1)
return preds
def testdataset(model, batch_size):
model_1 = model
model_1 = model_1.to(device)
model_1.load_state_dict(torch.load(os.path.join(weightslog, 'latest.pt')))
model_1.eval()
correct = 0.0
Tests = Laster_Welding('./new_dataset', 224, 'test')
testLoader = DataLoader(Tests, batch_size=batch_size, shuffle=True, num_workers=0)
total = 0.0
classes_name = Tests.cla_name()
with torch.no_grad():
for i, (inputs, labels) in enumerate(testLoader):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model_1(inputs)
_, preds = torch.max(outputs, 1)
total += labels.size(0)
correct += (preds == labels).squeeze().sum().cpu().numpy()
print("batch %d" % i)
for j in range(inputs.size()[0]):
print("{} pred label:{}, true label:{}".format(len(preds), classes_name[preds[j]],
classes_name[labels[j]]))
print("Acc:{:.2%}".format(correct / total))
conf_mat_valid, valid_acc = validate(model, testLoader, 'test', classes_name, device)
show_confMat(conf_mat_valid, classes_name, 'test', weightslog)
def train(model, learning_rate, epoch_total, batch_size, shuffle):
# writer = SummaryWriter(comment=f'{str(model)[:4]}')
tstep = 0
vstep = 0
Trains = Laster_Welding('./dataset', 224, 'train') # 加载训练数据
val = Laster_Welding('./dataset', 224, 'val') # 加载验证数据
Tests = Laster_Welding('./dataset', 224, 'test') # 加载测试数据
trainLoader = DataLoader(Trains, batch_size=batch_size, shuffle=shuffle, num_workers=0) # 将训练数据导入网络
valLoader = DataLoader(val, batch_size=batch_size, shuffle=shuffle, num_workers=0) # 将测试数据导入网络
testLoader = DataLoader(Tests, batch_size=batch_size, shuffle=shuffle, num_workers=0)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30,gamma=0.1)
for epoch in range(epoch_total):
loss_sigma = 0.0
correct = 0.0
total = 0.0
scheduler.step()
for step, (image,label) in enumerate(trainLoader):
tstep += 1
image = image.to(device)
label = label.to(device)
torch.cuda.synchronize()
pred = model(image)
torch.cuda.synchronize()
end = time.time()
loss = loss_function(pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predicted = torch.max(pred.data, 1)
total += label.size(0)
correct += (predicted == label).squeeze().sum().cpu().numpy()
loss_sigma += loss.item()
if step % 10 == 9:
loss_avg = loss_sigma / 10
loss_sigma = 0.0
print(f"Training: Epoch[{epoch+1}/{epoch_total}] iteration[{step+1}/{len(trainLoader)}] Loss: {loss_avg} Acc: {correct/ total}")
torch.save(model.state_dict(), os.path.join(weightslog, 'latest.pt'))
if epoch % 2 == 0:
loss_sigma = 0.0
cls_num = len(Trains.cla_name())
conf_mat = np.zeros([cls_num, cls_num])
model.eval()
for i, data in enumerate(testLoader):
vstep += 1
images, labels = data
images, labels = Variable(images), Variable(labels)
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
outputs.detach()
loss = loss_function(outputs, labels)
loss_sigma += loss.item()
_, predicted = torch.max(outputs.data, 1)
for j in range(len(labels)):
cate_i = labels[j].cpu().numpy()
pre_i = predicted[j].cpu().numpy()
conf_mat[cate_i, pre_i] += 1.0
print(f"Vaild set Accuracy: {conf_mat.trace()/ conf_mat.sum()}")
end1 = time.time()
classes_name = Trains.cla_name()
conf_mat_train, train_acc = validate(model, trainLoader, 'train', classes_name, device)
conf_mat_valid, valid_acc = validate(model, valLoader, 'valid', classes_name, device)
show_confMat(conf_mat_train, classes_name, 'train', weightslog)
show_confMat(conf_mat_valid, classes_name, 'valid', weightslog)
def image_classification():
model = VGG16().to(device)
epoch = 100
learning_rate = 0.01
batch_size = 2
shuffle = True
train(model, learning_rate, epoch, batch_size, shuffle)
if __name__ == "__main__":
image_classification()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/sunzhiq99/img_classification.git
git@gitee.com:sunzhiq99/img_classification.git
sunzhiq99
img_classification
img_classification
master

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385