@@ -251,12 +251,72 @@ def ops_not_in_model(self):
251251 return []
252252
253253
254+ class TestFusedDeepseekV4QnormRopeKvInsert (torch .nn .Module ):
255+ OP_REGISTERED = False
256+
257+ def __init__ (self ):
258+ super ().__init__ ()
259+ self .register_test_custom_op ()
260+
261+ @classmethod
262+ def register_test_custom_op (cls ):
263+ if not cls .OP_REGISTERED :
264+
265+ def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl (
266+ q : torch .Tensor ,
267+ kv : torch .Tensor ,
268+ k_cache : torch .Tensor ,
269+ ) -> None :
270+ q .add_ (kv )
271+ k_cache .add_ (kv )
272+
273+ def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake (
274+ q : torch .Tensor ,
275+ kv : torch .Tensor ,
276+ k_cache : torch .Tensor ,
277+ ) -> None :
278+ return None
279+
280+ direct_register_custom_op (
281+ op_name = "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert" ,
282+ op_func = fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl ,
283+ mutates_args = ["q" , "k_cache" ],
284+ fake_impl = fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake ,
285+ )
286+
287+ cls .OP_REGISTERED = True
288+
289+ def forward (
290+ self , q : torch .Tensor , kv : torch .Tensor , k_cache : torch .Tensor
291+ ) -> tuple [torch .Tensor , torch .Tensor ]:
292+ torch .ops .vllm .fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert (
293+ q , kv , k_cache
294+ )
295+ return q , k_cache
296+
297+ def example_inputs (self , num_tokens = 32 , hidden_size = 128 ):
298+ return (
299+ torch .rand (num_tokens , hidden_size ),
300+ torch .rand (num_tokens , hidden_size ),
301+ torch .rand (num_tokens , hidden_size ),
302+ )
303+
304+ def ops_in_model (self , do_fusion ):
305+ return [
306+ torch .ops .vllm .fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert .default
307+ ]
308+
309+ def ops_not_in_model (self ):
310+ return []
311+
312+
254313MODELS_AND_DO_FUSION = {
255314 TestSiluMul : [True , False ],
256315 TestFusedAddRMSNorm : [True , False ],
257316 TestRotaryEmbedding : [False ],
258317 TestRotaryEmbeddingSliceScatter : [False ],
259318 TestFunctionWithMutatedArgsAndReturn : [False ],
319+ TestFusedDeepseekV4QnormRopeKvInsert : [False ],
260320}
261321
262322
0 commit comments