|
16 | 16 | from contextlib import contextmanager |
17 | 17 |
|
18 | 18 | import mysql.connector |
| 19 | +from mysql.connector.constants import FieldType |
19 | 20 |
|
20 | 21 | import dbt.exceptions |
21 | 22 | import dbt_common.exceptions |
|
28 | 29 | ) |
29 | 30 | from dbt.adapters.sql import SQLConnectionManager |
30 | 31 | from dbt.adapters.events.logging import AdapterLogger |
31 | | -from typing import Optional |
| 32 | +from typing import Optional, Union |
32 | 33 |
|
33 | 34 | logger = AdapterLogger("starrocks") |
34 | 35 |
|
@@ -108,6 +109,44 @@ def _parse_version(result): |
108 | 109 |
|
109 | 110 | class StarRocksConnectionManager(SQLConnectionManager): |
110 | 111 | TYPE = 'starrocks' |
| 112 | + TYPE_CODE_TO_NAME = { |
| 113 | + FieldType.DECIMAL: "decimal", |
| 114 | + FieldType.NEWDECIMAL: "decimal", |
| 115 | + FieldType.TINY: "tinyint", |
| 116 | + FieldType.SHORT: "smallint", |
| 117 | + FieldType.LONG: "int", |
| 118 | + FieldType.INT24: "int", |
| 119 | + FieldType.LONGLONG: "bigint", |
| 120 | + FieldType.FLOAT: "float", |
| 121 | + FieldType.DOUBLE: "double", |
| 122 | + FieldType.DATE: "date", |
| 123 | + FieldType.DATETIME: "datetime", |
| 124 | + FieldType.TIMESTAMP: "datetime", |
| 125 | + FieldType.TIME: "varchar", |
| 126 | + FieldType.YEAR: "smallint", |
| 127 | + FieldType.VARCHAR: "varchar", |
| 128 | + FieldType.VAR_STRING: "varchar", |
| 129 | + FieldType.STRING: "varchar", |
| 130 | + FieldType.BLOB: "varbinary", |
| 131 | + FieldType.BIT: "tinyint", |
| 132 | + FieldType.JSON: "json", |
| 133 | + } |
| 134 | + |
| 135 | + @classmethod |
| 136 | + def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: |
| 137 | + if isinstance(type_code, str): |
| 138 | + return type_code.lower() |
| 139 | + |
| 140 | + data_type = cls.TYPE_CODE_TO_NAME.get(type_code) |
| 141 | + if data_type: |
| 142 | + return data_type |
| 143 | + |
| 144 | + mysql_type_name = FieldType.get_info(type_code) |
| 145 | + if mysql_type_name: |
| 146 | + return mysql_type_name.lower() |
| 147 | + |
| 148 | + logger.warning("Unknown StarRocks data type code: %s", type_code) |
| 149 | + return str(type_code) |
111 | 150 |
|
112 | 151 | @classmethod |
113 | 152 | def open(cls, connection): |
|
0 commit comments