-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Description
Currently, when doing inference using load_and_set_lora_ckpt() to load the lora weights everything is working as expected.
The only issue is that you have to re-load the base model to be able to do inference using different lora weights.
Would it be possible or is it possible to swap lora weights on a model that remains loaded into memory without corrupting it or how in most cases it's as if the lora itself is missing?
Minimal code to reproduce:
import os
import io
import json
import urllib.request
import torch
from diffusers import StableDiffusionPipeline
from peft import LoraModel, LoraConfig, set_peft_model_state_dict
global pipe
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=torch.float16).to("cuda")
def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype):
with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f:
lora_config = json.load(f)
print(lora_config)
checkpoint = f"{ckpt_dir}{instance_prompt}_lora.pt"
lora_checkpoint_sd = torch.load(checkpoint)
unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
text_encoder_lora_ds = {
k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
}
unet_config = LoraConfig(**lora_config["peft_config"])
pipe.unet = LoraModel(unet_config, pipe.unet)
set_peft_model_state_dict(pipe.unet, unet_lora_ds)
if "text_encoder_peft_config" in lora_config:
text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)
set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)
if dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()
pipe.to(device)
return pipe
def inference():
global pipe
INSTANCE_PROMPT = "INSTANCE_PROMPT 1"
ckpt_dir = "/ckpt_dir 1"
pipe = load_and_set_lora_ckpt(pipe, ckpt_dir, INSTANCE_PROMPT, "cuda", torch.float16)
prompt = "Prompt for LoRA 1"
image = pipe(prompt=prompt).images
image[0].show()
#Loading the same lora or a different one a second will result in behaviour as if no lora is loaded
INSTANCE_PROMPT2 = "INSTANCE_PROMPT 2"
ckpt_dir = "/ckpt_dir 2"
pipe = load_and_set_lora_ckpt(pipe, ckpt_dir, INSTANCE_PROMPT, "cuda", torch.float16)
prompt = "Prompt for LoRA 2"
image = pipe(prompt=prompt).images
image[0].show()
inference()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels