import cv2
import numpy as np
import torch
from torchvision.transforms import transforms
from models.hrnet import HRNet
from models.poseresnet import PoseResNet
from models.detectors.YOLOv3 import YOLOv3
class SimpleHRNet:
SimpleHRNet class.
The class provides a simple and customizable method to load the HRNet network, load the official pre-trained
weights, and predict the human pose on single images.
Multi-person support with the YOLOv3 detector is also included (and enabled by default).
def __init__(self,
resolution=(384, 288),
Initializes a new SimpleHRNet object.
HRNet (and YOLOv3) are initialized on the torch.device("device") and
its (their) pre-trained weights will be loaded from disk.
c (int): number of channels (when using HRNet model) or resnet size (when using PoseResNet model).
nof_joints (int): number of joints.
checkpoint_path (str): path to an official hrnet checkpoint or a checkpoint obtained with `train_coco.py`.
model_name (str): model name (HRNet or PoseResNet).
Valid names for HRNet are: `HRNet`, `hrnet`
Valid names for PoseResNet are: `PoseResNet`, `poseresnet`, `ResNet`, `resnet`
Default: "HRNet"
resolution (tuple): hrnet input resolution - format: (height, width).
Default: (384, 288)
interpolation (int): opencv interpolation algorithm.
Default: cv2.INTER_CUBIC
multiperson (bool): if True, multiperson detection will be enabled.
This requires the use of a people detector (like YOLOv3).
Default: True
return_heatmaps (bool): if True, heatmaps will be returned along with poses by self.predict.
Default: False
return_bounding_boxes (bool): if True, bounding boxes will be returned along with poses by self.predict.
Default: False
max_batch_size (int): maximum batch size used in hrnet inference.
Useless without multiperson=True.
Default: 16
yolo_model_def (str): path to yolo model definition file.
Default: "./models/detectors/yolo/config/yolov3.cfg"
yolo_class_path (str): path to yolo class definition file.
Default: "./models/detectors/yolo/data/coco.names"
yolo_weights_path (str): path to yolo pretrained weights file.
Default: "./models/detectors/yolo/weights/yolov3.weights.cfg"
device (:class:`torch.device`): the hrnet (and yolo) inference will be run on this device.
Default: torch.device("cpu")
self.c = c
self.nof_joints = nof_joints
self.checkpoint_path = checkpoint_path
self.model_name = model_name
self.resolution = resolution # in the form (height, width) as in the original implementation
self.interpolation = interpolation
self.multiperson = multiperson
self.return_heatmaps = return_heatmaps
self.return_bounding_boxes = return_bounding_boxes
self.max_batch_size = max_batch_size
self.yolo_model_def = yolo_model_def
self.yolo_class_path = yolo_class_path
self.yolo_weights_path = yolo_weights_path
self.device = device
if model_name in ('HRNet', 'hrnet'):
self.model = HRNet(c=c, nof_joints=nof_joints)
elif model_name in ('PoseResNet', 'poseresnet', 'ResNet', 'resnet'):
self.model = PoseResNet(resnet_size=c, nof_joints=nof_joints)
raise ValueError('Wrong model name.')
checkpoint = torch.load(checkpoint_path, map_location=self.device)
if 'model' in checkpoint:
if 'cuda' in str(self.device):
print("device: 'cuda' - ", end="")
if 'cuda' == str(self.device):
# if device is set to 'cuda', all available GPUs will be used
print("%d GPU(s) will be used" % torch.cuda.device_count())
device_ids = None
# if device is set to 'cuda:IDS', only that/those device(s) will be used
print("GPU(s) '%s' will be used" % str(self.device))
device_ids = [int(x) for x in str(self.device)[5:].split(',')]
self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
elif 'cpu' == str(self.device):
print("device: 'cpu'")
raise ValueError('Wrong device name.')
self.model = self.model.to(device)
if not self.multiperson:
self.transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
self.detector = YOLOv3(model_def=yolo_model_def,
self.transform = transforms.Compose([
transforms.Resize((self.resolution[0], self.resolution[1])), # (height, width)
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
def predict(self, image):
Predicts the human pose on a single image or a stack of n images.
image (:class:`np.ndarray`):
the image(s) on which the human pose will be estimated.
image is expected to be in the opencv format.
image can be:
- a single image with shape=(height, width, BGR color channel)
- a stack of n images with shape=(n, height, width, BGR color channel)
:class:`np.ndarray` or list:
a numpy array containing human joints for each (detected) person.
if image is a single image:
shape=(# of people, # of joints (nof_joints), 3); dtype=(np.float32).
if image is a stack of n images:
list of n np.ndarrays with
shape=(# of people, # of joints (nof_joints), 3); dtype=(np.float32).
Each joint has 3 values: (y position, x position, joint confidence).
If self.return_heatmaps, the class returns a list with (heatmaps, human joints)
If self.return_bounding_boxes, the class returns a list with (bounding boxes, human joints)
If self.return_heatmaps and self.return_bounding_boxes, the class returns a list with
(heatmaps, bounding boxes, human joints)
if len(image.shape) == 3:
return self._predict_single(image)
elif len(image.shape) == 4:
return self._predict_batch(image)
raise ValueError('Wrong image format.')
def _predict_single(self, image):
if not self.multiperson:
old_res = image.shape
if self.resolution is not None:
image = cv2.resize(
(self.resolution[1], self.resolution[0]), # (width, height)
images = self.transform(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).unsqueeze(dim=0)
boxes = np.asarray([[0, 0, old_res[1], old_res[0]]], dtype=np.float32) # [x1, y1, x2, y2]
heatmaps = np.zeros((1, self.nof_joints, self.resolution[0] // 4, self.resolution[1] // 4),
detections = self.detector.predict_single(image)
nof_people = len(detections) if detections is not None else 0
boxes = np.empty((nof_people, 4), dtype=np.int32)
images = torch.empty((nof_people, 3, self.resolution[0], self.resolution[1])) # (height, width)
heatmaps = np.zeros((nof_people, self.nof_joints, self.resolution[0] // 4, self.resolution[1] // 4),
if detections is not None:
for i, (x1, y1, x2, y2, conf, cls_conf, cls_pred) in enumerate(detections):
x1 = int(round(x1.item()))
x2 = int(round(x2.item()))
y1 = int(round(y1.item()))
y2 = int(round(y2.item()))
# Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14)
correction_factor = self.resolution[0] / self.resolution[1] * (x2 - x1) / (y2 - y1)
if correction_factor > 1:
# increase y side
center = y1 + (y2 - y1) // 2
length = int(round((y2 - y1) * correction_factor))
y1 = max(0, center - length // 2)
y2 = min(image.shape[0], center + length // 2)
elif correction_factor < 1:
# increase x side
center = x1 + (x2 - x1) // 2
length = int(round((x2 - x1) * 1 / correction_factor))
x1 = max(0, center - length // 2)
x2 = min(image.shape[1], center + length // 2)
boxes[i] = [x1, y1, x2, y2]
images[i] = self.transform(image[y1:y2, x1:x2, ::-1])
if images.shape[0] > 0:
images = images.to(self.device)
with torch.no_grad():
if len(images) <= self.max_batch_size:
out = self.model(images)
out = torch.empty(
(images.shape[0], self.nof_joints, self.resolution[0] // 4, self.resolution[1] // 4),
for i in range(0, len(images), self.max_batch_size):
out[i:i + self.max_batch_size] = self.model(images[i:i + self.max_batch_size])
out = out.detach().cpu().numpy()
pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
# For each human, for each joint: y, x, confidence
for i, human in enumerate(out):
heatmaps[i] = human
for j, joint in enumerate(human):
pt = np.unravel_index(np.argmax(joint), (self.resolution[0] // 4, self.resolution[1] // 4))
# 0: pt_y / (height // 4) * (bb_y2 - bb_y1) + bb_y1
# 1: pt_x / (width // 4) * (bb_x2 - bb_x1) + bb_x1
# 2: confidences
pts[i, j, 0] = pt[0] * 1. / (self.resolution[0] // 4) * (boxes[i][3] - boxes[i][1]) + boxes[i][1]
pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (boxes[i][2] - boxes[i][0]) + boxes[i][0]
pts[i, j, 2] = joint[pt]
pts = np.empty((0, 0, 3), dtype=np.float32)
res = list()
if self.return_heatmaps:
if self.return_bounding_boxes:
if len(res) > 1:
return res
return res[0]
def _predict_batch(self, images):
if not self.multiperson:
old_res = images[0].shape
if self.resolution is not None:
images_tensor = torch.empty(images.shape[0], 3, self.resolution[0], self.resolution[1])
images_tensor = torch.empty(images.shape[0], 3, images.shape[1], images.shape[2])
for i, image in enumerate(images):
if self.resolution is not None:
image = cv2.resize(
(self.resolution[1], self.resolution[0]), # (width, height)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
images_tensor[i] = self.transform(image)
images = images_tensor
boxes = np.repeat(
np.asarray([[0, 0, old_res[1], old_res[0]]], dtype=np.float32), len(images), axis=0
) # [x1, y1, x2, y2]
heatmaps = np.zeros((len(images), self.nof_joints, self.resolution[0] // 4, self.resolution[1] // 4),
image_detections = self.detector.predict(images)
base_index = 0
nof_people = int(np.sum([len(d) for d in image_detections if d is not None]))
boxes = np.empty((nof_people, 4), dtype=np.int32)
images_tensor = torch.empty((nof_people, 3, self.resolution[0], self.resolution[1])) # (height, width)
heatmaps = np.zeros((nof_people, self.nof_joints, self.resolution[0] // 4, self.resolution[1] // 4),
for d, detections in enumerate(image_detections):
image = images[d]
if detections is not None and len(detections) > 0:
for i, (x1, y1, x2, y2, conf, cls_conf, cls_pred) in enumerate(detections):
x1 = int(round(x1.item()))
x2 = int(round(x2.item()))
y1 = int(round(y1.item()))
y2 = int(round(y2.item()))
# Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14)
correction_factor = self.resolution[0] / self.resolution[1] * (x2 - x1) / (y2 - y1)
if correction_factor > 1:
# increase y side
center = y1 + (y2 - y1) // 2
length = int(round((y2 - y1) * correction_factor))
y1 = max(0, center - length // 2)
y2 = min(image.shape[0], center + length // 2)
elif correction_factor < 1:
# increase x side
center = x1 + (x2 - x1) // 2
length = int(round((x2 - x1) * 1 / correction_factor))
x1 = max(0, center - length // 2)
x2 = min(image.shape[1], center + length // 2)
boxes[base_index + i] = [x1, y1, x2, y2]
images_tensor[base_index + i] = self.transform(image[y1:y2, x1:x2, ::-1])
base_index += len(detections)
images = images_tensor
images = images.to(self.device)
if images.shape[0] > 0:
with torch.no_grad():
if len(images) <= self.max_batch_size:
out = self.model(images)
out = torch.empty(
(images.shape[0], self.nof_joints, self.resolution[0] // 4, self.resolution[1] // 4),
for i in range(0, len(images), self.max_batch_size):
out[i:i + self.max_batch_size] = self.model(images[i:i + self.max_batch_size])
out = out.detach().cpu().numpy()
pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
# For each human, for each joint: y, x, confidence
for i, human in enumerate(out):
heatmaps[i] = human
for j, joint in enumerate(human):
pt = np.unravel_index(np.argmax(joint), (self.resolution[0] // 4, self.resolution[1] // 4))
# 0: pt_y / (height // 4) * (bb_y2 - bb_y1) + bb_y1
# 1: pt_x / (width // 4) * (bb_x2 - bb_x1) + bb_x1
# 2: confidences
pts[i, j, 0] = pt[0] * 1. / (self.resolution[0] // 4) * (boxes[i][3] - boxes[i][1]) + boxes[i][1]
pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (boxes[i][2] - boxes[i][0]) + boxes[i][0]
pts[i, j, 2] = joint[pt]
if self.multiperson:
# re-add the removed batch axis (n)
if self.return_heatmaps:
heatmaps_batch = []
if self.return_bounding_boxes:
boxes_batch = []
pts_batch = []
index = 0
for detections in image_detections:
if detections is not None:
pts_batch.append(pts[index:index + len(detections)])
if self.return_heatmaps:
heatmaps_batch.append(heatmaps[index:index + len(detections)])
if self.return_bounding_boxes:
boxes_batch.append(boxes[index:index + len(detections)])
index += len(detections)
pts_batch.append(np.zeros((0, self.nof_joints, 3), dtype=np.float32))
if self.return_heatmaps:
heatmaps_batch.append(np.zeros((0, self.nof_joints, self.resolution[0] // 4,
self.resolution[1] // 4), dtype=np.float32))
if self.return_bounding_boxes:
boxes_batch.append(np.zeros((0, 4), dtype=np.float32))
if self.return_heatmaps:
heatmaps = heatmaps_batch
if self.return_bounding_boxes:
boxes = boxes_batch
pts = pts_batch
pts = np.expand_dims(pts, axis=1)
boxes = np.asarray([], dtype=np.int32)
if self.multiperson:
pts = []
for _ in range(len(image_detections)):
pts.append(np.zeros((0, self.nof_joints, 3), dtype=np.float32))
raise ValueError # should never happen
res = list()
if self.return_heatmaps:
if self.return_bounding_boxes:
if len(res) > 1:
return res
return res[0]
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。