Skip to content

Commit 6fac083

Browse files
authored
Resolve inference endpoint using runtime protocol when applicable (kserve#4527)
Signed-off-by: Edgar Hernández <23639005+israel-hdez@users.noreply.github.com>
1 parent 6530d1e commit 6fac083

File tree

3 files changed

+135
-4
lines changed

3 files changed

+135
-4
lines changed

pkg/controller/v1alpha1/inferencegraph/controller.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ func (r *InferenceGraphReconciler) Reconcile(ctx context.Context, req ctrl.Reque
160160
err := r.Client.Get(ctx, types.NamespacedName{Namespace: graph.Namespace, Name: route.ServiceName}, &isvc)
161161
if err == nil {
162162
if graph.Spec.Nodes[node].Steps[i].ServiceURL == "" {
163-
serviceUrl, err := isvcutils.GetPredictorEndpoint(&isvc)
163+
serviceUrl, err := isvcutils.GetPredictorEndpoint(ctx, r.Client, &isvc)
164164
if err == nil {
165165
graph.Spec.Nodes[node].Steps[i].ServiceURL = serviceUrl
166166
} else {

pkg/controller/v1beta1/inferenceservice/utils/utils.go

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ func GetModelName(isvc *v1beta1.InferenceService) string {
136136
}
137137

138138
// GetPredictorEndpoint returns the predictor endpoint if status.address.url is not nil else returns empty string with error.
139-
func GetPredictorEndpoint(isvc *v1beta1.InferenceService) (string, error) {
139+
func GetPredictorEndpoint(ctx context.Context, client client.Client, isvc *v1beta1.InferenceService) (string, error) {
140140
if isvc.Status.Address != nil && isvc.Status.Address.URL != nil {
141141
hostName := isvc.Status.Address.URL.String()
142142
path := ""
@@ -149,7 +149,47 @@ func GetPredictorEndpoint(isvc *v1beta1.InferenceService) (string, error) {
149149
path = constants.PredictPath(modelName, constants.ProtocolV2)
150150
}
151151
} else if !IsMMSPredictor(&isvc.Spec.Predictor) {
152-
protocol := isvc.Spec.Predictor.GetImplementation().GetProtocol()
152+
predictorImplementation := isvc.Spec.Predictor.GetImplementation()
153+
protocol := predictorImplementation.GetProtocol()
154+
155+
if modelSpec, ok := predictorImplementation.(*v1beta1.ModelSpec); ok {
156+
if modelSpec.Runtime != nil {
157+
// When a Runtime is specified, and there is no protocol specified
158+
// in the ISVC, the protocol cannot imply to be V1. The protocol
159+
// needs to be extracted from the Runtime.
160+
161+
runtime, err := GetServingRuntime(ctx, client, *modelSpec.Runtime, isvc.Namespace)
162+
if err != nil {
163+
return "", err
164+
}
165+
166+
// If the runtime has protocol versions, use the first one supported by IG.
167+
// Otherwise, assume Protocol V1.
168+
if len(runtime.ProtocolVersions) != 0 {
169+
found := false
170+
for _, pversion := range runtime.ProtocolVersions {
171+
if pversion == constants.ProtocolV1 || pversion == constants.ProtocolV2 {
172+
protocol = pversion
173+
found = true
174+
break
175+
}
176+
}
177+
178+
if !found {
179+
return "", errors.New("the runtime does not support a protocol compatible with Inference Graphs")
180+
}
181+
}
182+
}
183+
184+
// else {
185+
// Notice that when using auto-selection (i.e. Runtime is nil), the
186+
// ISVC is assumed to be protocol v1. Thus, for auto-select, a runtime
187+
// will only match if it lists protocol v1 as supported. In this case,
188+
// the code above (protocol := predictorImplementation.GetProtocol()) would
189+
// already get the right protocol to configure in the InferenceGraph.
190+
// }
191+
}
192+
153193
if protocol == constants.ProtocolV1 {
154194
path = constants.PredictPath(modelName, constants.ProtocolV1)
155195
} else if protocol == constants.ProtocolV2 {

pkg/controller/v1beta1/inferenceservice/utils/utils_test.go

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"testing"
2323

2424
"github.com/onsi/gomega/types"
25+
"k8s.io/utils/ptr"
2526
"knative.dev/pkg/apis"
2627
knativeV1 "knative.dev/pkg/apis/duck/v1"
2728

@@ -1561,6 +1562,34 @@ func TestGetPredictorEndpoint(t *testing.T) {
15611562
"cpu": resource.MustParse("90m"),
15621563
},
15631564
}
1565+
namespace := "default"
1566+
1567+
s := runtime.NewScheme()
1568+
err := v1alpha1.AddToScheme(s)
1569+
if err != nil {
1570+
t.Errorf("Failed to add v1alpha1 to scheme %s", err)
1571+
}
1572+
protocolV1Runtime := &v1alpha1.ServingRuntime{
1573+
TypeMeta: metav1.TypeMeta{},
1574+
ObjectMeta: metav1.ObjectMeta{
1575+
Name: "mocked-v1-runtime",
1576+
Namespace: namespace,
1577+
},
1578+
Spec: v1alpha1.ServingRuntimeSpec{
1579+
ProtocolVersions: []constants.InferenceServiceProtocol{"v1"},
1580+
},
1581+
}
1582+
protocolV2Runtime := &v1alpha1.ServingRuntime{
1583+
TypeMeta: metav1.TypeMeta{},
1584+
ObjectMeta: metav1.ObjectMeta{
1585+
Name: "mocked-v2-runtime",
1586+
Namespace: namespace,
1587+
},
1588+
Spec: v1alpha1.ServingRuntimeSpec{
1589+
ProtocolVersions: []constants.InferenceServiceProtocol{"v2"},
1590+
},
1591+
}
1592+
mockClient := fake.NewClientBuilder().WithScheme(s).WithObjects(protocolV1Runtime, protocolV2Runtime).Build()
15641593

15651594
scenarios := map[string]struct {
15661595
isvc InferenceService
@@ -1829,11 +1858,73 @@ func TestGetPredictorEndpoint(t *testing.T) {
18291858
expectedUrl: "",
18301859
expectedErr: gomega.MatchError("service sklearn is not ready"),
18311860
},
1861+
"NoProtocolWithRuntimeProtocolV1": {
1862+
isvc: InferenceService{
1863+
ObjectMeta: metav1.ObjectMeta{
1864+
Name: "sklearn",
1865+
Namespace: namespace,
1866+
},
1867+
Spec: InferenceServiceSpec{
1868+
Predictor: PredictorSpec{
1869+
Model: &ModelSpec{
1870+
Runtime: ptr.To("mocked-v1-runtime"),
1871+
ModelFormat: ModelFormat{
1872+
Name: "sklearn",
1873+
},
1874+
PredictorExtensionSpec: PredictorExtensionSpec{
1875+
StorageURI: proto.String("s3://test"),
1876+
},
1877+
},
1878+
},
1879+
},
1880+
Status: InferenceServiceStatus{
1881+
Address: &knativeV1.Addressable{
1882+
URL: &apis.URL{
1883+
Scheme: "http",
1884+
Host: "sklearn-predictor.default.svc.cluster.local",
1885+
},
1886+
},
1887+
},
1888+
},
1889+
expectedUrl: "http://sklearn-predictor.default.svc.cluster.local/v1/models/sklearn:predict",
1890+
expectedErr: gomega.BeNil(),
1891+
},
1892+
"NoProtocolWithRuntimeProtocolV2": {
1893+
isvc: InferenceService{
1894+
ObjectMeta: metav1.ObjectMeta{
1895+
Name: "sklearn",
1896+
Namespace: namespace,
1897+
},
1898+
Spec: InferenceServiceSpec{
1899+
Predictor: PredictorSpec{
1900+
Model: &ModelSpec{
1901+
Runtime: ptr.To("mocked-v2-runtime"),
1902+
ModelFormat: ModelFormat{
1903+
Name: "sklearn",
1904+
},
1905+
PredictorExtensionSpec: PredictorExtensionSpec{
1906+
StorageURI: proto.String("s3://test"),
1907+
},
1908+
},
1909+
},
1910+
},
1911+
Status: InferenceServiceStatus{
1912+
Address: &knativeV1.Addressable{
1913+
URL: &apis.URL{
1914+
Scheme: "http",
1915+
Host: "sklearn-predictor.default.svc.cluster.local",
1916+
},
1917+
},
1918+
},
1919+
},
1920+
expectedUrl: "http://sklearn-predictor.default.svc.cluster.local/v2/models/sklearn/infer",
1921+
expectedErr: gomega.BeNil(),
1922+
},
18321923
}
18331924

18341925
for name, scenario := range scenarios {
18351926
t.Run(name, func(t *testing.T) {
1836-
res, err := GetPredictorEndpoint(&scenario.isvc)
1927+
res, err := GetPredictorEndpoint(t.Context(), mockClient, &scenario.isvc)
18371928
g.Expect(err).To(scenario.expectedErr)
18381929
if !g.Expect(res).To(gomega.Equal(scenario.expectedUrl)) {
18391930
t.Errorf("got %s, want %s", res, scenario.expectedUrl)

0 commit comments

Comments
 (0)