Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from peft.tuners._buffer_dict import BufferDict
from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge
from peft.utils import ALLOWED_COMPUTE_DTYPES, UPCAST_DTYPES
from peft.utils.integrations import (
dequantize_module_weight,
gather_params_ctx,
Expand Down Expand Up @@ -1985,7 +1986,28 @@ def __init__(self, delta_weight):
super().__init__()
self.delta_weight = delta_weight

@staticmethod
def _low_prec_add(x, y):
# addition in fp8 is not directly supported, need to use a higher precision
orig_dtype = x.dtype
upcast_dtype = y.dtype
if upcast_dtype not in ALLOWED_COMPUTE_DTYPES:
raise RuntimeError(
f"There is an attempt to upcast the targeted parameter to {upcast_dtype} "
f"but the only supported are: {ALLOWED_COMPUTE_DTYPES}."
)

# this operation can be quite costly
x = x.to(upcast_dtype)
z = x + y
# clamp to valid range before casting down
info = torch.finfo(orig_dtype)
z = z.clamp(min=info.min, max=info.max)
return z.to(orig_dtype)

def forward(self, W):
if any(getattr(torch, dtype_name, None) == W.dtype for dtype_name in UPCAST_DTYPES):
return self._low_prec_add(W, self.delta_weight)
return W + self.delta_weight


Expand Down Expand Up @@ -2183,7 +2205,11 @@ def get_delta_weight(self, adapter_name, *args, **kwargs):

base_layer = self.get_base_layer()
param = self.get_param()
delta_weight = delta_weight.to(param.device, param.dtype)
if param.dtype in ALLOWED_COMPUTE_DTYPES:
delta_weight = delta_weight.to(param.device, param.dtype)
else:
# don't cast dW to weight dtype if it is in torch.float8_e4m3fn etc.
delta_weight = delta_weight.to(param.device)
return delta_weight

@contextmanager
Expand Down
9 changes: 7 additions & 2 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from transformers.pytorch_utils import Conv1D

from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, UPCAST_DTYPES
from peft.utils.constants import (
DUMMY_MODEL_CONFIG,
DUMMY_TARGET_MODULES,
Expand Down Expand Up @@ -2111,7 +2111,7 @@ def cast_adapter_dtype(model: nn.Module, adapter_name: str, autocast_adapter_dty
"""
A helper method to cast the adapter weights to the correct dtype.

Currently, this only upcasts float16 and bfloat16 to float32.
Currently, this only upcasts float dtypes to float32.

Args:
adapter_name (`str`):
Expand All @@ -2123,6 +2123,11 @@ def cast_adapter_dtype(model: nn.Module, adapter_name: str, autocast_adapter_dty
return

dtypes_to_convert_to_fp32 = {torch.float16, torch.bfloat16}
# Upcast lower precision floats like float8_e4m3fn; defensively only include dtypes that are actually found, as this
# could depend on torch version and platform
for name in UPCAST_DTYPES:
torch_dtype = getattr(torch, name)
dtypes_to_convert_to_fp32.add(torch_dtype)

for module in model.modules():
if not isinstance(module, BaseTunerLayer):
Expand Down
3 changes: 3 additions & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .constants import ALLOWED_COMPUTE_DTYPES, UPCAST_DTYPES
from .integrations import map_cache_to_layer_device_map
from .loftq_utils import replace_lora_weights_loftq
from .other import (
Expand Down Expand Up @@ -72,6 +73,7 @@


__all__ = [
"ALLOWED_COMPUTE_DTYPES",
"CONFIG_NAME",
"INCLUDE_LINEAR_LAYERS_SHORTHAND",
"SAFETENSORS_WEIGHTS_NAME",
Expand Down Expand Up @@ -100,6 +102,7 @@
"TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING",
"UPCAST_DTYPES",
"WEIGHTS_NAME",
"AuxiliaryTrainingWrapper",
"ModulesToSaveWrapper",
Expand Down
4 changes: 4 additions & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
# otherwise there is no point in optimizing and there is a small chance of bugs in the optimization algorithm, so no
# point in taking unnecessary risks. See #2045 for more context.
MIN_TARGET_MODULES_FOR_OPTIMIZATION = 20
# dtypes that are allowed to be used for adapter computation
ALLOWED_COMPUTE_DTYPES = (torch.float16, torch.bfloat16, torch.float32)
# float dtypes that should be upcast in the adapter for computation
UPCAST_DTYPES = ("float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2", "float8_e5m2fnuz", "float8_e8m0fnu")
109 changes: 109 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
FineGrainedFP8Config,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
Trainer,
Expand Down Expand Up @@ -5951,3 +5952,111 @@ def test_te_lora_backward(self, model_with_te_layers, tokenized_inputs):
assert torch.isfinite(loss), f"Loss is not finite: {loss.item()}"
loss.backward()
optimizer.step()


@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="Platform does not support torch.float8_e4m3fn")
@require_torch_gpu
@pytest.mark.single_gpu_tests
class TestDtypeFp8:
"""Tests that float8 models work.

Note that at this time, these lower dtypes require a GPU, so these tests cannot be added to the standard CPU test
suite.
"""

@pytest.fixture(scope="class", autouse=True)
def setup_cleanup(self):
yield
clear_device_cache(garbage_collection=True)

@pytest.fixture
def model(self):
model_id = "facebook/opt-125m"
# only convert q_proj to fp8, otherwise we get nan results
modules_not_to_convert = ["embed_tokens", "lm_head", "v_proj", "k_proj", "out_proj", "fc1", "fc2"]
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=0,
quantization_config=FineGrainedFP8Config(
modules_to_not_convert=modules_not_to_convert,
),
)
# sanity check
assert model.model.decoder.layers[0].self_attn.q_proj.weight.dtype == torch.float8_e4m3fn
return model

@pytest.mark.parametrize(
"config",
[
LoraConfig(target_modules=["q_proj", "v_proj"]),
VeraConfig(target_modules=["q_proj", "v_proj"]),
RoadConfig(target_modules=["q_proj", "v_proj"]),
],
ids=lambda c: c.__class__.__name__,
)
def test_target_modules_float8_e4m3fn(self, model, config):
# Test should work with all adapters, but only testing a few here to save time and resources.
inputs = torch.arange(10).view(1, -1).to(model.device)
with torch.inference_mode():
output_base = model(inputs)
# sanity check
assert torch.isfinite(output_base.logits).all()

model = get_peft_model(model, config)
with torch.inference_mode():
# check that there are no errors
output_peft = model(inputs)
# with default init, lora should be a no-op
assert torch.allclose(output_peft.logits, output_base.logits)

@pytest.mark.parametrize(
"config",
[
LoraConfig(target_modules=["q_proj", "v_proj"]),
VeraConfig(target_modules=["q_proj", "v_proj"]),
RoadConfig(target_modules=["q_proj", "v_proj"]),
],
ids=lambda c: c.__class__.__name__,
)
@pytest.mark.xfail(reason="Merging with float8 not supported (yet)", strict=True)
def test_merge_with_float8_e4m3fn(self, model, config):
# Test should work with all adapters, but only testing a few here to save time and resources.
inputs = torch.arange(10).view(1, -1).to(model.device)
with torch.inference_mode():
output_base = model(inputs)
# sanity check
assert torch.isfinite(output_base.logits).all()

model = get_peft_model(model, config)
unloaded = model.merge_and_unload()
with torch.inference_mode():
# check that there are no errors
output_unloaded = model(inputs)
# with default init, lora should be a no-op
assert torch.allclose(output_unloaded.logits, output_base.logits)

def test_lora_target_parameters_float8_e4m3fn(self, model):
# target_modules uses a different mechanism (return W + dW) so it gets its own test
inputs = torch.arange(10).view(1, -1).to(model.device)
with torch.inference_mode():
output_base = model(inputs)
# sanity check
assert torch.isfinite(output_base.logits).all()

config = LoraConfig(target_modules=["k_proj", "v_proj"], target_parameters=["q_proj.weight"])
model = get_peft_model(model, config)
with torch.inference_mode():
# check that there are no errors
output_lora = model(inputs)
# with default init, lora should be a no-op
assert torch.allclose(output_lora.logits, output_base.logits)

def test_target_modules_no_autocast_prevserves_e4m3fn(self, model):
# ensure that users can choose to keep the adapter weights in the same dtype as the original weights by passing
# autocast_adapter_dtype=False, even though the resulting model is not usable (no inference or training
# possible)
config = LoraConfig(target_modules=["q_proj", "v_proj"])
model = get_peft_model(model, config, autocast_adapter_dtype=False)
q_proj = model.base_model.model.model.decoder.layers[0].self_attn.q_proj
assert q_proj.lora_A.default.weight.dtype == torch.float8_e4m3fn
assert q_proj.lora_B.default.weight.dtype == torch.float8_e4m3fn
Loading