Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,11 @@ std::vector<paddle::Tensor> BuildSamplingParams(
const int64_t token_num_output_cpu,
const int64_t increment_value);

std::vector<paddle::Tensor> BuildSamplingParamLogProb(
const paddle::Tensor& input_params,
const paddle::Tensor& token_num_per_batch,
int64_t token_num_output_cpu);

void SpecTokenPenaltyMultiScores(
const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
Expand Down Expand Up @@ -1771,6 +1776,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&BuildSamplingParams,
"build_sampling_params function");

m.def("build_sampling_params_logprob",
&BuildSamplingParamLogProb,
"build_sampling_params_logprob function");

m.def("speculate_get_token_penalty_multi_scores",
&SpecTokenPenaltyMultiScores,
"speculate_get_token_penalty_multi_scores function");
Expand Down
129 changes: 129 additions & 0 deletions custom_ops/gpu_ops/speculate_decoding/build_sampling_params_logprob.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

template <typename T>
__global__ void BuildSamplingParamLogProbKernel(
T* output_params,
const T* input_params,
const int32_t* token_num_per_batch,
const int64_t token_num_output_cpu) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;

// Compute start offset: sum of token_num_per_batch[0..bi-1]
int start_offset = 0;
for (int i = 0; i < bi; i++) {

This comment was marked as outdated.

start_offset += token_num_per_batch[i];
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 start_offset 在每个 block 中重复计算,存在冗余工作。

当前实现中,每个 block(每个 bi)内的所有 256 个线程都各自执行相同的 O(bi) 累加循环来计算 start_offset。这是完全冗余的——可以让 tid==0 的线程计算后写入共享内存,其余线程读取共享内存,或者在 C++ 层预先计算前缀和传入 kernel。

建议改为使用 shared memory 避免重复计算:

__shared__ int32_t s_start_offset;
if (tid == 0) {
  int off = 0;
  for (int i = 0; i < bi; i++) off += token_num_per_batch[i];
  s_start_offset = off;
}
__syncthreads();
int start_offset = s_start_offset;

}

This comment was marked as outdated.

int cur_token_num = token_num_per_batch[bi];

if (cur_token_num <= 0) {
return;
}

// Read per-batch param into register
T val = input_params[bi];

// Fill output_params with bounds check against total output size
for (int i = tid; i < cur_token_num; i += blockDim.x) {
int64_t idx = static_cast<int64_t>(start_offset) + i;
if (idx < token_num_output_cpu) {
output_params[idx] = val;
}
}
}

std::vector<paddle::Tensor> BuildSamplingParamLogProb(
const paddle::Tensor& input_params,
const paddle::Tensor& token_num_per_batch,
const int64_t token_num_output_cpu) {
auto cu_stream = input_params.stream();
// Initialize output to safe defaults for use as divisors:
// int32/float32 -> 1, bool -> false
paddle::Tensor output_params;
switch (input_params.dtype()) {
case paddle::DataType::BOOL:
output_params = paddle::full({token_num_output_cpu},
false,
input_params.dtype(),
input_params.place());
break;
case paddle::DataType::INT32:
output_params = paddle::full({token_num_output_cpu},
1,
input_params.dtype(),
input_params.place());
break;
case paddle::DataType::FLOAT32:
output_params = paddle::full({token_num_output_cpu},
1.0f,
input_params.dtype(),
input_params.place());
break;
default:
PD_THROW(
"Unsupported data type for BuildSamplingParamLogProb. "
"Only bool, int32, float32 are supported.");
}

int32_t num_blocks = token_num_per_batch.shape()[0];
switch (input_params.dtype()) {
case paddle::DataType::BOOL: {
BuildSamplingParamLogProbKernel<bool><<<num_blocks, 256, 0, cu_stream>>>(
output_params.data<bool>(),
input_params.data<bool>(),
token_num_per_batch.data<int32_t>(),
token_num_output_cpu);
break;
}
case paddle::DataType::INT32: {
BuildSamplingParamLogProbKernel<int32_t>
<<<num_blocks, 256, 0, cu_stream>>>(
output_params.data<int32_t>(),
input_params.data<int32_t>(),
token_num_per_batch.data<int32_t>(),
token_num_output_cpu);
break;
}
case paddle::DataType::FLOAT32: {
BuildSamplingParamLogProbKernel<float><<<num_blocks, 256, 0, cu_stream>>>(
output_params.data<float>(),
input_params.data<float>(),
token_num_per_batch.data<int32_t>(),
token_num_output_cpu);
break;
}
default: {
PD_THROW(
"Unsupported data type for BuildSamplingParamLogProb. "
"Only bool, int32, float32 are supported.");
}
}

return {output_params};
}

