代码拉取完成,页面将自动刷新
from tensorflow.quantization import fake_quant_with_min_max_vars
from tensorflow.python.framework import graph_util
import math
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.disable_v2_behavior()
import torch
np.random.seed(0)
n, h, w, c = 1, 50, 50, 16
pad_h, pad_w = 2, 2
block_h, block_w = 2, 2
input = np.random.randint(0, 255, (n, h, w, c), dtype=np.int32)
print("np input: ", input)
output = tf.space_to_batch_nd(input, [block_h, block_w], [[pad_h, pad_h], [pad_w, pad_w]], name="test")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_gnd = output.eval(session=sess)
print("tf output shape: ", output_gnd.shape, "output data: ", output_gnd)
output1 = np.pad(input, [[0, 0], [pad_h, pad_h], [pad_w, pad_w], [0, 0]])
output2 = np.reshape(output1, [n, (h+pad_h*2)//block_h, block_h, (w+pad_w*2)//block_w, block_w, c])
output3 = np.transpose(output2, (2, 4, 0, 1, 3, 5))
shape = output3.shape
output4 = np.reshape(output3, (np.prod(shape[0:2]), *shape[3:]))
print("np output shape: ", output4.shape, "output data: ", output4)
print("tf.space_to_batch_nd vs np: ", np.allclose(output4, output_gnd))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。