[WIP] 410 refactor of model initialisation ie weight loading model freezing transfer learning#442
Draft
JesperDramsch wants to merge 13 commits intomainfrom
Conversation
Member
Author
|
I just saw that in the meantime, there was this change, which I will have to address in a future commit: |
Member
|
Thanks Jesper, I think this refactoring makes a lot of sense. Would it make sense to have another modifier "ResumeRun..." to which we could bring all the logic from |
Member
Author
Possibly. I believe my original design around |
18 tasks
Add extensible checkpoint loading system that separates checkpoint source handling from model weight loading strategies. Changes: - Add CheckpointLoaderRegistry for extensible source loading (local, S3, HTTP, GCS, Azure support) - Add ModelLoaderRegistry for weight loading strategies (standard, weights_only, transfer_learning) - Implement registry pattern for future extensibility This infrastructure enables: - Remote checkpoint loading from cloud storage - Modular loading strategies - Clean separation of concerns - Foundation for advanced features (quantization, PEFT) Depends on: #458 branch infrastructure Related: #422, #410
Replace WeightsInitModelModifier with new checkpoint loading architecture and integrate checkpoint loading into training pipeline. Changes: - Remove WeightsInitModelModifier (functionality moved to checkpoint_loading) - Update TransferLearningModelModifier to use new model_loading system - Add configurable strict and skip_mismatched parameters - Integrate checkpoint loading in training pipeline before model modifiers - Add _load_checkpoint_if_configured method to trainer Benefits: - Clean separation: checkpoint loading vs model transformation - Better parameter control for transfer learning - DRY principle: single checkpoint loading implementation - Extensible: prepare for quantization, PEFT features Related: #422, #410 Depends: checkpoint loading infrastructure
Introduce new configuration schema and templates for the checkpoint loading system to replace legacy WeightsInitModelModifier configs. Changes: - Add checkpoint_loading field to training schema with Pydantic validation - Create checkpoint_loading config directory with templates: * weights_only.yml - Load only model weights * transfer_learning.yml - Load with size mismatch handling * standard.yml - Full Lightning checkpoint loading - Update transfer_learning.yml with new parameters (strict, skip_mismatched) - Add enhanced_fine_tuning.yml example combining transfer learning + freezing Benefits: - Schema validation for checkpoint loading configurations - Pre-built templates for common use cases - Flexible parameter configuration per loader type - Clear separation from model modifier configs Related: #422, #410
Refactor existing tests for new architecture and add comprehensive test coverage for the checkpoint loading system. Changes: - Remove WeightsInitModelModifier tests (functionality moved) - Update TransferLearningModelModifier tests for new #458 integration - Add checkpoint loading integration tests for training pipeline - Add comprehensive test suites from #458 branch: * test_checkpoint_loaders.py - Source loading tests (S3, HTTP, etc.) * test_model_loading.py - Weight loading strategy tests - Update integration tests for new model modifier workflow - Add configuration validation and error handling tests Coverage: - All loader types (weights_only, transfer_learning, standard) - Remote checkpoint sources (S3, HTTP, GCS, Azure) - Training pipeline integration - Configuration validation and error scenarios - Model modifier compatibility with new system Related: #422, #410
- Update documentation with comprehensive checkpoint loading sections - Fix compatibility issues in forecaster and checkpoint utilities - Update config templates to maintain backward compatibility - Fix pre-commit hook issues (ruff RET504, docsig parameter mismatches)
- Add gradient validation during freezing to ensure parameters are truly frozen - Implement optimized module lookup using PyTorch's get_submodule() - Improve error messages with clear context for debugging - Add comprehensive module-level documentation - Enhance test coverage with improved documentation - Add noqa comments for necessary complex methods The gradient validation ensures frozen parameters don't accumulate gradients during training, providing runtime verification of the freezing mechanism. The optimized lookup improves performance for deep models by using O(1) access when possible.
12 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Implements instantiatable Model modifiers, that can, e.g. load weights or freeze components.
What problem does this change solve?
It implements a modular and extensible system for the initialisation of models that gets rid of a stack of nested
ifstatements and enables extension in anticipation of #248What issue or task does this change relate to?
Closes #410
Prepares changes for #248
As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/
By opening this pull request, I affirm that all authors agree to the Contributor License Agreement.
📚 Documentation preview 📚: https://anemoi-training--442.org.readthedocs.build/en/442/
📚 Documentation preview 📚: https://anemoi-graphs--442.org.readthedocs.build/en/442/
📚 Documentation preview 📚: https://anemoi-models--442.org.readthedocs.build/en/442/