代码拉取完成,页面将自动刷新
from torchbench.image_classification import ImageNet
from pytorch.pytorchcv.models.model_store import _model_sha1
from pytorch.pytorchcv.model_provider import get_model as ptcv_get_model
import torchvision.transforms as transforms
import torch
import math
from sys import version_info
# import os
for model_name, model_metainfo in (_model_sha1.items() if version_info[0] >= 3 else _model_sha1.iteritems()):
net = ptcv_get_model(model_name, pretrained=True)
error, checksum, repo_release_tag, caption, paper, ds, img_size, scale, batch, rem = model_metainfo
if (ds != "in1k") or (img_size == 0) or ((len(rem) > 0) and (rem[-1] == "*")):
continue
paper_model_name = caption
paper_arxiv_id = paper
input_image_size = img_size
resize_inv_factor = scale
batch_size = batch
model_description = "pytorch" + (rem if rem == "" else ", " + rem)
assert (not hasattr(net, "in_size")) or (input_image_size == net.in_size[0])
ImageNet.benchmark(
model=net,
model_description=model_description,
paper_model_name=paper_model_name,
paper_arxiv_id=paper_arxiv_id,
input_transform=transforms.Compose([
transforms.Resize(int(math.ceil(float(input_image_size) / resize_inv_factor))),
transforms.CenterCrop(input_image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]),
batch_size=batch_size,
num_gpu=1,
# data_root=os.path.join("..", "imgclsmob_data", "imagenet")
)
torch.cuda.empty_cache()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。