Skip to content

Commit a2ac28e

Browse files
l46kokcopybara-github
authored andcommitted
Enforce strictness of type function
This ensures that: - type(unknown) -> unknown - type(error) -> error PiperOrigin-RevId: 686325313
1 parent 7e6578c commit a2ac28e

File tree

4 files changed

+46
-5
lines changed

4 files changed

+46
-5
lines changed

runtime/src/main/java/dev/cel/runtime/CelUnknownSet.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public static CelUnknownSet create(ImmutableSet<CelAttribute> attributes) {
4646
return create(attributes, ImmutableSet.of());
4747
}
4848

49-
static CelUnknownSet create(Long... unknownExprIds) {
49+
public static CelUnknownSet create(Long... unknownExprIds) {
5050
return create(ImmutableSet.copyOf(unknownExprIds));
5151
}
5252

runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,14 @@ private IntermediateResult evalInternal(ExecutionFrame frame, CelExpr expr)
222222
}
223223
}
224224

225-
private boolean isUnknownValue(Object value) {
225+
private static boolean isUnknownValue(Object value) {
226226
return value instanceof CelUnknownSet || InterpreterUtil.isUnknown(value);
227227
}
228228

229+
private static boolean isUnknownOrError(Object value) {
230+
return isUnknownValue(value) || value instanceof Exception;
231+
}
232+
229233
private Object evalConstant(
230234
ExecutionFrame unusedFrame, CelExpr unusedExpr, CelConstant constExpr) {
231235
switch (constExpr.getKind()) {
@@ -593,6 +597,10 @@ private IntermediateResult evalType(ExecutionFrame frame, CelCall callExpr)
593597
throws InterpreterException {
594598
CelExpr typeExprArg = callExpr.args().get(0);
595599
IntermediateResult argResult = evalInternal(frame, typeExprArg);
600+
// Type is a strict function. Early return if the argument is an error or an unknown.
601+
if (isUnknownOrError(argResult.value())) {
602+
return argResult;
603+
}
596604

597605
CelType checkedType =
598606
ast.getType(typeExprArg.id())
@@ -682,9 +690,7 @@ private IntermediateResult evalBoolean(ExecutionFrame frame, CelExpr expr, boole
682690
throws InterpreterException {
683691
IntermediateResult value = strict ? evalInternal(frame, expr) : evalNonstrictly(frame, expr);
684692

685-
if (!(value.value() instanceof Boolean)
686-
&& !isUnknownValue(value.value())
687-
&& !(value.value() instanceof Exception)) {
693+
if (!(value.value() instanceof Boolean) && !isUnknownOrError(value.value())) {
688694
throw new InterpreterException.Builder("expected boolean value, found: %s", value.value())
689695
.setErrorCode(CelErrorCode.INVALID_ARGUMENT)
690696
.setLocation(metadata, expr.id())

runtime/src/test/resources/unknownResultSet.baseline

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,30 @@ result: unknown {
515515
exprs: 3
516516
exprs: 6
517517
}
518+
519+
520+
Source: type(x.single_int32)
521+
declare x {
522+
value google.api.expr.test.v1.proto3.TestAllTypes
523+
}
524+
declare f {
525+
function f int.(int) -> bool
526+
}
527+
=====>
528+
bindings: {}
529+
result: unknown {
530+
exprs: 2
531+
}
532+
533+
534+
Source: type(1 / 0 > 2)
535+
declare x {
536+
value google.api.expr.test.v1.proto3.TestAllTypes
537+
}
538+
declare f {
539+
function f int.(int) -> bool
540+
}
541+
=====>
542+
bindings: {}
543+
error: evaluation error: / by zero
544+
error_code: DIVIDE_BY_ZERO

testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,6 +1298,14 @@ public void unknownResultSet() {
12981298
// message with multiple unknowns => unknownSet
12991299
source = "TestAllTypes{single_int32: x.single_int32, single_int64: x.single_int64}";
13001300
runTest();
1301+
1302+
// type(unknown) -> unknown
1303+
source = "type(x.single_int32)";
1304+
runTest();
1305+
1306+
// type(error) -> error
1307+
source = "type(1 / 0 > 2)";
1308+
runTest();
13011309
}
13021310

13031311
@Test

0 commit comments

Comments
 (0)