Skip to content

Commit a057c2f

Browse files
committed
Expose base type of enumeration
1 parent 07f2467 commit a057c2f

7 files changed

Lines changed: 63 additions & 42 deletions

File tree

generate.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,14 @@ def generate_protocol(output: str) -> None:
6161
]
6262
)
6363

64+
enumerations = {e['name']: e for e in lsp_json['enumerations']}
65+
6466
content += '\n\n\n'
6567
content += '\n\n\n'.join(generate_enumerations(lsp_json['enumerations'], ENUM_OVERRIDES))
6668
content += '\n\n'
67-
content += '\n'.join(generate_type_aliases(lsp_json['typeAliases'], ALIAS_OVERRIDES))
69+
content += '\n'.join(generate_type_aliases(lsp_json['typeAliases'], ALIAS_OVERRIDES, enumerations))
6870
content += '\n\n\n'
69-
content += '\n\n\n'.join(generate_structures(lsp_json['structures']))
71+
content += '\n\n\n'.join(generate_structures(lsp_json['structures'], enumerations))
7072
content += '\n'
7173
content += '\n'.join(get_new_literal_structures())
7274

@@ -101,10 +103,12 @@ def generate_custom(output: str) -> None:
101103
requests = sorted(lsp_json['requests'], key=itemgetter('typeName'))
102104
notifications = sorted(lsp_json['notifications'], key=itemgetter('typeName'))
103105

106+
enumerations = {e['name']: e for e in lsp_json['enumerations']}
107+
104108
content += '\n\n\n'
105-
content += '\n\n\n'.join(generate_requests_and_responses(requests))
109+
content += '\n\n\n'.join(generate_requests_and_responses(requests, enumerations))
106110
content += '\n\n\n'
107-
content += '\n\n\n'.join(generate_notifications(notifications))
111+
content += '\n\n\n'.join(generate_notifications(notifications, enumerations))
108112
content += '\n'
109113

110114
# Remove trailing spaces.

generated/lsp_types.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,7 @@ class FoldingRange(TypedDict):
12491249
"""
12501250
endCharacter: NotRequired[Uint]
12511251
"""The zero-based character offset before the folded range ends. If not defined, defaults to the length of the end line."""
1252-
kind: NotRequired['FoldingRangeKind']
1252+
kind: NotRequired[Union[str, FoldingRangeKind]]
12531253
"""
12541254
Describes the kind of the folding range such as 'comment' or 'region'. The kind
12551255
is used to categorize folding ranges and used by commands like 'Fold all comments'.
@@ -3322,7 +3322,7 @@ class CodeAction(TypedDict):
33223322

33233323
title: str
33243324
"""A short, human-readable, title for this code action."""
3325-
kind: NotRequired['CodeActionKind']
3325+
kind: NotRequired[Union[str, CodeActionKind]]
33263326
"""
33273327
The kind of the code action.
33283328
@@ -3389,7 +3389,7 @@ class CodeActionRegistrationOptions(TypedDict):
33893389
A document selector to identify the scope of the registration. If set to null
33903390
the document selector provided on the client side will be used.
33913391
"""
3392-
codeActionKinds: NotRequired[List['CodeActionKind']]
3392+
codeActionKinds: NotRequired[List[Union[str, CodeActionKind]]]
33933393
"""
33943394
CodeActionKinds that this server may return.
33953395
@@ -4624,7 +4624,7 @@ class TextDocumentItem(TypedDict):
46244624

46254625
uri: DocumentUri
46264626
"""The text document's uri."""
4627-
languageId: 'LanguageKind'
4627+
languageId: Union[str, LanguageKind]
46284628
"""The text document's language identifier."""
46294629
version: int
46304630
"""
@@ -4804,7 +4804,7 @@ class ServerCapabilities(TypedDict):
48044804
server.
48054805
"""
48064806

