Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions CUDACore/src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,26 @@ function device_layout(@nospecialize(T))
size == sizeof(T) || return :mismatch
return (size, align)
end
# walk `T` and every type reachable from it through type parameters and fields, returning
# `true` as soon as `bad(S)` holds for some reached type `S`. we must look through type
# parameters, not just fields: an aggregate with a mismatching layout is typically reached
# through a pointer (e.g. the element type of a `CuDeviceArray`, carried as a type parameter
# and never as a field), so inspecting only the argument's own fields would miss it and the
# kernel would silently read or write garbage.
function layout_reaches(bad, @nospecialize(T), seen=Base.IdSet{Any}())
(T isa Type && !(T in seen)) || return false
push!(seen, T)
bad(T) && return true
T isa DataType || return false
any(p -> layout_reaches(bad, p, seen), T.parameters) && return true
isconcretetype(T) || return false
any(i -> layout_reaches(bad, fieldtype(T, i), seen), 1:fieldcount(T))
end

device_compatible_layout(@nospecialize(T)) =
# since Julia 1.12, host and device layouts are identical
Base.datatype_alignment(Int128) == 16 || device_layout(T) !== :mismatch
Base.datatype_alignment(Int128) == 16 ||
!layout_reaches(S -> device_layout(S) === :mismatch, T)

# compile to executable machine code
function compile(@nospecialize(job::CompilerJob))
Expand Down Expand Up @@ -356,8 +373,7 @@ function compile(@nospecialize(job::CompilerJob))
end
for dt in argtypes
if !device_compatible_layout(dt)
error("""Kernel argument of type $dt contains Int128 fields whose layout differs between this version of Julia and the device.
Use Julia 1.12 or later, where 128-bit integers are aligned to 16 bytes, matching the device.""")
error("Kernel argument of type $dt references 128-bit integer fields. This is only supported on Julia 1.12 or later.")
end
end
param_usage = sum(aligned_sizeof, argtypes)
Expand Down
86 changes: 70 additions & 16 deletions test/core/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -742,24 +742,78 @@ end
end

@testset "argument layout" begin
kernel(x) = nothing

# plain 128-bit integers occupy their own parameter slot, so are fine
# the back-end aligns 128-bit integers to 16 bytes, but Julia only started doing so in
# 1.12, so aggregates with (U)Int128 fields lay out differently on older hosts. such
# types are rejected there (host==device on 1.12+, so everything is accepted).
@eval struct Int128Wrapper; x::Int64; y::Int128; end
@eval struct FloatWrapper; x::Int64; y::Float64; end # control: no 128-bit integers
host_ok = Base.datatype_alignment(Int128) == 16 # true on Julia 1.12+

# -- the compatibility walk must look through type parameters (e.g. device-array
# element types), not just fields. this part is host-independent. --
reaches_i128(T) = CUDACore.layout_reaches(S -> S === Int128 || S === UInt128, T)
@test reaches_i128(Int128)
@test reaches_i128(Int128Wrapper) # via a field
@test reaches_i128(Tuple{Int64,Int128Wrapper}) # via a tuple element
@test reaches_i128(Ptr{Int128Wrapper}) # via a pointer's pointee
@test reaches_i128(CUDACore.CuDeviceArray{Int128Wrapper,1,1}) # via an element type
@test !reaches_i128(Float64)
@test !reaches_i128(FloatWrapper)
@test !reaches_i128(CUDACore.CuDeviceArray{Float64,1,1})

@test CUDACore.device_layout(Int128) == (16, 16)
@test CUDACore.device_layout(FloatWrapper) == (16, 8)
@test CUDACore.device_layout(Int128Wrapper) === (host_ok ? (32, 16) : :mismatch)
@test CUDACore.device_compatible_layout(Int128Wrapper) == host_ok
@test CUDACore.device_compatible_layout(CUDACore.CuDeviceArray{Int128Wrapper,1,1}) == host_ok
@test CUDACore.device_compatible_layout(CUDACore.CuDeviceArray{Float64,1,1})

# -- end-to-end: rejected on <1.12, compiled and correct on 1.12+ --

# plain 128-bit integers occupy their own parameter slot / array element, and are fine
# regardless of how the host aligns them
@cuda kernel(Int128(1))
@test true

# aggregates with 128-bit fields are only compatible when Julia aligns
# those to 16 bytes like the device does (i.e., on 1.12+)
@eval struct Int128Wrapper
x::Int64
y::Int128
end
if Base.datatype_alignment(Int128) == 16
@cuda kernel(Int128Wrapper(1, 2))
@test true
setval(out, v) = (@inbounds out[1] = v; return)
let out = CuArray{Int128}(undef, 1)
@cuda setval(out, Int128(2)^100 + 7)
@test Array(out)[1] == Int128(2)^100 + 7
end

# aggregate with a 128-bit field, passed as a kernel argument: read its field back so
# the layout check (not an unrelated type error) is what fails on incompatible hosts
gety(out, w) = (@inbounds out[1] = w.y; return)
if host_ok
out = CuArray{Int128}(undef, 1)
@cuda gety(out, Int128Wrapper(42, Int128(2)^100 + 7))
@test Array(out)[1] == Int128(2)^100 + 7
else
@test_throws "layout differs" @cuda kernel(Int128Wrapper(1, 2))
out = CuArray{Int128}(undef, 1)
@test_throws "references 128-bit integer fields" @cuda gety(out, Int128Wrapper(42, Int128(2)^100 + 7))
end

# aggregate with a 128-bit field, reached as a device-array element (memory traffic;
# this is only seen through a pointer, so it used to slip past the argument check)
function readfields(out, A)
i = threadIdx().x
@inbounds begin
out[2i-1] = A[i].x
out[2i] = A[i].y % Int64
end
return
end
wrappers = [Int128Wrapper(1, 2), Int128Wrapper(3, 4)]
if host_ok
dA = CuArray(wrappers); out = CuArray{Int64}(undef, 4)
@cuda threads=2 readfields(out, dA)
@test Array(out) == [1, 2, 3, 4]
else
dA = CuArray(wrappers); out = CuArray{Int64}(undef, 4)
@test_throws "references 128-bit integer fields" @cuda threads=2 readfields(out, dA)
end

# control: an aggregate without 128-bit integers is always accepted
let out = CuArray{Float64}(undef, 1)
@cuda gety(out, FloatWrapper(1, 3.5))
@test Array(out)[1] == 3.5
end
end

Expand Down