0xDEADBEEF

RSS odkazy english edition

matrix-multiplication.c

25. 11. 2016 #kód
__attribute__((always_inline)) inline float hadd(__m256 x) {
  x = _mm256_hadd_ps(x, x);
  x = _mm256_hadd_ps(x, x);
  return ((float*)&x)[0] + ((float*)&x)[4];
}

#define PAD 16

void square_mat_mul_tiered(float *a, float *b, const int len, float *res) {

  for (int i = 0; i < len*(len+PAD); i++) { res[i] = 0; }

  const int tile1 = TILE1;
  const int tile2 = TILE2;
  const int tile3 = TILE3;
  const int segment = S;

  #pragma omp parallel for
  for (int tilei3 = 0; tilei3 < len; tilei3 += tile3) {
    const int TILEI1 = tile1*2;
    __m256 sums[TILEI1*tile1];
    for (int tilej3 = 0; tilej3 < len; tilej3 += tile3) {

      for (int tilei2 = tilei3; tilei2 < tilei3+tile3; tilei2 += tile2) {
        for (int tilej2 = tilej3; tilej2 < tilej3+tile3; tilej2 += tile2) {

          for (int tilei1 = tilei2; tilei1 < tilei2+tile2; tilei1 += TILEI1) {
            for (int tilej1 = tilej2; tilej1 < tilej2+tile2; tilej1 += tile1) {

              __m256 zero = _mm256_set1_ps(0.0);
              for (int i = 0; i < TILEI1*tile1; i++) sums[i] = zero;

              for (int p = 0; p < len; p += segment) {

                for (int i = tilei1; i < tilei1+TILEI1; i += 2) {
                  for (int j = tilej1; j < tilej1+tile1; j += 2) {


                    int ii = i-tilei1;
                    int jj = j-tilej1;

                    float *_a = a+(i*(len+PAD))+p;
                    float *_b = a+((i+1)*(len+PAD))+p;
                    float *_c = b+(j*(len+PAD))+p;
                    float *_d = b+((j+1)*(len+PAD))+p;

                    __m256 da = sums[ ii   *tile1+jj]  ;
                    __m256 db = sums[ ii   *tile1+jj+1];
                    __m256 dc = sums[(ii+1)*tile1+jj]  ;
                    __m256 dd = sums[(ii+1)*tile1+jj+1];

                    for (int i = 0; i < segment; i+=8) {
                      __m256 aa = _mm256_load_ps(_a + i);
                      __m256 bb = _mm256_load_ps(_b + i);
                      __m256 cc = _mm256_load_ps(_c + i);
                      __m256 dd = _mm256_load_ps(_d + i);

                      da = _mm256_add_ps(da, _mm256_mul_ps(aa, cc));
                      db = _mm256_add_ps(db, _mm256_mul_ps(aa, dd));
                      dc = _mm256_add_ps(dc, _mm256_mul_ps(bb, cc));
                      dd = _mm256_add_ps(dd, _mm256_mul_ps(bb, dd));
                    }

                    sums[ ii   *tile1+jj]   = da;
                    sums[ ii   *tile1+jj+1] = db;
                    sums[(ii+1)*tile1+jj]   = dc;
                    sums[(ii+1)*tile1+jj+1] = dd;
                  }
                }

              }

              for (int ii = 0; ii < TILEI1; ii += 1) {
                for (int jj = 0; jj < tile1; jj += 1) {
                  int i = ii+tilei1;
                  int j = jj+tilej1;

                  res[i*len+j] = hadd(sums[ii*tile1+jj]);
                }
              }

            }
          }

        }
      }

    }
  }

}
píše k47 (@kaja47, k47)