-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Add Logprobs unit test with a loose threshold #10230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
883a91c
add test script
PrinsYin 10e4edf
Merge branch 'sgl-project:main' into logprobs
PrinsYin 20fd382
deleted generate
PrinsYin c9dd1a8
added github summary, changed to prompt with 200 average len
PrinsYin bd4d4ba
1
PrinsYin c70105e
Merge branch 'main' into logprobs
PrinsYin d4cf08f
1
PrinsYin 3f6562d
add AMD CI
yushengsu-thu 5d462b0
fix model name
PrinsYin 8785103
Merge branch 'main' into logprobs
zhaochenyang20 e6423cc
Merge branch 'main' into logprobs
zhaochenyang20 a015def
added only ask for logprobs in some of the requests
PrinsYin 8f7328b
1
PrinsYin 1d6d13d
Merge branch 'main' into logprobs
PrinsYin d7aa5d7
fix lint
PrinsYin 89650bf
added to run suite
PrinsYin 5f7546c
cleanup
PrinsYin 209b330
cleanup
PrinsYin 6d228d7
precommit
PrinsYin ad37403
Merge branch 'main' into logprobs
PrinsYin 8ccc321
Update test_logprobs.py
zhaochenyang20 3996faa
Update test_logprobs.py
zhaochenyang20 b16fe4e
Update test_logprobs.py
zhaochenyang20 01d571d
1
PrinsYin e75c092
Merge branch 'main' into logprobs
PrinsYin 5ae3701
threashold, return original
PrinsYin 3c8d2da
Merge branch 'main' into logprobs
zhaochenyang20 fe67652
naming
a6c0c8b
1
2537ae4
Merge branch 'main' into logprobs
PrinsYin fbf987f
1
PrinsYin 595da84
Merge branch 'main' into logprobs
zhaochenyang20 bd6ed55
precommit
2f3f267
Merge branch 'main' into logprobs
zhaochenyang20 9c2f74b
Update test_logprobs.py
zhaochenyang20 334c812
Merge branch 'main' into logprobs
zhaochenyang20 20f3291
Update run_suite.py
zhaochenyang20 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,265 @@ | ||
| import io | ||
| import os | ||
| import pickle | ||
| import random | ||
| import time | ||
| import unittest | ||
|
|
||
| import numpy as np | ||
| import requests | ||
| import torch | ||
|
|
||
| import sglang as sgl | ||
| from sglang.test.test_utils import ( | ||
| DEFAULT_SMALL_MODEL_NAME_FOR_TEST, | ||
| write_github_step_summary, | ||
| ) | ||
|
|
||
| # Dense model configuration | ||
| DENSE_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST | ||
| if torch.version.hip is not None: | ||
| print("Running on AMD ROCm GPU") | ||
| DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/yushengsu/logprobs/resolve/main/sglang_baseline_2000_amd.pkl" | ||
| DENSE_TOLERANCE_MAX_DIFF = 1.4 | ||
| DENSE_TOLERANCE_MEAN_DIFF = 0.1 | ||
| elif torch.version.cuda is not None: | ||
| print("Running on NVIDIA CUDA GPU") | ||
| DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/font-info/logprobs/resolve/main/sglang_baseline_2000.pkl" | ||
| DENSE_TOLERANCE_MAX_DIFF = 1.5 | ||
| DENSE_TOLERANCE_MEAN_DIFF = 0.1 | ||
| else: | ||
| print("No GPU backend (CPU only)") | ||
|
|
||
| # Common configuration | ||
| TOP_K = 20 | ||
| MAX_RETRIES = 3 | ||
| RETRY_DELAY = 2 | ||
| NUM_SAMPLES = 1000 | ||
| LOGPROB_SAMPLE_RATIO = 0.5 | ||
| TEMPERATURE = 1.0 | ||
|
|
||
|
|
||
| class TestLogprobsDense(unittest.TestCase): | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| """Set up the test class - initialize the engine once for all tests.""" | ||
| print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...") | ||
| cls.engine = sgl.Engine( | ||
| model_path=DENSE_MODEL_NAME, | ||
| random_seed=42, | ||
| skip_tokenizer_init=True, | ||
| mem_fraction_static=0.85, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| """Clean up after all tests - shutdown the engine.""" | ||
| cls.engine.shutdown() | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def load_test_data(self): | ||
| """Load test data from Hugging Face dataset with retry mechanism.""" | ||
| print(f"Loading data from {DENSE_INPUT_PKL_URL}...") | ||
|
|
||
| for attempt in range(MAX_RETRIES): | ||
| try: | ||
| response = requests.get(DENSE_INPUT_PKL_URL, timeout=30) | ||
| response.raise_for_status() | ||
|
|
||
| with io.BytesIO(response.content) as f: | ||
| records = pickle.load(f) | ||
|
|
||
| if not records: | ||
| raise ValueError("Empty dataset") | ||
|
|
||
| print(f"Successfully loaded {len(records)} records") | ||
| return records | ||
|
|
||
| except Exception as e: | ||
| print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}") | ||
| if attempt == MAX_RETRIES - 1: | ||
| raise Exception( | ||
| f"Failed to load data after {MAX_RETRIES} attempts: {e}" | ||
| ) | ||
| time.sleep(RETRY_DELAY) | ||
|
|
||
| def compare_meta(self, baseline_meta, sglang_meta): | ||
| """Compare metadata between two outputs and return max and mean differences.""" | ||
| diffs = [] | ||
| for key in ["input_top_logprobs", "output_top_logprobs"]: | ||
| baseline_logprobs, sglang_logprobs = baseline_meta[key], sglang_meta[key] | ||
| self.assertEqual( | ||
| len(baseline_logprobs), | ||
| len(sglang_logprobs), | ||
| f"Length of {key} is not equal, sglang did not return the correct number of log probs(should be top 20)", | ||
| ) | ||
| for baseline_entry, sglang_entry in zip(baseline_logprobs, sglang_logprobs): | ||
| if not baseline_entry or not sglang_entry: | ||
| continue | ||
| baseline_token_map = {tid: lp for lp, tid, _ in baseline_entry} | ||
| sglang_token_map = {tid: lp for lp, tid, _ in sglang_entry} | ||
| common_tokens = baseline_token_map.keys() & sglang_token_map.keys() | ||
| self.assertGreaterEqual( | ||
| len(common_tokens), | ||
| TOP_K / 2, | ||
| f"there are only {len(common_tokens)} common topk tokens that matches", | ||
| ) | ||
| for token_id in common_tokens: | ||
| diffs.append( | ||
| abs(baseline_token_map[token_id] - sglang_token_map[token_id]) | ||
| ) | ||
| return max(diffs), float(np.mean(diffs)) | ||
|
|
||
| def test_logprobs_comparison(self): | ||
| """Test the logprobs comparison functionality with different parameter combinations.""" | ||
| # Load test data with retry mechanism | ||
| records = self.load_test_data() | ||
|
|
||
| with self.subTest( | ||
| config={ | ||
| "num_samples": NUM_SAMPLES, | ||
| "logprob_sample_ratio": LOGPROB_SAMPLE_RATIO, | ||
| "temperature": TEMPERATURE, | ||
| } | ||
| ): | ||
|
|
||
| # Sample records for this config | ||
| test_records = random.sample(records, k=min(NUM_SAMPLES, len(records))) | ||
| random.shuffle(test_records) | ||
|
|
||
| # Calculate how many samples should return logprobs | ||
| logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO) | ||
| print( | ||
| f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}" | ||
| ) | ||
| print( | ||
| f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})" | ||
| ) | ||
|
|
||
| all_max, all_mean = [], [] | ||
| logprob_returned_count = 0 | ||
|
|
||
| # Process all records at once | ||
| input_ids = [rec["ids"] for rec in test_records] | ||
| logprob_start_lens = [rec["start_pos"] for rec in test_records] | ||
|
|
||
| # Determine which samples should return logprobs (randomly selected) | ||
| logprob_indices = set( | ||
| random.sample(range(len(test_records)), logprob_count) | ||
| ) | ||
| return_logprob_array = [ | ||
| sample_idx in logprob_indices for sample_idx in range(len(test_records)) | ||
| ] | ||
|
|
||
| # Sampling param per request | ||
| sampling_params = [ | ||
| { | ||
| "temperature": TEMPERATURE, | ||
| "top_p": 1.0, | ||
| "top_k": TOP_K, | ||
| "max_new_tokens": 1, | ||
| } | ||
| for _ in test_records | ||
| ] | ||
|
|
||
| outputs = self.engine.generate( | ||
| input_ids=input_ids, | ||
| sampling_params=sampling_params, | ||
| return_logprob=return_logprob_array, | ||
| logprob_start_len=logprob_start_lens, | ||
| top_logprobs_num=TOP_K, | ||
| ) | ||
|
|
||
| for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)): | ||
| # Only compare logprobs for samples that should have them | ||
| if sample_idx in logprob_indices: | ||
| # Safe access to meta_info and input_top_logprobs | ||
| meta_info = output.get("meta_info") | ||
| input_top_logprobs = ( | ||
| meta_info.get("input_top_logprobs") if meta_info else None | ||
| ) | ||
|
|
||
| self.assertIsNotNone( | ||
| input_top_logprobs, | ||
| f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})", | ||
| ) | ||
| baseline_meta = rec["meta"] | ||
| sglang_meta = meta_info | ||
|
|
||
| max_diff, mean_diff = self.compare_meta(baseline_meta, sglang_meta) | ||
| all_max.append(max_diff) | ||
| all_mean.append(mean_diff) | ||
| logprob_returned_count += 1 | ||
| else: | ||
| # Verify that logprobs were not returned for this sample | ||
| meta_info = output.get("meta_info") | ||
| input_top_logprobs = ( | ||
| meta_info.get("input_top_logprobs") if meta_info else None | ||
| ) | ||
| output_token_ids_logprobs = ( | ||
| meta_info.get("output_token_ids_logprobs") | ||
| if meta_info | ||
| else None | ||
| ) | ||
|
|
||
| self.assertFalse( | ||
| input_top_logprobs, | ||
| f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}", | ||
| ) | ||
|
|
||
| max_of_max = max(all_max) if all_max else 0.0 | ||
| mean_of_mean = np.mean(all_mean) if all_mean else 0.0 | ||
|
|
||
| print(f"max Δ={max_of_max:.6g}") | ||
| print(f"mean Δ={mean_of_mean:.6g}") | ||
| print( | ||
| f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})" | ||
| ) | ||
|
|
||
| # Verify correct number of logprobs returned | ||
| self.assertEqual( | ||
| logprob_returned_count, | ||
| logprob_count, | ||
| f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}", | ||
| ) | ||
|
|
||
| # Write results to GitHub summary | ||
| summary_content = f""" | ||
| - **Configuration**: {{"num_samples": {NUM_SAMPLES}, "logprob_sample_ratio": {LOGPROB_SAMPLE_RATIO}, "temperature": {TEMPERATURE}}} | ||
| - **Max of max Δ**: {max_of_max:.6g} | ||
| - **Mean of mean Δ**: {mean_of_mean:.6g} | ||
| - **Status**: {'✅ Passed' if max_of_max <= DENSE_TOLERANCE_MAX_DIFF and mean_of_mean <= DENSE_TOLERANCE_MEAN_DIFF else '❌ Failed'} | ||
| """ | ||
| write_github_step_summary(summary_content) | ||
|
|
||
| # Basic validation | ||
| self.assertIsInstance(all_max, list) | ||
| self.assertIsInstance(all_mean, list) | ||
| self.assertGreater( | ||
| len(all_max), | ||
| 0, | ||
| f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}", | ||
| ) | ||
|
|
||
| # Tolerance checks with clear error messages | ||
| failed_samples = [] | ||
| for sample_idx, (max_diff, mean_diff) in enumerate(zip(all_max, all_mean)): | ||
| if max_diff > DENSE_TOLERANCE_MAX_DIFF: | ||
| failed_samples.append( | ||
| f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}" | ||
| ) | ||
| if mean_diff > DENSE_TOLERANCE_MEAN_DIFF: | ||
| failed_samples.append( | ||
| f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}" | ||
| ) | ||
|
|
||
| if failed_samples: | ||
| self.fail( | ||
| f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n" | ||
| + "\n".join(failed_samples[:5]) | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.