feat: Checkpoint pipeline infrastructure (Phase 1)#501
Conversation
Phase 1: Checkpoint Pipeline Infrastructure - Progress UpdateSummaryThis PR implements the foundational infrastructure for the new checkpoint architecture, establishing the core abstractions, pipeline pattern, and utilities that all checkpoint operations will build upon. Related Issues: #493 (Phase 1), Part of 5-phase checkpoint architecture refactoring 🎯 MotivationThe current checkpoint handling in anemoi-training is monolithic and difficult to extend. This PR establishes a clean, three-layer pipeline architecture: 📦 Changes IncludedNew Files (VERIFIED)Modified Files (VERIFIED)Pending Additions (Not Yet in PR)🔑 Key Features1. Pipeline PatternClean, composable stages for checkpoint processing: # Example usage (programmatic)
pipeline = CheckpointPipeline(
stages=[
CheckpointSource(...), # Phase 2
LoadingStrategy(...), # Phase 2
ModelModifier(...), # Phase 2
],
async_execution=True
)
context = await pipeline.execute(initial_context)2. Component CatalogAutomatic discovery of pipeline components using reflection: # No manual registration needed!
ComponentCatalog.list_sources() # Auto-discovers CheckpointSource subclasses
ComponentCatalog.list_loaders() # Auto-discovers LoadingStrategy subclasses
ComponentCatalog.list_modifiers() # Auto-discovers ModelModifier subclasses3. Multi-Format SupportSeamless handling of different checkpoint formats:
Auto-detection and conversion included. 4. Comprehensive Error HandlingRich exception hierarchy for debugging: try:
context = await pipeline.execute(context)
except CheckpointNotFoundError as e:
logger.error(f"Checkpoint not found: {e}")
except CheckpointIncompatibleError as e:
logger.error(f"Incompatible checkpoint: {e}")5. Async-First DesignEfficient async operations with sync compatibility: # Async mode (default)
context = await pipeline.execute(context)
# Sync mode (when needed)
pipeline = CheckpointPipeline(stages, async_execution=False)
context = pipeline.execute(context)🧪 Testing StrategyCoverage Metrics (VERIFIED)
Test Execution# Run all checkpoint tests
pytest training/tests/checkpoint/ -v
# Run with coverage
pytest training/tests/checkpoint/ --cov=anemoi.training.checkpoint --cov-report=html
# Quick verification
pytest training/tests/checkpoint/ -q
# Output: 211 passed, 2 skipped in 46.46sTest Coverage by Module
🔄 Integration PointsWith PR #464 (Checkpoint Acquisition)# PR #464 will implement CheckpointSource subclasses
class S3Source(PipelineStage): # Uses base.py abstractions
async def process(self, context: CheckpointContext) -> CheckpointContext:
# Load from S3, populate context.checkpoint_data
...With PR #494 (Loading Orchestration)# PR #494 will implement LoadingStrategy subclasses
class TransferLearningLoader(PipelineStage): # Uses base.py abstractions
async def process(self, context: CheckpointContext) -> CheckpointContext:
# Apply checkpoint to model with flexibility
...With PR #442 (Model Modifiers)# PR #442 will implement ModelModifier as PipelineStage
class FreezingModifier(PipelineStage): # Uses base.py abstractions
async def process(self, context: CheckpointContext) -> CheckpointContext:
# Freeze model layers post-loading
...
|
training/src/anemoi/training/diagnostics/callbacks/checkpoint_pipeline.py
Outdated
Show resolved
Hide resolved
training/src/anemoi/training/diagnostics/callbacks/checkpoint_pipeline.py
Outdated
Show resolved
Hide resolved
|
Thank you @anaprietonem and @sahahner for the thorough review! Here's how I'm addressing your feedback: Already Addressed
Will Address Before Merge
Architectural Decisions
Future Work
Let me know if you'd like any of these addressed differently! |
Reverts TYPE_CHECKING import reorganization and noqa comment removals in 9 files that had no functional changes related to the checkpoint pipeline infrastructure (Phase 1 - Issue #493). These linting changes were introduced by pre-commit hooks running on the entire codebase. Reverting them reduces PR scope as requested in PR #501 review. Files reverted: - diagnostics/callbacks/plot.py - diagnostics/mlflow/logger.py - schemas/base_schema.py - schemas/dataloader.py - schemas/graphs/node_schemas.py - schemas/hardware.py - schemas/models/models.py - schemas/graphs/base_graph.py - train/forecaster.py All checkpoint pipeline implementation and integration code preserved. Note: This revert temporarily breaks linting rules (FA102) because origin/main uses PEP 604 union syntax without the required future import. This will be resolved when origin/main is updated or when this feature branch is rebased.
6b5b854 to
5cc5303
Compare
Changes based on @anaprietonem's Nov 25 review comments: - Remove GCS/Azure references from catalog.py (to be added in a future phase after proper discussion) - Make aiohttp an optional dependency in pyproject.toml (install with: pip install anemoi-training[remote]) - Add conditional import handling for aiohttp with HAS_AIOHTTP flag - Add execution patterns documentation to pipeline.py module docstring (explains standalone vs callback integration patterns) - Add ComponentCatalog relationship docs explaining connection to anemoi.utils.registry
- Add CheckpointContext dataclass for carrying state through pipeline - Add PipelineStage abstract base class for all pipeline stages - Add comprehensive exception hierarchy for checkpoint operations - Establish foundation for three-layer checkpoint architecture Part of Phase 1 checkpoint pipeline infrastructure (#493)
…dling - Extract _handle_unknown_loader() for loader error handling - Extract _handle_unknown_modifier() for modifier error handling - Add _build_loader_error_message() for detailed error messages - Add _build_modifier_error_message() for modifier errors - Add _get_loader_type_descriptions() for loader documentation - Add _get_modifier_type_descriptions() for modifier documentation - Add _find_similar_names() for smart suggestions - Use list comprehensions for better performance - Preserve all helpful error messages and suggestions - Reduce McCabe complexity from 11 to acceptable levels
…t validation - Add CheckpointPipeline class for stage orchestration - Implement async and sync execution modes - Add Hydra-based configuration support with instantiation - Implement smart pipeline composition validation - Check source-loader-modifier ordering automatically - Detect duplicate stages and provide warnings - Suggest missing stages based on pipeline composition - Add pre-execution validation with health checks - Support dynamic stage management (add/remove/clear) - Include comprehensive error handling with context - Track stage execution in metadata for debugging - Support continue_on_error for resilient pipelines
The conftest.py used hardcoded relative path "../src/anemoi/training/config" which only worked from training/tests/ directory, causing Hydra to hang when tests were run from training/ root. Changes: - Add _get_config_path() helper to dynamically locate config directory - Supports running tests from any directory (training/ or training/tests/) - Add lazy import of AnemoiDatasetsDataModule for performance - Tests now run in 40s instead of hanging for 2+ minutes Fixes issue where `cd training && pytest tests/checkpoint/` would timeout.
Updated error handling tests to expect the proper CheckpointConfigError exception instead of generic ValueError, matching the actual implementation in catalog.py. Changes: - test_get_source_target_when_empty: expect CheckpointConfigError - test_get_loader_target_when_empty: expect CheckpointConfigError - test_get_modifier_target_when_empty: expect CheckpointConfigError - Update assertions to match actual error message format
Remove undefined placeholder functions from __all__ that don't exist yet: - create_error_context - ensure_checkpoint_error - log_checkpoint_error - map_pytorch_error_to_checkpoint_error - validate_checkpoint_keys Apply ruff auto-fixes: - Sort __all__ alphabetically within sections - Add trailing commas where missing Fixes 5 F822 ruff errors (undefined name in __all__).
Remove safetensors format support to simplify checkpoint handling: Source changes: - Remove safetensors optional dependency from pyproject.toml - Remove safetensors import and HAS_SAFETENSORS flag from formats.py - Update checkpoint_format type hints to exclude "safetensors" - Remove safetensors loading/saving logic and helper functions - Remove .safetensors from supported file extensions in exceptions - Update documentation to reflect supported formats only Test changes: - Remove safetensors fixture from conftest.py - Remove 8 safetensors test methods from test_formats.py - Remove safetensors skip logic from test_utils.py - Remove unused unittest.mock.patch import Net changes: -55 lines from source, -157 lines from tests Supported formats after this change: lightning, pytorch, state_dict
for more information, see https://pre-commit.ci
- Restore missing CheckpointSourceError and CheckpointTimeoutError - Fix TRY003/EM102: Extract exception messages to variables before raising - Fix G004: Convert logging f-strings to % formatting - Fix TRY300: Add explicit else blocks for clarity - Fix BLE001: Specify exception types (OSError, RuntimeError) instead of blind catch - Fix TRY400: Use logging.exception() for automatic traceback inclusion - Fix TC002/TC003: Remove incorrect type checking noqa comments All core checkpoint tests passing (43/43 in test_formats.py, 31/31 in test_base.py and test_pipeline.py).
for more information, see https://pre-commit.ci
Changes based on @anaprietonem's Nov 25 review comments: - Remove GCS/Azure references from catalog.py (to be added in a future phase after proper discussion) - Make aiohttp an optional dependency in pyproject.toml (install with: pip install anemoi-training[remote]) - Add conditional import handling for aiohttp with HAS_AIOHTTP flag - Add execution patterns documentation to pipeline.py module docstring (explains standalone vs callback integration patterns) - Add ComponentCatalog relationship docs explaining connection to anemoi.utils.registry
Add comprehensive documentation for the Phase 1 checkpoint pipeline: - checkpoint_integration.rst: Core API reference (CheckpointContext, PipelineStage, CheckpointPipeline), error handling, utility functions, component discovery, and configuration examples - checkpoint_pipeline_configuration.rst: Configuration guide covering pipeline structure, stage types (sources, loaders, modifiers), complete examples, migration from legacy config, and best practices - checkpoint_troubleshooting.rst: Diagnostic guide for common issues including configuration errors, environment setup, file/path problems, loading compatibility, network issues, and memory optimization All documentation reflects currently implemented features. Planned features for subsequent phases are clearly marked with notes.
- Add pytest-asyncio dependency and configure asyncio_mode=auto - Include validation errors in CheckpointValidationError message - Make nested tensor validation recursive for deep structures - Fix "infinite values" message to match test expectations - Catch all exceptions in get_checkpoint_metadata for corrupted files - Fix test assertions for source_path and state_dict validation
Fix test failures that occurred in GitHub Actions CI: - test_pipeline.py: Avoid Hydra _target_ instantiation with test module paths that aren't installed as packages in CI. Changed tests to use pre-instantiated MockStage objects passed directly to pipeline. - test_formats.py: Fix PicklingError in test_detect_format_non_dict_checkpoint by saving state_dict instead of raw model object (recommended approach). - test_exceptions.py: Fix constructor signature mismatches for CheckpointSourceError tests, remove non-existent logging tests, and adjust picklability test to use base CheckpointError class. All 192 checkpoint tests now pass with 7 skipped.
- Add return type annotations (-> None) to all test methods - Fix isinstance syntax to use X | Y instead of (X, Y) - Move Path import to TYPE_CHECKING block in utils.py - Add else block for TRY300 compliance in formats.py - Fix PT017: refactor try/except assertions to pytest.raises() - Fix SIM117: combine nested with statements - Fix EM101/EM102: extract exception messages to variables - Rename unused parameters with underscore prefix - Rename fixture without return value with underscore prefix (PT004)
In Python 3.10, asyncio.TimeoutError is distinct from the builtin TimeoutError. This fix catches both exception types to ensure proper timeout handling across all supported Python versions. Also restores marshmallow warning filters in pytest.ini for compatibility with marshmallow 3.x installations while keeping asyncio test configuration.
cad403d to
1e0f1a2
Compare
for more information, see https://pre-commit.ci
PreprocessorSchema must be imported at runtime for Pydantic v2 model building. Reverts the TYPE_CHECKING guard from commit 2a0d002 that caused 16 CI test failures across Python 3.11/3.12/3.13.
anaprietonem
left a comment
There was a problem hiding this comment.
Thanks Jesper for addressing the comments and implementing this. Nice work!! LGTM and can be merged.
Summary
This PR implements Phase 1 of the checkpoint pipeline infrastructure, establishing the core abstractions and pipeline pattern for flexible checkpoint handling in Anemoi.
Changes
Technical Details
The component catalog uses a hybrid approach for identifying abstract classes:
This ensures true dynamic discovery without hardcoded component lists, making the system easily extensible.
Testing
Related Issues
Closes #493
Next Steps
This is Phase 1 of a multi-phase implementation:
📚 Documentation preview 📚: https://anemoi-training--501.org.readthedocs.build/en/501/
📚 Documentation preview 📚: https://anemoi-graphs--501.org.readthedocs.build/en/501/
📚 Documentation preview 📚: https://anemoi-models--501.org.readthedocs.build/en/501/