4807-
positionEncoding: NotRequired['PositionEncodingKind']
4807+
positionEncoding: NotRequired[Union[str, PositionEncodingKind]]
48084808
"""
48094809
The position encoding the server picked from the encodings offered
48104810
by the client via the client capability `general.positionEncodings`.
@@ -4987,7 +4987,7 @@ class FileSystemWatcher(TypedDict):
49874987
49884988
@since 3.17.0 support for relative patterns.
49894989
"""
4990-
kind: NotRequired['WatchKind']
4990+
kind: NotRequired[Union[Uint, WatchKind]]
49914991
"""
49924992
The kind of events of interest. If omitted it defaults
49934993
to WatchKind.Create | WatchKind.Change | WatchKind.Delete
@@ -5419,7 +5419,7 @@ class CodeActionContext(TypedDict):
54195419
that these accurately reflect the error state of the resource. The primary parameter
54205420
to compute code actions is the provided range.
54215421
"""
5422-
only: NotRequired[List['CodeActionKind']]
5422+
only: NotRequired[List[Union[str, CodeActionKind]]]
54235423
"""
54245424
Requested kind of actions to return.
54255425
@@ -5452,7 +5452,7 @@ class CodeActionDisabled(TypedDict):
54525452
class CodeActionOptions(TypedDict):
54535453
"""Provider options for a {@link CodeActionRequest}."""
54545454

5455-
codeActionKinds: NotRequired[List['CodeActionKind']]
5455+
codeActionKinds: NotRequired[List[Union[str, CodeActionKind]]]
54565456
"""
54575457
CodeActionKinds that this server may return.
54585458
@@ -6123,7 +6123,7 @@ class CodeActionKindDocumentation(TypedDict):
61236123
@proposed
61246124
"""
61256125

6126-
kind: 'CodeActionKind'
6126+
kind: Union[str, CodeActionKind]
61276127
"""
61286128
The kind of the code action being documented.
61296129
@@ -6513,7 +6513,7 @@ class GeneralClientCapabilities(TypedDict):
65136513
65146514
@since 3.16.0
65156515
"""
6516-
positionEncodings: NotRequired[List['PositionEncodingKind']]
6516+
positionEncodings: NotRequired[List[Union[str, PositionEncodingKind]]]
65176517
"""
65186518
The position encodings supported by the client. Client and server
65196519
have to agree on the same position encoding to ensure that offsets
@@ -7886,7 +7886,7 @@ class ClientCodeLensResolveOptions(TypedDict):
78867886
class ClientFoldingRangeKindOptions(TypedDict):
78877887
"""@since 3.18.0"""
78887888

7889-
valueSet: NotRequired[List['FoldingRangeKind']]
7889+
valueSet: NotRequired[List[Union[str, FoldingRangeKind]]]
78907890
"""
78917891
The folding range kind values the client supports. When this
78927892
property exists the client also guarantees that it will
@@ -8003,7 +8003,7 @@ class ClientSignatureParameterInformationOptions(TypedDict):
80038003
class ClientCodeActionKindOptions(TypedDict):
80048004
"""@since 3.18.0"""
80058005

8006-
valueSet: List['CodeActionKind']
8006+
valueSet: List[Union[str, CodeActionKind]]
80078007
"""
80088008
The code action kind values the client supports. When this
80098009
property exists the client also guarantees that it will

utils/generate_notifications.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@
55
from utils.helpers import indentation
66

77
if TYPE_CHECKING:
8+
from lsp_schema import Enumeration
89
from lsp_schema import Notification
910

1011

11-
def generate_notifications(notifications: list[Notification]) -> list[str]:
12+
def generate_notifications(notifications: list[Notification], enumerations: dict[str, Enumeration]) -> list[str]:
1213
client_notification_names: list[str] = []
1314
server_notification_names: list[str] = []
1415
definitions: list[str] = []
1516
for notification in notifications:
1617
message_direction = notification['messageDirection']
17-
name, definition = generate_notification(notification)
18+
name, definition = generate_notification(notification, enumerations)
1819
if message_direction == 'clientToServer':
1920
client_notification_names.append(name)
2021
elif message_direction == 'serverToClient':
@@ -32,14 +33,14 @@ def generate_notifications(notifications: list[Notification]) -> list[str]:
3233
]
3334

3435

35-
def generate_notification(notification: Notification) -> tuple[str, str]:
36+
def generate_notification(notification: Notification, enumerations: dict[str, Enumeration]) -> tuple[str, str]:
3637
method = notification['method']
3738
params = notification.get('params')
3839
name = notification['typeName']
3940
definition = f'class {name}(TypedDict):\n'
4041
definition += f"{indentation}method: Literal['{method}']\n"
4142
if params:
42-
definition += f'{indentation}params: {format_type(params, {"root_symbol_name": ""})}'
43+
definition += f'{indentation}params: {format_type(params, {"enumerations": enumerations})}'
4344
else:
4445
definition += f'{indentation}params: None'
4546
return (name, definition)

utils/generate_requests_and_responses.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from utils.helpers import indentation
66

77
if TYPE_CHECKING:
8+
from lsp_schema import Enumeration
89
from lsp_schema import Request
910

1011

