Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions docs/model_serving_framework/text_embedding_model_examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,9 @@ POST /_plugins/_ml/models/zwla5YUB1qmVrJFlwzXJ/_unload

## 1.2 trace huggingface transformers model
Without [`sentence-transformers`](https://pypi.org/project/sentence-transformers/) installed, you can trace this model `AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')`.
But model traced this way doesn't include post-processing. So user have to specify post-process logic with `pooling_mode` and `normalize_result`.
But model traced this way doesn't include post-processing. So user have to specify post-process logic with `pooling_mode` and `normalize_result`.

Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`.
Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`, `lasttoken`.

The only difference is the uploading model input, for load/predict/profile/unload model, you can refer to ["1.1 trace sentence transformers model"](#11-trace-sentence-transformers-model).

Expand All @@ -322,7 +322,18 @@ POST /_plugins/_ml/models/_upload
User can export Pytorch model to ONNX, then upload and run it with ml-commons APIs.
This example ONNX model also needs to specify post-process logic with `pooling_mode` and `normalize_result`.

Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`.
Supported pooling method: `mean`, `mean_sqrt_len`, `max`, `weightedmean`, `cls`, `lasttoken`.

### Pooling Methods

| Method | Description |
|--------|-------------|
| `mean` | Averages all token embeddings weighted by attention mask |
| `mean_sqrt_len` | Mean pooling divided by square root of sequence length |
| `max` | Takes maximum value across all token positions |
| `weightedmean` | Weighted average where later tokens have higher weights |
| `cls` | Uses the first token (CLS token) embedding |
| `lasttoken` | Uses the last non-padding token's embedding. Useful for decoder-only models where the final token captures cumulative context |

The only difference is the uploading model input, for load/predict/profile/unload model, you can refer to ["1.1 trace sentence transformers model"](#11-trace-sentence-transformers-model).

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ public float[] processOutput(TranslatorContext ctx, NDList list) {
case "cls":
embeddings = embeddings.get(0);
break;
case "lasttoken":
embeddings = lastTokenPool(embeddings, inputAttentionMask);
break;
default:
throw new AssertionError("Unexpected pooling model: " + pooling);
}
Expand Down Expand Up @@ -146,6 +149,18 @@ private NDArray weightedMeanPool(NDArray embeddings, NDArray inputAttentionMask)
return embeddingSum.div(maskSum);
}

private NDArray lastTokenPool(NDArray embeddings, NDArray inputAttentionMask) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this method be reused rather than defined twice?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in #4744

// Sum attention mask to get count of real tokens
long tokenCount = (long) inputAttentionMask.sum().toFloatArray()[0];
// Last token index (0-based)
long lastTokenIdx = tokenCount - 1;
// Handle edge case
if (lastTokenIdx < 0) {
lastTokenIdx = 0;
}
return embeddings.get(lastTokenIdx);
}

/**
* Creates a builder to build a {@code TextEmbeddingTranslator}.
*
Expand Down Expand Up @@ -216,10 +231,10 @@ public HuggingfaceTextEmbeddingTranslator.Builder optPoolingMode(String poolingM
&& !"max".equals(poolingMode)
&& !"cls".equals(poolingMode)
&& !"mean_sqrt_len".equals(poolingMode)
&& !"weightedmean".equals(poolingMode)) {
&& !"weightedmean".equals(poolingMode)
&& !"lasttoken".equals(poolingMode)) {
throw new IllegalArgumentException(
"Invalid pooling model, must be one of [mean_tokens, max_tokens,"
+ " cls_token, mean_sqrt_len_tokens, weightedmean_tokens]."
"Invalid pooling model, must be one of [mean, max, cls, mean_sqrt_len, weightedmean, lasttoken]."
);
}
this.pooling = poolingMode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
case CLS:
embeddings = embeddings.get(0);
break;
case LAST_TOKEN:
embeddings = lastTokenPool(embeddings, inputAttentionMask);
break;
default:
throw new IllegalArgumentException("Unsupported pooling method");
}
Expand Down Expand Up @@ -172,6 +175,18 @@ private NDArray weightedMeanPool(NDArray embeddings, NDArray attentionMask) {
return embeddingSum.div(maskSum);
}

private NDArray lastTokenPool(NDArray embeddings, NDArray attentionMask) {
// Sum attention mask to get count of real tokens
long tokenCount = attentionMask.sum().toLongArray()[0];
// Last token index (0-based)
long lastTokenIdx = tokenCount - 1;
// Handle edge case
if (lastTokenIdx < 0) {
lastTokenIdx = 0;
}
return embeddings.get(lastTokenIdx);
}

@Override
public void setArguments(Map<String, ?> arguments) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,28 @@ public void initModel_predict_ONNX_albert() throws URISyntaxException {
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, 768);
}

@Test
public void initModel_predict_ONNX_LastTokenPooling() throws URISyntaxException {
String modelFile = "all-MiniLM-L6-v2_onnx.zip";
String modelType = "bert";
TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.LAST_TOKEN;
boolean normalize = true;
int modelMaxLength = 512;
MLModelFormat modelFormat = MLModelFormat.ONNX;
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension);
}

@Test
public void initModel_predict_TorchScript_Huggingface_LastTokenPooling() throws URISyntaxException {
String modelFile = "all-MiniLM-L6-v2_torchscript_huggingface.zip";
String modelType = "bert";
TextEmbeddingModelConfig.PoolingMode poolingMode = TextEmbeddingModelConfig.PoolingMode.LAST_TOKEN;
boolean normalize = true;
int modelMaxLength = 512;
MLModelFormat modelFormat = MLModelFormat.TORCH_SCRIPT;
initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMode, normalize, modelMaxLength, modelFormat, dimension);
}

private void initModel_predict_HuggingfaceModel(
String modelFile,
String modelType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Compatible with OpenSearch and OpenSearch Dashboards version 3.4.0

### Enhancements
* Add LAST_TOKEN pooling support for text embedding models ([#4709](https://github.com/opensearch-project/ml-commons/issues/4709))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can skip adding this. Let's remove this change, there is an auto generated release notes!

* Declare credential and *.Authorization as sensitive param in create connector API ([#4308](https://github.com/opensearch-project/ml-commons/pull/4308))
* Pass resourceType instead of resourceIndex to resourceSharingClient ([#4333](https://github.com/opensearch-project/ml-commons/pull/4333))
* allow higher maximum number of batch inference job tasks ([#4474](https://github.com/opensearch-project/ml-commons/pull/4474))
Expand Down
Loading