[data][llm] Add pooling parameter#59534
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for pooling_params in vLLM embedding tasks, allowing users to specify parameters like token truncation and embedding normalization. The changes include updating the vLLM engine stage to handle these parameters and adding corresponding tests to validate the new functionality. My review includes a couple of suggestions for improving maintainability and aligning the implementation with the documented behavior.
0b66677 to
b57e380
Compare
| @pytest.mark.parametrize( | ||
| "pooling_params", | ||
| [ | ||
| {"truncate_prompt_tokens": -1}, |
There was a problem hiding this comment.
also test None, and {} ?
There was a problem hiding this comment.
Added a test case for empty dict. None is not a possible value as we default to empty dict if pooling_params is not provided.
| for key, expected_value in pooling_params.items(): | ||
| assert hasattr(request.params, key) | ||
| actual_value = getattr(request.params, key) | ||
| assert actual_value == expected_value |
There was a problem hiding this comment.
Critical can we test on some other property?
you want to basically test whether truncation is applied or whether normalizatio is applied.
It is not sufficient to check for request.params values to match what was sent in.
Idea: We can test that on input x the answer will be different comparing truncation=None vs. truncation=2, similarly we can test normalize=False vs. normalize=True
There was a problem hiding this comment.
Validating the difference of outputs is a good idea.
There was a problem hiding this comment.
sampling_params do not apply to encode. encode is deterministic.
There was a problem hiding this comment.
Added a couple more test cases:
- Compare truncation=None vs. truncation=3
- Compare normalize=False vs. normalize=True
- Validate that truncation is effective on long prompts
python/ray/llm/tests/batch/gpu/stages/test_vllm_engine_stage.py
Outdated
Show resolved
Hide resolved
884994f to
ff4e874
Compare
|
In the latest revision, introduced a small fix: Pooling parameter's truncate_prompt_tokens is not respected by AsyncLLMEngine.encode(). I filed an vllm-project/vllm#31012 in vLLM and have a vllm-project/vllm#31013 for it. As a temporary solution, prompt truncation is handled through the truncate_prompt_tokens argument passed to AsyncLLMEngine.encode. From Ray's users perspective, any value provided via pooling_params will be honored as expected. |
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
ff4e874 to
f234b30
Compare
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com> Signed-off-by: jasonwrwang <jasonwrwang@tencent.com>
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
Description
The preprocessor forwards only
sampling_paramsto the engine today. Fortask_type="embed", however, we should also allow forwardingpooling_params, enabling features such as truncating the input prompt to a fixed token budget viatruncate_prompt_tokensor normalizing the output embedding vianormalize. See https://docs.vllm.ai/en/latest/api/vllm/#vllm.PoolingParams for a comprehensive list of supported attributes.Related issues
Resolves #57805
Additional information
test_vllm_engine_stageto validate that the pooling parameters are received by the engine.