Skip to content

Commit a66054f

Browse files
committed
deepseek-v4: defunctionalize fused MLA insert op
Signed-off-by: jasl <jasl9187@hotmail.com>
1 parent d6da156 commit a66054f

2 files changed

Lines changed: 76 additions & 0 deletions

File tree

tests/compile/passes/test_functionalization.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,72 @@ def ops_not_in_model(self):
251251
return []
252252

253253

254+
class TestFusedDeepseekV4QnormRopeKvInsert(torch.nn.Module):
255+
OP_REGISTERED = False
256+
257+
def __init__(self):
258+
super().__init__()
259+
self.register_test_custom_op()
260+
261+
@classmethod
262+
def register_test_custom_op(cls):
263+
if not cls.OP_REGISTERED:
264+
265+
def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl(
266+
q: torch.Tensor,
267+
kv: torch.Tensor,
268+
k_cache: torch.Tensor,
269+
) -> None:
270+
q.add_(kv)
271+
k_cache.add_(kv)
272+
273+
def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake(
274+
q: torch.Tensor,
275+
kv: torch.Tensor,
276+
k_cache: torch.Tensor,
277+
) -> None:
278+
return None
279+
280+
direct_register_custom_op(
281+
op_name="fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
282+
op_func=fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl,
283+
mutates_args=["q", "k_cache"],
284+
fake_impl=fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake,
285+
)
286+
287+
cls.OP_REGISTERED = True
288+
289+
def forward(
290+
self, q: torch.Tensor, kv: torch.Tensor, k_cache: torch.Tensor
291+
) -> tuple[torch.Tensor, torch.Tensor]:
292+
torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
293+
q, kv, k_cache
294+
)
295+
return q, k_cache
296+
297+
def example_inputs(self, num_tokens=32, hidden_size=128):
298+
return (
299+
torch.rand(num_tokens, hidden_size),
300+
torch.rand(num_tokens, hidden_size),
301+
torch.rand(num_tokens, hidden_size),
302+
)
303+
304+
def ops_in_model(self, do_fusion):
305+
return [
306+
torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default
307+
]
308+
309+
def ops_not_in_model(self):
310+
return []
311+
312+
254313
MODELS_AND_DO_FUSION = {
255314
TestSiluMul: [True, False],
256315
TestFusedAddRMSNorm: [True, False],
257316
TestRotaryEmbedding: [False],
258317
TestRotaryEmbeddingSliceScatter: [False],
259318
TestFunctionWithMutatedArgsAndReturn: [False],
319+
TestFusedDeepseekV4QnormRopeKvInsert: [False],
260320
}
261321

262322

vllm/compilation/passes/utility/fix_functionalization.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,24 @@ def __call__(self, graph: torch.fx.Graph) -> None:
3939
count = 0
4040

4141
rope_targets = [torch.ops._C.rotary_embedding.default]
42+
fused_deepseek_v4_mla_targets = []
4243

4344
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
4445
rope_targets.append(
4546
torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
4647
)
48+
if hasattr(
49+
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert"
50+
):
51+
fused_deepseek_v4_mla_targets.append(
52+
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default
53+
)
54+
if hasattr(
55+
torch.ops.vllm, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert"
56+
):
57+
fused_deepseek_v4_mla_targets.append(
58+
torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default
59+
)
4760

4861
for node in graph.nodes:
4962
if not is_func(node, auto_functionalized):
@@ -181,6 +194,9 @@ def __call__(self, graph: torch.fx.Graph) -> None:
181194
2: "key",
182195
}
183196
self.defunctionalize(graph, node, mutated_args=mutated_args)
197+
elif at_target in fused_deepseek_v4_mla_targets:
198+
mutated_args = {1: "q", 2: "k_cache"}
199+
self.defunctionalize(graph, node, mutated_args)
184200
elif (
185201
hasattr(torch.ops.vllm, "fused_rope_unified_mla_kv_cache_update")
186202
and at_target

0 commit comments

Comments
 (0)