Skip to content

Commit cd11baf

Browse files
l46kokcopybara-github
authored andcommitted
Perform CSE on presence tests
PiperOrigin-RevId: 599886305
1 parent 5a7cbab commit cd11baf

File tree

6 files changed

+209
-12
lines changed

6 files changed

+209
-12
lines changed

common/src/main/java/dev/cel/common/ast/CelExprFormatter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ private void appendSelect(CelExpr.CelSelect celSelect) {
122122
indent();
123123
formatExpr(celSelect.operand());
124124
outdent();
125-
append(".");
126-
append(celSelect.field());
125+
appendWithoutIndent(".");
126+
appendWithoutIndent(celSelect.field());
127127
if (celSelect.testOnly()) {
128128
appendWithoutIndent("~presence_test");
129129
}

common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,4 +336,37 @@ public void comprehension() throws Exception {
336336
+ " }\n"
337337
+ "}");
338338
}
339+
340+
@Test
341+
public void ternaryWithPresenceTest() throws Exception {
342+
CelCompiler celCompiler =
343+
CelCompilerFactory.standardCelCompilerBuilder()
344+
.addMessageTypes(TestAllTypes.getDescriptor())
345+
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()))
346+
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
347+
.build();
348+
CelAbstractSyntaxTree ast =
349+
celCompiler.compile("has(msg.single_any) ? msg.single_any : 10").getAst();
350+
351+
String formattedExpr = CelExprFormatter.format(ast.getExpr());
352+
353+
assertThat(formattedExpr)
354+
.isEqualTo(
355+
"CALL [5] {\n"
356+
+ " function: _?_:_\n"
357+
+ " args: {\n"
358+
+ " SELECT [4] {\n"
359+
+ " IDENT [2] {\n"
360+
+ " name: msg\n"
361+
+ " }.single_any~presence_test\n"
362+
+ " }\n"
363+
+ " SELECT [7] {\n"
364+
+ " IDENT [6] {\n"
365+
+ " name: msg\n"
366+
+ " }.single_any\n"
367+
+ " }\n"
368+
+ " CONSTANT [8] { value: 10 }\n"
369+
+ " }\n"
370+
+ "}");
371+
}
339372
}

optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@ public interface CelAstOptimizer {
2626
CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel)
2727
throws CelOptimizationException;
2828

