代码拉取完成,页面将自动刷新
import common
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import cv2
from model.DBFace import DBFace
HAS_CUDA = torch.cuda.is_available()
print(f"HAS_CUDA = {HAS_CUDA}")
def nms(objs, iou=0.5):
if objs is None or len(objs) <= 1:
return objs
objs = sorted(objs, key=lambda obj: obj.score, reverse=True)
keep = []
flags = [0] * len(objs)
for index, obj in enumerate(objs):
if flags[index] != 0:
continue
keep.append(obj)
for j in range(index + 1, len(objs)):
if flags[j] == 0 and obj.iou(objs[j]) > iou:
flags[j] = 1
return keep
def detect(model, image, threshold=0.4, nms_iou=0.5):
mean = [0.408, 0.447, 0.47]
std = [0.289, 0.274, 0.278]
image = common.pad(image)
image = ((image / 255.0 - mean) / std).astype(np.float32)
image = image.transpose(2, 0, 1)
torch_image = torch.from_numpy(image)[None]
if HAS_CUDA:
torch_image = torch_image.cuda()
hm, box, landmark = model(torch_image)
hm_pool = F.max_pool2d(hm, 3, 1, 1)
scores, indices = ((hm == hm_pool).float() * hm).view(1, -1).cpu().topk(1000)
hm_height, hm_width = hm.shape[2:]
scores = scores.squeeze()
indices = indices.squeeze()
ys = list((indices / hm_width).int().data.numpy())
xs = list((indices % hm_width).int().data.numpy())
scores = list(scores.data.numpy())
box = box.cpu().squeeze().data.numpy()
landmark = landmark.cpu().squeeze().data.numpy()
stride = 4
objs = []
for cx, cy, score in zip(xs, ys, scores):
if score < threshold:
break
x, y, r, b = box[:, cy, cx]
xyrb = (np.array([cx, cy, cx, cy]) + [-x, -y, r, b]) * stride
x5y5 = landmark[:, cy, cx]
x5y5 = (common.exp(x5y5 * 4) + ([cx]*5 + [cy]*5)) * stride
box_landmark = list(zip(x5y5[:5], x5y5[5:]))
objs.append(common.BBox(0, xyrb=xyrb, score=score, landmark=box_landmark))
return nms(objs, iou=nms_iou)
def detect_image(model, file):
image = common.imread(file)
objs = detect(model, image)
for obj in objs:
common.drawbbox(image, obj)
common.imwrite("detect_result/" + common.file_name_no_suffix(file) + ".draw.jpg", image)
def image_demo():
dbface = DBFace()
dbface.eval()
if HAS_CUDA:
dbface.cuda()
dbface.load("model/dbface.pth")
detect_image(dbface, "datas/selfie.jpg")
detect_image(dbface, "datas/12_Group_Group_12_Group_Group_12_728.jpg")
def camera_demo():
dbface = DBFace()
dbface.eval()
if HAS_CUDA:
dbface.cuda()
dbface.load("model/dbface.pth")
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
ok, frame = cap.read()
while ok:
objs = detect(dbface, frame)
for obj in objs:
common.drawbbox(frame, obj)
cv2.imshow("demo DBFace", frame)
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
break
ok, frame = cap.read()
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
image_demo()
camera_demo()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。