1818import com .google .errorprone .annotations .Immutable ;
1919import dev .cel .checker .CelCheckerBuilder ;
2020import dev .cel .common .CelFunctionDecl ;
21+ import dev .cel .common .CelOptions ;
2122import dev .cel .common .CelOverloadDecl ;
22- import dev .cel .common .internal .ComparisonFunctions ;
23+ import dev .cel .common .internal .DefaultMessageFactory ;
24+ import dev .cel .common .internal .DynamicProto ;
2325import dev .cel .common .types .ListType ;
2426import dev .cel .common .types .SimpleType ;
2527import dev .cel .common .types .TypeParamType ;
2628import dev .cel .compiler .CelCompilerLibrary ;
2729import dev .cel .runtime .CelRuntime ;
2830import dev .cel .runtime .CelRuntimeBuilder ;
2931import dev .cel .runtime .CelRuntimeLibrary ;
32+ import dev .cel .runtime .RuntimeEquality ;
3033import java .util .Collection ;
3134import java .util .Iterator ;
32- import java .util .List ;
3335import java .util .Set ;
3436
3537/**
@@ -64,6 +66,9 @@ public final class CelSetsExtensions implements CelCompilerLibrary, CelRuntimeLi
6466 + " are unique, so size does not factor into the computation. If either list is empty,"
6567 + " the result will be false." ;
6668
69+ private static final RuntimeEquality RUNTIME_EQUALITY =
70+ new RuntimeEquality (DynamicProto .create (DefaultMessageFactory .INSTANCE ));
71+
6772 /** Denotes the set extension function. */
6873 public enum Function {
6974 CONTAINS (
@@ -74,12 +79,7 @@ public enum Function {
7479 SET_CONTAINS_OVERLOAD_DOC ,
7580 SimpleType .BOOL ,
7681 ListType .create (TypeParamType .create ("T" )),
77- ListType .create (TypeParamType .create ("T" )))),
78- CelRuntime .CelFunctionBinding .from (
79- "list_sets_contains_list" ,
80- Collection .class ,
81- Collection .class ,
82- CelSetsExtensions ::containsAll )),
82+ ListType .create (TypeParamType .create ("T" ))))),
8383 EQUIVALENT (
8484 CelFunctionDecl .newFunctionDeclaration (
8585 SET_EQUIVALENT_FUNCTION ,
@@ -88,12 +88,7 @@ public enum Function {
8888 SET_EQUIVALENT_OVERLOAD_DOC ,
8989 SimpleType .BOOL ,
9090 ListType .create (TypeParamType .create ("T" )),
91- ListType .create (TypeParamType .create ("T" )))),
92- CelRuntime .CelFunctionBinding .from (
93- "list_sets_equivalent_list" ,
94- Collection .class ,
95- Collection .class ,
96- (listA , listB ) -> containsAll (listA , listB ) && containsAll (listB , listA ))),
91+ ListType .create (TypeParamType .create ("T" ))))),
9792 INTERSECTS (
9893 CelFunctionDecl .newFunctionDeclaration (
9994 SET_INTERSECTS_FUNCTION ,
@@ -102,34 +97,29 @@ public enum Function {
10297 SET_INTERSECTS_OVERLOAD_DOC ,
10398 SimpleType .BOOL ,
10499 ListType .create (TypeParamType .create ("T" )),
105- ListType .create (TypeParamType .create ("T" )))),
106- CelRuntime .CelFunctionBinding .from (
107- "list_sets_intersects_list" ,
108- Collection .class ,
109- Collection .class ,
110- CelSetsExtensions ::setIntersects ));
100+ ListType .create (TypeParamType .create ("T" )))));
111101
112102 private final CelFunctionDecl functionDecl ;
113- private final ImmutableSet <CelRuntime .CelFunctionBinding > functionBindings ;
114103
115104 String getFunction () {
116105 return functionDecl .name ();
117106 }
118107
119- Function (CelFunctionDecl functionDecl , CelRuntime . CelFunctionBinding ... functionBindings ) {
108+ Function (CelFunctionDecl functionDecl ) {
120109 this .functionDecl = functionDecl ;
121- this .functionBindings = ImmutableSet .copyOf (functionBindings );
122110 }
123111 }
124112
125113 private final ImmutableSet <Function > functions ;
114+ private final CelOptions celOptions ;
126115
127- CelSetsExtensions () {
128- this (ImmutableSet .copyOf (Function .values ()));
116+ CelSetsExtensions (CelOptions celOptions ) {
117+ this (celOptions , ImmutableSet .copyOf (Function .values ()));
129118 }
130119
131- CelSetsExtensions (Set <Function > functions ) {
120+ CelSetsExtensions (CelOptions celOptions , Set <Function > functions ) {
132121 this .functions = ImmutableSet .copyOf (functions );
122+ this .celOptions = celOptions ;
133123 }
134124
135125 @ Override
@@ -139,7 +129,34 @@ public void setCheckerOptions(CelCheckerBuilder checkerBuilder) {
139129
140130 @ Override
141131 public void setRuntimeOptions (CelRuntimeBuilder runtimeBuilder ) {
142- functions .forEach (function -> runtimeBuilder .addFunctionBindings (function .functionBindings ));
132+ for (Function function : functions ) {
133+ switch (function ) {
134+ case CONTAINS :
135+ runtimeBuilder .addFunctionBindings (
136+ CelRuntime .CelFunctionBinding .from (
137+ "list_sets_contains_list" ,
138+ Collection .class ,
139+ Collection .class ,
140+ this ::containsAll ));
141+ break ;
142+ case EQUIVALENT :
143+ runtimeBuilder .addFunctionBindings (
144+ CelRuntime .CelFunctionBinding .from (
145+ "list_sets_equivalent_list" ,
146+ Collection .class ,
147+ Collection .class ,
148+ (listA , listB ) -> containsAll (listA , listB ) && containsAll (listB , listA )));
149+ break ;
150+ case INTERSECTS :
151+ runtimeBuilder .addFunctionBindings (
152+ CelRuntime .CelFunctionBinding .from (
153+ "list_sets_intersects_list" ,
154+ Collection .class ,
155+ Collection .class ,
156+ this ::setIntersects ));
157+ break ;
158+ }
159+ }
143160 }
144161
145162 /**
@@ -150,9 +167,9 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
150167 * <p>This is picked verbatim as implemented in the Java standard library
151168 * Collections.containsAll() method.
152169 *
153- * @see #contains(Object)
170+ * @see #contains(Object, Collection )
154171 */
155- private static <T > boolean containsAll (Collection <T > list , Collection <T > subList ) {
172+ private <T > boolean containsAll (Collection <T > list , Collection <T > subList ) {
156173 for (T e : subList ) {
157174 if (!contains (e , list )) {
158175 return false ;
@@ -171,7 +188,7 @@ private static <T> boolean containsAll(Collection<T> list, Collection<T> subList
171188 * <p>Source:
172189 * https://hg.openjdk.org/jdk8u/jdk8u-dev/jdk/file/c5d02f908fb2/src/share/classes/java/util/AbstractCollection.java#l98
173190 */
174- private static <T > boolean contains (Object o , Collection <T > list ) {
191+ private <T > boolean contains (Object o , Collection <T > list ) {
175192 Iterator <?> it = list .iterator ();
176193 if (o == null ) {
177194 while (it .hasNext ()) {
@@ -182,55 +199,19 @@ private static <T> boolean contains(Object o, Collection<T> list) {
182199 } else {
183200 while (it .hasNext ()) {
184201 Object item = it .next ();
185- if (objectsEquals (item , o )) { // TODO: Support Maps.
202+ if (objectsEquals (item , o )) {
186203 return true ;
187204 }
188205 }
189206 }
190207 return false ;
191208 }
192209
193- private static boolean objectsEquals (Object o1 , Object o2 ) {
194- if (o1 == o2 ) {
195- return true ;
196- }
197- if (o1 == null || o2 == null ) {
198- return false ;
199- }
200- if (isNumeric (o1 ) && isNumeric (o2 )) {
201- if (o1 .getClass ().equals (o2 .getClass ())) {
202- return o1 .equals (o2 );
203- }
204- return ComparisonFunctions .numericEquals ((Number ) o1 , (Number ) o2 );
205- }
206- if (isList (o1 ) && isList (o2 )) {
207- Collection <?> list1 = (Collection <?>) o1 ;
208- Collection <?> list2 = (Collection <?>) o2 ;
209- if (list1 .size () != list2 .size ()) {
210- return false ;
211- }
212- Iterator <?> iterator1 = list1 .iterator ();
213- Iterator <?> iterator2 = list2 .iterator ();
214- boolean result = true ;
215- while (iterator1 .hasNext () && iterator2 .hasNext ()) {
216- Object p1 = iterator1 .next ();
217- Object p2 = iterator2 .next ();
218- result = result && objectsEquals (p1 , p2 );
219- }
220- return result ;
221- }
222- return o1 .equals (o2 );
223- }
224-
225- private static boolean isNumeric (Object o ) {
226- return o instanceof Number ;
227- }
228-
229- private static boolean isList (Object o ) {
230- return o instanceof List ;
210+ private boolean objectsEquals (Object o1 , Object o2 ) {
211+ return RUNTIME_EQUALITY .objectEquals (o1 , o2 , celOptions );
231212 }
232213
233- private static <T > boolean setIntersects (Collection <T > listA , Collection <T > listB ) {
214+ private <T > boolean setIntersects (Collection <T > listA , Collection <T > listB ) {
234215 if (listA .isEmpty () || listB .isEmpty ()) {
235216 return false ;
236217 }
0 commit comments