Skip to content

Commit 49259c6

Browse files
l46kokcopybara-github
authored andcommitted
Perform field selections on lite messages by reading from the wire format
PiperOrigin-RevId: 748423924
1 parent 220312c commit 49259c6

File tree

4 files changed

+542
-2
lines changed

4 files changed

+542
-2
lines changed

common/src/main/java/dev/cel/common/values/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ java_library(
179179
"//protobuf:cel_lite_descriptor",
180180
"@maven//:com_google_errorprone_error_prone_annotations",
181181
"@maven//:com_google_guava_guava",
182+
"@maven//:com_google_protobuf_protobuf_java",
182183
"@maven//:org_jspecify_jspecify",
183184
"@maven_android//:com_google_protobuf_protobuf_javalite",
184185
],

common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,32 @@
1616

1717
import static com.google.common.base.Preconditions.checkNotNull;
1818

19+
import com.google.common.annotations.VisibleForTesting;
20+
import com.google.common.base.Defaults;
21+
import com.google.common.collect.ImmutableMap;
22+
import com.google.common.primitives.UnsignedLong;
1923
import com.google.errorprone.annotations.Immutable;
24+
import com.google.protobuf.ByteString;
25+
import com.google.protobuf.CodedInputStream;
26+
import com.google.protobuf.ExtensionRegistryLite;
2027
import com.google.protobuf.MessageLite;
28+
import com.google.protobuf.WireFormat;
2129
import dev.cel.common.annotations.Internal;
2230
import dev.cel.common.internal.CelLiteDescriptorPool;
2331
import dev.cel.common.internal.WellKnownProto;
32+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor;
33+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor.CelFieldValueType;
34+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor.JavaType;
2435
import dev.cel.protobuf.CelLiteDescriptor.MessageLiteDescriptor;
36+
import java.io.IOException;
37+
import java.util.AbstractMap;
38+
import java.util.ArrayList;
39+
import java.util.Collection;
40+
import java.util.Collections;
41+
import java.util.HashMap;
42+
import java.util.LinkedHashMap;
43+
import java.util.List;
44+
import java.util.Map;
2545

