1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
| #include <iostream> #include <chrono> #include <immintrin.h> #include <omp.h> #include <cmath> #define BLOCKSIZE 128 #define AVX_F_CAPACITY 8
void mul(double* a, double* b, double* c, uint64_t n1, uint64_t n2, uint64_t n3) { #pragma omp parallel for for (uint64_t i = 0; i < n1; i+=BLOCKSIZE) { for (uint64_t j = 0; j < n2; j+=BLOCKSIZE) { for (uint64_t k = 0; k < n3; k+=BLOCKSIZE) {
for(uint64_t ii=i; ii<i+BLOCKSIZE; ii+=AVX_F_CAPACITY) { for(uint64_t kk=k; kk<k+BLOCKSIZE; kk+=16) { __m512d vc0,vc1,vc2,vc3,vc4,vc5,vc6,vc7,vc8,vc9,vc10,vc11,vc12,vc13,vc14,vc15,vb,vb1; vc0 = _mm512_load_pd(&c[ii*n3+kk]); vc8 = _mm512_load_pd(&c[ii*n3+kk+8]); vc1 = _mm512_load_pd(&c[(ii+1)*n3+kk]); vc9 = _mm512_load_pd(&c[(ii+1)*n3+kk+8]); vc2 = _mm512_load_pd(&c[(ii+2)*n3+kk]); vc10 = _mm512_load_pd(&c[(ii+2)*n3+kk+8]); vc3 = _mm512_load_pd(&c[(ii+3)*n3+kk]); vc11 = _mm512_load_pd(&c[(ii+3)*n3+kk+8]); vc4 = _mm512_load_pd(&c[(ii+4)*n3+kk]); vc12 = _mm512_load_pd(&c[(ii+4)*n3+kk+8]); vc5 = _mm512_load_pd(&c[(ii+5)*n3+kk]); vc13 = _mm512_load_pd(&c[(ii+5)*n3+kk+8]); vc6 = _mm512_load_pd(&c[(ii+6)*n3+kk]); vc14 = _mm512_load_pd(&c[(ii+6)*n3+kk+8]); vc7 = _mm512_load_pd(&c[(ii+7)*n3+kk]); vc15 = _mm512_load_pd(&c[(ii+7)*n3+kk+8]);
for(uint64_t jj=j; jj<j+BLOCKSIZE; jj+=AVX_F_CAPACITY) { vb=_mm512_load_pd(&b[jj*n3 + kk]); vb1=_mm512_load_pd(&b[jj*n3 + kk+8]); vc0 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj]),vb,vc0); vc8 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj]),vb1,vc8); vc1 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj]),vb,vc1); vc9 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj]),vb1,vc9); vc2 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj]),vb,vc2); vc10 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj]),vb1,vc10); vc3 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj]),vb,vc3); vc11 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj]),vb1,vc11); vc4 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj]),vb,vc4); vc12 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj]),vb1,vc12); vc5 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj]),vb,vc5); vc13 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj]),vb1,vc13); vc6 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj]),vb,vc6); vc14 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj]),vb1,vc14); vc7 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj]),vb,vc7); vc15 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj]),vb1,vc15);
vb=_mm512_load_pd(&b[(jj+1)*n3 + kk]); vb1=_mm512_load_pd(&b[(jj+1)*n3 + kk+8]); vc0 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+1]),vb,vc0); vc8 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+1]),vb1,vc8); vc1 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+1]),vb,vc1); vc9 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+1]),vb1,vc9); vc2 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+1]),vb,vc2); vc10 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+1]),vb1,vc10); vc3 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+1]),vb,vc3); vc11 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+1]),vb1,vc11); vc4 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+1]),vb,vc4); vc12 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+1]),vb1,vc12); vc5 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+1]),vb,vc5); vc13 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+1]),vb1,vc13); vc6 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+1]),vb,vc6); vc14 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+1]),vb1,vc14); vc7 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+1]),vb,vc7); vc15 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+1]),vb1,vc15);
vb=_mm512_load_pd(&b[(jj+2)*n3 + kk]); vb1=_mm512_load_pd(&b[(jj+2)*n3 + kk+8]); vc0 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+2]),vb,vc0); vc8 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+2]),vb1,vc8); vc1 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+2]),vb,vc1); vc9 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+2]),vb1,vc9); vc2 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+2]),vb,vc2); vc10 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+2]),vb1,vc10); vc3 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+2]),vb,vc3); vc11 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+2]),vb1,vc11); vc4 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+2]),vb,vc4); vc12 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+2]),vb1,vc12); vc5 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+2]),vb,vc5); vc13 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+2]),vb1,vc13); vc6 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+2]),vb,vc6); vc14 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+2]),vb1,vc14); vc7 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+2]),vb,vc7); vc15 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+2]),vb1,vc15);
vb=_mm512_load_pd(&b[(jj+3)*n3 + kk]); vb1=_mm512_load_pd(&b[(jj+3)*n3 + kk+8]); vc0 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+3]),vb,vc0); vc8 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+3]),vb1,vc8); vc1 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+3]),vb,vc1); vc9 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+3]),vb1,vc9); vc2 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+3]),vb,vc2); vc10 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+3]),vb1,vc10); vc3 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+3]),vb,vc3); vc11 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+3]),vb1,vc11); vc4 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+3]),vb,vc4); vc12 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+3]),vb1,vc12); vc5 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+3]),vb,vc5); vc13 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+3]),vb1,vc13); vc6 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+3]),vb,vc6); vc14 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+3]),vb1,vc14); vc7 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+3]),vb,vc7); vc15 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+3]),vb1,vc15); vb=_mm512_load_pd(&b[(jj+4)*n3 + kk]); vb1=_mm512_load_pd(&b[(jj+4)*n3 + kk+8]); vc0 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+4]),vb,vc0); vc8 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+4]),vb1,vc8); vc1 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+4]),vb,vc1); vc9 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+4]),vb1,vc9); vc2 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+4]),vb,vc2); vc10 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+4]),vb1,vc10); vc3 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+4]),vb,vc3); vc11 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+4]),vb1,vc11); vc4 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+4]),vb,vc4); vc12 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+4]),vb1,vc12); vc5 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+4]),vb,vc5); vc13 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+4]),vb1,vc13); vc6 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+4]),vb,vc6); vc14 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+4]),vb1,vc14); vc7 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+4]),vb,vc7); vc15 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+4]),vb1,vc15);
vb=_mm512_load_pd(&b[(jj+5)*n3 + kk]); vb1=_mm512_load_pd(&b[(jj+5)*n3 + kk+8]); vc0 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+5]),vb,vc0); vc8 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+5]),vb1,vc8); vc1 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+5]),vb,vc1); vc9 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+5]),vb1,vc9); vc2 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+5]),vb,vc2); vc10 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+5]),vb1,vc10); vc3 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+5]),vb,vc3); vc11 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+5]),vb1,vc11); vc4 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+5]),vb,vc4); vc12 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+5]),vb1,vc12); vc5 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+5]),vb,vc5); vc13 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+5]),vb1,vc13); vc6 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+5]),vb,vc6); vc14 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+5]),vb1,vc14); vc7 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+5]),vb,vc7); vc15 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+5]),vb1,vc15);
vb=_mm512_load_pd(&b[(jj+6)*n3 + kk]); vb1=_mm512_load_pd(&b[(jj+6)*n3 + kk+8]); vc0 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+6]),vb,vc0); vc8 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+6]),vb1,vc8); vc1 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+6]),vb,vc1); vc9 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+6]),vb1,vc9); vc2 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+6]),vb,vc2); vc10 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+6]),vb1,vc10); vc3 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+6]),vb,vc3); vc11 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+6]),vb1,vc11); vc4 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+6]),vb,vc4); vc12 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+6]),vb1,vc12); vc5 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+6]),vb,vc5); vc13 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+6]),vb1,vc13); vc6 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+6]),vb,vc6); vc14 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+6]),vb1,vc14); vc7 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+6]),vb,vc7); vc15 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+6]),vb1,vc15);
vb=_mm512_load_pd(&b[(jj+7)*n3 + kk]); vb1=_mm512_load_pd(&b[(jj+7)*n3 + kk+8]); vc0 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+7]),vb,vc0); vc8 = _mm512_fmadd_pd(_mm512_set1_pd(a[ii*n2+jj+7]),vb1,vc8); vc1 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+7]),vb,vc1); vc9 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+1)*n2+jj+7]),vb1,vc9); vc2 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+7]),vb,vc2); vc10 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+2)*n2+jj+7]),vb1,vc10); vc3 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+7]),vb,vc3); vc11 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+3)*n2+jj+7]),vb1,vc11); vc4 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+7]),vb,vc4); vc12 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+4)*n2+jj+7]),vb1,vc12); vc5 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+7]),vb,vc5); vc13 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+5)*n2+jj+7]),vb1,vc13); vc6 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+7]),vb,vc6); vc14 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+6)*n2+jj+7]),vb1,vc14); vc7 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+7]),vb,vc7); vc15 = _mm512_fmadd_pd(_mm512_set1_pd(a[(ii+7)*n2+jj+7]),vb1,vc15); } _mm512_store_pd(&c[ii*n3 + kk],vc0); _mm512_store_pd(&c[ii*n3 + kk+8],vc8); _mm512_store_pd(&c[(ii+1)*n3 + kk],vc1); _mm512_store_pd(&c[(ii+1)*n3 + kk+8],vc9); _mm512_store_pd(&c[(ii+2)*n3 + kk],vc2); _mm512_store_pd(&c[(ii+2)*n3 + kk+8],vc10); _mm512_store_pd(&c[(ii+3)*n3 + kk],vc3); _mm512_store_pd(&c[(ii+3)*n3 + kk+8],vc11); _mm512_store_pd(&c[(ii+4)*n3 + kk],vc4); _mm512_store_pd(&c[(ii+4)*n3 + kk+8],vc12); _mm512_store_pd(&c[(ii+5)*n3 + kk],vc5); _mm512_store_pd(&c[(ii+5)*n3 + kk+8],vc13); _mm512_store_pd(&c[(ii+6)*n3 + kk],vc6); _mm512_store_pd(&c[(ii+6)*n3 + kk+8],vc14); _mm512_store_pd(&c[(ii+7)*n3 + kk],vc7); _mm512_store_pd(&c[(ii+7)*n3 + kk+8],vc15); } } } } } } int main() { uint64_t n1, n2, n3; FILE* fi;
fi = fopen("conf.data", "rb"); fread(&n1, 1, 8, fi); fread(&n2, 1, 8, fi); fread(&n3, 1, 8, fi);
double* a = (double*)_mm_malloc(n1 * n2 * 8,64); double* b = (double*)_mm_malloc(n2 * n3 * 8,64); double* c = (double*)_mm_malloc(n1 * n3 * 8,64);
fread(a, 1, n1 * n2 * 8, fi); fread(b, 1, n2 * n3 * 8, fi); fclose(fi);
for (uint64_t i = 0; i < n1; i++) { for (uint64_t k = 0; k < n3; k++) { c[i * n3 + k] = 0; } }
auto t1 = std::chrono::steady_clock::now(); mul(a, b, c, n1, n2, n3); auto t2 = std::chrono::steady_clock::now(); int d1 = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count(); printf("%d\n", d1);
fi = fopen("out.data", "wb"); fwrite(c, 1, n1 * n3 * 8, fi); fclose(fi);
return 0; }
|