330 lines
11 KiB
C
330 lines
11 KiB
C
|
|
/*******************************************************************
|
|
*
|
|
* M4RI: Linear Algebra over GF(2)
|
|
*
|
|
* Copyright (C) 2014 Martin Albrecht <martinralbrecht@googlemail.com>
|
|
*
|
|
* Distributed under the terms of the GNU General Public License (GPL)
|
|
* version 2 or higher.
|
|
*
|
|
* This code is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
* General Public License for more details.
|
|
*
|
|
* The full text of the GPL is available at:
|
|
*
|
|
* http://www.gnu.org/licenses/
|
|
*
|
|
********************************************************************/
|
|
|
|
#include "m4ri_config.h"
|
|
#include "misc.h"
|
|
#include "mp.h"
|
|
#include "brilliantrussian.h"
|
|
#include "strassen.h"
|
|
|
|
#if __M4RI_HAVE_OPENMP
|
|
|
|
#ifndef MIN
|
|
#define MIN(a,b) (((a)<(b))?(a):(b))
|
|
#endif
|
|
|
|
#include <omp.h>
|
|
|
|
// Returns true if a is closer to cutoff than a/2.
|
|
static inline int closer(rci_t a, int cutoff) {
|
|
return 3 * a < 4 * cutoff;
|
|
}
|
|
|
|
|
|
mzd_t *_mzd_addmul_mp4(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) {
|
|
/**
|
|
* \todo make sure not to overwrite crap after ncols and before width * m4ri_radix
|
|
*/
|
|
rci_t a = A->nrows;
|
|
rci_t b = A->ncols;
|
|
rci_t c = B->ncols;
|
|
/* handle case first, where the input matrices are too small already */
|
|
if (closer(A->nrows, cutoff) || closer(A->ncols, cutoff) || closer(B->ncols, cutoff)) {
|
|
/* we copy the matrix first since it is only constant memory
|
|
overhead and improves data locality, if you remove it make sure
|
|
there are no speed regressions */
|
|
/* C = _mzd_mul_m4rm(C, A, B, 0, TRUE); */
|
|
mzd_t *Cbar = mzd_init(C->nrows, C->ncols);
|
|
Cbar = _mzd_mul_m4rm(Cbar, A, B, 0, FALSE);
|
|
mzd_add(C, C, Cbar);
|
|
mzd_free(Cbar);
|
|
return C;
|
|
}
|
|
|
|
/* adjust cutting numbers to work on words */
|
|
{
|
|
rci_t mult = 2 * m4ri_radix;
|
|
a -= a % mult;
|
|
b -= b % mult;
|
|
c -= c % mult;
|
|
}
|
|
|
|
rci_t anr = ((a / m4ri_radix) >> 1) * m4ri_radix;
|
|
rci_t anc = ((b / m4ri_radix) >> 1) * m4ri_radix;
|
|
rci_t bnr = anc;
|
|
rci_t bnc = ((c / m4ri_radix) >> 1) * m4ri_radix;
|
|
|
|
mzd_t const *A00 = mzd_init_window_const(A, 0, 0, anr, anc);
|
|
mzd_t const *A01 = mzd_init_window_const(A, 0, anc, anr, 2*anc);
|
|
mzd_t const *A10 = mzd_init_window_const(A, anr, 0, 2*anr, anc);
|
|
mzd_t const *A11 = mzd_init_window_const(A, anr, anc, 2*anr, 2*anc);
|
|
|
|
mzd_t const *B00 = mzd_init_window_const(B, 0, 0, bnr, bnc);
|
|
mzd_t const *B01 = mzd_init_window_const(B, 0, bnc, bnr, 2*bnc);
|
|
mzd_t const *B10 = mzd_init_window_const(B, bnr, 0, 2*bnr, bnc);
|
|
mzd_t const *B11 = mzd_init_window_const(B, bnr, bnc, 2*bnr, 2*bnc);
|
|
|
|
mzd_t *C00 = mzd_init_window(C, 0, 0, anr, bnc);
|
|
mzd_t *C01 = mzd_init_window(C, 0, bnc, anr, 2*bnc);
|
|
mzd_t *C10 = mzd_init_window(C, anr, 0, 2*anr, bnc);
|
|
mzd_t *C11 = mzd_init_window(C, anr, bnc, 2*anr, 2*bnc);
|
|
|
|
#pragma omp parallel sections
|
|
{
|
|
#pragma omp section
|
|
{
|
|
_mzd_addmul_even(C00, A00, B00, cutoff);
|
|
_mzd_addmul_even(C00, A01, B10, cutoff);
|
|
}
|
|
#pragma omp section
|
|
{
|
|
_mzd_addmul_even(C01, A00, B01, cutoff);
|
|
_mzd_addmul_even(C01, A01, B11, cutoff);
|
|
}
|
|
#pragma omp section
|
|
{
|
|
_mzd_addmul_even(C10, A10, B00, cutoff);
|
|
_mzd_addmul_even(C10, A11, B10, cutoff);
|
|
}
|
|
#pragma omp section
|
|
{
|
|
_mzd_addmul_even(C11, A10, B01, cutoff);
|
|
_mzd_addmul_even(C11, A11, B11, cutoff);
|
|
}
|
|
}
|
|
|
|
/* deal with rest */
|
|
if (B->ncols > 2 * bnc) {
|
|
mzd_t const *B_last_col = mzd_init_window_const(B, 0, 2*bnc, A->ncols, B->ncols);
|
|
mzd_t *C_last_col = mzd_init_window(C, 0, 2*bnc, A->nrows, C->ncols);
|
|
mzd_addmul_m4rm(C_last_col, A, B_last_col, 0);
|
|
mzd_free_window((mzd_t*)B_last_col);
|
|
mzd_free_window(C_last_col);
|
|
}
|
|
if (A->nrows > 2 * anr) {
|
|
mzd_t const *A_last_row = mzd_init_window_const(A, 2*anr, 0, A->nrows, A->ncols);
|
|
mzd_t const *B_bulk = mzd_init_window_const(B, 0, 0, B->nrows, 2*bnc);
|
|
mzd_t *C_last_row = mzd_init_window(C, 2*anr, 0, C->nrows, 2*bnc);
|
|
mzd_addmul_m4rm(C_last_row, A_last_row, B_bulk, 0);
|
|
mzd_free_window((mzd_t*)A_last_row);
|
|
mzd_free_window((mzd_t*)B_bulk);
|
|
mzd_free_window(C_last_row);
|
|
}
|
|
if (A->ncols > 2 * anc) {
|
|
mzd_t const *A_last_col = mzd_init_window_const(A, 0, 2*anc, 2*anr, A->ncols);
|
|
mzd_t const *B_last_row = mzd_init_window_const(B, 2*bnr, 0, B->nrows, 2*bnc);
|
|
mzd_t *C_bulk = mzd_init_window(C, 0, 0, 2*anr, 2*bnc);
|
|
mzd_addmul_m4rm(C_bulk, A_last_col, B_last_row, 0);
|
|
mzd_free_window((mzd_t*)A_last_col);
|
|
mzd_free_window((mzd_t*)B_last_row);
|
|
mzd_free_window(C_bulk);
|
|
}
|
|
|
|
/* clean up */
|
|
mzd_free_window((mzd_t*)A00); mzd_free_window((mzd_t*)A01);
|
|
mzd_free_window((mzd_t*)A10); mzd_free_window((mzd_t*)A11);
|
|
|
|
mzd_free_window((mzd_t*)B00); mzd_free_window((mzd_t*)B01);
|
|
mzd_free_window((mzd_t*)B10); mzd_free_window((mzd_t*)B11);
|
|
|
|
mzd_free_window(C00); mzd_free_window(C01);
|
|
mzd_free_window(C10); mzd_free_window(C11);
|
|
|
|
__M4RI_DD_MZD(C);
|
|
return C;
|
|
}
|
|
|
|
mzd_t *_mzd_mul_mp4(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) {
|
|
/**
|
|
* \todo make sure not to overwrite crap after ncols and before width * m4ri_radix
|
|
*/
|
|
rci_t a = A->nrows;
|
|
rci_t b = A->ncols;
|
|
rci_t c = B->ncols;
|
|
/* handle case first, where the input matrices are too small already */
|
|
if (closer(A->nrows, cutoff) || closer(A->ncols, cutoff) || closer(B->ncols, cutoff)) {
|
|
/* we copy the matrix first since it is only constant memory
|
|
overhead and improves data locality, if you remove it make sure
|
|
there are no speed regressions */
|
|
/* C = _mzd_mul_m4rm(C, A, B, 0, TRUE); */
|
|
mzd_t *Cbar = mzd_init(C->nrows, C->ncols);
|
|
Cbar = _mzd_mul_m4rm(Cbar, A, B, 0, FALSE);
|
|
mzd_copy(C, Cbar);
|
|
mzd_free(Cbar);
|
|
return C;
|
|
}
|
|
|
|
/* adjust cutting numbers to work on words */
|
|
{
|
|
rci_t mult = 2 * m4ri_radix;
|
|
a -= a % mult;
|
|
b -= b % mult;
|
|
c -= c % mult;
|
|
}
|
|
|
|
rci_t anr = ((a / m4ri_radix) >> 1) * m4ri_radix;
|
|
rci_t anc = ((b / m4ri_radix) >> 1) * m4ri_radix;
|
|
rci_t bnr = anc;
|
|
rci_t bnc = ((c / m4ri_radix) >> 1) * m4ri_radix;
|
|
|
|
mzd_t const *A00 = mzd_init_window_const(A, 0, 0, anr, anc);
|
|
mzd_t const *A01 = mzd_init_window_const(A, 0, anc, anr, 2*anc);
|
|
mzd_t const *A10 = mzd_init_window_const(A, anr, 0, 2*anr, anc);
|
|
mzd_t const *A11 = mzd_init_window_const(A, anr, anc, 2*anr, 2*anc);
|
|
|
|
mzd_t const *B00 = mzd_init_window_const(B, 0, 0, bnr, bnc);
|
|
mzd_t const *B01 = mzd_init_window_const(B, 0, bnc, bnr, 2*bnc);
|
|
mzd_t const *B10 = mzd_init_window_const(B, bnr, 0, 2*bnr, bnc);
|
|
mzd_t const *B11 = mzd_init_window_const(B, bnr, bnc, 2*bnr, 2*bnc);
|
|
|
|
mzd_t *C00 = mzd_init_window(C, 0, 0, anr, bnc);
|
|
mzd_t *C01 = mzd_init_window(C, 0, bnc, anr, 2*bnc);
|
|
mzd_t *C10 = mzd_init_window(C, anr, 0, 2*anr, bnc);
|
|
mzd_t *C11 = mzd_init_window(C, anr, bnc, 2*anr, 2*bnc);
|
|
|
|
#pragma omp parallel sections
|
|
{
|
|
#pragma omp section
|
|
{
|
|
_mzd_mul_even(C00, A00, B00, cutoff);
|
|
_mzd_addmul_even(C00, A01, B10, cutoff);
|
|
}
|
|
#pragma omp section
|
|
{
|
|
_mzd_mul_even(C01, A00, B01, cutoff);
|
|
_mzd_addmul_even(C01, A01, B11, cutoff);
|
|
}
|
|
#pragma omp section
|
|
{
|
|
_mzd_mul_even(C10, A10, B00, cutoff);
|
|
_mzd_addmul_even(C10, A11, B10, cutoff);
|
|
}
|
|
#pragma omp section
|
|
{
|
|
_mzd_mul_even(C11, A10, B01, cutoff);
|
|
_mzd_addmul_even(C11, A11, B11, cutoff);
|
|
}
|
|
}
|
|
|
|
/* deal with rest */
|
|
if (B->ncols > 2 * bnc) {
|
|
mzd_t const *B_last_col = mzd_init_window_const(B, 0, 2*bnc, A->ncols, B->ncols);
|
|
mzd_t *C_last_col = mzd_init_window(C, 0, 2*bnc, A->nrows, C->ncols);
|
|
mzd_addmul_m4rm(C_last_col, A, B_last_col, 0);
|
|
mzd_free_window((mzd_t*)B_last_col);
|
|
mzd_free_window(C_last_col);
|
|
}
|
|
if (A->nrows > 2 * anr) {
|
|
mzd_t const *A_last_row = mzd_init_window_const(A, 2*anr, 0, A->nrows, A->ncols);
|
|
mzd_t const *B_bulk = mzd_init_window_const(B, 0, 0, B->nrows, 2*bnc);
|
|
mzd_t *C_last_row = mzd_init_window(C, 2*anr, 0, C->nrows, 2*bnc);
|
|
mzd_addmul_m4rm(C_last_row, A_last_row, B_bulk, 0);
|
|
mzd_free_window((mzd_t*)A_last_row);
|
|
mzd_free_window((mzd_t*)B_bulk);
|
|
mzd_free_window(C_last_row);
|
|
}
|
|
if (A->ncols > 2 * anc) {
|
|
mzd_t const *A_last_col = mzd_init_window_const(A, 0, 2*anc, 2*anr, A->ncols);
|
|
mzd_t const *B_last_row = mzd_init_window_const(B, 2*bnr, 0, B->nrows, 2*bnc);
|
|
mzd_t *C_bulk = mzd_init_window(C, 0, 0, 2*anr, 2*bnc);
|
|
mzd_addmul_m4rm(C_bulk, A_last_col, B_last_row, 0);
|
|
mzd_free_window((mzd_t*)A_last_col);
|
|
mzd_free_window((mzd_t*)B_last_row);
|
|
mzd_free_window(C_bulk);
|
|
}
|
|
|
|
/* clean up */
|
|
mzd_free_window((mzd_t*)A00); mzd_free_window((mzd_t*)A01);
|
|
mzd_free_window((mzd_t*)A10); mzd_free_window((mzd_t*)A11);
|
|
|
|
mzd_free_window((mzd_t*)B00); mzd_free_window((mzd_t*)B01);
|
|
mzd_free_window((mzd_t*)B10); mzd_free_window((mzd_t*)B11);
|
|
|
|
mzd_free_window(C00); mzd_free_window(C01);
|
|
mzd_free_window(C10); mzd_free_window(C11);
|
|
|
|
__M4RI_DD_MZD(C);
|
|
return C;
|
|
}
|
|
|
|
mzd_t *mzd_mul_mp(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) {
|
|
if(A->ncols != B->nrows)
|
|
m4ri_die("mzd_mul_mp: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows);
|
|
|
|
if (cutoff < 0)
|
|
m4ri_die("mzd_mul_mp: cutoff must be >= 0.\n");
|
|
|
|
if(cutoff == 0) {
|
|
cutoff = __M4RI_STRASSEN_MUL_CUTOFF;
|
|
}
|
|
|
|
cutoff = cutoff / m4ri_radix * m4ri_radix;
|
|
if (cutoff < m4ri_radix) {
|
|
cutoff = m4ri_radix;
|
|
};
|
|
|
|
if (C == NULL) {
|
|
C = mzd_init(A->nrows, B->ncols);
|
|
} else if (C->nrows != A->nrows || C->ncols != B->ncols){
|
|
m4ri_die("mzd_mul_mp: C (%d x %d) has wrong dimensions, expected (%d x %d)\n",
|
|
C->nrows, C->ncols, A->nrows, B->ncols);
|
|
}
|
|
|
|
_mzd_mul_mp4(C, A, B, cutoff);
|
|
return C;
|
|
}
|
|
|
|
|
|
mzd_t *mzd_addmul_mp(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) {
|
|
if(A->ncols != B->nrows)
|
|
m4ri_die("mzd_addmul_mp: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows);
|
|
|
|
if (cutoff < 0)
|
|
m4ri_die("mzd_addmul_mp: cutoff must be >= 0.\n");
|
|
|
|
if(cutoff == 0) {
|
|
cutoff = __M4RI_STRASSEN_MUL_CUTOFF;
|
|
}
|
|
|
|
cutoff = cutoff / m4ri_radix * m4ri_radix;
|
|
if (cutoff < m4ri_radix) {
|
|
cutoff = m4ri_radix;
|
|
};
|
|
|
|
if (C == NULL) {
|
|
C = mzd_init(A->nrows, B->ncols);
|
|
} else if (C->nrows != A->nrows || C->ncols != B->ncols){
|
|
m4ri_die("mzd_addmul_mp: C (%d x %d) has wrong dimensions, expected (%d x %d)\n",
|
|
C->nrows, C->ncols, A->nrows, B->ncols);
|
|
}
|
|
if(A->nrows == 0 || A->ncols == 0 || B->ncols == 0) {
|
|
__M4RI_DD_MZD(C);
|
|
return C;
|
|
}
|
|
|
|
C = _mzd_addmul_mp4(C, A, B, cutoff);
|
|
__M4RI_DD_MZD(C);
|
|
return C;
|
|
}
|
|
|
|
|
|
#endif //__M4RI_HAVE_OPENMP
|