/******************************************************************* * * M4RI: Linear Algebra over GF(2) * * Copyright (C) 2014 Martin Albrecht * * Distributed under the terms of the GNU General Public License (GPL) * version 2 or higher. * * This code is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * General Public License for more details. * * The full text of the GPL is available at: * * http://www.gnu.org/licenses/ * ********************************************************************/ #include "m4ri_config.h" #include "misc.h" #include "mp.h" #include "brilliantrussian.h" #include "strassen.h" #if __M4RI_HAVE_OPENMP #ifndef MIN #define MIN(a,b) (((a)<(b))?(a):(b)) #endif #include // Returns true if a is closer to cutoff than a/2. static inline int closer(rci_t a, int cutoff) { return 3 * a < 4 * cutoff; } mzd_t *_mzd_addmul_mp4(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) { /** * \todo make sure not to overwrite crap after ncols and before width * m4ri_radix */ rci_t a = A->nrows; rci_t b = A->ncols; rci_t c = B->ncols; /* handle case first, where the input matrices are too small already */ if (closer(A->nrows, cutoff) || closer(A->ncols, cutoff) || closer(B->ncols, cutoff)) { /* we copy the matrix first since it is only constant memory overhead and improves data locality, if you remove it make sure there are no speed regressions */ /* C = _mzd_mul_m4rm(C, A, B, 0, TRUE); */ mzd_t *Cbar = mzd_init(C->nrows, C->ncols); Cbar = _mzd_mul_m4rm(Cbar, A, B, 0, FALSE); mzd_add(C, C, Cbar); mzd_free(Cbar); return C; } /* adjust cutting numbers to work on words */ { rci_t mult = 2 * m4ri_radix; a -= a % mult; b -= b % mult; c -= c % mult; } rci_t anr = ((a / m4ri_radix) >> 1) * m4ri_radix; rci_t anc = ((b / m4ri_radix) >> 1) * m4ri_radix; rci_t bnr = anc; rci_t bnc = ((c / m4ri_radix) >> 1) * m4ri_radix; mzd_t const *A00 = mzd_init_window_const(A, 0, 0, anr, anc); mzd_t const *A01 = mzd_init_window_const(A, 0, anc, anr, 2*anc); mzd_t const *A10 = mzd_init_window_const(A, anr, 0, 2*anr, anc); mzd_t const *A11 = mzd_init_window_const(A, anr, anc, 2*anr, 2*anc); mzd_t const *B00 = mzd_init_window_const(B, 0, 0, bnr, bnc); mzd_t const *B01 = mzd_init_window_const(B, 0, bnc, bnr, 2*bnc); mzd_t const *B10 = mzd_init_window_const(B, bnr, 0, 2*bnr, bnc); mzd_t const *B11 = mzd_init_window_const(B, bnr, bnc, 2*bnr, 2*bnc); mzd_t *C00 = mzd_init_window(C, 0, 0, anr, bnc); mzd_t *C01 = mzd_init_window(C, 0, bnc, anr, 2*bnc); mzd_t *C10 = mzd_init_window(C, anr, 0, 2*anr, bnc); mzd_t *C11 = mzd_init_window(C, anr, bnc, 2*anr, 2*bnc); #pragma omp parallel sections { #pragma omp section { _mzd_addmul_even(C00, A00, B00, cutoff); _mzd_addmul_even(C00, A01, B10, cutoff); } #pragma omp section { _mzd_addmul_even(C01, A00, B01, cutoff); _mzd_addmul_even(C01, A01, B11, cutoff); } #pragma omp section { _mzd_addmul_even(C10, A10, B00, cutoff); _mzd_addmul_even(C10, A11, B10, cutoff); } #pragma omp section { _mzd_addmul_even(C11, A10, B01, cutoff); _mzd_addmul_even(C11, A11, B11, cutoff); } } /* deal with rest */ if (B->ncols > 2 * bnc) { mzd_t const *B_last_col = mzd_init_window_const(B, 0, 2*bnc, A->ncols, B->ncols); mzd_t *C_last_col = mzd_init_window(C, 0, 2*bnc, A->nrows, C->ncols); mzd_addmul_m4rm(C_last_col, A, B_last_col, 0); mzd_free_window((mzd_t*)B_last_col); mzd_free_window(C_last_col); } if (A->nrows > 2 * anr) { mzd_t const *A_last_row = mzd_init_window_const(A, 2*anr, 0, A->nrows, A->ncols); mzd_t const *B_bulk = mzd_init_window_const(B, 0, 0, B->nrows, 2*bnc); mzd_t *C_last_row = mzd_init_window(C, 2*anr, 0, C->nrows, 2*bnc); mzd_addmul_m4rm(C_last_row, A_last_row, B_bulk, 0); mzd_free_window((mzd_t*)A_last_row); mzd_free_window((mzd_t*)B_bulk); mzd_free_window(C_last_row); } if (A->ncols > 2 * anc) { mzd_t const *A_last_col = mzd_init_window_const(A, 0, 2*anc, 2*anr, A->ncols); mzd_t const *B_last_row = mzd_init_window_const(B, 2*bnr, 0, B->nrows, 2*bnc); mzd_t *C_bulk = mzd_init_window(C, 0, 0, 2*anr, 2*bnc); mzd_addmul_m4rm(C_bulk, A_last_col, B_last_row, 0); mzd_free_window((mzd_t*)A_last_col); mzd_free_window((mzd_t*)B_last_row); mzd_free_window(C_bulk); } /* clean up */ mzd_free_window((mzd_t*)A00); mzd_free_window((mzd_t*)A01); mzd_free_window((mzd_t*)A10); mzd_free_window((mzd_t*)A11); mzd_free_window((mzd_t*)B00); mzd_free_window((mzd_t*)B01); mzd_free_window((mzd_t*)B10); mzd_free_window((mzd_t*)B11); mzd_free_window(C00); mzd_free_window(C01); mzd_free_window(C10); mzd_free_window(C11); __M4RI_DD_MZD(C); return C; } mzd_t *_mzd_mul_mp4(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) { /** * \todo make sure not to overwrite crap after ncols and before width * m4ri_radix */ rci_t a = A->nrows; rci_t b = A->ncols; rci_t c = B->ncols; /* handle case first, where the input matrices are too small already */ if (closer(A->nrows, cutoff) || closer(A->ncols, cutoff) || closer(B->ncols, cutoff)) { /* we copy the matrix first since it is only constant memory overhead and improves data locality, if you remove it make sure there are no speed regressions */ /* C = _mzd_mul_m4rm(C, A, B, 0, TRUE); */ mzd_t *Cbar = mzd_init(C->nrows, C->ncols); Cbar = _mzd_mul_m4rm(Cbar, A, B, 0, FALSE); mzd_copy(C, Cbar); mzd_free(Cbar); return C; } /* adjust cutting numbers to work on words */ { rci_t mult = 2 * m4ri_radix; a -= a % mult; b -= b % mult; c -= c % mult; } rci_t anr = ((a / m4ri_radix) >> 1) * m4ri_radix; rci_t anc = ((b / m4ri_radix) >> 1) * m4ri_radix; rci_t bnr = anc; rci_t bnc = ((c / m4ri_radix) >> 1) * m4ri_radix; mzd_t const *A00 = mzd_init_window_const(A, 0, 0, anr, anc); mzd_t const *A01 = mzd_init_window_const(A, 0, anc, anr, 2*anc); mzd_t const *A10 = mzd_init_window_const(A, anr, 0, 2*anr, anc); mzd_t const *A11 = mzd_init_window_const(A, anr, anc, 2*anr, 2*anc); mzd_t const *B00 = mzd_init_window_const(B, 0, 0, bnr, bnc); mzd_t const *B01 = mzd_init_window_const(B, 0, bnc, bnr, 2*bnc); mzd_t const *B10 = mzd_init_window_const(B, bnr, 0, 2*bnr, bnc); mzd_t const *B11 = mzd_init_window_const(B, bnr, bnc, 2*bnr, 2*bnc); mzd_t *C00 = mzd_init_window(C, 0, 0, anr, bnc); mzd_t *C01 = mzd_init_window(C, 0, bnc, anr, 2*bnc); mzd_t *C10 = mzd_init_window(C, anr, 0, 2*anr, bnc); mzd_t *C11 = mzd_init_window(C, anr, bnc, 2*anr, 2*bnc); #pragma omp parallel sections { #pragma omp section { _mzd_mul_even(C00, A00, B00, cutoff); _mzd_addmul_even(C00, A01, B10, cutoff); } #pragma omp section { _mzd_mul_even(C01, A00, B01, cutoff); _mzd_addmul_even(C01, A01, B11, cutoff); } #pragma omp section { _mzd_mul_even(C10, A10, B00, cutoff); _mzd_addmul_even(C10, A11, B10, cutoff); } #pragma omp section { _mzd_mul_even(C11, A10, B01, cutoff); _mzd_addmul_even(C11, A11, B11, cutoff); } } /* deal with rest */ if (B->ncols > 2 * bnc) { mzd_t const *B_last_col = mzd_init_window_const(B, 0, 2*bnc, A->ncols, B->ncols); mzd_t *C_last_col = mzd_init_window(C, 0, 2*bnc, A->nrows, C->ncols); mzd_addmul_m4rm(C_last_col, A, B_last_col, 0); mzd_free_window((mzd_t*)B_last_col); mzd_free_window(C_last_col); } if (A->nrows > 2 * anr) { mzd_t const *A_last_row = mzd_init_window_const(A, 2*anr, 0, A->nrows, A->ncols); mzd_t const *B_bulk = mzd_init_window_const(B, 0, 0, B->nrows, 2*bnc); mzd_t *C_last_row = mzd_init_window(C, 2*anr, 0, C->nrows, 2*bnc); mzd_addmul_m4rm(C_last_row, A_last_row, B_bulk, 0); mzd_free_window((mzd_t*)A_last_row); mzd_free_window((mzd_t*)B_bulk); mzd_free_window(C_last_row); } if (A->ncols > 2 * anc) { mzd_t const *A_last_col = mzd_init_window_const(A, 0, 2*anc, 2*anr, A->ncols); mzd_t const *B_last_row = mzd_init_window_const(B, 2*bnr, 0, B->nrows, 2*bnc); mzd_t *C_bulk = mzd_init_window(C, 0, 0, 2*anr, 2*bnc); mzd_addmul_m4rm(C_bulk, A_last_col, B_last_row, 0); mzd_free_window((mzd_t*)A_last_col); mzd_free_window((mzd_t*)B_last_row); mzd_free_window(C_bulk); } /* clean up */ mzd_free_window((mzd_t*)A00); mzd_free_window((mzd_t*)A01); mzd_free_window((mzd_t*)A10); mzd_free_window((mzd_t*)A11); mzd_free_window((mzd_t*)B00); mzd_free_window((mzd_t*)B01); mzd_free_window((mzd_t*)B10); mzd_free_window((mzd_t*)B11); mzd_free_window(C00); mzd_free_window(C01); mzd_free_window(C10); mzd_free_window(C11); __M4RI_DD_MZD(C); return C; } mzd_t *mzd_mul_mp(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) { if(A->ncols != B->nrows) m4ri_die("mzd_mul_mp: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows); if (cutoff < 0) m4ri_die("mzd_mul_mp: cutoff must be >= 0.\n"); if(cutoff == 0) { cutoff = __M4RI_STRASSEN_MUL_CUTOFF; } cutoff = cutoff / m4ri_radix * m4ri_radix; if (cutoff < m4ri_radix) { cutoff = m4ri_radix; }; if (C == NULL) { C = mzd_init(A->nrows, B->ncols); } else if (C->nrows != A->nrows || C->ncols != B->ncols){ m4ri_die("mzd_mul_mp: C (%d x %d) has wrong dimensions, expected (%d x %d)\n", C->nrows, C->ncols, A->nrows, B->ncols); } _mzd_mul_mp4(C, A, B, cutoff); return C; } mzd_t *mzd_addmul_mp(mzd_t *C, mzd_t const *A, mzd_t const *B, int cutoff) { if(A->ncols != B->nrows) m4ri_die("mzd_addmul_mp: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows); if (cutoff < 0) m4ri_die("mzd_addmul_mp: cutoff must be >= 0.\n"); if(cutoff == 0) { cutoff = __M4RI_STRASSEN_MUL_CUTOFF; } cutoff = cutoff / m4ri_radix * m4ri_radix; if (cutoff < m4ri_radix) { cutoff = m4ri_radix; }; if (C == NULL) { C = mzd_init(A->nrows, B->ncols); } else if (C->nrows != A->nrows || C->ncols != B->ncols){ m4ri_die("mzd_addmul_mp: C (%d x %d) has wrong dimensions, expected (%d x %d)\n", C->nrows, C->ncols, A->nrows, B->ncols); } if(A->nrows == 0 || A->ncols == 0 || B->ncols == 0) { __M4RI_DD_MZD(C); return C; } C = _mzd_addmul_mp4(C, A, B, cutoff); __M4RI_DD_MZD(C); return C; } #endif //__M4RI_HAVE_OPENMP