2 Star 1 Fork 0

CJLU2021/table-seg

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
my_table_predict.py 5.25 KB
一键复制 编辑 原始数据 按行查看 历史
syshensyshen 提交于 2021-07-27 21:30 . table segment
import cv2
import torch
import torch.nn.functional as F
from pathlib import Path
import numpy as np
from glob import glob
from tqdm import tqdm
import os
import sys
from models.assembly.segmentation_table import Segmentation_Model
# from models.assembly.segmentation_table import Segmentation_Model
from models.assembly.deeplab import DeepLabV3
from models.assembly.my_pan_model import PanModel
def save_tensor(tensor, i, im, save_dir):
im = cv2.resize(im, (1280, 800))
np_array = tensor[0].cpu().numpy().transpose(1, 2, 0)
np_array = cv2.resize(np_array, (1280, 800))
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
im[np_array > 0.1] = 255
# new_im = im * np_array
cv2.imwrite(save_dir + os.path.basename(im_name).replace('.jpg', '-label.jpg'), im)
# cv2.imwrite(str(i)+'.jpg', np.vstack((np_array * 254, im)))
def save_results(im1,tensor, im_name, save_dir):
np_array = tensor[0].squeeze(0).cpu().numpy()
# im = cv2.cvtColor(np_array, cv2.COLOR_BGR2GRAY)
np_array = cv2.resize(np_array, (im1.shape[1], im1.shape[0]))
# np_array[np_array > 0.1] = 255
# new_im = im * np_array
np_array = cv2.cvtColor(np_array, cv2.COLOR_GRAY2BGR)
render_img = np.hstack((im1, np_array*255))
# render_img = np_array*255
cv2.imwrite(save_dir + os.path.basename(im_name).replace('.jpg', '-label.jpg'), render_img)
def main():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = Segmentation_Model().to(device)
model.load_state_dict(torch.load('./ckpt/hard_mining_v1.3_190.pth', map_location=torch.device('cpu')))
model.eval()
r'D:\PycharmProjects\table\CV-all-in-one\mask_label\aug'
fd = r':\PycharmProjects\table\CV-all-in-one\mask_label\aug\images'
save_dir = 'results/'
images = [str(i) for i in Path(fd).glob('*.jp*')]
width, height = 1280, 800
with torch.no_grad():
for i, image in tqdm(enumerate(images)):
im1 = cv2.imread(image)
im = cv2.resize(im1, (width, height))/255
im = im[:, :, ::-1].transpose(2, 0, 1)
im = np.ascontiguousarray(im)
im = torch.from_numpy(im).float()
if torch.cuda.is_available():
im = im.cuda()
if im.ndimension() == 3:
im = im.unsqueeze(0)
outputs = model(im)
pred = F.interpolate(outputs, scale_factor=4)
# save_tensor(pred, i, im1)
save_results(pred, image, save_dir)
def main_deeplabv3():
# device = torch.device('cpu')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = DeepLabV3(1).to(device)
model.load_state_dict(torch.load('./ckpt/hard_mining_v2.0_20.pth', map_location=torch.device('cpu')))
model.eval()
# r'D:\PycharmProjects\table\CV-all-in-one\mask_label\aug'
fd = r'D:\Projects_data\table\test'
save_dir = 'results_v2.0/'
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
images = [str(i) for i in Path(fd).glob('*.pn*')]
width, height = 1280, 800
with torch.no_grad():
for i, image in tqdm(enumerate(images)):
im1 = cv2.imread(image)
im = cv2.resize(im1, (width, height))/255
im = im[:, :, ::-1].transpose(2, 0, 1)
im = np.ascontiguousarray(im)
im = torch.from_numpy(im).float()
if torch.cuda.is_available():
im = im.cuda()
if im.ndimension() == 3:
im = im.unsqueeze(0)
outputs = model(im).sigmoid()
pred = F.interpolate(outputs, scale_factor=4)
# save_tensor(pred, i, im1)
save_results(im1,pred, image, save_dir)
def main_pannet():
# device = torch.device('cpu')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = PanModel().to(device)
model.load_state_dict(torch.load('./ckpt/pan_net_resnet50_gc_v1.6_655.pth', map_location=torch.device('cpu')))
model.eval()
# r'D:\PycharmProjects\table\CV-all-in-one\mask_label\aug'
fd = r'D:\Projects_data\table\test'
# fd = r'D:\Projects_data\crnn\contract\sg_mm_con_img'
save_dir = 'pan_net_resnet50_gc_v1.6_655_2/'
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
images = [str(i) for i in Path(fd).glob('*.png')]
# width, height = int(1580), int(1580)
with torch.no_grad():
for i, image in tqdm(enumerate(images)):
im1 = cv2.imread(image)
if np.array(im1.shape[:2]).max()<1280:
im = im1/255
else:
w,h = im1.shape[0]//1280+1,im1.shape[1]//1280+1
width, height = int(im1.shape[0]/w), int(im1.shape[1]/h)
print(width, height)
im = cv2.resize(im1, (width, height))/255
im = im[:, :, ::-1].transpose(2, 0, 1)
im = np.ascontiguousarray(im)
im = torch.from_numpy(im).float()
if torch.cuda.is_available():
im = im.cuda()
if im.ndimension() == 3:
im = im.unsqueeze(0)
outputs = model(im).sigmoid()
# outputs = F.interpolate(outputs, scale_factor=4)
# save_tensor(pred, i, im1)
save_results(im1,outputs, image, save_dir)
if __name__ == '__main__':
main_pannet()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/cjlu2021/table-seg.git
git@gitee.com:cjlu2021/table-seg.git
cjlu2021
table-seg
table-seg
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385