#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include "triangular_russian.h"
#include "graycode.h"
#include "brilliantrussian.h"
#include "ple_russian.h"
#include "xor.h"

/** the number of tables used in TRSM decomposition **/
#define __M4RI_TRSM_NTABLES 8

void _mzd_trsm_upper_left_submatrix(mzd_t const *U, mzd_t *B, rci_t const start_row, int const k, word const mask_end) {
  for (rci_t i = 0; i < k; ++i) {
    for (rci_t j = 0; j < i; ++j) {
      if (mzd_read_bit(U, start_row+(k-i-1), start_row+(k-i)+j)) {
        word *a = B->rows[start_row+(k-i-1)];
        word *b = B->rows[start_row+(k-i)+j];

	wi_t ii;
        for(ii = 0; ii + 8 <= B->width - 1; ii += 8) {
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
        }
        switch(B->width - ii) {
        case 8:  *a++ ^= *b++;
        case 7:  *a++ ^= *b++;
        case 6:  *a++ ^= *b++;
        case 5:  *a++ ^= *b++;
        case 4:  *a++ ^= *b++;
        case 3:  *a++ ^= *b++;
        case 2:  *a++ ^= *b++;
        case 1:  *a++ ^= (*b++ & mask_end);
        }
      }
    }
  }

  __M4RI_DD_MZD(B);
}

