Skip to content

Commit 985d83c

Browse files
authored
Fix LTX-2 Inference when num_videos_per_prompt > 1 and CFG is Enabled (#13121)
Fix LTX-2 inference when num_videos_per_prompt > 1 and CFG is enabled
1 parent ed77a24 commit 985d83c

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

src/diffusers/models/transformers/transformer_ltx2.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
5656
x_dtype = x.dtype
5757
needs_reshape = False
5858
if x.ndim != 4 and cos.ndim == 4:
59-
# cos is (#b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
60-
# The cos/sin batch dim may only be broadcastable, so take batch size from x
61-
b = x.shape[0]
62-
_, h, t, _ = cos.shape
59+
# cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
60+
b, h, t, _ = cos.shape
6361
x = x.reshape(b, t, h, -1).swapaxes(1, 2)
6462
needs_reshape = True
6563

src/diffusers/pipelines/ltx2/pipeline_ltx2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,10 @@ def __call__(
10811081
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
10821082
audio_latents.shape[0], audio_num_frames, audio_latents.device
10831083
)
1084+
# Duplicate the positional ids as well if using CFG
1085+
if self.do_classifier_free_guidance:
1086+
video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim
1087+
audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1))
10841088

10851089
# 7. Denoising loop
10861090
with self.progress_bar(total=num_inference_steps) as progress_bar:

src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,10 @@ def __call__(
11391139
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
11401140
audio_latents.shape[0], audio_num_frames, audio_latents.device
11411141
)
1142+
# Duplicate the positional ids as well if using CFG
1143+
if self.do_classifier_free_guidance:
1144+
video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim
1145+
audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1))
11421146

11431147
# 7. Denoising loop
11441148
with self.progress_bar(total=num_inference_steps) as progress_bar:

0 commit comments

Comments
 (0)