Skip to content

Commit 88f9e9b

Browse files
committed
Improve type inference
1 parent 70ac55c commit 88f9e9b

File tree

5 files changed

+41
-14
lines changed

5 files changed

+41
-14
lines changed

server/script/core/completion.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ local function checkFieldThen(name, src, word, start, offset, parent, oop, resul
449449
else
450450
kind = define.CompletionItemKind.Variable
451451
end
452-
elseif not infer and config.config.intelliSense.searchDepth > 0 then
452+
elseif not infer and (config.config.intelliSense.searchDepth > 0 or value.type == "typeof") then
453453
local infers = vm.getInfers(value, 0, {searchAll = true})
454454
for _, infer in ipairs(infers) do
455455
if infer.source then

server/script/core/guide.lua

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,6 +1339,9 @@ end
13391339
local function buildSimpleList(obj, max)
13401340
local list = {}
13411341
local cur = obj
1342+
if obj.type == "type.typeof" then
1343+
cur = obj.value
1344+
end
13421345
local limit = max and (max + 1) or 11
13431346
for i = 1, max or limit do
13441347
if i == limit then
@@ -1408,6 +1411,11 @@ local function buildSimpleList(obj, max)
14081411
elseif cur.type == "type.assert" then
14091412
list[i] = cur
14101413
break
1414+
elseif cur.type == 'type.field'
1415+
or cur.type == 'type.index'
1416+
or cur.type == 'type.library' then
1417+
list[i] = cur
1418+
break
14111419
else
14121420
return nil
14131421
end
@@ -1443,6 +1451,10 @@ function m.getSimple(obj, max)
14431451
or obj.type == 'doc.type.name'
14441452
or obj.type == 'doc.see.name'
14451453
or obj.type == 'doc.see.field'
1454+
or obj.type == 'type.typeof'
1455+
or obj.type == 'type.field'
1456+
or obj.type == 'type.index'
1457+
or obj.type == 'type.library'
14461458
or obj.type == 'type.assert' then
14471459
simpleList = buildSimpleList(obj, max)
14481460
elseif obj.type == 'field'
@@ -1459,7 +1471,7 @@ function m.getVisibleRefs(obj, status)
14591471
if not status.main then
14601472
return obj.ref
14611473
end
1462-
local searchFrom = status.refMain[obj] or status.searchFrom or status.main
1474+
local searchFrom = status.searchFrom or status.main
14631475
local root = m.getRoot(obj)
14641476
if root ~= m.getRoot(searchFrom) then
14651477
if root.returns then
@@ -1479,9 +1491,9 @@ function m.getVisibleRefs(obj, status)
14791491
end
14801492
local refFunc = m.getParentFunction(ref)
14811493
local mainFunc, range = mainFunc, range
1482-
if status.refMain[refFunc] and refFunc ~= mainFunc then
1494+
if status.funcMain[refFunc] and refFunc ~= mainFunc then
14831495
mainFunc = refFunc
1484-
range = select(2, m.getRange(status.refMain[refFunc]))
1496+
range = select(2, m.getRange(status.funcMain[refFunc]))
14851497
end
14861498
if refFunc == mainFunc then
14871499
if ref.start > range and not blockTypes[searchFrom.type] then
@@ -1593,7 +1605,7 @@ function m.status(parentStatus, main, interface, deep, options)
15931605
},
15941606
main = main or parentStatus and parentStatus.main,
15951607
searchFrom = parentStatus and parentStatus.searchFrom or options and options.searchFrom,
1596-
refMain = parentStatus and parentStatus.refMain or {},
1608+
funcMain = parentStatus and parentStatus.funcMain or {},
15971609
depth = parentStatus and (parentStatus.depth + 1) or 0,
15981610
searchDeep = parentStatus and parentStatus.searchDeep or deep or -999,
15991611
interface = parentStatus and parentStatus.interface or {},
@@ -3036,6 +3048,15 @@ function m.getFullType(status, tp, mark)
30363048
end
30373049
tp = tp.exp
30383050
end
3051+
-- if tp.type == "type.typeof" then
3052+
-- local newStatus = m.status(status)
3053+
-- m.searchRefs(newStatus, tp.value, 'def')
3054+
-- for _, def in ipairs(newStatus.results) do
3055+
-- if m.isTypeAnn(def) then
3056+
-- return m.getFullType(status, def, mark)
3057+
-- end
3058+
-- end
3059+
-- end
30393060
local typeAlias = m.getTypeAlias(status, tp)
30403061
if typeAlias then
30413062
local generics = tp.generics
@@ -3134,6 +3155,11 @@ local function checkSameSimpleAndMergeTypeAnnReturns(status, results, source, in
31343155
end
31353156
end
31363157
end
3158+
elseif source.type == "type.typeof" then
3159+
for _, result in ipairs(m.checkSameSimpleInCallInSameFile(status, source.value, args, index)) do
3160+
results[#results+1] = result
3161+
end
3162+
return true
31373163
end
31383164
end
31393165
if #returns == 0 then
@@ -3198,8 +3224,8 @@ function m.checkSameSimpleInCallInSameFile(status, func, args, index)
31983224
end
31993225
end
32003226
end
3201-
cache[index] = results
32023227
end
3228+
cache[index] = results
32033229
return results
32043230
end
32053231

@@ -3250,9 +3276,9 @@ function m.checkSameSimpleInCall(status, ref, start, pushQueue, mode)
32503276
local newStatus = m.status(status)
32513277
if status.main then
32523278
local parentFunc = m.getParentFunction(obj)
3253-
if parentFunc and parentFunc ~= m.getParentFunction(status.searchFrom or status.main) then
3254-
status.refMain[parentFunc] = obj
3255-
status.searchFrom = obj
3279+
if parentFunc and not m.hasParent(status.searchFrom or status.main, parentFunc) then
3280+
status.funcMain[parentFunc] = obj
3281+
status .searchFrom = obj
32563282
end
32573283
end
32583284
m.searchRefs(newStatus, obj, mode)
@@ -4917,6 +4943,9 @@ function m.inferCheckTypeAnn(status, source)
49174943
or typeAnn.type == "type.index" then
49184944
typeAnn = typeAnn.value
49194945
end
4946+
if typeAnn.type == "type.typeof" then
4947+
return false
4948+
end
49204949
if status.options.fullType then
49214950
typeAnn = m.getFullType(status, typeAnn)
49224951
end

server/script/core/hover/init.lua

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,7 @@ end
190190

191191
local function getHoverAsTypeName(source)
192192
local label
193-
local typeAlias = source.typeAlias
194-
if source.parent.type == "type.module" then
195-
typeAlias = vm.getModuleTypeAlias(source.parent)
196-
end
193+
local typeAlias = vm.getTypeAlias(source)
197194
if typeAlias then
198195
label = getHoverAsTypeAlias(typeAlias.name).label
199196
else

server/script/core/signature.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ local function makeSignatures(call, pos)
131131
index = 1
132132
end
133133
local signs = {}
134-
local defs = vm.getDefs(node, 0, {onlyDef = true})
134+
local defs = vm.getDefs(node, 0, {onlyDef = true, fullType = true})
135135
local mark = {}
136136
for _, src in ipairs(defs) do
137137
src = guide.getObjectValue(src) or src

server/script/core/type-checking.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,7 @@ function m.check(uri)
13751375

13761376
m.ast = ast.ast
13771377
m.init()
1378+
vm.flushCache()
13781379

13791380
local results = {}
13801381

0 commit comments

Comments
 (0)