Skip to content

Commit 2d272f2

Browse files
Grain-YugrainyuLeoQuote
authored
新增对doris的支持 (#2536)
* 新增对doris的支持 支持doris的查询、上线审核 相关讨论:#2175 * 调整继承关系 改为从MysqlEngine类继承 * 删去重复方法 * 用black处理 * 删除重复函数 * reuse get_all_databases --------- Co-authored-by: grainyu <[email protected]> Co-authored-by: Leo Q <[email protected]>
1 parent 033964c commit 2d272f2

File tree

7 files changed

+257
-6
lines changed

7 files changed

+257
-6
lines changed

archery/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"phoenix",
5353
"odps",
5454
"cassandra",
55+
"doris",
5556
],
5657
),
5758
ENABLED_NOTIFIERS=(
@@ -99,6 +100,7 @@
99100
"mongo": {"path": "sql.engines.mongo:MongoEngine"},
100101
"phoenix": {"path": "sql.engines.phoenix:PhoenixEngine"},
101102
"odps": {"path": "sql.engines.odps:ODPSEngine"},
103+
"doris": {"path": "sql.engines.doris:DorisEngine"},
102104
}
103105

104106
ENABLED_NOTIFIERS = env("ENABLED_NOTIFIERS")

sql/engines/doris.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# -*- coding: UTF-8 -*-
2+
from sql.utils.sql_utils import get_syntax_type, remove_comments
3+
from sql.engines.mysql import MysqlEngine
4+
from .models import ResultSet, ReviewResult, ReviewSet
5+
from common.utils.timer import FuncTimer
6+
from common.config import SysConfig
7+
from MySQLdb.constants import FIELD_TYPE
8+
import traceback
9+
import MySQLdb
10+
import pymysql
11+
import sqlparse
12+
import logging
13+
import re
14+
15+
16+
logger = logging.getLogger("default")
17+
18+
19+
class DorisEngine(MysqlEngine):
20+
name = "Doris"
21+
info = "Doris engine"
22+
23+
auto_backup = False
24+
25+
@property
26+
def server_version(self):
27+
sql = "show frontends"
28+
result = self.query(sql=sql)
29+
version = result.rows[0][-1].split("-")[0]
30+
return tuple([int(n) for n in version.split(".")[:3]])
31+
32+
def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
33+
"""返回 ResultSet"""
34+
result_set = ResultSet(full_sql=sql)
35+
try:
36+
conn = self.get_connection(db_name=db_name)
37+
cursor = conn.cursor()
38+
cursor.execute(sql)
39+
if int(limit_num) > 0:
40+
rows = cursor.fetchmany(size=int(limit_num))
41+
else:
42+
rows = cursor.fetchall()
43+
fields = cursor.description
44+
45+
result_set.column_list = [i[0] for i in fields] if fields else []
46+
result_set.rows = rows
47+
result_set.affected_rows = len(rows)
48+
except Exception as e:
49+
logger.warning(f"Doris语句执行报错,语句:{sql},错误信息{e}")
50+
result_set.error = str(e).split("Stack trace")[0]
51+
finally:
52+
if close_conn:
53+
self.close()
54+
return result_set
55+
56+
forbidden_databases = [
57+
"__internal_schema",
58+
"INFORMATION_SCHEMA",
59+
"information_schema",
60+
]
61+
62+
def execute_check(self, db_name=None, sql=""):
63+
"""上线单执行前的检查, 返回Review set"""
64+
check_result = ReviewSet(full_sql=sql)
65+
# 禁用/高危语句检查
66+
line = 1
67+
critical_ddl_regex = self.config.get("critical_ddl_regex", "")
68+
p = re.compile(critical_ddl_regex)
69+
check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML
70+
for statement in sqlparse.split(sql):
71+
statement = sqlparse.format(statement, strip_comments=True)
72+
# 禁用语句
73+
if re.match(r"^select|^show|^explain", statement.lower()):
74+
result = ReviewResult(
75+
id=line,
76+
errlevel=2,
77+
stagestatus="驳回不支持语句",
78+
errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
79+
sql=statement,
80+
)
81+
# 高危语句
82+
elif critical_ddl_regex and p.match(statement.strip().lower()):
83+
result = ReviewResult(
84+
id=line,
85+
errlevel=2,
86+
stagestatus="驳回高危SQL",
87+
errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!",
88+
sql=statement,
89+
)
90+
# 驳回未带where数据修改语句,如确实需做全部删除或更新,显示的带上where 1=1
91+
elif re.match(
92+
r"^update((?!where).)*$|^delete((?!where).)*$", statement.lower()
93+
):
94+
result = ReviewResult(
95+
id=line,
96+
errlevel=2,
97+
stagestatus="驳回未带where数据修改",
98+
errormessage="数据修改需带where条件!",
99+
sql=statement,
100+
)
101+
# 正常语句
102+
else:
103+
result = ReviewResult(
104+
id=line,
105+
errlevel=0,
106+
stagestatus="Audit completed",
107+
errormessage="None",
108+
sql=statement,
109+
affected_rows=0,
110+
execute_time=0,
111+
)
112+
# 判断工单类型
113+
if get_syntax_type(statement) == "DDL":
114+
check_result.syntax_type = 1
115+
check_result.rows += [result]
116+
line += 1
117+
# 统计警告和错误数量
118+
for r in check_result.rows:
119+
if r.errlevel == 1:
120+
check_result.warning_count += 1
121+
if r.errlevel == 2:
122+
check_result.error_count += 1
123+
return check_result
124+
125+
def execute_workflow(self, workflow):
126+
return self.execute(
127+
db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content
128+
)
129+
130+
def execute(self, db_name=None, sql="", close_conn=True):
131+
"""执行sql语句 返回 Review set"""
132+
execute_result = ReviewSet(full_sql=sql)
133+
conn = self.get_connection(db_name=db_name)
134+
rowid = 1
135+
effect_row = 0
136+
sql_list = sqlparse.split(sql)
137+
for statement in sql_list:
138+
try:
139+
cursor = conn.cursor()
140+
with FuncTimer() as t:
141+
effect_row = cursor.execute(statement)
142+
cursor.close()
143+
execute_result.rows.append(
144+
ReviewResult(
145+
id=rowid,
146+
errlevel=0,
147+
stagestatus="Execute Successfully",
148+
errormessage="None",
149+
sql=statement,
150+
affected_rows=effect_row,
151+
execute_time=t.cost,
152+
)
153+
)
154+
except Exception as e:
155+
logger.warning(
156+
f"{self.name} 命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}"
157+
)
158+
execute_result.error = str(e)
159+
execute_result.rows.append(
160+
ReviewResult(
161+
id=rowid,
162+
errlevel=2,
163+
stagestatus="Execute Failed",
164+
errormessage=f"异常信息:{e}",
165+
sql=statement,
166+
affected_rows=effect_row,
167+
execute_time=t.cost,
168+
)
169+
)
170+
break
171+
rowid += 1
172+
if execute_result.error:
173+
for statement in sql_list[rowid:]:
174+
execute_result.rows.append(
175+
ReviewResult(
176+
id=rowid + 1,
177+
errlevel=2,
178+
stagestatus="Audit Completed",
179+
errormessage="前序语句失败, 未执行",
180+
sql=statement,
181+
affected_rows=0,
182+
execute_time=0,
183+
)
184+
)
185+
rowid += 1
186+
if close_conn:
187+
self.close()
188+
return execute_result

