diff --git a/CHANGELOG.md b/CHANGELOG.md index d680618e8..e9752a733 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added ability to try loading latest checkpoint from save folder using `--try_load_latest_save`. +- Added support for flash attention and gradient checkpointing to `hf_olmo`. ## [v0.5.0](https://github.com/allenai/OLMo/releases/tag/v0.5.0) - 2024-08-26 diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index c24b9f2e3..12dca3f08 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -1,6 +1,6 @@ import logging from dataclasses import fields -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from transformers import PreTrainedModel @@ -8,7 +8,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.auto import AutoModelForCausalLM -from olmo.config import ModelConfig +from olmo.config import ActivationCheckpointingStrategy, ModelConfig from olmo.model import OLMo from .configuration_olmo import OLMoConfig @@ -26,6 +26,15 @@ def create_model_config_from_pretrained_config(config: OLMoConfig): kwargs[field.name] = getattr(config, field.name) model_config = ModelConfig(**kwargs) + + # Handle flash attention settings + if config._attn_implementation == "flash_attention_2": + model_config.flash_attention = True + elif config._attn_implementation in ("eager", "sdpa"): + model_config.flash_attention = False + else: + raise ValueError(f"Unexpected _attn_implementation {config._attn_implementation}") + return model_config @@ -37,10 +46,16 @@ class OLMoForCausalLM(PreTrainedModel): config_class = OLMoConfig base_model_prefix = "model" _no_split_modules = ["OLMoBlock"] + _supports_flash_attn_2 = True + _supports_sdpa = True + supports_gradient_checkpointing = True def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False): super().__init__(config) + self._gradient_checkpointing_func: Optional[Callable] = None + self._gradient_checkpointing = False + if not model: model_config = create_model_config_from_pretrained_config(config) # Initialize model (always on CPU to start with so we don't run out of GPU memory). @@ -49,6 +64,23 @@ def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params else: self.model = model + @property + def gradient_checkpointing(self) -> bool: + return self._gradient_checkpointing + + @gradient_checkpointing.setter + def gradient_checkpointing(self, enabled: bool): + if self._gradient_checkpointing == enabled: + return + + # HF does not specify a way to pass checkpointing strategies, so we pick + # whole layer as our strategy. We can make this configurable later if needed. + checkpointing_strategy = ActivationCheckpointingStrategy.whole_layer if enabled else None + self.model.set_activation_checkpointing( + checkpointing_strategy, checkpoint_func=self._gradient_checkpointing_func + ) + self._gradient_checkpointing = enabled + def forward( self, input_ids: torch.LongTensor = None, diff --git a/olmo/model.py b/olmo/model.py index 7f6e56aa1..ab7ed9f0d 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -418,7 +418,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): self.__cache = cache assert config.d_model % config.n_heads == 0 - self._activation_checkpoint_fn = None + self._activation_checkpoint_fn: Optional[Callable] = None # Dropout. self.dropout = Dropout(config.residual_dropout) @@ -500,9 +500,11 @@ def reset_parameters(self): init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor) init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor) - def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + def set_activation_checkpointing( + self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None + ): if strategy == ActivationCheckpointingStrategy.fine_grained: - self._activation_checkpoint_fn = activation_checkpoint_function(self.config) + self._activation_checkpoint_fn = checkpoint_func or activation_checkpoint_function(self.config) else: self._activation_checkpoint_fn = None @@ -980,7 +982,7 @@ class OLMoOutput(NamedTuple): Attention keys and values from each block. """ - hidden_states: Optional[Tuple[torch.Tensor]] + hidden_states: Optional[Tuple[torch.Tensor, ...]] """ Hidden states from each block. """ @@ -1050,10 +1052,12 @@ def reset_parameters(self): for block in self: block.reset_parameters() - def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + def set_activation_checkpointing( + self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None + ): self.activation_checkpointing_strategy = strategy for block in self: - block.set_activation_checkpointing(strategy) + block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func) class OLMo(nn.Module): @@ -1140,14 +1144,16 @@ def __init__(self, config: ModelConfig, init_params: bool = True): get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config)) self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) - def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + def set_activation_checkpointing( + self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None + ): self.activation_checkpointing_strategy = strategy if self.config.block_group_size != 1: for block_group in self.transformer.block_groups: - block_group.set_activation_checkpointing(strategy) + block_group.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func) else: for block in self.transformer.blocks: - block.set_activation_checkpointing(strategy) + block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func) @property def device(self) -> torch.device: @@ -1445,7 +1451,11 @@ def forward( if self.config.scale_logits: logits.mul_(1 / math.sqrt(self.config.d_model)) - return OLMoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type] + return OLMoOutput( + logits=logits, + attn_key_values=attn_key_values, + hidden_states=tuple(all_hidden_states) if output_hidden_states else None, + ) def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None): if wrap_strategy is None: diff --git a/test_fixtures/test-olmo-model/config.json b/test_fixtures/test-olmo-model/config.json index 136113148..f5ed8cc8b 100644 --- a/test_fixtures/test-olmo-model/config.json +++ b/test_fixtures/test-olmo-model/config.json @@ -13,7 +13,9 @@ "block_type": "sequential", "clip_qkv": null, "d_model": 32, + "emb_init_std": null, "embedding_dropout": 0.1, + "embedding_layer_norm": false, "embedding_size": 50304, "eos_token_id": 50256, "flash_attention": false, @@ -22,6 +24,7 @@ "init_device": null, "init_fn": "normal", "init_std": 0.02, + "layer_norm_eps": 1e-05, "layer_norm_type": "default", "layer_norm_with_affine": true, "max_sequence_length": 1024, @@ -32,13 +35,16 @@ "n_heads": 1, "n_kv_heads": null, "n_layers": 1, + "norm_after": false, "pad_token_id": 50256, "precision": null, "residual_dropout": 0.1, "rope": false, "rope_full_precision": true, + "rope_theta": 10000, + "scale_emb_init": false, "scale_logits": false, - "transformers_version": "4.40.2", + "transformers_version": "4.44.2", "use_cache": true, "vocab_size": 50257, "weight_tying": true diff --git a/tests/hf_olmo/modeling_olmo_test.py b/tests/hf_olmo/modeling_olmo_test.py index e4bb02f54..f51ca26f1 100644 --- a/tests/hf_olmo/modeling_olmo_test.py +++ b/tests/hf_olmo/modeling_olmo_test.py @@ -21,7 +21,86 @@ def test_olmo_model(model_path: str): output = model(input_tensor) hf_output = hf_model(input_tensor) - torch.testing.assert_allclose(output.logits, hf_output.logits) + torch.testing.assert_close(hf_output.logits, output.logits) + + +@pytest.mark.gpu +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA devices") +def test_flash_attention_2(model_path: str): + from transformers import AutoModelForCausalLM, AutoTokenizer + + import hf_olmo # noqa: F401 + + hf_model = AutoModelForCausalLM.from_pretrained(model_path) + hf_model_flash_attn = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="flash_attention_2") + + tokenizer = AutoTokenizer.from_pretrained(model_path) + encoded_input = tokenizer.encode("My name is OLMo!") + input_tensor = torch.tensor(encoded_input).unsqueeze(0) + + hf_output = hf_model(input_tensor) + hf_output_flash_attn = hf_model_flash_attn(input_tensor) + + torch.testing.assert_close(hf_output_flash_attn.logits, hf_output.logits) + + +def test_sdpa(model_path: str): + from transformers import AutoModelForCausalLM, AutoTokenizer + + import hf_olmo # noqa: F401 + + hf_model = AutoModelForCausalLM.from_pretrained(model_path) + hf_model_sdpa = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="sdpa") + + tokenizer = AutoTokenizer.from_pretrained(model_path) + encoded_input = tokenizer.encode("My name is OLMo!") + input_tensor = torch.tensor(encoded_input).unsqueeze(0) + + hf_output = hf_model(input_tensor) + hf_output_sdpa = hf_model_sdpa(input_tensor) + + torch.testing.assert_close(hf_output_sdpa.logits, hf_output.logits) + + +def test_gradient_checkpointing(model_path: str): + from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + + import hf_olmo # noqa: F401 + + hf_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_path) + + tokenizer = AutoTokenizer.from_pretrained(model_path) + encoded_input = tokenizer.encode("My name is OLMo!") + input_tensor = torch.tensor(encoded_input).unsqueeze(0) + + hf_output_no_checkpointing = hf_model(input_tensor) + + hf_model.gradient_checkpointing_enable() + + hf_output_checkpointing = hf_model(input_tensor) + + torch.testing.assert_close(hf_output_checkpointing.logits, hf_output_no_checkpointing.logits) + + +def test_gradient_checkpointing_disable(model_path: str): + from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + + import hf_olmo # noqa: F401 + + hf_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_path) + + tokenizer = AutoTokenizer.from_pretrained(model_path) + encoded_input = tokenizer.encode("My name is OLMo!") + input_tensor = torch.tensor(encoded_input).unsqueeze(0) + + hf_output = hf_model(input_tensor) + + hf_model.gradient_checkpointing_enable() + hf_model.gradient_checkpointing_disable() + + hf_output_after_disable = hf_model(input_tensor) + + torch.testing.assert_close(hf_output_after_disable.logits, hf_output.logits) def test_save_pretrained(model_path: str):