1 Star 0 Fork 0

wandongdong/tflite2opencl

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test_depthwise.py 3.15 KB
一键复制 编辑 原始数据 按行查看 历史
wandongdong 提交于 2021-01-16 02:59 . update depthwise
from tensorflow.quantization import fake_quant_with_min_max_vars
from tensorflow.python.framework import graph_util
import math
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.keras.layers import dense_attention
np.random.seed(0)
input_shape = (
[1, 5, 5, 6], [1, 4, 4, 4], [1, 4, 4, 4], [1, 4, 4, 4], [1, 3, 3, 5], [1, 112, 112, 24], [1, 112, 112, 96],
[1, 56, 56, 144])
filter_shape = (
[3, 3, 6, 1], [1, 1, 4, 1], [2, 2, 4, 1], [3, 3, 4, 1], [3, 3, 5, 1], [3, 3, 24, 1], [3, 3, 96, 1], [3, 3, 24, 1])
stride_shape = (
[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [2, 2, 2, 2], [1, 1, 1, 1])
pad_mode = ('SAME', 'SAME', 'VALID', 'VALID', 'SAME', 'SAME', 'SAME', 'SAME',)
case_idx = 0
input = np.random.rand(*input_shape[case_idx]).astype(dtype=np.float32)
filter = np.random.rand(*list(reversed(filter_shape[case_idx])))
print("np input: ", input.shape, input)
print("np filter ms layout: ", filter.shape, filter)
def Depthwise():
output_data = []
for i in range(input_shape[case_idx][1]):
for j in range(input_shape[case_idx][2]):
for c in range(input_shape[case_idx][3]):
val = 0
for kh in range(filter_shape[case_idx][0]):
for kw in range(filter_shape[case_idx][1]):
if ((i + kh) < input_shape[case_idx][1]) and ((j + kw) < input_shape[case_idx][2]):
val += filter[kh, kw, c] * input[0, i + kh, j + kw, c]
# print("mul {} {} {} {} {}: {} * {} = {}".format(i, j, c, kh, kw, filter[kh, kw, c], input[0, i+kh, j+kw, c], val))
output_data.append(val)
return output_data
filter = filter.transpose(0, 2, 3, 1).reshape(filter_shape[case_idx])
print("np filter tf layout: ", filter.shape, filter)
output = tf.nn.depthwise_conv2d_native(input,
filter,
strides=stride_shape[case_idx],
padding=pad_mode[case_idx],
data_format="NHWC",
dilations=[1, 1, 1, 1])
output_np = Depthwise()
print("numpy output: ", output_np)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_gnd = output.eval(session=sess)
print("tf output shape: ", output_gnd.shape, "output data: ", output_gnd)
input_file = "../run_lite/models/caffe_depthwise_conv2d_001.ms.bin"
output_file = "../run_lite/models/caffe_depthwise_conv2d_001.ms.out"
with open(input_file, 'wb') as fo:
fo.write(input.astype(np.float32, copy=False))
with open(output_file, 'w') as text_file:
output_np = output_gnd.transpose(0, 3, 1, 2)
flatten_data = np.squeeze(output_np).flatten()
size = flatten_data.shape[0]
text = 'nhwc2nchw_depthwise_conv2d_post2' + ' ' + str(len(output_np.shape)) + ' '
for i in range(len(output_np.shape)):
text += str(output_np.shape[i]) + ' '
text_file.write(text + '\n')
for i in range(size):
text_file.write(str((flatten_data[i])) + ' ')
text_file.write('\n')
print("output_data: ", output_np)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ddwsky/tflite2opencl.git
git@gitee.com:ddwsky/tflite2opencl.git
ddwsky
tflite2opencl
tflite2opencl
master

搜索帮助