Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ else
-e HOST_GLIBC_VER="${HOST_GLIBC_VER}" \
-e UCCL_RETAG_TO_HOST_GLIBC="${UCCL_RETAG_TO_HOST_GLIBC:-0}" \
-e UCCL_LOCAL_VERSION="${UCCL_LOCAL_VERSION:-}" \
-e DISABLE_SM90_FEATURES="${DISABLE_SM90_FEATURES:-0}" \
-e DISABLE_AGGRESSIVE_PTX_INSTRS="${DISABLE_AGGRESSIVE_PTX_INSTRS:-0}" \
-w /io \
"$IMAGE_NAME" \
/bin/bash /io/build_inner.sh
Expand Down
10 changes: 9 additions & 1 deletion ep/include/ep_configs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,19 @@ typedef uint16_t __hip_fp8x2_storage_t;
#ifndef DISABLE_SM90_FEATURES
#include <cuda_fp8.h>
#else
// Ampere does not support FP8 features
// Ampere does not support FP8 features, but CUDA 13+ always provides the
// header.
#if __CUDACC_VER_MAJOR__ >= 13
#include <cuda_fp8.h>
#else
#define __NV_E4M3 0
#define __NV_E5M2 1
typedef int __nv_fp8_interpretation_t;
typedef int __nv_fp8x4_e4m3;
typedef uint8_t __nv_fp8_storage_t;
typedef uint16_t __nv_fp8x2_storage_t;
#define __NV_SATFINITE 0
#define __nv_cvt_float2_to_fp8x2(a, b, c) ((uint16_t)0)
#endif
#endif
#endif
25 changes: 7 additions & 18 deletions ep/include/ep_launch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,19 @@
cfg.attrs = attr; \
cfg.numAttrs = 2
#else
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
int __num_sms = (sms); \
int __num_threads = (threads); \
auto __stream = (stream)
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
cudaLaunchAttribute attr[1]; \
attr[0].id = cudaLaunchAttributeCooperative; \
attr[0].val.cooperative = 1; \
cfg.attrs = attr; \
cfg.numAttrs = 1
#endif
#endif

#ifndef LAUNCH_KERNEL
#ifndef DISABLE_SM90_FEATURES
#define LAUNCH_KERNEL(config, kernel, ...) \
CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))
#else
#define LAUNCH_KERNEL(config, kernel, ...) \
do { \
kernel<<<__num_sms, __num_threads, 0, __stream>>>(__VA_ARGS__); \
cudaError_t e = cudaGetLastError(); \
if (e != cudaSuccess) { \
EPException cuda_exception("CUDA", __FILE__, __LINE__, \
cudaGetErrorString(e)); \
fprintf(stderr, "%s\n", cuda_exception.what()); \
throw cuda_exception; \
} \
} while (0)
#endif
#endif

#ifndef SET_SHARED_MEMORY_FOR_TMA
Expand Down
36 changes: 36 additions & 0 deletions ep/include/ep_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,42 @@ __device__ __forceinline__ void st_release_sys_global(dtype_t const* ptr,
} // namespace amd
#endif

// Software grid barrier for CUDA (non-SM90) path.
// Mirrors amd::grid_sync_then_zero: last arriver clears the counter so the
// barrier is reusable across iterations without an extra clean kernel.
#if !defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__)
__device__ __forceinline__ void cuda_grid_barrier(int* bar_ptr,
int num_blocks) {
// Ensure all threads in this block have completed their prior memory
// operations before thread 0 signals arrival on behalf of the block.
__syncthreads();
if (threadIdx.x == 0) {
// Release fence: ensure all prior writes (e.g. packed_recv_count clean)
// are visible to other blocks before we signal arrival.
__threadfence();

unsigned int val = atomicAdd(reinterpret_cast<unsigned int*>(bar_ptr), 1u);
if (val == static_cast<unsigned int>(num_blocks - 1)) {
// Last arriver: all blocks have fenced their writes. Reset counter.
// Use release semantics so the store of 0 is visible after all prior
// writes from this (last) block.
atomicExch(reinterpret_cast<unsigned int*>(bar_ptr), 0u);
} else {
// Spin until last arriver resets counter to 0.
while (*(unsigned int volatile*)bar_ptr != 0u)
;
}

// Acquire fence: ensure we see all writes from all blocks that arrived
// before us (transitively through the atomic total order).
__threadfence();

@MaoZiming MaoZiming Jun 2, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __threadfence(); here is not a acquire fence? Is it needed

}
// Broadcast barrier completion from thread 0 to all threads in the block
// before any thread proceeds past the barrier.
__syncthreads();
}
#endif

