Skip to content

Commit 6c7e720

Browse files
committed
Initial implementation of perturbed attn processor for LTX 2.3
1 parent 8ec0a5c commit 6c7e720

File tree

1 file changed

+107
-1
lines changed

1 file changed

+107
-1
lines changed

src/diffusers/models/transformers/transformer_ltx2.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,113 @@ def __call__(
217217
return hidden_states
218218

219219

220+
class LTX2PerturbedAttnProcessor:
221+
r"""
222+
Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models.
223+
"""
224+
225+
_attention_backend = None
226+
_parallel_config = None
227+
228+
def __init__(self):
229+
if is_torch_version("<", "2.0"):
230+
raise ValueError(
231+
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
232+
)
233+
234+
def __call__(
235+
self,
236+
attn: "LTX2Attention",
237+
hidden_states: torch.Tensor,
238+
encoder_hidden_states: torch.Tensor | None = None,
239+
attention_mask: torch.Tensor | None = None,
240+
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
241+
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
242+
perturbation_mask: torch.Tensor | None = None,
243+
all_perturbed: bool | None = None,
244+
) -> torch.Tensor:
245+
batch_size, sequence_length, _ = (
246+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
247+
)
248+
249+
if attention_mask is not None:
250+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
251+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
252+
253+
if encoder_hidden_states is None:
254+
encoder_hidden_states = hidden_states
255+
256+
if attn.to_gate_logits is not None:
257+
# Calculate gate logits on original hidden_states
258+
gate_logits = attn.to_gate_logits(hidden_states)
259+
260+
value = attn.to_v(encoder_hidden_states)
261+
if all_perturbed is None:
262+
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
263+
264+
if all_perturbed:
265+
# Skip attention, use the value projection value
266+
hidden_states = value
267+
else:
268+
query = attn.to_q(hidden_states)
269+
key = attn.to_k(encoder_hidden_states)
270+
271+
query = attn.norm_q(query)
272+
key = attn.norm_k(key)
273+
274+
if query_rotary_emb is not None:
275+
if attn.rope_type == "interleaved":
276+
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
277+
key = apply_interleaved_rotary_emb(
278+
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
279+
)
280+
elif attn.rope_type == "split":
281+
query = apply_split_rotary_emb(query, query_rotary_emb)
282+
key = apply_split_rotary_emb(
283+
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
284+
)
285+
286+
query = query.unflatten(2, (attn.heads, -1))
287+
key = key.unflatten(2, (attn.heads, -1))
288+
value = value.unflatten(2, (attn.heads, -1))
289+
290+
hidden_states = dispatch_attention_fn(
291+
query,
292+
key,
293+
value,
294+
attn_mask=attention_mask,
295+
dropout_p=0.0,
296+
is_causal=False,
297+
backend=self._attention_backend,
298+
parallel_config=self._parallel_config,
299+
)
300+
hidden_states = hidden_states.flatten(2, 3)
301+
hidden_states = hidden_states.to(query.dtype)
302+
303+
if perturbation_mask is not None:
304+
value = value.flatten(2, 3)
305+
hidden_states = torch.lerp(value, hidden_states, perturbation_mask)
306+
307+
if attn.to_gate_logits is not None:
308+
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
309+
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
310+
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
311+
hidden_states = hidden_states * gates.unsqueeze(-1)
312+
hidden_states = hidden_states.flatten(2, 3)
313+
314+
hidden_states = attn.to_out[0](hidden_states)
315+
hidden_states = attn.to_out[1](hidden_states)
316+
return hidden_states
317+
318+
220319
class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
221320
r"""
222321
Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key
223322
RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.
224323
"""
225324

226325
_default_processor_cls = LTX2AudioVideoAttnProcessor
227-
_available_processors = [LTX2AudioVideoAttnProcessor]
326+
_available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor]
228327

229328
def __init__(
230329
self,
@@ -240,6 +339,7 @@ def __init__(
240339
norm_eps: float = 1e-6,
241340
norm_elementwise_affine: bool = True,
242341
rope_type: str = "interleaved",
342+
apply_gated_attention: bool = False,
243343
processor=None,
244344
):
245345
super().__init__()
@@ -266,6 +366,12 @@ def __init__(
266366
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
267367
self.to_out.append(torch.nn.Dropout(dropout))
268368

369+
if apply_gated_attention:
370+
# Per head gate values
371+
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
372+
else:
373+
self.to_gate_logits = None
374+
269375
if processor is None:
270376
processor = self._default_processor_cls()
271377
self.set_processor(processor)

0 commit comments

Comments
 (0)