7
7
import struct
8
8
import sys
9
9
import re
10
+ import requests
10
11
from knack .log import get_logger
11
12
from azure .mgmt .core .tools import parse_resource_id
12
13
from azure .cli .core import telemetry
46
47
# For db(mysqlFlex/psql/psqlFlex/sql) linker with auth type=systemAssignedIdentity, enable Microsoft Entra auth and create db user on data plane
47
48
# For other linker, ignore the steps
48
49
def get_enable_mi_for_db_linker_func (yes = False , new = False ):
49
- def enable_mi_for_db_linker (cmd , source_id , target_id , auth_info , client_type , connection_name , * args , ** kwargs ):
50
+ def enable_mi_for_db_linker (cmd , source_id , target_id , auth_info , client_type , connection_name , connstr_props , * args , ** kwargs ):
50
51
# return if connection is not for db mi
51
52
if auth_info ['auth_type' ] not in [AUTHTYPES [AUTH_TYPE .SystemIdentity ],
52
53
AUTHTYPES [AUTH_TYPE .UserIdentity ],
@@ -61,7 +62,7 @@ def enable_mi_for_db_linker(cmd, source_id, target_id, auth_info, client_type, c
61
62
if source_handler is None :
62
63
return None
63
64
target_handler = getTargetHandler (
64
- cmd , target_id , target_type , auth_info , client_type , connection_name , skip_prompt = yes , new_user = new )
65
+ cmd , target_id , target_type , auth_info , client_type , connection_name , connstr_props , skip_prompt = yes , new_user = new )
65
66
if target_handler is None :
66
67
return None
67
68
target_handler .check_db_existence ()
@@ -88,7 +89,7 @@ def enable_mi_for_db_linker(cmd, source_id, target_id, auth_info, client_type, c
88
89
source_object_id = source_handler .get_identity_pid ()
89
90
target_handler .identity_object_id = source_object_id
90
91
try :
91
- if target_type in [RESOURCE .Sql ]:
92
+ if target_type in [RESOURCE .Sql , RESOURCE . FabricSql ]:
92
93
target_handler .identity_name = source_handler .get_identity_name ()
93
94
elif target_type in [RESOURCE .Postgres , RESOURCE .MysqlFlexible ]:
94
95
identity_info = run_cli_cmd (
@@ -149,7 +150,7 @@ def enable_mi_for_db_linker(cmd, source_id, target_id, auth_info, client_type, c
149
150
150
151
151
152
# pylint: disable=unused-argument, too-many-instance-attributes
152
- def getTargetHandler (cmd , target_id , target_type , auth_info , client_type , connection_name , skip_prompt , new_user ):
153
+ def getTargetHandler (cmd , target_id , target_type , auth_info , client_type , connection_name , connstr_props , skip_prompt , new_user ):
153
154
if target_type in {RESOURCE .Sql }:
154
155
return SqlHandler (cmd , target_id , target_type , auth_info , connection_name , skip_prompt , new_user )
155
156
if target_type in {RESOURCE .Postgres }:
@@ -158,6 +159,8 @@ def getTargetHandler(cmd, target_id, target_type, auth_info, client_type, connec
158
159
return PostgresFlexHandler (cmd , target_id , target_type , auth_info , connection_name , skip_prompt , new_user )
159
160
if target_type in {RESOURCE .MysqlFlexible }:
160
161
return MysqlFlexibleHandler (cmd , target_id , target_type , auth_info , connection_name , skip_prompt , new_user )
162
+ if target_type in {RESOURCE .FabricSql }:
163
+ return FabricSqlHandler (cmd , target_id , target_type , auth_info , connection_name , connstr_props , skip_prompt , new_user )
161
164
return None
162
165
163
166
@@ -960,6 +963,89 @@ def get_create_query(self):
960
963
]
961
964
962
965
966
+ class FabricSqlHandler (SqlHandler ):
967
+ def __init__ (self , cmd , target_id , target_type , auth_info , connection_name , connstr_props , skip_prompt , new_user ):
968
+ super ().__init__ (cmd , target_id , target_type ,
969
+ auth_info , connection_name , skip_prompt , new_user )
970
+
971
+ self .target_id = target_id
972
+
973
+ if not connstr_props :
974
+ raise CLIInternalError ("Missing additional connection string properties for Fabric SQL target." )
975
+
976
+ Server = connstr_props .get ('Server' ) or connstr_props .get ('Data Source' )
977
+ Database = connstr_props .get ('Database' ) or connstr_props .get ('Initial Catalog' )
978
+ if not Server or not Database :
979
+ raise CLIInternalError ("Missing 'Server' or 'Database' in additonal connection string properties keys."
980
+ "Use --connstr_props 'Server=xxx' 'Database=xxx' to provide the values." )
981
+
982
+ # Construct the ODBC connection string
983
+ self .ODBCConnectionString = self .construct_odbc_connection_string (Server , Database )
984
+ logger .warning ("ODBC connection string: %s" , self .ODBCConnectionString )
985
+
986
+ def check_db_existence (self ):
987
+ fabric_token = self .get_fabric_access_token ()
988
+ headers = {"Authorization" : "Bearer {}" .format (fabric_token )}
989
+ response = requests .get (self .target_id , headers = headers )
990
+
991
+ if response :
992
+ response_json = response .json ()
993
+ if response_json ["id" ]:
994
+ return
995
+
996
+ e = ResourceNotFoundError ("No database found with name {}" .format (self .dbname ))
997
+ telemetry .set_exception (e , "No-Db" )
998
+ raise e
999
+
1000
+ def construct_odbc_connection_string (self , server , database ):
1001
+ # Map fields to ODBC fields
1002
+ odbc_dict = {
1003
+ 'Driver' : '{driver}' ,
1004
+ 'Server' : server ,
1005
+ 'Database' : database ,
1006
+ }
1007
+
1008
+ odbc_connection_string = ';' .join ([f'{ key } ={ value } ' for key , value in odbc_dict .items ()])
1009
+ return odbc_connection_string
1010
+
1011
+ def create_aad_user (self ):
1012
+ query_list = self .get_create_query ()
1013
+ connection_args = self .get_connection_string ()
1014
+
1015
+ logger .warning ("Connecting to database..." )
1016
+ self .create_aad_user_in_sql (connection_args , query_list )
1017
+
1018
+ def get_fabric_access_token (self ):
1019
+ return run_cli_cmd ('az account get-access-token --output json --resource https://api.fabric.microsoft.com/' ).get ('accessToken' )
1020
+
1021
+ def set_user_admin (self , user_object_id , ** kwargs ):
1022
+ return
1023
+
1024
+ def get_connection_string (self , dbname = "" ):
1025
+ token_bytes = self .get_fabric_access_token ().encode ('utf-16-le' )
1026
+
1027
+ token_struct = struct .pack (
1028
+ f'<I{ len (token_bytes )} s' , len (token_bytes ), token_bytes )
1029
+ # This connection option is defined by microsoft in msodbcsql.h
1030
+ SQL_COPT_SS_ACCESS_TOKEN = 1256
1031
+ conn_string = self .ODBCConnectionString
1032
+ return {'connection_string' : conn_string , 'attrs_before' : {SQL_COPT_SS_ACCESS_TOKEN : token_struct }}
1033
+
1034
+ def get_create_query (self ):
1035
+ if self .auth_type in [AUTHTYPES [AUTH_TYPE .SystemIdentity ], AUTHTYPES [AUTH_TYPE .UserIdentity ]]:
1036
+ self .aad_username = self .identity_name
1037
+ else :
1038
+ raise CLIInternalError ("Unsupported auth type: " + self .auth_type )
1039
+
1040
+ delete_q = "DROP USER IF EXISTS \" {}\" ;" .format (self .aad_username )
1041
+ role_q = "CREATE USER \" {}\" FROM EXTERNAL PROVIDER;" .format (self .aad_username )
1042
+ grant_q1 = "ALTER ROLE db_datareader ADD MEMBER \" {}\" " .format (self .aad_username )
1043
+ grant_q2 = "ALTER ROLE db_datawriter ADD MEMBER \" {}\" " .format (self .aad_username )
1044
+ grant_q3 = "ALTER ROLE db_ddladmin ADD MEMBER \" {}\" " .format (self .aad_username )
1045
+
1046
+ return [delete_q , role_q , grant_q1 , grant_q2 , grant_q3 ]
1047
+
1048
+
963
1049
def getSourceHandler (source_id , source_type ):
964
1050
if source_type in {RESOURCE .WebApp , RESOURCE .FunctionApp }:
965
1051
return WebappHandler (source_id , source_type )
0 commit comments