Skip to content

Add LAST_TOKEN pooling implementation for text embedding models#4711

Merged
ylwu-amzn merged 2 commits intoopensearch-project:mainfrom
aneesh-db:feature/last-token-pooling-oss
Mar 25, 2026
Merged

Add LAST_TOKEN pooling implementation for text embedding models#4711
ylwu-amzn merged 2 commits intoopensearch-project:mainfrom
aneesh-db:feature/last-token-pooling-oss

Conversation

@aneesh-db
Copy link
Copy Markdown
Contributor

@aneesh-db aneesh-db commented Mar 10, 2026

Description

Adds the implementation for LAST_TOKEN pooling in text embedding translators. The LAST_TOKEN enum value already exists in PoolingMode but had no actual implementation in the translators.

LAST_TOKEN pooling extracts the embedding of the last non-padding token, which is the correct pooling strategy for decoder-only models (GPT-style, Qwen3, etc.) where the final token captures cumulative context through causal attention.

How it works:

  1. Sum the attention mask to determine the count of real (non-padding) tokens
  2. Extract the embedding at the last non-padding token position (index = token_count - 1)
  3. Handle edge case of empty sequences (default to index 0)

Changes:

  • Add LAST_TOKEN case to ONNXSentenceTransformerTextEmbeddingTranslator with lastTokenPool() method (uses int64 attention mask via toLongArray())
  • Add lasttoken case to HuggingfaceTextEmbeddingTranslator with lastTokenPool() method (uses float32 attention mask via toFloatArray())
  • Update pooling mode validation to accept lasttoken
  • Add unit tests for both ONNX and TorchScript models
  • Update documentation with pooling method descriptions table
  • Add release notes entry

Validated with: Qwen3-Embedding-0.6B producing correct 1024-dimensional normalized embeddings matching Python inference output.

Related Issues

Resolves #4709

Check List

  • New functionality includes testing.
  • New functionality has been documented.
  • API changes companion pull request created.
  • Commits are signed per the DCO using --signoff.
  • Public documentation issue/PR created.

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.

@opensearch-trigger-bot opensearch-trigger-bot bot added the documentation Improvements or additions to documentation label Mar 10, 2026
@aneesh-db aneesh-db had a problem deploying to ml-commons-cicd-env-require-approval March 10, 2026 09:08 — with GitHub Actions Error
@aneesh-db aneesh-db had a problem deploying to ml-commons-cicd-env-require-approval March 10, 2026 09:08 — with GitHub Actions Error
@aneesh-db aneesh-db had a problem deploying to ml-commons-cicd-env-require-approval March 10, 2026 09:08 — with GitHub Actions Failure
@aneesh-db aneesh-db had a problem deploying to ml-commons-cicd-env-require-approval March 10, 2026 09:08 — with GitHub Actions Failure
@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 10, 2026

PR Reviewer Guide 🔍

(Review updated until commit cc36e5f)

Here are some key observations to aid the review process:

🧪 PR contains tests
🔒 No security concerns identified
✅ No TODO sections
🔀 No multiple PR themes
⚡ Recommended focus areas for review

Batch Input Handling

The lastTokenPool method sums the entire attention mask across all tokens and sequences to get a single tokenCount. For batched inputs (multiple sequences), this would sum all tokens across all sequences, producing an incorrect index. The method should handle per-sequence token counts if batched inference is supported.

private NDArray lastTokenPool(NDArray embeddings, NDArray inputAttentionMask) {
    // Sum attention mask to get count of real tokens
    long tokenCount = (long) inputAttentionMask.sum().toFloatArray()[0];
    // Last token index (0-based)
    long lastTokenIdx = tokenCount - 1;
    // Handle edge case
    if (lastTokenIdx < 0) {
        lastTokenIdx = 0;
    }
    return embeddings.get(lastTokenIdx);
}
Batch Input Handling

Same issue as in HuggingfaceTextEmbeddingTranslator: the lastTokenPool method sums the entire attention mask to get a single tokenCount, which would be incorrect for batched inputs. If the embeddings tensor has shape [batch, seq_len, hidden], embeddings.get(lastTokenIdx) would index into the batch dimension rather than the sequence dimension.

private NDArray lastTokenPool(NDArray embeddings, NDArray attentionMask) {
    // Sum attention mask to get count of real tokens
    long tokenCount = attentionMask.sum().toLongArray()[0];
    // Last token index (0-based)
    long lastTokenIdx = tokenCount - 1;
    // Handle edge case
    if (lastTokenIdx < 0) {
        lastTokenIdx = 0;
    }
    return embeddings.get(lastTokenIdx);
}
Float Precision

