@@ -167,7 +167,7 @@ private static void createConnection() {
167167
168168 private static void exampleSimpleCollection (List <BulkFileType > fileTypes ) throws Exception {
169169 CreateCollectionReq .CollectionSchema collectionSchema = buildSimpleSchema ();
170- createCollection (SIMPLE_COLLECTION_NAME , collectionSchema , false );
170+ createCollection (SIMPLE_COLLECTION_NAME , collectionSchema );
171171
172172 for (BulkFileType fileType : fileTypes ) {
173173 remoteWriter (collectionSchema , fileType );
@@ -182,7 +182,7 @@ private static void exampleAllTypesCollectionRemote(List<BulkFileType> fileTypes
182182 for (BulkFileType fileType : fileTypes ) {
183183 CreateCollectionReq .CollectionSchema collectionSchema = buildAllTypesSchema ();
184184 List <List <String >> batchFiles = allTypesRemoteWriter (collectionSchema , fileType , rows );
185- createCollection (ALL_TYPES_COLLECTION_NAME , collectionSchema , false );
185+ createCollection (ALL_TYPES_COLLECTION_NAME , collectionSchema );
186186 callBulkInsert (batchFiles );
187187 verifyImportData (collectionSchema , originalData );
188188 }
@@ -192,7 +192,7 @@ private static void exampleAllTypesCollectionRemote(List<BulkFileType> fileTypes
192192// for (BulkFileType fileType : fileTypes) {
193193// CreateCollectionReq.CollectionSchema collectionSchema = buildAllTypesSchema();
194194// List<List<String>> batchFiles = allTypesRemoteWriter(collectionSchema, fileType, rows);
195- // createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema, false );
195+ // createCollection(ALL_TYPES_COLLECTION_NAME, collectionSchema);
196196// callCloudImport(batchFiles, ALL_TYPES_COLLECTION_NAME, "");
197197// verifyImportData(collectionSchema, originalData);
198198// }
@@ -227,6 +227,20 @@ private static void remoteWriter(CreateCollectionReq.CollectionSchema collection
227227 }
228228 }
229229
230+ private static Map <String , Object > genOriginStruct (int seed ) {
231+ Map <String , Object > st = new HashMap <>();
232+ st .put ("st_bool" , seed % 3 == 0 );
233+ st .put ("st_int8" , seed % 128 );
234+ st .put ("st_int16" , seed % 16384 );
235+ st .put ("st_int32" , seed % 65536 );
236+ st .put ("st_int64" , seed );
237+ st .put ("st_float" , (float ) seed / 4 );
238+ st .put ("st_double" , seed / 3 );
239+ st .put ("st_string" , String .format ("dummy_%d" , seed ));
240+ st .put ("st_float_vector" , CommonUtils .generateFloatVector (DIM ));
241+ return st ;
242+ }
243+
230244 private static List <Map <String , Object >> genOriginalData (int count ) {
231245 List <Map <String , Object >> data = new ArrayList <>();
232246 for (int i = 0 ; i < count ; ++i ) {
@@ -245,7 +259,7 @@ private static List<Map<String, Object>> genOriginalData(int count) {
245259 // vector field
246260 row .put ("float_vector" , CommonUtils .generateFloatVector (DIM ));
247261 row .put ("binary_vector" , CommonUtils .generateBinaryVector (DIM ).array ());
248- row .put ("int8_vector" , CommonUtils .generateInt8Vector (DIM ).array ());
262+ // row.put("int8_vector", CommonUtils.generateInt8Vector(DIM).array());
249263 row .put ("sparse_vector" , CommonUtils .generateSparseVector ());
250264
251265 // array field
@@ -258,6 +272,13 @@ private static List<Map<String, Object>> genOriginalData(int count) {
258272 row .put ("array_float" , GeneratorUtils .generatorFloatValue (9 ));
259273 row .put ("array_double" , GeneratorUtils .generatorDoubleValue (10 ));
260274
275+ // struct field
276+ List <Map <String , Object >> structList = new ArrayList <>();
277+ for (int k = 0 ; k < i % 4 + 1 ; k ++) {
278+ structList .add (genOriginStruct (i + k ));
279+ }
280+ row .put ("struct_field" , structList );
281+
261282 data .add (row );
262283 }
263284 // a special record with null/default values
@@ -277,7 +298,7 @@ private static List<Map<String, Object>> genOriginalData(int count) {
277298 // vector field
278299 row .put ("float_vector" , CommonUtils .generateFloatVector (DIM ));
279300 row .put ("binary_vector" , CommonUtils .generateBinaryVector (DIM ).array ());
280- row .put ("int8_vector" , CommonUtils .generateInt8Vector (DIM ).array ());
301+ // row.put("int8_vector", CommonUtils.generateInt8Vector(DIM).array());
281302 row .put ("sparse_vector" , CommonUtils .generateSparseVector ());
282303
283304 // array field
@@ -290,6 +311,9 @@ private static List<Map<String, Object>> genOriginalData(int count) {
290311 row .put ("array_float" , GeneratorUtils .generatorFloatValue (4 ));
291312 row .put ("array_double" , null );
292313
314+ // struct field
315+ row .put ("struct_field" , Collections .singletonList (genOriginStruct (0 )));
316+
293317 data .add (row );
294318 }
295319 return data ;
@@ -322,7 +346,7 @@ private static List<JsonObject> genImportData(List<Map<String, Object>> original
322346 // vector field
323347 rowObject .add ("float_vector" , GSON_INSTANCE .toJsonTree (row .get ("float_vector" )));
324348 rowObject .add ("binary_vector" , GSON_INSTANCE .toJsonTree (row .get ("binary_vector" )));
325- rowObject .add ("int8_vector" , GSON_INSTANCE .toJsonTree (row .get ("int8_vector" )));
349+ // rowObject.add("int8_vector", GSON_INSTANCE.toJsonTree(row.get("int8_vector")));
326350 rowObject .add ("sparse_vector" , GSON_INSTANCE .toJsonTree (row .get ("sparse_vector" )));
327351
328352 // array field
@@ -335,6 +359,9 @@ private static List<JsonObject> genImportData(List<Map<String, Object>> original
335359 rowObject .add ("array_float" , GSON_INSTANCE .toJsonTree (row .get ("array_float" )));
336360 rowObject .add ("array_double" , GSON_INSTANCE .toJsonTree (row .get ("array_double" )));
337361
362+ // struct field
363+ rowObject .add ("struct_field" , GSON_INSTANCE .toJsonTree (row .get ("struct_field" )));
364+
338365 // dynamic fields
339366 if (isEnableDynamicField ) {
340367 rowObject .addProperty ("dynamic" , "dynamic_" + row .get ("id" ));
@@ -462,11 +489,10 @@ private static void callBulkInsert(List<List<String>> batchFiles) throws Interru
462489 JsonObject getImportProgressObject = convertJsonObject (getImportProgressResult );
463490 String state = getImportProgressObject .getAsJsonObject ("data" ).get ("state" ).getAsString ();
464491 String progress = getImportProgressObject .getAsJsonObject ("data" ).get ("progress" ).getAsString ();
465- if ("Failed" .equals (state )) {
492+ if ("Failed" .equals (state )) {
466493 String reason = getImportProgressObject .getAsJsonObject ("data" ).get ("reason" ).getAsString ();
467- System .out .printf ("The job %s failed, reason: %s%n" , jobId , reason );
468- break ;
469- } else if ("Completed" .equals (state )) {
494+ throw new RuntimeException (String .format ("The job %s failed, reason: %s" , jobId , reason ));
495+ } else if ("Completed" .equals (state )) {
470496 System .out .printf ("The job %s completed%n" , jobId );
471497 break ;
472498 } else {
@@ -477,9 +503,8 @@ private static void callBulkInsert(List<List<String>> batchFiles) throws Interru
477503
478504 /**
479505 * @param collectionSchema collection info
480- * @param dropIfExist if collection already exist, will drop firstly and then create again
481506 */
482- private static void createCollection (String collectionName , CreateCollectionReq .CollectionSchema collectionSchema , boolean dropIfExist ) {
507+ private static void createCollection (String collectionName , CreateCollectionReq .CollectionSchema collectionSchema ) {
483508 System .out .println ("\n ===================== create collection ====================" );
484509 checkMilvusClientIfExist ();
485510
@@ -489,15 +514,8 @@ private static void createCollection(String collectionName, CreateCollectionReq.
489514 .consistencyLevel (ConsistencyLevel .BOUNDED )
490515 .build ();
491516
492- Boolean has = milvusClient .hasCollection (HasCollectionReq .builder ().collectionName (collectionName ).build ());
493- if (has ) {
494- if (dropIfExist ) {
495- milvusClient .dropCollection (DropCollectionReq .builder ().collectionName (collectionName ).build ());
496- milvusClient .createCollection (requestCreate );
497- }
498- } else {
499- milvusClient .createCollection (requestCreate );
500- }
517+ milvusClient .dropCollection (DropCollectionReq .builder ().collectionName (collectionName ).build ());
518+ milvusClient .createCollection (requestCreate );
501519
502520 System .out .printf ("Collection %s created%n" , collectionName );
503521 }
@@ -558,13 +576,40 @@ private static void comparePrint(CreateCollectionReq.CollectionSchema collection
558576 }
559577 }
560578
561- private static void verifyImportData (CreateCollectionReq .CollectionSchema collectionSchema , List <Map <String , Object >> rows ) {
562- createIndex ();
579+ private static void compareStruct (CreateCollectionReq .CollectionSchema collectionSchema ,
580+ Map <String , Object > expectedData , Map <String , Object > fetchedData ,
581+ String fieldName ) {
582+ CreateCollectionReq .StructFieldSchema field = collectionSchema .getStructField (fieldName );
583+ Object expectedValue = expectedData .get (fieldName );
584+ Object fetchedValue = fetchedData .get (fieldName );
585+ if (fetchedValue == null ) {
586+ throw new RuntimeException (String .format ("Struct field '%s' missed in fetched data" , fieldName ));
587+ }
588+
589+ List <Map <String , Object >> expectedList = (List <Map <String , Object >>) expectedValue ;
590+ if (!(fetchedValue instanceof List <?>)) {
591+ throw new RuntimeException (String .format ("Struct field '%s' value should be a list" , fieldName ));
592+ }
593+
594+ List <Map <String , Object >> fetchedList = (List <Map <String , Object >>) fetchedValue ;
595+ if (expectedList .size () != fetchedList .size ()) {
596+ throw new RuntimeException (String .format ("Struct field '%s' list count unmatched" , fieldName ));
597+ }
598+
599+ for (int i = 0 ; i < expectedList .size (); i ++) {
600+ Map <String , Object > expectedStruct = expectedList .get (i );
601+ Map <String , Object > fetchedStruct = fetchedList .get (i );
602+ if (expectedStruct .equals (fetchedStruct )) {
603+ throw new RuntimeException (String .format ("Struct field '%s' value unmatched" , fieldName ));
604+ }
605+ }
606+ }
563607
608+ private static void verifyImportData (CreateCollectionReq .CollectionSchema collectionSchema , List <Map <String , Object >> rows ) {
564609 List <Long > QUERY_IDS = Lists .newArrayList (1L , (long ) rows .get (rows .size () - 1 ).get ("id" ));
565610 System .out .printf ("Load collection and query items %s%n" , QUERY_IDS );
611+ createIndex (collectionSchema );
566612 loadCollection ();
567-
568613 String expr = String .format ("id in %s" , QUERY_IDS );
569614 System .out .println (expr );
570615
@@ -597,40 +642,74 @@ private static void verifyImportData(CreateCollectionReq.CollectionSchema collec
597642
598643 comparePrint (collectionSchema , originalEntity , fetchedEntity , "float_vector" );
599644 comparePrint (collectionSchema , originalEntity , fetchedEntity , "binary_vector" );
600- comparePrint (collectionSchema , originalEntity , fetchedEntity , "int8_vector" );
645+ // comparePrint(collectionSchema, originalEntity, fetchedEntity, "int8_vector");
601646 comparePrint (collectionSchema , originalEntity , fetchedEntity , "sparse_vector" );
602647
648+ compareStruct (collectionSchema , originalEntity , fetchedEntity , "struct_field" );
649+
603650 System .out .println (fetchedEntity );
604651 }
605652 System .out .println ("Result is correct!" );
606653 }
607654
608- private static void createIndex () {
655+ private static void createIndex (CreateCollectionReq . CollectionSchema collectionSchema ) {
609656 System .out .println ("Create index..." );
610657 checkMilvusClientIfExist ();
611658
612659 List <IndexParam > indexes = new ArrayList <>();
613- indexes .add (IndexParam .builder ()
614- .fieldName ("float_vector" )
615- .indexType (IndexParam .IndexType .FLAT )
616- .metricType (IndexParam .MetricType .L2 )
617- .build ());
618- indexes .add (IndexParam .builder ()
619- .fieldName ("binary_vector" )
620- .indexType (IndexParam .IndexType .BIN_FLAT )
621- .metricType (IndexParam .MetricType .HAMMING )
622- .build ());
623- indexes .add (IndexParam .builder ()
624- .fieldName ("int8_vector" )
625- .indexType (IndexParam .IndexType .AUTOINDEX )
626- .metricType (IndexParam .MetricType .L2 )
627- .build ());
628- indexes .add (IndexParam .builder ()
629- .fieldName ("sparse_vector" )
630- .indexType (IndexParam .IndexType .SPARSE_WAND )
631- .metricType (IndexParam .MetricType .IP )
632- .build ());
660+ for (CreateCollectionReq .FieldSchema field : collectionSchema .getFieldSchemaList ()) {
661+ IndexParam .IndexType indexType ;
662+ IndexParam .MetricType metricType ;
663+ switch (field .getDataType ()) {
664+ case FloatVector :
665+ case Float16Vector :
666+ case BFloat16Vector :
667+ indexType = IndexParam .IndexType .IVF_FLAT ;
668+ metricType = IndexParam .MetricType .L2 ;
669+ break ;
670+ case BinaryVector :
671+ indexType = IndexParam .IndexType .BIN_FLAT ;
672+ metricType = IndexParam .MetricType .HAMMING ;
673+ break ;
674+ case Int8Vector :
675+ indexType = IndexParam .IndexType .AUTOINDEX ;
676+ metricType = IndexParam .MetricType .L2 ;
677+ break ;
678+ case SparseFloatVector :
679+ indexType = IndexParam .IndexType .SPARSE_WAND ;
680+ metricType = IndexParam .MetricType .IP ;
681+ break ;
682+ default :
683+ continue ;
684+ }
685+ indexes .add (IndexParam .builder ()
686+ .fieldName (field .getName ())
687+ .indexName (String .format ("index_%s" , field .getName ()))
688+ .indexType (indexType )
689+ .metricType (metricType )
690+ .build ());
691+ }
633692
693+ for (CreateCollectionReq .StructFieldSchema struct : collectionSchema .getStructFields ()) {
694+ for (CreateCollectionReq .FieldSchema subField : struct .getFields ()) {
695+ IndexParam .IndexType indexType ;
696+ IndexParam .MetricType metricType ;
697+ switch (subField .getDataType ()) {
698+ case FloatVector :
699+ indexType = IndexParam .IndexType .HNSW ;
700+ metricType = IndexParam .MetricType .MAX_SIM_COSINE ;
701+ break ;
702+ default :
703+ continue ;
704+ }
705+ indexes .add (IndexParam .builder ()
706+ .fieldName (String .format ("%s[%s]" , struct .getName (), subField .getName ()))
707+ .indexName (String .format ("index_%s" , subField .getName ()))
708+ .indexType (indexType )
709+ .metricType (metricType )
710+ .build ());
711+ }
712+ }
634713 milvusClient .createIndex (CreateIndexReq .builder ()
635714 .collectionName (ALL_TYPES_COLLECTION_NAME )
636715 .indexParams (indexes )
@@ -660,6 +739,7 @@ private static List<QueryResp.QueryResult> query(String expr, List<String> outpu
660739 .collectionName (ALL_TYPES_COLLECTION_NAME )
661740 .filter (expr )
662741 .outputFields (outputFields )
742+ .consistencyLevel (ConsistencyLevel .STRONG )
663743 .build ();
664744 QueryResp response = milvusClient .query (test );
665745 return response .getQueryResults ();
@@ -798,11 +878,11 @@ private static CreateCollectionReq.CollectionSchema buildAllTypesSchema() {
798878 .dataType (DataType .BinaryVector )
799879 .dimension (DIM )
800880 .build ());
801- schemaV2 .addField (AddFieldReq .builder ()
802- .fieldName ("int8_vector" )
803- .dataType (DataType .Int8Vector )
804- .dimension (DIM )
805- .build ());
881+ // schemaV2.addField(AddFieldReq.builder()
882+ // .fieldName("int8_vector")
883+ // .dataType(DataType.Int8Vector)
884+ // .dimension(DIM)
885+ // .build());
806886 schemaV2 .addField (AddFieldReq .builder ()
807887 .fieldName ("sparse_vector" )
808888 .dataType (DataType .SparseFloatVector )
@@ -860,6 +940,50 @@ private static CreateCollectionReq.CollectionSchema buildAllTypesSchema() {
860940 .elementType (DataType .Double )
861941 .isNullable (true )
862942 .build ());
943+ schemaV2 .addField (AddFieldReq .builder ()
944+ .fieldName ("struct_field" )
945+ .dataType (DataType .Array )
946+ .elementType (DataType .Struct )
947+ .maxCapacity (100 )
948+ .addStructField (AddFieldReq .builder ()
949+ .fieldName ("st_bool" )
950+ .dataType (DataType .Bool )
951+ .build ())
952+ .addStructField (AddFieldReq .builder ()
953+ .fieldName ("st_int8" )
954+ .dataType (DataType .Int8 )
955+ .build ())
956+ .addStructField (AddFieldReq .builder ()
957+ .fieldName ("st_int16" )
958+ .dataType (DataType .Int16 )
959+ .build ())
960+ .addStructField (AddFieldReq .builder ()
961+ .fieldName ("st_int32" )
962+ .dataType (DataType .Int32 )
963+ .build ())
964+ .addStructField (AddFieldReq .builder ()
965+ .fieldName ("st_int64" )
966+ .dataType (DataType .Int64 )
967+ .build ())
968+ .addStructField (AddFieldReq .builder ()
969+ .fieldName ("st_float" )
970+ .dataType (DataType .Float )
971+ .build ())
972+ .addStructField (AddFieldReq .builder ()
973+ .fieldName ("st_double" )
974+ .dataType (DataType .Double )
975+ .build ())
976+ .addStructField (AddFieldReq .builder ()
977+ .fieldName ("st_string" )
978+ .dataType (DataType .VarChar )
979+ .maxLength (100 )
980+ .build ())
981+ .addStructField (AddFieldReq .builder ()
982+ .fieldName ("st_float_vector" )
983+ .dataType (DataType .FloatVector )
984+ .dimension (DIM )
985+ .build ())
986+ .build ());
863987
864988 return schemaV2 ;
865989 }
0 commit comments