Skip to content

Checkpoint Loading Orchestration (Phase 2) #494

@JesperDramsch

Description

@JesperDramsch

Overview

Implement different strategies for loading checkpoints into models - Phase 2 of the 5-PR implementation plan.

Parent Issue

Part of #248 - Checkpoint System Refactor

Dependencies

Purpose

This layer implements different strategies for how checkpoints are loaded into models, managing the complexity of various training scenarios.

Loading Strategies to Implement

1. Weights-Only Loading

  • Load only model weights, discard optimizer/scheduler state
  • Use case: Starting fresh training with pretrained weights
  • Config: type: weights_only

2. Transfer Learning Loading

  • Flexible weight loading with shape mismatch handling
  • Skip incompatible layers optionally
  • Use case: Fine-tuning from different model architecture
  • Config: type: transfer_learning, skip_mismatched: true

3. Warm Start Loading

  • Full state restoration including optimizer and scheduler
  • Resume training from exact point
  • Use case: Continuing interrupted training
  • Config: type: warm_start

4. Cold Start Loading

  • Load pretrained weights but reset training state
  • Optionally reset specific layers
  • Use case: New training run from good initialization
  • Config: type: cold_start, reset_layers: [decoder]

Key Components

1. Loading Strategy Interface (training/src/anemoi/training/checkpoint/loading/base.py)

class LoadingStrategy(PipelineStage):
    """Base class for checkpoint loading strategies"""
    
    @abstractmethod
    async def process(self, context: CheckpointContext) -> CheckpointContext:
        """Load checkpoint data into model/optimizer/scheduler"""
        pass
    
    def _extract_state_dict(self, checkpoint_data: Dict) -> Dict:
        """Extract model state dict from various checkpoint formats"""
        pass

2. Strategy Implementations (training/src/anemoi/training/checkpoint/loading/strategies.py)

  • WeightsOnlyLoader - Model weights only
  • TransferLearningLoader - Flexible loading with mismatch handling
  • WarmStartLoader - Full state restoration
  • ColdStartLoader - Fresh start with pretrained weights

3. State Management (training/src/anemoi/training/checkpoint/loading/state.py)

@dataclass
class TrainingState:
    """Encapsulates complete training state"""
    epoch: int = 0
    global_step: int = 0
    best_metric: Optional[float] = None
    metrics_history: List[Dict] = field(default_factory=list)

4. Loading Utilities (training/src/anemoi/training/checkpoint/loading/utils.py)

  • match_state_dict_keys - Fuzzy matching between state dicts
  • filter_state_dict - Filter by patterns
  • merge_state_dicts - Combine multiple checkpoints

Configuration Examples

Weights-Only

checkpoint:
  loading:
    type: weights_only
    strict: false

Transfer Learning

checkpoint:
  loading:
    type: transfer_learning
    strict: false
    skip_mismatched: true
    freeze_loaded: false  # Optionally freeze loaded weights

Warm Start

checkpoint:
  loading:
    type: warm_start
    # No additional config needed

Cold Start

checkpoint:
  loading:
    type: cold_start
    reset_layers: [decoder]  # Optionally reset specific layers

Implementation Checklist

  • Base loading strategy interface
  • WeightsOnlyLoader implementation
  • TransferLearningLoader implementation
  • WarmStartLoader implementation
  • ColdStartLoader implementation
  • State management utilities
  • Loading utilities (matching, filtering)
  • Unit tests for each strategy
  • Integration tests
  • Configuration templates
  • Documentation

Design Decisions

1. Strategy Pattern

  • Each loading approach is a separate strategy
  • Strategies are composable and extensible
  • Clear separation of concerns

2. State Management

  • Explicit handling of optimizer/scheduler state
  • Clear semantics for each loading type
  • Metadata preservation where appropriate

3. Error Handling

  • Graceful handling of shape mismatches
  • Clear error messages for debugging
  • Optional strict/non-strict modes

4. Flexibility

  • Support for various checkpoint formats
  • Configurable behavior via Hydra
  • Extensible for custom strategies

Testing Strategy

Unit Tests (tests/checkpoint/loading/test_strategies.py)

  • Test each loading strategy independently
  • Test parameter matching logic
  • Test state restoration
  • Test error conditions

Integration Tests (tests/checkpoint/loading/test_loading_integration.py)

  • Test with real checkpoint files
  • Test optimizer/scheduler restoration
  • Test transfer learning scenarios
  • Test compatibility with different model architectures

Migration from Legacy

  • load_weights_only=TrueWeightsOnlyLoader
  • transfer_learning=TrueTransferLearningLoader
  • resume_from_checkpoint=pathWarmStartLoader

Branch

  • Working branch: feature/checkpoint-loading-orchestration

Success Criteria

  • All loading strategies implemented and tested
  • State management working correctly
  • Configuration templates created
  • Integration with pipeline infrastructure
  • Documentation complete
  • Ready for Phase 3 integration

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

Status

To be triaged

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions