Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/cache_manager/cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def _init_storage(self, args):
* self.cache_item_bytes,
device_id=self.device,
dp_id=self.local_data_parallel_id,
splitwise_role=getattr(args, "splitwise_role", "mixed"),
)
logger.info("Initialized attention store successfully!")
elif args.kvcache_storage_backend == "file":
Expand Down
6 changes: 2 additions & 4 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,10 +1120,7 @@ def write_cache_to_storage(self, request: Request):
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()

if self.config.cache_config.enable_output_caching:
input_token_ids = token_ids + request.output_token_ids
else:
input_token_ids = token_ids
input_token_ids = token_ids + request.output_token_ids
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 移除 enable_output_caching 条件是否符合预期?

原代码在 enable_output_caching=False 时,input_token_ids 只包含 input tokens;现在无论该配置如何,总会追加 request.output_token_ids

虽然后续的截断 input_token_ids[: len(keys) * block_size] 可能会将 output tokens 剪掉,但这取决于 keys 覆盖的范围。如果 enable_output_caching=Falsekeys 的数量意外覆盖了 output token 区间,缓存内容将包含本不应写入的 output tokens,可能污染下次 prefix cache 命中结果。

请确认:当 enable_output_caching=False 时,keys 的长度是否严格不超过 input 对应的 block 数?如果可以保证,建议在此处添加注释说明该假设。


req_id = request.request_id
keys = []
Expand All @@ -1136,6 +1133,7 @@ def write_cache_to_storage(self, request: Request):
return

gpu_block_ids = request.block_tables[: len(keys)]
input_token_ids = input_token_ids[: len(keys) * self.config.cache_config.block_size]
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
write_storage_task = WriteStorageTask(
task_id=req_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""

import os
import time
import traceback
from dataclasses import dataclass
Expand Down Expand Up @@ -51,6 +52,7 @@ class AttentionStoreConfig:
bytes_per_shard_layer_per_block: int = 1024
device_id: int = 0
dp_id: int = 0
splitwise_role: str = "mixed"


class AttentionStore(KVCacheStorage):
Expand All @@ -62,6 +64,13 @@ def __init__(self, **args):
self.config = AttentionStoreConfig(**args)

try:
self.config.namespace = os.getenv("AS_NAMESPACE", self.config.namespace)
self.config.pod_name = os.getenv("AS_POD_NAME", self.config.pod_name)
if int(os.getenv("ENABLE_EP_DP_IN_FD", "1")):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 int(os.getenv(...)) 缺少异常处理,健壮性不足。

当环境变量 ENABLE_EP_DP_IN_FD 被设置为非整数字符串(如 "true" / "yes" / "on")时,int() 会直接抛出 ValueError,导致初始化过程崩溃。

建议修改为更健壮的写法:

if os.getenv("ENABLE_EP_DP_IN_FD", "1") != "0":

self.config.pod_name = (
self.config.pod_name + "_" + self.config.splitwise_role + "_" + str(self.config.dp_id)
)
self.config.model_version = os.getenv("AS_MODEL_VERSION", self.config.model_version)
logger.info(f"[INIT] Start initializing AttentionStoreSDK with config: {self.config}")
self.sdk = AttentionStoreSDK(
self.config.namespace,
Expand Down
2 changes: 2 additions & 0 deletions tests/cache_manager/test_cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Args:
kvcache_storage_backend = None
write_policy = "write_through"
model_path = "test_model"
splitwise_role = "mixed"


# ==========================
Expand Down Expand Up @@ -717,6 +718,7 @@ class LocalArgs(Args):
* manager.cache_item_bytes,
device_id=manager.device,
dp_id=manager.local_data_parallel_id,
splitwise_role=LocalArgs.splitwise_role,
)

def test_invalid_write_policy_raises(self):
Expand Down
Loading