1 Star 0 Fork 0

幽鸟/EfficientNet-PyTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
sotabench.py 2.04 KB
一键复制 编辑 原始数据 按行查看 历史
Luke Melas-Kyriazi 提交于 2020-08-26 16:59 . Update sotabench.py
import os
import numpy as np
import PIL
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageNet
from efficientnet_pytorch import EfficientNet
from sotabencheval.image_classification import ImageNetEvaluator
from sotabencheval.utils import is_server
if is_server():
DATA_ROOT = DATA_ROOT = os.environ.get('IMAGENET_DIR', './imagenet') # './.data/vision/imagenet'
else: # local settings
DATA_ROOT = os.environ['IMAGENET_DIR']
assert bool(DATA_ROOT), 'please set IMAGENET_DIR environment variable'
print('Local data root: ', DATA_ROOT)
model_name = 'EfficientNet-B5'
model = EfficientNet.from_pretrained(model_name.lower())
image_size = EfficientNet.get_image_size(model_name.lower())
input_transform = transforms.Compose([
transforms.Resize(image_size, PIL.Image.BICUBIC),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
test_dataset = ImageNet(
DATA_ROOT,
split="val",
transform=input_transform,
target_transform=None,
)
test_loader = DataLoader(
test_dataset,
batch_size=128,
shuffle=False,
num_workers=4,
pin_memory=True,
)
model = model.cuda()
model.eval()
evaluator = ImageNetEvaluator(model_name=model_name,
paper_arxiv_id='1905.11946')
def get_img_id(image_name):
return image_name.split('/')[-1].replace('.JPEG', '')
with torch.no_grad():
for i, (input, target) in enumerate(test_loader):
input = input.to(device='cuda', non_blocking=True)
target = target.to(device='cuda', non_blocking=True)
output = model(input)
image_ids = [get_img_id(img[0]) for img in test_loader.dataset.imgs[i*test_loader.batch_size:(i+1)*test_loader.batch_size]]
evaluator.add(dict(zip(image_ids, list(output.cpu().numpy()))))
if evaluator.cache_exists:
break
if not is_server():
print("Results:")
print(evaluator.get_results())
evaluator.save()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/quietbirds/EfficientNet-PyTorch.git
git@gitee.com:quietbirds/EfficientNet-PyTorch.git
quietbirds
EfficientNet-PyTorch
EfficientNet-PyTorch
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385