Skip to content

Commit f167c54

Browse files
authored
[diffusion] model: support z-image (sgl-project#14067)
1 parent 0f6e1c7 commit f167c54

File tree

15 files changed

+1051
-8
lines changed

15 files changed

+1051
-8
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
2+
3+
# SPDX-License-Identifier: Apache-2.0
4+
from dataclasses import dataclass, field
5+
from typing import Tuple
6+
7+
from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig
8+
9+
10+
@dataclass
11+
class ZImageArchConfig(DiTArchConfig):
12+
all_patch_size: Tuple[int, ...] = (2,)
13+
all_f_patch_size: Tuple[int, ...] = (1,)
14+
in_channels: int = 16
15+
out_channels: int | None = None
16+
dim: int = 3840
17+
num_layers: int = 30
18+
n_refiner_layers: int = 2
19+
num_attention_heads: int = 30
20+
n_kv_heads: int = 30
21+
norm_eps: float = 1e-5
22+
qk_norm: bool = True
23+
cap_feat_dim: int = 2560
24+
rope_theta: float = 256.0
25+
t_scale: float = 1000.0
26+
axes_dims: Tuple[int, int, int] = (32, 48, 48)
27+
axes_lens: Tuple[int, int, int] = (1024, 512, 512)
28+
29+
def __post_init__(self):
30+
super().__post_init__()
31+
self.out_channels = self.out_channels or self.in_channels
32+
self.num_channels_latents = self.in_channels
33+
self.hidden_size = self.dim
34+
35+
36+
@dataclass
37+
class ZImageDitConfig(DiTConfig):
38+
arch_config: ZImageArchConfig = field(default_factory=ZImageArchConfig)
39+
40+
prefix: str = "zimage"

python/sglang/multimodal_gen/configs/pipeline_configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
WanT2V480PConfig,
1818
WanT2V720PConfig,
1919
)
20+
from sglang.multimodal_gen.configs.pipeline_configs.zimage import ZImagePipelineConfig
2021

2122
__all__ = [
2223
"HunyuanConfig",
@@ -30,4 +31,5 @@
3031
"WanI2V720PConfig",
3132
"StepVideoT2VConfig",
3233
"SelfForcingWanT2V480PConfig",
34+
"ZImagePipelineConfig",
3335
]

python/sglang/multimodal_gen/configs/pipeline_configs/flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def _prepare_image_ids(
321321
return image_latent_ids
322322

323323

324-
def flux_2_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:
324+
def flux2_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:
325325
hidden_states_layers: list[int] = [10, 20, 30]
326326

327327
out = torch.stack([outputs.hidden_states[k] for k in hidden_states_layers], dim=1)
@@ -412,7 +412,7 @@ class Flux2PipelineConfig(FluxPipelineConfig):
412412
)
413413

