-
Notifications
You must be signed in to change notification settings - Fork 82
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Overview
Integrate all checkpoint components and provide migration utilities for legacy code - Phase 3 of the 5-PR implementation plan.
Parent Issue
Part of #248 - Checkpoint System Refactor
Dependencies
- Requires: Checkpoint Pipeline Infrastructure (Phase 1) #493 (Pipeline Infrastructure), Checkpoint Loading Orchestration (Phase 2) #494 (Loading Orchestration)
- Integrates: Checkpoint Acquisition Layer - Multi-source checkpoint loading (S3, HTTP, local, MLFlow) #458 (Checkpoint Acquisition via PR Refactor restart of training #464), Model Transformation Layer - Post-loading modifications (freezing, transfer learning, adapters) #410 (Model Modifiers via PR [WIP] 410 refactor of model initialisation ie weight loading model freezing transfer learning #442)
Purpose
This phase brings together all checkpoint system components into a unified API and provides migration paths for existing code.
Key Components
1. Unified API (training/src/anemoi/training/checkpoint/api.py)
class CheckpointManager:
"""High-level API for checkpoint operations"""
def __init__(self, config: DictConfig):
self.pipeline = self._build_pipeline()
async def load_checkpoint(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[Any] = None
) -> CheckpointContext:
"""Load checkpoint with configured pipeline"""
pass
def save_checkpoint(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[Any] = None,
epoch: int = 0,
global_step: int = 0,
metrics: Dict = None
) -> Path:
"""Save checkpoint with metadata"""
pass2. Legacy Migration (training/src/anemoi/training/checkpoint/migration.py)
class LegacyConfigMigrator:
"""Migrate legacy checkpoint configurations to new system"""
@staticmethod
def migrate_config(old_config: DictConfig) -> DictConfig:
"""Convert legacy config to new format with deprecation warnings"""
pass
class LegacyAdapter:
"""Adapter to support legacy code using new system"""
def load_weights_only(self, model: nn.Module, checkpoint_path: str) -> nn.Module:
"""Legacy method for backwards compatibility"""
pass
def transfer_learning(self, model: nn.Module, checkpoint_path: str) -> nn.Module:
"""Legacy method for backwards compatibility"""
pass3. PyTorch Lightning Integration (training/src/anemoi/training/checkpoint/lightning.py)
class CheckpointCallback(pl.Callback):
"""PyTorch Lightning callback for checkpoint handling"""
def on_train_start(self, trainer, pl_module):
"""Load checkpoint at training start if configured"""
pass
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
"""Enhance checkpoint with additional metadata"""
pass4. CLI Integration (training/src/anemoi/training/checkpoint/cli.py)
@click.group()
def checkpoint():
"""Checkpoint management commands"""
pass
@checkpoint.command()
def inspect(checkpoint_path: str):
"""Inspect checkpoint contents"""
pass
@checkpoint.command()
def convert(old_checkpoint: str, new_checkpoint: str):
"""Convert checkpoint to new format"""
pass
@checkpoint.command()
def diff(checkpoint1: str, checkpoint2: str):
"""Compare two checkpoints"""
passFull Integration Example
Configuration
# Complete checkpoint configuration
checkpoint:
# Acquisition layer
source:
type: s3
bucket: ecmwf-models
key: anemoi/pretrained.ckpt
cache_dir: /tmp/checkpoints
# Loading layer
loading:
type: transfer_learning
strict: false
skip_mismatched: true
freeze_loaded: false
# Model modification layer
model_modifier:
modifiers:
- type: freeze
layers: [encoder, processor.0]
- type: lora
rank: 8
target_modules: [decoder]
# Pipeline configuration
pipeline:
async_execution: true
max_retries: 3
timeout: 600Usage
# Initialize checkpoint manager
manager = CheckpointManager(config)
# Load checkpoint through pipeline
context = await manager.load_checkpoint(model, optimizer, scheduler)
# Access results
model = context.model
training_state = context.metadataMigration Strategy
Phase 1: Deprecation Warnings
# Old code still works but shows warnings
model = load_weights_only(model, "checkpoint.pt")
# Warning: load_weights_only is deprecated. Use CheckpointManager with type='weights_only'Phase 2: Adapter Layer
# Automatic migration
legacy_adapter = LegacyAdapter(checkpoint_manager)
model = legacy_adapter.load_weights_only(model, "checkpoint.pt")Phase 3: Full Migration
# New API only
manager = CheckpointManager(config)
context = await manager.load_checkpoint(model)Implementation Checklist
- Unified CheckpointManager API
- Legacy configuration migrator
- Legacy adapter for backwards compatibility
- PyTorch Lightning integration
- CLI tools (inspect, convert, diff)
- Testing utilities
- Migration documentation
- Integration tests
- End-to-end tests
- Performance benchmarks
Testing Strategy
Integration Tests (tests/checkpoint/test_full_integration.py)
- Test complete pipeline end-to-end
- Test all layer interactions
- Test error propagation through pipeline
- Test with real models and checkpoints
Migration Tests (tests/checkpoint/test_migration.py)
- Test legacy config migration
- Test backwards compatibility
- Test deprecation warnings
- Test adapter functionality
Lightning Integration (tests/checkpoint/test_lightning.py)
- Test callback integration
- Test state restoration in trainer
- Test checkpoint saving with metadata
CLI Tests (tests/checkpoint/test_cli.py)
- Test checkpoint inspection
- Test format conversion
- Test checkpoint comparison
Deprecation Timeline
- v1.0: Introduce new system with full backwards compatibility
- v1.1: Add deprecation warnings for legacy methods
- v1.2: Move legacy methods to separate module
- v2.0: Remove legacy methods
Branch
- Working branch:
feature/checkpoint-integration-migration
Success Criteria
- Unified API working seamlessly
- All components integrated
- Legacy migration complete
- PyTorch Lightning integration functional
- CLI tools implemented
- Full test coverage
- Documentation complete
- Performance benchmarks acceptable
- Backwards compatibility maintained
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request
Type
Projects
Status
To be triaged