Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bench/gemm_rs_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# Keep the release default, but allow parity tests against the experiment
# harness, which leaves this unset.
os.environ.setdefault("MKERNEL_BIND_RETAINED_HANDLE", "1")
# team_v14: skip the host-side reset_arrival_flags + cudaDeviceSynchronize in
# commit_epoch. Safe because the gemm_rs kernel now resets the arrival region
# on-device at iter-end via the dedicated reduce CTAs (mirrors gemm_ar).
os.environ.setdefault("MKERNEL_COMMIT_EPOCH_SKIP_ARRIVAL_RESET", "1")

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -201,6 +205,10 @@ def ready_entries(m_):
fifo = mod.get_fifo_handles()
arrival_ptr = mod.get_arrival_flags_ptr()
recv_ptr = mod.get_recv_buf_ptr()
# Total u32 words (count + tail_count) in the arrival region. Passed to
# the kernel so the dedicated reduce CTAs can perform the iter-end
# on-device reset (paired with MKERNEL_COMMIT_EPOCH_SKIP_ARRIVAL_RESET=1).
arrival_total_words = mod.get_arrival_flags_total_words()

epoch = 1
mod.set_epoch(epoch)
Expand All @@ -224,6 +232,7 @@ def run_once():
ready_chunk,
staging_dbuf,
num_nodes=NUM_NODES,
arrival_total_words=arrival_total_words,
)

