Skip to content

Commit e90b90a

Browse files
committed
Update DiT block for LTX 2.3 + add self_attention_mask
1 parent 6c7e720 commit e90b90a

File tree

1 file changed

+118
-40
lines changed

1 file changed

+118
-40
lines changed

src/diffusers/models/transformers/transformer_ltx2.py

Lines changed: 118 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ def __call__(
178178
if encoder_hidden_states is None:
179179
encoder_hidden_states = hidden_states
180180

181+
if attn.to_gate_logits is not None:
182+
# Calculate gate logits on original hidden_states
183+
gate_logits = attn.to_gate_logits(hidden_states)
184+
181185
query = attn.to_q(hidden_states)
182186
key = attn.to_k(encoder_hidden_states)
183187
value = attn.to_v(encoder_hidden_states)
@@ -212,6 +216,13 @@ def __call__(
212216
hidden_states = hidden_states.flatten(2, 3)
213217
hidden_states = hidden_states.to(query.dtype)
214218

219+
if attn.to_gate_logits is not None:
220+
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
221+
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
222+
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
223+
hidden_states = hidden_states * gates.unsqueeze(-1)
224+
hidden_states = hidden_states.flatten(2, 3)
225+
215226
hidden_states = attn.to_out[0](hidden_states)
216227
hidden_states = attn.to_out[1](hidden_states)
217228
return hidden_states
@@ -427,6 +438,10 @@ def __init__(
427438
audio_num_attention_heads: int,
428439
audio_attention_head_dim,
429440
audio_cross_attention_dim: int,
441+
video_gated_attn: bool = False,
442+
video_cross_attn_adaln: bool = False,
443+
audio_gated_attn: bool = False,
444+
audio_cross_attn_adaln: bool = False,
430445
qk_norm: str = "rms_norm_across_heads",
431446
activation_fn: str = "gelu-approximate",
432447
attention_bias: bool = True,
@@ -449,6 +464,8 @@ def __init__(
449464
out_bias=attention_out_bias,
450465
qk_norm=qk_norm,
451466
rope_type=rope_type,
467+
apply_gated_attention=video_gated_attn,
468+
processor=LTX2AudioVideoAttnProcessor(),
452469
)
453470

454471
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -462,6 +479,8 @@ def __init__(
462479
out_bias=attention_out_bias,
463480
qk_norm=qk_norm,
464481
rope_type=rope_type,
482+
apply_gated_attention=audio_gated_attn,
483+
processor=LTX2AudioVideoAttnProcessor(),
465484
)
466485

467486
# 2. Prompt Cross-Attention
@@ -476,6 +495,8 @@ def __init__(
476495
out_bias=attention_out_bias,
477496
qk_norm=qk_norm,
478497
rope_type=rope_type,
498+
apply_gated_attention=video_gated_attn,
499+
processor=LTX2AudioVideoAttnProcessor(),
479500
)
480501

481502
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -489,6 +510,8 @@ def __init__(
489510
out_bias=attention_out_bias,
490511
qk_norm=qk_norm,
491512
rope_type=rope_type,
513+
apply_gated_attention=audio_gated_attn,
514+
processor=LTX2AudioVideoAttnProcessor(),
492515
)
493516

494517
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -504,6 +527,8 @@ def __init__(
504527
out_bias=attention_out_bias,
505528
qk_norm=qk_norm,
506529
rope_type=rope_type,
530+
apply_gated_attention=video_gated_attn,
531+
processor=LTX2AudioVideoAttnProcessor(),
507532
)
508533

509534
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
@@ -518,6 +543,8 @@ def __init__(
518543
out_bias=attention_out_bias,
519544
qk_norm=qk_norm,
520545
rope_type=rope_type,
546+
apply_gated_attention=audio_gated_attn,
547+
processor=LTX2AudioVideoAttnProcessor(),
521548
)
522549

523550
# 4. Feedforward layers
@@ -528,14 +555,37 @@ def __init__(
528555
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
529556

530557
# 5. Per-Layer Modulation Parameters
531-
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
532-
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
533-
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
558+
# Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params
559+
# 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q
560+
self.video_cross_attn_adaln = video_cross_attn_adaln
561+
self.audio_cross_attn_adaln = audio_cross_attn_adaln
562+
video_mod_param_num = 9 if self.video_cross_attn_adaln else 6
563+
audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6
564+
self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5)
565+
self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5)
566+
567+
# Prompt cross-attn (attn2) additional modulation params
568+
self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
569+
if self.cross_attn_adaln:
570+
self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
571+
self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
534572

535573
# Per-layer a2v, v2a Cross-Attention mod params
536574
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
537575
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
538576

577+
@staticmethod
578+
def get_mod_params(
579+
scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
580+
) -> tuple[torch.Tensor, ...]:
581+
num_ada_params = scale_shift_table.shape[0]
582+
ada_values = (
583+
scale_shift_table[None, None].to(temb.device)
584+
+ temb.reshape(batch_size, temb.shape[1], num_ada_params, -1)
585+
)
586+
ada_params = ada_values.unbind(dim=2)
587+
return ada_params
588+
539589
def forward(
540590
self,
541591
hidden_states: torch.Tensor,
@@ -548,6 +598,8 @@ def forward(
548598
temb_ca_audio_scale_shift: torch.Tensor,
549599
temb_ca_gate: torch.Tensor,
550600
temb_ca_audio_gate: torch.Tensor,
601+
temb_prompt: torch.Tensor | None = None,
602+
temb_prompt_audio: torch.Tensor | None = None,
551603
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
552604
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
553605
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
@@ -560,13 +612,13 @@ def forward(
560612
batch_size = hidden_states.size(0)
561613

562614
# 1. Video and Audio Self-Attention
563-
norm_hidden_states = self.norm1(hidden_states)
615+
# 1.1. Video Self-Attention
616+
video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size)
617+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6]
618+
if self.video_cross_attn_adaln:
619+
shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9]
564620

565-
num_ada_params = self.scale_shift_table.shape[0]
566-
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
567-
batch_size, temb.size(1), num_ada_params, -1
568-
)
569-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
621+
norm_hidden_states = self.norm1(hidden_states)
570622
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
571623

572624
attn_hidden_states = self.attn1(
@@ -576,15 +628,15 @@ def forward(
576628
)
577629
hidden_states = hidden_states + attn_hidden_states * gate_msa
578630

579-
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
580-
581-
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
582-
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
583-
batch_size, temb_audio.size(1), num_audio_ada_params, -1
584-
)
631+
# 1.2. Audio Self-Attention
632+
audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size)
585633
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
586-
audio_ada_values.unbind(dim=2)
634+
audio_ada_params[:6]
587635
)
636+
if self.audio_cross_attn_adaln:
637+
audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9]
638+
639+
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
588640
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
589641