2646
/**
2747
* {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and
@@ -43,6 +63,271 @@ public static ProtoLiteCelValueConverter newInstance(
4363
return new ProtoLiteCelValueConverter(celLiteDescriptorPool);
4464
}
4565

66+
private static Object readPrimitiveField(
67+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
68+
switch (fieldDescriptor.getProtoFieldType()) {
69+
case SINT32:
70+
return inputStream.readSInt32();
71+
case SINT64:
72+
return inputStream.readSInt64();
73+
case INT32:
74+
case ENUM:
75+
return inputStream.readInt32();
76+
case INT64:
77+
return inputStream.readInt64();
78+
case UINT32:
79+
return UnsignedLong.fromLongBits(inputStream.readUInt32());
80+
case UINT64:
81+
return UnsignedLong.fromLongBits(inputStream.readUInt64());
82+
case BOOL:
83+
return inputStream.readBool();
84+
case FLOAT:
85+
case FIXED32:
86+
case SFIXED32:
87+
return readFixed32BitField(inputStream, fieldDescriptor);
88+
case DOUBLE:
89+
case FIXED64:
90+
case SFIXED64:
91+
return readFixed64BitField(inputStream, fieldDescriptor);
92+
default:
93+
throw new IllegalStateException(
94+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
95+
}
96+
}
97+
98+
private static Object readFixed32BitField(
99+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
100+
switch (fieldDescriptor.getProtoFieldType()) {
101+
case FLOAT:
102+
return inputStream.readFloat();
103+
case FIXED32:
104+
case SFIXED32:
105+
return inputStream.readRawLittleEndian32();
106+
default:
107+
throw new IllegalStateException(
108+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
109+
}
110+
}
111+
112+
private static Object readFixed64BitField(
113+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
114+
switch (fieldDescriptor.getProtoFieldType()) {
115+
case DOUBLE:
116+
return inputStream.readDouble();
117+
case FIXED64:
118+
case SFIXED64:
119+
return inputStream.readRawLittleEndian64();
120+
default:
121+
throw new IllegalStateException(
122+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
123+
}
124+
}
125+
126+
private Object readLengthDelimitedField(
127+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
128+
FieldLiteDescriptor.Type fieldType = fieldDescriptor.getProtoFieldType();
129+
130+
switch (fieldType) {
131+
case BYTES:
132+
return inputStream.readBytes();
133+
case MESSAGE:
134+
MessageLite.Builder builder =
135+
getDefaultMessageBuilder(fieldDescriptor.getFieldProtoTypeName());
136+
137+
inputStream.readMessage(builder, ExtensionRegistryLite.getEmptyRegistry());
138+
return builder.build();
139+
case STRING:
140+
return inputStream.readStringRequireUtf8();
141+
default:
142+
throw new IllegalStateException("Unexpected field type: " + fieldType);
143+
}
144+
}
145+
146+
private MessageLite.Builder getDefaultMessageBuilder(String protoTypeName) {
147+
return descriptorPool.getDescriptorOrThrow(protoTypeName).newMessageBuilder();
148+
}
149+
150+
CelValue getDefaultCelValue(String protoTypeName, String fieldName) {
151+
MessageLiteDescriptor messageDescriptor = descriptorPool.getDescriptorOrThrow(protoTypeName);
152+
FieldLiteDescriptor fieldDescriptor = messageDescriptor.getByFieldNameOrThrow(fieldName);
153+
154+
Object defaultValue = getDefaultValue(fieldDescriptor);
155+
if (defaultValue instanceof MessageLite) {
156+
return fromProtoMessageToCelValue(
157+
fieldDescriptor.getFieldProtoTypeName(), (MessageLite) defaultValue);
158+
} else {
159+
return fromJavaObjectToCelValue(getDefaultValue(fieldDescriptor));
160+
}
161+
}
162+
163+
private Object getDefaultValue(FieldLiteDescriptor fieldDescriptor) {
164+
FieldLiteDescriptor.CelFieldValueType celFieldValueType =
165+
fieldDescriptor.getCelFieldValueType();
166+
switch (celFieldValueType) {
167+
case LIST:
168+
return Collections.unmodifiableList(new ArrayList<>());
169+
case MAP:
170+
return Collections.unmodifiableMap(new HashMap<>());
171+
case SCALAR:
172+
return getScalarDefaultValue(fieldDescriptor);
173+
}
174+
throw new IllegalStateException("Unexpected cel field value type: " + celFieldValueType);
175+
}
176+
177+
private Object getScalarDefaultValue(FieldLiteDescriptor fieldDescriptor) {
178+
JavaType type = fieldDescriptor.getJavaType();
179+
switch (type) {
180+
case INT:
181+
return fieldDescriptor.getProtoFieldType().equals(FieldLiteDescriptor.Type.UINT32)
182+
? UnsignedLong.ZERO
183+
: Defaults.defaultValue(long.class);
184+
case LONG:
185+
return fieldDescriptor.getProtoFieldType().equals(FieldLiteDescriptor.Type.UINT64)
186+
? UnsignedLong.ZERO
187+
: Defaults.defaultValue(long.class);
188+
case ENUM:
189+
return Defaults.defaultValue(long.class);
190+
case FLOAT:
191+
return Defaults.defaultValue(float.class);
192+
case DOUBLE:
193+
return Defaults.defaultValue(double.class);
194+
case BOOLEAN:
195+
return Defaults.defaultValue(boolean.class);
196+
case STRING:
197+
return "";
198+
case BYTE_STRING:
199+
return ByteString.EMPTY;
200+
case MESSAGE:
201+
if (WellKnownProto.isWrapperType(fieldDescriptor.getFieldProtoTypeName())) {
202+
return NullValue.NULL_VALUE;
203+
}
204+
205+
return getDefaultMessageBuilder(fieldDescriptor.getFieldProtoTypeName()).build();
206+
}
207+
throw new IllegalStateException("Unexpected java type: " + type);
208+
}
209+
210+
private List<Object> readPackedRepeatedFields(
211+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
212+
int length = inputStream.readInt32();
213+
int oldLimit = inputStream.pushLimit(length);
214+
List<Object> repeatedFieldValues = new ArrayList<>();
215+
while (inputStream.getBytesUntilLimit() > 0) {
216+
Object value = readPrimitiveField(inputStream, fieldDescriptor);
217+
repeatedFieldValues.add(value);
218+
}
219+
inputStream.popLimit(oldLimit);
220+
return Collections.unmodifiableList(repeatedFieldValues);
221+
}
222+
223+
private Map.Entry<Object, Object> readSingleMapEntry(
224+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
225+
ImmutableMap<String, Object> singleMapEntry =
226+
readAllFields(inputStream.readByteArray(), fieldDescriptor.getFieldProtoTypeName());
227+
Object key = checkNotNull(singleMapEntry.get("key"));
228+
Object value = checkNotNull(singleMapEntry.get("value"));
229+
230+
return new AbstractMap.SimpleEntry<>(key, value);
231+
}
232+
233+
@VisibleForTesting
234+
ImmutableMap<String, Object> readAllFields(byte[] bytes, String protoTypeName)
235+
throws IOException {
236+
// TODO: Handle unknown fields by collecting them into a separate map.
237+
MessageLiteDescriptor messageDescriptor = descriptorPool.getDescriptorOrThrow(protoTypeName);
238+
CodedInputStream inputStream = CodedInputStream.newInstance(bytes);
239+
240+
ImmutableMap.Builder<String, Object> fieldValues = ImmutableMap.builder();
241+
Map<Integer, List<Object>> repeatedFieldValues = new LinkedHashMap<>();
242+
Map<Integer, Map<Object, Object>> mapFieldValues = new LinkedHashMap<>();
243+
for (int tag = inputStream.readTag(); tag != 0; tag = inputStream.readTag()) {
244+
int tagWireType = WireFormat.getTagWireType(tag);
245+
int fieldNumber = WireFormat.getTagFieldNumber(tag);
246+
FieldLiteDescriptor fieldDescriptor = messageDescriptor.getByFieldNumberOrThrow(fieldNumber);
247+
248+
Object payload;
249+
switch (tagWireType) {
250+
case WireFormat.WIRETYPE_VARINT:
251+
payload = readPrimitiveField(inputStream, fieldDescriptor);
252+
break;
253+
case WireFormat.WIRETYPE_FIXED32:
254+
payload = readFixed32BitField(inputStream, fieldDescriptor);
255+
break;
256+
case WireFormat.WIRETYPE_FIXED64:
257+
payload = readFixed64BitField(inputStream, fieldDescriptor);
258+
break;
259+
case WireFormat.WIRETYPE_LENGTH_DELIMITED:
260+
CelFieldValueType celFieldValueType = fieldDescriptor.getCelFieldValueType();
261+
switch (celFieldValueType) {
262+
case LIST:
263+
if (fieldDescriptor.getIsPacked()) {
264+
payload = readPackedRepeatedFields(inputStream, fieldDescriptor);
265+
} else {
266+
FieldLiteDescriptor.Type protoFieldType = fieldDescriptor.getProtoFieldType();
267+
boolean isLenDelimited =
268+
protoFieldType.equals(FieldLiteDescriptor.Type.MESSAGE)
269+
|| protoFieldType.equals(FieldLiteDescriptor.Type.STRING)
270+
|| protoFieldType.equals(FieldLiteDescriptor.Type.BYTES);
271+
if (!isLenDelimited) {
272+
throw new IllegalStateException(
273+
"Unexpected field type encountered for LEN-Delimited record: "
274+
+ protoFieldType);
275+
}
276+
277+
payload = readLengthDelimitedField(inputStream, fieldDescriptor);
278+
}
279+
break;
280+
case MAP:
281+
Map<Object, Object> fieldMap =
282+
mapFieldValues.computeIfAbsent(fieldNumber, (unused) -> new LinkedHashMap<>());
283+
Map.Entry<Object, Object> mapEntry = readSingleMapEntry(inputStream, fieldDescriptor);
284+
fieldMap.put(mapEntry.getKey(), mapEntry.getValue());
285+
payload = fieldMap;
286+
break;
287+
default:
288+
payload = readLengthDelimitedField(inputStream, fieldDescriptor);
289+
break;
290+
}
291+
break;
292+
case WireFormat.WIRETYPE_START_GROUP:
293+
case WireFormat.WIRETYPE_END_GROUP:
294+
// TODO: Support groups
295+
throw new UnsupportedOperationException("Groups are not supported");
296+
default:
297+
throw new IllegalArgumentException("Unexpected wire type: " + tagWireType);
298+
}
299+
300+
if (fieldDescriptor.getCelFieldValueType().equals(CelFieldValueType.LIST)) {
301+
String fieldName = fieldDescriptor.getFieldName();
302+
List<Object> repeatedValues =
303+
repeatedFieldValues.computeIfAbsent(
304+
fieldNumber,
305+
(unused) -> {
306+
List<Object> newList = new ArrayList<>();
307+
fieldValues.put(fieldName, newList);
308+
return newList;
309+
});
310+
311+
if (payload instanceof Collection) {
312+
repeatedValues.addAll((Collection<?>) payload);
313+
} else {
314+
repeatedValues.add(payload);
315+
}
316+
} else {
317+
fieldValues.put(fieldDescriptor.getFieldName(), payload);
318+
}
319+
}
320+
321+
// Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields,
322+
// we accept the last value encountered.
323+
return fieldValues.buildKeepingLast();
324+
}
325+
326+
ImmutableMap<String, Object> readAllFields(MessageLite msg, String protoTypeName)
327+
throws IOException {
328+
return readAllFields(msg.toByteArray(), protoTypeName);
329+
}
330+
46331
@Override
47332
public CelValue fromProtoMessageToCelValue(String protoTypeName, MessageLite msg) {
48333
checkNotNull(msg);

common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
package dev.cel.common.values;
1616

1717
import com.google.auto.value.AutoValue;
18+
import com.google.auto.value.extension.memoized.Memoized;
1819
import com.google.common.base.Preconditions;
20+
import com.google.common.collect.ImmutableMap;
1921
import com.google.errorprone.annotations.Immutable;
2022
import com.google.protobuf.MessageLite;
2123
import dev.cel.common.types.CelType;
2224
import dev.cel.common.types.StructTypeReference;
25+
import java.io.IOException;
2326
import java.util.Optional;
2427

2528
/**
@@ -42,19 +45,32 @@ public abstract class ProtoMessageLiteValue extends StructValue<StringValue> {
4245

4346
abstract ProtoLiteCelValueConverter protoLiteCelValueConverter();
4447

48+
@Memoized
49+
ImmutableMap<String, Object> fieldValues() {
50+
try {
51+
return protoLiteCelValueConverter().readAllFields(value(), celType().name());
52+
} catch (IOException e) {
53+
throw new IllegalStateException("Unable to read message fields for " + celType().name(), e);
54+
}
55+
}
56+
4557
@Override
4658
public boolean isZeroValue() {
4759
return value().getDefaultInstanceForType().equals(value());
4860
}
4961

5062
@Override
5163
public CelValue select(StringValue field) {
52-
throw new UnsupportedOperationException("Not implemented yet");
64+
return find(field)
65+
.orElseGet(
66+
() -> protoLiteCelValueConverter().getDefaultCelValue(celType().name(), field.value()));
5367
}
5468

5569
@Override
5670
public Optional<CelValue> find(StringValue field) {
57-
throw new UnsupportedOperationException("Not implemented yet");
71+
Object fieldValue = fieldValues().get(field.value());
72+
return Optional.ofNullable(fieldValue)
73+
.map(value -> protoLiteCelValueConverter().fromJavaObjectToCelValue(fieldValue));
5874
}
5975

6076
public static ProtoMessageLiteValue create(

0 commit comments

Comments
 (0)