Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,27 +80,29 @@ public NDList processInput(TranslatorContext ctx, String input) {
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray embeddings;

// NONE pooling mode uses pre-pooled output directly if available
// NONE pooling mode uses pre-pooled output directly — early return to skip unnecessary work
if ("none".equals(pooling)) {
// Try to get pre-pooled output (sentence_embedding, pooler_output, etc.)
embeddings = list.get("sentence_embedding");
if (embeddings == null) {
embeddings = list.get("pooler_output");
}
if (embeddings == null && list.size() > 1) {
// Use second output if available
embeddings = list.get(1);
}
if (embeddings == null) {
// Fallback to first output
embeddings = list.get(0);
}
} else {
// For other pooling modes, use last_hidden_state or first output
embeddings = list.get("last_hidden_state");
if (embeddings == null) {
embeddings = list.get(0);
if (normalize) {
embeddings = embeddings.normalize(2, 0);
}
return embeddings.toFloatArray();
}

// For other pooling modes, use last_hidden_state or first output
embeddings = list.get("last_hidden_state");
if (embeddings == null) {
embeddings = list.get(0);
}
Encoding encoding = (Encoding) ctx.getAttachment("encoding");
long[] attentionMask = encoding.getAttentionMask();
Expand All @@ -123,10 +125,7 @@ public float[] processOutput(TranslatorContext ctx, NDList list) {
embeddings = embeddings.get(0);
break;
case "lasttoken":
embeddings = lastTokenPool(embeddings, inputAttentionMask);
break;
case "none":
// No pooling - use pre-pooled output as-is
embeddings = TextEmbeddingPoolingUtils.lastTokenPool(embeddings, inputAttentionMask);
break;
default:
throw new AssertionError("Unexpected pooling model: " + pooling);
Expand Down Expand Up @@ -172,18 +171,6 @@ private NDArray weightedMeanPool(NDArray embeddings, NDArray inputAttentionMask)
return embeddingSum.div(maskSum);
}

private NDArray lastTokenPool(NDArray embeddings, NDArray inputAttentionMask) {
// Sum attention mask to get count of real tokens
long tokenCount = (long) inputAttentionMask.sum().toFloatArray()[0];
// Last token index (0-based)
long lastTokenIdx = tokenCount - 1;
// Handle edge case
if (lastTokenIdx < 0) {
lastTokenIdx = 0;
}
return embeddings.get(lastTokenIdx);
}

/**
* Creates a builder to build a {@code TextEmbeddingTranslator}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
embeddings = embeddings.get(0);
break;
case LAST_TOKEN:
embeddings = lastTokenPool(embeddings, inputAttentionMask);
embeddings = TextEmbeddingPoolingUtils.lastTokenPool(embeddings, inputAttentionMask);
break;
case NONE:
// No pooling - use pre-pooled output as-is
Expand Down Expand Up @@ -187,18 +187,6 @@ private NDArray weightedMeanPool(NDArray embeddings, NDArray attentionMask) {
return embeddingSum.div(maskSum);
}

private NDArray lastTokenPool(NDArray embeddings, NDArray attentionMask) {
// Sum attention mask to get count of real tokens
long tokenCount = attentionMask.sum().toLongArray()[0];
// Last token index (0-based)
long lastTokenIdx = tokenCount - 1;
// Handle edge case
if (lastTokenIdx < 0) {
lastTokenIdx = 0;
}
return embeddings.get(lastTokenIdx);
}

@Override
public void setArguments(Map<String, ?> arguments) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.text_embedding;

import ai.djl.ndarray.NDArray;

/**
* Shared pooling utility methods for text embedding translators.
*/
final class TextEmbeddingPoolingUtils {

private TextEmbeddingPoolingUtils() {}

/**
* Extracts the embedding of the last non-padding token. Used for decoder-only models
* where the final token captures cumulative context through causal attention.
*
* @param embeddings the token embeddings (sequence_length x hidden_size)
* @param attentionMask the attention mask (sequence_length), 1 for real tokens, 0 for padding
* @return the embedding at the last non-padding token position
*/
static NDArray lastTokenPool(NDArray embeddings, NDArray attentionMask) {
long tokenCount = attentionMask.sum().toLongArray()[0];
long lastTokenIdx = Math.max(tokenCount - 1, 0);
return embeddings.get(lastTokenIdx);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
Compatible with OpenSearch and OpenSearch Dashboards version 3.4.0

### Enhancements
* Add LAST_TOKEN pooling support for text embedding models ([#4709](https://github.com/opensearch-project/ml-commons/issues/4709))
* Declare credential and *.Authorization as sensitive param in create connector API ([#4308](https://github.com/opensearch-project/ml-commons/pull/4308))
* Pass resourceType instead of resourceIndex to resourceSharingClient ([#4333](https://github.com/opensearch-project/ml-commons/pull/4333))
* allow higher maximum number of batch inference job tasks ([#4474](https://github.com/opensearch-project/ml-commons/pull/4474))

### Bug Fixes
* Add NONE pooling mode to support pre-pooled model outputs, fixing bug where MEAN pooling was applied by default ([#4708](https://github.com/opensearch-project/ml-commons/issues/4708))
* Fix agent type update ([#4341](https://github.com/opensearch-project/ml-commons/pull/4341))
* Handle edge case of empty values of tool configs ([#4479](https://github.com/opensearch-project/ml-commons/pull/4479))
* Fix OpenAI RAG integration tests: Replace Wikimedia image URL with Unsplash ([#4472](https://github.com/opensearch-project/ml-commons/pull/4472))
Expand Down
Loading