Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added ability to try loading latest checkpoint from save folder using `--try_load_latest_save`.
- Added support for flash attention and gradient checkpointing to `hf_olmo`.

## [v0.5.0](https://github.com/allenai/OLMo/releases/tag/v0.5.0) - 2024-08-26

Expand Down
36 changes: 34 additions & 2 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
from dataclasses import fields
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
from transformers import PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModelForCausalLM

from olmo.config import ModelConfig
from olmo.config import ActivationCheckpointingStrategy, ModelConfig
from olmo.model import OLMo

from .configuration_olmo import OLMoConfig
Expand All @@ -26,6 +26,15 @@ def create_model_config_from_pretrained_config(config: OLMoConfig):
kwargs[field.name] = getattr(config, field.name)

model_config = ModelConfig(**kwargs)

# Handle flash attention settings
if config._attn_implementation == "flash_attention_2":
model_config.flash_attention = True
elif config._attn_implementation in ("eager", "sdpa"):
model_config.flash_attention = False
else:
raise ValueError(f"Unexpected _attn_implementation {config._attn_implementation}")

return model_config


Expand All @@ -37,10 +46,16 @@ class OLMoForCausalLM(PreTrainedModel):
config_class = OLMoConfig
base_model_prefix = "model"
_no_split_modules = ["OLMoBlock"]
_supports_flash_attn_2 = True
_supports_sdpa = True
supports_gradient_checkpointing = True

def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False):
super().__init__(config)

self._gradient_checkpointing_func: Optional[Callable] = None
self._gradient_checkpointing = False

if not model:
model_config = create_model_config_from_pretrained_config(config)
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
Expand All @@ -49,6 +64,23 @@ def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params
else:
self.model = model

@property
def gradient_checkpointing(self) -> bool:
return self._gradient_checkpointing

@gradient_checkpointing.setter
def gradient_checkpointing(self, enabled: bool):
if self._gradient_checkpointing == enabled:
return

# HF does not specify a way to pass checkpointing strategies, so we pick
# whole layer as our strategy. We can make this configurable later if needed.
checkpointing_strategy = ActivationCheckpointingStrategy.whole_layer if enabled else None
self.model.set_activation_checkpointing(
checkpointing_strategy, checkpoint_func=self._gradient_checkpointing_func
)
self._gradient_checkpointing = enabled

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
30 changes: 20 additions & 10 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
self.__cache = cache
assert config.d_model % config.n_heads == 0

self._activation_checkpoint_fn = None
self._activation_checkpoint_fn: Optional[Callable] = None

# Dropout.
self.dropout = Dropout(config.residual_dropout)
Expand Down Expand Up @@ -500,9 +500,11 @@ def reset_parameters(self):
init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor)

def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
def set_activation_checkpointing(
self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
):
if strategy == ActivationCheckpointingStrategy.fine_grained:
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
self._activation_checkpoint_fn = checkpoint_func or activation_checkpoint_function(self.config)
else:
self._activation_checkpoint_fn = None

