1 Star 0 Fork 0

wandongdong/tflite2opencl

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test_space_to_batch_nd.py 1.16 KB
一键复制 编辑 原始数据 按行查看 历史
wandongdong 提交于 2020-10-31 00:28 . init push
from tensorflow.quantization import fake_quant_with_min_max_vars
from tensorflow.python.framework import graph_util
import math
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.disable_v2_behavior()
import torch
np.random.seed(0)
n, h, w, c = 1, 50, 50, 16
pad_h, pad_w = 2, 2
block_h, block_w = 2, 2
input = np.random.randint(0, 255, (n, h, w, c), dtype=np.int32)
print("np input: ", input)
output = tf.space_to_batch_nd(input, [block_h, block_w], [[pad_h, pad_h], [pad_w, pad_w]], name="test")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_gnd = output.eval(session=sess)
print("tf output shape: ", output_gnd.shape, "output data: ", output_gnd)
output1 = np.pad(input, [[0, 0], [pad_h, pad_h], [pad_w, pad_w], [0, 0]])
output2 = np.reshape(output1, [n, (h+pad_h*2)//block_h, block_h, (w+pad_w*2)//block_w, block_w, c])
output3 = np.transpose(output2, (2, 4, 0, 1, 3, 5))
shape = output3.shape
output4 = np.reshape(output3, (np.prod(shape[0:2]), *shape[3:]))
print("np output shape: ", output4.shape, "output data: ", output4)
print("tf.space_to_batch_nd vs np: ", np.allclose(output4, output_gnd))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ddwsky/tflite2opencl.git
git@gitee.com:ddwsky/tflite2opencl.git
ddwsky
tflite2opencl
tflite2opencl
master

搜索帮助