-
Notifications
You must be signed in to change notification settings - Fork 71
Description
Describe the bug
The Bug
Request: Enable the SGLang Scoring API (/v1/score) on TPUs with feature parity to the GPU implementation.
The scoring API is already working on GPUs (see PR #10979), but when using the /v1/score endpoint on sglang-jax (TPU), the server crashes during logits processing with AttributeError: 'list' object has no attribute 'tolist'. The crash occurs on the very first scoring request after successful server startup.
Use case: We need to run scoring workloads on TPUs for production LLM preference model (LPM) shelf-scoring, where we compute log probabilities for ranking candidate items. TPU support for the scoring API would enable cost-effective, high-throughput scoring at scale.
Expected behavior
The /v1/score endpoint should work on TPUs with the same functionality as the GPU implementation:
- Accept scoring requests with
query,items, andlabel_token_ids - Return log probabilities for the specified tokens
- Handle both single and batched requests without crashing
Related
- GPU Scoring API (working): sgl-project/sglang PR #10979 - This implementation works correctly on GPUs and is the reference for the expected behavior on TPUs.
What works (success behavior)
The server successfully completes initialization and warmup requests to /generate work fine:
[2026-01-23 22:01:37] 'jit_jitted_run_model' took at least 0.00 seconds to compile (9.87s)
[2026-01-23 22:02:45] INFO: Started server process [1]
[2026-01-23 22:02:45] INFO: Application startup complete.
[2026-01-23 22:02:45] INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
[2026-01-23 22:02:46] INFO: 127.0.0.1:36034 - "GET /get_model_info HTTP/1.1" 200 OK
[2026-01-23 22:02:46] INFO: 127.0.0.1:36044 - "POST /generate HTTP/1.1" 200 OK <-- Regular generation works!
[2026-01-23 22:02:46] The server is fired up and ready to roll!
[2026-01-23 22:02:50] INFO: 10.160.192.65:55458 - "GET /health HTTP/1.1" 200 OK
When the first /v1/score request arrives, prefill succeeds:
[2026-01-23 22:02:54] Receive: obj=GenerateReqInput(
batch_size=1,
return_logprob=[True],
logprob_start_len=[-1],
token_ids_logprob=[[0, 1, 2]],
...
)
[2026-01-23 22:02:54] Prefill batch. #new-seq: 1, #new-token: 2375, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
What fails (failure behavior)
During batch processing, the scheduler crashes when constructing LogitsMetadata for the scoring request. The crash happens in forward_batch_generation() when it tries to build logits metadata from the batch:
[2026-01-23 22:02:54] Scheduler hit an exception: Traceback (most recent call last):
File "/app/python/sgl_jax/srt/managers/scheduler.py", line 1496, in run_scheduler_process
scheduler.event_loop_normal()
File "/app/python/sgl_jax/srt/managers/scheduler.py", line 484, in event_loop_normal
result = self.run_batch(batch)
^^^^^^^^^^^^^^^^^^^^^
File "/app/python/sgl_jax/srt/managers/scheduler.py", line 1273, in run_batch
self.tp_worker.forward_batch_generation(
File "/app/python/sgl_jax/srt/managers/tp_worker.py", line 494, in forward_batch_generation
logits_metadata=LogitsMetadata.from_model_worker_batch(model_worker_batch, self.mesh),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/python/sgl_jax/srt/layers/logits_processor.py", line 207, in from_model_worker_batch
batch.extend_logprob_start_lens.tolist()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'list' object has no attribute 'tolist'
The server then terminates:
[2026-01-23 22:02:54] Received sigquit from a child process. It usually means the child failed.
[2026-01-23 22:02:54] Clearing JAX backend caches.
[2026-01-23 22:02:54] Dumping requests before crash. crash_dump_folder=None
[2026-01-23 22:02:54] INFO: 127.0.0.1:36056 - "POST /v1/score HTTP/1.1" 500 Internal Server Error
Observed behavior
Running a client that sends controllable QPS against the server shows:
- First request triggers the crash
- Subsequent requests fail because the server becomes unreachable
Test run 1 (RPS=100):
Logs from our client that sends requests to the SGLang Server:
[2026-01-23 22:02:22] The server is fired up and ready to roll!
[2026-01-23 22:02:28] Receive: obj=GenerateReqInput(batch_size=1, ..., return_logprob=[True], ...)
AttributeError: 'list' object has no attribute 'tolist'
[2026-01-23 22:02:28] INFO: "POST /v1/score HTTP/1.1" 500 Internal Server Error
Test run 2 (RPS=10):
[2026-01-26 15:02:22] The server is fired up and ready to roll!
[2026-01-26 15:02:28] Receive: obj=GenerateReqInput(batch_size=1, ..., return_logprob=[True], ...)
AttributeError: 'list' object has no attribute 'tolist'
[2026-01-26 15:02:28] INFO: "POST /v1/score HTTP/1.1" 500 Internal Server Error
Benchmark client output (identical for both):
WARNING:llm_bench.bench:Backend request failure: Unexpected error: ConnectError('All connection attempts failed').
WARNING:llm_bench.bench:Backend request failure: Unexpected error: ConnectError('All connection attempts failed').
... (repeats for all benchmark requests)
INFO:llm_bench.bench:Benchmark run complete.
Full API call stack
File "/app/python/sgl_jax/srt/entrypoints/http_server.py", line 759, in v1_score_request
return await raw_request.app.state.openai_serving_score.handle_request(request, raw_request)
File "/app/python/sgl_jax/srt/entrypoints/openai/serving_base.py", line 43, in handle_request
return await self._handle_non_streaming_request(
File "/app/python/sgl_jax/srt/entrypoints/openai/serving_score.py", line 43, in _handle_non_streaming_request
scores = await self.tokenizer_manager.score_request(
File "/app/python/sgl_jax/srt/managers/tokenizer_manager.py", line 1260, in score_request
results = await self.generate_request(batch_request, request).__anext__()
File "/app/python/sgl_jax/srt/managers/tokenizer_manager.py", line 280, in generate_request
async for response in self._handle_batch_request(obj, request, created_time):
File "/app/python/sgl_jax/srt/managers/tokenizer_manager.py", line 565, in _handle_batch_request
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
Configurations tested
Is it a batching or concurrecny issue? To find out we tested:
| Configuration | Result |
|---|---|
| Default batching | ❌ Crashes on 1st request |
--max-running-requests=1 --disable-overlap-schedule |
❌ Crashes on 1st request |
On GPUs, the scoring API production deployment uses --max-running-requests=1 and --disable-overlap-schedule as a workaround for a known batching bug in the scoring code path. We tested this same configuration on TPU to rule out batching as the cause, but the crash still occurs on the very first request—confirming the bug is in the scoring code path itself, ruling out a batching or concurrency issue.
GPU production config (working):
--max-running-requests=1
--disable-overlap-schedule
--disable-cuda-graph
--chunked-prefill-size=-1
Reproduction
Model
Qwen3-0.6B
Server launch command
python -m sgl_jax.launch_server \
--model-path=/model-repository/Qwen3-0.6B/Qwen3-0.6B \
--mem-fraction-static=0.50 \
--served-model-name=Qwen3-0.6B \
--host=0.0.0.0 \
--port=8000 \
--disable-radix-cache \
--chunked-prefill-size=-1 \
--enable-metrics \
--attention-backend=fa \
--log-level=debug \
--log-requestsRequest that triggers the crash
curl -X POST http://localhost:8000/v1/score \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen3-0.6B",
"query": "Any text prompt here",
"items": ["item1"],
"label_token_ids": [0, 1, 2],
"apply_softmax": false
}'Environment
Environment
| Component | Specification |
|---|---|
| TPU Type | Google Cloud TPU v6e |
| TPU Topology | 1x1 |
| Memory | 128Gi |
| sglang-jax Image | latest |
| Python | 3.12 |
| Platform | GKE |