diff --git a/internal/eval/compile.go b/internal/eval/compile.go index 599c349c..41c57b51 100644 --- a/internal/eval/compile.go +++ b/internal/eval/compile.go @@ -26,7 +26,7 @@ func (e *BoolEvaler) Eval(env Env) (types.Boolean, error) { func Compile(p *ast.Policy) BoolEvaler { p = foldPolicy(p) node := policyToNode(p).AsIsNode() - return BoolEvaler{eval: toEval(node)} + return BoolEvaler{eval: ToEval(node)} } func policyToNode(p *ast.Policy) ast.Node { diff --git a/internal/eval/convert.go b/internal/eval/convert.go index 27420504..86dca61f 100644 --- a/internal/eval/convert.go +++ b/internal/eval/convert.go @@ -8,28 +8,28 @@ import ( "github.com/cedar-policy/cedar-go/x/exp/ast" ) -func toEval(n ast.IsNode) Evaler { +func ToEval(n ast.IsNode) Evaler { switch v := n.(type) { case ast.NodeTypeAccess: - return newAttributeAccessEval(toEval(v.Arg), v.Value) + return newAttributeAccessEval(ToEval(v.Arg), v.Value) case ast.NodeTypeHas: - return newHasEval(toEval(v.Arg), v.Value) + return newHasEval(ToEval(v.Arg), v.Value) case ast.NodeTypeGetTag: - return newGetTagEval(toEval(v.Left), toEval(v.Right)) + return newGetTagEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeHasTag: - return newHasTagEval(toEval(v.Left), toEval(v.Right)) + return newHasTagEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeLike: - return newLikeEval(toEval(v.Arg), v.Value) + return newLikeEval(ToEval(v.Arg), v.Value) case ast.NodeTypeIfThenElse: - return newIfThenElseEval(toEval(v.If), toEval(v.Then), toEval(v.Else)) + return newIfThenElseEval(ToEval(v.If), ToEval(v.Then), ToEval(v.Else)) case ast.NodeTypeIs: - return newIsEval(toEval(v.Left), v.EntityType) + return newIsEval(ToEval(v.Left), v.EntityType) case ast.NodeTypeIsIn: - return newIsInEval(toEval(v.Left), v.EntityType, toEval(v.Entity)) + return newIsInEval(ToEval(v.Left), v.EntityType, ToEval(v.Entity)) case ast.NodeTypeExtensionCall: args := make([]Evaler, len(v.Args)) for i, a := range v.Args { - args[i] = toEval(a) + args[i] = ToEval(a) } return newExtensionEval(v.Name, args) case ast.NodeValue: @@ -37,19 +37,19 @@ func toEval(n ast.IsNode) Evaler { case ast.NodeTypeRecord: m := make(map[types.String]Evaler, len(v.Elements)) for _, e := range v.Elements { - m[e.Key] = toEval(e.Value) + m[e.Key] = ToEval(e.Value) } return newRecordLiteralEval(m) case ast.NodeTypeSet: s := make([]Evaler, len(v.Elements)) for i, e := range v.Elements { - s[i] = toEval(e) + s[i] = ToEval(e) } return newSetLiteralEval(s) case ast.NodeTypeNegate: - return newNegateEval(toEval(v.Arg)) + return newNegateEval(ToEval(v.Arg)) case ast.NodeTypeNot: - return newNotEval(toEval(v.Arg)) + return newNotEval(ToEval(v.Arg)) case ast.NodeTypeVariable: switch v.Name { case consts.Principal, consts.Action, consts.Resource, consts.Context: @@ -58,37 +58,37 @@ func toEval(n ast.IsNode) Evaler { panic(fmt.Errorf("unknown variable: %v", v.Name)) } case ast.NodeTypeIn: - return newInEval(toEval(v.Left), toEval(v.Right)) + return newInEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeAnd: - return newAndEval(toEval(v.Left), toEval(v.Right)) + return newAndEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeOr: - return newOrEval(toEval(v.Left), toEval(v.Right)) + return newOrEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeEquals: - return newEqualEval(toEval(v.Left), toEval(v.Right)) + return newEqualEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeNotEquals: - return newNotEqualEval(toEval(v.Left), toEval(v.Right)) + return newNotEqualEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeGreaterThan: - return newComparableValueGreaterThanEval(toEval(v.Left), toEval(v.Right)) + return newComparableValueGreaterThanEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeGreaterThanOrEqual: - return newComparableValueGreaterThanOrEqualEval(toEval(v.Left), toEval(v.Right)) + return newComparableValueGreaterThanOrEqualEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeLessThan: - return newComparableValueLessThanEval(toEval(v.Left), toEval(v.Right)) + return newComparableValueLessThanEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeLessThanOrEqual: - return newComparableValueLessThanOrEqualEval(toEval(v.Left), toEval(v.Right)) + return newComparableValueLessThanOrEqualEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeSub: - return newSubtractEval(toEval(v.Left), toEval(v.Right)) + return newSubtractEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeAdd: - return newAddEval(toEval(v.Left), toEval(v.Right)) + return newAddEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeMult: - return newMultiplyEval(toEval(v.Left), toEval(v.Right)) + return newMultiplyEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeContains: - return newContainsEval(toEval(v.Left), toEval(v.Right)) + return newContainsEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeContainsAll: - return newContainsAllEval(toEval(v.Left), toEval(v.Right)) + return newContainsAllEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeContainsAny: - return newContainsAnyEval(toEval(v.Left), toEval(v.Right)) + return newContainsAnyEval(ToEval(v.Left), ToEval(v.Right)) case ast.NodeTypeIsEmpty: - return newIsEmptyEval(toEval(v.Arg)) + return newIsEmptyEval(ToEval(v.Arg)) default: panic(fmt.Sprintf("unknown node type %T", v)) } diff --git a/internal/eval/convert_test.go b/internal/eval/convert_test.go index 318aa7a3..8004ad75 100644 --- a/internal/eval/convert_test.go +++ b/internal/eval/convert_test.go @@ -367,7 +367,7 @@ func TestToEval(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - e := toEval(tt.in.AsIsNode()) + e := ToEval(tt.in.AsIsNode()) out, err := e.Eval(Env{ Principal: types.NewEntityUID("Actor", "principal"), Action: types.NewEntityUID("Action", "test"), @@ -389,13 +389,13 @@ func TestToEval(t *testing.T) { func TestToEvalPanic(t *testing.T) { t.Parallel() testutil.Panic(t, func() { - _ = toEval(ast.Node{}.AsIsNode()) + _ = ToEval(ast.Node{}.AsIsNode()) }) } func TestToEvalVariablePanic(t *testing.T) { t.Parallel() testutil.Panic(t, func() { - _ = toEval(ast.NodeTypeVariable{Name: "bananas"}) + _ = ToEval(ast.NodeTypeVariable{Name: "bananas"}) }) } diff --git a/x/exp/eval/eval.go b/x/exp/eval/eval.go new file mode 100644 index 00000000..889f1ff4 --- /dev/null +++ b/x/exp/eval/eval.go @@ -0,0 +1,17 @@ +// Package eval provides a simple interface for evaluating a policy node in a given environment. +package eval + +import ( + "github.com/cedar-policy/cedar-go/internal/eval" + "github.com/cedar-policy/cedar-go/types" + "github.com/cedar-policy/cedar-go/x/exp/ast" +) + +// Env is the environment for evaluating a policy. +type Env = eval.Env + +// Eval evaluates a policy node in the given environment. +func Eval(n ast.IsNode, env Env) (types.Value, error) { + evaler := eval.ToEval(n) + return evaler.Eval(env) +} diff --git a/x/exp/eval/eval_test.go b/x/exp/eval/eval_test.go new file mode 100644 index 00000000..ef9630dd --- /dev/null +++ b/x/exp/eval/eval_test.go @@ -0,0 +1,400 @@ +package eval + +import ( + "net/netip" + "testing" + "time" + + "github.com/cedar-policy/cedar-go/internal/testutil" + "github.com/cedar-policy/cedar-go/types" + "github.com/cedar-policy/cedar-go/x/exp/ast" +) + +func TestToEval(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in ast.Node + out types.Value + err func(testutil.TB, error) + }{ + { + "access", + ast.Value(types.NewRecord(types.RecordMap{"key": types.Long(42)})).Access("key"), + types.Long(42), + testutil.OK, + }, + { + "has", + ast.Value(types.NewRecord(types.RecordMap{"key": types.Long(42)})).Has("key"), + types.True, + testutil.OK, + }, + { + "getTag", + ast.EntityUID("T", "ID").GetTag(ast.String("key")), + types.Long(42), + testutil.OK, + }, + { + "hasTag", + ast.EntityUID("T", "ID").HasTag(ast.String("key")), + types.True, + testutil.OK, + }, + { + "like", + ast.String("test").Like(types.Pattern{}), + types.False, + testutil.OK, + }, + { + "if", + ast.IfThenElse(ast.True(), ast.Long(42), ast.Long(43)), + types.Long(42), + testutil.OK, + }, + { + "is", + ast.EntityUID("T", "42").Is("T"), + types.True, + testutil.OK, + }, + { + "isIn", + ast.EntityUID("T", "42").IsIn("T", ast.EntityUID("T", "42")), + types.True, + testutil.OK, + }, + { + "value", + ast.Long(42), + types.Long(42), + testutil.OK, + }, + { + "record", + ast.Record(ast.Pairs{{Key: "key", Value: ast.Long(42)}}), + types.NewRecord(types.RecordMap{"key": types.Long(42)}), + testutil.OK, + }, + { + "set", + ast.Set(ast.Long(42)), + types.NewSet(types.Long(42)), + testutil.OK, + }, + { + "negate", + ast.Negate(ast.Long(42)), + types.Long(-42), + testutil.OK, + }, + { + "not", + ast.Not(ast.True()), + types.False, + testutil.OK, + }, + { + "principal", + ast.Principal(), + types.NewEntityUID("Actor", "principal"), + testutil.OK, + }, + { + "action", + ast.Action(), + types.NewEntityUID("Action", "test"), + testutil.OK, + }, + { + "resource", + ast.Resource(), + types.NewEntityUID("Resource", "database"), + testutil.OK, + }, + { + "context", + ast.Context(), + types.Record{}, + testutil.OK, + }, + { + "in", + ast.EntityUID("T", "42").In(ast.EntityUID("T", "43")), + types.False, + testutil.OK, + }, + { + "and", + ast.True().And(ast.False()), + types.False, + testutil.OK, + }, + { + "or", + ast.True().Or(ast.False()), + types.True, + testutil.OK, + }, + { + "equals", + ast.Long(42).Equal(ast.Long(43)), + types.False, + testutil.OK, + }, + { + "notEquals", + ast.Long(42).NotEqual(ast.Long(43)), + types.True, + testutil.OK, + }, + { + "greaterThan", + ast.Long(42).GreaterThan(ast.Long(43)), + types.False, + testutil.OK, + }, + { + "greaterThanOrEqual", + ast.Long(42).GreaterThanOrEqual(ast.Long(43)), + types.False, + testutil.OK, + }, + { + "lessThan", + ast.Long(42).LessThan(ast.Long(43)), + types.True, + testutil.OK, + }, + { + "lessThanOrEqual", + ast.Long(42).LessThanOrEqual(ast.Long(43)), + types.True, + testutil.OK, + }, + { + "sub", + ast.Long(42).Subtract(ast.Long(2)), + types.Long(40), + testutil.OK, + }, + { + "add", + ast.Long(40).Add(ast.Long(2)), + types.Long(42), + testutil.OK, + }, + { + "mult", + ast.Long(6).Multiply(ast.Long(7)), + types.Long(42), + testutil.OK, + }, + { + "contains", + ast.Value(types.NewSet(types.Long(42))).Contains(ast.Long(42)), + types.True, + testutil.OK, + }, + { + "containsAll", + ast.Value(types.NewSet(types.Long(42), types.Long(43), types.Long(44))).ContainsAll(ast.Value(types.NewSet(types.Long(42), types.Long(43)))), + types.True, + testutil.OK, + }, + { + "containsAny", + ast.Value(types.NewSet(types.Long(42), types.Long(43), types.Long(44))).ContainsAny(ast.Value(types.NewSet(types.Long(1), types.Long(42)))), + types.True, + testutil.OK, + }, + { + "isEmpty", + ast.Value(types.NewSet(types.Long(42), types.Long(43), types.Long(44))).IsEmpty(), + types.False, + testutil.OK, + }, + { + "ip", + ast.ExtensionCall("ip", ast.String("127.0.0.42/16")), + types.IPAddr(netip.MustParsePrefix("127.0.0.42/16")), + testutil.OK, + }, + { + "decimal", + ast.ExtensionCall("decimal", ast.String("42.42")), + testutil.Must(types.NewDecimal(4242, -2)), + testutil.OK, + }, + { + "datetime", + ast.ExtensionCall("datetime", ast.String("1970-01-01T00:00:00.001Z")), + types.NewDatetime(time.UnixMilli(1)), + testutil.OK, + }, + { + "duration", + ast.ExtensionCall("duration", ast.String("1ms")), + types.NewDuration(1 * time.Millisecond), + testutil.OK, + }, + { + "toDate", + ast.ExtensionCall("toDate", ast.Value(types.NewDatetime(time.UnixMilli(1)))), + types.NewDatetime(time.UnixMilli(0)), + testutil.OK, + }, + { + "toTime", + ast.ExtensionCall("toTime", ast.Value(types.NewDatetime(time.UnixMilli(1)))), + types.NewDuration(1 * time.Millisecond), + testutil.OK, + }, + { + "toDays", + ast.ExtensionCall("toDays", ast.Value(types.NewDuration(time.Duration(0)))), + types.Long(0), + testutil.OK, + }, + { + "toHours", + ast.ExtensionCall("toHours", ast.Value(types.NewDuration(time.Duration(0)))), + types.Long(0), + testutil.OK, + }, + { + "toMinutes", + ast.ExtensionCall("toMinutes", ast.Value(types.NewDuration(time.Duration(0)))), + types.Long(0), + testutil.OK, + }, + { + "toSeconds", + ast.ExtensionCall("toSeconds", ast.Value(types.NewDuration(time.Duration(0)))), + types.Long(0), + testutil.OK, + }, + { + "toMilliseconds", + ast.ExtensionCall("toMilliseconds", ast.Value(types.NewDuration(time.Duration(0)))), + types.Long(0), + testutil.OK, + }, + { + "offset", + ast.ExtensionCall("offset", ast.Value(types.NewDatetime(time.UnixMilli(0))), ast.Value(types.NewDuration(1*time.Millisecond))), + types.NewDatetime(time.UnixMilli(1)), + testutil.OK, + }, + { + "durationSince", + ast.ExtensionCall("durationSince", ast.Value(types.NewDatetime(time.UnixMilli(1))), ast.Value(types.NewDatetime(time.UnixMilli(1)))), + types.NewDuration(time.Duration(0)), + testutil.OK, + }, + + { + "lessThan", + ast.ExtensionCall("lessThan", ast.Value(testutil.Must(types.NewDecimal(42, 0))), ast.Value(testutil.Must(types.NewDecimalFromInt(43)))), + types.True, + testutil.OK, + }, + { + "lessThanOrEqual", + ast.ExtensionCall("lessThanOrEqual", ast.Value(testutil.Must(types.NewDecimal(42, 0))), ast.Value(testutil.Must(types.NewDecimalFromInt(43)))), + types.True, + testutil.OK, + }, + { + "greaterThan", + ast.ExtensionCall("greaterThan", ast.Value(testutil.Must(types.NewDecimal(42, 0))), ast.Value(testutil.Must(types.NewDecimalFromInt(43)))), + types.False, + testutil.OK, + }, + { + "greaterThanOrEqual", + ast.ExtensionCall("greaterThanOrEqual", ast.Value(testutil.Must(types.NewDecimal(42, 0))), ast.Value(testutil.Must(types.NewDecimalFromInt(43)))), + types.False, + testutil.OK, + }, + { + "isIpv4", + ast.ExtensionCall("isIpv4", ast.IPAddr(netip.MustParsePrefix("127.0.0.42/16"))), + types.True, + testutil.OK, + }, + { + "isIpv6", + ast.ExtensionCall("isIpv6", ast.IPAddr(netip.MustParsePrefix("::1/16"))), + types.True, + testutil.OK, + }, + { + "isLoopback", + ast.ExtensionCall("isLoopback", ast.IPAddr(netip.MustParsePrefix("127.0.0.1/32"))), + types.True, + testutil.OK, + }, + { + "isMulticast", + ast.ExtensionCall("isMulticast", ast.IPAddr(netip.MustParsePrefix("239.255.255.255/32"))), + types.True, + testutil.OK, + }, + { + "isInRange", + ast.ExtensionCall("isInRange", ast.IPAddr(netip.MustParsePrefix("127.0.0.42/32")), ast.IPAddr(netip.MustParsePrefix("127.0.0.0/16"))), + types.True, + testutil.OK, + }, + { + "extUnknown", + ast.ExtensionCall("unknown", ast.String("hello")), + nil, + testutil.Error, + }, + { + "extArgs", + ast.ExtensionCall("ip", ast.String("1"), ast.String("2")), + nil, + testutil.Error, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + out, err := Eval(tt.in.AsIsNode(), Env{ + Principal: types.NewEntityUID("Actor", "principal"), + Action: types.NewEntityUID("Action", "test"), + Resource: types.NewEntityUID("Resource", "database"), + Context: types.Record{}, + Entities: types.EntityMap{ + types.NewEntityUID("T", "ID"): types.Entity{ + Tags: types.NewRecord(types.RecordMap{"key": types.Long(42)}), + }, + }, + }) + tt.err(t, err) + testutil.Equals(t, out, tt.out) + }) + } + +} + +func TestToEvalPanic(t *testing.T) { + t.Parallel() + testutil.Panic(t, func() { + _, _ = Eval(ast.Node{}.AsIsNode(), Env{}) + }) +} + +func TestToEvalVariablePanic(t *testing.T) { + t.Parallel() + testutil.Panic(t, func() { + _, _ = Eval(ast.NodeTypeVariable{Name: "bananas"}, Env{}) + }) +}