29+
/**
30+
* Replaces a subtree in the given expression node. This operation is intended for AST
31+
* optimization purposes.
32+
*
33+
* <p>This is a very dangerous operation. Callers should re-typecheck the mutated AST and
34+
* additionally verify that the resulting AST is semantically valid.
35+
*
36+
* <p>All expression IDs will be renumbered in a stable manner to ensure there's no ID collision
37+
* between the nodes. The renumbering occurs even if the subtree was not replaced.
38+
*
39+
* @param celExpr Original expression node to rewrite.
40+
* @param newExpr New CelExpr to replace the subtree with.
41+
* @param exprIdToReplace Expression id of the subtree that is getting replaced.
42+
*/
43+
default CelExpr replaceSubtree(CelExpr celExpr, CelExpr newExpr, long exprIdToReplace) {
44+
return MutableAst.replaceSubtree(celExpr, newExpr, exprIdToReplace);
45+
}
46+
2947
/**
3048
* Replaces a subtree in the given AST. This operation is intended for AST optimization purposes.
3149
*

optimizer/src/main/java/dev/cel/optimizer/MutableAst.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ static CelExpr clearExprIds(CelExpr celExpr) {
6666
return renumberExprIds((unused) -> 0, celExpr.toBuilder()).build();
6767
}
6868

69+
/** Mutates the given {@link CelExpr} by replacing a subtree at the given index. */
70+
static CelExpr replaceSubtree(CelExpr expr, CelExpr newExpr, long exprIdToReplace) {
71+
return replaceSubtree(
72+
CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build()),
73+
CelAbstractSyntaxTree.newParsedAst(newExpr, CelSource.newBuilder().build()),
74+
exprIdToReplace)
75+
.getExpr();
76+
}
77+
6978
/**
7079
* Mutates the given AST by replacing a subtree at a given index.
7180
*

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {
9797
int bindIdentifierIndex = 0;
9898
int iterCount;
9999
for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) {
100-
CelNavigableExpr cseCandidate = findCseCandidate(astToModify).orElse(null);
100+
CelExpr cseCandidate = findCseCandidate(astToModify).map(CelNavigableExpr::expr).orElse(null);
101101
if (cseCandidate == null) {
102102
break;
103103
}
@@ -107,7 +107,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {
107107

108108
// Using the CSE candidate, fetch all semantically equivalent subexpressions ahead of time.
109109
ImmutableList<CelExpr> allCseCandidates =
110-
getAllCseCandidatesStream(astToModify, cseCandidate.expr()).collect(toImmutableList());
110+
getAllCseCandidatesStream(astToModify, cseCandidate).collect(toImmutableList());
111111

112112
// Replace all CSE candidates with new bind identifier
113113
for (CelExpr semanticallyEqualNode : allCseCandidates) {
@@ -142,7 +142,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {
142142
// Insert the new bind call
143143
astToModify =
144144
replaceSubtreeWithNewBindMacro(
145-
astToModify, bindIdentifier, cseCandidate.expr(), lca.expr(), lca.id());
145+
astToModify, bindIdentifier, cseCandidate, lca.expr(), lca.id());
146146

147147
// Retain the existing macro calls in case if the bind identifiers are replacing a subtree
148148
// that contains a comprehension.
@@ -224,8 +224,8 @@ private Optional<CelNavigableExpr> findCseCandidate(CelAbstractSyntaxTree ast) {
224224
.collect(toImmutableList());
225225

226226
for (CelNavigableExpr node : allNodes) {
227-
// Strip out all IDs to test equivalence
228-
CelExpr celExpr = clearExprIds(node.expr());
227+
// Normalize the expr to test semantic equivalence.
228+
CelExpr celExpr = normalizeForEquality(node.expr());
229229
if (encounteredNodes.contains(celExpr)) {
230230
return Optional.of(node);
231231
}
@@ -240,6 +240,7 @@ private static boolean canEliminate(CelNavigableExpr navigableExpr) {
240240
return !navigableExpr.getKind().equals(Kind.CONSTANT)
241241
&& !navigableExpr.getKind().equals(Kind.IDENT)
242242
&& !navigableExpr.expr().identOrDefault().name().startsWith(BIND_IDENTIFIER_PREFIX)
243+
&& !navigableExpr.expr().selectOrDefault().testOnly()
243244
&& isAllowedFunction(navigableExpr)
244245
&& isWithinInlineableComprehension(navigableExpr);
245246
}
@@ -271,7 +272,7 @@ private static boolean isWithinInlineableComprehension(CelNavigableExpr expr) {
271272
}
272273

273274
private boolean areSemanticallyEqual(CelExpr expr1, CelExpr expr2) {
274-
return clearExprIds(expr1).equals(clearExprIds(expr2));
275+
return normalizeForEquality(expr1).equals(normalizeForEquality(expr2));
275276
}
276277

277278
private static boolean isAllowedFunction(CelNavigableExpr navigableExpr) {
@@ -282,6 +283,47 @@ private static boolean isAllowedFunction(CelNavigableExpr navigableExpr) {
282283
return true;
283284
}
284285

286+
/**
287+
* Converts the {@link CelExpr} to make it suitable for performing semantically equals check in
288+
* {@link #areSemanticallyEqual(CelExpr, CelExpr)}.
289+
*
290+
* <p>Specifically, this will:
291+
*
292+
* <ul>
293+
* <li>Set all expr IDs in the expression tree to 0.
294+
* <li>Strip all presence tests (i.e: testOnly is marked as false on {@link
295+
* CelExpr.ExprKind.Kind#SELECT}
296+
* </ul>
297+
*/
298+
private CelExpr normalizeForEquality(CelExpr celExpr) {
299+
int iterCount;
300+
for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) {
301+
CelExpr presenceTestExpr =
302+
CelNavigableExpr.fromExpr(celExpr)
303+
.allNodes()
304+
.map(CelNavigableExpr::expr)
305+
.filter(expr -> expr.selectOrDefault().testOnly())
306+
.findAny()
307+
.orElse(null);
308+
if (presenceTestExpr == null) {
309+
break;
310+
}
311+
312+
CelExpr newExpr =
313+
presenceTestExpr.toBuilder()
314+
.setSelect(presenceTestExpr.select().toBuilder().setTestOnly(false).build())
315+
.build();
316+
317+
celExpr = replaceSubtree(celExpr, newExpr, newExpr.id());
318+
}
319+
320+
if (iterCount >= cseOptions.maxIterationLimit()) {
321+
throw new IllegalStateException("Max iteration count reached.");
322+
}
323+
324+
return clearExprIds(celExpr);
325+
}
326+
285327
/** Options to configure how Common Subexpression Elimination behave. */
286328
@AutoValue
287329
public abstract static class SubexpressionOptimizerOptions {

optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import dev.cel.common.CelOptions;
3232
import dev.cel.common.ast.CelConstant;
3333
import dev.cel.common.ast.CelExpr;
34+
import dev.cel.common.types.OptionalType;
3435
import dev.cel.common.types.SimpleType;
3536
import dev.cel.common.types.StructTypeReference;
3637
import dev.cel.extensions.CelExtensions;
@@ -75,7 +76,8 @@ public class SubexpressionOptimizerTest {
7576
.setSingleInt64(10L)
7677
.putMapInt32Int64(0, 1)
7778
.putMapInt32Int64(1, 5)
78-
.putMapInt32Int64(2, 2)))
79+
.putMapInt32Int64(2, 2)
80+
.putMapStringString("key", "A")))
7981
.build();
8082

