diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 0ecb8370de7c..3819cc87a94a 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -115,7 +115,11 @@ def run_eval(args): # VLM MMMU evaluation with fixed 100 examples by default from sglang.test.simple_eval_mmmu_vlm import MMMUVLMEval - eval_obj = MMMUVLMEval(args.num_examples, args.num_threads) + eval_obj = MMMUVLMEval( + args.num_examples, + args.num_threads, + response_answer_regex=getattr(args, "response_answer_regex", None), + ) else: raise ValueError(f"Invalid eval name: {args.eval_name}") diff --git a/python/sglang/test/simple_eval_mmmu_vlm.py b/python/sglang/test/simple_eval_mmmu_vlm.py index f13cfd68793c..de0cfbebb62f 100644 --- a/python/sglang/test/simple_eval_mmmu_vlm.py +++ b/python/sglang/test/simple_eval_mmmu_vlm.py @@ -7,6 +7,7 @@ import base64 import io +import re from typing import List, Optional, Tuple from datasets import concatenate_datasets, load_dataset @@ -53,7 +54,11 @@ class MMMUVLMEval(Eval): } def __init__( - self, num_examples: Optional[int] = 100, num_threads: int = 32, seed: int = 42 + self, + num_examples: Optional[int] = 100, + num_threads: int = 32, + seed: int = 42, + response_answer_regex: str = None, ): """Create MMMU VLM eval (Math subset, 100 fixed samples by default).""" self.num_examples = num_examples @@ -61,6 +66,10 @@ def __init__( self.seed = seed # Prepare samples deterministically across all MMMU subjects (validation split) self.samples = self._prepare_mmmu_samples(self.num_examples) + # For example, "<\|begin_of_box\|>foo<\|end_of_box\|>" could be used to extract "foo" as the answer from the response text + self.response_answer_regex = ( + response_answer_regex if response_answer_regex is not None else "(.*)" + ) @staticmethod def _to_data_uri(image: Image.Image) -> str: @@ -205,6 +214,14 @@ def fn(sample: dict): # Sample response_text = sampler(prompt_messages) response_text = response_text or "" + match = ( + re.search(self.response_answer_regex, response_text) + if response_text is not None + else None + ) + response_text = ( + match.group(1).strip() if match is not None else response_text + ) # Parse and score gold = sample["answer"] diff --git a/test/srt/test_pp_single_node.py b/test/srt/test_pp_single_node.py index 037cdb2e0884..86adddaec0a4 100644 --- a/test/srt/test_pp_single_node.py +++ b/test/srt/test_pp_single_node.py @@ -336,6 +336,8 @@ def setUpClass(cls): "--chunked-prefill-size", 8192, "--enable-multimodal", + "--reasoning-parser", + "glm45", ], ) @@ -350,10 +352,12 @@ def test_mmmu(self): eval_name="mmmu", num_examples=None, num_threads=32, + response_answer_regex="<\|begin_of_box\|>(.*)<\|end_of_box\|>", ) + metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.55) + self.assertGreater(metrics["score"], 0.45) if __name__ == "__main__":