4141 UTF8 ,
4242 PAYLOAD_JSON_FILE_PATH ,
4343 HTTP_STATUS_OK ,
44+ DATAPROC_SERVICE_NAME ,
4445)
4546from scheduler_jupyter_plugin .models .models import DescribeJob
4647from scheduler_jupyter_plugin .services import airflow
@@ -147,7 +148,50 @@ async def upload_to_gcs(
147148 self .log .exception (f"Error uploading file to GCS: { str (error )} " )
148149 raise IOError (str (error ))
149150
150- def prepare_dag (self , job , gcs_dag_bucket , dag_file , project_id , region_id ):
151+ async def get_cluster_details (self , cluster_name ):
152+ try :
153+ dataproc_url = await urls .gcp_service_url (DATAPROC_SERVICE_NAME )
154+ api_endpoint = f"{ dataproc_url } /v1/projects/{ self .project_id } /regions/{ self .region_id } /clusters/{ cluster_name } "
155+ async with self .client_session .get (
156+ api_endpoint , headers = self .create_headers ()
157+ ) as response :
158+ if response .status == HTTP_STATUS_OK :
159+ resp = await response .json ()
160+ return resp
161+ else :
162+ return {
163+ "error" : f"Failed to fetch clusters: { response .status } { await response .text ()} "
164+ }
165+
166+ except Exception as e :
167+ self .log .exception ("Error fetching cluster list" )
168+ return {"error" : str (e )}
169+
170+ async def multi_tenant_user_service_account (self , cluster_name ):
171+ cluster_data = await self .get_cluster_details (cluster_name )
172+ if cluster_data :
173+ multi_tenant = (
174+ cluster_data .get ("config" , {})
175+ .get ("softwareConfig" , {})
176+ .get ("properties" , {})
177+ .get ("dataproc:dataproc.dynamic.multi.tenancy.enabled" , "false" )
178+ )
179+ if multi_tenant == "true" :
180+ cmd = "config get account"
181+ process = await async_run_gcloud_subcommand (cmd )
182+ user_email = process .strip ()
183+ service_account = (
184+ cluster_data .get ("config" , {})
185+ .get ("securityConfig" , {})
186+ .get ("identityConfig" , {})
187+ .get ("userServiceAccountMapping" , {})
188+ .get (user_email , "" )
189+ )
190+ if service_account :
191+ return service_account
192+ return ""
193+
194+ async def prepare_dag (self , job , gcs_dag_bucket , dag_file , project_id , region_id ):
151195 self .log .info ("Generating dag file" )
152196 DAG_TEMPLATE_CLUSTER_V1 = "pysparkJobTemplate-v1.txt"
153197 DAG_TEMPLATE_SERVERLESS_V1 = "pysparkBatchTemplate-v1.txt"
@@ -181,6 +225,11 @@ def prepare_dag(self, job, gcs_dag_bucket, dag_file, project_id, region_id):
181225 parameters = ""
182226 if job .local_kernel is False :
183227 if job .mode_selected == "cluster" :
228+ multi_tenant_service_account = (
229+ await self .multi_tenant_user_service_account (
230+ cluster_name = job .cluster_name
231+ )
232+ )
184233 template = environment .get_template (DAG_TEMPLATE_CLUSTER_V1 )
185234 if not job .input_filename .startswith (GCS ):
186235 input_notebook = f"gs://{ gcs_dag_bucket } /dataproc-notebooks/{ job .name } /input_notebooks/{ job .input_filename } "
@@ -198,6 +247,7 @@ def prepare_dag(self, job, gcs_dag_bucket, dag_file, project_id, region_id):
198247 start_date = start_date ,
199248 parameters = parameters ,
200249 time_zone = time_zone ,
250+ multi_tenant_service_account = multi_tenant_service_account ,
201251 )
202252 else :
203253 template = environment .get_template (DAG_TEMPLATE_SERVERLESS_V1 )
@@ -402,7 +452,7 @@ async def execute(self, input_data, project_id, region_id):
402452 destination_dir = f"dataproc-notebooks/{ job_name } /dag_details" ,
403453 )
404454
405- file_path = self .prepare_dag (
455+ file_path = await self .prepare_dag (
406456 job , gcs_dag_bucket , dag_file , project_id , region_id
407457 )
408458 await self .upload_to_gcs (
0 commit comments