Skip to content

Commit 317238b

Browse files
authored
Merge pull request #94 from jaredzhou/expose-eval
expose Eval function for custom partial eval
2 parents ccdc86f + c2a25ad commit 317238b

File tree

5 files changed

+451
-34
lines changed

5 files changed

+451
-34
lines changed

internal/eval/compile.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func (e *BoolEvaler) Eval(env Env) (types.Boolean, error) {
2626
func Compile(p *ast.Policy) BoolEvaler {
2727
p = foldPolicy(p)
2828
node := policyToNode(p).AsIsNode()
29-
return BoolEvaler{eval: toEval(node)}
29+
return BoolEvaler{eval: ToEval(node)}
3030
}
3131

3232
func policyToNode(p *ast.Policy) ast.Node {

internal/eval/convert.go

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,48 +8,48 @@ import (
88
"github.com/cedar-policy/cedar-go/x/exp/ast"
99
)
1010

11-
func toEval(n ast.IsNode) Evaler {
11+
func ToEval(n ast.IsNode) Evaler {
1212
switch v := n.(type) {
1313
case ast.NodeTypeAccess:
14-
return newAttributeAccessEval(toEval(v.Arg), v.Value)
14+
return newAttributeAccessEval(ToEval(v.Arg), v.Value)
1515
case ast.NodeTypeHas:
16-
return newHasEval(toEval(v.Arg), v.Value)
16+
return newHasEval(ToEval(v.Arg), v.Value)
1717
case ast.NodeTypeGetTag:
18-
return newGetTagEval(toEval(v.Left), toEval(v.Right))
18+
return newGetTagEval(ToEval(v.Left), ToEval(v.Right))
1919
case ast.NodeTypeHasTag:
20-
return newHasTagEval(toEval(v.Left), toEval(v.Right))
20+
return newHasTagEval(ToEval(v.Left), ToEval(v.Right))
2121
case ast.NodeTypeLike:
22-
return newLikeEval(toEval(v.Arg), v.Value)
22+
return newLikeEval(ToEval(v.Arg), v.Value)
2323
case ast.NodeTypeIfThenElse:
24-
return newIfThenElseEval(toEval(v.If), toEval(v.Then), toEval(v.Else))
24+
return newIfThenElseEval(ToEval(v.If), ToEval(v.Then), ToEval(v.Else))
2525
case ast.NodeTypeIs:
26-
return newIsEval(toEval(v.Left), v.EntityType)
26+
return newIsEval(ToEval(v.Left), v.EntityType)
2727
case ast.NodeTypeIsIn:
28-
return newIsInEval(toEval(v.Left), v.EntityType, toEval(v.Entity))
28+
return newIsInEval(ToEval(v.Left), v.EntityType, ToEval(v.Entity))
2929
case ast.NodeTypeExtensionCall:
3030
args := make([]Evaler, len(v.Args))
3131
for i, a := range v.Args {
32-
args[i] = toEval(a)
32+
args[i] = ToEval(a)
3333
}
3434
return newExtensionEval(v.Name, args)
3535
case ast.NodeValue:
3636
return newLiteralEval(v.Value)
3737
case ast.NodeTypeRecord:
3838
m := make(map[types.String]Evaler, len(v.Elements))
3939
for _, e := range v.Elements {
40-
m[e.Key] = toEval(e.Value)
40+
m[e.Key] = ToEval(e.Value)
4141
}
4242
return newRecordLiteralEval(m)
4343
case ast.NodeTypeSet:
4444
s := make([]Evaler, len(v.Elements))
4545
for i, e := range v.Elements {
46-
s[i] = toEval(e)
46+
s[i] = ToEval(e)
4747
}
4848
return newSetLiteralEval(s)
4949
case ast.NodeTypeNegate:
50-
return newNegateEval(toEval(v.Arg))
50+
return newNegateEval(ToEval(v.Arg))
5151
case ast.NodeTypeNot:
52-
return newNotEval(toEval(v.Arg))
52+
return newNotEval(ToEval(v.Arg))
5353
case ast.NodeTypeVariable:
5454
switch v.Name {
5555
case consts.Principal, consts.Action, consts.Resource, consts.Context:
@@ -58,37 +58,37 @@ func toEval(n ast.IsNode) Evaler {
5858
panic(fmt.Errorf("unknown variable: %v", v.Name))
5959
}
6060
case ast.NodeTypeIn:
61-
return newInEval(toEval(v.Left), toEval(v.Right))
61+
return newInEval(ToEval(v.Left), ToEval(v.Right))
6262
case ast.NodeTypeAnd:
63-
return newAndEval(toEval(v.Left), toEval(v.Right))
63+
return newAndEval(ToEval(v.Left), ToEval(v.Right))
6464
case ast.NodeTypeOr:
65-
return newOrEval(toEval(v.Left), toEval(v.Right))
65+
return newOrEval(ToEval(v.Left), ToEval(v.Right))
6666
case ast.NodeTypeEquals:
67-
return newEqualEval(toEval(v.Left), toEval(v.Right))
67+
return newEqualEval(ToEval(v.Left), ToEval(v.Right))
6868
case ast.NodeTypeNotEquals:
69-
return newNotEqualEval(toEval(v.Left), toEval(v.Right))
69+
return newNotEqualEval(ToEval(v.Left), ToEval(v.Right))
7070
case ast.NodeTypeGreaterThan:
71-
return newComparableValueGreaterThanEval(toEval(v.Left), toEval(v.Right))
71+
return newComparableValueGreaterThanEval(ToEval(v.Left), ToEval(v.Right))
7272
case ast.NodeTypeGreaterThanOrEqual:
73-
return newComparableValueGreaterThanOrEqualEval(toEval(v.Left), toEval(v.Right))
73+
return newComparableValueGreaterThanOrEqualEval(ToEval(v.Left), ToEval(v.Right))
7474
case ast.NodeTypeLessThan:
75-
return newComparableValueLessThanEval(toEval(v.Left), toEval(v.Right))
75+
return newComparableValueLessThanEval(ToEval(v.Left), ToEval(v.Right))
7676
case ast.NodeTypeLessThanOrEqual:
77-
return newComparableValueLessThanOrEqualEval(toEval(v.Left), toEval(v.Right))
77+
return newComparableValueLessThanOrEqualEval(ToEval(v.Left), ToEval(v.Right))
7878
case ast.NodeTypeSub:
79-
return newSubtractEval(toEval(v.Left), toEval(v.Right))
79+
return newSubtractEval(ToEval(v.Left), ToEval(v.Right))
8080
case ast.NodeTypeAdd:
81-
return newAddEval(toEval(v.Left), toEval(v.Right))
81+
return newAddEval(ToEval(v.Left), ToEval(v.Right))
8282
case ast.NodeTypeMult:
83-
return newMultiplyEval(toEval(v.Left), toEval(v.Right))
83+
return newMultiplyEval(ToEval(v.Left), ToEval(v.Right))
8484
case ast.NodeTypeContains:
85-
return newContainsEval(toEval(v.Left), toEval(v.Right))
85+
return newContainsEval(ToEval(v.Left), ToEval(v.Right))
8686
case ast.NodeTypeContainsAll:
87-
return newContainsAllEval(toEval(v.Left), toEval(v.Right))
87+
return newContainsAllEval(ToEval(v.Left), ToEval(v.Right))
8888
case ast.NodeTypeContainsAny:
89-
return newContainsAnyEval(toEval(v.Left), toEval(v.Right))
89+
return newContainsAnyEval(ToEval(v.Left), ToEval(v.Right))
9090
case ast.NodeTypeIsEmpty:
91-
return newIsEmptyEval(toEval(v.Arg))
91+
return newIsEmptyEval(ToEval(v.Arg))
9292
default:
9393
panic(fmt.Sprintf("unknown node type %T", v))
9494
}

internal/eval/convert_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ func TestToEval(t *testing.T) {
367367
tt := tt
368368
t.Run(tt.name, func(t *testing.T) {
369369
t.Parallel()
370-
e := toEval(tt.in.AsIsNode())
370+
e := ToEval(tt.in.AsIsNode())
371371
out, err := e.Eval(Env{
372372
Principal: types.NewEntityUID("Actor", "principal"),
373373
Action: types.NewEntityUID("Action", "test"),
@@ -389,13 +389,13 @@ func TestToEval(t *testing.T) {
389389
func TestToEvalPanic(t *testing.T) {
390390
t.Parallel()
391391
testutil.Panic(t, func() {
392-
_ = toEval(ast.Node{}.AsIsNode())
392+
_ = ToEval(ast.Node{}.AsIsNode())
393393
})
394394
}
395395

396396
func TestToEvalVariablePanic(t *testing.T) {
397397
t.Parallel()
398398
testutil.Panic(t, func() {
399-
_ = toEval(ast.NodeTypeVariable{Name: "bananas"})
399+
_ = ToEval(ast.NodeTypeVariable{Name: "bananas"})
400400
})
401401
}

x/exp/eval/eval.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Package eval provides a simple interface for evaluating a policy node in a given environment.
2+
package eval
3+
4+
import (
5+
"github.com/cedar-policy/cedar-go/internal/eval"
6+
"github.com/cedar-policy/cedar-go/types"
7+
"github.com/cedar-policy/cedar-go/x/exp/ast"
8+
)
9+
10+
// Env is the environment for evaluating a policy.
11+
type Env = eval.Env
12+
13+
// Eval evaluates a policy node in the given environment.
14+
func Eval(n ast.IsNode, env Env) (types.Value, error) {
15+
evaler := eval.ToEval(n)
16+
return evaler.Eval(env)
17+
}

0 commit comments

Comments
 (0)