Skip to content

fix: sdxl model invalid configuration after the hijack#12201

Merged
AUTOMATIC1111 merged 1 commit intoAUTOMATIC1111:devfrom
AnyISalIn:dev
Aug 4, 2023
Merged

fix: sdxl model invalid configuration after the hijack#12201
AUTOMATIC1111 merged 1 commit intoAUTOMATIC1111:devfrom
AnyISalIn:dev

Conversation

@AnyISalIn
Copy link
Contributor

@AnyISalIn AnyISalIn commented Jul 31, 2023

Description

Currently, we are using the following code to inspect if the model is SDXL:

 if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
    return config_sdxl

However, after hijack, conditioner.embedders.1.model.ln_final.weight will be changed to conditioner.embedders.1.wrapped.model.ln_final.weight. This will affect the choice of an invalid configuration for the SDXL model when shared.opts.sd_checkpoint_cache is enabled.

Maybe we can solve this problem by preventing the addition of wrapped ?

Screenshots/videos:

How to reproduce it.

When starting the webui, use the following command arguments: --xformers --listen --ckpt=/stable-diffusion-webui/models/Stable-diffusion/sd_xl_base_1.0.safetensors --enable-insecure-extension-access --api --port 7860. Set sd_checkpoint_cache to 3.

Then, test it using the following scripts. It will switch to 0.9, switch back to 1.0, and then switch back to 0.9:

curl -X 'POST' \
  'http://office-gpu:7860/sdapi/v1/options' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "sd_model_checkpoint": "sd_xl_base_0.9.safetensors"
}'


curl -X 'POST' \
  'http://office-gpu:7860/sdapi/v1/options' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "sd_model_checkpoint": "sd_xl_base_1.0.safetensors"
}'


curl -X 'POST' \
  'http://office-gpu:7860/sdapi/v1/options' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "sd_model_checkpoint": "sd_xl_base_0.9.safetensors"
}'

Here is the error output:

