-
Notifications
You must be signed in to change notification settings - Fork 198
Store vector L2 norm to restore original vectors for Faiss cosine with derived source #3216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bc28c2d
9bbf8d2
8753743
bd0c4fa
8de56cc
350ce6a
1270ba5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| import lombok.AccessLevel; | ||
| import lombok.NoArgsConstructor; | ||
|
|
||
| import java.util.Arrays; | ||
| import java.util.List; | ||
| import java.util.Objects; | ||
|
|
||
|
|
@@ -60,4 +61,34 @@ public static int[] intListToArray(final List<Integer> integerList) { | |
| } | ||
| return intArray; | ||
| } | ||
|
|
||
| /** | ||
| * Get the doc values field name used to store the L2 norm for a given vector field. | ||
| * | ||
| * @param fieldName the vector field name | ||
| * @return the norm field name | ||
| */ | ||
| public static String getNormFieldName(String fieldName) { | ||
| return KNNConstants.NORM_FIELD_PREFIX + fieldName; | ||
| } | ||
|
|
||
| /** | ||
| * Denormalize a normalized vector by multiplying each element by the given L2 norm. | ||
| * | ||
| * @param vector the normalized vector | ||
| * @param norm the L2 norm to restore | ||
| * @param inplace if true, modifies the input array; if false, returns a new array | ||
| * @return the denormalized vector | ||
| */ | ||
| public static float[] denormalize(float[] vector, float norm, boolean inplace) { | ||
| Objects.requireNonNull(vector, "vector must not be null"); | ||
| if (norm <= 0 || !Float.isFinite(norm)) { | ||
| throw new IllegalArgumentException("norm must be a positive finite value, got: " + norm); | ||
| } | ||
| float[] result = inplace ? vector : Arrays.copyOf(vector, vector.length); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add in checks here to prevent NPEs
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added. |
||
| for (int i = 0; i < result.length; i++) { | ||
| result[i] *= norm; | ||
| } | ||
| return result; | ||
|
Comment on lines
+83
to
+92
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use SIMD here rather than just plain for loop, otherwise the latency will be high. |
||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,15 +19,23 @@ public class DerivedKnnFloatVectorField extends KnnFloatVectorField { | |
| @Getter | ||
| private final boolean isDerivedEnabled; | ||
|
|
||
| @Getter | ||
| private final float vectorNorm; | ||
|
|
||
| /** | ||
| * | ||
| * @param name Name of the field | ||
| * @param vector vector for the field | ||
| * @param isDerivedEnabled boolean to indicate if derived source is enabled | ||
| */ | ||
| public DerivedKnnFloatVectorField(String name, float[] vector, boolean isDerivedEnabled) { | ||
| this(name, vector, isDerivedEnabled, 1.0f); | ||
| } | ||
|
|
||
| public DerivedKnnFloatVectorField(String name, float[] vector, boolean isDerivedEnabled, float norm) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we maintain the original signature with a 1.0 value?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added original signature. |
||
| super(name, vector); | ||
| this.isDerivedEnabled = isDerivedEnabled; | ||
| this.vectorNorm = norm; | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -38,7 +46,12 @@ public DerivedKnnFloatVectorField(String name, float[] vector, boolean isDerived | |
| * @param isDerivedEnabled boolean to indicate if derived source is enabled | ||
| */ | ||
| public DerivedKnnFloatVectorField(String name, float[] vector, FieldType fieldType, boolean isDerivedEnabled) { | ||
| this(name, vector, fieldType, isDerivedEnabled, 1.0f); | ||
| } | ||
|
|
||
| public DerivedKnnFloatVectorField(String name, float[] vector, FieldType fieldType, boolean isDerivedEnabled, float norm) { | ||
| super(name, vector, fieldType); | ||
| this.isDerivedEnabled = isDerivedEnabled; | ||
| this.vectorNorm = norm; | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.knn.index.codec.derivedsource; | ||
|
|
||
| import org.apache.lucene.index.NumericDocValues; | ||
| import org.opensearch.common.CheckedSupplier; | ||
|
|
||
| import java.io.IOException; | ||
|
|
||
| /** | ||
| * Supplies the L2 norm for a given document. Used to denormalize vectors when reconstructing _source. | ||
| */ | ||
| @FunctionalInterface | ||
| public interface DerivedSourceNormSupplier { | ||
|
|
||
| /** | ||
| * A no-op supplier that always returns 1.0f (no denormalization). | ||
| */ | ||
| DerivedSourceNormSupplier UNIT = (docId) -> 1.0f; | ||
|
|
||
| /** | ||
| * Get the L2 norm for the given document. | ||
| * | ||
| * @param docId document ID to advance to | ||
| * @return L2 norm value. 1.0f means no denormalization needed. | ||
| * @throws IOException if an I/O error occurs | ||
| */ | ||
| float getNorm(int docId) throws IOException; | ||
|
|
||
| /** | ||
| * Create a DerivedSourceNormSupplier backed by NumericDocValues. | ||
| * | ||
| * @param supplier supplies a fresh NumericDocValues iterator on each call | ||
| * @return DerivedSourceNormSupplier that reads norm from doc values | ||
| */ | ||
| static DerivedSourceNormSupplier fromDocValues(CheckedSupplier<NumericDocValues, IOException> supplier) { | ||
| return (docId) -> { | ||
| NumericDocValues dv = supplier.get(); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Every call to getNorm(docId) creates a fresh NumericDocValues iterator. This is safe (no stale state) but potentially expensive. DocValuesProducer.getNumeric() typically returns a new iterator. In the read path, this is called once per document per field during _source reconstruction. For a GET or search with _source enabled, this is fine. But during a force merge that touches millions of docs, the overhead could add up. @navneet1v what do u think here?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. During merges, since shouldInject is set to false and doc values are not read, the overhead from this processing may be somewhat mitigated.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Why we are doing this? Even for search this is just extra latency we are adding. Why we are not going with the behavior where per search request/merge we just have 1 Norms producer. |
||
| if (!dv.advanceExact(docId)) { | ||
| return 1.0f; | ||
| } | ||
| return Float.intBitsToFloat((int) dv.longValue()); | ||
| }; | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,16 +15,24 @@ public class NestedPerFieldDerivedVectorTransformer extends AbstractPerFieldDeri | |
|
|
||
| private final FieldInfo childFieldInfo; | ||
| private final DerivedSourceReaders derivedSourceReaders; | ||
| private final DerivedSourceNormSupplier normSupplier; | ||
| private KNNVectorValues<?> vectorValues; | ||
| private int currentOffset; | ||
|
|
||
| /** | ||
| * | ||
| * @param childFieldInfo FieldInfo of the child field | ||
| * @param derivedSourceReaders Readers for access segment info | ||
| * @param normSupplier supplier for L2 norm values | ||
| */ | ||
| public NestedPerFieldDerivedVectorTransformer(FieldInfo childFieldInfo, DerivedSourceReaders derivedSourceReaders) { | ||
| public NestedPerFieldDerivedVectorTransformer( | ||
| FieldInfo childFieldInfo, | ||
| DerivedSourceReaders derivedSourceReaders, | ||
| DerivedSourceNormSupplier normSupplier | ||
| ) { | ||
| this.childFieldInfo = childFieldInfo; | ||
| this.derivedSourceReaders = derivedSourceReaders; | ||
| this.normSupplier = normSupplier; | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -34,7 +42,12 @@ public Object apply(Object object) { | |
| } | ||
|
|
||
| try { | ||
| Object vector = formatVector(childFieldInfo, vectorValues::getVector, vectorValues::conditionalCloneVector); | ||
| Object vector = formatVector( | ||
| childFieldInfo, | ||
| vectorValues::getVector, | ||
| vectorValues::conditionalCloneVector, | ||
| () -> normSupplier.getNorm(currentOffset) | ||
| ); | ||
| vectorValues.nextDoc(); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currentOffset is set once in setCurrentDoc and never updated. for multi-valued nested fields, all nested vectors would use the first child's norm?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added offset incrementation. |
||
| return vector; | ||
| } catch (IOException e) { | ||
|
|
@@ -50,5 +63,6 @@ public void setCurrentDoc(int offset, int docId) throws IOException { | |
| derivedSourceReaders.getKnnVectorsReader() | ||
| ); | ||
| vectorValues.advance(offset); | ||
| currentOffset = offset; | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we already prevent user-defined mappings from colliding with this prefix? I think we should either reserve it explicitly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the important callout. Unlike _source, this won't throw a duplicate mapping exception, so we need to fully prevent the collision. It seems like we'd need to add an explicit rejection on the OpenSearch core side — can you think of any other approaches?