Skip to content

Commit dc9e707

Browse files
committed
BulkWriter supports Struct
Signed-off-by: yhmo <[email protected]>
1 parent 45ae45d commit dc9e707

File tree

4 files changed

+290
-110
lines changed

4 files changed

+290
-110
lines changed

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BulkWriter.java

Lines changed: 84 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
package io.milvus.bulkwriter;
2121

2222
import com.google.common.collect.Lists;
23-
import com.google.gson.JsonElement;
24-
import com.google.gson.JsonNull;
25-
import com.google.gson.JsonObject;
26-
import com.google.gson.JsonPrimitive;
23+
import com.google.gson.*;
2724
import io.milvus.bulkwriter.common.clientenum.BulkFileType;
2825
import io.milvus.bulkwriter.common.clientenum.TypeSize;
2926
import io.milvus.bulkwriter.common.utils.V2AdapterUtils;
@@ -281,54 +278,54 @@ protected Map<String, Object> verifyRow(JsonObject row) {
281278
}
282279
}
283280

284-
DataType dataType = field.getDataType();
285-
switch (dataType) {
286-
case BinaryVector:
287-
case FloatVector:
288-
case Float16Vector:
289-
case BFloat16Vector:
290-
case SparseFloatVector:
291-
case Int8Vector: {
292-
Pair<Object, Integer> objectAndSize = verifyVector(obj, field);
293-
rowValues.put(fieldName, objectAndSize.getLeft());
294-
rowSize += objectAndSize.getRight();
295-
break;
296-
}
297-
case VarChar:
298-
case Geometry:
299-
case Timestamptz: {
300-
Pair<Object, Integer> objectAndSize = verifyVarchar(obj, field);
301-
rowValues.put(fieldName, objectAndSize.getLeft());
302-
rowSize += objectAndSize.getRight();
303-
break;
304-
}
305-
case JSON: {
306-
Pair<Object, Integer> objectAndSize = verifyJSON(obj, field);
307-
rowValues.put(fieldName, objectAndSize.getLeft());
308-
rowSize += objectAndSize.getRight();
309-
break;
281+
Pair<Object, Integer> objectAndSize = verifyByDatatype(field, obj);
282+
if (objectAndSize != null) {
283+
rowValues.put(fieldName, objectAndSize.getLeft());
284+
rowSize += objectAndSize.getRight();
285+
}
286+
}
287+
288+
for (CreateCollectionReq.StructFieldSchema struct : collectionSchema.getStructFields()) {
289+
String structName = struct.getName();
290+
JsonArray structList = row.getAsJsonArray(structName);
291+
if (structList == null) {
292+
String msg = String.format("Value of struct field '%s' is not provided.", structName);
293+
ExceptionUtils.throwUnExpectedException(msg);
294+
}
295+
296+
List<Map<String, Object>> validList = new ArrayList<>();
297+
for (JsonElement st : structList.asList()) {
298+
if (!st.isJsonObject()) {
299+
String msg = String.format("Element of struct field '%s' must be JSON dict.", structName);
300+
ExceptionUtils.throwUnExpectedException(msg);
310301
}
311-
case Array: {
312-
Pair<Object, Integer> objectAndSize = verifyArray(obj, field);
313-
rowValues.put(fieldName, objectAndSize.getLeft());
314-
rowSize += objectAndSize.getRight();
315-
break;
302+
303+
JsonObject dict = st.getAsJsonObject();
304+
Map<String, Object> validStruct = new HashMap<>();
305+
for (CreateCollectionReq.FieldSchema subField : struct.getFields()) {
306+
String subFieldName = subField.getName();
307+
String combineName = String.format("%s[%s]", structName, subFieldName);
308+
boolean provided = dict.has(subFieldName);
309+
boolean isFunctionOutput = outputFieldNames.contains(combineName);
310+
if (provided && isFunctionOutput) {
311+
String msg = String.format("The field '%s' is function output, no need to provide", combineName);
312+
ExceptionUtils.throwUnExpectedException(msg);
313+
}
314+
if (!isFunctionOutput && !provided) {
315+
String msg = String.format("Value of field '%s' is not provided.", combineName);
316+
ExceptionUtils.throwUnExpectedException(msg);
317+
}
318+
319+
Pair<Object, Integer> objectAndSize = verifyByDatatype(subField, dict.get(subFieldName));
320+
if (objectAndSize != null) {
321+
validStruct.put(subFieldName, objectAndSize.getLeft());
322+
rowSize += objectAndSize.getRight();
323+
}
316324
}
317-
case Bool:
318-
case Int8:
319-
case Int16:
320-
case Int32:
321-
case Int64:
322-
case Float:
323-
case Double:
324-
Pair<Object, Integer> objectAndSize = verifyScalar(obj, field);
325-
rowValues.put(fieldName, objectAndSize.getLeft());
326-
rowSize += objectAndSize.getRight();
327-
break;
328-
default:
329-
String msg = String.format("Unsupported data type of field '%s', not implemented in BulkWriter.", fieldName);
330-
ExceptionUtils.throwUnExpectedException(msg);
325+
validList.add(validStruct);
331326
}
327+
328+
rowValues.put(structName, validList);
332329
}
333330

334331
// process dynamic values
@@ -361,6 +358,44 @@ protected Map<String, Object> verifyRow(JsonObject row) {
361358
return rowValues;
362359
}
363360

361+
private Pair<Object, Integer> verifyByDatatype(CreateCollectionReq.FieldSchema field, JsonElement obj) {
362+
DataType dataType = field.getDataType();
363+
String fieldName = field.getName();
364+
switch (dataType) {
365+
case BinaryVector:
366+
case FloatVector:
367+
case Float16Vector:
368+
case BFloat16Vector:
369+
case SparseFloatVector:
370+
case Int8Vector: {
371+
return verifyVector(obj, field);
372+
}
373+
case VarChar:
374+
case Geometry:
375+
case Timestamptz: {
376+
return verifyVarchar(obj, field);
377+
}
378+
case JSON: {
379+
return verifyJSON(obj, field);
380+
}
381+
case Array: {
382+
return verifyArray(obj, field);
383+
}
384+
case Bool:
385+
case Int8:
386+
case Int16:
387+
case Int32:
388+
case Int64:
389+
case Float:
390+
case Double:
391+
return verifyScalar(obj, field);
392+
default:
393+
String msg = String.format("Unsupported data type of field '%s', not implemented in BulkWriter.", fieldName);
394+
ExceptionUtils.throwUnExpectedException(msg);
395+
}
396+
return null;
397+
}
398+
364399
private Pair<Object, Integer> verifyVector(JsonElement object, CreateCollectionReq.FieldSchema field) {
365400
Object vector = DataUtils.checkFieldValue(field, object);
366401
io.milvus.v2.common.DataType dataType = field.getDataType();

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/ParquetUtils.java

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import static io.milvus.param.Constant.DYNAMIC_FIELD_NAME;
3232

3333
public class ParquetUtils {
34-
private static void setMessageType(Types.MessageTypeBuilder builder,
34+
private static void setMessageType(Types.BaseGroupBuilder<?, ?> builder,
3535
PrimitiveType.PrimitiveTypeName primitiveName,
3636
LogicalTypeAnnotation logicType,
3737
CreateCollectionReq.FieldSchema field,
@@ -75,6 +75,29 @@ public static MessageType parseCollectionSchema(CreateCollectionReq.CollectionSc
7575
}
7676

7777
switch (field.getDataType()) {
78+
case Bool:
79+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BOOLEAN, null, field, false);
80+
break;
81+
case Int8:
82+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
83+
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true), field, false);
84+
break;
85+
case Int16:
86+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
87+
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true), field, false);
88+
break;
89+
case Int32:
90+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, null, field, false);
91+
break;
92+
case Int64:
93+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, field, false);
94+
break;
95+
case Float:
96+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, field, false);
97+
break;
98+
case Double:
99+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.DOUBLE, null, field, false);
100+
break;
78101
case FloatVector:
79102
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, field, true);
80103
break;
@@ -89,10 +112,6 @@ public static MessageType parseCollectionSchema(CreateCollectionReq.CollectionSc
89112
case Array:
90113
fillArrayType(messageTypeBuilder, field);
91114
break;
92-
93-
case Int64:
94-
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, field, false);
95-
break;
96115
case VarChar:
97116
case Geometry:
98117
case Timestamptz:
@@ -101,30 +120,14 @@ public static MessageType parseCollectionSchema(CreateCollectionReq.CollectionSc
101120
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BINARY,
102121
LogicalTypeAnnotation.stringType(), field, false);
103122
break;
104-
case Int8:
105-
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
106-
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true), field, false);
107-
break;
108-
case Int16:
109-
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
110-
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true), field, false);
111-
break;
112-
case Int32:
113-
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, null, field, false);
114-
break;
115-
case Float:
116-
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, field, false);
117-
break;
118-
case Double:
119-
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.DOUBLE, null, field, false);
120-
break;
121-
case Bool:
122-
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BOOLEAN, null, field, false);
123-
break;
124-
125123
}
126124
}
127125

