Skip to content

Commit 24fb03a

Browse files
sayakpaulyiyixuxu
authored andcommitted
Klein tests (#2)
* tests * up * tests * up
1 parent ba3aaef commit 24fb03a

File tree

2 files changed

+197
-1
lines changed

2 files changed

+197
-1
lines changed

src/diffusers/pipelines/flux2/pipeline_flux2_klein.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def encode_prompt(
417417
num_images_per_prompt: int = 1,
418418
prompt_embeds: Optional[torch.Tensor] = None,
419419
max_sequence_length: int = 512,
420+
text_encoder_out_layers: Tuple[int] = (9, 18, 27),
420421
):
421422
device = device or self._execution_device
422423

@@ -432,6 +433,7 @@ def encode_prompt(
432433
prompt=prompt,
433434
device=device,
434435
max_sequence_length=max_sequence_length,
436+
hidden_states_layers=text_encoder_out_layers
435437
)
436438

437439
batch_size, seq_len, _ = prompt_embeds.shape
@@ -604,6 +606,7 @@ def __call__(
604606
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
605607
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
606608
max_sequence_length: int = 512,
609+
text_encoder_out_layers: Tuple[int] = (9, 18, 27)
607610
):
608611
r"""
609612
Function invoked when calling the pipeline for generation.
@@ -666,6 +669,8 @@ def __call__(
666669
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
667670
`._callback_tensor_inputs` attribute of your pipeline class.
668671
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
672+
text_encoder_out_layers (`Tuple[int]`):
673+
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
669674
670675
Examples:
671676
@@ -706,15 +711,20 @@ def __call__(
706711
device=device,
707712
num_images_per_prompt=num_images_per_prompt,
708713
max_sequence_length=max_sequence_length,
714+
text_encoder_out_layers=text_encoder_out_layers
709715
)
710716

711717
if self.do_classifier_free_guidance:
718+
negative_prompt = ""
719+
if prompt is not None and isinstance(prompt, list):
720+
negative_prompt = [negative_prompt] * len(prompt)
712721
negative_prompt_embeds, negative_text_ids = self.encode_prompt(
713-
prompt="",
722+
prompt=negative_prompt,
714723
prompt_embeds=None,
715724
device=device,
716725
num_images_per_prompt=num_images_per_prompt,
717726
max_sequence_length=max_sequence_length,
727+
text_encoder_out_layers=text_encoder_out_layers,
718728
)
719729

720730
# 4. process images
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import unittest
2+
3+
import numpy as np
4+
import torch
5+
from PIL import Image
6+
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
7+
8+
from diffusers import (
9+
AutoencoderKLFlux2,
10+
FlowMatchEulerDiscreteScheduler,
11+
Flux2KleinPipeline,
12+
Flux2Transformer2DModel,
13+
)
14+
15+
from ...testing_utils import (
16+
torch_device
17+
)
18+
from ..test_pipelines_common import (
19+
PipelineTesterMixin,
20+
check_qkv_fused_layers_exist
21+
)
22+
23+
24+
class Flux2KleinPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
25+
pipeline_class = Flux2KleinPipeline
26+
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
27+
batch_params = frozenset(["prompt"])
28+
29+
test_xformers_attention = False
30+
test_layerwise_casting = True
31+
test_group_offloading = True
32+
33+
supports_dduf = False
34+
35+
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
36+
torch.manual_seed(0)
37+
transformer = Flux2Transformer2DModel(
38+
patch_size=1,
39+
in_channels=4,
40+
num_layers=num_layers,
41+
num_single_layers=num_single_layers,
42+
attention_head_dim=16,
43+
num_attention_heads=2,
44+
joint_attention_dim=16,
45+
timestep_guidance_channels=256,
46+
axes_dims_rope=[4, 4, 4, 4],
47+
guidance_embeds=False
48+
)
49+
50+
# Create minimal Qwen3 config
51+
config = Qwen3Config(
52+
intermediate_size=16,
53+
hidden_size=16,
54+
num_hidden_layers=2,
55+
num_attention_heads=2,
56+
num_key_value_heads=2,
57+
vocab_size=151936,
58+
max_position_embeddings=512,
59+
)
60+
torch.manual_seed(0)
61+
text_encoder = Qwen3ForCausalLM(config)
62+
63+
# Use a simple tokenizer for testing
64+
tokenizer = Qwen2TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
65+
66+
torch.manual_seed(0)
67+
vae = AutoencoderKLFlux2(
68+
sample_size=32,
69+
in_channels=3,
70+
out_channels=3,
71+
down_block_types=("DownEncoderBlock2D",),
72+
up_block_types=("UpDecoderBlock2D",),
73+
block_out_channels=(4,),
74+
layers_per_block=1,
75+
latent_channels=1,
76+
norm_num_groups=1,
77+
use_quant_conv=False,
78+
use_post_quant_conv=False,
79+
)
80+
81+
scheduler = FlowMatchEulerDiscreteScheduler()
82+
83+
return {
84+
"scheduler": scheduler,
85+
"text_encoder": text_encoder,
86+
"tokenizer": tokenizer,
87+
"transformer": transformer,
88+
"vae": vae,
89+
}
90+
91+
def get_dummy_inputs(self, device, seed=0):
92+
if str(device).startswith("mps"):
93+
generator = torch.manual_seed(seed)
94+
else:
95+
generator = torch.Generator(device="cpu").manual_seed(seed)
96+
97+
inputs = {
98+
"prompt": "a dog is dancing",
99+
"generator": generator,
100+
"num_inference_steps": 2,
101+
"guidance_scale": 4.0,
102+
"height": 8,
103+
"width": 8,
104+
"max_sequence_length": 64,
105+
"output_type": "np",
106+
"text_encoder_out_layers": (1,)
107+
}
108+
return inputs
109+
110+
def test_fused_qkv_projections(self):
111+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
112+
components = self.get_dummy_components()
113+
pipe = self.pipeline_class(**components)
114+
pipe = pipe.to(device)
115+
pipe.set_progress_bar_config(disable=None)
116+
117+
inputs = self.get_dummy_inputs(device)
118+
image = pipe(**inputs).images
119+
original_image_slice = image[0, -3:, -3:, -1]
120+
121+
pipe.transformer.fuse_qkv_projections()
122+
self.assertTrue(
123+
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
124+
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
125+
)
126+
127+
inputs = self.get_dummy_inputs(device)
128+
image = pipe(**inputs).images
129+
image_slice_fused = image[0, -3:, -3:, -1]
130+
131+
pipe.transformer.unfuse_qkv_projections()
132+
inputs = self.get_dummy_inputs(device)
133+
image = pipe(**inputs).images
134+
image_slice_disabled = image[0, -3:, -3:, -1]
135+
136+
self.assertTrue(
137+
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
138+
("Fusion of QKV projections shouldn't affect the outputs."),
139+
)
140+
self.assertTrue(
141+
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
142+
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
143+
)
144+
self.assertTrue(
145+
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
146+
("Original outputs should match when fused QKV projections are disabled."),
147+
)
148+
149+
def test_image_output_shape(self):
150+
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
151+
inputs = self.get_dummy_inputs(torch_device)
152+
153+
height_width_pairs = [(32, 32), (72, 57)]
154+
for height, width in height_width_pairs:
155+
expected_height = height - height % (pipe.vae_scale_factor * 2)
156+
expected_width = width - width % (pipe.vae_scale_factor * 2)
157+
158+
inputs.update({"height": height, "width": width})
159+
image = pipe(**inputs).images[0]
160+
output_height, output_width, _ = image.shape
161+
self.assertEqual(
162+
(output_height, output_width),
163+
(expected_height, expected_width),
164+
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
165+
)
166+
167+
def test_image_input(self):
168+
device = "cpu"
169+
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
170+
inputs = self.get_dummy_inputs(device)
171+
172+
inputs["image"] = Image.new("RGB", (64, 64))
173+
image = pipe(**inputs).images.flatten()
174+
generated_slice = np.concatenate([image[:8], image[-8:]])
175+
# fmt: off
176+
expected_slice = np.array(
177+
[
178+
0.8255048 , 0.66054785, 0.6643694 , 0.67462724, 0.5494932 , 0.3480271 , 0.52535003, 0.44510138, 0.23549396, 0.21372932, 0.21166152, 0.63198495, 0.49942136, 0.39147034, 0.49156153, 0.3713916
179+
]
180+
)
181+
# fmt: on
182+
assert np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4)
183+
184+
@unittest.skip("Needs to be revisited")
185+
def test_encode_prompt_works_in_isolation(self):
186+
pass

0 commit comments

Comments
 (0)