Skip to content

Commit 0b2ee51

Browse files
fix: Add vscale helpers
As `vscale` by design doesn't consider factors like accumulation it's straightforward to use shared helpers and avoid using magic numbers. This change introduces the constant `KAI_VSCALE_MAX` which is the maximum allowed `vscale` factor, and `kai_get_sme_vscale()` which is the SME `vscale` factor for current CPU. Signed-off-by: Emil Ohlsson <emil.ohlsson@arm.com> Reviewed-by: James Gross <james.gross@arm.com> Reviewed-by: Felix Johnny Thomasmathibalan <felixjohnny.thomasmathibalan@arm.com> Approved-by: Felix Johnny Thomasmathibalan <felixjohnny.thomasmathibalan@arm.com>
1 parent 0008666 commit 0b2ee51

4 files changed

Lines changed: 23 additions & 9 deletions

kai/kai_common.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,17 @@ extern "C" {
118118
#define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b))
119119
#define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b))
120120

121-
/// Largest supported SME vector length in bytes
122-
#define KAI_SME_VEC_LENGTH_MAX_BYTES 256 // NOLINT(cppcoreguidelines-macro-to-enum,modernize-macro-to-enum)
121+
/// KleidiAI shared constants
122+
enum {
123+
/// Largest supported SME vector length in bytes
124+
KAI_SME_VEC_LENGTH_MAX_BYTES = 256,
125+
126+
/// Size of one vscale unit, in bytes
127+
KAI_VSCALE_UNIT_BYTES = 16,
128+
129+
/// Maximum possible VSCALE
130+
KAI_VSCALE_MAX = KAI_SME_VEC_LENGTH_MAX_BYTES / KAI_VSCALE_UNIT_BYTES,
131+
};
123132

124133
/// Gets the version of the project in the Major.Minor.Patch semantic versioning format.
125134
///
@@ -221,6 +230,11 @@ inline static uint64_t kai_get_sme_vector_length_u32(void) {
221230
return kai_get_sme_vector_length_u8() / 4;
222231
}
223232

233+
/// Gets the vscale scale factor for SME
234+
inline static uint64_t kai_get_sme_vscale(void) {
235+
return kai_get_sme_vector_length_u8() / KAI_VSCALE_UNIT_BYTES;
236+
}
237+
224238
/// Commit ZA to lazy save buffer
225239
void kai_commit_za(void);
226240
#endif // defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64)
@@ -244,7 +258,7 @@ inline static uint64_t kai_get_sve_vector_length_u32(void) {
244258
/// @return the int8_t value with sign extended
245259
inline static int8_t kai_ext_sign_i8_i4(int8_t value) {
246260
// Make sure value holds correct int4 value
247-
KAI_ASSERT(value <= 0xF);
261+
KAI_ASSUME(value <= 0xF);
248262

249263
return (value ^ 0x8) - 8; // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
250264
}

kai/ukernels/matmul/pack/kai_matmul_pack_lhs_mxk_x32p4vsx1_x32_sme.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ enum {
1616
MR_VSCALE = 4,
1717
KR = 1,
1818

19-
MAX_MR = MR_VSCALE * KAI_SME_VEC_LENGTH_MAX_BYTES / 16,
19+
MAX_MR = MR_VSCALE * KAI_VSCALE_MAX,
2020
};
2121

2222
void kai_kernel_matmul_pack_lhs_mxk_x32p4vsx1_x32_sme(
2323
size_t height, size_t width, const void* in, size_t row_offset, void* out);
2424

2525
static size_t get_mr(void) {
26-
return MR_VSCALE * kai_get_sme_vector_length_u8() / 16;
26+
return MR_VSCALE * kai_get_sme_vscale();
2727
}
2828

2929
static size_t div_ceil(size_t a, size_t b) {

kai/ukernels/matmul/pack/kai_matmul_pack_rhs_kxn_x32p4vsx1bx32_x32_x32_sme.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ enum {
1515
NR_VSCALE = 4,
1616
KR = 1,
1717

18-
MAX_NR = NR_VSCALE * KAI_SME_VEC_LENGTH_MAX_BYTES / 16,
18+
MAX_NR = NR_VSCALE * KAI_VSCALE_MAX,
1919
};
2020

2121
struct uker_args_t {
@@ -31,7 +31,7 @@ struct uker_args_t {
3131
void kai_kernel_matmul_pack_rhs_kxn_x32p4vsx1bx32_x32_x32_sme(const struct uker_args_t* args);
3232

3333
static size_t get_nr(void) {
34-
return NR_VSCALE * kai_get_sme_vector_length_u8() / 16;
34+
return NR_VSCALE * kai_get_sme_vscale();
3535
}
3636

3737
static size_t div_ceil(size_t a, size_t b) {

kai/ukernels/matmul/pack/kai_matmul_pack_rhs_nxk_x32p4vsx1bx32_x32_x32_sme.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ enum {
1717
NR_VSCALE = 4,
1818
KR = 1,
1919

20-
MAX_NR = NR_VSCALE * KAI_SME_VEC_LENGTH_MAX_BYTES / 16,
20+
MAX_NR = NR_VSCALE * KAI_VSCALE_MAX,
2121
};
2222

2323
void kai_kernel_matmul_pack_rhs_nxk_x32p4vsx1bx32_x32_x32_sme(
2424
size_t height, size_t width, const void* in, size_t row_offset, void* out, const void* bias);
2525

2626
static size_t get_nr(void) {
27-
return NR_VSCALE * kai_get_sme_vector_length_u8() / 16;
27+
return NR_VSCALE * kai_get_sme_vscale();
2828
}
2929

3030
static size_t div_ceil(size_t a, size_t b) {

0 commit comments

Comments
 (0)