Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions modules/models/sd3/other_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast

from modules import sd_hijack


#################################################################################################
### Core/Utility
Expand Down Expand Up @@ -110,9 +112,9 @@ def forward(self, x, mask=None, intermediate_output=None):


class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)

def forward(self, input_tokens):
Expand All @@ -127,7 +129,7 @@ def __init__(self, config_dict, dtype, device):
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)

Expand Down
6 changes: 5 additions & 1 deletion modules/models/sd3/sd3_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __getitem__(self, key):
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
"textual_inversion_key": "clip_g",
}

T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
Expand Down Expand Up @@ -204,7 +205,10 @@ def before_load_weights(self, state_dict):
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)

def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX
return self.model_lg.encode_embedding_init_text(init_text, nvpt)

def tokenize(self, texts):
return self.model_lg.tokenize(texts)

def medvram_modules(self):
return [self.clip_g, self.clip_l, self.t5xxl]
Expand Down
17 changes: 16 additions & 1 deletion modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,28 @@ def forward(self, input_ids):
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)

vecs.append(tensor)

return torch.stack(vecs)


class TextualInversionEmbeddings(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)

self.embeddings = model_hijack
self.textual_inversion_key = textual_inversion_key

@property
def wrapped(self):
return super().forward

def forward(self, input_ids):
return EmbeddingsWithFixes.forward(self, input_ids)


def add_circular_option_to_conv_2d():
conv2d_constructor = torch.nn.Conv2d.__init__

Expand Down