11-
def generate_requests_and_responses(requests: list[Request]) -> list[str]:
12+
def generate_requests_and_responses(requests: list[Request], enumerations: dict[str, Enumeration]) -> list[str]:
1213
client_request_names: list[str] = []
1314
server_request_names: list[str] = []
1415
client_response_names: list[str] = []
@@ -18,7 +19,7 @@ def generate_requests_and_responses(requests: list[Request]) -> list[str]:
1819
for request in requests:
1920
message_direction = request['messageDirection']
2021
# Requests
21-
req_name, req_definition = generate_request(request)
22+
req_name, req_definition = generate_request(request, enumerations)
2223
if message_direction == 'clientToServer':
2324
client_request_names.append(req_name)
2425
elif message_direction == 'serverToClient':
@@ -28,7 +29,7 @@ def generate_requests_and_responses(requests: list[Request]) -> list[str]:
2829
server_request_names.append(req_name)
2930
req_definitions.append(req_definition)
3031
# Responses
31-
res_name, res_definition = generate_response(request)
32+
res_name, res_definition = generate_response(request, enumerations)
3233
if message_direction == 'clientToServer':
3334
server_response_names.append(res_name)
3435
elif message_direction == 'serverToClient':
@@ -51,20 +52,20 @@ def generate_requests_and_responses(requests: list[Request]) -> list[str]:
5152
]
5253

5354

54-
def generate_request(request: Request) -> tuple[str, str]:
55+
def generate_request(request: Request, enumerations: dict[str, Enumeration]) -> tuple[str, str]:
5556
method = request['method']
5657
params = request.get('params')
5758
name = request['typeName']
5859
definition = f'class {name}(TypedDict):\n'
5960
definition += f"{indentation}method: Literal['{method}']\n"
6061
if params:
61-
definition += f'{indentation}params: {format_type(params, {"root_symbol_name": ""})}'
62+
definition += f'{indentation}params: {format_type(params, {"enumerations": enumerations})}'
6263
else:
6364
definition += f'{indentation}params: None'
6465
return (name, definition)
6566

6667

67-
def generate_response(request: Request) -> tuple[str, str]:
68+
def generate_response(request: Request, enumerations: dict[str, Enumeration]) -> tuple[str, str]:
6869
method = request['method']
6970
result = request['result']
7071
params = request.get('params')
@@ -73,7 +74,7 @@ def generate_response(request: Request) -> tuple[str, str]:
7374
definition = f'class {name}(TypedDict):\n'
7475
definition += f"{indentation}method: Literal['{method}']\n"
7576
if request['messageDirection'] == 'serverToClient':
76-
typ = format_type(params, {'root_symbol_name': ''}) if params else None
77+
typ = format_type(params, {'root_symbol_name': '', 'enumerations': enumerations}) if params else None
7778
definition += f'{indentation}params: {typ}\n'
78-
definition += f'{indentation}result: {format_type(result, {"root_symbol_name": ""})}'
79+
definition += f'{indentation}result: {format_type(result, {"enumerations": enumerations})}'
7980
return (name, definition)

utils/generate_structures.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,21 @@
1111
from utils.helpers import StructureKind
1212

1313
if TYPE_CHECKING:
14+
from lsp_schema import Enumeration
1415
from lsp_schema import Structure
1516

1617

17-
def generate_structures(structures: list[Structure]) -> list[str]:
18+
def generate_structures(structures: list[Structure], enumerations: dict[str, Enumeration]) -> list[str]:
1819
def to_string(structure: Structure) -> str:
1920
kind = StructureKind.Function if has_invalid_property_name(structure['properties']) else StructureKind.Class
20-
return generate_structure(structure, structures, kind)
21+
return generate_structure(structure, structures, kind, enumerations)
2122

2223
return [to_string(structure) for structure in structures if not structure['name'].startswith('_')]
2324

2425

