1 Star 0 Fork 0

xxy/efficientDet-d5

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
gwd_test.py 4.21 KB
一键复制 编辑 原始数据 按行查看 历史
xxy 提交于 2020-07-11 00:30 . init
import sys
from ensemble_boxes import *
import torch
import numpy as np
import pandas as pd
from glob import glob
from torch.utils.data import Dataset,DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import gc
from matplotlib import pyplot as plt
from effdet import get_efficientdet_config, EfficientDet, DetBenchEval
from effdet.efficientdet import HeadNet
import os
def get_valid_transforms():
return A.Compose([
A.Resize(height=512, width=512, p=1.0),
ToTensorV2(p=1.0),
], p=1.0)
DATA_ROOT_PATH = r'D:\Workspace\Python\GWD\data\test'
class DatasetRetriever(Dataset):
def __init__(self, image_ids, transforms=None):
super().__init__()
self.image_ids = image_ids
self.transforms = transforms
def __getitem__(self, index: int):
image_id = self.image_ids[index]
image = cv2.imread(os.path.join(DATA_ROOT_PATH,image_id+".jpg"), cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
image /= 255.0
if self.transforms:
sample = {'image': image}
sample = self.transforms(**sample)
image = sample['image']
return image, image_id
def __len__(self) -> int:
return self.image_ids.shape[0]
dataset = DatasetRetriever(
image_ids=np.array([path.split('/')[-1][:-4] for path in glob(f'{DATA_ROOT_PATH}/*.jpg')]),
transforms=get_valid_transforms()
)
def collate_fn(batch):
return tuple(zip(*batch))
data_loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=4,
drop_last=False,
collate_fn=collate_fn
)
def load_net(checkpoint_path):
config = get_efficientdet_config('tf_efficientdet_d5')
net = EfficientDet(config, pretrained_backbone=False)
config.num_classes = 1
config.image_size=512
net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['model_state_dict'])
del checkpoint
gc.collect()
net = DetBenchEval(net, config)
net.eval();
return net.cuda()
net = load_net(r'D:\Workspace\efficientdet-pytorch-master\effdet5-cutmix-augmix1\last-checkpoint.bin')
def make_predictions(images, score_threshold=0.22):
images = torch.stack(images).cuda().float()
predictions = []
with torch.no_grad():
det = net(images, torch.tensor([1]*images.shape[0]).float().cuda())
for i in range(images.shape[0]):
boxes = det[i].detach().cpu().numpy()[:,:4]
scores = det[i].detach().cpu().numpy()[:,4]
indexes = np.where(scores > score_threshold)[0]
boxes = boxes[indexes]
boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
predictions.append({
'boxes': boxes[indexes],
'scores': scores[indexes],
})
return [predictions]
def run_wbf(predictions, image_index, image_size=512, iou_thr=0.44, skip_box_thr=0.43, weights=None):
boxes = [(prediction[image_index]['boxes']/(image_size-1)).tolist() for prediction in predictions]
scores = [prediction[image_index]['scores'].tolist() for prediction in predictions]
labels = [np.ones(prediction[image_index]['scores'].shape[0]).tolist() for prediction in predictions]
boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels, weights=None, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
boxes = boxes*(image_size-1)
return boxes, scores, labels
if __name__ == '__main__':
import matplotlib.pyplot as plt
for j, (images, image_ids) in enumerate(data_loader):
predictions = make_predictions(images)
i = 0
sample = images[i].permute(1, 2, 0).cpu().numpy()
boxes, scores, labels = run_wbf(predictions, image_index=i)
boxes = boxes.astype(np.int32).clip(min=0, max=511)
fig, ax = plt.subplots(1, 1, figsize=(16, 8))
for box in boxes:
cv2.rectangle(sample, (box[0], box[1]), (box[2], box[3]), (1, 0, 0), 1)
# ax.set_axis_off()
# ax.imshow(sample)
cv2.imwrite(str(j)+'.jpg', sample*255)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xxyim/efficientDet-d5.git
git@gitee.com:xxyim/efficientDet-d5.git
xxyim
efficientDet-d5
efficientDet-d5
master

搜索帮助