-
Notifications
You must be signed in to change notification settings - Fork 358
PLUGINS/UCX: add scatter-gather (SGL) put path #1835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| #include "common/nixl_log.h" | ||
| #include "serdes/serdes.h" | ||
| #include "common/backend.h" | ||
| #include "common/configuration.h" | ||
| #include "common/nixl_log.h" | ||
|
|
||
| #include <optional> | ||
|
|
@@ -32,6 +33,10 @@ | |
| #include "absl/strings/str_split.h" | ||
| #include <asio.hpp> | ||
|
|
||
| // A transfer to a single endpoint posts at most three requests: | ||
| // one data request, one flush request, and one notification request. | ||
| constexpr size_t single_ep_request_count = 3; | ||
|
|
||
| /**************************************** | ||
| * Backend request management | ||
| *****************************************/ | ||
|
|
@@ -76,6 +81,25 @@ class nixlUcxBackendReqH : public nixlBackendReqH { | |
|
|
||
| std::optional<Notif> notif; | ||
|
|
||
| #ifdef HAVE_UCX_SGL_API | ||
| struct sglBuffers { | ||
| std::vector<void *> localAddrs; | ||
| std::vector<uint64_t> remoteAddrs; | ||
| std::vector<size_t> lengths; | ||
| std::vector<ucp_mem_h> memhs; | ||
| std::vector<ucp_rkey_h> rkeys; | ||
|
|
||
| void | ||
| resize(size_t count) { | ||
| localAddrs.resize(count); | ||
| remoteAddrs.resize(count); | ||
| lengths.resize(count); | ||
| memhs.resize(count); | ||
| rkeys.resize(count); | ||
| } | ||
| } sgl; | ||
| #endif | ||
|
|
||
| nixlUcxBackendReqH(nixlUcxWorker *worker, size_t worker_id) | ||
| : worker_(worker), | ||
| workerId_(worker_id) {} | ||
|
|
@@ -826,6 +850,17 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params) | |
| const auto engine_config = | ||
| nixl::getBackendParamDefaulted(custom_params, "engine_config", std::string()); | ||
|
|
||
| sglEnabled_ = nixl::config::getValueOptional<bool>("NIXL_UCX_SGL_ENABLE").value_or(false); | ||
| #ifdef HAVE_UCX_SGL_API | ||
| NIXL_DEBUG << "UCX SGL offload " << (sglEnabled_ ? "enabled" : "disabled"); | ||
| #else | ||
| if (sglEnabled_) { | ||
| NIXL_WARN << "NIXL_UCX_SGL_ENABLE is set but NIXL was built without UCX SGL " | ||
| "support"; | ||
| sglEnabled_ = false; | ||
| } | ||
| #endif | ||
|
|
||
| uc = std::make_unique<nixlUcxContext>(devs, | ||
| init_params.enableProgTh, | ||
| num_workers, | ||
|
|
@@ -1201,6 +1236,73 @@ nixlUcxEngine::sendXferRangeBatch(nixlUcxEp &ep, | |
| return result; | ||
| } | ||
|
|
||
| #ifdef HAVE_UCX_SGL_API | ||
| nixl_status_t | ||
| nixlUcxEngine::sendXferSgl(const nixl_meta_dlist_t &local, | ||
| const nixl_meta_dlist_t &remote, | ||
| const std::string &remote_agent, | ||
| nixlBackendReqH *handle, | ||
| size_t start_idx, | ||
| size_t end_idx) const { | ||
| const auto int_handle = static_cast<nixlUcxBackendReqH *>(handle); | ||
| const size_t worker_id = int_handle->getWorkerId(); | ||
| const ucx_connection_ptr_t conn = getConnection(remote_agent); | ||
| if (!conn) { | ||
| NIXL_ERROR << "No connection found for remote agent: " << remote_agent; | ||
| return NIXL_ERR_NOT_FOUND; | ||
| } | ||
|
|
||
| const size_t count = end_idx - start_idx; | ||
| auto &sgl = int_handle->sgl; | ||
| sgl.resize(count); | ||
| for (size_t i = start_idx; i < end_idx; ++i) { | ||
| const size_t out = i - start_idx; | ||
| const auto lmd = static_cast<nixlUcxPrivateMetadata *>(local[i].metadataP); | ||
| const auto rmd = static_cast<nixlUcxPublicMetadata *>(remote[i].metadataP); | ||
| NIXL_ASSERT(local[i].len == remote[i].len); | ||
|
|
||
| sgl.localAddrs[out] = reinterpret_cast<void *>(local[i].addr); | ||
| sgl.remoteAddrs[out] = static_cast<uint64_t>(remote[i].addr); | ||
| sgl.lengths[out] = local[i].len; | ||
| sgl.memhs[out] = lmd->getMem().getMemh(); | ||
| sgl.rkeys[out] = rmd->getRkey(worker_id).get(); | ||
| } | ||
|
|
||
| const ucp_dt_local_sgl_t local_sgl = { | ||
| .field_mask = UCP_DT_LOCAL_SGL_FIELD_BUFFERS | UCP_DT_LOCAL_SGL_FIELD_LENGTHS | | ||
| UCP_DT_LOCAL_SGL_FIELD_MEMHS, | ||
| .buffers = sgl.localAddrs.data(), | ||
| .lengths = sgl.lengths.data(), | ||
| .memhs = sgl.memhs.data(), | ||
| }; | ||
| const ucp_dt_remote_sgl_t remote_sgl = { | ||
| .field_mask = UCP_DT_REMOTE_SGL_FIELD_REMOTE_ADDRS | UCP_DT_REMOTE_SGL_FIELD_LENGTHS | | ||
| UCP_DT_REMOTE_SGL_FIELD_RKEYS, | ||
| .remote_addrs = sgl.remoteAddrs.data(), | ||
| .lengths = sgl.lengths.data(), | ||
| .rkeys = sgl.rkeys.data(), | ||
| }; | ||
|
|
||
| auto &ep = conn->getEp(worker_id); | ||
|
|
||
| int_handle->reserve(single_ep_request_count); | ||
|
|
||
| nixlUcxReq req; | ||
| const nixl_status_t post_ret = ep->postSgl(local_sgl, remote_sgl, count, req); | ||
|
Comment on lines
+1249
to
+1293
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🗄️ Data Integrity & Integration | 🟠 Major | 🏗️ Heavy lift Keep SGL posts scoped to a single UCX connection.
🧰 Tools🪛 GitHub Actions: Clang Format Check / 0_clang-format.txt[error] 1269-1277: clang-format-diff-19 reported formatting changes required (field_mask line wrapping). Run clang-format-diff-19/clang-format to apply formatting. [error] 1278-1286: clang-format-diff-19 reported formatting changes required (field_mask line wrapping). Run clang-format-diff-19/clang-format to apply formatting. 🪛 GitHub Actions: Clang Format Check / clang-format[error] 1269-1276: clang-format-diff-19 reported formatting differences (clang format check failed). Run clang-format on this file to match project style. 🤖 Prompt for AI Agents |
||
| if (int_handle->append(post_ret, req, conn) != NIXL_SUCCESS) { | ||
| return post_ret; | ||
| } | ||
|
|
||
| nixlUcxReq flush_req; | ||
| const nixl_status_t flush_ret = ep->flushEp(flush_req); | ||
| if (int_handle->append(flush_ret, flush_req, conn) != NIXL_SUCCESS) { | ||
| return flush_ret; | ||
| } | ||
|
|
||
| return NIXL_SUCCESS; | ||
| } | ||
| #endif | ||
|
|
||
| nixl_status_t | ||
| nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation, | ||
| const nixl_meta_dlist_t &local, | ||
|
|
@@ -1216,9 +1318,13 @@ nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation, | |
| return NIXL_ERR_INVALID_PARAM; | ||
| } | ||
|
|
||
| /* Assuming we have a single EP, we need 3 requests: one pending request, | ||
| * one flush request, and one notification request */ | ||
| int_handle->reserve(3); | ||
| #ifdef HAVE_UCX_SGL_API | ||
| if (sglEnabled_ && operation == NIXL_WRITE) { | ||
| return sendXferSgl(local, remote, remote_agent, handle, start_idx, end_idx); | ||
|
Comment on lines
+1323
to
+1325
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🗄️ Data Integrity & Integration | 🟠 Major | 🏗️ Heavy lift Only route to SGL after confirming the range is connection-homogeneous. This branch bypasses the existing per-endpoint batching logic. Gate it on all 🤖 Prompt for AI Agents |
||
| } | ||
| #endif | ||
|
|
||
| int_handle->reserve(single_ep_request_count); | ||
|
|
||
| for (size_t i = start_idx; i < end_idx;) { | ||
| /* Send requests to a single EP */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -109,6 +109,14 @@ class nixlUcxEp { | |||||||||||||||||||||||||||||||||||||||
| nixl_status_t | ||||||||||||||||||||||||||||||||||||||||
| flushEp(nixlUcxReq &req); | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| #ifdef HAVE_UCX_SGL_API | ||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] nixl_status_t | ||||||||||||||||||||||||||||||||||||||||
| postSgl(const ucp_dt_local_sgl_t &local, | ||||||||||||||||||||||||||||||||||||||||
| const ucp_dt_remote_sgl_t &remote, | ||||||||||||||||||||||||||||||||||||||||
| size_t count, | ||||||||||||||||||||||||||||||||||||||||
| nixlUcxReq &req); | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+112
to
+117
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win Document the SGL buffer lifetime contract on the public endpoint API.
Suggested documentation `#ifdef` HAVE_UCX_SGL_API
+ /**
+ * Post a UCX put using local and remote scatter-gather descriptors.
+ *
+ * The arrays referenced by `local` and `remote` must remain valid until
+ * the returned UCX request completes.
+ */
[[nodiscard]] nixl_status_t
postSgl(const ucp_dt_local_sgl_t &local,📝 Committable suggestion
Suggested change
🤖 Prompt for AI AgentsSource: Path instructions |
||||||||||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] ucp_ep_h | ||||||||||||||||||||||||||||||||||||||||
| getEp() const noexcept { | ||||||||||||||||||||||||||||||||||||||||
| return eph; | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...::getValueDefaulted("NIXL_UCX_SGL_ENABLE", false)