void _mzd_trsm_upper_left_russian(mzd_t const *U, mzd_t *B, int k) {
  wi_t const wide = B->width;

  word mask_end = __M4RI_LEFT_BITMASK(B->ncols % m4ri_radix);

  if(k == 0) {
    /* __M4RI_CPU_L2_CACHE == __M4RI_TRSM_NTABLES * 2^k * B->width * 8 */
    k = (int)log2((__M4RI_CPU_L2_CACHE/8)/(double)B->width/(double)__M4RI_TRSM_NTABLES);

    rci_t const klog = round(0.75 * log2_floor(MIN(B->nrows, B->ncols)));

    if(klog < k) k = klog;
    if (k<2)     k = 2;
    else if(k>8) k = 8;
  }


  int kk = __M4RI_TRSM_NTABLES * k;
  assert(kk <= m4ri_radix);

  mzd_t *T[__M4RI_TRSM_NTABLES];
  rci_t *L[__M4RI_TRSM_NTABLES];

#ifdef __M4RI_HAVE_SSE2
  mzd_t *Talign[__M4RI_TRSM_NTABLES];
  int b_align = (__M4RI_ALIGNMENT(B->rows[0], 16) == 8);
#endif

  for(int i=0; i<__M4RI_TRSM_NTABLES; i++) {
#ifdef __M4RI_HAVE_SSE2
    /* we make sure that T are aligned as C */
    Talign[i] = mzd_init(__M4RI_TWOPOW(k), B->ncols + m4ri_radix);
    T[i] = mzd_init_window(Talign[i], 0, b_align*m4ri_radix, Talign[i]->nrows, B->ncols + b_align*m4ri_radix);
#else
    T[i] = mzd_init(__M4RI_TWOPOW(k), B->ncols);
#endif
    L[i] = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
  }

  rci_t i = 0;
  for (; i < B->nrows - kk; i += kk) {

    _mzd_trsm_upper_left_submatrix(U, B, B->nrows-i-kk, kk, mask_end);

    switch(__M4RI_TRSM_NTABLES) {
    case 8:  mzd_make_table(B, B->nrows - i - 8*k, 0, k, T[7], L[7]);
    case 7:  mzd_make_table(B, B->nrows - i - 7*k, 0, k, T[6], L[6]);
    case 6:  mzd_make_table(B, B->nrows - i - 6*k, 0, k, T[5], L[5]);
    case 5:  mzd_make_table(B, B->nrows - i - 5*k, 0, k, T[4], L[4]);
    case 4:  mzd_make_table(B, B->nrows - i - 4*k, 0, k, T[3], L[3]);
    case 3:  mzd_make_table(B, B->nrows - i - 3*k, 0, k, T[2], L[2]);
    case 2:  mzd_make_table(B, B->nrows - i - 2*k, 0, k, T[1], L[1]);
    case 1:  mzd_make_table(B, B->nrows - i - 1*k, 0, k, T[0], L[0]);
      break;
    default:
      m4ri_die("__M4RI_TRSM_NTABLES must be <= 8 but got %d", __M4RI_TRSM_NTABLES);
    }


    for(rci_t j = 0; j < B->nrows - i - kk; ++j) {
      rci_t x;
      const word *t[__M4RI_TRSM_NTABLES];

      switch(__M4RI_TRSM_NTABLES) {
      case 8: x = L[7][ mzd_read_bits_int(U, j, B->nrows - i - 8*k, k) ]; t[7] = T[7]->rows[x];
      case 7: x = L[6][ mzd_read_bits_int(U, j, B->nrows - i - 7*k, k) ]; t[6] = T[6]->rows[x];
      case 6: x = L[5][ mzd_read_bits_int(U, j, B->nrows - i - 6*k, k) ]; t[5] = T[5]->rows[x];
      case 5: x = L[4][ mzd_read_bits_int(U, j, B->nrows - i - 5*k, k) ]; t[4] = T[4]->rows[x];
      case 4: x = L[3][ mzd_read_bits_int(U, j, B->nrows - i - 4*k, k) ]; t[3] = T[3]->rows[x];
      case 3: x = L[2][ mzd_read_bits_int(U, j, B->nrows - i - 3*k, k) ]; t[2] = T[2]->rows[x];
      case 2: x = L[1][ mzd_read_bits_int(U, j, B->nrows - i - 2*k, k) ]; t[1] = T[1]->rows[x];
      case 1: x = L[0][ mzd_read_bits_int(U, j, B->nrows - i - 1*k, k) ]; t[0] = T[0]->rows[x];
        break;
      default:
        m4ri_die("__M4RI_TRSM_NTABLES must be <= 8 but got %d", __M4RI_TRSM_NTABLES);
      }

      word *b = B->rows[j];
      switch(__M4RI_TRSM_NTABLES) {
      case 8: _mzd_combine_8(b, t, wide); break;
      case 7: _mzd_combine_7(b, t, wide); break;
      case 6: _mzd_combine_6(b, t, wide); break;
      case 5: _mzd_combine_5(b, t, wide); break;
      case 4: _mzd_combine_4(b, t, wide); break;
      case 3: _mzd_combine_3(b, t, wide); break;
      case 2: _mzd_combine_2(b, t, wide); break;
      case 1: _mzd_combine(b, t[0], wide);
        break;
      default:
        m4ri_die("__M4RI_TRSM_NTABLES must be <= 8 but got %d", __M4RI_TRSM_NTABLES);
      }
    }
  }

  /* handle stuff that doesn't fit in multiples of kk */
  for ( ;i < B->nrows; i += k) {
    if (i > B->nrows - k)
      k = B->nrows - i;

    _mzd_trsm_upper_left_submatrix(U, B, B->nrows-i-k, k, mask_end);

    mzd_make_table(B, B->nrows - i - 1*k, 0, k, T[0], L[0]);

    for(rci_t j = 0; j < B->nrows - i - k; ++j) {
      rci_t const x0 = L[0][ mzd_read_bits_int(U, j, B->nrows - i - 1*k, k) ];

      word *b = B->rows[j];
      word *t0 = T[0]->rows[x0];

      for (wi_t ii = 0; ii < wide; ++ii)
        b[ii] ^= t0[ii];
    }
  }
  for(int i=0; i<__M4RI_TRSM_NTABLES; i++) {
    mzd_free(T[i]);
#ifdef __M4RI_HAVE_SSE2
    mzd_free(Talign[i]);
#endif
    m4ri_mm_free(L[i]);
  }

  __M4RI_DD_MZD(B);
}

void _mzd_trsm_lower_left_submatrix(mzd_t const *L, mzd_t *B, rci_t const start_row, int const k, word const mask_end) {
  for (int i = 0; i < k; ++i) {
    for (int j = 0; j < i; ++j) {
      if (mzd_read_bit(L, start_row+i, start_row+j)) {
        word *a = B->rows[start_row+i];
        word *b = B->rows[start_row+j];

	wi_t ii;
        for(ii = 0; ii + 8 <= B->width - 1; ii += 8) {
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
          *a++ ^= *b++;
        }
        switch(B->width - ii) {
        case 8:  *a++ ^= *b++;
        case 7:  *a++ ^= *b++;
        case 6:  *a++ ^= *b++;
        case 5:  *a++ ^= *b++;
        case 4:  *a++ ^= *b++;
        case 3:  *a++ ^= *b++;
        case 2:  *a++ ^= *b++;
        case 1:  *a++ ^= (*b++ & mask_end);
        }
      }
    }
  }

  __M4RI_DD_MZD(B);
}

