-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Modular] Save Modular Pipeline weights to Hub #13168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |||||
| import importlib | ||||||
| import inspect | ||||||
| import os | ||||||
| import sys | ||||||
| import traceback | ||||||
| import warnings | ||||||
| from collections import OrderedDict | ||||||
|
|
@@ -28,10 +29,16 @@ | |||||
| from typing_extensions import Self | ||||||
|
|
||||||
| from ..configuration_utils import ConfigMixin, FrozenDict | ||||||
| from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj | ||||||
| from ..pipelines.pipeline_loading_utils import ( | ||||||
| LOADABLE_CLASSES, | ||||||
| _fetch_class_library_tuple, | ||||||
| _unwrap_model, | ||||||
| simple_get_class_obj, | ||||||
| ) | ||||||
| from ..utils import PushToHubMixin, is_accelerate_available, logging | ||||||
| from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code | ||||||
| from ..utils.hub_utils import load_or_create_model_card, populate_model_card | ||||||
| from ..utils.torch_utils import is_compiled_module | ||||||
| from .components_manager import ComponentsManager | ||||||
| from .modular_pipeline_utils import ( | ||||||
| MODULAR_MODEL_CARD_TEMPLATE, | ||||||
|
|
@@ -1819,29 +1826,111 @@ def from_pretrained( | |||||
| ) | ||||||
| return pipeline | ||||||
|
|
||||||
| def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs): | ||||||
| def save_pretrained( | ||||||
| self, | ||||||
| save_directory: str | os.PathLike, | ||||||
| safe_serialization: bool = True, | ||||||
| variant: str | None = None, | ||||||
| max_shard_size: int | str | None = None, | ||||||
| push_to_hub: bool = False, | ||||||
| **kwargs, | ||||||
| ): | ||||||
| """ | ||||||
| Save the pipeline to a directory. It does not save components, you need to save them separately. | ||||||
| Save the pipeline and all its components to a directory, so that it can be re-loaded using the | ||||||
| [`~ModularPipeline.from_pretrained`] class method. | ||||||
|
|
||||||
| Args: | ||||||
| save_directory (`str` or `os.PathLike`): | ||||||
| Path to the directory where the pipeline will be saved. | ||||||
| push_to_hub (`bool`, optional): | ||||||
| Whether to push the pipeline to the huggingface hub. | ||||||
| **kwargs: Additional arguments passed to `save_config()` method | ||||||
| """ | ||||||
| Directory to save the pipeline to. Will be created if it doesn't exist. | ||||||
| safe_serialization (`bool`, *optional*, defaults to `True`): | ||||||
| Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. | ||||||
| variant (`str`, *optional*): | ||||||
| If specified, weights are saved in the format `pytorch_model.<variant>.bin`. | ||||||
| max_shard_size (`int` or `str`, defaults to `None`): | ||||||
| The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size | ||||||
| lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). | ||||||
| If expressed as an integer, the unit is bytes. | ||||||
| push_to_hub (`bool`, *optional*, defaults to `False`): | ||||||
| Whether to push the pipeline to the Hugging Face model hub after saving it. | ||||||
| **kwargs: Additional keyword arguments passed along to the push to hub method. | ||||||
| """ | ||||||
| overwrite_modular_index = kwargs.pop("overwrite_modular_index", False) | ||||||
| repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) | ||||||
|
|
||||||
| for component_name, component_spec in self._component_specs.items(): | ||||||
| sub_model = getattr(self, component_name, None) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit):
Suggested change
Not all components need to models. |
||||||
| if sub_model is None: | ||||||
| continue | ||||||
|
|
||||||
| model_cls = sub_model.__class__ | ||||||
| if is_compiled_module(sub_model): | ||||||
| sub_model = _unwrap_model(sub_model) | ||||||
| model_cls = sub_model.__class__ | ||||||
|
|
||||||
| save_method_name = None | ||||||
| for library_name, library_classes in LOADABLE_CLASSES.items(): | ||||||
| if library_name in sys.modules: | ||||||
| library = importlib.import_module(library_name) | ||||||
| else: | ||||||
| logger.info( | ||||||
| f"{library_name} is not installed. Cannot save {component_name} as {library_classes} from {library_name}" | ||||||
| ) | ||||||
| continue | ||||||
|
|
||||||
| for base_class, save_load_methods in library_classes.items(): | ||||||
| class_candidate = getattr(library, base_class, None) | ||||||
| if class_candidate is not None and issubclass(model_cls, class_candidate): | ||||||
| save_method_name = save_load_methods[0] | ||||||
| break | ||||||
| if save_method_name is not None: | ||||||
| break | ||||||
|
|
||||||
| if save_method_name is None: | ||||||
| logger.warning(f"self.{component_name}={sub_model} of type {type(sub_model)} cannot be saved.") | ||||||
| continue | ||||||
|
|
||||||
| save_method = getattr(sub_model, save_method_name) | ||||||
| save_method_signature = inspect.signature(save_method) | ||||||
| save_method_accept_safe = "safe_serialization" in save_method_signature.parameters | ||||||
| save_method_accept_variant = "variant" in save_method_signature.parameters | ||||||
| save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters | ||||||
|
|
||||||
| save_kwargs = {} | ||||||
| if save_method_accept_safe: | ||||||
| save_kwargs["safe_serialization"] = safe_serialization | ||||||
| if save_method_accept_variant: | ||||||
| save_kwargs["variant"] = variant | ||||||
| if save_method_accept_max_shard_size and max_shard_size is not None: | ||||||
| save_kwargs["max_shard_size"] = max_shard_size | ||||||
|
|
||||||
| save_method(os.path.join(save_directory, component_name), **save_kwargs) | ||||||
|
|
||||||
| if push_to_hub: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would prefer to fully push the |
||||||
| commit_message = kwargs.pop("commit_message", None) | ||||||
| private = kwargs.pop("private", None) | ||||||
| create_pr = kwargs.pop("create_pr", False) | ||||||
| token = kwargs.pop("token", None) | ||||||
| repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is this going? |
||||||
| repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id | ||||||
|
|
||||||
| # Generate modular pipeline card content | ||||||
| card_content = generate_modular_model_card_content(self.blocks) | ||||||
| if overwrite_modular_index: | ||||||
| for component_name, component_spec in self._component_specs.items(): | ||||||
| if component_spec.default_creation_method != "from_pretrained": | ||||||
| continue | ||||||
|
Comment on lines
+1917
to
+1918
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain what this is doing? |
||||||
| if component_name not in self.config: | ||||||
| continue | ||||||
|
|
||||||
| library, class_name, component_spec_dict = self.config[component_name] | ||||||
| component_spec_dict["pretrained_model_name_or_path"] = repo_id | ||||||
| component_spec_dict["subfolder"] = component_name | ||||||
| if variant is not None and "variant" in component_spec_dict: | ||||||
| component_spec_dict["variant"] = variant | ||||||
|
|
||||||
| self.register_to_config(**{component_name: (library, class_name, component_spec_dict)}) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not too sure about the objective of this block. What happens if its corresponding Or is this unrelated? |
||||||
|
|
||||||
| self.save_config(save_directory=save_directory) | ||||||
|
|
||||||
| # Create a new empty model card and eventually tag it | ||||||
| if push_to_hub: | ||||||
| card_content = generate_modular_model_card_content(self.blocks) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this conditioned on the above changes? If not, maybe we can keep it in the earlier position? |
||||||
| model_card = load_or_create_model_card( | ||||||
| repo_id, | ||||||
| token=token, | ||||||
|
|
@@ -1850,13 +1939,8 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = | |||||
| is_modular=True, | ||||||
| ) | ||||||
| model_card = populate_model_card(model_card, tags=card_content["tags"]) | ||||||
|
|
||||||
| model_card.save(os.path.join(save_directory, "README.md")) | ||||||
|
|
||||||
| # YiYi TODO: maybe order the json file to make it more readable: configs first, then components | ||||||
| self.save_config(save_directory=save_directory) | ||||||
|
|
||||||
| if push_to_hub: | ||||||
| self._upload_folder( | ||||||
| save_directory, | ||||||
| repo_id, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we document what is allowed in the
kwargs?overwrite_modular_indexdeserves some documentation IMO.