Skip to content

Commit 4b94957

Browse files
committed
w
1 parent e0a6655 commit 4b94957

4 files changed

Lines changed: 665 additions & 266 deletions

File tree

src/layer/riscv/gemm_bf16s_fp16s.h

Lines changed: 2 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,7 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
351351
{
352352
if (elempack == packn)
353353
{
354-
const int q = (j + jj) / packn * packn;
355-
const int r = (j + jj) % packn;
356-
const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + q * packn + r * packn;
354+
const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * packn;
357355

358356
int kk = 0;
359357
for (; kk + (packn - 1) < max_kk; kk += packn)
@@ -382,9 +380,7 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
382380
{
383381
if (elempack == packn)
384382
{
385-
const int q = (j + jj) / packn * packn;
386-
const int r = (j + jj) % packn;
387-
const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + q * packn + r * packn;
383+
const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * packn;
388384

389385
int kk = 0;
390386
for (; kk + (packn - 1) < max_kk; kk += packn)
@@ -475,111 +471,6 @@ static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int ma
475471
}
476472
}
477473

478-
static void transpose_unpack_output_tile_bf16_fp16(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj)
479-
{
480-
#if __riscv_vector
481-
const int packn = csrr_vlenb() / 2;
482-
const size_t vl = __riscv_vsetvl_e16m1(packn);
483-
#endif
484-
485-
const int out_elempack = top_blob.elempack;
486-
const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w;
487-
488-
const unsigned short* pp = topT;
489-
490-
int ii = 0;
491-
#if __riscv_vector
492-
for (; ii + (packn - 1) < max_ii; ii += packn)
493-
{
494-
if (out_elempack == packn)
495-
{
496-
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * packn;
497-
498-
for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn)
499-
{
500-
// transposeNxN
501-
for (int l = 0; l < packn; l++)
502-
{
503-
__riscv_vsse16_v_u16m1(p0 + l, packn * sizeof(unsigned short), __riscv_vle16_v_u16m1(pp, vl), vl);
504-
pp += packn;
505-
}
506-
p0 += out_hstep * packn;
507-
}
508-
}
509-
if (out_elempack == 1)
510-
{
511-
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii);
512-
513-
for (int jj = 0; jj < max_jj; jj += 1)
514-
{
515-
vuint16m1_t _r0 = __riscv_vle16_v_u16m1(pp, vl);
516-
__riscv_vse16_v_u16m1(p0, _r0, vl);
517-
pp += packn;
518-
p0 += out_hstep;
519-
}
520-
}
521-
}
522-
#endif // __riscv_vector
523-
for (; ii + 1 < max_ii; ii += 2)
524-
{
525-
#if __riscv_vector
526-
if (out_elempack == packn)
527-
{
528-
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * packn;
529-
530-
for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn)
531-
{
532-
vuint16m1x2_t _s0 = __riscv_vlseg2e16_v_u16m1x2(pp, vl);
533-
__riscv_vse16_v_u16m1(p0, __riscv_vget_v_u16m1x2_u16m1(_s0, 0), vl);
534-
__riscv_vse16_v_u16m1(p0 + packn, __riscv_vget_v_u16m1x2_u16m1(_s0, 1), vl);
535-
pp += packn * 2;
536-
p0 += out_hstep * packn;
537-
}
538-
}
539-
#endif // __riscv_vector
540-
if (out_elempack == 1)
541-
{
542-
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii);
543-
544-
for (int jj = 0; jj < max_jj; jj += 1)
545-
{
546-
p0[0] = pp[0];
547-
p0[1] = pp[1];
548-
pp += 2;
549-
p0 += out_hstep;
550-
}
551-
}
552-
}
553-
for (; ii < max_ii; ii += 1)
554-
{
555-
#if __riscv_vector
556-
if (out_elempack == packn)
557-
{
558-
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * packn;
559-
560-
for (int jj = 0; jj + (packn - 1) < max_jj; jj += packn)
561-
{
562-
vuint16m1_t _r0 = __riscv_vle16_v_u16m1(pp, vl);
563-
__riscv_vse16_v_u16m1(p0, _r0, vl);
564-
pp += packn;
565-
p0 += out_hstep * packn;
566-
}
567-
}
568-
#endif // __riscv_vector
569-
if (out_elempack == 1)
570-
{
571-
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii);
572-
573-
for (int jj = 0; jj < max_jj; jj += 1)
574-
{
575-
p0[0] = pp[0];
576-
pp += 1;
577-
p0 += out_hstep;
578-
}
579-
}
580-
}
581-
}
582-
583474
static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT)
584475
{
585476
// resolve optimal tile size from cache size

0 commit comments

Comments
 (0)