Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions CUDACore/src/compiler/compilation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
## gpucompiler interface implementation

Base.@kwdef struct CUDACompilerParams <: AbstractCompilerParams
abstract type AbstractCUDACompilerParams <: AbstractCompilerParams end

Base.@kwdef struct CUDACompilerParams <: AbstractCUDACompilerParams
cap::VersionNumber
ptx::VersionNumber
end
Expand All @@ -13,19 +15,19 @@ function Base.hash(params::CUDACompilerParams, h::UInt)
end

const CUDACompilerConfig = CompilerConfig{PTXCompilerTarget, CUDACompilerParams}
const CUDACompilerJob = CompilerJob{PTXCompilerTarget,CUDACompilerParams}
const AnyCUDAJob = CompilerJob{PTXCompilerTarget,<:AbstractCUDACompilerParams}

GPUCompiler.runtime_module(@nospecialize(job::CUDACompilerJob)) = CUDACore
GPUCompiler.runtime_module(@nospecialize(job::AnyCUDAJob)) = CUDACore

# filter out functions from libdevice and cudadevrt
GPUCompiler.isintrinsic(@nospecialize(job::CUDACompilerJob), fn::String) =
GPUCompiler.isintrinsic(@nospecialize(job::AnyCUDAJob), fn::String) =
invoke(GPUCompiler.isintrinsic,
Tuple{CompilerJob{PTXCompilerTarget}, typeof(fn)},
job, fn) ||
fn == "__nvvm_reflect" || startswith(fn, "cuda")

# link libdevice
function GPUCompiler.link_libraries!(@nospecialize(job::CUDACompilerJob), mod::LLVM.Module,
function GPUCompiler.link_libraries!(@nospecialize(job::AnyCUDAJob), mod::LLVM.Module,
undefined_fns::Vector{String})
# only link if there's undefined __nv_ functions
if !any(fn->startswith(fn, "__nv_"), undefined_fns)
Expand All @@ -49,11 +51,11 @@ function GPUCompiler.link_libraries!(@nospecialize(job::CUDACompilerJob), mod::L
return
end

GPUCompiler.method_table(@nospecialize(job::CUDACompilerJob)) = method_table
GPUCompiler.method_table(@nospecialize(job::AnyCUDAJob)) = method_table

GPUCompiler.kernel_state_type(job::CUDACompilerJob) = KernelState
GPUCompiler.kernel_state_type(job::AnyCUDAJob) = KernelState

function GPUCompiler.finish_module!(@nospecialize(job::CUDACompilerJob),
function GPUCompiler.finish_module!(@nospecialize(job::AnyCUDAJob),
mod::LLVM.Module, entry::LLVM.Function)
entry = invoke(GPUCompiler.finish_module!,
Tuple{CompilerJob{PTXCompilerTarget}, LLVM.Module, LLVM.Function},
Expand Down Expand Up @@ -130,7 +132,7 @@ function rewrite_ptx_header(asm, ptx, cap)
r"\.target sm_\d+\w*" => ".target sm_$(cap.major)$(cap.minor)")
end

function GPUCompiler.mcgen(@nospecialize(job::CUDACompilerJob), mod::LLVM.Module, format)
function GPUCompiler.mcgen(@nospecialize(job::AnyCUDAJob), mod::LLVM.Module, format)
@assert format == LLVM.API.LLVMAssemblyFile
asm = invoke(GPUCompiler.mcgen,
Tuple{CompilerJob{PTXCompilerTarget}, LLVM.Module, typeof(format)},
Expand All @@ -150,9 +152,10 @@ function GPUCompiler.mcgen(@nospecialize(job::CUDACompilerJob), mod::LLVM.Module
asm = replace(asm, r"(\.target .+), debug" => s"\1")
end

(; ptx, cap) = job.config.params
if job.config.target.ptx != ptx || job.config.target.cap != cap
asm = rewrite_ptx_header(asm, ptx, cap)
ptx_param = job.config.params.ptx
cap_param = job.config.params.cap
if job.config.target.ptx != ptx_param || job.config.target.cap != cap_param
asm = rewrite_ptx_header(asm, ptx_param, cap_param)
end

return asm
Expand Down