代码拉取完成,页面将自动刷新
from torch.utils.data import DataLoader
from utils.dataset import TestDataset
import albumentations as A
from albumentations import pytorch as AT
import numpy as np
import pandas as pd
from utils.model import *
import ttach as tta
from tqdm import tqdm
import os
import json
import time
input_size = 224
nc = 137
batch_size = 64
nw = 8
test_csv = '../Dataset/test_clean.csv' # 缺少 ['a2411.jpg']
test_path = '../Dataset/test/'
model1_name = 'resnest50d'
model2_name = 'tf_efficientnetv2_s_in21ft1k'
save_dir = 'pred-csv'
save_csv = os.path.join(save_dir, 'ress50_effv2s') + os.sep
saveFileName = f'{save_csv}resnest50-effv2s.csv'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
if not os.path.exists(save_csv):
os.mkdir(save_csv)
albu_transform = {
'test': A.Compose([
A.Resize((int(input_size * (256 / 224))), (int(input_size * (256 / 224)))),
A.CenterCrop(input_size, input_size),
A.Normalize(),
AT.ToTensorV2(p=1.0)
])
}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cuda_info = torch.cuda.get_device_properties(0)
print("using {} {} {}MB.".format(device, cuda_info.name, cuda_info.total_memory / 1024 ** 2))
test_dataset = TestDataset(test_csv, test_path, transform=albu_transform['test'])
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=nw)
print('Test total {} images'.format(len(test_dataset)))
# read num_to_class
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
json_file = open(json_path, "r")
# {'0': 0, '1': 1, ...}
num_to_class = json.load(json_file)
# predict
model1 = timm_model(model1_name, pretrained=False, num_classes=nc).to(device)
model2 = timm_model(model2_name, pretrained=False, num_classes=nc).to(device)
model_path_list = [
'resnest50d-seed233/epoch28-val_acc-0.9006.pth',
'tf_efficientnetv2_s_in21ft1k/epoch28-val_acc-0.912.pth',
]
print('--------------------------------------')
print(f'predict-starting')
print('--------------------------------------')
model1.load_state_dict(torch.load(model_path_list[0]))
model2.load_state_dict(torch.load(model_path_list[1]))
print(f'load weight {model_path_list}')
print(f'save name: {saveFileName}')
time.sleep(0.1)
# Make sure the model is in eval mode.
# Some modules like Dropout or BatchNorm affect if the model is in training mode.
model1.eval()
model2.eval()
tta_model1 = tta.ClassificationTTAWrapper(model1, tta.aliases.d4_transform())
tta_model2 = tta.ClassificationTTAWrapper(model2, tta.aliases.d4_transform())
# Initialize a list to store the predictions.
predictions = []
# Iterate the testing set by batches.
for batch in tqdm(test_loader):
imgs = batch
with torch.no_grad():
# logits = model(imgs.to(device)) # do not use tta
# logits = tta_model(imgs.to(device))
logits1 = tta_model1(imgs.to(device))
logits2 = tta_model2(imgs.to(device))
# logits3 = tta_model3(imgs.to(device))
# logits4 = tta_model4(imgs.to(device))
logits = 0.4 * logits1 + 0.6 * logits2
# Take the class with greatest logit as prediction and record it.
predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())
preds = []
for i in predictions:
preds.append(num_to_class[str(i)]) # to_str
test_data = pd.read_csv(test_csv)
test_data['category_id'] = pd.Series(preds)
pred_csv = pd.concat([test_data['image_id'], test_data['category_id']], axis=1)
only_one_gif_format = pd.Series({'image_id': 'a2411.jpg', 'category_id': 1}) # append
pred_csv = pred_csv.append(only_one_gif_format, ignore_index=True)
pred_csv.to_csv(saveFileName, index=False)
print(f"{saveFileName}predict done!")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。