代码拉取完成,页面将自动刷新
import sys
from ensemble_boxes import *
import torch
import numpy as np
import pandas as pd
from glob import glob
from torch.utils.data import Dataset,DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import gc
from matplotlib import pyplot as plt
from effdet import get_efficientdet_config, EfficientDet, DetBenchEval
from effdet.efficientdet import HeadNet
import os
def get_valid_transforms():
return A.Compose([
A.Resize(height=512, width=512, p=1.0),
ToTensorV2(p=1.0),
], p=1.0)
DATA_ROOT_PATH = r'D:\Workspace\Python\GWD\data\test'
class DatasetRetriever(Dataset):
def __init__(self, image_ids, transforms=None):
super().__init__()
self.image_ids = image_ids
self.transforms = transforms
def __getitem__(self, index: int):
image_id = self.image_ids[index]
image = cv2.imread(os.path.join(DATA_ROOT_PATH,image_id+".jpg"), cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
image /= 255.0
if self.transforms:
sample = {'image': image}
sample = self.transforms(**sample)
image = sample['image']
return image, image_id
def __len__(self) -> int:
return self.image_ids.shape[0]
dataset = DatasetRetriever(
image_ids=np.array([path.split('/')[-1][:-4] for path in glob(f'{DATA_ROOT_PATH}/*.jpg')]),
transforms=get_valid_transforms()
)
def collate_fn(batch):
return tuple(zip(*batch))
data_loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=4,
drop_last=False,
collate_fn=collate_fn
)
def load_net(checkpoint_path):
config = get_efficientdet_config('tf_efficientdet_d5')
net = EfficientDet(config, pretrained_backbone=False)
config.num_classes = 1
config.image_size=512
net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['model_state_dict'])
del checkpoint
gc.collect()
net = DetBenchEval(net, config)
net.eval();
return net.cuda()
net = load_net(r'D:\Workspace\efficientdet-pytorch-master\effdet5-cutmix-augmix1\last-checkpoint.bin')
def make_predictions(images, score_threshold=0.22):
images = torch.stack(images).cuda().float()
predictions = []
with torch.no_grad():
det = net(images, torch.tensor([1]*images.shape[0]).float().cuda())
for i in range(images.shape[0]):
boxes = det[i].detach().cpu().numpy()[:,:4]
scores = det[i].detach().cpu().numpy()[:,4]
indexes = np.where(scores > score_threshold)[0]
boxes = boxes[indexes]
boxes[:, 2] = boxes[:, 2] + boxes[:, 0]
boxes[:, 3] = boxes[:, 3] + boxes[:, 1]
predictions.append({
'boxes': boxes[indexes],
'scores': scores[indexes],
})
return [predictions]
def run_wbf(predictions, image_index, image_size=512, iou_thr=0.44, skip_box_thr=0.43, weights=None):
boxes = [(prediction[image_index]['boxes']/(image_size-1)).tolist() for prediction in predictions]
scores = [prediction[image_index]['scores'].tolist() for prediction in predictions]
labels = [np.ones(prediction[image_index]['scores'].shape[0]).tolist() for prediction in predictions]
boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels, weights=None, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
boxes = boxes*(image_size-1)
return boxes, scores, labels
if __name__ == '__main__':
import matplotlib.pyplot as plt
for j, (images, image_ids) in enumerate(data_loader):
predictions = make_predictions(images)
i = 0
sample = images[i].permute(1, 2, 0).cpu().numpy()
boxes, scores, labels = run_wbf(predictions, image_index=i)
boxes = boxes.astype(np.int32).clip(min=0, max=511)
fig, ax = plt.subplots(1, 1, figsize=(16, 8))
for box in boxes:
cv2.rectangle(sample, (box[0], box[1]), (box[2], box[3]), (1, 0, 0), 1)
# ax.set_axis_off()
# ax.imshow(sample)
cv2.imwrite(str(j)+'.jpg', sample*255)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。