__forceinline__ __device__ int get_lane_id() {
int lane_id;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
Expand Down
36 changes: 18 additions & 18 deletions ep/src/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ __global__ void __launch_bounds__(
};

// TMA stuffs
#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + target_rank * kNumTMABytesPerWarp;
auto tma_mbarrier =
Expand Down Expand Up @@ -934,7 +934,7 @@ __global__ void __launch_bounds__(
}
__syncwarp();

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
// Release the transaction in the window
if (is_token_in_rank_uint64 != 0) {
// Acquire lock first
Expand Down Expand Up @@ -1347,7 +1347,7 @@ __global__ void __launch_bounds__(
reinterpret_cast<int4*>(dst_shifted),
reinterpret_cast<int4*>(shifted), ld_nc_global,
st_na_global);
#else
#elif !defined(DISABLE_SM90_FEATURES)
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token,
false);
Expand All @@ -1366,7 +1366,7 @@ __global__ void __launch_bounds__(
if ((++num_tokens_sent) == num_max_rdma_chunked_send_tokens)
src_rdma_tail = i + 1;

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
tma_store_wait();
__syncwarp();
#endif
Expand Down Expand Up @@ -1552,7 +1552,7 @@ __global__ void __launch_bounds__(
reinterpret_cast<float*>(shifted + hidden_bytes),
ld_nc_global, st_na_global);

#else
#elif !defined(DISABLE_SM90_FEATURES)
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, tma_load_bytes);
Expand Down Expand Up @@ -1606,7 +1606,7 @@ __global__ void __launch_bounds__(
st_na_global(recv_topk_weights + recv_idx, weight_value);
}

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
// Wait TMA to be finished
tma_store_wait();
#endif
Expand Down Expand Up @@ -1789,7 +1789,7 @@ __global__ void cached_notify(
} else if (sm_id == 1) {
if (is_cached_dispatch) return;

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
EP_DEVICE_ASSERT(num_warps >= num_channels);
#endif
EP_DEVICE_ASSERT(num_rdma_ranks <= WARP_SIZE);
Expand Down Expand Up @@ -1829,7 +1829,7 @@ __global__ void cached_notify(
} else {
if (is_cached_dispatch) return;

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
EP_DEVICE_ASSERT(num_warps >= num_channels);
#endif
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
Expand All @@ -1852,7 +1852,7 @@ __global__ void cached_notify(
EP_STATIC_ASSERT(num_bytes_per_token % 16 == 0,
"num_bytes_per_token should be divisible by 16");

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp;
Expand Down Expand Up @@ -1890,7 +1890,7 @@ __global__ void cached_notify(
auto batch_start_idx =
max(token_start_idx, batch_end_idx - num_tokens_per_batch);

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
if (lane_id == 0) {
tma_load_1d(
tma_buffer,
Expand Down Expand Up @@ -1918,7 +1918,7 @@ __global__ void cached_notify(
} else {
last_head = current_head;
}
#else
#elif !defined(DISABLE_SM90_FEATURES)
auto current_head = reinterpret_cast<int*>(tma_buffer)
[(token_idx - batch_start_idx) * NUM_MAX_NVL_PEERS + lane_id];
if (current_head < 0) {
Expand All @@ -1933,7 +1933,7 @@ __global__ void cached_notify(
}
}

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
tma_store_fence();
__syncwarp();

Expand Down Expand Up @@ -2309,7 +2309,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
channel_id, num_channels, nvl_rank)
.advance_also(local_buffer_ptr);

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer =
Expand Down Expand Up @@ -2438,7 +2438,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
sizeof(SourceMeta) +
lane_id * sizeof(float)),
ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
#else
#elif !defined(DISABLE_SM90_FEATURES)
if (lane_id == 0) {
tma_store_wait();
tma_load_1d(tma_buffer, shifted_x, tma_mbarrier, hidden_bytes);
Expand Down Expand Up @@ -2473,7 +2473,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
}

// Move queue tail
#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
tma_store_wait();
#endif
__syncwarp();
Expand Down Expand Up @@ -2577,7 +2577,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16,
"Barriers are not enough");

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
// TMA stuffs
constexpr int kNumStages = 2;
constexpr int kNumTMALoadBytes = sizeof(int4) * 32;
Expand Down Expand Up @@ -2745,7 +2745,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma,
get_addr_fn, recv_tw_fn, nullptr, dummy_tma_phases);

#else
#elif !defined(DISABLE_SM90_FEATURES)
combine_token<NUM_MAX_NVL_PEERS, false, dtype_t, NUM_MAX_NVL_PEERS,
true, kNumStages, kNumTMALoadBytes>(
expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk,
Expand Down Expand Up @@ -3034,7 +3034,7 @@ void combine(cudaDataType_t type, void* combined_x,
constexpr int kNumTMABytesPerSenderWarp = 16384;
constexpr int kNumTMABytesPerForwarderWarp = 9248;

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
constexpr int smem_size =
std::max(kNumTMABytesPerSenderWarp * NUM_MAX_NVL_PEERS,
kNumTMABytesPerForwarderWarp * kNumCombineForwarderWarps);
Expand Down
15 changes: 13 additions & 2 deletions ep/src/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
// Reset counter after sync so send-only launches (return_recv_hook) do not
// leave a stale value that deadlocks the next dispatch.
amd::grid_sync_then_zero(grid_sync_barrier_ptr, num_sms);
#elif defined(DISABLE_SM90_FEATURES)
cuda_grid_barrier(grid_sync_barrier_ptr, num_sms);
#else
cg::this_grid().sync();
#endif
Expand Down Expand Up @@ -464,6 +466,8 @@ LOW_LATENCY_DISPATCH_RECV:
if (phases & LOW_LATENCY_SEND_PHASE)
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
amd::grid_sync_then_zero(grid_sync_barrier_ptr, num_sms);
#elif defined(DISABLE_SM90_FEATURES)
cuda_grid_barrier(grid_sync_barrier_ptr, num_sms);
#else
cg::this_grid().sync();
#endif
Expand Down Expand Up @@ -811,7 +815,7 @@ __global__ __launch_bounds__(1024, 1) void combine(
int offset, num_tokens_to_send;
unpack2(layout, num_tokens_to_send, offset);

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
// TMA stuffs
constexpr int kNumTMABufferBytes = sizeof(int4) * WARP_SIZE * kNumUnrolls;
constexpr int kNumStages = 3;
Expand Down Expand Up @@ -899,6 +903,11 @@ __global__ __launch_bounds__(1024, 1) void combine(
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, cpy_dst_int4_ptr,
cpy_src_int4_ptr, ld_nc_global, st_na_global);

#elif defined(DISABLE_SM90_FEATURES)
// Non-SM90 NVIDIA path: simple warp copy (no TMA available)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, cpy_dst_int4_ptr,
cpy_src_int4_ptr, ld_nc_global, st_na_global);

#else
// Prefetch
if (elect_one_sync(lane_id))
Expand Down Expand Up @@ -1010,7 +1019,7 @@ __global__ __launch_bounds__(1024, 1) void combine(
#endif
}

#if defined(__NVCC__)
#if defined(__NVCC__) && !defined(DISABLE_SM90_FEATURES)
// Flush all stores
tma_store_wait();
__syncwarp();
Expand Down Expand Up @@ -1144,6 +1153,8 @@ LOW_LATENCY_COMBINE_RECV:
}
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
amd::grid_sync_then_zero(grid_sync_barrier_ptr, num_sms);
#elif defined(DISABLE_SM90_FEATURES)
cuda_grid_barrier(grid_sync_barrier_ptr, num_sms);
#else
cg::this_grid().sync();
#endif
Expand Down
Loading