126+
List<CreateCollectionReq.StructFieldSchema> structFields = collectionSchema.getStructFields();
127+
for (CreateCollectionReq.StructFieldSchema struct : structFields) {
128+
fillStructType(messageTypeBuilder, struct);
129+
}
130+
128131
if (collectionSchema.isEnableDynamicField()) {
129132
messageTypeBuilder.optional(PrimitiveType.PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType())
130133
.named(DYNAMIC_FIELD_NAME);
@@ -134,12 +137,8 @@ public static MessageType parseCollectionSchema(CreateCollectionReq.CollectionSc
134137

135138
private static void fillArrayType(Types.MessageTypeBuilder messageTypeBuilder, CreateCollectionReq.FieldSchema field) {
136139
switch (field.getElementType()) {
137-
case Int64:
138-
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, field, true);
139-
break;
140-
case VarChar:
141-
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BINARY,
142-
LogicalTypeAnnotation.stringType(), field, true);
140+
case Bool:
141+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BOOLEAN, null, field, true);
143142
break;
144143
case Int8:
145144
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
@@ -152,18 +151,63 @@ private static void fillArrayType(Types.MessageTypeBuilder messageTypeBuilder, C
152151
case Int32:
153152
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, null, field, true);
154153
break;
154+
case Int64:
155+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, field, true);
156+
break;
155157
case Float:
156158
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, field, true);
157159
break;
158160
case Double:
159161
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.DOUBLE, null, field, true);
160162
break;
161-
case Bool:
162-
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BOOLEAN, null, field, true);
163+
case VarChar:
164+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BINARY,
165+
LogicalTypeAnnotation.stringType(), field, true);
163166
break;
164167
}
165168
}
166169

