代码拉取完成,页面将自动刷新
"""
获取LSTM轨迹数据集
"""
import os
import cv2
import math
import config
import tensorflow as tf
import math_utility as mu
import gesture_recognition_utility as gu
from CNM import point_matching
RESCALE_RATE = config.RESCALE_RATE
KERNEL_SIZE = config.KERNEL_SIZE
FAST_RADIUS = config.FAST_RADIUS
SET_MAX_LENGTH = config.SET_MAX_LENGTH
MIN_DISTANCE = config.MIN_DISTANCE
# 模型全局变量
REVISE_WIDTH = config.REVISE_WIDTH
REVISE_HEIGHT = config.REVISE_HEIGHT
MODEL_WIDTH = config.MODEL_WIDTH
MODEL_HEIGHT = config.MODEL_HEIGHT
SLOT_NUMBERS = config.SLOT_NUMBERS
# 路径全局变量
MODEL_DIR = './CNN/Model/model.meta'
OUT_FILE = './LSTM/Data/data.txt'
VIDEO_FILE_DIR = './LSTM/Trainsets_LSTM/' # 获取视频流
# 对轨迹编码1-12
def angle_to_number(angle):
if angle == -1:
return angle
number = int(angle * 12) + 1
return number
if __name__ == '__main__':
file_list = os.listdir(VIDEO_FILE_DIR) # 获取视频列表
# 遍历所有视频文件
for file in range(len(file_list)):
video_list = os.listdir(os.path.join(VIDEO_FILE_DIR, str(file_list[file])))
for video in range(len(video_list)):
video_path = os.path.join(VIDEO_FILE_DIR, str(file_list[file]), str(video_list[video]))
video_capture = cv2.VideoCapture(video_path)
print(video_path)
tf.reset_default_graph() # 重置tensorflow缓存,避免内存泄漏
with tf.Session() as sess:
# 导入模型
saver = tf.train.import_meta_graph(MODEL_DIR)
saver.restore(sess, tf.train.latest_checkpoint('./CNN/Model/'))
# 导入计算图
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('input/x:0')
# 逐帧计算
last_points = []
frame_count = 0
ang_list = []
dis_list = []
while True:
ret, frame = video_capture.read()
# 判断视频流中的帧是否存在
if frame is None:
break
# 强制竖屏
if frame.shape[1] > frame.shape[0]:
frame = cv2.transpose(frame)
frame = cv2.flip(frame, 1)
frame = cv2.resize(frame, (REVISE_WIDTH, REVISE_HEIGHT))
copy = frame
# 先缩小二分之一进行检测,再映射到原来的尺度空间中
frame = cv2.resize(frame, (int(REVISE_WIDTH / RESCALE_RATE), int(REVISE_HEIGHT / RESCALE_RATE)))
heat_map, point_set = gu.get_heatmap(frame) # 获取热力图和聚类中心
input_list, out_points = gu.cut_image(copy, point_set) # 对图像进行切片
# 导入数据集进行测试
feed_dict = {x: input_list}
logits = graph.get_tensor_by_name('logits_eval:0')
classification_result = sess.run(logits, feed_dict)
# 输出预测矩阵每一行最大值的索引
output = tf.argmax(classification_result, 1).eval()
# 判断是否为指尖点,0是negative,1是positive
fingertips = []
for i in range(len(out_points)):
if int(output[i]) == 1:
fingertips.append(out_points[i])
'''
for i in range(len(fingertips)):
cv2.circle(copy, tuple((fingertips[i][0] * RESCALE_RATE, fingertips[i][1] * RESCALE_RATE)), 3, (0, 0, 255), cv2.FILLED)
'''
# 绘制匹配线段
if frame_count > 0:
match = point_matching(fingertips, last_points, slot=SLOT_NUMBERS)
ang_set = [-1] * SLOT_NUMBERS
dis_set = [-1] * SLOT_NUMBERS
for i in range(len(match)):
point_sta = fingertips[match[i][0]]
point_end = last_points[match[i][1]]
distance = mu.get_distance(point_sta, point_end)
angle = mu.get_angle(point_sta, point_end)
if distance > MIN_DISTANCE and angle != -1:
ang_set[match[i][0]] = round(angle / (2 * math.pi), 2)
dis_set[match[i][0]] = round(distance, 2)
if ang_set.count(-1) < SLOT_NUMBERS - 1:
ang_list.append(ang_set)
dis_list.append(dis_set)
print(ang_set)
'''
p1 = tuple((point_sta[0] * RESCALE_RATE, point_sta[1] * RESCALE_RATE))
p2 = tuple((point_end[0] * RESCALE_RATE, point_end[1] * RESCALE_RATE))
cv2.line(copy, p1, p2, (0, 255, 255))
'''
last_points = fingertips
frame_count += 1
'''
cv2.imshow('show_img', max_skin)
if cv2.waitKey(1) & 0xff == ord('n'):
continue
'''
# 文件存储
out_file = open(OUT_FILE, 'a+')
for i in range(5):
for j in range(len(ang_list)):
out_file.write(' ' + str(angle_to_number(ang_list[j][i])))
out_file.write(';')
out_file.write(str(file))
out_file.write('\n')
out_file.close()
video_capture.release()
cv2.destroyAllWindows()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。