In lastTokenPool, the attention mask sum is retrieved via toFloatArray()[0] and cast to long. Floating-point representation of integer values could theoretically introduce rounding errors (e.g., 3.9999 cast to long = 3). Using toLongArray() or rounding before casting would be safer.

long tokenCount = (long) inputAttentionMask.sum().toFloatArray()[0];
Test Model Mismatch

The LAST_TOKEN pooling tests use all-MiniLM-L6-v2 (a BERT encoder-only model) to validate the implementation. LAST_TOKEN pooling is specifically designed for decoder-only (causal) models. While the test may pass functionally, it does not validate the correctness of the pooling strategy for its intended use case, and results may not be meaningful.

public void initModel_predict_ONNX_LastTokenPooling() throws URISyntaxException {
    String modelFile = "all-MiniLM-L6-v2_onnx.zip";
    String modelType = "bert";
    TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.LAST_TOKEN;
    boolean normalize = true;
    int modelMaxLength = 512;
    MLModelFormat modelFormat = MLModelFormat.ONNX;
    initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension);
}

@Test
public void initModel_predict_TorchScript_Huggingface_LastTokenPooling() throws URISyntaxException {
    String modelFile = "all-MiniLM-L6-v2_torchscript_huggingface.zip";
    String modelType = "bert";
    TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.LAST_TOKEN;
    boolean normalize = true;
    int modelMaxLength = 512;
    MLModelFormat modelFormat = MLModelFormat.TORCH_SCRIPT;
    initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension);
}

@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 10, 2026

PR Code Suggestions ✨

Latest suggestions up to cc36e5f

Explore these optional code suggestions:

CategorySuggestion                                                                                                                                    Impact
Possible issue
Fix incorrect last token index calculation

