1+ import unittest
2+
3+ import numpy as np
4+ import torch
5+ from PIL import Image
6+ from transformers import Qwen2TokenizerFast , Qwen3Config , Qwen3ForCausalLM
7+
8+ from diffusers import (
9+ AutoencoderKLFlux2 ,
10+ FlowMatchEulerDiscreteScheduler ,
11+ Flux2KleinPipeline ,
12+ Flux2Transformer2DModel ,
13+ )
14+
15+ from ...testing_utils import (
16+ torch_device
17+ )
18+ from ..test_pipelines_common import (
19+ PipelineTesterMixin ,
20+ check_qkv_fused_layers_exist
21+ )
22+
23+
24+ class Flux2KleinPipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
25+ pipeline_class = Flux2KleinPipeline
26+ params = frozenset (["prompt" , "height" , "width" , "guidance_scale" , "prompt_embeds" ])
27+ batch_params = frozenset (["prompt" ])
28+
29+ test_xformers_attention = False
30+ test_layerwise_casting = True
31+ test_group_offloading = True
32+
33+ supports_dduf = False
34+
35+ def get_dummy_components (self , num_layers : int = 1 , num_single_layers : int = 1 ):
36+ torch .manual_seed (0 )
37+ transformer = Flux2Transformer2DModel (
38+ patch_size = 1 ,
39+ in_channels = 4 ,
40+ num_layers = num_layers ,
41+ num_single_layers = num_single_layers ,
42+ attention_head_dim = 16 ,
43+ num_attention_heads = 2 ,
44+ joint_attention_dim = 16 ,
45+ timestep_guidance_channels = 256 ,
46+ axes_dims_rope = [4 , 4 , 4 , 4 ],
47+ guidance_embeds = False
48+ )
49+
50+ # Create minimal Qwen3 config
51+ config = Qwen3Config (
52+ intermediate_size = 16 ,
53+ hidden_size = 16 ,
54+ num_hidden_layers = 2 ,
55+ num_attention_heads = 2 ,
56+ num_key_value_heads = 2 ,
57+ vocab_size = 151936 ,
58+ max_position_embeddings = 512 ,
59+ )
60+ torch .manual_seed (0 )
61+ text_encoder = Qwen3ForCausalLM (config )
62+
63+ # Use a simple tokenizer for testing
64+ tokenizer = Qwen2TokenizerFast .from_pretrained ("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" )
65+
66+ torch .manual_seed (0 )
67+ vae = AutoencoderKLFlux2 (
68+ sample_size = 32 ,
69+ in_channels = 3 ,
70+ out_channels = 3 ,
71+ down_block_types = ("DownEncoderBlock2D" ,),
72+ up_block_types = ("UpDecoderBlock2D" ,),
73+ block_out_channels = (4 ,),
74+ layers_per_block = 1 ,
75+ latent_channels = 1 ,
76+ norm_num_groups = 1 ,
77+ use_quant_conv = False ,
78+ use_post_quant_conv = False ,
79+ )
80+
81+ scheduler = FlowMatchEulerDiscreteScheduler ()
82+
83+ return {
84+ "scheduler" : scheduler ,
85+ "text_encoder" : text_encoder ,
86+ "tokenizer" : tokenizer ,
87+ "transformer" : transformer ,
88+ "vae" : vae ,
89+ }
90+
91+ def get_dummy_inputs (self , device , seed = 0 ):
92+ if str (device ).startswith ("mps" ):
93+ generator = torch .manual_seed (seed )
94+ else :
95+ generator = torch .Generator (device = "cpu" ).manual_seed (seed )
96+
97+ inputs = {
98+ "prompt" : "a dog is dancing" ,
99+ "generator" : generator ,
100+ "num_inference_steps" : 2 ,
101+ "guidance_scale" : 4.0 ,
102+ "height" : 8 ,
103+ "width" : 8 ,
104+ "max_sequence_length" : 64 ,
105+ "output_type" : "np" ,
106+ "text_encoder_out_layers" : (1 ,)
107+ }
108+ return inputs
109+
110+ def test_fused_qkv_projections (self ):
111+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
112+ components = self .get_dummy_components ()
113+ pipe = self .pipeline_class (** components )
114+ pipe = pipe .to (device )
115+ pipe .set_progress_bar_config (disable = None )
116+
117+ inputs = self .get_dummy_inputs (device )
118+ image = pipe (** inputs ).images
119+ original_image_slice = image [0 , - 3 :, - 3 :, - 1 ]
120+
121+ pipe .transformer .fuse_qkv_projections ()
122+ self .assertTrue (
123+ check_qkv_fused_layers_exist (pipe .transformer , ["to_qkv" ]),
124+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused." ),
125+ )
126+
127+ inputs = self .get_dummy_inputs (device )
128+ image = pipe (** inputs ).images
129+ image_slice_fused = image [0 , - 3 :, - 3 :, - 1 ]
130+
131+ pipe .transformer .unfuse_qkv_projections ()
132+ inputs = self .get_dummy_inputs (device )
133+ image = pipe (** inputs ).images
134+ image_slice_disabled = image [0 , - 3 :, - 3 :, - 1 ]
135+
136+ self .assertTrue (
137+ np .allclose (original_image_slice , image_slice_fused , atol = 1e-3 , rtol = 1e-3 ),
138+ ("Fusion of QKV projections shouldn't affect the outputs." ),
139+ )
140+ self .assertTrue (
141+ np .allclose (image_slice_fused , image_slice_disabled , atol = 1e-3 , rtol = 1e-3 ),
142+ ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." ),
143+ )
144+ self .assertTrue (
145+ np .allclose (original_image_slice , image_slice_disabled , atol = 1e-2 , rtol = 1e-2 ),
146+ ("Original outputs should match when fused QKV projections are disabled." ),
147+ )
148+
149+ def test_image_output_shape (self ):
150+ pipe = self .pipeline_class (** self .get_dummy_components ()).to (torch_device )
151+ inputs = self .get_dummy_inputs (torch_device )
152+
153+ height_width_pairs = [(32 , 32 ), (72 , 57 )]
154+ for height , width in height_width_pairs :
155+ expected_height = height - height % (pipe .vae_scale_factor * 2 )
156+ expected_width = width - width % (pipe .vae_scale_factor * 2 )
157+
158+ inputs .update ({"height" : height , "width" : width })
159+ image = pipe (** inputs ).images [0 ]
160+ output_height , output_width , _ = image .shape
161+ self .assertEqual (
162+ (output_height , output_width ),
163+ (expected_height , expected_width ),
164+ f"Output shape { image .shape } does not match expected shape { (expected_height , expected_width )} " ,
165+ )
166+
167+ def test_image_input (self ):
168+ device = "cpu"
169+ pipe = self .pipeline_class (** self .get_dummy_components ()).to (device )
170+ inputs = self .get_dummy_inputs (device )
171+
172+ inputs ["image" ] = Image .new ("RGB" , (64 , 64 ))
173+ image = pipe (** inputs ).images .flatten ()
174+ generated_slice = np .concatenate ([image [:8 ], image [- 8 :]])
175+ # fmt: off
176+ expected_slice = np .array (
177+ [
178+ 0.8255048 , 0.66054785 , 0.6643694 , 0.67462724 , 0.5494932 , 0.3480271 , 0.52535003 , 0.44510138 , 0.23549396 , 0.21372932 , 0.21166152 , 0.63198495 , 0.49942136 , 0.39147034 , 0.49156153 , 0.3713916
179+ ]
180+ )
181+ # fmt: on
182+ assert np .allclose (expected_slice , generated_slice , atol = 1e-4 , rtol = 1e-4 )
183+
184+ @unittest .skip ("Needs to be revisited" )
185+ def test_encode_prompt_works_in_isolation (self ):
186+ pass
0 commit comments