Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,16 @@ if ucx_dep.found() and cuda_dep.found() and nvcc_prog.found()
}, section: 'UCX GPU Device API', bool_yn: true)
endif

# UCX SGL (scatter-gather list) API detection
if ucx_dep.found()
have_ucx_sgl = cpp.has_type('ucp_dt_local_sgl_t',
prefix: '#include <ucp/api/ucp.h>',
dependencies: ucx_dep)
if have_ucx_sgl
add_project_arguments('-DHAVE_UCX_SGL_API', language: 'cpp')
endif
endif

if get_option('disable_gds_backend')
add_project_arguments('-DDISABLE_GDS_BACKEND', language: 'cpp')
endif
Expand Down
112 changes: 109 additions & 3 deletions src/plugins/ucx/ucx_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "common/nixl_log.h"
#include "serdes/serdes.h"
#include "common/backend.h"
#include "common/configuration.h"
#include "common/nixl_log.h"

#include <optional>
Expand All @@ -32,6 +33,10 @@
#include "absl/strings/str_split.h"
#include <asio.hpp>

// A transfer to a single endpoint posts at most three requests:
// one data request, one flush request, and one notification request.
constexpr size_t single_ep_request_count = 3;

/****************************************
* Backend request management
*****************************************/
Expand Down Expand Up @@ -76,6 +81,25 @@ class nixlUcxBackendReqH : public nixlBackendReqH {

std::optional<Notif> notif;

#ifdef HAVE_UCX_SGL_API
struct sglBuffers {
std::vector<void *> localAddrs;
std::vector<uint64_t> remoteAddrs;
std::vector<size_t> lengths;
std::vector<ucp_mem_h> memhs;
std::vector<ucp_rkey_h> rkeys;

void
resize(size_t count) {
localAddrs.resize(count);
remoteAddrs.resize(count);
lengths.resize(count);
memhs.resize(count);
rkeys.resize(count);
}
} sgl;
#endif

nixlUcxBackendReqH(nixlUcxWorker *worker, size_t worker_id)
: worker_(worker),
workerId_(worker_id) {}
Expand Down Expand Up @@ -826,6 +850,17 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params)
const auto engine_config =
nixl::getBackendParamDefaulted(custom_params, "engine_config", std::string());

sglEnabled_ = nixl::config::getValueOptional<bool>("NIXL_UCX_SGL_ENABLE").value_or(false);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...::getValueDefaulted("NIXL_UCX_SGL_ENABLE", false)

#ifdef HAVE_UCX_SGL_API
NIXL_DEBUG << "UCX SGL offload " << (sglEnabled_ ? "enabled" : "disabled");
#else
if (sglEnabled_) {
NIXL_WARN << "NIXL_UCX_SGL_ENABLE is set but NIXL was built without UCX SGL "
"support";
sglEnabled_ = false;
}
#endif

