Skip to content

Commit 6791d49

Browse files
committed
[ET Device Support] CUDA-native Qwen 3.5 MoE inference with device tensor pipeline
Pull Request resolved: #18788 Integrate the ET device tensor pipeline into the Qwen 3.5 MoE model to eliminate unnecessary H2D/D2H copies during inference. - Export: Multi-method export (`forward` + `sample`) with device memory planning enabled and method-level H2D/D2H skipping. - Runner: Custom CUDA-native inference loop that keeps logits on GPU between forward and sample, reuses CUDA tensors across iterations, and only copies the 8-byte token ID back to CPU for EOS checking. ghstack-source-id: 392754115 @exported-using-ghexport Differential Revision: [D100133933](https://our.internmc.facebook.com/intern/diff/D100133933/)
1 parent 1060d28 commit 6791d49

3 files changed

Lines changed: 109 additions & 14 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,7 @@ def _export_cuda(model, config, args):
942942
)
943943
from executorch.exir.backend.compile_spec_schema import CompileSpec
944944
from executorch.exir.passes import MemoryPlanningPass
945+
from executorch.exir.passes.propagate_device_pass import PropagateDeviceConfig
945946
from torch.export import Dim, export
946947

947948
# Coordinate descent recompiles each kernel trying config perturbations,
@@ -1038,7 +1039,10 @@ def _export_cuda(model, config, args):
10381039
extract_delegate_segments=True,
10391040
do_quant_fusion_and_const_prop=True,
10401041
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
1041-
emit_mutable_buffer_names=True,
1042+
propagate_device_config=PropagateDeviceConfig(
1043+
skip_h2d_for_method_inputs=True,
1044+
skip_d2h_for_method_outputs=True,
1045+
),
10421046
),
10431047
)
10441048

examples/models/qwen3_5_moe/main.cpp

Lines changed: 102 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313
#include <executorch/extension/llm/runner/util.h>
1414
#include <executorch/extension/module/module.h>
1515
#include <executorch/extension/tensor/tensor.h>
16+
#include <executorch/extension/tensor/tensor_ptr.h>
1617
#include <executorch/runtime/backend/interface.h>
1718
#include <executorch/runtime/backend/options.h>
19+
#include <executorch/runtime/core/portable_type/device.h>
20+
#include <executorch/runtime/platform/assert.h>
1821
#include <executorch/runtime/platform/log.h>
1922
#include <pytorch/tokenizers/hf_tokenizer.h>
2023

2124
#include <algorithm>
2225
#include <cinttypes>
2326
#include <fstream>
27+
#include <numeric>
2428
#include <string>
2529
#include <vector>
2630

@@ -51,14 +55,22 @@ using ::executorch::extension::Module;
5155
using ::executorch::extension::TensorPtr;
5256
using ::executorch::runtime::Error;
5357
using ::executorch::runtime::EValue;
58+
#ifdef EXECUTORCH_BUILD_CUDA
59+
using ::executorch::extension::clone_tensor_ptr_to;
60+
#endif
5461

5562
using SizesType = executorch::aten::SizesType;
5663

5764
// Convert a model output tensor to the next sampled token id.
5865
//
5966
// On the CUDA build, the model fuses the sampler in (see sampler.py /
6067
// Qwen35MoE.forward) and returns a single sampled token id as a [B, 1]
61-
// float tensor; we just copy that scalar back from device.
68+
// int64 tensor that lives in CUDA device memory (skip_d2h keeps method
69+
// outputs on-device). We copy just that 8-byte scalar back to host — this
70+
// is the only device->host transfer per decode step, needed for EOS
71+
// detection and streaming detokenization. The token is fed to the next
72+
// step device->device (see the decode loop), so no host round-trip occurs
73+
// for the model input.
6274
//
6375
// On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits
6476
// of shape [B, T, V] in the model dtype (typically bf16). We sample on
@@ -72,10 +84,10 @@ static uint64_t read_token(const executorch::aten::Tensor& output) {
7284
bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess &&
7385
attrs.type == cudaMemoryTypeDevice;
7486

75-
float val;
87+
int64_t val;
7688
if (on_device) {
7789
cudaError_t err =
78-
cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost);
90+
cudaMemcpy(&val, ptr, sizeof(int64_t), cudaMemcpyDeviceToHost);
7991
if (err != cudaSuccess) {
8092
ET_LOG(
8193
Error,
@@ -84,7 +96,7 @@ static uint64_t read_token(const executorch::aten::Tensor& output) {
8496
return 0;
8597
}
8698
} else {
87-
memcpy(&val, ptr, sizeof(float));
99+
memcpy(&val, ptr, sizeof(int64_t));
88100
}
89101
return static_cast<uint64_t>(val);
90102
#else
@@ -272,10 +284,20 @@ int main(int argc, char** argv) {
272284
// a third input. Use a very small temperature for greedy to avoid
273285
// division by zero while keeping the Gumbel noise negligible relative
274286
// to logit differences.
287+
//
288+
// The export lowered this program with skip_h2d_for_method_inputs=True,
289+
// so the CUDA backend requires every method input to already live in
290+
// CUDA device memory (no host->device copy is inserted in the graph).
291+
// We therefore stage all inputs on-device via clone_tensor_ptr_to. The
292+
// temperature is constant, so it is cloned to the device exactly once
293+
// and reused for prefill and every decode step.
294+
auto cuda_device =
295+
executorch::aten::Device(executorch::aten::DeviceType::CUDA, 0);
275296
float temp_val =
276297
FLAGS_temperature <= 0.0 ? 1e-6f : static_cast<float>(FLAGS_temperature);
277-
auto temp_tensor =
278-
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);
298+
auto temp_tensor = clone_tensor_ptr_to(
299+
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float),
300+
cuda_device);
279301
#endif
280302

281303
stats.inference_start_ms = llm::time_in_ms();
@@ -298,14 +320,22 @@ int main(int argc, char** argv) {
298320
pos_data[i] = i;
299321
}
300322
std::vector<int64_t> token_data(prompt_tokens.begin(), prompt_tokens.end());
301-
auto tokens_tensor = from_blob(
323+
auto tokens_cpu = from_blob(
302324
token_data.data(),
303325
{1, S(num_prompt_tokens)},
304326
executorch::aten::ScalarType::Long);
305-
auto pos_tensor = from_blob(
327+
auto pos_cpu = from_blob(
306328
pos_data.data(),
307329
{S(num_prompt_tokens)},
308330
executorch::aten::ScalarType::Long);
331+
#ifdef EXECUTORCH_BUILD_CUDA
332+
// Stage prefill inputs in CUDA device memory (see temperature note above).
333+
auto tokens_tensor = clone_tensor_ptr_to(tokens_cpu, cuda_device);
334+
auto pos_tensor = clone_tensor_ptr_to(pos_cpu, cuda_device);
335+
#else
336+
auto tokens_tensor = tokens_cpu;
337+
auto pos_tensor = pos_cpu;
338+
#endif
309339

310340
std::vector<EValue> prefill_inputs;
311341
prefill_inputs.push_back(tokens_tensor);
@@ -348,14 +378,57 @@ int main(int argc, char** argv) {
348378

349379
std::vector<int64_t> decode_token_data = {static_cast<int64_t>(cur_token)};
350380
std::vector<int64_t> decode_pos_data = {pos};
351-
auto decode_tokens = from_blob(
381+
auto decode_tokens_cpu = from_blob(
352382
decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long);
353-
auto decode_pos = from_blob(
383+
auto decode_pos_cpu = from_blob(
354384
decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long);
385+
#ifdef EXECUTORCH_BUILD_CUDA
386+
// Device-resident decode loop. The decode method's token input and its
387+
// fused sampled-token output are both int64 [1,1] living in CUDA memory
388+
// (skip_h2d on inputs, skip_d2h on outputs). We keep fixed device buffers
389+
// (CUDA graph requires stable input addresses) and feed each step's output
390+
// straight into the next step's token input with a device->device copy —
391+
// no host round-trip for the model I/O. The initial clone seeds
392+
// decode_tokens with the prefill-sampled token (one-time H2D at setup).
393+
auto decode_tokens = clone_tensor_ptr_to(decode_tokens_cpu, cuda_device);
394+
auto decode_pos = clone_tensor_ptr_to(decode_pos_cpu, cuda_device);
395+
396+
// Precompute every decode position on-device with a SINGLE H2D up front, so
397+
// the per-step position update becomes a device->device copy (no per-step
398+
// H2D). positions[k] = num_prompt_tokens + k.
399+
std::vector<int64_t> all_pos_data(FLAGS_max_new_tokens);
400+
std::iota(all_pos_data.begin(), all_pos_data.end(), pos);
401+
auto all_pos = clone_tensor_ptr_to(
402+
from_blob(
403+
all_pos_data.data(),
404+
{S(FLAGS_max_new_tokens)},
405+
executorch::aten::ScalarType::Long),
406+
cuda_device);
407+
const auto* all_pos_dev =
408+
static_cast<const int64_t*>(all_pos->const_data_ptr());
409+
#else
410+
auto decode_tokens = decode_tokens_cpu;
411+
auto decode_pos = decode_pos_cpu;
412+
#endif
355413

356414
for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) {
415+
#ifdef EXECUTORCH_BUILD_CUDA
416+
// Set this step's position via device->device copy from the precomputed
417+
// on-device array (no per-step H2D). The token input (decode_tokens)
418+
// already holds the token to feed: the prefill-sampled token on step 0,
419+
// and the previous step's output (copied in device->device at the end of
420+
// the prior iteration) on every later step.
421+
ET_CHECK_MSG(
422+
cudaMemcpy(
423+
decode_pos->mutable_data_ptr(),
424+
all_pos_dev + step,
425+
sizeof(int64_t),
426+
cudaMemcpyDeviceToDevice) == cudaSuccess,
427+
"Failed to set decode position device-to-device");
428+
#else
357429
decode_token_data[0] = static_cast<int64_t>(cur_token);
358430
decode_pos_data[0] = pos;
431+
#endif
359432

360433
std::vector<EValue> decode_inputs;
361434
decode_inputs.push_back(EValue(decode_tokens));
@@ -370,9 +443,27 @@ int main(int argc, char** argv) {
370443
return 1;
371444
}
372445
auto& decode_outputs = decode_result.get();
446+
const auto& out_tensor = decode_outputs[0].toTensor();
373447

374448
prev_token = cur_token;
375-
cur_token = read_token(decode_outputs[0].toTensor());
449+
// Single per-step device->host copy: the 8-byte sampled token id, needed
450+
// for EOS detection and streaming detokenization below.
451+
cur_token = read_token(out_tensor);
452+
453+
#ifdef EXECUTORCH_BUILD_CUDA
454+
// Feed this step's sampled token straight into the next step's token input
455+
// on-device (device->device). This replaces the old host re-upload (H2D)
456+
// and, together with read_token's D2H above, leaves exactly one 8-byte
457+
// D2H and zero H2D per decode step. read_token's synchronous D2H has
458+
// already forced the output to be ready, so the copy below is well-ordered.
459+
ET_CHECK_MSG(
460+
cudaMemcpy(
461+
decode_tokens->mutable_data_ptr(),
462+
out_tensor.const_data_ptr(),
463+
sizeof(int64_t),
464+
cudaMemcpyDeviceToDevice) == cudaSuccess,
465+
"Failed to feed decode token device-to-device");
466+
#endif
376467

377468
if (step == 0) {
378469
stats.first_token_ms = llm::time_in_ms();

examples/models/qwen3_5_moe/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def sample(
4242
sampler returns the unmodified ``logits`` tensor.
4343
4444
Returns:
45-
``[B, 1]`` float32 tensor of sampled token IDs, or the unmodified
45+
``[B, 1]`` int64 tensor of sampled token IDs, or the unmodified
4646
``logits`` tensor when ``temperature`` is ``None``.
4747
"""
4848
# No sampling configured — return raw logits.
@@ -57,4 +57,4 @@ def sample(
5757
# float32 note in the docstring.
5858
noise = torch.rand_like(logits)
5959
gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
60-
return (logits + gumbel).argmax(dim=-1, keepdim=True).float()
60+
return (logits + gumbel).argmax(dim=-1, keepdim=True)

0 commit comments

Comments
 (0)