代码拉取完成,页面将自动刷新
from tensorflow.quantization import fake_quant_with_min_max_vars
from tensorflow.python.framework import graph_util
import math
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.disable_v2_behavior()
import torch
np.random.seed(0)
block_h, block_w = 2, 2
batch = 1
n, h, w, c = block_h*block_w*batch, 5, 5, 4
crops = [[2, 0], [2, 0]]
h_crop = block_h*h-crops[0][0]-crops[0][1]
w_crop = block_w*w-crops[1][0]-crops[1][1]
input = np.random.randint(0, 255, (n, h, w, c), dtype=np.int32)
print("np input: ", input.shape, input)
output = tf.batch_to_space_nd(input, [block_h, block_w], crops, name="test")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_gnd = output.eval(session=sess)
print("tf output shape: ", output_gnd.shape, "output data: ", output_gnd)
output1 = np.reshape(input, (block_h, block_w, batch, *input.shape[1:]))
output2 = np.transpose(output1, (2, 3, 0, 4, 1, 5))
output3 = np.reshape(output2, [batch, h*block_h, w*block_w, c])
output4 = output3[:,crops[0][0]:h_crop+crops[0][0],crops[1][0]:w_crop+crops[1][0],:]
print("np output shape: ", output4.shape, "output *data: ", output4)
print("tf.batch_to_space_nd vs np: ", np.allclose(output4, output_gnd))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。