463 lines
16 KiB
C
Raw Normal View History

2023-02-24 07:58:40 +00:00
/******************************************************************************
*
* M4RI: Linear Algebra over GF(2)
*
* Copyright (C) 2008 Martin Albrecht <malb@informatik.uni-bremen.de>
*
* Distributed under the terms of the GNU General Public License (GPL)
* version 2 or higher.
*
* This code is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* The full text of the GPL is available at:
*
* http://www.gnu.org/licenses/
******************************************************************************/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "mzp.h"
#include "mzd.h"
mzp_t *mzp_init(rci_t length) {
mzp_t *P = (mzp_t*)m4ri_mm_malloc(sizeof(mzp_t));
P->values = (rci_t*)m4ri_mm_malloc(sizeof(rci_t) * length);
P->length = length;
for (rci_t i = 0; i < length; ++i) {
P->values[i] = i;
}
return P;
}
void mzp_free(mzp_t *P) {
m4ri_mm_free(P->values);
m4ri_mm_free(P);
}
mzp_t *mzp_init_window(mzp_t *P, rci_t begin, rci_t end){
mzp_t *window = (mzp_t *)m4ri_mm_malloc(sizeof(mzp_t));
window->values = P->values + begin;
window->length = end - begin;
__M4RI_DD_MZP(window);
return window;
}
void mzp_free_window(mzp_t *condemned){
m4ri_mm_free(condemned);
}
mzp_t *mzp_copy(mzp_t *P, const mzp_t *Q) {
if(P == NULL)
P = mzp_init(Q->length);
for(rci_t i=0; i<Q->length; i++)
P->values[i] = Q->values[i];
return P;
}
void mzp_set_ui(mzp_t *P, unsigned int value) {
assert(value == 1);
for (rci_t i = 0; i < P->length; ++i) {
P->values[i] = i;
}
}
void mzd_apply_p_left(mzd_t *A, mzp_t const *P) {
if(A->ncols == 0)
return;
rci_t const length = MIN(P->length, A->nrows);
for (rci_t i = 0; i < length; ++i) {
assert(P->values[i] >= i);
mzd_row_swap(A, i, P->values[i]);
}
}
void mzd_apply_p_left_trans(mzd_t *A, mzp_t const *P) {
if(A->ncols == 0)
return;
rci_t const length = MIN(P->length, A->nrows);
for (rci_t i = length - 1; i >= 0; --i) {
assert(P->values[i] >= i);
mzd_row_swap(A, i, P->values[i]);
}
}
/* optimised column swap operations */
static inline void mzd_write_col_to_rows_blockd(mzd_t *A, mzd_t const *B, rci_t const *permutation, word const *write_mask, rci_t const start_row, rci_t const stop_row, rci_t length) {
for(rci_t i = 0; i < length; i += m4ri_radix) {
/* optimisation for identity permutations */
if (write_mask[i / m4ri_radix] == m4ri_ffff)
continue;
int const todo = MIN(m4ri_radix, length - i);
wi_t const a_word = i / m4ri_radix;
wi_t words[m4ri_radix];
int bits[m4ri_radix];
word bitmasks[m4ri_radix];
/* we pre-compute bit access in advance */
for(int k = 0; k < todo; ++k) {
rci_t const colb = permutation[i + k];
words[k] = colb / m4ri_radix;
bits[k] = colb % m4ri_radix;
bitmasks[k] = m4ri_one << bits[k];
}
for (rci_t r = start_row; r < stop_row; ++r) {
word const *Brow = B->rows[r-start_row];
word *Arow = A->rows[r];
register word value = 0;
/* we gather the bits in a register word */
switch(todo-1) {
case 63: value |= ((Brow[words[63]] & bitmasks[63]) >> bits[63]) << 63;
case 62: value |= ((Brow[words[62]] & bitmasks[62]) >> bits[62]) << 62;
case 61: value |= ((Brow[words[61]] & bitmasks[61]) >> bits[61]) << 61;
case 60: value |= ((Brow[words[60]] & bitmasks[60]) >> bits[60]) << 60;
case 59: value |= ((Brow[words[59]] & bitmasks[59]) >> bits[59]) << 59;
case 58: value |= ((Brow[words[58]] & bitmasks[58]) >> bits[58]) << 58;
case 57: value |= ((Brow[words[57]] & bitmasks[57]) >> bits[57]) << 57;
case 56: value |= ((Brow[words[56]] & bitmasks[56]) >> bits[56]) << 56;
case 55: value |= ((Brow[words[55]] & bitmasks[55]) >> bits[55]) << 55;
case 54: value |= ((Brow[words[54]] & bitmasks[54]) >> bits[54]) << 54;
case 53: value |= ((Brow[words[53]] & bitmasks[53]) >> bits[53]) << 53;
case 52: value |= ((Brow[words[52]] & bitmasks[52]) >> bits[52]) << 52;
case 51: value |= ((Brow[words[51]] & bitmasks[51]) >> bits[51]) << 51;
case 50: value |= ((Brow[words[50]] & bitmasks[50]) >> bits[50]) << 50;
case 49: value |= ((Brow[words[49]] & bitmasks[49]) >> bits[49]) << 49;
case 48: value |= ((Brow[words[48]] & bitmasks[48]) >> bits[48]) << 48;
case 47: value |= ((Brow[words[47]] & bitmasks[47]) >> bits[47]) << 47;
case 46: value |= ((Brow[words[46]] & bitmasks[46]) >> bits[46]) << 46;
case 45: value |= ((Brow[words[45]] & bitmasks[45]) >> bits[45]) << 45;
case 44: value |= ((Brow[words[44]] & bitmasks[44]) >> bits[44]) << 44;
case 43: value |= ((Brow[words[43]] & bitmasks[43]) >> bits[43]) << 43;
case 42: value |= ((Brow[words[42]] & bitmasks[42]) >> bits[42]) << 42;
case 41: value |= ((Brow[words[41]] & bitmasks[41]) >> bits[41]) << 41;
case 40: value |= ((Brow[words[40]] & bitmasks[40]) >> bits[40]) << 40;
case 39: value |= ((Brow[words[39]] & bitmasks[39]) >> bits[39]) << 39;
case 38: value |= ((Brow[words[38]] & bitmasks[38]) >> bits[38]) << 38;
case 37: value |= ((Brow[words[37]] & bitmasks[37]) >> bits[37]) << 37;
case 36: value |= ((Brow[words[36]] & bitmasks[36]) >> bits[36]) << 36;
case 35: value |= ((Brow[words[35]] & bitmasks[35]) >> bits[35]) << 35;
case 34: value |= ((Brow[words[34]] & bitmasks[34]) >> bits[34]) << 34;
case 33: value |= ((Brow[words[33]] & bitmasks[33]) >> bits[33]) << 33;
case 32: value |= ((Brow[words[32]] & bitmasks[32]) >> bits[32]) << 32;
case 31: value |= ((Brow[words[31]] & bitmasks[31]) >> bits[31]) << 31;
case 30: value |= ((Brow[words[30]] & bitmasks[30]) >> bits[30]) << 30;
case 29: value |= ((Brow[words[29]] & bitmasks[29]) >> bits[29]) << 29;
case 28: value |= ((Brow[words[28]] & bitmasks[28]) >> bits[28]) << 28;
case 27: value |= ((Brow[words[27]] & bitmasks[27]) >> bits[27]) << 27;
case 26: value |= ((Brow[words[26]] & bitmasks[26]) >> bits[26]) << 26;
case 25: value |= ((Brow[words[25]] & bitmasks[25]) >> bits[25]) << 25;
case 24: value |= ((Brow[words[24]] & bitmasks[24]) >> bits[24]) << 24;
case 23: value |= ((Brow[words[23]] & bitmasks[23]) >> bits[23]) << 23;
case 22: value |= ((Brow[words[22]] & bitmasks[22]) >> bits[22]) << 22;
case 21: value |= ((Brow[words[21]] & bitmasks[21]) >> bits[21]) << 21;
case 20: value |= ((Brow[words[20]] & bitmasks[20]) >> bits[20]) << 20;
case 19: value |= ((Brow[words[19]] & bitmasks[19]) >> bits[19]) << 19;
case 18: value |= ((Brow[words[18]] & bitmasks[18]) >> bits[18]) << 18;
case 17: value |= ((Brow[words[17]] & bitmasks[17]) >> bits[17]) << 17;
case 16: value |= ((Brow[words[16]] & bitmasks[16]) >> bits[16]) << 16;
case 15: value |= ((Brow[words[15]] & bitmasks[15]) >> bits[15]) << 15;
case 14: value |= ((Brow[words[14]] & bitmasks[14]) >> bits[14]) << 14;
case 13: value |= ((Brow[words[13]] & bitmasks[13]) >> bits[13]) << 13;
case 12: value |= ((Brow[words[12]] & bitmasks[12]) >> bits[12]) << 12;
case 11: value |= ((Brow[words[11]] & bitmasks[11]) >> bits[11]) << 11;
case 10: value |= ((Brow[words[10]] & bitmasks[10]) >> bits[10]) << 10;
case 9: value |= ((Brow[words[ 9]] & bitmasks[ 9]) >> bits[ 9]) << 9;
case 8: value |= ((Brow[words[ 8]] & bitmasks[ 8]) >> bits[ 8]) << 8;
case 7: value |= ((Brow[words[ 7]] & bitmasks[ 7]) >> bits[ 7]) << 7;
case 6: value |= ((Brow[words[ 6]] & bitmasks[ 6]) >> bits[ 6]) << 6;
case 5: value |= ((Brow[words[ 5]] & bitmasks[ 5]) >> bits[ 5]) << 5;
case 4: value |= ((Brow[words[ 4]] & bitmasks[ 4]) >> bits[ 4]) << 4;
case 3: value |= ((Brow[words[ 3]] & bitmasks[ 3]) >> bits[ 3]) << 3;
case 2: value |= ((Brow[words[ 2]] & bitmasks[ 2]) >> bits[ 2]) << 2;
case 1: value |= ((Brow[words[ 1]] & bitmasks[ 1]) >> bits[ 1]) << 1;
case 0: value |= ((Brow[words[ 0]] & bitmasks[ 0]) >> bits[ 0]) << 0;
default:
break;
}
/* for(int k = 0; k < todo; ++k) { */
/* value |= ((Brow[words[k]] & bitmasks[k]) << bits[k]) >> k; */
/* } */
/* and write the word once */
Arow[a_word] |= value;
}
}
__M4RI_DD_MZD(A);
}
/**
* Implements both apply_p_right and apply_p_right_trans.
*/
void _mzd_apply_p_right_even(mzd_t *A, mzp_t const *P, rci_t start_row, rci_t start_col, int notrans) {
if(A->nrows - start_row == 0)
return;
rci_t const length = MIN(P->length, A->ncols);
wi_t const width = A->width;
int step_size = MIN(A->nrows - start_row, MAX((__M4RI_CPU_L1_CACHE >> 3) / A->width, 1));
/* our temporary where we store the columns we want to swap around */
mzd_t *B = mzd_init(step_size, A->ncols);
word *Arow;
word *Brow;
/* setup mathematical permutation */
rci_t *permutation = (rci_t*)m4ri_mm_calloc(A->ncols, sizeof(rci_t));
for(rci_t i = 0; i < A->ncols; ++i)
permutation[i] = i;
if (!notrans) {
for(rci_t i = start_col; i < length; ++i) {
rci_t t = permutation[i];
permutation[i] = permutation[P->values[i]];
permutation[P->values[i]] = t;
}
} else {
for(rci_t i = start_col; i < length; ++i) {
rci_t t = permutation[length - i - 1];
permutation[length - i - 1] = permutation[P->values[length - i - 1]];
permutation[P->values[length - i - 1]] = t;
}
}
/* we have a bitmask to encode where to write to */
word *write_mask = (word*)m4ri_mm_calloc(width, sizeof(word));
for(rci_t i = 0; i < A->ncols; i += m4ri_radix) {
int const todo = MIN(m4ri_radix, A->ncols - i);
for(int k = 0; k < todo; ++k) {
if(permutation[i + k] == i + k) {
write_mask[i / m4ri_radix] |= m4ri_one << k;
}
}
}
write_mask[width-1] |= ~A->high_bitmask;
for(rci_t i = start_row; i < A->nrows; i += step_size) {
step_size = MIN(step_size, A->nrows - i);
for(int k = 0; k < step_size; ++k) {
Arow = A->rows[i+k];
Brow = B->rows[k];
/*copy row & clear those values which will be overwritten */
for(wi_t j = 0; j < width; ++j) {
Brow[j] = Arow[j];
Arow[j] = Arow[j] & write_mask[j];
}
}
/* here we actually write out the permutation */
mzd_write_col_to_rows_blockd(A, B, permutation, write_mask, i, i + step_size, length);
}
m4ri_mm_free(permutation);
m4ri_mm_free(write_mask);
mzd_free(B);
__M4RI_DD_MZD(A);
}
void _mzd_apply_p_right_trans(mzd_t *A, mzp_t const *P) {
if(A->nrows == 0)
return;
rci_t const length = MIN(P->length, A->ncols);
int const step_size = MAX((__M4RI_CPU_L1_CACHE >> 3) / A->width, 1);
for(rci_t j = 0; j < A->nrows; j += step_size) {
rci_t stop_row = MIN(j + step_size, A->nrows);
for (rci_t i = 0; i < length; ++i) {
assert(P->values[i] >= i);
mzd_col_swap_in_rows(A, i, P->values[i], j, stop_row);
}
}
/* for (i=0; i<P->length; i++) { */
/* assert(P->values[i] >= i); */
/* mzd_col_swap(A, i, P->values[i]); */
/* } */
__M4RI_DD_MZD(A);
}
void _mzd_apply_p_right(mzd_t *A, mzp_t const *P) {
if(A->nrows == 0)
return;
int const step_size = MAX((__M4RI_CPU_L1_CACHE >> 3) / A->width, 1);
for(rci_t j = 0; j < A->nrows; j += step_size) {
rci_t stop_row = MIN(j + step_size, A->nrows);
for (rci_t i = P->length - 1; i >= 0; --i) {
assert(P->values[i] >= i);
mzd_col_swap_in_rows(A, i, P->values[i], j, stop_row);
}
}
/* long i; */
/* for (i=P->length-1; i>=0; i--) { */
/* assert(P->values[i] >= i); */
/* mzd_col_swap(A, i, P->values[i]); */
/* } */
__M4RI_DD_MZD(A);
}
void mzd_apply_p_right_trans(mzd_t *A, mzp_t const *P) {
if(!A->nrows)
return;
_mzd_apply_p_right_even(A, P, 0, 0, 0);
}
void mzd_apply_p_right(mzd_t *A, mzp_t const *P) {
if(!A->nrows)
return;
_mzd_apply_p_right_even(A, P, 0, 0, 1);
}
void mzd_apply_p_right_trans_even_capped(mzd_t *A, mzp_t const *P, rci_t start_row, rci_t start_col) {
if(!A->nrows)
return;
_mzd_apply_p_right_even(A, P, start_row, start_col, 0);
}
void mzd_apply_p_right_even_capped(mzd_t *A, mzp_t const *P, rci_t start_row, rci_t start_col) {
if(!A->nrows)
return;
_mzd_apply_p_right_even(A, P, start_row, start_col, 1);
}
void mzp_print(mzp_t const *P) {
printf("[ ");
for(rci_t i = 0; i < P->length; ++i) {
printf("%zd ", (size_t)P->values[i]);
}
printf("]");
}
void mzd_apply_p_right_trans_tri(mzd_t *A, mzp_t const *P) {
assert(P->length == A->ncols);
int const step_size = MAX((__M4RI_CPU_L1_CACHE >> 2) / A->width, 1);
for(rci_t r = 0; r < A->nrows; r += step_size) {
rci_t const row_bound = MIN(r + step_size, A->nrows);
for (rci_t i =0 ; i < A->ncols; ++i) {
assert(P->values[i] >= i);
mzd_col_swap_in_rows(A, i, P->values[i], r, MIN(row_bound, i));
}
}
__M4RI_DD_MZD(A);
}
void _mzd_compress_l(mzd_t *A, rci_t r1, rci_t n1, rci_t r2) {
/**
* We are compressing this matrix
\verbatim
r1 n1
------------------------------------------
| \ \____|___ | A01 |
| \ | \ | |
r1------------------------------------------
| | | | \ \_____ |
| L1| | | \ \________|
| | | | L2| |
------------------------------------------
\endverbatim
*
* to this matrix
*
\verbatim
r1 n1
------------------------------------------
| \ \____|___ | A01 |
| \ | \ | |
r1------------------------------------------
| \ | | \_____ |
| \ | | \________|
| | | | |
------------------------------------------
\endverbatim
*/
if (r1 == n1)
return;
#if 0
mzp_t *shift = mzp_init(A->ncols);
for (rci_t i=r1,j=n1;i<r1+r2;i++,j++){
mzd_col_swap_in_rows(A, i, j, i, r1+r2);
shift->values[i] = j;
}
mzd_apply_p_right_trans_even_capped(A, shift, r1+r2, 0);
mzp_free(shift);
#else
for (rci_t i = r1, j = n1; i < r1 + r2; ++i, ++j){
mzd_col_swap_in_rows(A, i, j, i, r1 + r2);
}
word tmp;
wi_t block;
for(rci_t i = r1 + r2; i < A->nrows; ++i) {
rci_t j = r1;
/* first we deal with the rest of the current word we need to
write */
int const rest = m4ri_radix - (j % m4ri_radix);
tmp = mzd_read_bits(A, i, n1, rest);
mzd_clear_bits(A, i, j, rest);
mzd_xor_bits(A, i, j, rest, tmp);
j += rest;
/* now each write is simply a word write */
block = (n1 + j - r1) / m4ri_radix;
if (rest % m4ri_radix == 0) {
for( ; j + m4ri_radix <= r1 + r2; j += m4ri_radix, ++block) {
tmp = A->rows[i][block];
A->rows[i][j / m4ri_radix] = tmp;
}
} else {
for(; j + m4ri_radix <= r1 + r2; j += m4ri_radix, ++block) {
tmp = (A->rows[i][block] >> rest) | ( A->rows[i][block + 1] << (m4ri_radix - rest));
A->rows[i][j / m4ri_radix] = tmp;
}
}
/* we deal with the remaining bits. While we could write past the
end of r1+r2 here, but we have no guarantee that we can read
past the end of n1+r2. */
if (j < r1 + r2) {
tmp = mzd_read_bits(A, i, n1 + j - r1, r1 + r2 - j);
A->rows[i][j / m4ri_radix] = tmp;
}
/* now clear the rest of L2 */
j = r1 + r2;
mzd_clear_bits(A, i, j, m4ri_radix - (j % m4ri_radix));
j += m4ri_radix - (j % m4ri_radix);
/* it's okay to write the full word, i.e. past n1+r2, because
everything is zero there anyway. Thus, we can omit the code
which deals with last few bits. */
for(; j < n1 + r2; j += m4ri_radix) {
A->rows[i][j / m4ri_radix] = 0;
}
}
#endif
__M4RI_DD_MZD(A);
}