Skip to content

Can i turn llama_flash_attn_monkey_patch off? #132

@itsjustfons

Description

@itsjustfons

I see that the training pipeline of this uses a monkey patch to replace the LLamaAttention.forward with a custom forward pass which uses flash_attn. My system however, does not support flash_attn.

If i turned off the monkey patch, would the regular LLamaAttention.forward be able to run training correctly to create similar results?

eg.

# Need to call this before importing transformers.
from video_chatgpt.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn

#replace_llama_attn_with_flash_attn() #What if we just turned this off and trained with the default attn function from LLaMA

from video_chatgpt.train.train import train

if __name__ == "__main__":
    train()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions