Skip to content

Commit 40a369d

Browse files
authored
fix(tui): restore sample media previews from dataset sources (#1224)
1 parent 23b5908 commit 40a369d

File tree

3 files changed

+530
-1
lines changed

3 files changed

+530
-1
lines changed

lmms_eval/evaluator.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import base64
12
import collections
23
import copy
34
import itertools
45
import json
6+
import mimetypes
57
import os
68
import random
79
import re
@@ -65,6 +67,79 @@
6567
simple_parse_args_string,
6668
)
6769

70+
IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tif", ".tiff")
71+
72+
73+
def _looks_like_image_ref(value: str) -> bool:
74+
lowered = value.lower().strip()
75+
if lowered.startswith("data:image/"):
76+
return True
77+
if lowered.startswith(("http://", "https://", "file://")):
78+
return any(ext in lowered for ext in IMAGE_EXTENSIONS)
79+
return lowered.endswith(IMAGE_EXTENSIONS)
80+
81+
82+
def _guess_image_mime(path_hint: Optional[str]) -> str:
83+
if path_hint:
84+
guessed, _ = mimetypes.guess_type(path_hint)
85+
if guessed and guessed.startswith("image/"):
86+
return guessed
87+
return "image/png"
88+
89+
90+
def _append_image_source(target: list[str], source: str, seen: set[str], max_items: int) -> None:
91+
if not source or source in seen or len(target) >= max_items:
92+
return
93+
seen.add(source)
94+
target.append(source)
95+
96+
97+
def _extract_image_sources(value, out: list[str], seen: set[str], max_items: int = 4, max_inline_bytes: int = 300_000) -> None:
98+
if len(out) >= max_items:
99+
return
100+
101+
if isinstance(value, str):
102+
if _looks_like_image_ref(value):
103+
_append_image_source(out, value, seen, max_items)
104+
return
105+
106+
if isinstance(value, dict):
107+
path_hint: Optional[str] = None
108+
for key in ("url", "uri", "path", "image", "image_url", "image_path"):
109+
candidate = value.get(key)
110+
if isinstance(candidate, str):
111+
if key == "path":
112+
path_hint = candidate
113+
if _looks_like_image_ref(candidate):
114+
_append_image_source(out, candidate, seen, max_items)
115+
116+
raw_bytes = value.get("bytes")
117+
if isinstance(raw_bytes, (bytes, bytearray)) and 0 < len(raw_bytes) <= max_inline_bytes and len(out) < max_items:
118+
mime = _guess_image_mime(path_hint)
119+
encoded = base64.b64encode(raw_bytes).decode("ascii")
120+
_append_image_source(out, f"data:{mime};base64,{encoded}", seen, max_items)
121+
122+
for nested in value.values():
123+
_extract_image_sources(nested, out, seen, max_items=max_items, max_inline_bytes=max_inline_bytes)
124+
return
125+
126+
if isinstance(value, (list, tuple)):
127+
for item in value:
128+
_extract_image_sources(item, out, seen, max_items=max_items, max_inline_bytes=max_inline_bytes)
129+
if len(out) >= max_items:
130+
break
131+
132+
133+
def _collect_input_media(doc: dict, request_args: list) -> list[str]:
134+
sources: list[str] = []
135+
seen: set[str] = set()
136+
_extract_image_sources(doc, sources, seen)
137+
for arg in request_args:
138+
if len(sources) >= 4:
139+
break
140+
_extract_image_sources(arg, sources, seen)
141+
return sources
142+
68143

69144
@positional_deprecated
70145
def simple_evaluate(
@@ -1106,6 +1181,8 @@ def evaluate(
11061181
# else:
11071182
# filtered_arguments.append(_handle_non_serializable(value))
11081183

1184+
input_media = _collect_input_media(doc, filtered_arguments)
1185+
11091186
per_sample_tc = []
11101187
for req in requests:
11111188
if req.token_counts:
@@ -1131,6 +1208,8 @@ def evaluate(
11311208
)
11321209
),
11331210
}
1211+
if input_media:
1212+
example["input_media"] = input_media
11341213
example.update(metrics)
11351214
task_output.logged_samples.append(example)
11361215
for metric, value in metrics.items():

lmms_eval/tui/server.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from __future__ import annotations
66

77
import asyncio
8+
import io
89
import json
10+
import mimetypes
911
import os
1012
import platform
1113
import re
@@ -25,6 +27,9 @@
2527
from fastapi.staticfiles import StaticFiles
2628
from pydantic import BaseModel
2729

