You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I am doing LoRA fine-tuning with FSDP, I am seeing a huge memory usage compared to transformers v4.49.0. This issue is specific to versions including 4.50.0 and above. For example,
For 4 GPUs, I see the following memory usage on transformers==4.49.0
Memory allocated after setup: 4.03 GB
Peak memory during training step: 5.36 GB
vs when I am using any higher version transformers==4.54.0
Memory allocated after setup: 4.03 GB
Peak memory during training step: 20.16 GB
The peak memory usage is 4x.
Keeping all other library versions constant, the bug only appears when upgrading transformers to any version above 4.49.0. That's the reason I have raised the bug here and not in accelerate. Downgrading to transformers==4.49.0 fixes the issue.
The issue ends here, but I will provide some of my findings in case it is helpful
I was able to reproduce this issue in other Llama-based models, too.
The bug only appears with FSDP + LoRA. Single GPU jobs don't seem to have the bug.
The memory explosion happens during the backward pass, specifically at: accelerator.backward(loss)
Looking at the memory profiling results, it seems like all attention heads (Q, V) are somehow treated as trainable and the memory is reserved for their optimizer states which is leading to this 4x spike. I am also attaching the photos from the memory profiling.
For fsdp config, I have tried both values of - fsdp_cpu_ram_efficient_loading, fsdp_use_orig_params, with and without setting fsdp_transformer_layer_cls_to_wrap
System Info
transformersversion: 4.54.0Who can help?
@zach-huggingface @SunMarc
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
sft.pyfsdp.yamlRun
Expected behavior
When I am doing LoRA fine-tuning with FSDP, I am seeing a huge memory usage compared to transformers v4.49.0. This issue is specific to versions including 4.50.0 and above. For example,
For 4 GPUs, I see the following memory usage on
transformers==4.49.0vs when I am using any higher version
transformers==4.54.0The peak memory usage is 4x.
Keeping all other library versions constant, the bug only appears when upgrading transformers to any version above 4.49.0. That's the reason I have raised the bug here and not in accelerate. Downgrading to
transformers==4.49.0fixes the issue.The issue ends here, but I will provide some of my findings in case it is helpful
accelerator.backward(loss)fsdp_cpu_ram_efficient_loading,fsdp_use_orig_params, with and without settingfsdp_transformer_layer_cls_to_wrap