-
Notifications
You must be signed in to change notification settings - Fork 127
Open
Description
I see that the training pipeline of this uses a monkey patch to replace the LLamaAttention.forward with a custom forward pass which uses flash_attn. My system however, does not support flash_attn.
If i turned off the monkey patch, would the regular LLamaAttention.forward be able to run training correctly to create similar results?
eg.
# Need to call this before importing transformers.
from video_chatgpt.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
#replace_llama_attn_with_flash_attn() #What if we just turned this off and trained with the default attn function from LLaMA
from video_chatgpt.train.train import train
if __name__ == "__main__":
train()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels