Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1145,13 +1145,16 @@ void SpeculateInsertFirstToken(const paddle::Tensor& token_ids,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder);

void SpeculateGetTargetLogits(const paddle::Tensor& target_logits,
const paddle::Tensor& logits,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& ori_cu_batch_token_offset,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num);
void SpeculateGetAcceptTokensAndLogits(
const paddle::Tensor& token_ids,
const paddle::Tensor& target_logits,
const paddle::Tensor& logits,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num,
const paddle::Tensor& accept_tokens);

std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const paddle::Tensor& ids_remove_padding,
Expand Down Expand Up @@ -1879,9 +1882,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&SpeculateInsertFirstToken,
"speculate_insert_first_token function");

m.def("speculate_get_target_logits",
&SpeculateGetTargetLogits,
"speculate_get_target_logits function");
m.def("speculate_get_accept_tokens_and_logits",
&SpeculateGetAcceptTokensAndLogits,
"speculate_get_accept_tokens_and_logits function");
#endif

m.def("update_attn_mask_offsets",
Expand Down
124 changes: 102 additions & 22 deletions custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -184,24 +184,65 @@ void SpeculateInsertFirstToken(const paddle::Tensor& token_ids,
real_bsz);
}

template <int BLOCK_DIM, int ITEMS_PER_THREAD>
__global__ void compute_cu_batch_offset_kernel(int* cu_batch_token_offset,
const int* accept_num,
const int real_bsz) {
using BlockScan = cub::BlockScan<int, BLOCK_DIM>;
__shared__ typename BlockScan::TempStorage temp_storage;

int tid = threadIdx.x;
if (tid == 0) cu_batch_token_offset[0] = 0;

int thread_data[ITEMS_PER_THREAD];

for (int i = 0; i < ITEMS_PER_THREAD; i++) {
int batch_id = tid * ITEMS_PER_THREAD + i;
thread_data[i] =
batch_id < real_bsz ? accept_num[tid * ITEMS_PER_THREAD + i] : 0;
}

BlockScan(temp_storage).InclusiveSum(thread_data, thread_data);
__syncthreads();

for (int i = 0; i < ITEMS_PER_THREAD; i++) {
int batch_id = tid * ITEMS_PER_THREAD + i;
if (batch_id < real_bsz) {
cu_batch_token_offset[batch_id + 1] = thread_data[i];
}
}
}

