@@ -217,14 +217,113 @@ def __call__(
217217 return hidden_states
218218
219219
220+ class LTX2PerturbedAttnProcessor :
221+ r"""
222+ Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models.
223+ """
224+
225+ _attention_backend = None
226+ _parallel_config = None
227+
228+ def __init__ (self ):
229+ if is_torch_version ("<" , "2.0" ):
230+ raise ValueError (
231+ "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
232+ )
233+
234+ def __call__ (
235+ self ,
236+ attn : "LTX2Attention" ,
237+ hidden_states : torch .Tensor ,
238+ encoder_hidden_states : torch .Tensor | None = None ,
239+ attention_mask : torch .Tensor | None = None ,
240+ query_rotary_emb : tuple [torch .Tensor , torch .Tensor ] | None = None ,
241+ key_rotary_emb : tuple [torch .Tensor , torch .Tensor ] | None = None ,
242+ perturbation_mask : torch .Tensor | None = None ,
243+ all_perturbed : bool | None = None ,
244+ ) -> torch .Tensor :
245+ batch_size , sequence_length , _ = (
246+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
247+ )
248+
249+ if attention_mask is not None :
250+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
251+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
252+
253+ if encoder_hidden_states is None :
254+ encoder_hidden_states = hidden_states
255+
256+ if attn .to_gate_logits is not None :
257+ # Calculate gate logits on original hidden_states
258+ gate_logits = attn .to_gate_logits (hidden_states )
259+
260+ value = attn .to_v (encoder_hidden_states )
261+ if all_perturbed is None :
262+ all_perturbed = torch .all (perturbation_mask == 0 ) if perturbation_mask is not None else False
263+
264+ if all_perturbed :
265+ # Skip attention, use the value projection value
266+ hidden_states = value
267+ else :
268+ query = attn .to_q (hidden_states )
269+ key = attn .to_k (encoder_hidden_states )
270+
271+ query = attn .norm_q (query )
272+ key = attn .norm_k (key )
273+
274+ if query_rotary_emb is not None :
275+ if attn .rope_type == "interleaved" :
276+ query = apply_interleaved_rotary_emb (query , query_rotary_emb )
277+ key = apply_interleaved_rotary_emb (
278+ key , key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
279+ )
280+ elif attn .rope_type == "split" :
281+ query = apply_split_rotary_emb (query , query_rotary_emb )
282+ key = apply_split_rotary_emb (
283+ key , key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
284+ )
285+
286+ query = query .unflatten (2 , (attn .heads , - 1 ))
287+ key = key .unflatten (2 , (attn .heads , - 1 ))
288+ value = value .unflatten (2 , (attn .heads , - 1 ))
289+
290+ hidden_states = dispatch_attention_fn (
291+ query ,
292+ key ,
293+ value ,
294+ attn_mask = attention_mask ,
295+ dropout_p = 0.0 ,
296+ is_causal = False ,
297+ backend = self ._attention_backend ,
298+ parallel_config = self ._parallel_config ,
299+ )
300+ hidden_states = hidden_states .flatten (2 , 3 )
301+ hidden_states = hidden_states .to (query .dtype )
302+
303+ if perturbation_mask is not None :
304+ value = value .flatten (2 , 3 )
305+ hidden_states = torch .lerp (value , hidden_states , perturbation_mask )
306+
307+ if attn .to_gate_logits is not None :
308+ hidden_states = hidden_states .unflatten (2 , (attn .heads , - 1 )) # [B, T, H, D]
309+ # The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
310+ gates = 2.0 * torch .sigmoid (gate_logits ) # [B, T, H]
311+ hidden_states = hidden_states * gates .unsqueeze (- 1 )
312+ hidden_states = hidden_states .flatten (2 , 3 )
313+
314+ hidden_states = attn .to_out [0 ](hidden_states )
315+ hidden_states = attn .to_out [1 ](hidden_states )
316+ return hidden_states
317+
318+
220319class LTX2Attention (torch .nn .Module , AttentionModuleMixin ):
221320 r"""
222321 Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key
223322 RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.
224323 """
225324
226325 _default_processor_cls = LTX2AudioVideoAttnProcessor
227- _available_processors = [LTX2AudioVideoAttnProcessor ]
326+ _available_processors = [LTX2AudioVideoAttnProcessor , LTX2PerturbedAttnProcessor ]
228327
229328 def __init__ (
230329 self ,
@@ -240,6 +339,7 @@ def __init__(
240339 norm_eps : float = 1e-6 ,
241340 norm_elementwise_affine : bool = True ,
242341 rope_type : str = "interleaved" ,
342+ apply_gated_attention : bool = False ,
243343 processor = None ,
244344 ):
245345 super ().__init__ ()
@@ -266,6 +366,12 @@ def __init__(
266366 self .to_out .append (torch .nn .Linear (self .inner_dim , self .out_dim , bias = out_bias ))
267367 self .to_out .append (torch .nn .Dropout (dropout ))
268368
369+ if apply_gated_attention :
370+ # Per head gate values
371+ self .to_gate_logits = torch .nn .Linear (query_dim , heads , bias = True )
372+ else :
373+ self .to_gate_logits = None
374+
269375 if processor is None :
270376 processor = self ._default_processor_cls ()
271377 self .set_processor (processor )
0 commit comments