@@ -37,7 +37,7 @@ static void pack_A_tile_bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, in
3737 {
3838 int kk = 0;
3939#if __AVX512BF16__
40- __m512i _idx = _mm512_set_epi16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21 , 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0 );
40+ __m512i _idx = _mm512_setr_epi32(0 | (16 << 16), 1 | (17 << 16), 2 | (18 << 16), 3 | (19 << 16), 4 | (20 << 16) , 5 | (21 << 16), 6 | (22 << 16), 7 | (23 << 16), 8 | (24 << 16), 9 | (25 << 16), 10 | (26 << 16), 11 | (27 << 16), 12 | (28 << 16), 13 | (29 << 16), 14 | (30 << 16), 15 | (31 << 16) );
4141 for (; kk + 1 < max_kk; kk += 2)
4242 {
4343 __m512i _p = _mm512_loadu_si512((const __m512i*)p0);
@@ -60,7 +60,7 @@ static void pack_A_tile_bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, in
6060
6161 int kk = 0;
6262#if __AVX512BF16__
63- __m512i _idx = _mm512_set_epi16(31, 23, 30, 22, 29, 21, 28, 20, 27, 19, 26, 18, 25, 17, 24, 16, 15, 7, 14, 6, 13 , 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0 );
63+ __m512i _idx = _mm512_setr_epi32(0 | (8 << 16), 1 | (9 << 16), 2 | (10 << 16), 3 | (11 << 16), 4 | (12 << 16) , 5 | (13 << 16), 6 | (14 << 16), 7 | (15 << 16), 16 | (24 << 16), 17 | (25 << 16), 18 | (26 << 16), 19 | (27 << 16), 20 | (28 << 16), 21 | (29 << 16), 22 | (30 << 16), 23 | (31 << 16) );
6464 for (; kk + 1 < max_kk; kk += 2)
6565 {
6666 __m256i _a = _mm256_loadu_si256((const __m256i*)p0);
@@ -597,8 +597,8 @@ static void transpose_pack_A_tile_bf16(const Mat& A, Mat& AT, int i, int max_ii,
597597 __m512i _p0 = _mm512_permutex2var_epi32(_r0, idx_lo, _r1);
598598 __m512i _p1 = _mm512_permutex2var_epi32(_r0, idx_hi, _r1);
599599#else // __AVX512BF16__
600- __m512i idx_lo = _mm512_set_epi16(61, 57, 53, 49, 45, 41, 37, 33, 29, 25, 21, 17, 13, 9, 5, 1, 60, 56, 52, 48, 44 , 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0 );
601- __m512i idx_hi = _mm512_set_epi16(63, 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 62, 58, 54, 50, 46 , 42, 38, 34, 30, 26, 22, 18, 14, 10, 6, 2 );
600+ __m512i idx_lo = _mm512_setr_epi32(0 | (4 << 16), 8 | (12 << 16), 16 | (20 << 16), 24 | (28 << 16), 32 | (36 << 16) , 40 | (44 << 16), 48 | (52 << 16), 56 | (60 << 16), 1 | (5 << 16), 9 | (13 << 16), 17 | (21 << 16), 25 | (29 << 16), 33 | (37 << 16), 41 | (45 << 16), 49 | (53 << 16), 57 | (61 << 16) );
601+ __m512i idx_hi = _mm512_setr_epi32(2 | (6 << 16), 10 | (14 << 16), 18 | (22 << 16), 26 | (30 << 16), 34 | (38 << 16) , 42 | (46 << 16), 50 | (54 << 16), 58 | (62 << 16), 3 | (7 << 16), 11 | (15 << 16), 19 | (23 << 16), 27 | (31 << 16), 35 | (39 << 16), 43 | (47 << 16), 51 | (55 << 16), 59 | (63 << 16) );
602602 __m512i _p0 = _mm512_permutex2var_epi16(_r0, idx_lo, _r1);
603603 __m512i _p1 = _mm512_permutex2var_epi16(_r0, idx_hi, _r1);
604604#endif // __AVX512BF16__
@@ -1050,7 +1050,7 @@ static void pack_B_tile_bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, in
10501050 {
10511051 int kk = 0;
10521052#if __AVX512BF16__
1053- __m512i _idx = _mm512_set_epi16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21 , 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0 );
1053+ __m512i _idx = _mm512_setr_epi32(0 | (16 << 16), 1 | (17 << 16), 2 | (18 << 16), 3 | (19 << 16), 4 | (20 << 16) , 5 | (21 << 16), 6 | (22 << 16), 7 | (23 << 16), 8 | (24 << 16), 9 | (25 << 16), 10 | (26 << 16), 11 | (27 << 16), 12 | (28 << 16), 13 | (29 << 16), 14 | (30 << 16), 15 | (31 << 16) );
10541054 for (; kk + 1 < max_kk; kk += 2)
10551055 {
10561056 __m512i _p = _mm512_loadu_si512((const __m512i*)p0);
@@ -1073,7 +1073,7 @@ static void pack_B_tile_bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, in
10731073
10741074 int kk = 0;
10751075#if __AVX512BF16__
1076- __m512i _idx = _mm512_set_epi16(31, 23, 30, 22, 29, 21, 28, 20, 27, 19, 26, 18, 25, 17, 24, 16, 15, 7, 14, 6, 13 , 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0 );
1076+ __m512i _idx = _mm512_setr_epi32(0 | (8 << 16), 1 | (9 << 16), 2 | (10 << 16), 3 | (11 << 16), 4 | (12 << 16) , 5 | (13 << 16), 6 | (14 << 16), 7 | (15 << 16), 16 | (24 << 16), 17 | (25 << 16), 18 | (26 << 16), 19 | (27 << 16), 20 | (28 << 16), 21 | (29 << 16), 22 | (30 << 16), 23 | (31 << 16) );
10771077 for (; kk + 1 < max_kk; kk += 2)
10781078 {
10791079 __m256i _a = _mm256_loadu_si256((const __m256i*)p0);
@@ -1701,8 +1701,8 @@ static void transpose_pack_B_tile_bf16(const Mat& B, Mat& BT, int j, int max_jj,
17011701 __m512i _p0 = _mm512_permutex2var_epi32(_r0, idx_lo, _r1);
17021702 __m512i _p1 = _mm512_permutex2var_epi32(_r0, idx_hi, _r1);
17031703#else // __AVX512BF16__
1704- __m512i idx_lo = _mm512_set_epi16(61, 57, 53, 49, 45, 41, 37, 33, 29, 25, 21, 17, 13, 9, 5, 1, 60, 56, 52, 48, 44 , 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0 );
1705- __m512i idx_hi = _mm512_set_epi16(63, 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 62, 58, 54, 50, 46 , 42, 38, 34, 30, 26, 22, 18, 14, 10, 6, 2 );
1704+ __m512i idx_lo = _mm512_setr_epi32(0 | (4 << 16), 8 | (12 << 16), 16 | (20 << 16), 24 | (28 << 16), 32 | (36 << 16) , 40 | (44 << 16), 48 | (52 << 16), 56 | (60 << 16), 1 | (5 << 16), 9 | (13 << 16), 17 | (21 << 16), 25 | (29 << 16), 33 | (37 << 16), 41 | (45 << 16), 49 | (53 << 16), 57 | (61 << 16) );
1705+ __m512i idx_hi = _mm512_setr_epi32(2 | (6 << 16), 10 | (14 << 16), 18 | (22 << 16), 26 | (30 << 16), 34 | (38 << 16) , 42 | (46 << 16), 50 | (54 << 16), 58 | (62 << 16), 3 | (7 << 16), 11 | (15 << 16), 19 | (23 << 16), 27 | (31 << 16), 35 | (39 << 16), 43 | (47 << 16), 51 | (55 << 16), 59 | (63 << 16) );
17061706 __m512i _p0 = _mm512_permutex2var_epi16(_r0, idx_lo, _r1);
17071707 __m512i _p1 = _mm512_permutex2var_epi16(_r0, idx_hi, _r1);
17081708#endif // __AVX512BF16__
@@ -8445,8 +8445,8 @@ static void unpack_output_tile_fp32_to_bf16(const Mat& topT, const Mat& C, Mat&
84458445 }
84468446 if (out_elempack == 1)
84478447 {
8448- __m512i _idx_r0r1 = _mm512_set_epi16(61, 45, 29, 13, 57, 41, 25, 9, 53, 37, 21, 5, 49, 33, 17, 1, 60, 44, 28, 12, 56 , 40, 24, 8, 52, 36, 20, 4, 48, 32, 16, 0 );
8449- __m512i _idx_r2r3 = _mm512_set_epi16(63, 47, 31, 15, 59, 43, 27, 11, 55, 39, 23, 7, 51, 35, 19, 3, 62, 46, 30, 14, 58 , 42, 26, 10, 54, 38, 22, 6, 50, 34, 18, 2 );
8448+ __m512i _idx_r0r1 = _mm512_setr_epi32(0 | (16 << 16), 32 | (48 << 16), 4 | (20 << 16), 36 | (52 << 16), 8 | (24 << 16) , 40 | (56 << 16), 12 | (28 << 16), 44 | (60 << 16), 1 | (17 << 16), 33 | (49 << 16), 5 | (21 << 16), 37 | (53 << 16), 9 | (25 << 16), 41 | (57 << 16), 13 | (29 << 16), 45 | (61 << 16) );
8449+ __m512i _idx_r2r3 = _mm512_setr_epi32(2 | (18 << 16), 34 | (50 << 16), 6 | (22 << 16), 38 | (54 << 16), 10 | (26 << 16) , 42 | (58 << 16), 14 | (30 << 16), 46 | (62 << 16), 3 | (19 << 16), 35 | (51 << 16), 7 | (23 << 16), 39 | (55 << 16), 11 | (27 << 16), 43 | (59 << 16), 15 | (31 << 16), 47 | (63 << 16) );
84508450
84518451 __m512i _bf01 = combine8x2_epi32(_bf0, _bf1);
84528452 __m512i _bf23 = combine8x2_epi32(_bf2, _bf3);
0 commit comments