diff --git a/LiteDB.Tests/Query/VectorIndex_Tests.cs b/LiteDB.Tests/Query/VectorIndex_Tests.cs index 8d99f324f..17012d18a 100644 --- a/LiteDB.Tests/Query/VectorIndex_Tests.cs +++ b/LiteDB.Tests/Query/VectorIndex_Tests.cs @@ -264,6 +264,52 @@ public void TopKNear_UsesVectorIndex() results.Select(x => x.Id).Should().Equal(new[] { 1, 3 }); } + [Fact] + public void OrderBy_VectorSimilarity_WithCompositeOrdering_UsesVectorIndex() + { + using var db = new LiteDatabase(":memory:"); + var collection = db.GetCollection("vectors"); + + collection.Insert(new[] + { + new VectorDocument { Id = 1, Embedding = new[] { 1f, 0f }, Flag = true }, + new VectorDocument { Id = 2, Embedding = new[] { 1f, 0f }, Flag = false }, + new VectorDocument { Id = 3, Embedding = new[] { 0f, 1f }, Flag = true } + }); + + collection.EnsureIndex( + "embedding_idx", + BsonExpression.Create("$.Embedding"), + new VectorIndexOptions(2, VectorDistanceMetric.Cosine)); + + var similarity = BsonExpression.Create("VECTOR_SIM($.Embedding, [1.0, 0.0])"); + + var query = (LiteQueryable)collection.Query() + .OrderBy(similarity, Query.Ascending) + .ThenBy(x => x.Flag); + + var queryField = typeof(LiteQueryable).GetField("_query", BindingFlags.NonPublic | BindingFlags.Instance); + var definition = (Query)queryField.GetValue(query); + + definition.OrderBy.Should().HaveCount(2); + definition.OrderBy[0].Expression.Type.Should().Be(BsonExpressionType.VectorSim); + + definition.VectorField = "$.Embedding"; + definition.VectorTarget = new[] { 1f, 0f }; + definition.VectorMaxDistance = double.MaxValue; + + var plan = query.GetPlan(); + + plan["index"]["mode"].AsString.Should().Be("VECTOR INDEX SEARCH"); + plan["index"]["expr"].AsString.Should().Be("$.Embedding"); + plan.ContainsKey("orderBy").Should().BeFalse(); + + var results = query.ToArray(); + + results.Should().HaveCount(3); + results.Select(x => x.Id).Should().BeEquivalentTo(new[] { 1, 2, 3 }); + } + [Fact] public void WhereNear_DotProductHonorsMinimumSimilarity() { diff --git a/LiteDB/Engine/Query/QueryOptimization.cs b/LiteDB/Engine/Query/QueryOptimization.cs index 5ef601777..c67b42e66 100644 --- a/LiteDB/Engine/Query/QueryOptimization.cs +++ b/LiteDB/Engine/Query/QueryOptimization.cs @@ -326,12 +326,16 @@ private bool TrySelectVectorIndex(out VectorIndexQuery index, out BsonExpression } } - if (expression == null && _query.OrderBy != null) + if (expression == null && _query.OrderBy.Count > 0) { - if (this.TryParseVectorExpression(_query.OrderBy, out expression, out target)) + foreach (var order in _query.OrderBy) { - matchedFromOrderBy = true; - maxDistance = double.MaxValue; + if (this.TryParseVectorExpression(order.Expression, out expression, out target)) + { + matchedFromOrderBy = true; + maxDistance = double.MaxValue; + break; + } } } @@ -340,7 +344,7 @@ private bool TrySelectVectorIndex(out VectorIndexQuery index, out BsonExpression expression = NormalizeVectorField(_query.VectorField); target = _query.VectorTarget?.ToArray(); maxDistance = _query.VectorMaxDistance; - matchedFromOrderBy = matchedFromOrderBy || (_query.OrderBy != null && _query.OrderBy.Type == BsonExpressionType.VectorSim); + matchedFromOrderBy = matchedFromOrderBy || (_query.OrderBy.Any(order => order.Expression?.Type == BsonExpressionType.VectorSim)); } if (expression == null || target == null)