代码拉取完成,页面将自动刷新
import numpy as np
import split
import time
def test_matmul():
W = np.random.randint(1, 100, size=(128, 64)).astype(np.float32)
I = np.random.randint(1, 100, size=(64, 128)).astype(np.float32)
O = np.zeros((128, 128), dtype=np.float32)
O_mask = generate_mask(8, 16, 4, 1)
cubes_list = split.split(np.ones(W.shape, dtype=bool), np.ones(I.shape, dtype=bool), O_mask)
# print(cubes_list)
sliced_matmul = module.Module(module.KernelModule(block_size=16))
sliced_matmul(W, I, O, cubes_list)
# print(O.shape)
O2 = np.dot(W, I)
O2[np.logical_not(O_mask)] = 0
# print(O2)
print(np.sum(O - O2))
print(sliced_matmul.time)
def batch_matmul_np(W, I):
O2 = []
for i in range(W.shape[0]):
O2.append(np.dot(W[i], I[i]))
return np.array(O2, dtype=W.dtype)
def test_batch_matmul_N(N):
time_logger = [0] * 3
for _ in range(N):
t = test_batch_matmul()
for j in range(len(time_logger)):
time_logger[j] += t[j]
for j in range(len(time_logger)):
time_logger[j] /= N
print('--------average time for {} runs--------'.format(N))
print('numpy time:{}ms'.format(time_logger[0]))
print('split: {}ms'.format(time_logger[1]))
print('non-split: {}ms'.format(time_logger[2]))
def test_batch_matmul():
batch = 64
num_head = 16
seq_len = 512
hidden_size = 64
W = np.random.randint(1, 100, size=(batch, num_head, seq_len, hidden_size)).astype(np.float32)
I = np.random.randint(1, 100, size=(batch, num_head, hidden_size, seq_len)).astype(np.float32)
O1 = np.zeros((batch*num_head, seq_len, seq_len), dtype=np.float32)
O_mask = generate_mask(seq_len // 16, 16, 4, 1)
#split
cubes_list = split.split(np.ones(W.shape[-2:], dtype=bool), np.ones(I.shape[-2:], dtype=bool), O_mask)
print(cubes_list)
sliced_batch_matmul = module.Module(module.KernelModule(block_size=16))
W.shape = (batch*num_head, seq_len, hidden_size)
I.shape = (batch*num_head, hidden_size, seq_len)
sliced_batch_matmul(W, I, O1, cubes_list)
#non-split
O2 = np.zeros((batch*num_head, seq_len, seq_len), dtype=np.float32)
cubes_list_2 = cubes_list = split.split(np.ones(W.shape[-2:], dtype=bool), np.ones(I.shape[-2:], dtype=bool), np.ones(O_mask.shape, dtype=bool))
# print(cubes_list_2)
sliced_batch_matmul_2 = module.Module(module.KernelModule(block_size=16))
sliced_batch_matmul_2(W, I, O2, cubes_list_2)
O2[:,np.logical_not(O_mask)] = 0
#numpy.batch_dot
# O2 = batch_matmul_np(W, I)
st = time.time()
O3 = np.einsum("bij, bjk -> bik", W, I)
O3[:,np.logical_not(O_mask)] = 0
ed = time.time()
t1 = ed - st
print(np.sum(O1 - O3))
print(np.sum(O2 - O3))
# print('numpy time:{}ms'.format(t1*1000))
# print('split: {}ms'.format(sliced_batch_matmul.time))
# print('non-split: {}ms'.format(sliced_batch_matmul_2.time))
return [t1*1000, sliced_batch_matmul.time, sliced_batch_matmul_2.time]
# print(sliced_batch_matmul.time)
if __name__ == '__main__':
# test_matmul()
test_batch_matmul_N(1)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。