Skip to content

Commit 99c2eed

Browse files
committed
feat: add dp4a device intrinsic
Add CUDACore.dp4a with the four signedness variants of the PTX dp4a instruction (packed 4-element int8/uint8 dot product with 32-bit accumulate), available on sm_61 and later. On LLVM 21 and later the implementation uses the @llvm.nvvm.idp4a.[us].[us] intrinsics added in LLVM 21; on older versions it falls back to inline PTX via @asmcall. Both paths verified on sm_75: identical dp4a instruction selection and bit-identical results against a byte-wise reference, on Julia 1.11 (LLVM 16, asm path) and nightly (LLVM 21, intrinsic path).
1 parent aa47d7a commit 99c2eed

2 files changed

Lines changed: 183 additions & 1 deletion

File tree

  • CUDACore/src/device/intrinsics
  • test/core/device/intrinsics

CUDACore/src/device/intrinsics/math.jl

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# we only use libdevice where needed. if possible, we go through LLVM instead,
44
# ideally relying on Julia's existing definitions.
55

6-
@public fma, rsqrt, saturate, byte_perm, assume
6+
@public fma, rsqrt, saturate, byte_perm, dp4a, assume
77
@public add_rn, add_rz, add_rm, add_rp
88
@public sub_rn, sub_rz, sub_rm, sub_rp
99
@public mul_rn, mul_rz, mul_rm, mul_rp
@@ -286,6 +286,60 @@ end
286286
ccall("extern __nv_byte_perm", llvmcall, Int32, (UInt32, UInt32, UInt32), x, y, z)
287287
end
288288

289+
"""
290+
dp4a(a, b, c)
291+
292+
Packed 4-element int8 (or uint8) dot product with 32-bit accumulation, mapped to a single
293+
PTX `dp4a` instruction on sm_61+.
294+
295+
The semantics depend on the signedness of `a` and `b`:
296+
297+
- `dp4a(a::Int32, b::Int32, c::Int32) -> Int32` — signed × signed
298+
- `dp4a(a::Int32, b::UInt32, c::Int32) -> Int32` — signed × unsigned
299+
- `dp4a(a::UInt32, b::Int32, c::Int32) -> Int32` — unsigned × signed
300+
- `dp4a(a::UInt32, b::UInt32, c::UInt32) -> UInt32` — unsigned × unsigned
301+
302+
Each 32-bit argument `a` and `b` is interpreted as four packed 8-bit integers. The result
303+
is `c + a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]` where the individual byte
304+
extractions respect the signed/unsigned interpretation of each operand.
305+
306+
!!! note
307+
Requires compute capability sm_61 or higher.
308+
"""
309+
function dp4a end
310+
311+
@static if LLVM.version() >= v"21"
312+
# LLVM 21 added @llvm.nvvm.idp4a.[us].[us]; prefer the intrinsic over inline PTX so
313+
# the instruction participates in optimization and instruction selection.
314+
@device_function dp4a(a::Int32, b::Int32, c::Int32) =
315+
ccall("llvm.nvvm.idp4a.s.s", llvmcall, Int32, (Int32, Int32, Int32), a, b, c)
316+
317+
@device_function dp4a(a::Int32, b::UInt32, c::Int32) =
318+
ccall("llvm.nvvm.idp4a.s.u", llvmcall, Int32, (Int32, UInt32, Int32), a, b, c)
319+
320+
@device_function dp4a(a::UInt32, b::Int32, c::Int32) =
321+
ccall("llvm.nvvm.idp4a.u.s", llvmcall, Int32, (UInt32, Int32, Int32), a, b, c)
322+
323+
@device_function dp4a(a::UInt32, b::UInt32, c::UInt32) =
324+
ccall("llvm.nvvm.idp4a.u.u", llvmcall, UInt32, (UInt32, UInt32, UInt32), a, b, c)
325+
else
326+
@device_function dp4a(a::Int32, b::Int32, c::Int32) =
327+
@asmcall("dp4a.s32.s32 \$0, \$1, \$2, \$3;", "=r,r,r,r", false,
328+
Int32, Tuple{Int32, Int32, Int32}, a, b, c)
329+
330+
@device_function dp4a(a::Int32, b::UInt32, c::Int32) =
331+
@asmcall("dp4a.s32.u32 \$0, \$1, \$2, \$3;", "=r,r,r,r", false,
332+
Int32, Tuple{Int32, UInt32, Int32}, a, b, c)
333+
334+
@device_function dp4a(a::UInt32, b::Int32, c::Int32) =
335+
@asmcall("dp4a.u32.s32 \$0, \$1, \$2, \$3;", "=r,r,r,r", false,
336+
Int32, Tuple{UInt32, Int32, Int32}, a, b, c)
337+
338+
@device_function dp4a(a::UInt32, b::UInt32, c::UInt32) =
339+
@asmcall("dp4a.u32.u32 \$0, \$1, \$2, \$3;", "=r,r,r,r", false,
340+
UInt32, Tuple{UInt32, UInt32, UInt32}, a, b, c)
341+
end
342+
289343

