1 Star 0 Fork 0

原水衣人/sparsity_compiler

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test_sparse_transformer.py 3.07 KB
一键复制 编辑 原始数据 按行查看 历史
import numpy as np
import split
import time
def test_matmul():
W = np.random.randint(1, 100, size=(128, 64)).astype(np.float32)
I = np.random.randint(1, 100, size=(64, 128)).astype(np.float32)
O = np.zeros((128, 128), dtype=np.float32)
O_mask = generate_mask(8, 16, 4, 1)
cubes_list = split.split(np.ones(W.shape, dtype=bool), np.ones(I.shape, dtype=bool), O_mask)
# print(cubes_list)
sliced_matmul = module.Module(module.KernelModule(block_size=16))
sliced_matmul(W, I, O, cubes_list)
# print(O.shape)
O2 = np.dot(W, I)
O2[np.logical_not(O_mask)] = 0
# print(O2)
print(np.sum(O - O2))
print(sliced_matmul.time)
def batch_matmul_np(W, I):
O2 = []
for i in range(W.shape[0]):
O2.append(np.dot(W[i], I[i]))
return np.array(O2, dtype=W.dtype)
def test_batch_matmul_N(N):
time_logger = [0] * 3
for _ in range(N):
t = test_batch_matmul()
for j in range(len(time_logger)):
time_logger[j] += t[j]
for j in range(len(time_logger)):
time_logger[j] /= N
print('--------average time for {} runs--------'.format(N))
print('numpy time:{}ms'.format(time_logger[0]))
print('split: {}ms'.format(time_logger[1]))
print('non-split: {}ms'.format(time_logger[2]))
def test_batch_matmul():
batch = 64
num_head = 16
seq_len = 512
hidden_size = 64
W = np.random.randint(1, 100, size=(batch, num_head, seq_len, hidden_size)).astype(np.float32)
I = np.random.randint(1, 100, size=(batch, num_head, hidden_size, seq_len)).astype(np.float32)
O1 = np.zeros((batch*num_head, seq_len, seq_len), dtype=np.float32)
O_mask = generate_mask(seq_len // 16, 16, 4, 1)
#split
cubes_list = split.split(np.ones(W.shape[-2:], dtype=bool), np.ones(I.shape[-2:], dtype=bool), O_mask)
print(cubes_list)
sliced_batch_matmul = module.Module(module.KernelModule(block_size=16))
W.shape = (batch*num_head, seq_len, hidden_size)
I.shape = (batch*num_head, hidden_size, seq_len)
sliced_batch_matmul(W, I, O1, cubes_list)
#non-split
O2 = np.zeros((batch*num_head, seq_len, seq_len), dtype=np.float32)
cubes_list_2 = cubes_list = split.split(np.ones(W.shape[-2:], dtype=bool), np.ones(I.shape[-2:], dtype=bool), np.ones(O_mask.shape, dtype=bool))
# print(cubes_list_2)
sliced_batch_matmul_2 = module.Module(module.KernelModule(block_size=16))
sliced_batch_matmul_2(W, I, O2, cubes_list_2)
O2[:,np.logical_not(O_mask)] = 0
#numpy.batch_dot
# O2 = batch_matmul_np(W, I)
st = time.time()
O3 = np.einsum("bij, bjk -> bik", W, I)
O3[:,np.logical_not(O_mask)] = 0
ed = time.time()
t1 = ed - st
print(np.sum(O1 - O3))
print(np.sum(O2 - O3))
# print('numpy time:{}ms'.format(t1*1000))
# print('split: {}ms'.format(sliced_batch_matmul.time))
# print('non-split: {}ms'.format(sliced_batch_matmul_2.time))
return [t1*1000, sliced_batch_matmul.time, sliced_batch_matmul_2.time]
# print(sliced_batch_matmul.time)
if __name__ == '__main__':
# test_matmul()
test_batch_matmul_N(1)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/MondayYuan/sparsity_compiler.git
git@gitee.com:MondayYuan/sparsity_compiler.git
MondayYuan
sparsity_compiler
sparsity_compiler
main

搜索帮助