1616
1717import static com .google .common .base .Preconditions .checkNotNull ;
1818
19+ import com .google .common .annotations .VisibleForTesting ;
20+ import com .google .common .base .Defaults ;
21+ import com .google .common .collect .ImmutableList ;
22+ import com .google .common .collect .ImmutableMap ;
23+ import com .google .common .primitives .UnsignedLong ;
1924import com .google .errorprone .annotations .Immutable ;
25+ import com .google .protobuf .ByteString ;
26+ import com .google .protobuf .CodedInputStream ;
27+ import com .google .protobuf .ExtensionRegistryLite ;
2028import com .google .protobuf .MessageLite ;
29+ import com .google .protobuf .WireFormat ;
2130import dev .cel .common .annotations .Internal ;
2231import dev .cel .common .internal .CelLiteDescriptorPool ;
2332import dev .cel .common .internal .WellKnownProto ;
33+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor ;
34+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor .CelFieldValueType ;
35+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor .JavaType ;
2436import dev .cel .protobuf .CelLiteDescriptor .MessageLiteDescriptor ;
37+ import java .io .IOException ;
38+ import java .util .AbstractMap ;
39+ import java .util .ArrayList ;
40+ import java .util .Collection ;
41+ import java .util .LinkedHashMap ;
42+ import java .util .List ;
43+ import java .util .Map ;
2544
2645/**
2746 * {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and
@@ -43,6 +62,270 @@ public static ProtoLiteCelValueConverter newInstance(
4362 return new ProtoLiteCelValueConverter (celLiteDescriptorPool );
4463 }
4564
65+ private static Object readPrimitiveField (
66+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
67+ switch (fieldDescriptor .getProtoFieldType ()) {
68+ case SINT32 :
69+ return inputStream .readSInt32 ();
70+ case SINT64 :
71+ return inputStream .readSInt64 ();
72+ case INT32 :
73+ case ENUM :
74+ return inputStream .readInt32 ();
75+ case INT64 :
76+ return inputStream .readInt64 ();
77+ case UINT32 :
78+ return UnsignedLong .fromLongBits (inputStream .readUInt32 ());
79+ case UINT64 :
80+ return UnsignedLong .fromLongBits (inputStream .readUInt64 ());
81+ case BOOL :
82+ return inputStream .readBool ();
83+ case FLOAT :
84+ case FIXED32 :
85+ case SFIXED32 :
86+ return readFixed32BitField (inputStream , fieldDescriptor );
87+ case DOUBLE :
88+ case FIXED64 :
89+ case SFIXED64 :
90+ return readFixed64BitField (inputStream , fieldDescriptor );
91+ default :
92+ throw new IllegalStateException (
93+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
94+ }
95+ }
96+
97+ private static Object readFixed32BitField (
98+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
99+ switch (fieldDescriptor .getProtoFieldType ()) {
100+ case FLOAT :
101+ return inputStream .readFloat ();
102+ case FIXED32 :
103+ case SFIXED32 :
104+ return inputStream .readRawLittleEndian32 ();
105+ default :
106+ throw new IllegalStateException (
107+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
108+ }
109+ }
110+
111+ private static Object readFixed64BitField (
112+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
113+ switch (fieldDescriptor .getProtoFieldType ()) {
114+ case DOUBLE :
115+ return inputStream .readDouble ();
116+ case FIXED64 :
117+ case SFIXED64 :
118+ return inputStream .readRawLittleEndian64 ();
119+ default :
120+ throw new IllegalStateException (
121+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
122+ }
123+ }
124+
125+ private Object readLengthDelimitedField (
126+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
127+ FieldLiteDescriptor .Type fieldType = fieldDescriptor .getProtoFieldType ();
128+
129+ switch (fieldType ) {
130+ case BYTES :
131+ return inputStream .readBytes ();
132+ case MESSAGE :
133+ MessageLite .Builder builder =
134+ getDefaultMessageBuilder (fieldDescriptor .getFieldProtoTypeName ());
135+
136+ inputStream .readMessage (builder , ExtensionRegistryLite .getEmptyRegistry ());
137+ return builder .build ();
138+ case STRING :
139+ return inputStream .readStringRequireUtf8 ();
140+ default :
141+ throw new IllegalStateException ("Unexpected field type: " + fieldType );
142+ }
143+ }
144+
145+ private MessageLite .Builder getDefaultMessageBuilder (String protoTypeName ) {
146+ return descriptorPool .getDescriptorOrThrow (protoTypeName ).newMessageBuilder ();
147+ }
148+
149+ CelValue getDefaultCelValue (String protoTypeName , String fieldName ) {
150+ MessageLiteDescriptor messageDescriptor = descriptorPool .getDescriptorOrThrow (protoTypeName );
151+ FieldLiteDescriptor fieldDescriptor = messageDescriptor .getByFieldNameOrThrow (fieldName );
152+
153+ Object defaultValue = getDefaultValue (fieldDescriptor );
154+ if (defaultValue instanceof MessageLite ) {
155+ return fromProtoMessageToCelValue (
156+ fieldDescriptor .getFieldProtoTypeName (), (MessageLite ) defaultValue );
157+ } else {
158+ return fromJavaObjectToCelValue (getDefaultValue (fieldDescriptor ));
159+ }
160+ }
161+
162+ private Object getDefaultValue (FieldLiteDescriptor fieldDescriptor ) {
163+ FieldLiteDescriptor .CelFieldValueType celFieldValueType =
164+ fieldDescriptor .getCelFieldValueType ();
165+ switch (celFieldValueType ) {
166+ case LIST :
167+ return ImmutableList .of ();
168+ case MAP :
169+ return ImmutableMap .of ();
170+ case SCALAR :
171+ return getScalarDefaultValue (fieldDescriptor );
172+ }
173+ throw new IllegalStateException ("Unexpected cel field value type: " + celFieldValueType );
174+ }
175+
176+ private Object getScalarDefaultValue (FieldLiteDescriptor fieldDescriptor ) {
177+ JavaType type = fieldDescriptor .getJavaType ();
178+ switch (type ) {
179+ case INT :
180+ return fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .UINT32 )
181+ ? UnsignedLong .ZERO
182+ : Defaults .defaultValue (long .class );
183+ case LONG :
184+ return fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .UINT64 )
185+ ? UnsignedLong .ZERO
186+ : Defaults .defaultValue (long .class );
187+ case ENUM :
188+ return Defaults .defaultValue (long .class );
189+ case FLOAT :
190+ return Defaults .defaultValue (float .class );
191+ case DOUBLE :
192+ return Defaults .defaultValue (double .class );
193+ case BOOLEAN :
194+ return Defaults .defaultValue (boolean .class );
195+ case STRING :
196+ return "" ;
197+ case BYTE_STRING :
198+ return ByteString .EMPTY ;
199+ case MESSAGE :
200+ if (WellKnownProto .isWrapperType (fieldDescriptor .getFieldProtoTypeName ())) {
201+ return NullValue .NULL_VALUE ;
202+ }
203+
204+ return getDefaultMessageBuilder (fieldDescriptor .getFieldProtoTypeName ()).build ();
205+ }
206+ throw new IllegalStateException ("Unexpected java type: " + type );
207+ }
208+
209+ private ImmutableList <Object > readPackedRepeatedFields (
210+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
211+ int length = inputStream .readInt32 ();
212+ int oldLimit = inputStream .pushLimit (length );
213+ ImmutableList .Builder <Object > builder = ImmutableList .builder ();
214+ while (inputStream .getBytesUntilLimit () > 0 ) {
215+ builder .add (readPrimitiveField (inputStream , fieldDescriptor ));
216+ }
217+ inputStream .popLimit (oldLimit );
218+ return builder .build ();
219+ }
220+
221+ private Map .Entry <Object , Object > readSingleMapEntry (
222+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
223+ ImmutableMap <String , Object > singleMapEntry =
224+ readAllFields (inputStream .readByteArray (), fieldDescriptor .getFieldProtoTypeName ());
225+ Object key = checkNotNull (singleMapEntry .get ("key" ));
226+ Object value = checkNotNull (singleMapEntry .get ("value" ));
227+
228+ return new AbstractMap .SimpleEntry <>(key , value );
229+ }
230+
231+ @ VisibleForTesting
232+ ImmutableMap <String , Object > readAllFields (byte [] bytes , String protoTypeName )
233+ throws IOException {
234+ // TODO: Handle unknown fields by collecting them into a separate map.
235+ MessageLiteDescriptor messageDescriptor = descriptorPool .getDescriptorOrThrow (protoTypeName );
236+ CodedInputStream inputStream = CodedInputStream .newInstance (bytes );
237+
238+ ImmutableMap .Builder <String , Object > fieldValues = ImmutableMap .builder ();
239+ Map <Integer , List <Object >> repeatedFieldValues = new LinkedHashMap <>();
240+ Map <Integer , Map <Object , Object >> mapFieldValues = new LinkedHashMap <>();
241+ for (int tag = inputStream .readTag (); tag != 0 ; tag = inputStream .readTag ()) {
242+ int tagWireType = WireFormat .getTagWireType (tag );
243+ int fieldNumber = WireFormat .getTagFieldNumber (tag );
244+ FieldLiteDescriptor fieldDescriptor = messageDescriptor .getByFieldNumberOrThrow (fieldNumber );
245+
246+ Object payload ;
247+ switch (tagWireType ) {
248+ case WireFormat .WIRETYPE_VARINT :
249+ payload = readPrimitiveField (inputStream , fieldDescriptor );
250+ break ;
251+ case WireFormat .WIRETYPE_FIXED32 :
252+ payload = readFixed32BitField (inputStream , fieldDescriptor );
253+ break ;
254+ case WireFormat .WIRETYPE_FIXED64 :
255+ payload = readFixed64BitField (inputStream , fieldDescriptor );
256+ break ;
257+ case WireFormat .WIRETYPE_LENGTH_DELIMITED :
258+ CelFieldValueType celFieldValueType = fieldDescriptor .getCelFieldValueType ();
259+ switch (celFieldValueType ) {
260+ case LIST :
261+ if (fieldDescriptor .getIsPacked ()) {
262+ payload = readPackedRepeatedFields (inputStream , fieldDescriptor );
263+ } else {
264+ FieldLiteDescriptor .Type protoFieldType = fieldDescriptor .getProtoFieldType ();
265+ boolean isLenDelimited =
266+ protoFieldType .equals (FieldLiteDescriptor .Type .MESSAGE )
267+ || protoFieldType .equals (FieldLiteDescriptor .Type .STRING )
268+ || protoFieldType .equals (FieldLiteDescriptor .Type .BYTES );
269+ if (!isLenDelimited ) {
270+ throw new IllegalStateException (
271+ "Unexpected field type encountered for LEN-Delimited record: "
272+ + protoFieldType );
273+ }
274+
275+ payload = readLengthDelimitedField (inputStream , fieldDescriptor );
276+ }
277+ break ;
278+ case MAP :
279+ Map <Object , Object > fieldMap =
280+ mapFieldValues .computeIfAbsent (fieldNumber , (unused ) -> new LinkedHashMap <>());
281+ Map .Entry <Object , Object > mapEntry = readSingleMapEntry (inputStream , fieldDescriptor );
282+ fieldMap .put (mapEntry .getKey (), mapEntry .getValue ());
283+ payload = fieldMap ;
284+ break ;
285+ default :
286+ payload = readLengthDelimitedField (inputStream , fieldDescriptor );
287+ break ;
288+ }
289+ break ;
290+ case WireFormat .WIRETYPE_START_GROUP :
291+ case WireFormat .WIRETYPE_END_GROUP :
292+ // TODO: Support groups
293+ throw new UnsupportedOperationException ("Groups are not supported" );
294+ default :
295+ throw new IllegalArgumentException ("Unexpected wire type: " + tagWireType );
296+ }
297+
298+ if (fieldDescriptor .getCelFieldValueType ().equals (CelFieldValueType .LIST )) {
299+ String fieldName = fieldDescriptor .getFieldName ();
300+ List <Object > repeatedValues =
301+ repeatedFieldValues .computeIfAbsent (
302+ fieldNumber ,
303+ (unused ) -> {
304+ List <Object > newList = new ArrayList <>();
305+ fieldValues .put (fieldName , newList );
306+ return newList ;
307+ });
308+
309+ if (payload instanceof Collection ) {
310+ repeatedValues .addAll ((Collection <?>) payload );
311+ } else {
312+ repeatedValues .add (payload );
313+ }
314+ } else {
315+ fieldValues .put (fieldDescriptor .getFieldName (), payload );
316+ }
317+ }
318+
319+ // Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields,
320+ // we accept the last value encountered.
321+ return fieldValues .buildKeepingLast ();
322+ }
323+
324+ ImmutableMap <String , Object > readAllFields (MessageLite msg , String protoTypeName )
325+ throws IOException {
326+ return readAllFields (msg .toByteArray (), protoTypeName );
327+ }
328+
46329 @ Override
47330 public CelValue fromProtoMessageToCelValue (String protoTypeName , MessageLite msg ) {
48331 checkNotNull (msg );
0 commit comments