The current implementation sums the entire attention mask across all batches and
sequence positions, which is incorrect for batched inputs. For a batch of sequences,
the attention mask has shape [batch_size, seq_len], so summing all values gives the
total token count across all sequences, not per-sequence. The method should find the
last non-padding token index per sequence individually. Additionally, using
toFloatArray()[0] to get a long value may lose precision; the ONNX version correctly
uses toLongArray()[0].

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java [152-162]

 private NDArray lastTokenPool(NDArray embeddings, NDArray inputAttentionMask) {
-    // Sum attention mask to get count of real tokens
-    long tokenCount = (long) inputAttentionMask.sum().toFloatArray()[0];
-    // Last token index (0-based)
-    long lastTokenIdx = tokenCount - 1;
-    // Handle edge case
-    if (lastTokenIdx < 0) {
-        lastTokenIdx = 0;
+    // Find the index of the last non-padding token per sequence
+    // inputAttentionMask shape: [seq_len], embeddings shape: [seq_len, hidden_size]
+    long[] maskValues = inputAttentionMask.toLongArray();
+    long lastTokenIdx = 0;
+    for (long i = maskValues.length - 1; i >= 0; i--) {
+        if (maskValues[(int) i] != 0) {
+            lastTokenIdx = i;
+            break;
+        }
     }
     return embeddings.get(lastTokenIdx);
 }
Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies that using sum() as a proxy for the last token index only works when tokens are left-aligned and contiguous (no padding in the middle). However, for standard transformer tokenization with left-aligned padding, the sum approach is actually equivalent to finding the last non-zero index. The suggestion also correctly points out the toFloatArray()[0] precision issue vs toLongArray()[0]. The improved code is more robust and correct in the general case.

Medium
Fix last token index derivation from attention mask

Summing the entire attention mask to derive the last token index is incorrect when
the mask contains padding zeros interspersed or when the sequence doesn't start from
index 0. The sum gives the count of non-padding tokens, but this count equals the
last token index only if all non-padding tokens are contiguous and left-aligned. The
correct approach is to find the actual index of the last non-zero element in the
attention mask rather than using the sum as a proxy for the index.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java [178-188]

 private NDArray lastTokenPool(NDArray embeddings, NDArray attentionMask) {
-    // Sum attention mask to get count of real tokens
-    long tokenCount = attentionMask.sum().toLongArray()[0];
-    // Last token index (0-based)
-    long lastTokenIdx = tokenCount - 1;
-    // Handle edge case
-    if (lastTokenIdx < 0) {
-        lastTokenIdx = 0;
+    // Find the index of the last non-padding token
+    long[] maskValues = attentionMask.toLongArray();
+    long lastTokenIdx = 0;
+    for (long i = maskValues.length - 1; i >= 0; i--) {
+        if (maskValues[(int) i] != 0) {
+            lastTokenIdx = i;
+            break;
+        }
     }
     return embeddings.get(lastTokenIdx);
 }
Suggestion importance[1-10]: 7

__

Why: Similar to suggestion 1, the sum-based approach works for left-aligned sequences but fails for right-padded sequences where padding may appear in the middle or the sequence isn't contiguous. The improved code iterates from the end to find the actual last non-zero index, which is more robust and correct for edge cases.

Medium

Previous suggestions

Suggestions up to commit c144310
CategorySuggestion                                                                                                                                    Impact
Possible issue
Add bounds check to prevent index out of range

The current implementation sums the entire attention mask across all batches and
sequences, which is incorrect for batched inputs. For a batch of sequences, the
attention mask has shape [batch_size, seq_len], so summing all values gives the
total token count across all sequences rather than the last real token index for
each sequence. This will produce wrong results for any batch size > 1. The method
should find the last non-zero position in the attention mask for the relevant
sequence dimension.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java [152-162]

 private NDArray lastTokenPool(NDArray embeddings, NDArray inputAttentionMask) {
-    // Sum attention mask to get count of real tokens
+    // Find the index of the last real (non-padding) token
+    // inputAttentionMask shape: [seq_len], sum gives count of real tokens
+    long[] maskShape = inputAttentionMask.getShape().getShape();
+    long seqLen = maskShape[0];
+    // Sum along sequence to get count of real tokens
     long tokenCount = (long) inputAttentionMask.sum().toFloatArray()[0];
-    // Last token index (0-based)
     long lastTokenIdx = tokenCount - 1;
-    // Handle edge case
     if (lastTokenIdx < 0) {
         lastTokenIdx = 0;
+    }
+    if (lastTokenIdx >= seqLen) {
+        lastTokenIdx = seqLen - 1;
     }
     return embeddings.get(lastTokenIdx);
 }
Suggestion importance[1-10]: 5

__

Why: Adding an upper-bound check on lastTokenIdx against the sequence length is a valid defensive measure, but in practice the attention mask sum should never exceed the sequence length since it's a binary mask. The improvement is minor but adds robustness.

Low
Add upper bounds check to prevent index overflow

There is no upper-bound check on lastTokenIdx against the actual sequence length
(first dimension of embeddings). If tokenCount somehow exceeds the sequence length
dimension of the embeddings tensor, calling embeddings.get(lastTokenIdx) will throw
an index-out-of-bounds error. Add a bounds check to clamp lastTokenIdx to the valid
range.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java [178-188]

 private NDArray lastTokenPool(NDArray embeddings, NDArray attentionMask) {
     // Sum attention mask to get count of real tokens
     long tokenCount = attentionMask.sum().toLongArray()[0];
     // Last token index (0-based)
     long lastTokenIdx = tokenCount - 1;
-    // Handle edge case
+    // Handle edge case: no real tokens
     if (lastTokenIdx < 0) {
         lastTokenIdx = 0;
+    }
+    // Handle edge case: index exceeds embedding sequence length
+    long seqLen = embeddings.getShape().getShape()[0];
+    if (lastTokenIdx >= seqLen) {
+        lastTokenIdx = seqLen - 1;
     }
     return embeddings.get(lastTokenIdx);
 }
Suggestion importance[1-10]: 5

__

Why: Similar to the first suggestion, adding an upper-bound check on lastTokenIdx against seqLen is a valid defensive measure. The attention mask is binary so tokenCount shouldn't exceed sequence length, but the bounds check adds robustness against unexpected inputs.

Low
Suggestions up to commit 6714448
CategorySuggestion                                                                                                                                    Impact
Possible issue
Fix attention mask summation for batched inputs

The current implementation sums the entire attention mask across all batches and
sequences, which is incorrect for batched inputs. The attention mask is typically a
2D tensor (batch_size × sequence_length), so summing all values gives the total
token count across all sequences rather than per-sequence counts. For a
single-sequence case, this may work, but the logic should use per-row last-token
indexing. Additionally, using toFloatArray()[0] to get a long value may lose
precision for large counts.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java [152-162]

 private NDArray lastTokenPool(NDArray embeddings, NDArray inputAttentionMask) {
-    // Sum attention mask to get count of real tokens
-    long tokenCount = (long) inputAttentionMask.sum().toFloatArray()[0];
-    // Last token index (0-based)
-    long lastTokenIdx = tokenCount - 1;
-    // Handle edge case
-    if (lastTokenIdx < 0) {
-        lastTokenIdx = 0;
-    }
+    // Sum attention mask along sequence dimension to get count of real tokens per sequence
+    NDArray tokenCounts = inputAttentionMask.sum(new int[]{1}); // shape: [batch_size]
+    // Last token index (0-based) for the first (or only) sequence
+    long tokenCount = tokenCounts.toLongArray()[0];
+    long lastTokenIdx = Math.max(tokenCount - 1, 0);
     return embeddings.get(lastTokenIdx);
 }
Suggestion importance[1-10]: 5

__

Why: The suggestion raises a valid concern about batched inputs and precision loss from toFloatArray()[0], but the improved code still only handles the first sequence ([0]) rather than properly handling all sequences in a batch. The fix is incomplete and the existing code in ONNXSentenceTransformerTextEmbeddingTranslator already uses toLongArray()[0] which avoids the float precision issue. The batched input concern is real but the proposed solution doesn't fully address it.

Low

The LAST_TOKEN enum value exists in PoolingMode but has no implementation
in the translators. This adds the actual pooling logic that extracts the
embedding of the last non-padding token, which is needed for decoder-only
models (GPT-style, Qwen3, etc.) where the final token captures cumulative
context through causal attention.

Changes:
- Add LAST_TOKEN case to ONNXSentenceTransformerTextEmbeddingTranslator
  with lastTokenPool() method using int64 attention mask
- Add lasttoken case to HuggingfaceTextEmbeddingTranslator with
  lastTokenPool() method using float32 attention mask
- Update pooling mode validation to include lasttoken
- Add unit tests for ONNX and TorchScript models
- Update documentation with pooling method descriptions
- Add release notes entry

Resolves opensearch-project#4709

Signed-off-by: Aneesh Nema <aneesh.nema@databricks.com>
@aneesh-db aneesh-db force-pushed the feature/last-token-pooling-oss branch from 6714448 to c144310 Compare March 11, 2026 08:19
@github-actions
Copy link
Copy Markdown

Persistent review updated to latest commit c144310

@aneesh-db aneesh-db had a problem deploying to ml-commons-cicd-env-require-approval March 11, 2026 08:21 — with GitHub Actions Failure
@aneesh-db aneesh-db temporarily deployed to ml-commons-cicd-env-require-approval March 11, 2026 08:21 — with GitHub Actions Inactive
@aneesh-db aneesh-db had a problem deploying to ml-commons-cicd-env-require-approval March 11, 2026 08:21 — with GitHub Actions Error
@aneesh-db aneesh-db temporarily deployed to ml-commons-cicd-env-require-approval March 11, 2026 08:21 — with GitHub Actions Inactive
@github-actions
Copy link
Copy Markdown

Persistent review updated to latest commit cc36e5f

@aneesh-db aneesh-db had a problem deploying to ml-commons-cicd-env-require-approval March 17, 2026 08:51 — with GitHub Actions Failure
@aneesh-db aneesh-db had a problem deploying to ml-commons-cicd-env-require-approval March 17, 2026 08:51 — with GitHub Actions Error
@aneesh-db aneesh-db had a problem deploying to ml-commons-cicd-env-require-approval March 17, 2026 08:51 — with GitHub Actions Failure
@aneesh-db aneesh-db had a problem deploying to ml-commons-cicd-env-require-approval March 17, 2026 08:51 — with GitHub Actions Error
return embeddingSum.div(maskSum);
}

private NDArray lastTokenPool(NDArray embeddings, NDArray inputAttentionMask) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this method be reused rather than defined twice?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in #4744

Compatible with OpenSearch and OpenSearch Dashboards version 3.4.0

### Enhancements
* Add LAST_TOKEN pooling support for text embedding models ([#4709](https://github.com/opensearch-project/ml-commons/issues/4709))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can skip adding this. Let's remove this change, there is an auto generated release notes!

@ylwu-amzn ylwu-amzn merged commit fc7b333 into opensearch-project:main Mar 25, 2026
8 of 12 checks passed
@aneesh-db aneesh-db deleted the feature/last-token-pooling-oss branch March 27, 2026 10:12
aneesh-db added a commit to aneesh-db/ml-commons that referenced this pull request Mar 31, 2026
Release notes are auto-generated, removing entries added in opensearch-project#4710 and opensearch-project#4711.

Signed-off-by: Aneesh Nema <aneesh.nema@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Add LAST_TOKEN pooling mode for decoder-only text embedding models

4 participants