代码拉取完成,页面将自动刷新
from tensorflow.quantization import fake_quant_with_min_max_vars
from tensorflow.python.framework import graph_util
import math
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.keras.layers import dense_attention
np.random.seed(0)
input_shape = (
[1, 5, 5, 6], [1, 4, 4, 4], [1, 4, 4, 4], [1, 4, 4, 4], [1, 3, 3, 5], [1, 112, 112, 24], [1, 112, 112, 96],
[1, 56, 56, 144])
filter_shape = (
[3, 3, 6, 1], [1, 1, 4, 1], [2, 2, 4, 1], [3, 3, 4, 1], [3, 3, 5, 1], [3, 3, 24, 1], [3, 3, 96, 1], [3, 3, 24, 1])
stride_shape = (
[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [2, 2, 2, 2], [1, 1, 1, 1])
pad_mode = ('SAME', 'SAME', 'VALID', 'VALID', 'SAME', 'SAME', 'SAME', 'SAME',)
case_idx = 0
input = np.random.rand(*input_shape[case_idx]).astype(dtype=np.float32)
filter = np.random.rand(*list(reversed(filter_shape[case_idx])))
print("np input: ", input.shape, input)
print("np filter ms layout: ", filter.shape, filter)
def Depthwise():
output_data = []
for i in range(input_shape[case_idx][1]):
for j in range(input_shape[case_idx][2]):
for c in range(input_shape[case_idx][3]):
val = 0
for kh in range(filter_shape[case_idx][0]):
for kw in range(filter_shape[case_idx][1]):
if ((i + kh) < input_shape[case_idx][1]) and ((j + kw) < input_shape[case_idx][2]):
val += filter[kh, kw, c] * input[0, i + kh, j + kw, c]
# print("mul {} {} {} {} {}: {} * {} = {}".format(i, j, c, kh, kw, filter[kh, kw, c], input[0, i+kh, j+kw, c], val))
output_data.append(val)
return output_data
filter = filter.transpose(0, 2, 3, 1).reshape(filter_shape[case_idx])
print("np filter tf layout: ", filter.shape, filter)
output = tf.nn.depthwise_conv2d_native(input,
filter,
strides=stride_shape[case_idx],
padding=pad_mode[case_idx],
data_format="NHWC",
dilations=[1, 1, 1, 1])
output_np = Depthwise()
print("numpy output: ", output_np)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_gnd = output.eval(session=sess)
print("tf output shape: ", output_gnd.shape, "output data: ", output_gnd)
input_file = "../run_lite/models/caffe_depthwise_conv2d_001.ms.bin"
output_file = "../run_lite/models/caffe_depthwise_conv2d_001.ms.out"
with open(input_file, 'wb') as fo:
fo.write(input.astype(np.float32, copy=False))
with open(output_file, 'w') as text_file:
output_np = output_gnd.transpose(0, 3, 1, 2)
flatten_data = np.squeeze(output_np).flatten()
size = flatten_data.shape[0]
text = 'nhwc2nchw_depthwise_conv2d_post2' + ' ' + str(len(output_np.shape)) + ' '
for i in range(len(output_np.shape)):
text += str(output_np.shape[i]) + ' '
text_file.write(text + '\n')
for i in range(size):
text_file.write(str((flatten_data[i])) + ' ')
text_file.write('\n')
print("output_data: ", output_np)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。