Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
from torchtune import training

from torchtune.models import convert_weights
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
llama3_vision_tune_to_meta,
)
from torchtune.models.phi3._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf
from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf
from torchtune.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf
Expand Down Expand Up @@ -740,6 +736,10 @@ def load_checkpoint(self) -> Dict[str, Any]:
state_dict: Dict[str:Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
)

state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
model_state_dict
)
Expand Down Expand Up @@ -784,6 +784,10 @@ def save_checkpoint(
if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
)

state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
Expand Down