sql/engines/mysql.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,24 +178,32 @@ def kill_connection(self, thread_id):
178178
"""终止数据库连接"""
179179
self.query(sql=f"kill {thread_id}")
180180

181+
# 禁止查询的数据库
182+
forbidden_databases = [
183+
"information_schema",
184+
"performance_schema",
185+
"mysql",
186+
"test",
187+
"sys",
188+
]
189+
181190
def get_all_databases(self):
182191
"""获取数据库列表, 返回一个ResultSet"""
183192
sql = "show databases"
184193
result = self.query(sql=sql)
185194
db_list = [
186-
row[0]
187-
for row in result.rows
188-
if row[0]
189-
not in ("information_schema", "performance_schema", "mysql", "test", "sys")
195+
row[0] for row in result.rows if row[0] not in self.forbidden_databases
190196
]
191197
result.rows = db_list
192198
return result
193199

200+
forbidden_tables = ["test"]
201+
194202
def get_all_tables(self, db_name, **kwargs):
195203
"""获取table 列表, 返回一个ResultSet"""
196204
sql = "show tables"
197205
result = self.query(db_name=db_name, sql=sql)
198-
tb_list = [row[0] for row in result.rows if row[0] not in ["test"]]
206+
tb_list = [row[0] for row in result.rows if row[0] not in self.forbidden_tables]
199207
result.rows = tb_list
200208
return result
201209

sql/engines/test_doris.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from pytest_mock import MockFixture
2+
3+
from sql.engines.doris import DorisEngine
4+
from sql.engines.models import ResultSet
5+
6+
7+
def test_doris_server_info(db_instance, mocker: MockFixture):
8+
mock_query = mocker.patch.object(DorisEngine, "query")
9+
mock_query.return_value = ResultSet(
10+
full_sql="show frontends", rows=[["foo", "bar", "2.1.0-doris"]]
11+
)
12+
db_instance.db_type = "doris"
13+
engine = DorisEngine(instance=db_instance)
14+
version = engine.server_version
15+
assert version == (2, 1, 0)
16+
17+
18+
def test_doris_query(db_instance, mocker: MockFixture):
19+
mock_get_connection = mocker.patch.object(DorisEngine, "get_connection")
20+
21+
class DummyCursor:
22+
def __init__(self):
23+
self.description = [("foo",), ("bar",)]
24+
self.fetchall = lambda: [("baz", "qux")]
25+
26+
def execute(self, sql):
27+
pass
28+
29+
mock_get_connection.return_value.cursor.return_value = DummyCursor()
30+
db_instance.db_type = "doris"
31+
engine = DorisEngine(instance=db_instance)
32+
result_set = engine.query(sql="select * from foo")
33+
assert result_set.column_list == ["foo", "bar"]
34+
assert result_set.rows == [("baz", "qux")]
35+
assert result_set.affected_rows == 1
36+
37+
38+
def test_forbidden_db(db_instance, mocker: MockFixture):
39+
db_instance.db_type = "doris"
40+
mock_query = mocker.patch.object(DorisEngine, "query")
41+
mock_query.return_value = ResultSet(
42+
full_sql="show databases", rows=[["__internal_schema"]]
43+
)
44+
45+
engine = DorisEngine(instance=db_instance)
46+
all_db = engine.get_all_databases()
47+
assert all_db.rows == []

sql/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class Meta:
133133
("clickhouse", "ClickHouse"),
134134
("goinception", "goInception"),
135135
("cassandra", "Cassandra"),
136+
("doris", "Doris"),
136137
)
137138

138139

sql/templates/sqlquery.html

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,11 @@ <h4 class="modal-title text-danger">收藏语句</h4>
979979
if (sql === 'explain') {
980980
sqlContent = 'explain ' + sqlContent
981981
}
982+
} else if (optgroup === "Doris") {
983+
//查看执行计划
984+
if (sql === 'explain') {
985+
sqlContent = 'explain ' + sqlContent
986+
}
982987
}
983988
//提交请求
984989
$.ajax({

sql/views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def config(request):
503503
# 获取所有实例标签
504504
instance_tags = InstanceTag.objects.all()
505505
# 支持自动审核的数据库类型
506-
db_type = ["mysql", "oracle", "mongo", "clickhouse", "redis"]
506+
db_type = ["mysql", "oracle", "mongo", "clickhouse", "redis", "doris"]
507507
# 获取所有配置项
508508
all_config = Config.objects.all().values("item", "value")
509509
sys_config = {}

0 commit comments

Comments
 (0)