490 lines
21 KiB
C

/*******************************************************************
*
* M4RI: Linear Algebra over GF(2)
*
* Copyright (C) 2008 Clement Pernet <clement.pernet@gmail.com>
*
* Distributed under the terms of the GNU General Public License (GPL)
* version 2 or higher.
*
* This code is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* The full text of the GPL is available at:
*
* http://www.gnu.org/licenses/
*
********************************************************************/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include <stdio.h>
#include "triangular.h"
#include "triangular_russian.h"
#include "strassen.h"
#include "mzd.h"
#include "parity.h"
/*****************
* UPPER RIGHT
****************/
/*
* Assumes that U->ncols < 64
*/
void _mzd_trsm_upper_right_base(mzd_t const *U, mzd_t *B);
void mzd_trsm_upper_right(mzd_t const *U, mzd_t *B, const int cutoff) {
if(U->nrows != B->ncols)
m4ri_die("mzd_trsm_upper_right: U nrows (%d) need to match B ncols (%d).\n", U->nrows, B->ncols);
if(U->nrows != U->ncols)
m4ri_die("mzd_trsm_upper_right: U must be square and is found to be (%d) x (%d).\n", U->nrows, U->ncols);
_mzd_trsm_upper_right(U, B, cutoff);
}
void _mzd_trsm_upper_right_trtri(mzd_t const *U, mzd_t *B) {
mzd_t *u = mzd_extract_u(NULL, U);
mzd_trtri_upper(u);
mzd_t *C = mzd_mul(NULL, B, u, 0);
mzd_copy(B, C);
mzd_free(C);
mzd_free(u);
}
void _mzd_trsm_upper_right(mzd_t const *U, mzd_t *B, const int cutoff) {
rci_t const mb = B->nrows;
rci_t const nb = B->ncols;
if(nb <= m4ri_radix) {
/* base case */
_mzd_trsm_upper_right_base(U, B);
return;
} else if(nb <= __M4RI_MUL_BLOCKSIZE) {
_mzd_trsm_upper_right_trtri(U, B);
return;
}
rci_t const nb1 = (((nb - 1) / m4ri_radix + 1) >> 1) * m4ri_radix;
/**
\verbatim
_________
\U00| |
\ |U01|
\ | |
\|___|
\U11|
\ |
\ |
\|
_______
|B0 |B1 |
|___|___|
\endverbatim
*/
mzd_t *B0 = mzd_init_window(B, 0, 0, mb, nb1);
mzd_t *B1 = mzd_init_window(B, 0, nb1, mb, nb);
mzd_t const *U00 = mzd_init_window_const(U, 0, 0, nb1, nb1);
mzd_t const *U01 = mzd_init_window_const(U, 0, nb1, nb1, nb);
mzd_t const *U11 = mzd_init_window_const(U, nb1, nb1, nb, nb);
_mzd_trsm_upper_right(U00, B0, cutoff);
mzd_addmul (B1, B0, U01, cutoff);
_mzd_trsm_upper_right(U11, B1, cutoff);
mzd_free_window(B0);
mzd_free_window(B1);
mzd_free_window((mzd_t*)U00);
mzd_free_window((mzd_t*)U01);
mzd_free_window((mzd_t*)U11);
__M4RI_DD_MZD(B);
}
void _mzd_trsm_upper_right_base(mzd_t const *U, mzd_t *B) {
rci_t const mb = B->nrows;
rci_t const nb = B->ncols;
for(rci_t i = 1; i < nb; ++i) {
/* Computes X_i = B_i + X_{0..i-1} U_{0..i-1,i} */
register word ucol = 0;
for(rci_t k = 0; k < i; ++k) {
if(__M4RI_GET_BIT(U->rows[k][0], i))
__M4RI_SET_BIT(ucol, k);
}
/* doing 64 dotproducts at a time, to use the m4ri_parity64 parallelism */
rci_t giantstep;
word tmp[64];
for(giantstep = 0; giantstep + m4ri_radix < mb; giantstep += m4ri_radix) {
#if 0
for(int babystep = 0; babystep < m4ri_radix; ++babystep)
tmp[babystep] = B->rows[giantstep + babystep][0] & ucol;
#else
word **src = B->rows + giantstep;
tmp[ 0] = src[ 0][0] & ucol, tmp[ 1] = src[ 1][0] & ucol, tmp[ 2] = src[ 2][0] & ucol, tmp[ 3] = src[ 3][0] & ucol;
tmp[ 4] = src[ 4][0] & ucol, tmp[ 5] = src[ 5][0] & ucol, tmp[ 6] = src[ 6][0] & ucol, tmp[ 7] = src[ 7][0] & ucol;
tmp[ 8] = src[ 8][0] & ucol, tmp[ 9] = src[ 9][0] & ucol, tmp[10] = src[10][0] & ucol, tmp[11] = src[11][0] & ucol;
tmp[12] = src[12][0] & ucol, tmp[13] = src[13][0] & ucol, tmp[14] = src[14][0] & ucol, tmp[15] = src[15][0] & ucol;
tmp[16] = src[16][0] & ucol, tmp[17] = src[17][0] & ucol, tmp[18] = src[18][0] & ucol, tmp[19] = src[19][0] & ucol;
tmp[20] = src[20][0] & ucol, tmp[21] = src[21][0] & ucol, tmp[22] = src[22][0] & ucol, tmp[23] = src[23][0] & ucol;
tmp[24] = src[24][0] & ucol, tmp[25] = src[25][0] & ucol, tmp[26] = src[26][0] & ucol, tmp[27] = src[27][0] & ucol;
tmp[28] = src[28][0] & ucol, tmp[29] = src[29][0] & ucol, tmp[30] = src[30][0] & ucol, tmp[31] = src[31][0] & ucol;
tmp[32] = src[32][0] & ucol, tmp[33] = src[33][0] & ucol, tmp[34] = src[34][0] & ucol, tmp[35] = src[35][0] & ucol;
tmp[36] = src[36][0] & ucol, tmp[37] = src[37][0] & ucol, tmp[38] = src[38][0] & ucol, tmp[39] = src[39][0] & ucol;
tmp[40] = src[40][0] & ucol, tmp[41] = src[41][0] & ucol, tmp[42] = src[42][0] & ucol, tmp[43] = src[43][0] & ucol;
tmp[44] = src[44][0] & ucol, tmp[45] = src[45][0] & ucol, tmp[46] = src[46][0] & ucol, tmp[47] = src[47][0] & ucol;
tmp[48] = src[48][0] & ucol, tmp[49] = src[49][0] & ucol, tmp[50] = src[50][0] & ucol, tmp[51] = src[51][0] & ucol;
tmp[52] = src[52][0] & ucol, tmp[53] = src[53][0] & ucol, tmp[54] = src[54][0] & ucol, tmp[55] = src[55][0] & ucol;
tmp[56] = src[56][0] & ucol, tmp[57] = src[57][0] & ucol, tmp[58] = src[58][0] & ucol, tmp[59] = src[59][0] & ucol;
tmp[60] = src[60][0] & ucol, tmp[61] = src[61][0] & ucol, tmp[62] = src[62][0] & ucol, tmp[63] = src[63][0] & ucol;
#endif
word const dotprod = m4ri_parity64(tmp);
#if 0
for(int babystep = 0; babystep < m4ri_radix; ++babystep)
if(__M4RI_GET_BIT(dotprod, babystep))
__M4RI_FLIP_BIT(B->rows[giantstep + babystep][0], i);
#else
src[ 0][0] ^= ((dotprod>> 0)&m4ri_one)<<i, src[ 1][0] ^= ((dotprod>> 1)&m4ri_one)<<i, src[ 2][0] ^= ((dotprod>> 2)&m4ri_one)<<i, src[ 3][0] ^= ((dotprod>> 3)&m4ri_one)<<i;
src[ 4][0] ^= ((dotprod>> 4)&m4ri_one)<<i, src[ 5][0] ^= ((dotprod>> 5)&m4ri_one)<<i, src[ 6][0] ^= ((dotprod>> 6)&m4ri_one)<<i, src[ 7][0] ^= ((dotprod>> 7)&m4ri_one)<<i;
src[ 8][0] ^= ((dotprod>> 8)&m4ri_one)<<i, src[ 9][0] ^= ((dotprod>> 9)&m4ri_one)<<i, src[10][0] ^= ((dotprod>>10)&m4ri_one)<<i, src[11][0] ^= ((dotprod>>11)&m4ri_one)<<i;
src[12][0] ^= ((dotprod>>12)&m4ri_one)<<i, src[13][0] ^= ((dotprod>>13)&m4ri_one)<<i, src[14][0] ^= ((dotprod>>14)&m4ri_one)<<i, src[15][0] ^= ((dotprod>>15)&m4ri_one)<<i;
src[16][0] ^= ((dotprod>>16)&m4ri_one)<<i, src[17][0] ^= ((dotprod>>17)&m4ri_one)<<i, src[18][0] ^= ((dotprod>>18)&m4ri_one)<<i, src[19][0] ^= ((dotprod>>19)&m4ri_one)<<i;
src[20][0] ^= ((dotprod>>20)&m4ri_one)<<i, src[21][0] ^= ((dotprod>>21)&m4ri_one)<<i, src[22][0] ^= ((dotprod>>22)&m4ri_one)<<i, src[23][0] ^= ((dotprod>>23)&m4ri_one)<<i;
src[24][0] ^= ((dotprod>>24)&m4ri_one)<<i, src[25][0] ^= ((dotprod>>25)&m4ri_one)<<i, src[26][0] ^= ((dotprod>>26)&m4ri_one)<<i, src[27][0] ^= ((dotprod>>27)&m4ri_one)<<i;
src[28][0] ^= ((dotprod>>28)&m4ri_one)<<i, src[29][0] ^= ((dotprod>>29)&m4ri_one)<<i, src[30][0] ^= ((dotprod>>30)&m4ri_one)<<i, src[31][0] ^= ((dotprod>>31)&m4ri_one)<<i;
src[32][0] ^= ((dotprod>>32)&m4ri_one)<<i, src[33][0] ^= ((dotprod>>33)&m4ri_one)<<i, src[34][0] ^= ((dotprod>>34)&m4ri_one)<<i, src[35][0] ^= ((dotprod>>35)&m4ri_one)<<i;
src[36][0] ^= ((dotprod>>36)&m4ri_one)<<i, src[37][0] ^= ((dotprod>>37)&m4ri_one)<<i, src[38][0] ^= ((dotprod>>38)&m4ri_one)<<i, src[39][0] ^= ((dotprod>>39)&m4ri_one)<<i;
src[40][0] ^= ((dotprod>>40)&m4ri_one)<<i, src[41][0] ^= ((dotprod>>41)&m4ri_one)<<i, src[42][0] ^= ((dotprod>>42)&m4ri_one)<<i, src[43][0] ^= ((dotprod>>43)&m4ri_one)<<i;
src[44][0] ^= ((dotprod>>44)&m4ri_one)<<i, src[45][0] ^= ((dotprod>>45)&m4ri_one)<<i, src[46][0] ^= ((dotprod>>46)&m4ri_one)<<i, src[47][0] ^= ((dotprod>>47)&m4ri_one)<<i;
src[48][0] ^= ((dotprod>>48)&m4ri_one)<<i, src[49][0] ^= ((dotprod>>49)&m4ri_one)<<i, src[50][0] ^= ((dotprod>>50)&m4ri_one)<<i, src[51][0] ^= ((dotprod>>51)&m4ri_one)<<i;
src[52][0] ^= ((dotprod>>52)&m4ri_one)<<i, src[53][0] ^= ((dotprod>>53)&m4ri_one)<<i, src[54][0] ^= ((dotprod>>54)&m4ri_one)<<i, src[55][0] ^= ((dotprod>>55)&m4ri_one)<<i;
src[56][0] ^= ((dotprod>>56)&m4ri_one)<<i, src[57][0] ^= ((dotprod>>57)&m4ri_one)<<i, src[58][0] ^= ((dotprod>>58)&m4ri_one)<<i, src[59][0] ^= ((dotprod>>59)&m4ri_one)<<i;
src[60][0] ^= ((dotprod>>60)&m4ri_one)<<i, src[61][0] ^= ((dotprod>>61)&m4ri_one)<<i, src[62][0] ^= ((dotprod>>62)&m4ri_one)<<i, src[63][0] ^= ((dotprod>>63)&m4ri_one)<<i;
#endif
}
for(int babystep = 0; giantstep + babystep < mb; ++babystep)
tmp[babystep] = B->rows[giantstep + babystep][0] & ucol;
for(int babystep = mb - giantstep; babystep < 64; ++babystep)
tmp[babystep] = 0;
word const dotprod = m4ri_parity64(tmp);
for(int babystep = 0; giantstep + babystep < mb; ++babystep)
if(__M4RI_GET_BIT(dotprod, babystep))
__M4RI_FLIP_BIT(B->rows[giantstep + babystep][0], i);
}
__M4RI_DD_MZD(B);
}
/*****************
* LOWER RIGHT
****************/
void _mzd_trsm_lower_right_base(mzd_t const *L, mzd_t *B);
void mzd_trsm_lower_right(mzd_t const *L, mzd_t *B, const int cutoff) {
if(L->nrows != B->ncols)
m4ri_die("mzd_trsm_lower_right: L nrows (%d) need to match B ncols (%d).\n", L->nrows, B->ncols);
if(L->nrows != L->ncols)
m4ri_die("mzd_trsm_lower_right: L must be square and is found to be (%d) x (%d).\n", L->nrows, L->ncols);
_mzd_trsm_lower_right (L, B, cutoff);
}
void _mzd_trsm_lower_right(mzd_t const *L, mzd_t *B, const int cutoff) {
rci_t const mb = B->nrows;
rci_t const nb = B->ncols;
if(nb <= m4ri_radix) {
_mzd_trsm_lower_right_base (L, B);
return;
}
rci_t const nb1 = (((nb - 1) / m4ri_radix + 1) >> 1) * m4ri_radix;
/**
\verbatim
|\
| \
| \
|L00\
|____\
| |\
| | \
| | \
|L10 |L11\
|____|____\
_________
|B0 |B1 |
|____|____|
\endverbatim
*/
mzd_t *B0 = mzd_init_window(B, 0, 0, mb, nb1);
mzd_t *B1 = mzd_init_window(B, 0, nb1, mb, nb);
mzd_t const *L00 = mzd_init_window_const(L, 0, 0, nb1, nb1);
mzd_t const *L10 = mzd_init_window_const(L, nb1, 0, nb, nb1);
mzd_t const *L11 = mzd_init_window_const(L, nb1, nb1, nb, nb);
_mzd_trsm_lower_right(L11, B1, cutoff);
mzd_addmul (B0, B1, L10, cutoff);
_mzd_trsm_lower_right(L00, B0, cutoff);
mzd_free_window(B0);
mzd_free_window(B1);
mzd_free_window((mzd_t*)L00);
mzd_free_window((mzd_t*)L10);
mzd_free_window((mzd_t*)L11);
__M4RI_DD_MZD(B);
}
void _mzd_trsm_lower_right_base(mzd_t const *L, mzd_t *B) {
rci_t const mb = B->nrows;
rci_t const nb = B->ncols;
for(rci_t i = nb - 1; i >= 0; --i) {
/* Computes X_i = B_i + X_{i+1,n} L_{i+1..n,i} */
register word ucol = 0;
for(rci_t k = i + 1; k < nb; ++k) {
if(__M4RI_GET_BIT(L->rows[k][0], i))
__M4RI_SET_BIT(ucol, k);
}
/* doing 64 dotproducts at a time, to use the parity64 parallelism */
rci_t giantstep;
word tmp[64];
for(giantstep = 0; giantstep + m4ri_radix < mb; giantstep += m4ri_radix) {
#if 0
for(int babystep = 0; babystep < m4ri_radix; ++babystep)
tmp[babystep] = B->rows[giantstep + babystep][0] & ucol;
#else
word **src = B->rows + giantstep;
tmp[ 0] = src[ 0][0] & ucol, tmp[ 1] = src[ 1][0] & ucol, tmp[ 2] = src[ 2][0] & ucol, tmp[ 3] = src[ 3][0] & ucol;
tmp[ 4] = src[ 4][0] & ucol, tmp[ 5] = src[ 5][0] & ucol, tmp[ 6] = src[ 6][0] & ucol, tmp[ 7] = src[ 7][0] & ucol;
tmp[ 8] = src[ 8][0] & ucol, tmp[ 9] = src[ 9][0] & ucol, tmp[10] = src[10][0] & ucol, tmp[11] = src[11][0] & ucol;
tmp[12] = src[12][0] & ucol, tmp[13] = src[13][0] & ucol, tmp[14] = src[14][0] & ucol, tmp[15] = src[15][0] & ucol;
tmp[16] = src[16][0] & ucol, tmp[17] = src[17][0] & ucol, tmp[18] = src[18][0] & ucol, tmp[19] = src[19][0] & ucol;
tmp[20] = src[20][0] & ucol, tmp[21] = src[21][0] & ucol, tmp[22] = src[22][0] & ucol, tmp[23] = src[23][0] & ucol;
tmp[24] = src[24][0] & ucol, tmp[25] = src[25][0] & ucol, tmp[26] = src[26][0] & ucol, tmp[27] = src[27][0] & ucol;
tmp[28] = src[28][0] & ucol, tmp[29] = src[29][0] & ucol, tmp[30] = src[30][0] & ucol, tmp[31] = src[31][0] & ucol;
tmp[32] = src[32][0] & ucol, tmp[33] = src[33][0] & ucol, tmp[34] = src[34][0] & ucol, tmp[35] = src[35][0] & ucol;
tmp[36] = src[36][0] & ucol, tmp[37] = src[37][0] & ucol, tmp[38] = src[38][0] & ucol, tmp[39] = src[39][0] & ucol;
tmp[40] = src[40][0] & ucol, tmp[41] = src[41][0] & ucol, tmp[42] = src[42][0] & ucol, tmp[43] = src[43][0] & ucol;
tmp[44] = src[44][0] & ucol, tmp[45] = src[45][0] & ucol, tmp[46] = src[46][0] & ucol, tmp[47] = src[47][0] & ucol;
tmp[48] = src[48][0] & ucol, tmp[49] = src[49][0] & ucol, tmp[50] = src[50][0] & ucol, tmp[51] = src[51][0] & ucol;
tmp[52] = src[52][0] & ucol, tmp[53] = src[53][0] & ucol, tmp[54] = src[54][0] & ucol, tmp[55] = src[55][0] & ucol;
tmp[56] = src[56][0] & ucol, tmp[57] = src[57][0] & ucol, tmp[58] = src[58][0] & ucol, tmp[59] = src[59][0] & ucol;
tmp[60] = src[60][0] & ucol, tmp[61] = src[61][0] & ucol, tmp[62] = src[62][0] & ucol, tmp[63] = src[63][0] & ucol;
#endif
word const dotprod = m4ri_parity64(tmp);
#if 0
for(int babystep = 0; babystep < m4ri_radix; ++babystep)
if(__M4RI_GET_BIT(dotprod, babystep))
__M4RI_FLIP_BIT(B->rows[giantstep + babystep][0], i);
#else
src[ 0][0] ^= ((dotprod>> 0)&m4ri_one)<<i, src[ 1][0] ^= ((dotprod>> 1)&m4ri_one)<<i, src[ 2][0] ^= ((dotprod>> 2)&m4ri_one)<<i, src[ 3][0] ^= ((dotprod>> 3)&m4ri_one)<<i;
src[ 4][0] ^= ((dotprod>> 4)&m4ri_one)<<i, src[ 5][0] ^= ((dotprod>> 5)&m4ri_one)<<i, src[ 6][0] ^= ((dotprod>> 6)&m4ri_one)<<i, src[ 7][0] ^= ((dotprod>> 7)&m4ri_one)<<i;
src[ 8][0] ^= ((dotprod>> 8)&m4ri_one)<<i, src[ 9][0] ^= ((dotprod>> 9)&m4ri_one)<<i, src[10][0] ^= ((dotprod>>10)&m4ri_one)<<i, src[11][0] ^= ((dotprod>>11)&m4ri_one)<<i;
src[12][0] ^= ((dotprod>>12)&m4ri_one)<<i, src[13][0] ^= ((dotprod>>13)&m4ri_one)<<i, src[14][0] ^= ((dotprod>>14)&m4ri_one)<<i, src[15][0] ^= ((dotprod>>15)&m4ri_one)<<i;
src[16][0] ^= ((dotprod>>16)&m4ri_one)<<i, src[17][0] ^= ((dotprod>>17)&m4ri_one)<<i, src[18][0] ^= ((dotprod>>18)&m4ri_one)<<i, src[19][0] ^= ((dotprod>>19)&m4ri_one)<<i;
src[20][0] ^= ((dotprod>>20)&m4ri_one)<<i, src[21][0] ^= ((dotprod>>21)&m4ri_one)<<i, src[22][0] ^= ((dotprod>>22)&m4ri_one)<<i, src[23][0] ^= ((dotprod>>23)&m4ri_one)<<i;
src[24][0] ^= ((dotprod>>24)&m4ri_one)<<i, src[25][0] ^= ((dotprod>>25)&m4ri_one)<<i, src[26][0] ^= ((dotprod>>26)&m4ri_one)<<i, src[27][0] ^= ((dotprod>>27)&m4ri_one)<<i;
src[28][0] ^= ((dotprod>>28)&m4ri_one)<<i, src[29][0] ^= ((dotprod>>29)&m4ri_one)<<i, src[30][0] ^= ((dotprod>>30)&m4ri_one)<<i, src[31][0] ^= ((dotprod>>31)&m4ri_one)<<i;
src[32][0] ^= ((dotprod>>32)&m4ri_one)<<i, src[33][0] ^= ((dotprod>>33)&m4ri_one)<<i, src[34][0] ^= ((dotprod>>34)&m4ri_one)<<i, src[35][0] ^= ((dotprod>>35)&m4ri_one)<<i;
src[36][0] ^= ((dotprod>>36)&m4ri_one)<<i, src[37][0] ^= ((dotprod>>37)&m4ri_one)<<i, src[38][0] ^= ((dotprod>>38)&m4ri_one)<<i, src[39][0] ^= ((dotprod>>39)&m4ri_one)<<i;
src[40][0] ^= ((dotprod>>40)&m4ri_one)<<i, src[41][0] ^= ((dotprod>>41)&m4ri_one)<<i, src[42][0] ^= ((dotprod>>42)&m4ri_one)<<i, src[43][0] ^= ((dotprod>>43)&m4ri_one)<<i;
src[44][0] ^= ((dotprod>>44)&m4ri_one)<<i, src[45][0] ^= ((dotprod>>45)&m4ri_one)<<i, src[46][0] ^= ((dotprod>>46)&m4ri_one)<<i, src[47][0] ^= ((dotprod>>47)&m4ri_one)<<i;
src[48][0] ^= ((dotprod>>48)&m4ri_one)<<i, src[49][0] ^= ((dotprod>>49)&m4ri_one)<<i, src[50][0] ^= ((dotprod>>50)&m4ri_one)<<i, src[51][0] ^= ((dotprod>>51)&m4ri_one)<<i;
src[52][0] ^= ((dotprod>>52)&m4ri_one)<<i, src[53][0] ^= ((dotprod>>53)&m4ri_one)<<i, src[54][0] ^= ((dotprod>>54)&m4ri_one)<<i, src[55][0] ^= ((dotprod>>55)&m4ri_one)<<i;
src[56][0] ^= ((dotprod>>56)&m4ri_one)<<i, src[57][0] ^= ((dotprod>>57)&m4ri_one)<<i, src[58][0] ^= ((dotprod>>58)&m4ri_one)<<i, src[59][0] ^= ((dotprod>>59)&m4ri_one)<<i;
src[60][0] ^= ((dotprod>>60)&m4ri_one)<<i, src[61][0] ^= ((dotprod>>61)&m4ri_one)<<i, src[62][0] ^= ((dotprod>>62)&m4ri_one)<<i, src[63][0] ^= ((dotprod>>63)&m4ri_one)<<i;
#endif
}
for(int babystep = 0; giantstep + babystep < mb; ++babystep)
tmp[babystep] = B->rows[giantstep + babystep][0] & ucol;
for(int babystep = mb - giantstep; babystep < 64; ++babystep)
tmp[babystep] = 0;
word const dotprod = m4ri_parity64(tmp);
for(int babystep = 0; giantstep + babystep < mb; ++babystep)
if(__M4RI_GET_BIT(dotprod, babystep))
__M4RI_FLIP_BIT(B->rows[giantstep + babystep][0], i);
}
__M4RI_DD_MZD(B);
}
/*****************
* LOWER LEFT
****************/
void mzd_trsm_lower_left(mzd_t const *L, mzd_t *B, const int cutoff) {
if(L->ncols != B->nrows)
m4ri_die("mzd_trsm_lower_left: L ncols (%d) need to match B nrows (%d).\n", L->ncols, B->nrows);
if(L->nrows != L->ncols)
m4ri_die("mzd_trsm_lower_left: L must be square and is found to be (%d) x (%d).\n", L->nrows, L->ncols);
_mzd_trsm_lower_left (L, B, cutoff);
}
void _mzd_trsm_lower_left(mzd_t const *L, mzd_t *B, const int cutoff) {
rci_t const mb = B->nrows;
rci_t const nb = B->ncols;
int const nbrest = nb % m4ri_radix;
if(mb <= m4ri_radix) {
/* base case */
word const mask_end = __M4RI_LEFT_BITMASK(nbrest);
for(rci_t i = 1; i < mb; ++i) {
/* Computes X_i = B_i + L_{i,0..i-1} X_{0..i-1} */
word *Lrow = L->rows[i];
word *Brow = B->rows[i];
for (rci_t k = 0; k < i; ++k) {
if (__M4RI_GET_BIT(Lrow[0], k)) {
for(wi_t j = 0; j < B->width - 1; ++j)
Brow[j] ^= B->rows[k][j];
Brow[B->width - 1] ^= B->rows[k][B->width - 1] & mask_end;
}
}
}
} else if(mb <= __M4RI_MUL_BLOCKSIZE) {
_mzd_trsm_lower_left_russian(L, B, 0);
} else {
rci_t const mb1 = (((mb - 1) / m4ri_radix + 1) >> 1) * m4ri_radix;
mzd_t *B0 = mzd_init_window(B, 0, 0, mb1, nb);
mzd_t *B1 = mzd_init_window(B, mb1, 0, mb, nb);
mzd_t const *L00 = mzd_init_window_const(L, 0, 0, mb1, mb1);
mzd_t const *L10 = mzd_init_window_const(L, mb1, 0, mb, mb1);
mzd_t const *L11 = mzd_init_window_const(L, mb1, mb1, mb, mb);
_mzd_trsm_lower_left(L00, B0, cutoff);
mzd_addmul (B1, L10, B0, cutoff);
_mzd_trsm_lower_left(L11, B1, cutoff);
mzd_free_window(B0);
mzd_free_window(B1);
mzd_free_window((mzd_t*)L00);
mzd_free_window((mzd_t*)L10);
mzd_free_window((mzd_t*)L11);
}
__M4RI_DD_MZD(B);
}
/*****************
* UPPER LEFT
****************/
void mzd_trsm_upper_left(mzd_t const *U, mzd_t *B, const int cutoff) {
if(U->ncols != B->nrows)
m4ri_die("mzd_trsm_upper_left: U ncols (%d) need to match B nrows (%d).\n", U->ncols, B->nrows);
if(U->nrows != U->ncols)
m4ri_die("mzd_trsm_upper_left: U must be square and is found to be (%d) x (%d).\n", U->nrows, U->ncols);
_mzd_trsm_upper_left(U, B, cutoff);
}
void _mzd_trsm_upper_left(mzd_t const *U, mzd_t *B, const int cutoff) {
rci_t const mb = B->nrows;
rci_t const nb = B->ncols;
if(mb <= m4ri_radix) {
/* base case */
word const mask_end = B->high_bitmask;
// U[mb-1,mb-1] = 1, so no work required for i=mb-1
for(rci_t i = mb - 2; i >= 0; --i) {
/* Computes X_i = B_i + U_{i,i+1..mb} X_{i+1..mb} */
word *Urow = U->rows[i];
word *Brow = B->rows[i];
for(rci_t k = i + 1; k < mb; ++k) {
if(__M4RI_GET_BIT(Urow[0], k)){
for(wi_t j = 0; j < B->width - 1; ++j)
Brow[j] ^= B->rows[k][j];
Brow[B->width - 1] ^= B->rows[k][B->width - 1] & mask_end;
}
}
}
} else if(mb <= __M4RI_MUL_BLOCKSIZE) {
_mzd_trsm_upper_left_russian(U, B, 0);
} else {
rci_t const mb1 = (((mb-1) / m4ri_radix + 1) >> 1) * m4ri_radix;
mzd_t *B0 = mzd_init_window(B, 0, 0, mb1, nb);
mzd_t *B1 = mzd_init_window(B, mb1, 0, mb, nb);
mzd_t const *U00 = mzd_init_window_const(U, 0, 0, mb1, mb1);
mzd_t const *U01 = mzd_init_window_const(U, 0, mb1, mb1, mb);
mzd_t const *U11 = mzd_init_window_const(U, mb1, mb1, mb, mb);
_mzd_trsm_upper_left(U11, B1, cutoff);
_mzd_addmul (B0, U01, B1, cutoff);
_mzd_trsm_upper_left(U00, B0, cutoff);
mzd_free_window(B0);
mzd_free_window(B1);
mzd_free_window((mzd_t*)U00);
mzd_free_window((mzd_t*)U01);
mzd_free_window((mzd_t*)U11);
}
__M4RI_DD_MZD(B);
}
mzd_t *mzd_trtri_upper(mzd_t *U) {
if (U->nrows*U->ncols < __M4RI_CPU_L3_CACHE<<1) {
mzd_trtri_upper_russian(U,0);
} else {
rci_t const n = U->nrows;
rci_t n2 = (((n - 1) / m4ri_radix + 1) >> 1);
#if __M4RI_HAVE_SSE2
if (n2%2) n2 += 1;
#endif
n2 *= m4ri_radix;
assert(n2 < n);
mzd_t *U00 = mzd_init_window(U, 0, 0, n2, n2);
mzd_t *U01 = mzd_init_window(U, 0, n2, n2, n );
mzd_t *U11 = mzd_init_window(U, n2, n2 , n , n );
_mzd_trsm_upper_left( U00, U01, 0);
_mzd_trsm_upper_right(U11, U01, 0);
mzd_trtri_upper(U00);
mzd_trtri_upper(U11);
mzd_free_window((mzd_t*)U00);
mzd_free_window((mzd_t*)U01);
mzd_free_window((mzd_t*)U11);
}
return U;
}