1 Star 2 Fork 0

xinanXu/DScode

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
strassen.cpp 7.03 KB
一键复制 编辑 原始数据 按行查看 历史
xinanXu 提交于 2023-01-30 01:59 . early stage code
#include <iostream>
#include <cmath>
using namespace std;
// O(n^3)的矩阵乘法,维度较小时使用
int** MatrixMul(int** A, int** B, int n) {
int** ret = new int*[n];
for (int i = 0; i < n; i++) {
ret[i] = new int[n];
for (int j = 0; j < n; j++)
ret[i][j] = 0;
}
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
for (int k = 0; k < n; k++)
ret[i][j] += A[i][k] + B[k][i];
return ret;
}
// 将n阶方阵通过0填充转换为m阶方阵
int** Padding(int** A, int n, int m) {
int** ret = new int* [m];
for (int i = 0; i < m; i++) {
ret[i] = new int[m];
for (int j = 0; j < m; j++)
ret[i][j] = 0;
}
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
ret[i][j] = A[i][j];
return ret;
}
// 得到一个数字对于二次幂的大于等于的幂次
// 例如 4->2 5->3
int GetSquare(int n) {
int m = 1, i = 0;
while (n > m) {
m *= 2;
i++;
}
return i;
}
// 将A中x,y开始的n方阵替换为B方阵
void change(int** A, int** B, int x, int y, int n) {
for (int i = x; i < x + n; i++)
for (int j = y; j < y + n; j++)
A[i][j] = B[i - x][j - y];
}
int** Minus(int** A, int** B, int xA, int yA, int xB, int yB, int n) {
int** ret = new int* [n];
for (int i = 0; i < n; i++) ret[i] = new int[n];
for (int iA = xA, iB = xB; iA < xA + n; iA++, iB++)
for (int jA = yA, jB = yB; jA < yA + n; jA++, jB++)
ret[iA - xA][jA - yA] = A[iA][jA] - B[iB][jB];
return ret;
}
int** Plus(int** A, int** B, int xA, int yA, int xB, int yB, int n) {
int** ret = new int* [n];
for (int i = 0; i < n; i++) ret[i] = new int[n];
for (int iA = xA, iB = xB; iA < xA + n; iA++, iB++)
for (int jA = yA, jB = yB; jA < yA + n; jA++, jB++)
ret[iA - xA][jA - yA] = A[iA][jA] + B[iB][jB];
return ret;
}
// m 是矩阵原维度
int** Strassen(int** A, int** B, int x, int y, int xB, int yB, int n, int m = -1) {
if (m == -1) m = n;
if (m != n) {
A = Padding(A, m, n);
B = Padding(B, m, n);
}
int** ret = new int* [n];
for (int i = 0; i < n; i++)
ret[i] = new int[n];
// base
if (n == 1) {
ret[0][0] = A[x][y] * B[xB][yB];
return ret;
}
if (n > 2) {
// A11 * B11 + A12 * B21
int** t = Strassen(A, B, x, y, xB, yB, n / 2);
int** tempt = Strassen(A, B, x + n / 2, y, xB, yB + n / 2, n / 2);
t = Plus(t, tempt, 0, 0, 0, 0, n / 2);
change(ret, t, x, y, n / 2);
// A11 * B12 + A12 * B22
t = Strassen(A, B, x, y, xB + n / 2, yB, n / 2);
tempt = Strassen(A, B, x + n / 2, y, xB + n / 2, yB + n / 2, n / 2);
t = Plus(t, tempt, 0, 0, 0, 0, n / 2);
change(ret, t, x + n / 2, y, n / 2);
// A21 * B11 + A22 * B21
t = Strassen(A, B, x, y + n / 2, xB, yB, n / 2);
tempt = Strassen(A, B, x + n / 2, y + n / 2, xB, yB + n / 2, n / 2);
t = Plus(t, tempt, 0, 0, 0, 0, n / 2);
change(ret, t, x, y + n / 2, n / 2);
// A21 * B12 + A22 * B22
t = Strassen(A, B, x, y + n / 2, xB + n / 2, yB, n / 2);
tempt = Strassen(A, B, x + n / 2, y + n / 2, xB + n / 2, yB + n / 2, n / 2);
t = Plus(t, tempt, 0, 0, 0, 0, n / 2);
change(ret, t, x + n / 2, y + n / 2, n / 2);
}
int** S1 = Minus(B, B, xB, yB + n / 2, xB + n / 2, yB + n / 2, n / 2);
int** S2 = Plus(A, A, x, y, x, y + n / 2, n / 2);
int** S3 = Plus(A, A, x + n / 2, y, x + n / 2, y + n / 2, n / 2);
int** S4 = Minus(B, B, xB + n / 2, yB, xB, yB, n / 2);
int** S5 = Plus(A, A, x, y, x + n / 2, y + n / 2, n / 2);
int** S6 = Plus(B, B, xB, yB, xB + n / 2, yB + n / 2, n / 2);
int** S7 = Minus(A, A, x, y + n / 2, x + n / 2, y + n / 2, n / 2);
int** S8 = Plus(B, B, xB + n / 2, yB, xB + n / 2, yB + n / 2, n / 2);
int** S9 = Minus(A, A, x, y, x + n / 2, y, n / 2);
int** S10 = Plus(B, B, xB, yB, xB, yB + n / 2, n / 2);
int** P1 = Strassen(A, S1, x, y, 0, 0, n / 2);
int** P2 = Strassen(S2, B, 0, 0, xB + n / 2, yB + n / 2, n / 2);
int** P3 = Strassen(S3, B, 0, 0, xB, yB, n / 2);
int** P4 = Strassen(A, S4, x + n / 2, y + n / 2, 0, 0, n / 2);
int** P5 = Strassen(S5, S6, 0, 0, 0, 0, n / 2);
int** P6 = Strassen(S7, S8, 0, 0, 0, 0, n / 2);
int** P7 = Strassen(S9, S10, 0, 0, 0, 0, n / 2);
int** C11 = Plus(P5, P4, 0, 0, 0, 0, n / 2);
C11 = Minus(C11, P2, 0, 0, 0, 0, n / 2);
C11 = Plus(C11, P6, 0, 0, 0, 0, n / 2);
int** C12 = Plus(P1, P2, 0, 0, 0, 0, n / 2);
int** C21 = Plus(P3, P4, 0, 0, 0, 0, n / 2);
int** C22 = Plus(P5, P1, 0, 0, 0, 0, n / 2);
C22 = Minus(C22, P3, 0, 0, 0, 0, n / 2);
C22 = Minus(C22, P7, 0, 0, 0, 0, n / 2);
for (int i = 0; i < n / 2; i++)
for (int j = 0; j < n / 2; j++)
ret[i][j] = C11[i][j];
for (int i = n / 2; i < n; i++)
for (int j = 0; j < n / 2; j++)
ret[i][j] = C21[i - n / 2][j];
for (int i = 0; i < n / 2; i++)
for (int j = n / 2; j < n; j++)
ret[i][j] = C12[i][j - n / 2];
for (int i = n / 2; i < n; i++)
for (int j = n / 2; j < n; j++)
ret[i][j] = C22[i - n / 2][j - n / 2];
return ret;
}
int** StrassenMul(int** A, int** B, int n, bool useMatrixMul = false) {
if (useMatrixMul && n <= 32) {
int** ret = MatrixMul(A, B, n);
return ret;
}
int** ret = new int* [n];
for (int i = 0; i < n; i++)
ret[i] = new int[n];
int m = pow(2, GetSquare(n));
int** temp = Strassen(A, B, 0, 0, 0, 0, m, n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
ret[i][j] = temp[i][j];
return ret;
}
// 课后题
void test1() {
int** A = new int* [2];
int** B = new int* [2];
for (int i = 0; i < 2; i++) {
A[i] = new int[2];
B[i] = new int[2];
}
A[0][0] = 1, A[0][1] = 3, A[1][0] = 7, A[1][1] = 5;
B[0][0] = 6, B[0][1] = 8, B[1][0] = 4, B[1][1] = 2;
cout << "test1: A: \n";
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++)
cout << A[i][j] << " ";
cout << endl;
}
cout << "B: \n";
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++)
cout << B[i][j] << " ";
cout << endl;
}
cout << endl << "result:\n";
int** ret = StrassenMul(A, B, 2);
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++)
cout << ret[i][j] << " ";
cout << endl;
}
cout << endl;
}
// 测试非2的幂次的阶次的矩阵
void test2() {
int** A = new int* [3];
int** B = new int* [3];
for (int i = 0; i < 3; i++) {
A[i] = new int[3];
B[i] = new int[3];
for (int j = 0; j < 3; j++) {
A[i][j] = B[i][j] = 2;
}
}
cout << "test2: A B:\n";
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++)
cout << A[i][j] << " ";
cout << endl;
}
cout << endl << "result\n";
int** ret = StrassenMul(A, B, 3);
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++)
cout << ret[i][j] << " ";
cout << endl;
}
cout << endl;
}
void test3() {
int** A = new int* [4];
int** B = new int* [4];
for (int i = 0; i < 4; i++) {
A[i] = new int[4];
B[i] = new int[4];
for (int j = 0; j < 4; j++) {
A[i][j] = B[i][j] = 2;
}
}
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
A[i][j] = B[i][j] = 1;
cout << "test3: A B:\n";
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++)
cout << B[i][j] << " ";
cout << endl;
}
cout << endl << "result:\n";
int** ret = StrassenMul(A, B, 4);
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++)
cout << ret[i][j] << " ";
cout << endl;
}
cout << endl;
}
int main() {
test1();
test2();
test3();
return 0;
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/DearAtri/dscode.git
git@gitee.com:DearAtri/dscode.git
DearAtri
dscode
DScode
master

搜索帮助