Skip to content

[Optimization] Support logprob overlap in speculative decoding#7600

Open
Deleter-D wants to merge 6 commits intoPaddlePaddle:developfrom
Deleter-D:dev_logprob_overlap
Open

[Optimization] Support logprob overlap in speculative decoding#7600
Deleter-D wants to merge 6 commits intoPaddlePaddle:developfrom
Deleter-D:dev_logprob_overlap

Conversation

@Deleter-D
Copy link
Copy Markdown
Collaborator

@Deleter-D Deleter-D commented Apr 23, 2026

Motivation

Improve performance of speculative decoding by overlapping logprob computation with token acceptance operations.

Modifications

  • Merged speculate_get_target_logits into speculate_get_accept_tokens_and_logits to enable kernel-level overlap
  • Added compute_cu_batch_offset_kernel for efficient batch offset calculation
  • Integrated accept tokens retrieval with target logits extraction in a single kernel

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 23, 2026

Thanks for your contribution!

PaddlePaddle-bot

This comment was marked as outdated.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 23, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@9236d0c). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7600   +/-   ##
==========================================
  Coverage           ?   71.68%           
==========================================
  Files              ?      419           
  Lines              ?    57849           
  Branches           ?     9077           
==========================================
  Hits               ?    41470           
  Misses             ?    13551           
  Partials           ?     2828           
Flag Coverage Δ
GPU 71.68% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

@Deleter-D Deleter-D changed the title [Optimization] support logprob overlap [Optimization] Support logprob overlap in speculative decoding Apr 24, 2026
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-24 17:44:55

📋 Review 摘要

PR 概述:将 speculate_get_target_logits 与 accepted token 提取合并为单个 CUDA kernel,通过 GPU 端计算 cu_batch_token_offset 实现 logprob 与 token acceptance 的 kernel 级重叠,提升投机解码性能。
变更范围custom_ops/gpu_ops/speculate_decoding/fastdeploy/model_executor/layers/sample/
影响面 TagOP Speculative Decoding


问题

级别 文件 概述
🔴 Bug logprobs.py:216 gather_logprobs 使用了包含填充 0 的完整 token_ids/raw_logprobs,导致 logprobs 结果错误
🟡 建议 speculate_logprob_utils.cu:238 token_ids 写入缺少 C++ 层边界断言防护
🟡 建议 speculate_logprob_utils.cu:282 两个 kernel launch 后缺少 CUDA error 检查

核心 Bug 说明

gather_logprobs 使用填充数据导致 logprobs 错误logprobs.py 第 190-216 行):

output_logitstoken_ids 分配了 real_bsz * max_draft_token_num_plus_1 大小,但 kernel 实际只写入了 accept_num.sum() 个有效行(accept_num.sum() ≤ real_bsz * max_draft_token_num_plus_1)。剩余位置:

  • token_ids 中仍为 fill_value=0(即 token_id=0)
  • output_logits 中为未初始化(paddle.empty)数据

gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) 会对全量数据操作,导致为 token_id=0 错误地提取 logprob,最终 logprobs 结果包含无效数据。

建议修复:在 gather_logprobs 调用前截断到有效长度。


总体评价

本 PR 的优化方向正确,将 CPU 端的 accept_num cumsum 计算下沉到 GPU kernel,减少了同步开销。但 Python 调用层存在 output_logits/token_ids 预分配过大后未截断就传入 gather_logprobs 的 Bug,会导致 logprob 计算结果包含填充噪声数据,建议修复后合入。

Comment thread fastdeploy/model_executor/layers/sample/logprobs.py
Comment thread custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu
Comment thread custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants