From cf4a706a521df06508fce2b1f21d70c47d11a288 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sat, 8 Feb 2025 17:29:42 -0800 Subject: [PATCH] minor: cleanup test_eagle_infer --- test/srt/test_eagle_infer.py | 127 +++++++++++++++++------------------ 1 file changed, 63 insertions(+), 64 deletions(-) diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index b04b132110b5..4a617032092d 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -20,79 +20,78 @@ class TestEAGLEEngine(unittest.TestCase): + BASE_CONFIG = { + "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "speculative_algorithm": "EAGLE", + "speculative_num_steps": 5, + "speculative_eagle_topk": 8, + "speculative_num_draft_tokens": 64, + "mem_fraction_static": 0.7, + } + + def setUp(self): + self.prompt = "Today is a sunny day and I like" + self.sampling_params = {"temperature": 0, "max_new_tokens": 8} - def test_eagle_accuracy(self): - prompt1 = "Today is a sunny day and I like" - sampling_params1 = {"temperature": 0, "max_new_tokens": 8} - - # Get the reference output ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) - ref_output = ref_engine.generate(prompt1, sampling_params1)["text"] + self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"] ref_engine.shutdown() - # Test cases with different configurations + def test_eagle_accuracy(self): configs = [ - # Original config - { - "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, - "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - "speculative_algorithm": "EAGLE", - "speculative_num_steps": 5, - "speculative_eagle_topk": 8, - "speculative_num_draft_tokens": 64, - "mem_fraction_static": 0.7, - }, - # Config with CUDA graph disabled - { - "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, - "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - "speculative_algorithm": "EAGLE", - "speculative_num_steps": 5, - "speculative_eagle_topk": 8, - "speculative_num_draft_tokens": 64, - "mem_fraction_static": 0.7, - "disable_cuda_graph": True, - }, + self.BASE_CONFIG, + {**self.BASE_CONFIG, "disable_cuda_graph": True}, ] for config in configs: - # Launch EAGLE engine - engine = sgl.Engine(**config) - - # Case 1: Test the output of EAGLE engine is the same as normal engine - out1 = engine.generate(prompt1, sampling_params1)["text"] - print(f"{out1=}, {ref_output=}") - self.assertEqual(out1, ref_output) - - # Case 2: Test the output of EAGLE engine does not contain unexpected EOS - prompt2 = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]" - sampling_params2 = { - "temperature": 0, - "max_new_tokens": 1024, - "skip_special_tokens": False, - } + with self.subTest( + cuda_graph=( + "enabled" if len(config) == len(self.BASE_CONFIG) else "disabled" + ) + ): + engine = sgl.Engine(**config) + try: + self._test_basic_generation(engine) + self._test_eos_token(engine) + self._test_batch_generation(engine) + finally: + engine.shutdown() + + def _test_basic_generation(self, engine): + output = engine.generate(self.prompt, self.sampling_params)["text"] + print(f"{output=}, {self.ref_output=}") + self.assertEqual(output, self.ref_output) + + def _test_eos_token(self, engine): + prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]" + params = { + "temperature": 0, + "max_new_tokens": 1024, + "skip_special_tokens": False, + } + + tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) + output = engine.generate(prompt, params)["text"] + print(f"{output=}") + + tokens = tokenizer.encode(output, truncation=False) + self.assertNotIn(tokenizer.eos_token_id, tokens) + + def _test_batch_generation(self, engine): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + params = {"temperature": 0, "max_new_tokens": 30} - tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) - out2 = engine.generate(prompt2, sampling_params2)["text"] - print(f"{out2=}") - tokens = tokenizer.encode(out2, truncation=False) - assert tokenizer.eos_token_id not in tokens - - # Case 3: Batched prompts - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params3 = {"temperature": 0, "max_new_tokens": 30} - outputs = engine.generate(prompts, sampling_params3) - for prompt, output in zip(prompts, outputs): - print("===============================") - print(f"Prompt: {prompt}\nGenerated text: {output['text']}") - - # Shutdown the engine - engine.shutdown() + outputs = engine.generate(prompts, params) + for prompt, output in zip(prompts, outputs): + print(f"Prompt: {prompt}") + print(f"Generated: {output['text']}") + print("-" * 40) prompts = [