@@ -178,6 +178,10 @@ def __call__(
178178 if encoder_hidden_states is None :
179179 encoder_hidden_states = hidden_states
180180
181+ if attn .to_gate_logits is not None :
182+ # Calculate gate logits on original hidden_states
183+ gate_logits = attn .to_gate_logits (hidden_states )
184+
181185 query = attn .to_q (hidden_states )
182186 key = attn .to_k (encoder_hidden_states )
183187 value = attn .to_v (encoder_hidden_states )
@@ -212,6 +216,13 @@ def __call__(
212216 hidden_states = hidden_states .flatten (2 , 3 )
213217 hidden_states = hidden_states .to (query .dtype )
214218
219+ if attn .to_gate_logits is not None :
220+ hidden_states = hidden_states .unflatten (2 , (attn .heads , - 1 )) # [B, T, H, D]
221+ # The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
222+ gates = 2.0 * torch .sigmoid (gate_logits ) # [B, T, H]
223+ hidden_states = hidden_states * gates .unsqueeze (- 1 )
224+ hidden_states = hidden_states .flatten (2 , 3 )
225+
215226 hidden_states = attn .to_out [0 ](hidden_states )
216227 hidden_states = attn .to_out [1 ](hidden_states )
217228 return hidden_states
@@ -427,6 +438,10 @@ def __init__(
427438 audio_num_attention_heads : int ,
428439 audio_attention_head_dim ,
429440 audio_cross_attention_dim : int ,
441+ video_gated_attn : bool = False ,
442+ video_cross_attn_adaln : bool = False ,
443+ audio_gated_attn : bool = False ,
444+ audio_cross_attn_adaln : bool = False ,
430445 qk_norm : str = "rms_norm_across_heads" ,
431446 activation_fn : str = "gelu-approximate" ,
432447 attention_bias : bool = True ,
@@ -449,6 +464,8 @@ def __init__(
449464 out_bias = attention_out_bias ,
450465 qk_norm = qk_norm ,
451466 rope_type = rope_type ,
467+ apply_gated_attention = video_gated_attn ,
468+ processor = LTX2AudioVideoAttnProcessor (),
452469 )
453470
454471 self .audio_norm1 = RMSNorm (audio_dim , eps = eps , elementwise_affine = elementwise_affine )
@@ -462,6 +479,8 @@ def __init__(
462479 out_bias = attention_out_bias ,
463480 qk_norm = qk_norm ,
464481 rope_type = rope_type ,
482+ apply_gated_attention = audio_gated_attn ,
483+ processor = LTX2AudioVideoAttnProcessor (),
465484 )
466485
467486 # 2. Prompt Cross-Attention
@@ -476,6 +495,8 @@ def __init__(
476495 out_bias = attention_out_bias ,
477496 qk_norm = qk_norm ,
478497 rope_type = rope_type ,
498+ apply_gated_attention = video_gated_attn ,
499+ processor = LTX2AudioVideoAttnProcessor (),
479500 )
480501
481502 self .audio_norm2 = RMSNorm (audio_dim , eps = eps , elementwise_affine = elementwise_affine )
@@ -489,6 +510,8 @@ def __init__(
489510 out_bias = attention_out_bias ,
490511 qk_norm = qk_norm ,
491512 rope_type = rope_type ,
513+ apply_gated_attention = audio_gated_attn ,
514+ processor = LTX2AudioVideoAttnProcessor (),
492515 )
493516
494517 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -504,6 +527,8 @@ def __init__(
504527 out_bias = attention_out_bias ,
505528 qk_norm = qk_norm ,
506529 rope_type = rope_type ,
530+ apply_gated_attention = video_gated_attn ,
531+ processor = LTX2AudioVideoAttnProcessor (),
507532 )
508533
509534 # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
@@ -518,6 +543,8 @@ def __init__(
518543 out_bias = attention_out_bias ,
519544 qk_norm = qk_norm ,
520545 rope_type = rope_type ,
546+ apply_gated_attention = audio_gated_attn ,
547+ processor = LTX2AudioVideoAttnProcessor (),
521548 )
522549
523550 # 4. Feedforward layers
@@ -528,14 +555,37 @@ def __init__(
528555 self .audio_ff = FeedForward (audio_dim , activation_fn = activation_fn )
529556
530557 # 5. Per-Layer Modulation Parameters
531- # Self-Attention / Feedforward AdaLayerNorm-Zero mod params
532- self .scale_shift_table = nn .Parameter (torch .randn (6 , dim ) / dim ** 0.5 )
533- self .audio_scale_shift_table = nn .Parameter (torch .randn (6 , audio_dim ) / audio_dim ** 0.5 )
558+ # Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params
559+ # 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q
560+ self .video_cross_attn_adaln = video_cross_attn_adaln
561+ self .audio_cross_attn_adaln = audio_cross_attn_adaln
562+ video_mod_param_num = 9 if self .video_cross_attn_adaln else 6
563+ audio_mod_param_num = 9 if self .audio_cross_attn_adaln else 6
564+ self .scale_shift_table = nn .Parameter (torch .randn (video_mod_param_num , dim ) / dim ** 0.5 )
565+ self .audio_scale_shift_table = nn .Parameter (torch .randn (audio_mod_param_num , audio_dim ) / audio_dim ** 0.5 )
566+
567+ # Prompt cross-attn (attn2) additional modulation params
568+ self .cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
569+ if self .cross_attn_adaln :
570+ self .prompt_scale_shift_table = nn .Parameter (torch .randn (2 , dim ))
571+ self .audio_prompt_scale_shift_table = nn .Parameter (torch .randn (2 , dim ))
534572
535573 # Per-layer a2v, v2a Cross-Attention mod params
536574 self .video_a2v_cross_attn_scale_shift_table = nn .Parameter (torch .randn (5 , dim ))
537575 self .audio_a2v_cross_attn_scale_shift_table = nn .Parameter (torch .randn (5 , audio_dim ))
538576
577+ @staticmethod
578+ def get_mod_params (
579+ scale_shift_table : torch .Tensor , temb : torch .Tensor , batch_size : int
580+ ) -> tuple [torch .Tensor , ...]:
581+ num_ada_params = scale_shift_table .shape [0 ]
582+ ada_values = (
583+ scale_shift_table [None , None ].to (temb .device )
584+ + temb .reshape (batch_size , temb .shape [1 ], num_ada_params , - 1 )
585+ )
586+ ada_params = ada_values .unbind (dim = 2 )
587+ return ada_params
588+
539589 def forward (
540590 self ,
541591 hidden_states : torch .Tensor ,
@@ -548,6 +598,8 @@ def forward(
548598 temb_ca_audio_scale_shift : torch .Tensor ,
549599 temb_ca_gate : torch .Tensor ,
550600 temb_ca_audio_gate : torch .Tensor ,
601+ temb_prompt : torch .Tensor | None = None ,
602+ temb_prompt_audio : torch .Tensor | None = None ,
551603 video_rotary_emb : tuple [torch .Tensor , torch .Tensor ] | None = None ,
552604 audio_rotary_emb : tuple [torch .Tensor , torch .Tensor ] | None = None ,
553605 ca_video_rotary_emb : tuple [torch .Tensor , torch .Tensor ] | None = None ,
@@ -560,13 +612,13 @@ def forward(
560612 batch_size = hidden_states .size (0 )
561613
562614 # 1. Video and Audio Self-Attention
563- norm_hidden_states = self .norm1 (hidden_states )
615+ # 1.1. Video Self-Attention
616+ video_ada_params = self .get_mod_params (self .scale_shift_table , temb , batch_size )
617+ shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = video_ada_params [:6 ]
618+ if self .video_cross_attn_adaln :
619+ shift_text_q , scale_text_q , gate_text_q = video_ada_params [6 :9 ]
564620
565- num_ada_params = self .scale_shift_table .shape [0 ]
566- ada_values = self .scale_shift_table [None , None ].to (temb .device ) + temb .reshape (
567- batch_size , temb .size (1 ), num_ada_params , - 1
568- )
569- shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = ada_values .unbind (dim = 2 )
621+ norm_hidden_states = self .norm1 (hidden_states )
570622 norm_hidden_states = norm_hidden_states * (1 + scale_msa ) + shift_msa
571623
572624 attn_hidden_states = self .attn1 (
@@ -576,15 +628,15 @@ def forward(
576628 )
577629 hidden_states = hidden_states + attn_hidden_states * gate_msa
578630
579- norm_audio_hidden_states = self .audio_norm1 (audio_hidden_states )
580-
581- num_audio_ada_params = self .audio_scale_shift_table .shape [0 ]
582- audio_ada_values = self .audio_scale_shift_table [None , None ].to (temb_audio .device ) + temb_audio .reshape (
583- batch_size , temb_audio .size (1 ), num_audio_ada_params , - 1
584- )
631+ # 1.2. Audio Self-Attention
632+ audio_ada_params = self .get_mod_params (self .audio_scale_shift_table , temb_audio , batch_size )
585633 audio_shift_msa , audio_scale_msa , audio_gate_msa , audio_shift_mlp , audio_scale_mlp , audio_gate_mlp = (
586- audio_ada_values . unbind ( dim = 2 )
634+ audio_ada_params [: 6 ]
587635 )
636+ if self .audio_cross_attn_adaln :
637+ audio_shift_text_q , audio_scale_text_q , audio_gate_text_q = audio_ada_params [6 :9 ]
638+
639+ norm_audio_hidden_states = self .audio_norm1 (audio_hidden_states )
588640 norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa ) + audio_shift_msa
589641
590642 attn_audio_hidden_states = self .audio_attn1 (
@@ -594,63 +646,74 @@ def forward(
594646 )
595647 audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
596648
597- # 2. Video and Audio Cross-Attention with the text embeddings
649+ # 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text)
650+ if self .cross_attn_adaln :
651+ video_prompt_ada_params = self .get_mod_params (self .prompt_scale_shift_table , temb_prompt , batch_size )
652+ shift_text_kv , scale_text_kv = video_prompt_ada_params
653+
654+ audio_prompt_ada_params = self .get_mod_params (self .audio_prompt_scale_shift_table , temb_prompt_audio , batch_size )
655+ audio_shift_text_kv , audio_scale_text_kv = audio_prompt_ada_params
656+
657+ # 2.1. Video-Text Cross-Attention (Q: Video; K,V: Test)
598658 norm_hidden_states = self .norm2 (hidden_states )
659+ if self .video_cross_attn_adaln :
660+ norm_hidden_states = norm_hidden_states * (1 + scale_text_q ) + shift_text_q
661+ if self .cross_attn_adaln :
662+ encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv ) + shift_text_kv
663+
599664 attn_hidden_states = self .attn2 (
600665 norm_hidden_states ,
601666 encoder_hidden_states = encoder_hidden_states ,
602667 query_rotary_emb = None ,
603668 attention_mask = encoder_attention_mask ,
604669 )
670+ if self .video_cross_attn_adaln :
671+ attn_hidden_states = attn_hidden_states * gate_text_q
605672 hidden_states = hidden_states + attn_hidden_states
606673
674+ # 2.2. Audio-Text Cross-Attention
607675 norm_audio_hidden_states = self .audio_norm2 (audio_hidden_states )
676+ if self .audio_cross_attn_adaln :
677+ norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q ) + audio_shift_text_q
678+ if self .cross_attn_adaln :
679+ audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv ) + audio_shift_text_kv
680+
608681 attn_audio_hidden_states = self .audio_attn2 (
609682 norm_audio_hidden_states ,
610683 encoder_hidden_states = audio_encoder_hidden_states ,
611684 query_rotary_emb = None ,
612685 attention_mask = audio_encoder_attention_mask ,
613686 )
687+ if self .audio_cross_attn_adaln :
688+ attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
614689 audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
615690
616691 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
617692 norm_hidden_states = self .audio_to_video_norm (hidden_states )
618693 norm_audio_hidden_states = self .video_to_audio_norm (audio_hidden_states )
619694
620- # Combine global and per-layer cross attention modulation parameters
695+ # 3.1. Combine global and per-layer cross attention modulation parameters
621696 # Video
622697 video_per_layer_ca_scale_shift = self .video_a2v_cross_attn_scale_shift_table [:4 , :]
623698 video_per_layer_ca_gate = self .video_a2v_cross_attn_scale_shift_table [4 :, :]
624699
625- video_ca_scale_shift_table = (
626- video_per_layer_ca_scale_shift [:, :, ...].to (temb_ca_scale_shift .dtype )
627- + temb_ca_scale_shift .reshape (batch_size , temb_ca_scale_shift .shape [1 ], 4 , - 1 )
628- ).unbind (dim = 2 )
629- video_ca_gate = (
630- video_per_layer_ca_gate [:, :, ...].to (temb_ca_gate .dtype )
631- + temb_ca_gate .reshape (batch_size , temb_ca_gate .shape [1 ], 1 , - 1 )
632- ).unbind (dim = 2 )
700+ video_ca_ada_params = self .get_mod_params (video_per_layer_ca_scale_shift , temb_ca_scale_shift , batch_size )
701+ video_ca_gate_param = self .get_mod_params (video_per_layer_ca_gate , temb_ca_gate , batch_size )
633702
634- video_a2v_ca_scale , video_a2v_ca_shift , video_v2a_ca_scale , video_v2a_ca_shift = video_ca_scale_shift_table
635- a2v_gate = video_ca_gate [0 ].squeeze (2 )
703+ video_a2v_ca_scale , video_a2v_ca_shift , video_v2a_ca_scale , video_v2a_ca_shift = video_ca_ada_params
704+ a2v_gate = video_ca_gate_param [0 ].squeeze (2 )
636705
637706 # Audio
638707 audio_per_layer_ca_scale_shift = self .audio_a2v_cross_attn_scale_shift_table [:4 , :]
639708 audio_per_layer_ca_gate = self .audio_a2v_cross_attn_scale_shift_table [4 :, :]
640709
641- audio_ca_scale_shift_table = (
642- audio_per_layer_ca_scale_shift [:, :, ...].to (temb_ca_audio_scale_shift .dtype )
643- + temb_ca_audio_scale_shift .reshape (batch_size , temb_ca_audio_scale_shift .shape [1 ], 4 , - 1 )
644- ).unbind (dim = 2 )
645- audio_ca_gate = (
646- audio_per_layer_ca_gate [:, :, ...].to (temb_ca_audio_gate .dtype )
647- + temb_ca_audio_gate .reshape (batch_size , temb_ca_audio_gate .shape [1 ], 1 , - 1 )
648- ).unbind (dim = 2 )
710+ audio_ca_ada_params = self .get_mod_params (audio_per_layer_ca_scale_shift , temb_ca_audio_scale_shift , batch_size )
711+ audio_ca_gate_param = self .get_mod_params (audio_per_layer_ca_gate , temb_ca_audio_gate , batch_size )
649712
650- audio_a2v_ca_scale , audio_a2v_ca_shift , audio_v2a_ca_scale , audio_v2a_ca_shift = audio_ca_scale_shift_table
651- v2a_gate = audio_ca_gate [0 ].squeeze (2 )
713+ audio_a2v_ca_scale , audio_a2v_ca_shift , audio_v2a_ca_scale , audio_v2a_ca_shift = audio_ca_ada_params
714+ v2a_gate = audio_ca_gate_param [0 ].squeeze (2 )
652715
653- # Audio-to-Video Cross Attention: Q: Video; K,V: Audio
716+ # 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
654717 mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale .squeeze (2 )) + video_a2v_ca_shift .squeeze (
655718 2
656719 )
@@ -668,7 +731,7 @@ def forward(
668731
669732 hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
670733
671- # Video-to-Audio Cross Attention: Q: Audio; K,V: Video
734+ # 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
672735 mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale .squeeze (2 )) + video_v2a_ca_shift .squeeze (
673736 2
674737 )
@@ -1209,6 +1272,7 @@ def forward(
12091272 audio_timestep : torch .LongTensor | None = None ,
12101273 encoder_attention_mask : torch .Tensor | None = None ,
12111274 audio_encoder_attention_mask : torch .Tensor | None = None ,
1275+ self_attention_mask : torch .Tensor | None = None ,
12121276 num_frames : int | None = None ,
12131277 height : int | None = None ,
12141278 width : int | None = None ,
@@ -1241,6 +1305,8 @@ def forward(
12411305 Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
12421306 audio_encoder_attention_mask (`torch.Tensor`, *optional*):
12431307 Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling.
1308+ self_attention_mask (`torch.Tensor`, *optional*):
1309+ Optional multiplicative self-attention mask of shape `(batch_size, seq_len, seq_len)`.
12441310 num_frames (`int`, *optional*):
12451311 The number of latent video frames. Used if calculating the video coordinates for RoPE.
12461312 height (`int`, *optional*):
@@ -1281,6 +1347,18 @@ def forward(
12811347 audio_encoder_attention_mask = (1 - audio_encoder_attention_mask .to (audio_hidden_states .dtype )) * - 10000.0
12821348 audio_encoder_attention_mask = audio_encoder_attention_mask .unsqueeze (1 )
12831349
1350+ if self_attention_mask is not None and self_attention_mask .ndim == 3 :
1351+ # Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative
1352+ # number and positive values are mapped to their logarithm.
1353+ dtype_finfo = torch .finfo (hidden_states .dtype )
1354+ additive_self_attn_mask = torch .full_like (self_attention_mask , dtype_finfo .min , dtype = hidden_states .dtype )
1355+ unmasked_entries = self_attention_mask > 0
1356+ if torch .any (unmasked_entries ):
1357+ additive_self_attn_mask [unmasked_entries ] = torch .log (
1358+ self_attention_mask [unmasked_entries ].clamp (min = dtype_finfo .tiny )
1359+ ).to (hidden_states .dtype )
1360+ self_attention_mask = additive_self_attn_mask .unsqueeze (1 ) # [batch_size, 1, seq_len, seq_len]
1361+
12841362 batch_size = hidden_states .size (0 )
12851363
12861364 # 1. Prepare RoPE positional embeddings
0 commit comments