当前仓库属于暂停状态,部分功能使用受限,详情请查阅 仓库状态说明
2 Star 3 Fork 1

bodyless/MTM
暂停

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
data_validate.py 5.62 KB
一键复制 编辑 原始数据 按行查看 历史
bodyless 提交于 2019-10-10 17:43 . Completely revised
"""
数据集合法性验证
"""
import os
import cv2
import math
import config
import tensorflow as tf
import math_utility as mu
import gesture_recognition_utility as gu
from CNM import point_matching
RESCALE_RATE = config.RESCALE_RATE
KERNEL_SIZE = config.KERNEL_SIZE
FAST_RADIUS = config.FAST_RADIUS
SET_MAX_LENGTH = config.SET_MAX_LENGTH
MIN_DISTANCE = config.MIN_DISTANCE
# 模型全局变量
REVISE_WIDTH = config.REVISE_WIDTH
REVISE_HEIGHT = config.REVISE_HEIGHT
MODEL_WIDTH = config.MODEL_WIDTH
MODEL_HEIGHT = config.MODEL_HEIGHT
SLOT_NUMBERS = config.SLOT_NUMBERS
# 路径全局变量
MODEL_DIR = './CNN/Model/model.meta'
VIDEO_FILE_DIR = 'E:/DataSet/' # 获取视频流
if __name__ == '__main__':
file_list = os.listdir(VIDEO_FILE_DIR) # 获取视频列表
# 遍历所有视频文件
for file in range(len(file_list)):
video_list = os.listdir(os.path.join(VIDEO_FILE_DIR, str(file_list[file])))
for video in range(len(video_list)):
video_path = os.path.join(VIDEO_FILE_DIR, str(file_list[file]), str(video_list[video]))
video_capture = cv2.VideoCapture(video_path)
print(video_path)
tf.reset_default_graph() # 重置tensorflow缓存,避免内存泄漏
with tf.Session() as sess:
# 导入模型
saver = tf.train.import_meta_graph(MODEL_DIR)
saver.restore(sess, tf.train.latest_checkpoint('./CNN/Model/'))
# 导入计算图
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('input/x:0')
# 逐帧计算
last_points = []
time_sum = 0
ang_list = []
dis_list = []
frame_count = 0
while True:
ret, frame = video_capture.read()
# 判断视频流中的帧是否存在
if frame is None:
break
# 强制竖屏
if frame.shape[1] > frame.shape[0]:
frame = cv2.transpose(frame)
frame = cv2.flip(frame, 1)
frame = cv2.resize(frame, (REVISE_WIDTH, REVISE_HEIGHT))
copy = frame
# 先缩小二分之一进行检测,再映射到原来的尺度空间中
frame = cv2.resize(frame, (int(REVISE_WIDTH / RESCALE_RATE), int(REVISE_HEIGHT / RESCALE_RATE)))
heat_map, point_set = gu.get_heatmap(frame) # 获取热力图和聚类中心
cv2.imshow('heat map', heat_map)
input_list, out_points = gu.cut_image(copy, point_set) # 对图像进行切片
# 导入数据集进行测试
feed_dict = {x: input_list}
logits = graph.get_tensor_by_name('logits_eval:0')
classification_result = sess.run(logits, feed_dict)
# 输出预测矩阵每一行最大值的索引
output = tf.argmax(classification_result, 1).eval()
# 判断是否为指尖点,0是negative,1是positive
fingertips = []
for i in range(len(out_points)):
if int(output[i]) == 1:
fingertips.append(out_points[i])
for i in range(len(fingertips)):
cv2.circle(copy, tuple((fingertips[i][0] * RESCALE_RATE, fingertips[i][1] * RESCALE_RATE)), 3, (0, 0, 255), cv2.FILLED)
# 绘制匹配线段
if frame_count > 0:
match = point_matching(fingertips, last_points, slot=SLOT_NUMBERS)
ang_set = [-1] * SLOT_NUMBERS
dis_set = [-1] * SLOT_NUMBERS
for i in range(len(match)):
point_sta = fingertips[match[i][0]]
point_end = last_points[match[i][1]]
distance = mu.get_distance(point_sta, point_end)
angle = mu.get_angle(point_sta, point_end)
if distance > MIN_DISTANCE and angle != -1:
ang_set[match[i][0]] = round(angle / (2 * math.pi), 2)
dis_set[match[i][0]] = round(distance, 2)
if ang_set.count(-1) < SLOT_NUMBERS - 1:
print(ang_set)
ang_list.append(ang_set)
dis_list.append(dis_set)
p1 = tuple((point_sta[0] * RESCALE_RATE, point_sta[1] * RESCALE_RATE))
p2 = tuple((point_end[0] * RESCALE_RATE, point_end[1] * RESCALE_RATE))
cv2.line(copy, p1, p2, (255, 0, 0), 4)
last_points = fingertips
frame_count += 1
cv2.imshow('show_img', copy)
if cv2.waitKey(1) & 0xff == ord('n'):
continue
judge = None
while True:
print('Please judge this video: ')
judge = input()
if judge == 'n' or judge == 'y':
break
video_capture.release()
if judge == 'n':
os.remove(video_path)
print('Already delete this video!')
cv2.destroyAllWindows()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/bodyless/MTM.git
git@gitee.com:bodyless/MTM.git
bodyless
MTM
MTM
master

搜索帮助