代码拉取完成,页面将自动刷新
import cv2
import scipy.io as sio
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import os
import numpy as np
import torch
from PIL import Image
unloader = transforms.ToPILImage()
class Dataloader(Dataset):
def __init__(self, img_path, ann_path, down_sample=False):
# 图像路径文件夹 和 标签文件 文件夹 采用绝对路径
self.pre_img_path = img_path # 文件夹路径
self.pre_ann_path = ann_path
# 图像的文件名是 IMG_15.jpg 则 标签是 GT_IMG_15.mat
# 因此不需要listdir标签路径
self.img_names = os.listdir(img_path)
self.down = down_sample
# 返回tensor类的的image和对应的grund_truth
def __getitem__(self, index):
# 该函数用于返回每一个sample
img_name = self.img_names[index] # index对应的图片名
mat_name = 'GT_' + img_name.replace('jpg', 'mat')
img = cv2.imread(self.pre_img_path + img_name,0)
# img = np.array(Image.open(self.pre_img_path + img_name).convert("RGB"))
anno = sio.loadmat(self.pre_ann_path + mat_name)
xy = anno['image_info'][0][0][0][0][0] # N,2的坐标数组
density_map = self.get_density(img, xy) # 密度图
img = img.astype(np.float32)
density_map = density_map.astype(np.float32)
h = img.shape[0]
w = img.shape[1]
ht1 = h // 4 * 4
wd1 = w // 4 * 4
img = cv2.resize(img, (wd1, ht1))
if self.down:
wd1 = wd1 // 4
ht1 = ht1 // 4
density_map = cv2.resize(density_map, (wd1, ht1))
density_map = density_map * ((w * h) / (wd1 * ht1)) #
else:
density_map = cv2.resize(density_map, (wd1, ht1))
density_map = density_map * ((w * h) / (wd1 * ht1))
# img = torch.from_numpy(img.reshape(1, img.shape[0], img.shape[1]))
img_tensor = torch.tensor(img, dtype=torch.float)
img_tensor=img_tensor.resize_((1,img_tensor.shape[0], img_tensor.shape[1]))
density_map= torch.tensor(density_map,dtype=torch.float)
density_map = density_map.resize_((1, density_map.shape[0]//4, density_map.shape[1]//4))
return img_tensor, density_map
def __len__(self):
return len(self.img_names)
# 核函数
def fspecial(self, ksize_x, ksize_y, sigma):
kx = cv2.getGaussianKernel(ksize_x, sigma)
ky = cv2.getGaussianKernel(ksize_y, sigma)
return np.multiply(kx, np.transpose(ky)) # 核函数
# 通过point头部坐标返回label
def get_density(self, img, points):
h, w = img.shape[0], img.shape[1]
# 密度图 初始化全0
labels = np.zeros(shape=(h, w))
for loc in points:
f_sz = 17 # 滤波器尺寸 预设为15 也是邻域的尺寸
sigma = 4.0 # sigma参数
H = self.fspecial(f_sz, f_sz, sigma) # 高斯核矩阵
x = min(max(0, abs(int(loc[0]))), int(w)) # 头部坐标
y = min(max(0, abs(int(loc[1]))), int(h))
if x > w or y > h:
continue
x1 = x - f_sz / 2;
y1 = y - f_sz / 2
x2 = x + f_sz / 2;
y2 = y + f_sz / 2
dfx1 = 0;
dfy1 = 0;
dfx2 = 0;
dfy2 = 0
change_H = False
if x1 < 0:
dfx1 = abs(x1);
x1 = 0;
change_H = True
if y1 < 0:
dfy1 = abs(y1);
y1 = 0;
change_H = True
if x2 > w:
dfx2 = x2 - w;
x2 = w - 1;
change_H = True
if y2 > h:
dfy2 = y2 - h;
y2 = h - 1;
change_H = True
x1h = 1 + dfx1;
y1h = 1 + dfy1
x2h = f_sz - dfx2;
y2h = f_sz - dfy2
if change_H:
H = self.fspecial(int(y2h - y1h + 1), int(x2h - x1h + 1), sigma)
labels[int(y1):int(y2), int(x1):int(x2)] = labels[int(y1):int(y2), int(x1):int(x2)] + H
return labels
path1 = 'D:\\love_shangdong3\\data\\shanghai\\part_B_final\\test_data\\images'
path2 = 'D:\\love_shangdong3\\data\\shanghai\\part_B_final\\test_data\\ground_truth'
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。