Skip to content

Commit c070283

Browse files
l46kokcopybara-github
authored andcommitted
Fix unpacking any messages containing extension fields
PiperOrigin-RevId: 572685714
1 parent 3c3aa79 commit c070283

File tree

6 files changed

+84
-6
lines changed

6 files changed

+84
-6
lines changed

bundle/src/main/java/dev/cel/bundle/CelBuilder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
2121
import com.google.protobuf.Descriptors.Descriptor;
2222
import com.google.protobuf.Descriptors.FileDescriptor;
23+
import com.google.protobuf.ExtensionRegistry;
2324
import com.google.protobuf.Message;
2425
import dev.cel.checker.ProtoTypeMask;
2526
import dev.cel.checker.TypeProvider;
@@ -282,6 +283,13 @@ public interface CelBuilder {
282283
@CanIgnoreReturnValue
283284
CelBuilder addRuntimeLibraries(Iterable<CelRuntimeLibrary> libraries);
284285

286+
/**
287+
* Sets a proto ExtensionRegistry to assist with unpacking Any messages containing a proto2
288+
extension field.
289+
*/
290+
@CanIgnoreReturnValue
291+
CelBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry);
292+
285293
/** Construct a new {@code Cel} instance from the provided configuration. */
286294
Cel build();
287295
}

bundle/src/main/java/dev/cel/bundle/CelImpl.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
2626
import com.google.protobuf.Descriptors.Descriptor;
2727
import com.google.protobuf.Descriptors.FileDescriptor;
28+
import com.google.protobuf.ExtensionRegistry;
2829
import com.google.protobuf.Message;
2930
import dev.cel.checker.CelCheckerBuilder;
3031
import dev.cel.checker.ProtoTypeMask;
@@ -339,6 +340,13 @@ public Builder addRuntimeLibraries(Iterable<CelRuntimeLibrary> libraries) {
339340
return this;
340341
}
341342