PD_BUILD_STATIC_OP(build_sampling_params_logprob)
.Inputs({"input_params", "token_num_per_batch"})
.Outputs({"output_params"})
.Attrs({"token_num_output_cpu: int64_t"})
.SetKernelFn(PD_KERNEL(BuildSamplingParamLogProb));
18 changes: 10 additions & 8 deletions fastdeploy/model_executor/layers/sample/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def build_output_logprobs(
is_naive: bool = False,
logprobs_mode: str = "default",
compute_logprobs_fn: Optional[Callable] = None,
) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor], Optional[paddle.Tensor]]:
real_bsz: int = 0,
) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 docstring 中的 Returns 描述与实际函数返回类型不一致。

函数签名已更新为 Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]](2元素),但 docstring 仍写:

Returns:
    tuple: (logprobs_tensors, cu_batch_token_offset, output_logits)

描述的是3元素元组。请同步更新 docstring。

"""
Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes.

Expand All @@ -170,21 +171,22 @@ def build_output_logprobs(
logprobs_tensors = None
cu_batch_token_offset = None

real_bsz = share_inputs["seq_lens_this_time"].shape[0]
# NOTE(huicongyao) real_bsz is passed from _postprocess, remove this in future
max_occupied_slots = share_inputs["seq_lens_this_time"].shape[0]

if is_naive:
# NAIVE mode: one token per request, logits are already correct
output_logits = logits
token_ids = share_inputs["accept_tokens"][:real_bsz, 0]
token_ids = share_inputs["accept_tokens"][:max_occupied_slots, 0]
else:
# Speculative mode: extract target logits for accepted positions
from fastdeploy.model_executor.layers.sample.ops import (
speculate_get_target_logits,
)

batch_token_num = paddle.where(
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
share_inputs["seq_lens_encoder"][:max_occupied_slots] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"][:max_occupied_slots]),
share_inputs["seq_lens_this_time"],
).flatten()

Expand All @@ -194,12 +196,12 @@ def build_output_logprobs(
"int32"
)
cu_batch_token_offset = paddle.concat(
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:max_occupied_slots])]
).astype("int32")
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset

output_logits = paddle.empty(
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]],
[share_inputs["accept_num"][:max_occupied_slots].sum(), logits.shape[1]],
dtype=logits.dtype,
)
speculate_get_target_logits(
Expand All @@ -222,7 +224,7 @@ def build_output_logprobs(

# Compute logprobs with temperature scaling and top_p normalization
if logprobs_mode == "raw_logprobs":
raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata)
raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata, real_bsz)
elif logprobs_mode == "raw_logits":
raw_logprobs = output_logits.clone()
else:
Expand Down
60 changes: 33 additions & 27 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
build_sampling_params,
build_sampling_params_logprob,
naive_update_model_status,
)

Expand Down Expand Up @@ -833,6 +834,7 @@ def __init__(self, fd_config: FDConfig):
self.spec_method = spec_config.method
self.verify_strategy = spec_config.verify_strategy
self.prefill_one_step_stop = fd_config.parallel_config.prefill_one_step_stop
self.num_speculative_tokens = spec_config.num_speculative_tokens

# Accept policy from config (can be overridden by function parameters)
self.config_accept_all = spec_config.accept_policy == "accept_all"
Expand All @@ -858,55 +860,57 @@ def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
real_bsz: int = 0,
) -> paddle.Tensor:
"""compute logprobs"""
share_inputs = sampling_metadata.share_inputs
last_logits = logits
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = share_inputs["accept_num"][:real_bsz]

# NOTE(huicongyao): temporarily used to provide a max_sized input, remove in the future
num_tokens = real_bsz * (self.num_speculative_tokens + 1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问real_bsz=0 时,num_tokens = 0 * (self.num_speculative_tokens + 1) = 0,后续 paddle.zeros(shape=[0, last_logits.shape[1]], ...) 将创建空张量。

请确认:

  1. build_sampling_params_logprob(..., num_tokens=0) 的 CUDA kernel(token_num_output_cpu=0)不会产生越界访问;
  2. F.log_softmax(padded_logits=空张量) 的返回结果在后续 [:real_token_num] 切片逻辑中是否正确处理。

如果 real_bsz=0 的情况已在上游通过 if token_num_cpu > 0: 保护(如 mtp.py),建议在函数头部加断言或注释说明。

padded_logits = paddle.zeros(shape=[num_tokens, last_logits.shape[1]], dtype=last_logits.dtype)

This comment was marked as outdated.

padded_logits[: logits.shape[0]] = last_logits
max_occupied_slots = share_inputs["seq_lens_this_time"].shape[0]

This comment was marked as outdated.


batch_token_num = share_inputs["accept_num"][:max_occupied_slots]

temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
if temp_scaled_logprobs is not None:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
real_bsz_temp_scaled = (
real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool")
)
temperature = temperature.squeeze(1).repeat_interleave(batch_token_num)
real_bsz_temp_scaled = temp_scaled_logprobs[:max_occupied_slots]
temperature = sampling_metadata.temperature[:max_occupied_slots]
real_bsz_temp_scaled = build_sampling_params_logprob(real_bsz_temp_scaled, batch_token_num, num_tokens)
temperature = build_sampling_params_logprob(temperature, batch_token_num, num_tokens)
temp_temperature = paddle.where(
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
).unsqueeze(1)
last_logits = last_logits / temp_temperature
padded_logits = padded_logits / temp_temperature

last_logprobs = F.log_softmax(last_logits, axis=-1)
last_logprobs = F.log_softmax(padded_logits, axis=-1)
top_p_logprob = None
top_p_token_mask = None

if (
top_p_normalized_logprobs is not None
and share_inputs is not None
and sampling_metadata.top_p_normalized_logprobs_flag
):
real_token_top_p = (
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
)
top_p_normalized_logprobs = (
top_p_normalized_logprobs[:real_bsz]
.astype("int32")
.squeeze(1)
.repeat_interleave(batch_token_num)
.astype("bool")
.unsqueeze(1)
)
real_token_top_p = build_sampling_params_logprob(
sampling_metadata.top_p[:max_occupied_slots].squeeze(1), batch_token_num, num_tokens
).unsqueeze(1)
top_p_normalized_logprobs = build_sampling_params_logprob(
top_p_normalized_logprobs[:max_occupied_slots].squeeze(1), batch_token_num, num_tokens
).unsqueeze(1)
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)

This comment was marked as outdated.

if top_p_token_mask.any():
probs = F.softmax(last_logits, axis=-1)
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
top_p_logprob = paddle.log(probs)

probs = F.softmax(padded_logits, axis=-1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 移除了 if top_p_token_mask.any(): 条件保护,导致即使所有请求均无 top_p 缩放需求,也会无条件执行 F.softmaxtop_p_normalize_probs_paddlepaddle.log 三个计算密集型操作。

原代码有 if top_p_token_mask.any(): 保护,性能更优。如果该变更是为解决某个特定 bug 而故意移除,建议在注释中说明原因。

probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
top_p_logprob = paddle.log(probs)
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
return last_logprobs

# NOTE(huicongyao) temporarily used for slice last_logprobs to its real shape, remove in the future
real_token_num = batch_token_num.sum().item()
return last_logprobs[:real_token_num]

def gather_logprobs(
self,
Expand Down Expand Up @@ -1136,6 +1140,7 @@ def forward_cuda(
increment_value: int,
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
real_bsz: int = 0,
) -> SamplerOutput:
"""
Forward pass for speculative sampling.
Expand Down Expand Up @@ -1229,6 +1234,7 @@ def forward_cuda(
is_naive=is_naive,
logprobs_mode=self.logprobs_mode,
compute_logprobs_fn=self.compute_logprobs,
real_bsz=real_bsz,
)
sampler_output.logprobs_tensors = logprobs_tensors
if cu_batch_token_offset is not None:
Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def save_output_normal(
share_inputs["last_preempted_idx"][:] = 0


def post_process_specualate(
def post_process_speculate(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: InputBatch,
Expand Down Expand Up @@ -570,7 +570,7 @@ def post_process_specualate(
# so that async D2H of logz_per_batch has more time to complete.


def save_output_specualate(
def save_output_speculate(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: InputBatch,
Expand Down Expand Up @@ -764,7 +764,7 @@ def post_process(
)
else:
if speculative_decoding:
post_process_specualate(
post_process_speculate(
sampler_or_pooler_output,
model_output,
share_inputs,
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/xpu_pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def xpu_post_process_normal(
share_inputs["preempted_idx"][:] = 0


def xpu_post_process_specualate(
def xpu_post_process_speculate(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: Dict[str, paddle.Tensor],
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
else:
if substep == 0:
token_num_cpu = real_bsz * (self.max_draft_token_num + 1)
token_num_cpu = self.model_inputs["target_hidden_states"].shape[0]
else:
token_num_cpu = real_bsz
if token_num_cpu > 0:
Expand Down
Loading
Loading