Skip to content

Can i turn llama_flash_attn_monkey_patch off? #132

@itsjustfons

Description

@itsjustfons

I see that the training pipeline of this uses a monkey patch to replace the LLamaAttention.forward with a custom forward pass which uses flash_attn. My system however, does not support flash_attn.

If i turned off the monkey patch, would the regular LLamaAttention.forward be able to run training correctly to create similar results?

eg.

# Need to call this before importing transformers.
from video_chatgpt.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn

#replace_llama_attn_with_flash_attn() #What if we just turned this off and trained with the default attn function from LLaMA

from video_chatgpt.train.train import train

if __name__ == "__main__":
    train()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions