Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions sgl-router/scripts/generate_vision_golden.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@
"processor_class": "Phi3VImageProcessor",
"description": "Dynamic HD transform with 336x336 tiles",
},
"phi4_vision": {
"model_id": "microsoft/Phi-4-multimodal-instruct",
"processor_class": "Phi4MMImageProcessor",
"description": "Dynamic HD transform with 448x448 tiles and SiGLIP encoder",
},
}

# Default test images
Expand Down Expand Up @@ -419,6 +424,62 @@ def generate_golden_phi3_vision(image_path: str, output_dir: str) -> dict:
return result


def generate_golden_phi4_vision(image_path: str, output_dir: str) -> dict:
"""Generate golden output for Phi4-Vision (Phi-4-multimodal).

Phi4-Vision uses Dynamic HD transform similar to Phi3 but with:
- Base resolution: 448 (vs 336 in Phi3)
- Normalization: [0.5, 0.5, 0.5] mean/std (vs CLIP in Phi3)
- Default dynamic_hd: 36 (vs 16 num_crops in Phi3)
- Uses SiGLIP vision encoder (vs CLIP in Phi3)
- Has per-crop attention masks

Token count formula:
256 + 1 + mask_sum + mask_col0_sum + 16

Note: Phi4 uses 'input_image_embeds' key instead of 'pixel_values'
"""
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(
"microsoft/Phi-4-multimodal-instruct", trust_remote_code=True
)
image = Image.open(image_path).convert("RGB")
original_size = image.size

# Process image using the image processor directly
outputs = processor.image_processor(images=image, return_tensors="np")

# Phi4 uses 'input_image_embeds' instead of 'pixel_values'
pixel_values = outputs.get("input_image_embeds")
pixel_attention_mask = outputs.get("image_attention_mask")
image_sizes = outputs.get("image_sizes")
num_img_tokens = outputs.get("num_img_tokens")

result = {
"pixel_values": pixel_values,
"original_size": original_size,
"processor_config": processor.image_processor.to_dict(),
}

if pixel_attention_mask is not None:
result["pixel_attention_mask"] = np.array(pixel_attention_mask)

if image_sizes is not None:
result["image_sizes"] = np.array(image_sizes)

if num_img_tokens is not None:
result["num_img_tokens"] = np.array(num_img_tokens)

# Add debug info
result["config_info"] = {
"dynamic_hd": getattr(processor.image_processor, "dynamic_hd", 36),
"base_resolution": 448,
}

return result


def generate_for_model(model_key: str, image_paths: list, output_dir: str):
"""Generate golden outputs for a specific model."""
print(f"\nGenerating golden outputs for {model_key}...")
Expand All @@ -430,6 +491,7 @@ def generate_for_model(model_key: str, image_paths: list, output_dir: str):
"qwen2_vl": generate_golden_qwen2_vl,
"qwen3_vl": generate_golden_qwen3_vl,
"phi3_vision": generate_golden_phi3_vision,
"phi4_vision": generate_golden_phi4_vision,
}.get(model_key)

