Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 63 additions & 64 deletions test/srt/test_eagle_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,79 +20,78 @@


class TestEAGLEEngine(unittest.TestCase):
BASE_CONFIG = {
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
}

def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}

def test_eagle_accuracy(self):
prompt1 = "Today is a sunny day and I like"
sampling_params1 = {"temperature": 0, "max_new_tokens": 8}

# Get the reference output
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
ref_output = ref_engine.generate(prompt1, sampling_params1)["text"]
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown()

# Test cases with different configurations
def test_eagle_accuracy(self):
configs = [
# Original config
{
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
},
# Config with CUDA graph disabled
{
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
"disable_cuda_graph": True,
},
self.BASE_CONFIG,
{**self.BASE_CONFIG, "disable_cuda_graph": True},
]

for config in configs:
# Launch EAGLE engine
engine = sgl.Engine(**config)

# Case 1: Test the output of EAGLE engine is the same as normal engine
out1 = engine.generate(prompt1, sampling_params1)["text"]
print(f"{out1=}, {ref_output=}")
self.assertEqual(out1, ref_output)

# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
prompt2 = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
sampling_params2 = {
"temperature": 0,
"max_new_tokens": 1024,
"skip_special_tokens": False,
}
with self.subTest(
cuda_graph=(
"enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
)
):
engine = sgl.Engine(**config)
try:
self._test_basic_generation(engine)
self._test_eos_token(engine)
self._test_batch_generation(engine)
finally:
engine.shutdown()

def _test_basic_generation(self, engine):
output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{output=}, {self.ref_output=}")
self.assertEqual(output, self.ref_output)

def _test_eos_token(self, engine):
prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
params = {
"temperature": 0,
"max_new_tokens": 1024,
"skip_special_tokens": False,
}

tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
output = engine.generate(prompt, params)["text"]
print(f"{output=}")

tokens = tokenizer.encode(output, truncation=False)
self.assertNotIn(tokenizer.eos_token_id, tokens)

def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
params = {"temperature": 0, "max_new_tokens": 30}

tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
out2 = engine.generate(prompt2, sampling_params2)["text"]
print(f"{out2=}")
tokens = tokenizer.encode(out2, truncation=False)
assert tokenizer.eos_token_id not in tokens

# Case 3: Batched prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params3 = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, sampling_params3)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")

# Shutdown the engine
engine.shutdown()
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)


prompts = [
Expand Down
Loading