Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 45 additions & 44 deletions python/sgl_jax/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,39 @@ class LogitsProcessorOutput:
# The logprobs of the next tokens. shape: [#seq]
next_token_logprobs: jax.Array | None = None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: list | None = None
next_token_top_logprobs_idx: list | None = None
next_token_top_logprobs_val: jax.Array | None = None
next_token_top_logprobs_idx: jax.Array = None
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
next_token_token_ids_logprobs_val: list | None = None
next_token_token_ids_logprobs_idx: list | None = None
next_token_token_ids_logprobs_val: jax.Array | None = None
next_token_token_ids_logprobs_idx: jax.Array | None = None

## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logprobs of input tokens. shape: [#token]
input_token_logprobs: jax.Array | None = None
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
input_top_logprobs_val: list = None
input_top_logprobs_idx: list = None
input_top_logprobs_val: jax.Array | None = None
input_top_logprobs_idx: jax.Array | None = None
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
input_token_ids_logprobs_val: list | None = None
input_token_ids_logprobs_idx: list | None = None
input_token_ids_logprobs_val: jax.Array | None = None
input_token_ids_logprobs_idx: jax.Array | None = None

def tree_flatten(self):
children = (
self.next_token_logits,
self.hidden_states,
self.next_token_logprobs,
self.input_token_logprobs,
self.next_token_top_logprobs_val,
self.next_token_top_logprobs_idx,
self.next_token_token_ids_logprobs_val,
self.next_token_token_ids_logprobs_idx,
self.input_top_logprobs_val,
self.input_top_logprobs_idx,
self.input_token_ids_logprobs_val,
self.input_token_ids_logprobs_idx,
)

aux_data = {
"next_token_top_logprobs_val": self.next_token_top_logprobs_val,
"next_token_top_logprobs_idx": self.next_token_top_logprobs_idx,
"next_token_token_ids_logprobs_val": self.next_token_token_ids_logprobs_val,
"next_token_token_ids_logprobs_idx": self.next_token_token_ids_logprobs_idx,
"input_top_logprobs_val": self.input_top_logprobs_val,
"input_top_logprobs_idx": self.input_top_logprobs_idx,
"input_token_ids_logprobs_val": self.input_token_ids_logprobs_val,
"input_token_ids_logprobs_idx": self.input_token_ids_logprobs_idx,
}
aux_data = {}
return (children, aux_data)

@classmethod
Expand All @@ -74,14 +73,14 @@ def tree_unflatten(cls, aux_data, children):
obj.next_token_logprobs = children[2]
obj.input_token_logprobs = children[3]

obj.next_token_top_logprobs_val = aux_data["next_token_top_logprobs_val"]
obj.next_token_top_logprobs_idx = aux_data["next_token_top_logprobs_idx"]
obj.next_token_token_ids_logprobs_val = aux_data["next_token_token_ids_logprobs_val"]
obj.next_token_token_ids_logprobs_idx = aux_data["next_token_token_ids_logprobs_idx"]
obj.input_top_logprobs_val = aux_data["input_top_logprobs_val"]
obj.input_top_logprobs_idx = aux_data["input_top_logprobs_idx"]
obj.input_token_ids_logprobs_val = aux_data["input_token_ids_logprobs_val"]
obj.input_token_ids_logprobs_idx = aux_data["input_token_ids_logprobs_idx"]
obj.next_token_top_logprobs_val = children[4]
obj.next_token_top_logprobs_idx = children[5]
obj.next_token_token_ids_logprobs_val = children[6]
obj.next_token_token_ids_logprobs_idx = children[7]
obj.input_top_logprobs_val = children[8]
obj.input_top_logprobs_idx = children[9]
obj.input_token_ids_logprobs_val = children[10]
obj.input_token_ids_logprobs_idx = children[11]

return obj