590642
attn_audio_hidden_states = self.audio_attn1(
@@ -594,63 +646,74 @@ def forward(
594646
)
595647
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
596648

597-
# 2. Video and Audio Cross-Attention with the text embeddings
649+
# 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text)
650+
if self.cross_attn_adaln:
651+
video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size)
652+
shift_text_kv, scale_text_kv = video_prompt_ada_params
653+
654+
audio_prompt_ada_params = self.get_mod_params(self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size)
655+
audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params
656+
657+
# 2.1. Video-Text Cross-Attention (Q: Video; K,V: Test)
598658
norm_hidden_states = self.norm2(hidden_states)
659+
if self.video_cross_attn_adaln:
660+
norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q
661+
if self.cross_attn_adaln:
662+
encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
663+
599664
attn_hidden_states = self.attn2(
600665
norm_hidden_states,
601666
encoder_hidden_states=encoder_hidden_states,
602667
query_rotary_emb=None,
603668
attention_mask=encoder_attention_mask,
604669
)
670+
if self.video_cross_attn_adaln:
671+
attn_hidden_states = attn_hidden_states * gate_text_q
605672
hidden_states = hidden_states + attn_hidden_states
606673

674+
# 2.2. Audio-Text Cross-Attention
607675
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
676+
if self.audio_cross_attn_adaln:
677+
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q
678+
if self.cross_attn_adaln:
679+
audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
680+
608681
attn_audio_hidden_states = self.audio_attn2(
609682
norm_audio_hidden_states,
610683
encoder_hidden_states=audio_encoder_hidden_states,
611684
query_rotary_emb=None,
612685
attention_mask=audio_encoder_attention_mask,
613686
)
687+
if self.audio_cross_attn_adaln:
688+
attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
614689
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
615690

