Skip to content

Commit 288270c

Browse files
slin1237tonyluj
authored andcommitted
[model-gateway] multimodality initialization (sgl-project#13350)
1 parent 364e92b commit 288270c

File tree

9 files changed

+1351
-1
lines changed

9 files changed

+1351
-1
lines changed

sgl-router/Cargo.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ version = "0.2.3"
44
edition = "2021"
55

66
[features]
7-
default = []
7+
default = ["grpc-client"]
8+
grpc-client = []
9+
grpc-server = []
10+
811
vendored-openssl = ["openssl/vendored"]
912

1013
[lints.rust]
@@ -36,6 +39,7 @@ serde_json = { version = "1.0", default-features = false, features = [
3639
"std",
3740
"preserve_order",
3841
] }
42+
serde_bytes = "0.11"
3943
bytes = "1.8.0"
4044
rand = "0.9.2"
4145
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json", "rustls-tls"], default-features = false }
@@ -86,6 +90,9 @@ oracle = { version = "0.6.3", features = ["chrono"] }
8690
subtle = "2.6"
8791
rustpython-parser = "0.4.0"
8892
num-traits = "0.2"
93+
image = { version = "0.25.4", default-features = false, features = ["png", "jpeg", "gif", "bmp", "ico", "tiff", "webp"] }
94+
ndarray = "0.16"
95+
base64 = "0.22"
8996
openai-harmony = { git = "https://github.com/openai/harmony", tag = "v0.0.4" }
9097
openmetrics-parser = "0.4.4"
9198

sgl-router/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub mod grpc_client;
88
pub mod mcp;
99
pub mod metrics;
1010
pub mod middleware;
11+
pub mod multimodal;
1112
pub mod policies;
1213
pub mod protocols;
1314
pub mod reasoning_parser;

sgl-router/src/multimodal/error.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
use std::time::Duration;
2+
3+
use thiserror::Error;
4+
5+
use super::types::Modality;
6+
7+
pub type MultiModalResult<T> = Result<T, MultiModalError>;
8+
9+
#[derive(Debug, Error)]
10+
pub enum MediaConnectorError {
11+
#[error("unsupported media scheme: {0}")]
12+
UnsupportedScheme(String),
13+
#[error("invalid media URL: {0}")]
14+
InvalidUrl(String),
15+
#[error("media domain '{0}' is not in the allow list")]
16+
DisallowedDomain(String),
17+
#[error("local media path is not allowed: {0}")]
18+
DisallowedLocalPath(String),
19+
#[error("HTTP error while fetching media: {0}")]
20+
Http(#[from] reqwest::Error),
21+
#[error("I/O error while reading media: {0}")]
22+
Io(#[from] std::io::Error),
23+
#[error("base64 decode error: {0}")]
24+
Base64Decode(#[from] base64::DecodeError),
25+
#[error("data URL parse error: {0}")]
26+
DataUrl(String),
27+
#[error("media decode task failed: {0}")]
28+
Blocking(#[from] tokio::task::JoinError),
29+
#[error("image decode error: {0}")]
30+
Image(#[from] image::ImageError),
31+
#[error("media fetch timed out after {0:?}")]
32+
Timeout(Duration),
33+
}
34+
35+
#[derive(Debug, Error)]
36+
pub enum MultiModalError {
37+
#[error(transparent)]
38+
Media(#[from] MediaConnectorError),
39+
#[error("unsupported content part: {0}")]
40+
UnsupportedContent(&'static str),
41+
#[error("too many {modality:?} items provided. limit={limit}")]
42+
ModalityLimit { modality: Modality, limit: usize },
43+
#[error("tracker task join error: {0}")]
44+
Join(#[from] tokio::task::JoinError),
45+
#[error("tracker validation error: {0}")]
46+
Validation(String),
47+
}

sgl-router/src/multimodal/media.rs

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
use std::{collections::HashSet, path::PathBuf, sync::Arc, time::Duration};
2+
3+
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine};
4+
use bytes::Bytes;
5+
use image::DynamicImage;
6+
use reqwest::Client;
7+
use tokio::{fs, task};
8+
use url::Url;
9+
10+
use super::{
11+
error::MediaConnectorError,
12+
types::{ImageDetail, ImageFrame, ImageSource},
13+
};
14+
15+
#[derive(Clone)]
16+
pub struct MediaConnectorConfig {
17+
pub allowed_domains: Option<Vec<String>>,
18+
pub allowed_local_media_path: Option<PathBuf>,
19+
pub fetch_timeout: Duration,
20+
}
21+
22+
impl Default for MediaConnectorConfig {
23+
fn default() -> Self {
24+
Self {
25+
allowed_domains: None,
26+
allowed_local_media_path: None,
27+
fetch_timeout: Duration::from_secs(10),
28+
}
29+
}
30+
}
31+
32+
#[derive(Clone, Copy, Debug)]
33+
pub struct ImageFetchConfig {
34+
pub detail: ImageDetail,
35+
}
36+
37+
impl Default for ImageFetchConfig {
38+
fn default() -> Self {
39+
Self {
40+
detail: ImageDetail::Auto,
41+
}
42+
}
43+
}
44+
45+
#[derive(Debug, Clone)]
46+
pub enum MediaSource {
47+
Url(String),
48+
DataUrl(String),
49+
InlineBytes(Vec<u8>),
50+
File(PathBuf),
51+
}
52+
53+
#[derive(Clone)]
54+
pub struct MediaConnector {
55+
client: Client,
56+
allowed_domains: Option<HashSet<String>>,
57+
allowed_local_media_path: Option<PathBuf>,
58+
fetch_timeout: Duration,
59+
}
60+
61+
impl MediaConnector {
62+
pub fn new(client: Client, config: MediaConnectorConfig) -> Result<Self, MediaConnectorError> {
63+
let allowed_domains = config.allowed_domains.map(|domains| {
64+
domains
65+
.into_iter()
66+
.map(|d| d.to_ascii_lowercase())
67+
.collect::<HashSet<_>>()
68+
});
69+
70+
let allowed_local_media_path = if let Some(path) = config.allowed_local_media_path {
71+
Some(std::fs::canonicalize(path)?)
72+
} else {
73+
None
74+
};
75+
76+
Ok(Self {
77+
client,
78+
allowed_domains,
79+
allowed_local_media_path,
80+
fetch_timeout: config.fetch_timeout,
81+
})
82+
}
83+
84+
pub async fn fetch_image(
85+
&self,
86+
source: MediaSource,
87+
cfg: ImageFetchConfig,
88+
) -> Result<Arc<ImageFrame>, MediaConnectorError> {
89+
match source {
90+
MediaSource::Url(url) => self.fetch_http_image(url, cfg).await,
91+
MediaSource::DataUrl(data_url) => self.fetch_data_url(data_url, cfg).await,
92+
MediaSource::InlineBytes(bytes) => {
93+
self.decode_image(bytes.into(), cfg.detail, ImageSource::InlineBytes)
94+
.await
95+
}
96+
MediaSource::File(path) => self.fetch_file(path, cfg).await,
97+
}
98+
}
99+
100+
async fn fetch_http_image(
101+
&self,
102+
url: String,
103+
cfg: ImageFetchConfig,
104+
) -> Result<Arc<ImageFrame>, MediaConnectorError> {
105+
let parsed = Url::parse(&url).map_err(|_| MediaConnectorError::InvalidUrl(url.clone()))?;
106+
self.ensure_domain_allowed(&parsed)?;
107+
108+
let mut req = self.client.get(parsed.as_str());
109+
if self.fetch_timeout > Duration::ZERO {
110+
req = req.timeout(self.fetch_timeout);
111+
}
112+
113+
let resp = req.send().await.map_err(|err| {
114+
if err.is_timeout() {
115+
MediaConnectorError::Timeout(self.fetch_timeout)
116+
} else {
117+
MediaConnectorError::Http(err)
118+
}
119+
})?;
120+
121+
let resp = resp.error_for_status()?;
122+
let bytes = resp.bytes().await?;
123+
self.decode_image(
124+
bytes,
125+
cfg.detail,
126+
ImageSource::Url {
127+
url: parsed.to_string(),
128+
},
129+
)
130+
.await
131+
}
132+
133+
async fn fetch_data_url(
134+
&self,
135+
data_url: String,
136+
cfg: ImageFetchConfig,
137+
) -> Result<Arc<ImageFrame>, MediaConnectorError> {
138+
let (metadata, data) = data_url
139+
.split_once(',')
140+
.ok_or_else(|| MediaConnectorError::DataUrl("missing comma in data url".into()))?;
141+
142+
if !metadata.ends_with(";base64") {
143+
return Err(MediaConnectorError::DataUrl(
144+
"only base64 encoded data URLs are supported".into(),
145+
));
146+
}
147+
148+
let data = data.trim();
149+
let decoded = BASE64_STANDARD.decode(data)?;
150+
self.decode_image(decoded.into(), cfg.detail, ImageSource::DataUrl)
151+
.await
152+
}
153+
154+
async fn fetch_file(
155+
&self,
156+
path: PathBuf,
157+
cfg: ImageFetchConfig,
158+
) -> Result<Arc<ImageFrame>, MediaConnectorError> {
159+
let allowed_root = self
160+
.allowed_local_media_path
161+
.as_ref()
162+
.ok_or_else(|| MediaConnectorError::DisallowedLocalPath(path.display().to_string()))?;
163+
164+
let canonical = fs::canonicalize(&path).await?;
165+
if !canonical.starts_with(allowed_root) {
166+
return Err(MediaConnectorError::DisallowedLocalPath(
167+
path.display().to_string(),
168+
));
169+
}
170+
171+
let bytes = fs::read(&canonical).await?;
172+
self.decode_image(
173+
bytes.into(),
174+
cfg.detail,
175+
ImageSource::File { path: canonical },
176+
)
177+
.await
178+
}
179+
180+
fn ensure_domain_allowed(&self, url: &Url) -> Result<(), MediaConnectorError> {
181+
if let Some(allowed) = &self.allowed_domains {
182+
let host = url
183+
.host_str()
184+
.map(|h| h.to_ascii_lowercase())
185+
.ok_or_else(|| MediaConnectorError::InvalidUrl(url.to_string()))?;
186+
if !allowed.contains(&host) {
187+
return Err(MediaConnectorError::DisallowedDomain(host));
188+
}
189+
}
190+
Ok(())
191+
}
192+
193+
async fn decode_image(
194+
&self,
195+
bytes: Bytes,
196+
detail: ImageDetail,
197+
source: ImageSource,
198+
) -> Result<Arc<ImageFrame>, MediaConnectorError> {
199+
let raw: Arc<Vec<u8>> = Arc::new(bytes.to_vec());
200+
let raw_clone = raw.clone();
201+
let image: DynamicImage =
202+
task::spawn_blocking(move || image::load_from_memory(&raw_clone)).await??;
203+
204+
Ok(Arc::new(ImageFrame::new(image, raw, detail, source)))
205+
}
206+
}

sgl-router/src/multimodal/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
pub mod error;
2+
pub mod media;
3+
pub mod registry;
4+
pub mod tracker;
5+
pub mod types;
6+
7+
pub use error::{MediaConnectorError, MultiModalError, MultiModalResult};
8+
pub use media::{ImageFetchConfig, MediaConnector, MediaConnectorConfig, MediaSource};
9+
pub use registry::{ModelProcessorSpec, ModelRegistry};
10+
pub use tracker::{AsyncMultiModalTracker, TrackerConfig, TrackerOutput};
11+
pub use types::*;

0 commit comments

Comments
 (0)