1 Star 0 Fork 0

seekerrc/actiondet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
ucf_TSM_Module.py 6.00 KB
一键复制 编辑 原始数据 按行查看 历史
seekerrc 提交于 2022-04-12 16:47 . reorganize structure
import numpy as np
import cv2
import os
import time
from ops.models import TSN
import torch
import torchvision
from PIL import Image
from torch.nn import functional as F
from ops.transforms import *
import threading
def alertAction(image_list):
proc_start_time = time.time()
image_list_PIL = []
for i in range(len(image_list)):
img = Image.fromarray(cv2.cvtColor(image_list[i], cv2.COLOR_BGR2RGB))
image_list_PIL.append(img)
transform = torchvision.transforms.Compose([
cropping,
Stack(roll=False),
ToTorchFormatTensor(div=True),
GroupNormalize(input_mean, input_std)
])
image_all = transform(image_list_PIL)
decode_crop_time = time.time() - proc_start_time
print('crop time: {:.3f}'.format(float(decode_crop_time)))
net.eval()
proc_start_time = time.time()
results = eval_video(image_all, net, test_segments, modality)
cnt_time = time.time() - proc_start_time
print('{:.3f} sec/video'.format(float(cnt_time)))
results = results[0].tolist()
classIndex = results.index(max(results))
print("The return result: " + classnames[classIndex])
print('probability:', results[classIndex])
return classnames[classIndex], results[classIndex]*100
def eval_video(video_data, net, this_test_segments, modality):
with torch.no_grad():
data = video_data
batch_size = 1
num_crop = test_crops
if dense_sample:
num_crop *= 10 # 10 clips for testing when using dense sample
if twice_sample:
num_crop *= 2
if modality == 'RGB':
length = 3
elif modality == 'Flow':
length = 10
elif modality == 'RGBDiff':
length = 18
else:
raise ValueError("Unknown modality " + modality)
print("data_in shape:")
print(data.shape)
data_in = data.view(-1, length, data.size(1), data.size(2))
if is_shift:
data_in = data_in.view(
batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
rst = net(data_in)
rst = rst.reshape(batch_size, num_crop, -1).mean(1)
if True:
# take the softmax to normalize the output to probability
rst = F.softmax(rst, dim=1)
rst = rst.data.cpu().numpy().copy()
if net.module.is_shift:
rst = rst.reshape(batch_size, num_class)
else:
rst = rst.reshape((batch_size, -1, num_class)
).mean(axis=1).reshape((batch_size, num_class))
return rst
def parse_shift_option_from_log_name(log_name):
if 'shift' in log_name:
strings = log_name.split('_')
for i, s in enumerate(strings):
if 'shift' in s:
break
return True, int(strings[i].replace('shift', '')), strings[i + 1]
else:
return False, None, None
# # 和thumos给的索引差1,用thumos的话记得减掉1,
with open('./data/ucf101/classInd.txt') as f:
lines = f.readlines()
# categories = [item.rstrip() for item in lines]
categories = [item.rstrip().split(' ')[1] for item in lines]
# print(categories)
weight_file = './weights/TSM_ucf101_RGB_resnet50_shift8_blockres_avg_segment8_e25/ckpt.best.pth.tar'
num_class = 101
test_segments = 8
img_feature_dim = 256
test_crops = 1
full_res = False
pretrain = 'imagenet'
dense_sample = False
twice_sample = False
softmax = False
classnames = categories
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
gpus = ['0']
workers = 1
scale_size = 720
input_mean = [0.485, 0.456, 0.406]
input_std = [0.229, 0.224, 0.225]
is_shift, shift_div, shift_place = parse_shift_option_from_log_name(weight_file)
if 'RGB' in weight_file:
modality = 'RGB'
else:
modality = 'Flow'
this_arch = weight_file.split('TSM_')[1].split('_')[2]
print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
net = TSN(num_class, test_segments if is_shift else 1, modality,
base_model=this_arch,
consensus_type='avg',
img_feature_dim=img_feature_dim,
pretrain=pretrain,
is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
non_local='_nl' in weight_file,
)
if 'tpool' in weight_file:
from ops.temporal_shift import make_temporal_pool
make_temporal_pool(net.base_model, test_segments) # since DataParallel
checkpoint = torch.load(weight_file)
checkpoint = checkpoint['state_dict']
# base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
'base_model.classifier.bias': 'new_fc.bias',
}
for k, v in replace_dict.items():
if k in base_dict:
base_dict[v] = base_dict.pop(k)
net.load_state_dict(base_dict)
input_size = net.scale_size if full_res else net.input_size
if test_crops == 1:
cropping = torchvision.transforms.Compose([
GroupScale(net.scale_size),
# GroupCenterCrop(input_size),
])
elif test_crops == 3: # do not flip, so only 5 crops
cropping = torchvision.transforms.Compose([
GroupFullResSample(input_size, net.scale_size, flip=False)
])
elif test_crops == 5: # do not flip, so only 5 crops
cropping = torchvision.transforms.Compose([
GroupOverSample(input_size, net.scale_size, flip=False)
])
elif test_crops == 10:
cropping = torchvision.transforms.Compose([
GroupOverSample(input_size, net.scale_size)
])
else:
raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crops))
if gpus is not None:
devices = [gpus[i] for i in range(workers)]
else:
devices = list(range(workers))
net = torch.nn.DataParallel(net.cuda())
net.eval()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/seekerrc/actiondet.git
git@gitee.com:seekerrc/actiondet.git
seekerrc
actiondet
actiondet
master

搜索帮助