Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,29 +155,86 @@ def write(
start_write_block_idx: int,
timeout: float = 30.0,
) -> int:
logger.debug(
f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_write_block_idx: {start_write_block_idx} timeout: {timeout}"
)
tokens = Tokens(token_ids, self.config.block_token_size)
k_data_ptrs = [k.data_ptr() for k in key_cache]
v_data_ptrs = [v.data_ptr() for v in val_cache]
num = 0
try:
num = self.sdk.write(
list(range(self.config.layer_num)),
tokens,
start_write_block_idx,
k_data_ptrs,
v_data_ptrs,
gpu_block_ids,
timeout,
layer_ids = list(range(self.config.layer_num))
block_token_size = self.config.block_token_size

total_timeout = float(os.getenv("AS_WRITE_TOTAL_TIMEOUT", str(timeout)))
slice_block_num = int(os.getenv("AS_WRITE_SLICE_BLOCK_NUM", "500"))
slice_timeout = float(os.getenv("AS_WRITE_SLICE_TIMEOUT", "10"))
logger.debug(
f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids}"

This comment was marked as outdated.

f"start_write_block_idx: {start_write_block_idx} timeout: {total_timeout}"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 [WRITE BEGIN] 日志两段 f-string 直接拼接,中间缺少分隔符

f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids}"
f"start_write_block_idx: {start_write_block_idx} timeout: {total_timeout}"

Python 字符串字面量拼接不会自动插入空格,实际输出会是 ...gpu_block_ids: [...]start_write_block_idx: ...(两段紧挨在一起),可读性差。

建议修复为:

f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} "
f"start_write_block_idx: {start_write_block_idx} timeout: {total_timeout}"

)
total_blocks = len(gpu_block_ids)

This comment was marked as outdated.

total_written = 0
overall_start = time.time()

for slice_start in range(0, total_blocks, slice_block_num):
elapsed = time.time() - overall_start
remaining_timeout = total_timeout - elapsed
if remaining_timeout <= 0:
logger.warning(
f"[WRITE TIMEOUT] task_id: {task_id} total timeout {total_timeout}s reached, "
f"written {total_written}/{total_blocks} blocks"
)
break

slice_end = min(slice_start + slice_block_num, total_blocks)
slice_gpu_block_ids = gpu_block_ids[slice_start:slice_end]
slice_write_block_idx = start_write_block_idx + slice_start
slice_token_ids = token_ids[: (start_write_block_idx + slice_end) * block_token_size]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug token_ids 切片上界计算可能越界或截断错误

slice_token_ids = token_ids[: (start_write_block_idx + slice_end) * block_token_size]

这里假设 token_ids 包含 start_write_block_idx 之前所有 block 的 tokens,即调用方始终传入完整的累积 token 序列。若调用方只传入本次需要写入的 token 片段(不含前缀偏移),则上界会超出 token_ids 实际长度,Python 会静默截断,导致少写 token 或 Tokens 对象被构造为错误长度。

建议在循环外增加断言或日志,验证 len(token_ids) >= (start_write_block_idx + total_blocks) * block_token_size,或者在文档/注释中明确说明调用方的约定。

slice_tokens = Tokens(slice_token_ids, block_token_size)

effective_timeout = (
remaining_timeout if total_blocks <= slice_block_num else min(slice_timeout, remaining_timeout)
)
logger.debug(f"[WRITE END] task_id: {task_id} written_blocks: {num}")
except AttentionStoreSDKError:
logger.error(
f"[WRITE ERROR] failed to execute sdk write, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
logger.debug(
f"[WRITE SLICE BEGIN] task_id: {task_id} slice [{slice_start}:{slice_end}] "
f"block_idx={slice_write_block_idx}, blocks={len(slice_gpu_block_ids)}, "
f"token_ids_len={len(slice_token_ids)}, timeout={effective_timeout:.2f}s"
)
return num
slice_start_time = time.time()
try:
written = self.sdk.write(
layer_ids,
slice_tokens,
slice_write_block_idx,
k_data_ptrs,
v_data_ptrs,
slice_gpu_block_ids,
effective_timeout,
)
except AttentionStoreSDKError:
logger.error(
f"[WRITE ERROR] task_id: {task_id} slice [{slice_start}:{slice_end}], "
f"traceback:\n{traceback.format_exc()}"
)
written = 0
slice_cost = time.time() - slice_start_time
total_written += written
Comment thread
jackyYang6 marked this conversation as resolved.

if written < len(slice_gpu_block_ids):
logger.warning(
f"[WRITE INCOMPLETE] task_id: {task_id} slice [{slice_start}:{slice_end}] "
f"({written}/{len(slice_gpu_block_ids)}), cost={slice_cost:.6f}s, "
f"total written {total_written}/{total_blocks}, "
f"prefix cache continuity broken, skip remaining slices"
)
break

logger.debug(
f"[WRITE SLICE] task_id: {task_id} slice [{slice_start}:{slice_end}] "
f"written={written}, cost={slice_cost:.6f}s"
)

total_cost = time.time() - overall_start
logger.info(

This comment was marked as outdated.

f"[WRITE END] task_id: {task_id} total cost={total_cost:.6f}s, "
f"written {total_written}/{total_blocks} blocks"
)
return total_written

def query(self, task_id: str, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0):
"""
Expand Down
Loading