1 Star 0 Fork 0

seekerrc/actiondet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
timing_Yolo_TSM.py 7.66 KB
一键复制 编辑 原始数据 按行查看 历史
seekerrc 提交于 2022-04-12 16:47 . reorganize structure
import numpy as np
import cv2
import os
import time
from PIL import Image
import io
import copy
import torch
from models.experimental import attempt_load
from utils.torch_utils import select_device
from utils.general import (
check_img_size, non_max_suppression, apply_classifier, scale_coords,
xyxy2xywh, xywh2xyxy, strip_optimizer)
from torchvision import transforms
# import random
import fall_TSM_Module
import random
from torch.nn import functional as F
weights = 'weights/yolov5x.pt'
imgsize = 640
confthres = 0.4
iouthres = 0.5
frame_number = 8
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']
# Initialize
device = select_device('0')
half = device.type != 'cpu'
model = attempt_load(weights, map_location=device)
names = model.module.names if hasattr(model, 'module') else model.names
if half:
model.half() # to FP16
imgsz = check_img_size(imgsize, s=model.stride.max())
# image data list
imagedatalist = []
# result list
resultlist = []
# detected index
detected = []
detectFall = 0
detectNum = 0
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return img, ratio, (dw, dh)
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
# Plots one bounding box on image img
tl = line_thickness or round(
0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label:
tf = max(tl - 1, 1) # font thickness
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3,
[225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
def plot_action(img, color=None, line_thickness=None, action=None):
tl = line_thickness or round(
0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]
tf = max(tl - 1, 1) # font thickness
t_size1 = cv2.getTextSize(action, 0, fontScale=tl / 3, thickness=tf)[0]
c3 = (100, 100)
# c3 = (1, img.shape[0] - 1)
c4 = (c3[0] + t_size1[0], c3[1] - t_size1[1] - 3)
cv2.rectangle(img, c3, c4, color, -1, cv2.LINE_AA) # filled
cv2.putText(img, action, (c3[0], c3[1] - 2), 0, tl / 3,
[225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
def fallDetection(img0):
global detectFall, detectNum
img_org = copy.deepcopy(img0)
height, width, _ = img0.shape[0], img0.shape[1], img0.shape[2]
img = letterbox(img0, new_shape=imgsize)[0]
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# convert to PIL image
pilimg = Image.fromarray(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB))
# detection process
pred = model(img, augment=False)[0]
pred = non_max_suppression(pred, confthres, iouthres, classes=None, agnostic=False)
result = []
for i, det in enumerate(pred):
if det is not None and len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()
for *xyxyTensor, conf, cls in reversed(det):
xyxy = torch.tensor(xyxyTensor).view(1, 4).view(-1).tolist()
height = xyxy[3] - xyxy[1]
# if cls == 0 and height >= 100:
if cls == 0 and height >= 20:
result.append(xyxy)
if len(detected) >= frame_number:
imagedatalist.pop(0)
resultlist.pop(0)
detected.pop(0)
imagedatalist.append(img_org)
resultlist.append(result)
if len(result):
detected.append(1)
else:
detected.append(0)
# TODO filter condition
if sum(detected) >= 4 and len(imagedatalist) == frame_number:
#TODO TSM
tsmPred, tsmConf = fall_TSM_Module.alertAction(imagedatalist)
if tsmPred == 'fall':
detectFall += 1
detectNum += 1
# resultlist.clear()
# imagedatalist.clear()
# detected.clear()
del imagedatalist[0:4]
del resultlist[0:4]
del detected[0:4]
def main():
# videoPath = 'data/background/none.mp4'
# videoPath = 'data/background/normal.mp4'
videoPath = 'data/background/person.mp4'
cap = cv2.VideoCapture(videoPath)
fNUMS = cap.get(cv2.CAP_PROP_FRAME_COUNT)
fps = cap.get(cv2.CAP_PROP_FPS)
print("FPS is ", fps)
frameCount = 0
print("Starting...")
start = time.time()
while cap.isOpened():
ret, frame = cap.read()
if ret:
frameCount += 1
if frameCount % 6 == 0:
fallDetection(frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
if frameCount % 1000 == 0:
print("Current :", round(frameCount / float(fNUMS) * 100, 1), "%")
cap.release()
end = time.time()
print("all done")
print("processing time : {0:.1f}, total : {1}, fall : {2}".format(
float(end - start), detectNum, detectFall))
if __name__ == "__main__":
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/seekerrc/actiondet.git
git@gitee.com:seekerrc/actiondet.git
seekerrc
actiondet
actiondet
master

搜索帮助