Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def apply_penalty_multi_scores(
from fastdeploy.model_executor.ops.xpu import get_token_penalty_multi_scores

logits = get_token_penalty_multi_scores(
pre_token_ids,
token_ids_all,
logits,
repetition_penalties,
frequency_penalties,
Expand Down Expand Up @@ -179,7 +179,6 @@ def apply_speculative_penalty_multi_scores(
batch_id_per_token_output: paddle.Tensor,
cu_seqlens_q_output: paddle.Tensor,
max_len: int,
pre_token_ids: Optional[paddle.Tensor] = None, # used in xpu
):
"""
apply_speculative_penalty_multi_scores
Expand Down Expand Up @@ -213,7 +212,7 @@ def apply_speculative_penalty_multi_scores(
)

speculate_get_token_penalty_multi_scores(
pre_token_ids,
token_ids_all,
logits,
repetition_penalties,
frequency_penalties,
Expand Down
4 changes: 1 addition & 3 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,6 @@ def forward_cuda(

for proc in sampling_metadata.logits_processors or []:
logits = proc.apply(logits)

logits = apply_penalty_multi_scores(
sampling_metadata.token_ids_all,
logits,
Expand Down Expand Up @@ -1174,6 +1173,7 @@ def forward_xpu(
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> SamplerOutput:

logits = apply_speculative_penalty_multi_scores(
sampling_metadata.token_ids_all,
sampling_metadata.prompt_lens,
Expand All @@ -1191,7 +1191,6 @@ def forward_xpu(
share_inputs["batch_id_per_token_output"],
share_inputs["cu_seqlens_q_output"],
max_model_len,
sampling_metadata.pre_token_ids,
)

if self.enf_gen_phase_tag:
Expand Down Expand Up @@ -1436,7 +1435,6 @@ def forward_xpu(
share_inputs["batch_id_per_token_output"],
share_inputs["cu_seqlens_q_output"],
max_model_len,
sampling_metadata.pre_token_ids,
)
probs = F.softmax(logits)
next_tokens = paddle.argmax(probs, axis=-1)
Expand Down
180 changes: 18 additions & 162 deletions fastdeploy/model_executor/xpu_pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import paddle

from fastdeploy import envs
from fastdeploy.config import SpeculativeConfig
from fastdeploy.model_executor.forward_meta import XPUForwardMeta
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
Expand All @@ -43,12 +42,7 @@
speculate_pre_process,
speculate_save_output,
speculate_set_stop_value_multi_seqs,
speculate_step_paddle,
speculate_step_reschedule,
speculate_step_system_cache,
step_paddle,
unified_update_model_status,
update_inputs,
update_inputs_v1,
)
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
Expand Down Expand Up @@ -367,33 +361,22 @@ def xpu_post_process_normal(

# 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff():
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
share_inputs["step_seq_lens_decoder"],
share_inputs["prompt_lens"],
sampled_token_ids,
model_output.input_ids,
share_inputs["block_tables"],
model_output.next_tokens,
model_output.is_block_step,
block_size,
)
else:
update_inputs(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.input_ids,
sampled_token_ids,
model_output.is_block_step,
)
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
share_inputs["step_seq_lens_decoder"],
share_inputs["prompt_lens"],
sampled_token_ids,
model_output.input_ids,
share_inputs["block_tables"],
model_output.next_tokens,
model_output.is_block_step,
block_size,
)

# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if not skip_save_output:
Expand Down Expand Up @@ -464,7 +447,7 @@ def xpu_post_process_speculate(
speculate_set_stop_value_multi_seqs(
model_output.accept_tokens,
model_output.accept_num,
model_output.pre_ids,
model_output.token_ids_all,
model_output.step_idx,
model_output.stop_flags,
model_output.seq_lens_this_time,
Expand All @@ -486,7 +469,7 @@ def xpu_post_process_speculate(
model_output.seq_lens_this_time,
model_output.is_block_step,
model_output.mask_rollback,
model_output.pre_ids,
model_output.token_ids_all,
model_output.prompt_lens,
model_output.step_idx,
model_output.eos_token_id,
Expand All @@ -513,130 +496,3 @@ def xpu_post_process_speculate(

speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)
share_inputs["preempted_idx"][:] = 0


def step_xpu(
share_inputs: Dict[str, paddle.Tensor],
block_size: int,
enc_dec_block_num: int,
speculative_config: SpeculativeConfig,
enable_prefix_caching: bool = False,
) -> None:
if speculative_config.method is not None:
if DISABLE_RECOVER:
speculate_step_reschedule(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
if enable_prefix_caching:
speculate_step_system_cache(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
speculate_step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
# TODO(chenhuan09): add step system cache/reschedule support
step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
2 changes: 1 addition & 1 deletion fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,7 @@ def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = Fa
top_k=self.model_inputs["top_k"],
seed=self.model_inputs["infer_seed"],
step_idx=self.model_inputs["step_idx"],
pre_token_ids=self.model_inputs["pre_ids"],
token_ids_all=self.model_inputs["token_ids_all"],
frequency_penalties=self.model_inputs["frequency_score"],
presence_penalties=self.model_inputs["presence_score"],
repetition_penalties=self.model_inputs["penalty_score"],
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/worker/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def init_share_inputs(self):
self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"])
self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu")
self.not_need_stop_device = paddle.to_tensor([False], dtype="bool")
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_xpu():
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
if "token_ids_all" in self.target_model_input_batch:
Expand All @@ -787,6 +787,7 @@ def init_share_inputs(self):
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])

self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"])
self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"])
self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"])
Expand Down
Loading
Loading