Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add 32x support for SQ encoder on Faiss [#3193](https://github.com/opensearch-project/k-NN/pull/3193)
* Faiss SQ 1bit MOS changes [#3182](https://github.com/opensearch-project/k-NN/pull/3182)
* Support compression to 1 bit for Lucene's scalar quantizer [#3144](https://github.com/opensearch-project/k-NN/pull/3144)
* Enable cosine similarity to return original vectors rather than normalized vectors with Faiss Engine [#3083](https://github.com/opensearch-project/k-NN/issues/3083)

### Maintenance
* Improve unit tests by tightening asserts [#3112](https://github.com/opensearch-project/k-NN/pull/3112)
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,8 @@ public class KNNConstants {
public static final int BYTE_ALIGNMENT_MASK = 7; // Used for rounding up to nearest byte (Byte.SIZE - 1)
// Define here: https://github.com/opensearch-project/remote-vector-index-builder/blob/main/API.md#index-parameters
public static final int MIN_DOCS_FOR_REMOTE_INDEX_BUILD = 4;

// Prefix for the NumericDocValues field that stores the L2 norm of vectors before normalization.
// Used with derived source to restore original vectors when Faiss + cosinesimil is configured.
public static final String NORM_FIELD_PREFIX = "_knn_norm_";
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we already prevent user-defined mappings from colliding with this prefix? I think we should either reserve it explicitly

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the important callout. Unlike _source, this won't throw a duplicate mapping exception, so we need to fully prevent the collision. It seems like we'd need to add an explicit rejection on the OpenSearch core side — can you think of any other approaches?

}
31 changes: 31 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNVectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.AccessLevel;
import lombok.NoArgsConstructor;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -60,4 +61,34 @@ public static int[] intListToArray(final List<Integer> integerList) {
}
return intArray;
}

/**
* Get the doc values field name used to store the L2 norm for a given vector field.
*
* @param fieldName the vector field name
* @return the norm field name
*/
public static String getNormFieldName(String fieldName) {
return KNNConstants.NORM_FIELD_PREFIX + fieldName;
}

/**
* Denormalize a normalized vector by multiplying each element by the given L2 norm.
*
* @param vector the normalized vector
* @param norm the L2 norm to restore
* @param inplace if true, modifies the input array; if false, returns a new array
* @return the denormalized vector
*/
public static float[] denormalize(float[] vector, float norm, boolean inplace) {
Objects.requireNonNull(vector, "vector must not be null");
if (norm <= 0 || !Float.isFinite(norm)) {
throw new IllegalArgumentException("norm must be a positive finite value, got: " + norm);
}
float[] result = inplace ? vector : Arrays.copyOf(vector, vector.length);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add in checks here to prevent NPEs

public static float[] denormalize(float[] vector, float norm, boolean inplace) {
   Objects.requireNonNull(vector, "vector must not be null");
   if (norm <= 0 || !Float.isFinite(norm)) {
       throw new IllegalArgumentException("norm must be a positive finite value, got: " + norm);
   }
  ...
}

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

for (int i = 0; i < result.length; i++) {
result[i] *= norm;
}
return result;
Comment on lines +83 to +92
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use SIMD here rather than just plain for loop, otherwise the latency will be high.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,23 @@ public class DerivedKnnFloatVectorField extends KnnFloatVectorField {
@Getter
private final boolean isDerivedEnabled;

@Getter
private final float vectorNorm;

/**
*
* @param name Name of the field
* @param vector vector for the field
* @param isDerivedEnabled boolean to indicate if derived source is enabled
*/
public DerivedKnnFloatVectorField(String name, float[] vector, boolean isDerivedEnabled) {
this(name, vector, isDerivedEnabled, 1.0f);
}

public DerivedKnnFloatVectorField(String name, float[] vector, boolean isDerivedEnabled, float norm) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maintain the original signature with a 1.0 value?

public DerivedKnnFloatVectorField(String name, float[] vector, boolean isDerivedEnabled) {
    ...
    this.vectorNorm = 1.0f;
}

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added original signature.

super(name, vector);
this.isDerivedEnabled = isDerivedEnabled;
this.vectorNorm = norm;
}

/**
Expand All @@ -38,7 +46,12 @@ public DerivedKnnFloatVectorField(String name, float[] vector, boolean isDerived
* @param isDerivedEnabled boolean to indicate if derived source is enabled
*/
public DerivedKnnFloatVectorField(String name, float[] vector, FieldType fieldType, boolean isDerivedEnabled) {
this(name, vector, fieldType, isDerivedEnabled, 1.0f);
}

public DerivedKnnFloatVectorField(String name, float[] vector, FieldType fieldType, boolean isDerivedEnabled, float norm) {
super(name, vector, fieldType);
this.isDerivedEnabled = isDerivedEnabled;
this.vectorNorm = norm;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.CheckedSupplier;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.common.KNNVectorUtil;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;

Expand All @@ -28,14 +29,49 @@ protected Object formatVector(
FieldInfo fieldInfo,
CheckedSupplier<Object, IOException> vectorSupplier,
CheckedSupplier<Object, IOException> vectorCloneSupplier
) throws IOException {
return formatVector(fieldInfo, vectorSupplier, vectorCloneSupplier, () -> 1.0f);
}

/**
* Utility method for formatting the vector values based on the vector data type, with lazy denormalization.
* The norm is only read from doc values when the vector actually needs denormalization (Faiss engine),
* avoiding unnecessary I/O for Lucene engine vectors.
*
* @param fieldInfo fieldinfo for the vector field
* @param vectorSupplier supplies vector (without clone)
* @param vectorCloneSupplier supplies clone of vector.
* @param normSupplier lazily supplies the L2 norm. Only invoked when denormalization is needed.
* @return vector formatted based on the vector data type. Typically, this will be a float[] or int[].
* @throws IOException if unable to deserialize stored vector
*/
protected Object formatVector(
FieldInfo fieldInfo,
CheckedSupplier<Object, IOException> vectorSupplier,
CheckedSupplier<Object, IOException> vectorCloneSupplier,
CheckedSupplier<Float, IOException> normSupplier
) throws IOException {
Object vectorValue = vectorSupplier.get();
// If the vector value is a byte[], we must deserialize
// If the vector value is a byte[], we must deserialize.
// This is the Faiss/nmslib path where vectors are stored as BinaryDocValues.
if (vectorValue instanceof byte[]) {
BytesRef vectorBytesRef = new BytesRef((byte[]) vectorValue);
VectorDataType vectorDataType = FieldInfoExtractor.extractVectorDataType(fieldInfo);
return KNNVectorFieldMapperUtil.deserializeStoredVector(vectorBytesRef, vectorDataType);
Object deserialized = KNNVectorFieldMapperUtil.deserializeStoredVector(vectorBytesRef, vectorDataType);
if (deserialized instanceof float[] floatVector) {
denormalizeIfNeeded(floatVector, normSupplier);
}
return deserialized;
}
float[] vector = (float[]) vectorCloneSupplier.get();
denormalizeIfNeeded(vector, normSupplier);
return vector;
}

private void denormalizeIfNeeded(float[] vector, CheckedSupplier<Float, IOException> normSupplier) throws IOException {
float norm = normSupplier.get();
if (norm != 1.0f) {
KNNVectorUtil.denormalize(vector, norm, true);
}
return vectorCloneSupplier.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.index.mapper.SourceFieldMapper;
import org.opensearch.index.shard.IndexingOperationListener;
import org.opensearch.knn.common.KNNVectorUtil;
import org.opensearch.knn.index.DerivedKnnByteVectorField;
import org.opensearch.knn.index.DerivedKnnFloatVectorField;
import org.opensearch.knn.index.VectorDataType;
Expand Down Expand Up @@ -106,8 +107,12 @@ private Pair<Function<Map<String, Object>, Map<String, Object>>> createInjectTra
for (Iterator<IndexableField> it = document.iterator(); it.hasNext();) {
IndexableField indexableField = it.next();
if (indexableField instanceof DerivedKnnFloatVectorField knnVectorFieldType && knnVectorFieldType.isDerivedEnabled()) {
injectedVectors.computeIfAbsent(indexableField.name(), k -> new ArrayList<>())
.add(formatVector(VectorDataType.FLOAT, knnVectorFieldType.vectorValue()));
Object vector = formatVector(VectorDataType.FLOAT, knnVectorFieldType.vectorValue());
float norm = knnVectorFieldType.getVectorNorm();
if (norm != 1.0f && vector instanceof float[] floatVector) {
vector = KNNVectorUtil.denormalize(floatVector, norm, false);
}
injectedVectors.computeIfAbsent(indexableField.name(), k -> new ArrayList<>()).add(vector);
}

if (indexableField instanceof DerivedKnnByteVectorField knnByteVectorField && knnByteVectorField.isDerivedEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.derivedsource;

import org.apache.lucene.index.NumericDocValues;
import org.opensearch.common.CheckedSupplier;

import java.io.IOException;

/**
* Supplies the L2 norm for a given document. Used to denormalize vectors when reconstructing _source.
*/
@FunctionalInterface
public interface DerivedSourceNormSupplier {

/**
* A no-op supplier that always returns 1.0f (no denormalization).
*/
DerivedSourceNormSupplier UNIT = (docId) -> 1.0f;

/**
* Get the L2 norm for the given document.
*
* @param docId document ID to advance to
* @return L2 norm value. 1.0f means no denormalization needed.
* @throws IOException if an I/O error occurs
*/
float getNorm(int docId) throws IOException;

/**
* Create a DerivedSourceNormSupplier backed by NumericDocValues.
*
* @param supplier supplies a fresh NumericDocValues iterator on each call
* @return DerivedSourceNormSupplier that reads norm from doc values
*/
static DerivedSourceNormSupplier fromDocValues(CheckedSupplier<NumericDocValues, IOException> supplier) {
return (docId) -> {
NumericDocValues dv = supplier.get();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every call to getNorm(docId) creates a fresh NumericDocValues iterator. This is safe (no stale state) but potentially expensive. DocValuesProducer.getNumeric() typically returns a new iterator. In the read path, this is called once per document per field during _source reconstruction. For a GET or search with _source enabled, this is fine. But during a force merge that touches millions of docs, the overhead could add up. @navneet1v what do u think here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During merges, since shouldInject is set to false and doc values are not read, the overhead from this processing may be somewhat mitigated.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every call to getNorm(docId) creates a fresh NumericDocValues iterator.

Why we are doing this? Even for search this is just extra latency we are adding. Why we are not going with the behavior where per search request/merge we just have 1 Norms producer.

if (!dv.advanceExact(docId)) {
return 1.0f;
}
return Float.intBitsToFloat((int) dv.longValue());
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ public static void addDerivedVectorFieldsSegmentInfoAttribute(SegmentInfo segmen
String fieldName = isNested ? NESTED_DERIVED_SOURCE_FIELD : DERIVED_SOURCE_FIELD;
segmentInfo.putAttribute(fieldName, String.join(DELIMETER, fields));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ public DerivedSourceVectorTransformer(
PerFieldDerivedVectorTransformer perFieldDerivedVectorTransformer = PerFieldDerivedVectorTransformerFactory.create(
derivedFieldInfo.fieldInfo(),
derivedFieldInfo.isNested(),
derivedSourceReaders
derivedSourceReaders,
segmentReadState.fieldInfos
);
perFieldDerivedVectorTransformers.put(derivedFieldInfo.name(), perFieldDerivedVectorTransformer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,24 @@ public class NestedPerFieldDerivedVectorTransformer extends AbstractPerFieldDeri

private final FieldInfo childFieldInfo;
private final DerivedSourceReaders derivedSourceReaders;
private final DerivedSourceNormSupplier normSupplier;
private KNNVectorValues<?> vectorValues;
private int currentOffset;

/**
*
* @param childFieldInfo FieldInfo of the child field
* @param derivedSourceReaders Readers for access segment info
* @param normSupplier supplier for L2 norm values
*/
public NestedPerFieldDerivedVectorTransformer(FieldInfo childFieldInfo, DerivedSourceReaders derivedSourceReaders) {
public NestedPerFieldDerivedVectorTransformer(
FieldInfo childFieldInfo,
DerivedSourceReaders derivedSourceReaders,
DerivedSourceNormSupplier normSupplier
) {
this.childFieldInfo = childFieldInfo;
this.derivedSourceReaders = derivedSourceReaders;
this.normSupplier = normSupplier;
}

@Override
Expand All @@ -34,7 +42,12 @@ public Object apply(Object object) {
}

try {
Object vector = formatVector(childFieldInfo, vectorValues::getVector, vectorValues::conditionalCloneVector);
Object vector = formatVector(
childFieldInfo,
vectorValues::getVector,
vectorValues::conditionalCloneVector,
() -> normSupplier.getNorm(currentOffset)
);
vectorValues.nextDoc();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currentOffset is set once in setCurrentDoc and never updated. for multi-valued nested fields, all nested vectors would use the first child's norm?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added offset incrementation.

return vector;
} catch (IOException e) {
Expand All @@ -50,5 +63,6 @@ public void setCurrentDoc(int offset, int docId) throws IOException {
derivedSourceReaders.getKnnVectorsReader()
);
vectorValues.advance(offset);
currentOffset = offset;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,34 @@
package org.opensearch.knn.index.codec.derivedsource;

import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.opensearch.knn.common.KNNVectorUtil;

public class PerFieldDerivedVectorTransformerFactory {

/**
* Create a {@link PerFieldDerivedVectorTransformer} instance based on information in field info.
*
* @param fieldInfo FieldInfo for the field to create the injector for
* @param isNested whether the field is nested
* @param derivedSourceReaders {@link DerivedSourceReaders} instance
* @return PerFieldDerivedVectorInjector instance
* @param fieldInfos FieldInfos to look up the norm field
* @return PerFieldDerivedVectorTransformer instance
*/
public static PerFieldDerivedVectorTransformer create(
FieldInfo fieldInfo,
boolean isNested,
DerivedSourceReaders derivedSourceReaders
DerivedSourceReaders derivedSourceReaders,
FieldInfos fieldInfos
) {
// Nested case
FieldInfo normFieldInfo = fieldInfos.fieldInfo(KNNVectorUtil.getNormFieldName(fieldInfo.name));
DerivedSourceNormSupplier normSupplier = normFieldInfo != null
? DerivedSourceNormSupplier.fromDocValues(() -> derivedSourceReaders.getDocValuesProducer().getNumeric(normFieldInfo))
: DerivedSourceNormSupplier.UNIT;

if (isNested) {
return new NestedPerFieldDerivedVectorTransformer(fieldInfo, derivedSourceReaders);
return new NestedPerFieldDerivedVectorTransformer(fieldInfo, derivedSourceReaders, normSupplier);
}

// Non-nested case
return new RootPerFieldDerivedVectorTransformer(fieldInfo, derivedSourceReaders);
return new RootPerFieldDerivedVectorTransformer(fieldInfo, derivedSourceReaders, normSupplier);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,32 @@ public class RootPerFieldDerivedVectorTransformer extends AbstractPerFieldDerive

private final FieldInfo fieldInfo;
private final CheckedSupplier<KNNVectorValues<?>, IOException> vectorValuesSupplier;
private final DerivedSourceNormSupplier normSupplier;
private KNNVectorValues<?> vectorValues;
private int currentDocId;

/**
* Constructor for RootPerFieldDerivedVectorTransformer.
*
* @param fieldInfo FieldInfo for the field to create the injector for
* @param derivedSourceReaders {@link DerivedSourceReaders} instance
* @param normSupplier supplier for L2 norm values
*/
public RootPerFieldDerivedVectorTransformer(FieldInfo fieldInfo, DerivedSourceReaders derivedSourceReaders) {
public RootPerFieldDerivedVectorTransformer(FieldInfo fieldInfo, DerivedSourceReaders derivedSourceReaders, DerivedSourceNormSupplier normSupplier) {
this.fieldInfo = fieldInfo;
this.vectorValuesSupplier = () -> KNNVectorValuesFactory.getVectorValues(
fieldInfo,
derivedSourceReaders.getDocValuesProducer(),
derivedSourceReaders.getKnnVectorsReader()
);
this.normSupplier = normSupplier;
}

@Override
public void setCurrentDoc(int offset, int docId) throws IOException {
vectorValues = vectorValuesSupplier.get();
vectorValues.advance(docId);
currentDocId = docId;
}

@Override
Expand All @@ -46,7 +51,12 @@ public Object apply(Object object) {
}

try {
return formatVector(fieldInfo, vectorValues::getVector, vectorValues::conditionalCloneVector);
return formatVector(
fieldInfo,
vectorValues::getVector,
vectorValues::conditionalCloneVector,
() -> normSupplier.getNorm(currentDocId)
);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Loading
Loading