Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a6c84e0
feat(checkpoint): add core checkpoint pipeline infrastructure
JesperDramsch Aug 21, 2025
ba49884
feat(checkpoint): add pipeline orchestrator for stage execution
JesperDramsch Aug 21, 2025
9dc5faa
feat(checkpoint): add component catalog with dynamic discovery
JesperDramsch Aug 21, 2025
fbf4c4a
feat(checkpoint): add utilities and comprehensive tests
JesperDramsch Aug 21, 2025
3d319c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2025
4d62b45
feat(checkpoint): add multi-format checkpoint support
JesperDramsch Aug 21, 2025
9879a03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2025
24885a5
feat(checkpoint): add core pipeline base classes
JesperDramsch Sep 18, 2025
5a669df
feat(checkpoint): implement CheckpointContext with validation and Pip…
JesperDramsch Sep 19, 2025
8efa937
refactor(checkpoint): reduce complexity in ComponentCatalog error han…
JesperDramsch Sep 19, 2025
6ba4e87
feat(checkpoint): implement CheckpointPipeline orchestrator with smar…
JesperDramsch Sep 19, 2025
2170e76
fix(tests): resolve pytest hanging by fixing Hydra config paths
JesperDramsch Oct 16, 2025
75b0eb9
test(checkpoint): update catalog tests to expect CheckpointConfigError
JesperDramsch Oct 16, 2025
00cfd2c
chore(checkpoint): clean up module exports and apply formatting
JesperDramsch Oct 16, 2025
ec3eaaa
refactor(checkpoint): remove safetensors support from codebase
JesperDramsch Oct 29, 2025
e1e9566
docs: add logging statement for device on checkpoint load
JesperDramsch Nov 19, 2025
bba3547
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2025
55a93c1
fix: resolve ruff linting errors in checkpoint pipeline
JesperDramsch Nov 26, 2025
8ec1542
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2025
c336948
refactor(checkpoint): address PR #501 reviewer feedback
JesperDramsch Dec 5, 2025
0110f03
docs: add checkpoint pipeline documentation for readthedocs
JesperDramsch Dec 5, 2025
a830551
fix(checkpoint): resolve async test failures and validation errors
JesperDramsch Dec 23, 2025
ad1406d
fix(tests): resolve CI test failures in checkpoint tests
JesperDramsch Dec 23, 2025
2d7ebb0
fix(lint): resolve ruff errors in checkpoint tests and source files
JesperDramsch Dec 23, 2025
2a0d002
fix: type_checking
JesperDramsch Dec 23, 2025
de6b2b3
fix: lazy import
JesperDramsch Dec 23, 2025
1e0f1a2
fix: handle asyncio.TimeoutError for Python 3.10 compatibility
JesperDramsch Jan 29, 2026
3ce3f19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2026
ddbd706
fix: restore runtime import for Pydantic schema forward reference
JesperDramsch Feb 5, 2026
5241cda
Merge branch 'main' into feature/checkpoint-pipeline-infrastructure
JesperDramsch Feb 12, 2026
dbc56ad
Merge branch 'main' into feature/checkpoint-pipeline-infrastructure
JesperDramsch Feb 17, 2026
44fee31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 256 additions & 0 deletions training/docs/checkpoint_integration.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
.. _checkpoint_integration:

#################################
Checkpoint Pipeline Integration
#################################

This guide covers the checkpoint pipeline infrastructure for Anemoi
training. The pipeline provides a foundation for building checkpoint
loading workflows.

.. note::

This documents Phase 1 (Pipeline Infrastructure). Sources, loaders,
and modifiers are implemented in subsequent phases.

**************
Core Classes
**************

CheckpointContext
=================

The ``CheckpointContext`` carries state through pipeline stages:

.. code:: python

from anemoi.training.checkpoint import CheckpointContext

# Create context with a model
context = CheckpointContext(
model=my_model,
config=my_config, # Optional OmegaConf config
)

# Access and update metadata
context.update_metadata(source="local", loaded=True)
print(context.metadata)

**Attributes:**

- ``model``: PyTorch model
- ``optimizer``: Optional optimizer
- ``scheduler``: Optional learning rate scheduler
- ``checkpoint_path``: Path to checkpoint file
- ``checkpoint_data``: Loaded checkpoint dictionary
- ``metadata``: Dictionary for tracking state
- ``config``: Optional Hydra configuration
- ``checkpoint_format``: Detected format (lightning, pytorch,
state_dict)

PipelineStage
=============

Base class for implementing pipeline stages:

.. code:: python

from anemoi.training.checkpoint import PipelineStage, CheckpointContext


class MyCustomStage(PipelineStage):
def __init__(self, param: str):
self.param = param

