Skip to content

Expressions: coverage #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 48 additions & 24 deletions datalog/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ func (e *Expression) Evaluate(values map[Variable]*Term, symbols *SymbolTable) (
return nil, fmt.Errorf("datalog: expressions: unknown variable %d", id.(Variable))
}
id = *idptr
default: // do nothing
}
err := s.Push(id)
if err != nil {
return nil, fmt.Errorf("datalog: expressions: stack overflow")
}
s.Push(id)
case OpTypeUnary:
v, err := s.Pop()
if err != nil {
Expand All @@ -45,7 +49,10 @@ func (e *Expression) Evaluate(values map[Variable]*Term, symbols *SymbolTable) (
if err != nil {
return nil, fmt.Errorf("datalog: expressions: unary eval failed: %w", err)
}
s.Push(res)
err = s.Push(res)
if err != nil {
return nil, fmt.Errorf("datalog: expressions: stack overflow")
}
case OpTypeBinary:
right, err := s.Pop()
if err != nil {
Expand All @@ -60,7 +67,10 @@ func (e *Expression) Evaluate(values map[Variable]*Term, symbols *SymbolTable) (
if err != nil {
return nil, fmt.Errorf("datalog: expressions: binary eval failed: %w", err)
}
s.Push(res)
err = s.Push(res)
if err != nil {
return nil, fmt.Errorf("datalog: expressions: stack overflow")
}
default:
return nil, fmt.Errorf("datalog: expressions: unsupported Op: %v", op.Type())
}
Expand All @@ -83,22 +93,31 @@ func (e *Expression) Print(symbols *SymbolTable) string {
id := op.(Value).ID
switch id.Type() {
case TermTypeString:
s.Push(fmt.Sprintf("\"%s\"", symbols.Str(id.(String))))
err := s.Push(fmt.Sprintf("\"%s\"", symbols.Str(id.(String))))
if err != nil {
return "<invalid expression: stack overflow>"
}
case TermTypeVariable:
s.Push(fmt.Sprintf("$%s", symbols.Var(id.(Variable))))
err := s.Push(fmt.Sprintf("$%s", symbols.Var(id.(Variable))))
if err != nil {
return "<invalid expression: stack overflow>"
}
default:
s.Push(id.String())
err := s.Push(id.String())
if err != nil {
return "<invalid expression: stack overflow>"
}
}
case OpTypeUnary:
v, err := s.Pop()
if err != nil {
return "<invalid expression: unary operation failed to pop value>"
}
res := op.(UnaryOp).Print(v)
err = s.Push(res)
if err != nil {
return "<invalid expression: binary operation failed to pop right value>"
return "<invalid expression: stack overflow>"
}
s.Push(res)
case OpTypeBinary:
right, err := s.Pop()
if err != nil {
Expand All @@ -109,7 +128,10 @@ func (e *Expression) Print(symbols *SymbolTable) string {
return "<invalid expression: binary operation failed to pop left value>"
}
res := op.(BinaryOp).Print(left, right)
s.Push(res)
err = s.Push(res)
if err != nil {
return "<invalid expression: stack overflow>"
}
default:
return fmt.Sprintf("<invalid expression: unsupported op type %v>", op.Type())
}
Expand Down Expand Up @@ -160,6 +182,8 @@ func (op UnaryOp) Print(value string) string {
out = fmt.Sprintf("!%s", value)
case UnaryParens:
out = fmt.Sprintf("(%s)", value)
case UnaryLength:
out = fmt.Sprintf("%s.length()", value)
default:
out = fmt.Sprintf("unknown(%s)", value)
}
Expand All @@ -186,7 +210,7 @@ type Negate struct{}
func (Negate) Type() UnaryOpType {
return UnaryNegate
}
func (Negate) Eval(value Term, symbols *SymbolTable) (Term, error) {
func (Negate) Eval(value Term, _ *SymbolTable) (Term, error) {
var out Term
switch value.Type() {
case TermTypeBool:
Expand All @@ -206,7 +230,7 @@ type Parens struct{}
func (Parens) Type() UnaryOpType {
return UnaryParens
}
func (Parens) Eval(value Term, symbols *SymbolTable) (Term, error) {
func (Parens) Eval(value Term, _ *SymbolTable) (Term, error) {
return value, nil
}

Expand All @@ -228,7 +252,7 @@ func (Length) Eval(value Term, symbols *SymbolTable) (Term, error) {
case TermTypeSet:
out = Integer(len(value.(Set)))
default:
return nil, fmt.Errorf("datalog: unexpected Negate value type: %d", value.Type())
return nil, fmt.Errorf("datalog: unexpected Length value type: %d", value.Type())
}
return out, nil
}
Expand Down Expand Up @@ -318,7 +342,7 @@ type LessThan struct{}
func (LessThan) Type() BinaryOpType {
return BinaryLessThan
}
func (LessThan) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (LessThan) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: LessThan type mismatch: %d != %d", g, w)
}
Expand All @@ -344,7 +368,7 @@ type LessOrEqual struct{}
func (LessOrEqual) Type() BinaryOpType {
return BinaryLessOrEqual
}
func (LessOrEqual) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (LessOrEqual) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: LessOrEqual type mismatch: %d != %d", g, w)
}
Expand All @@ -370,7 +394,7 @@ type GreaterThan struct{}
func (GreaterThan) Type() BinaryOpType {
return BinaryGreaterThan
}
func (GreaterThan) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (GreaterThan) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: GreaterThan type mismatch: %d != %d", g, w)
}
Expand All @@ -396,7 +420,7 @@ type GreaterOrEqual struct{}
func (GreaterOrEqual) Type() BinaryOpType {
return BinaryGreaterOrEqual
}
func (GreaterOrEqual) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (GreaterOrEqual) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: GreaterOrEqual type mismatch: %d != %d", g, w)
}
Expand All @@ -422,7 +446,7 @@ type Equal struct{}
func (Equal) Type() BinaryOpType {
return BinaryEqual
}
func (Equal) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Equal) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
if g, w := left.Type(), right.Type(); g != w {
return nil, fmt.Errorf("datalog: Equal type mismatch: %d != %d", g, w)
}
Expand Down Expand Up @@ -510,7 +534,7 @@ type Intersection struct{}
func (Intersection) Type() BinaryOpType {
return BinaryIntersection
}
func (Intersection) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Intersection) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
set, ok := left.(Set)
if !ok {
return nil, errors.New("datalog: Intersection left value must be a Set")
Expand All @@ -530,7 +554,7 @@ type Union struct{}
func (Union) Type() BinaryOpType {
return BinaryUnion
}
func (Union) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Union) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
set, ok := left.(Set)
if !ok {
return nil, errors.New("datalog: Union left value must be a Set")
Expand Down Expand Up @@ -654,7 +678,7 @@ type Sub struct{}
func (Sub) Type() BinaryOpType {
return BinarySub
}
func (Sub) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Sub) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
ileft, ok := left.(Integer)
if !ok {
return nil, fmt.Errorf("datalog: Sub requires left value to be an Integer, got %T", left)
Expand Down Expand Up @@ -682,7 +706,7 @@ type Mul struct{}
func (Mul) Type() BinaryOpType {
return BinaryMul
}
func (Mul) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Mul) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
ileft, ok := left.(Integer)
if !ok {
return nil, fmt.Errorf("datalog: Mul requires left value to be an Integer, got %T", left)
Expand Down Expand Up @@ -711,7 +735,7 @@ type Div struct{}
func (Div) Type() BinaryOpType {
return BinaryDiv
}
func (Div) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Div) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
ileft, ok := left.(Integer)
if !ok {
return nil, fmt.Errorf("datalog: Div requires left value to be an Integer, got %T", left)
Expand All @@ -735,7 +759,7 @@ type And struct{}
func (And) Type() BinaryOpType {
return BinaryAnd
}
func (And) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (And) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
bleft, ok := left.(Bool)
if !ok {
return nil, fmt.Errorf("datalog: And requires left value to be a Bool, got %T", left)
Expand All @@ -755,7 +779,7 @@ type Or struct{}
func (Or) Type() BinaryOpType {
return BinaryOr
}
func (Or) Eval(left Term, right Term, symbols *SymbolTable) (Term, error) {
func (Or) Eval(left Term, right Term, _ *SymbolTable) (Term, error) {
bleft, ok := left.(Bool)
if !ok {
return nil, fmt.Errorf("datalog: Or requires left value to be a Bool, got %T", left)
Expand Down
49 changes: 49 additions & 0 deletions datalog/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1188,3 +1188,52 @@ func TestBinaryOr(t *testing.T) {
})
}
}

func TestPrint(t *testing.T) {
syms := SymbolTable{}
syms.Insert("abc")
testCases := []struct {
desc string
expr Expression
res string
}{
{
desc: "number",
expr: Expression{Value{Integer(9)}},
res: "9",
},
{
desc: "string",
expr: Expression{Value{syms.Sym("abc")}},
res: "\"abc\"",
},
{
desc: "unary",
expr: Expression{Value{syms.Sym("abc")}, UnaryOp{Length{}}},
res: "\"abc\".length()",
},
{
desc: "binary",
expr: Expression{Value{Integer(9)}, Value{Integer(4)}, BinaryOp{Mul{}}},
res: "9 * 4",
},
{
desc: "parens",
expr: Expression{
Value{Integer(9)},
Value{Integer(3)},
BinaryOp{Add{}},
UnaryOp{Parens{}},
Value{Integer(4)},
BinaryOp{Div{}},
},
res: "(9 + 3) / 4",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
p := tc.expr.Print(&syms)
require.Equal(t, tc.res, p)
})
}
}
Loading