1 Star 0 Fork 0

seekerrc/actiondet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
ucf_thow_TSM.py 12.17 KB
一键复制 编辑 原始数据 按行查看 历史
seekerrc 提交于 2022-04-12 16:47 . reorganize structure
import numpy as np
import cv2
import os
import time
from PIL import Image
import io
import copy
import torch
from models.experimental import attempt_load
from utils.torch_utils import select_device
from utils.general import (
check_img_size, non_max_suppression, apply_classifier, scale_coords,
xyxy2xywh, xywh2xyxy, strip_optimizer)
from torchvision import transforms
# import random
import ucf_TSM_Module
import random
from torch.nn import functional as F
weights = 'weights/yolov5x.pt'
imgsize = 640
confthres = 0.4
iouthres = 0.5
frame_number = 8
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']
# Initialize
device = select_device('0')
half = device.type != 'cpu'
model = attempt_load(weights, map_location=device)
names = model.module.names if hasattr(model, 'module') else model.names
if half:
model.half() # to FP16
imgsz = check_img_size(imgsize, s=model.stride.max())
# image data list
imagedatalist = []
# result list
resultlist = []
resultPoseList = []
# detected index
detected = []
posedetected = []
proposalFlag = False
from efficientnet_pytorch import EfficientNet
# classnames = ["abnormal"->0, "normal"->1]
poseNames = ["bend", "fall", "jump", "lie", "run", "sit", "squat", "stand", "throw", "walk"]
# TODO 2 classes versus 10 classes
# posemodel = EfficientNet.from_pretrained('efficientnet-b5',
# weights_path='weights/pose2/pose.best.pth.tar',
# num_classes=2, load_fc=True)
posemodel = EfficientNet.from_pretrained('efficientnet-b5',
weights_path='weights/pose1215/pose.best.pth.tar',
num_classes=10, load_fc=True)
posemodel.to(device)
posemodel.eval()
if half:
posemodel.half() # to FP16
# transform after crop
tfms = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
detectAct = 0
detectOther = 0
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return img, ratio, (dw, dh)
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
# Plots one bounding box on image img
tl = line_thickness or round(
0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label:
tf = max(tl - 1, 1) # font thickness
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3,
[225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
def plot_action(img, color=None, line_thickness=None, action=None):
tl = line_thickness or round(
0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]
tf = max(tl - 1, 1) # font thickness
t_size1 = cv2.getTextSize(action, 0, fontScale=tl / 3, thickness=tf)[0]
c3 = (10, 10)
# c3 = (1, img.shape[0] - 1)
c4 = (c3[0] + t_size1[0], c3[1] - t_size1[1] - 3)
cv2.rectangle(img, c3, c4, color, -1, cv2.LINE_AA) # filled
cv2.putText(img, action, (c3[0], c3[1] - 2), 0, tl / 3,
[225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
def weightedPose(poseList):
weightedList = []
length = len(poseList)
for i in range(length):
leftIndex = max(0, i - 2)
rightIndex = min(length - 1, i + 2)
windowSum = sum(poseList[leftIndex: rightIndex + 1])
weight = windowSum - poseList[i]
weightedList.append(weight)
return list(map(lambda x, y: x * y, weightedList, poseList))
def fallDetection(suffix, img0):
global detectAct, detectOther, proposalFlag
img_org = copy.deepcopy(img0)
height, width, _ = img0.shape[0], img0.shape[1], img0.shape[2]
img = letterbox(img0, new_shape=imgsize)[0]
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# convert to PIL image
pilimg = Image.fromarray(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB))
# detection process
pred = model(img, augment=False)[0]
pred = non_max_suppression(pred, confthres, iouthres, classes=None, agnostic=False)
result = []
poseimglist = []
poseOutputList = []
normalflag = True
for i, det in enumerate(pred):
if det is not None and len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()
for *xyxyTensor, conf, cls in reversed(det):
xyxy = torch.tensor(xyxyTensor).view(1, 4).view(-1).tolist()
height = xyxy[3] - xyxy[1]
if cls == 0 and height >= 50:
# if cls == 0:
# cropstart = time.time()
cropleft, croptop, cropright, cropbottom = xyxy
cropped = pilimg.crop((cropleft, croptop, cropright, cropbottom))
cropped = tfms(cropped)
poseimglist.append(cropped)
# cropend = time.time()
# print('crop single box time: {:.3f}'.format(float(cropend - cropstart)))
result.append(xyxy)
# posestart = time.time()
# classnames = ["abnormal"->0, "normal"->1]
# poseNames = ["bend", "fall", "jump", "lie", "ride", "run", "sit", "squat", "stand", "throw", "walk"]
if len(poseimglist):
posemodelinput = torch.stack(poseimglist, dim=0)
posemodelinput = posemodelinput.to(device)
posemodelinput = posemodelinput.half() if half else posemodelinput.float()
with torch.no_grad():
output = posemodel(posemodelinput)
output = F.softmax(output, dim=1)
prob, pred = output.topk(1, 1, True, True)
for i in range(len(pred)):
poseOutputList.append([poseNames[pred[i].item()], prob[i].item()])
# normalflag = pred[i].item() == 1
# if not normalflag:
# break
# poseend = time.time()
# print('pose single frame time: {:.3f}'.format(float(poseend - posestart)))
if len(detected) >= frame_number:
imagedatalist.pop(0)
resultlist.pop(0)
resultPoseList.pop(0)
detected.pop(0)
posedetected.pop(0)
imagedatalist.append(img_org)
resultlist.append(result)
resultPoseList.append(poseOutputList)
if len(result):
detected.append(1)
else:
detected.append(0)
# 异常判断逻辑
if len(poseOutputList):
tmp = set([x[0] for x in poseOutputList])
# if "fall" in tmp or "lie" in tmp or "jump" in tmp:
if "throw" in tmp:
posedetected.append(2)
# elif "squat" in tmp:
# posedetected.append(1)
else:
posedetected.append(0)
else:
posedetected.append(0)
# print(*posedetected, sep=' ')
if(len(posedetected) == frame_number):
newPose = weightedPose(posedetected)
# 生成proposal条件:加权后的异常值大于4,即比[0, 0, 0, 0, 0, 0, 1, 2]要好,且存在一次典型姿态
proposalFlag = True
# TODO filter condition
if proposalFlag:
#TODO TSM
tsmPred, tsmConf = ucf_TSM_Module.alertAction(imagedatalist)
actionLabel = f'{tsmPred} {tsmConf:.2f}%'
# or tsmPred == 'JavelinThrow'
if tsmPred == 'BaseballPitch':
detectAct += 1
for index, actFrame in enumerate(imagedatalist):
for (resultBox, resultPose) in zip(resultlist[index], resultPoseList[index]):
# label = f'{resultPose[0]} {resultPose[1]:.2f}'
# label = f'{resultPose[0]}'
label = f'{resultPose[0]} {resultPose[1]:.2f} {resultBox[3] - resultBox[1]}'
# plot_one_box(resultBox, actFrame, label=label,
# color=[0, 0, 255], line_thickness=3)
plot_one_box(resultBox, actFrame, label=label)
plot_action(actFrame, action=actionLabel)
savepath = os.path.join('data/ucftest/baseball', '{0}_{5}_{1}_{2}_detected{3}_pose{4}.jpg'.format(
tsmPred, detectAct, index, detected[index], posedetected[index], suffix))
# savepath = os.path.join('data/falltest/all', '{0}_{1}_{2}_detected{3}_pose{4}.jpg'.format(
# detectOther, tsmPred, index, detected[index], posedetected[index]))
cv2.imwrite(savepath, actFrame)
else:
detectOther += 1
resultlist.clear()
resultPoseList.clear()
imagedatalist.clear()
detected.clear()
posedetected.clear()
proposalFlag = False
def det(i, videoPath):
cap = cv2.VideoCapture(videoPath)
fNUMS = cap.get(cv2.CAP_PROP_FRAME_COUNT)
fps = cap.get(cv2.CAP_PROP_FPS)
print("FPS is ", fps)
frameCount = 0
print("Starting...")
start = time.time()
while cap.isOpened():
ret, frame = cap.read()
if ret:
frameCount += 1
if frameCount % 6 == 0:
fallDetection(i, frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
cap.release()
end = time.time()
print("all done")
print("processing time : {0:.1f}, total : {1}, act : {2}".format(
float(end - start), detectAct + detectOther, detectAct))
if __name__ == "__main__":
os.system('rm -rf data/ucftest/baseball/*')
for i in [1, 2, 3, 5, 7, 9]:
videoPath = 'data/thumos14/baseball/video_validation_000068{0}.mp4'.format(i)
det(i, videoPath)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/seekerrc/actiondet.git
git@gitee.com:seekerrc/actiondet.git
seekerrc
actiondet
actiondet
master

搜索帮助