30+
from lmms_eval import utils as lmms_utils
31+
from lmms_eval.tasks import TaskManager
32+
from lmms_eval.tasks._task_utils.media_resolver import resolve_media_reference
2833
from lmms_eval.tui.discovery import get_discovery_cache
2934

3035
app = FastAPI(title="LMMs-Eval Web UI", version="0.1.0")
@@ -42,6 +47,180 @@
4247

4348
# In-memory job storage
4449
_jobs: dict[str, dict[str, Any]] = {}
50+
_task_manager: TaskManager | None = None
51+
_dataset_cache: dict[tuple[str, str | None, str], Any] = {}
52+
53+
54+
def _get_task_manager() -> TaskManager:
55+
global _task_manager
56+
if _task_manager is None:
57+
_task_manager = TaskManager(verbosity="ERROR")
58+
return _task_manager
59+
60+
61+
def _get_task_dataset_spec(task_name: str) -> tuple[str, str | None, str, dict[str, Any]]:
62+
manager = _get_task_manager()
63+
info = manager.task_index.get(task_name)
64+
if info is None or info.get("type") != "task":
65+
raise HTTPException(status_code=404, detail="Task not found")
66+
67+
yaml_path = info.get("yaml_path")
68+
if not yaml_path or yaml_path == -1:
69+
raise HTTPException(status_code=404, detail="Task config not found")
70+
71+
config = lmms_utils.load_yaml_config(yaml_path, mode="full")
72+
if not isinstance(config, dict):
73+
raise HTTPException(status_code=500, detail="Task config is invalid")
74+
75+
dataset_path = config.get("dataset_path")
76+
if not isinstance(dataset_path, str) or not dataset_path:
77+
raise HTTPException(status_code=404, detail="Task dataset is not configured")
78+
79+
dataset_name = config.get("dataset_name")
80+
if dataset_name is not None and not isinstance(dataset_name, str):
81+
dataset_name = None
82+
83+
split: str | None = None
84+
for key in ("test_split", "validation_split", "train_split", "split"):
85+
value = config.get(key)
86+
if isinstance(value, str) and value:
87+
split = value
88+
break
89+
if split is None:
90+
raise HTTPException(status_code=404, detail="Task split is not configured")
91+
92+
dataset_kwargs = config.get("dataset_kwargs")
93+
if not isinstance(dataset_kwargs, dict):
94+
dataset_kwargs = {}
95+
96+
return dataset_path, dataset_name, split, dataset_kwargs
97+
98+
99+
def _get_dataset(dataset_path: str, dataset_name: str | None, split: str, dataset_kwargs: dict[str, Any]):
100+
cache_key = (dataset_path, dataset_name, split)
101+
if cache_key in _dataset_cache:
102+
return _dataset_cache[cache_key]
103+
104+
from datasets import load_dataset
105+
106+
kwargs = dict(dataset_kwargs)
107+
if dataset_name:
108+
dataset = load_dataset(dataset_path, dataset_name, split=split, **kwargs)
109+
else:
110+
dataset = load_dataset(dataset_path, split=split, **kwargs)
111+
_dataset_cache[cache_key] = dataset
112+
return dataset
113+
114+
115+
def _serialize_pil_image(image) -> tuple[bytes, str]:
116+
fmt = getattr(image, "format", None)
117+
pil_format = fmt.upper() if isinstance(fmt, str) and fmt else "PNG"
118+
mime = f"image/{pil_format.lower()}"
119+
if pil_format == "JPG":
120+
pil_format = "JPEG"
121+
mime = "image/jpeg"
122+
123+
buffer = io.BytesIO()
124+
image.save(buffer, format=pil_format)
125+
return buffer.getvalue(), mime
126+
127+
128+
def _extract_image_blob(value: Any) -> tuple[bytes, str] | None:
129+
if value is None:
130+
return None
131+
132+
try:
133+
from PIL import Image
134+
135+
if isinstance(value, Image.Image):
136+
return _serialize_pil_image(value)
137+
except ImportError:
138+
pass
139+
140+
if isinstance(value, (bytes, bytearray)):
141+
return bytes(value), "image/png"
142+
143+
if isinstance(value, dict):
144+
raw_bytes = value.get("bytes")
145+
if isinstance(raw_bytes, (bytes, bytearray)):
146+
path_hint = value.get("path") if isinstance(value.get("path"), str) else None
147+
guessed, _ = mimetypes.guess_type(path_hint or "")
148+
media_type = guessed if guessed and guessed.startswith("image/") else "image/png"
149+
return bytes(raw_bytes), media_type
150+
151+
for candidate_key in ("image", "img", "picture"):
152+
if candidate_key in value:
153+
nested = _extract_image_blob(value[candidate_key])
154+
if nested is not None:
155+
return nested
156+
157+
for nested_value in value.values():
158+
nested = _extract_image_blob(nested_value)
159+
if nested is not None:
160+
return nested
161+
162+
if isinstance(value, (list, tuple)):
163+
for item in value:
164+
nested = _extract_image_blob(item)
165+
if nested is not None:
166+
return nested
167+
168+
return None
169+
170+
171+
def _extract_video_path(value: Any, dataset_cache_dir: str | None) -> str | None:
172+
if value is None:
173+
return None
174+
175+
if isinstance(value, str):
176+
resolved = resolve_media_reference(value, media_type="video", cache_dir=dataset_cache_dir)
177+
if isinstance(resolved, str) and Path(resolved).exists():
178+
return resolved
179+
return None
180+
181+
if isinstance(value, dict):
182+
for key in ("video", "video_path", "path", "file", "clip_path"):
183+
candidate = value.get(key)
184+
path = _extract_video_path(candidate, dataset_cache_dir)
185+
if path is not None:
186+
return path
187+
188+
for nested in value.values():
189+
path = _extract_video_path(nested, dataset_cache_dir)
190+
if path is not None:
191+
return path
192+
193+
if isinstance(value, (list, tuple)):
194+
for item in value:
195+
path = _extract_video_path(item, dataset_cache_dir)
196+
if path is not None:
197+
return path
198+
199+
return None
200+
201+
202+
def _resolve_dataset_media(task_name: str, doc_id: int) -> tuple[str, bytes | str, str]:
203+
dataset_path, dataset_name, split, dataset_kwargs = _get_task_dataset_spec(task_name)
204+
dataset = _get_dataset(dataset_path, dataset_name, split, dataset_kwargs)
205+
206+
try:
207+
record = dataset[doc_id]
208+
except Exception as exc: # noqa: BLE001
209+
raise HTTPException(status_code=404, detail="Sample doc_id not found in dataset") from exc
210+
211+
image_payload = _extract_image_blob(record)
212+
if image_payload is not None:
213+
image_bytes, media_type = image_payload
214+
return "bytes", image_bytes, media_type
215+
216+
cache_dir = dataset_kwargs.get("cache_dir") if isinstance(dataset_kwargs.get("cache_dir"), str) else None
217+
video_path = _extract_video_path(record, cache_dir)
218+
if video_path is not None:
219+
guessed, _ = mimetypes.guess_type(video_path)
220+
media_type = guessed if guessed and guessed.startswith("video/") else "video/mp4"
221+
return "file", video_path, media_type
222+
223+
raise HTTPException(status_code=404, detail="No image/video found in dataset sample")
45224

46225

47226
def get_version() -> str:
@@ -712,6 +891,37 @@ async def get_log_run_samples(
712891
return LogSamplesResponse(samples=samples, total=total, offset=offset, limit=limit)
713892

714893

894+
@app.get("/logs/runs/{run_id:path}/samples/{task_name}/media/{doc_id}")
895+
async def get_log_run_sample_media(
896+
run_id: str,
897+
task_name: str,
898+
doc_id: int,
899+
logs_path: str = Query("./logs/"),
900+
):
901+
run_path = _resolve_run_results_path(logs_path, run_id)
902+
if not run_path.name.endswith("_results.json"):
903+
raise HTTPException(status_code=404, detail="Run results not found")
904+
if not run_path.exists() or not run_path.is_file():
905+
raise HTTPException(status_code=404, detail="Run results not found")
906+
if "/" in task_name or "\\" in task_name:
907+
raise HTTPException(status_code=400, detail="Invalid task name")
908+
if doc_id < 0:
909+
raise HTTPException(status_code=400, detail="Invalid doc_id")
910+
911+
mode, payload, media_type = await asyncio.to_thread(_resolve_dataset_media, task_name, doc_id)
912+
if mode == "file":
913+
return FileResponse(
914+
path=str(payload),
915+
media_type=media_type,
916+
headers={"Cache-Control": "public, max-age=3600"},
917+
)
918+
return StreamingResponse(
919+
io.BytesIO(payload),
920+
media_type=media_type,
921+
headers={"Cache-Control": "public, max-age=3600"},
922+
)
923+
924+
715925
if STATIC_DIR.exists():
716926
app.mount("/assets", StaticFiles(directory=STATIC_DIR / "assets"), name="assets")
717927

0 commit comments

Comments
 (0)