Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
408 changes: 276 additions & 132 deletions CMakeLists.txt

Large diffs are not rendered by default.

19 changes: 16 additions & 3 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ endmacro()
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
#
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})

# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
Expand Down Expand Up @@ -345,6 +345,17 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction()


function(cuda_archs_sm90plus OUT_CUDA_ARCHS TGT_CUDA_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(_archs "9.0a;10.0f;11.0f;12.0f" "${TGT_CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(_archs "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${TGT_CUDA_ARCHS}")
endif()
set(${OUT_CUDA_ARCHS} ${_archs} PARENT_SCOPE)
endfunction()


#
# Override the GPU architectures detected by cmake/torch and filter them by
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
Expand Down Expand Up @@ -458,7 +469,9 @@ function (define_gpu_extension_target GPU_MOD_NAME)
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")

target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
target_include_directories(${GPU_MOD_NAME} PRIVATE
csrc
csrc/libtorch_stable
${GPU_INCLUDE_DIRECTORIES})

target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
Expand Down
64 changes: 62 additions & 2 deletions csrc/attention/dtype_bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>

typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
#endif

#include <stdint.h>

Expand Down Expand Up @@ -81,23 +89,43 @@ struct FloatVec<bf16_8_t> {

// Utility functions for type conversions.
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __bfloat1622float2(val);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __bfloat162bfloat162(val);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

// Vector addition.
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
#ifndef USE_ROCM
return a + b;
#else
return __hadd(a, b);
#endif
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hadd2(a, b);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

Expand Down Expand Up @@ -141,13 +169,21 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
// Vector multiplication.
template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hmul(a, b);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hmul2(a, b);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

Expand Down Expand Up @@ -254,13 +290,21 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
// Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hfma2(a, b, c);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hfma2(bf162bf162(a), b, c);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

Expand Down Expand Up @@ -374,19 +418,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
}

inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst = __float22bfloat162_rn(src);
#endif
}

inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
#endif
}

inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w);
#endif
}

// From bfloat16 to float32.
Expand All @@ -396,8 +452,12 @@ inline __device__ float to_float(__nv_bfloat16 u) {

// Zero-out a variable.
inline __device__ void zero(__nv_bfloat16& dst) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
#endif
}

} // namespace vllm
29 changes: 15 additions & 14 deletions csrc/attention/dtype_float16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

#include "attention_generic.cuh"
#include "dtype_float32.cuh"
#include "cuda_fp16.h"

#ifdef USE_ROCM
#include <hip/hip_fp16.h>
#endif

#include <stdint.h>

Expand Down Expand Up @@ -66,10 +69,13 @@ struct FloatVec<uint4> {

// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
uint32_t b;
b = a;
b = b << 16 | b;
return b;
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = a;
tmp.u16[1] = a;
return tmp.u32;
}

inline __device__ float half_to_float(uint16_t h) {
Expand All @@ -79,15 +85,15 @@ inline __device__ float half_to_float(uint16_t h) {
}

inline __device__ float2 half2_to_float2(uint32_t v) {
uint16_t lo, hi;
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u32 = v;
lo = tmp.u16[0];
hi = tmp.u16[1];
return make_float2(half_to_float(lo), half_to_float(hi));
float2 ret;
ret.x = half_to_float(tmp.u16[0]);
ret.y = half_to_float(tmp.u16[1]);
return ret;
}

inline __device__ uint16_t float_to_half(float f) {
Expand All @@ -105,13 +111,8 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
__half2 __tmp = __half2(__float2half(f.x), __float2half(f.y));
tmp.u32 = *(uint32_t*)&__tmp;
#else
tmp.u16[0] = float_to_half(f.x);
tmp.u16[1] = float_to_half(f.y);
#endif
return tmp.u32;
}

Expand Down
3 changes: 2 additions & 1 deletion csrc/attention/dtype_fp8.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "attention_generic.cuh"
#include "torch_utils.h"

#include <stdint.h>
#ifdef ENABLE_FP8
Expand Down Expand Up @@ -30,7 +31,7 @@ inline Fp8KVCacheDataType get_fp8_kv_cache_data_type(
} else if (dtype_str == "fp8_e5m2") {
return Fp8KVCacheDataType::kFp8E5M2;
}
TORCH_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str);
TORCH_UTILS_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str);
}

// fp8 vector types for quantization of kv cache
Expand Down
Loading
Loading