Skip to content

change dtype of output after passing through lora_A #1172

@huseyinatahaninan

Description

@huseyinatahaninan

System Info

peft 0.6.2

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

I am not sure if this is really a bug but my question is that after passing the input x through lora_A, should we cast it to the lora_B.weight.dtype like we do it for lora_A.weight.dtype in the first place?

I am talking about this line:

output = lora_B(lora_A(dropout(x)))

instead of output = lora_B(lora_A(dropout(x))) I was thinking if the following should be done output = lora_B(lora_A(dropout(x)).to(lora_B.weight.dtype)) because otherwise for instance in mixed precision training x becomes fp32 but then after passing through lora_A, it becomes bf16 as the input to lora_B. So I was thinking whether we should cast it back to fp32.

Thanks very much for your help in advance!

Expected behavior

na

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions