Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions python/sglang/multimodal_gen/configs/models/dits/zimage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Tuple

from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig


@dataclass
class ZImageArchConfig(DiTArchConfig):
all_patch_size: Tuple[int, ...] = (2,)
all_f_patch_size: Tuple[int, ...] = (1,)
in_channels: int = 16
out_channels: int | None = None
dim: int = 3840
num_layers: int = 30
n_refiner_layers: int = 2
num_attention_heads: int = 30
n_kv_heads: int = 30
norm_eps: float = 1e-5
qk_norm: bool = True
cap_feat_dim: int = 2560
rope_theta: float = 256.0
t_scale: float = 1000.0
axes_dims: Tuple[int, int, int] = (32, 48, 48)
axes_lens: Tuple[int, int, int] = (1024, 512, 512)

def __post_init__(self):
super().__post_init__()
self.out_channels = self.out_channels or self.in_channels
self.num_channels_latents = self.in_channels
self.hidden_size = self.dim


@dataclass
class ZImageDitConfig(DiTConfig):
arch_config: ZImageArchConfig = field(default_factory=ZImageArchConfig)

prefix: str = "zimage"
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
WanT2V480PConfig,
WanT2V720PConfig,
)
from sglang.multimodal_gen.configs.pipeline_configs.zimage import ZImagePipelineConfig

__all__ = [
"HunyuanConfig",
Expand All @@ -30,4 +31,5 @@
"WanI2V720PConfig",
"StepVideoT2VConfig",
"SelfForcingWanT2V480PConfig",
"ZImagePipelineConfig",
]
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _prepare_image_ids(
return image_latent_ids


def flux_2_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:
def flux2_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:
hidden_states_layers: list[int] = [10, 20, 30]

out = torch.stack([outputs.hidden_states[k] for k in hidden_states_layers], dim=1)
Expand Down Expand Up @@ -412,7 +412,7 @@ class Flux2PipelineConfig(FluxPipelineConfig):
)

postprocess_text_funcs: tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (flux_2_postprocess_text,)
default_factory=lambda: (flux2_postprocess_text,)
)
vae_config: VAEConfig = field(default_factory=Flux2VAEConfig)

Expand Down
74 changes: 74 additions & 0 deletions python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from dataclasses import dataclass, field
from typing import Callable

import torch

from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig
from sglang.multimodal_gen.configs.models.dits.zimage import ZImageDitConfig
from sglang.multimodal_gen.configs.models.encoders import (
BaseEncoderOutput,
TextEncoderConfig,
)
from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig
from sglang.multimodal_gen.configs.pipeline_configs.base import (
ImagePipelineConfig,
ModelTaskType,
)


def zimage_preprocess_text(prompt: str):
messages = [
{"role": "user", "content": prompt},
]
return messages


def zimage_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:
device = outputs.hidden_states[-2].device
prompt_mask = _text_inputs.attention_mask.to(device).bool()
return outputs.hidden_states[-2][0][prompt_mask[0]]


class TransformersModelConfig(EncoderConfig):
tokenizer_kwargs: dict = field(default_factory=lambda: {})


@dataclass
class ZImagePipelineConfig(ImagePipelineConfig):

should_use_guidance: bool = False
task_type: ModelTaskType = ModelTaskType.T2I

dit_config: DiTConfig = field(default_factory=ZImageDitConfig)
vae_config: VAEConfig = field(default_factory=FluxVAEConfig)
text_encoder_configs: tuple[EncoderConfig, ...] = field(
default_factory=lambda: (TextEncoderConfig(),)
)

preprocess_text_funcs: tuple[Callable, ...] = field(
default_factory=lambda: (zimage_preprocess_text,)
)

postprocess_text_funcs: tuple[Callable, ...] = field(
default_factory=lambda: (zimage_postprocess_text,)
)

def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict:
# flatten to 1-d list
inputs = tokenizer.apply_chat_template(
prompts,
tokenize=True,
add_generation_prompt=True,
enable_thinking=True,
padding="max_length",
max_length=512, # TODO (yhyang201): set max length according to config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The max_length is hardcoded to 512. As the TODO comment suggests, this should be made configurable rather than being a fixed value.

truncation=True,
return_tensors="pt",
return_dict=True,
)
return inputs

def post_denoising_loop(self, latents, batch):
bs, channels, num_frames, height, width = latents.shape
return latents.view(bs, channels, height, width)
32 changes: 32 additions & 0 deletions python/sglang/multimodal_gen/configs/sample/zimage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

from sglang.multimodal_gen.configs.sample.base import SamplingParams
from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams


@dataclass
class ZImageSamplingParams(SamplingParams):
num_inference_steps: int = 9

num_frames: int = 1
height: int = 720
width: int = 1280
fps: int = 24

guidance_scale: float = 0.0

teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.15,
coefficients=[
7.33226126e02,
-4.01131952e02,
6.75869174e01,
-3.14987800e00,
9.61237896e-02,
],
)
)
11 changes: 11 additions & 0 deletions python/sglang/multimodal_gen/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
WanI2V720PConfig,
WanT2V480PConfig,
WanT2V720PConfig,
ZImagePipelineConfig,
)
from sglang.multimodal_gen.configs.pipeline_configs.base import PipelineConfig
from sglang.multimodal_gen.configs.pipeline_configs.flux import Flux2PipelineConfig
Expand Down Expand Up @@ -56,6 +57,7 @@
WanT2V_1_3B_SamplingParams,
WanT2V_14B_SamplingParams,
)
from sglang.multimodal_gen.configs.sample.zimage import ZImageSamplingParams
from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (
ComposedPipelineBase,
)
Expand Down Expand Up @@ -404,6 +406,15 @@ def _register_configs():
],
model_detectors=[lambda id: "flux.2" in id.lower()],
)
register_configs(
model_name="Z-image",
sampling_param_cls=ZImageSamplingParams,
pipeline_config_cls=ZImagePipelineConfig,
model_paths=[
"Tongyi-MAI/Z-Image-Turbo",
],
model_detectors=[lambda id: "z-image" in id.lower()],
)

# Qwen-Image
register_configs(
Expand Down
Loading
Loading