Skip to content

Commit 81ea911

Browse files
l46kokcopybara-github
authored andcommitted
Fix runtime equality behavior for sets extension
PiperOrigin-RevId: 662216130
1 parent 8073b79 commit 81ea911

File tree

4 files changed

+268
-86
lines changed

4 files changed

+268
-86
lines changed

extensions/src/main/java/dev/cel/extensions/BUILD.bazel

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,13 @@ java_library(
137137
deps = [
138138
"//checker:checker_builder",
139139
"//common:compiler_common",
140-
"//common/internal:comparison_functions",
140+
"//common:options",
141+
"//common/internal:default_message_factory",
142+
"//common/internal:dynamic_proto",
141143
"//common/types",
142144
"//compiler:compiler_builder",
143145
"//runtime",
146+
"//runtime:runtime_helper",
144147
"@maven//:com_google_errorprone_error_prone_annotations",
145148
"@maven//:com_google_guava_guava",
146149
],

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ public final class CelExtensions {
3434
private static final CelProtoExtensions PROTO_EXTENSIONS = new CelProtoExtensions();
3535
private static final CelBindingsExtensions BINDINGS_EXTENSIONS = new CelBindingsExtensions();
3636
private static final CelEncoderExtensions ENCODER_EXTENSIONS = new CelEncoderExtensions();
37-
private static final CelSetsExtensions SET_EXTENSIONS = new CelSetsExtensions();
3837

3938
/**
4039
* Extended functions for string manipulation.
@@ -175,6 +174,14 @@ public static CelEncoderExtensions encoders() {
175174
return ENCODER_EXTENSIONS;
176175
}
177176

177+
/**
178+
* @deprecated Use {@link #sets(CelOptions)} instead.
179+
*/
180+
@Deprecated
181+
public static CelSetsExtensions sets() {
182+
return sets(CelOptions.DEFAULT);
183+
}
184+
178185
/**
179186
* Extended functions for Set manipulation.
180187
*
@@ -184,8 +191,8 @@ public static CelEncoderExtensions encoders() {
184191
* future additions. To expose only a subset of functions, use {@link
185192
* #sets(CelSetExtensions.Function...)} instead.
186193
*/
187-
public static CelSetsExtensions sets() {
188-
return SET_EXTENSIONS;
194+
public static CelSetsExtensions sets(CelOptions celOptions) {
195+
return new CelSetsExtensions(celOptions);
189196
}
190197

191198
/**
@@ -195,8 +202,9 @@ public static CelSetsExtensions sets() {
195202
*
196203
* <p>This will include only the specific functions denoted by {@link CelSetsExtensions.Function}.
197204
*/
198-
public static CelSetsExtensions sets(CelSetsExtensions.Function... functions) {
199-
return sets(ImmutableSet.copyOf(functions));
205+
public static CelSetsExtensions sets(
206+
CelOptions celOptions, CelSetsExtensions.Function... functions) {
207+
return sets(celOptions, ImmutableSet.copyOf(functions));
200208
}
201209

202210
/**
@@ -206,8 +214,9 @@ public static CelSetsExtensions sets(CelSetsExtensions.Function... functions) {
206214
*
207215
* <p>This will include only the specific functions denoted by {@link CelSetsExtensions.Function}.
208216
*/
209-
public static CelSetsExtensions sets(Set<CelSetsExtensions.Function> functions) {
210-
return new CelSetsExtensions(functions);
217+
public static CelSetsExtensions sets(
218+
CelOptions celOptions, Set<CelSetsExtensions.Function> functions) {
219+
return new CelSetsExtensions(celOptions, functions);
211220
}
212221

213222
/**

extensions/src/main/java/dev/cel/extensions/CelSetsExtensions.java

Lines changed: 51 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,20 @@
1818
import com.google.errorprone.annotations.Immutable;
1919
import dev.cel.checker.CelCheckerBuilder;
2020
import dev.cel.common.CelFunctionDecl;
21+
import dev.cel.common.CelOptions;
2122
import dev.cel.common.CelOverloadDecl;
22-
import dev.cel.common.internal.ComparisonFunctions;
23+
import dev.cel.common.internal.DefaultMessageFactory;
24+
import dev.cel.common.internal.DynamicProto;
2325
import dev.cel.common.types.ListType;
2426
import dev.cel.common.types.SimpleType;
2527
import dev.cel.common.types.TypeParamType;
2628
import dev.cel.compiler.CelCompilerLibrary;
2729
import dev.cel.runtime.CelRuntime;
2830
import dev.cel.runtime.CelRuntimeBuilder;
2931
import dev.cel.runtime.CelRuntimeLibrary;
32+
import dev.cel.runtime.RuntimeEquality;
3033
import java.util.Collection;
3134
import java.util.Iterator;
32-
import java.util.List;
3335
import java.util.Set;
3436

3537
/**
@@ -64,6 +66,9 @@ public final class CelSetsExtensions implements CelCompilerLibrary, CelRuntimeLi
6466
+ " are unique, so size does not factor into the computation. If either list is empty,"
6567
+ " the result will be false.";
6668

69+
private static final RuntimeEquality RUNTIME_EQUALITY =
70+
new RuntimeEquality(DynamicProto.create(DefaultMessageFactory.INSTANCE));
71+
6772
/** Denotes the set extension function. */
6873
public enum Function {
6974
CONTAINS(
@@ -74,12 +79,7 @@ public enum Function {
7479
SET_CONTAINS_OVERLOAD_DOC,
7580
SimpleType.BOOL,
7681
ListType.create(TypeParamType.create("T")),
77-
ListType.create(TypeParamType.create("T")))),
78-
CelRuntime.CelFunctionBinding.from(
79-
"list_sets_contains_list",
80-
Collection.class,
81-
Collection.class,
82-
CelSetsExtensions::containsAll)),
82+
ListType.create(TypeParamType.create("T"))))),
8383
EQUIVALENT(
8484
CelFunctionDecl.newFunctionDeclaration(
8585
SET_EQUIVALENT_FUNCTION,
@@ -88,12 +88,7 @@ public enum Function {
8888
SET_EQUIVALENT_OVERLOAD_DOC,
8989
SimpleType.BOOL,
9090
ListType.create(TypeParamType.create("T")),
91-
ListType.create(TypeParamType.create("T")))),
92-
CelRuntime.CelFunctionBinding.from(
93-
"list_sets_equivalent_list",
94-
Collection.class,
95-
Collection.class,
96-
(listA, listB) -> containsAll(listA, listB) && containsAll(listB, listA))),
91+
ListType.create(TypeParamType.create("T"))))),
9792
INTERSECTS(
9893
CelFunctionDecl.newFunctionDeclaration(
9994
SET_INTERSECTS_FUNCTION,
@@ -102,34 +97,29 @@ public enum Function {
10297
SET_INTERSECTS_OVERLOAD_DOC,
10398
SimpleType.BOOL,
10499
ListType.create(TypeParamType.create("T")),
105-
ListType.create(TypeParamType.create("T")))),
106-
CelRuntime.CelFunctionBinding.from(
107-
"list_sets_intersects_list",
108-
Collection.class,
109-
Collection.class,
110-
CelSetsExtensions::setIntersects));
100+
ListType.create(TypeParamType.create("T")))));
111101