void _mzd_trsm_lower_left_russian(mzd_t const *L, mzd_t *B, int k) {
  wi_t const wide = B->width;

  if(k == 0) {
    /* __M4RI_CPU_L2_CACHE == __M4RI_TRSM_NTABLES * 2^k * B->width * 8 */
    k = (int)log2((__M4RI_CPU_L2_CACHE/8)/(double)B->width/(double)__M4RI_TRSM_NTABLES);

    rci_t const klog = round(0.75 * log2_floor(MIN(B->nrows, B->ncols)));

    if(klog < k) k = klog;
    if (k<2)     k = 2;
    else if(k>8) k = 8;
  }
  int kk = __M4RI_TRSM_NTABLES * k;
  assert(kk <= m4ri_radix);

  mzd_t *T[__M4RI_TRSM_NTABLES];
  rci_t *J[__M4RI_TRSM_NTABLES];

#ifdef __M4RI_HAVE_SSE2
    /* we make sure that T are aligned as B, this is dirty, we need a function for this */
  mzd_t *Talign[__M4RI_TRSM_NTABLES];
  int b_align = (__M4RI_ALIGNMENT(B->rows[0], 16) == 8);
#endif

  for(int i=0; i<__M4RI_TRSM_NTABLES; i++) {
#ifdef __M4RI_HAVE_SSE2
    Talign[i] = mzd_init(__M4RI_TWOPOW(k), B->ncols + m4ri_radix);
    T[i] = mzd_init_window(Talign[i], 0, b_align*m4ri_radix, Talign[i]->nrows, B->ncols + b_align*m4ri_radix);
#else
    T[i] = mzd_init(__M4RI_TWOPOW(k), B->ncols);
#endif
    J[i] = (rci_t*)m4ri_mm_calloc(__M4RI_TWOPOW(k), sizeof(rci_t));
  }

  const word mask = __M4RI_LEFT_BITMASK(k);
  rci_t i = 0;
  for (; i < B->nrows - kk; i += kk) {

    _mzd_trsm_lower_left_submatrix(L, B, i, kk, B->high_bitmask);

    switch(__M4RI_TRSM_NTABLES) {
    case 8:  mzd_make_table(B, i + 7*k, 0, k, T[7], J[7]);
    case 7:  mzd_make_table(B, i + 6*k, 0, k, T[6], J[6]);
    case 6:  mzd_make_table(B, i + 5*k, 0, k, T[5], J[5]);
    case 5:  mzd_make_table(B, i + 4*k, 0, k, T[4], J[4]);
    case 4:  mzd_make_table(B, i + 3*k, 0, k, T[3], J[3]);
    case 3:  mzd_make_table(B, i + 2*k, 0, k, T[2], J[2]);
    case 2:  mzd_make_table(B, i + 1*k, 0, k, T[1], J[1]);
    case 1:  mzd_make_table(B, i + 0*k, 0, k, T[0], J[0]);
      break;
    default:
      m4ri_die("__M4RI_TRSM_NTABLES must be <= 8 but got %d", __M4RI_TRSM_NTABLES);
    }


    for(rci_t j = i+kk; j < B->nrows; ++j) {
      const word *t[__M4RI_TRSM_NTABLES];

      word tmp = mzd_read_bits(L, j, i, kk);

      switch(__M4RI_TRSM_NTABLES) {
      case 8: t[7] = T[7]->rows[ J[7][ (tmp >> (7*k)) & mask ] ];
      case 7: t[6] = T[6]->rows[ J[6][ (tmp >> (6*k)) & mask ] ];
      case 6: t[5] = T[5]->rows[ J[5][ (tmp >> (5*k)) & mask ] ];
      case 5: t[4] = T[4]->rows[ J[4][ (tmp >> (4*k)) & mask ] ];
      case 4: t[3] = T[3]->rows[ J[3][ (tmp >> (3*k)) & mask ] ];
      case 3: t[2] = T[2]->rows[ J[2][ (tmp >> (2*k)) & mask ] ];
      case 2: t[1] = T[1]->rows[ J[1][ (tmp >> (1*k)) & mask ] ];
      case 1: t[0] = T[0]->rows[ J[0][ (tmp >> (0*k)) & mask ] ];
        break;
      default:
        m4ri_die("__M4RI_TRSM_NTABLES must be <= 8 but got %d", __M4RI_TRSM_NTABLES);
      }

      word *b = B->rows[j];
      switch(__M4RI_TRSM_NTABLES) {
      case 8: _mzd_combine_8(b, t, wide); break;
      case 7: _mzd_combine_7(b, t, wide); break;
      case 6: _mzd_combine_6(b, t, wide); break;
      case 5: _mzd_combine_5(b, t, wide); break;
      case 4: _mzd_combine_4(b, t, wide); break;
      case 3: _mzd_combine_3(b, t, wide); break;
      case 2: _mzd_combine_2(b, t, wide); break;
      case 1: _mzd_combine(b, t[0], wide);
        break;
      default:
        m4ri_die("__M4RI_TRSM_NTABLES must be <= 8 but got %d", __M4RI_TRSM_NTABLES);
      }
    }
  }

  /* handle stuff that doesn't fit in multiples of kk */
  for ( ;i < B->nrows; i += k) {
    if (i > B->nrows - k)
      k = B->nrows - i;

    _mzd_trsm_lower_left_submatrix(L, B, i, k, B->high_bitmask);

    mzd_make_table(B, i + 0*k, 0, k, T[0], J[0]);

    for(rci_t j = i+k; j < L->nrows; ++j) {
      rci_t const x0 = J[0][ mzd_read_bits_int(L, j, i, k) ];

      word *b = B->rows[j];
      word *t0 = T[0]->rows[x0];

      for (wi_t ii = 0; ii < wide; ++ii)
        b[ii] ^= t0[ii];
    }
  }
  for(int i=0; i<__M4RI_TRSM_NTABLES; i++) {
    mzd_free(T[i]);
#ifdef __M4RI_HAVE_SSE2
    mzd_free(Talign[i]);
#endif
    m4ri_mm_free(J[i]);
  }

  __M4RI_DD_MZD(B);
}