uc = std::make_unique<nixlUcxContext>(devs,
init_params.enableProgTh,
num_workers,
Expand Down Expand Up @@ -1201,6 +1236,73 @@ nixlUcxEngine::sendXferRangeBatch(nixlUcxEp &ep,
return result;
}

#ifdef HAVE_UCX_SGL_API
nixl_status_t
nixlUcxEngine::sendXferSgl(const nixl_meta_dlist_t &local,
const nixl_meta_dlist_t &remote,
const std::string &remote_agent,
nixlBackendReqH *handle,
size_t start_idx,
size_t end_idx) const {
const auto int_handle = static_cast<nixlUcxBackendReqH *>(handle);
const size_t worker_id = int_handle->getWorkerId();
const ucx_connection_ptr_t conn = getConnection(remote_agent);
if (!conn) {
NIXL_ERROR << "No connection found for remote agent: " << remote_agent;
return NIXL_ERR_NOT_FOUND;
}

const size_t count = end_idx - start_idx;
auto &sgl = int_handle->sgl;
sgl.resize(count);
for (size_t i = start_idx; i < end_idx; ++i) {
const size_t out = i - start_idx;
const auto lmd = static_cast<nixlUcxPrivateMetadata *>(local[i].metadataP);
const auto rmd = static_cast<nixlUcxPublicMetadata *>(remote[i].metadataP);
NIXL_ASSERT(local[i].len == remote[i].len);

sgl.localAddrs[out] = reinterpret_cast<void *>(local[i].addr);
sgl.remoteAddrs[out] = static_cast<uint64_t>(remote[i].addr);
sgl.lengths[out] = local[i].len;
sgl.memhs[out] = lmd->getMem().getMemh();
sgl.rkeys[out] = rmd->getRkey(worker_id).get();
}

const ucp_dt_local_sgl_t local_sgl = {
.field_mask = UCP_DT_LOCAL_SGL_FIELD_BUFFERS | UCP_DT_LOCAL_SGL_FIELD_LENGTHS |
UCP_DT_LOCAL_SGL_FIELD_MEMHS,
.buffers = sgl.localAddrs.data(),
.lengths = sgl.lengths.data(),
.memhs = sgl.memhs.data(),
};
const ucp_dt_remote_sgl_t remote_sgl = {
.field_mask = UCP_DT_REMOTE_SGL_FIELD_REMOTE_ADDRS | UCP_DT_REMOTE_SGL_FIELD_LENGTHS |
UCP_DT_REMOTE_SGL_FIELD_RKEYS,
.remote_addrs = sgl.remoteAddrs.data(),
.lengths = sgl.lengths.data(),
.rkeys = sgl.rkeys.data(),
};

auto &ep = conn->getEp(worker_id);

int_handle->reserve(single_ep_request_count);

nixlUcxReq req;
const nixl_status_t post_ret = ep->postSgl(local_sgl, remote_sgl, count, req);
Comment on lines +1249 to +1293

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🟠 Major | 🏗️ Heavy lift

Keep SGL posts scoped to a single UCX connection.

sendXferSgl() posts the entire range through getConnection(remote_agent), but the non-SGL path batches by each descriptor’s rmd->conn and flushes all distinct connections. If a range spans multiple connections, this sends remote addresses/rkeys through the wrong endpoint. Split the SGL path per rmd->conn or fall back unless all descriptors share the same connection.

🧰 Tools
🪛 GitHub Actions: Clang Format Check / 0_clang-format.txt

[error] 1269-1277: clang-format-diff-19 reported formatting changes required (field_mask line wrapping). Run clang-format-diff-19/clang-format to apply formatting.


[error] 1278-1286: clang-format-diff-19 reported formatting changes required (field_mask line wrapping). Run clang-format-diff-19/clang-format to apply formatting.

🪛 GitHub Actions: Clang Format Check / clang-format

[error] 1269-1276: clang-format-diff-19 reported formatting differences (clang format check failed). Run clang-format on this file to match project style.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/plugins/ucx/ucx_backend.cpp` around lines 1249 - 1293, The sendXferSgl()
path is using a single getConnection(remote_agent) for the whole range, which
can post remote addrs/rkeys through the wrong UCX endpoint when descriptors span
multiple connections. Update ucx_backend.cpp so the SGL send is scoped to one
connection by grouping descriptors by rmd->conn (like the non-SGL batching
logic) and posting each group on its own conn->getEp(worker_id), or detect mixed
connections and fall back to the non-SGL path. Keep the existing
local_sgl/remote_sgl setup but build it per connection instead of once for the
whole range.

if (int_handle->append(post_ret, req, conn) != NIXL_SUCCESS) {
return post_ret;
}

nixlUcxReq flush_req;
const nixl_status_t flush_ret = ep->flushEp(flush_req);
if (int_handle->append(flush_ret, flush_req, conn) != NIXL_SUCCESS) {
return flush_ret;
}

return NIXL_SUCCESS;
}
#endif

nixl_status_t
nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation,
const nixl_meta_dlist_t &local,
Expand All @@ -1216,9 +1318,13 @@ nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation,
return NIXL_ERR_INVALID_PARAM;
}

/* Assuming we have a single EP, we need 3 requests: one pending request,
* one flush request, and one notification request */
int_handle->reserve(3);
#ifdef HAVE_UCX_SGL_API
if (sglEnabled_ && operation == NIXL_WRITE) {
return sendXferSgl(local, remote, remote_agent, handle, start_idx, end_idx);
Comment on lines +1323 to +1325

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🟠 Major | 🏗️ Heavy lift

Only route to SGL after confirming the range is connection-homogeneous.

This branch bypasses the existing per-endpoint batching logic. Gate it on all remote[start_idx:end_idx] metadata resolving to one connection, or make sendXferSgl() perform the same batching internally.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/plugins/ucx/ucx_backend.cpp` around lines 1323 - 1325, The SGL fast path
in ucx_backend.cpp currently skips the per-endpoint batching logic in the send
path. Update the NIXL_WRITE branch in the transfer routine (the block that calls
sendXferSgl) so it only takes this route when the full remote range
remote[start_idx:end_idx] resolves to a single connection, or move the
batching/connection-homogeneity check into sendXferSgl itself. Use the existing
transfer flow and metadata helpers in this backend to verify the range before
returning the SGL path.

}
#endif

int_handle->reserve(single_ep_request_count);

for (size_t i = start_idx; i < end_idx;) {
/* Send requests to a single EP */
Expand Down
11 changes: 11 additions & 0 deletions src/plugins/ucx/ucx_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,16 @@ class nixlUcxEngine : public nixlBackendEngine {
size_t start_idx,
size_t end_idx);

#ifdef HAVE_UCX_SGL_API
nixl_status_t
sendXferSgl(const nixl_meta_dlist_t &local,
const nixl_meta_dlist_t &remote,
const std::string &remote_agent,
nixlBackendReqH *handle,
size_t start_idx,
size_t end_idx) const;
#endif

/**
* Get the worker ID from the optional arguments.
* Returns std::nullopt if the 'worker_id' option extraction fails.
Expand All @@ -292,6 +302,7 @@ class nixlUcxEngine : public nixlBackendEngine {
std::vector<std::unique_ptr<nixlUcxWorker>> uws;
std::string workerAddr;
mutable std::atomic<size_t> sharedWorkerIndex_;
bool sglEnabled_ = false;

// Map of agent name to saved nixlUcxConnection info
std::unordered_map<std::string, ucx_connection_ptr_t> remoteConnMap;
Expand Down
32 changes: 32 additions & 0 deletions src/plugins/ucx/ucx_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,38 @@ nixlUcxEp::flushEp(nixlUcxReq &req) {
return nixl::ucx::ucsToNixlStatus(UCS_PTR_STATUS(request));
}

#ifdef HAVE_UCX_SGL_API
nixl_status_t
nixlUcxEp::postSgl(const ucp_dt_local_sgl_t &local,
const ucp_dt_remote_sgl_t &remote,
size_t count,
nixlUcxReq &req) {
const nixl_status_t status = checkTxState();
if (status != NIXL_SUCCESS) {
return status;
}

const ucp_request_param_t param = {
.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_REMOTE_DATATYPE |
UCP_OP_ATTR_FIELD_REMOTE | UCP_OP_ATTR_FIELD_REMOTE_COUNT,
.datatype = ucp_dt_make_sgl(),
.remote_datatype = ucp_dt_make_sgl(),
.remote = &remote,
.remote_count = count,
};

const ucs_status_ptr_t request =
ucp_put_nbx(eph, &local, count, UCP_REMOTE_ADDR_INVALID, UCP_RKEY_INVALID, &param);
if (UCS_PTR_IS_PTR(request)) {
req = static_cast<nixlUcxReq>(request);
return NIXL_IN_PROG;
}

req = nullptr;
return nixl::ucx::ucsToNixlStatus(UCS_PTR_STATUS(request));
}
#endif

bool
nixlUcxMtLevelIsSupported(const nixl::ucx::mt_mode_t mt_type) noexcept {
ucp_lib_attr_t attr;
Expand Down
8 changes: 8 additions & 0 deletions src/plugins/ucx/ucx_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ class nixlUcxEp {
nixl_status_t
flushEp(nixlUcxReq &req);

#ifdef HAVE_UCX_SGL_API
[[nodiscard]] nixl_status_t
postSgl(const ucp_dt_local_sgl_t &local,
const ucp_dt_remote_sgl_t &remote,
size_t count,
nixlUcxReq &req);
Comment on lines +112 to +117

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Document the SGL buffer lifetime contract on the public endpoint API.

postSgl() stores UCX pointers into caller-owned SGL arrays, so callers must know those arrays must outlive request completion. Add a Doxygen block for the new public API. As per path instructions, “Use Doxygen block comments (/** ... */) for public APIs.”

Suggested documentation
 `#ifdef` HAVE_UCX_SGL_API
+    /**
+     * Post a UCX put using local and remote scatter-gather descriptors.
+     *
+     * The arrays referenced by `local` and `remote` must remain valid until
+     * the returned UCX request completes.
+     */
     [[nodiscard]] nixl_status_t
     postSgl(const ucp_dt_local_sgl_t &local,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#ifdef HAVE_UCX_SGL_API
[[nodiscard]] nixl_status_t
postSgl(const ucp_dt_local_sgl_t &local,
const ucp_dt_remote_sgl_t &remote,
size_t count,
nixlUcxReq &req);
`#ifdef` HAVE_UCX_SGL_API
/**
* Post a UCX put using local and remote scatter-gather descriptors.
*
* The arrays referenced by `local` and `remote` must remain valid until
* the returned UCX request completes.
*/
[[nodiscard]] nixl_status_t
postSgl(const ucp_dt_local_sgl_t &local,
const ucp_dt_remote_sgl_t &remote,
size_t count,
nixlUcxReq &req);
`#endif`
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/plugins/ucx/ucx_utils.h` around lines 112 - 117, The public endpoint API
`postSgl()` in `ucx_utils.h` needs a Doxygen block documenting the SGL lifetime
contract. Add a `/** ... */` comment above `postSgl()` that clearly states the
caller-owned `ucp_dt_local_sgl_t` and `ucp_dt_remote_sgl_t` arrays must remain
valid until the associated request completes, so users know the stored UCX
pointers must outlive completion. Use the function name `postSgl()` and the
`nixlUcxReq` request type in the comment to make the ownership and completion
requirement explicit.

Source: Path instructions

#endif

[[nodiscard]] ucp_ep_h
getEp() const noexcept {
return eph;
Expand Down
Loading