Skip to content

Commit 6fd6a4c

Browse files
l46kokcopybara-github
authored andcommitted
Improve CSE for comprehensions by taking into their nesting level and types into account
PiperOrigin-RevId: 664956665
1 parent c1a05d6 commit 6fd6a4c

21 files changed

+4029
-7173
lines changed

optimizer/src/main/java/dev/cel/optimizer/AstMutator.java

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import com.google.auto.value.AutoValue;
2222
import com.google.common.base.Preconditions;
2323
import com.google.common.base.Strings;
24+
import com.google.common.collect.HashBasedTable;
2425
import com.google.common.collect.ImmutableMap;
26+
import com.google.common.collect.Table;
2527
import com.google.errorprone.annotations.Immutable;
2628
import dev.cel.common.CelAbstractSyntaxTree;
2729
import dev.cel.common.CelMutableAst;
@@ -270,17 +272,22 @@ public CelMutableAst renumberIdsConsecutively(CelMutableAst mutableAst) {
270272
* @param ast AST containing type-checked references
271273
* @param newIterVarPrefix Prefix to use for new iteration variable identifier name. For example,
272274
* providing @c will produce @c0:0, @c0:1, @c1:0, @c2:0... as new names.
273-
* @param newResultPrefix Prefix to use for new comprehensin result identifier names.
275+
* @param newAccuVarPrefix Prefix to use for new accumulation variable identifier name.
276+
* @param incrementSerially If true, indices for the mangled variables are incremented serially
277+
* per occurrence regardless of their nesting level or its types.
274278
*/
275279
public MangledComprehensionAst mangleComprehensionIdentifierNames(
276-
CelMutableAst ast, String newIterVarPrefix, String newResultPrefix) {
280+
CelMutableAst ast,
281+
String newIterVarPrefix,
282+
String newAccuVarPrefix,
283+
boolean incrementSerially) {
277284
CelNavigableMutableAst navigableMutableAst = CelNavigableMutableAst.fromAst(ast);
278285
Predicate<CelNavigableMutableExpr> comprehensionIdentifierPredicate = x -> true;
279286
comprehensionIdentifierPredicate =
280287
comprehensionIdentifierPredicate
281288
.and(node -> node.getKind().equals(Kind.COMPREHENSION))
282289
.and(node -> !node.expr().comprehension().iterVar().startsWith(newIterVarPrefix))
283-
.and(node -> !node.expr().comprehension().accuVar().startsWith(newResultPrefix));
290+
.and(node -> !node.expr().comprehension().accuVar().startsWith(newAccuVarPrefix));
284291

285292
LinkedHashMap<CelNavigableMutableExpr, MangledComprehensionType> comprehensionsToMangle =
286293
navigableMutableAst
@@ -352,18 +359,43 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
352359
// The map that we'll eventually return to the caller.
353360
HashMap<MangledComprehensionName, MangledComprehensionType> mangledIdentNamesToType =
354361
new HashMap<>();
362+
// Intermediary table used for the purposes of generating a unique mangled variable name.
363+
Table<Integer, MangledComprehensionType, MangledComprehensionName> comprehensionLevelToType =
364+
HashBasedTable.create();
355365
CelMutableExpr mutatedComprehensionExpr = navigableMutableAst.getAst().expr();
356366
CelMutableSource newSource = navigableMutableAst.getAst().source();
357367
int iterCount = 0;
358368
for (Entry<CelNavigableMutableExpr, MangledComprehensionType> comprehensionEntry :
359369
comprehensionsToMangle.entrySet()) {
360-
String mangledIterVarName = newIterVarPrefix + ":" + iterCount;
361-
String mangledResultName = newResultPrefix + ":" + iterCount;
362-
MangledComprehensionName mangledComprehensionName =
363-
MangledComprehensionName.of(mangledIterVarName, mangledResultName);
364-
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntry.getValue());
370+
CelNavigableMutableExpr comprehensionNode = comprehensionEntry.getKey();
371+
MangledComprehensionType comprehensionEntryType = comprehensionEntry.getValue();
372+
373+
CelMutableExpr comprehensionExpr = comprehensionNode.expr();
374+
MangledComprehensionName mangledComprehensionName;
375+
if (incrementSerially) {
376+
// In case of applying CSE via cascaded cel.binds, not only is mangling based on level/types
377+
// meaningless (because all comprehensions are nested anyways, thus all indices would be
378+
// uinque),
379+
// it can lead to an erroneous result due to extracting a common subexpr with accu_var at
380+
// the wrong scope.
381+
// Example: "[1].exists(k, k > 1) && [2].exists(l, l > 1). The loop step for both branches
382+
// are identical, but shouldn't be extracted.
383+
String mangledIterVarName = newIterVarPrefix + ":" + iterCount;
384+
String mangledResultName = newAccuVarPrefix + ":" + iterCount;
385+
mangledComprehensionName =
386+
MangledComprehensionName.of(mangledIterVarName, mangledResultName);
387+
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntry.getValue());
388+
} else {
389+
mangledComprehensionName =
390+
getMangledComprehensionName(
391+
newIterVarPrefix,
392+
newAccuVarPrefix,
393+
comprehensionNode,
394+
comprehensionLevelToType,
395+
comprehensionEntryType);
396+
}
397+
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntryType);
365398

