Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/eval/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (e *BoolEvaler) Eval(env Env) (types.Boolean, error) {
func Compile(p *ast.Policy) BoolEvaler {
p = foldPolicy(p)
node := policyToNode(p).AsIsNode()
return BoolEvaler{eval: toEval(node)}
return BoolEvaler{eval: ToEval(node)}
}

func policyToNode(p *ast.Policy) ast.Node {
Expand Down
60 changes: 30 additions & 30 deletions internal/eval/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,48 @@ import (
"github.com/cedar-policy/cedar-go/x/exp/ast"
)

func toEval(n ast.IsNode) Evaler {
func ToEval(n ast.IsNode) Evaler {
switch v := n.(type) {
case ast.NodeTypeAccess:
return newAttributeAccessEval(toEval(v.Arg), v.Value)
return newAttributeAccessEval(ToEval(v.Arg), v.Value)
case ast.NodeTypeHas:
return newHasEval(toEval(v.Arg), v.Value)
return newHasEval(ToEval(v.Arg), v.Value)
case ast.NodeTypeGetTag:
return newGetTagEval(toEval(v.Left), toEval(v.Right))
return newGetTagEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeHasTag:
return newHasTagEval(toEval(v.Left), toEval(v.Right))
return newHasTagEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeLike:
return newLikeEval(toEval(v.Arg), v.Value)
return newLikeEval(ToEval(v.Arg), v.Value)
case ast.NodeTypeIfThenElse:
return newIfThenElseEval(toEval(v.If), toEval(v.Then), toEval(v.Else))
return newIfThenElseEval(ToEval(v.If), ToEval(v.Then), ToEval(v.Else))
case ast.NodeTypeIs:
return newIsEval(toEval(v.Left), v.EntityType)
return newIsEval(ToEval(v.Left), v.EntityType)
case ast.NodeTypeIsIn:
return newIsInEval(toEval(v.Left), v.EntityType, toEval(v.Entity))
return newIsInEval(ToEval(v.Left), v.EntityType, ToEval(v.Entity))
case ast.NodeTypeExtensionCall:
args := make([]Evaler, len(v.Args))
for i, a := range v.Args {
args[i] = toEval(a)
args[i] = ToEval(a)
}
return newExtensionEval(v.Name, args)
case ast.NodeValue:
return newLiteralEval(v.Value)
case ast.NodeTypeRecord:
m := make(map[types.String]Evaler, len(v.Elements))
for _, e := range v.Elements {
m[e.Key] = toEval(e.Value)
m[e.Key] = ToEval(e.Value)
}
return newRecordLiteralEval(m)
case ast.NodeTypeSet:
s := make([]Evaler, len(v.Elements))
for i, e := range v.Elements {
s[i] = toEval(e)
s[i] = ToEval(e)
}
return newSetLiteralEval(s)
case ast.NodeTypeNegate:
return newNegateEval(toEval(v.Arg))
return newNegateEval(ToEval(v.Arg))
case ast.NodeTypeNot:
return newNotEval(toEval(v.Arg))
return newNotEval(ToEval(v.Arg))
case ast.NodeTypeVariable:
switch v.Name {
case consts.Principal, consts.Action, consts.Resource, consts.Context:
Expand All @@ -58,37 +58,37 @@ func toEval(n ast.IsNode) Evaler {
panic(fmt.Errorf("unknown variable: %v", v.Name))
}
case ast.NodeTypeIn:
return newInEval(toEval(v.Left), toEval(v.Right))
return newInEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeAnd:
return newAndEval(toEval(v.Left), toEval(v.Right))
return newAndEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeOr:
return newOrEval(toEval(v.Left), toEval(v.Right))
return newOrEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeEquals:
return newEqualEval(toEval(v.Left), toEval(v.Right))
return newEqualEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeNotEquals:
return newNotEqualEval(toEval(v.Left), toEval(v.Right))
return newNotEqualEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeGreaterThan:
return newComparableValueGreaterThanEval(toEval(v.Left), toEval(v.Right))
return newComparableValueGreaterThanEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeGreaterThanOrEqual:
return newComparableValueGreaterThanOrEqualEval(toEval(v.Left), toEval(v.Right))
return newComparableValueGreaterThanOrEqualEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeLessThan:
return newComparableValueLessThanEval(toEval(v.Left), toEval(v.Right))
return newComparableValueLessThanEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeLessThanOrEqual:
return newComparableValueLessThanOrEqualEval(toEval(v.Left), toEval(v.Right))
return newComparableValueLessThanOrEqualEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeSub:
return newSubtractEval(toEval(v.Left), toEval(v.Right))
return newSubtractEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeAdd:
return newAddEval(toEval(v.Left), toEval(v.Right))
return newAddEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeMult:
return newMultiplyEval(toEval(v.Left), toEval(v.Right))
return newMultiplyEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeContains:
return newContainsEval(toEval(v.Left), toEval(v.Right))
return newContainsEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeContainsAll:
return newContainsAllEval(toEval(v.Left), toEval(v.Right))
return newContainsAllEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeContainsAny:
return newContainsAnyEval(toEval(v.Left), toEval(v.Right))
return newContainsAnyEval(ToEval(v.Left), ToEval(v.Right))
case ast.NodeTypeIsEmpty:
return newIsEmptyEval(toEval(v.Arg))
return newIsEmptyEval(ToEval(v.Arg))
default:
panic(fmt.Sprintf("unknown node type %T", v))
}
Expand Down
6 changes: 3 additions & 3 deletions internal/eval/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ func TestToEval(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
e := toEval(tt.in.AsIsNode())
e := ToEval(tt.in.AsIsNode())
out, err := e.Eval(Env{
Principal: types.NewEntityUID("Actor", "principal"),
Action: types.NewEntityUID("Action", "test"),
Expand All @@ -389,13 +389,13 @@ func TestToEval(t *testing.T) {
func TestToEvalPanic(t *testing.T) {
t.Parallel()
testutil.Panic(t, func() {
_ = toEval(ast.Node{}.AsIsNode())
_ = ToEval(ast.Node{}.AsIsNode())
})
}

func TestToEvalVariablePanic(t *testing.T) {
t.Parallel()
testutil.Panic(t, func() {
_ = toEval(ast.NodeTypeVariable{Name: "bananas"})
_ = ToEval(ast.NodeTypeVariable{Name: "bananas"})
})
}
17 changes: 17 additions & 0 deletions x/exp/eval/eval.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Package eval provides a simple interface for evaluating a policy node in a given environment.
package eval

import (
"github.com/cedar-policy/cedar-go/internal/eval"
"github.com/cedar-policy/cedar-go/types"
"github.com/cedar-policy/cedar-go/x/exp/ast"
)

// Env is the environment for evaluating a policy.
type Env = eval.Env

// Eval evaluates a policy node in the given environment.
func Eval(n ast.IsNode, env Env) (types.Value, error) {
evaler := eval.ToEval(n)
return evaler.Eval(env)
}
Loading