Skip to content

Commit 26f5c66

Browse files
Return correct type when invoking GetFieldType and GetProviderSpecificFieldType for vector float32 column (#4105)
This commit returns correct type for GetFieldType and GetProviderSpecificFieldType
1 parent 27999d9 commit 26f5c66

2 files changed

Lines changed: 59 additions & 3 deletions

File tree

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,10 @@ private Type GetFieldTypeInternal(_SqlMetaData metaData)
13071307
Connection.CheckGetExtendedUDTInfo(metaData, false);
13081308
fieldType = metaData.udt?.Type;
13091309
}
1310+
else if (metaData.type == SqlDbTypeExtensions.Vector)
1311+
{
1312+
fieldType = GetVectorFieldType(metaData.scale);
1313+
}
13101314
else
13111315
{ // For all other types, including Xml - use data in MetaType.
13121316
if (metaData.cipherMD != null)
@@ -1329,6 +1333,19 @@ private Type GetFieldTypeInternal(_SqlMetaData metaData)
13291333
return fieldType;
13301334
}
13311335

1336+
#if !NETFRAMEWORK
1337+
[return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)]
1338+
#endif
1339+
private static Type GetVectorFieldType(byte vectorElementType)
1340+
{
1341+
MetaType.SqlVectorElementType elementType = (MetaType.SqlVectorElementType)vectorElementType;
1342+
return elementType switch
1343+
{
1344+
MetaType.SqlVectorElementType.Float32 => typeof(SqlVector<float>),
1345+
_ => throw SQL.VectorTypeNotSupported(elementType.ToString()),
1346+
};
1347+
}
1348+
13321349
virtual internal int GetLocaleId(int i)
13331350
{
13341351
_SqlMetaData sqlMetaData = MetaData[i];
@@ -1422,6 +1439,10 @@ private Type GetProviderSpecificFieldTypeInternal(_SqlMetaData metaData)
14221439
Connection.CheckGetExtendedUDTInfo(metaData, false);
14231440
providerSpecificFieldType = metaData.udt?.Type;
14241441
}
1442+
else if (metaData.type == SqlDbTypeExtensions.Vector)
1443+
{
1444+
providerSpecificFieldType = GetVectorFieldType(metaData.scale);
1445+
}
14251446
else
14261447
{
14271448
// For all other types, including Xml - use data in MetaType.

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@ public static IEnumerable<object[]> GetVectorFloat32TestData()
2929
yield return new object[] { 3, new SqlVector<float>(testData), testData, vectorColumnLength };
3030
yield return new object[] { 4, new SqlVector<float>(testData), testData, vectorColumnLength };
3131

32-
// Pattern 1-4 with SqlVector<float>(n)
32+
// Pattern 1-4 with SqlVector<float>(n)
3333
yield return new object[] { 1, SqlVector<float>.CreateNull(vectorColumnLength), Array.Empty<float>(), vectorColumnLength };
3434
yield return new object[] { 2, SqlVector<float>.CreateNull(vectorColumnLength), Array.Empty<float>(), vectorColumnLength };
3535
yield return new object[] { 3, SqlVector<float>.CreateNull(vectorColumnLength), Array.Empty<float>(), vectorColumnLength };
3636
yield return new object[] { 4, SqlVector<float>.CreateNull(vectorColumnLength), Array.Empty<float>(), vectorColumnLength };
3737

38-
// Pattern 1-4 with DBNull
38+
// Pattern 1-4 with DBNull
3939
yield return new object[] { 1, DBNull.Value, Array.Empty<float>(), vectorColumnLength };
4040
yield return new object[] { 2, DBNull.Value, Array.Empty<float>(), vectorColumnLength };
4141
yield return new object[] { 3, DBNull.Value, Array.Empty<float>(), vectorColumnLength };
4242
yield return new object[] { 4, DBNull.Value, Array.Empty<float>(), vectorColumnLength };
4343

44-
// Pattern 1-4 with SqlVector<float>.Null
44+
// Pattern 1-4 with SqlVector<float>.Null
4545
yield return new object[] { 1, SqlVector<float>.Null, Array.Empty<float>(), vectorColumnLength };
4646

4747
// Following scenario is not supported in SqlClient.
@@ -561,6 +561,41 @@ public async Task TestBulkCopyFromSqlTableAsync(int bulkCopySourceMode)
561561
Assert.Equal(VectorFloat32TestData.testData.Length, vector.Length);
562562
}
563563

564+
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))]
565+
public void TestGetFieldTypeReturnsSqlVectorForVectorColumn()
566+
{
567+
using var connection = new SqlConnection(s_connectionString);
568+
connection.Open();
569+
570+
// Insert a row so we can query it
571+
using (var insertCmd = new SqlCommand(s_insertCmdString, connection))
572+
{
573+
var param = insertCmd.Parameters.Add(s_vectorParamName, SqlDbTypeExtensions.Vector);
574+
param.Value = new SqlVector<float>(VectorFloat32TestData.testData);
575+
insertCmd.ExecuteNonQuery();
576+
}
577+
578+
using var selectCmd = new SqlCommand(s_selectCmdString, connection);
579+
using var reader = selectCmd.ExecuteReader();
580+
581+
// Verify GetFieldType returns SqlVector<float> for the vector column
582+
Assert.Equal(typeof(SqlVector<float>), reader.GetFieldType(0));
583+
584+
// Verify GetProviderSpecificFieldType also returns SqlVector<float>
585+
Assert.Equal(typeof(SqlVector<float>), reader.GetProviderSpecificFieldType(0));
586+
587+
// Verify that GetValue returns an instance consistent with GetFieldType
588+
Assert.True(reader.Read(), "No data found in the table.");
589+
object value = reader.GetValue(0);
590+
Assert.IsType<SqlVector<float>>(value);
591+
Assert.Equal(VectorFloat32TestData.testData, ((SqlVector<float>)value).Memory.ToArray());
592+
593+
// Verify GetFieldValue<SqlVector<float>> returns the correct typed value
594+
SqlVector<float> typedValue = reader.GetFieldValue<SqlVector<float>>(0);
595+
Assert.IsType<SqlVector<float>>(typedValue);
596+
Assert.Equal(VectorFloat32TestData.testData, typedValue.Memory.ToArray());
597+
}
598+
564599
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))]
565600
public void TestInsertVectorsFloat32WithPrepare()
566601
{

0 commit comments

Comments
 (0)