|
21 | 21 | import com.google.auto.value.AutoValue; |
22 | 22 | import com.google.common.base.Preconditions; |
23 | 23 | import com.google.common.base.Strings; |
| 24 | +import com.google.common.collect.HashBasedTable; |
24 | 25 | import com.google.common.collect.ImmutableMap; |
| 26 | +import com.google.common.collect.Table; |
25 | 27 | import com.google.errorprone.annotations.Immutable; |
26 | 28 | import dev.cel.common.CelAbstractSyntaxTree; |
27 | 29 | import dev.cel.common.CelMutableAst; |
@@ -270,17 +272,22 @@ public CelMutableAst renumberIdsConsecutively(CelMutableAst mutableAst) { |
270 | 272 | * @param ast AST containing type-checked references |
271 | 273 | * @param newIterVarPrefix Prefix to use for new iteration variable identifier name. For example, |
272 | 274 | * providing @c will produce @c0:0, @c0:1, @c1:0, @c2:0... as new names. |
273 | | - * @param newResultPrefix Prefix to use for new comprehensin result identifier names. |
| 275 | + * @param newAccuVarPrefix Prefix to use for new accumulation variable identifier name. |
| 276 | + * @param incrementSerially If true, indices for the mangled variables are incremented serially |
| 277 | + * per occurrence regardless of their nesting level or its types. |
274 | 278 | */ |
275 | 279 | public MangledComprehensionAst mangleComprehensionIdentifierNames( |
276 | | - CelMutableAst ast, String newIterVarPrefix, String newResultPrefix) { |
| 280 | + CelMutableAst ast, |
| 281 | + String newIterVarPrefix, |
| 282 | + String newAccuVarPrefix, |
| 283 | + boolean incrementSerially) { |
277 | 284 | CelNavigableMutableAst navigableMutableAst = CelNavigableMutableAst.fromAst(ast); |
278 | 285 | Predicate<CelNavigableMutableExpr> comprehensionIdentifierPredicate = x -> true; |
279 | 286 | comprehensionIdentifierPredicate = |
280 | 287 | comprehensionIdentifierPredicate |
281 | 288 | .and(node -> node.getKind().equals(Kind.COMPREHENSION)) |
282 | 289 | .and(node -> !node.expr().comprehension().iterVar().startsWith(newIterVarPrefix)) |
283 | | - .and(node -> !node.expr().comprehension().accuVar().startsWith(newResultPrefix)); |
| 290 | + .and(node -> !node.expr().comprehension().accuVar().startsWith(newAccuVarPrefix)); |
284 | 291 |
|
285 | 292 | LinkedHashMap<CelNavigableMutableExpr, MangledComprehensionType> comprehensionsToMangle = |
286 | 293 | navigableMutableAst |
@@ -352,18 +359,43 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames( |
352 | 359 | // The map that we'll eventually return to the caller. |
353 | 360 | HashMap<MangledComprehensionName, MangledComprehensionType> mangledIdentNamesToType = |
354 | 361 | new HashMap<>(); |
| 362 | + // Intermediary table used for the purposes of generating a unique mangled variable name. |
| 363 | + Table<Integer, MangledComprehensionType, MangledComprehensionName> comprehensionLevelToType = |
| 364 | + HashBasedTable.create(); |
355 | 365 | CelMutableExpr mutatedComprehensionExpr = navigableMutableAst.getAst().expr(); |
356 | 366 | CelMutableSource newSource = navigableMutableAst.getAst().source(); |
357 | 367 | int iterCount = 0; |
358 | 368 | for (Entry<CelNavigableMutableExpr, MangledComprehensionType> comprehensionEntry : |
359 | 369 | comprehensionsToMangle.entrySet()) { |
360 | | - String mangledIterVarName = newIterVarPrefix + ":" + iterCount; |
361 | | - String mangledResultName = newResultPrefix + ":" + iterCount; |
362 | | - MangledComprehensionName mangledComprehensionName = |
363 | | - MangledComprehensionName.of(mangledIterVarName, mangledResultName); |
364 | | - mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntry.getValue()); |
| 370 | + CelNavigableMutableExpr comprehensionNode = comprehensionEntry.getKey(); |
| 371 | + MangledComprehensionType comprehensionEntryType = comprehensionEntry.getValue(); |
| 372 | + |
| 373 | + CelMutableExpr comprehensionExpr = comprehensionNode.expr(); |
| 374 | + MangledComprehensionName mangledComprehensionName; |
| 375 | + if (incrementSerially) { |
| 376 | + // In case of applying CSE via cascaded cel.binds, not only is mangling based on level/types |
| 377 | + // meaningless (because all comprehensions are nested anyways, thus all indices would be |
| 378 | + // uinque), |
| 379 | + // it can lead to an erroneous result due to extracting a common subexpr with accu_var at |
| 380 | + // the wrong scope. |
| 381 | + // Example: "[1].exists(k, k > 1) && [2].exists(l, l > 1). The loop step for both branches |
| 382 | + // are identical, but shouldn't be extracted. |
| 383 | + String mangledIterVarName = newIterVarPrefix + ":" + iterCount; |
| 384 | + String mangledResultName = newAccuVarPrefix + ":" + iterCount; |
| 385 | + mangledComprehensionName = |
| 386 | + MangledComprehensionName.of(mangledIterVarName, mangledResultName); |
| 387 | + mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntry.getValue()); |
| 388 | + } else { |
| 389 | + mangledComprehensionName = |
| 390 | + getMangledComprehensionName( |
| 391 | + newIterVarPrefix, |
| 392 | + newAccuVarPrefix, |
| 393 | + comprehensionNode, |
| 394 | + comprehensionLevelToType, |
| 395 | + comprehensionEntryType); |
| 396 | + } |
| 397 | + mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntryType); |
365 | 398 |
|
366 | | - CelMutableExpr comprehensionExpr = comprehensionEntry.getKey().expr(); |
367 | 399 | String iterVar = comprehensionExpr.comprehension().iterVar(); |
368 | 400 | String accuVar = comprehensionExpr.comprehension().accuVar(); |
369 | 401 | mutatedComprehensionExpr = |
@@ -396,6 +428,45 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames( |
396 | 428 | ImmutableMap.copyOf(mangledIdentNamesToType)); |
397 | 429 | } |
398 | 430 |
|
| 431 | + private static MangledComprehensionName getMangledComprehensionName( |
| 432 | + String newIterVarPrefix, |
| 433 | + String newResultPrefix, |
| 434 | + CelNavigableMutableExpr comprehensionNode, |
| 435 | + Table<Integer, MangledComprehensionType, MangledComprehensionName> comprehensionLevelToType, |
| 436 | + MangledComprehensionType comprehensionEntryType) { |
| 437 | + MangledComprehensionName mangledComprehensionName; |
| 438 | + int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode); |
| 439 | + if (comprehensionLevelToType.contains(comprehensionNestingLevel, comprehensionEntryType)) { |
| 440 | + mangledComprehensionName = |
| 441 | + comprehensionLevelToType.get(comprehensionNestingLevel, comprehensionEntryType); |
| 442 | + } else { |
| 443 | + // First time encountering the pair of <ComprehensionLevel, CelType>. Generate a unique |
| 444 | + // mangled variable name for this. |
| 445 | + int uniqueTypeIdx = comprehensionLevelToType.row(comprehensionNestingLevel).size(); |
| 446 | + String mangledIterVarName = |
| 447 | + newIterVarPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx; |
| 448 | + String mangledResultName = |
| 449 | + newResultPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx; |
| 450 | + mangledComprehensionName = MangledComprehensionName.of(mangledIterVarName, mangledResultName); |
| 451 | + comprehensionLevelToType.put( |
| 452 | + comprehensionNestingLevel, comprehensionEntryType, mangledComprehensionName); |
| 453 | + } |
| 454 | + return mangledComprehensionName; |
| 455 | + } |
| 456 | + |
| 457 | + private static int countComprehensionNestingLevel(CelNavigableMutableExpr comprehensionExpr) { |
| 458 | + int nestedLevel = 0; |
| 459 | + Optional<CelNavigableMutableExpr> maybeParent = comprehensionExpr.parent(); |
| 460 | + while (maybeParent.isPresent()) { |
| 461 | + if (maybeParent.get().getKind().equals(Kind.COMPREHENSION)) { |
| 462 | + nestedLevel++; |
| 463 | + } |
| 464 | + |
| 465 | + maybeParent = maybeParent.get().parent(); |
| 466 | + } |
| 467 | + return nestedLevel; |
| 468 | + } |
| 469 | + |
399 | 470 | /** |
400 | 471 | * Replaces a subtree in the given expression node. This operation is intended for AST |
401 | 472 | * optimization purposes. |
|
0 commit comments