代码拉取完成,页面将自动刷新
import numpy as np
import torch
import torchvision.models as models
import onnxruntime as ort
model_path = "../models/Q_face_recognition.onnx"
input_file = "../models/Q_face_recognition.onnx.ms.bin"
output_file = "../models/Q_face_recognition.onnx.ms.out"
N, C, H, W = 1, 3, 112, 112
input_shape = N, C, H, W
input = torch.rand(N, C, H, W)
# input = np.random.randint(0, 255, input_shape).astype(dtype=np.float32)
input = np.random.rand(*input_shape).astype(dtype=np.float32)
input_name = 'input'
with open(input_file, 'wb') as fo:
fo.write(input.transpose(0, 2, 3, 1).copy().astype(np.float32, copy=False))
torch.backends.quantized.engine = 'qnnpack'
session = ort.InferenceSession(model_path)
outputs = session.run(None, {input_name: input})
with open(output_file, 'w') as text_file:
for i in range(len(outputs)):
output_data = outputs[i]
flatten_data = np.squeeze(output_data).flatten()
size = flatten_data.shape[0]
# text = output_detail['name'] + ' ' + str(len(output_detail['shape'])) + ' '
text = 'Reshape_output' + ' ' + str(len(output_data.shape)) + ' '
for i in range(len(output_data.shape)):
text += str(output_data.shape[i]) + ' '
text_file.write(text + '\n')
for i in range(size):
text_file.write(str((flatten_data[i])) + ' ')
text_file.write('\n')
print("output_data: ", output_data)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。