void mzd_make_table_trtri(mzd_t const *M, rci_t r, rci_t c, int k, ple_table_t *Tb, rci_t startcol) {
  mzd_t *T = Tb->T;
  rci_t *L = Tb->E;

  assert(!(T->flags & mzd_flag_multiple_blocks));
  wi_t const blockoffset  = c / m4ri_radix;
  wi_t const blockoffset0 = startcol / m4ri_radix;

  assert(blockoffset - blockoffset0 <= 1);

  int const twokay= __M4RI_TWOPOW(k);
  wi_t const wide = T->width - blockoffset;
  wi_t const count = (wide + 7) / 8;
  int const entry_point = wide % 8;
  wi_t const next_row_offset = blockoffset + T->rowstride - T->width;

  word *ti, *ti1, *m;

  ti1 = T->rows[0] + blockoffset;
  ti = ti1 + T->rowstride;

  L[0] = 0;
  for (int i = 1; i < twokay; ++i) {
    T->rows[i][blockoffset0] = 0; /* we make sure that we can safely add from blockoffset0 */
    rci_t rowneeded = r + m4ri_codebook[k]->inc[i - 1];
    m = M->rows[rowneeded] + blockoffset;

    wi_t n = count;
    switch (entry_point) {
    case 0: do { *(ti++) = *(m++) ^ *(ti1++);
    case 7:      *(ti++) = *(m++) ^ *(ti1++);
    case 6:      *(ti++) = *(m++) ^ *(ti1++);
    case 5:      *(ti++) = *(m++) ^ *(ti1++);
    case 4:      *(ti++) = *(m++) ^ *(ti1++);
    case 3:      *(ti++) = *(m++) ^ *(ti1++);
    case 2:      *(ti++) = *(m++) ^ *(ti1++);
    case 1:      *(ti++) = *(m++) ^ *(ti1++);
      } while (--n > 0);
    }
    ti += next_row_offset;
    ti1 += next_row_offset;

    L[m4ri_codebook[k]->ord[i]] = i;
  }
  Tb->B[0] = 0;
  for(int i=1; i<twokay; ++i) {
    mzd_xor_bits(T, i, c, k, (word)m4ri_codebook[k]->ord[i]);
    Tb->B[i] = mzd_read_bits(T, i, startcol, m4ri_radix);
  }
}

