Skip to content

Commit 9b03095

Browse files
Support data type code conversion (#110)
Signed-off-by: Rami Prilutsky <rami.p@taboola.com>
1 parent 6c89aac commit 9b03095

3 files changed

Lines changed: 70 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- `get_catalog` macro (#89)
1919
- Multiple indexes issue (#97)
2020
- View rename to `__dbt_backup` fails during `--full-refresh` when upstream view columns changed
21+
- Implement StarRocks DBAPI type-code conversion for dbt query schema inference
2122

2223
## [1.11.0] - 2025-10-16
2324

dbt/adapters/starrocks/connections.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from contextlib import contextmanager
1717

1818
import mysql.connector
19+
from mysql.connector.constants import FieldType
1920

2021
import dbt.exceptions
2122
import dbt_common.exceptions
@@ -28,7 +29,7 @@
2829
)
2930
from dbt.adapters.sql import SQLConnectionManager
3031
from dbt.adapters.events.logging import AdapterLogger
31-
from typing import Optional
32+
from typing import Optional, Union
3233

3334
logger = AdapterLogger("starrocks")
3435

@@ -108,6 +109,44 @@ def _parse_version(result):
108109

109110
class StarRocksConnectionManager(SQLConnectionManager):
110111
TYPE = 'starrocks'
112+
TYPE_CODE_TO_NAME = {
113+
FieldType.DECIMAL: "decimal",
114+
FieldType.NEWDECIMAL: "decimal",
115+
FieldType.TINY: "tinyint",
116+
FieldType.SHORT: "smallint",
117+
FieldType.LONG: "int",
118+
FieldType.INT24: "int",
119+
FieldType.LONGLONG: "bigint",
120+
FieldType.FLOAT: "float",
121+
FieldType.DOUBLE: "double",
122+
FieldType.DATE: "date",
123+
FieldType.DATETIME: "datetime",
124+
FieldType.TIMESTAMP: "datetime",
125+
FieldType.TIME: "varchar",
126+
FieldType.YEAR: "smallint",
127+
FieldType.VARCHAR: "varchar",
128+
FieldType.VAR_STRING: "varchar",
129+
FieldType.STRING: "varchar",
130+
FieldType.BLOB: "varbinary",
131+
FieldType.BIT: "tinyint",
132+
FieldType.JSON: "json",
133+
}
134+
135+
@classmethod
136+
def data_type_code_to_name(cls, type_code: Union[int, str]) -> str:
137+
if isinstance(type_code, str):
138+
return type_code.lower()
139+
140+
data_type = cls.TYPE_CODE_TO_NAME.get(type_code)
141+
if data_type:
142+
return data_type
143+
144+
mysql_type_name = FieldType.get_info(type_code)
145+
if mysql_type_name:
146+
return mysql_type_name.lower()
147+
148+
logger.warning("Unknown StarRocks data type code: %s", type_code)
149+
return str(type_code)
111150

112151
@classmethod
113152
def open(cls, connection):

tests/unit/test_connections.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from mysql.connector.constants import FieldType
2+
3+
from dbt.adapters.starrocks.connections import StarRocksConnectionManager
4+
5+
6+
def test_data_type_code_to_name_maps_mysql_connector_type_codes():
7+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.DECIMAL) == "decimal"
8+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.DATE) == "date"
9+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.DATETIME) == "datetime"
10+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.TIMESTAMP) == "datetime"
11+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.LONGLONG) == "bigint"
12+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.NEWDECIMAL) == "decimal"
13+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.TIME) == "varchar"
14+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.YEAR) == "smallint"
15+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.VAR_STRING) == "varchar"
16+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.BLOB) == "varbinary"
17+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.BIT) == "tinyint"
18+
19+
20+
def test_data_type_code_to_name_normalizes_string_type_codes():
21+
assert StarRocksConnectionManager.data_type_code_to_name("VARCHAR") == "varchar"
22+
23+
24+
def test_data_type_code_to_name_falls_back_to_mysql_connector_name():
25+
assert StarRocksConnectionManager.data_type_code_to_name(FieldType.GEOMETRY) == "geometry"
26+
27+
28+
def test_data_type_code_to_name_logs_unknown_numeric_code():
29+
assert StarRocksConnectionManager.data_type_code_to_name(99999) == "99999"

0 commit comments

Comments
 (0)