Skip to content

Support for swapping Dreambooth LoRAs when model is loaded in memory. #266

@Dentoty

Description

@Dentoty

Currently, when doing inference using load_and_set_lora_ckpt() to load the lora weights everything is working as expected.
The only issue is that you have to re-load the base model to be able to do inference using different lora weights.

Would it be possible or is it possible to swap lora weights on a model that remains loaded into memory without corrupting it or how in most cases it's as if the lora itself is missing?

Minimal code to reproduce:

import os
import io
import json
import urllib.request
import torch
from diffusers import StableDiffusionPipeline
from peft import LoraModel, LoraConfig, set_peft_model_state_dict

global pipe
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=torch.float16).to("cuda")

def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype):
    with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f:
        lora_config = json.load(f)
    print(lora_config)

    checkpoint = f"{ckpt_dir}{instance_prompt}_lora.pt"
    lora_checkpoint_sd = torch.load(checkpoint)
    unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
    text_encoder_lora_ds = {
        k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
    }

    unet_config = LoraConfig(**lora_config["peft_config"])
    pipe.unet = LoraModel(unet_config, pipe.unet)
    set_peft_model_state_dict(pipe.unet, unet_lora_ds)

    if "text_encoder_peft_config" in lora_config:
        text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
        pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)
        set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)

    if dtype in (torch.float16, torch.bfloat16):
        pipe.unet.half()
        pipe.text_encoder.half()

    pipe.to(device)
    return pipe


def inference():
    global pipe

    INSTANCE_PROMPT = "INSTANCE_PROMPT 1"
    ckpt_dir = "/ckpt_dir 1"

    pipe = load_and_set_lora_ckpt(pipe, ckpt_dir, INSTANCE_PROMPT, "cuda", torch.float16)
    prompt = "Prompt for LoRA 1"
    image = pipe(prompt=prompt).images
    image[0].show()

#Loading the same lora or a different one a second will result in behaviour as if no lora is loaded
    INSTANCE_PROMPT2 = "INSTANCE_PROMPT 2"
    ckpt_dir = "/ckpt_dir 2"

    pipe = load_and_set_lora_ckpt(pipe, ckpt_dir, INSTANCE_PROMPT, "cuda", torch.float16)
    prompt = "Prompt for LoRA 2"
    image = pipe(prompt=prompt).images
    image[0].show()

inference()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions