Skip to content

Commit f97e5fa

Browse files
l46kokcopybara-github
authored andcommitted
Support pruning aggregate literals in optionals
PiperOrigin-RevId: 936809712
1 parent 534f651 commit f97e5fa

2 files changed

Lines changed: 63 additions & 14 deletions

File tree

optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
import java.util.HashSet;
5959
import java.util.List;
6060
import java.util.Map;
61-
import java.util.Map.Entry;
6261
import java.util.Optional;
6362

6463
/**
@@ -309,8 +308,12 @@ private Optional<CelMutableAst> maybeFold(
309308
return maybeRewriteOptional(optResult, mutableAst, node.expr());
310309
}
311310

312-
return maybeAdaptEvaluatedResult(result)
313-
.map(celExpr -> astMutator.replaceSubtree(mutableAst, celExpr, node.id()));
311+
CelMutableExpr adaptedResult = maybeAdaptEvaluatedResult(result).orElse(null);
312+
if (adaptedResult == null) {
313+
return Optional.empty();
314+
}
315+
316+
return Optional.of(astMutator.replaceSubtree(mutableAst, adaptedResult, node.id()));
314317
}
315318

316319
private Optional<CelMutableExpr> maybeAdaptEvaluatedResult(Object result) {
@@ -331,7 +334,7 @@ private Optional<CelMutableExpr> maybeAdaptEvaluatedResult(Object result) {
331334
} else if (result instanceof Map<?, ?>) {
332335
Map<?, ?> map = (Map<?, ?>) result;
333336
List<CelMutableMap.Entry> mapEntries = new ArrayList<>();
334-
for (Entry<?, ?> entry : map.entrySet()) {
337+
for (Map.Entry<?, ?> entry : map.entrySet()) {
335338
CelMutableExpr adaptedKey = maybeAdaptEvaluatedResult(entry.getKey()).orElse(null);
336339
if (adaptedKey == null) {
337340
return Optional.empty();
@@ -384,16 +387,15 @@ private Optional<CelMutableAst> maybeRewriteOptional(
384387
return Optional.empty();
385388
}
386389

387-
if (!CelConstant.isConstantValue(unwrappedResult)) {
388-
// Evaluated result is not a constant. Leave the optional as is.
390+
CelMutableExpr adaptedResult = maybeAdaptEvaluatedResult(unwrappedResult).orElse(null);
391+
if (adaptedResult == null) {
392+
// Evaluated result is not an adaptable constant. Leave the optional as is.
389393
return Optional.empty();
390394
}
391395

392396
CelMutableExpr newOptionalOfCall =
393397
CelMutableExpr.ofCall(
394-
CelMutableCall.create(
395-
Function.OPTIONAL_OF.getFunction(),
396-
CelMutableExpr.ofConstant(CelConstant.ofObjectValue(unwrappedResult))));
398+
CelMutableCall.create(Function.OPTIONAL_OF.getFunction(), adaptedResult));
397399

398400
return Optional.of(astMutator.replaceSubtree(mutableAst, newOptionalOfCall, expr.id()));
399401
}
@@ -530,6 +532,37 @@ private Optional<CelMutableAst> maybeShortCircuitCall(
530532
"Folding variadic logical operator is not supported yet.");
531533
}
532534

535+
private boolean isFoldedAggregateLiteral(CelMutableExpr expr) {
536+
if (expr.getKind().equals(Kind.CONSTANT)) {
537+
return true;
538+
}
539+
if (expr.getKind().equals(Kind.LIST)) {
540+
for (CelMutableExpr child : expr.list().elements()) {
541+
if (!isFoldedAggregateLiteral(child)) {
542+
return false;
543+
}
544+
}
545+
return true;
546+
}
547+
if (expr.getKind().equals(Kind.MAP)) {
548+
for (CelMutableExpr.CelMutableMap.Entry entry : expr.map().entries()) {
549+
if (!isFoldedAggregateLiteral(entry.key()) || !isFoldedAggregateLiteral(entry.value())) {
550+
return false;
551+
}
552+
}
553+
return true;
554+
}
555+
if (expr.getKind().equals(Kind.STRUCT)) {
556+
for (CelMutableExpr.CelMutableStruct.Entry entry : expr.struct().entries()) {
557+
if (!isFoldedAggregateLiteral(entry.value())) {
558+
return false;
559+
}
560+
}
561+
return true;
562+
}
563+
return false;
564+
}
565+
533566
private CelMutableAst pruneOptionalElements(CelMutableAst ast) {
534567
ImmutableList<CelMutableExpr> aggregateLiterals =
535568
CelNavigableMutableExpr.fromExpr(ast.expr())
@@ -588,7 +621,7 @@ private CelMutableAst pruneOptionalListElements(CelMutableAst mutableAst, CelMut
588621
continue;
589622
} else if (call.function().equals(Function.OPTIONAL_OF.getFunction())) {
590623
CelMutableExpr arg = call.args().get(0);
591-
if (arg.getKind().equals(Kind.CONSTANT)) {
624+
if (isFoldedAggregateLiteral(arg)) {
592625
updatedElemBuilder.add(call.args().get(0));
593626
continue;
594627
}
@@ -629,7 +662,7 @@ private CelMutableAst pruneOptionalMapElements(CelMutableAst ast, CelMutableExpr
629662
continue;
630663
} else if (call.function().equals(Function.OPTIONAL_OF.getFunction())) {
631664
CelMutableExpr arg = call.args().get(0);
632-
if (arg.getKind().equals(Kind.CONSTANT)) {
665+
if (isFoldedAggregateLiteral(arg)) {
633666
modified = true;
634667
entry.setOptionalEntry(false);
635668
entry.setValue(call.args().get(0));
@@ -670,7 +703,7 @@ private CelMutableAst pruneOptionalStructElements(CelMutableAst ast, CelMutableE
670703
continue;
671704
} else if (call.function().equals(Function.OPTIONAL_OF.getFunction())) {
672705
CelMutableExpr arg = call.args().get(0);
673-
if (arg.getKind().equals(Kind.CONSTANT)) {
706+
if (isFoldedAggregateLiteral(arg)) {
674707
modified = true;
675708
entry.setOptionalEntry(false);
676709
entry.setValue(call.args().get(0));

optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,24 @@ private static Cel setupEnv(CelBuilder celBuilder) {
255255
@TestParameters(
256256
"{source: 'has({\"req\": \"Avail\"}.opt) ? ({\"req\": \"Avail\"}.req + \" \" +"
257257
+ " {\"req\": \"Avail\"}.opt) : {\"req\": \"Avail\"}.req', expected: '\"Avail\"'}")
258+
@TestParameters("{source: 'true || optional.none().hasValue()', expected: 'true'}")
259+
@TestParameters("{source: 'false && map_var[?\"missing\"].hasValue()', expected: 'false'}")
260+
@TestParameters("{source: '{\"hello\": [1, 2]}.?hello', expected: 'optional.of([1, 2])'}")
261+
@TestParameters(
262+
"{source: '{?\"key\": optional.of({\"a\": 1})}', expected: '{\"key\": {\"a\": 1}}'}")
263+
@TestParameters(
264+
"{source: 'TestAllTypes{?repeated_int32: optional.of([1, 2])}',"
265+
+ " expected: 'cel.expr.conformance.proto3.TestAllTypes{repeated_int32: [1, 2]}'}")
266+
@TestParameters("{source: '[?optional.of([1, x])]', expected: '[?optional.of([1, x])]'}")
267+
@TestParameters("{source: '[?optional.of({\"a\": x})]', expected: '[?optional.of({\"a\": x})]'}")
268+
@TestParameters("{source: '[?optional.of({x: 1})]', expected: '[?optional.of({x: 1})]'}")
269+
@TestParameters(
270+
"{source: '[?optional.of(TestAllTypes{single_int32: x})]', expected:"
271+
+ " '[?optional.of(cel.expr.conformance.proto3.TestAllTypes{single_int32: x})]'}")
272+
@TestParameters(
273+
"{source: '[?optional.of(TestAllTypes{single_int32: 1})]', expected:"
274+
+ " '[cel.expr.conformance.proto3.TestAllTypes{single_int32: 1}]'}")
275+
@TestParameters("{source: '[?optional.of(x)]', expected: '[?optional.of(x)]'}")
258276
// TODO: Support folding lists with mixed types. This requires mutable lists.
259277
// @TestParameters("{source: 'dyn([1]) + [1.0]'}")
260278
public void constantFold_success(String source, String expected) throws Exception {
@@ -560,6 +578,4 @@ public void iterationLimitReached_throws() throws Exception {
560578
assertThrows(CelOptimizationException.class, () -> optimizer.optimize(ast));
561579
assertThat(e).hasMessageThat().contains("Optimization failure: Max iteration count reached.");
562580
}
563-
564-
565581
}

0 commit comments

Comments
 (0)