Skip to content

Commit 6517392

Browse files
kouroshHakhagemini-code-assist[bot]
authored andcommitted
[Data][LLM] Add support for classification and scoring models in Ray Data LLM (ray-project#59499)
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
1 parent 84cee1f commit 6517392

File tree

4 files changed

+160
-4
lines changed

4 files changed

+160
-4
lines changed

doc/source/data/doc_code/working-with-llms/basic_llm_example.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,41 @@ def create_embedding_processor():
197197

198198
# __embedding_config_example_end__
199199

200+
# __classification_config_example_start__
201+
# Sequence classification model configuration
202+
# Use task_type="classify" for classification models (e.g., sentiment, quality scoring)
203+
# Use task_type="score" for cross-encoder scoring models
204+
classification_config = vLLMEngineProcessorConfig(
205+
model_source="nvidia/nemocurator-fineweb-nemotron-4-edu-classifier",
206+
task_type="classify",
207+
engine_kwargs=dict(
208+
max_model_len=512,
209+
enforce_eager=True,
210+
),
211+
batch_size=8,
212+
concurrency=1,
213+
apply_chat_template=False,
214+
detokenize=False,
215+
)
216+
217+
218+
# Example usage for classification
219+
def create_classification_processor():
220+
return build_processor(
221+
classification_config,
222+
preprocess=lambda row: dict(prompt=row["text"]),
223+
postprocess=lambda row: {
224+
"text": row["prompt"],
225+
# Classification models return logits in the 'embeddings' field
226+
"score": float(row["embeddings"][0])
227+
if row.get("embeddings") is not None and len(row["embeddings"]) > 0
228+
else None,
229+
},
230+
)
231+
232+
233+
# __classification_config_example_end__
234+
200235
# __shared_vllm_engine_config_example_start__
201236
import ray
202237
from ray import serve
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
Documentation example and test for classification model batch inference.
3+
4+
This example demonstrates how to use Ray Data LLM with sequence classification
5+
models like educational content classifiers.
6+
"""
7+
8+
import subprocess
9+
import sys
10+
11+
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "ray[llm]"])
12+
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.26.4"])
13+
14+
15+
def run_classification_example():
16+
# __classification_example_start__
17+
import ray
18+
from ray.data.llm import vLLMEngineProcessorConfig, build_processor
19+
20+
# Configure vLLM for a sequence classification model
21+
classification_config = vLLMEngineProcessorConfig(
22+
model_source="nvidia/nemocurator-fineweb-nemotron-4-edu-classifier",
23+
task_type="classify", # Use 'classify' for sequence classification models
24+
engine_kwargs=dict(
25+
max_model_len=512,
26+
enforce_eager=True,
27+
),
28+
batch_size=8,
29+
concurrency=1,
30+
apply_chat_template=False,
31+
detokenize=False,
32+
)
33+
34+
classification_processor = build_processor(
35+
classification_config,
36+
preprocess=lambda row: dict(prompt=row["text"]),
37+
postprocess=lambda row: {
38+
"text": row["prompt"],
39+
# Classification models return logits in the 'embeddings' field
40+
"edu_score": float(row["embeddings"][0])
41+
if row.get("embeddings") is not None and len(row["embeddings"]) > 0
42+
else None,
43+
},
44+
)
45+
46+
# Sample texts with varying educational quality
47+
texts = [
48+
"lol that was so funny haha",
49+
"Photosynthesis converts light energy into chemical energy.",
50+
"Newton's laws describe the relationship between forces and motion.",
51+
]
52+
ds = ray.data.from_items([{"text": text} for text in texts])
53+
54+
classified_ds = classification_processor(ds)
55+
classified_ds.show(limit=3)
56+
# __classification_example_end__
57+
58+
59+
if __name__ == "__main__":
60+
try:
61+
import torch
62+
63+
if torch.cuda.is_available():
64+
run_classification_example()
65+
else:
66+
print("Skipping classification example (no GPU available)")
67+
except Exception as e:
68+
print(f"Skipping classification example: {e}")
69+

doc/source/data/working-with-llms.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ This guide shows you how to use :ref:`ray.data.llm <llm-ref>` to:
1111
* :ref:`Perform batch inference with LLMs <batch_inference_llm>`
1212
* :ref:`Configure vLLM for LLM inference <vllm_llm>`
1313
* :ref:`Batch inference with embedding models <embedding_models>`
14+
* :ref:`Batch inference with classification models <classification_models>`
1415
* :ref:`Query deployed models with an OpenAI compatible API endpoint <openai_compatible_api_endpoint>`
1516

1617
.. _vllm_quickstart:
@@ -221,6 +222,39 @@ For a complete embedding configuration example, see:
221222
:start-after: __embedding_config_example_start__
222223
:end-before: __embedding_config_example_end__
223224

225+
.. _classification_models:
226+
227+
Batch inference with classification models
228+
------------------------------------------
229+
230+
Ray Data LLM supports batch inference with sequence classification models, such as content classifiers and sentiment analyzers:
231+
232+
.. literalinclude:: doc_code/working-with-llms/classification_example.py
233+
:language: python
234+
:start-after: __classification_example_start__
235+
:end-before: __classification_example_end__
236+
237+
.. testoutput::
238+
:options: +MOCK
239+
240+
{'text': 'lol that was so funny haha', 'edu_score': -0.05}
241+
{'text': 'Photosynthesis converts light energy...', 'edu_score': 1.73}
242+
{'text': "Newton's laws describe...", 'edu_score': 2.52}
243+
244+
Key differences for classification models:
245+
246+
- Set ``task_type="classify"`` (or ``task_type="score"`` for scoring models)
247+
- Set ``apply_chat_template=False`` and ``detokenize=False``
248+
- Use direct ``prompt`` input instead of ``messages``
249+
- Access classification logits through ``row["embeddings"]``
250+
251+
For a complete classification configuration example, see:
252+
253+
.. literalinclude:: doc_code/working-with-llms/basic_llm_example.py
254+
:language: python
255+
:start-after: __classification_config_example_start__
256+
:end-before: __classification_config_example_end__
257+
224258
.. _openai_compatible_api_endpoint:
225259

226260
Batch inference with an OpenAI-compatible endpoint

python/ray/llm/_internal/batch/stages/vllm_engine_stage.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ class vLLMTaskType(str, Enum):
6363
"""Generate embeddings."""
6464
EMBED = "embed"
6565

66+
"""Classification (e.g., sequence classification models)."""
67+
CLASSIFY = "classify"
68+
69+
"""Scoring (e.g., cross-encoder models)."""
70+
SCORE = "score"
71+
6672

6773
class vLLMEngineRequest(BaseModel):
6874
"""A request to the vLLM engine."""
@@ -255,7 +261,11 @@ def __init__(
255261
) from e
256262

257263
# Construct PoolerConfig if override_pooler_config is specified.
258-
if self.task_type == vLLMTaskType.EMBED and "override_pooler_config" in kwargs:
264+
if (
265+
self.task_type
266+
in {vLLMTaskType.EMBED, vLLMTaskType.CLASSIFY, vLLMTaskType.SCORE}
267+
and "override_pooler_config" in kwargs
268+
):
259269
kwargs["override_pooler_config"] = vllm.config.PoolerConfig(
260270
**kwargs["override_pooler_config"]
261271
)
@@ -375,7 +385,11 @@ async def _prepare_llm_request(self, row: Dict[str, Any]) -> vLLMEngineRequest:
375385
**maybe_convert_ndarray_to_list(sampling_params),
376386
structured_outputs=structured_outputs,
377387
)
378-
elif self.task_type == vLLMTaskType.EMBED:
388+
elif self.task_type in (
389+
vLLMTaskType.EMBED,
390+
vLLMTaskType.CLASSIFY,
391+
vLLMTaskType.SCORE,
392+
):
379393
params = vllm.PoolingParams(task=self.task_type.value)
380394
else:
381395
raise ValueError(f"Unsupported task type: {self.task_type}")
@@ -456,8 +470,12 @@ async def _generate_async(self, request: vLLMEngineRequest) -> Any:
456470
)
457471

458472
# Send the request to the LLM engine.
459-
# vLLM 0.12.0 uses encode() for pooling/embedding tasks, generate() for text generation
460-
if self.task_type == vLLMTaskType.EMBED:
473+
# vLLM 0.12.0 uses encode() for pooling/embedding/classification tasks, generate() for text generation
474+
if self.task_type in (
475+
vLLMTaskType.EMBED,
476+
vLLMTaskType.CLASSIFY,
477+
vLLMTaskType.SCORE,
478+
):
461479
stream = self.engine.encode(
462480
request_id=str(request.request_id),
463481
prompt=llm_prompt,

0 commit comments

Comments
 (0)