1616
1717import static com .google .common .base .Preconditions .checkNotNull ;
1818
19+ import com .google .common .annotations .VisibleForTesting ;
20+ import com .google .common .base .Defaults ;
21+ import com .google .common .collect .ImmutableMap ;
22+ import com .google .common .primitives .UnsignedLong ;
1923import com .google .errorprone .annotations .Immutable ;
24+ import com .google .protobuf .ByteString ;
25+ import com .google .protobuf .CodedInputStream ;
26+ import com .google .protobuf .ExtensionRegistryLite ;
2027import com .google .protobuf .MessageLite ;
28+ import com .google .protobuf .WireFormat ;
2129import dev .cel .common .annotations .Internal ;
2230import dev .cel .common .internal .CelLiteDescriptorPool ;
2331import dev .cel .common .internal .WellKnownProto ;
32+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor ;
33+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor .CelFieldValueType ;
34+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor .JavaType ;
2435import dev .cel .protobuf .CelLiteDescriptor .MessageLiteDescriptor ;
36+ import java .io .IOException ;
37+ import java .util .AbstractMap ;
38+ import java .util .ArrayList ;
39+ import java .util .Collection ;
40+ import java .util .Collections ;
41+ import java .util .HashMap ;
42+ import java .util .LinkedHashMap ;
43+ import java .util .List ;
44+ import java .util .Map ;
2545
2646/**
2747 * {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and
@@ -43,6 +63,271 @@ public static ProtoLiteCelValueConverter newInstance(
4363 return new ProtoLiteCelValueConverter (celLiteDescriptorPool );
4464 }
4565
66+ private static Object readPrimitiveField (
67+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
68+ switch (fieldDescriptor .getProtoFieldType ()) {
69+ case SINT32 :
70+ return inputStream .readSInt32 ();
71+ case SINT64 :
72+ return inputStream .readSInt64 ();
73+ case INT32 :
74+ case ENUM :
75+ return inputStream .readInt32 ();
76+ case INT64 :
77+ return inputStream .readInt64 ();
78+ case UINT32 :
79+ return UnsignedLong .fromLongBits (inputStream .readUInt32 ());
80+ case UINT64 :
81+ return UnsignedLong .fromLongBits (inputStream .readUInt64 ());
82+ case BOOL :
83+ return inputStream .readBool ();
84+ case FLOAT :
85+ case FIXED32 :
86+ case SFIXED32 :
87+ return readFixed32BitField (inputStream , fieldDescriptor );
88+ case DOUBLE :
89+ case FIXED64 :
90+ case SFIXED64 :
91+ return readFixed64BitField (inputStream , fieldDescriptor );
92+ default :
93+ throw new IllegalStateException (
94+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
95+ }
96+ }
97+
98+ private static Object readFixed32BitField (
99+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
100+ switch (fieldDescriptor .getProtoFieldType ()) {
101+ case FLOAT :
102+ return inputStream .readFloat ();
103+ case FIXED32 :
104+ case SFIXED32 :
105+ return inputStream .readRawLittleEndian32 ();
106+ default :
107+ throw new IllegalStateException (
108+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
109+ }
110+ }
111+
112+ private static Object readFixed64BitField (
113+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
114+ switch (fieldDescriptor .getProtoFieldType ()) {
115+ case DOUBLE :
116+ return inputStream .readDouble ();
117+ case FIXED64 :
118+ case SFIXED64 :
119+ return inputStream .readRawLittleEndian64 ();
120+ default :
121+ throw new IllegalStateException (
122+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
123+ }
124+ }
125+
126+ private Object readLengthDelimitedField (
127+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
128+ FieldLiteDescriptor .Type fieldType = fieldDescriptor .getProtoFieldType ();
129+
130+ switch (fieldType ) {
131+ case BYTES :
132+ return inputStream .readBytes ();
133+ case MESSAGE :
134+ MessageLite .Builder builder =
135+ getDefaultMessageBuilder (fieldDescriptor .getFieldProtoTypeName ());
136+
137+ inputStream .readMessage (builder , ExtensionRegistryLite .getEmptyRegistry ());
138+ return builder .build ();
139+ case STRING :
140+ return inputStream .readStringRequireUtf8 ();
141+ default :
142+ throw new IllegalStateException ("Unexpected field type: " + fieldType );
143+ }
144+ }
145+
146+ private MessageLite .Builder getDefaultMessageBuilder (String protoTypeName ) {
147+ return descriptorPool .getDescriptorOrThrow (protoTypeName ).newMessageBuilder ();
148+ }
149+
150+ CelValue getDefaultCelValue (String protoTypeName , String fieldName ) {
151+ MessageLiteDescriptor messageDescriptor = descriptorPool .getDescriptorOrThrow (protoTypeName );
152+ FieldLiteDescriptor fieldDescriptor = messageDescriptor .getByFieldNameOrThrow (fieldName );
153+
154+ Object defaultValue = getDefaultValue (fieldDescriptor );
155+ if (defaultValue instanceof MessageLite ) {
156+ return fromProtoMessageToCelValue (
157+ fieldDescriptor .getFieldProtoTypeName (), (MessageLite ) defaultValue );
158+ } else {
159+ return fromJavaObjectToCelValue (getDefaultValue (fieldDescriptor ));
160+ }
161+ }
162+
163+ private Object getDefaultValue (FieldLiteDescriptor fieldDescriptor ) {
164+ FieldLiteDescriptor .CelFieldValueType celFieldValueType =
165+ fieldDescriptor .getCelFieldValueType ();
166+ switch (celFieldValueType ) {
167+ case LIST :
168+ return Collections .unmodifiableList (new ArrayList <>());
169+ case MAP :
170+ return Collections .unmodifiableMap (new HashMap <>());
171+ case SCALAR :
172+ return getScalarDefaultValue (fieldDescriptor );
173+ }
174+ throw new IllegalStateException ("Unexpected cel field value type: " + celFieldValueType );
175+ }
176+
177+ private Object getScalarDefaultValue (FieldLiteDescriptor fieldDescriptor ) {
178+ JavaType type = fieldDescriptor .getJavaType ();
179+ switch (type ) {
180+ case INT :
181+ return fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .UINT32 )
182+ ? UnsignedLong .ZERO
183+ : Defaults .defaultValue (long .class );
184+ case LONG :
185+ return fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .UINT64 )
186+ ? UnsignedLong .ZERO
187+ : Defaults .defaultValue (long .class );
188+ case ENUM :
189+ return Defaults .defaultValue (long .class );
190+ case FLOAT :
191+ return Defaults .defaultValue (float .class );
192+ case DOUBLE :
193+ return Defaults .defaultValue (double .class );
194+ case BOOLEAN :
195+ return Defaults .defaultValue (boolean .class );
196+ case STRING :
197+ return "" ;
198+ case BYTE_STRING :
199+ return ByteString .EMPTY ;
200+ case MESSAGE :
201+ if (WellKnownProto .isWrapperType (fieldDescriptor .getFieldProtoTypeName ())) {
202+ return NullValue .NULL_VALUE ;
203+ }
204+
205+ return getDefaultMessageBuilder (fieldDescriptor .getFieldProtoTypeName ()).build ();
206+ }
207+ throw new IllegalStateException ("Unexpected java type: " + type );
208+ }
209+
210+ private List <Object > readPackedRepeatedFields (
211+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
212+ int length = inputStream .readInt32 ();
213+ int oldLimit = inputStream .pushLimit (length );
214+ List <Object > repeatedFieldValues = new ArrayList <>();
215+ while (inputStream .getBytesUntilLimit () > 0 ) {
216+ Object value = readPrimitiveField (inputStream , fieldDescriptor );
217+ repeatedFieldValues .add (value );
218+ }
219+ inputStream .popLimit (oldLimit );
220+ return Collections .unmodifiableList (repeatedFieldValues );
221+ }
222+
223+ private Map .Entry <Object , Object > readSingleMapEntry (
224+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
225+ ImmutableMap <String , Object > singleMapEntry =
226+ readAllFields (inputStream .readByteArray (), fieldDescriptor .getFieldProtoTypeName ());
227+ Object key = checkNotNull (singleMapEntry .get ("key" ));
228+ Object value = checkNotNull (singleMapEntry .get ("value" ));
229+
230+ return new AbstractMap .SimpleEntry <>(key , value );
231+ }
232+
233+ @ VisibleForTesting
234+ ImmutableMap <String , Object > readAllFields (byte [] bytes , String protoTypeName )
235+ throws IOException {
236+ // TODO: Handle unknown fields by collecting them into a separate map.
237+ MessageLiteDescriptor messageDescriptor = descriptorPool .getDescriptorOrThrow (protoTypeName );
238+ CodedInputStream inputStream = CodedInputStream .newInstance (bytes );
239+
240+ ImmutableMap .Builder <String , Object > fieldValues = ImmutableMap .builder ();
241+ Map <Integer , List <Object >> repeatedFieldValues = new LinkedHashMap <>();
242+ Map <Integer , Map <Object , Object >> mapFieldValues = new LinkedHashMap <>();
243+ for (int tag = inputStream .readTag (); tag != 0 ; tag = inputStream .readTag ()) {
244+ int tagWireType = WireFormat .getTagWireType (tag );
245+ int fieldNumber = WireFormat .getTagFieldNumber (tag );
246+ FieldLiteDescriptor fieldDescriptor = messageDescriptor .getByFieldNumberOrThrow (fieldNumber );
247+
248+ Object payload ;
249+ switch (tagWireType ) {
250+ case WireFormat .WIRETYPE_VARINT :
251+ payload = readPrimitiveField (inputStream , fieldDescriptor );
252+ break ;
253+ case WireFormat .WIRETYPE_FIXED32 :
254+ payload = readFixed32BitField (inputStream , fieldDescriptor );
255+ break ;
256+ case WireFormat .WIRETYPE_FIXED64 :
257+ payload = readFixed64BitField (inputStream , fieldDescriptor );
258+ break ;
259+ case WireFormat .WIRETYPE_LENGTH_DELIMITED :
260+ CelFieldValueType celFieldValueType = fieldDescriptor .getCelFieldValueType ();
261+ switch (celFieldValueType ) {
262+ case LIST :
263+ if (fieldDescriptor .getIsPacked ()) {
264+ payload = readPackedRepeatedFields (inputStream , fieldDescriptor );
265+ } else {
266+ FieldLiteDescriptor .Type protoFieldType = fieldDescriptor .getProtoFieldType ();
267+ boolean isLenDelimited =
268+ protoFieldType .equals (FieldLiteDescriptor .Type .MESSAGE )
269+ || protoFieldType .equals (FieldLiteDescriptor .Type .STRING )
270+ || protoFieldType .equals (FieldLiteDescriptor .Type .BYTES );
271+ if (!isLenDelimited ) {
272+ throw new IllegalStateException (
273+ "Unexpected field type encountered for LEN-Delimited record: "
274+ + protoFieldType );
275+ }
276+
277+ payload = readLengthDelimitedField (inputStream , fieldDescriptor );
278+ }
279+ break ;
280+ case MAP :
281+ Map <Object , Object > fieldMap =
282+ mapFieldValues .computeIfAbsent (fieldNumber , (unused ) -> new LinkedHashMap <>());
283+ Map .Entry <Object , Object > mapEntry = readSingleMapEntry (inputStream , fieldDescriptor );
284+ fieldMap .put (mapEntry .getKey (), mapEntry .getValue ());
285+ payload = fieldMap ;
286+ break ;
287+ default :
288+ payload = readLengthDelimitedField (inputStream , fieldDescriptor );
289+ break ;
290+ }
291+ break ;
292+ case WireFormat .WIRETYPE_START_GROUP :
293+ case WireFormat .WIRETYPE_END_GROUP :
294+ // TODO: Support groups
295+ throw new UnsupportedOperationException ("Groups are not supported" );
296+ default :
297+ throw new IllegalArgumentException ("Unexpected wire type: " + tagWireType );
298+ }
299+
300+ if (fieldDescriptor .getCelFieldValueType ().equals (CelFieldValueType .LIST )) {
301+ String fieldName = fieldDescriptor .getFieldName ();
302+ List <Object > repeatedValues =
303+ repeatedFieldValues .computeIfAbsent (
304+ fieldNumber ,
305+ (unused ) -> {
306+ List <Object > newList = new ArrayList <>();
307+ fieldValues .put (fieldName , newList );
308+ return newList ;
309+ });
310+
311+ if (payload instanceof Collection ) {
312+ repeatedValues .addAll ((Collection <?>) payload );
313+ } else {
314+ repeatedValues .add (payload );
315+ }
316+ } else {
317+ fieldValues .put (fieldDescriptor .getFieldName (), payload );
318+ }
319+ }
320+
321+ // Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields,
322+ // we accept the last value encountered.
323+ return fieldValues .buildKeepingLast ();
324+ }
325+
326+ ImmutableMap <String , Object > readAllFields (MessageLite msg , String protoTypeName )
327+ throws IOException {
328+ return readAllFields (msg .toByteArray (), protoTypeName );
329+ }
330+
46331 @ Override
47332 public CelValue fromProtoMessageToCelValue (String protoTypeName , MessageLite msg ) {
48333 checkNotNull (msg );
0 commit comments