async def process(self, context: CheckpointContext) -> CheckpointContext:
# Implement your logic here
context.update_metadata(custom_param=self.param)
return context

CheckpointPipeline
==================

Orchestrates execution of multiple stages:

.. code:: python

from anemoi.training.checkpoint import CheckpointPipeline, CheckpointContext

# Build pipeline with stages
pipeline = CheckpointPipeline(
stages=[stage1, stage2, stage3],
async_execution=True,
continue_on_error=False,
)

# Execute
context = CheckpointContext(model=my_model)
result = await pipeline.execute(context)

**From Hydra configuration:**

.. code:: python

from omegaconf import OmegaConf

config = OmegaConf.create({
"stages": [
{"_target_": "my_module.MyStage", "param": "value"},
],
"async_execution": True,
})

pipeline = CheckpointPipeline.from_config(config)

****************
Error Handling
****************

The checkpoint module provides a hierarchy of exceptions:

.. code:: python

from anemoi.training.checkpoint import (
CheckpointError, # Base exception
CheckpointNotFoundError, # File not found
CheckpointLoadError, # Loading failed
CheckpointValidationError, # Validation failed
CheckpointSourceError, # Source fetch failed
CheckpointTimeoutError, # Operation timed out
CheckpointConfigError, # Configuration error
CheckpointIncompatibleError, # Model/checkpoint mismatch
)

try:
result = await pipeline.execute(context)
except CheckpointNotFoundError as e:
print(f"Checkpoint not found: {e.path}")
except CheckpointLoadError as e:
print(f"Failed to load: {e}")
except CheckpointError as e:
print(f"Checkpoint error: {e}")

*******************
Utility Functions
*******************

Format Detection
================

.. code:: python

from anemoi.training.checkpoint.formats import (
detect_checkpoint_format,
load_checkpoint,
extract_state_dict,
)

# Auto-detect format
fmt = detect_checkpoint_format("/path/to/checkpoint.ckpt")
# Returns: "lightning", "pytorch", or "state_dict"

# Load checkpoint
data = load_checkpoint("/path/to/checkpoint.ckpt")

# Extract state dict from various formats
state_dict = extract_state_dict(data)

Checkpoint Utilities
====================

.. code:: python

from anemoi.training.checkpoint import (
get_checkpoint_metadata,
validate_checkpoint,
calculate_checksum,
compare_state_dicts,
estimate_checkpoint_memory,
format_size,
)

# Get metadata without loading full checkpoint
metadata = get_checkpoint_metadata(Path("model.ckpt"))

# Validate checkpoint structure
validate_checkpoint(checkpoint_data)

# Calculate file checksum
checksum = calculate_checksum(Path("model.ckpt"), algorithm="sha256")

# Compare state dictionaries
missing, unexpected, mismatched = compare_state_dicts(source_dict, target_dict)

# Estimate memory usage
bytes_needed = estimate_checkpoint_memory(checkpoint_data)
print(format_size(bytes_needed)) # e.g., "1.5 GB"

*********************
Component Discovery
*********************

The ``ComponentCatalog`` provides discovery of available pipeline
components:

.. code:: python

from anemoi.training.checkpoint import ComponentCatalog

# List available components
print(ComponentCatalog.list_sources()) # Available source types
print(ComponentCatalog.list_loaders()) # Available loading strategies
print(ComponentCatalog.list_modifiers()) # Available model modifiers

# Get Hydra target path for a component
target = ComponentCatalog.get_source_target("local")

********************
Configuration YAML
********************

Example pipeline configuration:

.. code:: yaml

# config/training/checkpoint_pipeline.yaml
training:
checkpoint_pipeline:
stages:
# Each stage uses Hydra _target_ pattern
- _target_: my_module.sources.LocalSource
path: /path/to/checkpoint.ckpt

- _target_: my_module.loaders.WeightsOnlyLoader
strict: false

async_execution: true
continue_on_error: false

**Execution Patterns:**

The pipeline supports two execution approaches:

#. **Standalone (recommended)**: Execute during model initialization

.. code:: python

pipeline = CheckpointPipeline.from_config(config)
context = CheckpointContext(model=model)
result = await pipeline.execute(context)
model = result.model

#. **Lightning callback**: Integrate with PyTorch Lightning lifecycle
for coordinated checkpoint operations.

************
Next Steps
************

This infrastructure enables subsequent phases:

- **Phase 2**: Loading strategies (weights-only, transfer learning,
warm/cold start)
- **Phase 3**: Integration with model modifiers and legacy migration

See :ref:`checkpoint_pipeline_configuration` for configuration details
and :ref:`checkpoint_troubleshooting` for common issues.
Loading