Skip to content

[Issue]: Performance degradation on mamba2's triton kernel #777

Description

@stephen-youn

Problem Description

When compared triton3.1 with the latest 3.3, mamba2 module in mamba2-2.7b that runs triton kernels (mamba_chunk_scan_combined or mamba_split_conv1d_scan_combined) runs slower. Specifically, for sequence length 4096, when combined bwd and fwd pass latencies, 7% degradation was observed from mi300x.

| fwd+bwd latency in ms | triton3.3| triton3.1 |
| mamba2-2.7b mixer module | 4.03 | 3.73 |

Operating System

Linux tw028 5.15.0-116-generic #126-Ubuntu SMP Mon Jul 1 10:14:24 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux

CPU

AMD EPYC 9575F 64-Core Processor

GPU

MI300x

ROCm Version

rocm6.3

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions