-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Description
Hi all,
I have noticed that the attention function of my GPT-2 (_attn in modeling_gpt2.py) received float32 despite using autocast to bf16 in the context manager. Everything works alright when turning off Lora and PEFT.
Library versions:
peft: '0.4.0'
transformers '4.28.1'
torch '2.0.1+cu117'
Am I doing something wrong?
Cheers,
Martin
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder - My own task or dataset (give details below)
Reproduction
Minimum working code:
import torch
from peft import get_peft_model, LoraConfig
from transformers import AutoModel
if __name__ == "__main__":
model = AutoModel.from_pretrained("gpt2").cuda().bfloat16()
lora_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05, bias="lora_only", task_type="lm")
model = get_peft_model(model, lora_cfg)
# prepare dummy inputs
input_ids, labels = torch.randint(0, 1000, (2, 10)).cuda(), torch.randint(0, 1000, (10,)).cuda()
# forward pass
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
outputs = model(input_ids)
loss, logits = outputs[:2]Check the dtype when entering _attn function
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
...
>>> key.dtype # with peft
torch.float32
>>> key.dtype # without peft
torch.bfloat16Expected behavior
autocast works
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels