Skip to content

Commit 766e345

Browse files
authored
disallow name field in standard predictor (kserve#4535)
Signed-off-by: HutakiHare <sunniessheep@gmail.com>
1 parent 4edbb36 commit 766e345

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

pkg/apis/serving/v1beta1/inference_service_validation.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ func validateInferenceService(isvc *InferenceService) (admission.Warnings, error
125125
return allWarnings, err
126126
}
127127

128+
if err := validatePredictor(isvc); err != nil {
129+
return allWarnings, err
130+
}
131+
128132
for _, component := range []Component{
129133
&isvc.Spec.Predictor,
130134
isvc.Spec.Transformer,
@@ -146,6 +150,40 @@ func validateInferenceService(isvc *InferenceService) (admission.Warnings, error
146150
return allWarnings, nil
147151
}
148152

153+
func validatePredictor(isvc *InferenceService) error {
154+
predictor := isvc.Spec.Predictor
155+
156+
// log predictor
157+
validatorLogger.Info("Incoming predictor struct", "predictor", predictor)
158+
159+
// in most of the case, standard predictors will all be packed into `predictor.model`, and decide the backend process through `modelFormat.name``
160+
switch {
161+
case predictor.SKLearn != nil && predictor.SKLearn.Name != "":
162+
return errors.New("the 'name' field is not allowed in standard predictor")
163+
case predictor.XGBoost != nil && predictor.XGBoost.Name != "":
164+
return errors.New("the 'name' field is not allowed in standard predictor")
165+
case predictor.Tensorflow != nil && predictor.Tensorflow.Name != "":
166+
return errors.New("the 'name' field is not allowed in standard predictor")
167+
case predictor.PyTorch != nil && predictor.PyTorch.Name != "":
168+
return errors.New("the 'name' field is not allowed in standard predictor")
169+
case predictor.Triton != nil && predictor.Triton.Name != "":
170+
return errors.New("the 'name' field is not allowed in standard predictor")
171+
case predictor.ONNX != nil && predictor.ONNX.Name != "":
172+
return errors.New("the 'name' field is not allowed in standard predictor")
173+
case predictor.HuggingFace != nil && predictor.HuggingFace.Name != "":
174+
return errors.New("the 'name' field is not allowed in standard predictor")
175+
case predictor.PMML != nil && predictor.PMML.Name != "":
176+
return errors.New("the 'name' field is not allowed in standard predictor")
177+
case predictor.LightGBM != nil && predictor.LightGBM.Name != "":
178+
return errors.New("the 'name' field is not allowed in standard predictor")
179+
case predictor.Paddle != nil && predictor.Paddle.Name != "":
180+
return errors.New("the 'name' field is not allowed in standard predictor")
181+
case predictor.Model != nil && predictor.Model.Name != "":
182+
return errors.New("the 'name' field is not allowed in standard predictor")
183+
}
184+
return nil
185+
}
186+
149187
// validateMultiNodeVariables validates when there is workerSpec set in isvc
150188
func validateMultiNodeVariables(isvc *InferenceService) error {
151189
if isvc.Spec.Predictor.WorkerSpec != nil {

pkg/apis/serving/v1beta1/inference_service_validation_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package v1beta1
1818

1919
import (
2020
"fmt"
21+
"strings"
2122
"testing"
2223

2324
"github.com/kserve/kserve/pkg/constants"
@@ -32,6 +33,31 @@ import (
3233
"k8s.io/utils/ptr"
3334
)
3435

36+
func TestInvalidNameInSKLearnPredictor(t *testing.T) {
37+
isvc := InferenceService{
38+
ObjectMeta: metav1.ObjectMeta{
39+
Name: "test-isvc",
40+
},
41+
Spec: InferenceServiceSpec{
42+
Predictor: PredictorSpec{
43+
SKLearn: &SKLearnSpec{
44+
PredictorExtensionSpec: PredictorExtensionSpec{
45+
Container: corev1.Container{
46+
Name: "invalid-name",
47+
Image: "dummy-image",
48+
},
49+
StorageURI: proto.String("gs://kfserving-examples/models/sklearn/1.0/model"),
50+
},
51+
},
52+
},
53+
},
54+
}
55+
err := validatePredictor(&isvc)
56+
if err == nil || !strings.Contains(err.Error(), "not allowed") {
57+
t.Errorf("Expected error for name field in SKLearn predictor, got: %v", err)
58+
}
59+
}
60+
3561
func makeTestInferenceService() InferenceService {
3662
inferenceservice := InferenceService{
3763
ObjectMeta: metav1.ObjectMeta{

0 commit comments

Comments
 (0)