@@ -351,9 +351,7 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
351351 {
352352 if (elempack == packn)
353353 {
354- const int q = (j + jj) / packn * packn;
355- const int r = (j + jj) % packn;
356- const unsigned short * p0 = (const unsigned short *)B + k * B_hstep + q * packn + r * packn;
354+ const unsigned short * p0 = (const unsigned short *)B + k * B_hstep + (j + jj) * packn;
357355
358356 int kk = 0 ;
359357 for (; kk + (packn - 1 ) < max_kk; kk += packn)
@@ -382,9 +380,7 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
382380 {
383381 if (elempack == packn)
384382 {
385- const int q = (j + jj) / packn * packn;
386- const int r = (j + jj) % packn;
387- const unsigned short * p0 = (const unsigned short *)B + k * B_hstep + q * packn + r * packn;
383+ const unsigned short * p0 = (const unsigned short *)B + k * B_hstep + (j + jj) * packn;
388384
389385 int kk = 0 ;
390386 for (; kk + (packn - 1 ) < max_kk; kk += packn)
@@ -475,111 +471,6 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
475471 }
476472}
477473
478- static void transpose_unpack_output_tile_bf16_fp16 (const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj)
479- {
480- #if __riscv_vector
481- const int packn = csrr_vlenb () / 2 ;
482- const size_t vl = __riscv_vsetvl_e16m1 (packn);
483- #endif
484-
485- const int out_elempack = top_blob.elempack ;
486- const int out_hstep = top_blob.dims == 3 ? (int )top_blob.cstep : top_blob.w ;
487-
488- const unsigned short * pp = topT;
489-
490- int ii = 0 ;
491- #if __riscv_vector
492- for (; ii + (packn - 1 ) < max_ii; ii += packn)
493- {
494- if (out_elempack == packn)
495- {
496- unsigned short * p0 = (unsigned short *)top_blob + j * out_hstep + (i + ii) * packn;
497-
498- for (int jj = 0 ; jj + (packn - 1 ) < max_jj; jj += packn)
499- {
500- // transposeNxN
501- for (int l = 0 ; l < packn; l++)
502- {
503- __riscv_vsse16_v_u16m1 (p0 + l, packn * sizeof (unsigned short ), __riscv_vle16_v_u16m1 (pp, vl), vl);
504- pp += packn;
505- }
506- p0 += out_hstep * packn;
507- }
508- }
509- if (out_elempack == 1 )
510- {
511- unsigned short * p0 = (unsigned short *)top_blob + j * out_hstep + (i + ii);
512-
513- for (int jj = 0 ; jj < max_jj; jj += 1 )
514- {
515- vuint16m1_t _r0 = __riscv_vle16_v_u16m1 (pp, vl);
516- __riscv_vse16_v_u16m1 (p0, _r0, vl);
517- pp += packn;
518- p0 += out_hstep;
519- }
520- }
521- }
522- #endif // __riscv_vector
523- for (; ii + 1 < max_ii; ii += 2 )
524- {
525- #if __riscv_vector
526- if (out_elempack == packn)
527- {
528- unsigned short * p0 = (unsigned short *)top_blob + j * out_hstep + (i + ii) * packn;
529-
530- for (int jj = 0 ; jj + (packn - 1 ) < max_jj; jj += packn)
531- {
532- vuint16m1x2_t _s0 = __riscv_vlseg2e16_v_u16m1x2 (pp, vl);
533- __riscv_vse16_v_u16m1 (p0, __riscv_vget_v_u16m1x2_u16m1 (_s0, 0 ), vl);
534- __riscv_vse16_v_u16m1 (p0 + packn, __riscv_vget_v_u16m1x2_u16m1 (_s0, 1 ), vl);
535- pp += packn * 2 ;
536- p0 += out_hstep * packn;
537- }
538- }
539- #endif // __riscv_vector
540- if (out_elempack == 1 )
541- {
542- unsigned short * p0 = (unsigned short *)top_blob + j * out_hstep + (i + ii);
543-
544- for (int jj = 0 ; jj < max_jj; jj += 1 )
545- {
546- p0[0 ] = pp[0 ];
547- p0[1 ] = pp[1 ];
548- pp += 2 ;
549- p0 += out_hstep;
550- }
551- }
552- }
553- for (; ii < max_ii; ii += 1 )
554- {
555- #if __riscv_vector
556- if (out_elempack == packn)
557- {
558- unsigned short * p0 = (unsigned short *)top_blob + j * out_hstep + (i + ii) * packn;
559-
560- for (int jj = 0 ; jj + (packn - 1 ) < max_jj; jj += packn)
561- {
562- vuint16m1_t _r0 = __riscv_vle16_v_u16m1 (pp, vl);
563- __riscv_vse16_v_u16m1 (p0, _r0, vl);
564- pp += packn;
565- p0 += out_hstep * packn;
566- }
567- }
568- #endif // __riscv_vector
569- if (out_elempack == 1 )
570- {
571- unsigned short * p0 = (unsigned short *)top_blob + j * out_hstep + (i + ii);
572-
573- for (int jj = 0 ; jj < max_jj; jj += 1 )
574- {
575- p0[0 ] = pp[0 ];
576- pp += 1 ;
577- p0 += out_hstep;
578- }
579- }
580- }
581- }
582-
583474static void get_optimal_tile_mnk_bf16s_fp16s (int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int & TILE_M , int & TILE_N , int & TILE_K , int nT)
584475{
585476 // resolve optimal tile size from cache size
0 commit comments