代码拉取完成,页面将自动刷新
import numpy as np
import cv2
import os
import time
from PIL import Image
import io
import copy
import torch
from models.experimental import attempt_load
from utils.torch_utils import select_device
from utils.general import (
check_img_size, non_max_suppression, apply_classifier, scale_coords,
xyxy2xywh, xywh2xyxy, strip_optimizer)
from torchvision import transforms
# import random
import ucf_TSM_Module
import random
from torch.nn import functional as F
weights = 'weights/yolov5x.pt'
imgsize = 640
confthres = 0.4
iouthres = 0.5
frame_number = 8
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']
# Initialize
device = select_device('0')
half = device.type != 'cpu'
model = attempt_load(weights, map_location=device)
names = model.module.names if hasattr(model, 'module') else model.names
if half:
model.half() # to FP16
imgsz = check_img_size(imgsize, s=model.stride.max())
# image data list
imagedatalist = []
# result list
resultlist = []
resultPoseList = []
# detected index
detected = []
posedetected = []
proposalFlag = False
from efficientnet_pytorch import EfficientNet
# classnames = ["abnormal"->0, "normal"->1]
poseNames = ["bend", "fall", "jump", "lie", "run", "sit", "squat", "stand", "throw", "walk"]
# TODO 2 classes versus 10 classes
# posemodel = EfficientNet.from_pretrained('efficientnet-b5',
# weights_path='weights/pose2/pose.best.pth.tar',
# num_classes=2, load_fc=True)
posemodel = EfficientNet.from_pretrained('efficientnet-b5',
weights_path='weights/pose1215/pose.best.pth.tar',
num_classes=10, load_fc=True)
posemodel.to(device)
posemodel.eval()
if half:
posemodel.half() # to FP16
# transform after crop
tfms = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
detectAct = 0
detectOther = 0
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return img, ratio, (dw, dh)
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
# Plots one bounding box on image img
tl = line_thickness or round(
0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label:
tf = max(tl - 1, 1) # font thickness
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3,
[225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
def plot_action(img, color=None, line_thickness=None, action=None):
tl = line_thickness or round(
0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]
tf = max(tl - 1, 1) # font thickness
t_size1 = cv2.getTextSize(action, 0, fontScale=tl / 3, thickness=tf)[0]
c3 = (10, 10)
# c3 = (1, img.shape[0] - 1)
c4 = (c3[0] + t_size1[0], c3[1] - t_size1[1] - 3)
cv2.rectangle(img, c3, c4, color, -1, cv2.LINE_AA) # filled
cv2.putText(img, action, (c3[0], c3[1] - 2), 0, tl / 3,
[225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
def weightedPose(poseList):
weightedList = []
length = len(poseList)
for i in range(length):
leftIndex = max(0, i - 2)
rightIndex = min(length - 1, i + 2)
windowSum = sum(poseList[leftIndex: rightIndex + 1])
weight = windowSum - poseList[i]
weightedList.append(weight)
return list(map(lambda x, y: x * y, weightedList, poseList))
def fallDetection(suffix, img0):
global detectAct, detectOther, proposalFlag
img_org = copy.deepcopy(img0)
height, width, _ = img0.shape[0], img0.shape[1], img0.shape[2]
img = letterbox(img0, new_shape=imgsize)[0]
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# convert to PIL image
pilimg = Image.fromarray(cv2.cvtColor(img0, cv2.COLOR_BGR2RGB))
# detection process
pred = model(img, augment=False)[0]
pred = non_max_suppression(pred, confthres, iouthres, classes=None, agnostic=False)
result = []
poseimglist = []
poseOutputList = []
normalflag = True
for i, det in enumerate(pred):
if det is not None and len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()
for *xyxyTensor, conf, cls in reversed(det):
xyxy = torch.tensor(xyxyTensor).view(1, 4).view(-1).tolist()
height = xyxy[3] - xyxy[1]
if cls == 0 and height >= 50:
# if cls == 0:
# cropstart = time.time()
cropleft, croptop, cropright, cropbottom = xyxy
cropped = pilimg.crop((cropleft, croptop, cropright, cropbottom))
cropped = tfms(cropped)
poseimglist.append(cropped)
# cropend = time.time()
# print('crop single box time: {:.3f}'.format(float(cropend - cropstart)))
result.append(xyxy)
# posestart = time.time()
# classnames = ["abnormal"->0, "normal"->1]
# poseNames = ["bend", "fall", "jump", "lie", "ride", "run", "sit", "squat", "stand", "throw", "walk"]
if len(poseimglist):
posemodelinput = torch.stack(poseimglist, dim=0)
posemodelinput = posemodelinput.to(device)
posemodelinput = posemodelinput.half() if half else posemodelinput.float()
with torch.no_grad():
output = posemodel(posemodelinput)
output = F.softmax(output, dim=1)
prob, pred = output.topk(1, 1, True, True)
for i in range(len(pred)):
poseOutputList.append([poseNames[pred[i].item()], prob[i].item()])
# normalflag = pred[i].item() == 1
# if not normalflag:
# break
# poseend = time.time()
# print('pose single frame time: {:.3f}'.format(float(poseend - posestart)))
if len(detected) >= frame_number:
imagedatalist.pop(0)
resultlist.pop(0)
resultPoseList.pop(0)
detected.pop(0)
posedetected.pop(0)
imagedatalist.append(img_org)
resultlist.append(result)
resultPoseList.append(poseOutputList)
if len(result):
detected.append(1)
else:
detected.append(0)
# 异常判断逻辑
if len(poseOutputList):
tmp = set([x[0] for x in poseOutputList])
# if "fall" in tmp or "lie" in tmp or "jump" in tmp:
if "throw" in tmp:
posedetected.append(2)
# elif "squat" in tmp:
# posedetected.append(1)
else:
posedetected.append(0)
else:
posedetected.append(0)
# print(*posedetected, sep=' ')
if(len(posedetected) == frame_number):
newPose = weightedPose(posedetected)
# 生成proposal条件:加权后的异常值大于4,即比[0, 0, 0, 0, 0, 0, 1, 2]要好,且存在一次典型姿态
proposalFlag = True
# TODO filter condition
if proposalFlag:
#TODO TSM
tsmPred, tsmConf = ucf_TSM_Module.alertAction(imagedatalist)
actionLabel = f'{tsmPred} {tsmConf:.2f}%'
# or tsmPred == 'JavelinThrow'
if tsmPred == 'BaseballPitch':
detectAct += 1
for index, actFrame in enumerate(imagedatalist):
for (resultBox, resultPose) in zip(resultlist[index], resultPoseList[index]):
# label = f'{resultPose[0]} {resultPose[1]:.2f}'
# label = f'{resultPose[0]}'
label = f'{resultPose[0]} {resultPose[1]:.2f} {resultBox[3] - resultBox[1]}'
# plot_one_box(resultBox, actFrame, label=label,
# color=[0, 0, 255], line_thickness=3)
plot_one_box(resultBox, actFrame, label=label)
plot_action(actFrame, action=actionLabel)
savepath = os.path.join('data/ucftest/baseball', '{0}_{5}_{1}_{2}_detected{3}_pose{4}.jpg'.format(
tsmPred, detectAct, index, detected[index], posedetected[index], suffix))
# savepath = os.path.join('data/falltest/all', '{0}_{1}_{2}_detected{3}_pose{4}.jpg'.format(
# detectOther, tsmPred, index, detected[index], posedetected[index]))
cv2.imwrite(savepath, actFrame)
else:
detectOther += 1
resultlist.clear()
resultPoseList.clear()
imagedatalist.clear()
detected.clear()
posedetected.clear()
proposalFlag = False
def det(i, videoPath):
cap = cv2.VideoCapture(videoPath)
fNUMS = cap.get(cv2.CAP_PROP_FRAME_COUNT)
fps = cap.get(cv2.CAP_PROP_FPS)
print("FPS is ", fps)
frameCount = 0
print("Starting...")
start = time.time()
while cap.isOpened():
ret, frame = cap.read()
if ret:
frameCount += 1
if frameCount % 6 == 0:
fallDetection(i, frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
cap.release()
end = time.time()
print("all done")
print("processing time : {0:.1f}, total : {1}, act : {2}".format(
float(end - start), detectAct + detectOther, detectAct))
if __name__ == "__main__":
os.system('rm -rf data/ucftest/baseball/*')
for i in [1, 2, 3, 5, 7, 9]:
videoPath = 'data/thumos14/baseball/video_validation_000068{0}.mp4'.format(i)
det(i, videoPath)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。