当前仓库属于暂停状态,部分功能使用受限,详情请查阅 仓库状态说明
2 Star 3 Fork 1

bodyless/MTM
暂停

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
dexter_test.py 3.65 KB
一键复制 编辑 原始数据 按行查看 历史
bodyless 提交于 2019-10-10 17:43 . Completely revised
"""
在Dexter数据集上的指尖点检测测试实验,我懒得写代码直接手动统计了
"""
import os
import cv2
import time
import config
import tensorflow as tf
import gesture_recognition_utility as gu
os.environ["CUDA_VISIBLE_DEVICES"] = '-1' # 禁用GPU
RESCALE_RATE = config.RESCALE_RATE
REVISE_WIDTH = config.REVISE_WIDTH
REVISE_HEIGHT = config.REVISE_HEIGHT
MODEL_WIDTH = config.MODEL_WIDTH
MODEL_HEIGHT = config.MODEL_HEIGHT
SLOT_NUMBERS = config.SLOT_NUMBERS
IMAGE_FILE = 'E:/Lijunjie_Papers/Paper-MTM/dexter/data/fingercount/multicam/cam0/' # 获取视频流
CNN_MODEL_DIR = './CNN/Model/model.meta'
if __name__ == '__main__':
corrected_sum = 0
fingertip_sum = 0
file_list = os.listdir(IMAGE_FILE)
# 打开session
with tf.Session() as sess:
# 导入模型
saver_cnn = tf.train.import_meta_graph(CNN_MODEL_DIR)
saver_cnn.restore(sess, tf.train.latest_checkpoint('./CNN/Model/')) # 导入CNN模型
# 导入计算图
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('input/x:0')
# 逐帧计算
last_points = []
frame_count = 0
for index in file_list:
frame = cv2.imread(IMAGE_FILE + index)
# <--------------------计时开始-------------------->
time_sta = time.time()
frame = cv2.resize(frame, (REVISE_WIDTH, REVISE_HEIGHT)) # 将图像压缩至480*640
copy = frame
# 先缩小二分之一进行检测,再映射到原来的尺度空间中
frame = cv2.resize(frame, (int(REVISE_WIDTH / RESCALE_RATE), int(REVISE_HEIGHT / RESCALE_RATE)))
heat_map, point_set = gu.get_heatmap(frame) # 获取热力图和聚类中心
heat_map = cv2.resize(heat_map, (REVISE_WIDTH, REVISE_HEIGHT))
cv2.imshow('heat map', heat_map)
input_list, out_points = gu.cut_image(copy, point_set, True) # 对图像进行切片
# 导入数据集进行测试
feed_dict = {x: input_list}
logits = graph.get_tensor_by_name('logits_eval:0')
time_network_sta = time.time()
classification_result = sess.run(logits, feed_dict)
time_network_end = time.time()
# 输出预测矩阵每一行最大值的索引
output = tf.argmax(classification_result, 1).eval()
# 判断是否为指尖点,0是negative,1是positive
fingertips = []
for i in range(len(out_points)):
if int(output[i]) == 1:
fingertips.append(out_points[i])
for i in range(len(fingertips)):
cv2.circle(copy, tuple((fingertips[i][0] * RESCALE_RATE, fingertips[i][1] * RESCALE_RATE)), 3, (0, 0, 255), cv2.FILLED)
last_points = fingertips
time_end = time.time()
# <--------------------计时结束-------------------->
cv2.putText(copy, 'FPS: ' + str(round(1 / (time_end - time_sta), 2)),
(int(REVISE_WIDTH / 20), int(REVISE_HEIGHT / 20)),
cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 255), 2)
cv2.imshow('corner', copy)
if cv2.waitKey(1):
corrected = int(input())
corrected_sum += corrected
fingertip_sum += len(fingertips)
if corrected_sum < fingertip_sum:
print(corrected_sum / fingertip_sum)
else:
print(fingertip_sum / corrected_sum)
continue
frame_count += 1
cv2.destroyAllWindows()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/bodyless/MTM.git
git@gitee.com:bodyless/MTM.git
bodyless
MTM
MTM
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385