for wi in range(args.warmup):
Expand Down
3 changes: 3 additions & 0 deletions bench/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ run_one_2node() {
if [[ -n "${MKERNEL_PREP_EPOCH_FAST:-}" ]]; then
env_str="$env_str MKERNEL_PREP_EPOCH_FAST=$MKERNEL_PREP_EPOCH_FAST"
fi
if [[ -n "${MKERNEL_COMMIT_EPOCH_SKIP_ARRIVAL_RESET:-}" ]]; then
env_str="$env_str MKERNEL_COMMIT_EPOCH_SKIP_ARRIVAL_RESET=$MKERNEL_COMMIT_EPOCH_SKIP_ARRIVAL_RESET"
fi
# Allow per-shape SM-split overrides for gemm_ar tuning sweeps. These
# take precedence over the COMMON_ENV defaults because later assignments
# in the env_str win.
Expand Down
8 changes: 8 additions & 0 deletions include/comm/internode/session_py.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ inline int64_t get_arrival_flags_ptr(Session* session) {
return reinterpret_cast<int64_t>(internode::get_arrival_device_ptr(session));
}

// Total number of u32 words in the arrival region (count + tail_count).
// Used by kernels that perform on-device iter-end reset paired with
// MKERNEL_COMMIT_EPOCH_SKIP_ARRIVAL_RESET=1 in commit_epoch.
inline int64_t get_arrival_flags_total_words(Session* session) {
if (!session) return 0;
return (int64_t)session->arrival.count + (int64_t)session->arrival.tail_count;
}

inline int64_t get_recv_buf_ptr(Session* session) {
return reinterpret_cast<int64_t>(internode::get_recv_buf_ptr(session));
}
Expand Down
11 changes: 11 additions & 0 deletions include/operators/gemm_ar/gemm_ar.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ struct fused_globals {
int remote_queue_stride;
int defer_final_multicast_finish;
int work_steal_enabled;
// MKERNEL_GEMM_AR_SKIP_FINAL_BARRIER: when 1, skip the iter-end
// hierarchical_xnode_barrier entirely (no NOTIFY push, no spin). For
// experimental measurement of barrier cost. Correctness only holds if
// host-side dist.barrier()/synchronize keeps ranks aligned per iter
// (i.e., legacy-sync mode), or the workload tolerates >=1 iter of drift.
int skip_final_barrier;
// When true, intra-AR xdev barrier waits use acquire loads instead of
// relaxed loads. The branch is outside the spin body.
bool use_acquire_poll = false;
Expand Down Expand Up @@ -1220,6 +1226,7 @@ __host__ inline fused_globals gemm_ar_make_globals(
.remote_queue_stride = scratch.remote_queue_stride,
.defer_final_multicast_finish = 0,
.work_steal_enabled = 0,
.skip_final_barrier = 0,
.total_chunks = scratch.total_chunks,
.total_tiles_per_device = scratch.slice_tiles,
.chunk_tiles = scratch.chunk_tiles,
Expand Down Expand Up @@ -1318,6 +1325,10 @@ void entrypoint(
const char* ws_env = std::getenv("GEMM_AR_WORK_STEAL");
G.work_steal_enabled = (ws_env != nullptr && ws_env[0] == '1') ? 1 : 0;
}
{
const char* sb_env = std::getenv("MKERNEL_GEMM_AR_SKIP_FINAL_BARRIER");
G.skip_final_barrier = (sb_env != nullptr && sb_env[0] == '1') ? 1 : 0;
}
{
const char* r8_env = std::getenv("GEMM_AR_R8_WARP_SPEC");
G.r8_warp_spec = (r8_env != nullptr && r8_env[0] == '1') ? 1u : 0u;
Expand Down
59 changes: 57 additions & 2 deletions include/operators/gemm_rs/gemm_rs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ struct fused_globals {
uint8_t use_intra_rs_dual_write;
uint8_t _pad0[2];
uint32_t reduce_poll_sleep_ns;
// Total number of u32 words in the arrival region (count + tail_count).
// Populated from session at entry; used by the on-device iter-end reset
// to zero the entire arrival region cooperatively. When 0, the kernel
// skips the on-device reset and the host commit_epoch path is responsible.
uint32_t arrival_total_words;
};

intra_globals intra;
Expand Down Expand Up @@ -385,6 +390,50 @@ __device__ __forceinline__ int gemm_rs_send_ready_bitmap_region_base(
}


// Cooperative iter-end reset of arrival flags (and per-queue tails). Mirrors
// gemm_ar_iter_end_reset_arrival_flags in gemm_ar.cuh. Called by the dedicated
// reduce CTAs after reduce_tiles_ws drains the work-stealing pool, before the
// kernel exits. Pairs with MKERNEL_COMMIT_EPOCH_SKIP_ARRIVAL_RESET=1 in the
// host commit_epoch path: that env var elides the host-side memset over the
// arrival region and the cudaDeviceSynchronize that follows; the kernel does
// the equivalent zeroing here.
//
// Safety: the reduce CTAs only enter this routine after every chunk's
// arrival_flags slot has been polled to == epoch in reduce_tiles_ws (the
// next_reduce work-stealing counter exits the loop only when chunk_id >=
// total_chunks, meaning every claimed chunk's flag was read). After zeroing,
// peer's next-iter RDMA WRITEs of the next epoch land on a clean slot — the
// same window the host-side memset relies on, just moved earlier on-stream.
__device__ inline void gemm_rs_iter_end_reset_arrival_flags(
const fused_globals::runtime_state& Rt,
int participating_block_start, int participating_block_count
) {
if (Rt.arrival_total_words == 0u) return;
const int local_bid = blockIdx.x - participating_block_start;
if (local_bid < 0 || local_bid >= participating_block_count) return;
// Spin until every chunk has published completion. reduce_tiles_ws
// increments Rt.chunks_processed once per chunk; when it reaches
// total_chunks every reducer (including recycled compute/send CTAs)
// has finished polling arrival_flags for this iter, so it's safe to
// wipe. Only thread 0 polls the counter; the rest wait at __syncthreads.
const unsigned int total_chunks =
(unsigned int)Rt.row_blocks_per_slice * (unsigned int)Rt.chunks_per_row;
if (threadIdx.x == 0) {
while (comm::atomic_u32::acquire_load_gpu(Rt.chunks_processed) < total_chunks) {
}
}
__syncthreads();
const int total_words = (int)Rt.arrival_total_words;
const int stride = participating_block_count * blockDim.x;
const int offset = local_bid * blockDim.x + threadIdx.x;
volatile uint32_t* flag_ptr = Rt.arrival_flags;
for (int i = offset; i < total_words; i += stride) {
flag_ptr[i] = 0u;
}
__threadfence_system();
__syncthreads();
}

// ============================================================================
// Host entrypoint
// ============================================================================
Expand Down Expand Up @@ -475,8 +524,13 @@ void entrypoint_fused(
dist::ParallelBuffer &ready_chunk,
// Staging DistBuffer used as the chunk-major intra-RS atomic-add target.
pybind11::object staging_obj,
int num_nodes = 2 // Total node count (>= 2). N == 2 reproduces the
// legacy 2-node behavior bit-for-bit.
int num_nodes = 2, // Total node count (>= 2). N == 2 reproduces the
// legacy 2-node behavior bit-for-bit.
int arrival_total_words = 0 // Total u32 words in the arrival region
// (session arrival.count + arrival.tail_count).
// When > 0, the kernel performs an on-device
// iter-end reset; pairs with
// MKERNEL_COMMIT_EPOCH_SKIP_ARRIVAL_RESET=1.
) {
const int dev_idx = output.local_rank_;
c10::cuda::CUDAGuard device_guard(dev_idx);
Expand Down Expand Up @@ -703,6 +757,7 @@ void entrypoint_fused(
.use_intra_rs_dual_write = (uint8_t)(use_intra_rs_dual_write_rt ? 1u : 0u),
._pad0 = {0, 0},
.reduce_poll_sleep_ns = (uint32_t)(reduce_poll_sleep_ns > 0 ? reduce_poll_sleep_ns : 100),
.arrival_total_words = (uint32_t)(arrival_total_words > 0 ? arrival_total_words : 0),
};
cudaMemcpyAsync(g_fused_runtime[dev_idx], &rt, sizeof(rt),
cudaMemcpyHostToDevice, stream);
Expand Down
7 changes: 6 additions & 1 deletion include/operators/gemm_rs/session.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, int> get_fifo_handles_py() {
return internode::py::get_fifo_handles(g_session);
}
int64_t get_arrival_flags_ptr_py() { return internode::py::get_arrival_flags_ptr(g_session); }
int64_t get_arrival_flags_total_words_py() {
return internode::py::get_arrival_flags_total_words(g_session);
}
int64_t get_recv_buf_ptr_py() { return internode::py::get_recv_buf_ptr(g_session); }

#include <torch/csrc/utils/pybind.h>
Expand All @@ -96,6 +99,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("set_epoch", &set_epoch_py);
m.def("get_fifo_handles", &get_fifo_handles_py);
m.def("get_arrival_flags_ptr", &get_arrival_flags_ptr_py);
m.def("get_arrival_flags_total_words", &get_arrival_flags_total_words_py);
m.def("get_recv_buf_ptr", &get_recv_buf_ptr_py);
m.def("gemm_rs_fused", &gemm_rs_multinode::entrypoint_fused,
pybind11::arg("A"),
Expand All @@ -122,5 +126,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::arg("reduce_poll_sleep_ns") = (int64_t)100,
pybind11::arg("ready_chunk"),
pybind11::arg("staging") = pybind11::none(),
pybind11::arg("num_nodes") = 2);
pybind11::arg("num_nodes") = 2,
pybind11::arg("arrival_total_words") = 0);
}
8 changes: 6 additions & 2 deletions src/gemm_ar.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,9 @@ __device__ __forceinline__ void fused_kernel(const fused_globals& G) {
// Reset arrival flags on-stream BEFORE the barrier — the barrier then
// gates peer's next-iter send, so no clobber race with peer RDMA.
gemm_ar_iter_end_reset_arrival_flags(G, reduce_base, G.num_inter_reduce_store_sms);
gemm_ar_hierarchical_xnode_barrier(G, reduce_base);
if (!G.skip_final_barrier) {
gemm_ar_hierarchical_xnode_barrier(G, reduce_base);
}
}
}
}
Expand All @@ -1026,7 +1028,9 @@ __device__ inline void fused_epilogue_kernel(const fused_globals& G) {
// Reset arrival flags on-stream BEFORE the barrier; gating peer's
// next-iter RDMA writes behind our push means no clobber race.
gemm_ar_iter_end_reset_arrival_flags(G, reduce_base, G.num_inter_reduce_store_sms);
gemm_ar_hierarchical_xnode_barrier(G, reduce_base);
if (!G.skip_final_barrier) {
gemm_ar_hierarchical_xnode_barrier(G, reduce_base);
}
}
}

