4545)
4646from great_expectations .core .expectation_suite import ExpectationSuite
4747from great_expectations .core .validation_definition import ValidationDefinition
48+ from great_expectations .datasource .fluent .sql_server_datasource import (
49+ SQLServerAuthConnectionDetails ,
50+ )
4851from great_expectations .execution_engine .sqlalchemy_dialect import (
4952 DIALECT_IDENTIFIER_QUOTE_STRINGS ,
5053 GXSqlDialect ,
6467 SQLDatasource ,
6568 SqliteDatasource ,
6669 )
70+ from great_expectations .datasource .fluent .sql_server_datasource import (
71+ SQLServerDatasource ,
72+ )
6773 from great_expectations .execution_engine import SqlAlchemyExecutionEngine
6874
6975TERMINAL_WIDTH : Final = shutil .get_terminal_size ().columns
8389# sqlite db files should be using fresh tmp_path on every test
8490DO_NOT_DROP_TABLES : set [str ] = {"sqlite" }
8591
86- DatabaseType : TypeAlias = Literal ["postgres" , "sqlite" , "trino" ]
92+ DatabaseType : TypeAlias = Literal ["postgres" , "sqlite" , "trino" , "mssql" ]
8793TableNameCase : TypeAlias = Literal [
8894 "quoted_lower" ,
8995 "quoted_mixed" ,
122128 "quoted_mixed" : f'"{ TEST_TABLE_NAME .title ()} "' ,
123129 "unquoted_mixed" : TEST_TABLE_NAME .title (),
124130 },
131+ "mssql" : {
132+ "unquoted_lower" : TEST_TABLE_NAME .lower (),
133+ "quoted_lower" : f"[{ TEST_TABLE_NAME .lower ()} ]" ,
134+ "unquoted_upper" : TEST_TABLE_NAME .upper (),
135+ "quoted_upper" : f"[{ TEST_TABLE_NAME .upper ()} ]" ,
136+ "quoted_mixed" : f"[{ TEST_TABLE_NAME .title ()} ]" ,
137+ "unquoted_mixed" : TEST_TABLE_NAME .title (),
138+ },
125139}
126140
127141# column names
@@ -343,6 +357,24 @@ def __call__(
343357 ) -> None : ...
344358
345359
360+ def _create_schema_ddl (schema : str , is_mssql : bool ) -> str :
361+ if is_mssql :
362+ return (
363+ f"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{ schema } ')"
364+ f" EXEC('CREATE SCHEMA { schema } ')"
365+ )
366+ return f"CREATE SCHEMA IF NOT EXISTS { schema } "
367+
368+
369+ def _create_table_ddl (qualified_table_name : str , table_columns : str , is_mssql : bool ) -> str :
370+ if is_mssql :
371+ return (
372+ f"IF OBJECT_ID(N'{ qualified_table_name } ', N'U') IS NULL"
373+ f" CREATE TABLE { qualified_table_name } { table_columns } "
374+ )
375+ return f"CREATE TABLE IF NOT EXISTS { qualified_table_name } { table_columns } "
376+
377+
346378@pytest .fixture (
347379 scope = "class" ,
348380)
@@ -382,6 +414,7 @@ def _table_factory(
382414 )
383415 dialect = GXSqlDialect (sa_engine .dialect .name )
384416 created_tables : list [dict [Literal ["table_name" , "schema" ], str | None ]] = []
417+ is_mssql = dialect == GXSqlDialect .MSSQL
385418
386419 with gx_engine .get_connection () as conn :
387420 quoted_upper_col : str = quote_str (QUOTED_UPPER_COL , dialect = dialect )
@@ -390,18 +423,18 @@ def _table_factory(
390423 quoted_mixed_case : str = quote_str (QUOTED_MIXED_CASE , dialect = dialect )
391424
392425 if schema :
393- conn .execute (TextClause (f"CREATE SCHEMA IF NOT EXISTS { schema } " ))
426+ conn .execute (TextClause (_create_schema_ddl ( schema , is_mssql ) ))
394427 for name in table_names :
395428 qualified_table_name = f"{ schema } .{ name } " if schema else name
396- # TODO: use dialect specific quotes
397- create_tables : str = (
398- f"CREATE TABLE IF NOT EXISTS { qualified_table_name } "
429+ table_columns : str = (
399430 " (id INTEGER, name VARCHAR(255),"
400431 f" { quoted_upper_col } VARCHAR(255), { quoted_lower_col } VARCHAR(255),"
401432 f" { UNQUOTED_UPPER_COL } VARCHAR(255), { UNQUOTED_LOWER_COL } VARCHAR(255),"
402433 f" { quoted_mixed_case } VARCHAR(255), { quoted_w_dots } VARCHAR(255))"
403434 )
404- conn .execute (TextClause (create_tables ))
435+ conn .execute (
436+ TextClause (_create_table_ddl (qualified_table_name , table_columns , is_mssql ))
437+ )
405438 if data :
406439 insert_data = (
407440 f"INSERT INTO { qualified_table_name } "
@@ -470,6 +503,24 @@ def sqlite_ds(context: EphemeralDataContext, tmp_path: pathlib.Path) -> SqliteDa
470503 return ds
471504
472505
506+ @pytest .fixture
507+ def mssql_ds (context : EphemeralDataContext ) -> SQLServerDatasource :
508+ ds = context .data_sources .add_sql_server (
509+ "mssql" ,
510+ connection_string = SQLServerAuthConnectionDetails (
511+ host = "127.0.0.1" ,
512+ port = 1433 ,
513+ database = "test_ci" ,
514+ schema = "dbo" ,
515+ username = "sa" ,
516+ password = "ReallyStrongPwd1234%^&*" ,
517+ driver = "ODBC Driver 18 for SQL Server" ,
518+ encrypt = "Optional" ,
519+ ),
520+ )
521+ return ds
522+
523+
473524@pytest .fixture (
474525 params = [
475526 param (
@@ -551,6 +602,24 @@ def test_sqlite(
551602
552603 sqlite_ds .add_table_asset (asset_name , table_name = table_name )
553604
605+ @pytest .mark .mssql
606+ def test_mssql (
607+ self ,
608+ mssql_ds : SQLServerDatasource ,
609+ asset_name : TableNameCase ,
610+ table_factory : TableFactory ,
611+ ):
612+ table_name = TABLE_NAME_MAPPING ["mssql" ].get (asset_name )
613+ if not table_name :
614+ pytest .skip (f"no '{ asset_name } ' table_name for mssql" )
615+ # create table
616+ table_factory (gx_engine = mssql_ds .get_execution_engine (), table_names = {table_name })
617+
618+ table_names : list [str ] = inspect (mssql_ds .get_engine ()).get_table_names ()
619+ print (f"mssql tables:\n { pf (table_names )} ))" )
620+
621+ mssql_ds .add_table_asset (asset_name , table_name = table_name )
622+
554623 @pytest .mark .filterwarnings ( # snowflake `add_table_asset` raises warning on passing a schema
555624 "once::great_expectations.datasource.fluent.GxDatasourceWarning"
556625 )
@@ -560,6 +629,7 @@ def test_sqlite(
560629 param ("trino" , None , marks = [pytest .mark .trino ]),
561630 param ("postgres" , None , marks = [pytest .mark .postgresql ]),
562631 param ("sqlite" , None , marks = [pytest .mark .sqlite ]),
632+ param ("mssql" , None , marks = [pytest .mark .mssql ]),
563633 ],
564634 )
565635 def test_checkpoint_run (
@@ -625,10 +695,10 @@ def _is_quote_char_dialect_mismatch(
625695 dialect : GXSqlDialect ,
626696 column_name : str | quoted_name ,
627697) -> bool :
628- quote_char = column_name [0 ] if column_name [0 ] in ("'" , '"' , "`" ) else None
698+ quote_char = column_name [0 ] if column_name [0 ] in ("'" , '"' , "`" , "[" ) else None
629699 if quote_char :
630- dialect_quote_char = DIALECT_IDENTIFIER_QUOTE_STRINGS [dialect ]
631- if quote_char != dialect_quote_char :
700+ expected = DIALECT_IDENTIFIER_QUOTE_STRINGS [dialect ][ 0 ]
701+ if quote_char != expected :
632702 return True
633703 return False
634704
@@ -639,8 +709,12 @@ def _raw_query_check_column_exists(
639709 gx_execution_engine : SqlAlchemyExecutionEngine ,
640710) -> bool :
641711 """Use a simple 'SELECT {column_name_param} from {qualified_table_name};' query to check if the column exists.'""" # noqa: E501 # FIXME CoP
642- with gx_execution_engine .get_connection () as connection :
712+ dialect_name = gx_execution_engine .engine .dialect .name
713+ if dialect_name == "mssql" :
714+ query = f"""SELECT TOP 1 { column_name_param } FROM { qualified_table_name } ;"""
715+ else :
643716 query = f"""SELECT { column_name_param } FROM { qualified_table_name } LIMIT 1;"""
717+ with gx_execution_engine .get_connection () as connection :
644718 print (f"query:\n { query } " )
645719 # an exception will be raised if the column does not exist
646720 try :
0 commit comments