@@ -1381,28 +1381,38 @@ def gen_expr(depth: int) -> exp.Expression:
1381
1381
def test_spark_annotators (self ):
1382
1382
"""Test Spark annotators, mainly built-in string/binary functions"""
1383
1383
1384
- schema = {"tbl" : {"bin_col" : "BINARY" , "str_col" : "STRING" }}
1384
+ spark_schema = {"tbl" : {"bin_col" : "BINARY" , "str_col" : "STRING" }}
1385
+
1386
+ from sqlglot .dialects import Dialect
1385
1387
1386
1388
def _assert_func_return_type (func : str , dialect : str , target_type : str ):
1387
1389
ast = parse_one (f"SELECT { func } FROM tbl" , read = dialect )
1388
- optimized = optimizer .optimize (ast , schema = schema , dialect = dialect )
1390
+ annotators = Dialect .get_or_raise (dialect ).ANNOTATORS
1391
+ annotated = annotate_types (ast , annotators = annotators , schema = spark_schema )
1392
+
1389
1393
self .assertEqual (
1390
- optimized .expressions [0 ].type .sql (dialect ),
1394
+ annotated .expressions [0 ].type .sql (dialect ),
1391
1395
exp .DataType .build (target_type ).sql (dialect ),
1392
1396
)
1393
1397
1398
+ str_col , bin_col = "tbl.str_col" , "tbl.bin_col"
1399
+
1394
1400
# In Spark hierarchy, SUBSTRING result type is dependent on input expr type
1395
1401
for dialect in ("spark2" , "spark" , "databricks" ):
1396
- _assert_func_return_type ("SUBSTRING(str_col, 0, 0)" , dialect , "STRING" )
1397
- _assert_func_return_type ("SUBSTRING(bin_col, 0, 0)" , dialect , "BINARY" )
1402
+ _assert_func_return_type (f"SUBSTRING({ str_col } , 0, 0)" , dialect , "STRING" )
1403
+ _assert_func_return_type (f"SUBSTRING({ bin_col } , 0, 0)" , dialect , "BINARY" )
1404
+
1405
+ _assert_func_return_type (f"CONCAT({ bin_col } , { bin_col } )" , dialect , "BINARY" )
1406
+ _assert_func_return_type (f"CONCAT({ bin_col } , { str_col } )" , dialect , "STRING" )
1407
+ _assert_func_return_type (f"CONCAT({ str_col } , { bin_col } )" , dialect , "STRING" )
1408
+ _assert_func_return_type (f"CONCAT({ str_col } , { str_col } )" , dialect , "STRING" )
1398
1409
1399
- _assert_func_return_type ("CONCAT(bin_col, bin_col)" , dialect , "BINARY" )
1400
- _assert_func_return_type ("CONCAT(bin_col, str_col)" , dialect , "STRING" )
1401
- _assert_func_return_type ("CONCAT(str_col, bin_col)" , dialect , "STRING" )
1402
- _assert_func_return_type ("CONCAT(str_col, str_col)" , dialect , "STRING" )
1410
+ _assert_func_return_type (f"CONCAT({ str_col } , foo)" , dialect , "STRING" )
1411
+ _assert_func_return_type (f"CONCAT({ bin_col } , bar)" , dialect , "UNKNOWN" )
1412
+ _assert_func_return_type ("CONCAT(foo, bar)" , dialect , "UNKNOWN" )
1403
1413
1404
1414
for func in ("LPAD" , "RPAD" ):
1405
- _assert_func_return_type (f"{ func } (bin_col, 1, bin_col)" , dialect , "BINARY" )
1406
- _assert_func_return_type (f"{ func } (bin_col, 1, str_col)" , dialect , "STRING" )
1407
- _assert_func_return_type (f"{ func } (str_col, 1, bin_col)" , dialect , "STRING" )
1408
- _assert_func_return_type (f"{ func } (str_col, 1, str_col)" , dialect , "STRING" )
1415
+ _assert_func_return_type (f"{ func } ({ bin_col } , 1, { bin_col } )" , dialect , "BINARY" )
1416
+ _assert_func_return_type (f"{ func } ({ bin_col } , 1, { str_col } )" , dialect , "STRING" )
1417
+ _assert_func_return_type (f"{ func } ({ str_col } , 1, { bin_col } )" , dialect , "STRING" )
1418
+ _assert_func_return_type (f"{ func } ({ str_col } , 1, { str_col } )" , dialect , "STRING" )
0 commit comments