diff --git a/docs/changelog/83550.yaml b/docs/changelog/83550.yaml new file mode 100644 index 0000000000000..51ab72f642fe6 --- /dev/null +++ b/docs/changelog/83550.yaml @@ -0,0 +1,5 @@ +pr: 83550 +summary: "Script: Fields API for Dense Vector" +area: Infra/Scripting +type: enhancement +issues: [] diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/60_knn_and_binary_dv_fields_api.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/60_knn_and_binary_dv_fields_api.yml new file mode 100644 index 0000000000000..b583a25738215 --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/60_knn_and_binary_dv_fields_api.yml @@ -0,0 +1,848 @@ +--- +"size and isEmpty code works for any vector, including empty": + - skip: + version: " - 8.1.99" + reason: "Fields API for dense vector added in 8.2" + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + bdv: + type: dense_vector + dims: 3 + knn: + type: dense_vector + dims: 3 + index: true + similarity: l2_norm + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}' + - '{"index": {"_id": "2"}}' + - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}' + - '{"index": {"_id": "3"}}' + - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}' + - '{"index": {"_id": "missing_vector"}}' + - '{}' + + - do: + search: + body: + query: + script_score: + query: { match_all: {} } + script: + source: | + def dv = field(params.field).get(); + if (dv.isEmpty()) { + return dv.size(); + } + return dv.vector[2] * dv.size() + params: + field: bdv + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0._score: 3 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 2 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 1 } + - match: { hits.hits.3._id: "missing_vector" } + - match: { hits.hits.3._score: 0 } + + - do: + search: + body: + query: + script_score: + query: { match_all: {} } + script: + source: | + def dv = field(params.field).get(); + if (dv.isEmpty()) { + return dv.size(); + } + return dv.vector[2] * dv.size() + params: + field: knn + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0._score: 3 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 2 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 1 } + - match: { hits.hits.3._id: "missing_vector" } + - match: { hits.hits.3._score: 0 } + +--- +"null can be used for default value": + - skip: + version: " - 8.1.99" + reason: "Fields API for dense vector added in 8.2" + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + bdv: + type: dense_vector + dims: 3 + knn: + type: dense_vector + dims: 3 + index: true + similarity: l2_norm + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}' + - '{"index": {"_id": "2"}}' + - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}' + - '{"index": {"_id": "3"}}' + - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}' + - '{"index": {"_id": "missing_vector"}}' + - '{}' + + - do: + search: + body: + query: + script_score: + query: { match_all: {} } + script: + source: | + DenseVector dv = field(params.field).get(null); + if (dv == null) { + return 1; + } + return dv.vector[2]; + params: + field: bdv + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0._score: 3 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 2 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 1 } + - match: { hits.hits.3._id: "missing_vector" } + - match: { hits.hits.3._score: 1 } + + - do: + search: + body: + query: + script_score: + query: { match_all: {} } + script: + source: | + DenseVector dv = field(params.field).get(null); + if (dv == null) { + return 1; + } + return dv.vector[2]; + params: + field: knn + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0._score: 3 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 2 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 1 } + - match: { hits.hits.3._id: "missing_vector" } + - match: { hits.hits.3._score: 1 } + +--- +"empty dense vector throws for vector accesses": + - skip: + version: " - 8.1.99" + reason: "Fields API for dense vector added in 8.2" + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + bdv: + type: dense_vector + dims: 3 + knn: + type: dense_vector + dims: 3 + index: true + similarity: l2_norm + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}' + - '{"index": {"_id": "2"}}' + - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}' + - '{"index": {"_id": "3"}}' + - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}' + - '{"index": {"_id": "missing_vector"}}' + - '{}' + + - do: + catch: bad_request + search: + body: + query: + script_score: + query: { "bool": { "must_not": { "exists": { "field": "bdv" } } } } + script: + source: | + field(params.field).get().vector[2] + params: + field: bdv + + - match: { error.failed_shards.0.reason.caused_by.type: "illegal_argument_exception" } + - match: { error.failed_shards.0.reason.caused_by.reason: "Dense vector value missing for a field, use isEmpty() to check for a missing vector value" } + + - do: + catch: bad_request + search: + body: + query: + script_score: + query: { "bool": { "must_not": { "exists": { "field": "bdv" } } } } + script: + source: | + field(params.field).get().vector[2] + params: + field: knn + + - match: { error.failed_shards.0.reason.caused_by.type: "illegal_argument_exception" } + - match: { error.failed_shards.0.reason.caused_by.reason: "Dense vector value missing for a field, use isEmpty() to check for a missing vector value" } + + - do: + search: + body: + query: + script_score: + query: { "bool": { "must_not": { "exists": { "field": "bdv" } } } } + script: + source: | + float[] q = new float[1]; + q[0] = 3; + DenseVector dv = field(params.field).get(); + float score = 0; + try { score += dv.magnitude } catch (IllegalArgumentException e) { score += 10; } + try { score += dv.dotProduct(q) } catch (IllegalArgumentException e) { score += 200; } + try { score += dv.l1Norm(q) } catch (IllegalArgumentException e) { score += 3000; } + try { score += dv.l2Norm(q) } catch (IllegalArgumentException e) { score += 40000; } + try { score += dv.vector[0] } catch (IllegalArgumentException e) { score += 500000; } + try { score += dv.dims } catch (IllegalArgumentException e) { score += 6000000; } + return score; + params: + field: bdv + + - match: { hits.hits.0._id: "missing_vector" } + - match: { hits.hits.0._score: 6543210 } + + - do: + search: + body: + query: + script_score: + query: { "bool": { "must_not": { "exists": { "field": "bdv" } } } } + script: + source: | + float[] q = new float[1]; + q[0] = 3; + DenseVector dv = field(params.field).get(); + float score = 0; + try { score += dv.magnitude } catch (IllegalArgumentException e) { score += 10; } + try { score += dv.dotProduct(q) } catch (IllegalArgumentException e) { score += 200; } + try { score += dv.l1Norm(q) } catch (IllegalArgumentException e) { score += 3000; } + try { score += dv.l2Norm(q) } catch (IllegalArgumentException e) { score += 40000; } + try { score += dv.cosineSimilarity(q) } catch (IllegalArgumentException e) { score += 200000; } + try { score += dv.vector[0] } catch (IllegalArgumentException e) { score += 500000; } + try { score += dv.dims } catch (IllegalArgumentException e) { score += 6000000; } + return score; + params: + field: knn + + - match: { hits.hits.0._id: "missing_vector" } + - match: { hits.hits.0._score: 6743210 } + +--- +"dot product works on dense vectors": + - skip: + version: " - 8.1.99" + reason: "Fields API for dense vector added in 8.2" + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + bdv: + type: dense_vector + dims: 3 + knn: + type: dense_vector + dims: 3 + index: true + similarity: l2_norm + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}' + - '{"index": {"_id": "2"}}' + - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}' + - '{"index": {"_id": "3"}}' + - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}' + - '{"index": {"_id": "missing_vector"}}' + - '{}' + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + field(params.field).get().dotProduct(params.query) + params: + query: [4, 5, 6] + field: bdv + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0._score: 27 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 21 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 15 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + float[] query = new float[3]; + query[0] = 4; query[1] = 5; query[2] = 6; + field(params.field).get().dotProduct(query) + params: + field: bdv + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0._score: 27 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 21 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 15 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + field(params.field).get().dotProduct(params.query) + params: + query: [4, 5, 6] + field: knn + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0._score: 27 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 21 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 15 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + float[] query = new float[3]; + query[0] = 4; query[1] = 5; query[2] = 6; + field(params.field).get().dotProduct(query) + params: + field: knn + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0._score: 27 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 21 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 15 } + +--- +"iterator over dense vector values": + - skip: + version: " - 8.1.99" + reason: "Fields API for dense vector added in 8.2" + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + bdv: + type: dense_vector + dims: 3 + knn: + type: dense_vector + dims: 3 + index: true + similarity: l2_norm + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}' + - '{"index": {"_id": "2"}}' + - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}' + - '{"index": {"_id": "3"}}' + - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}' + - '{"index": {"_id": "missing_vector"}}' + - '{}' + + - do: + catch: bad_request + search: + body: + query: + script_score: + query: { match_all: {} } + script: + source: | + float sum = 0.0f; + for (def v : field(params.field)) { + sum += v; + } + return sum; + params: + field: bdv + + - match: { error.failed_shards.0.reason.caused_by.type: "unsupported_operation_exception" } + - match: { error.failed_shards.0.reason.caused_by.reason: "Cannot iterate over single valued dense_vector field, use get() instead" } + + - do: + catch: bad_request + search: + body: + query: + script_score: + query: { match_all: {} } + script: + source: | + float sum = 0.0f; + for (def v : field(params.field)) { + sum += v; + } + return sum; + params: + field: knn + + - match: { error.failed_shards.0.reason.caused_by.type: "unsupported_operation_exception" } + - match: { error.failed_shards.0.reason.caused_by.reason: "Cannot iterate over single valued dense_vector field, use get() instead"} + +--- +"l1Norm works on dense vectors": + - skip: + version: " - 8.1.99" + reason: "Fields API for dense vector added in 8.2" + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + bdv: + type: dense_vector + dims: 3 + knn: + type: dense_vector + dims: 3 + index: true + similarity: l2_norm + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}' + - '{"index": {"_id": "2"}}' + - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}' + - '{"index": {"_id": "3"}}' + - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}' + - '{"index": {"_id": "missing_vector"}}' + - '{}' + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + field(params.field).get().l1Norm(params.query) + params: + query: [4, 5, 6] + field: bdv + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 12 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + float[] query = new float[3]; + query[0] = 4; query[1] = 5; query[2] = 6; + field(params.field).get().l1Norm(query) + params: + field: bdv + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 12 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + field(params.field).get().l1Norm(params.query) + params: + query: [4, 5, 6] + field: knn + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 12 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + float[] query = new float[3]; + query[0] = 4; query[1] = 5; query[2] = 6; + field(params.field).get().l1Norm(query) + params: + field: knn + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 12 } + +--- +"l2Norm works on dense vectors": + - skip: + version: " - 8.1.99" + reason: "Fields API for dense vector added in 8.2" + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + bdv: + type: dense_vector + dims: 3 + knn: + type: dense_vector + dims: 3 + index: true + similarity: l2_norm + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}' + - '{"index": {"_id": "2"}}' + - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}' + - '{"index": {"_id": "3"}}' + - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}' + - '{"index": {"_id": "missing_vector"}}' + - '{}' + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + (int) field(params.field).get().l2Norm(params.query) + params: + query: [4, 5, 6] + field: bdv + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 7 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 6 } + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2._score: 5 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + float[] query = new float[3]; + query[0] = 4; query[1] = 5; query[2] = 6; + (int) field(params.field).get().l2Norm(query) + params: + field: bdv + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 7 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 6 } + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2._score: 5 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + (int) field(params.field).get().l2Norm(params.query) + params: + query: [4, 5, 6] + field: knn + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 7 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 6 } + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2._score: 5 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + float[] query = new float[3]; + query[0] = 4; query[1] = 5; query[2] = 6; + (int) field(params.field).get().l2Norm(query) + params: + field: knn + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 7 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 6 } + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2._score: 5 } + +--- +"cosineSimilarity works on dense vectors": + - skip: + version: " - 8.1.99" + reason: "Fields API for dense vector added in 8.2" + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + bdv: + type: dense_vector + dims: 3 + knn: + type: dense_vector + dims: 3 + index: true + similarity: l2_norm + + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}' + - '{"index": {"_id": "2"}}' + - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}' + - '{"index": {"_id": "3"}}' + - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}' + - '{"index": {"_id": "missing_vector"}}' + - '{}' + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + float[] query = new float[3]; + query[0] = 4; query[1] = 5; query[2] = 6; + (int) (field(params.field).get().cosineSimilarity(query) * 100.0f) + params: + field: bdv + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 98 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 97 } + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2._score: 92 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + (int) (field(params.field).get().cosineSimilarity(params.query) * 100.0f) + params: + query: [4, 5, 6] + field: knn + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 98 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 97 } + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2._score: 92 } + + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + (int) (field(params.field).get().cosineSimilarity(params.query) * 100.0f) + params: + query: [4, 5, 6] + field: bdv + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 98 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 97 } + - match: { hits.hits.2._id: "3" } + - match: { hits.hits.2._score: 92 } + +--- +"query vector of wrong type errors": + - skip: + version: " - 8.0.99" + reason: "Fields API for dense vector added in 8.2" + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + bdv: + type: dense_vector + dims: 3 + knn: + type: dense_vector + dims: 3 + index: true + similarity: l2_norm + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}' + + - do: + catch: bad_request + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + (int) field(params.field).get().l2Norm(params.query) + params: + query: "one, two, three" + field: bdv + + - match: { error.failed_shards.0.reason.caused_by.type: "illegal_argument_exception" } + - match: { error.failed_shards.0.reason.caused_by.reason: "Cannot use vector [one, two, three] with class [java.lang.String] as query vector" } + + - do: + catch: bad_request + search: + body: + query: + script_score: + query: { "exists": { "field": "bdv" } } + script: + source: | + (int) field(params.field).get().l2Norm(params.query) + params: + query: "one, two, three" + field: knn + + - match: { error.failed_shards.0.reason.caused_by.type: "illegal_argument_exception" } + - match: { error.failed_shards.0.reason.caused_by.reason: "Cannot use vector [one, two, three] with class [java.lang.String] as query vector" } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVector.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVector.java new file mode 100644 index 0000000000000..785016bed097a --- /dev/null +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVector.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.vectors.query; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; +import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; + +import java.nio.ByteBuffer; +import java.util.List; + +public class BinaryDenseVector implements DenseVector { + protected final BytesRef docVector; + protected final int dims; + protected final Version indexVersion; + + protected float[] decodedDocVector; + + public BinaryDenseVector(BytesRef docVector, int dims, Version indexVersion) { + this.docVector = docVector; + this.indexVersion = indexVersion; + this.dims = dims; + } + + @Override + public float[] getVector() { + if (decodedDocVector == null) { + decodedDocVector = new float[dims]; + VectorEncoderDecoder.decodeDenseVector(docVector, decodedDocVector); + } + return decodedDocVector; + } + + @Override + public float getMagnitude() { + return VectorEncoderDecoder.getMagnitude(indexVersion, docVector); + } + + @Override + public double dotProduct(float[] queryVector) { + ByteBuffer byteBuffer = wrap(docVector); + + double dotProduct = 0; + for (float v : queryVector) { + dotProduct += byteBuffer.getFloat() * v; + } + return dotProduct; + } + + @Override + public double dotProduct(List queryVector) { + ByteBuffer byteBuffer = wrap(docVector); + + double dotProduct = 0; + for (int i = 0; i < queryVector.size(); i++) { + dotProduct += byteBuffer.getFloat() * queryVector.get(i).floatValue(); + } + return dotProduct; + } + + @Override + public double l1Norm(float[] queryVector) { + ByteBuffer byteBuffer = wrap(docVector); + + double l1norm = 0; + for (float v : queryVector) { + l1norm += Math.abs(v - byteBuffer.getFloat()); + } + return l1norm; + } + + @Override + public double l1Norm(List queryVector) { + ByteBuffer byteBuffer = wrap(docVector); + + double l1norm = 0; + for (int i = 0; i < queryVector.size(); i++) { + l1norm += Math.abs(queryVector.get(i).floatValue() - byteBuffer.getFloat()); + } + return l1norm; + } + + @Override + public double l2Norm(float[] queryVector) { + ByteBuffer byteBuffer = wrap(docVector); + double l2norm = 0; + for (float queryValue : queryVector) { + double diff = byteBuffer.getFloat() - queryValue; + l2norm += diff * diff; + } + return Math.sqrt(l2norm); + } + + @Override + public double l2Norm(List queryVector) { + ByteBuffer byteBuffer = wrap(docVector); + double l2norm = 0; + for (Number number : queryVector) { + double diff = byteBuffer.getFloat() - number.floatValue(); + l2norm += diff * diff; + } + return Math.sqrt(l2norm); + } + + @Override + public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) { + if (normalizeQueryVector) { + return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude()); + } + return dotProduct(queryVector) / getMagnitude(); + } + + @Override + public double cosineSimilarity(List queryVector) { + return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude()); + } + + @Override + public int size() { + return 1; + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public int getDims() { + return dims; + } + + private static ByteBuffer wrap(BytesRef dv) { + return ByteBuffer.wrap(dv.bytes, dv.offset, dv.length); + } +} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorDocValuesField.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorDocValuesField.java new file mode 100644 index 0000000000000..ad1d016132547 --- /dev/null +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorDocValuesField.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.vectors.query; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; + +import java.io.IOException; + +public class BinaryDenseVectorDocValuesField extends DenseVectorDocValuesField { + + protected final BinaryDocValues input; + protected final Version indexVersion; + protected final int dims; + protected BytesRef value; + + public BinaryDenseVectorDocValuesField(BinaryDocValues input, String name, int dims, Version indexVersion) { + super(name); + this.input = input; + this.indexVersion = indexVersion; + this.dims = dims; + } + + @Override + public void setNextDocId(int docId) throws IOException { + if (input.advanceExact(docId)) { + value = input.binaryValue(); + } else { + value = null; + } + } + + @Override + public DenseVectorScriptDocValues getScriptDocValues() { + return new DenseVectorScriptDocValues(this, dims); + } + + @Override + public boolean isEmpty() { + return value == null; + } + + @Override + public DenseVector get() { + if (isEmpty()) { + return DenseVector.EMPTY; + } + + return new BinaryDenseVector(value, dims, indexVersion); + } + + @Override + public DenseVector get(DenseVector defaultValue) { + if (isEmpty()) { + return defaultValue; + } + return new BinaryDenseVector(value, dims, indexVersion); + } + + @Override + public DenseVector getInternal() { + return get(null); + } +} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValues.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValues.java deleted file mode 100644 index 852b63500a9bf..0000000000000 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValues.java +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.vectors.query; - -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.Version; -import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; - -import java.io.IOException; -import java.nio.ByteBuffer; - -public class BinaryDenseVectorScriptDocValues extends DenseVectorScriptDocValues { - - public static class BinaryDenseVectorSupplier implements DenseVectorSupplier { - - private final BinaryDocValues in; - private BytesRef value; - - public BinaryDenseVectorSupplier(BinaryDocValues in) { - this.in = in; - } - - @Override - public void setNextDocId(int docId) throws IOException { - if (in.advanceExact(docId)) { - value = in.binaryValue(); - } else { - value = null; - } - } - - @Override - public BytesRef getInternal(int index) { - throw new UnsupportedOperationException(); - } - - public BytesRef getInternal() { - return value; - } - - @Override - public int size() { - if (value == null) { - return 0; - } else { - return 1; - } - } - } - - private final BinaryDenseVectorSupplier bdvSupplier; - private final Version indexVersion; - private final float[] vector; - - BinaryDenseVectorScriptDocValues(BinaryDenseVectorSupplier supplier, Version indexVersion, int dims) { - super(supplier, dims); - this.bdvSupplier = supplier; - this.indexVersion = indexVersion; - this.vector = new float[dims]; - } - - @Override - public int size() { - return supplier.size(); - } - - @Override - public float[] getVectorValue() { - VectorEncoderDecoder.decodeDenseVector(bdvSupplier.getInternal(), vector); - return vector; - } - - @Override - public float getMagnitude() { - return VectorEncoderDecoder.getMagnitude(indexVersion, bdvSupplier.getInternal()); - } - - @Override - public double dotProduct(float[] queryVector) { - BytesRef value = bdvSupplier.getInternal(); - ByteBuffer byteBuffer = ByteBuffer.wrap(value.bytes, value.offset, value.length); - - double dotProduct = 0; - for (float queryValue : queryVector) { - dotProduct += queryValue * byteBuffer.getFloat(); - } - return (float) dotProduct; - } - - @Override - public double l1Norm(float[] queryVector) { - BytesRef value = bdvSupplier.getInternal(); - ByteBuffer byteBuffer = ByteBuffer.wrap(value.bytes, value.offset, value.length); - - double l1norm = 0; - for (float queryValue : queryVector) { - l1norm += Math.abs(queryValue - byteBuffer.getFloat()); - } - return l1norm; - } - - @Override - public double l2Norm(float[] queryVector) { - BytesRef value = bdvSupplier.getInternal(); - ByteBuffer byteBuffer = ByteBuffer.wrap(value.bytes, value.offset, value.length); - double l2norm = 0; - for (float queryValue : queryVector) { - double diff = queryValue - byteBuffer.getFloat(); - l2norm += diff * diff; - } - return Math.sqrt(l2norm); - } -} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVector.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVector.java new file mode 100644 index 0000000000000..4ffbccbd9e415 --- /dev/null +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVector.java @@ -0,0 +1,227 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.vectors.query; + +import java.util.List; + +/** + * DenseVector value type for the painless. + */ +/* dotProduct, l1Norm, l2Norm, cosineSimilarity have three flavors depending on the type of the queryVector + * 1) float[], this is for the ScoreScriptUtils class bindings which have converted a List based query vector into an array + * 2) List, A painless script will typically use Lists since they are easy to pass as params and have an easy + * literal syntax. Working with Lists directly, instead of converting to a float[], trades off runtime operations against + * memory pressure. Dense Vectors may have high dimensionality, up to 2048. Allocating a float[] per doc per script API + * call is prohibitively expensive. + * 3) Object, the whitelisted method for the painless API. Calls into the float[] or List version based on the + class of the argument and checks dimensionality. + */ +public interface DenseVector { + float[] getVector(); + + float getMagnitude(); + + double dotProduct(float[] queryVector); + + double dotProduct(List queryVector); + + @SuppressWarnings("unchecked") + default double dotProduct(Object queryVector) { + if (queryVector instanceof float[] array) { + checkDimensions(getDims(), array.length); + return dotProduct(array); + + } else if (queryVector instanceof List list) { + checkDimensions(getDims(), list.size()); + return dotProduct((List) list); + } + + throw new IllegalArgumentException(badQueryVectorType(queryVector)); + } + + double l1Norm(float[] queryVector); + + double l1Norm(List queryVector); + + @SuppressWarnings("unchecked") + default double l1Norm(Object queryVector) { + if (queryVector instanceof float[] array) { + checkDimensions(getDims(), array.length); + return l1Norm(array); + + } else if (queryVector instanceof List list) { + checkDimensions(getDims(), list.size()); + return l1Norm((List) list); + } + + throw new IllegalArgumentException(badQueryVectorType(queryVector)); + } + + double l2Norm(float[] queryVector); + + double l2Norm(List queryVector); + + @SuppressWarnings("unchecked") + default double l2Norm(Object queryVector) { + if (queryVector instanceof float[] array) { + checkDimensions(getDims(), array.length); + return l2Norm(array); + + } else if (queryVector instanceof List list) { + checkDimensions(getDims(), list.size()); + return l2Norm((List) list); + } + + throw new IllegalArgumentException(badQueryVectorType(queryVector)); + } + + /** + * Get the cosine similarity with the un-normalized query vector + */ + default double cosineSimilarity(float[] queryVector) { + return cosineSimilarity(queryVector, true); + } + + /** + * Get the cosine similarity with the query vector + * @param normalizeQueryVector - normalize the query vector, does not change the contents of passed in query vector + */ + double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector); + + /** + * Get the cosine similarity with the un-normalized query vector + */ + double cosineSimilarity(List queryVector); + + /** + * Get the cosine similarity with the un-normalized query vector. Handles queryVectors of type float[] and List. + */ + @SuppressWarnings("unchecked") + default double cosineSimilarity(Object queryVector) { + if (queryVector instanceof float[] array) { + checkDimensions(getDims(), array.length); + return cosineSimilarity(array); + + } else if (queryVector instanceof List list) { + checkDimensions(getDims(), list.size()); + return cosineSimilarity((List) list); + } + + throw new IllegalArgumentException(badQueryVectorType(queryVector)); + } + + boolean isEmpty(); + + int getDims(); + + int size(); + + static float getMagnitude(float[] vector) { + double mag = 0.0f; + for (float elem : vector) { + mag += elem * elem; + } + return (float) Math.sqrt(mag); + } + + static float getMagnitude(List vector) { + double mag = 0.0f; + for (Number number : vector) { + float elem = number.floatValue(); + mag += elem * elem; + } + return (float) Math.sqrt(mag); + } + + static void checkDimensions(int dvDims, int qvDims) { + if (dvDims != qvDims) { + throw new IllegalArgumentException( + "The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]." + ); + } + } + + private static String badQueryVectorType(Object queryVector) { + return "Cannot use vector [" + queryVector + "] with class [" + queryVector.getClass().getName() + "] as query vector"; + } + + DenseVector EMPTY = new DenseVector() { + public static final String MISSING_VECTOR_FIELD_MESSAGE = "Dense vector value missing for a field," + + " use isEmpty() to check for a missing vector value"; + + @Override + public float getMagnitude() { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public double dotProduct(float[] queryVector) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public double dotProduct(List queryVector) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public double l1Norm(List queryVector) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public double l1Norm(float[] queryVector) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public double l2Norm(List queryVector) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public double l2Norm(float[] queryVector) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public double cosineSimilarity(float[] queryVector) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public double cosineSimilarity(List queryVector) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public float[] getVector() { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public int getDims() { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public int size() { + return 0; + } + }; +} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorDocValuesField.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorDocValuesField.java new file mode 100644 index 0000000000000..dd4a00fef3af0 --- /dev/null +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorDocValuesField.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.vectors.query; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.script.field.DocValuesField; + +import java.util.Iterator; + +public abstract class DenseVectorDocValuesField implements DocValuesField, DenseVectorScriptDocValues.DenseVectorSupplier { + protected final String name; + + public DenseVectorDocValuesField(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } + + @Override + public int size() { + return isEmpty() ? 0 : 1; + } + + @Override + public BytesRef getInternal(int index) { + throw new UnsupportedOperationException(); + } + + /** + * Get the DenseVector for a document if one exists, DenseVector.EMPTY otherwise + */ + public abstract DenseVector get(); + + public abstract DenseVector get(DenseVector defaultValue); + + public abstract DenseVectorScriptDocValues getScriptDocValues(); + + // DenseVector fields are single valued, so Iterable does not make sense. + @Override + public Iterator iterator() { + throw new UnsupportedOperationException("Cannot iterate over single valued dense_vector field, use get() instead"); + } +} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java index 650ebca1d5ee5..43d04f5ccde7a 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java @@ -10,24 +10,16 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.fielddata.ScriptDocValues; -public abstract class DenseVectorScriptDocValues extends ScriptDocValues { - - public interface DenseVectorSupplier extends Supplier { - - @Override - default BytesRef getInternal(int index) { - throw new UnsupportedOperationException(); - } - - T getInternal(); - } +public class DenseVectorScriptDocValues extends ScriptDocValues { public static final String MISSING_VECTOR_FIELD_MESSAGE = "A document doesn't have a value for a vector field!"; private final int dims; + protected final DenseVectorSupplier dvSupplier; - public DenseVectorScriptDocValues(DenseVectorSupplier supplier, int dims) { + public DenseVectorScriptDocValues(DenseVectorSupplier supplier, int dims) { super(supplier); + this.dvSupplier = supplier; this.dims = dims; } @@ -35,60 +27,58 @@ public int dims() { return dims; } + private DenseVector getCheckedVector() { + DenseVector vector = dvSupplier.getInternal(); + if (vector == null) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + return vector; + } + /** * Get dense vector's value as an array of floats */ - public abstract float[] getVectorValue(); + public float[] getVectorValue() { + return getCheckedVector().getVector(); + } /** * Get dense vector's magnitude */ - public abstract float getMagnitude(); + public float getMagnitude() { + return getCheckedVector().getMagnitude(); + } - public abstract double dotProduct(float[] queryVector); + public double dotProduct(float[] queryVector) { + return getCheckedVector().dotProduct(queryVector); + } - public abstract double l1Norm(float[] queryVector); + public double l1Norm(float[] queryVector) { + return getCheckedVector().l1Norm(queryVector); + } - public abstract double l2Norm(float[] queryVector); + public double l2Norm(float[] queryVector) { + return getCheckedVector().l2Norm(queryVector); + } @Override public BytesRef get(int index) { throw new UnsupportedOperationException( - "accessing a vector field's value through 'get' or 'value' is not supported!" + "Use 'vectorValue' or 'magnitude' instead!'" + "accessing a vector field's value through 'get' or 'value' is not supported, use 'vectorValue' or 'magnitude' instead." ); } - public static DenseVectorScriptDocValues empty(DenseVectorSupplier supplier, int dims) { - return new DenseVectorScriptDocValues(supplier, dims) { - @Override - public float[] getVectorValue() { - throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); - } - - @Override - public float getMagnitude() { - throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); - } - - @Override - public double dotProduct(float[] queryVector) { - throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); - } - - @Override - public double l1Norm(float[] queryVector) { - throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); - } - - @Override - public double l2Norm(float[] queryVector) { - throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); - } - - @Override - public int size() { - return supplier.size(); - } - }; + @Override + public int size() { + return dvSupplier.getInternal() == null ? 0 : 1; + } + + public interface DenseVectorSupplier extends Supplier { + @Override + default BytesRef getInternal(int index) { + throw new UnsupportedOperationException(); + } + + DenseVector getInternal(); } } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DocValuesWhitelistExtension.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DocValuesWhitelistExtension.java index c53d1379dc252..953044c3a5500 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DocValuesWhitelistExtension.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DocValuesWhitelistExtension.java @@ -19,7 +19,10 @@ public class DocValuesWhitelistExtension implements PainlessExtension { - private static final Whitelist WHITELIST = WhitelistLoader.loadFromResourceFiles(DocValuesWhitelistExtension.class, "whitelist.txt"); + private static final Whitelist WHITELIST = WhitelistLoader.loadFromResourceFiles( + DocValuesWhitelistExtension.class, + "org.elasticsearch.xpack.vectors.txt" + ); @Override public Map, List> getContextWhitelists() { diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVector.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVector.java new file mode 100644 index 0000000000000..1c240892ab2bd --- /dev/null +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVector.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.vectors.query; + +import org.apache.lucene.util.VectorUtil; + +import java.util.Arrays; +import java.util.List; + +public class KnnDenseVector implements DenseVector { + protected final float[] docVector; + + public KnnDenseVector(float[] docVector) { + this.docVector = docVector; + } + + @Override + public float[] getVector() { + // we need to copy the value, since {@link VectorValues} can reuse + // the underlying array across documents + return Arrays.copyOf(docVector, docVector.length); + } + + @Override + public float getMagnitude() { + return DenseVector.getMagnitude(docVector); + } + + @Override + public double dotProduct(float[] queryVector) { + return VectorUtil.dotProduct(docVector, queryVector); + } + + @Override + public double dotProduct(List queryVector) { + double dotProduct = 0; + for (int i = 0; i < docVector.length; i++) { + dotProduct += docVector[i] * queryVector.get(i).floatValue(); + } + return dotProduct; + } + + @Override + public double l1Norm(float[] queryVector) { + double result = 0.0; + for (int i = 0; i < docVector.length; i++) { + result += Math.abs(docVector[i] - queryVector[i]); + } + return result; + } + + @Override + public double l1Norm(List queryVector) { + double result = 0.0; + for (int i = 0; i < docVector.length; i++) { + result += Math.abs(docVector[i] - queryVector.get(i).floatValue()); + } + return result; + } + + @Override + public double l2Norm(float[] queryVector) { + return Math.sqrt(VectorUtil.squareDistance(docVector, queryVector)); + } + + @Override + public double l2Norm(List queryVector) { + double l2norm = 0; + for (int i = 0; i < docVector.length; i++) { + double diff = docVector[i] - queryVector.get(i).floatValue(); + l2norm += diff * diff; + } + return Math.sqrt(l2norm); + } + + @Override + public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) { + if (normalizeQueryVector) { + return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude()); + } + + return dotProduct(queryVector) / getMagnitude(); + } + + @Override + public double cosineSimilarity(List queryVector) { + return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude()); + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public int getDims() { + return docVector.length; + } + + @Override + public int size() { + return 1; + } +} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorDocValuesField.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorDocValuesField.java new file mode 100644 index 0000000000000..58b2e60a0fb80 --- /dev/null +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorDocValuesField.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.vectors.query; + +import org.apache.lucene.index.VectorValues; +import org.elasticsearch.core.Nullable; + +import java.io.IOException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +public class KnnDenseVectorDocValuesField extends DenseVectorDocValuesField { + protected VectorValues input; // null if no vectors + protected float[] vector; + protected final int dims; + + public KnnDenseVectorDocValuesField(@Nullable VectorValues input, String name, int dims) { + super(name); + this.dims = dims; + this.input = input; + } + + @Override + public void setNextDocId(int docId) throws IOException { + if (input == null) { + return; + } + int currentDoc = input.docID(); + if (currentDoc == NO_MORE_DOCS || docId < currentDoc) { + vector = null; + } else if (docId == currentDoc) { + vector = input.vectorValue(); + } else { + currentDoc = input.advance(docId); + if (currentDoc == docId) { + vector = input.vectorValue(); + } else { + vector = null; + } + } + } + + @Override + public DenseVectorScriptDocValues getScriptDocValues() { + return new DenseVectorScriptDocValues(this, dims); + } + + public boolean isEmpty() { + return vector == null; + } + + @Override + public DenseVector get() { + if (isEmpty()) { + return DenseVector.EMPTY; + } + + return new KnnDenseVector(vector); + } + + @Override + public DenseVector get(DenseVector defaultValue) { + if (isEmpty()) { + return defaultValue; + } + + return new KnnDenseVector(vector); + } + + @Override + public DenseVector getInternal() { + return get(null); + } +} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValues.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValues.java deleted file mode 100644 index fc6f1bdb59906..0000000000000 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValues.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.vectors.query; - -import org.apache.lucene.index.VectorValues; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.VectorUtil; - -import java.io.IOException; -import java.util.Arrays; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -public class KnnDenseVectorScriptDocValues extends DenseVectorScriptDocValues { - - public static class KnnDenseVectorSupplier implements DenseVectorSupplier { - - private final VectorValues in; - private float[] vector; - - public KnnDenseVectorSupplier(VectorValues in) { - this.in = in; - } - - @Override - public void setNextDocId(int docId) throws IOException { - int currentDoc = in.docID(); - if (currentDoc == NO_MORE_DOCS || docId < currentDoc) { - vector = null; - } else if (docId == currentDoc) { - vector = in.vectorValue(); - } else { - currentDoc = in.advance(docId); - if (currentDoc == docId) { - vector = in.vectorValue(); - } else { - vector = null; - } - } - } - - @Override - public BytesRef getInternal(int index) { - throw new UnsupportedOperationException(); - } - - public float[] getInternal() { - return vector; - } - - @Override - public int size() { - if (vector == null) { - return 0; - } else { - return 1; - } - } - } - - private final KnnDenseVectorSupplier kdvSupplier; - - KnnDenseVectorScriptDocValues(KnnDenseVectorSupplier supplier, int dims) { - super(supplier, dims); - this.kdvSupplier = supplier; - } - - private float[] getVectorChecked() { - if (kdvSupplier.getInternal() == null) { - throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); - } - return kdvSupplier.getInternal(); - } - - @Override - public float[] getVectorValue() { - float[] vector = getVectorChecked(); - // we need to copy the value, since {@link VectorValues} can reuse - // the underlying array across documents - return Arrays.copyOf(vector, vector.length); - } - - @Override - public float getMagnitude() { - float[] vector = getVectorChecked(); - double magnitude = 0.0f; - for (float elem : vector) { - magnitude += elem * elem; - } - return (float) Math.sqrt(magnitude); - } - - @Override - public double dotProduct(float[] queryVector) { - return VectorUtil.dotProduct(getVectorChecked(), queryVector); - } - - @Override - public double l1Norm(float[] queryVector) { - float[] vectorValue = getVectorChecked(); - double result = 0.0; - for (int i = 0; i < queryVector.length; i++) { - result += Math.abs(vectorValue[i] - queryVector[i]); - } - return result; - } - - @Override - public double l2Norm(float[] queryVector) { - return Math.sqrt(VectorUtil.squareDistance(getVectorValue(), queryVector)); - } - - @Override - public int size() { - return supplier.size(); - } -} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java index e97daf4c2f397..24e74e4a93958 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java @@ -18,10 +18,10 @@ public class ScoreScriptUtils { public static class DenseVectorFunction { final ScoreScript scoreScript; final float[] queryVector; - final DenseVectorScriptDocValues docValues; + final DenseVectorDocValuesField field; - public DenseVectorFunction(ScoreScript scoreScript, List queryVector, String field) { - this(scoreScript, queryVector, field, false); + public DenseVectorFunction(ScoreScript scoreScript, List queryVector, String fieldName) { + this(scoreScript, queryVector, fieldName, false); } /** @@ -31,19 +31,10 @@ public DenseVectorFunction(ScoreScript scoreScript, List queryVector, St * @param queryVector The query vector. * @param normalizeQuery Whether the provided query should be normalized to unit length. */ - public DenseVectorFunction(ScoreScript scoreScript, List queryVector, String field, boolean normalizeQuery) { + public DenseVectorFunction(ScoreScript scoreScript, List queryVector, String fieldName, boolean normalizeQuery) { this.scoreScript = scoreScript; - this.docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(field); - - if (docValues.dims() != queryVector.size()) { - throw new IllegalArgumentException( - "The query vector has a different number of dimensions [" - + queryVector.size() - + "] than the document vectors [" - + docValues.dims() - + "]." - ); - } + this.field = (DenseVectorDocValuesField) scoreScript.field(fieldName); + DenseVector.checkDimensions(field.get().getDims(), queryVector.size()); this.queryVector = new float[queryVector.size()]; double queryMagnitude = 0.0; @@ -63,11 +54,11 @@ public DenseVectorFunction(ScoreScript scoreScript, List queryVector, St void setNextVector() { try { - docValues.getSupplier().setNextDocId(scoreScript._getDocId()); + field.setNextDocId(scoreScript._getDocId()); } catch (IOException e) { throw ExceptionsHelper.convertToElastic(e); } - if (docValues.size() == 0) { + if (field.isEmpty()) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } } @@ -82,7 +73,7 @@ public L1Norm(ScoreScript scoreScript, List queryVector, String field) { public double l1norm() { setNextVector(); - return docValues.l1Norm(queryVector); + return field.get().l1Norm(queryVector); } } @@ -95,7 +86,7 @@ public L2Norm(ScoreScript scoreScript, List queryVector, String field) { public double l2norm() { setNextVector(); - return docValues.l2Norm(queryVector); + return field.get().l2Norm(queryVector); } } @@ -108,7 +99,7 @@ public DotProduct(ScoreScript scoreScript, List queryVector, String fiel public double dotProduct() { setNextVector(); - return docValues.dotProduct(queryVector); + return field.get().dotProduct(queryVector); } } @@ -121,7 +112,8 @@ public CosineSimilarity(ScoreScript scoreScript, List queryVector, Strin public double cosineSimilarity() { setNextVector(); - return docValues.dotProduct(queryVector) / docValues.getMagnitude(); + // query vector normalized in constructor + return field.get().cosineSimilarity(queryVector, false); } } } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java index 1d8c45e9c60c2..a4789543ded43 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java @@ -15,18 +15,12 @@ import org.elasticsearch.Version; import org.elasticsearch.index.fielddata.LeafFieldData; import org.elasticsearch.index.fielddata.SortedBinaryDocValues; -import org.elasticsearch.script.field.DelegateDocValuesField; import org.elasticsearch.script.field.DocValuesField; -import org.elasticsearch.xpack.vectors.query.BinaryDenseVectorScriptDocValues.BinaryDenseVectorSupplier; -import org.elasticsearch.xpack.vectors.query.DenseVectorScriptDocValues.DenseVectorSupplier; -import org.elasticsearch.xpack.vectors.query.KnnDenseVectorScriptDocValues.KnnDenseVectorSupplier; import java.io.IOException; import java.util.Collection; import java.util.Collections; -import static org.elasticsearch.xpack.vectors.query.DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE; - final class VectorDVLeafFieldData implements LeafFieldData { private final LeafReader reader; @@ -63,31 +57,15 @@ public DocValuesField getScriptField(String name) { try { if (indexed) { VectorValues values = reader.getVectorValues(field); - if (values == null || values == VectorValues.EMPTY) { - return new DelegateDocValuesField(DenseVectorScriptDocValues.empty(new DenseVectorSupplier() { - @Override - public float[] getInternal() { - throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); - } - - @Override - public void setNextDocId(int docId) throws IOException { - // do nothing - } - - @Override - public int size() { - return 0; - } - }, dims), name); + if (values == VectorValues.EMPTY) { + // There's no way for KnnDenseVectorDocValuesField to reliably differentiate between VectorValues.EMPTY and + // values that can be iterated through. Since VectorValues.EMPTY throws on docID(), pass a null instead. + values = null; } - return new DelegateDocValuesField(new KnnDenseVectorScriptDocValues(new KnnDenseVectorSupplier(values), dims), name); + return new KnnDenseVectorDocValuesField(values, name, dims); } else { BinaryDocValues values = DocValues.getBinary(reader, field); - return new DelegateDocValuesField( - new BinaryDenseVectorScriptDocValues(new BinaryDenseVectorSupplier(values), indexVersion, dims), - name - ); + return new BinaryDenseVectorDocValuesField(values, name, dims, indexVersion); } } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for vector field!", e); diff --git a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/org.elasticsearch.xpack.vectors.txt similarity index 52% rename from x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt rename to x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/org.elasticsearch.xpack.vectors.txt index 86583d77264a2..bcf989933b04e 100644 --- a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt +++ b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/org.elasticsearch.xpack.vectors.txt @@ -11,6 +11,43 @@ class org.elasticsearch.xpack.vectors.query.DenseVectorScriptDocValues { class org.elasticsearch.script.ScoreScript @no_import { } +class org.elasticsearch.xpack.vectors.query.DenseVector { + DenseVector EMPTY + float getMagnitude() + + # handle List and float[] arguments + double dotProduct(Object) + double l1Norm(Object) + double l2Norm(Object) + double cosineSimilarity(Object) + + float[] getVector() + boolean isEmpty() + int getDims() + int size() +} + +# implementation of DenseVector +class org.elasticsearch.xpack.vectors.query.BinaryDenseVector { +} + +# implementation of DenseVector +class org.elasticsearch.xpack.vectors.query.KnnDenseVector { +} + +class org.elasticsearch.xpack.vectors.query.DenseVectorDocValuesField { + DenseVector get() + DenseVector get(DenseVector) +} + +# implementation of DenseVectorDocValuesField +class org.elasticsearch.xpack.vectors.query.KnnDenseVectorDocValuesField { +} + +# implementation of DenseVectorDocValuesField +class org.elasticsearch.xpack.vectors.query.BinaryDenseVectorDocValuesField { +} + static_import { double l1norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm double l2norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValuesTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValuesTests.java index 2761364e51505..ddd96ba9fd0a7 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValuesTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValuesTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.Version; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; -import org.elasticsearch.xpack.vectors.query.BinaryDenseVectorScriptDocValues.BinaryDenseVectorSupplier; import java.io.IOException; import java.nio.ByteBuffer; @@ -29,24 +28,56 @@ public void testGetVectorValueAndGetMagnitude() throws IOException { for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) { BinaryDocValues docValues = wrap(vectors, indexVersion); - BinaryDenseVectorSupplier supplier = new BinaryDenseVectorSupplier(docValues); - DenseVectorScriptDocValues scriptDocValues = new BinaryDenseVectorScriptDocValues(supplier, indexVersion, dims); + BinaryDenseVectorDocValuesField field = new BinaryDenseVectorDocValuesField(docValues, "test", dims, indexVersion); + DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues(); for (int i = 0; i < vectors.length; i++) { - supplier.setNextDocId(i); + field.setNextDocId(i); + assertEquals(1, field.size()); + assertEquals(dims, scriptDocValues.dims()); assertArrayEquals(vectors[i], scriptDocValues.getVectorValue(), 0.0001f); assertEquals(expectedMagnitudes[i], scriptDocValues.getMagnitude(), 0.0001f); } } } + public void testMetadataAndIterator() throws IOException { + int dims = 3; + Version indexVersion = Version.CURRENT; + float[][] vectors = fill(new float[randomIntBetween(1, 5)][dims]); + BinaryDocValues docValues = wrap(vectors, indexVersion); + BinaryDenseVectorDocValuesField field = new BinaryDenseVectorDocValuesField(docValues, "test", dims, indexVersion); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + DenseVector dv = field.get(); + assertEquals(1, dv.size()); + assertFalse(dv.isEmpty()); + assertEquals(dims, dv.getDims()); + UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator); + assertEquals("Cannot iterate over single valued dense_vector field, use get() instead", e.getMessage()); + } + field.setNextDocId(vectors.length); + DenseVector dv = field.get(); + assertEquals(dv, DenseVector.EMPTY); + } + + protected float[][] fill(float[][] vectors) { + for (float[] vector : vectors) { + for (int i = 0; i < vector.length; i++) { + vector[i] = randomFloat(); + } + } + return vectors; + } + public void testMissingValues() throws IOException { int dims = 3; float[][] vectors = { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; BinaryDocValues docValues = wrap(vectors, Version.CURRENT); - BinaryDenseVectorSupplier supplier = new BinaryDenseVectorSupplier(docValues); - DenseVectorScriptDocValues scriptDocValues = new BinaryDenseVectorScriptDocValues(supplier, Version.CURRENT, dims); + BinaryDenseVectorDocValuesField field = new BinaryDenseVectorDocValuesField(docValues, "test", dims, Version.CURRENT); + DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues(); - supplier.setNextDocId(3); + field.setNextDocId(3); + assertEquals(0, field.size()); Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValue); assertEquals("A document doesn't have a value for a vector field!", e.getMessage()); @@ -58,12 +89,17 @@ public void testGetFunctionIsNotAccessible() throws IOException { int dims = 3; float[][] vectors = { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; BinaryDocValues docValues = wrap(vectors, Version.CURRENT); - BinaryDenseVectorSupplier supplier = new BinaryDenseVectorSupplier(docValues); - DenseVectorScriptDocValues scriptDocValues = new BinaryDenseVectorScriptDocValues(supplier, Version.CURRENT, dims); + BinaryDenseVectorDocValuesField field = new BinaryDenseVectorDocValuesField(docValues, "test", dims, Version.CURRENT); + DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues(); - supplier.setNextDocId(0); + field.setNextDocId(0); Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); - assertThat(e.getMessage(), containsString("accessing a vector field's value through 'get' or 'value' is not supported!")); + assertThat( + e.getMessage(), + containsString( + "accessing a vector field's value through 'get' or 'value' is not supported, use 'vectorValue' or 'magnitude' instead." + ) + ); } public void testSimilarityFunctions() throws IOException { @@ -73,10 +109,10 @@ public void testSimilarityFunctions() throws IOException { for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) { BinaryDocValues docValues = wrap(new float[][] { docVector }, indexVersion); - BinaryDenseVectorSupplier supplier = new BinaryDenseVectorSupplier(docValues); - DenseVectorScriptDocValues scriptDocValues = new BinaryDenseVectorScriptDocValues(supplier, Version.CURRENT, dims); + BinaryDenseVectorDocValuesField field = new BinaryDenseVectorDocValuesField(docValues, "test", dims, indexVersion); + DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues(); - supplier.setNextDocId(0); + field.setNextDocId(0); assertEquals( "dotProduct result is not equal to the expected value!", @@ -133,7 +169,7 @@ public long cost() { }; } - private static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) { + static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) { byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0) ? new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES] : new byte[VectorEncoderDecoder.INT_BYTES * values.length]; diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java index 0ecd26f08c20c..d40d7e3abd663 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java @@ -7,18 +7,16 @@ package org.elasticsearch.xpack.vectors.query; -import org.apache.lucene.index.BinaryDocValues; import org.elasticsearch.Version; import org.elasticsearch.script.ScoreScript; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.vectors.query.BinaryDenseVectorScriptDocValues.BinaryDenseVectorSupplier; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilarity; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProduct; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1Norm; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm; +import java.io.IOException; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.function.Supplier; @@ -28,34 +26,72 @@ public class DenseVectorFunctionTests extends ESTestCase { - public void testVectorFunctions() { - String field = "vector"; + public void testVectorClassBindings() throws IOException { + String fieldName = "vector"; int dims = 5; float[] docVector = new float[] { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f }; List queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f); List invalidQueryVector = Arrays.asList(0.5, 111.3); - for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) { - BinaryDocValues docValues = BinaryDenseVectorScriptDocValuesTests.wrap(new float[][] { docVector }, indexVersion); - DenseVectorScriptDocValues scriptDocValues = new BinaryDenseVectorScriptDocValues( - new BinaryDenseVectorSupplier(docValues), - indexVersion, - dims - ); + List fields = List.of( + new BinaryDenseVectorDocValuesField( + BinaryDenseVectorScriptDocValuesTests.wrap(new float[][] { docVector }, Version.V_7_4_0), + "test", + dims, + Version.V_7_4_0 + ), + new BinaryDenseVectorDocValuesField( + BinaryDenseVectorScriptDocValuesTests.wrap(new float[][] { docVector }, Version.CURRENT), + "test", + dims, + Version.CURRENT + ), + new KnnDenseVectorDocValuesField(KnnDenseVectorScriptDocValuesTests.wrap(new float[][] { docVector }), "test", dims) + ); + for (DenseVectorDocValuesField field : fields) { + field.setNextDocId(0); ScoreScript scoreScript = mock(ScoreScript.class); - when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, scriptDocValues)); + when(scoreScript.field("vector")).thenAnswer(mock -> field); // Test cosine similarity explicitly, as it must perform special logic on top of the doc values - CosineSimilarity function = new CosineSimilarity(scoreScript, queryVector, field); - assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, function.cosineSimilarity(), 0.001); + CosineSimilarity function = new CosineSimilarity(scoreScript, queryVector, fieldName); + float cosineSimilarityExpected = 0.790f; + assertEquals( + "cosineSimilarity result is not equal to the expected value!", + cosineSimilarityExpected, + function.cosineSimilarity(), + 0.001 + ); + + // Test normalization for cosineSimilarity + float[] queryVectorArray = new float[queryVector.size()]; + for (int i = 0; i < queryVectorArray.length; i++) { + queryVectorArray[i] = queryVector.get(i).floatValue(); + } + assertEquals( + "cosineSimilarity result is not equal to the expected value!", + cosineSimilarityExpected, + field.getInternal().cosineSimilarity(queryVectorArray, true), + 0.001 + ); // Check each function rejects query vectors with the wrong dimension - assertDimensionMismatch(() -> new DotProduct(scoreScript, invalidQueryVector, field)); - assertDimensionMismatch(() -> new CosineSimilarity(scoreScript, invalidQueryVector, field)); - assertDimensionMismatch(() -> new L1Norm(scoreScript, invalidQueryVector, field)); - assertDimensionMismatch(() -> new L2Norm(scoreScript, invalidQueryVector, field)); + assertDimensionMismatch(() -> new DotProduct(scoreScript, invalidQueryVector, fieldName)); + assertDimensionMismatch(() -> new CosineSimilarity(scoreScript, invalidQueryVector, fieldName)); + assertDimensionMismatch(() -> new L1Norm(scoreScript, invalidQueryVector, fieldName)); + assertDimensionMismatch(() -> new L2Norm(scoreScript, invalidQueryVector, fieldName)); + + // Check scripting infrastructure integration + DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName); + assertEquals(65425.6249, dotProduct.dotProduct(), 0.001); + assertEquals(485.1837, new L1Norm(scoreScript, queryVector, fieldName).l1norm(), 0.001); + assertEquals(301.3614, new L2Norm(scoreScript, queryVector, fieldName).l2norm(), 0.001); + when(scoreScript._getDocId()).thenReturn(1); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct::dotProduct); + assertEquals("A document doesn't have a value for a vector field!", e.getMessage()); } + } private void assertDimensionMismatch(Supplier supplier) { diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorTests.java new file mode 100644 index 0000000000000..11078e4964920 --- /dev/null +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorTests.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.vectors.query; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; +import org.elasticsearch.test.ESTestCase; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; + +import static org.hamcrest.Matchers.containsString; + +public class DenseVectorTests extends ESTestCase { + public void testBadVectorType() { + DenseVector knn = new KnnDenseVector(new float[] { 1.0f, 2.0f, 3.5f }); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> knn.dotProduct(new HashMap<>())); + assertThat(e.getMessage(), containsString("Cannot use vector [")); + assertThat(e.getMessage(), containsString("] with class [java.util.HashMap] as query vector")); + + e = expectThrows(IllegalArgumentException.class, () -> knn.l1Norm(new HashMap<>())); + assertThat(e.getMessage(), containsString("Cannot use vector [")); + assertThat(e.getMessage(), containsString("] with class [java.util.HashMap] as query vector")); + + e = expectThrows(IllegalArgumentException.class, () -> knn.l2Norm(new HashMap<>())); + assertThat(e.getMessage(), containsString("Cannot use vector [")); + assertThat(e.getMessage(), containsString("] with class [java.util.HashMap] as query vector")); + + e = expectThrows(IllegalArgumentException.class, () -> knn.cosineSimilarity(new HashMap<>())); + assertThat(e.getMessage(), containsString("Cannot use vector [")); + assertThat(e.getMessage(), containsString("] with class [java.util.HashMap] as query vector")); + } + + public void testFloatVsListQueryVector() { + int dims = randomIntBetween(1, 16); + float[] docVector = new float[dims]; + float[] arrayQV = new float[dims]; + List listQV = new ArrayList<>(dims); + for (int i = 0; i < docVector.length; i++) { + docVector[i] = randomFloat(); + float q = randomFloat(); + arrayQV[i] = q; + listQV.add(q); + } + + KnnDenseVector knn = new KnnDenseVector(docVector); + assertEquals(knn.dotProduct(arrayQV), knn.dotProduct(listQV), 0.001f); + assertEquals(knn.dotProduct((Object) listQV), knn.dotProduct((Object) arrayQV), 0.001f); + + assertEquals(knn.l1Norm(arrayQV), knn.l1Norm(listQV), 0.001f); + assertEquals(knn.l1Norm((Object) listQV), knn.l1Norm((Object) arrayQV), 0.001f); + + assertEquals(knn.l2Norm(arrayQV), knn.l2Norm(listQV), 0.001f); + assertEquals(knn.l2Norm((Object) listQV), knn.l2Norm((Object) arrayQV), 0.001f); + + assertEquals(knn.cosineSimilarity(arrayQV), knn.cosineSimilarity(listQV), 0.001f); + assertEquals(knn.cosineSimilarity((Object) listQV), knn.cosineSimilarity((Object) arrayQV), 0.001f); + + for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) { + BytesRef value = BinaryDenseVectorScriptDocValuesTests.mockEncodeDenseVector(docVector, indexVersion); + BinaryDenseVector bdv = new BinaryDenseVector(value, dims, indexVersion); + + assertEquals(bdv.dotProduct(arrayQV), bdv.dotProduct(listQV), 0.001f); + assertEquals(bdv.dotProduct((Object) listQV), bdv.dotProduct((Object) arrayQV), 0.001f); + + assertEquals(bdv.l1Norm(arrayQV), bdv.l1Norm(listQV), 0.001f); + assertEquals(bdv.l1Norm((Object) listQV), bdv.l1Norm((Object) arrayQV), 0.001f); + + assertEquals(bdv.l2Norm(arrayQV), bdv.l2Norm(listQV), 0.001f); + assertEquals(bdv.l2Norm((Object) listQV), bdv.l2Norm((Object) arrayQV), 0.001f); + + assertEquals(bdv.cosineSimilarity(arrayQV), bdv.cosineSimilarity(listQV), 0.001f); + assertEquals(bdv.cosineSimilarity((Object) listQV), bdv.cosineSimilarity((Object) arrayQV), 0.001f); + } + } + +} diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValuesTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValuesTests.java index 7005e4d7bd531..743fc2d8bb63e 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValuesTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValuesTests.java @@ -10,7 +10,6 @@ import org.apache.lucene.index.VectorValues; import org.apache.lucene.util.BytesRef; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.vectors.query.KnnDenseVectorScriptDocValues.KnnDenseVectorSupplier; import java.io.IOException; @@ -23,22 +22,52 @@ public void testGetVectorValueAndGetMagnitude() throws IOException { float[][] vectors = { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; float[] expectedMagnitudes = { 1.7320f, 2.4495f, 3.3166f }; - KnnDenseVectorSupplier supplier = new KnnDenseVectorSupplier(wrap(vectors)); - DenseVectorScriptDocValues scriptDocValues = new KnnDenseVectorScriptDocValues(supplier, dims); + DenseVectorDocValuesField field = new KnnDenseVectorDocValuesField(wrap(vectors), "test", dims); + DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues(); for (int i = 0; i < vectors.length; i++) { - supplier.setNextDocId(i); + field.setNextDocId(i); + assertEquals(1, field.size()); + assertEquals(dims, scriptDocValues.dims()); assertArrayEquals(vectors[i], scriptDocValues.getVectorValue(), 0.0001f); assertEquals(expectedMagnitudes[i], scriptDocValues.getMagnitude(), 0.0001f); } } + public void testMetadataAndIterator() throws IOException { + int dims = 3; + float[][] vectors = fill(new float[randomIntBetween(1, 5)][dims]); + KnnDenseVectorDocValuesField field = new KnnDenseVectorDocValuesField(wrap(vectors), "test", dims); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + DenseVector dv = field.get(); + assertEquals(1, dv.size()); + assertFalse(dv.isEmpty()); + assertEquals(dims, dv.getDims()); + UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator); + assertEquals("Cannot iterate over single valued dense_vector field, use get() instead", e.getMessage()); + } + assertEquals(1, field.size()); + field.setNextDocId(vectors.length); + DenseVector dv = field.get(); + assertEquals(dv, DenseVector.EMPTY); + } + + protected float[][] fill(float[][] vectors) { + for (float[] vector : vectors) { + for (int i = 0; i < vector.length; i++) { + vector[i] = randomFloat(); + } + } + return vectors; + } + public void testMissingValues() throws IOException { int dims = 3; float[][] vectors = { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; - KnnDenseVectorSupplier supplier = new KnnDenseVectorSupplier(wrap(vectors)); - DenseVectorScriptDocValues scriptDocValues = new KnnDenseVectorScriptDocValues(supplier, dims); + DenseVectorDocValuesField field = new KnnDenseVectorDocValuesField(wrap(vectors), "test", dims); + DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues(); - supplier.setNextDocId(3); + field.setNextDocId(3); Exception e = expectThrows(IllegalArgumentException.class, () -> scriptDocValues.getVectorValue()); assertEquals("A document doesn't have a value for a vector field!", e.getMessage()); @@ -49,12 +78,17 @@ public void testMissingValues() throws IOException { public void testGetFunctionIsNotAccessible() throws IOException { int dims = 3; float[][] vectors = { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; - KnnDenseVectorSupplier supplier = new KnnDenseVectorSupplier(wrap(vectors)); - DenseVectorScriptDocValues scriptDocValues = new KnnDenseVectorScriptDocValues(supplier, dims); + DenseVectorDocValuesField field = new KnnDenseVectorDocValuesField(wrap(vectors), "test", dims); + DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues(); - supplier.setNextDocId(0); + field.setNextDocId(0); Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); - assertThat(e.getMessage(), containsString("accessing a vector field's value through 'get' or 'value' is not supported!")); + assertThat( + e.getMessage(), + containsString( + "accessing a vector field's value through 'get' or 'value' is not supported, use 'vectorValue' or 'magnitude' instead." + ) + ); } public void testSimilarityFunctions() throws IOException { @@ -62,16 +96,30 @@ public void testSimilarityFunctions() throws IOException { float[] docVector = new float[] { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f }; float[] queryVector = new float[] { 0.5f, 111.3f, -13.0f, 14.8f, -156.0f }; - KnnDenseVectorSupplier supplier = new KnnDenseVectorSupplier(wrap(new float[][] { docVector })); - DenseVectorScriptDocValues scriptDocValues = new KnnDenseVectorScriptDocValues(supplier, dims); - supplier.setNextDocId(0); + DenseVectorDocValuesField field = new KnnDenseVectorDocValuesField(wrap(new float[][] { docVector }), "test", dims); + DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues(); + field.setNextDocId(0); assertEquals("dotProduct result is not equal to the expected value!", 65425.624, scriptDocValues.dotProduct(queryVector), 0.001); assertEquals("l1norm result is not equal to the expected value!", 485.184, scriptDocValues.l1Norm(queryVector), 0.001); assertEquals("l2norm result is not equal to the expected value!", 301.361, scriptDocValues.l2Norm(queryVector), 0.001); } - private static VectorValues wrap(float[][] vectors) { + public void testMissingVectorValues() throws IOException { + int dims = 7; + KnnDenseVectorDocValuesField emptyKnn = new KnnDenseVectorDocValuesField(null, "test", dims); + + emptyKnn.setNextDocId(0); + assertEquals(0, emptyKnn.getScriptDocValues().size()); + assertTrue(emptyKnn.getScriptDocValues().isEmpty()); + assertEquals(DenseVector.EMPTY, emptyKnn.get()); + assertNull(emptyKnn.get(null)); + assertNull(emptyKnn.getInternal()); + UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, emptyKnn::iterator); + assertEquals("Cannot iterate over single valued dense_vector field, use get() instead", e.getMessage()); + } + + static VectorValues wrap(float[][] vectors) { return new VectorValues() { int index = 0;