Skip to content

Checkpoint System Integration and Migration (Phase 3) #495

@JesperDramsch

Description

@JesperDramsch

Overview

Integrate all checkpoint components and provide migration utilities for legacy code - Phase 3 of the 5-PR implementation plan.

Parent Issue

Part of #248 - Checkpoint System Refactor

Dependencies

Purpose

This phase brings together all checkpoint system components into a unified API and provides migration paths for existing code.

Key Components

1. Unified API (training/src/anemoi/training/checkpoint/api.py)

class CheckpointManager:
    """High-level API for checkpoint operations"""
    
    def __init__(self, config: DictConfig):
        self.pipeline = self._build_pipeline()
    
    async def load_checkpoint(
        self, 
        model: nn.Module,
        optimizer: Optional[Optimizer] = None,
        scheduler: Optional[Any] = None
    ) -> CheckpointContext:
        """Load checkpoint with configured pipeline"""
        pass
    
    def save_checkpoint(
        self,
        model: nn.Module,
        optimizer: Optional[Optimizer] = None,
        scheduler: Optional[Any] = None,
        epoch: int = 0,
        global_step: int = 0,
        metrics: Dict = None
    ) -> Path:
        """Save checkpoint with metadata"""
        pass

2. Legacy Migration (training/src/anemoi/training/checkpoint/migration.py)

class LegacyConfigMigrator:
    """Migrate legacy checkpoint configurations to new system"""
    
    @staticmethod
    def migrate_config(old_config: DictConfig) -> DictConfig:
        """Convert legacy config to new format with deprecation warnings"""
        pass

class LegacyAdapter:
    """Adapter to support legacy code using new system"""
    
    def load_weights_only(self, model: nn.Module, checkpoint_path: str) -> nn.Module:
        """Legacy method for backwards compatibility"""
        pass
    
    def transfer_learning(self, model: nn.Module, checkpoint_path: str) -> nn.Module:
        """Legacy method for backwards compatibility"""
        pass

3. PyTorch Lightning Integration (training/src/anemoi/training/checkpoint/lightning.py)

class CheckpointCallback(pl.Callback):
    """PyTorch Lightning callback for checkpoint handling"""
    
    def on_train_start(self, trainer, pl_module):
        """Load checkpoint at training start if configured"""
        pass
    
    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        """Enhance checkpoint with additional metadata"""
        pass

4. CLI Integration (training/src/anemoi/training/checkpoint/cli.py)

@click.group()
def checkpoint():
    """Checkpoint management commands"""
    pass

@checkpoint.command()
def inspect(checkpoint_path: str):
    """Inspect checkpoint contents"""
    pass

@checkpoint.command()
def convert(old_checkpoint: str, new_checkpoint: str):
    """Convert checkpoint to new format"""
    pass

@checkpoint.command()
def diff(checkpoint1: str, checkpoint2: str):
    """Compare two checkpoints"""
    pass

Full Integration Example

Configuration

# Complete checkpoint configuration
checkpoint:
  # Acquisition layer
  source:
    type: s3
    bucket: ecmwf-models
    key: anemoi/pretrained.ckpt
    cache_dir: /tmp/checkpoints
  
  # Loading layer
  loading:
    type: transfer_learning
    strict: false
    skip_mismatched: true
    freeze_loaded: false

# Model modification layer
model_modifier:
  modifiers:
    - type: freeze
      layers: [encoder, processor.0]
    - type: lora
      rank: 8
      target_modules: [decoder]

# Pipeline configuration
pipeline:
  async_execution: true
  max_retries: 3
  timeout: 600

Usage

# Initialize checkpoint manager
manager = CheckpointManager(config)

# Load checkpoint through pipeline
context = await manager.load_checkpoint(model, optimizer, scheduler)

# Access results
model = context.model
training_state = context.metadata

Migration Strategy

Phase 1: Deprecation Warnings

# Old code still works but shows warnings
model = load_weights_only(model, "checkpoint.pt")
# Warning: load_weights_only is deprecated. Use CheckpointManager with type='weights_only'

Phase 2: Adapter Layer

# Automatic migration
legacy_adapter = LegacyAdapter(checkpoint_manager)
model = legacy_adapter.load_weights_only(model, "checkpoint.pt")

Phase 3: Full Migration

# New API only
manager = CheckpointManager(config)
context = await manager.load_checkpoint(model)

Implementation Checklist

  • Unified CheckpointManager API
  • Legacy configuration migrator
  • Legacy adapter for backwards compatibility
  • PyTorch Lightning integration
  • CLI tools (inspect, convert, diff)
  • Testing utilities
  • Migration documentation
  • Integration tests
  • End-to-end tests
  • Performance benchmarks

Testing Strategy

Integration Tests (tests/checkpoint/test_full_integration.py)

  • Test complete pipeline end-to-end
  • Test all layer interactions
  • Test error propagation through pipeline
  • Test with real models and checkpoints

Migration Tests (tests/checkpoint/test_migration.py)

  • Test legacy config migration
  • Test backwards compatibility
  • Test deprecation warnings
  • Test adapter functionality

Lightning Integration (tests/checkpoint/test_lightning.py)

  • Test callback integration
  • Test state restoration in trainer
  • Test checkpoint saving with metadata

CLI Tests (tests/checkpoint/test_cli.py)

  • Test checkpoint inspection
  • Test format conversion
  • Test checkpoint comparison

Deprecation Timeline

  1. v1.0: Introduce new system with full backwards compatibility
  2. v1.1: Add deprecation warnings for legacy methods
  3. v1.2: Move legacy methods to separate module
  4. v2.0: Remove legacy methods

Branch

  • Working branch: feature/checkpoint-integration-migration

Success Criteria

  • Unified API working seamlessly
  • All components integrated
  • Legacy migration complete
  • PyTorch Lightning integration functional
  • CLI tools implemented
  • Full test coverage
  • Documentation complete
  • Performance benchmarks acceptable
  • Backwards compatibility maintained

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

Status

To be triaged

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions