-
Notifications
You must be signed in to change notification settings - Fork 742
[Cherry-Pick][Speculative Decoding][BugFix] overlap compute logprobs for speculative decoding (#7406) #7585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/2.6
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| // Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "helper.h" | ||
| #include "paddle/extension.h" | ||
|
|
||
| #ifndef PD_BUILD_STATIC_OP | ||
| #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) | ||
| #endif | ||
|
|
||
| template <typename T> | ||
| __global__ void BuildSamplingParamLogProbKernel( | ||
| T* output_params, | ||
| const T* input_params, | ||
| const int32_t* token_num_per_batch, | ||
| const int64_t token_num_output_cpu) { | ||
| const int bi = blockIdx.x; | ||
| const int tid = threadIdx.x; | ||
|
|
||
| // Compute start offset: sum of token_num_per_batch[0..bi-1] | ||
| int start_offset = 0; | ||
| for (int i = 0; i < bi; i++) { | ||
| start_offset += token_num_per_batch[i]; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 当前实现中,每个 block(每个 建议改为使用 shared memory 避免重复计算: __shared__ int32_t s_start_offset;
if (tid == 0) {
int off = 0;
for (int i = 0; i < bi; i++) off += token_num_per_batch[i];
s_start_offset = off;
}
__syncthreads();
int start_offset = s_start_offset; |
||
| } | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| int cur_token_num = token_num_per_batch[bi]; | ||
|
|
||
| if (cur_token_num <= 0) { | ||
| return; | ||
| } | ||
|
|
||
| // Read per-batch param into register | ||
| T val = input_params[bi]; | ||
|
|
||
| // Fill output_params with bounds check against total output size | ||
| for (int i = tid; i < cur_token_num; i += blockDim.x) { | ||
| int64_t idx = static_cast<int64_t>(start_offset) + i; | ||
| if (idx < token_num_output_cpu) { | ||
| output_params[idx] = val; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| std::vector<paddle::Tensor> BuildSamplingParamLogProb( | ||
| const paddle::Tensor& input_params, | ||
| const paddle::Tensor& token_num_per_batch, | ||
| const int64_t token_num_output_cpu) { | ||
| auto cu_stream = input_params.stream(); | ||
| // Initialize output to safe defaults for use as divisors: | ||
| // int32/float32 -> 1, bool -> false | ||
| paddle::Tensor output_params; | ||
| switch (input_params.dtype()) { | ||
| case paddle::DataType::BOOL: | ||
| output_params = paddle::full({token_num_output_cpu}, | ||
| false, | ||
| input_params.dtype(), | ||
| input_params.place()); | ||
| break; | ||
| case paddle::DataType::INT32: | ||
| output_params = paddle::full({token_num_output_cpu}, | ||
| 1, | ||
| input_params.dtype(), | ||
| input_params.place()); | ||
| break; | ||
| case paddle::DataType::FLOAT32: | ||
| output_params = paddle::full({token_num_output_cpu}, | ||
| 1.0f, | ||
| input_params.dtype(), | ||
| input_params.place()); | ||
| break; | ||
| default: | ||
| PD_THROW( | ||
| "Unsupported data type for BuildSamplingParamLogProb. " | ||
| "Only bool, int32, float32 are supported."); | ||
| } | ||
|
|
||
| int32_t num_blocks = token_num_per_batch.shape()[0]; | ||
| switch (input_params.dtype()) { | ||
| case paddle::DataType::BOOL: { | ||
| BuildSamplingParamLogProbKernel<bool><<<num_blocks, 256, 0, cu_stream>>>( | ||
| output_params.data<bool>(), | ||
| input_params.data<bool>(), | ||
| token_num_per_batch.data<int32_t>(), | ||
| token_num_output_cpu); | ||
| break; | ||
| } | ||
| case paddle::DataType::INT32: { | ||
| BuildSamplingParamLogProbKernel<int32_t> | ||
| <<<num_blocks, 256, 0, cu_stream>>>( | ||
| output_params.data<int32_t>(), | ||
| input_params.data<int32_t>(), | ||
| token_num_per_batch.data<int32_t>(), | ||
| token_num_output_cpu); | ||
| break; | ||
| } | ||
| case paddle::DataType::FLOAT32: { | ||
| BuildSamplingParamLogProbKernel<float><<<num_blocks, 256, 0, cu_stream>>>( | ||
| output_params.data<float>(), | ||
| input_params.data<float>(), | ||
| token_num_per_batch.data<int32_t>(), | ||
| token_num_output_cpu); | ||
| break; | ||
| } | ||
| default: { | ||
| PD_THROW( | ||
| "Unsupported data type for BuildSamplingParamLogProb. " | ||
| "Only bool, int32, float32 are supported."); | ||
| } | ||
| } | ||
|
|
||
| return {output_params}; | ||
| } | ||
|
|
||
| PD_BUILD_STATIC_OP(build_sampling_params_logprob) | ||
| .Inputs({"input_params", "token_num_per_batch"}) | ||
| .Outputs({"output_params"}) | ||
| .Attrs({"token_num_output_cpu: int64_t"}) | ||
| .SetKernelFn(PD_KERNEL(BuildSamplingParamLogProb)); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -144,7 +144,8 @@ def build_output_logprobs( | |
| is_naive: bool = False, | ||
| logprobs_mode: str = "default", | ||
| compute_logprobs_fn: Optional[Callable] = None, | ||
| ) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor], Optional[paddle.Tensor]]: | ||
| real_bsz: int = 0, | ||
| ) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 docstring 中的 Returns 描述与实际函数返回类型不一致。 函数签名已更新为 描述的是3元素元组。请同步更新 docstring。 |
||
| """ | ||
| Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes. | ||
|
|
||
|
|
@@ -170,21 +171,22 @@ def build_output_logprobs( | |
| logprobs_tensors = None | ||
| cu_batch_token_offset = None | ||
|
|
||
| real_bsz = share_inputs["seq_lens_this_time"].shape[0] | ||
| # NOTE(huicongyao) real_bsz is passed from _postprocess, remove this in future | ||
| max_occupied_slots = share_inputs["seq_lens_this_time"].shape[0] | ||
|
|
||
| if is_naive: | ||
| # NAIVE mode: one token per request, logits are already correct | ||
| output_logits = logits | ||
| token_ids = share_inputs["accept_tokens"][:real_bsz, 0] | ||
| token_ids = share_inputs["accept_tokens"][:max_occupied_slots, 0] | ||
| else: | ||
| # Speculative mode: extract target logits for accepted positions | ||
| from fastdeploy.model_executor.layers.sample.ops import ( | ||
| speculate_get_target_logits, | ||
| ) | ||
|
|
||
| batch_token_num = paddle.where( | ||
| share_inputs["seq_lens_encoder"][:real_bsz] != 0, | ||
| paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]), | ||
| share_inputs["seq_lens_encoder"][:max_occupied_slots] != 0, | ||
| paddle.ones_like(share_inputs["seq_lens_encoder"][:max_occupied_slots]), | ||
| share_inputs["seq_lens_this_time"], | ||
| ).flatten() | ||
|
|
||
|
|
@@ -194,12 +196,12 @@ def build_output_logprobs( | |
| "int32" | ||
| ) | ||
| cu_batch_token_offset = paddle.concat( | ||
| [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])] | ||
| [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:max_occupied_slots])] | ||
| ).astype("int32") | ||
| share_inputs["cu_batch_token_offset"] = cu_batch_token_offset | ||
|
|
||
| output_logits = paddle.empty( | ||
| [share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], | ||
| [share_inputs["accept_num"][:max_occupied_slots].sum(), logits.shape[1]], | ||
| dtype=logits.dtype, | ||
| ) | ||
| speculate_get_target_logits( | ||
|
|
@@ -222,7 +224,7 @@ def build_output_logprobs( | |
|
|
||
| # Compute logprobs with temperature scaling and top_p normalization | ||
| if logprobs_mode == "raw_logprobs": | ||
| raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata) | ||
| raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata, real_bsz) | ||
| elif logprobs_mode == "raw_logits": | ||
| raw_logprobs = output_logits.clone() | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,6 +54,7 @@ | |
| if current_platform.is_cuda(): | ||
| from fastdeploy.model_executor.ops.gpu import ( | ||
| build_sampling_params, | ||
| build_sampling_params_logprob, | ||
| naive_update_model_status, | ||
| ) | ||
|
|
||
|
|
@@ -833,6 +834,7 @@ def __init__(self, fd_config: FDConfig): | |
| self.spec_method = spec_config.method | ||
| self.verify_strategy = spec_config.verify_strategy | ||
| self.prefill_one_step_stop = fd_config.parallel_config.prefill_one_step_stop | ||
| self.num_speculative_tokens = spec_config.num_speculative_tokens | ||
|
|
||
| # Accept policy from config (can be overridden by function parameters) | ||
| self.config_accept_all = spec_config.accept_policy == "accept_all" | ||
|
|
@@ -858,55 +860,57 @@ def compute_logprobs( | |
| self, | ||
| logits: paddle.Tensor, | ||
| sampling_metadata: SamplingMetadata, | ||
| real_bsz: int = 0, | ||
| ) -> paddle.Tensor: | ||
| """compute logprobs""" | ||
| share_inputs = sampling_metadata.share_inputs | ||
| last_logits = logits | ||
| real_bsz = share_inputs["seq_lens_this_time"].shape[0] | ||
| batch_token_num = share_inputs["accept_num"][:real_bsz] | ||
|
|
||
| # NOTE(huicongyao): temporarily used to provide a max_sized input, remove in the future | ||
| num_tokens = real_bsz * (self.num_speculative_tokens + 1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 当 请确认:
如果 |
||
| padded_logits = paddle.zeros(shape=[num_tokens, last_logits.shape[1]], dtype=last_logits.dtype) | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| padded_logits[: logits.shape[0]] = last_logits | ||
| max_occupied_slots = share_inputs["seq_lens_this_time"].shape[0] | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
|
|
||
| batch_token_num = share_inputs["accept_num"][:max_occupied_slots] | ||
|
|
||
| temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs | ||
| top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs | ||
| if temp_scaled_logprobs is not None: | ||
| real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] | ||
| temperature = sampling_metadata.temperature[:real_bsz] | ||
| real_bsz_temp_scaled = ( | ||
| real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool") | ||
| ) | ||
| temperature = temperature.squeeze(1).repeat_interleave(batch_token_num) | ||
| real_bsz_temp_scaled = temp_scaled_logprobs[:max_occupied_slots] | ||
| temperature = sampling_metadata.temperature[:max_occupied_slots] | ||
| real_bsz_temp_scaled = build_sampling_params_logprob(real_bsz_temp_scaled, batch_token_num, num_tokens) | ||
| temperature = build_sampling_params_logprob(temperature, batch_token_num, num_tokens) | ||
| temp_temperature = paddle.where( | ||
| real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) | ||
| ).unsqueeze(1) | ||
| last_logits = last_logits / temp_temperature | ||
| padded_logits = padded_logits / temp_temperature | ||
|
|
||
| last_logprobs = F.log_softmax(last_logits, axis=-1) | ||
| last_logprobs = F.log_softmax(padded_logits, axis=-1) | ||
| top_p_logprob = None | ||
| top_p_token_mask = None | ||
|
|
||
| if ( | ||
| top_p_normalized_logprobs is not None | ||
| and share_inputs is not None | ||
| and sampling_metadata.top_p_normalized_logprobs_flag | ||
| ): | ||
| real_token_top_p = ( | ||
| sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1) | ||
| ) | ||
| top_p_normalized_logprobs = ( | ||
| top_p_normalized_logprobs[:real_bsz] | ||
| .astype("int32") | ||
| .squeeze(1) | ||
| .repeat_interleave(batch_token_num) | ||
| .astype("bool") | ||
| .unsqueeze(1) | ||
| ) | ||
| real_token_top_p = build_sampling_params_logprob( | ||
| sampling_metadata.top_p[:max_occupied_slots].squeeze(1), batch_token_num, num_tokens | ||
| ).unsqueeze(1) | ||
| top_p_normalized_logprobs = build_sampling_params_logprob( | ||
| top_p_normalized_logprobs[:max_occupied_slots].squeeze(1), batch_token_num, num_tokens | ||
| ).unsqueeze(1) | ||
| top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| if top_p_token_mask.any(): | ||
| probs = F.softmax(last_logits, axis=-1) | ||
| probs = top_p_normalize_probs_paddle(probs, real_token_top_p) | ||
| top_p_logprob = paddle.log(probs) | ||
|
|
||
| probs = F.softmax(padded_logits, axis=-1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 移除了 原代码有 |
||
| probs = top_p_normalize_probs_paddle(probs, real_token_top_p) | ||
| top_p_logprob = paddle.log(probs) | ||
| if top_p_logprob is not None: | ||
| last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) | ||
| return last_logprobs | ||
|
|
||
| # NOTE(huicongyao) temporarily used for slice last_logprobs to its real shape, remove in the future | ||
| real_token_num = batch_token_num.sum().item() | ||
| return last_logprobs[:real_token_num] | ||
|
|
||
| def gather_logprobs( | ||
| self, | ||
|
|
@@ -1136,6 +1140,7 @@ def forward_cuda( | |
| increment_value: int, | ||
| accept_all_drafts: bool = False, | ||
| reject_all_drafts: bool = False, | ||
| real_bsz: int = 0, | ||
| ) -> SamplerOutput: | ||
| """ | ||
| Forward pass for speculative sampling. | ||
|
|
@@ -1229,6 +1234,7 @@ def forward_cuda( | |
| is_naive=is_naive, | ||
| logprobs_mode=self.logprobs_mode, | ||
| compute_logprobs_fn=self.compute_logprobs, | ||
| real_bsz=real_bsz, | ||
| ) | ||
| sampler_output.logprobs_tensors = logprobs_tensors | ||
| if cu_batch_token_offset is not None: | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.