-
Notifications
You must be signed in to change notification settings - Fork 82
feat: Checkpoint pipeline infrastructure (Phase 1) #501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
JesperDramsch
merged 32 commits into
main
from
feature/checkpoint-pipeline-infrastructure
Feb 17, 2026
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
a6c84e0
feat(checkpoint): add core checkpoint pipeline infrastructure
JesperDramsch ba49884
feat(checkpoint): add pipeline orchestrator for stage execution
JesperDramsch 9dc5faa
feat(checkpoint): add component catalog with dynamic discovery
JesperDramsch fbf4c4a
feat(checkpoint): add utilities and comprehensive tests
JesperDramsch 3d319c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4d62b45
feat(checkpoint): add multi-format checkpoint support
JesperDramsch 9879a03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 24885a5
feat(checkpoint): add core pipeline base classes
JesperDramsch 5a669df
feat(checkpoint): implement CheckpointContext with validation and Pip…
JesperDramsch 8efa937
refactor(checkpoint): reduce complexity in ComponentCatalog error han…
JesperDramsch 6ba4e87
feat(checkpoint): implement CheckpointPipeline orchestrator with smar…
JesperDramsch 2170e76
fix(tests): resolve pytest hanging by fixing Hydra config paths
JesperDramsch 75b0eb9
test(checkpoint): update catalog tests to expect CheckpointConfigError
JesperDramsch 00cfd2c
chore(checkpoint): clean up module exports and apply formatting
JesperDramsch ec3eaaa
refactor(checkpoint): remove safetensors support from codebase
JesperDramsch e1e9566
docs: add logging statement for device on checkpoint load
JesperDramsch bba3547
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 55a93c1
fix: resolve ruff linting errors in checkpoint pipeline
JesperDramsch 8ec1542
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c336948
refactor(checkpoint): address PR #501 reviewer feedback
JesperDramsch 0110f03
docs: add checkpoint pipeline documentation for readthedocs
JesperDramsch a830551
fix(checkpoint): resolve async test failures and validation errors
JesperDramsch ad1406d
fix(tests): resolve CI test failures in checkpoint tests
JesperDramsch 2d7ebb0
fix(lint): resolve ruff errors in checkpoint tests and source files
JesperDramsch 2a0d002
fix: type_checking
JesperDramsch de6b2b3
fix: lazy import
JesperDramsch 1e0f1a2
fix: handle asyncio.TimeoutError for Python 3.10 compatibility
JesperDramsch 3ce3f19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ddbd706
fix: restore runtime import for Pydantic schema forward reference
JesperDramsch 5241cda
Merge branch 'main' into feature/checkpoint-pipeline-infrastructure
JesperDramsch dbc56ad
Merge branch 'main' into feature/checkpoint-pipeline-infrastructure
JesperDramsch 44fee31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,256 @@ | ||
| .. _checkpoint_integration: | ||
|
|
||
| ################################# | ||
| Checkpoint Pipeline Integration | ||
| ################################# | ||
|
|
||
| This guide covers the checkpoint pipeline infrastructure for Anemoi | ||
| training. The pipeline provides a foundation for building checkpoint | ||
| loading workflows. | ||
|
|
||
| .. note:: | ||
|
|
||
| This documents Phase 1 (Pipeline Infrastructure). Sources, loaders, | ||
| and modifiers are implemented in subsequent phases. | ||
|
|
||
| ************** | ||
| Core Classes | ||
| ************** | ||
|
|
||
| CheckpointContext | ||
| ================= | ||
|
|
||
| The ``CheckpointContext`` carries state through pipeline stages: | ||
|
|
||
| .. code:: python | ||
|
|
||
| from anemoi.training.checkpoint import CheckpointContext | ||
|
|
||
| # Create context with a model | ||
| context = CheckpointContext( | ||
| model=my_model, | ||
| config=my_config, # Optional OmegaConf config | ||
| ) | ||
|
|
||
| # Access and update metadata | ||
| context.update_metadata(source="local", loaded=True) | ||
| print(context.metadata) | ||
|
|
||
| **Attributes:** | ||
|
|
||
| - ``model``: PyTorch model | ||
| - ``optimizer``: Optional optimizer | ||
| - ``scheduler``: Optional learning rate scheduler | ||
| - ``checkpoint_path``: Path to checkpoint file | ||
| - ``checkpoint_data``: Loaded checkpoint dictionary | ||
| - ``metadata``: Dictionary for tracking state | ||
| - ``config``: Optional Hydra configuration | ||
| - ``checkpoint_format``: Detected format (lightning, pytorch, | ||
| state_dict) | ||
|
|
||
| PipelineStage | ||
| ============= | ||
|
|
||
| Base class for implementing pipeline stages: | ||
|
|
||
| .. code:: python | ||
|
|
||
| from anemoi.training.checkpoint import PipelineStage, CheckpointContext | ||
|
|
||
|
|
||
| class MyCustomStage(PipelineStage): | ||
| def __init__(self, param: str): | ||
| self.param = param | ||
|
|
||
| async def process(self, context: CheckpointContext) -> CheckpointContext: | ||
| # Implement your logic here | ||
| context.update_metadata(custom_param=self.param) | ||
| return context | ||
|
|
||
| CheckpointPipeline | ||
| ================== | ||
|
|
||
| Orchestrates execution of multiple stages: | ||
|
|
||
| .. code:: python | ||
|
|
||
| from anemoi.training.checkpoint import CheckpointPipeline, CheckpointContext | ||
|
|
||
| # Build pipeline with stages | ||
| pipeline = CheckpointPipeline( | ||
| stages=[stage1, stage2, stage3], | ||
| async_execution=True, | ||
| continue_on_error=False, | ||
| ) | ||
|
|
||
| # Execute | ||
| context = CheckpointContext(model=my_model) | ||
| result = await pipeline.execute(context) | ||
|
|
||
| **From Hydra configuration:** | ||
|
|
||
| .. code:: python | ||
|
|
||
| from omegaconf import OmegaConf | ||
|
|
||
| config = OmegaConf.create({ | ||
| "stages": [ | ||
| {"_target_": "my_module.MyStage", "param": "value"}, | ||
| ], | ||
| "async_execution": True, | ||
| }) | ||
|
|
||
| pipeline = CheckpointPipeline.from_config(config) | ||
|
|
||
| **************** | ||
| Error Handling | ||
| **************** | ||
|
|
||
| The checkpoint module provides a hierarchy of exceptions: | ||
|
|
||
| .. code:: python | ||
|
|
||
| from anemoi.training.checkpoint import ( | ||
| CheckpointError, # Base exception | ||
| CheckpointNotFoundError, # File not found | ||
| CheckpointLoadError, # Loading failed | ||
| CheckpointValidationError, # Validation failed | ||
| CheckpointSourceError, # Source fetch failed | ||
| CheckpointTimeoutError, # Operation timed out | ||
| CheckpointConfigError, # Configuration error | ||
| CheckpointIncompatibleError, # Model/checkpoint mismatch | ||
| ) | ||
|
|
||
| try: | ||
| result = await pipeline.execute(context) | ||
| except CheckpointNotFoundError as e: | ||
| print(f"Checkpoint not found: {e.path}") | ||
| except CheckpointLoadError as e: | ||
| print(f"Failed to load: {e}") | ||
| except CheckpointError as e: | ||
| print(f"Checkpoint error: {e}") | ||
|
|
||
| ******************* | ||
| Utility Functions | ||
| ******************* | ||
|
|
||
| Format Detection | ||
| ================ | ||
|
|
||
| .. code:: python | ||
|
|
||
| from anemoi.training.checkpoint.formats import ( | ||
| detect_checkpoint_format, | ||
| load_checkpoint, | ||
| extract_state_dict, | ||
| ) | ||
|
|
||
| # Auto-detect format | ||
| fmt = detect_checkpoint_format("/path/to/checkpoint.ckpt") | ||
| # Returns: "lightning", "pytorch", or "state_dict" | ||
|
|
||
| # Load checkpoint | ||
| data = load_checkpoint("/path/to/checkpoint.ckpt") | ||
|
|
||
| # Extract state dict from various formats | ||
| state_dict = extract_state_dict(data) | ||
|
|
||
| Checkpoint Utilities | ||
| ==================== | ||
|
|
||
| .. code:: python | ||
|
|
||
| from anemoi.training.checkpoint import ( | ||
| get_checkpoint_metadata, | ||
| validate_checkpoint, | ||
| calculate_checksum, | ||
| compare_state_dicts, | ||
| estimate_checkpoint_memory, | ||
| format_size, | ||
| ) | ||
|
|
||
| # Get metadata without loading full checkpoint | ||
| metadata = get_checkpoint_metadata(Path("model.ckpt")) | ||
|
|
||
| # Validate checkpoint structure | ||
| validate_checkpoint(checkpoint_data) | ||
|
|
||
| # Calculate file checksum | ||
| checksum = calculate_checksum(Path("model.ckpt"), algorithm="sha256") | ||
|
|
||
| # Compare state dictionaries | ||
| missing, unexpected, mismatched = compare_state_dicts(source_dict, target_dict) | ||
|
|
||
| # Estimate memory usage | ||
| bytes_needed = estimate_checkpoint_memory(checkpoint_data) | ||
| print(format_size(bytes_needed)) # e.g., "1.5 GB" | ||
|
|
||
| ********************* | ||
| Component Discovery | ||
| ********************* | ||
|
|
||
| The ``ComponentCatalog`` provides discovery of available pipeline | ||
| components: | ||
|
|
||
| .. code:: python | ||
|
|
||
| from anemoi.training.checkpoint import ComponentCatalog | ||
|
|
||
| # List available components | ||
| print(ComponentCatalog.list_sources()) # Available source types | ||
| print(ComponentCatalog.list_loaders()) # Available loading strategies | ||
| print(ComponentCatalog.list_modifiers()) # Available model modifiers | ||
|
|
||
| # Get Hydra target path for a component | ||
| target = ComponentCatalog.get_source_target("local") | ||
|
|
||
| ******************** | ||
| Configuration YAML | ||
| ******************** | ||
|
|
||
| Example pipeline configuration: | ||
|
|
||
| .. code:: yaml | ||
|
|
||
| # config/training/checkpoint_pipeline.yaml | ||
| training: | ||
| checkpoint_pipeline: | ||
| stages: | ||
| # Each stage uses Hydra _target_ pattern | ||
| - _target_: my_module.sources.LocalSource | ||
| path: /path/to/checkpoint.ckpt | ||
|
|
||
| - _target_: my_module.loaders.WeightsOnlyLoader | ||
| strict: false | ||
|
|
||
| async_execution: true | ||
| continue_on_error: false | ||
|
|
||
| **Execution Patterns:** | ||
|
|
||
| The pipeline supports two execution approaches: | ||
|
|
||
| #. **Standalone (recommended)**: Execute during model initialization | ||
|
|
||
| .. code:: python | ||
|
|
||
| pipeline = CheckpointPipeline.from_config(config) | ||
| context = CheckpointContext(model=model) | ||
| result = await pipeline.execute(context) | ||
| model = result.model | ||
|
|
||
| #. **Lightning callback**: Integrate with PyTorch Lightning lifecycle | ||
| for coordinated checkpoint operations. | ||
|
|
||
| ************ | ||
| Next Steps | ||
| ************ | ||
|
|
||
| This infrastructure enables subsequent phases: | ||
|
|
||
| - **Phase 2**: Loading strategies (weights-only, transfer learning, | ||
| warm/cold start) | ||
| - **Phase 3**: Integration with model modifiers and legacy migration | ||
|
|
||
| See :ref:`checkpoint_pipeline_configuration` for configuration details | ||
| and :ref:`checkpoint_troubleshooting` for common issues. | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.