diff --git a/python/sglang/multimodal_gen/docs/contributing.md b/python/sglang/multimodal_gen/docs/contributing.md index 33a4699b8fa2..d937b7ef40c8 100644 --- a/python/sglang/multimodal_gen/docs/contributing.md +++ b/python/sglang/multimodal_gen/docs/contributing.md @@ -29,12 +29,12 @@ For PRs that impact **latency**, **throughput**, or **memory usage**, you **shou 1. **Baseline**: run the benchmark (for a single generation task) ```bash - $ sglang generate --model-path --prompt "A benchmark prompt" --perf-dump-path baseline.json + $ sglang generate --model-path --prompt "A benchmark prompt" --perf-dump-path baseline.json [--warmup] ``` 2. **New**: run the same benchmark, without modifying any server_args or sampling_params ```bash - $ sglang generate --model-path --prompt "A benchmark prompt" --perf-dump-path new.json + $ sglang generate --model-path --prompt "A benchmark prompt" --perf-dump-path new.json [--warmup] ``` 3. **Compare**: run the compare script, which will print a Markdown table to the console diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py b/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py index 0aad34db258d..0a9d03e9f2b7 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py @@ -8,6 +8,8 @@ import os from typing import cast +from tqdm import tqdm + from sglang.multimodal_gen import DiffGenerator from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand @@ -41,6 +43,11 @@ def add_multimodal_gen_generate_args(parser: argparse.ArgumentParser): required=False, help="Path to dump the performance metrics (JSON) for the run.", ) + parser.add_argument( + "--warmup", + action="store_true", + help="Run a warm-up phase to exclude compilation overhead from performance measurements, if needed. Requires --perf-dump-path to be set.", + ) parser = ServerArgs.add_cli_args(parser) parser = SamplingParams.add_cli_args(parser) @@ -86,6 +93,9 @@ def maybe_dump_performance(args: argparse.Namespace, server_args, prompt: str, r def generate_cmd(args: argparse.Namespace): """The entry point for the generate command.""" + if args.warmup and args.perf_dump_path is None: + raise ValueError("--warmup requires --perf-dump-path to be specified") + args.request_id = "mocked_fake_id_for_offline_generate" server_args = ServerArgs.from_cli_args(args) @@ -94,6 +104,12 @@ def generate_cmd(args: argparse.Namespace): model_path=server_args.model_path, server_args=server_args ) + if args.warmup: + for _ in tqdm(range(10)): + generator.generate( + sampling_params_kwargs={**sampling_params_kwargs, "save_output": False} + ) + results = generator.generate(sampling_params_kwargs=sampling_params_kwargs) prompt = sampling_params_kwargs.get("prompt", None)