代码拉取完成,页面将自动刷新
# https://github.com/MartinThoma/matrix-multiplication/tree/master/Python
# Core Library modules
from math import ceil, log
import numpy as np
import time
def read(filename):
lines = open(filename).read().splitlines()
A = []
B = []
matrix = A
for line in lines:
if line != "":
matrix.append([int(el) for el in line.split("\t")])
else:
matrix = B
return A, B
def print_matrix(matrix):
for line in matrix:
print("\t".join(map(str, line)))
def ikj_matrix_product(A, B):
n = len(A)
C = [[0 for i in range(n)] for j in range(n)]
for i in range(n):
for k in range(n):
for j in range(n):
C[i][j] += A[i][k] * B[k][j]
return C
def add(A, B):
n = len(A)
C = [[0 for j in range(0, n)] for i in range(0, n)]
for i in range(0, n):
for j in range(0, n):
C[i][j] = A[i][j] + B[i][j]
return C
def subtract(A, B):
n = len(A)
C = [[0 for j in range(0, n)] for i in range(0, n)]
for i in range(0, n):
for j in range(0, n):
C[i][j] = A[i][j] - B[i][j]
return C
def strassenR(A, B):
"""
Implementation of the strassen algorithm, similar to
http://en.wikipedia.org/w/index.php?title=Strassen_algorithm&oldid=498910018#Source_code_of_the_Strassen_algorithm_in_C_language
"""
n = len(A)
if n <= LEAF_SIZE:
return ikj_matrix_product(A, B)
else:
# initializing the new sub-matrices
new_size = n // 2
a11 = [[0 for j in range(0, new_size)] for i in range(0, new_size)]
a12 = [[0 for j in range(0, new_size)] for i in range(0, new_size)]
a21 = [[0 for j in range(0, new_size)] for i in range(0, new_size)]
a22 = [[0 for j in range(0, new_size)] for i in range(0, new_size)]
b11 = [[0 for j in range(0, new_size)] for i in range(0, new_size)]
b12 = [[0 for j in range(0, new_size)] for i in range(0, new_size)]
b21 = [[0 for j in range(0, new_size)] for i in range(0, new_size)]
b22 = [[0 for j in range(0, new_size)] for i in range(0, new_size)]
aResult = [[0 for j in range(0, new_size)] for i in range(0, new_size)]
bResult = [[0 for j in range(0, new_size)] for i in range(0, new_size)]
# dividing the matrices in 4 sub-matrices:
for i in range(0, new_size):
for j in range(0, new_size):
a11[i][j] = A[i][j] # top left
a12[i][j] = A[i][j + new_size] # top right
a21[i][j] = A[i + new_size][j] # bottom left
a22[i][j] = A[i + new_size][j + new_size] # bottom right
b11[i][j] = B[i][j] # top left
b12[i][j] = B[i][j + new_size] # top right
b21[i][j] = B[i + new_size][j] # bottom left
b22[i][j] = B[i + new_size][j + new_size] # bottom right
# Calculating p1 to p7:
aResult = add(a11, a22)
bResult = add(b11, b22)
p1 = strassenR(aResult, bResult) # p1 = (a11+a22) * (b11+b22)
aResult = add(a21, a22) # a21 + a22
p2 = strassenR(aResult, b11) # p2 = (a21+a22) * (b11)
bResult = subtract(b12, b22) # b12 - b22
p3 = strassenR(a11, bResult) # p3 = (a11) * (b12 - b22)
bResult = subtract(b21, b11) # b21 - b11
p4 = strassenR(a22, bResult) # p4 = (a22) * (b21 - b11)
aResult = add(a11, a12) # a11 + a12
p5 = strassenR(aResult, b22) # p5 = (a11+a12) * (b22)
aResult = subtract(a21, a11) # a21 - a11
bResult = add(b11, b12) # b11 + b12
p6 = strassenR(aResult, bResult) # p6 = (a21-a11) * (b11+b12)
aResult = subtract(a12, a22) # a12 - a22
bResult = add(b21, b22) # b21 + b22
p7 = strassenR(aResult, bResult) # p7 = (a12-a22) * (b21+b22)
# calculating c21, c21, c11 e c22:
c12 = add(p3, p5) # c12 = p3 + p5
c21 = add(p2, p4) # c21 = p2 + p4
aResult = add(p1, p4) # p1 + p4
bResult = add(aResult, p7) # p1 + p4 + p7
c11 = subtract(bResult, p5) # c11 = p1 + p4 - p5 + p7
aResult = add(p1, p3) # p1 + p3
bResult = add(aResult, p6) # p1 + p3 + p6
c22 = subtract(bResult, p2) # c22 = p1 + p3 - p2 + p6
# Grouping the results obtained in a single matrix:
C = [[0 for j in range(0, n)] for i in range(0, n)]
for i in range(0, new_size):
for j in range(0, new_size):
C[i][j] = c11[i][j]
C[i][j + new_size] = c12[i][j]
C[i + new_size][j] = c21[i][j]
C[i + new_size][j + new_size] = c22[i][j]
return C
def strassen(A, B):
assert len(A) == len(A[0]) == len(B) == len(B[0])
# Make the matrices bigger so that you can apply the strassen
# algorithm recursively without having to deal with odd
# matrix sizes
nextPowerOfTwo = lambda n: 2 ** int(ceil(log(n, 2)))
n = len(A)
m = nextPowerOfTwo(n)
APrep = [[0 for i in range(m)] for j in range(m)]
BPrep = [[0 for i in range(m)] for j in range(m)]
for i in range(n):
for j in range(n):
APrep[i][j] = A[i][j]
BPrep[i][j] = B[i][j]
CPrep = strassenR(APrep, BPrep)
C = [[0 for i in range(n)] for j in range(n)]
for i in range(n):
for j in range(n):
C[i][j] = CPrep[i][j]
return C
def direct_multiply(A, B):
size = len(A)
C = np.zeros((size, size), dtype="i").tolist()
for row in range(0, size):
for col in range(0, size):
for k in range(0, size):
C[row][col] += A[row][k] * B[k][col]
return C
if __name__ == "__main__":
LEAF_SIZE = 32
size_list = [2 ** 5, 2 ** 6, 2 ** 7, 2 ** 8, 2 ** 9]
for size in size_list:
A = np.random.rand(size, size)
B = np.random.rand(size, size)
start = time.perf_counter()
C_gnd = direct_multiply(A, B)
end = time.perf_counter()
print("direct time:%sms" % ((end - start) * 1000))
start = time.perf_counter()
C = strassen(A, B)
end = time.perf_counter()
assert np.allclose(C, C_gnd)
print("strassen time:%sms" % ((end - start) * 1000))
# print_matrix(C)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。