Expand Down Expand Up @@ -196,7 +195,9 @@ def from_model_worker_batch(cls, batch: ModelWorkerBatch, mesh: Mesh = None):
extend_seq_lens=device_array(batch.extend_seq_lens, sharding=sharding),
extend_seq_lens_cpu=extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=(
batch.extend_logprob_start_lens if batch.return_logprob else None
batch.extend_logprob_start_lens.tolist()
if batch.return_logprob and batch.extend_logprob_start_lens is not None
else None
),
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
top_logprobs_nums=batch.top_logprobs_nums,
Expand Down Expand Up @@ -241,18 +242,12 @@ def __call__(
input_logprob_indices_pt = 0
input_logprob_indices = []
pt, pruned_states = 0, []

for extend_logprob_start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
):
if extend_len == 0:
break

start_len = extend_logprob_start_len

# We always need at least 1 token to sample because that's required
# by a caller.
assert extend_len > start_len
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
pt += extend_len
sample_index_pt += extend_len - start_len
Expand Down Expand Up @@ -338,14 +333,16 @@ def __call__(
(
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata, self.mesh)
else:
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None

input_token_logprobs = input_logprobs[
device_array(np.arange(input_logprobs.shape[0])),
out_sharding = NamedSharding(self.mesh, P(None))
indices = (
np.arange(input_logprobs.shape[0]),
logits_metadata.extend_input_logprob_token_ids_device,
]
)
input_token_logprobs = input_logprobs.at[indices].get(out_sharding=out_sharding)

return LogitsProcessorOutput(
next_token_logits=sampled_logits,
Expand All @@ -358,7 +355,10 @@ def __call__(
)

@staticmethod
def get_token_ids_logprobs(all_logprobs: jax.Array, logits_metadata: LogitsMetadata):
def get_token_ids_logprobs(
all_logprobs: jax.Array, logits_metadata: LogitsMetadata, mesh: Mesh
):
out_sharding = NamedSharding(mesh, P(None))
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
pt = 0
for token_ids, pruned_len in zip(
Expand All @@ -371,19 +371,20 @@ def get_token_ids_logprobs(all_logprobs: jax.Array, logits_metadata: LogitsMetad
continue

input_token_ids_logprobs_val.append(
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
[
all_logprobs.at[pt + j, token_ids].get(out_sharding=out_sharding)
for j in range(pruned_len)
]
)
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
pt += pruned_len

return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
return jnp.array(input_token_ids_logprobs_val), jnp.array(input_token_ids_logprobs_idx)

@staticmethod
def get_top_logprobs(all_logprobs: jax.Array, logits_metadata: LogitsMetadata):
max_k = max(logits_metadata.top_logprobs_nums)
values, indices = jax.lax.top_k(all_logprobs, max_k)
values = values.tolist()
indices = indices.tolist()

input_top_logprobs_val, input_top_logprobs_idx = [], []

Expand All @@ -401,7 +402,7 @@ def get_top_logprobs(all_logprobs: jax.Array, logits_metadata: LogitsMetadata):
input_top_logprobs_idx.append([indices[pt + j][:k] for j in range(pruned_len)])
pt += pruned_len

return input_top_logprobs_val, input_top_logprobs_idx
return jnp.array(input_top_logprobs_val), jnp.array(input_top_logprobs_idx)

def compute_temp_top_p_normalized_logprobs(
self, last_logits: jax.Array, logits_metadata: LogitsMetadata
Expand Down
51 changes: 33 additions & 18 deletions python/sgl_jax/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,49 @@ def _process_logprob_results(self, operands):
"""Process logprob results when return_logprob=True"""
logits_output, sampling_metadata, batch_next_token_ids, logprobs = operands

(
next_token_logprobs,
next_token_top_logprobs_val,
next_token_top_logprobs_idx,
next_token_token_ids_logprobs_val,
next_token_token_ids_logprobs_idx,
) = (None, None, None, None, None)
# Set next_token_logprobs
out_sharding = NamedSharding(self.mesh, P(None))
indices = (np.arange(len(batch_next_token_ids)), batch_next_token_ids)
logits_output.next_token_logprobs = logprobs.at[indices].get(out_sharding=out_sharding)
next_token_logprobs = logprobs.at[indices].get(out_sharding=out_sharding)

# Set top_logprobs if needed
if sampling_metadata.top_logprobs_nums is not None and any(
x > 0 for x in sampling_metadata.top_logprobs_nums
):
(
logits_output.next_token_top_logprobs_val,
logits_output.next_token_top_logprobs_idx,
next_token_top_logprobs_val,
next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, sampling_metadata.top_logprobs_nums)

# Set token_ids_logprobs if needed
if sampling_metadata.token_ids_logprobs is not None and any(
x is not None for x in sampling_metadata.token_ids_logprobs
):
(
logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx,
next_token_token_ids_logprobs_val,
next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs(logprobs, sampling_metadata.token_ids_logprobs, self.mesh)

return None
return LogitsProcessorOutput(
next_token_logits=logits_output.next_token_logits,
next_token_logprobs=next_token_logprobs,
next_token_top_logprobs_val=next_token_top_logprobs_val,
next_token_top_logprobs_idx=next_token_top_logprobs_idx,
next_token_token_ids_logprobs_val=next_token_token_ids_logprobs_val,
next_token_token_ids_logprobs_idx=next_token_token_ids_logprobs_idx,
input_token_logprobs=logits_output.input_token_logprobs,
input_top_logprobs_val=logits_output.input_top_logprobs_val,
input_top_logprobs_idx=logits_output.input_top_logprobs_idx,
input_token_ids_logprobs_val=logits_output.input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=logits_output.input_token_ids_logprobs_idx,
)

def _apply_linear_penalty(self, operands):
"""
Expand Down Expand Up @@ -150,6 +169,7 @@ def __call__(
sampling_metadata: Metadata for sampling
use_sort_for_toppk_minp: whether use sort when dealing with top_k, top_k and min_p.
"""

# Apply penalties before sampling
logits = lax.cond(
sampling_metadata.do_penalties,
Expand Down Expand Up @@ -182,28 +202,23 @@ def __call__(
batch_next_token_ids,
logprobs,
)
lax.cond(
sampling_metadata.return_logprob,
self._process_logprob_results,
lambda operands: None,
logprob_operands,
)
new_logits_output = None
if sampling_metadata.return_logprob:
new_logits_output = self._process_logprob_results(logprob_operands)

return batch_next_token_ids
return batch_next_token_ids, new_logits_output


def get_top_logprobs(logprobs: jax.Array, top_logprobs_nums: list[int]):
max_k = max(top_logprobs_nums)
values, indices = jax.lax.top_k(logprobs, max_k)
values = values.tolist()
indices = indices.tolist()

output_top_logprobs_val = []
output_top_logprobs_idx = []
for i, k in enumerate(top_logprobs_nums):
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return output_top_logprobs_val, output_top_logprobs_idx
return jnp.array(output_top_logprobs_val), jnp.array(output_top_logprobs_idx)


def get_token_ids_logprobs(logprobs: jax.Array, token_ids_logprobs: list[list[int]], mesh: Mesh):
Expand All @@ -213,14 +228,14 @@ def get_token_ids_logprobs(logprobs: jax.Array, token_ids_logprobs: list[list[in
for i, token_ids in enumerate(token_ids_logprobs):
if token_ids is not None:
output_token_ids_logprobs_val.append(
logprobs.at[i, token_ids].get(out_sharding=out_sharding).tolist()
logprobs.at[i, token_ids].get(out_sharding=out_sharding)
)
output_token_ids_logprobs_idx.append(token_ids)
else:
output_token_ids_logprobs_val.append([])
output_token_ids_logprobs_idx.append([])

return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
return jnp.array(output_token_ids_logprobs_val), jnp.array(output_token_ids_logprobs_idx)


def multinomial(
Expand Down
2 changes: 1 addition & 1 deletion python/sgl_jax/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class GenerateReqInput:
# If return logprobs, the token ids to return logprob for.
token_ids_logprob: list[list[int]] | list[int] | None = None
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False
return_text_in_logprobs: bool = True

def _normalize_rid(self, num):
"""Normalize request IDs for batch processing."""
Expand Down
6 changes: 6 additions & 0 deletions python/sgl_jax/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,12 @@ def get_model_worker_batch(
[0] * bs_padding_size, dtype=extend_seq_lens.dtype
)
extend_seq_lens = np.concat([extend_seq_lens, invalid_extend_seq_lens], axis=0)

invalid_extend_logprob_start_lens = np.array([0] * bs_padding_size, dtype=np.int32)
extend_logprob_start_lens = np.concat(
[extend_logprob_start_lens, invalid_extend_logprob_start_lens], axis=0
)

else:
invalid_extend_start_loc = np.array(
[len(seq_lens_cpu)] * bs_padding_size, dtype=extend_start_loc.dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def process_batch_result_prefill(
logits_output.input_token_logprobs = tuple(
jax.device_get(logits_output.input_token_logprobs).astype(float)
)

# Check finish conditions
logprob_pt = 0
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
Expand Down Expand Up @@ -317,8 +316,6 @@ def add_input_logprob_return_values(
assert req.input_token_logprobs_val is not None
return

# Important for the performance.
assert isinstance(output.input_token_logprobs, tuple)
input_token_logprobs: tuple[int] = output.input_token_logprobs
input_token_logprobs = input_token_logprobs[logprob_pt : logprob_pt + num_input_logprobs]
req.input_token_logprobs.extend(input_token_logprobs)
Expand Down
30 changes: 29 additions & 1 deletion python/sgl_jax/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,39 @@ def forward_batch_generation(
self._update_grammar_vocab_mask(model_worker_batch, sampling_metadata)

with jtu.count_pjit_cpp_cache_miss() as count:
next_token_ids_device = self.model_runner.sample(
next_token_ids_device, new_logits_output = self.model_runner.sample(
logits_output,
sampling_metadata,
)
cache_miss_count += count()
if new_logits_output is not None:
logits_output = new_logits_output
if logits_output.next_token_top_logprobs_val is not None:
logits_output.next_token_top_logprobs_val = (
logits_output.next_token_top_logprobs_val.astype(jnp.float32).tolist()
)
logits_output.next_token_top_logprobs_idx = (
logits_output.next_token_top_logprobs_idx.tolist()
)
if logits_output.next_token_token_ids_logprobs_val is not None:
logits_output.next_token_token_ids_logprobs_val = (
logits_output.next_token_token_ids_logprobs_val.astype(jnp.float32).tolist()
)
logits_output.next_token_token_ids_logprobs_idx = (
logits_output.next_token_token_ids_logprobs_idx.tolist()
)
if logits_output.input_token_ids_logprobs_val is not None:
logits_output.input_token_ids_logprobs_val = (
logits_output.input_token_ids_logprobs_val.astype(jnp.float32).tolist()
)
logits_output.input_token_ids_logprobs_idx = (
logits_output.input_token_ids_logprobs_idx.tolist()
)
if logits_output.input_top_logprobs_val is not None:
logits_output.input_top_logprobs_val = logits_output.input_top_logprobs_val.astype(
jnp.float32
).tolist()
logits_output.input_top_logprobs_idx = logits_output.input_top_logprobs_idx.tolist()

return (
logits_output,
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def run_one_file(filename):
TestFile("test/srt/openai_server/features/test_json_mode.py", 2),
TestFile("test/srt/openai_server/features/test_structural_tag.py", 2),
TestFile("test/srt/test_srt_engine.py", 1),
TestFile("test/srt/test_logprobs.py", 3),
],
"e2e-test-tpu-v6e-4": [
TestFile("test/srt/openai_server/basic/test_tool_calls.py", 3),
Expand Down
Loading