|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import asyncio |
| 8 | +import io |
8 | 9 | import json |
| 10 | +import mimetypes |
9 | 11 | import os |
10 | 12 | import platform |
11 | 13 | import re |
|
25 | 27 | from fastapi.staticfiles import StaticFiles |
26 | 28 | from pydantic import BaseModel |
27 | 29 |
|
| 30 | +from lmms_eval import utils as lmms_utils |
| 31 | +from lmms_eval.tasks import TaskManager |
| 32 | +from lmms_eval.tasks._task_utils.media_resolver import resolve_media_reference |
28 | 33 | from lmms_eval.tui.discovery import get_discovery_cache |
29 | 34 |
|
30 | 35 | app = FastAPI(title="LMMs-Eval Web UI", version="0.1.0") |
|
42 | 47 |
|
43 | 48 | # In-memory job storage |
44 | 49 | _jobs: dict[str, dict[str, Any]] = {} |
| 50 | +_task_manager: TaskManager | None = None |
| 51 | +_dataset_cache: dict[tuple[str, str | None, str], Any] = {} |
| 52 | + |
| 53 | + |
| 54 | +def _get_task_manager() -> TaskManager: |
| 55 | + global _task_manager |
| 56 | + if _task_manager is None: |
| 57 | + _task_manager = TaskManager(verbosity="ERROR") |
| 58 | + return _task_manager |
| 59 | + |
| 60 | + |
| 61 | +def _get_task_dataset_spec(task_name: str) -> tuple[str, str | None, str, dict[str, Any]]: |
| 62 | + manager = _get_task_manager() |
| 63 | + info = manager.task_index.get(task_name) |
| 64 | + if info is None or info.get("type") != "task": |
| 65 | + raise HTTPException(status_code=404, detail="Task not found") |
| 66 | + |
| 67 | + yaml_path = info.get("yaml_path") |
| 68 | + if not yaml_path or yaml_path == -1: |
| 69 | + raise HTTPException(status_code=404, detail="Task config not found") |
| 70 | + |
| 71 | + config = lmms_utils.load_yaml_config(yaml_path, mode="full") |
| 72 | + if not isinstance(config, dict): |
| 73 | + raise HTTPException(status_code=500, detail="Task config is invalid") |
| 74 | + |
| 75 | + dataset_path = config.get("dataset_path") |
| 76 | + if not isinstance(dataset_path, str) or not dataset_path: |
| 77 | + raise HTTPException(status_code=404, detail="Task dataset is not configured") |
| 78 | + |
| 79 | + dataset_name = config.get("dataset_name") |
| 80 | + if dataset_name is not None and not isinstance(dataset_name, str): |
| 81 | + dataset_name = None |
| 82 | + |
| 83 | + split: str | None = None |
| 84 | + for key in ("test_split", "validation_split", "train_split", "split"): |
| 85 | + value = config.get(key) |
| 86 | + if isinstance(value, str) and value: |
| 87 | + split = value |
| 88 | + break |
| 89 | + if split is None: |
| 90 | + raise HTTPException(status_code=404, detail="Task split is not configured") |
| 91 | + |
| 92 | + dataset_kwargs = config.get("dataset_kwargs") |
| 93 | + if not isinstance(dataset_kwargs, dict): |
| 94 | + dataset_kwargs = {} |
| 95 | + |
| 96 | + return dataset_path, dataset_name, split, dataset_kwargs |
| 97 | + |
| 98 | + |
| 99 | +def _get_dataset(dataset_path: str, dataset_name: str | None, split: str, dataset_kwargs: dict[str, Any]): |
| 100 | + cache_key = (dataset_path, dataset_name, split) |
| 101 | + if cache_key in _dataset_cache: |
| 102 | + return _dataset_cache[cache_key] |
| 103 | + |
| 104 | + from datasets import load_dataset |
| 105 | + |
| 106 | + kwargs = dict(dataset_kwargs) |
| 107 | + if dataset_name: |
| 108 | + dataset = load_dataset(dataset_path, dataset_name, split=split, **kwargs) |
| 109 | + else: |
| 110 | + dataset = load_dataset(dataset_path, split=split, **kwargs) |
| 111 | + _dataset_cache[cache_key] = dataset |
| 112 | + return dataset |
| 113 | + |
| 114 | + |
| 115 | +def _serialize_pil_image(image) -> tuple[bytes, str]: |
| 116 | + fmt = getattr(image, "format", None) |
| 117 | + pil_format = fmt.upper() if isinstance(fmt, str) and fmt else "PNG" |
| 118 | + mime = f"image/{pil_format.lower()}" |
| 119 | + if pil_format == "JPG": |
| 120 | + pil_format = "JPEG" |
| 121 | + mime = "image/jpeg" |
| 122 | + |
| 123 | + buffer = io.BytesIO() |
| 124 | + image.save(buffer, format=pil_format) |
| 125 | + return buffer.getvalue(), mime |
| 126 | + |
| 127 | + |
| 128 | +def _extract_image_blob(value: Any) -> tuple[bytes, str] | None: |
| 129 | + if value is None: |
| 130 | + return None |
| 131 | + |
| 132 | + try: |
| 133 | + from PIL import Image |
| 134 | + |
| 135 | + if isinstance(value, Image.Image): |
| 136 | + return _serialize_pil_image(value) |
| 137 | + except ImportError: |
| 138 | + pass |
| 139 | + |
| 140 | + if isinstance(value, (bytes, bytearray)): |
| 141 | + return bytes(value), "image/png" |
| 142 | + |
| 143 | + if isinstance(value, dict): |
| 144 | + raw_bytes = value.get("bytes") |
| 145 | + if isinstance(raw_bytes, (bytes, bytearray)): |
| 146 | + path_hint = value.get("path") if isinstance(value.get("path"), str) else None |
| 147 | + guessed, _ = mimetypes.guess_type(path_hint or "") |
| 148 | + media_type = guessed if guessed and guessed.startswith("image/") else "image/png" |
| 149 | + return bytes(raw_bytes), media_type |
| 150 | + |
| 151 | + for candidate_key in ("image", "img", "picture"): |
| 152 | + if candidate_key in value: |
| 153 | + nested = _extract_image_blob(value[candidate_key]) |
| 154 | + if nested is not None: |
| 155 | + return nested |
| 156 | + |
| 157 | + for nested_value in value.values(): |
| 158 | + nested = _extract_image_blob(nested_value) |
| 159 | + if nested is not None: |
| 160 | + return nested |
| 161 | + |
| 162 | + if isinstance(value, (list, tuple)): |
| 163 | + for item in value: |
| 164 | + nested = _extract_image_blob(item) |
| 165 | + if nested is not None: |
| 166 | + return nested |
| 167 | + |
| 168 | + return None |
| 169 | + |
| 170 | + |
| 171 | +def _extract_video_path(value: Any, dataset_cache_dir: str | None) -> str | None: |
| 172 | + if value is None: |
| 173 | + return None |
| 174 | + |
| 175 | + if isinstance(value, str): |
| 176 | + resolved = resolve_media_reference(value, media_type="video", cache_dir=dataset_cache_dir) |
| 177 | + if isinstance(resolved, str) and Path(resolved).exists(): |
| 178 | + return resolved |
| 179 | + return None |
| 180 | + |
| 181 | + if isinstance(value, dict): |
| 182 | + for key in ("video", "video_path", "path", "file", "clip_path"): |
| 183 | + candidate = value.get(key) |
| 184 | + path = _extract_video_path(candidate, dataset_cache_dir) |
| 185 | + if path is not None: |
| 186 | + return path |
| 187 | + |
| 188 | + for nested in value.values(): |
| 189 | + path = _extract_video_path(nested, dataset_cache_dir) |
| 190 | + if path is not None: |
| 191 | + return path |
| 192 | + |
| 193 | + if isinstance(value, (list, tuple)): |
| 194 | + for item in value: |
| 195 | + path = _extract_video_path(item, dataset_cache_dir) |
| 196 | + if path is not None: |
| 197 | + return path |
| 198 | + |
| 199 | + return None |
| 200 | + |
| 201 | + |
| 202 | +def _resolve_dataset_media(task_name: str, doc_id: int) -> tuple[str, bytes | str, str]: |
| 203 | + dataset_path, dataset_name, split, dataset_kwargs = _get_task_dataset_spec(task_name) |
| 204 | + dataset = _get_dataset(dataset_path, dataset_name, split, dataset_kwargs) |
| 205 | + |
| 206 | + try: |
| 207 | + record = dataset[doc_id] |
| 208 | + except Exception as exc: # noqa: BLE001 |
| 209 | + raise HTTPException(status_code=404, detail="Sample doc_id not found in dataset") from exc |
| 210 | + |
| 211 | + image_payload = _extract_image_blob(record) |
| 212 | + if image_payload is not None: |
| 213 | + image_bytes, media_type = image_payload |
| 214 | + return "bytes", image_bytes, media_type |
| 215 | + |
| 216 | + cache_dir = dataset_kwargs.get("cache_dir") if isinstance(dataset_kwargs.get("cache_dir"), str) else None |
| 217 | + video_path = _extract_video_path(record, cache_dir) |
| 218 | + if video_path is not None: |
| 219 | + guessed, _ = mimetypes.guess_type(video_path) |
| 220 | + media_type = guessed if guessed and guessed.startswith("video/") else "video/mp4" |
| 221 | + return "file", video_path, media_type |
| 222 | + |
| 223 | + raise HTTPException(status_code=404, detail="No image/video found in dataset sample") |
45 | 224 |
|
46 | 225 |
|
47 | 226 | def get_version() -> str: |
@@ -712,6 +891,37 @@ async def get_log_run_samples( |
712 | 891 | return LogSamplesResponse(samples=samples, total=total, offset=offset, limit=limit) |
713 | 892 |
|
714 | 893 |
|
| 894 | +@app.get("/logs/runs/{run_id:path}/samples/{task_name}/media/{doc_id}") |
| 895 | +async def get_log_run_sample_media( |
| 896 | + run_id: str, |
| 897 | + task_name: str, |
| 898 | + doc_id: int, |
| 899 | + logs_path: str = Query("./logs/"), |
| 900 | +): |
| 901 | + run_path = _resolve_run_results_path(logs_path, run_id) |
| 902 | + if not run_path.name.endswith("_results.json"): |
| 903 | + raise HTTPException(status_code=404, detail="Run results not found") |
| 904 | + if not run_path.exists() or not run_path.is_file(): |
| 905 | + raise HTTPException(status_code=404, detail="Run results not found") |
| 906 | + if "/" in task_name or "\\" in task_name: |
| 907 | + raise HTTPException(status_code=400, detail="Invalid task name") |
| 908 | + if doc_id < 0: |
| 909 | + raise HTTPException(status_code=400, detail="Invalid doc_id") |
| 910 | + |
| 911 | + mode, payload, media_type = await asyncio.to_thread(_resolve_dataset_media, task_name, doc_id) |
| 912 | + if mode == "file": |
| 913 | + return FileResponse( |
| 914 | + path=str(payload), |
| 915 | + media_type=media_type, |
| 916 | + headers={"Cache-Control": "public, max-age=3600"}, |
| 917 | + ) |
| 918 | + return StreamingResponse( |
| 919 | + io.BytesIO(payload), |
| 920 | + media_type=media_type, |
| 921 | + headers={"Cache-Control": "public, max-age=3600"}, |
| 922 | + ) |
| 923 | + |
| 924 | + |
715 | 925 | if STATIC_DIR.exists(): |
716 | 926 | app.mount("/assets", StaticFiles(directory=STATIC_DIR / "assets"), name="assets") |
717 | 927 |
|
|
0 commit comments