366-
CelMutableExpr comprehensionExpr = comprehensionEntry.getKey().expr();
367399
String iterVar = comprehensionExpr.comprehension().iterVar();
368400
String accuVar = comprehensionExpr.comprehension().accuVar();
369401
mutatedComprehensionExpr =
@@ -396,6 +428,45 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
396428
ImmutableMap.copyOf(mangledIdentNamesToType));
397429
}
398430

431+
private static MangledComprehensionName getMangledComprehensionName(
432+
String newIterVarPrefix,
433+
String newResultPrefix,
434+
CelNavigableMutableExpr comprehensionNode,
435+
Table<Integer, MangledComprehensionType, MangledComprehensionName> comprehensionLevelToType,
436+
MangledComprehensionType comprehensionEntryType) {
437+
MangledComprehensionName mangledComprehensionName;
438+
int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode);
439+
if (comprehensionLevelToType.contains(comprehensionNestingLevel, comprehensionEntryType)) {
440+
mangledComprehensionName =
441+
comprehensionLevelToType.get(comprehensionNestingLevel, comprehensionEntryType);
442+
} else {
443+
// First time encountering the pair of <ComprehensionLevel, CelType>. Generate a unique
444+
// mangled variable name for this.
445+
int uniqueTypeIdx = comprehensionLevelToType.row(comprehensionNestingLevel).size();
446+
String mangledIterVarName =
447+
newIterVarPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
448+
String mangledResultName =
449+
newResultPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
450+
mangledComprehensionName = MangledComprehensionName.of(mangledIterVarName, mangledResultName);
451+
comprehensionLevelToType.put(
452+
comprehensionNestingLevel, comprehensionEntryType, mangledComprehensionName);
453+
}
454+
return mangledComprehensionName;
455+
}
456+
457+
private static int countComprehensionNestingLevel(CelNavigableMutableExpr comprehensionExpr) {
458+
int nestedLevel = 0;
459+
Optional<CelNavigableMutableExpr> maybeParent = comprehensionExpr.parent();
460+
while (maybeParent.isPresent()) {
461+
if (maybeParent.get().getKind().equals(Kind.COMPREHENSION)) {
462+
nestedLevel++;
463+
}
464+
465+
maybeParent = maybeParent.get().parent();
466+
}
467+
return nestedLevel;
468+
}
469+
399470
/**
400471
* Replaces a subtree in the given expression node. This operation is intended for AST
401472
* optimization purposes.

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
9292
private static final SubexpressionOptimizer INSTANCE =
9393
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
9494
private static final String BIND_IDENTIFIER_PREFIX = "@r";
95-
private static final String MANGLED_COMPREHENSION_IDENTIFIER_PREFIX = "@c";
96-
private static final String MANGLED_COMPREHENSION_RESULT_PREFIX = "@x";
95+
private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
96+
private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
9797
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
9898
private static final String BLOCK_INDEX_PREFIX = "@index";
9999
private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
@@ -138,8 +138,9 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
138138
MangledComprehensionAst mangledComprehensionAst =
139139
astMutator.mangleComprehensionIdentifierNames(
140140
astToModify,
141-
MANGLED_COMPREHENSION_IDENTIFIER_PREFIX,
142-
MANGLED_COMPREHENSION_RESULT_PREFIX);
141+
MANGLED_COMPREHENSION_ITER_VAR_PREFIX,
142+
MANGLED_COMPREHENSION_ACCU_VAR_PREFIX,
143+
/* incrementSerially= */ false);
143144
astToModify = mangledComprehensionAst.mutableAst();
144145
CelMutableSource sourceToModify = astToModify.source();
145146

@@ -339,8 +340,9 @@ private OptimizationResult optimizeUsingCelBind(CelAbstractSyntaxTree ast) {
339340
astMutator
340341
.mangleComprehensionIdentifierNames(
341342
astToModify,
342-
MANGLED_COMPREHENSION_IDENTIFIER_PREFIX,
343-
MANGLED_COMPREHENSION_RESULT_PREFIX)
343+
MANGLED_COMPREHENSION_ITER_VAR_PREFIX,
344+
MANGLED_COMPREHENSION_ACCU_VAR_PREFIX,
345+
/* incrementSerially= */ true)
344346
.mutableAst();
345347
CelMutableSource sourceToModify = astToModify.source();
346348

0 commit comments

Comments
 (0)