if generator_fn is None:
Expand Down
3 changes: 2 additions & 1 deletion sgl-router/src/multimodal/vision/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub use image_processor::{
};
pub use preprocessor_config::PreProcessorConfig;
pub use processors::{
LlavaNextProcessor, LlavaProcessor, Phi3VisionProcessor, Qwen2VLProcessor, Qwen3VLProcessor,
LlavaNextProcessor, LlavaProcessor, Phi3VisionProcessor, Phi4VisionProcessor, Qwen2VLProcessor,
Qwen3VLProcessor,
};
pub use transforms::TransformError;
4 changes: 4 additions & 0 deletions sgl-router/src/multimodal/vision/preprocessor_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ pub struct PreProcessorConfig {
#[serde(default)]
pub num_crops: Option<usize>,

/// Phi4-Vision: dynamic HD max crops
#[serde(default)]
pub dynamic_hd: Option<usize>,

/// LLaMA-Vision: maximum image tiles
#[serde(default)]
pub max_image_tiles: Option<usize>,
Expand Down
3 changes: 3 additions & 0 deletions sgl-router/src/multimodal/vision/processors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
//! - **Qwen2.5-VL** (`qwen2_vl`): Same processor as Qwen2-VL (identical preprocessing)
//! - **Qwen3-VL** (`qwen3_vl`): Similar to Qwen2-VL but with patch_size=16 and [0.5,0.5,0.5] normalization
//! - **Phi3-Vision** (`phi3_vision`): Dynamic HD transform with 336x336 tiles
//! - **Phi4-Vision** (`phi4_vision`): Dynamic HD transform with 448x448 tiles and SiGLIP encoder

pub mod llava;
pub mod phi3_vision;
pub mod phi4_vision;
pub mod qwen2_vl;
pub mod qwen3_vl;
pub mod qwen_vl_base;

pub use llava::{ImageAspectRatio, LlavaNextProcessor, LlavaProcessor};
pub use phi3_vision::Phi3VisionProcessor;
pub use phi4_vision::Phi4VisionProcessor;
pub use qwen2_vl::Qwen2VLProcessor;
pub use qwen3_vl::Qwen3VLProcessor;
57 changes: 8 additions & 49 deletions sgl-router/src/multimodal/vision/processors/phi3_vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ impl Phi3VisionProcessor {
let new_w = (scale * TILE_SIZE as f64) as u32;
let new_h = (new_w as f64 / ratio) as u32;

// Resize using bilinear filter (matching HuggingFace's default)
// Resize using bilinear filter (matching torchvision's bilinear+antialias)
// HuggingFace uses torchvision.transforms.functional.resize with
// BILINEAR interpolation and antialias=True. PIL's BILINEAR includes
// implicit antialiasing that closely matches torchvision.
let resized = img.resize_exact(new_w, new_h, FilterType::Triangle);

// Pad height to multiple of 336
Expand Down Expand Up @@ -170,56 +173,12 @@ impl Phi3VisionProcessor {
new_image
}

/// Create global image by bilinear interpolation to 336x336.
/// Create global image by bicubic interpolation to 336x336.
///
/// Uses PyTorch-compatible coordinate mapping with align_corners=False:
/// `src = (dst + 0.5) * (src_size / dst_size) - 0.5`
/// Uses the shared `bicubic_resize` which matches PyTorch's
/// `torch.nn.functional.interpolate(mode='bicubic', align_corners=False)`.
fn create_global_image(&self, tensor: &Array3<f32>) -> Array3<f32> {
// tensor is [C, H, W], we need to resize to [C, 336, 336]
let (_c, h, w) = (tensor.shape()[0], tensor.shape()[1], tensor.shape()[2]);

if h == TILE_SIZE as usize && w == TILE_SIZE as usize {
return tensor.clone();
}

let mut result = Array3::<f32>::zeros((3, TILE_SIZE as usize, TILE_SIZE as usize));

// PyTorch align_corners=False coordinate mapping
let scale_h = h as f32 / TILE_SIZE as f32;
let scale_w = w as f32 / TILE_SIZE as f32;

for c in 0..3 {
for y in 0..TILE_SIZE as usize {
for x in 0..TILE_SIZE as usize {
// PyTorch align_corners=False: src = (dst + 0.5) * scale - 0.5
let src_y = ((y as f32 + 0.5) * scale_h - 0.5).max(0.0);
let src_x = ((x as f32 + 0.5) * scale_w - 0.5).max(0.0);

// Bilinear interpolation
let y0 = src_y.floor() as usize;
let x0 = src_x.floor() as usize;
let y1 = (y0 + 1).min(h - 1);
let x1 = (x0 + 1).min(w - 1);

let fy = src_y - y0 as f32;
let fx = src_x - x0 as f32;

let v00 = tensor[[c, y0, x0]];
let v01 = tensor[[c, y0, x1]];
let v10 = tensor[[c, y1, x0]];
let v11 = tensor[[c, y1, x1]];

let value = v00 * (1.0 - fx) * (1.0 - fy)
+ v01 * fx * (1.0 - fy)
+ v10 * (1.0 - fx) * fy
+ v11 * fx * fy;

result[[c, y, x]] = value;
}
}
}

result
transforms::bicubic_resize(tensor, TILE_SIZE as usize, TILE_SIZE as usize)
}

/// Reshape HD image into tiles.
Expand Down
Loading
Loading