Skip to content

Commit 7e9c709

Browse files
Merge pull request #28 from nirupama-dev/main
Composer Scheduler: Dataproc multi tenant cluster enablement
2 parents fc65034 + 41b562d commit 7e9c709

File tree

6 files changed

+64
-7
lines changed

6 files changed

+64
-7
lines changed

scheduler_jupyter_plugin/dagTemplates/pysparkJobTemplate-v1.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ submit_pyspark_job = DataprocSubmitJobOperator(
157157
'args' : notebook_args
158158
},
159159
},
160+
{% if multi_tenant_service_account %}
161+
impersonation_chain=['{{multi_tenant_service_account}}'],
162+
{% endif %}
160163
gcp_conn_id='google_cloud_default', # Reference to the GCP connection
161164
dag=dag,
162165
)

scheduler_jupyter_plugin/services/executor.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
UTF8,
4242
PAYLOAD_JSON_FILE_PATH,
4343
HTTP_STATUS_OK,
44+
DATAPROC_SERVICE_NAME,
4445
)
4546
from scheduler_jupyter_plugin.models.models import DescribeJob
4647
from scheduler_jupyter_plugin.services import airflow
@@ -147,7 +148,50 @@ async def upload_to_gcs(
147148
self.log.exception(f"Error uploading file to GCS: {str(error)}")
148149
raise IOError(str(error))
149150

150-
def prepare_dag(self, job, gcs_dag_bucket, dag_file, project_id, region_id):
151+
async def get_cluster_details(self, cluster_name):
152+
try:
153+
dataproc_url = await urls.gcp_service_url(DATAPROC_SERVICE_NAME)
154+
api_endpoint = f"{dataproc_url}/v1/projects/{self.project_id}/regions/{self.region_id}/clusters/{cluster_name}"
155+
async with self.client_session.get(
156+
api_endpoint, headers=self.create_headers()
157+
) as response:
158+
if response.status == HTTP_STATUS_OK:
159+
resp = await response.json()
160+
return resp
161+
else:
162+
return {
163+
"error": f"Failed to fetch clusters: {response.status} {await response.text()}"
164+
}
165+
166+
except Exception as e:
167+
self.log.exception("Error fetching cluster list")
168+
return {"error": str(e)}
169+
170+
async def multi_tenant_user_service_account(self, cluster_name):
171+
cluster_data = await self.get_cluster_details(cluster_name)
172+
if cluster_data:
173+
multi_tenant = (
174+
cluster_data.get("config", {})
175+
.get("softwareConfig", {})
176+
.get("properties", {})
177+
.get("dataproc:dataproc.dynamic.multi.tenancy.enabled", "false")
178+
)
179+
if multi_tenant == "true":
180+
cmd = "config get account"
181+
process = await async_run_gcloud_subcommand(cmd)
182+
user_email = process.strip()
183+
service_account = (
184+
cluster_data.get("config", {})
185+
.get("securityConfig", {})
186+
.get("identityConfig", {})
187+
.get("userServiceAccountMapping", {})
188+
.get(user_email, "")
189+
)
190+
if service_account:
191+
return service_account
192+
return ""
193+
194+
async def prepare_dag(self, job, gcs_dag_bucket, dag_file, project_id, region_id):
151195
self.log.info("Generating dag file")
152196
DAG_TEMPLATE_CLUSTER_V1 = "pysparkJobTemplate-v1.txt"
153197
DAG_TEMPLATE_SERVERLESS_V1 = "pysparkBatchTemplate-v1.txt"
@@ -181,6 +225,11 @@ def prepare_dag(self, job, gcs_dag_bucket, dag_file, project_id, region_id):
181225
parameters = ""
182226
if job.local_kernel is False:
183227
if job.mode_selected == "cluster":
228+
multi_tenant_service_account = (
229+
await self.multi_tenant_user_service_account(
230+
cluster_name=job.cluster_name
231+
)
232+
)
184233
template = environment.get_template(DAG_TEMPLATE_CLUSTER_V1)
185234
if not job.input_filename.startswith(GCS):
186235
input_notebook = f"gs://{gcs_dag_bucket}/dataproc-notebooks/{job.name}/input_notebooks/{job.input_filename}"
@@ -198,6 +247,7 @@ def prepare_dag(self, job, gcs_dag_bucket, dag_file, project_id, region_id):
198247
start_date=start_date,
199248
parameters=parameters,
200249
time_zone=time_zone,
250+
multi_tenant_service_account=multi_tenant_service_account,
201251
)
202252
else:
203253
template = environment.get_template(DAG_TEMPLATE_SERVERLESS_V1)
@@ -402,7 +452,7 @@ async def execute(self, input_data, project_id, region_id):
402452
destination_dir=f"dataproc-notebooks/{job_name}/dag_details",
403453
)
404454

