Skip to content

Commit 58e97c2

Browse files
set priorityClassName for spark tasks (#7476)
Signed-off-by: madiyar-wayve <madiyar.aitzhanov@wayve.ai>
1 parent 65abdca commit 58e97c2

2 files changed

Lines changed: 45 additions & 0 deletions

File tree

flyteplugins/go/tasks/plugins/k8s/spark/spark.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont
234234
},
235235
}
236236
spec.sparkSpec.ServiceAccount = strPtr(serviceAccountName(nonInterruptibleTaskCtx.TaskExecutionMetadata()))
237+
spec.sparkSpec.PriorityClassName = strPtr(podSpec.PriorityClassName)
237238

238239
if cores, err := strconv.ParseInt(sparkConfig["spark.driver.cores"], 10, 32); err == nil {
239240
spec.sparkSpec.Cores = intPtr(int32(cores))
@@ -288,6 +289,7 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo
288289
},
289290
serviceAccountName,
290291
}
292+
spec.sparkSpec.PriorityClassName = strPtr(podSpec.PriorityClassName)
291293
if execCores, err := strconv.ParseInt(sparkConfig["spark.executor.cores"], 10, 32); err == nil {
292294
spec.sparkSpec.Cores = intPtr(int32(execCores))
293295
}

flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,49 @@ func TestBuildResourcePodTemplate(t *testing.T) {
10581058
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
10591059
}
10601060

1061+
func TestBuildResourcePriorityClassName(t *testing.T) {
1062+
defaultConfig := defaultPluginConfig()
1063+
assert.NoError(t, config.SetK8sPluginConfig(defaultConfig))
1064+
1065+
const priorityClassName = "high-priority"
1066+
podSpec := dummyPodSpec()
1067+
podSpec.PriorityClassName = priorityClassName
1068+
taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec)
1069+
1070+
sparkResourceHandler := sparkResourceHandler{}
1071+
taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})
1072+
resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx)
1073+
1074+
assert.Nil(t, err)
1075+
assert.NotNil(t, resource)
1076+
sparkApp, ok := resource.(*sparkOp.SparkApplication)
1077+
assert.True(t, ok)
1078+
1079+
assert.NotNil(t, sparkApp.Spec.Driver.PriorityClassName)
1080+
assert.Equal(t, priorityClassName, *sparkApp.Spec.Driver.PriorityClassName)
1081+
assert.NotNil(t, sparkApp.Spec.Executor.PriorityClassName)
1082+
assert.Equal(t, priorityClassName, *sparkApp.Spec.Executor.PriorityClassName)
1083+
}
1084+
1085+
func TestBuildResourceNoPriorityClassName(t *testing.T) {
1086+
defaultConfig := defaultPluginConfig()
1087+
assert.NoError(t, config.SetK8sPluginConfig(defaultConfig))
1088+
1089+
taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf)
1090+
sparkResourceHandler := sparkResourceHandler{}
1091+
taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})
1092+
resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx)
1093+
1094+
assert.Nil(t, err)
1095+
assert.NotNil(t, resource)
1096+
sparkApp, ok := resource.(*sparkOp.SparkApplication)
1097+
assert.True(t, ok)
1098+
1099+
// When no priority class is set, the field should be left unset (nil) rather than an empty string.
1100+
assert.Nil(t, sparkApp.Spec.Driver.PriorityClassName)
1101+
assert.Nil(t, sparkApp.Spec.Executor.PriorityClassName)
1102+
}
1103+
10611104
func TestGetPropertiesSpark(t *testing.T) {
10621105
sparkResourceHandler := sparkResourceHandler{}
10631106
expected := k8s.PluginProperties{}

0 commit comments

Comments
 (0)