Expand Down Expand Up @@ -980,7 +982,7 @@ class OLMoOutput(NamedTuple):
Attention keys and values from each block.
"""

hidden_states: Optional[Tuple[torch.Tensor]]
hidden_states: Optional[Tuple[torch.Tensor, ...]]
"""
Hidden states from each block.
"""
Expand Down Expand Up @@ -1050,10 +1052,12 @@ def reset_parameters(self):
for block in self:
block.reset_parameters()

def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
def set_activation_checkpointing(
self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
):
self.activation_checkpointing_strategy = strategy
for block in self:
block.set_activation_checkpointing(strategy)
block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)


class OLMo(nn.Module):
Expand Down Expand Up @@ -1140,14 +1144,16 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))

def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
def set_activation_checkpointing(
self, strategy: Optional[ActivationCheckpointingStrategy], checkpoint_func: Optional[Callable] = None
):
self.activation_checkpointing_strategy = strategy
if self.config.block_group_size != 1:
for block_group in self.transformer.block_groups:
block_group.set_activation_checkpointing(strategy)
block_group.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
else:
for block in self.transformer.blocks:
block.set_activation_checkpointing(strategy)
block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -1445,7 +1451,11 @@ def forward(
if self.config.scale_logits:
logits.mul_(1 / math.sqrt(self.config.d_model))

return OLMoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
return OLMoOutput(
logits=logits,
attn_key_values=attn_key_values,
hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
)

def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
if wrap_strategy is None:
Expand Down
8 changes: 7 additions & 1 deletion test_fixtures/test-olmo-model/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
"block_type": "sequential",
"clip_qkv": null,
"d_model": 32,
"emb_init_std": null,
"embedding_dropout": 0.1,
"embedding_layer_norm": false,
"embedding_size": 50304,
"eos_token_id": 50256,
"flash_attention": false,
Expand All @@ -22,6 +24,7 @@
"init_device": null,
"init_fn": "normal",
"init_std": 0.02,
"layer_norm_eps": 1e-05,
"layer_norm_type": "default",
"layer_norm_with_affine": true,
"max_sequence_length": 1024,
Expand All @@ -32,13 +35,16 @@
"n_heads": 1,
"n_kv_heads": null,
"n_layers": 1,
"norm_after": false,
"pad_token_id": 50256,
"precision": null,
"residual_dropout": 0.1,
"rope": false,
"rope_full_precision": true,
"rope_theta": 10000,
"scale_emb_init": false,
"scale_logits": false,
"transformers_version": "4.40.2",
"transformers_version": "4.44.2",
"use_cache": true,
"vocab_size": 50257,
"weight_tying": true
Expand Down
81 changes: 80 additions & 1 deletion tests/hf_olmo/modeling_olmo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,86 @@ def test_olmo_model(model_path: str):
output = model(input_tensor)
hf_output = hf_model(input_tensor)

torch.testing.assert_allclose(output.logits, hf_output.logits)
torch.testing.assert_close(hf_output.logits, output.logits)


@pytest.mark.gpu
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA devices")
def test_flash_attention_2(model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer

import hf_olmo # noqa: F401

hf_model = AutoModelForCausalLM.from_pretrained(model_path)
hf_model_flash_attn = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="flash_attention_2")

tokenizer = AutoTokenizer.from_pretrained(model_path)
encoded_input = tokenizer.encode("My name is OLMo!")
input_tensor = torch.tensor(encoded_input).unsqueeze(0)

hf_output = hf_model(input_tensor)
hf_output_flash_attn = hf_model_flash_attn(input_tensor)

torch.testing.assert_close(hf_output_flash_attn.logits, hf_output.logits)


def test_sdpa(model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer

import hf_olmo # noqa: F401

hf_model = AutoModelForCausalLM.from_pretrained(model_path)
hf_model_sdpa = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="sdpa")

tokenizer = AutoTokenizer.from_pretrained(model_path)
encoded_input = tokenizer.encode("My name is OLMo!")
input_tensor = torch.tensor(encoded_input).unsqueeze(0)

hf_output = hf_model(input_tensor)
hf_output_sdpa = hf_model_sdpa(input_tensor)

torch.testing.assert_close(hf_output_sdpa.logits, hf_output.logits)


def test_gradient_checkpointing(model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

import hf_olmo # noqa: F401

hf_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_path)

tokenizer = AutoTokenizer.from_pretrained(model_path)
encoded_input = tokenizer.encode("My name is OLMo!")
input_tensor = torch.tensor(encoded_input).unsqueeze(0)

hf_output_no_checkpointing = hf_model(input_tensor)

hf_model.gradient_checkpointing_enable()

hf_output_checkpointing = hf_model(input_tensor)

torch.testing.assert_close(hf_output_checkpointing.logits, hf_output_no_checkpointing.logits)


def test_gradient_checkpointing_disable(model_path: str):
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

import hf_olmo # noqa: F401

hf_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_path)

tokenizer = AutoTokenizer.from_pretrained(model_path)
encoded_input = tokenizer.encode("My name is OLMo!")
input_tensor = torch.tensor(encoded_input).unsqueeze(0)

hf_output = hf_model(input_tensor)

hf_model.gradient_checkpointing_enable()
hf_model.gradient_checkpointing_disable()

hf_output_after_disable = hf_model(input_tensor)

torch.testing.assert_close(hf_output_after_disable.logits, hf_output.logits)


def test_save_pretrained(model_path: str):
Expand Down