/******************************************************************* * * M4RI: Linear Algebra over GF(2) * * Copyright (C) 2008 Martin Albrecht * Copyright (C) 2008 Clement Pernet * Copyright (C) 2008 Marco Bodrato * * Distributed under the terms of the GNU General Public License (GPL) * version 2 or higher. * * This code is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * General Public License for more details. * * The full text of the GPL is available at: * * http://www.gnu.org/licenses/ * ********************************************************************/ #ifdef HAVE_CONFIG_H #include "config.h" #endif #include "graycode.h" #include "strassen.h" #include "parity.h" #ifndef MIN #define MIN(a,b) (((a)<(b))?(a):(b)) #endif #if __M4RI_HAVE_OPENMP #include #endif // Returns true if a is closer to cutoff than a/2. static inline int closer(rci_t a, int cutoff) { return 3 * a < 4 * cutoff; } mzd_t *_mzd_mul_even(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) { rci_t mmm, kkk, nnn; if(C->nrows == 0 || C->ncols == 0) return C; rci_t m = A->nrows; rci_t k = A->ncols; rci_t n = B->ncols; /* handle case first, where the input matrices are too small already */ if (closer(m, cutoff) || closer(k, cutoff) || closer(n, cutoff)) { /* we copy the matrices first since it is only constant memory overhead and improves data locality */ if(mzd_is_windowed(A)|mzd_is_windowed(B)|mzd_is_windowed(C)) { mzd_t *Abar = mzd_copy(NULL, A); mzd_t *Bbar = mzd_copy(NULL, B); mzd_t *Cbar = mzd_init(m, n); _mzd_mul_m4rm(Cbar, Abar, Bbar, 0, FALSE); mzd_copy(C, Cbar); mzd_free(Cbar); mzd_free(Bbar); mzd_free(Abar); } else { _mzd_mul_m4rm(C, A, B, 0, TRUE); } return C; } /* adjust cutting numbers to work on words */ rci_t mult = m4ri_radix; rci_t width = MIN(MIN(m, n), k) / 2; while (width > cutoff) { width /= 2; mult *= 2; } mmm = (((m - m % mult) / m4ri_radix) >> 1) * m4ri_radix; kkk = (((k - k % mult) / m4ri_radix) >> 1) * m4ri_radix; nnn = (((n - n % mult) / m4ri_radix) >> 1) * m4ri_radix; /* |A | |B | |C | * Compute | | x | | = | | */ { mzd_t const *A11 = mzd_init_window_const(A, 0, 0, mmm, kkk); mzd_t const *A12 = mzd_init_window_const(A, 0, kkk, mmm, 2*kkk); mzd_t const *A21 = mzd_init_window_const(A, mmm, 0, 2*mmm, kkk); mzd_t const *A22 = mzd_init_window_const(A, mmm, kkk, 2*mmm, 2*kkk); mzd_t const *B11 = mzd_init_window_const(B, 0, 0, kkk, nnn); mzd_t const *B12 = mzd_init_window_const(B, 0, nnn, kkk, 2*nnn); mzd_t const *B21 = mzd_init_window_const(B, kkk, 0, 2*kkk, nnn); mzd_t const *B22 = mzd_init_window_const(B, kkk, nnn, 2*kkk, 2*nnn); mzd_t *C11 = mzd_init_window(C, 0, 0, mmm, nnn); mzd_t *C12 = mzd_init_window(C, 0, nnn, mmm, 2*nnn); mzd_t *C21 = mzd_init_window(C, mmm, 0, 2*mmm, nnn); mzd_t *C22 = mzd_init_window(C, mmm, nnn, 2*mmm, 2*nnn); /** * \note See Marco Bodrato; "A Strassen-like Matrix Multiplication * Suited for Squaring and Highest Power Computation"; * http://bodrato.it/papres/#CIVV2008 for reference on the used * sequence of operations. */ /* change this to mzd_init(mmm, MAX(nnn,kkk)) to fix the todo below */ mzd_t *Wmk = mzd_init(mmm, kkk); mzd_t *Wkn = mzd_init(kkk, nnn); _mzd_add(Wkn, B22, B12); /* Wkn = B22 + B12 */ _mzd_add(Wmk, A22, A12); /* Wmk = A22 + A12 */ _mzd_mul_even(C21, Wmk, Wkn, cutoff);/* C21 = Wmk * Wkn */ _mzd_add(Wmk, A22, A21); /* Wmk = A22 - A21 */ _mzd_add(Wkn, B22, B21); /* Wkn = B22 - B21 */ _mzd_mul_even(C22, Wmk, Wkn, cutoff);/* C22 = Wmk * Wkn */ _mzd_add(Wkn, Wkn, B12); /* Wkn = Wkn + B12 */ _mzd_add(Wmk, Wmk, A12); /* Wmk = Wmk + A12 */ _mzd_mul_even(C11, Wmk, Wkn, cutoff);/* C11 = Wmk * Wkn */ _mzd_add(Wmk, Wmk, A11); /* Wmk = Wmk - A11 */ _mzd_mul_even(C12, Wmk, B12, cutoff);/* C12 = Wmk * B12 */ _mzd_add(C12, C12, C22); /* C12 = C12 + C22 */ /** * \todo ideally we would use the same Wmk throughout the function * but some called function doesn't like that and we end up with a * wrong result if we use virtual Wmk matrices. Ideally, this should * be fixed not worked around. The check whether the bug has been * fixed, use only one Wmk and check if mzd_mul(4096, 3528, * 4096, 2124) still returns the correct answer. */ mzd_free(Wmk); Wmk = mzd_mul(NULL, A12, B21, cutoff);/*Wmk = A12 * B21 */ _mzd_add(C11, C11, Wmk); /* C11 = C11 + Wmk */ _mzd_add(C12, C11, C12); /* C12 = C11 - C12 */ _mzd_add(C11, C21, C11); /* C11 = C21 - C11 */ _mzd_add(Wkn, Wkn, B11); /* Wkn = Wkn - B11 */ _mzd_mul_even(C21, A21, Wkn, cutoff); /* C21 = A21 * Wkn */ mzd_free(Wkn); _mzd_add(C21, C11, C21); /* C21 = C11 - C21 */ _mzd_add(C22, C22, C11); /* C22 = C22 + C11 */ _mzd_mul_even(C11, A11, B11, cutoff); /* C11 = A11 * B11 */ _mzd_add(C11, C11, Wmk); /* C11 = C11 + Wmk */ /* clean up */ mzd_free_window((mzd_t*)A11); mzd_free_window((mzd_t*)A12); mzd_free_window((mzd_t*)A21); mzd_free_window((mzd_t*)A22); mzd_free_window((mzd_t*)B11); mzd_free_window((mzd_t*)B12); mzd_free_window((mzd_t*)B21); mzd_free_window((mzd_t*)B22); mzd_free_window(C11); mzd_free_window(C12); mzd_free_window(C21); mzd_free_window(C22); mzd_free(Wmk); } /* deal with rest */ nnn *= 2; if (n > nnn) { /* |AA| | B| | C| * Compute |AA| x | B| = | C| */ mzd_t const *B_last_col = mzd_init_window_const(B, 0, nnn, k, n); mzd_t *C_last_col = mzd_init_window(C, 0, nnn, m, n); _mzd_mul_m4rm(C_last_col, A, B_last_col, 0, TRUE); mzd_free_window((mzd_t*)B_last_col); mzd_free_window(C_last_col); } mmm *= 2; if (m > mmm) { /* | | |B | | | * Compute |AA| x |B | = |C | */ mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, k); mzd_t const *B_first_col= mzd_init_window_const(B, 0, 0, k, nnn); mzd_t *C_last_row = mzd_init_window(C, mmm, 0, m, nnn); _mzd_mul_m4rm(C_last_row, A_last_row, B_first_col, 0, TRUE); mzd_free_window((mzd_t*)A_last_row); mzd_free_window((mzd_t*)B_first_col); mzd_free_window(C_last_row); } kkk *= 2; if (k > kkk) { /* Add to | | | B| |C | * result |A | x | | = | | */ mzd_t const *A_last_col = mzd_init_window_const(A, 0, kkk, mmm, k); mzd_t const *B_last_row = mzd_init_window_const(B, kkk, 0, k, nnn); mzd_t *C_bulk = mzd_init_window(C, 0, 0, mmm, nnn); mzd_addmul_m4rm(C_bulk, A_last_col, B_last_row, 0); mzd_free_window((mzd_t*)A_last_col); mzd_free_window((mzd_t*)B_last_row); mzd_free_window(C_bulk); } __M4RI_DD_MZD(C); return C; } mzd_t *_mzd_sqr_even(mzd_t *C, mzd_t const *A, int cutoff) { rci_t m; m = A->nrows; /* handle case first, where the input matrices are too small already */ if (closer(m, cutoff)) { /* we copy the matrices first since it is only constant memory overhead and improves data locality */ if(mzd_is_windowed(A)|mzd_is_windowed(C)) { mzd_t *Abar = mzd_copy(NULL, A); mzd_t *Cbar = mzd_init(m, m); _mzd_mul_m4rm(Cbar, Abar, Abar, 0, FALSE); mzd_copy(C, Cbar); mzd_free(Cbar); mzd_free(Abar); } else { _mzd_mul_m4rm(C, A, A, 0, TRUE); } return C; } /* adjust cutting numbers to work on words */ rci_t mmm; { rci_t mult = m4ri_radix; rci_t width = m / 2; while (width > cutoff) { width /= 2; mult *= 2; } mmm = (((m - m % mult) / m4ri_radix) >> 1) * m4ri_radix; } /* |A | |A | |C | * Compute | | x | | = | | */ { mzd_t const *A11 = mzd_init_window_const(A, 0, 0, mmm, mmm); mzd_t const *A12 = mzd_init_window_const(A, 0, mmm, mmm, 2*mmm); mzd_t const *A21 = mzd_init_window_const(A, mmm, 0, 2*mmm, mmm); mzd_t const *A22 = mzd_init_window_const(A, mmm, mmm, 2*mmm, 2*mmm); mzd_t *C11 = mzd_init_window(C, 0, 0, mmm, mmm); mzd_t *C12 = mzd_init_window(C, 0, mmm, mmm, 2*mmm); mzd_t *C21 = mzd_init_window(C, mmm, 0, 2*mmm, mmm); mzd_t *C22 = mzd_init_window(C, mmm, mmm, 2*mmm, 2*mmm); /** * \note See Marco Bodrato; "A Strassen-like Matrix Multiplication * Suited for Squaring and Highest Power Computation"; * http://bodrato.it/papres/#CIVV2008 for reference on the used * sequence of operations. */ mzd_t *Wmk; mzd_t *Wkn = mzd_init(mmm, mmm); _mzd_add(Wkn, A22, A12); /* Wkn = A22 + A12 */ _mzd_sqr_even(C21, Wkn, cutoff); /* C21 = Wkn^2 */ _mzd_add(Wkn, A22, A21); /* Wkn = A22 - A21 */ _mzd_sqr_even(C22, Wkn, cutoff); /* C22 = Wkn^2 */ _mzd_add(Wkn, Wkn, A12); /* Wkn = Wkn + A12 */ _mzd_sqr_even(C11, Wkn, cutoff); /* C11 = Wkn^2 */ _mzd_add(Wkn, Wkn, A11); /* Wkn = Wkn - A11 */ _mzd_mul_even(C12, Wkn, A12, cutoff);/* C12 = Wkn * A12 */ _mzd_add(C12, C12, C22); /* C12 = C12 + C22 */ Wmk = mzd_mul(NULL, A12, A21, cutoff);/*Wmk = A12 * A21 */ _mzd_add(C11, C11, Wmk); /* C11 = C11 + Wmk */ _mzd_add(C12, C11, C12); /* C12 = C11 - C12 */ _mzd_add(C11, C21, C11); /* C11 = C21 - C11 */ _mzd_mul_even(C21, A21, Wkn, cutoff);/* C21 = A21 * Wkn */ mzd_free(Wkn); _mzd_add(C21, C11, C21); /* C21 = C11 - C21 */ _mzd_add(C22, C22, C11); /* C22 = C22 + C11 */ _mzd_sqr_even(C11, A11, cutoff); /* C11 = A11^2 */ _mzd_add(C11, C11, Wmk); /* C11 = C11 + Wmk */ /* clean up */ mzd_free_window((mzd_t*)A11); mzd_free_window((mzd_t*)A12); mzd_free_window((mzd_t*)A21); mzd_free_window((mzd_t*)A22); mzd_free_window(C11); mzd_free_window(C12); mzd_free_window(C21); mzd_free_window(C22); mzd_free(Wmk); } /* deal with rest */ mmm *= 2; if (m > mmm) { /* |AA| | A| | C| * Compute |AA| x | A| = | C| */ { mzd_t const *A_last_col = mzd_init_window_const(A, 0, mmm, m, m); mzd_t *C_last_col = mzd_init_window(C, 0, mmm, m, m); _mzd_mul_m4rm(C_last_col, A, A_last_col, 0, TRUE); mzd_free_window((mzd_t*)A_last_col); mzd_free_window(C_last_col); } /* | | |A | | | * Compute |AA| x |A | = |C | */ { mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, m); mzd_t const *A_first_col= mzd_init_window_const(A, 0, 0, m, mmm); mzd_t *C_last_row = mzd_init_window(C, mmm, 0, m, mmm); _mzd_mul_m4rm(C_last_row, A_last_row, A_first_col, 0, TRUE); mzd_free_window((mzd_t*)A_last_row); mzd_free_window((mzd_t*)A_first_col); mzd_free_window(C_last_row); } /* Add to | | | A| |C | * result |A | x | | = | | */ { mzd_t const *A_last_col = mzd_init_window_const(A, 0, mmm, mmm, m); mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, mmm); mzd_t *C_bulk = mzd_init_window(C, 0, 0, mmm, mmm); mzd_addmul_m4rm(C_bulk, A_last_col, A_last_row, 0); mzd_free_window((mzd_t*)A_last_col); mzd_free_window((mzd_t*)A_last_row); mzd_free_window(C_bulk); } } __M4RI_DD_MZD(C); return C; } mzd_t *mzd_mul(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) { if(A->ncols != B->nrows) m4ri_die("mzd_mul: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows); if (cutoff < 0) m4ri_die("mzd_mul: cutoff must be >= 0.\n"); if(cutoff == 0) { cutoff = __M4RI_STRASSEN_MUL_CUTOFF; } cutoff = cutoff / m4ri_radix * m4ri_radix; if (cutoff < m4ri_radix) { cutoff = m4ri_radix; }; if (C == NULL) { C = mzd_init(A->nrows, B->ncols); } else if (C->nrows != A->nrows || C->ncols != B->ncols){ m4ri_die("mzd_mul: C (%d x %d) has wrong dimensions, expected (%d x %d)\n", C->nrows, C->ncols, A->nrows, B->ncols); } C = (A == B) ? _mzd_sqr_even(C, A, cutoff) : _mzd_mul_even(C, A, B, cutoff); return C; } mzd_t *_mzd_addmul_even(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) { /** * \todo make sure not to overwrite crap after ncols and before width * m4ri_radix */ if(C->nrows == 0 || C->ncols == 0) return C; rci_t m = A->nrows; rci_t k = A->ncols; rci_t n = B->ncols; /* handle case first, where the input matrices are too small already */ if (closer(m, cutoff) || closer(k, cutoff) || closer(n, cutoff)) { /* we copy the matrices first since it is only constant memory overhead and improves data locality */ if(mzd_is_windowed(A)|mzd_is_windowed(B)|mzd_is_windowed(C)) { mzd_t *Abar = mzd_copy(NULL, A); mzd_t *Bbar = mzd_copy(NULL, B); mzd_t *Cbar = mzd_copy(NULL, C); mzd_addmul_m4rm(Cbar, Abar, Bbar, 0); mzd_copy(C, Cbar); mzd_free(Cbar); mzd_free(Bbar); mzd_free(Abar); } else { mzd_addmul_m4rm(C, A, B, 0); } return C; } /* adjust cutting numbers to work on words */ rci_t mmm, kkk, nnn; { rci_t mult = m4ri_radix; rci_t width = MIN(MIN(m, n), k) / 2; while (width > cutoff) { width /= 2; mult *= 2; } mmm = (((m - m % mult) / m4ri_radix) >> 1) * m4ri_radix; kkk = (((k - k % mult) / m4ri_radix) >> 1) * m4ri_radix; nnn = (((n - n % mult) / m4ri_radix) >> 1) * m4ri_radix; } /* |C | |A | |B | * Compute | | += | | x | | */ { mzd_t const *A11 = mzd_init_window_const(A, 0, 0, mmm, kkk); mzd_t const *A12 = mzd_init_window_const(A, 0, kkk, mmm, 2*kkk); mzd_t const *A21 = mzd_init_window_const(A, mmm, 0, 2*mmm, kkk); mzd_t const *A22 = mzd_init_window_const(A, mmm, kkk, 2*mmm, 2*kkk); mzd_t const *B11 = mzd_init_window_const(B, 0, 0, kkk, nnn); mzd_t const *B12 = mzd_init_window_const(B, 0, nnn, kkk, 2*nnn); mzd_t const *B21 = mzd_init_window_const(B, kkk, 0, 2*kkk, nnn); mzd_t const *B22 = mzd_init_window_const(B, kkk, nnn, 2*kkk, 2*nnn); mzd_t *C11 = mzd_init_window(C, 0, 0, mmm, nnn); mzd_t *C12 = mzd_init_window(C, 0, nnn, mmm, 2*nnn); mzd_t *C21 = mzd_init_window(C, mmm, 0, 2*mmm, nnn); mzd_t *C22 = mzd_init_window(C, mmm, nnn, 2*mmm, 2*nnn); /** * \note See Marco Bodrato; "A Strassen-like Matrix Multiplication * Suited for Squaring and Highest Power Computation"; * http://bodrato.it/papres/#CIVV2008 for reference on the used * sequence of operations. */ mzd_t *S = mzd_init(mmm, kkk); mzd_t *T = mzd_init(kkk, nnn); mzd_t *U = mzd_init(mmm, nnn); _mzd_add(S, A22, A21); /* 1 S = A22 - A21 */ _mzd_add(T, B22, B21); /* 2 T = B22 - B21 */ _mzd_mul_even(U, S, T, cutoff); /* 3 U = S*T */ _mzd_add(C22, U, C22); /* 4 C22 = U + C22 */ _mzd_add(C12, U, C12); /* 5 C12 = U + C12 */ _mzd_mul_even(U, A12, B21, cutoff); /* 8 U = A12*B21 */ _mzd_add(C11, U, C11); /* 9 C11 = U + C11 */ _mzd_addmul_even(C11, A11, B11, cutoff); /* 11 C11 = A11*B11 + C11 */ _mzd_add(S, S, A12); /* 6 S = S - A12 */ _mzd_add(T, T, B12); /* 7 T = T - B12 */ _mzd_addmul_even(U, S, T, cutoff); /* 10 U = S*T + U */ _mzd_add(C12, C12, U); /* 15 C12 = U + C12 */ _mzd_add(S, A11, S); /* 12 S = A11 - S */ _mzd_addmul_even(C12, S, B12, cutoff); /* 14 C12 = S*B12 + C12 */ _mzd_add(T, B11, T); /* 13 T = B11 - T */ _mzd_addmul_even(C21, A21, T, cutoff); /* 16 C21 = A21*T + C21 */ _mzd_add(S, A22, A12); /* 17 S = A22 + A21 */ _mzd_add(T, B22, B12); /* 18 T = B22 + B21 */ _mzd_addmul_even(U, S, T, cutoff); /* 19 U = U - S*T */ _mzd_add(C21, C21, U); /* 20 C21 = C21 - U */ _mzd_add(C22, C22, U); /* 21 C22 = C22 - U */ /* clean up */ mzd_free_window((mzd_t*)A11); mzd_free_window((mzd_t*)A12); mzd_free_window((mzd_t*)A21); mzd_free_window((mzd_t*)A22); mzd_free_window((mzd_t*)B11); mzd_free_window((mzd_t*)B12); mzd_free_window((mzd_t*)B21); mzd_free_window((mzd_t*)B22); mzd_free_window(C11); mzd_free_window(C12); mzd_free_window(C21); mzd_free_window(C22); mzd_free(S); mzd_free(T); mzd_free(U); } /* deal with rest */ nnn *= 2; if (n > nnn) { /* | C| |AA| | B| * Compute | C| += |AA| x | B| */ mzd_t const *B_last_col = mzd_init_window_const(B, 0, nnn, k, n); mzd_t *C_last_col = mzd_init_window(C, 0, nnn, m, n); mzd_addmul_m4rm(C_last_col, A, B_last_col, 0); mzd_free_window((mzd_t*)B_last_col); mzd_free_window(C_last_col); } mmm *= 2; if (m > mmm) { /* | | | | |B | * Compute |C | += |AA| x |B | */ mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, k); mzd_t const *B_first_col= mzd_init_window_const(B, 0, 0, k, nnn); mzd_t *C_last_row = mzd_init_window(C, mmm, 0, m, nnn); mzd_addmul_m4rm(C_last_row, A_last_row, B_first_col, 0); mzd_free_window((mzd_t*)A_last_row); mzd_free_window((mzd_t*)B_first_col); mzd_free_window(C_last_row); } kkk *= 2; if (k > kkk) { /* Add to | | | B| |C | * result |A | x | | = | | */ mzd_t const *A_last_col = mzd_init_window_const(A, 0, kkk, mmm, k); mzd_t const *B_last_row = mzd_init_window_const(B, kkk, 0, k, nnn); mzd_t *C_bulk = mzd_init_window(C, 0, 0, mmm, nnn); mzd_addmul_m4rm(C_bulk, A_last_col, B_last_row, 0); mzd_free_window((mzd_t*)A_last_col); mzd_free_window((mzd_t*)B_last_row); mzd_free_window(C_bulk); } __M4RI_DD_MZD(C); return C; } mzd_t *_mzd_addsqr_even(mzd_t *C, mzd_t const *A, int cutoff) { /** * \todo make sure not to overwrite crap after ncols and before width * m4ri_radix */ if(C->nrows == 0) return C; rci_t m = A->nrows; /* handle case first, where the input matrices are too small already */ if (closer(m, cutoff)) { /* we copy the matrices first since it is only constant memory overhead and improves data locality */ if(mzd_is_windowed(A)|mzd_is_windowed(C)) { mzd_t *Cbar = mzd_copy(NULL, C); mzd_t *Abar = mzd_copy(NULL, A); mzd_addmul_m4rm(Cbar, Abar, Abar, 0); mzd_copy(C, Cbar); mzd_free(Cbar); mzd_free(Abar); } else { mzd_addmul_m4rm(C, A, A, 0); } return C; } /* adjust cutting numbers to work on words */ rci_t mmm; { rci_t mult = m4ri_radix; rci_t width = m / 2; while (width > cutoff) { width /= 2; mult *= 2; } mmm = (((m - m % mult) / m4ri_radix) >> 1) * m4ri_radix; } /* |C | |A | |B | * Compute | | += | | x | | */ { mzd_t const *A11 = mzd_init_window_const(A, 0, 0, mmm, mmm); mzd_t const *A12 = mzd_init_window_const(A, 0, mmm, mmm, 2*mmm); mzd_t const *A21 = mzd_init_window_const(A, mmm, 0, 2*mmm, mmm); mzd_t const *A22 = mzd_init_window_const(A, mmm, mmm, 2*mmm, 2*mmm); mzd_t *C11 = mzd_init_window(C, 0, 0, mmm, mmm); mzd_t *C12 = mzd_init_window(C, 0, mmm, mmm, 2*mmm); mzd_t *C21 = mzd_init_window(C, mmm, 0, 2*mmm, mmm); mzd_t *C22 = mzd_init_window(C, mmm, mmm, 2*mmm, 2*mmm); /** * \note See Marco Bodrato; "A Strassen-like Matrix Multiplication * Suited for Squaring and Highest Power Computation"; on-line v. * http://bodrato.it/papres/#CIVV2008 for reference on the used * sequence of operations. */ mzd_t *S = mzd_init(mmm, mmm); mzd_t *U = mzd_init(mmm, mmm); _mzd_add(S, A22, A21); /* 1 S = A22 - A21 */ _mzd_sqr_even(U, S, cutoff); /* 3 U = S^2 */ _mzd_add(C22, U, C22); /* 4 C22 = U + C22 */ _mzd_add(C12, U, C12); /* 5 C12 = U + C12 */ _mzd_mul_even(U, A12, A21, cutoff); /* 8 U = A12*A21 */ _mzd_add(C11, U, C11); /* 9 C11 = U + C11 */ _mzd_addsqr_even(C11, A11, cutoff); /* 11 C11 = A11^2 + C11 */ _mzd_add(S, S, A12); /* 6 S = S + A12 */ _mzd_addsqr_even(U, S, cutoff); /* 10 U = S^2 + U */ _mzd_add(C12, C12, U); /* 15 C12 = U + C12 */ _mzd_add(S, A11, S); /* 12 S = A11 - S */ _mzd_addmul_even(C12, S, A12, cutoff); /* 14 C12 = S*B12 + C12 */ _mzd_addmul_even(C21, A21, S, cutoff); /* 16 C21 = A21*T + C21 */ _mzd_add(S, A22, A12); /* 17 S = A22 + A21 */ _mzd_addsqr_even(U, S, cutoff); /* 19 U = U - S^2 */ _mzd_add(C21, C21, U); /* 20 C21 = C21 - U3 */ _mzd_add(C22, C22, U); /* 21 C22 = C22 - U3 */ /* clean up */ mzd_free_window((mzd_t*)A11); mzd_free_window((mzd_t*)A12); mzd_free_window((mzd_t*)A21); mzd_free_window((mzd_t*)A22); mzd_free_window(C11); mzd_free_window(C12); mzd_free_window(C21); mzd_free_window(C22); mzd_free(S); mzd_free(U); } /* deal with rest */ mmm *= 2; if (m > mmm) { /* | C| |AA| | B| * Compute | C| += |AA| x | B| */ { mzd_t const *A_last_col = mzd_init_window_const(A, 0, mmm, m, m); mzd_t *C_last_col = mzd_init_window(C, 0, mmm, m, m); mzd_addmul_m4rm(C_last_col, A, A_last_col, 0); mzd_free_window((mzd_t*)A_last_col); mzd_free_window(C_last_col); } /* | | | | |B | * Compute |C | += |AA| x |B | */ { mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, m); mzd_t const *A_first_col= mzd_init_window_const(A, 0, 0, m, mmm); mzd_t *C_last_row = mzd_init_window(C, mmm, 0, m, mmm); mzd_addmul_m4rm(C_last_row, A_last_row, A_first_col, 0); mzd_free_window((mzd_t*)A_last_row); mzd_free_window((mzd_t*)A_first_col); mzd_free_window(C_last_row); } /* Add to | | | B| |C | * result |A | x | | = | | */ { mzd_t const *A_last_col = mzd_init_window_const(A, 0, mmm, mmm, m); mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, mmm); mzd_t *C_bulk = mzd_init_window(C, 0, 0, mmm, mmm); mzd_addmul_m4rm(C_bulk, A_last_col, A_last_row, 0); mzd_free_window((mzd_t*)A_last_col); mzd_free_window((mzd_t*)A_last_row); mzd_free_window(C_bulk); } } __M4RI_DD_MZD(C); return C; } mzd_t *_mzd_addmul(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) { /** * Assumes that B and C are aligned in the same manner (as in a Schur complement) */ return (A == B) ? _mzd_addsqr_even(C, A, cutoff) : _mzd_addmul_even(C, A, B, cutoff); } mzd_t *mzd_addmul(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) { if(A->ncols != B->nrows) m4ri_die("mzd_addmul: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows); if (cutoff < 0) m4ri_die("mzd_addmul: cutoff must be >= 0.\n"); if(cutoff == 0) { cutoff = __M4RI_STRASSEN_MUL_CUTOFF; } cutoff = cutoff / m4ri_radix * m4ri_radix; if (cutoff < m4ri_radix) { cutoff = m4ri_radix; }; if (C == NULL) { C = mzd_init(A->nrows, B->ncols); } else if (C->nrows != A->nrows || C->ncols != B->ncols){ m4ri_die("mzd_addmul: C (%d x %d) has wrong dimensions, expected (%d x %d)\n", C->nrows, C->ncols, A->nrows, B->ncols); } if(A->nrows == 0 || A->ncols == 0 || B->ncols == 0) { __M4RI_DD_MZD(C); return C; } C = _mzd_addmul(C, A, B, cutoff); __M4RI_DD_MZD(C); return C; }