290344
## floating-point handling
291345

test/core/device/intrinsics/math.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,134 @@ using SpecialFunctions
339339
end
340340
end
341341

342+
# dp4a requires sm_61+
343+
if capability(device()) >= v"6.1"
344+
@testset "dp4a" begin
345+
# Pure-Julia reference: unpack four bytes from a packed Int32/UInt32,
346+
# dot-product them (with the respective signed/unsigned semantics), and
347+
# add the accumulator.
348+
function ref_dp4a_ss(a::Int32, b::Int32, c::Int32)
349+
ba = reinterpret(NTuple{4,Int8}, a)
350+
bb = reinterpret(NTuple{4,Int8}, b)
351+
c + Int32(ba[1])*Int32(bb[1]) + Int32(ba[2])*Int32(bb[2]) +
352+
Int32(ba[3])*Int32(bb[3]) + Int32(ba[4])*Int32(bb[4])
353+
end
354+
function ref_dp4a_su(a::Int32, b::UInt32, c::Int32)
355+
ba = reinterpret(NTuple{4,Int8}, a)
356+
bb = reinterpret(NTuple{4,UInt8}, b)
357+
c + Int32(ba[1])*Int32(bb[1]) + Int32(ba[2])*Int32(bb[2]) +
358+
Int32(ba[3])*Int32(bb[3]) + Int32(ba[4])*Int32(bb[4])
359+
end
360+
function ref_dp4a_us(a::UInt32, b::Int32, c::Int32)
361+
ba = reinterpret(NTuple{4,UInt8}, a)
362+
bb = reinterpret(NTuple{4,Int8}, b)
363+
c + Int32(ba[1])*Int32(bb[1]) + Int32(ba[2])*Int32(bb[2]) +
364+
Int32(ba[3])*Int32(bb[3]) + Int32(ba[4])*Int32(bb[4])
365+
end
366+
function ref_dp4a_uu(a::UInt32, b::UInt32, c::UInt32)
367+
ba = reinterpret(NTuple{4,UInt8}, a)
368+
bb = reinterpret(NTuple{4,UInt8}, b)
369+
c + UInt32(ba[1])*UInt32(bb[1]) + UInt32(ba[2])*UInt32(bb[2]) +
370+
UInt32(ba[3])*UInt32(bb[3]) + UInt32(ba[4])*UInt32(bb[4])
371+
end
372+
373+
# Kernels: each writes one result per thread (we launch 1 thread, one
374+
# case per test to keep the kernel signatures simple).
375+
function kernel_ss(out, a, b, c)
376+
out[] = CUDA.dp4a(a, b, c)
377+
return
378+
end
379+
function kernel_su(out, a, b, c)
380+
out[] = CUDA.dp4a(a, b, c)
381+
return
382+
end
383+
function kernel_us(out, a, b, c)
384+
out[] = CUDA.dp4a(a, b, c)
385+
return
386+
end
387+
function kernel_uu(out, a, b, c)
388+
out[] = CUDA.dp4a(a, b, c)
389+
return
390+
end
391+
392+
# Helper: pack four Int8/UInt8 values (little-endian: b0 in bits 7:0).
393+
# Use reinterpret(Int32/UInt32, NTuple{4,Int8/UInt8}) — portable and avoids
394+
# integer-width pitfalls in the shift+or approach.
395+
pack_s(b0, b1, b2, b3) = reinterpret(Int32, (b0%Int8, b1%Int8, b2%Int8, b3%Int8))
396+
pack_u(b0, b1, b2, b3) = reinterpret(UInt32, (b0%UInt8, b1%UInt8, b2%UInt8, b3%UInt8))
397+
398+
@testset "ss — signed × signed" begin
399+
cases = [
400+
# (a_bytes…, b_bytes…, c, label)
401+
(Int32(0), Int32(0), Int32(0)), # all zeros
402+
(pack_s(1,2,3,4), pack_s(5,6,7,8), Int32(10)), # 1*5+2*6+3*7+4*8+10 = 80
403+
(pack_s(127,127,127,127), pack_s(1,1,1,1), Int32(0)), # max positive bytes
404+
(pack_s(-128,-128,-128,-128), pack_s(1,1,1,1), Int32(0)), # most-negative bytes
405+
(pack_s(-1,-1,-1,-1), pack_s(-1,-1,-1,-1), Int32(0)), # neg*neg
406+
(Int32(-1), Int32(-1), Int32(100)), # 0xFF packing
407+
]
408+
for (a, b, c) in cases
409+
expected = ref_dp4a_ss(a, b, c)
410+
buf = CuArray{Int32}(undef, 1)
411+
@cuda threads=1 kernel_ss(buf, a, b, c)
412+
@test Array(buf)[] == expected
413+
end
414+
end
415+
416+
@testset "su — signed × unsigned" begin
417+
cases = [
418+
(Int32(0), UInt32(0), Int32(0)),
419+
(pack_s(1,2,3,4), pack_u(5,6,7,8), Int32(10)), # 1*5+…+10 = 80
420+
(pack_s(127,0,-128,1), pack_u(255,128,1,0), Int32(5)),
421+
(pack_s(-1,-1,-1,-1), pack_u(255,255,255,255), Int32(0)), # -1 * 255 * 4 = -1020
422+
]
423+
for (a, b, c) in cases
424+
expected = ref_dp4a_su(a, b, c)
425+
buf = CuArray{Int32}(undef, 1)
426+
@cuda threads=1 kernel_su(buf, a, b, c)
427+
@test Array(buf)[] == expected
428+
end
429+
end
430+
431+
@testset "us — unsigned × signed" begin
432+
cases = [
433+
(UInt32(0), Int32(0), Int32(0)),
434+
(pack_u(1,2,3,4), pack_s(5,6,7,8), Int32(10)),
435+
(pack_u(255,128,0,1), pack_s(-1,1,-128,127), Int32(0)),
436+
]
437+
for (a, b, c) in cases
438+
expected = ref_dp4a_us(a, b, c)
439+
buf = CuArray{Int32}(undef, 1)
440+
@cuda threads=1 kernel_us(buf, a, b, c)
441+
@test Array(buf)[] == expected
442+
end
443+
end
444+
445+
@testset "uu — unsigned × unsigned" begin
446+
cases = [
447+
(UInt32(0), UInt32(0), UInt32(0)),
448+
(pack_u(1,2,3,4), pack_u(5,6,7,8), UInt32(10)), # 80
449+
(pack_u(255,255,255,255), pack_u(1,1,1,1), UInt32(0)), # 4*255 = 1020
450+
(pack_u(255,255,255,255), pack_u(255,255,255,255), UInt32(0)), # 4*255^2 = 260100
451+
]
452+
for (a, b, c) in cases
453+
expected = ref_dp4a_uu(a, b, c)
454+
buf = CuArray{UInt32}(undef, 1)
455+
@cuda threads=1 kernel_uu(buf, a, b, c)
456+
@test Array(buf)[] == expected
457+
end
458+
end
459+
460+
@testset "PTX instruction selection" begin
461+
# Verify the backend emits the actual dp4a instruction, not a
462+
# software emulation sequence.
463+
buf = CuArray{Int32}(undef, 1)
464+
ptx = sprint(io->(@device_code_ptx io=io @cuda launch=false kernel_ss(buf, Int32(0), Int32(0), Int32(0))))
465+
@test occursin("dp4a", ptx)
466+
end
467+
end
468+
end # capability >= v"6.1"
469+
342470
@testset "@fastmath sincos" begin
343471
# JuliaGPU/CUDA.jl#1606: FastMath.sincos fell back to regular sin/cos
344472
@test @filecheck CUDA.code_ptx(NTuple{3,CuDeviceArray{Float32,1,AS.Global}}) do a, b, c

0 commit comments

Comments
 (0)