112102
private final CelFunctionDecl functionDecl;
113-
private final ImmutableSet<CelRuntime.CelFunctionBinding> functionBindings;
114103

115104
String getFunction() {
116105
return functionDecl.name();
117106
}
118107

119-
Function(CelFunctionDecl functionDecl, CelRuntime.CelFunctionBinding... functionBindings) {
108+
Function(CelFunctionDecl functionDecl) {
120109
this.functionDecl = functionDecl;
121-
this.functionBindings = ImmutableSet.copyOf(functionBindings);
122110
}
123111
}
124112

125113
private final ImmutableSet<Function> functions;
114+
private final CelOptions celOptions;
126115

127-
CelSetsExtensions() {
128-
this(ImmutableSet.copyOf(Function.values()));
116+
CelSetsExtensions(CelOptions celOptions) {
117+
this(celOptions, ImmutableSet.copyOf(Function.values()));
129118
}
130119

131-
CelSetsExtensions(Set<Function> functions) {
120+
CelSetsExtensions(CelOptions celOptions, Set<Function> functions) {
132121
this.functions = ImmutableSet.copyOf(functions);
122+
this.celOptions = celOptions;
133123
}
134124

135125
@Override
@@ -139,7 +129,34 @@ public void setCheckerOptions(CelCheckerBuilder checkerBuilder) {
139129

140130
@Override
141131
public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
142-
functions.forEach(function -> runtimeBuilder.addFunctionBindings(function.functionBindings));
132+
for (Function function : functions) {
133+
switch (function) {
134+
case CONTAINS:
135+
runtimeBuilder.addFunctionBindings(
136+
CelRuntime.CelFunctionBinding.from(
137+
"list_sets_contains_list",
138+
Collection.class,
139+
Collection.class,
140+
this::containsAll));
141+
break;
142+
case EQUIVALENT:
143+
runtimeBuilder.addFunctionBindings(
144+
CelRuntime.CelFunctionBinding.from(
145+
"list_sets_equivalent_list",
146+
Collection.class,
147+
Collection.class,
148+
(listA, listB) -> containsAll(listA, listB) && containsAll(listB, listA)));
149+
break;
150+
case INTERSECTS:
151+
runtimeBuilder.addFunctionBindings(
152+
CelRuntime.CelFunctionBinding.from(
153+
"list_sets_intersects_list",
154+
Collection.class,
155+
Collection.class,
156+
this::setIntersects));
157+
break;
158+
}
159+
}
143160
}
144161

145162
/**
@@ -150,9 +167,9 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
150167
* <p>This is picked verbatim as implemented in the Java standard library
151168
* Collections.containsAll() method.
152169
*
153-
* @see #contains(Object)
170+
* @see #contains(Object, Collection)
154171
*/
155-
private static <T> boolean containsAll(Collection<T> list, Collection<T> subList) {
172+
private <T> boolean containsAll(Collection<T> list, Collection<T> subList) {
156173
for (T e : subList) {
157174
if (!contains(e, list)) {
158175
return false;
@@ -171,7 +188,7 @@ private static <T> boolean containsAll(Collection<T> list, Collection<T> subList
171188
* <p>Source:
172189
* https://hg.openjdk.org/jdk8u/jdk8u-dev/jdk/file/c5d02f908fb2/src/share/classes/java/util/AbstractCollection.java#l98
173190
*/
174-
private static <T> boolean contains(Object o, Collection<T> list) {
191+
private <T> boolean contains(Object o, Collection<T> list) {
175192
Iterator<?> it = list.iterator();
176193
if (o == null) {
177194
while (it.hasNext()) {
@@ -182,55 +199,19 @@ private static <T> boolean contains(Object o, Collection<T> list) {
182199
} else {
183200
while (it.hasNext()) {
184201
Object item = it.next();
185-
if (objectsEquals(item, o)) { // TODO: Support Maps.
202+
if (objectsEquals(item, o)) {
186203
return true;
187204
}
188205
}
189206
}
190207
return false;
191208
}
192209

193-
private static boolean objectsEquals(Object o1, Object o2) {
194-
if (o1 == o2) {
195-
return true;
196-
}
197-
if (o1 == null || o2 == null) {
198-
return false;
199-
}
200-
if (isNumeric(o1) && isNumeric(o2)) {
201-
if (o1.getClass().equals(o2.getClass())) {
202-
return o1.equals(o2);
203-
}
204-
return ComparisonFunctions.numericEquals((Number) o1, (Number) o2);
205-
}
206-
if (isList(o1) && isList(o2)) {
207-
Collection<?> list1 = (Collection<?>) o1;
208-
Collection<?> list2 = (Collection<?>) o2;
209-
if (list1.size() != list2.size()) {
210-
return false;
211-
}
212-
Iterator<?> iterator1 = list1.iterator();
213-
Iterator<?> iterator2 = list2.iterator();
214-
boolean result = true;
215-
while (iterator1.hasNext() && iterator2.hasNext()) {
216-
Object p1 = iterator1.next();
217-
Object p2 = iterator2.next();
218-
result = result && objectsEquals(p1, p2);
219-
}
220-
return result;
221-
}
222-
return o1.equals(o2);
223-
}
224-
225-
private static boolean isNumeric(Object o) {
226-
return o instanceof Number;
227-
}
228-
229-
private static boolean isList(Object o) {
230-
return o instanceof List;
210+
private boolean objectsEquals(Object o1, Object o2) {
211+
return RUNTIME_EQUALITY.objectEquals(o1, o2, celOptions);
231212
}
232213

233-
private static <T> boolean setIntersects(Collection<T> listA, Collection<T> listB) {
214+
private <T> boolean setIntersects(Collection<T> listA, Collection<T> listB) {
234215
if (listA.isEmpty() || listB.isEmpty()) {
235216
return false;
236217
}

0 commit comments

Comments
 (0)