Skip to content

Commit 457a316

Browse files
committed
int8 vec support
1 parent 83f94ce commit 457a316

1 file changed

Lines changed: 17 additions & 7 deletions

File tree

backends/cuda/runtime/shims/int8_plain_mm.cuh

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ __host__ __forceinline__ int32_t log2_pow2_i8(int32_t v) {
5858
// blocks, NATURAL order — qs[k] holds the quantized value for element k).
5959
// ---------------------------------------------------------------------------
6060

61-
struct Q8BlockNat {
61+
// alignas(16) pads sizeof(Q8BlockNat) 36->48 so each block (and its two 16-byte
62+
// qs halves) is 16-byte aligned. This lets the matvec load 16 int8 activations
63+
// with one vectorized uint4 load instead of four scalar int32 loads, cutting
64+
// activation load instructions ~4x.
65+
struct alignas(16) Q8BlockNat {
6266
int8_t qs[Q8_NAT_BLOCK_SIZE];
6367
float d; // scale
6468
};
@@ -135,6 +139,17 @@ __global__ void __launch_bounds__(MV8_THREADS) int8_w8a8_matvec_kernel(
135139
int32_t k_base = i * 16;
136140
uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w};
137141

142+
// One uint4 (16 int8 weights) maps to exactly one 16-byte half of a Q8
143+
// activation block (16 activations): block i>>1, byte offset 0 (i even) or
144+
// 16 (i odd). Load those 16 int8 activations with a single vectorized uint4
145+
// load (+ one scale load) instead of four scalar int32 loads + four scale
146+
// reloads. av.{x,y,z,w} == qs[off+0:4],[4:8],[8:12],[12:16] == a_word for
147+
// w=0..3 -> bit-identical to the scalar path.
148+
const Q8BlockNat* qb = &q8_row[i >> 1];
149+
uint4 av = *reinterpret_cast<const uint4*>(qb->qs + ((i & 1) ? 16 : 0));
150+
float a_scale = qb->d;
151+
const uint32_t a_words[4] = {av.x, av.y, av.z, av.w};
152+
138153
#pragma unroll
139154
for (int32_t w = 0; w < 4; w++) {
140155
int32_t k_word = k_base + w * 4; // 4 int8 weights start here
@@ -147,15 +162,10 @@ __global__ void __launch_bounds__(MV8_THREADS) int8_w8a8_matvec_kernel(
147162
}
148163

149164
int32_t w_word = static_cast<int32_t>(words[w]);
150-
151-
int32_t q8_block_idx = k_word / Q8_NAT_BLOCK_SIZE;
152-
int32_t q8_offset = k_word % Q8_NAT_BLOCK_SIZE;
153-
const Q8BlockNat* qb = &q8_row[q8_block_idx];
154-
int32_t a_word = *reinterpret_cast<const int32_t*>(qb->qs + q8_offset);
165+
int32_t a_word = static_cast<int32_t>(a_words[w]);
155166

156167
int32_t dp = __dp4a(w_word, a_word, 0);
157168
int32_t a_sum = __dp4a(0x01010101, a_word, 0);
158-
float a_scale = qb->d;
159169

160170
sum += ws * a_scale *
161171
(static_cast<float>(dp) - wz * static_cast<float>(a_sum));

0 commit comments

Comments
 (0)