Skip to content

[Bug] TPU Support / Fixes for Scoring API (/v1/score) #727

@shireen-bean

Description

@shireen-bean

Describe the bug

The Bug

Request: Enable the SGLang Scoring API (/v1/score) on TPUs with feature parity to the GPU implementation.

The scoring API is already working on GPUs (see PR #10979), but when using the /v1/score endpoint on sglang-jax (TPU), the server crashes during logits processing with AttributeError: 'list' object has no attribute 'tolist'. The crash occurs on the very first scoring request after successful server startup.

Use case: We need to run scoring workloads on TPUs for production LLM preference model (LPM) shelf-scoring, where we compute log probabilities for ranking candidate items. TPU support for the scoring API would enable cost-effective, high-throughput scoring at scale.

Expected behavior

The /v1/score endpoint should work on TPUs with the same functionality as the GPU implementation:

  • Accept scoring requests with query, items, and label_token_ids
  • Return log probabilities for the specified tokens
  • Handle both single and batched requests without crashing

Related

  • GPU Scoring API (working): sgl-project/sglang PR #10979 - This implementation works correctly on GPUs and is the reference for the expected behavior on TPUs.

What works (success behavior)

The server successfully completes initialization and warmup requests to /generate work fine:

[2026-01-23 22:01:37] 'jit_jitted_run_model' took at least 0.00 seconds to compile (9.87s)
[2026-01-23 22:02:45] INFO:     Started server process [1]
[2026-01-23 22:02:45] INFO:     Application startup complete.
[2026-01-23 22:02:45] INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
[2026-01-23 22:02:46] INFO:     127.0.0.1:36034 - "GET /get_model_info HTTP/1.1" 200 OK
[2026-01-23 22:02:46] INFO:     127.0.0.1:36044 - "POST /generate HTTP/1.1" 200 OK  <-- Regular generation works!
[2026-01-23 22:02:46] The server is fired up and ready to roll!
[2026-01-23 22:02:50] INFO:     10.160.192.65:55458 - "GET /health HTTP/1.1" 200 OK

When the first /v1/score request arrives, prefill succeeds:

[2026-01-23 22:02:54] Receive: obj=GenerateReqInput(
    batch_size=1, 
    return_logprob=[True], 
    logprob_start_len=[-1], 
    token_ids_logprob=[[0, 1, 2]], 
    ...
)
[2026-01-23 22:02:54] Prefill batch. #new-seq: 1, #new-token: 2375, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0

What fails (failure behavior)

During batch processing, the scheduler crashes when constructing LogitsMetadata for the scoring request. The crash happens in forward_batch_generation() when it tries to build logits metadata from the batch:

[2026-01-23 22:02:54] Scheduler hit an exception: Traceback (most recent call last):
  File "/app/python/sgl_jax/srt/managers/scheduler.py", line 1496, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/app/python/sgl_jax/srt/managers/scheduler.py", line 484, in event_loop_normal
    result = self.run_batch(batch)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/app/python/sgl_jax/srt/managers/scheduler.py", line 1273, in run_batch
    self.tp_worker.forward_batch_generation(
  File "/app/python/sgl_jax/srt/managers/tp_worker.py", line 494, in forward_batch_generation
    logits_metadata=LogitsMetadata.from_model_worker_batch(model_worker_batch, self.mesh),
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/python/sgl_jax/srt/layers/logits_processor.py", line 207, in from_model_worker_batch
    batch.extend_logprob_start_lens.tolist()
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'list' object has no attribute 'tolist'

The server then terminates:

[2026-01-23 22:02:54] Received sigquit from a child process. It usually means the child failed.
[2026-01-23 22:02:54] Clearing JAX backend caches.
[2026-01-23 22:02:54] Dumping requests before crash. crash_dump_folder=None
[2026-01-23 22:02:54] INFO:     127.0.0.1:36056 - "POST /v1/score HTTP/1.1" 500 Internal Server Error

Observed behavior

Running a client that sends controllable QPS against the server shows:

  • First request triggers the crash
  • Subsequent requests fail because the server becomes unreachable

Test run 1 (RPS=100):

Logs from our client that sends requests to the SGLang Server:

[2026-01-23 22:02:22] The server is fired up and ready to roll!
[2026-01-23 22:02:28] Receive: obj=GenerateReqInput(batch_size=1, ..., return_logprob=[True], ...)
AttributeError: 'list' object has no attribute 'tolist'
[2026-01-23 22:02:28] INFO:     "POST /v1/score HTTP/1.1" 500 Internal Server Error

Test run 2 (RPS=10):

[2026-01-26 15:02:22] The server is fired up and ready to roll!
[2026-01-26 15:02:28] Receive: obj=GenerateReqInput(batch_size=1, ..., return_logprob=[True], ...)
AttributeError: 'list' object has no attribute 'tolist'
[2026-01-26 15:02:28] INFO:     "POST /v1/score HTTP/1.1" 500 Internal Server Error

Benchmark client output (identical for both):

WARNING:llm_bench.bench:Backend request failure: Unexpected error: ConnectError('All connection attempts failed').
WARNING:llm_bench.bench:Backend request failure: Unexpected error: ConnectError('All connection attempts failed').
... (repeats for all benchmark requests)
INFO:llm_bench.bench:Benchmark run complete.

Full API call stack

File "/app/python/sgl_jax/srt/entrypoints/http_server.py", line 759, in v1_score_request
    return await raw_request.app.state.openai_serving_score.handle_request(request, raw_request)
File "/app/python/sgl_jax/srt/entrypoints/openai/serving_base.py", line 43, in handle_request
    return await self._handle_non_streaming_request(
File "/app/python/sgl_jax/srt/entrypoints/openai/serving_score.py", line 43, in _handle_non_streaming_request
    scores = await self.tokenizer_manager.score_request(
File "/app/python/sgl_jax/srt/managers/tokenizer_manager.py", line 1260, in score_request
    results = await self.generate_request(batch_request, request).__anext__()
File "/app/python/sgl_jax/srt/managers/tokenizer_manager.py", line 280, in generate_request
    async for response in self._handle_batch_request(obj, request, created_time):
File "/app/python/sgl_jax/srt/managers/tokenizer_manager.py", line 565, in _handle_batch_request
    outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))

Configurations tested

Is it a batching or concurrecny issue? To find out we tested:

Configuration Result
Default batching ❌ Crashes on 1st request
--max-running-requests=1 --disable-overlap-schedule ❌ Crashes on 1st request

On GPUs, the scoring API production deployment uses --max-running-requests=1 and --disable-overlap-schedule as a workaround for a known batching bug in the scoring code path. We tested this same configuration on TPU to rule out batching as the cause, but the crash still occurs on the very first request—confirming the bug is in the scoring code path itself, ruling out a batching or concurrency issue.

GPU production config (working):

--max-running-requests=1
--disable-overlap-schedule
--disable-cuda-graph
--chunked-prefill-size=-1

Reproduction

Model

Qwen3-0.6B

Server launch command

python -m sgl_jax.launch_server \
  --model-path=/model-repository/Qwen3-0.6B/Qwen3-0.6B \
  --mem-fraction-static=0.50 \
  --served-model-name=Qwen3-0.6B \
  --host=0.0.0.0 \
  --port=8000 \
  --disable-radix-cache \
  --chunked-prefill-size=-1 \
  --enable-metrics \
  --attention-backend=fa \
  --log-level=debug \
  --log-requests

Request that triggers the crash

curl -X POST http://localhost:8000/v1/score \
  -H "Content-Type: application/json" \
  -d '{
    "model": "Qwen3-0.6B",
    "query": "Any text prompt here",
    "items": ["item1"],
    "label_token_ids": [0, 1, 2],
    "apply_softmax": false
  }'

Environment

Environment

Component Specification
TPU Type Google Cloud TPU v6e
TPU Topology 1x1
Memory 128Gi
sglang-jax Image latest
Python 3.12
Platform GKE

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions