Skip to content

Commit 8803f97

Browse files
committed
add string replacment to account for sql variables
1 parent d913008 commit 8803f97

File tree

2 files changed

+158
-2
lines changed

2 files changed

+158
-2
lines changed

.lintrunner.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,13 @@ init_command = [
339339
is_formatter = true
340340

341341
[[linter]]
342-
code = 'CLICKHOUSE'
342+
code = 'SQLFLUFF'
343343
include_patterns = ['torchci/clickhouse_queries/**/*.sql']
344344
exclude_patterns = [
345345
]
346346
command = [
347347
'python3',
348-
'tools/linter/adapters/clickhouse_sql_linter.py',
348+
'tools/linter/adapters/sqlfluff_linter.py',
349349
'@{{PATHSFILE}}',
350350
]
351351
init_command = [
+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import argparse
2+
import concurrent.futures
3+
import json
4+
import logging
5+
import os
6+
import re
7+
import subprocess
8+
import time
9+
from enum import Enum
10+
from typing import List, NamedTuple, Optional, Pattern
11+
12+
13+
LINTER_CODE = "SQLFLUFF"
14+
15+
16+
class LintSeverity(str, Enum):
17+
ERROR = "error"
18+
WARNING = "warning"
19+
ADVICE = "advice"
20+
DISABLED = "disabled"
21+
22+
23+
class LintMessage(NamedTuple):
24+
path: Optional[str]
25+
line: Optional[int]
26+
char: Optional[int]
27+
code: str
28+
severity: LintSeverity
29+
name: str
30+
original: Optional[str]
31+
replacement: Optional[str]
32+
description: Optional[str]
33+
34+
35+
RESULTS_RE: Pattern[str] = re.compile(
36+
r"""(?mx)
37+
^
38+
(?P<file>.*?):
39+
(?P<line>\d+):
40+
(?P<char>\d+):
41+
\s(?P<message>.*)
42+
\s(?P<code>\[.*\])
43+
$
44+
"""
45+
)
46+
47+
48+
def run_command(
49+
args: List[str],
50+
) -> "subprocess.CompletedProcess[bytes]":
51+
logging.debug("$ %s", " ".join(args))
52+
start_time = time.monotonic()
53+
try:
54+
return subprocess.run(
55+
args,
56+
capture_output=True,
57+
)
58+
finally:
59+
end_time = time.monotonic()
60+
logging.debug("took %dms", (end_time - start_time) * 1000)
61+
62+
63+
def check_file(
64+
filename: str,
65+
) -> List[LintMessage]:
66+
with open(filename, 'r') as f:
67+
original = f.read()
68+
original = original.replace('{', '\'{').replace('}', '}\'')
69+
with open(filename, 'w') as f:
70+
f.write(original)
71+
72+
try:
73+
# proc.run_command(sed -i -e "s/'{/{/g" -e "s/}'/}/g")
74+
proc = run_command(
75+
[
76+
"sqlfluff",
77+
"format",
78+
"--dialect",
79+
"clickhouse",
80+
filename,
81+
]
82+
)
83+
except OSError as err:
84+
return [
85+
LintMessage(
86+
path=None,
87+
line=None,
88+
char=None,
89+
code=LINTER_CODE,
90+
severity=LintSeverity.ERROR,
91+
name="command-failed",
92+
original=None,
93+
replacement=None,
94+
description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
95+
)
96+
]
97+
98+
with open(filename, 'r') as f:
99+
final = f.read()
100+
final = final.replace('\'{', '{').replace('}\'', '}')
101+
with open(filename, 'w') as f:
102+
f.write(final)
103+
104+
lint_message = proc.stdout
105+
106+
107+
return [
108+
LintMessage(
109+
path=filename,
110+
line=None,
111+
char=None,
112+
code=LINTER_CODE,
113+
severity=LintSeverity.WARNING,
114+
name="format",
115+
original=None,
116+
replacement=None,
117+
description=lint_message.decode("utf-8"),
118+
)
119+
]
120+
121+
122+
def main() -> None:
123+
parser = argparse.ArgumentParser(
124+
description=f"sqlfluff format linter for sql queries.",
125+
fromfile_prefix_chars="@",
126+
)
127+
parser.add_argument(
128+
"filenames",
129+
nargs="+",
130+
help="paths to lint",
131+
)
132+
133+
args = parser.parse_args()
134+
135+
with concurrent.futures.ThreadPoolExecutor(
136+
max_workers=os.cpu_count(),
137+
thread_name_prefix="Thread",
138+
) as executor:
139+
futures = {
140+
executor.submit(
141+
check_file,
142+
filename,
143+
): filename
144+
for filename in args.filenames
145+
}
146+
for future in concurrent.futures.as_completed(futures):
147+
try:
148+
for lint_message in future.result():
149+
print(json.dumps(lint_message._asdict()), flush=True)
150+
except Exception:
151+
logging.critical('Failed at "%s".', futures[future])
152+
raise
153+
154+
155+
if __name__ == "__main__":
156+
main()

0 commit comments

Comments
 (0)