diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index bcf5043457..0a60152b8e 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -26,6 +26,7 @@ from peft.tuners._buffer_dict import BufferDict from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge +from peft.utils import ALLOWED_COMPUTE_DTYPES, UPCAST_DTYPES from peft.utils.integrations import ( dequantize_module_weight, gather_params_ctx, @@ -1985,7 +1986,28 @@ def __init__(self, delta_weight): super().__init__() self.delta_weight = delta_weight + @staticmethod + def _low_prec_add(x, y): + # addition in fp8 is not directly supported, need to use a higher precision + orig_dtype = x.dtype + upcast_dtype = y.dtype + if upcast_dtype not in ALLOWED_COMPUTE_DTYPES: + raise RuntimeError( + f"There is an attempt to upcast the targeted parameter to {upcast_dtype} " + f"but the only supported are: {ALLOWED_COMPUTE_DTYPES}." + ) + + # this operation can be quite costly + x = x.to(upcast_dtype) + z = x + y + # clamp to valid range before casting down + info = torch.finfo(orig_dtype) + z = z.clamp(min=info.min, max=info.max) + return z.to(orig_dtype) + def forward(self, W): + if any(getattr(torch, dtype_name, None) == W.dtype for dtype_name in UPCAST_DTYPES): + return self._low_prec_add(W, self.delta_weight) return W + self.delta_weight @@ -2183,7 +2205,11 @@ def get_delta_weight(self, adapter_name, *args, **kwargs): base_layer = self.get_base_layer() param = self.get_param() - delta_weight = delta_weight.to(param.device, param.dtype) + if param.dtype in ALLOWED_COMPUTE_DTYPES: + delta_weight = delta_weight.to(param.device, param.dtype) + else: + # don't cast dW to weight dtype if it is in torch.float8_e4m3fn etc. + delta_weight = delta_weight.to(param.device) return delta_weight @contextmanager diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 1e90e3b328..cfd024f759 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -34,7 +34,7 @@ from transformers.pytorch_utils import Conv1D from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING -from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND +from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, UPCAST_DTYPES from peft.utils.constants import ( DUMMY_MODEL_CONFIG, DUMMY_TARGET_MODULES, @@ -2111,7 +2111,7 @@ def cast_adapter_dtype(model: nn.Module, adapter_name: str, autocast_adapter_dty """ A helper method to cast the adapter weights to the correct dtype. - Currently, this only upcasts float16 and bfloat16 to float32. + Currently, this only upcasts float dtypes to float32. Args: adapter_name (`str`): @@ -2123,6 +2123,11 @@ def cast_adapter_dtype(model: nn.Module, adapter_name: str, autocast_adapter_dty return dtypes_to_convert_to_fp32 = {torch.float16, torch.bfloat16} + # Upcast lower precision floats like float8_e4m3fn; defensively only include dtypes that are actually found, as this + # could depend on torch version and platform + for name in UPCAST_DTYPES: + torch_dtype = getattr(torch, name) + dtypes_to_convert_to_fp32.add(torch_dtype) for module in model.modules(): if not isinstance(module, BaseTunerLayer): diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 1e96f515e8..da8f360bd7 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .constants import ALLOWED_COMPUTE_DTYPES, UPCAST_DTYPES from .integrations import map_cache_to_layer_device_map from .loftq_utils import replace_lora_weights_loftq from .other import ( @@ -72,6 +73,7 @@ __all__ = [ + "ALLOWED_COMPUTE_DTYPES", "CONFIG_NAME", "INCLUDE_LINEAR_LAYERS_SHORTHAND", "SAFETENSORS_WEIGHTS_NAME", @@ -100,6 +102,7 @@ "TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING", + "UPCAST_DTYPES", "WEIGHTS_NAME", "AuxiliaryTrainingWrapper", "ModulesToSaveWrapper", diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index 24f53a9ff0..91368c5520 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -363,3 +363,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values): # otherwise there is no point in optimizing and there is a small chance of bugs in the optimization algorithm, so no # point in taking unnecessary risks. See #2045 for more context. MIN_TARGET_MODULES_FOR_OPTIMIZATION = 20 +# dtypes that are allowed to be used for adapter computation +ALLOWED_COMPUTE_DTYPES = (torch.float16, torch.bfloat16, torch.float32) +# float dtypes that should be upcast in the adapter for computation +UPCAST_DTYPES = ("float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2", "float8_e5m2fnuz", "float8_e8m0fnu") diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 46282a60e4..ec9c576949 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -43,6 +43,7 @@ AutoTokenizer, BitsAndBytesConfig, DataCollatorForLanguageModeling, + FineGrainedFP8Config, Seq2SeqTrainer, Seq2SeqTrainingArguments, Trainer, @@ -5951,3 +5952,111 @@ def test_te_lora_backward(self, model_with_te_layers, tokenized_inputs): assert torch.isfinite(loss), f"Loss is not finite: {loss.item()}" loss.backward() optimizer.step() + + +@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="Platform does not support torch.float8_e4m3fn") +@require_torch_gpu +@pytest.mark.single_gpu_tests +class TestDtypeFp8: + """Tests that float8 models work. + + Note that at this time, these lower dtypes require a GPU, so these tests cannot be added to the standard CPU test + suite. + """ + + @pytest.fixture(scope="class", autouse=True) + def setup_cleanup(self): + yield + clear_device_cache(garbage_collection=True) + + @pytest.fixture + def model(self): + model_id = "facebook/opt-125m" + # only convert q_proj to fp8, otherwise we get nan results + modules_not_to_convert = ["embed_tokens", "lm_head", "v_proj", "k_proj", "out_proj", "fc1", "fc2"] + model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map=0, + quantization_config=FineGrainedFP8Config( + modules_to_not_convert=modules_not_to_convert, + ), + ) + # sanity check + assert model.model.decoder.layers[0].self_attn.q_proj.weight.dtype == torch.float8_e4m3fn + return model + + @pytest.mark.parametrize( + "config", + [ + LoraConfig(target_modules=["q_proj", "v_proj"]), + VeraConfig(target_modules=["q_proj", "v_proj"]), + RoadConfig(target_modules=["q_proj", "v_proj"]), + ], + ids=lambda c: c.__class__.__name__, + ) + def test_target_modules_float8_e4m3fn(self, model, config): + # Test should work with all adapters, but only testing a few here to save time and resources. + inputs = torch.arange(10).view(1, -1).to(model.device) + with torch.inference_mode(): + output_base = model(inputs) + # sanity check + assert torch.isfinite(output_base.logits).all() + + model = get_peft_model(model, config) + with torch.inference_mode(): + # check that there are no errors + output_peft = model(inputs) + # with default init, lora should be a no-op + assert torch.allclose(output_peft.logits, output_base.logits) + + @pytest.mark.parametrize( + "config", + [ + LoraConfig(target_modules=["q_proj", "v_proj"]), + VeraConfig(target_modules=["q_proj", "v_proj"]), + RoadConfig(target_modules=["q_proj", "v_proj"]), + ], + ids=lambda c: c.__class__.__name__, + ) + @pytest.mark.xfail(reason="Merging with float8 not supported (yet)", strict=True) + def test_merge_with_float8_e4m3fn(self, model, config): + # Test should work with all adapters, but only testing a few here to save time and resources. + inputs = torch.arange(10).view(1, -1).to(model.device) + with torch.inference_mode(): + output_base = model(inputs) + # sanity check + assert torch.isfinite(output_base.logits).all() + + model = get_peft_model(model, config) + unloaded = model.merge_and_unload() + with torch.inference_mode(): + # check that there are no errors + output_unloaded = model(inputs) + # with default init, lora should be a no-op + assert torch.allclose(output_unloaded.logits, output_base.logits) + + def test_lora_target_parameters_float8_e4m3fn(self, model): + # target_modules uses a different mechanism (return W + dW) so it gets its own test + inputs = torch.arange(10).view(1, -1).to(model.device) + with torch.inference_mode(): + output_base = model(inputs) + # sanity check + assert torch.isfinite(output_base.logits).all() + + config = LoraConfig(target_modules=["k_proj", "v_proj"], target_parameters=["q_proj.weight"]) + model = get_peft_model(model, config) + with torch.inference_mode(): + # check that there are no errors + output_lora = model(inputs) + # with default init, lora should be a no-op + assert torch.allclose(output_lora.logits, output_base.logits) + + def test_target_modules_no_autocast_prevserves_e4m3fn(self, model): + # ensure that users can choose to keep the adapter weights in the same dtype as the original weights by passing + # autocast_adapter_dtype=False, even though the resulting model is not usable (no inference or training + # possible) + config = LoraConfig(target_modules=["q_proj", "v_proj"]) + model = get_peft_model(model, config, autocast_adapter_dtype=False) + q_proj = model.base_model.model.model.decoder.layers[0].self_attn.q_proj + assert q_proj.lora_A.default.weight.dtype == torch.float8_e4m3fn + assert q_proj.lora_B.default.weight.dtype == torch.float8_e4m3fn