Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions src/io/rdma/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,13 +738,33 @@ NotifManager::FlushDrainStats NotifManager::ProcessOneCqe(
struct ibv_recv_wr* bad = nullptr;
SYSCALL_RETURN_ZERO(ibv_post_recv(ep.local.ibvHandle.qp, &wr, &bad));
} else if (wc[i].opcode == IBV_WC_SEND) {
if (!IsNotifSendWrId(wc[i].wr_id)) {
MORI_IO_WARN(
"ProcessOneCqe: unexpected SEND completion with non-notification wr_id {}; "
"releasing 1 sqDepth under current SEND invariant",
wc[i].wr_id);
if (IsNotifSendWrId(wc[i].wr_id)) {
// Backward-compatible handling for any notification SENDs posted by
// older code that used tagged transfer IDs instead of ledger records.
if (ep.sqDepth) ep.sqDepth->fetch_sub(1, std::memory_order_relaxed);
} else {
int mergedBatchSize = 0;
auto meta = ep.ledger
? ep.ledger->ReleaseByCqe(wc[i].wr_id, ep.sqDepth.get(), &mergedBatchSize)
: nullptr;
if (meta) {
uint32_t finishedBefore = meta->finishedBatchSize.fetch_add(mergedBatchSize);
TransferStatus* statusPtr = meta->status;
if (statusPtr != nullptr &&
(finishedBefore + mergedBatchSize) == meta->totalBatchSize) {
statusPtr->Update(StatusCode::SUCCESS, ibv_wc_status_str(wc[i].status));
}
MORI_IO_TRACE(
"ProcessOneCqe: notification SEND CQE for task {} total={} finished={} cur={}",
meta->id, meta->totalBatchSize, finishedBefore, mergedBatchSize);
} else {
MORI_IO_WARN(
"ProcessOneCqe: notification SEND CQE has no ledger record for wr_id {}; "
"releasing 1 sqDepth under fallback path",
wc[i].wr_id);
if (ep.sqDepth) ep.sqDepth->fetch_sub(1, std::memory_order_relaxed);
}
}
if (ep.sqDepth) ep.sqDepth->fetch_sub(1, std::memory_order_relaxed);
} else {
// Batch path: wr_id carries a recordId from the SubmissionLedger.
uint64_t recordId = wc[i].wr_id;
Expand Down Expand Up @@ -1198,7 +1218,8 @@ void RdmaBackendSession::ReadWrite(size_t localOffset, size_t remoteOffset, size
TransferStatus* status, TransferUniqueId id, bool isRead) {
MORI_IO_FUNCTION_TIMER;
status->SetCode(StatusCode::IN_PROGRESS);
auto callbackMeta = std::make_shared<CqCallbackMeta>(status, id, 1);
const int notifBatchSize = config.enableNotification ? static_cast<int>(eps.size()) : 0;
auto callbackMeta = std::make_shared<CqCallbackMeta>(status, id, 1 + notifBatchSize);
internal::PublishCurrentIoCallDiagnostics(callbackMeta);

RdmaOpRet ret = RdmaBatchReadWrite(eps, localMrPerEp, remoteMrPerEp, {localOffset},
Expand All @@ -1211,7 +1232,7 @@ void RdmaBackendSession::ReadWrite(size_t localOffset, size_t remoteOffset, size
status->Update(ret.code, ret.message);
}
if (!ret.Failed() && config.enableNotification) {
RdmaOpRet notifRet = RdmaNotifyTransfer(eps, status, id);
RdmaOpRet notifRet = RdmaNotifyTransfer(eps, callbackMeta, id);
if (notifRet.Failed()) {
status->Update(notifRet.code, notifRet.message);
}
Expand All @@ -1223,7 +1244,9 @@ void RdmaBackendSession::BatchReadWrite(const SizeVec& localOffsets, const SizeV
TransferUniqueId id, bool isRead) {
MORI_IO_FUNCTION_TIMER;
status->SetCode(StatusCode::IN_PROGRESS);
auto callbackMeta = std::make_shared<CqCallbackMeta>(status, id, sizes.size());
const int notifBatchSize = config.enableNotification ? static_cast<int>(eps.size()) : 0;
auto callbackMeta =
std::make_shared<CqCallbackMeta>(status, id, static_cast<int>(sizes.size()) + notifBatchSize);
internal::PublishCurrentIoCallDiagnostics(callbackMeta);
RdmaOpRet ret;
if (executor) {
Expand All @@ -1241,7 +1264,7 @@ void RdmaBackendSession::BatchReadWrite(const SizeVec& localOffsets, const SizeV
status->Update(ret.code, ret.message);
}
if (!ret.Failed() && config.enableNotification) {
RdmaOpRet notifRet = RdmaNotifyTransfer(eps, status, id);
RdmaOpRet notifRet = RdmaNotifyTransfer(eps, callbackMeta, id);
if (notifRet.Failed()) {
status->Update(notifRet.code, notifRet.message);
}
Expand Down
25 changes: 19 additions & 6 deletions src/io/rdma/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,9 @@ static void ResetMergedWorkRequestPointers(MergedWorkRequest* wr) {
/* Rdma Utilities */
/* ---------------------------------------------------------------------------------------------- */

RdmaOpRet RdmaNotifyTransfer(const EpPairVec& eps, TransferStatus* status, TransferUniqueId id) {
RdmaOpRet RdmaNotifyTransfer(const EpPairVec& eps, std::shared_ptr<CqCallbackMeta> callbackMeta,
TransferUniqueId id) {
MORI_IO_FUNCTION_TIMER;
(void)status;

std::string reserveErr;
int reserved = 0;
Expand All @@ -388,6 +388,12 @@ RdmaOpRet RdmaNotifyTransfer(const EpPairVec& eps, TransferStatus* status, Trans
}

for (size_t i = 0; i < eps.size(); i++) {
if (!eps[i].ledger) {
for (int j = static_cast<int>(i); j < reserved; ++j) ReleaseSqDepth(eps[j], 1);
return {StatusCode::ERR_RDMA_OP,
"submission ledger is not initialized for notification SEND tracking"};
}

const application::RdmaEndpoint& ep = eps[i].local;
NotifMessage msg{id, static_cast<int>(i), static_cast<int>(eps.size())};

Expand All @@ -396,8 +402,10 @@ RdmaOpRet RdmaNotifyTransfer(const EpPairVec& eps, TransferStatus* status, Trans
sge.length = sizeof(NotifMessage);
sge.lkey = 0;

const uint64_t recordId = eps[i].ledger->Insert(1, true, callbackMeta, 1);

struct ibv_send_wr wr{};
wr.wr_id = MakeNotifSendWrId(id);
wr.wr_id = recordId;
wr.opcode = IBV_WR_SEND;
wr.send_flags = IBV_SEND_INLINE | IBV_SEND_SIGNALED;
wr.sg_list = &sge;
Expand All @@ -406,8 +414,11 @@ RdmaOpRet RdmaNotifyTransfer(const EpPairVec& eps, TransferStatus* status, Trans
struct ibv_send_wr* bad_wr = nullptr;
int ret = ibv_post_send(ep.ibvHandle.qp, &wr, &bad_wr);
if (ret != 0) {
// WR i was reserved but failed to post if bad_wr points at this WR.
if (bad_wr == &wr) ReleaseSqDepth(eps[i], 1);
// This call posts a single WR, so a non-zero return means the current
// notification was not accepted and no CQE will arrive for its ledger record.
ReleaseSqDepth(eps[i], 1);
int dummy = 0;
eps[i].ledger->ReleaseByCqe(recordId, nullptr, &dummy);
// Any remaining endpoints are reserved but not posted yet.
for (int j = i + 1; j < eps.size(); ++j) ReleaseSqDepth(eps[j], 1);
std::string message =
Expand Down Expand Up @@ -675,8 +686,10 @@ RdmaOpRet RdmaBatchReadWrite(const EpPairVec& eps,
if (mergedWrCount > static_cast<size_t>(std::numeric_limits<int>::max())) {
return {StatusCode::ERR_INVALID_ARGS, "final WR count exceeds int range"};
}
const int notifBatchSize =
std::max(0, callbackMeta->totalBatchSize - static_cast<int>(batchSize));
for (size_t k = 0; k < mergedWrCount; ++k) mergedWrs[k].mergedRequests = 1;
callbackMeta->totalBatchSize = static_cast<int>(mergedWrCount);
callbackMeta->totalBatchSize = static_cast<int>(mergedWrCount) + notifBatchSize;
}

size_t epNum = eps.size();
Expand Down
3 changes: 2 additions & 1 deletion src/io/rdma/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ struct RdmaOpRet {
bool Failed() { return code > StatusCode::ERR_BEGIN; }
};

RdmaOpRet RdmaNotifyTransfer(const EpPairVec& eps, TransferStatus* status, TransferUniqueId id);
RdmaOpRet RdmaNotifyTransfer(const EpPairVec& eps, std::shared_ptr<CqCallbackMeta> callbackMeta,
TransferUniqueId id);

RdmaOpRet RdmaBatchReadWrite(const EpPairVec& eps,
const std::vector<application::RdmaMemoryRegion>& localMrPerEp,
Expand Down
41 changes: 41 additions & 0 deletions tests/cpp/io/test_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,46 @@ void CaseSubmissionLedgerBasic() {
Require(sqDepth2.load(std::memory_order_relaxed) == 5, "sq depth after posted CQE release");
}

void CaseNotificationCompletionFanIn() {
TransferStatus status;
status.SetCode(StatusCode::IN_PROGRESS);
auto meta = std::make_shared<CqCallbackMeta>(&status, 42, 3);

uint32_t finishedBefore = meta->finishedBatchSize.fetch_add(2);
if (finishedBefore + 2 == meta->totalBatchSize) {
status.Update(StatusCode::SUCCESS, "data complete");
}
Require(status.InProgress(), "data completion must not finish before notification SEND CQE");

finishedBefore = meta->finishedBatchSize.fetch_add(1);
if (finishedBefore + 1 == meta->totalBatchSize) {
status.Update(StatusCode::SUCCESS, "notification complete");
}
Require(status.Succeeded(), "final notification completion should finish transfer");

TransferStatus outOfOrderStatus;
outOfOrderStatus.SetCode(StatusCode::IN_PROGRESS);
auto outOfOrderMeta = std::make_shared<CqCallbackMeta>(&outOfOrderStatus, 44, 5);
(void)outOfOrderMeta->finishedBatchSize.fetch_add(2); // notification SEND CQEs first
finishedBefore = outOfOrderMeta->finishedBatchSize.fetch_add(3);
if (finishedBefore + 3 == outOfOrderMeta->totalBatchSize) {
outOfOrderStatus.Update(StatusCode::SUCCESS, "data complete after notification");
}
Require(outOfOrderStatus.Succeeded(),
"notification-first completion order must still finish at the exact total");

TransferStatus failedStatus;
failedStatus.SetCode(StatusCode::IN_PROGRESS);
auto failedMeta = std::make_shared<CqCallbackMeta>(&failedStatus, 43, 2);
(void)failedMeta->finishedBatchSize.fetch_add(1);
failedStatus.Update(StatusCode::ERR_RDMA_OP, "notification failed");
finishedBefore = failedMeta->finishedBatchSize.fetch_add(1);
if (finishedBefore + 1 == failedMeta->totalBatchSize) {
failedStatus.Update(StatusCode::SUCCESS, "late success");
}
Require(failedStatus.Failed(), "notification failure must not be overwritten by late success");
}

void CaseWrIdNamespaceHelpers() {
const uint64_t taggedZero = MakeNotifSendWrId(0);
Require(taggedZero == kNotifSendWrIdTag, "tagged zero should only set the reserved high bit");
Expand Down Expand Up @@ -1451,6 +1491,7 @@ int main(int argc, char* argv[]) {
SetLogLevel("info");
std::vector<TestCase> cases = {
{"submission_ledger_basic", CaseSubmissionLedgerBasic},
{"notification_completion_fan_in", CaseNotificationCompletionFanIn},
{"wr_id_namespace_helpers", CaseWrIdNamespaceHelpers},
{"rdma_backend_config_chunking_fields", CaseRdmaBackendConfigChunkingFields},
{"resolve_requested_nics", CaseResolveRequestedNics},
Expand Down
Loading