Skip to content

Commit f6a4bd5

Browse files
committed
BulkWriter supports Struct
Signed-off-by: yhmo <[email protected]>
1 parent 45ae45d commit f6a4bd5

File tree

6 files changed

+731
-195
lines changed

6 files changed

+731
-195
lines changed

examples/src/main/java/io/milvus/v2/bulkwriter/BulkWriterRemoteExample.java

Lines changed: 175 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ private static void createConnection() {
167167

168168
private static void exampleSimpleCollection(List<BulkFileType> fileTypes) throws Exception {
169169
CreateCollectionReq.CollectionSchema collectionSchema = buildSimpleSchema();
170-
createCollection(SIMPLE_COLLECTION_NAME, collectionSchema, false);
170+
createCollection(SIMPLE_COLLECTION_NAME, collectionSchema);
171171

172172
for (BulkFileType fileType : fileTypes) {
173173
remoteWriter(collectionSchema, fileType);
@@ -182,7 +182,7 @@ private static void exampleAllTypesCollectionRemote(List<BulkFileType> fileTypes
182182
for (BulkFileType fileType : fileTypes) {
183183
CreateCollectionReq.CollectionSchema collectionSchema = buildAllTypesSchema();
184184
List<List<String>> batchFiles = allTypesRemoteWriter(collectionSchema, fileType, rows);
185-
createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, false);
185+
createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema);
186186
callBulkInsert(batchFiles);
187187
verifyImportData(collectionSchema, originalData);
188188
}
@@ -192,7 +192,7 @@ private static void exampleAllTypesCollectionRemote(List<BulkFileType> fileTypes
192192
// for (BulkFileType fileType : fileTypes) {
193193
// CreateCollectionReq.CollectionSchema collectionSchema = buildAllTypesSchema();
194194
// List<List<String>> batchFiles = allTypesRemoteWriter(collectionSchema, fileType, rows);
195-
// createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, false);
195+
// createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema);
196196
// callCloudImport(batchFiles, ALL_TYPES_COLLECTION_NAME, "");
197197
// verifyImportData(collectionSchema, originalData);
198198
// }
@@ -227,6 +227,20 @@ private static void remoteWriter(CreateCollectionReq.CollectionSchema collection
227227
}
228228
}
229229

230+
private static Map<String, Object> genOriginStruct(int seed) {
231+
Map<String, Object> st = new HashMap<>();
232+
st.put("st_bool", seed % 3 == 0);
233+
st.put("st_int8", seed % 128);
234+
st.put("st_int16", seed % 16384);
235+
st.put("st_int32", seed % 65536);
236+
st.put("st_int64", seed);
237+
st.put("st_float", (float) seed / 4);
238+
st.put("st_double", seed / 3);
239+
st.put("st_string", String.format("dummy_%d", seed));
240+
st.put("st_float_vector", CommonUtils.generateFloatVector(DIM));
241+
return st;
242+
}
243+
230244
private static List<Map<String, Object>> genOriginalData(int count) {
231245
List<Map<String, Object>> data = new ArrayList<>();
232246
for (int i = 0; i < count; ++i) {
@@ -245,7 +259,7 @@ private static List<Map<String, Object>> genOriginalData(int count) {
245259
// vector field
246260
row.put("float_vector", CommonUtils.generateFloatVector(DIM));
247261
row.put("binary_vector", CommonUtils.generateBinaryVector(DIM).array());
248-
row.put("int8_vector", CommonUtils.generateInt8Vector(DIM).array());
262+
// row.put("int8_vector", CommonUtils.generateInt8Vector(DIM).array());
249263
row.put("sparse_vector", CommonUtils.generateSparseVector());
250264

251265
// array field
@@ -258,6 +272,13 @@ private static List<Map<String, Object>> genOriginalData(int count) {
258272
row.put("array_float", GeneratorUtils.generatorFloatValue(9));
259273
row.put("array_double", GeneratorUtils.generatorDoubleValue(10));
260274

275+
// struct field
276+
List<Map<String, Object>> structList = new ArrayList<>();
277+
for (int k = 0; k < i % 4 + 1; k++) {
278+
structList.add(genOriginStruct(i + k));
279+
}
280+
row.put("struct_field", structList);
281+
261282
data.add(row);
262283
}
263284
// a special record with null/default values
@@ -277,7 +298,7 @@ private static List<Map<String, Object>> genOriginalData(int count) {
277298
// vector field
278299
row.put("float_vector", CommonUtils.generateFloatVector(DIM));
279300
row.put("binary_vector", CommonUtils.generateBinaryVector(DIM).array());
280-
row.put("int8_vector", CommonUtils.generateInt8Vector(DIM).array());
301+
// row.put("int8_vector", CommonUtils.generateInt8Vector(DIM).array());
281302
row.put("sparse_vector", CommonUtils.generateSparseVector());
282303

283304
// array field
@@ -290,6 +311,9 @@ private static List<Map<String, Object>> genOriginalData(int count) {
290311
row.put("array_float", GeneratorUtils.generatorFloatValue(4));
291312
row.put("array_double", null);
292313

314+
// struct field
315+
row.put("struct_field", Collections.singletonList(genOriginStruct(0)));
316+
293317
data.add(row);
294318
}
295319
return data;
@@ -322,7 +346,7 @@ private static List<JsonObject> genImportData(List<Map<String, Object>> original
322346
// vector field
323347
rowObject.add("float_vector", GSON_INSTANCE.toJsonTree(row.get("float_vector")));
324348
rowObject.add("binary_vector", GSON_INSTANCE.toJsonTree(row.get("binary_vector")));
325-
rowObject.add("int8_vector", GSON_INSTANCE.toJsonTree(row.get("int8_vector")));
349+
// rowObject.add("int8_vector", GSON_INSTANCE.toJsonTree(row.get("int8_vector")));
326350
rowObject.add("sparse_vector", GSON_INSTANCE.toJsonTree(row.get("sparse_vector")));
327351

328352
// array field
@@ -335,6 +359,9 @@ private static List<JsonObject> genImportData(List<Map<String, Object>> original
335359
rowObject.add("array_float", GSON_INSTANCE.toJsonTree(row.get("array_float")));
336360
rowObject.add("array_double", GSON_INSTANCE.toJsonTree(row.get("array_double")));
337361

362+
// struct field
363+
rowObject.add("struct_field", GSON_INSTANCE.toJsonTree(row.get("struct_field")));
364+
338365
// dynamic fields
339366
if (isEnableDynamicField) {
340367
rowObject.addProperty("dynamic", "dynamic_" + row.get("id"));
@@ -462,11 +489,10 @@ private static void callBulkInsert(List<List<String>> batchFiles) throws Interru
462489
JsonObject getImportProgressObject = convertJsonObject(getImportProgressResult);
463490
String state = getImportProgressObject.getAsJsonObject("data").get("state").getAsString();
464491
String progress = getImportProgressObject.getAsJsonObject("data").get("progress").getAsString();
465-
if ("Failed".equals(state)) {
492+
if ("Failed" .equals(state)) {
466493
String reason = getImportProgressObject.getAsJsonObject("data").get("reason").getAsString();
467-
System.out.printf("The job %s failed, reason: %s%n", jobId, reason);
468-
break;
469-
} else if ("Completed".equals(state)) {
494+
throw new RuntimeException(String.format("The job %s failed, reason: %s", jobId, reason));
495+
} else if ("Completed" .equals(state)) {
470496
System.out.printf("The job %s completed%n", jobId);
471497
break;
472498
} else {
@@ -477,9 +503,8 @@ private static void callBulkInsert(List<List<String>> batchFiles) throws Interru
477503

478504
/**
479505
* @param collectionSchema collection info
480-
* @param dropIfExist if collection already exist, will drop firstly and then create again
481506
*/
482-
private static void createCollection(String collectionName, CreateCollectionReq.CollectionSchema collectionSchema, boolean dropIfExist) {
507+
private static void createCollection(String collectionName, CreateCollectionReq.CollectionSchema collectionSchema) {
483508
System.out.println("\n===================== create collection ====================");
484509
checkMilvusClientIfExist();
485510

@@ -489,15 +514,8 @@ private static void createCollection(String collectionName, CreateCollectionReq.
489514
.consistencyLevel(ConsistencyLevel.BOUNDED)
490515
.build();
491516

492-
Boolean has = milvusClient.hasCollection(HasCollectionReq.builder().collectionName(collectionName).build());
493-
if (has) {
494-
if (dropIfExist) {
495-
milvusClient.dropCollection(DropCollectionReq.builder().collectionName(collectionName).build());
496-
milvusClient.createCollection(requestCreate);
497-
}
498-
} else {
499-
milvusClient.createCollection(requestCreate);
500-
}
517+
milvusClient.dropCollection(DropCollectionReq.builder().collectionName(collectionName).build());
518+
milvusClient.createCollection(requestCreate);
501519

502520
System.out.printf("Collection %s created%n", collectionName);
503521
}
@@ -558,13 +576,40 @@ private static void comparePrint(CreateCollectionReq.CollectionSchema collection
558576
}
559577
}
560578

561-
private static void verifyImportData(CreateCollectionReq.CollectionSchema collectionSchema, List<Map<String, Object>> rows) {
562-
createIndex();
579+
private static void compareStruct(CreateCollectionReq.CollectionSchema collectionSchema,
580+
Map<String, Object> expectedData, Map<String, Object> fetchedData,
581+
String fieldName) {
582+
CreateCollectionReq.StructFieldSchema field = collectionSchema.getStructField(fieldName);
583+
Object expectedValue = expectedData.get(fieldName);
584+
Object fetchedValue = fetchedData.get(fieldName);
585+
if (fetchedValue == null) {
586+
throw new RuntimeException(String.format("Struct field '%s' missed in fetched data", fieldName));
587+
}
588+
589+
List<Map<String, Object>> expectedList = (List<Map<String, Object>>) expectedValue;
590+
if (!(fetchedValue instanceof List<?>)) {
591+
throw new RuntimeException(String.format("Struct field '%s' value should be a list", fieldName));
592+
}
593+
594+
List<Map<String, Object>> fetchedList = (List<Map<String, Object>>) fetchedValue;
595+
if (expectedList.size() != fetchedList.size()) {
596+
throw new RuntimeException(String.format("Struct field '%s' list count unmatched", fieldName));
597+
}
598+
599+
for (int i = 0; i < expectedList.size(); i++) {
600+
Map<String, Object> expectedStruct = expectedList.get(i);
601+
Map<String, Object> fetchedStruct = fetchedList.get(i);
602+
if (expectedStruct.equals(fetchedStruct)) {
603+
throw new RuntimeException(String.format("Struct field '%s' value unmatched", fieldName));
604+
}
605+
}
606+
}
563607

608+
private static void verifyImportData(CreateCollectionReq.CollectionSchema collectionSchema, List<Map<String, Object>> rows) {
564609
List<Long> QUERY_IDS = Lists.newArrayList(1L, (long) rows.get(rows.size() - 1).get("id"));
565610
System.out.printf("Load collection and query items %s%n", QUERY_IDS);
611+
createIndex(collectionSchema);
566612
loadCollection();
567-
568613
String expr = String.format("id in %s", QUERY_IDS);
569614
System.out.println(expr);
570615

@@ -597,40 +642,74 @@ private static void verifyImportData(CreateCollectionReq.CollectionSchema collec
597642

598643
comparePrint(collectionSchema, originalEntity, fetchedEntity, "float_vector");
599644
comparePrint(collectionSchema, originalEntity, fetchedEntity, "binary_vector");
600-
comparePrint(collectionSchema, originalEntity, fetchedEntity, "int8_vector");
645+
// comparePrint(collectionSchema, originalEntity, fetchedEntity, "int8_vector");
601646
comparePrint(collectionSchema, originalEntity, fetchedEntity, "sparse_vector");
602647

648+
compareStruct(collectionSchema, originalEntity, fetchedEntity, "struct_field");
649+
603650
System.out.println(fetchedEntity);
604651
}
605652
System.out.println("Result is correct!");
606653
}
607654

608-
private static void createIndex() {
655+
private static void createIndex(CreateCollectionReq.CollectionSchema collectionSchema) {
609656
System.out.println("Create index...");
610657
checkMilvusClientIfExist();
611658

612659
List<IndexParam> indexes = new ArrayList<>();
613-
indexes.add(IndexParam.builder()
614-
.fieldName("float_vector")
615-
.indexType(IndexParam.IndexType.FLAT)
616-
.metricType(IndexParam.MetricType.L2)
617-
.build());
618-
indexes.add(IndexParam.builder()
619-
.fieldName("binary_vector")
620-
.indexType(IndexParam.IndexType.BIN_FLAT)
621-
.metricType(IndexParam.MetricType.HAMMING)
622-
.build());
623-
indexes.add(IndexParam.builder()
624-
.fieldName("int8_vector")
625-
.indexType(IndexParam.IndexType.AUTOINDEX)
626-
.metricType(IndexParam.MetricType.L2)
627-
.build());
628-
indexes.add(IndexParam.builder()
629-
.fieldName("sparse_vector")
630-
.indexType(IndexParam.IndexType.SPARSE_WAND)
631-
.metricType(IndexParam.MetricType.IP)
632-
.build());
660+
for (CreateCollectionReq.FieldSchema field : collectionSchema.getFieldSchemaList()) {
661+
IndexParam.IndexType indexType;
662+
IndexParam.MetricType metricType;
663+
switch (field.getDataType()) {
664+
case FloatVector:
665+
case Float16Vector:
666+
case BFloat16Vector:
667+
indexType = IndexParam.IndexType.IVF_FLAT;
668+
metricType = IndexParam.MetricType.L2;
669+
break;
670+
case BinaryVector:
671+
indexType = IndexParam.IndexType.BIN_FLAT;
672+
metricType = IndexParam.MetricType.HAMMING;
673+
break;
674+
case Int8Vector:
675+
indexType = IndexParam.IndexType.AUTOINDEX;
676+
metricType = IndexParam.MetricType.L2;
677+
break;
678+
case SparseFloatVector:
679+
indexType = IndexParam.IndexType.SPARSE_WAND;
680+
metricType = IndexParam.MetricType.IP;
681+
break;
682+
default:
683+
continue;
684+
}
685+
indexes.add(IndexParam.builder()
686+
.fieldName(field.getName())
687+
.indexName(String.format("index_%s", field.getName()))
688+
.indexType(indexType)
689+
.metricType(metricType)
690+
.build());
691+
}
633692

693+
for (CreateCollectionReq.StructFieldSchema struct : collectionSchema.getStructFields()) {
694+
for (CreateCollectionReq.FieldSchema subField : struct.getFields()) {
695+
IndexParam.IndexType indexType;
696+
IndexParam.MetricType metricType;
697+
switch (subField.getDataType()) {
698+
case FloatVector:
699+
indexType = IndexParam.IndexType.HNSW;
700+
metricType = IndexParam.MetricType.MAX_SIM_COSINE;
701+
break;
702+
default:
703+
continue;
704+
}
705+
indexes.add(IndexParam.builder()
706+
.fieldName(String.format("%s[%s]", struct.getName(), subField.getName()))
707+
.indexName(String.format("index_%s", subField.getName()))
708+
.indexType(indexType)
709+
.metricType(metricType)
710+
.build());
711+
}
712+
}
634713
milvusClient.createIndex(CreateIndexReq.builder()
635714
.collectionName(ALL_TYPES_COLLECTION_NAME)
636715
.indexParams(indexes)
@@ -660,6 +739,7 @@ private static List<QueryResp.QueryResult> query(String expr, List<String> outpu
660739
.collectionName(ALL_TYPES_COLLECTION_NAME)
661740
.filter(expr)
662741
.outputFields(outputFields)
742+
.consistencyLevel(ConsistencyLevel.STRONG)
663743
.build();
664744
QueryResp response = milvusClient.query(test);
665745
return response.getQueryResults();
@@ -798,11 +878,11 @@ private static CreateCollectionReq.CollectionSchema buildAllTypesSchema() {
798878
.dataType(DataType.BinaryVector)
799879
.dimension(DIM)
800880
.build());
801-
schemaV2.addField(AddFieldReq.builder()
802-
.fieldName("int8_vector")
803-
.dataType(DataType.Int8Vector)
804-
.dimension(DIM)
805-
.build());
881+
// schemaV2.addField(AddFieldReq.builder()
882+
// .fieldName("int8_vector")
883+
// .dataType(DataType.Int8Vector)
884+
// .dimension(DIM)
885+
// .build());
806886
schemaV2.addField(AddFieldReq.builder()
807887
.fieldName("sparse_vector")
808888
.dataType(DataType.SparseFloatVector)
@@ -860,6 +940,50 @@ private static CreateCollectionReq.CollectionSchema buildAllTypesSchema() {
860940
.elementType(DataType.Double)
861941
.isNullable(true)
862942
.build());
943+
schemaV2.addField(AddFieldReq.builder()
944+
.fieldName("struct_field")
945+
.dataType(DataType.Array)
946+
.elementType(DataType.Struct)
947+
.maxCapacity(100)
948+
.addStructField(AddFieldReq.builder()
949+
.fieldName("st_bool")
950+
.dataType(DataType.Bool)
951+
.build())
952+
.addStructField(AddFieldReq.builder()
953+
.fieldName("st_int8")
954+
.dataType(DataType.Int8)
955+
.build())
956+
.addStructField(AddFieldReq.builder()
957+
.fieldName("st_int16")
958+
.dataType(DataType.Int16)
959+
.build())
960+
.addStructField(AddFieldReq.builder()
961+
.fieldName("st_int32")
962+
.dataType(DataType.Int32)
963+
.build())
964+
.addStructField(AddFieldReq.builder()
965+
.fieldName("st_int64")
966+
.dataType(DataType.Int64)
967+
.build())
968+
.addStructField(AddFieldReq.builder()
969+
.fieldName("st_float")
970+
.dataType(DataType.Float)
971+
.build())
972+
.addStructField(AddFieldReq.builder()
973+
.fieldName("st_double")
974+
.dataType(DataType.Double)
975+
.build())
976+
.addStructField(AddFieldReq.builder()
977+
.fieldName("st_string")
978+
.dataType(DataType.VarChar)
979+
.maxLength(100)
980+
.build())
981+
.addStructField(AddFieldReq.builder()
982+
.fieldName("st_float_vector")
983+
.dataType(DataType.FloatVector)
984+
.dimension(DIM)
985+
.build())
986+
.build());
863987

864988
return schemaV2;
865989
}

0 commit comments

Comments
 (0)