Startup time: 32.3s (prepare environment: 13.0s, launcher: 0.2s, import torch: 2.5s, import gradio: 1.5s, setup paths: 0.7s, other imports: 0.6s, list SD models: 0.2s, load scripts: 5.2s, create ui: 2.1s, gradio launch: 2.8s, add APIs: 0.
1s, app_started_callback: 3.3s).
Loading weights [1f69731261] from /stable-diffusion-webui/models/Stable-diffusion/sd_xl_base_0.9.safetensors
Applying attention optimization: xformers... done.
Weights loaded in 3.4s (load weights from disk: 1.3s, apply weights to model: 0.9s, move model to device: 1.1s).
Loading weights [31e35c80fc] from cache
Applying attention optimization: xformers... done.
Weights loaded in 2.0s (apply weights to model: 0.9s, move model to device: 1.1s).
Loading weights [1f69731261] from cache
Creating model from config: /stable-diffusion-webui/configs/v1-inference.yaml
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
changing setting sd_model_checkpoint to sd_xl_base_0.9.safetensors: RuntimeError
Traceback (most recent call last):
  File "/stable-diffusion-webui/modules/shared.py", line 633, in set
    self.data_labels[key].onchange()
  File "/stable-diffusion-webui/modules/call_queue.py", line 14, in f
    res = func(*args, **kwargs)
  File "/stable-diffusion-webui/webui.py", line 238, in <lambda>
    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
  File "/stable-diffusion-webui/modules/sd_models.py", line 582, in reload_model_weights
    load_model(checkpoint_info, already_loaded_state_dict=state_dict)
  File "/stable-diffusion-webui/modules/sd_models.py", line 514, in load_model
    load_model_weights(sd_model, checkpoint_info, state_dict, timer)
  File "/stable-diffusion-webui/modules/sd_models.py", line 299, in load_model_weights
    model.load_state_dict(state_dict, strict=False)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for LatentDiffusion:
        size mismatch for model.diffusion_model.input_blocks.4.1.proj_in.weight: copying a param with shape torch.Size([640, 640]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
        size mismatch for model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([640, 2048]) from checkpoint, the shape in current model is torch.Size([640, 768]).
        size mismatch for model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([640, 2048]) from checkpoint, the shape in current model is torch.Size([640, 768]).
        size mismatch for model.diffusion_model.input_blocks.4.1.proj_out.weight: copying a param with shape torch.Size([640, 640]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
        size mismatch for model.diffusion_model.input_blocks.5.1.proj_in.weight: copying a param with shape torch.Size([640, 640]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
        size mismatch for model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([640, 2048]) from checkpoint, the shape in current model is torch.Size([640, 768]).
        size mismatch for model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([640, 2048]) from checkpoint, the shape in current model is torch.Size([640, 768]).
        size mismatch for model.diffusion_model.input_blocks.5.1.proj_out.weight: copying a param with shape torch.Size([640, 640]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
        size mismatch for model.diffusion_model.input_blocks.7.1.proj_in.weight: copying a param with shape torch.Size([1280, 1280]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 1, 1]).
        size mismatch for model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([1280, 2048]) from checkpoint, the shape in current model is torch.Size([1280, 768]).
        size mismatch for model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([1280, 2048]) from checkpoint, the shape in current model is torch.Size([1280, 768]).
        size mismatch for model.diffusion_model.input_blocks.7.1.proj_out.weight: copying a param with shape torch.Size([1280, 1280]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 1, 1]).

Checklist:

@AnyISalIn AnyISalIn requested a review from AUTOMATIC1111 as a code owner July 31, 2023 07:26
@AUTOMATIC1111
Copy link
Owner

This is not the solution because in addition to incorrect detection, this will fail to load weights properly.

Maybe the right way to go is to change

    model.load_state_dict(state_dict, strict=False)
    del state_dict
    timer.record("apply weights to model")

    if shared.opts.sd_checkpoint_cache > 0:
        # cache newly loaded model
        checkpoints_loaded[checkpoint_info] = model.state_dict().copy()

to

    if shared.opts.sd_checkpoint_cache > 0:
        # cache newly loaded model
        checkpoints_loaded[checkpoint_info] = state_dict

    model.load_state_dict(state_dict, strict=False)
    del state_dict
    timer.record("apply weights to model")

ie store the state dict in the cache before it is applied to a model at all

@AnyISalIn
Copy link
Contributor Author

This is not the solution because in addition to incorrect detection, this will fail to load weights properly.

Maybe the right way to go is to change

    model.load_state_dict(state_dict, strict=False)
    del state_dict
    timer.record("apply weights to model")

    if shared.opts.sd_checkpoint_cache > 0:
        # cache newly loaded model
        checkpoints_loaded[checkpoint_info] = model.state_dict().copy()

to

    if shared.opts.sd_checkpoint_cache > 0:
        # cache newly loaded model
        checkpoints_loaded[checkpoint_info] = state_dict

    model.load_state_dict(state_dict, strict=False)
    del state_dict
    timer.record("apply weights to model")

ie store the state dict in the cache before it is applied to a model at all

Yes, your solution is actually my first version, but I noticed that you submitted a new PR (#12227) that indirectly solves this problem.

@AUTOMATIC1111
Copy link
Owner

Right, well, I do not want to outright remove checkpoint cache yet because there is a scenario where it is still useful, like keeping 5 checkpoints in CPU RAM and 2 models loaded on GPU.

That said checkpoint cache has a much bigger problem - it caches the currently loaded model too, so value like 1 does nothing useful and just eats RAM.

@AnyISalIn
Copy link
Contributor Author

Right, well, I do not want to outright remove checkpoint cache yet because there is a scenario where it is still useful, like keeping 5 checkpoints in CPU RAM and 2 models loaded on GPU.

That said checkpoint cache has a much bigger problem - it caches the currently loaded model too, so value like 1 does nothing useful and just eats RAM.

Maybe we should only cache the checkpoint during the reload weights stage. This will save RAM.

@AnyISalIn
Copy link
Contributor Author

If you do not remove checkpoint_loaded, maybe I can commit following this change?

    if shared.opts.sd_checkpoint_cache > 0:
        # cache newly loaded model
        checkpoints_loaded[checkpoint_info] = state_dict

    model.load_state_dict(state_dict, strict=False)
    del state_dict
    timer.record("apply weights to model")
 

@AUTOMATIC1111
Copy link
Owner

Well that was my recommendation... I wanted you to test if it works well and commit it into the PR.

@catboxanon catboxanon added the sdxl Related to SDXL label Aug 3, 2023
Signed-off-by: AnyISalIn <anyisalin@gmail.com>
@AnyISalIn
Copy link
Contributor Author

Well that was my recommendation... I wanted you to test if it works well and commit it into the PR.

done, its works fine for me.

@AUTOMATIC1111 AUTOMATIC1111 merged commit c938579 into AUTOMATIC1111:dev Aug 4, 2023
brkirch pushed a commit to brkirch/stable-diffusion-webui that referenced this pull request Aug 4, 2023
fix: sdxl model invalid configuration after the hijack
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

sdxl Related to SDXL

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants