Skip to content

Commit aaf33ed

Browse files
committed
conflicts
1 parent df8ed96 commit aaf33ed

4 files changed

Lines changed: 12 additions & 319 deletions

File tree

op/sglang/csrc/elementwise/fused_rotary_emb.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ int64_t fused_mla_absorb_rotary_emb(
9292

9393

9494

95-
// dim3 grid = dim3((q_len/4 +4)*(num_local_heads+1)-1, 1, 1);
95+
// BMM part: each thread block covers 256 N values (4 waves * 64 N per wave),
96+
// so the number of N-blocks per head is kv_lora_rank / 256.
97+
// Previously this used kv_lora_rank / 128, which launched 2x too many BMM
98+
// blocks; the extra blocks had hdx >= num_local_heads, read out-of-bounds
99+
// w_kc/q and overwrote valid q_input rows with garbage/NaN.
96100
dim3 grid = dim3((q_len + 15)/16 * kv_lora_rank/256 * num_local_heads + (q_len+3)/4 * num_local_heads + (q_len+3)/4, 1, 1);
97101
dim3 block = dim3(256, 1, 1);
98102
const int latent_cache_stride = latent_cache.stride(0);

op/sglang/include/fused_mla_impl.cuh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,13 @@ __global__ void fused_absorb_mla(
8585
uint32_t bidx = blockIdx.x;
8686
uint32_t tid = threadIdx.x;
8787

88-
if (bidx < (Q_LEN + 15)/16*4*NUM_LOCAL_HEADS) {
88+
// BMM branch: (Q_LEN+15)/16 M-blocks * (KV_LORA_RANK/256) N-blocks per head.
89+
// Each do_bmm block covers 256 N values, so N-blocks per head = KV_LORA_RANK/256.
90+
if (bidx < (Q_LEN + 15)/16*(KV_LORA_RANK/256)*NUM_LOCAL_HEADS) {
8991
do_bmm<scalar_t, 1, QK_NOPE_HEAD_DIM/16, 4, NUM_LOCAL_HEADS, KV_LORA_RANK, QK_NOPE_HEAD_DIM, QK_ROPE_HEAD_DIM>(Q_LEN, q, w_kc, q_input, tid, bidx);
90-
} else if (bidx < ((Q_LEN+3)/4 + (Q_LEN + 15)/16*4) * NUM_LOCAL_HEADS) {
92+
} else if (bidx < ((Q_LEN+3)/4 + (Q_LEN + 15)/16*(KV_LORA_RANK/256)) * NUM_LOCAL_HEADS) {
9193
//do t1/t2
92-
bidx -= (Q_LEN + 15)/16*4*NUM_LOCAL_HEADS;
94+
bidx -= (Q_LEN + 15)/16*(KV_LORA_RANK/256)*NUM_LOCAL_HEADS;
9395
bidx = 4*bidx;
9496

9597
//#pragma unroll
@@ -113,7 +115,7 @@ __global__ void fused_absorb_mla(
113115
);
114116
}
115117
} else {
116-
bidx -= ((Q_LEN+3)/4 + (Q_LEN + 15)/16*4) * NUM_LOCAL_HEADS;
118+
bidx -= ((Q_LEN+3)/4 + (Q_LEN + 15)/16*(KV_LORA_RANK/256)) * NUM_LOCAL_HEADS;
117119
bidx *= 4;
118120

119121
uint32_t m = bidx + tid/QK_ROPE_HEAD_DIM;

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def compute_num_jobs(self):
151151
num_jobs = len(os.sched_getaffinity(0))
152152
except AttributeError:
153153
num_jobs = os.cpu_count()
154-
nvcc_threads = 1
154+
nvcc_threads = 10
155155
return num_jobs, nvcc_threads
156156

157157
#

unit_test/test_fused_mla_absorb_rope.py

Lines changed: 0 additions & 313 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
<<<<<<< HEAD (114324 MC3-8755 sgl057 fused_moe_gate_opt op support glm5 config)
21
import os
32
import time
43
import sys
@@ -641,315 +640,3 @@ def save_tensor_to_bin(tensor, string_name):
641640
show_error(legacy_q_input, fused_q_input, "DIFF ERROR OF Q_INPUT")
642641
show_error(legacy_k_input, fused_k_input, "DIFF ERROR OF K_INPUT")
643642
show_error(legacy_v_input, fused_v_input, "DIFF ERROR OF V_INPUT")
644-
=======
645-
import os
646-
import time
647-
import torch
648-
import torch.nn.functional as F
649-
from torch import nn
650-
from torch.profiler import profile, record_function, ProfilerActivity
651-
import argparse
652-
from sgl_kernel import fused_mla_absorb_rotary_emb
653-
654-
# ============================================================
655-
# Standard GPT-J style rotary embedding (matching kernel implementation)
656-
# ============================================================
657-
def compute_cos_sin_cache(
658-
max_position_embeddings: int,
659-
head_dim: int,
660-
base: float = 10000.0,
661-
device: torch.device = torch.device("cuda"),
662-
dtype: torch.dtype = torch.float32
663-
) -> torch.Tensor:
664-
"""
665-
Compute cos/sin cache for rotary position embedding.
666-
667-
The cache layout is [max_position_embeddings, head_dim] where:
668-
- cache[:, :head_dim//2] contains cos values
669-
- cache[:, head_dim//2:] contains sin values
670-
671-
This matches the format expected by the CUDA kernel.
672-
"""
673-
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=dtype, device=device) / head_dim))
674-
t = torch.arange(max_position_embeddings, dtype=dtype, device=device)
675-
freqs = torch.einsum("i,j->ij", t, inv_freq)
676-
cos = freqs.cos()
677-
sin = freqs.sin()
678-
cache = torch.cat([cos, sin], dim=-1)
679-
return cache
680-
681-
682-
def torch_rotary_emb_gptj_style(
683-
x: torch.Tensor,
684-
cos_sin_cache: torch.Tensor,
685-
positions: torch.Tensor
686-
) -> torch.Tensor:
687-
"""
688-
PyTorch reference implementation of GPT-J style rotary position embedding.
689-
690-
GPT-J style rotates pairs of elements:
691-
- out[2i] = x[2i] * cos - x[2i+1] * sin
692-
- out[2i+1] = x[2i+1] * cos + x[2i] * sin
693-
694-
Args:
695-
x: Input tensor of shape [..., head_dim], e.g., [q_len, num_heads, head_dim]
696-
cos_sin_cache: Cache of shape [max_pos, head_dim] containing [cos, sin]
697-
positions: Position indices of shape [q_len]
698-
699-
Returns:
700-
Rotated tensor of same shape as x
701-
"""
702-
head_dim = x.shape[-1]
703-
704-
# Get cos/sin for each position
705-
cos_sin = cos_sin_cache[positions] # [q_len, head_dim]
706-
cos = cos_sin[..., :head_dim // 2] # [q_len, head_dim//2]
707-
sin = cos_sin[..., head_dim // 2:] # [q_len, head_dim//2]
708-
709-
# Reshape cos/sin to broadcast with input tensor
710-
# x shape: [q_len, num_heads, head_dim]
711-
# We need cos/sin shape: [q_len, 1, head_dim//2] for proper broadcasting
712-
while cos.dim() < x.dim():
713-
cos = cos.unsqueeze(1)
714-
sin = sin.unsqueeze(1)
715-
716-
# Interleave cos and sin for GPT-J style
717-
# x1 = x[..., ::2], x2 = x[..., 1::2]
718-
x1 = x[..., ::2] # Even indices: [..., head_dim//2]
719-
x2 = x[..., 1::2] # Odd indices: [..., head_dim//2]
720-
721-
# Apply rotation
722-
o1 = x1 * cos - x2 * sin
723-
o2 = x2 * cos + x1 * sin
724-
725-
# Interleave output
726-
out = torch.stack([o1, o2], dim=-1).flatten(-2)
727-
return out
728-
729-
730-
class RMSNorm(nn.Module):
731-
def __init__(self, hidden_size, eps=1e-6):
732-
super().__init__()
733-
self.variance_epsilon = eps # Changed to 1e-6 to match kernel
734-
self.weight = nn.Parameter(torch.ones(hidden_size).to(torch.bfloat16))
735-
736-
def forward(
737-
self,
738-
x: torch.Tensor,
739-
residual= None,
740-
):
741-
orig_dtype = x.dtype
742-
x = x.to(torch.float32)
743-
if residual is not None:
744-
x = x + residual.to(torch.float32)
745-
residual = x.to(orig_dtype)
746-
747-
variance = x.pow(2).mean(dim=-1, keepdim=True)
748-
x = x * torch.rsqrt(variance + self.variance_epsilon)
749-
x = x.to(orig_dtype) * self.weight
750-
if residual is None:
751-
return x
752-
else:
753-
return x, residual
754-
755-
def fused_forward_absorb(
756-
q:torch.Tensor, # [bs, 128, 192], dtype=bf16
757-
w_kc:torch.Tensor, # [128, 128, 512], dtype=bf16
758-
latent_cache:torch.Tensor, # [bs, 576], dtype=bf16
759-
cos_sin_cache:torch.Tensor, # [max_position_embeddings, 64], dtype=float32
760-
positions:torch.Tensor, # [bs], dtype=int64
761-
norm_weight:torch.Tensor, # [512], dtype=bf16
762-
q_input:torch.Tensor, # [bs, 128, 576], dtype=bf16
763-
k_input:torch.Tensor, # [bs, 1, 576], dtype=bf16
764-
v_input:torch.Tensor, # [bs, 1, 512]
765-
q_len:int, #16
766-
num_local_heads:int, #128,
767-
kv_lora_rank:int, # 512
768-
qk_rope_head_dim:int, #64
769-
qk_nope_head_dim:int, #128
770-
):
771-
out = fused_mla_absorb_rotary_emb(q, w_kc, latent_cache, cos_sin_cache, positions, norm_weight, q_input, k_input, v_input, q_len, num_local_heads, kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim)
772-
if out != 0:
773-
print("Failed to call fusedMLA.[fused_forward_absorb]")
774-
return q_input, k_input, v_input
775-
776-
def mla_absorb_rotary_emb(
777-
kv_a_layernorm,
778-
cos_sin_cache, # Standard format: [max_pos, head_dim] with [cos, sin]
779-
q:torch.Tensor, # [bs, 128, 192], dtype=bf16
780-
w_kc:torch.Tensor, # [128, 128, 512], dtype=bf16
781-
latent_cache:torch.Tensor, # [bs, 576], dtype=bf16
782-
positions:torch.Tensor, # [bs], dtype=int64
783-
q_input:torch.Tensor, # [bs, 128, 576], dtype=bf16
784-
k_input:torch.Tensor, # [bs, 1, 576], dtype=bf16
785-
v_input:torch.Tensor, # [bs, 1, 512]
786-
q_len:int, # 16
787-
num_local_heads:int, # 128,
788-
kv_lora_rank:int, # 512
789-
qk_rope_head_dim:int, # 64
790-
qk_nope_head_dim:int, # 128
791-
):
792-
"""
793-
Reference PyTorch implementation that matches the CUDA kernel logic.
794-
795-
This implementation uses the same cos_sin_cache format and GPT-J style
796-
rotary embedding as the kernel, ensuring numerical consistency.
797-
"""
798-
# Step 1: BMM - Compute q_nope @ w_kc
799-
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
800-
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
801-
q_input[..., : kv_lora_rank] = q_nope_out.transpose(0, 1)
802-
803-
# Step 2: RMS Norm on latent_cache
804-
v_input = latent_cache[..., : kv_lora_rank]
805-
v_input = kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
806-
807-
# Step 3: Prepare k_input
808-
k_input = latent_cache.unsqueeze(1)
809-
k_input[..., : kv_lora_rank] = v_input
810-
811-
# Step 4: Apply rotary embedding using GPT-J style (matching kernel)
812-
k_pe = k_input[..., kv_lora_rank :]
813-
814-
# Apply GPT-J style rotary embedding
815-
q_pe_rotated = torch_rotary_emb_gptj_style(q_pe, cos_sin_cache, positions)
816-
k_pe_rotated = torch_rotary_emb_gptj_style(k_pe.squeeze(1), cos_sin_cache, positions).unsqueeze(1)
817-
818-
# Step 5: Store rotated results
819-
q_input[..., kv_lora_rank :] = q_pe_rotated
820-
k_input[..., kv_lora_rank :] = k_pe_rotated
821-
822-
return q_input, k_input, v_input
823-
824-
with_profile=False
825-
826-
def show_error(golden, v, tag="DIFF ERROR"):
827-
errors = torch.abs(golden - v)
828-
829-
errors_max = torch.max(errors)
830-
errors_ave = torch.sum(errors) / errors.numel()
831-
832-
max_idx_flat = torch.argmax(errors)
833-
max_idx = torch.unravel_index(max_idx_flat, errors.shape)
834-
835-
golden_val = golden[max_idx]
836-
v_val = v[max_idx]
837-
838-
print(f"{tag}: error_max={errors_max}, error_ave={errors_ave}, max_error_idx={max_idx}")
839-
print(f"golden[{max_idx}]={golden_val}, v[{max_idx}]={v_val}")
840-
841-
def print_profiler_summary(prof, max_key_len=50):
842-
events = prof.key_averages()
843-
events = sorted(events, key=lambda x: (x.device_time_total / x.count) if x.count > 0 else 0, reverse=True)
844-
845-
print(f"{'Name':<{max_key_len}} | {'CPU Time Avg (us)':>20} | {'CUDA Time Avg (us)':>20} | {'Count':>10}")
846-
847-
total_cpu_time = 0.0
848-
total_cuda_time = 0.0
849-
850-
for evt in events:
851-
if evt.count == 0:
852-
continue
853-
854-
cpu_time_avg = evt.cpu_time_total / evt.count
855-
cuda_time_avg = evt.device_time_total / evt.count
856-
key_str = evt.key
857-
858-
if len(key_str) > max_key_len:
859-
key_str = key_str[:max_key_len-3] + '...'
860-
861-
print(f"{key_str.ljust(max_key_len)} | {cpu_time_avg:20.2f} | {cuda_time_avg:20.2f} | {evt.count:10}")
862-
863-
total_cpu_time += cpu_time_avg
864-
total_cuda_time += cuda_time_avg
865-
866-
print("-" * (max_key_len + 55))
867-
print(f"{'Total'.ljust(max_key_len)} | {total_cpu_time:20.2f} | {total_cuda_time:20.2f} |")
868-
869-
870-
if __name__ == "__main__":
871-
parser = argparse.ArgumentParser()
872-
parser.add_argument('--mode', type=str, default="profile", choices=["profile", "acc"])
873-
args = parser.parse_args()
874-
875-
# MLA parameters (matching GLM5 configuration)
876-
q_len = 32
877-
num_local_heads = 4
878-
kv_lora_rank = 512
879-
qk_nope_head_dim = 192
880-
qk_rope_head_dim = 64
881-
hidden_size = 6144
882-
883-
# RMS norm with eps=1e-6 to match kernel
884-
kv_a_layernorm = RMSNorm(kv_lora_rank, eps=1e-6).cuda()
885-
max_position_embeddings = 4096 # Standard size for testing
886-
rope_theta = 10000 # Standard base
887-
888-
# Create standard format cos_sin_cache (matching kernel expectation)
889-
# Format: [max_position_embeddings, qk_rope_head_dim]
890-
# with cos in first half and sin in second half
891-
cos_sin_cache = compute_cos_sin_cache(
892-
max_position_embeddings,
893-
qk_rope_head_dim,
894-
base=rope_theta,
895-
device=torch.device("cuda"),
896-
dtype=torch.float32 # Kernel expects float32
897-
)
898-
print(f"cos_sin_cache shape: {cos_sin_cache.shape}, dtype: {cos_sin_cache.dtype}")
899-
print(f"cos_sin_cache format: cos[:32], sin[32:64] for each position")
900-
901-
for q_len in [64]:
902-
print(f"\n\n================ Profiling q_len={q_len} ================")
903-
q = (torch.rand(q_len, num_local_heads, qk_nope_head_dim+qk_rope_head_dim, dtype=torch.bfloat16).cuda() - 0.5)/10
904-
w_kc = (torch.rand(num_local_heads, qk_nope_head_dim, kv_lora_rank, dtype=torch.bfloat16).cuda() - 0.5) / 10
905-
906-
shape = (q_len, kv_lora_rank+qk_rope_head_dim)
907-
strides = (576, 1) # contiguous stride for GLM5
908-
storage_size = (shape[0] - 1) * strides[0] + (shape[1] - 1) * strides[1] + 1
909-
latent_cache_storage = (torch.rand(storage_size, dtype=torch.bfloat16).cuda() - 0.5) / 10
910-
latent_cache = torch.as_strided(latent_cache_storage, size=shape, stride=strides)
911-
latent_cache2_storage = latent_cache_storage.clone().detach()
912-
latent_cache2 = torch.as_strided(latent_cache2_storage, size=shape, stride=strides)
913-
914-
q_input = torch.zeros(q_len, num_local_heads, kv_lora_rank + qk_rope_head_dim, dtype=torch.bfloat16).cuda()
915-
k_input = torch.zeros(q_len, 1, kv_lora_rank + qk_rope_head_dim, dtype=torch.bfloat16).cuda()
916-
v_input = torch.zeros(q_len, 1, kv_lora_rank, dtype=torch.bfloat16).cuda()
917-
fused_q_input = torch.zeros(q_len, num_local_heads, kv_lora_rank + qk_rope_head_dim, dtype=torch.bfloat16).cuda()
918-
fused_k_input = torch.zeros(q_len, 1, kv_lora_rank + qk_rope_head_dim, dtype=torch.bfloat16).cuda()
919-
fused_v_input = torch.zeros(q_len, 1, kv_lora_rank, dtype=torch.bfloat16).cuda()
920-
921-
positions = torch.arange(0, q_len, dtype=torch.int64).cuda() # Start from 0
922-
923-
print("w_kc stride:", w_kc.stride())
924-
print("latent_cache stride:", latent_cache.stride())
925-
print("latent_cache2 stride:", latent_cache2.stride())
926-
927-
# Run PyTorch reference implementation (using standard GPT-J style rotary)
928-
legacy_q_input, legacy_k_input, legacy_v_input = mla_absorb_rotary_emb(
929-
kv_a_layernorm, cos_sin_cache,
930-
q, w_kc, latent_cache, positions,
931-
q_input, k_input, v_input,
932-
q_len, num_local_heads, kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim
933-
)
934-
935-
# Run CUDA kernel
936-
fused_forward_absorb(
937-
q, w_kc, latent_cache2, cos_sin_cache, positions, kv_a_layernorm.weight,
938-
fused_q_input, fused_k_input, fused_v_input,
939-
q_len, num_local_heads, kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim
940-
)
941-
942-
show_error(legacy_q_input, fused_q_input, "DIFF ERROR OF Q_INPUT")
943-
show_error(legacy_k_input, fused_k_input, "DIFF ERROR OF K_INPUT")
944-
show_error(legacy_v_input, fused_v_input, "DIFF ERROR OF V_INPUT")
945-
946-
# Additional: Check BMM separately
947-
print("\n--- BMM Verification ---")
948-
q_nope = q[..., :qk_nope_head_dim]
949-
q_nope_out_torch = torch.bmm(q_nope.transpose(0, 1).float(), w_kc.float())
950-
q_nope_out_torch = q_nope_out_torch.transpose(0, 1).bfloat16()
951-
952-
# Compare with CUDA result
953-
bmm_error = torch.abs(q_nope_out_torch - fused_q_input[..., :kv_lora_rank])
954-
print(f"BMM max error: {bmm_error.max().item()}, avg error: {bmm_error.mean().item()}")
955-
>>>>>>> CHANGE (4a11cf MC3-8615 fused_mla_absorb_rotary_emb support glm5 model)

0 commit comments

Comments
 (0)