diff --git a/docs/model_serving_framework/text_embedding_model_examples.md b/docs/model_serving_framework/text_embedding_model_examples.md index c13a1ab085..c0682ed992 100644 --- a/docs/model_serving_framework/text_embedding_model_examples.md +++ b/docs/model_serving_framework/text_embedding_model_examples.md @@ -293,9 +293,9 @@ POST /_plugins/_ml/models/zwla5YUB1qmVrJFlwzXJ/_unload ## 1.2 trace huggingface transformers model Without [`sentence-transformers`](https://pypi.org/project/sentence-transformers/) installed, you can trace this model `AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')`. -But model traced this way doesn't include post-processing. So user have to specify post-process logic with `pooling_mode` and `normalize_result`. +But model traced this way doesn't include post-processing. So user have to specify post-process logic with `pooling_mode` and `normalize_result`. -Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`. +Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`, `lasttoken`. The only difference is the uploading model input, for load/predict/profile/unload model, you can refer to ["1.1 trace sentence transformers model"](#11-trace-sentence-transformers-model). @@ -322,7 +322,18 @@ POST /_plugins/_ml/models/_upload User can export Pytorch model to ONNX, then upload and run it with ml-commons APIs. This example ONNX model also needs to specify post-process logic with `pooling_mode` and `normalize_result`. -Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`. +Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`, `lasttoken`. + +### Pooling Methods + +| Method | Description | +|--------|-------------| +| `mean` | Averages all token embeddings weighted by attention mask | +| `mean_sqrt_len` | Mean pooling divided by square root of sequence length | +| `max` | Takes maximum value across all token positions | +| `weightedmean` | Weighted average where later tokens have higher weights | +| `cls` | Uses the first token (CLS token) embedding | +| `lasttoken` | Uses the last non-padding token's embedding. Useful for decoder-only models where the final token captures cumulative context | The only difference is the uploading model input, for load/predict/profile/unload model, you can refer to ["1.1 trace sentence transformers model"](#11-trace-sentence-transformers-model). diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java index 55fea64f9a..286e3655dc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java @@ -102,6 +102,9 @@ public float[] processOutput(TranslatorContext ctx, NDList list) { case "cls": embeddings = embeddings.get(0); break; + case "lasttoken": + embeddings = lastTokenPool(embeddings, inputAttentionMask); + break; default: throw new AssertionError("Unexpected pooling model: " + pooling); } @@ -146,6 +149,18 @@ private NDArray weightedMeanPool(NDArray embeddings, NDArray inputAttentionMask) return embeddingSum.div(maskSum); } + private NDArray lastTokenPool(NDArray embeddings, NDArray inputAttentionMask) { + // Sum attention mask to get count of real tokens + long tokenCount = (long) inputAttentionMask.sum().toFloatArray()[0]; + // Last token index (0-based) + long lastTokenIdx = tokenCount - 1; + // Handle edge case + if (lastTokenIdx < 0) { + lastTokenIdx = 0; + } + return embeddings.get(lastTokenIdx); + } + /** * Creates a builder to build a {@code TextEmbeddingTranslator}. * @@ -216,10 +231,10 @@ public HuggingfaceTextEmbeddingTranslator.Builder optPoolingMode(String poolingM && !"max".equals(poolingMode) && !"cls".equals(poolingMode) && !"mean_sqrt_len".equals(poolingMode) - && !"weightedmean".equals(poolingMode)) { + && !"weightedmean".equals(poolingMode) + && !"lasttoken".equals(poolingMode)) { throw new IllegalArgumentException( - "Invalid pooling model, must be one of [mean_tokens, max_tokens," - + " cls_token, mean_sqrt_len_tokens, weightedmean_tokens]." + "Invalid pooling model, must be one of [mean, max, cls, mean_sqrt_len, weightedmean, lasttoken]." ); } this.pooling = poolingMode; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java index 03b993c7cb..10a2391f49 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java @@ -112,6 +112,9 @@ public Output processOutput(TranslatorContext ctx, NDList list) { case CLS: embeddings = embeddings.get(0); break; + case LAST_TOKEN: + embeddings = lastTokenPool(embeddings, inputAttentionMask); + break; default: throw new IllegalArgumentException("Unsupported pooling method"); } @@ -172,6 +175,18 @@ private NDArray weightedMeanPool(NDArray embeddings, NDArray attentionMask) { return embeddingSum.div(maskSum); } + private NDArray lastTokenPool(NDArray embeddings, NDArray attentionMask) { + // Sum attention mask to get count of real tokens + long tokenCount = attentionMask.sum().toLongArray()[0]; + // Last token index (0-based) + long lastTokenIdx = tokenCount - 1; + // Handle edge case + if (lastTokenIdx < 0) { + lastTokenIdx = 0; + } + return embeddings.get(lastTokenIdx); + } + @Override public void setArguments(Map arguments) {} } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java index 9546710580..7d8a6e9234 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java @@ -214,6 +214,28 @@ public void initModel_predict_ONNX_albert() throws URISyntaxException { initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, 768); } + @Test + public void initModel_predict_ONNX_LastTokenPooling() throws URISyntaxException { + String modelFile = "all-MiniLM-L6-v2_onnx.zip"; + String modelType = "bert"; + TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.LAST_TOKEN; + boolean normalize = true; + int modelMaxLength = 512; + MLModelFormat modelFormat = MLModelFormat.ONNX; + initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension); + } + + @Test + public void initModel_predict_TorchScript_Huggingface_LastTokenPooling() throws URISyntaxException { + String modelFile = "all-MiniLM-L6-v2_torchscript_huggingface.zip"; + String modelType = "bert"; + TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.LAST_TOKEN; + boolean normalize = true; + int modelMaxLength = 512; + MLModelFormat modelFormat = MLModelFormat.TORCH_SCRIPT; + initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension); + } + private void initModel_predict_HuggingfaceModel( String modelFile, String modelType, diff --git a/release-notes/opensearch-ml-commons.release-notes-3.4.0.0.md b/release-notes/opensearch-ml-commons.release-notes-3.4.0.0.md index 87ffd3da65..bf5084c9e9 100644 --- a/release-notes/opensearch-ml-commons.release-notes-3.4.0.0.md +++ b/release-notes/opensearch-ml-commons.release-notes-3.4.0.0.md @@ -3,6 +3,7 @@ Compatible with OpenSearch and OpenSearch Dashboards version 3.4.0 ### Enhancements +* Add LAST_TOKEN pooling support for text embedding models ([#4709](https://github.com/opensearch-project/ml-commons/issues/4709)) * Declare credential and *.Authorization as sensitive param in create connector API ([#4308](https://github.com/opensearch-project/ml-commons/pull/4308)) * Pass resourceType instead of resourceIndex to resourceSharingClient ([#4333](https://github.com/opensearch-project/ml-commons/pull/4333)) * allow higher maximum number of batch inference job tasks ([#4474](https://github.com/opensearch-project/ml-commons/pull/4474))