699 lines
24 KiB
C

/*******************************************************************
*
* M4RI: Linear Algebra over GF(2)
*
* Copyright (C) 2008 Martin Albrecht <M.R.Albrecht@rhul.ac.uk>
* Copyright (C) 2008 Clement Pernet <pernet@math.washington.edu>
* Copyright (C) 2008 Marco Bodrato <bodrato@mail.dm.unipi.it>
*
* Distributed under the terms of the GNU General Public License (GPL)
* version 2 or higher.
*
* This code is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* The full text of the GPL is available at:
*
* http://www.gnu.org/licenses/
*
********************************************************************/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "graycode.h"
#include "strassen.h"
#include "parity.h"
#ifndef MIN
#define MIN(a,b) (((a)<(b))?(a):(b))
#endif
#if __M4RI_HAVE_OPENMP
#include <omp.h>
#endif
// Returns true if a is closer to cutoff than a/2.
static inline int closer(rci_t a, int cutoff) {
return 3 * a < 4 * cutoff;
}
mzd_t *_mzd_mul_even(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) {
rci_t mmm, kkk, nnn;
if(C->nrows == 0 || C->ncols == 0)
return C;
rci_t m = A->nrows;
rci_t k = A->ncols;
rci_t n = B->ncols;
/* handle case first, where the input matrices are too small already */
if (closer(m, cutoff) || closer(k, cutoff) || closer(n, cutoff)) {
/* we copy the matrices first since it is only constant memory overhead and improves data
locality */
if(mzd_is_windowed(A)|mzd_is_windowed(B)|mzd_is_windowed(C)) {
mzd_t *Abar = mzd_copy(NULL, A);
mzd_t *Bbar = mzd_copy(NULL, B);
mzd_t *Cbar = mzd_init(m, n);
_mzd_mul_m4rm(Cbar, Abar, Bbar, 0, FALSE);
mzd_copy(C, Cbar);
mzd_free(Cbar);
mzd_free(Bbar);
mzd_free(Abar);
} else {
_mzd_mul_m4rm(C, A, B, 0, TRUE);
}
return C;
}
/* adjust cutting numbers to work on words */
rci_t mult = m4ri_radix;
rci_t width = MIN(MIN(m, n), k) / 2;
while (width > cutoff) {
width /= 2;
mult *= 2;
}
mmm = (((m - m % mult) / m4ri_radix) >> 1) * m4ri_radix;
kkk = (((k - k % mult) / m4ri_radix) >> 1) * m4ri_radix;
nnn = (((n - n % mult) / m4ri_radix) >> 1) * m4ri_radix;
/* |A | |B | |C |
* Compute | | x | | = | | */
{
mzd_t const *A11 = mzd_init_window_const(A, 0, 0, mmm, kkk);
mzd_t const *A12 = mzd_init_window_const(A, 0, kkk, mmm, 2*kkk);
mzd_t const *A21 = mzd_init_window_const(A, mmm, 0, 2*mmm, kkk);
mzd_t const *A22 = mzd_init_window_const(A, mmm, kkk, 2*mmm, 2*kkk);
mzd_t const *B11 = mzd_init_window_const(B, 0, 0, kkk, nnn);
mzd_t const *B12 = mzd_init_window_const(B, 0, nnn, kkk, 2*nnn);
mzd_t const *B21 = mzd_init_window_const(B, kkk, 0, 2*kkk, nnn);
mzd_t const *B22 = mzd_init_window_const(B, kkk, nnn, 2*kkk, 2*nnn);
mzd_t *C11 = mzd_init_window(C, 0, 0, mmm, nnn);
mzd_t *C12 = mzd_init_window(C, 0, nnn, mmm, 2*nnn);
mzd_t *C21 = mzd_init_window(C, mmm, 0, 2*mmm, nnn);
mzd_t *C22 = mzd_init_window(C, mmm, nnn, 2*mmm, 2*nnn);
/**
* \note See Marco Bodrato; "A Strassen-like Matrix Multiplication
* Suited for Squaring and Highest Power Computation";
* http://bodrato.it/papres/#CIVV2008 for reference on the used
* sequence of operations.
*/
/* change this to mzd_init(mmm, MAX(nnn,kkk)) to fix the todo below */
mzd_t *Wmk = mzd_init(mmm, kkk);
mzd_t *Wkn = mzd_init(kkk, nnn);
_mzd_add(Wkn, B22, B12); /* Wkn = B22 + B12 */
_mzd_add(Wmk, A22, A12); /* Wmk = A22 + A12 */
_mzd_mul_even(C21, Wmk, Wkn, cutoff);/* C21 = Wmk * Wkn */
_mzd_add(Wmk, A22, A21); /* Wmk = A22 - A21 */
_mzd_add(Wkn, B22, B21); /* Wkn = B22 - B21 */
_mzd_mul_even(C22, Wmk, Wkn, cutoff);/* C22 = Wmk * Wkn */
_mzd_add(Wkn, Wkn, B12); /* Wkn = Wkn + B12 */
_mzd_add(Wmk, Wmk, A12); /* Wmk = Wmk + A12 */
_mzd_mul_even(C11, Wmk, Wkn, cutoff);/* C11 = Wmk * Wkn */
_mzd_add(Wmk, Wmk, A11); /* Wmk = Wmk - A11 */
_mzd_mul_even(C12, Wmk, B12, cutoff);/* C12 = Wmk * B12 */
_mzd_add(C12, C12, C22); /* C12 = C12 + C22 */
/**
* \todo ideally we would use the same Wmk throughout the function
* but some called function doesn't like that and we end up with a
* wrong result if we use virtual Wmk matrices. Ideally, this should
* be fixed not worked around. The check whether the bug has been
* fixed, use only one Wmk and check if mzd_mul(4096, 3528,
* 4096, 2124) still returns the correct answer.
*/
mzd_free(Wmk);
Wmk = mzd_mul(NULL, A12, B21, cutoff);/*Wmk = A12 * B21 */
_mzd_add(C11, C11, Wmk); /* C11 = C11 + Wmk */
_mzd_add(C12, C11, C12); /* C12 = C11 - C12 */
_mzd_add(C11, C21, C11); /* C11 = C21 - C11 */
_mzd_add(Wkn, Wkn, B11); /* Wkn = Wkn - B11 */
_mzd_mul_even(C21, A21, Wkn, cutoff); /* C21 = A21 * Wkn */
mzd_free(Wkn);
_mzd_add(C21, C11, C21); /* C21 = C11 - C21 */
_mzd_add(C22, C22, C11); /* C22 = C22 + C11 */
_mzd_mul_even(C11, A11, B11, cutoff); /* C11 = A11 * B11 */
_mzd_add(C11, C11, Wmk); /* C11 = C11 + Wmk */
/* clean up */
mzd_free_window((mzd_t*)A11); mzd_free_window((mzd_t*)A12);
mzd_free_window((mzd_t*)A21); mzd_free_window((mzd_t*)A22);
mzd_free_window((mzd_t*)B11); mzd_free_window((mzd_t*)B12);
mzd_free_window((mzd_t*)B21); mzd_free_window((mzd_t*)B22);
mzd_free_window(C11); mzd_free_window(C12);
mzd_free_window(C21); mzd_free_window(C22);
mzd_free(Wmk);
}
/* deal with rest */
nnn *= 2;
if (n > nnn) {
/* |AA| | B| | C|
* Compute |AA| x | B| = | C| */
mzd_t const *B_last_col = mzd_init_window_const(B, 0, nnn, k, n);
mzd_t *C_last_col = mzd_init_window(C, 0, nnn, m, n);
_mzd_mul_m4rm(C_last_col, A, B_last_col, 0, TRUE);
mzd_free_window((mzd_t*)B_last_col);
mzd_free_window(C_last_col);
}
mmm *= 2;
if (m > mmm) {
/* | | |B | | |
* Compute |AA| x |B | = |C | */
mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, k);
mzd_t const *B_first_col= mzd_init_window_const(B, 0, 0, k, nnn);
mzd_t *C_last_row = mzd_init_window(C, mmm, 0, m, nnn);
_mzd_mul_m4rm(C_last_row, A_last_row, B_first_col, 0, TRUE);
mzd_free_window((mzd_t*)A_last_row);
mzd_free_window((mzd_t*)B_first_col);
mzd_free_window(C_last_row);
}
kkk *= 2;
if (k > kkk) {
/* Add to | | | B| |C |
* result |A | x | | = | | */
mzd_t const *A_last_col = mzd_init_window_const(A, 0, kkk, mmm, k);
mzd_t const *B_last_row = mzd_init_window_const(B, kkk, 0, k, nnn);
mzd_t *C_bulk = mzd_init_window(C, 0, 0, mmm, nnn);
mzd_addmul_m4rm(C_bulk, A_last_col, B_last_row, 0);
mzd_free_window((mzd_t*)A_last_col);
mzd_free_window((mzd_t*)B_last_row);
mzd_free_window(C_bulk);
}
__M4RI_DD_MZD(C);
return C;
}
mzd_t *_mzd_sqr_even(mzd_t *C, mzd_t const *A, int cutoff) {
rci_t m;
m = A->nrows;
/* handle case first, where the input matrices are too small already */
if (closer(m, cutoff)) {
/* we copy the matrices first since it is only constant memory overhead and improves data
locality */
if(mzd_is_windowed(A)|mzd_is_windowed(C)) {
mzd_t *Abar = mzd_copy(NULL, A);
mzd_t *Cbar = mzd_init(m, m);
_mzd_mul_m4rm(Cbar, Abar, Abar, 0, FALSE);
mzd_copy(C, Cbar);
mzd_free(Cbar);
mzd_free(Abar);
} else {
_mzd_mul_m4rm(C, A, A, 0, TRUE);
}
return C;
}
/* adjust cutting numbers to work on words */
rci_t mmm;
{
rci_t mult = m4ri_radix;
rci_t width = m / 2;
while (width > cutoff) {
width /= 2;
mult *= 2;
}
mmm = (((m - m % mult) / m4ri_radix) >> 1) * m4ri_radix;
}
/* |A | |A | |C |
* Compute | | x | | = | | */
{
mzd_t const *A11 = mzd_init_window_const(A, 0, 0, mmm, mmm);
mzd_t const *A12 = mzd_init_window_const(A, 0, mmm, mmm, 2*mmm);
mzd_t const *A21 = mzd_init_window_const(A, mmm, 0, 2*mmm, mmm);
mzd_t const *A22 = mzd_init_window_const(A, mmm, mmm, 2*mmm, 2*mmm);
mzd_t *C11 = mzd_init_window(C, 0, 0, mmm, mmm);
mzd_t *C12 = mzd_init_window(C, 0, mmm, mmm, 2*mmm);
mzd_t *C21 = mzd_init_window(C, mmm, 0, 2*mmm, mmm);
mzd_t *C22 = mzd_init_window(C, mmm, mmm, 2*mmm, 2*mmm);
/**
* \note See Marco Bodrato; "A Strassen-like Matrix Multiplication
* Suited for Squaring and Highest Power Computation";
* http://bodrato.it/papres/#CIVV2008 for reference on the used
* sequence of operations.
*/
mzd_t *Wmk;
mzd_t *Wkn = mzd_init(mmm, mmm);
_mzd_add(Wkn, A22, A12); /* Wkn = A22 + A12 */
_mzd_sqr_even(C21, Wkn, cutoff); /* C21 = Wkn^2 */
_mzd_add(Wkn, A22, A21); /* Wkn = A22 - A21 */
_mzd_sqr_even(C22, Wkn, cutoff); /* C22 = Wkn^2 */
_mzd_add(Wkn, Wkn, A12); /* Wkn = Wkn + A12 */
_mzd_sqr_even(C11, Wkn, cutoff); /* C11 = Wkn^2 */
_mzd_add(Wkn, Wkn, A11); /* Wkn = Wkn - A11 */
_mzd_mul_even(C12, Wkn, A12, cutoff);/* C12 = Wkn * A12 */
_mzd_add(C12, C12, C22); /* C12 = C12 + C22 */
Wmk = mzd_mul(NULL, A12, A21, cutoff);/*Wmk = A12 * A21 */
_mzd_add(C11, C11, Wmk); /* C11 = C11 + Wmk */
_mzd_add(C12, C11, C12); /* C12 = C11 - C12 */
_mzd_add(C11, C21, C11); /* C11 = C21 - C11 */
_mzd_mul_even(C21, A21, Wkn, cutoff);/* C21 = A21 * Wkn */
mzd_free(Wkn);
_mzd_add(C21, C11, C21); /* C21 = C11 - C21 */
_mzd_add(C22, C22, C11); /* C22 = C22 + C11 */
_mzd_sqr_even(C11, A11, cutoff); /* C11 = A11^2 */
_mzd_add(C11, C11, Wmk); /* C11 = C11 + Wmk */
/* clean up */
mzd_free_window((mzd_t*)A11); mzd_free_window((mzd_t*)A12);
mzd_free_window((mzd_t*)A21); mzd_free_window((mzd_t*)A22);
mzd_free_window(C11); mzd_free_window(C12);
mzd_free_window(C21); mzd_free_window(C22);
mzd_free(Wmk);
}
/* deal with rest */
mmm *= 2;
if (m > mmm) {
/* |AA| | A| | C|
* Compute |AA| x | A| = | C| */
{
mzd_t const *A_last_col = mzd_init_window_const(A, 0, mmm, m, m);
mzd_t *C_last_col = mzd_init_window(C, 0, mmm, m, m);
_mzd_mul_m4rm(C_last_col, A, A_last_col, 0, TRUE);
mzd_free_window((mzd_t*)A_last_col);
mzd_free_window(C_last_col);
}
/* | | |A | | |
* Compute |AA| x |A | = |C | */
{
mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, m);
mzd_t const *A_first_col= mzd_init_window_const(A, 0, 0, m, mmm);
mzd_t *C_last_row = mzd_init_window(C, mmm, 0, m, mmm);
_mzd_mul_m4rm(C_last_row, A_last_row, A_first_col, 0, TRUE);
mzd_free_window((mzd_t*)A_last_row);
mzd_free_window((mzd_t*)A_first_col);
mzd_free_window(C_last_row);
}
/* Add to | | | A| |C |
* result |A | x | | = | | */
{
mzd_t const *A_last_col = mzd_init_window_const(A, 0, mmm, mmm, m);
mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, mmm);
mzd_t *C_bulk = mzd_init_window(C, 0, 0, mmm, mmm);
mzd_addmul_m4rm(C_bulk, A_last_col, A_last_row, 0);
mzd_free_window((mzd_t*)A_last_col);
mzd_free_window((mzd_t*)A_last_row);
mzd_free_window(C_bulk);
}
}
__M4RI_DD_MZD(C);
return C;
}
mzd_t *mzd_mul(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) {
if(A->ncols != B->nrows)
m4ri_die("mzd_mul: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows);
if (cutoff < 0)
m4ri_die("mzd_mul: cutoff must be >= 0.\n");
if(cutoff == 0) {
cutoff = __M4RI_STRASSEN_MUL_CUTOFF;
}
cutoff = cutoff / m4ri_radix * m4ri_radix;
if (cutoff < m4ri_radix) {
cutoff = m4ri_radix;
};
if (C == NULL) {
C = mzd_init(A->nrows, B->ncols);
} else if (C->nrows != A->nrows || C->ncols != B->ncols){
m4ri_die("mzd_mul: C (%d x %d) has wrong dimensions, expected (%d x %d)\n",
C->nrows, C->ncols, A->nrows, B->ncols);
}
C = (A == B) ? _mzd_sqr_even(C, A, cutoff) : _mzd_mul_even(C, A, B, cutoff);
return C;
}
mzd_t *_mzd_addmul_even(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) {
/**
* \todo make sure not to overwrite crap after ncols and before width * m4ri_radix
*/
if(C->nrows == 0 || C->ncols == 0)
return C;
rci_t m = A->nrows;
rci_t k = A->ncols;
rci_t n = B->ncols;
/* handle case first, where the input matrices are too small already */
if (closer(m, cutoff) || closer(k, cutoff) || closer(n, cutoff)) {
/* we copy the matrices first since it is only constant memory overhead and improves data
locality */
if(mzd_is_windowed(A)|mzd_is_windowed(B)|mzd_is_windowed(C)) {
mzd_t *Abar = mzd_copy(NULL, A);
mzd_t *Bbar = mzd_copy(NULL, B);
mzd_t *Cbar = mzd_copy(NULL, C);
mzd_addmul_m4rm(Cbar, Abar, Bbar, 0);
mzd_copy(C, Cbar);
mzd_free(Cbar);
mzd_free(Bbar);
mzd_free(Abar);
} else {
mzd_addmul_m4rm(C, A, B, 0);
}
return C;
}
/* adjust cutting numbers to work on words */
rci_t mmm, kkk, nnn;
{
rci_t mult = m4ri_radix;
rci_t width = MIN(MIN(m, n), k) / 2;
while (width > cutoff) {
width /= 2;
mult *= 2;
}
mmm = (((m - m % mult) / m4ri_radix) >> 1) * m4ri_radix;
kkk = (((k - k % mult) / m4ri_radix) >> 1) * m4ri_radix;
nnn = (((n - n % mult) / m4ri_radix) >> 1) * m4ri_radix;
}
/* |C | |A | |B |
* Compute | | += | | x | | */
{
mzd_t const *A11 = mzd_init_window_const(A, 0, 0, mmm, kkk);
mzd_t const *A12 = mzd_init_window_const(A, 0, kkk, mmm, 2*kkk);
mzd_t const *A21 = mzd_init_window_const(A, mmm, 0, 2*mmm, kkk);
mzd_t const *A22 = mzd_init_window_const(A, mmm, kkk, 2*mmm, 2*kkk);
mzd_t const *B11 = mzd_init_window_const(B, 0, 0, kkk, nnn);
mzd_t const *B12 = mzd_init_window_const(B, 0, nnn, kkk, 2*nnn);
mzd_t const *B21 = mzd_init_window_const(B, kkk, 0, 2*kkk, nnn);
mzd_t const *B22 = mzd_init_window_const(B, kkk, nnn, 2*kkk, 2*nnn);
mzd_t *C11 = mzd_init_window(C, 0, 0, mmm, nnn);
mzd_t *C12 = mzd_init_window(C, 0, nnn, mmm, 2*nnn);
mzd_t *C21 = mzd_init_window(C, mmm, 0, 2*mmm, nnn);
mzd_t *C22 = mzd_init_window(C, mmm, nnn, 2*mmm, 2*nnn);
/**
* \note See Marco Bodrato; "A Strassen-like Matrix Multiplication
* Suited for Squaring and Highest Power Computation";
* http://bodrato.it/papres/#CIVV2008 for reference on the used
* sequence of operations.
*/
mzd_t *S = mzd_init(mmm, kkk);
mzd_t *T = mzd_init(kkk, nnn);
mzd_t *U = mzd_init(mmm, nnn);
_mzd_add(S, A22, A21); /* 1 S = A22 - A21 */
_mzd_add(T, B22, B21); /* 2 T = B22 - B21 */
_mzd_mul_even(U, S, T, cutoff); /* 3 U = S*T */
_mzd_add(C22, U, C22); /* 4 C22 = U + C22 */
_mzd_add(C12, U, C12); /* 5 C12 = U + C12 */
_mzd_mul_even(U, A12, B21, cutoff); /* 8 U = A12*B21 */
_mzd_add(C11, U, C11); /* 9 C11 = U + C11 */
_mzd_addmul_even(C11, A11, B11, cutoff); /* 11 C11 = A11*B11 + C11 */
_mzd_add(S, S, A12); /* 6 S = S - A12 */
_mzd_add(T, T, B12); /* 7 T = T - B12 */
_mzd_addmul_even(U, S, T, cutoff); /* 10 U = S*T + U */
_mzd_add(C12, C12, U); /* 15 C12 = U + C12 */
_mzd_add(S, A11, S); /* 12 S = A11 - S */
_mzd_addmul_even(C12, S, B12, cutoff); /* 14 C12 = S*B12 + C12 */
_mzd_add(T, B11, T); /* 13 T = B11 - T */
_mzd_addmul_even(C21, A21, T, cutoff); /* 16 C21 = A21*T + C21 */
_mzd_add(S, A22, A12); /* 17 S = A22 + A21 */
_mzd_add(T, B22, B12); /* 18 T = B22 + B21 */
_mzd_addmul_even(U, S, T, cutoff); /* 19 U = U - S*T */
_mzd_add(C21, C21, U); /* 20 C21 = C21 - U */
_mzd_add(C22, C22, U); /* 21 C22 = C22 - U */
/* clean up */
mzd_free_window((mzd_t*)A11); mzd_free_window((mzd_t*)A12);
mzd_free_window((mzd_t*)A21); mzd_free_window((mzd_t*)A22);
mzd_free_window((mzd_t*)B11); mzd_free_window((mzd_t*)B12);
mzd_free_window((mzd_t*)B21); mzd_free_window((mzd_t*)B22);
mzd_free_window(C11); mzd_free_window(C12);
mzd_free_window(C21); mzd_free_window(C22);
mzd_free(S);
mzd_free(T);
mzd_free(U);
}
/* deal with rest */
nnn *= 2;
if (n > nnn) {
/* | C| |AA| | B|
* Compute | C| += |AA| x | B| */
mzd_t const *B_last_col = mzd_init_window_const(B, 0, nnn, k, n);
mzd_t *C_last_col = mzd_init_window(C, 0, nnn, m, n);
mzd_addmul_m4rm(C_last_col, A, B_last_col, 0);
mzd_free_window((mzd_t*)B_last_col);
mzd_free_window(C_last_col);
}
mmm *= 2;
if (m > mmm) {
/* | | | | |B |
* Compute |C | += |AA| x |B | */
mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, k);
mzd_t const *B_first_col= mzd_init_window_const(B, 0, 0, k, nnn);
mzd_t *C_last_row = mzd_init_window(C, mmm, 0, m, nnn);
mzd_addmul_m4rm(C_last_row, A_last_row, B_first_col, 0);
mzd_free_window((mzd_t*)A_last_row);
mzd_free_window((mzd_t*)B_first_col);
mzd_free_window(C_last_row);
}
kkk *= 2;
if (k > kkk) {
/* Add to | | | B| |C |
* result |A | x | | = | | */
mzd_t const *A_last_col = mzd_init_window_const(A, 0, kkk, mmm, k);
mzd_t const *B_last_row = mzd_init_window_const(B, kkk, 0, k, nnn);
mzd_t *C_bulk = mzd_init_window(C, 0, 0, mmm, nnn);
mzd_addmul_m4rm(C_bulk, A_last_col, B_last_row, 0);
mzd_free_window((mzd_t*)A_last_col);
mzd_free_window((mzd_t*)B_last_row);
mzd_free_window(C_bulk);
}
__M4RI_DD_MZD(C);
return C;
}
mzd_t *_mzd_addsqr_even(mzd_t *C, mzd_t const *A, int cutoff) {
/**
* \todo make sure not to overwrite crap after ncols and before width * m4ri_radix
*/
if(C->nrows == 0)
return C;
rci_t m = A->nrows;
/* handle case first, where the input matrices are too small already */
if (closer(m, cutoff)) {
/* we copy the matrices first since it is only constant memory overhead and improves data
locality */
if(mzd_is_windowed(A)|mzd_is_windowed(C)) {
mzd_t *Cbar = mzd_copy(NULL, C);
mzd_t *Abar = mzd_copy(NULL, A);
mzd_addmul_m4rm(Cbar, Abar, Abar, 0);
mzd_copy(C, Cbar);
mzd_free(Cbar);
mzd_free(Abar);
} else {
mzd_addmul_m4rm(C, A, A, 0);
}
return C;
}
/* adjust cutting numbers to work on words */
rci_t mmm;
{
rci_t mult = m4ri_radix;
rci_t width = m / 2;
while (width > cutoff) {
width /= 2;
mult *= 2;
}
mmm = (((m - m % mult) / m4ri_radix) >> 1) * m4ri_radix;
}
/* |C | |A | |B |
* Compute | | += | | x | | */
{
mzd_t const *A11 = mzd_init_window_const(A, 0, 0, mmm, mmm);
mzd_t const *A12 = mzd_init_window_const(A, 0, mmm, mmm, 2*mmm);
mzd_t const *A21 = mzd_init_window_const(A, mmm, 0, 2*mmm, mmm);
mzd_t const *A22 = mzd_init_window_const(A, mmm, mmm, 2*mmm, 2*mmm);
mzd_t *C11 = mzd_init_window(C, 0, 0, mmm, mmm);
mzd_t *C12 = mzd_init_window(C, 0, mmm, mmm, 2*mmm);
mzd_t *C21 = mzd_init_window(C, mmm, 0, 2*mmm, mmm);
mzd_t *C22 = mzd_init_window(C, mmm, mmm, 2*mmm, 2*mmm);
/**
* \note See Marco Bodrato; "A Strassen-like Matrix Multiplication
* Suited for Squaring and Highest Power Computation"; on-line v.
* http://bodrato.it/papres/#CIVV2008 for reference on the used
* sequence of operations.
*/
mzd_t *S = mzd_init(mmm, mmm);
mzd_t *U = mzd_init(mmm, mmm);
_mzd_add(S, A22, A21); /* 1 S = A22 - A21 */
_mzd_sqr_even(U, S, cutoff); /* 3 U = S^2 */
_mzd_add(C22, U, C22); /* 4 C22 = U + C22 */
_mzd_add(C12, U, C12); /* 5 C12 = U + C12 */
_mzd_mul_even(U, A12, A21, cutoff); /* 8 U = A12*A21 */
_mzd_add(C11, U, C11); /* 9 C11 = U + C11 */
_mzd_addsqr_even(C11, A11, cutoff); /* 11 C11 = A11^2 + C11 */
_mzd_add(S, S, A12); /* 6 S = S + A12 */
_mzd_addsqr_even(U, S, cutoff); /* 10 U = S^2 + U */
_mzd_add(C12, C12, U); /* 15 C12 = U + C12 */
_mzd_add(S, A11, S); /* 12 S = A11 - S */
_mzd_addmul_even(C12, S, A12, cutoff); /* 14 C12 = S*B12 + C12 */
_mzd_addmul_even(C21, A21, S, cutoff); /* 16 C21 = A21*T + C21 */
_mzd_add(S, A22, A12); /* 17 S = A22 + A21 */
_mzd_addsqr_even(U, S, cutoff); /* 19 U = U - S^2 */
_mzd_add(C21, C21, U); /* 20 C21 = C21 - U3 */
_mzd_add(C22, C22, U); /* 21 C22 = C22 - U3 */
/* clean up */
mzd_free_window((mzd_t*)A11); mzd_free_window((mzd_t*)A12);
mzd_free_window((mzd_t*)A21); mzd_free_window((mzd_t*)A22);
mzd_free_window(C11); mzd_free_window(C12);
mzd_free_window(C21); mzd_free_window(C22);
mzd_free(S);
mzd_free(U);
}
/* deal with rest */
mmm *= 2;
if (m > mmm) {
/* | C| |AA| | B|
* Compute | C| += |AA| x | B| */
{
mzd_t const *A_last_col = mzd_init_window_const(A, 0, mmm, m, m);
mzd_t *C_last_col = mzd_init_window(C, 0, mmm, m, m);
mzd_addmul_m4rm(C_last_col, A, A_last_col, 0);
mzd_free_window((mzd_t*)A_last_col);
mzd_free_window(C_last_col);
}
/* | | | | |B |
* Compute |C | += |AA| x |B | */
{
mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, m);
mzd_t const *A_first_col= mzd_init_window_const(A, 0, 0, m, mmm);
mzd_t *C_last_row = mzd_init_window(C, mmm, 0, m, mmm);
mzd_addmul_m4rm(C_last_row, A_last_row, A_first_col, 0);
mzd_free_window((mzd_t*)A_last_row);
mzd_free_window((mzd_t*)A_first_col);
mzd_free_window(C_last_row);
}
/* Add to | | | B| |C |
* result |A | x | | = | | */
{
mzd_t const *A_last_col = mzd_init_window_const(A, 0, mmm, mmm, m);
mzd_t const *A_last_row = mzd_init_window_const(A, mmm, 0, m, mmm);
mzd_t *C_bulk = mzd_init_window(C, 0, 0, mmm, mmm);
mzd_addmul_m4rm(C_bulk, A_last_col, A_last_row, 0);
mzd_free_window((mzd_t*)A_last_col);
mzd_free_window((mzd_t*)A_last_row);
mzd_free_window(C_bulk);
}
}
__M4RI_DD_MZD(C);
return C;
}
mzd_t *_mzd_addmul(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) {
/**
* Assumes that B and C are aligned in the same manner (as in a Schur complement)
*/
return (A == B) ? _mzd_addsqr_even(C, A, cutoff) : _mzd_addmul_even(C, A, B, cutoff);
}
mzd_t *mzd_addmul(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) {
if(A->ncols != B->nrows)
m4ri_die("mzd_addmul: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows);
if (cutoff < 0)
m4ri_die("mzd_addmul: cutoff must be >= 0.\n");
if(cutoff == 0) {
cutoff = __M4RI_STRASSEN_MUL_CUTOFF;
}
cutoff = cutoff / m4ri_radix * m4ri_radix;
if (cutoff < m4ri_radix) {
cutoff = m4ri_radix;
};
if (C == NULL) {
C = mzd_init(A->nrows, B->ncols);
} else if (C->nrows != A->nrows || C->ncols != B->ncols){
m4ri_die("mzd_addmul: C (%d x %d) has wrong dimensions, expected (%d x %d)\n",
C->nrows, C->ncols, A->nrows, B->ncols);
}
if(A->nrows == 0 || A->ncols == 0 || B->ncols == 0) {
__M4RI_DD_MZD(C);
return C;
}
C = _mzd_addmul(C, A, B, cutoff);
__M4RI_DD_MZD(C);
return C;
}