343+
@Override
344+
public CelBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry) {
345+
checkNotNull(extensionRegistry);
346+
runtimeBuilder.setExtensionRegistry(extensionRegistry);
347+
return this;
348+
}
349+
342350
@Override
343351
public Cel build() {
344352
return new CelImpl(

common/src/main/java/dev/cel/common/internal/DefaultDescriptorPool.java

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ public final class DefaultDescriptorPool implements CelDescriptorPool {
4646

4747
/** A DefaultDescriptorPool instance with just well known types loaded. */
4848
public static final DefaultDescriptorPool INSTANCE =
49-
new DefaultDescriptorPool(WELL_KNOWN_TYPE_DESCRIPTORS, ImmutableMultimap.of());
49+
new DefaultDescriptorPool(
50+
WELL_KNOWN_TYPE_DESCRIPTORS,
51+
ImmutableMultimap.of(),
52+
ExtensionRegistry.getEmptyRegistry());
5053

5154
// K: Fully qualified message type name, V: Message descriptor
5255
private final ImmutableMap<String, Descriptor> descriptorMap;
@@ -55,7 +58,15 @@ public final class DefaultDescriptorPool implements CelDescriptorPool {
5558
// V: Field descriptor for the extension message
5659
private final ImmutableMultimap<String, FieldDescriptor> extensionDescriptorMap;
5760

61+
@SuppressWarnings("Immutable") // ExtensionRegistry is immutable, just not marked as such.
62+
private final ExtensionRegistry extensionRegistry;
63+
5864
public static DefaultDescriptorPool create(CelDescriptors celDescriptors) {
65+
return create(celDescriptors, ExtensionRegistry.getEmptyRegistry());
66+
}
67+
68+
public static DefaultDescriptorPool create(
69+
CelDescriptors celDescriptors, ExtensionRegistry extensionRegistry) {
5970
Map<String, Descriptor> descriptorMap = new HashMap<>(); // Using a hashmap to allow deduping
6071
stream(WellKnownProto.values()).forEach(d -> descriptorMap.put(d.typeName(), d.descriptor()));
6172

@@ -64,7 +75,9 @@ public static DefaultDescriptorPool create(CelDescriptors celDescriptors) {
6475
}
6576

6677
return new DefaultDescriptorPool(
67-
ImmutableMap.copyOf(descriptorMap), celDescriptors.extensionDescriptors());
78+
ImmutableMap.copyOf(descriptorMap),
79+
celDescriptors.extensionDescriptors(),
80+
extensionRegistry);
6881
}
6982

7083
@Override
@@ -83,14 +96,15 @@ public Optional<FieldDescriptor> findExtensionDescriptor(
8396

8497
@Override
8598
public ExtensionRegistry getExtensionRegistry() {
86-
// TODO: Populate one from runtime builder.
87-
return ExtensionRegistry.getEmptyRegistry();
99+
return extensionRegistry;
88100
}
89101

90102
private DefaultDescriptorPool(
91103
ImmutableMap<String, Descriptor> descriptorMap,
92-
ImmutableMultimap<String, FieldDescriptor> extensionDescriptorMap) {
104+
ImmutableMultimap<String, FieldDescriptor> extensionDescriptorMap,
105+
ExtensionRegistry extensionRegistry) {
93106
this.descriptorMap = checkNotNull(descriptorMap);
94107
this.extensionDescriptorMap = checkNotNull(extensionDescriptorMap);
108+
this.extensionRegistry = checkNotNull(extensionRegistry);
95109
}
96110
}

extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@
1919

2020
import com.google.common.collect.ImmutableList;
2121
import com.google.common.collect.ImmutableMap;
22+
import com.google.protobuf.Any;
2223
import com.google.protobuf.Descriptors.FieldDescriptor;
24+
import com.google.protobuf.ExtensionRegistry;
2325
import com.google.testing.junit.testparameterinjector.TestParameter;
2426
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
2527
import com.google.testing.junit.testparameterinjector.TestParameters;
28+
import dev.cel.bundle.Cel;
29+
import dev.cel.bundle.CelFactory;
2630
import dev.cel.common.CelAbstractSyntaxTree;
2731
import dev.cel.common.CelFunctionDecl;
2832
import dev.cel.common.CelOverloadDecl;
@@ -277,6 +281,29 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception {
277281
assertThat(result).isTrue();
278282
}
279283

284+
@Test
285+
public void getExt_onAnyPackedExtensionField_success() throws Exception {
286+
ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance();
287+
MessagesProto2Extensions.registerAllExtensions(extensionRegistry);
288+
Cel cel =
289+
CelFactory.standardCelBuilder()
290+
.addCompilerLibraries(CelExtensions.protos())
291+
.addFileTypes(MessagesProto2Extensions.getDescriptor())
292+
.setExtensionRegistry(extensionRegistry)
293+
.addVar(
294+
"msg", StructTypeReference.create("dev.cel.testing.testdata.proto2.Proto2Message"))
295+
.build();
296+
CelAbstractSyntaxTree ast =
297+
cel.compile("proto.getExt(msg, dev.cel.testing.testdata.proto2.int32_ext)").getAst();
298+
Any anyMsg =
299+
Any.pack(
300+
Proto2Message.newBuilder().setExtension(MessagesProto2Extensions.int32Ext, 1).build());
301+
302+
Long result = (Long) cel.createProgram(ast).eval(ImmutableMap.of("msg", anyMsg));
303+
304+
assertThat(result).isEqualTo(1);
305+
}
306+
280307
private enum ParseErrorTestCase {
281308
FIELD_NOT_FULLY_QUALIFIED(
282309
"proto.getExt(Proto2ExtensionScopedMessage{}, int64_ext)",

runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
2020
import com.google.protobuf.Descriptors.Descriptor;
2121
import com.google.protobuf.Descriptors.FileDescriptor;
22+
import com.google.protobuf.ExtensionRegistry;
2223
import com.google.protobuf.Message;
2324
import dev.cel.common.CelOptions;
2425
import java.util.function.Function;
@@ -149,6 +150,13 @@ public interface CelRuntimeBuilder {
149150
@CanIgnoreReturnValue
150151
CelRuntimeBuilder addLibraries(Iterable<? extends CelRuntimeLibrary> libraries);
151152

153+
/**
154+
* Sets a proto ExtensionRegistry to assist with unpacking Any messages containing a proto2
155+
extension field.
156+
*/
157+
@CanIgnoreReturnValue
158+
CelRuntimeBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry);
159+
152160
/** Build a new instance of the {@code CelRuntime}. */
153161
@CheckReturnValue
154162
CelRuntime build();

runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
2626
import com.google.protobuf.Descriptors.Descriptor;
2727
import com.google.protobuf.Descriptors.FileDescriptor;
28+
import com.google.protobuf.ExtensionRegistry;
2829
import com.google.protobuf.Message;
2930
import dev.cel.common.CelAbstractSyntaxTree;
3031
import dev.cel.common.CelDescriptorUtil;
@@ -79,6 +80,7 @@ public static final class Builder implements CelRuntimeBuilder {
7980

8081
private boolean standardEnvironmentEnabled;
8182
private Function<String, Message.Builder> customTypeFactory;
83+
private ExtensionRegistry extensionRegistry;
8284

8385
@Override
8486
@CanIgnoreReturnValue
@@ -161,6 +163,14 @@ public Builder addLibraries(Iterable<? extends CelRuntimeLibrary> libraries) {
161163
return this;
162164
}
163165

166+
@Override
167+
@CanIgnoreReturnValue
168+
public Builder setExtensionRegistry(ExtensionRegistry extensionRegistry) {
169+
checkNotNull(extensionRegistry);
170+
this.extensionRegistry = extensionRegistry.getUnmodifiable();
171+
return this;
172+
}
173+
164174
/** Build a new {@code CelRuntimeLegacyImpl} instance from the builder config. */
165175
@Override
166176
@CanIgnoreReturnValue
@@ -171,6 +181,7 @@ public CelRuntimeLegacyImpl build() {
171181
CelDescriptorPool celDescriptorPool =
172182
newDescriptorPool(
173183
fileTypes.build(),
184+
extensionRegistry,
174185
options);
175186

176187
@SuppressWarnings("Immutable")
@@ -214,14 +225,15 @@ public CelRuntimeLegacyImpl build() {
214225

215226
private static CelDescriptorPool newDescriptorPool(
216227
ImmutableSet<FileDescriptor> fileTypeSet,
228+
ExtensionRegistry extensionRegistry,
217229
CelOptions celOptions) {
218230
CelDescriptors celDescriptors =
219231
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(
220232
fileTypeSet, celOptions.resolveTypeDependencies());
221233

222234
ImmutableList.Builder<CelDescriptorPool> descriptorPools = new ImmutableList.Builder<>();
223235

224-
descriptorPools.add(DefaultDescriptorPool.create(celDescriptors));
236+
descriptorPools.add(DefaultDescriptorPool.create(celDescriptors, extensionRegistry));
225237

226238
return CombinedDescriptorPool.create(descriptorPools.build());
227239
}
@@ -241,6 +253,7 @@ private Builder() {
241253
this.fileTypes = ImmutableSet.builder();
242254
this.functionBindings = ImmutableMap.builder();
243255
this.celRuntimeLibraries = ImmutableSet.builder();
256+
this.extensionRegistry = ExtensionRegistry.getEmptyRegistry();
244257
this.customTypeFactory = null;
245258
}
246259
}

0 commit comments

Comments
 (0)