代码拉取完成,页面将自动刷新
import cv2
import torch
import torch.nn.functional as F
from pathlib import Path
import numpy as np
from glob import glob
from tqdm import tqdm
import os
from models.assembly.mul_loss_model import PanModel
def save_tensor(tensor, i, im, save_dir):
im = cv2.resize(im, (1280, 800))
np_array = tensor[0].cpu().numpy().transpose(1, 2, 0)
np_array = cv2.resize(np_array, (1280, 800))
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
im[np_array > 0.1] = 255
# new_im = im * np_array
cv2.imwrite(save_dir + os.path.basename(im_name).replace('.jpg', '-label.jpg'), im)
# cv2.imwrite(str(i)+'.jpg', np.vstack((np_array * 254, im)))
def save_results(im1,tensor, im_name, save_dir):
np_array = tensor[0].squeeze(0).cpu().numpy()
# im = cv2.cvtColor(np_array, cv2.COLOR_BGR2GRAY)
np_array = cv2.resize(np_array, (im1.shape[1], im1.shape[0]))
# np_array[np_array > 0.1] = 255
# new_im = im * np_array
np_array = cv2.cvtColor(np_array, cv2.COLOR_GRAY2BGR)
render_img = np.hstack((im1, np_array*255))
# render_img = np_array*255
cv2.imwrite(save_dir + os.path.basename(im_name).replace('.jpg', '-label.jpg'), render_img)
def main_pannet():
# device = torch.device('cpu')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = PanModel().to(device)
model.load_state_dict(torch.load('./ckpt/pan_absnet_480.pth', map_location=torch.device('cpu')))
model.eval()
# r'D:\PycharmProjects\table\CV-all-in-one\mask_label\aug'
# fd = r'D:\Projects_data\table\test'
# fd = 'mask_label/test'
fd = r'D:\Projects_data\table\test'
save_dir = 'pan_absnet_480_fy/'
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
images = [str(i) for i in Path(fd).glob('*.png')]
width, height = 1280, 800
with torch.no_grad():
for i, image in tqdm(enumerate(images)):
im1 = cv2.imread(image)
im = cv2.resize(im1, (width, height))/255
im = im[:, :, ::-1].transpose(2, 0, 1)
im = np.ascontiguousarray(im)
im = torch.from_numpy(im).float()
if torch.cuda.is_available():
im = im.cuda()
if im.ndimension() == 3:
im = im.unsqueeze(0)
c2,c3,c4,c5,FY = model(im)
# cv2.imwrite('c2.jpg',c2*255)
# cv2.imwrite('FY.jpg', FY * 255)
# cv2.imwrite('c3.jpg', c3 * 255)
# outputs = F.interpolate(outputs, scale_factor=4)
# save_tensor(pred, i, im1)
# save_results(im1,c2, image, save_dir)
# save_results(im1, c3, image, save_dir)
# save_results(im1, c4, image, save_dir)
# save_results(im1, c5, image, save_dir)
save_results(im1, FY, image, save_dir)
if __name__ == '__main__':
main_pannet()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。