1 Star 0 Fork 1

suifengpiao/onnx-matting-on-loongson

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
main.cpp 15.35 KB
一键复制 编辑 原始数据 按行查看 历史
#include <QCoreApplication>
#include <onnxruntime//onnxruntime_session_options_config_keys.h>
#include <onnxruntime/onnxruntime_cxx_api.h>
#include <QImage>
#include <QFile>
#include <QDebug>
#include <QDir>
#include <QElapsedTimer>
#include <QPainter>
#include <QProcess>
Ort::Env env;
int cutoutTest(const QString& modPath, const QImage& img, const QString& outPath)
{
QFile file(modPath);
file.open(QFile::ReadOnly);
QByteArray model = file.readAll();
Ort::SessionOptions session_options(nullptr);
Ort::Session session(env, model.data(), model.count(), session_options);
QElapsedTimer time;
time.start();
Ort::RunOptions runOpt(nullptr);
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
size_t num_input_nodes = session.GetInputCount();
size_t num_output_nodes = session.GetOutputCount();
Ort::AllocatorWithDefaultOptions allocator;
auto inputName = session.GetInputNameAllocated(0, allocator);
auto outputName = session.GetOutputNameAllocated(0, allocator);
//qInfo() << "Input node count:" << num_input_nodes << ">" << inputName.get() << ", output node count:" << num_output_nodes << ">" << outputName.get();
auto input_dims = session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
auto output_dims = session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
input_dims[0] = img.height();
input_dims[1] = img.width();
std::vector<const char*> input_node_names = { inputName.get() };
std::vector<const char*> output_node_names = { outputName.get() };
int pixCount = img.width() * img.height();
float* pixData = new float[pixCount * 3];
float* pixArry = pixData;
for (int y = 0; y < img.height(); ++y)
{
const uint8_t* pix = img.scanLine(y);
for (int x = 0; x < img.width(); ++x)
{
*pixArry++ = pix[0];
*pixArry++ = pix[1];
*pixArry++ = pix[2];
pix += 3;
}
}
std::vector<Ort::Value> inputData;
inputData.emplace_back( Ort::Value::CreateTensor<float>(memory_info, pixData, pixCount * 3, input_dims.data(), input_dims.size()));
auto output_tensors = session.Run(runOpt, input_node_names.data(), inputData.data(), inputData.size(), output_node_names.data(), output_node_names.size());
auto output_fmt = output_tensors[0].GetTensorTypeAndShapeInfo().GetShape();
delete []pixData;
auto outPix = output_tensors[0].GetTensorMutableData<uint8_t>();
int ms = time.elapsed();
QImage(outPix, output_fmt[1], output_fmt[0], QImage::Format_RGBA8888).save(outPath, nullptr, 100);
// QImage imgOut(output_fmt[1], output_fmt[0], QImage::Format_RGBA8888);
// pixCount = output_fmt[0] * output_fmt[1];
// uint32_t* pngPix = (uint32_t*)imgOut.bits();
// for (int i = 0; i < pixCount; ++i)
// {
// pngPix[i] = ((outPix[i] & 0xFF000000) == 0) ? 0 : outPix[i];
// }
// imgOut.save(outPath);
// auto outPix = output_tensors[0].GetTensorMutableData<float>();
// img = QImage(output_fmt[1], output_fmt[0], QImage::Format_Grayscale8);
// // pixCount = output_fmt[0] * output_fmt[1];
// for (int y = 0; y < img.height(); ++y)
// {
// uint8_t* pix = img.scanLine(y);
// float* png = (outPix + img.width() * y);
// for (int x = 0; x < img.width(); ++x)
// {
// pix[x] = png[x];
// }
// }
return ms;
}
int detectionTest(const QString& modPath, const QImage& img, const QString& outPath, std::vector<QRect>& outFaces)
{
QFile file(modPath);
file.open(QFile::ReadOnly);
QByteArray model = file.readAll();
Ort::SessionOptions session_options(nullptr);
Ort::Session session(env, model.data(), model.count(), session_options);
QElapsedTimer time;
time.start();
Ort::RunOptions runOpt(nullptr);
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
std::vector<const char*> input_node_names;
std::vector<const char*> output_node_names;
Ort::AllocatorWithDefaultOptions allocator;
size_t num_input_nodes = session.GetInputCount();
for (size_t i = 0; i < num_input_nodes; ++i )
{
auto inputName = session.GetInputNameAllocated(i, allocator);
input_node_names.push_back(strdup(inputName.get()));
}
size_t num_output_nodes = session.GetOutputCount();
for (size_t i = 0; i < num_output_nodes; ++i )
{
auto output_dims = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
auto outputName = session.GetOutputNameAllocated(i, allocator);
output_node_names.push_back(strdup(outputName.get()));
}
//qInfo() << "Input node count:" << num_input_nodes << ">" << inputName.get() << ", output node count:" << num_output_nodes << ">" << outputName.get();
auto input_dims = session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
QImage detImg;
if (img.width() > 1920 || img.height() > 1920 || img.width() % 4 != 0 )
{
QSize newSize = (img.width() <= 1920 && img.height() <= 1920) ? img.size() : img.size().scaled(1920, 1920, Qt::KeepAspectRatio);
if (newSize.width() % 4 != 0)
{
newSize = img.size().scaled(newSize.width() / 4 * 4, newSize.height(), Qt::KeepAspectRatio);
}
detImg = img.scaled(newSize);
}
else
{
detImg = img;
}
float scale = img.width() / float(detImg.width());
QPainter pnt(&detImg);
input_dims[0] = detImg.height();
input_dims[1] = detImg.width();
std::vector<Ort::Value> inputData;
inputData.emplace_back( Ort::Value::CreateTensor<uint8_t>(memory_info, detImg.bits(), detImg.sizeInBytes(), input_dims.data(), input_dims.size()));
auto output_tensors = session.Run(runOpt, input_node_names.data(), inputData.data(), inputData.size(), output_node_names.data(), output_node_names.size());
for (auto m : input_node_names)
{
if (m) free((void*)m);
}
for (auto m : output_node_names)
{
if (m) free((void*)m);
}
auto sco_shape = output_tensors[0].GetTensorTypeAndShapeInfo().GetShape();
auto box_shape = output_tensors[1].GetTensorTypeAndShapeInfo().GetShape();
auto box_data = output_tensors[1].GetTensorMutableData<int64_t>();
auto kps_shape = output_tensors[2].GetTensorTypeAndShapeInfo().GetShape();
auto kps_data = output_tensors[2].GetTensorMutableData<int64_t>();
auto img_shape = output_tensors[3].GetTensorTypeAndShapeInfo().GetShape();
auto img_data = output_tensors[3].GetTensorMutableData<uint8_t>();
auto lmk_shape = output_tensors[4].GetTensorTypeAndShapeInfo().GetShape();
auto lmk_data = output_tensors[4].GetTensorMutableData<int64_t>();
for (int i = 0; i < sco_shape[0]; ++i)
{
pnt.setPen(QPen(QBrush(QColor(250, 30, 75)), 3));
QRect rt(box_data[0], box_data[1], box_data[2] - box_data[0], box_data[3] - box_data[1]);
pnt.drawRect(rt);
rt.setRect(rt.x() * scale, rt.y() * scale, rt.width() * scale, rt.height() * scale);
outFaces.push_back(rt);
box_data += 4;
pnt.setPen(QPen(QBrush(QColor(240, 230, 33)), 2));
for (int j = 0; j < 106; ++j)
{
pnt.drawEllipse(lmk_data[0], lmk_data[1], 2, 2);
pnt.drawEllipse(lmk_data[0], lmk_data[1], 3, 3);
lmk_data += 2;
}
pnt.setPen(QPen(QBrush(QColor(40, 230, 50)), 2));
for (int j = 0; j < 5; ++j)
{
pnt.drawEllipse(kps_data[0], kps_data[1], 2, 2);
kps_data += 2;
}
// QImage(img_data, img_shape[1], img_shape[2], img_shape[1] * img_shape[3], QImage::Format_RGB888).save(QString("%1-face-%2.jpg").arg(outPath).arg(i));
// img_data += img_shape[1] * img_shape[2] * img_shape[3];
}
int ms = time.elapsed();
if (sco_shape[0])
detImg.save(outPath, nullptr, 100);
return ms;
}
int parsingTest(const QString& modPath, const QImage& img, const std::vector<QRect>& faces, const QString& outPath)
{
QColor partColors[] = {{255, 255, 255, 255}, {255, 85, 0, 255}, {255, 170, 0, 255},
{255, 0, 85, 255}, {255, 0, 170, 255},
{0, 255, 0, 255}, {85, 255, 0, 255}, {170, 255, 0, 255},
{0, 255, 85, 255}, {0, 255, 170, 255},
{0, 0, 255, 255}, {85, 0, 255, 255}, {170, 0, 255, 255},
{0, 85, 255, 255}, {0, 170, 255, 255},
{255, 255, 0, 255}, {255, 255, 85, 255}, {255, 255, 170, 255},
{255, 0, 255, 255}, {255, 85, 255, 255}, {255, 170, 255, 255},
{0, 255, 255, 255}, {85, 255, 255, 255}, {170, 255, 255, 255}};
QFile file(modPath);
file.open(QFile::ReadOnly);
QByteArray model = file.readAll();
Ort::SessionOptions session_options(nullptr);
Ort::Session session(env, model.data(), model.count(), session_options);
QElapsedTimer time;
time.start();
Ort::RunOptions runOpt(nullptr);
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
size_t num_input_nodes = session.GetInputCount();
size_t num_output_nodes = session.GetOutputCount();
Ort::AllocatorWithDefaultOptions allocator;
auto inputName = session.GetInputNameAllocated(0, allocator);
auto outputName = session.GetOutputNameAllocated(0, allocator);
//qInfo() << "Input node count:" << num_input_nodes << ">" << inputName.get() << ", output node count:" << num_output_nodes << ">" << outputName.get();
auto input_dims = session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
auto output_dims = session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
std::vector<const char*> input_node_names = { inputName.get() };
std::vector<const char*> output_node_names = { outputName.get() };
int ms = time.elapsed();
for (int i = 0; i < faces.size(); ++i)
{
time.restart();
QPoint center = faces[i].center();
int radius = std::max(faces[i].width(), faces[i].height());
QRect rtFace;
rtFace.setLeft(std::max(center.x() - radius, 0));
rtFace.setRight(std::min(center.x() + radius - 1, img.width() - 1));
rtFace.setTop(std::max(center.y() - radius, 0));
rtFace.setBottom(std::min(center.y() + radius - 1, img.height() - 1));
center = rtFace.center();
radius = std::min(rtFace.width(), rtFace.height()) / 2;
rtFace.setLeft(std::max(center.x() - radius, 0));
rtFace.setRight(std::min(center.x() + radius - 1, img.width() - 1));
rtFace.setTop(std::max(center.y() - radius, 0));
rtFace.setBottom(std::min(center.y() + radius - 1, img.height() - 1));
QImage img512(input_dims[3], input_dims[2], QImage::Format_RGB888);
QPainter pnt(&img512);
pnt.drawImage(QRect(0,0,input_dims[3], input_dims[2]), img, rtFace);
int pixCount = img512.width() * img512.height();
float* pixData = new float[pixCount * 3];
float* pixArrR = pixData;
float* pixArrG = pixData + pixCount;
float* pixArrB = pixData + pixCount * 2;
for (int y = 0; y < img512.height(); ++y)
{
const uint8_t* pix = img512.scanLine(y);
for (int x = 0; x < img512.width(); ++x)
{
*pixArrR++ = (pix[0] / 255.0f - 0.485f) / 0.229;
*pixArrG++ = (pix[1] / 255.0f - 0.456f) / 0.224;
*pixArrB++ = (pix[2] / 255.0f - 0.406f) / 0.225;
pix += 3;
}
}
std::vector<Ort::Value> inputData;
inputData.emplace_back( Ort::Value::CreateTensor<float>(memory_info, pixData, pixCount * 3, input_dims.data(), input_dims.size()));
delete []pixData;
auto output_tensors = session.Run(runOpt, input_node_names.data(), inputData.data(), inputData.size(), output_node_names.data(), output_node_names.size());
auto outPix = output_tensors[0].GetTensorMutableData<float>();
for (int y = 0; y < output_dims[2]; ++y)
{
uint8_t* pix = img512.scanLine(y);
for (int x = 0; x < output_dims[3] ; ++x)
{
float maxv = -99;
int imax = 0;
for (int p = 0; p < output_dims[1] ; ++p)
{
float v = outPix[x + 512 * 512 * p];
if (v > maxv)
{
maxv = v;
imax = p;
}
}
pix[2] = partColors[imax].red();
pix[1] = partColors[imax].green();
pix[0] = partColors[imax].blue();
pix += 3;
}
outPix += 512;
}
ms += time.elapsed();
img512.save(outPath + QString("-face-%1.jpg").arg(i + 1), nullptr, 100);
}
return ms;
}
QString getCpuName()
{
QString cpuName = "Unknow CPU";
//在UOS/Deepin上,虽然能用cat命令显示/proc/cpuinfo的内容,但这个文件却是空文件,直接打开什么都没有。
QProcess proc;
proc.setProgram("lscpu");
proc.start(QProcess::ReadOnly);
proc.waitForFinished(60000);
QByteArray cpuInfo = proc.readAll();
int i = cpuInfo.indexOf("Model name");
if (i > 0)
{
int j = cpuInfo.indexOf('\n', i + 1);
if (j > 0)
{
cpuName = cpuInfo.mid(i, j - i).split(':')[1].trimmed();
}
}
return cpuName;
}
int main(int argc, char *argv[])
{
QCoreApplication a(argc, argv);
if (argc == 2)
{
QElapsedTimer time;
time.start();
QString filePath = QDir::fromNativeSeparators(argv[1]);
QImage img(filePath);
if (img.isNull())
{
qCritical() << "加载图像文件失败!";
return 1;
}
if (img.format() != QImage::Format_RGB888)
{
img = img.convertToFormat(QImage::Format_RGB888);
}
env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "OnnxModel");
QString modPath_matting = QCoreApplication::applicationDirPath() + "/matting-tensorflow.onnx";
QString modPath_detection = QCoreApplication::applicationDirPath() + "/detection-insightface.onnx";
QString modPath_parsing = QCoreApplication::applicationDirPath() + "/parsing-pytorch.onnx";
int pos = filePath.lastIndexOf('.');
filePath = (pos >= 0) ? filePath.left(pos) : filePath;
QString outPath_matting = filePath + "-matting.png";
QString outPath_detection = filePath + "-detection.jpg";
QString outPath_parsing = filePath + "-parsing";
std::vector<QRect> faces;
qInfo().noquote() << "CPU:" << getCpuName();
int ms_detection = detectionTest(modPath_detection, img, outPath_detection, faces);
qInfo() << "人脸检测和关键点标注耗时:" << ms_detection << "ms," << faces.size() << "张脸。";
int ms_parsing = parsingTest(modPath_parsing, img, faces, outPath_parsing);
qInfo() << "人脸区域分割和标注耗时:" << ms_parsing << "ms," << faces.size() << "张脸。";
int ms_matting = cutoutTest(modPath_matting, img, outPath_matting);
qInfo() << "半透明分割人像与背景耗时:" << ms_matting << "ms。" ;
qInfo().noquote() << "总耗时(含文件读写):" << time.elapsed() << QString("ms,图像分辨率:%1x%2").arg(img.width()).arg(img.height());
}
return 0;//a.exec();
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
C++
1
https://gitee.com/piaoddang/onnx-matting-on-loongson.git
git@gitee.com:piaoddang/onnx-matting-on-loongson.git
piaoddang
onnx-matting-on-loongson
onnx-matting-on-loongson
master

搜索帮助