1 Star 0 Fork 0

wandongdong/tflite2opencl

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test_batch_to_space_nd.py 1.22 KB
一键复制 编辑 原始数据 按行查看 历史
wandongdong 提交于 2020-11-06 00:56 . fix batch_to_space bug
from tensorflow.quantization import fake_quant_with_min_max_vars
from tensorflow.python.framework import graph_util
import math
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.disable_v2_behavior()
import torch
np.random.seed(0)
block_h, block_w = 2, 2
batch = 1
n, h, w, c = block_h*block_w*batch, 5, 5, 4
crops = [[2, 0], [2, 0]]
h_crop = block_h*h-crops[0][0]-crops[0][1]
w_crop = block_w*w-crops[1][0]-crops[1][1]
input = np.random.randint(0, 255, (n, h, w, c), dtype=np.int32)
print("np input: ", input.shape, input)
output = tf.batch_to_space_nd(input, [block_h, block_w], crops, name="test")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_gnd = output.eval(session=sess)
print("tf output shape: ", output_gnd.shape, "output data: ", output_gnd)
output1 = np.reshape(input, (block_h, block_w, batch, *input.shape[1:]))
output2 = np.transpose(output1, (2, 3, 0, 4, 1, 5))
output3 = np.reshape(output2, [batch, h*block_h, w*block_w, c])
output4 = output3[:,crops[0][0]:h_crop+crops[0][0],crops[1][0]:w_crop+crops[1][0],:]
print("np output shape: ", output4.shape, "output *data: ", output4)
print("tf.batch_to_space_nd vs np: ", np.allclose(output4, output_gnd))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ddwsky/tflite2opencl.git
git@gitee.com:ddwsky/tflite2opencl.git
ddwsky
tflite2opencl
tflite2opencl
master

搜索帮助