616691
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
617692
norm_hidden_states = self.audio_to_video_norm(hidden_states)
618693
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
619694

620-
# Combine global and per-layer cross attention modulation parameters
695+
# 3.1. Combine global and per-layer cross attention modulation parameters
621696
# Video
622697
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
623698
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
624699

625-
video_ca_scale_shift_table = (
626-
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
627-
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
628-
).unbind(dim=2)
629-
video_ca_gate = (
630-
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
631-
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
632-
).unbind(dim=2)
700+
video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size)
701+
video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size)
633702

634-
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
635-
a2v_gate = video_ca_gate[0].squeeze(2)
703+
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params
704+
a2v_gate = video_ca_gate_param[0].squeeze(2)
636705

637706
# Audio
638707
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
639708
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
640709

641-
audio_ca_scale_shift_table = (
642-
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
643-
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
644-
).unbind(dim=2)
645-
audio_ca_gate = (
646-
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
647-
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
648-
).unbind(dim=2)
710+
audio_ca_ada_params = self.get_mod_params(audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size)
711+
audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size)
649712

650-
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
651-
v2a_gate = audio_ca_gate[0].squeeze(2)
713+
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params
714+
v2a_gate = audio_ca_gate_param[0].squeeze(2)
652715

653-
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
716+
# 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
654717
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
655718
2
656719
)
@@ -668,7 +731,7 @@ def forward(
668731

669732
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
670733

671-
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
734+
# 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
672735
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
673736
2
674737
)
@@ -1209,6 +1272,7 @@ def forward(
12091272
audio_timestep: torch.LongTensor | None = None,
12101273
encoder_attention_mask: torch.Tensor | None = None,
12111274
audio_encoder_attention_mask: torch.Tensor | None = None,
1275+
self_attention_mask: torch.Tensor | None = None,
12121276
num_frames: int | None = None,
12131277
height: int | None = None,
12141278
width: int | None = None,
@@ -1241,6 +1305,8 @@ def forward(
12411305
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
12421306
audio_encoder_attention_mask (`torch.Tensor`, *optional*):
12431307
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling.
1308+
self_attention_mask (`torch.Tensor`, *optional*):
1309+
Optional multiplicative self-attention mask of shape `(batch_size, seq_len, seq_len)`.
12441310
num_frames (`int`, *optional*):
12451311
The number of latent video frames. Used if calculating the video coordinates for RoPE.
12461312
height (`int`, *optional*):
@@ -1281,6 +1347,18 @@ def forward(
12811347
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
12821348
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
12831349

1350+
if self_attention_mask is not None and self_attention_mask.ndim == 3:
1351+
# Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative
1352+
# number and positive values are mapped to their logarithm.
1353+
dtype_finfo = torch.finfo(hidden_states.dtype)
1354+
additive_self_attn_mask = torch.full_like(self_attention_mask, dtype_finfo.min, dtype=hidden_states.dtype)
1355+
unmasked_entries = self_attention_mask > 0
1356+
if torch.any(unmasked_entries):
1357+
additive_self_attn_mask[unmasked_entries] = torch.log(
1358+
self_attention_mask[unmasked_entries].clamp(min=dtype_finfo.tiny)
1359+
).to(hidden_states.dtype)
1360+
self_attention_mask = additive_self_attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
1361+
12841362
batch_size = hidden_states.size(0)
12851363

12861364
# 1. Prepare RoPE positional embeddings

0 commit comments

Comments
 (0)