Skip to content

Commit 988b336

Browse files
committed
tools/syz-declextract: refine arg types for syscall variants
Use scope-based dataflow analysis for syscall variants (including ioctls). As the result we only consider code that relates to a partiuclar command/ioctl, and can infer arguments/return types for each command/ioctl independently.
1 parent 16f995f commit 988b336

File tree

6 files changed

+444
-360
lines changed

6 files changed

+444
-360
lines changed

pkg/declextract/declextract.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -161,27 +161,35 @@ func (ctx *context) processSyscalls() {
161161
var syscalls []*Syscall
162162
for _, call := range ctx.Syscalls {
163163
ctx.processFields(call.Args, "", false)
164-
call.returnType = ctx.inferReturnType(call.Func, call.SourceFile)
165-
for i, arg := range call.Args {
166-
typ := ctx.inferArgType(call.Func, call.SourceFile, i)
167-
refineFieldType(arg, typ, false)
168-
}
169-
ctx.emitSyscall(&syscalls, call, "")
170-
for i := range call.Args {
171-
cmds := ctx.inferCommandVariants(call.Func, call.SourceFile, i)
164+
for varArg := range call.Args {
165+
cmds := ctx.inferCommandVariants(call.Func, call.SourceFile, varArg)
172166
for _, cmd := range cmds {
173167
variant := *call
174168
variant.Args = slices.Clone(call.Args)
175-
newArg := *variant.Args[i]
176-
newArg.syzType = fmt.Sprintf("const[%v]", cmd)
177-
variant.Args[i] = &newArg
169+
for i, oldArg := range variant.Args {
170+
arg := *oldArg
171+
if i == varArg {
172+
arg.syzType = fmt.Sprintf("const[%v]", cmd)
173+
} else {
174+
typ := ctx.inferArgType(call.Func, call.SourceFile, i, varArg, cmd)
175+
refineFieldType(&arg, typ, false)
176+
}
177+
variant.Args[i] = &arg
178+
}
179+
variant.returnType = ctx.inferReturnType(call.Func, call.SourceFile, varArg, cmd)
178180
suffix := cmd
179181
if call.Func == "__do_sys_ioctl" {
180182
suffix = ctx.uniqualize("ioctl cmd", cmd)
181183
}
182184
ctx.emitSyscall(&syscalls, &variant, "_"+suffix)
183185
}
184186
}
187+
call.returnType = ctx.inferReturnType(call.Func, call.SourceFile, -1, "")
188+
for i, arg := range call.Args {
189+
typ := ctx.inferArgType(call.Func, call.SourceFile, i, -1, "")
190+
refineFieldType(arg, typ, false)
191+
}
192+
ctx.emitSyscall(&syscalls, call, "")
185193
}
186194
ctx.Syscalls = sortAndDedupSlice(syscalls)
187195
}

pkg/declextract/entity.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ type FunctionScope struct {
4545
LOC int `json:"loc,omitempty"`
4646
Calls []string `json:"calls,omitempty"`
4747
Facts []*TypingFact `json:"facts,omitempty"`
48+
49+
fn *Function
4850
}
4951

5052
type ConstInfo struct {

pkg/declextract/fileops.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,15 @@ func (ctx *context) createFops(fops *FileOps, files []string) {
6161
}
6262

6363
func (ctx *context) createIoctls(fops *FileOps, suffix, fdt string) {
64-
const defaultArgType = "ptr[in, array[int8]]"
65-
cmds := ctx.inferCommandVariants(fops.Ioctl, fops.SourceFile, 1)
64+
const (
65+
cmdArg = 1
66+
argArg = 2
67+
defaultArgType = "ptr[in, array[int8]]"
68+
)
69+
cmds := ctx.inferCommandVariants(fops.Ioctl, fops.SourceFile, cmdArg)
6670
if len(cmds) == 0 {
67-
retType := ctx.inferReturnType(fops.Ioctl, fops.SourceFile)
68-
argType := ctx.inferArgType(fops.Ioctl, fops.SourceFile, 2)
71+
retType := ctx.inferReturnType(fops.Ioctl, fops.SourceFile, -1, "")
72+
argType := ctx.inferArgType(fops.Ioctl, fops.SourceFile, argArg, -1, "")
6973
if argType == "" {
7074
argType = defaultArgType
7175
}
@@ -80,10 +84,16 @@ func (ctx *context) createIoctls(fops *FileOps, suffix, fdt string) {
8084
Type: typ,
8185
}
8286
argType = ctx.fieldType(f, nil, "", false)
87+
} else {
88+
argType = ctx.inferArgType(fops.Ioctl, fops.SourceFile, argArg, cmdArg, cmd)
89+
if argType == "" {
90+
argType = defaultArgType
91+
}
8392
}
93+
retType := ctx.inferReturnType(fops.Ioctl, fops.SourceFile, cmdArg, cmd)
8494
name := ctx.uniqualize("ioctl cmd", cmd)
85-
ctx.fmt("ioctl%v_%v(fd %v, cmd const[%v], arg %v)\n",
86-
autoSuffix, name, fdt, cmd, argType)
95+
ctx.fmt("ioctl%v_%v(fd %v, cmd const[%v], arg %v) %v\n",
96+
autoSuffix, name, fdt, cmd, argType, retType)
8797
}
8898
}
8999

pkg/declextract/typing.go

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ import (
3434
// - Infer that pointers are file names (they should flow to some known function for path resolution).
3535
// - Use SSA analysis to track flow via local variables better. Potentiall we can just rename on every next use
3636
// and ignore backwards edges (it's unlikely that backwards edges are required for type inference).
37-
// - Infer ioctl commands in transitively called functions using data flow.
3837
// - Infer file_operations associated with an fd by tracking flow to alloc_file_pseudo and friends.
39-
// - Add context-sensitivity at least on switched arguments (ioctl commands).
40-
// - Infer other switched arguments besides ioctl commands.
4138
// - Infer netlink arg types by tracking flow from genl_info::attrs[ATTR_FOO].
4239
// - Infer simple constraints on arguments, e.g. "if (arg != 0) return -EINVAL".
4340
// - Use kernel typedefs for typing (e.g. pid_t). We can use them for uapi structs, but also for kernel
@@ -48,6 +45,10 @@ import (
4845
// For example, these cases lead to false inference of fd type for returned value:
4946
// https://elixir.bootlin.com/linux/v6.13-rc2/source/net/core/sock.c#L1870
5047
// https://elixir.bootlin.com/linux/v6.13-rc2/source/net/socket.c#L1742
48+
// - Use const[0] for unused arguments. If an arg is unused, or only flows to functions where it's unused,
49+
// we can consider it as unused.
50+
// - Detect common patterns for "must be 0" or "must be const" arguments, e.g.:
51+
// if (flags != 0) return -EINVAL;
5152

5253
var (
5354
// Refines types based on data flows...
@@ -96,7 +97,7 @@ type typingNode struct {
9697
id string
9798
fn *Function
9899
arg int
99-
flows [2]map[*typingNode]bool
100+
flows [2]map[*typingNode][]*FunctionScope
100101
}
101102

102103
const (
@@ -107,14 +108,16 @@ const (
107108
func (ctx *context) processTypingFacts() {
108109
for _, fn := range ctx.Functions {
109110
for _, scope := range fn.Scopes {
111+
scope.fn = fn
110112
for _, fact := range scope.Facts {
111113
src := ctx.canonicalNode(fn, fact.Src)
112114
dst := ctx.canonicalNode(fn, fact.Dst)
113115
if src == nil || dst == nil {
114116
continue
115117
}
116-
src.flows[flowTo][dst] = true
117-
dst.flows[flowFrom][src] = true
118+
119+
src.flows[flowTo][dst] = append(src.flows[flowTo][dst], scope)
120+
dst.flows[flowFrom][src] = append(dst.flows[flowFrom][src], scope)
118121
}
119122
}
120123
}
@@ -156,7 +159,7 @@ func (ctx *context) canonicalNode(fn *Function, ent *TypingEntity) *typingNode {
156159
arg: arg,
157160
}
158161
for i := range n.flows {
159-
n.flows[i] = make(map[*typingNode]bool)
162+
n.flows[i] = make(map[*typingNode][]*FunctionScope)
160163
}
161164
facts[id] = n
162165
return n
@@ -179,35 +182,43 @@ func (ent *TypingEntity) ID(fn *Function) (string, string) {
179182
}
180183
}
181184

182-
func (ctx *context) inferReturnType(name, file string) string {
183-
return ctx.inferFuncNode(name, file, "ret")
185+
func (ctx *context) inferReturnType(name, file string, scopeArg int, scopeVal string) string {
186+
return ctx.inferFuncNode(name, file, "ret", scopeArg, scopeVal)
187+
}
188+
189+
func (ctx *context) inferArgType(name, file string, arg, scopeArg int, scopeVal string) string {
190+
return ctx.inferFuncNode(name, file, fmt.Sprintf("arg%v", arg), scopeArg, scopeVal)
184191
}
185192

186-
func (ctx *context) inferArgType(name, file string, arg int) string {
187-
return ctx.inferFuncNode(name, file, fmt.Sprintf("arg%v", arg))
193+
type fnArg struct {
194+
fn *Function
195+
arg int
188196
}
189197

190-
func (ctx *context) inferFuncNode(name, file, node string) string {
198+
func (ctx *context) inferFuncNode(name, file, node string, scopeArg int, scopeVal string) string {
191199
fn := ctx.findFunc(name, file)
192200
if fn == nil {
193201
return ""
194202
}
195-
return ctx.inferNodeType(fn.facts[node], fmt.Sprintf("%v %v", name, node))
203+
scopeFnArgs := ctx.inferArgFlow(fnArg{fn, scopeArg})
204+
return ctx.inferNodeType(fn.facts[node], scopeFnArgs, scopeVal, fmt.Sprintf("%v %v", name, node))
196205
}
197206

198207
func (ctx *context) inferFieldType(structName, field string) string {
199208
name := fmt.Sprintf("%v.%v", structName, field)
200-
return ctx.inferNodeType(ctx.facts[name], name)
209+
return ctx.inferNodeType(ctx.facts[name], nil, "", name)
201210
}
202211

203-
func (ctx *context) inferNodeType(n *typingNode, what string) string {
212+
func (ctx *context) inferNodeType(n *typingNode, scopeFnArgs map[fnArg]bool, scopeVal, what string) string {
204213
if n == nil {
205214
return ""
206215
}
207216
ic := &inferContext{
208-
visited: make(map[*typingNode]bool),
209-
flowType: flowFrom,
210-
maxDepth: maxTraversalDepth,
217+
scopeFnArgs: scopeFnArgs,
218+
scopeVal: scopeVal,
219+
visited: make(map[*typingNode]bool),
220+
flowType: flowFrom,
221+
maxDepth: maxTraversalDepth,
211222
}
212223
ic.walk(n)
213224
ic.flowType = flowTo
@@ -220,13 +231,15 @@ func (ctx *context) inferNodeType(n *typingNode, what string) string {
220231
}
221232

222233
type inferContext struct {
223-
path []*typingNode
224-
visited map[*typingNode]bool
225-
result string
226-
resultPath []*typingNode
227-
resultFlow int
228-
flowType int
229-
maxDepth int
234+
path []*typingNode
235+
visited map[*typingNode]bool
236+
scopeFnArgs map[fnArg]bool
237+
scopeVal string
238+
result string
239+
resultPath []*typingNode
240+
resultFlow int
241+
flowType int
242+
maxDepth int
230243
}
231244

232245
func (ic *inferContext) walk(n *typingNode) {
@@ -246,13 +259,39 @@ func (ic *inferContext) walk(n *typingNode) {
246259
}
247260
}
248261
if len(ic.path) < ic.maxDepth {
249-
for e := range n.flows[ic.flowType] {
250-
ic.walk(e)
262+
for e, scopes := range n.flows[ic.flowType] {
263+
if ic.relevantScope(scopes) {
264+
ic.walk(e)
265+
}
251266
}
252267
}
253268
ic.path = ic.path[:len(ic.path)-1]
254269
}
255270

271+
func (ic *inferContext) relevantScope(scopes []*FunctionScope) bool {
272+
if ic.scopeFnArgs == nil {
273+
// We are not doing scope-limited walk, so all scopes are relevant.
274+
return true
275+
}
276+
for _, scope := range scopes {
277+
if scope.Arg == -1 {
278+
// Always use global scope.
279+
return true
280+
}
281+
if !ic.scopeFnArgs[fnArg{scope.fn, scope.Arg}] {
282+
// The scope argument is not related to the current scope.
283+
return true
284+
}
285+
// For the scope argument, check that it has the right value.
286+
for _, val := range scope.Values {
287+
if val == ic.scopeVal {
288+
return true
289+
}
290+
}
291+
}
292+
return false
293+
}
294+
256295
func refineFieldType(f *Field, typ string, preserveSize bool) {
257296
// If our manual heuristics have figured out a more precise fd subtype,
258297
// don't replace it with generic fd.
@@ -319,3 +358,28 @@ func (ctx *context) walkCommandVariants(n *typingNode, variants *[]string, visit
319358
ctx.walkCommandVariants(e, variants, visited, depth+1)
320359
}
321360
}
361+
362+
// inferArgFlow returns transitive closure of all function arguments that the given argument flows to.
363+
func (ctx *context) inferArgFlow(arg fnArg) map[fnArg]bool {
364+
n := arg.fn.facts[fmt.Sprintf("arg%v", arg.arg)]
365+
if n == nil {
366+
return nil
367+
}
368+
fnArgs := make(map[fnArg]bool)
369+
visited := make(map[*typingNode]bool)
370+
ctx.walkArgFlow(n, fnArgs, visited, 0)
371+
return fnArgs
372+
}
373+
374+
func (ctx *context) walkArgFlow(n *typingNode, fnArgs map[fnArg]bool, visited map[*typingNode]bool, depth int) {
375+
if visited[n] || depth >= 10 {
376+
return
377+
}
378+
visited[n] = true
379+
if n.arg >= 0 {
380+
fnArgs[fnArg{n.fn, n.arg}] = true
381+
}
382+
for e := range n.flows[flowTo] {
383+
ctx.walkArgFlow(e, fnArgs, visited, depth+1)
384+
}
385+
}

0 commit comments

Comments
 (0)