#define __M4RI_TRTRI_NTABLES 4

static inline void _mzd_trtri_upper_submatrix(mzd_t *A, rci_t pivot_r, rci_t elim_r, const int k) {
  for(rci_t i=pivot_r; i<pivot_r+k; i++)
    for(rci_t j=elim_r; j<i; j++)
      if(mzd_read_bit(A,j,i) && (i+1)<A->ncols )
        mzd_row_add_offset(A, j, i, i+1);
}


mzd_t *mzd_trtri_upper_russian(mzd_t *A, int k) {
  assert(A->nrows == A->ncols);

  if (k == 0) {
    k = m4ri_opt_k(A->nrows, A->ncols, 0);
    if (k >= 7)
      k = 7;
    if (0.75 * __M4RI_TWOPOW(k) *A->ncols > __M4RI_CPU_L3_CACHE / 2.0)
      k -= 1;
  }

  const int kk = __M4RI_TRTRI_NTABLES*k;

  int k_[__M4RI_TRTRI_NTABLES];
  for (int i=0; i<__M4RI_TRTRI_NTABLES; i++)
    k_[i] = k;

  ple_table_t *T[__M4RI_TRTRI_NTABLES];
  mzd_t *U[__M4RI_TRTRI_NTABLES];
  for(int i=0; i<__M4RI_TRTRI_NTABLES; i++) {
    T[i] = ple_table_init(k, A->ncols);
    U[i] = mzd_init(k, A->ncols);
  }

  /** dummy offsets table for _mzd_ple_to_e**/
  rci_t id[m4ri_radix];
  for(int i=0; i<m4ri_radix; i++) id[i] = i;

  rci_t r = 0;
  while(r+kk <= A->nrows) {

    /***
     * ----------------------------
     * [  ....................... ]
     * [  ... U00 U01 U02 U03 ... ]
     * [  ...     U10 U12 U13 ... ]
     * ---------------------------- r
     * [  ...         U22 U23 ... ]
     * [  ...             U33 ... ]
     * ----------------------------
     *
     * Assume [ U00 U01 ] was already inverted and multiplied with [ U02 U03 ... ]
     *        [     U10 ]                                          [ U12 U13 ... ]
     *
     * We then invert U22 and construct a table for [U22 U23 ... ], then we
     * invert [U33] and multiply it with [U23]. Then we construct a table for [U23 ... ]
     **/

    _mzd_trtri_upper_submatrix(A, r, r, k);
    _mzd_ple_to_e(U[0], A, r, r, k, id);
    mzd_make_table_trtri(U[0], 0, r,   k, T[0], r);

    _mzd_trtri_upper_submatrix(A, r+k, r, k);
    _mzd_ple_to_e(U[1], A, r+k, r+k, k, id);
    mzd_make_table_trtri(U[1], 0, r+k, k, T[1], r);

    _mzd_trtri_upper_submatrix(A, r+2*k, r, k);
    _mzd_ple_to_e(U[2], A, r+2*k, r+2*k, k, id);
    mzd_make_table_trtri(U[2], 0, r+2*k, k, T[2], r);

    _mzd_trtri_upper_submatrix(A, r+3*k, r, k);
    _mzd_ple_to_e(U[3], A, r+3*k, r+3*k, k, id);
    mzd_make_table_trtri(U[3], 0, r+3*k, k, T[3], r);

    _mzd_process_rows_ple_4(A, 0, r, r, k_, (const ple_table_t** const)T);
    r += kk;
  }

  /** deal with the rest **/
  while(r < A->nrows) {
    if (A->nrows - r < k)
      k = A->nrows - r;
    for(rci_t i=0; i<k; i++)
      for(rci_t j=0; j<i; j++)
        if(mzd_read_bit(A,r+j,r+i) && (r+i+1)<A->ncols )
          mzd_row_add_offset(A, r+j, r+i, r+i+1);

    _mzd_ple_to_e(U[0], A, r, r, k, id);
    mzd_make_table_trtri(U[0], 0, r, k, T[0], r);

    mzd_process_rows(A, 0, r, r, k, T[0]->T, T[0]->E);
    r += k;
  }

  for(int i=0; i<__M4RI_TRTRI_NTABLES; i++) {
    ple_table_free(T[i]);
    mzd_free(U[i]);
  }
  __M4RI_DD_MZD(A);
  return A;
}