template <int VecSize>
__global__ void speculate_get_target_logits_kernel(
__global__ void speculate_get_accept_tokens_and_logits_kernel(
int64_t* token_ids,
float* target_logits,
const float* logits,
const int* cu_batch_token_offset,
const int* ori_cu_batch_token_offset,
const int* cu_seqlens_q_output,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* accept_num,
const int64_t* accept_tokens,
const int vocab_size,
const int max_draft_tokens,
const int real_bsz) {
AlignedVector<float, VecSize> src_vec;
const int bid = blockIdx.x;
const int tid = threadIdx.x;
if (bid < real_bsz) {
// get token_ids
if (tid == 0) {
auto* accept_tokens_now = accept_tokens + bid * max_draft_tokens;
for (int i = 0; i < accept_num[bid]; i++) {
token_ids[cu_batch_token_offset[bid] + i] = accept_tokens_now[i];

This comment was marked as outdated.

Comment thread
Deleter-D marked this conversation as resolved.
}
}

// get output_logits
auto* target_logits_now =
target_logits + cu_batch_token_offset[bid] * vocab_size;
auto* logits_now = logits + ori_cu_batch_token_offset[bid] * vocab_size;
auto* logits_now = logits + cu_seqlens_q_output[bid] * vocab_size;
for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) {
if (seq_lens_encoder[bid] > 0) {
Load<float, VecSize>(&logits_now[i], &src_vec);
Expand All @@ -217,31 +258,64 @@ __global__ void speculate_get_target_logits_kernel(
}
}

void SpeculateGetTargetLogits(const paddle::Tensor& target_logits,
const paddle::Tensor& logits,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& ori_cu_batch_token_offset,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num) {
void SpeculateGetAcceptTokensAndLogits(
const paddle::Tensor& token_ids,
const paddle::Tensor& target_logits,
const paddle::Tensor& logits,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num,
const paddle::Tensor& accept_tokens) {
auto cu_stream = seq_lens_this_time.stream();
const int vocab_size = logits.shape()[1];
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_occupied_slots = seq_lens_this_time.shape()[0];
const int max_draft_tokens = accept_tokens.shape()[1];

Comment thread
Deleter-D marked this conversation as resolved.
const int BLOCK_DIM = 512;
PADDLE_ENFORCE_LE(max_occupied_slots,
2048,
phi::errors::InvalidArgument(
"Only support bsz <= 2048, but received bsz is ",
max_occupied_slots));
if (max_occupied_slots <= 512) {
Comment thread
Deleter-D marked this conversation as resolved.
compute_cu_batch_offset_kernel<BLOCK_DIM, 1>
<<<1, BLOCK_DIM, 0, cu_stream>>>(
const_cast<int*>(cu_batch_token_offset.data<int>()),
accept_num.data<int>(),
max_occupied_slots);
} else if (max_occupied_slots <= 1024) {
compute_cu_batch_offset_kernel<BLOCK_DIM, 2>
<<<1, BLOCK_DIM, 0, cu_stream>>>(
const_cast<int*>(cu_batch_token_offset.data<int>()),
accept_num.data<int>(),
max_occupied_slots);
} else if (max_occupied_slots <= 2048) {
compute_cu_batch_offset_kernel<BLOCK_DIM, 4>
<<<1, BLOCK_DIM, 0, cu_stream>>>(
const_cast<int*>(cu_batch_token_offset.data<int>()),
accept_num.data<int>(),
max_occupied_slots);
}
Comment thread
Deleter-D marked this conversation as resolved.

constexpr int PackSize = VEC_16B / sizeof(float);
dim3 grid_dim(real_bsz);
dim3 grid_dim(max_occupied_slots);
dim3 block_dim(128);
speculate_get_target_logits_kernel<PackSize>
speculate_get_accept_tokens_and_logits_kernel<PackSize>

This comment was marked as outdated.

<<<grid_dim, block_dim, 0, cu_stream>>>(
const_cast<int64_t*>(token_ids.data<int64_t>()),
const_cast<float*>(target_logits.data<float>()),
logits.data<float>(),
cu_batch_token_offset.data<int>(),
ori_cu_batch_token_offset.data<int>(),
cu_seqlens_q_output.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
accept_num.data<int>(),
accept_tokens.data<int64_t>(),
vocab_size,
real_bsz);
max_draft_tokens,
max_occupied_slots);
}

PD_BUILD_STATIC_OP(speculate_get_logits)
Expand Down Expand Up @@ -274,14 +348,20 @@ PD_BUILD_STATIC_OP(speculate_insert_first_token)
.SetInplaceMap({{"token_ids", "token_ids_out"}})
.SetKernelFn(PD_KERNEL(SpeculateInsertFirstToken));

PD_BUILD_STATIC_OP(speculate_get_target_logits)
.Inputs({"target_logits",
PD_BUILD_STATIC_OP(speculate_get_accept_tokens_and_logits)
.Inputs({"token_ids",
"target_logits",
"logits",
"cu_batch_token_offset",
"ori_cu_batch_token_offset",
"cu_seqlens_q_output",
"seq_lens_this_time",
"seq_lens_encoder",
"accept_num"})
.Outputs({"target_logits_out"})
.SetInplaceMap({{"target_logits", "target_logits_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetTargetLogits));
"accept_num",
"accept_tokens"})
.Outputs({"token_ids_out",
"target_logits_out",
"cu_batch_token_offset_out"})
.SetInplaceMap({{"token_ids", "token_ids_out"},
{"target_logits", "target_logits_out"},
{"cu_batch_token_offset", "cu_batch_token_offset_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetAcceptTokensAndLogits));
38 changes: 11 additions & 27 deletions fastdeploy/model_executor/layers/sample/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def build_output_logprobs(
if num_logprobs is None:
return logprobs_tensors, cu_batch_token_offset

# NOTE(huicongyao) real_bsz is passed from _postprocess, remove this in future
max_draft_token_num_plus_1 = share_inputs["accept_tokens"].shape[1]
max_occupied_slots = share_inputs["seq_lens_this_time"].shape[0]

if is_naive:
Expand All @@ -184,43 +184,27 @@ def build_output_logprobs(
else:
# Speculative mode: extract target logits for accepted positions
from fastdeploy.model_executor.layers.sample.ops import (
speculate_get_target_logits,
speculate_get_accept_tokens_and_logits,
)

batch_token_num = paddle.where(
share_inputs["seq_lens_encoder"][:max_occupied_slots] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"][:max_occupied_slots]),
share_inputs["seq_lens_this_time"],
).flatten()

share_inputs["batch_token_num"] = batch_token_num

ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
"int32"
)
cu_batch_token_offset = paddle.concat(
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:max_occupied_slots])]
).astype("int32")
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset

output_logits = paddle.empty(
[share_inputs["accept_num"][:max_occupied_slots].sum(), logits.shape[1]],
[real_bsz * max_draft_token_num_plus_1, logits.shape[1]],

This comment was marked as outdated.

Comment thread
Deleter-D marked this conversation as resolved.
dtype=logits.dtype,
)
speculate_get_target_logits(
token_ids = paddle.full([real_bsz * max_draft_token_num_plus_1], fill_value=0, dtype="int64")
Comment thread
Deleter-D marked this conversation as resolved.

speculate_get_accept_tokens_and_logits(
token_ids,
output_logits,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
share_inputs["cu_batch_token_offset"],
share_inputs["cu_seqlens_q_output"],
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["accept_num"],
share_inputs["accept_tokens"],
)

idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
mask = idx < share_inputs["accept_num"].unsqueeze(1)
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)

# Compute logprobs with temperature scaling and top_p normalization
if logprobs_mode == "raw_logprobs":
raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata, real_bsz)
Expand All @@ -231,4 +215,4 @@ def build_output_logprobs(

logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
Comment thread
Deleter-D marked this conversation as resolved.

return logprobs_tensors, cu_batch_token_offset
return logprobs_tensors, share_inputs["cu_batch_token_offset"]
4 changes: 2 additions & 2 deletions fastdeploy/model_executor/layers/sample/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
reasoning_phase_token_constraint,
)
from .speculate_logprob_utils import (
speculate_get_target_logits,
speculate_get_accept_tokens_and_logits,
speculate_insert_first_token,
)
from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling
Expand All @@ -31,6 +31,6 @@
"reasoning_phase_token_constraint",
"top_k_top_p_sampling",
"min_p_sampling",
"speculate_get_target_logits",
"speculate_get_accept_tokens_and_logits",
"speculate_insert_first_token",
]
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,35 @@
from fastdeploy.platforms import current_platform


def speculate_get_target_logits(
def speculate_get_accept_tokens_and_logits(
token_ids: paddle.Tensor,
target_logits: paddle.Tensor,
logits: paddle.Tensor,
cu_batch_token_offset: paddle.Tensor,
ori_cu_batch_token_offset: paddle.Tensor,
cu_seqlens_q_output: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
accept_num: paddle.Tensor,
accept_tokens: paddle.Tensor,
):
"""
speculate_get_target_logits
speculate_get_accept_tokens_and_logits
"""
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import speculate_get_target_logits
from fastdeploy.model_executor.ops.gpu import (
speculate_get_accept_tokens_and_logits,
)

speculate_get_target_logits(
speculate_get_accept_tokens_and_logits(
token_ids,
target_logits,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
cu_seqlens_q_output,
seq_lens_this_time,
seq_lens_encoder,
accept_num,
accept_tokens,
)
else:
raise NotImplementedError
Expand Down
8 changes: 4 additions & 4 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,9 +729,7 @@ def compute_logprobs(
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)

# NOTE(huicongyao) temporarily used for slice last_logprobs to its real shape, remove in the future
real_token_num = batch_token_num.sum().item()
return last_logprobs[:real_token_num]
return last_logprobs

def gather_logprobs(
self,
Expand Down Expand Up @@ -1059,7 +1057,9 @@ def forward_cuda(
)
sampler_output.logprobs_tensors = logprobs_tensors
if cu_batch_token_offset is not None:
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
cu_batch_token_offset_cpu = paddle.empty_like(cu_batch_token_offset, device="cpu").pin_memory()
cu_batch_token_offset_cpu.copy_(cu_batch_token_offset, False)

This comment was marked as outdated.

sampler_output.cu_batch_token_offset = cu_batch_token_offset_cpu
return sampler_output

def _normal_sample_xpu(
Expand Down
2 changes: 1 addition & 1 deletion tests/layers/test_speculative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_speculative_sampler():
increment_value = (max_draft_token_num + 1) * 4

sampler = SpeculativeSampler(fd_config)
sampler(logits, sampling_metadata, max_model_len, share_inputs, token_num_output_cpu, increment_value)
sampler(logits, sampling_metadata, max_model_len, share_inputs, token_num_output_cpu, increment_value, batch_size)


def test_speculative_sampler_logprobs():
Expand Down
Loading
Loading