[refactor] switch to torch stable api#301
Conversation
Signed-off-by: Hank <hcc.mayday@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request migrates several CUDA and HIP kernels to a stable PyTorch ABI under the _C_stable_libtorch and _moe_C_stable_libtorch extension targets, while adding new optimized kernels like dsv3_fused_a_gemm, fp32_router_gemm, and cooperative_topk. It also introduces ROCm compatibility improvements across various files, including the memory allocator. Feedback on the changes highlights three critical issues: first, the VLLM_STABLE_EXT_SRC source list is missing from the _C_stable_libtorch target definition, which will break compilation; second, a virtual address space leak exists in cumem_allocator.cpp if memory allocation fails; and third, an out-of-bounds shared memory read vulnerability was identified in preprocessTopkIdKernel due to a missing boundary check on topk_id.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| define_gpu_extension_target( | ||
| _C | ||
| _C_stable_libtorch | ||
| DESTINATION vllm_metax | ||
| LANGUAGE ${VLLM_GPU_LANG} | ||
| SOURCES ${VLLM_EXT_SRC} |
There was a problem hiding this comment.
The VLLM_STABLE_EXT_SRC list of source files is never passed to define_gpu_extension_target for the _C_stable_libtorch target. This will cause compilation to fail due to missing source files and undefined symbols. Please append ${VLLM_STABLE_EXT_SRC} to the SOURCES argument.
define_gpu_extension_target(
_C_stable_libtorch
DESTINATION vllm_metax
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_EXT_SRC} ${VLLM_STABLE_EXT_SRC}
| if (p_memHandle[i] == nullptr) { | ||
| std::cerr << "ERROR: malloc failed for p_memHandle[" << i << "].\n"; | ||
| for (auto j = 0; j < i; ++j) { | ||
| free(p_memHandle[j]); | ||
| } | ||
| free(p_memHandle); | ||
| free(chunk_sizes); | ||
| return nullptr; | ||
| } |
There was a problem hiding this comment.
If the allocation of p_memHandle[i] fails, the reserved virtual address space d_mem is leaked because cuMemAddressFree is not called before returning nullptr. Please free the reserved address space on this failure path.
if (p_memHandle[i] == nullptr) {
std::cerr << "ERROR: malloc failed for p_memHandle[" << i << "].\n";
for (auto j = 0; j < i; ++j) {
free(p_memHandle[j]);
}
free(p_memHandle);
free(chunk_sizes);
(void)cuMemAddressFree(d_mem, alignedSize);
return nullptr;
}| __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size, | ||
| const int* expert_map_ptr, | ||
| int num_experts) { | ||
| auto tidx = threadIdx.x; | ||
| auto bidx = blockIdx.x; | ||
| auto offset = bidx * blockDim.x; | ||
| auto bound = min(offset + blockDim.x, size); | ||
| extern __shared__ int smem_expert_map[]; | ||
| // store expert_map in smem | ||
| for (int i = tidx; i < num_experts; i += blockDim.x) { | ||
| smem_expert_map[i] = expert_map_ptr[i]; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // query global expert id in expert map. | ||
| // if global expert id = -1 in exert map, plus n_expert | ||
| // else set global expert id = exert map[global expert id] | ||
| if (offset + tidx < bound) { | ||
| auto topk_id = topk_id_ptr[offset + tidx]; | ||
| auto local_expert_idx = smem_expert_map[topk_id]; | ||
| if (local_expert_idx == -1) { | ||
| topk_id += num_experts; | ||
| } else { | ||
| topk_id = local_expert_idx; | ||
| } | ||
| __syncwarp(); | ||
| topk_id_ptr[offset + tidx] = topk_id; | ||
| } | ||
| } |
There was a problem hiding this comment.
In preprocessTopkIdKernel, if topk_id is negative or greater than or equal to num_experts (which can happen for padded or invalid entries), reading smem_expert_map[topk_id] will result in an out-of-bounds shared memory read. Please add a boundary check for topk_id before accessing the shared memory array.
__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
const int* expert_map_ptr,
int num_experts) {
auto tidx = threadIdx.x;
auto bidx = blockIdx.x;
auto offset = bidx * blockDim.x;
auto bound = min(offset + blockDim.x, size);
extern __shared__ int smem_expert_map[];
// store expert_map in smem
for (int i = tidx; i < num_experts; i += blockDim.x) {
smem_expert_map[i] = expert_map_ptr[i];
}
__syncthreads();
// query global expert id in expert map.
// if global expert id = -1 in exert map, plus n_expert
// else set global expert id = exert map[global expert id]
if (offset + tidx < bound) {
auto topk_id = topk_id_ptr[offset + tidx];
if (topk_id >= 0 && topk_id < num_experts) {
auto local_expert_idx = smem_expert_map[topk_id];
if (local_expert_idx == -1) {
topk_id += num_experts;
} else {
topk_id = local_expert_idx;
}
__syncwarp();
topk_id_ptr[offset + tidx] = topk_id;
}
}
}
Signed-off-by: Hank <hcc.mayday@gmail.com>
Switching to torch 2.10 and shifting the code to torch stable api