Skip to content

Commit e9607c7

Browse files
committed
Improve type checking for $env
Fixes #462
1 parent 0354d1b commit e9607c7

File tree

3 files changed

+59
-17
lines changed

3 files changed

+59
-17
lines changed

ast/node.go

+18
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ type IdentifierNode struct {
5454
MethodIndex int // index of method, set only if Method is true
5555
}
5656

57+
func (n *IdentifierNode) SetFieldIndex(field []int) {
58+
n.FieldIndex = field
59+
}
60+
61+
func (n *IdentifierNode) SetMethodIndex(methodIndex int) {
62+
n.Method = true
63+
n.MethodIndex = methodIndex
64+
}
65+
5766
type IntegerNode struct {
5867
base
5968
Value int
@@ -111,6 +120,15 @@ type MemberNode struct {
111120
MethodIndex int
112121
}
113122

123+
func (n *MemberNode) SetFieldIndex(field []int) {
124+
n.FieldIndex = field
125+
}
126+
127+
func (n *MemberNode) SetMethodIndex(methodIndex int) {
128+
n.Method = true
129+
n.MethodIndex = methodIndex
130+
}
131+
114132
type SliceNode struct {
115133
base
116134
Node Node

checker/checker.go

+30-16
Original file line numberDiff line numberDiff line change
@@ -157,23 +157,34 @@ func (v *checker) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info)
157157
if node.Value == "$env" {
158158
return mapType, info{}
159159
}
160-
if fn, ok := v.config.Builtins[node.Value]; ok {
160+
return v.env(node, node.Value, true)
161+
}
162+
163+
type NodeWithIndexes interface {
164+
ast.Node
165+
SetFieldIndex(field []int)
166+
SetMethodIndex(methodIndex int)
167+
}
168+
169+
func (v *checker) env(node NodeWithIndexes, name string, strict bool) (reflect.Type, info) {
170+
if fn, ok := v.config.Builtins[name]; ok {
161171
return functionType, info{fn: fn}
162172
}
163-
if fn, ok := v.config.Functions[node.Value]; ok {
173+
if fn, ok := v.config.Functions[name]; ok {
164174
return functionType, info{fn: fn}
165175
}
166-
if t, ok := v.config.Types[node.Value]; ok {
176+
if t, ok := v.config.Types[name]; ok {
167177
if t.Ambiguous {
168-
return v.error(node, "ambiguous identifier %v", node.Value)
178+
return v.error(node, "ambiguous identifier %v", name)
179+
}
180+
node.SetFieldIndex(t.FieldIndex)
181+
if t.Method {
182+
node.SetMethodIndex(t.MethodIndex)
169183
}
170-
node.Method = t.Method
171-
node.MethodIndex = t.MethodIndex
172-
node.FieldIndex = t.FieldIndex
173184
return t.Type, info{method: t.Method}
174185
}
175-
if v.config.Strict {
176-
return v.error(node, "unknown name %v", node.Value)
186+
if v.config.Strict && strict {
187+
return v.error(node, "unknown name %v", name)
177188
}
178189
if v.config.DefaultType != nil {
179190
return v.config.DefaultType, info{}
@@ -433,12 +444,16 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
433444
prop, _ := v.visit(node.Property)
434445

435446
if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" {
436-
// If the index is a constant string, can save some
437-
// cycles later by finding the type of its referent
438447
if name, ok := node.Property.(*ast.StringNode); ok {
439-
if t, ok := v.config.Types[name.Value]; ok {
440-
return t.Type, info{method: t.Method}
441-
} // No error if no type found; it may be added to env between compile and run
448+
strict := v.config.Strict
449+
if node.Optional {
450+
// If user explicitly set optional flag, then we should not
451+
// throw error if field is not found (as user trying to handle
452+
// this case). But if user did not set optional flag, then we
453+
// should throw error if field is not found & v.config.Strict.
454+
strict = false
455+
}
456+
return v.env(node, name.Value, strict)
442457
}
443458
return anyType, info{}
444459
}
@@ -460,8 +475,7 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
460475
// the same interface.
461476
return m.Type, info{}
462477
} else {
463-
node.Method = true
464-
node.MethodIndex = m.Index
478+
node.SetMethodIndex(m.Index)
465479
node.Name = name.Value
466480
return m.Type, info{method: true}
467481
}

expr_test.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,7 @@ func TestEnv_keyword(t *testing.T) {
18841884
{"$env[red + irect]", 10},
18851885
{"$env['String Map']?.five", ""},
18861886
{"$env.red", "n"},
1887-
{"$env?.blue", nil},
1887+
{"$env?.unknown", nil},
18881888
{"$env.mylist[1]", 2},
18891889
{"$env?.OtherMap?.a", "b"},
18901890
{"$env?.OtherMap?.d", ""},
@@ -2102,3 +2102,13 @@ func TestIssue461(t *testing.T) {
21022102
})
21032103
}
21042104
}
2105+
2106+
func TestIssue462(t *testing.T) {
2107+
env := map[string]any{
2108+
"foo": func() (string, error) {
2109+
return "bar", nil
2110+
},
2111+
}
2112+
_, err := expr.Compile(`$env.unknown(int())`, expr.Env(env))
2113+
require.Error(t, err)
2114+
}

0 commit comments

Comments
 (0)