1 Star 0 Fork 0

gvraky/DBFace

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 3.29 KB
一键复制 编辑 原始数据 按行查看 历史
chenj 提交于 2020-05-27 19:34 . finish ncnn dbface
import common
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import cv2
from model.DBFace import DBFace
HAS_CUDA = torch.cuda.is_available()
print(f"HAS_CUDA = {HAS_CUDA}")
def nms(objs, iou=0.5):
if objs is None or len(objs) <= 1:
return objs
objs = sorted(objs, key=lambda obj: obj.score, reverse=True)
keep = []
flags = [0] * len(objs)
for index, obj in enumerate(objs):
if flags[index] != 0:
continue
keep.append(obj)
for j in range(index + 1, len(objs)):
if flags[j] == 0 and obj.iou(objs[j]) > iou:
flags[j] = 1
return keep
def detect(model, image, threshold=0.4, nms_iou=0.5):
mean = [0.408, 0.447, 0.47]
std = [0.289, 0.274, 0.278]
image = common.pad(image)
image = ((image / 255.0 - mean) / std).astype(np.float32)
image = image.transpose(2, 0, 1)
torch_image = torch.from_numpy(image)[None]
if HAS_CUDA:
torch_image = torch_image.cuda()
hm, box, landmark = model(torch_image)
hm_pool = F.max_pool2d(hm, 3, 1, 1)
scores, indices = ((hm == hm_pool).float() * hm).view(1, -1).cpu().topk(1000)
hm_height, hm_width = hm.shape[2:]
scores = scores.squeeze()
indices = indices.squeeze()
ys = list((indices / hm_width).int().data.numpy())
xs = list((indices % hm_width).int().data.numpy())
scores = list(scores.data.numpy())
box = box.cpu().squeeze().data.numpy()
landmark = landmark.cpu().squeeze().data.numpy()
stride = 4
objs = []
for cx, cy, score in zip(xs, ys, scores):
if score < threshold:
break
x, y, r, b = box[:, cy, cx]
xyrb = (np.array([cx, cy, cx, cy]) + [-x, -y, r, b]) * stride
x5y5 = landmark[:, cy, cx]
x5y5 = (common.exp(x5y5 * 4) + ([cx]*5 + [cy]*5)) * stride
box_landmark = list(zip(x5y5[:5], x5y5[5:]))
objs.append(common.BBox(0, xyrb=xyrb, score=score, landmark=box_landmark))
return nms(objs, iou=nms_iou)
def detect_image(model, file):
image = common.imread(file)
objs = detect(model, image)
for obj in objs:
common.drawbbox(image, obj)
common.imwrite("detect_result/" + common.file_name_no_suffix(file) + ".draw.jpg", image)
def image_demo():
dbface = DBFace()
dbface.eval()
if HAS_CUDA:
dbface.cuda()
dbface.load("model/dbface.pth")
detect_image(dbface, "datas/selfie.jpg")
detect_image(dbface, "datas/12_Group_Group_12_Group_Group_12_728.jpg")
def camera_demo():
dbface = DBFace()
dbface.eval()
if HAS_CUDA:
dbface.cuda()
dbface.load("model/dbface.pth")
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
ok, frame = cap.read()
while ok:
objs = detect(dbface, frame)
for obj in objs:
common.drawbbox(frame, obj)
cv2.imshow("demo DBFace", frame)
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
break
ok, frame = cap.read()
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
image_demo()
camera_demo()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/gvraky/DBFace.git
git@gitee.com:gvraky/DBFace.git
gvraky
DBFace
DBFace
master

搜索帮助