/* -*- mode: org -*- * GEMM - General Matrix-Matrix Multiplication C_ij = sum_k A_ik*B_kj - Options: Scale, Accumulate, Transpose, Submatrices - e.g. C = alpha * op_1(A) * op_2(B) + beta * C - Core primitive of almost any linear algebra operation - GEMV being the other one - O(N*M*K) flops, O(N*M+M*K+N*K) data - Compute bound, high arithmetic intensity - Probably the most tuned operation in all of software - Today we'll look at some tuning techniques * Simplifications - *This* computer (i7-5600U) - No scaling, always accumulating, ld=dim - Single threaded - We'll see where parallelisation comes in later, but not measure it - Square matrices - We'll see how to change this later, but not measure it - Double precision * Vectorization - A /regular/ operations operates on one operand, i.e. a += b - The CPU has support for fixed-length vectors and can perform operations over them all - i.e. [a0, a1, a2, a3] += [b0, b1, b2, b3] - History (x86): First there was MMX, then SSE, then AVX, AVX2 and now AVX-512 - SSE (128-bit wide, operating on up to FP32/FP64) - AVX (256-bit wide) - AVX2 (same, FMA support) - AVX-512 (512-bit wide, FMA support) - Our CPU supports AVX2 - Another limiting factor: Registers - SSE/AVX/AVX2: 16, AVX-512: 32 * Hardware - Your CPU has cores - Your CPU has a number of caches (probably L3, L2 and L1-data and L1-instructions) [0] - Your CPU is complex: - Each cycle, it will fetch, decode, rename and schedule more than one instruction - It has execution units to do run the instructions on - Two vector FMA units - Two load units, one store unit - Two more misc. units - In our case (Haswell), check out [1] for more info - L1: 32KB, L2: 256KB * Peak Performance - 4 (double precision numbers per vector) - 2 (flops per FMA) - 2 (FMA operations per cycle) - 2.6ish GHz (clock rate, variable due to power mgmt/turbo boost) -> 41.6 GFLOPS * A naive GEMM - See Code - Issues: (1) Poor cache usage, (2) poor vector usage - Lets see how much reuse occurs - Especially bad: Skipping around 0: https://www.7-cpu.com/cpu/Haswell.html 1: https://www.realworldtech.com/haswell-cpu/6/ */ #include #include #include #include #include int min(int a, int b) { return a > b ? b : a; } void gemm_0(int N, double* A, double* B, double* C) { for (int i = 0; i < N; i++) { for (int j = 0; j < N; j++) { for (int k = 0; k < N; k++) { C[i + N * j] += A[i + N * k] * B[k + N * j]; } } } } #ifdef __cplusplus extern "C" #endif void dgemm_(char*, char*, int*, int*, int*, double*, double*, int*, double*, int*, double*, double*, int*); void gemm_blas(int N, double* A, double* B, double* C) { double c_1 = 1; dgemm_("N", "N", &N, &N, &N, &c_1, A, &N, B, &N, &c_1, C, &N); } void bench(int N, int R, double * A, double * B, double * C, double * Cref, void (*gemm)(int N, double*, double*, double*), const char * name) { double max_flops = 0, avg_flops = 0, max_error = 0; for (int r = 0; r < R; r++) { for (int i = 0; i < N * N; i++) { C[i] = 0; } double start = omp_get_wtime(); gemm(N, A, B, C); double end = omp_get_wtime(); double error = 0; for (int i = 0; i < N * N; i++) { double err_i = C[i] - Cref[i]; if (err_i < 0) err_i = -err_i; if (error < err_i) error = err_i; } double flops = 2.0 * N * N * N / (end - start) * 1e-9; max_error = fmax(max_error, error); max_flops = fmax(max_flops, flops); avg_flops += flops / R; } printf("%s %d %f/%f GFLOPS %e abs error\n", name, N, max_flops, avg_flops, max_error); } /* * Loop Transposition - Prioritize stride-1 indices - Still suboptimal caching */ void gemm_1(int N, double* A, double* B, double* C) { for (int j = 0; j < N; j++) { for (int i = 0; i < N; i++) { for (int k = 0; k < N; k++) { C[i + N * j] += A[i + N * k] * B[k + N * j]; } } } } void gemm_2(int N, double* A, double* B, double* C) { for (int j = 0; j < N; j++) { for (int k = 0; k < N; k++) { for (int i = 0; i < N; i++) { C[i + N * j] += A[i + N * k] * B[k + N * j]; } } } } /* * First Blocking - Constants from BLIS - M/N/K=I/J/K - Improves Reuse */ void gemm_3(int N, double* A, double* B, double* C) { const int BS_J = 4080; const int BS_K = 256; const int BS_I = 72; for (int jb = 0; jb < N; jb += BS_J) { for (int kb = 0; kb < N; kb += BS_K) { for (int ib = 0; ib < N; ib += BS_I) { for (int j = jb; j < jb + min(BS_J, N - jb); j++) { for (int k = kb; k < kb + min(BS_K, N - kb); k++) { for (int i = ib; i < ib + min(BS_I, N - ib); i++) { C[i + N * j] += A[i + N * k] * B[k + N * j]; } } } } } } } /* * Additional Tiling - Again BLIS constants - Prepares Packing & Vectorization - Allows us to place submatrices in caches */ void gemm_4(int N, double* A, double* B, double* C) { const int BS_J = 4080; const int BS_K = 256; const int BS_I = 72; const int BC_J = 8; const int BC_I = 6; for (int jb = 0; jb < N; jb += BS_J) { for (int kb = 0; kb < N; kb += BS_K) { // pack B for (int ib = 0; ib < N; ib += BS_I) { // pack A for (int jc = jb; jc < jb + min(BS_J, N - jb); jc += BC_J) { for (int ic = ib; ic < ib + min(BS_I, N - ib); ic += BC_I) { for (int j = jc; j < jc + min(BC_J, N - jc); j++) { for (int k = kb; k < kb + min(BS_K, N - kb); k++) { for (int i = ic; i < ic + min(BC_I, N - ic); i++) { C[i + N * j] += A[i + N * k] * B[k + N * j]; } } } } } } } } } /* * Packing - Removes TLB issues - Improves locality */ void pack_B(int BS_J, int BS_K, int BC_J, double* Bpack, double* B, int N, int jb, int kb) { int j_bound = min(BS_J, N - jb); for (int j = 0; j < j_bound / BC_J; j++) { for (int k = 0; k < min(BS_K, N - kb); k++) { for (int l = 0; l < BC_J; l++) { Bpack[j * BS_K * BC_J + k * BC_J + l] = B[(kb + k) + N * (jb + j * BC_J + l)]; } } } if (j_bound % BC_J != 0) { int j = j_bound / BC_J; for (int k = 0; k < min(BS_K, N - kb); k++) { for (int l = 0; l < j_bound % BC_J; l++) { Bpack[j * BS_K * BC_J + k * BC_J + l] = B[(kb + k) + N * (jb + j * BC_J + l)]; } for (int l = j_bound % BC_J; l < BC_J; l++) { Bpack[j * BS_K * BC_J + k * BC_J + l] = 0; } } } } void pack_A(int BS_I, int BS_K, int BC_I, double* Apack, double* A, int N, int ib, int kb) { int i_bound = min(BS_I, N - ib); for (int i = 0; i < i_bound / BC_I; i++) { for (int l = 0; l < BC_I; l++) { for (int k = 0; k < min(BS_K, N - kb); k++) { Apack[i * BS_K * BC_I + k * BC_I + l] = A[(ib + i * BC_I + l) + N * (kb + k)]; } } } if (i_bound % BC_I != 0) { int i = i_bound / BC_I; for (int k = 0; k < min(BS_K, N - kb); k++) { for (int l = 0; l < i_bound % BC_I; l++) { Apack[i * BS_K * BC_I + k * BC_I + l] = A[(ib + i * BC_I + l) + N * (kb + k)]; } for (int l = i_bound % BC_I; l < BC_I; l++) { Apack[i * BS_K * BC_I + k * BC_I + l] = 0; } } } } void gemm_5(int N, double* A, double* B, double* C) { const int BS_J = 4080; const int BS_K = 256; const int BS_I = 72; const int BC_J = 8; const int BC_I = 6; double *Apack = _mm_malloc(BS_I * BS_K * sizeof(double), 64); double *Bpack = _mm_malloc(BS_J * BS_K * sizeof(double), 64); for (int jb = 0; jb < N; jb += BS_J) { for (int kb = 0; kb < N; kb += BS_K) { // pack B pack_B(BS_J, BS_K, BC_J, Bpack, B, N, jb, kb); for (int ib = 0; ib < N; ib += BS_I) { // pack A pack_A(BS_I, BS_K, BC_I, Apack, A, N, ib, kb); int jc_loc = 0; for (int jc = jb; jc < jb + min(BS_J, N - jb); jc += BC_J, jc_loc += 1) { double *Blocal = &Bpack[jc_loc * BS_K * BC_J]; int ic_loc = 0; for (int ic = ib; ic < ib + min(BS_I, N - ib); ic += BC_I, ic_loc += 1) { double *Alocal = &Apack[ic_loc * BS_K * BC_I]; int k_loc = 0; for (int k = kb; k < kb + min(BS_K, N - kb); k++, k_loc += 1) { int j_loc = 0; for (int j = jc; j < jc + BC_J; j++, j_loc += 1) { int i_loc = 0; for (int i = ic; i < ic + BC_I; i++, i_loc += 1) { C[i + N * j] += Alocal[k_loc * BC_I + i_loc] * Blocal[k_loc * BC_J + j_loc]; } } } } } } } } _mm_free(Apack); _mm_free(Bpack); } /* * Vectorization - Intrinsics - Approach: Write-back through mem - Avoids transpose code */ void gemm_6(int N, double* A, double* B, double* C) { const int BS_J = 4080; const int BS_K = 256; const int BS_I = 72; const int BC_J = 8; const int BC_I = 6; double *Apack = _mm_malloc(BS_I * BS_K * sizeof(double), 64); double *Bpack = _mm_malloc(BS_J * BS_K * sizeof(double), 64); for (int jb = 0; jb < N; jb += BS_J) { for (int kb = 0; kb < N; kb += BS_K) { // pack B pack_B(BS_J, BS_K, BC_J, Bpack, B, N, jb, kb); for (int ib = 0; ib < N; ib += BS_I) { // pack A pack_A(BS_I, BS_K, BC_I, Apack, A, N, ib, kb); int jc_loc = 0; for (int jc = jb; jc < jb + min(BS_J, N - jb); jc += BC_J, jc_loc += 1) { double *Blocal = &Bpack[jc_loc * BS_K * BC_J]; int ic_loc = 0; for (int ic = ib; ic < ib + min(BS_I, N - ib); ic += BC_I, ic_loc += 1) { double *Alocal = &Apack[ic_loc * BS_K * BC_I]; int j_loc = 0; for (int j = jc; j < jc + BC_J; j += 4, j_loc += 4) { int i_loc = 0; for (int i = ic; i < ic + BC_I; i++, i_loc += 1) { __m256d C_00 = _mm256_setzero_pd(); int k_loc = 0; for (int k = kb; k < kb + min(BS_K, N - kb); k++, k_loc += 1) { C_00 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + i_loc]), _mm256_load_pd(&Blocal[k_loc * BC_J + j_loc]), C_00); } double buffer[4] __attribute__((aligned(32))); _mm256_store_pd(buffer, C_00); for (int jj = 0; jj < min(4, N - j); jj++) { C[i + N * (j + jj)] += buffer[jj]; /* <--- "+=" instead of "=" mistake during lecture */ } } } } } } } } _mm_free(Apack); _mm_free(Bpack); } /* * Unrolling i, j - Also register blocking - First j, then i */ void gemm_7(int N, double* A, double* B, double* C) { const int BS_J = 4080; const int BS_K = 256; const int BS_I = 72; const int BC_J = 8; const int BC_I = 6; double *Apack = _mm_malloc(BS_I * BS_K * sizeof(double), 64); double *Bpack = _mm_malloc(BS_J * BS_K * sizeof(double), 64); for (int jb = 0; jb < N; jb += BS_J) { for (int kb = 0; kb < N; kb += BS_K) { // pack B pack_B(BS_J, BS_K, BC_J, Bpack, B, N, jb, kb); for (int ib = 0; ib < N; ib += BS_I) { // pack A pack_A(BS_I, BS_K, BC_I, Apack, A, N, ib, kb); int jc_loc = 0; for (int jc = jb; jc < jb + min(BS_J, N - jb); jc += BC_J, jc_loc += 1) { double *Blocal = &Bpack[jc_loc * BS_K * BC_J]; int ic_loc = 0; for (int ic = ib; ic < ib + min(BS_I, N - ib); ic += BC_I, ic_loc += 1) { double *Alocal = &Apack[ic_loc * BS_K * BC_I]; int i_loc = 0; for (int i = ic; i < ic + BC_I; i++, i_loc += 1) { __m256d C_00 = _mm256_setzero_pd(); __m256d C_01 = _mm256_setzero_pd(); int k_loc = 0; for (int k = kb; k < kb + min(BS_K, N - kb); k++, k_loc += 1) { C_00 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + i_loc]), _mm256_load_pd(&Blocal[k_loc * BC_J + 0]), C_00); C_01 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + i_loc]), _mm256_load_pd(&Blocal[k_loc * BC_J + 4]), C_01); } double buffer[BC_J] __attribute__((aligned(32))); _mm256_store_pd(&buffer[0], C_00); _mm256_store_pd(&buffer[4], C_01); for (int jj = 0; jj < min(BC_J, N - jc); jj++) { C[i + N * (jc + jj)] += buffer[jj]; } } } } } } } _mm_free(Apack); _mm_free(Bpack); } void gemm_8(int N, double* A, double* B, double* C) { const int BS_J = 4080; const int BS_K = 256; const int BS_I = 72; const int BC_J = 8; const int BC_I = 6; double *Apack = _mm_malloc(BS_I * BS_K * sizeof(double), 64); double *Bpack = _mm_malloc(BS_J * BS_K * sizeof(double), 64); for (int jb = 0; jb < N; jb += BS_J) { for (int kb = 0; kb < N; kb += BS_K) { // pack B pack_B(BS_J, BS_K, BC_J, Bpack, B, N, jb, kb); for (int ib = 0; ib < N; ib += BS_I) { // pack A pack_A(BS_I, BS_K, BC_I, Apack, A, N, ib, kb); int jc_loc = 0; for (int jc = jb; jc < jb + min(BS_J, N - jb); jc += BC_J, jc_loc += 1) { double *Blocal = &Bpack[jc_loc * BS_K * BC_J]; int ic_loc = 0; for (int ic = ib; ic < ib + min(BS_I, N - ib); ic += BC_I, ic_loc += 1) { double *Alocal = &Apack[ic_loc * BS_K * BC_I]; __m256d C_00 = _mm256_setzero_pd(); __m256d C_01 = _mm256_setzero_pd(); __m256d C_10 = _mm256_setzero_pd(); __m256d C_11 = _mm256_setzero_pd(); __m256d C_20 = _mm256_setzero_pd(); __m256d C_21 = _mm256_setzero_pd(); __m256d C_30 = _mm256_setzero_pd(); __m256d C_31 = _mm256_setzero_pd(); __m256d C_40 = _mm256_setzero_pd(); __m256d C_41 = _mm256_setzero_pd(); __m256d C_50 = _mm256_setzero_pd(); __m256d C_51 = _mm256_setzero_pd(); int k_loc = 0; for (int k = kb; k < kb + min(BS_K, N - kb); k++, k_loc += 1) { C_00 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 0]), _mm256_load_pd(&Blocal[k_loc * BC_J + 0]), C_00); C_01 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 0]), _mm256_load_pd(&Blocal[k_loc * BC_J + 4]), C_01); C_10 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 1]), _mm256_load_pd(&Blocal[k_loc * BC_J + 0]), C_10); C_11 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 1]), _mm256_load_pd(&Blocal[k_loc * BC_J + 4]), C_11); C_20 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 2]), _mm256_load_pd(&Blocal[k_loc * BC_J + 0]), C_20); C_21 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 2]), _mm256_load_pd(&Blocal[k_loc * BC_J + 4]), C_21); C_30 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 3]), _mm256_load_pd(&Blocal[k_loc * BC_J + 0]), C_30); C_31 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 3]), _mm256_load_pd(&Blocal[k_loc * BC_J + 4]), C_31); C_40 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 4]), _mm256_load_pd(&Blocal[k_loc * BC_J + 0]), C_40); C_41 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 4]), _mm256_load_pd(&Blocal[k_loc * BC_J + 4]), C_41); C_50 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 5]), _mm256_load_pd(&Blocal[k_loc * BC_J + 0]), C_50); C_51 = _mm256_fmadd_pd(_mm256_set1_pd(Alocal[k_loc * BC_I + 5]), _mm256_load_pd(&Blocal[k_loc * BC_J + 4]), C_51); } double buffer[BC_I * BC_J] __attribute__((aligned(32))); _mm256_store_pd(&buffer[0 * BC_J + 0], C_00); _mm256_store_pd(&buffer[0 * BC_J + 4], C_01); _mm256_store_pd(&buffer[1 * BC_J + 0], C_10); _mm256_store_pd(&buffer[1 * BC_J + 4], C_11); _mm256_store_pd(&buffer[2 * BC_J + 0], C_20); _mm256_store_pd(&buffer[2 * BC_J + 4], C_21); _mm256_store_pd(&buffer[3 * BC_J + 0], C_30); _mm256_store_pd(&buffer[3 * BC_J + 4], C_31); _mm256_store_pd(&buffer[4 * BC_J + 0], C_40); _mm256_store_pd(&buffer[4 * BC_J + 4], C_41); _mm256_store_pd(&buffer[5 * BC_J + 0], C_50); _mm256_store_pd(&buffer[5 * BC_J + 4], C_51); for (int ii = 0; ii < min(BC_I, N - ic); ii++) { for (int jj = 0; jj < min(BC_J, N - jc); jj++) { C[ic + ii + N * (jc + jj)] += buffer[ii * BC_J + jj]; } } } } } } } _mm_free(Apack); _mm_free(Bpack); } void gemm_9(int N, double* A, double* B, double* C) { const int BS_J = 4080; const int BS_K = 256; const int BS_I = 72; const int BC_J = 8; const int BC_I = 6; double *Apack = _mm_malloc(BS_I * BS_K * sizeof(double), 64); double *Bpack = _mm_malloc(BS_J * BS_K * sizeof(double), 64); for (int jb = 0; jb < N; jb += BS_J) { for (int kb = 0; kb < N; kb += BS_K) { // pack B pack_B(BS_J, BS_K, BC_J, Bpack, B, N, jb, kb); for (int ib = 0; ib < N; ib += BS_I) { // pack A pack_A(BS_I, BS_K, BC_I, Apack, A, N, ib, kb); int jc_loc = 0; for (int jc = jb; jc < jb + min(BS_J, N - jb); jc += BC_J, jc_loc += 1) { double *Blocal = &Bpack[jc_loc * BS_K * BC_J]; int ic_loc = 0; for (int ic = ib; ic < ib + min(BS_I, N - ib); ic += BC_I, ic_loc += 1) { double *Alocal = &Apack[ic_loc * BS_K * BC_I]; __m256d C_00 = _mm256_setzero_pd(); __m256d C_01 = _mm256_setzero_pd(); __m256d C_10 = _mm256_setzero_pd(); __m256d C_11 = _mm256_setzero_pd(); __m256d C_20 = _mm256_setzero_pd(); __m256d C_21 = _mm256_setzero_pd(); __m256d C_30 = _mm256_setzero_pd(); __m256d C_31 = _mm256_setzero_pd(); __m256d C_40 = _mm256_setzero_pd(); __m256d C_41 = _mm256_setzero_pd(); __m256d C_50 = _mm256_setzero_pd(); __m256d C_51 = _mm256_setzero_pd(); __m256d A_0, A_1, B_0, B_1; for (int k_loc = 0; k_loc < min(BS_K, N - kb); k_loc += 1) { // Cleanup B_0 = _mm256_load_pd(&Blocal[k_loc * BC_J + 0]); B_1 = _mm256_load_pd(&Blocal[k_loc * BC_J + 4]); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 0]); C_00 = _mm256_fmadd_pd(A_0, B_0, C_00); C_01 = _mm256_fmadd_pd(A_0, B_1, C_01); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 1]); C_10 = _mm256_fmadd_pd(A_0, B_0, C_10); C_11 = _mm256_fmadd_pd(A_0, B_1, C_11); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 2]); C_20 = _mm256_fmadd_pd(A_0, B_0, C_20); C_21 = _mm256_fmadd_pd(A_0, B_1, C_21); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 3]); C_30 = _mm256_fmadd_pd(A_0, B_0, C_30); C_31 = _mm256_fmadd_pd(A_0, B_1, C_31); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 4]); C_40 = _mm256_fmadd_pd(A_0, B_0, C_40); C_41 = _mm256_fmadd_pd(A_0, B_1, C_41); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 5]); C_50 = _mm256_fmadd_pd(A_0, B_0, C_50); C_51 = _mm256_fmadd_pd(A_0, B_1, C_51); } double buffer[BC_I * BC_J] __attribute__((aligned(32))); _mm256_store_pd(&buffer[0 * BC_J + 0], C_00); _mm256_store_pd(&buffer[0 * BC_J + 4], C_01); _mm256_store_pd(&buffer[1 * BC_J + 0], C_10); _mm256_store_pd(&buffer[1 * BC_J + 4], C_11); _mm256_store_pd(&buffer[2 * BC_J + 0], C_20); _mm256_store_pd(&buffer[2 * BC_J + 4], C_21); _mm256_store_pd(&buffer[3 * BC_J + 0], C_30); _mm256_store_pd(&buffer[3 * BC_J + 4], C_31); _mm256_store_pd(&buffer[4 * BC_J + 0], C_40); _mm256_store_pd(&buffer[4 * BC_J + 4], C_41); _mm256_store_pd(&buffer[5 * BC_J + 0], C_50); _mm256_store_pd(&buffer[5 * BC_J + 4], C_51); for (int ii = 0; ii < min(BC_I, N - ic); ii++) { for (int jj = 0; jj < min(BC_J, N - jc); jj++) { C[ic + ii + N * (jc + jj)] += buffer[ii * BC_J + jj]; } } } } } } } _mm_free(Apack); _mm_free(Bpack); } /* * Unroll k */ void gemm_10(int N, double* A, double* B, double* C) { const int BS_J = 4080; const int BS_K = 256; const int BS_I = 72; const int BC_J = 8; const int BC_I = 6; double *Apack = _mm_malloc(BS_I * BS_K * sizeof(double), 64); double *Bpack = _mm_malloc(BS_J * BS_K * sizeof(double), 64); for (int jb = 0; jb < N; jb += BS_J) { for (int kb = 0; kb < N; kb += BS_K) { // pack B pack_B(BS_J, BS_K, BC_J, Bpack, B, N, jb, kb); for (int ib = 0; ib < N; ib += BS_I) { // pack A pack_A(BS_I, BS_K, BC_I, Apack, A, N, ib, kb); int jc_loc = 0; for (int jc = jb; jc < jb + min(BS_J, N - jb); jc += BC_J, jc_loc += 1) { double *Blocal = &Bpack[jc_loc * BS_K * BC_J]; int ic_loc = 0; for (int ic = ib; ic < ib + min(BS_I, N - ib); ic += BC_I, ic_loc += 1) { double *Alocal = &Apack[ic_loc * BS_K * BC_I]; __m256d C_00 = _mm256_setzero_pd(); __m256d C_01 = _mm256_setzero_pd(); __m256d C_10 = _mm256_setzero_pd(); __m256d C_11 = _mm256_setzero_pd(); __m256d C_20 = _mm256_setzero_pd(); __m256d C_21 = _mm256_setzero_pd(); __m256d C_30 = _mm256_setzero_pd(); __m256d C_31 = _mm256_setzero_pd(); __m256d C_40 = _mm256_setzero_pd(); __m256d C_41 = _mm256_setzero_pd(); __m256d C_50 = _mm256_setzero_pd(); __m256d C_51 = _mm256_setzero_pd(); __m256d A_0, A_1, B_0, B_1; int k_loc = 0; int k_total = min(BS_K, N - kb); int k_unroll = (k_total / 4) * 4; int k_remainder = k_total % 4; for (int k_loc = 0; k_loc < k_unroll; k_loc += 1) { B_0 = _mm256_load_pd(&Blocal[k_loc * BC_J + 0]); B_1 = _mm256_load_pd(&Blocal[k_loc * BC_J + 4]); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 0]); C_00 = _mm256_fmadd_pd(A_0, B_0, C_00); C_01 = _mm256_fmadd_pd(A_0, B_1, C_01); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 1]); C_10 = _mm256_fmadd_pd(A_0, B_0, C_10); C_11 = _mm256_fmadd_pd(A_0, B_1, C_11); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 2]); C_20 = _mm256_fmadd_pd(A_0, B_0, C_20); C_21 = _mm256_fmadd_pd(A_0, B_1, C_21); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 3]); C_30 = _mm256_fmadd_pd(A_0, B_0, C_30); C_31 = _mm256_fmadd_pd(A_0, B_1, C_31); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 4]); C_40 = _mm256_fmadd_pd(A_0, B_0, C_40); C_41 = _mm256_fmadd_pd(A_0, B_1, C_41); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 5]); C_50 = _mm256_fmadd_pd(A_0, B_0, C_50); C_51 = _mm256_fmadd_pd(A_0, B_1, C_51); k_loc += 1; B_0 = _mm256_load_pd(&Blocal[k_loc * BC_J + 0]); B_1 = _mm256_load_pd(&Blocal[k_loc * BC_J + 4]); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 0]); C_00 = _mm256_fmadd_pd(A_0, B_0, C_00); C_01 = _mm256_fmadd_pd(A_0, B_1, C_01); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 1]); C_10 = _mm256_fmadd_pd(A_0, B_0, C_10); C_11 = _mm256_fmadd_pd(A_0, B_1, C_11); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 2]); C_20 = _mm256_fmadd_pd(A_0, B_0, C_20); C_21 = _mm256_fmadd_pd(A_0, B_1, C_21); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 3]); C_30 = _mm256_fmadd_pd(A_0, B_0, C_30); C_31 = _mm256_fmadd_pd(A_0, B_1, C_31); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 4]); C_40 = _mm256_fmadd_pd(A_0, B_0, C_40); C_41 = _mm256_fmadd_pd(A_0, B_1, C_41); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 5]); C_50 = _mm256_fmadd_pd(A_0, B_0, C_50); C_51 = _mm256_fmadd_pd(A_0, B_1, C_51); k_loc += 1; B_0 = _mm256_load_pd(&Blocal[k_loc * BC_J + 0]); B_1 = _mm256_load_pd(&Blocal[k_loc * BC_J + 4]); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 0]); C_00 = _mm256_fmadd_pd(A_0, B_0, C_00); C_01 = _mm256_fmadd_pd(A_0, B_1, C_01); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 1]); C_10 = _mm256_fmadd_pd(A_0, B_0, C_10); C_11 = _mm256_fmadd_pd(A_0, B_1, C_11); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 2]); C_20 = _mm256_fmadd_pd(A_0, B_0, C_20); C_21 = _mm256_fmadd_pd(A_0, B_1, C_21); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 3]); C_30 = _mm256_fmadd_pd(A_0, B_0, C_30); C_31 = _mm256_fmadd_pd(A_0, B_1, C_31); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 4]); C_40 = _mm256_fmadd_pd(A_0, B_0, C_40); C_41 = _mm256_fmadd_pd(A_0, B_1, C_41); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 5]); C_50 = _mm256_fmadd_pd(A_0, B_0, C_50); C_51 = _mm256_fmadd_pd(A_0, B_1, C_51); k_loc += 1; B_0 = _mm256_load_pd(&Blocal[k_loc * BC_J + 0]); B_1 = _mm256_load_pd(&Blocal[k_loc * BC_J + 4]); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 0]); C_00 = _mm256_fmadd_pd(A_0, B_0, C_00); C_01 = _mm256_fmadd_pd(A_0, B_1, C_01); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 1]); C_10 = _mm256_fmadd_pd(A_0, B_0, C_10); C_11 = _mm256_fmadd_pd(A_0, B_1, C_11); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 2]); C_20 = _mm256_fmadd_pd(A_0, B_0, C_20); C_21 = _mm256_fmadd_pd(A_0, B_1, C_21); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 3]); C_30 = _mm256_fmadd_pd(A_0, B_0, C_30); C_31 = _mm256_fmadd_pd(A_0, B_1, C_31); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 4]); C_40 = _mm256_fmadd_pd(A_0, B_0, C_40); C_41 = _mm256_fmadd_pd(A_0, B_1, C_41); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 5]); C_50 = _mm256_fmadd_pd(A_0, B_0, C_50); C_51 = _mm256_fmadd_pd(A_0, B_1, C_51); } for (int k_loc = 0; k_loc < k_remainder; k_loc += 1) { B_0 = _mm256_load_pd(&Blocal[k_loc * BC_J + 0]); B_1 = _mm256_load_pd(&Blocal[k_loc * BC_J + 4]); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 0]); C_00 = _mm256_fmadd_pd(A_0, B_0, C_00); C_01 = _mm256_fmadd_pd(A_0, B_1, C_01); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 1]); C_10 = _mm256_fmadd_pd(A_0, B_0, C_10); C_11 = _mm256_fmadd_pd(A_0, B_1, C_11); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 2]); C_20 = _mm256_fmadd_pd(A_0, B_0, C_20); C_21 = _mm256_fmadd_pd(A_0, B_1, C_21); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 3]); C_30 = _mm256_fmadd_pd(A_0, B_0, C_30); C_31 = _mm256_fmadd_pd(A_0, B_1, C_31); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 4]); C_40 = _mm256_fmadd_pd(A_0, B_0, C_40); C_41 = _mm256_fmadd_pd(A_0, B_1, C_41); A_0 = _mm256_set1_pd(Alocal[k_loc * BC_I + 5]); C_50 = _mm256_fmadd_pd(A_0, B_0, C_50); C_51 = _mm256_fmadd_pd(A_0, B_1, C_51); } double buffer[BC_I * BC_J] __attribute__((aligned(32))); _mm256_store_pd(&buffer[0 * BC_J + 0], C_00); _mm256_store_pd(&buffer[0 * BC_J + 4], C_01); _mm256_store_pd(&buffer[1 * BC_J + 0], C_10); _mm256_store_pd(&buffer[1 * BC_J + 4], C_11); _mm256_store_pd(&buffer[2 * BC_J + 0], C_20); _mm256_store_pd(&buffer[2 * BC_J + 4], C_21); _mm256_store_pd(&buffer[3 * BC_J + 0], C_30); _mm256_store_pd(&buffer[3 * BC_J + 4], C_31); _mm256_store_pd(&buffer[4 * BC_J + 0], C_40); _mm256_store_pd(&buffer[4 * BC_J + 4], C_41); _mm256_store_pd(&buffer[5 * BC_J + 0], C_50); _mm256_store_pd(&buffer[5 * BC_J + 4], C_51); for (int ii = 0; ii < min(BC_I, N - ic); ii++) { for (int jj = 0; jj < min(BC_J, N - jc); jj++) { C[ic + ii + N * (jc + jj)] += buffer[ii * BC_J + jj]; } } } } } } } _mm_free(Apack); _mm_free(Bpack); } /* * Assembly - Maximum control over register allocation - Maximum control over instruction scheduling */ void gemm_11(int N, double* A, double* B, double* C) { const int BS_J = 4080; const int BS_K = 256; const int BS_I = 72; const int BC_J = 8; const int BC_I = 6; double *Apack = _mm_malloc(BS_I * BS_K * sizeof(double), 64); double *Bpack = _mm_malloc(BS_J * BS_K * sizeof(double), 64); for (int jb = 0; jb < N; jb += BS_J) { for (int kb = 0; kb < N; kb += BS_K) { // pack B pack_B(BS_J, BS_K, BC_J, Bpack, B, N, jb, kb); for (int ib = 0; ib < N; ib += BS_I) { // pack A pack_A(BS_I, BS_K, BC_I, Apack, A, N, ib, kb); int jc_loc = 0; for (int jc = jb; jc < jb + min(BS_J, N - jb); jc += BC_J, jc_loc += 1) { double *Blocal = &Bpack[jc_loc * BS_K * BC_J]; int ic_loc = 0; for (int ic = ib; ic < ib + min(BS_I, N - ib); ic += BC_I, ic_loc += 1) { double *Alocal = &Apack[ic_loc * BS_K * BC_I]; double buffer[BC_I * BC_J] __attribute__((aligned(32))); double * buf_ptr = &buffer[0]; int kmax = min(BS_K, N - kb); int k_unroll = kmax / 4; int k_rest = kmax % 4; // 0-11: C, 12, 13: B, 14, 15: A __asm__ volatile ( "movq %0, %%rcx \n\t" "movq %1, %%rax \n\t" "movq %2, %%rbx \n\t" "movl %3, %%esi \n\t" "movl %4, %%edi \n\t" " \n\t" "vzeroall \n\t" " \n\t" "vmovapd 0 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 1 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "testl %%esi, %%esi \n\t" "jz .AFTERLOOP%= \n\t" " \n\t" ".p2align 4 \n\t" ".LOOP%=: \n\t" " \n\t" "vbroadcastsd 0 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 1 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm0 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm1 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm2 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm3 \n\t" " \n\t" "vbroadcastsd 2 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 3 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm4 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm5 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm6 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm7 \n\t" " \n\t" "vbroadcastsd 4 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 5 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm8 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm9 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm10 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm11 \n\t" " \n\t" "vmovapd 2 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 3 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "vbroadcastsd 6 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 7 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm0 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm1 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm2 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm3 \n\t" " \n\t" "vbroadcastsd 8 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 9 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm4 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm5 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm6 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm7 \n\t" " \n\t" "vbroadcastsd 10* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 11* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm8 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm9 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm10 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm11 \n\t" " \n\t" "vmovapd 4 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 5 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "vbroadcastsd 12* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 13* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm0 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm1 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm2 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm3 \n\t" " \n\t" "vbroadcastsd 14* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 15* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm4 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm5 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm6 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm7 \n\t" " \n\t" "vbroadcastsd 16* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 17* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm8 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm9 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm10 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm11 \n\t" " \n\t" "vmovapd 6 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 7 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "subq $-8 * 32, %%rbx \n\t" " \n\t" "vbroadcastsd 18* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 19* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm0 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm1 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm2 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm3 \n\t" " \n\t" "vbroadcastsd 20* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 21* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm4 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm5 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm6 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm7 \n\t" " \n\t" "vbroadcastsd 22* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 23* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm8 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm9 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm10 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm11 \n\t" " \n\t" "subq $-24 * 8, %%rax \n\t" " \n\t" "vmovapd 0 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 1 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "decl %%esi \n\t" " \n\t" "jnz .LOOP%= \n\t" " \n\t" ".AFTERLOOP%=: \n\t" " \n\t" "testl %%edi, %%edi \n\t" "jz .AFTERREST%= \n\t" " \n\t" ".REST%=: \n\t" " \n\t" "vbroadcastsd 0 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 1 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm0 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm1 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm2 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm3 \n\t" " \n\t" "vbroadcastsd 2 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 3 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm4 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm5 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm6 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm7 \n\t" " \n\t" "vbroadcastsd 4 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 5 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm8 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm9 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm10 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm11 \n\t" " \n\t" "addq $2 * 32, %%rbx \n\t" "addq $6 * 8, %%rax \n\t" " \n\t" "vmovapd 0 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 1 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "decl %%edi \n\t" " \n\t" "jnz .REST%= \n\t" " \n\t" ".AFTERREST%=: \n\t" " \n\t" "vmovapd %%ymm0, 0 * 32(%%rcx) \n\t" "vmovapd %%ymm1, 1 * 32(%%rcx) \n\t" "vmovapd %%ymm2, 2 * 32(%%rcx) \n\t" "vmovapd %%ymm3, 3 * 32(%%rcx) \n\t" " \n\t" "vmovapd %%ymm4, 4 * 32(%%rcx) \n\t" "vmovapd %%ymm5, 5 * 32(%%rcx) \n\t" "vmovapd %%ymm6, 6 * 32(%%rcx) \n\t" "vmovapd %%ymm7, 7 * 32(%%rcx) \n\t" " \n\t" "vmovapd %%ymm8, 8 * 32(%%rcx) \n\t" "vmovapd %%ymm9, 9 * 32(%%rcx) \n\t" "vmovapd %%ymm10, 10* 32(%%rcx) \n\t" "vmovapd %%ymm11, 11* 32(%%rcx) \n\t" : : "m" (buf_ptr), // 0 "m" (Alocal), // 1 "m" (Blocal), // 2 "m" (k_unroll), // 3 "m" (k_rest) // 4 : "rax", "rbx", "rcx", "esi", "edi", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15" ); for (int ii = 0; ii < min(BC_I, N - ic); ii++) { for (int jj = 0; jj < min(BC_J, N - jc); jj++) { C[ic + ii + N * (jc + jj)] += buffer[ii * BC_J + jj]; } } } } } } } _mm_free(Apack); _mm_free(Bpack); } /* * Prefetching - Instructions that preload memory locations - Placement is ``magic'', distance is ``magic'' */ void gemm_12(int N, double* A, double* B, double* C) { const int BS_J = 4080; const int BS_K = 256; const int BS_I = 72; const int BC_J = 8; const int BC_I = 6; double *Apack = _mm_malloc(BS_I * BS_K * sizeof(double), 64); double *Bpack = _mm_malloc(BS_J * BS_K * sizeof(double), 64); for (int jb = 0; jb < N; jb += BS_J) { for (int kb = 0; kb < N; kb += BS_K) { // pack B pack_B(BS_J, BS_K, BC_J, Bpack, B, N, jb, kb); for (int ib = 0; ib < N; ib += BS_I) { // pack A pack_A(BS_I, BS_K, BC_I, Apack, A, N, ib, kb); int jc_loc = 0; for (int jc = jb; jc < jb + min(BS_J, N - jb); jc += BC_J, jc_loc += 1) { double *Blocal = &Bpack[jc_loc * BS_K * BC_J]; int ic_loc = 0; for (int ic = ib; ic < ib + min(BS_I, N - ib); ic += BC_I, ic_loc += 1) { double *Alocal = &Apack[ic_loc * BS_K * BC_I]; double buffer[BC_I * BC_J] __attribute__((aligned(32))); double * buf_ptr = &buffer[0]; int kmax = min(BS_K, N - kb); int k_unroll = kmax / 4; int k_rest = kmax % 4; // 0-11: C, 12, 13: B, 14, 15: A __asm__ volatile ( "movq %0, %%rcx \n\t" "movq %1, %%rax \n\t" "movq %2, %%rbx \n\t" "movl %3, %%esi \n\t" "movl %4, %%edi \n\t" " \n\t" "vzeroall \n\t" " \n\t" "vmovapd 0 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 1 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "testl %%esi, %%esi \n\t" "jz .AFTERLOOP%= \n\t" " \n\t" ".p2align 4 \n\t" ".LOOP%=: \n\t" " \n\t" "vbroadcastsd 0 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 1 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm0 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm1 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm2 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm3 \n\t" " \n\t" "vbroadcastsd 2 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 3 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm4 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm5 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm6 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm7 \n\t" " \n\t" "prefetcht0 8 * 32(%%rbx) \n\t" " \n\t" "vbroadcastsd 4 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 5 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm8 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm9 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm10 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm11 \n\t" " \n\t" "vmovapd 2 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 3 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "vbroadcastsd 6 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 7 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm0 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm1 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm2 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm3 \n\t" " \n\t" "vbroadcastsd 8 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 9 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm4 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm5 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm6 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm7 \n\t" " \n\t" "prefetcht0 10 * 32(%%rbx) \n\t" " \n\t" "vbroadcastsd 10* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 11* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm8 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm9 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm10 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm11 \n\t" " \n\t" "vmovapd 4 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 5 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "vbroadcastsd 12* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 13* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm0 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm1 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm2 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm3 \n\t" " \n\t" "vbroadcastsd 14* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 15* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm4 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm5 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm6 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm7 \n\t" " \n\t" "prefetcht0 12 * 32(%%rbx) \n\t" " \n\t" "vbroadcastsd 16* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 17* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm8 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm9 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm10 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm11 \n\t" " \n\t" "vmovapd 6 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 7 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "subq $-8 * 32, %%rbx \n\t" " \n\t" "vbroadcastsd 18* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 19* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm0 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm1 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm2 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm3 \n\t" " \n\t" "vbroadcastsd 20* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 21* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm4 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm5 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm6 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm7 \n\t" " \n\t" "vbroadcastsd 22* 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 23* 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm8 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm9 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm10 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm11 \n\t" " \n\t" "subq $-24 * 8, %%rax \n\t" " \n\t" "vmovapd 0 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 1 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "decl %%esi \n\t" " \n\t" "jnz .LOOP%= \n\t" " \n\t" ".AFTERLOOP%=: \n\t" " \n\t" "testl %%edi, %%edi \n\t" "jz .AFTERREST%= \n\t" " \n\t" ".REST%=: \n\t" " \n\t" "vbroadcastsd 0 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 1 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm0 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm1 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm2 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm3 \n\t" " \n\t" "vbroadcastsd 2 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 3 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm4 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm5 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm6 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm7 \n\t" " \n\t" "vbroadcastsd 4 * 8(%%rax), %%ymm14 \n\t" "vbroadcastsd 5 * 8(%%rax), %%ymm15 \n\t" "vfmadd231pd %%ymm12, %%ymm14, %%ymm8 \n\t" "vfmadd231pd %%ymm13, %%ymm14, %%ymm9 \n\t" "vfmadd231pd %%ymm12, %%ymm15, %%ymm10 \n\t" "vfmadd231pd %%ymm13, %%ymm15, %%ymm11 \n\t" " \n\t" "addq $2 * 32, %%rbx \n\t" "addq $6 * 8, %%rax \n\t" " \n\t" "vmovapd 0 * 32(%%rbx), %%ymm12 \n\t" "vmovapd 1 * 32(%%rbx), %%ymm13 \n\t" " \n\t" "decl %%edi \n\t" " \n\t" "jnz .REST%= \n\t" " \n\t" ".AFTERREST%=: \n\t" " \n\t" "vmovapd %%ymm0, 0 * 32(%%rcx) \n\t" "vmovapd %%ymm1, 1 * 32(%%rcx) \n\t" "vmovapd %%ymm2, 2 * 32(%%rcx) \n\t" "vmovapd %%ymm3, 3 * 32(%%rcx) \n\t" " \n\t" "vmovapd %%ymm4, 4 * 32(%%rcx) \n\t" "vmovapd %%ymm5, 5 * 32(%%rcx) \n\t" "vmovapd %%ymm6, 6 * 32(%%rcx) \n\t" "vmovapd %%ymm7, 7 * 32(%%rcx) \n\t" " \n\t" "vmovapd %%ymm8, 8 * 32(%%rcx) \n\t" "vmovapd %%ymm9, 9 * 32(%%rcx) \n\t" "vmovapd %%ymm10, 10* 32(%%rcx) \n\t" "vmovapd %%ymm11, 11* 32(%%rcx) \n\t" : : "m" (buf_ptr), // 0 "m" (Alocal), // 1 "m" (Blocal), // 2 "m" (k_unroll), // 3 "m" (k_rest) // 4 : "rax", "rbx", "rcx", "esi", "edi", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15" ); for (int ii = 0; ii < min(BC_I, N - ic); ii++) { for (int jj = 0; jj < min(BC_J, N - jc); jj++) { C[ic + ii + N * (jc + jj)] += buffer[ii * BC_J + jj]; } } } } } } } _mm_free(Apack); _mm_free(Bpack); } /* * Main Function */ int main(int argc, char ** argv) { int N = atoi(argv[1]); int R = atoi(argv[2]); double* A = calloc(N, N * sizeof(double)); double* B = calloc(N, N * sizeof(double)); double* C = calloc(N, N * sizeof(double)); double* Cref = calloc(N, N * sizeof(double)); for (int i = 0; i < N * N; i++) { A[i] = rand() * 1.0 / RAND_MAX; B[i] = rand() * 1.0 / RAND_MAX; } gemm_blas(N, A, B, Cref); #ifndef REF bench(N, R, A, B, C, Cref, gemm_0, "gemm_0"); bench(N, R, A, B, C, Cref, gemm_1, "gemm_1"); bench(N, R, A, B, C, Cref, gemm_2, "gemm_2"); bench(N, R, A, B, C, Cref, gemm_3, "gemm_3"); bench(N, R, A, B, C, Cref, gemm_4, "gemm_4"); bench(N, R, A, B, C, Cref, gemm_5, "gemm_5"); bench(N, R, A, B, C, Cref, gemm_6, "gemm_6"); bench(N, R, A, B, C, Cref, gemm_7, "gemm_7"); bench(N, R, A, B, C, Cref, gemm_8, "gemm_8"); bench(N, R, A, B, C, Cref, gemm_9, "gemm_9"); bench(N, R, A, B, C, Cref, gemm_10, "gemm_10"); bench(N, R, A, B, C, Cref, gemm_11, "gemm_11"); bench(N, R, A, B, C, Cref, gemm_12, "gemm_12"); bench(N, R, A, B, C, Cref, gemm_blas, "gemm_b"); #else bench(N, R, A, B, C, Cref, gemm_blas, "gemm_b"); #endif free(A); free(B); free(C); }