8183
private static CelBuilder newCelBuilder() {
@@ -92,6 +94,7 @@ private static CelBuilder newCelBuilder() {
9294
"custom_func",
9395
newGlobalOverload("custom_func_overload", SimpleType.INT, SimpleType.INT)))
9496
.addVar("x", SimpleType.DYN)
97+
.addVar("opt_x", OptionalType.create(SimpleType.DYN))
9598
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()));
9699
}
97100

@@ -314,6 +317,13 @@ private enum CseTestCase {
314317
"[1,2,3].map(i, [1, 2, 3].map(i, i + 1)) == [[2, 3, 4], [2, 3, 4], [2, 3, 4]]",
315318
"cel.bind(@r0, [1, 2, 3], @r0.map(@c0, @r0.map(@c1, @c1 + 1))) == "
316319
+ "cel.bind(@r1, [2, 3, 4], [@r1, @r1, @r1])"),
320+
INCLUSION_LIST(
321+
"1 in [1,2,3] && 2 in [1,2,3] && 3 in [3, [1,2,3]] && 1 in [1,2,3]",
322+
"cel.bind(@r0, [1, 2, 3], cel.bind(@r1, 1 in @r0, @r1 && 2 in @r0 && 3 in [3, @r0] &&"
323+
+ " @r1))"),
324+
INCLUSION_MAP(
325+
"2 in {'a': 1, 2: {true: false}, 3: {true: false}}",
326+
"2 in cel.bind(@r0, {true: false}, {\"a\": 1, 2: @r0, 3: @r0})"),
317327
MACRO_SHADOWED_VARIABLE(
318328
"[x - 1 > 3 ? x - 1 : 5].exists(x, x - 1 > 3) || x - 1 > 3",
319329
"cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(@c0, @c0 - 1 > 3) ||"
@@ -322,6 +332,86 @@ private enum CseTestCase {
322332
"size([\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])) == 2",
323333
"size([\"foo\", \"bar\"].map(@c1, cel.bind(@r0, @c1 + @c1, [@r0, @r0]))"
324334
+ ".map(@c0, cel.bind(@r1, @c0 + @c0, [@r1, @r1]))) == 2"),
335+
PRESENCE_TEST(
336+
"has({'a': true}.a) && {'a':true}['a']",
337+
"cel.bind(@r0, {\"a\": true}, has(@r0.a) && @r0[\"a\"])"),
338+
PRESENCE_TEST_WITH_TERNARY(
339+
"(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 : 0) == 10",
340+
"cel.bind(@r0, msg.oneof_type, has(@r0.payload) ? @r0.payload.single_int64 : 0) == 10"),
341+
PRESENCE_TEST_WITH_TERNARY_2(
342+
"(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 :"
343+
+ " msg.oneof_type.payload.single_int64 * 0) == 10",
344+
"cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload.single_int64, has(@r0.payload) ?"
345+
+ " @r1 : (@r1 * 0))) == 10"),
346+
PRESENCE_TEST_WITH_TERNARY_3(
347+
"(has(msg.oneof_type.payload.single_int64) ? msg.oneof_type.payload.single_int64 :"
348+
+ " msg.oneof_type.payload.single_int64 * 0) == 10",
349+
"cel.bind(@r0, msg.oneof_type.payload, cel.bind(@r1, @r0.single_int64,"
350+
+ " has(@r0.single_int64) ? @r1 : (@r1 * 0))) == 10"),
351+
/**
352+
* Input:
353+
*
354+
* <pre>{@code
355+
* (
356+
* has(msg.oneof_type) &&
357+
* has(msg.oneof_type.payload) &&
358+
* has(msg.oneof_type.payload.single_int64)
359+
* ) ?
360+
* (
361+
* (
362+
* has(msg.oneof_type.payload.map_string_string) &&
363+
* has(msg.oneof_type.payload.map_string_string.key)
364+
* ) ?
365+
* msg.oneof_type.payload.map_string_string.key == "A"
366+
* : false
367+
* )
368+
* : false
369+
* }</pre>
370+
*
371+
* Unparsed:
372+
*
373+
* <pre>{@code
374+
* cel.bind(
375+
* @r0, msg.oneof_type,
376+
* cel.bind(
377+
* @r1, @r0.payload,
378+
* has(msg.oneof_type) && has(@r0.payload) && has(@r1.single_int64) ?
379+
* cel.bind(
380+
* @r2, @r1.map_string_string,
381+
* has(@r1.map_string_string) && has(@r2.key) ? @r2.key == "A" : false,
382+
* )
383+
* : false,
384+
* ),
385+
* )
386+
* }</pre>
387+
*/
388+
PRESENCE_TEST_WITH_TERNARY_NESTED(
389+
"(has(msg.oneof_type) && has(msg.oneof_type.payload) &&"
390+
+ " has(msg.oneof_type.payload.single_int64)) ?"
391+
+ " ((has(msg.oneof_type.payload.map_string_string) &&"
392+
+ " has(msg.oneof_type.payload.map_string_string.key)) ?"
393+
+ " msg.oneof_type.payload.map_string_string.key == 'A' : false) : false",
394+
"cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload, (has(msg.oneof_type) &&"
395+
+ " has(@r0.payload) && has(@r1.single_int64)) ? cel.bind(@r2, @r1.map_string_string,"
396+
+ " (has(@r1.map_string_string) && has(@r2.key)) ? (@r2.key == \"A\") : false) :"
397+
+ " false))"),
398+
OPTIONAL_LIST(
399+
"[10, ?optional.none(), [?optional.none(), ?opt_x], [?optional.none(), ?opt_x]] == [10,"
400+
+ " [5], [5]]",
401+
"cel.bind(@r0, [?optional.none(), ?opt_x], [10, ?optional.none(), @r0, @r0]) =="
402+
+ " cel.bind(@r1, [5], [10, @r1, @r1])"),
403+
OPTIONAL_MAP(
404+
"{?'hello': optional.of('hello')}['hello'] + {?'hello': optional.of('hello')}['hello'] =="
405+
+ " 'hellohello'",
406+
"cel.bind(@r0, {?\"hello\": optional.of(\"hello\")}[\"hello\"], @r0 + @r0) =="
407+
+ " \"hellohello\""),
408+
OPTIONAL_MESSAGE(
409+
"TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:"
410+
+ " optional.of(4)}.single_int32 + TestAllTypes{?single_int64:"
411+
+ " optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}.single_int64 == 5",
412+
"cel.bind(@r0, TestAllTypes{"
413+
+ "?single_int64: optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}, "
414+
+ "@r0.single_int32 + @r0.single_int64) == 5"),
325415
;
326416

327417
private final String source;
@@ -342,7 +432,9 @@ public void cse_withMacroMapPopulated_success(@TestParameter CseTestCase testCas
342432

343433
assertThat(
344434
CEL.createProgram(optimizedAst)
345-
.eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L)))
435+
.eval(
436+
ImmutableMap.of(
437+
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))))
346438
.isEqualTo(true);
347439
assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.unparsed);
348440
}
@@ -366,7 +458,9 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr
366458
assertThat(
367459
celWithoutMacroMap
368460
.createProgram(optimizedAst)
369-
.eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L)))
461+
.eval(
462+
ImmutableMap.of(
463+
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))))
370464
.isEqualTo(true);
371465
}
372466

@@ -384,7 +478,8 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr
384478
@TestParameters("{source: 'custom_func(1) + custom_func(1)'}")
385479
// Duplicated but nested calls.
386480
@TestParameters("{source: 'int(timestamp(int(timestamp(1000000000))))'}")
387-
// Ternary with presence test is not supported yet.
481+
// This cannot be optimized. Extracting the common subexpression would presence test
482+
// the bound identifier (e.g: has(@r0)), which is not valid.
388483
@TestParameters("{source: 'has(msg.single_any) ? msg.single_any : 10'}")
389484
public void cse_noop(String source) throws Exception {
390485
CelAbstractSyntaxTree ast = CEL.compile(source).getAst();

0 commit comments

Comments
 (0)