Skip to content

support speculative decoding kernel in sgl-kernel#3373

Merged
zhyncs merged 8 commits intomainfrom
ying-spec
Feb 7, 2025
Merged

support speculative decoding kernel in sgl-kernel#3373
zhyncs merged 8 commits intomainfrom
ying-spec

Conversation

@zhyncs
Copy link
Collaborator

@zhyncs zhyncs commented Feb 7, 2025

Motivation

  • support build_tree_kernel build_tree_kernel_efficient tree_speculative_sampling_target_only by @Ying1123
  • bump v0.0.3.post2
sglang git:(ying-spec) python3 python/sglang/srt/speculative/build_eagle_tree.py
=========== build tree kernel ==========
position=tensor([ 5,  6,  6,  7,  7,  8,  8,  9, 10, 11, 12, 12, 12, 12, 13, 14],
       device='cuda:0')
retrive_index=tensor([[ 0, -1, -1, -1, -1, -1],
        [ 0,  2,  4,  6, -1, -1],
        [ 0,  1,  3,  5,  7, -1],
        [ 8, -1, -1, -1, -1, -1],
        [ 8,  9, 10, -1, -1, -1],
        [ 8,  9, 12, -1, -1, -1],
        [ 8,  9, 13, -1, -1, -1],
        [ 8,  9, 11, 14, 15, -1]], device='cuda:0')
retrive_cum_len=tensor([0, 3, 8], device='cuda:0', dtype=torch.int32)
draft_tokens=tensor([29974, 29896, 29906, 29889, 29974, 29946, 29896, 29946,    13,    13,
        22550,  4136, 16492,  8439, 29871, 29941], device='cuda:0')
=========== build tree kernel efficient ==========
position=tensor([ 5,  6,  6,  7,  7,  8,  8,  9, 10, 11, 12, 12, 12, 12, 13, 14],
       device='cuda:0')
retrive_index=tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15]], device='cuda:0')
retrive_next_token=tensor([[ 1,  3,  4,  5,  6,  7, -1, -1],
        [ 1,  2, -1,  6, -1, -1,  7, -1]], device='cuda:0')
retrive_next_sibling=tensor([[-1,  2, -1, -1, -1, -1, -1, -1],
        [-1, -1,  3,  4,  5, -1, -1, -1]], device='cuda:0')
draft_tokens=tensor([29974, 29896, 29906, 29889, 29974, 29946, 29896, 29946,    13,    13,
        22550,  4136, 16492,  8439, 29871, 29941], device='cuda:0')
sgl-kernel git:(ying-spec) python3 tests/test_speculative_sampling.py
candidates=tensor([[ 0,  1,  2,  3,  4,  5],
        [ 7,  8,  9, 10, 11, 12]], device='cuda:0', dtype=torch.int32)
retrive_index=tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]], device='cuda:0', dtype=torch.int32)
retrive_next_token=tensor([[ 1,  2, -1,  4,  5, -1],
        [ 4,  2,  3, -1,  5, -1]], device='cuda:0', dtype=torch.int32)
retrive_next_sibling=tensor([[-1,  3, -1, -1, -1, -1],
        [-1, -1, -1, -1,  1, -1]], device='cuda:0', dtype=torch.int32)
coins=tensor([[0.0536, 0.7639, 0.4346, 0.6656, 0.2928, 0.9630],
        [0.0203, 0.2049, 0.6945, 0.1469, 0.6007, 0.9195]], device='cuda:0')
predicts=tensor([ 3, -1, -1,  4,  5, 18, 11, -1, -1, -1, 12, 18], device='cuda:0',
       dtype=torch.int32)
accept_index=tensor([[ 0,  3,  4,  5],
        [ 6, 10, 11, -1]], device='cuda:0', dtype=torch.int32)
accept_token_num=tensor([3, 2], device='cuda:0', dtype=torch.int32)

Modifications

Checklist

  • Format your code according to the Code Formatting with Pre-Commit.
  • Add unit tests as outlined in the Running Unit Tests.
  • Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.

@zhyncs zhyncs merged commit f9905d5 into main Feb 7, 2025
22 of 25 checks passed
@zhyncs zhyncs deleted the ying-spec branch February 7, 2025 12:29
@zhyncs zhyncs mentioned this pull request Feb 10, 2025
13 tasks
timethink pushed a commit to timethink/sglang that referenced this pull request Mar 9, 2025
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants