1 Star 0 Fork 0

wandongdong/tflite2opencl

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
run_onnx_model.py 1.38 KB
一键复制 编辑 原始数据 按行查看 历史
wandongdong 提交于 2021-01-20 23:09 . add run_onnx_model
import numpy as np
import torch
import torchvision.models as models
import onnxruntime as ort
model_path = "../models/Q_face_recognition.onnx"
input_file = "../models/Q_face_recognition.onnx.ms.bin"
output_file = "../models/Q_face_recognition.onnx.ms.out"
N, C, H, W = 1, 3, 112, 112
input_shape = N, C, H, W
input = torch.rand(N, C, H, W)
# input = np.random.randint(0, 255, input_shape).astype(dtype=np.float32)
input = np.random.rand(*input_shape).astype(dtype=np.float32)
input_name = 'input'
with open(input_file, 'wb') as fo:
fo.write(input.transpose(0, 2, 3, 1).copy().astype(np.float32, copy=False))
torch.backends.quantized.engine = 'qnnpack'
session = ort.InferenceSession(model_path)
outputs = session.run(None, {input_name: input})
with open(output_file, 'w') as text_file:
for i in range(len(outputs)):
output_data = outputs[i]
flatten_data = np.squeeze(output_data).flatten()
size = flatten_data.shape[0]
# text = output_detail['name'] + ' ' + str(len(output_detail['shape'])) + ' '
text = 'Reshape_output' + ' ' + str(len(output_data.shape)) + ' '
for i in range(len(output_data.shape)):
text += str(output_data.shape[i]) + ' '
text_file.write(text + '\n')
for i in range(size):
text_file.write(str((flatten_data[i])) + ' ')
text_file.write('\n')
print("output_data: ", output_data)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ddwsky/tflite2opencl.git
git@gitee.com:ddwsky/tflite2opencl.git
ddwsky
tflite2opencl
tflite2opencl
master

搜索帮助