Skip to content

Commit e340b52

Browse files
committed
update
1 parent 71ce634 commit e340b52

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,6 +1854,9 @@ def save_pretrained(
18541854
Whether to push the pipeline to the Hugging Face model hub after saving it.
18551855
**kwargs: Additional keyword arguments passed along to the push to hub method.
18561856
"""
1857+
overwrite_modular_index = kwargs.pop("overwrite_modular_index", False)
1858+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
1859+
18571860
for component_name, component_spec in self._component_specs.items():
18581861
sub_model = getattr(self, component_name, None)
18591862
if sub_model is None:
@@ -1902,16 +1905,33 @@ def save_pretrained(
19021905

19031906
save_method(os.path.join(save_directory, component_name), **save_kwargs)
19041907

1905-
self.save_config(save_directory=save_directory)
1906-
19071908
if push_to_hub:
19081909
commit_message = kwargs.pop("commit_message", None)
19091910
private = kwargs.pop("private", None)
19101911
create_pr = kwargs.pop("create_pr", False)
19111912
token = kwargs.pop("token", None)
1912-
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
19131913
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
19141914

1915+
if overwrite_modular_index:
1916+
for component_name, component_spec in self._component_specs.items():
1917+
if component_spec.default_creation_method != "from_pretrained":
1918+
continue
1919+
sub_model = getattr(self, component_name, None)
1920+
if sub_model is None:
1921+
continue
1922+
1923+
component_spec.pretrained_model_name_or_path = repo_id
1924+
component_spec.subfolder = component_name
1925+
if variant is not None and hasattr(component_spec, "variant"):
1926+
component_spec.variant = variant
1927+
1928+
library, class_name = _fetch_class_library_tuple(sub_model)
1929+
component_spec_dict = self._component_spec_to_dict(component_spec)
1930+
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
1931+
1932+
self.save_config(save_directory=save_directory)
1933+
1934+
if push_to_hub:
19151935
card_content = generate_modular_model_card_content(self.blocks)
19161936
model_card = load_or_create_model_card(
19171937
repo_id,

0 commit comments

Comments
 (0)