-
Notifications
You must be signed in to change notification settings - Fork 82
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Overview
Implement different strategies for loading checkpoints into models - Phase 2 of the 5-PR implementation plan.
Parent Issue
Part of #248 - Checkpoint System Refactor
Dependencies
- Requires: Checkpoint Pipeline Infrastructure (Phase 1) #493 (Pipeline Infrastructure)
- Works with: Checkpoint Acquisition Layer - Multi-source checkpoint loading (S3, HTTP, local, MLFlow) #458 (Checkpoint Acquisition), Model Transformation Layer - Post-loading modifications (freezing, transfer learning, adapters) #410 (Model Modifiers)
Purpose
This layer implements different strategies for how checkpoints are loaded into models, managing the complexity of various training scenarios.
Loading Strategies to Implement
1. Weights-Only Loading
- Load only model weights, discard optimizer/scheduler state
- Use case: Starting fresh training with pretrained weights
- Config:
type: weights_only
2. Transfer Learning Loading
- Flexible weight loading with shape mismatch handling
- Skip incompatible layers optionally
- Use case: Fine-tuning from different model architecture
- Config:
type: transfer_learning, skip_mismatched: true
3. Warm Start Loading
- Full state restoration including optimizer and scheduler
- Resume training from exact point
- Use case: Continuing interrupted training
- Config:
type: warm_start
4. Cold Start Loading
- Load pretrained weights but reset training state
- Optionally reset specific layers
- Use case: New training run from good initialization
- Config:
type: cold_start, reset_layers: [decoder]
Key Components
1. Loading Strategy Interface (training/src/anemoi/training/checkpoint/loading/base.py)
class LoadingStrategy(PipelineStage):
"""Base class for checkpoint loading strategies"""
@abstractmethod
async def process(self, context: CheckpointContext) -> CheckpointContext:
"""Load checkpoint data into model/optimizer/scheduler"""
pass
def _extract_state_dict(self, checkpoint_data: Dict) -> Dict:
"""Extract model state dict from various checkpoint formats"""
pass2. Strategy Implementations (training/src/anemoi/training/checkpoint/loading/strategies.py)
WeightsOnlyLoader- Model weights onlyTransferLearningLoader- Flexible loading with mismatch handlingWarmStartLoader- Full state restorationColdStartLoader- Fresh start with pretrained weights
3. State Management (training/src/anemoi/training/checkpoint/loading/state.py)
@dataclass
class TrainingState:
"""Encapsulates complete training state"""
epoch: int = 0
global_step: int = 0
best_metric: Optional[float] = None
metrics_history: List[Dict] = field(default_factory=list)4. Loading Utilities (training/src/anemoi/training/checkpoint/loading/utils.py)
match_state_dict_keys- Fuzzy matching between state dictsfilter_state_dict- Filter by patternsmerge_state_dicts- Combine multiple checkpoints
Configuration Examples
Weights-Only
checkpoint:
loading:
type: weights_only
strict: falseTransfer Learning
checkpoint:
loading:
type: transfer_learning
strict: false
skip_mismatched: true
freeze_loaded: false # Optionally freeze loaded weightsWarm Start
checkpoint:
loading:
type: warm_start
# No additional config neededCold Start
checkpoint:
loading:
type: cold_start
reset_layers: [decoder] # Optionally reset specific layersImplementation Checklist
- Base loading strategy interface
- WeightsOnlyLoader implementation
- TransferLearningLoader implementation
- WarmStartLoader implementation
- ColdStartLoader implementation
- State management utilities
- Loading utilities (matching, filtering)
- Unit tests for each strategy
- Integration tests
- Configuration templates
- Documentation
Design Decisions
1. Strategy Pattern
- Each loading approach is a separate strategy
- Strategies are composable and extensible
- Clear separation of concerns
2. State Management
- Explicit handling of optimizer/scheduler state
- Clear semantics for each loading type
- Metadata preservation where appropriate
3. Error Handling
- Graceful handling of shape mismatches
- Clear error messages for debugging
- Optional strict/non-strict modes
4. Flexibility
- Support for various checkpoint formats
- Configurable behavior via Hydra
- Extensible for custom strategies
Testing Strategy
Unit Tests (tests/checkpoint/loading/test_strategies.py)
- Test each loading strategy independently
- Test parameter matching logic
- Test state restoration
- Test error conditions
Integration Tests (tests/checkpoint/loading/test_loading_integration.py)
- Test with real checkpoint files
- Test optimizer/scheduler restoration
- Test transfer learning scenarios
- Test compatibility with different model architectures
Migration from Legacy
load_weights_only=True→WeightsOnlyLoadertransfer_learning=True→TransferLearningLoaderresume_from_checkpoint=path→WarmStartLoader
Branch
- Working branch:
feature/checkpoint-loading-orchestration
Success Criteria
- All loading strategies implemented and tested
- State management working correctly
- Configuration templates created
- Integration with pipeline infrastructure
- Documentation complete
- Ready for Phase 3 integration
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request
Type
Projects
Status
To be triaged