1313#include < executorch/extension/llm/runner/util.h>
1414#include < executorch/extension/module/module.h>
1515#include < executorch/extension/tensor/tensor.h>
16+ #include < executorch/extension/tensor/tensor_ptr.h>
1617#include < executorch/runtime/backend/interface.h>
1718#include < executorch/runtime/backend/options.h>
19+ #include < executorch/runtime/core/portable_type/device.h>
20+ #include < executorch/runtime/platform/assert.h>
1821#include < executorch/runtime/platform/log.h>
1922#include < pytorch/tokenizers/hf_tokenizer.h>
2023
2124#include < algorithm>
2225#include < cinttypes>
2326#include < fstream>
27+ #include < numeric>
2428#include < string>
2529#include < vector>
2630
@@ -51,14 +55,22 @@ using ::executorch::extension::Module;
5155using ::executorch::extension::TensorPtr;
5256using ::executorch::runtime::Error;
5357using ::executorch::runtime::EValue;
58+ #ifdef EXECUTORCH_BUILD_CUDA
59+ using ::executorch::extension::clone_tensor_ptr_to;
60+ #endif
5461
5562using SizesType = executorch::aten::SizesType;
5663
5764// Convert a model output tensor to the next sampled token id.
5865//
5966// On the CUDA build, the model fuses the sampler in (see sampler.py /
6067// Qwen35MoE.forward) and returns a single sampled token id as a [B, 1]
61- // float tensor; we just copy that scalar back from device.
68+ // int64 tensor that lives in CUDA device memory (skip_d2h keeps method
69+ // outputs on-device). We copy just that 8-byte scalar back to host — this
70+ // is the only device->host transfer per decode step, needed for EOS
71+ // detection and streaming detokenization. The token is fed to the next
72+ // step device->device (see the decode loop), so no host round-trip occurs
73+ // for the model input.
6274//
6375// On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits
6476// of shape [B, T, V] in the model dtype (typically bf16). We sample on
@@ -72,10 +84,10 @@ static uint64_t read_token(const executorch::aten::Tensor& output) {
7284 bool on_device = cudaPointerGetAttributes (&attrs, ptr) == cudaSuccess &&
7385 attrs.type == cudaMemoryTypeDevice;
7486
75- float val;
87+ int64_t val;
7688 if (on_device) {
7789 cudaError_t err =
78- cudaMemcpy (&val, ptr, sizeof (float ), cudaMemcpyDeviceToHost);
90+ cudaMemcpy (&val, ptr, sizeof (int64_t ), cudaMemcpyDeviceToHost);
7991 if (err != cudaSuccess) {
8092 ET_LOG (
8193 Error,
@@ -84,7 +96,7 @@ static uint64_t read_token(const executorch::aten::Tensor& output) {
8496 return 0 ;
8597 }
8698 } else {
87- memcpy (&val, ptr, sizeof (float ));
99+ memcpy (&val, ptr, sizeof (int64_t ));
88100 }
89101 return static_cast <uint64_t >(val);
90102#else
@@ -272,10 +284,20 @@ int main(int argc, char** argv) {
272284 // a third input. Use a very small temperature for greedy to avoid
273285 // division by zero while keeping the Gumbel noise negligible relative
274286 // to logit differences.
287+ //
288+ // The export lowered this program with skip_h2d_for_method_inputs=True,
289+ // so the CUDA backend requires every method input to already live in
290+ // CUDA device memory (no host->device copy is inserted in the graph).
291+ // We therefore stage all inputs on-device via clone_tensor_ptr_to. The
292+ // temperature is constant, so it is cloned to the device exactly once
293+ // and reused for prefill and every decode step.
294+ auto cuda_device =
295+ executorch::aten::Device (executorch::aten::DeviceType::CUDA , 0 );
275296 float temp_val =
276297 FLAGS_temperature <= 0.0 ? 1e-6f : static_cast <float >(FLAGS_temperature);
277- auto temp_tensor =
278- from_blob (&temp_val, {1 }, executorch::aten::ScalarType::Float);
298+ auto temp_tensor = clone_tensor_ptr_to (
299+ from_blob (&temp_val, {1 }, executorch::aten::ScalarType::Float),
300+ cuda_device);
279301#endif
280302
281303 stats.inference_start_ms = llm::time_in_ms ();
@@ -298,14 +320,22 @@ int main(int argc, char** argv) {
298320 pos_data[i] = i;
299321 }
300322 std::vector<int64_t > token_data (prompt_tokens.begin (), prompt_tokens.end ());
301- auto tokens_tensor = from_blob (
323+ auto tokens_cpu = from_blob (
302324 token_data.data (),
303325 {1 , S (num_prompt_tokens)},
304326 executorch::aten::ScalarType::Long);
305- auto pos_tensor = from_blob (
327+ auto pos_cpu = from_blob (
306328 pos_data.data (),
307329 {S (num_prompt_tokens)},
308330 executorch::aten::ScalarType::Long);
331+ #ifdef EXECUTORCH_BUILD_CUDA
332+ // Stage prefill inputs in CUDA device memory (see temperature note above).
333+ auto tokens_tensor = clone_tensor_ptr_to (tokens_cpu, cuda_device);
334+ auto pos_tensor = clone_tensor_ptr_to (pos_cpu, cuda_device);
335+ #else
336+ auto tokens_tensor = tokens_cpu;
337+ auto pos_tensor = pos_cpu;
338+ #endif
309339
310340 std::vector<EValue> prefill_inputs;
311341 prefill_inputs.push_back (tokens_tensor);
@@ -348,14 +378,57 @@ int main(int argc, char** argv) {
348378
349379 std::vector<int64_t > decode_token_data = {static_cast <int64_t >(cur_token)};
350380 std::vector<int64_t > decode_pos_data = {pos};
351- auto decode_tokens = from_blob (
381+ auto decode_tokens_cpu = from_blob (
352382 decode_token_data.data (), {1 , 1 }, executorch::aten::ScalarType::Long);
353- auto decode_pos = from_blob (
383+ auto decode_pos_cpu = from_blob (
354384 decode_pos_data.data (), {1 }, executorch::aten::ScalarType::Long);
385+ #ifdef EXECUTORCH_BUILD_CUDA
386+ // Device-resident decode loop. The decode method's token input and its
387+ // fused sampled-token output are both int64 [1,1] living in CUDA memory
388+ // (skip_h2d on inputs, skip_d2h on outputs). We keep fixed device buffers
389+ // (CUDA graph requires stable input addresses) and feed each step's output
390+ // straight into the next step's token input with a device->device copy —
391+ // no host round-trip for the model I/O. The initial clone seeds
392+ // decode_tokens with the prefill-sampled token (one-time H2D at setup).
393+ auto decode_tokens = clone_tensor_ptr_to (decode_tokens_cpu, cuda_device);
394+ auto decode_pos = clone_tensor_ptr_to (decode_pos_cpu, cuda_device);
395+
396+ // Precompute every decode position on-device with a SINGLE H2D up front, so
397+ // the per-step position update becomes a device->device copy (no per-step
398+ // H2D). positions[k] = num_prompt_tokens + k.
399+ std::vector<int64_t > all_pos_data (FLAGS_max_new_tokens);
400+ std::iota (all_pos_data.begin (), all_pos_data.end (), pos);
401+ auto all_pos = clone_tensor_ptr_to (
402+ from_blob (
403+ all_pos_data.data (),
404+ {S (FLAGS_max_new_tokens)},
405+ executorch::aten::ScalarType::Long),
406+ cuda_device);
407+ const auto * all_pos_dev =
408+ static_cast <const int64_t *>(all_pos->const_data_ptr ());
409+ #else
410+ auto decode_tokens = decode_tokens_cpu;
411+ auto decode_pos = decode_pos_cpu;
412+ #endif
355413
356414 for (int32_t step = 0 ; step < FLAGS_max_new_tokens; step++) {
415+ #ifdef EXECUTORCH_BUILD_CUDA
416+ // Set this step's position via device->device copy from the precomputed
417+ // on-device array (no per-step H2D). The token input (decode_tokens)
418+ // already holds the token to feed: the prefill-sampled token on step 0,
419+ // and the previous step's output (copied in device->device at the end of
420+ // the prior iteration) on every later step.
421+ ET_CHECK_MSG (
422+ cudaMemcpy (
423+ decode_pos->mutable_data_ptr (),
424+ all_pos_dev + step,
425+ sizeof (int64_t ),
426+ cudaMemcpyDeviceToDevice) == cudaSuccess,
427+ " Failed to set decode position device-to-device" );
428+ #else
357429 decode_token_data[0 ] = static_cast <int64_t >(cur_token);
358430 decode_pos_data[0 ] = pos;
431+ #endif
359432
360433 std::vector<EValue> decode_inputs;
361434 decode_inputs.push_back (EValue (decode_tokens));
@@ -370,9 +443,27 @@ int main(int argc, char** argv) {
370443 return 1 ;
371444 }
372445 auto & decode_outputs = decode_result.get ();
446+ const auto & out_tensor = decode_outputs[0 ].toTensor ();
373447
374448 prev_token = cur_token;
375- cur_token = read_token (decode_outputs[0 ].toTensor ());
449+ // Single per-step device->host copy: the 8-byte sampled token id, needed
450+ // for EOS detection and streaming detokenization below.
451+ cur_token = read_token (out_tensor);
452+
453+ #ifdef EXECUTORCH_BUILD_CUDA
454+ // Feed this step's sampled token straight into the next step's token input
455+ // on-device (device->device). This replaces the old host re-upload (H2D)
456+ // and, together with read_token's D2H above, leaves exactly one 8-byte
457+ // D2H and zero H2D per decode step. read_token's synchronous D2H has
458+ // already forced the output to be ready, so the copy below is well-ordered.
459+ ET_CHECK_MSG (
460+ cudaMemcpy (
461+ decode_tokens->mutable_data_ptr (),
462+ out_tensor.const_data_ptr (),
463+ sizeof (int64_t ),
464+ cudaMemcpyDeviceToDevice) == cudaSuccess,
465+ " Failed to feed decode token device-to-device" );
466+ #endif
376467
377468 if (step == 0 ) {
378469 stats.first_token_ms = llm::time_in_ms ();
0 commit comments