405-
file_path = self.prepare_dag(
455+
file_path = await self.prepare_dag(
406456
job, gcs_dag_bucket, dag_file, project_id, region_id
407457
)
408458
await self.upload_to_gcs(

scheduler_jupyter_plugin/tests/test_dataproc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def test_list_clusters(monkeypatch, jp_fetch):
3333
payload = json.loads(response.body)
3434
assert (
3535
payload["api_endpoint"]
36-
== f"https://dataproc.googleapis.com//v1/projects/credentials-project/regions/{mock_region_id}/clusters?pageSize={mock_page_size}&pageToken={mock_page_token}"
36+
== f"https://dataproc.googleapis.com//v1/projects/{mock_project_id}/regions/{mock_region_id}/clusters?pageSize={mock_page_size}&pageToken={mock_page_token}"
3737
)
3838
assert payload["headers"]["Authorization"] == f"Bearer mock-token"
3939

src/controls/RegionDropdown.tsx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type Props = {
3838
/** Initial loading flag for region */
3939
loaderRegion?: boolean;
4040
setLoaderRegion?: (value: boolean) => void;
41+
label?: string;
4142
};
4243

4344
/**
@@ -53,7 +54,8 @@ export function RegionDropdown(props: Props) {
5354
regionDisable,
5455
fromPage,
5556
loaderRegion,
56-
setLoaderRegion
57+
setLoaderRegion,
58+
label
5759
} = props;
5860
let regionStrList: string[] = [];
5961

@@ -77,7 +79,7 @@ export function RegionDropdown(props: Props) {
7779
renderInput={params => (
7880
<TextField
7981
{...params}
80-
label={'Region*'}
82+
label={label || 'Region*'}
8183
InputProps={{
8284
...params.InputProps,
8385
endAdornment: (

src/scheduler/composer/CreateNotebookScheduler.tsx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ const CreateNotebookScheduler = ({
717717
setProjectId(projectId ?? '')
718718
}
719719
fetchFunc={projectListAPI}
720-
label="Project ID*"
720+
label="Composer Project ID*"
721721
// Always show the clear indicator and hide the dropdown arrow
722722
// make it very clear that this is an autocomplete.
723723
sx={{
@@ -740,6 +740,7 @@ const CreateNotebookScheduler = ({
740740
editMode={editMode}
741741
loaderRegion={loaderRegion}
742742
setLoaderRegion={setLoaderRegion}
743+
label={'Composer Region*'}
743744
/>
744745
</div>
745746
{!region && <ErrorMessage message="Region is required" />}

src/scheduler/composer/ListNotebookScheduler.tsx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ function ListNotebookScheduler({
829829
handleProjectIdChange(projectId);
830830
}}
831831
fetchFunc={projectListAPI}
832-
label="Project ID*"
832+
label="Composer Project ID*"
833833
// Always show the clear indicator and hide the dropdown arrow
834834
// make it very clear that this is an autocomplete.
835835
sx={{
@@ -855,6 +855,7 @@ function ListNotebookScheduler({
855855
onRegionChange={region => handleRegionChange(region)}
856856
loaderRegion={loaderRegion}
857857
setLoaderRegion={setLoaderRegion}
858+
label={'Composer Region*'}
858859
/>
859860
</div>
860861
{!region && (

0 commit comments

Comments
 (0)