代码拉取完成,页面将自动刷新
# https://github.com/dorthyluu/cs194-winograd
import math
import numpy as np
import torch
from tensorflow.quantization import fake_quant_with_min_max_vars
from tensorflow.python.framework import graph_util
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.disable_v2_behavior()
m = 2
r = 3
alpha = m + r - 1
H = 8
W = 8
K = 2
C = 3
N = 2
P = N * math.ceil(H / m) * math.ceil(W / m)
G = np.array([
[1, 0, 0],
[0.5, 0.5, 0.5],
[0.5, -0.5, 0.5],
[0, 0, 1]
])
B = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]
])
A = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]
])
D = np.array([
[
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
],
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
],
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
]
],
[
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
],
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
],
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
]
]
])
g = np.array([
[
[
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]
],
[
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]
],
[
[3, 3, 3],
[3, 3, 3],
[3, 3, 3]
],
],
[
[
[4, 4, 4],
[4, 4, 4],
[4, 4, 4]
],
[
[5, 5, 5],
[5, 5, 5],
[5, 5, 5]
],
[
[6, 6, 6],
[6, 6, 6],
[6, 6, 6]
]
]
])
U = np.empty(shape=(alpha, alpha, K, C))
for k in range(K):
for c in range(C):
u = G.dot(g[k][c]).dot(G.T)
# scatter
for xi in range(alpha):
for nu in range(alpha):
U[xi][nu][k][c] = u[xi][nu]
V = np.empty(shape=(alpha, alpha, C, P))
for i in range(N):
for c in range(C):
for y in range(math.ceil(H / m)):
for x in range(math.ceil(W / m)):
d = D[i, c, y * m:y * m + alpha, x * m:x * m + alpha]
v = B.T.dot(d).dot(B)
b = i * (math.ceil(H / m) * math.ceil(W / m)) + y * (math.ceil(W / m)) + x
for xi in range(alpha):
for nu in range(alpha):
V[xi][nu][c][b] = v[xi][nu]
M = np.empty(shape=(alpha, alpha, K, P))
for xi in range(alpha):
for nu in range(alpha):
M[xi][nu] = U[xi][nu].dot(V[xi][nu])
Y = np.empty(shape=(N, K, H, W))
temp_m = np.empty(shape=(alpha, alpha))
for i in range(N):
for k in range(K):
for y in range(math.ceil(H / m)):
for x in range(math.ceil(W / m)):
b = i * (math.ceil(H / m) * math.ceil(W / m)) + y * (math.ceil(W / m)) + x
# gather
for xi in range(alpha):
for nu in range(alpha):
temp_m[xi][nu] = M[xi][nu][k][b]
Y[i, k, y * m:y * m + m, x * m:x * m + m] = A.T.dot(temp_m).dot(A)
for i in range(N):
for k in range(K):
for y in range(math.ceil(H / m)):
for x in range(math.ceil(W / m)):
b = i * (math.ceil(H / m) * math.ceil(W / m)) + y * (math.ceil(W / m)) + x
print(k, b)
print(Y[i, k, y * m:y * m + m, x * m:x * m + m])
################################################################
# tf ground truth data generate
H = 8
W = 8
K = 2
C = 3
N = 2
# np.random.seed(0)
output = tf.nn.conv2d(D.transpose([0, 2, 3, 1]).astype('float'), g.transpose([2, 3, 1, 0]), strides=[1, 1, 1, 1], padding="VALID")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_gnd = output.eval(session=sess).transpose([0, 3, 1, 2])
print("tf output shape: ", output_gnd.shape, "output data: ", output_gnd)
assert np.allclose(Y, output_gnd)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。