|
1 | | -<<<<<<< HEAD (114324 MC3-8755 sgl057 fused_moe_gate_opt op support glm5 config) |
2 | 1 | import os |
3 | 2 | import time |
4 | 3 | import sys |
@@ -641,315 +640,3 @@ def save_tensor_to_bin(tensor, string_name): |
641 | 640 | show_error(legacy_q_input, fused_q_input, "DIFF ERROR OF Q_INPUT") |
642 | 641 | show_error(legacy_k_input, fused_k_input, "DIFF ERROR OF K_INPUT") |
643 | 642 | show_error(legacy_v_input, fused_v_input, "DIFF ERROR OF V_INPUT") |
644 | | -======= |
645 | | -import os |
646 | | -import time |
647 | | -import torch |
648 | | -import torch.nn.functional as F |
649 | | -from torch import nn |
650 | | -from torch.profiler import profile, record_function, ProfilerActivity |
651 | | -import argparse |
652 | | -from sgl_kernel import fused_mla_absorb_rotary_emb |
653 | | - |
654 | | -# ============================================================ |
655 | | -# Standard GPT-J style rotary embedding (matching kernel implementation) |
656 | | -# ============================================================ |
657 | | -def compute_cos_sin_cache( |
658 | | - max_position_embeddings: int, |
659 | | - head_dim: int, |
660 | | - base: float = 10000.0, |
661 | | - device: torch.device = torch.device("cuda"), |
662 | | - dtype: torch.dtype = torch.float32 |
663 | | -) -> torch.Tensor: |
664 | | - """ |
665 | | - Compute cos/sin cache for rotary position embedding. |
666 | | -
|
667 | | - The cache layout is [max_position_embeddings, head_dim] where: |
668 | | - - cache[:, :head_dim//2] contains cos values |
669 | | - - cache[:, head_dim//2:] contains sin values |
670 | | -
|
671 | | - This matches the format expected by the CUDA kernel. |
672 | | - """ |
673 | | - inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=dtype, device=device) / head_dim)) |
674 | | - t = torch.arange(max_position_embeddings, dtype=dtype, device=device) |
675 | | - freqs = torch.einsum("i,j->ij", t, inv_freq) |
676 | | - cos = freqs.cos() |
677 | | - sin = freqs.sin() |
678 | | - cache = torch.cat([cos, sin], dim=-1) |
679 | | - return cache |
680 | | - |
681 | | - |
682 | | -def torch_rotary_emb_gptj_style( |
683 | | - x: torch.Tensor, |
684 | | - cos_sin_cache: torch.Tensor, |
685 | | - positions: torch.Tensor |
686 | | -) -> torch.Tensor: |
687 | | - """ |
688 | | - PyTorch reference implementation of GPT-J style rotary position embedding. |
689 | | -
|
690 | | - GPT-J style rotates pairs of elements: |
691 | | - - out[2i] = x[2i] * cos - x[2i+1] * sin |
692 | | - - out[2i+1] = x[2i+1] * cos + x[2i] * sin |
693 | | -
|
694 | | - Args: |
695 | | - x: Input tensor of shape [..., head_dim], e.g., [q_len, num_heads, head_dim] |
696 | | - cos_sin_cache: Cache of shape [max_pos, head_dim] containing [cos, sin] |
697 | | - positions: Position indices of shape [q_len] |
698 | | -
|
699 | | - Returns: |
700 | | - Rotated tensor of same shape as x |
701 | | - """ |
702 | | - head_dim = x.shape[-1] |
703 | | - |
704 | | - # Get cos/sin for each position |
705 | | - cos_sin = cos_sin_cache[positions] # [q_len, head_dim] |
706 | | - cos = cos_sin[..., :head_dim // 2] # [q_len, head_dim//2] |
707 | | - sin = cos_sin[..., head_dim // 2:] # [q_len, head_dim//2] |
708 | | - |
709 | | - # Reshape cos/sin to broadcast with input tensor |
710 | | - # x shape: [q_len, num_heads, head_dim] |
711 | | - # We need cos/sin shape: [q_len, 1, head_dim//2] for proper broadcasting |
712 | | - while cos.dim() < x.dim(): |
713 | | - cos = cos.unsqueeze(1) |
714 | | - sin = sin.unsqueeze(1) |
715 | | - |
716 | | - # Interleave cos and sin for GPT-J style |
717 | | - # x1 = x[..., ::2], x2 = x[..., 1::2] |
718 | | - x1 = x[..., ::2] # Even indices: [..., head_dim//2] |
719 | | - x2 = x[..., 1::2] # Odd indices: [..., head_dim//2] |
720 | | - |
721 | | - # Apply rotation |
722 | | - o1 = x1 * cos - x2 * sin |
723 | | - o2 = x2 * cos + x1 * sin |
724 | | - |
725 | | - # Interleave output |
726 | | - out = torch.stack([o1, o2], dim=-1).flatten(-2) |
727 | | - return out |
728 | | - |
729 | | - |
730 | | -class RMSNorm(nn.Module): |
731 | | - def __init__(self, hidden_size, eps=1e-6): |
732 | | - super().__init__() |
733 | | - self.variance_epsilon = eps # Changed to 1e-6 to match kernel |
734 | | - self.weight = nn.Parameter(torch.ones(hidden_size).to(torch.bfloat16)) |
735 | | - |
736 | | - def forward( |
737 | | - self, |
738 | | - x: torch.Tensor, |
739 | | - residual= None, |
740 | | - ): |
741 | | - orig_dtype = x.dtype |
742 | | - x = x.to(torch.float32) |
743 | | - if residual is not None: |
744 | | - x = x + residual.to(torch.float32) |
745 | | - residual = x.to(orig_dtype) |
746 | | - |
747 | | - variance = x.pow(2).mean(dim=-1, keepdim=True) |
748 | | - x = x * torch.rsqrt(variance + self.variance_epsilon) |
749 | | - x = x.to(orig_dtype) * self.weight |
750 | | - if residual is None: |
751 | | - return x |
752 | | - else: |
753 | | - return x, residual |
754 | | - |
755 | | -def fused_forward_absorb( |
756 | | - q:torch.Tensor, # [bs, 128, 192], dtype=bf16 |
757 | | - w_kc:torch.Tensor, # [128, 128, 512], dtype=bf16 |
758 | | - latent_cache:torch.Tensor, # [bs, 576], dtype=bf16 |
759 | | - cos_sin_cache:torch.Tensor, # [max_position_embeddings, 64], dtype=float32 |
760 | | - positions:torch.Tensor, # [bs], dtype=int64 |
761 | | - norm_weight:torch.Tensor, # [512], dtype=bf16 |
762 | | - q_input:torch.Tensor, # [bs, 128, 576], dtype=bf16 |
763 | | - k_input:torch.Tensor, # [bs, 1, 576], dtype=bf16 |
764 | | - v_input:torch.Tensor, # [bs, 1, 512] |
765 | | - q_len:int, #16 |
766 | | - num_local_heads:int, #128, |
767 | | - kv_lora_rank:int, # 512 |
768 | | - qk_rope_head_dim:int, #64 |
769 | | - qk_nope_head_dim:int, #128 |
770 | | -): |
771 | | - out = fused_mla_absorb_rotary_emb(q, w_kc, latent_cache, cos_sin_cache, positions, norm_weight, q_input, k_input, v_input, q_len, num_local_heads, kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim) |
772 | | - if out != 0: |
773 | | - print("Failed to call fusedMLA.[fused_forward_absorb]") |
774 | | - return q_input, k_input, v_input |
775 | | - |
776 | | -def mla_absorb_rotary_emb( |
777 | | - kv_a_layernorm, |
778 | | - cos_sin_cache, # Standard format: [max_pos, head_dim] with [cos, sin] |
779 | | - q:torch.Tensor, # [bs, 128, 192], dtype=bf16 |
780 | | - w_kc:torch.Tensor, # [128, 128, 512], dtype=bf16 |
781 | | - latent_cache:torch.Tensor, # [bs, 576], dtype=bf16 |
782 | | - positions:torch.Tensor, # [bs], dtype=int64 |
783 | | - q_input:torch.Tensor, # [bs, 128, 576], dtype=bf16 |
784 | | - k_input:torch.Tensor, # [bs, 1, 576], dtype=bf16 |
785 | | - v_input:torch.Tensor, # [bs, 1, 512] |
786 | | - q_len:int, # 16 |
787 | | - num_local_heads:int, # 128, |
788 | | - kv_lora_rank:int, # 512 |
789 | | - qk_rope_head_dim:int, # 64 |
790 | | - qk_nope_head_dim:int, # 128 |
791 | | -): |
792 | | - """ |
793 | | - Reference PyTorch implementation that matches the CUDA kernel logic. |
794 | | -
|
795 | | - This implementation uses the same cos_sin_cache format and GPT-J style |
796 | | - rotary embedding as the kernel, ensuring numerical consistency. |
797 | | - """ |
798 | | - # Step 1: BMM - Compute q_nope @ w_kc |
799 | | - q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) |
800 | | - q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc) |
801 | | - q_input[..., : kv_lora_rank] = q_nope_out.transpose(0, 1) |
802 | | - |
803 | | - # Step 2: RMS Norm on latent_cache |
804 | | - v_input = latent_cache[..., : kv_lora_rank] |
805 | | - v_input = kv_a_layernorm(v_input.contiguous()).unsqueeze(1) |
806 | | - |
807 | | - # Step 3: Prepare k_input |
808 | | - k_input = latent_cache.unsqueeze(1) |
809 | | - k_input[..., : kv_lora_rank] = v_input |
810 | | - |
811 | | - # Step 4: Apply rotary embedding using GPT-J style (matching kernel) |
812 | | - k_pe = k_input[..., kv_lora_rank :] |
813 | | - |
814 | | - # Apply GPT-J style rotary embedding |
815 | | - q_pe_rotated = torch_rotary_emb_gptj_style(q_pe, cos_sin_cache, positions) |
816 | | - k_pe_rotated = torch_rotary_emb_gptj_style(k_pe.squeeze(1), cos_sin_cache, positions).unsqueeze(1) |
817 | | - |
818 | | - # Step 5: Store rotated results |
819 | | - q_input[..., kv_lora_rank :] = q_pe_rotated |
820 | | - k_input[..., kv_lora_rank :] = k_pe_rotated |
821 | | - |
822 | | - return q_input, k_input, v_input |
823 | | - |
824 | | -with_profile=False |
825 | | - |
826 | | -def show_error(golden, v, tag="DIFF ERROR"): |
827 | | - errors = torch.abs(golden - v) |
828 | | - |
829 | | - errors_max = torch.max(errors) |
830 | | - errors_ave = torch.sum(errors) / errors.numel() |
831 | | - |
832 | | - max_idx_flat = torch.argmax(errors) |
833 | | - max_idx = torch.unravel_index(max_idx_flat, errors.shape) |
834 | | - |
835 | | - golden_val = golden[max_idx] |
836 | | - v_val = v[max_idx] |
837 | | - |
838 | | - print(f"{tag}: error_max={errors_max}, error_ave={errors_ave}, max_error_idx={max_idx}") |
839 | | - print(f"golden[{max_idx}]={golden_val}, v[{max_idx}]={v_val}") |
840 | | - |
841 | | -def print_profiler_summary(prof, max_key_len=50): |
842 | | - events = prof.key_averages() |
843 | | - events = sorted(events, key=lambda x: (x.device_time_total / x.count) if x.count > 0 else 0, reverse=True) |
844 | | - |
845 | | - print(f"{'Name':<{max_key_len}} | {'CPU Time Avg (us)':>20} | {'CUDA Time Avg (us)':>20} | {'Count':>10}") |
846 | | - |
847 | | - total_cpu_time = 0.0 |
848 | | - total_cuda_time = 0.0 |
849 | | - |
850 | | - for evt in events: |
851 | | - if evt.count == 0: |
852 | | - continue |
853 | | - |
854 | | - cpu_time_avg = evt.cpu_time_total / evt.count |
855 | | - cuda_time_avg = evt.device_time_total / evt.count |
856 | | - key_str = evt.key |
857 | | - |
858 | | - if len(key_str) > max_key_len: |
859 | | - key_str = key_str[:max_key_len-3] + '...' |
860 | | - |
861 | | - print(f"{key_str.ljust(max_key_len)} | {cpu_time_avg:20.2f} | {cuda_time_avg:20.2f} | {evt.count:10}") |
862 | | - |
863 | | - total_cpu_time += cpu_time_avg |
864 | | - total_cuda_time += cuda_time_avg |
865 | | - |
866 | | - print("-" * (max_key_len + 55)) |
867 | | - print(f"{'Total'.ljust(max_key_len)} | {total_cpu_time:20.2f} | {total_cuda_time:20.2f} |") |
868 | | - |
869 | | - |
870 | | -if __name__ == "__main__": |
871 | | - parser = argparse.ArgumentParser() |
872 | | - parser.add_argument('--mode', type=str, default="profile", choices=["profile", "acc"]) |
873 | | - args = parser.parse_args() |
874 | | - |
875 | | - # MLA parameters (matching GLM5 configuration) |
876 | | - q_len = 32 |
877 | | - num_local_heads = 4 |
878 | | - kv_lora_rank = 512 |
879 | | - qk_nope_head_dim = 192 |
880 | | - qk_rope_head_dim = 64 |
881 | | - hidden_size = 6144 |
882 | | - |
883 | | - # RMS norm with eps=1e-6 to match kernel |
884 | | - kv_a_layernorm = RMSNorm(kv_lora_rank, eps=1e-6).cuda() |
885 | | - max_position_embeddings = 4096 # Standard size for testing |
886 | | - rope_theta = 10000 # Standard base |
887 | | - |
888 | | - # Create standard format cos_sin_cache (matching kernel expectation) |
889 | | - # Format: [max_position_embeddings, qk_rope_head_dim] |
890 | | - # with cos in first half and sin in second half |
891 | | - cos_sin_cache = compute_cos_sin_cache( |
892 | | - max_position_embeddings, |
893 | | - qk_rope_head_dim, |
894 | | - base=rope_theta, |
895 | | - device=torch.device("cuda"), |
896 | | - dtype=torch.float32 # Kernel expects float32 |
897 | | - ) |
898 | | - print(f"cos_sin_cache shape: {cos_sin_cache.shape}, dtype: {cos_sin_cache.dtype}") |
899 | | - print(f"cos_sin_cache format: cos[:32], sin[32:64] for each position") |
900 | | - |
901 | | - for q_len in [64]: |
902 | | - print(f"\n\n================ Profiling q_len={q_len} ================") |
903 | | - q = (torch.rand(q_len, num_local_heads, qk_nope_head_dim+qk_rope_head_dim, dtype=torch.bfloat16).cuda() - 0.5)/10 |
904 | | - w_kc = (torch.rand(num_local_heads, qk_nope_head_dim, kv_lora_rank, dtype=torch.bfloat16).cuda() - 0.5) / 10 |
905 | | - |
906 | | - shape = (q_len, kv_lora_rank+qk_rope_head_dim) |
907 | | - strides = (576, 1) # contiguous stride for GLM5 |
908 | | - storage_size = (shape[0] - 1) * strides[0] + (shape[1] - 1) * strides[1] + 1 |
909 | | - latent_cache_storage = (torch.rand(storage_size, dtype=torch.bfloat16).cuda() - 0.5) / 10 |
910 | | - latent_cache = torch.as_strided(latent_cache_storage, size=shape, stride=strides) |
911 | | - latent_cache2_storage = latent_cache_storage.clone().detach() |
912 | | - latent_cache2 = torch.as_strided(latent_cache2_storage, size=shape, stride=strides) |
913 | | - |
914 | | - q_input = torch.zeros(q_len, num_local_heads, kv_lora_rank + qk_rope_head_dim, dtype=torch.bfloat16).cuda() |
915 | | - k_input = torch.zeros(q_len, 1, kv_lora_rank + qk_rope_head_dim, dtype=torch.bfloat16).cuda() |
916 | | - v_input = torch.zeros(q_len, 1, kv_lora_rank, dtype=torch.bfloat16).cuda() |
917 | | - fused_q_input = torch.zeros(q_len, num_local_heads, kv_lora_rank + qk_rope_head_dim, dtype=torch.bfloat16).cuda() |
918 | | - fused_k_input = torch.zeros(q_len, 1, kv_lora_rank + qk_rope_head_dim, dtype=torch.bfloat16).cuda() |
919 | | - fused_v_input = torch.zeros(q_len, 1, kv_lora_rank, dtype=torch.bfloat16).cuda() |
920 | | - |
921 | | - positions = torch.arange(0, q_len, dtype=torch.int64).cuda() # Start from 0 |
922 | | - |
923 | | - print("w_kc stride:", w_kc.stride()) |
924 | | - print("latent_cache stride:", latent_cache.stride()) |
925 | | - print("latent_cache2 stride:", latent_cache2.stride()) |
926 | | - |
927 | | - # Run PyTorch reference implementation (using standard GPT-J style rotary) |
928 | | - legacy_q_input, legacy_k_input, legacy_v_input = mla_absorb_rotary_emb( |
929 | | - kv_a_layernorm, cos_sin_cache, |
930 | | - q, w_kc, latent_cache, positions, |
931 | | - q_input, k_input, v_input, |
932 | | - q_len, num_local_heads, kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim |
933 | | - ) |
934 | | - |
935 | | - # Run CUDA kernel |
936 | | - fused_forward_absorb( |
937 | | - q, w_kc, latent_cache2, cos_sin_cache, positions, kv_a_layernorm.weight, |
938 | | - fused_q_input, fused_k_input, fused_v_input, |
939 | | - q_len, num_local_heads, kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim |
940 | | - ) |
941 | | - |
942 | | - show_error(legacy_q_input, fused_q_input, "DIFF ERROR OF Q_INPUT") |
943 | | - show_error(legacy_k_input, fused_k_input, "DIFF ERROR OF K_INPUT") |
944 | | - show_error(legacy_v_input, fused_v_input, "DIFF ERROR OF V_INPUT") |
945 | | - |
946 | | - # Additional: Check BMM separately |
947 | | - print("\n--- BMM Verification ---") |
948 | | - q_nope = q[..., :qk_nope_head_dim] |
949 | | - q_nope_out_torch = torch.bmm(q_nope.transpose(0, 1).float(), w_kc.float()) |
950 | | - q_nope_out_torch = q_nope_out_torch.transpose(0, 1).bfloat16() |
951 | | - |
952 | | - # Compare with CUDA result |
953 | | - bmm_error = torch.abs(q_nope_out_torch - fused_q_input[..., :kv_lora_rank]) |
954 | | - print(f"BMM max error: {bmm_error.max().item()}, avg error: {bmm_error.mean().item()}") |
955 | | ->>>>>>> CHANGE (4a11cf MC3-8615 fused_mla_absorb_rotary_emb support glm5 model) |
0 commit comments