Skip to content

[WIP] 410 refactor of model initialisation ie weight loading model freezing transfer learning#442

Draft
JesperDramsch wants to merge 13 commits intomainfrom
410-refactor-of-model-initialisation-ie-weight-loading-model-freezing-transfer-learning
Draft

[WIP] 410 refactor of model initialisation ie weight loading model freezing transfer learning#442
JesperDramsch wants to merge 13 commits intomainfrom
410-refactor-of-model-initialisation-ie-weight-loading-model-freezing-transfer-learning

Conversation

@JesperDramsch
Copy link
Member

@JesperDramsch JesperDramsch commented Jul 29, 2025

Description

Implements instantiatable Model modifiers, that can, e.g. load weights or freeze components.

What problem does this change solve?

It implements a modular and extensible system for the initialisation of models that gets rid of a stack of nested if statements and enables extension in anticipation of #248

What issue or task does this change relate to?

Closes #410
Prepares changes for #248

As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/

By opening this pull request, I affirm that all authors agree to the Contributor License Agreement.


📚 Documentation preview 📚: https://anemoi-training--442.org.readthedocs.build/en/442/


📚 Documentation preview 📚: https://anemoi-graphs--442.org.readthedocs.build/en/442/


📚 Documentation preview 📚: https://anemoi-models--442.org.readthedocs.build/en/442/

@JesperDramsch JesperDramsch added this to the Fine-Tuning milestone Jul 29, 2025
@JesperDramsch JesperDramsch self-assigned this Jul 29, 2025
@github-project-automation github-project-automation bot moved this to Now In Progress in Anemoi-dev Jul 29, 2025
@github-actions github-actions bot added training enhancement New feature or request labels Jul 29, 2025
@JesperDramsch
Copy link
Member Author

I just saw that in the meantime, there was this change, which I will have to address in a future commit:

        model.data_indices = self.data_indices
        # check data indices in original checkpoint and current data indices are the same
        self.data_indices.compare_variables(model._ckpt_model_name_to_index, self.data_indices.name_to_index)

@mchantry mchantry added the ATS Approval Needed Approval needed by ATS label Jul 30, 2025
@JPXKQX
Copy link
Member

JPXKQX commented Jul 31, 2025

Thanks Jesper, I think this refactoring makes a lot of sense. Would it make sense to have another modifier "ResumeRun..." to which we could bring all the logic from run_id, fork_run_id, load_only_weights and warm_start?

@JesperDramsch
Copy link
Member Author

Thanks Jesper, I think this refactoring makes a lot of sense. Would it make sense to have another modifier "ResumeRun..." to which we could bring all the logic from run_id, fork_run_id, load_only_weights and warm_start?

Possibly. I believe my original design around fork_run_id ended up confusing most people, so we could take a look whether this design could be used to fix that.

Add extensible checkpoint loading system that separates checkpoint
source handling from model weight loading strategies.

Changes:
- Add CheckpointLoaderRegistry for extensible source loading
  (local, S3, HTTP, GCS, Azure support)
- Add ModelLoaderRegistry for weight loading strategies
  (standard, weights_only, transfer_learning)
- Implement registry pattern for future extensibility

This infrastructure enables:
- Remote checkpoint loading from cloud storage
- Modular loading strategies
- Clean separation of concerns
- Foundation for advanced features (quantization, PEFT)

Depends on: #458 branch infrastructure
Related: #422, #410
Replace WeightsInitModelModifier with new checkpoint loading architecture
and integrate checkpoint loading into training pipeline.

Changes:
- Remove WeightsInitModelModifier (functionality moved to checkpoint_loading)
- Update TransferLearningModelModifier to use new model_loading system
- Add configurable strict and skip_mismatched parameters
- Integrate checkpoint loading in training pipeline before model modifiers
- Add _load_checkpoint_if_configured method to trainer

Benefits:
- Clean separation: checkpoint loading vs model transformation
- Better parameter control for transfer learning
- DRY principle: single checkpoint loading implementation
- Extensible: prepare for quantization, PEFT features

Related: #422, #410
Depends: checkpoint loading infrastructure
Introduce new configuration schema and templates for the checkpoint
loading system to replace legacy WeightsInitModelModifier configs.

Changes:
- Add checkpoint_loading field to training schema with Pydantic validation
- Create checkpoint_loading config directory with templates:
  * weights_only.yml - Load only model weights
  * transfer_learning.yml - Load with size mismatch handling
  * standard.yml - Full Lightning checkpoint loading
- Update transfer_learning.yml with new parameters (strict, skip_mismatched)
- Add enhanced_fine_tuning.yml example combining transfer learning + freezing

Benefits:
- Schema validation for checkpoint loading configurations
- Pre-built templates for common use cases
- Flexible parameter configuration per loader type
- Clear separation from model modifier configs

Related: #422, #410
Refactor existing tests for new architecture and add comprehensive
test coverage for the checkpoint loading system.

Changes:
- Remove WeightsInitModelModifier tests (functionality moved)
- Update TransferLearningModelModifier tests for new #458 integration
- Add checkpoint loading integration tests for training pipeline
- Add comprehensive test suites from #458 branch:
  * test_checkpoint_loaders.py - Source loading tests (S3, HTTP, etc.)
  * test_model_loading.py - Weight loading strategy tests
- Update integration tests for new model modifier workflow
- Add configuration validation and error handling tests

Coverage:
- All loader types (weights_only, transfer_learning, standard)
- Remote checkpoint sources (S3, HTTP, GCS, Azure)
- Training pipeline integration
- Configuration validation and error scenarios
- Model modifier compatibility with new system

Related: #422, #410
- Update documentation with comprehensive checkpoint loading sections
- Fix compatibility issues in forecaster and checkpoint utilities
- Update config templates to maintain backward compatibility
- Fix pre-commit hook issues (ruff RET504, docsig parameter mismatches)
- Add gradient validation during freezing to ensure parameters are truly frozen
- Implement optimized module lookup using PyTorch's get_submodule()
- Improve error messages with clear context for debugging
- Add comprehensive module-level documentation
- Enhance test coverage with improved documentation
- Add noqa comments for necessary complex methods

The gradient validation ensures frozen parameters don't accumulate gradients during
training, providing runtime verification of the freezing mechanism. The optimized
lookup improves performance for deep models by using O(1) access when possible.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ATS Approval Needed Approval needed by ATS enhancement New feature or request training

Projects

Status: Reviewers needed

Development

Successfully merging this pull request may close these issues.

Model Transformation Layer - Post-loading modifications (freezing, transfer learning, adapters)

3 participants