25-
def get_additional_properties(for_structure: Structure, structures: list[Structure]) -> list[FormattedProperty]:
26+
def get_additional_properties(
27+
for_structure: Structure, structures: list[Structure], enumerations: dict[str, Enumeration]
28+
) -> list[FormattedProperty]:
2629
"""Return properties from extended and mixin types."""
2730
result: list[FormattedProperty] = []
2831
additional_structures = for_structure.get('extends') or []
@@ -33,16 +36,21 @@ def get_additional_properties(for_structure: Structure, structures: list[Structu
3336
raise Exception(error, additional_structure['kind'])
3437
structure = next(structure for structure in structures if structure['name'] == additional_structure['name'])
3538
if structure:
36-
properties = get_formatted_properties(structure['properties'], structure['name'])
39+
properties = get_formatted_properties(structure['properties'], {'enumerations': enumerations})
3740
result.extend(properties)
3841
return result
3942

4043

41-
def generate_structure(structure: Structure, structures: list[Structure], structure_kind: StructureKind) -> str:
44+
def generate_structure(
45+
structure: Structure,
46+
structures: list[Structure],
47+
structure_kind: StructureKind,
48+
enumerations: dict[str, Enumeration],
49+
) -> str:
4250
result = ''
4351
symbol_name = structure['name']
44-
properties = get_formatted_properties(structure['properties'], structure['name'])
45-
additional_properties = get_additional_properties(structure, structures)
52+
properties = get_formatted_properties(structure['properties'], {'enumerations': enumerations})
53+
additional_properties = get_additional_properties(structure, structures, enumerations)
4654

4755
# add extended properties
4856
taken_property_names = [p['name'] for p in properties]

utils/generate_type_aliases.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
from utils.helpers import format_type
66

77
if TYPE_CHECKING:
8+
from lsp_schema import Enumeration
89
from lsp_schema import TypeAlias
910

1011

11-
def generate_type_aliases(type_aliases: list[TypeAlias], overrides: dict[str, str]) -> list[str]:
12+
def generate_type_aliases(
13+
type_aliases: list[TypeAlias], overrides: dict[str, str], enumerations: dict[str, Enumeration]
14+
) -> list[str]:
1215
def to_string(type_alias: TypeAlias) -> str:
1316
symbol_name = type_alias['name']
1417
documentation = format_comment(type_alias.get('documentation'))
1518
if symbol_name in overrides:
1619
value = overrides[symbol_name]
1720
else:
18-
value = format_type(type_alias['type'], {'root_symbol_name': symbol_name})
21+
value = format_type(type_alias['type'], {'enumerations': enumerations})
1922
result = f"""
2023
{symbol_name}: TypeAlias = {value}"""
2124
if documentation:

utils/helpers.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

33
from enum import Enum
4+
from lsp_schema import EnumerationType
45
from typing import Any
56
from typing import TYPE_CHECKING
67
from typing import TypedDict
78
import keyword
89

910
if TYPE_CHECKING:
1011
from lsp_schema import BaseType
12+
from lsp_schema import Enumeration
1113
from lsp_schema import EveryType
1214
from lsp_schema import MapKeyType
1315
from lsp_schema import Property
@@ -54,7 +56,7 @@ class StructureKind(Enum):
5456

5557

5658
class FormatTypeContext(TypedDict):
57-
root_symbol_name: str
59+
enumerations: dict[str, Enumeration]
5860

5961

6062
def format_type(typ: EveryType, context: FormatTypeContext) -> str:
@@ -63,13 +65,15 @@ def format_type(typ: EveryType, context: FormatTypeContext) -> str:
6365
return format_base_types(typ)
6466
if typ['kind'] == 'reference':
6567
literal_symbol_name = typ['name']
68+
if (enum := context['enumerations'].get(literal_symbol_name)) and enum.get('supportsCustomValues'):
69+
return f'Union[{format_type(enum["type"], context)}, {literal_symbol_name}]'
6670
return f"'{literal_symbol_name}'"
6771
if typ['kind'] == 'array':
6872
literal_symbol_name = format_type(typ['element'], context)
6973
return f'List[{literal_symbol_name}]'
7074
if typ['kind'] == 'map':
7175
key = format_base_types(typ['key'])
72-
value = format_type(typ['value'], {'root_symbol_name': key})
76+
value = format_type(typ['value'], {'enumerations': context['enumerations']})
7377
return f'Dict[{key}, {value}]'
7478
if typ['kind'] == 'and':
7579
pass
@@ -91,7 +95,7 @@ def format_type(typ: EveryType, context: FormatTypeContext) -> str:
9195
return result
9296

9397

94-
def format_base_types(base_type: BaseType | MapKeyType) -> str:
98+
def format_base_types(base_type: BaseType | MapKeyType | EnumerationType) -> str:
9599
mapping: dict[str, str] = {
96100
'integer': 'int',
97101
'uinteger': 'Uint',
@@ -111,11 +115,11 @@ class FormattedProperty(TypedDict):
111115
documentation: str
112116

113117

114-
def get_formatted_properties(properties: list[Property], root_symbol_name: str) -> list[FormattedProperty]:
118+
def get_formatted_properties(properties: list[Property], context: FormatTypeContext) -> list[FormattedProperty]:
115119
result: list[FormattedProperty] = []
116120
for p in properties:
117121
key = p['name']
118-
value = format_type(p['type'], {'root_symbol_name': root_symbol_name + '_' + key})
122+
value = format_type(p['type'], context)
119123
if p.get('optional'):
120124
value = f'NotRequired[{value}]'
121125
documentation = p.get('documentation') or ''

0 commit comments

Comments
 (0)