170+
private static void fillStructType(Types.MessageTypeBuilder messageTypeBuilder, CreateCollectionReq.StructFieldSchema struct) {
171+
Types.BaseListBuilder.GroupElementBuilder<?, ?> groupBuilder = messageTypeBuilder.optionalList().optionalGroupElement();
172+
for (CreateCollectionReq.FieldSchema subField : struct.getFields()) {
173+
switch (subField.getDataType()) {
174+
case Bool:
175+
setMessageType(groupBuilder, PrimitiveType.PrimitiveTypeName.BOOLEAN, null, subField, false);
176+
break;
177+
case Int8:
178+
setMessageType(groupBuilder, PrimitiveType.PrimitiveTypeName.INT32,
179+
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true), subField, false);
180+
break;
181+
case Int16:
182+
setMessageType(groupBuilder, PrimitiveType.PrimitiveTypeName.INT32,
183+
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true), subField, false);
184+
break;
185+
case Int32:
186+
setMessageType(groupBuilder, PrimitiveType.PrimitiveTypeName.INT32, null, subField, false);
187+
break;
188+
case Int64:
189+
setMessageType(groupBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, subField, false);
190+
break;
191+
case Float:
192+
setMessageType(groupBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, subField, false);
193+
break;
194+
case Double:
195+
setMessageType(groupBuilder, PrimitiveType.PrimitiveTypeName.DOUBLE, null, subField, false);
196+
break;
197+
case VarChar:
198+
setMessageType(groupBuilder, PrimitiveType.PrimitiveTypeName.BINARY,
199+
LogicalTypeAnnotation.stringType(), subField, false);
200+
break;
201+
case FloatVector:
202+
setMessageType(groupBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, subField, true);
203+
break;
204+
default:
205+
break;
206+
}
207+
}
208+
groupBuilder.named(struct.getName());
209+
}
210+
167211
public static Configuration getParquetConfiguration() {
168212
// set fs.file.impl.disable.cache to true for this issue: https://github.com/milvus-io/milvus-sdk-java/issues/1381
169213
Configuration configuration = new Configuration();

0 commit comments

Comments
 (0)