414414
postprocess_text_funcs: tuple[Callable[[str], str], ...] = field(
415-
default_factory=lambda: (flux_2_postprocess_text,)
415+
default_factory=lambda: (flux2_postprocess_text,)
416416
)
417417
vae_config: VAEConfig = field(default_factory=Flux2VAEConfig)
418418

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
2+
from dataclasses import dataclass, field
3+
from typing import Callable
4+
5+
import torch
6+
7+
from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig
8+
from sglang.multimodal_gen.configs.models.dits.zimage import ZImageDitConfig
9+
from sglang.multimodal_gen.configs.models.encoders import (
10+
BaseEncoderOutput,
11+
TextEncoderConfig,
12+
)
13+
from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig
14+
from sglang.multimodal_gen.configs.pipeline_configs.base import (
15+
ImagePipelineConfig,
16+
ModelTaskType,
17+
)
18+
19+
20+
def zimage_preprocess_text(prompt: str):
21+
messages = [
22+
{"role": "user", "content": prompt},
23+
]
24+
return messages
25+
26+
27+
def zimage_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:
28+
device = outputs.hidden_states[-2].device
29+
prompt_mask = _text_inputs.attention_mask.to(device).bool()
30+
return outputs.hidden_states[-2][0][prompt_mask[0]]
31+
32+
33+
class TransformersModelConfig(EncoderConfig):
34+
tokenizer_kwargs: dict = field(default_factory=lambda: {})
35+
36+
37+
@dataclass
38+
class ZImagePipelineConfig(ImagePipelineConfig):
39+
40+
should_use_guidance: bool = False
41+
task_type: ModelTaskType = ModelTaskType.T2I
42+
43+
dit_config: DiTConfig = field(default_factory=ZImageDitConfig)
44+
vae_config: VAEConfig = field(default_factory=FluxVAEConfig)
45+
text_encoder_configs: tuple[EncoderConfig, ...] = field(
46+
default_factory=lambda: (TextEncoderConfig(),)
47+
)
48+
49+
preprocess_text_funcs: tuple[Callable, ...] = field(
50+
default_factory=lambda: (zimage_preprocess_text,)
51+
)
52+
53+
postprocess_text_funcs: tuple[Callable, ...] = field(
54+
default_factory=lambda: (zimage_postprocess_text,)
55+
)
56+
57+
def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict:
58+
# flatten to 1-d list
59+
inputs = tokenizer.apply_chat_template(
60+
prompts,
61+
tokenize=True,
62+
add_generation_prompt=True,
63+
enable_thinking=True,
64+
padding="max_length",
65+
max_length=512, # TODO (yhyang201): set max length according to config
66+
truncation=True,
67+
return_tensors="pt",
68+
return_dict=True,
69+
)
70+
return inputs
71+
72+
def post_denoising_loop(self, latents, batch):
73+
bs, channels, num_frames, height, width = latents.shape
74+
return latents.view(bs, channels, height, width)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
2+
3+
# SPDX-License-Identifier: Apache-2.0
4+
from dataclasses import dataclass, field
5+
6+
from sglang.multimodal_gen.configs.sample.base import SamplingParams
7+
from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams
8+
9+
10+
@dataclass
11+
class ZImageSamplingParams(SamplingParams):
12+
num_inference_steps: int = 9
13+
14+
num_frames: int = 1
15+
height: int = 720
16+
width: int = 1280
17+
fps: int = 24
18+
19+
guidance_scale: float = 0.0
20+
21+
teacache_params: TeaCacheParams = field(
22+
default_factory=lambda: TeaCacheParams(
23+
teacache_thresh=0.15,
24+
coefficients=[
25+
7.33226126e02,
26+
-4.01131952e02,
27+
6.75869174e01,
28+
-3.14987800e00,
29+
9.61237896e-02,
30+
],
31+
)
32+
)

python/sglang/multimodal_gen/registry.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
WanI2V720PConfig,
2525
WanT2V480PConfig,
2626
WanT2V720PConfig,
27+
ZImagePipelineConfig,
2728
)
2829
from sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig
2930
from sglang.multimodal_gen.configs.pipeline_configs.flux import Flux2PipelineConfig
@@ -56,6 +57,7 @@
5657
WanT2V_1_3B_SamplingParams,
5758
WanT2V_14B_SamplingParams,
5859
)
60+
from sglang.multimodal_gen.configs.sample.zimage import ZImageSamplingParams
5961
from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (
6062
ComposedPipelineBase,
6163
)
@@ -404,6 +406,15 @@ def _register_configs():
404406
],
405407
model_detectors=[lambda id: "flux.2" in id.lower()],
406408
)
409+
register_configs(
410+
model_name="Z-image",
411+
sampling_param_cls=ZImageSamplingParams,
412+
pipeline_config_cls=ZImagePipelineConfig,
413+
model_paths=[
414+
"Tongyi-MAI/Z-Image-Turbo",
415+
],
416+
model_detectors=[lambda id: "z-image" in id.lower()],
417+
)
407418

408419
# Qwen-Image
409420
register_configs(

0 commit comments

Comments
 (0)