代码拉取完成,页面将自动刷新
"""
This script will use the 2D box from the label rather than from YOLO,
but will still use the neural nets to get the 3D position and plot onto the
image. Press space for next image and escape to quit
"""
from torch_lib.Dataset import *
from library.Math import *
from library.Plotting import *
from torch_lib import Model, ClassAverages
import os
import cv2
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.models import vgg
import numpy as np
# to run car by car
single_car = False
def plot_regressed_3d_bbox(img, truth_img, cam_to_img, box_2d, dimensions, alpha, theta_ray):
# the math! returns X, the corners used for constraint
location, X = calc_location(dimensions, cam_to_img, box_2d, alpha, theta_ray)
orient = alpha + theta_ray
plot_2d_box(truth_img, box_2d)
plot_3d_box(img, cam_to_img, orient, dimensions, location) # 3d boxes
return location
def main():
weights_path = os.path.abspath(os.path.dirname(__file__)) + '/weights'
model_lst = [x for x in sorted(os.listdir(weights_path)) if x.endswith('.pkl')]
if len(model_lst) == 0:
print('No previous model found, please train first!')
exit()
else:
print ('Using previous model %s'%model_lst[-1])
my_vgg = vgg.vgg19_bn(pretrained=True)
model = Model.Model(features=my_vgg.features, bins=2).cuda()
checkpoint = torch.load(weights_path + '/%s'%model_lst[-1])
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# defaults to /eval
dataset = Dataset(os.path.abspath(os.path.dirname(__file__)) + '/eval')
averages = ClassAverages.ClassAverages()
all_images = dataset.all_objects()
for key in sorted(all_images.keys()):
start_time = time.time()
data = all_images[key]
truth_img = data['Image']
img = np.copy(truth_img)
objects = data['Objects']
cam_to_img = data['Calib']
for detectedObject in objects:
label = detectedObject.label
theta_ray = detectedObject.theta_ray
input_img = detectedObject.img
input_tensor = torch.zeros([1,3,224,224]).cuda()
input_tensor[0,:,:,:] = input_img
input_tensor.cuda()
[orient, conf, dim] = model(input_tensor)
orient = orient.cpu().data.numpy()[0, :, :]
conf = conf.cpu().data.numpy()[0, :]
dim = dim.cpu().data.numpy()[0, :]
dim += averages.get_item(label['Class'])
argmax = np.argmax(conf)
orient = orient[argmax, :]
cos = orient[0]
sin = orient[1]
alpha = np.arctan2(sin, cos)
alpha += dataset.angle_bins[argmax]
alpha -= np.pi
location = plot_regressed_3d_bbox(img, truth_img, cam_to_img, label['Box_2D'], dim, alpha, theta_ray)
print('Estimated pose: %s'%location)
print('Truth pose: %s'%label['Location'])
print('-------------')
# plot car by car
if single_car:
numpy_vertical = np.concatenate((truth_img, img), axis=0)
cv2.imshow('2D detection on top, 3D prediction on bottom', numpy_vertical)
cv2.waitKey(0)
print('Got %s poses in %.3f seconds\n'%(len(objects), time.time() - start_time))
# plot image by image
if not single_car:
numpy_vertical = np.concatenate((truth_img, img), axis=0)
cv2.imshow('2D detection on top, 3D prediction on bottom', numpy_vertical)
if cv2.waitKey(0) == 27:
return
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。