Expand Down
25 changes: 24 additions & 1 deletion src/gemm_rs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,16 @@ __device__ inline void reduce_tiles_ws(const G &Gv) {
*reinterpret_cast<uint4*>(Rt.output_local + si) = ov;
}
}
// Publish per-chunk completion. Used by the iter-end on-device
// arrival-flag reset (gemm_rs_iter_end_reset_arrival_flags) to wait
// for ALL claimed chunks to drain — including those still being
// polled by a peer reduce CTA after this one has exited via the
// chunk_id >= total_chunks branch.
__syncthreads();
if (threadIdx.x == 0) {
__threadfence();
atomicAdd(Rt.chunks_processed, 1u);
}
}

}
Expand Down Expand Up @@ -547,7 +557,20 @@ __device__ inline void fused_kernel(const fused_globals &G) {
reduce_tiles_ws<fused_globals>(G);
}


// On-device iter-end reset of arrival flags. Mirrors gemm_ar's
// gemm_ar_iter_end_reset_arrival_flags pattern: the dedicated reduce
// CTAs cooperatively zero the arrival region after reduce_tiles_ws has
// drained — by then every chunk's flag has been polled to == epoch and
// the slots are no longer read this iter. Pairs with
// MKERNEL_COMMIT_EPOCH_SKIP_ARRIVAL_RESET=1 in commit_epoch, which then
// skips the host-side memset + cudaDeviceSynchronize.
if (G.rt != nullptr) {
const int reduce_base = I.num_comp_sms + I.num_comm_sms + G.num_send_sms;
const int reduce_count = (int)gridDim.x - reduce_base;
if (reduce_count > 0 && (int)blockIdx.x >= reduce_base) {
gemm_rs_iter_end_reset_arrival_flags(*G.rt, reduce_base, reduce_count);
}
}

// Match the split intra kernel when we intentionally launch only the
// compute + intranode CTAs for debugging/reuse checks.
Expand Down