Skip to content

Commit 7dd0467

Browse files
authored
feat : svd distillation with CLI (#98)
1 parent 35653d2 commit 7dd0467

5 files changed

Lines changed: 129 additions & 10 deletions

File tree

analog_svd_distill.pt

4.83 MB
Binary file not shown.

analog_svd_distill.text_encoder.pt

1.15 MB
Binary file not shown.

lora_diffusion/cli_svd.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import fire
2+
from diffusers import StableDiffusionPipeline
3+
import torch
4+
import torch.nn as nn
5+
6+
from .lora import save_all, _find_modules
7+
8+
9+
def _text_lora_path(path: str) -> str:
10+
assert path.endswith(".pt"), "Only .pt files are supported"
11+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
12+
13+
14+
def _ti_lora_path(path: str) -> str:
15+
assert path.endswith(".pt"), "Only .pt files are supported"
16+
return ".".join(path.split(".")[:-1] + ["ti", "pt"])
17+
18+
19+
def extract_linear_weights(model, target_replace_module):
20+
lins = []
21+
for _module, name, _child_module in _find_modules(
22+
model, target_replace_module, search_class=[nn.Linear]
23+
):
24+
lins.append(_child_module.weight)
25+
26+
return lins
27+
28+
29+
def svd_distill(
30+
target_model: str,
31+
base_model: str,
32+
rank: int = 4,
33+
clamp_quantile: float = 0.99,
34+
device: str = "cuda:0",
35+
save_path: str = "svd_distill.pt",
36+
):
37+
pipe_base = StableDiffusionPipeline.from_pretrained(
38+
base_model, torch_dtype=torch.float16
39+
).to(device)
40+
41+
model_id = "wavymulder/Analog-Diffusion"
42+
pipe_tuned = StableDiffusionPipeline.from_pretrained(
43+
target_model, torch_dtype=torch.float16
44+
).to(device)
45+
46+
ori_unet = extract_linear_weights(
47+
pipe_base.unet, ["CrossAttention", "Attention", "GEGLU"]
48+
)
49+
ori_clip = extract_linear_weights(pipe_base.text_encoder, ["CLIPAttention"])
50+
51+
tuned_unet = extract_linear_weights(
52+
pipe_tuned.unet, ["CrossAttention", "Attention", "GEGLU"]
53+
)
54+
tuned_clip = extract_linear_weights(pipe_tuned.text_encoder, ["CLIPAttention"])
55+
56+
diffs_unet = []
57+
diffs_clip = []
58+
59+
for ori, tuned in zip(ori_unet, tuned_unet):
60+
diffs_unet.append(tuned - ori)
61+
62+
for ori, tuned in zip(ori_clip, tuned_clip):
63+
diffs_clip.append(tuned - ori)
64+
65+
uds_unet = []
66+
uds_clip = []
67+
with torch.no_grad():
68+
for mat in diffs_unet:
69+
mat = mat.float()
70+
71+
U, S, Vh = torch.linalg.svd(mat)
72+
73+
U = U[:, :rank]
74+
S = S[:rank]
75+
U = U @ torch.diag(S)
76+
77+
Vh = Vh[:rank, :]
78+
79+
dist = torch.cat([U.flatten(), Vh.flatten()])
80+
hi_val = torch.quantile(dist, clamp_quantile)
81+
low_val = -hi_val
82+
83+
U = U.clamp(low_val, hi_val)
84+
Vh = Vh.clamp(low_val, hi_val)
85+
86+
uds_unet.append(U)
87+
uds_unet.append(Vh)
88+
89+
for mat in diffs_clip:
90+
mat = mat.float()
91+
92+
U, S, Vh = torch.linalg.svd(mat)
93+
94+
U = U[:, :rank]
95+
S = S[:rank]
96+
U = U @ torch.diag(S)
97+
98+
Vh = Vh[:rank, :]
99+
100+
dist = torch.cat([U.flatten(), Vh.flatten()])
101+
hi_val = torch.quantile(dist, clamp_quantile)
102+
low_val = -hi_val
103+
104+
U = U.clamp(low_val, hi_val)
105+
Vh = Vh.clamp(low_val, hi_val)
106+
107+
uds_clip.append(U)
108+
uds_clip.append(Vh)
109+
110+
torch.save(uds_unet, save_path)
111+
torch.save(uds_clip, _text_lora_path(save_path))
112+
113+
114+
def main():
115+
fire.Fire(svd_distill)

lora_diffusion/lora.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -654,22 +654,25 @@ def save_all(
654654
placeholder_tokens,
655655
save_path,
656656
save_lora=True,
657+
save_ti=True,
657658
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
658659
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
659660
):
660661

661662
# save ti
662-
ti_path = _ti_lora_path(save_path)
663-
learned_embeds_dict = {}
664-
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
665-
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
666-
print(
667-
f"Current Learned Embeddings for {tok}:, id {tok_id} ", learned_embeds[:4]
668-
)
669-
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
663+
if save_ti:
664+
ti_path = _ti_lora_path(save_path)
665+
learned_embeds_dict = {}
666+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
667+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
668+
print(
669+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
670+
learned_embeds[:4],
671+
)
672+
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
670673

671-
torch.save(learned_embeds_dict, ti_path)
672-
print("Ti saved to ", ti_path)
674+
torch.save(learned_embeds_dict, ti_path)
675+
print("Ti saved to ", ti_path)
673676

674677
# save text encoder
675678
if save_lora:

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"console_scripts": [
1515
"lora_add = lora_diffusion.cli_lora_add:main",
1616
"lora_pti = lora_diffusion.cli_lora_pti:main",
17+
"lora_distill = lora_diffusion.cli_svd:main",
1718
],
1819
},
1920
install_requires=[

0 commit comments

Comments
 (0)