Skip to content

Commit 4413753

Browse files
authored
Allow to set custom timeouts for InferenceGraph router (kserve#4218)
Signed-off-by: Jakub Filo <jakub.filo@customink.com>
1 parent 46434e1 commit 4413753

15 files changed

+686
-18
lines changed

cmd/router/main.go

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
66
You may obtain a copy of the License at
77
8-
http://www.apache.org/licenses/LICENSE-2.0
8+
http://www.apache.org/licenses/LICENSE-2.0
99
1010
Unless required by applicable law or agreed to in writing, software
1111
distributed under the License is distributed on an "AS IS" BASIS,
@@ -45,8 +45,6 @@ import (
4545
"github.com/kserve/kserve/pkg/constants"
4646
)
4747

48-
var log = logf.Log.WithName("InferenceGraphRouter")
49-
5048
// _isInMesh is an auxiliary global variable for isInIstioMesh function.
5149
var _isInMesh *bool
5250

@@ -147,7 +145,16 @@ func callService(serviceUrl string, input []byte, headers http.Header) ([]byte,
147145
if val := req.Header.Get("Content-Type"); val == "" {
148146
req.Header.Add("Content-Type", "application/json")
149147
}
150-
resp, err := http.DefaultClient.Do(req)
148+
149+
var client *http.Client
150+
if routerTimeouts == nil || routerTimeouts.ServiceClient == nil {
151+
client = http.DefaultClient
152+
} else {
153+
client = &http.Client{
154+
Timeout: time.Duration(*routerTimeouts.ServiceClient) * time.Second,
155+
}
156+
}
157+
resp, err := client.Do(req)
151158
if err != nil {
152159
log.Error(err, "An error has occurred while calling service", "service", serviceUrl)
153160
return nil, 500, err
@@ -373,8 +380,6 @@ func prepareErrorResponse(err error, errorMessage string) []byte {
373380
return errorResponseBytes
374381
}
375382

376-
var inferenceGraph *v1alpha1.InferenceGraphSpec
377-
378383
func graphHandler(w http.ResponseWriter, req *http.Request) {
379384
inputBytes, _ := io.ReadAll(req.Body)
380385
if response, statusCode, err := routeStep(v1alpha1.GraphRootNodeName, *inferenceGraph, inputBytes, req.Header); err != nil {
@@ -409,6 +414,35 @@ func compilePatterns(patterns []string) ([]*regexp.Regexp, error) {
409414
return compiled, goerrors.Join(allErrors...)
410415
}
411416

417+
func getTimeout(value, defaultValue *int64) *int64 {
418+
if value != nil {
419+
return value
420+
}
421+
return defaultValue
422+
}
423+
424+
func initTimeouts(graph v1alpha1.InferenceGraphSpec) {
425+
defaultServerRead := int64(constants.RouterTimeoutsServerRead)
426+
defaultServerWrite := int64(constants.RouterTimeoutServerWrite)
427+
defaultServerIdle := int64(constants.RouterTimeoutServerIdle)
428+
429+
timeouts := &v1alpha1.InfereceGraphRouterTimeouts{
430+
ServerRead: &defaultServerRead,
431+
ServerWrite: &defaultServerWrite,
432+
ServerIdle: &defaultServerIdle,
433+
ServiceClient: nil,
434+
}
435+
436+
if graph.RouterTimeouts != nil {
437+
timeouts.ServerRead = getTimeout(graph.RouterTimeouts.ServerRead, &defaultServerRead)
438+
timeouts.ServerWrite = getTimeout(graph.RouterTimeouts.ServerWrite, &defaultServerWrite)
439+
timeouts.ServerIdle = getTimeout(graph.RouterTimeouts.ServerIdle, &defaultServerIdle)
440+
timeouts.ServiceClient = getTimeout(graph.RouterTimeouts.ServiceClient, nil)
441+
}
442+
443+
routerTimeouts = timeouts
444+
}
445+
412446
// Mainly used for kubernetes readiness probe. It responds with "503 shutting down" if server is shutting down,
413447
// otherwise returns "200 OK".
414448
func readyHandler(w http.ResponseWriter, req *http.Request) {
@@ -420,10 +454,14 @@ func readyHandler(w http.ResponseWriter, req *http.Request) {
420454
}
421455

422456
var (
423-
jsonGraph = flag.String("graph-json", "", "serialized json graph def")
457+
jsonGraph = flag.String("graph-json", "", "serialized json graph def")
458+
inferenceGraph *v1alpha1.InferenceGraphSpec = nil
424459
compiledHeaderPatterns []*regexp.Regexp
425-
isShuttingDown = false
426-
drainSleepDuration = 30 * time.Second
460+
isShuttingDown = false
461+
drainSleepDuration = 30 * time.Second
462+
routerTimeouts *v1alpha1.InfereceGraphRouterTimeouts = nil
463+
log = logf.Log.WithName("InferenceGraphRouter")
464+
signalChan = make(chan os.Signal, 1)
427465
)
428466

429467
func main() {
@@ -438,22 +476,24 @@ func main() {
438476
log.Error(err, "Failed to compile some header patterns")
439477
}
440478
}
479+
441480
inferenceGraph = &v1alpha1.InferenceGraphSpec{}
442481
err := json.Unmarshal([]byte(*jsonGraph), inferenceGraph)
443482
if err != nil {
444483
log.Error(err, "failed to unmarshall inference graph json")
445484
os.Exit(1)
446485
}
486+
initTimeouts(*inferenceGraph)
447487

448488
http.HandleFunc("/", graphHandler)
449489
http.HandleFunc(constants.RouterReadinessEndpoint, readyHandler)
450490

451491
server := &http.Server{
452492
Addr: ":" + strconv.Itoa(constants.RouterPort),
453-
Handler: nil, // default server mux
454-
ReadTimeout: time.Minute, // https://medium.com/a-journey-with-go/go-understand-and-mitigate-slowloris-attack-711c1b1403f6
455-
WriteTimeout: time.Minute, // set the maximum duration before timing out writes of the response
456-
IdleTimeout: 3 * time.Minute, // set the maximum amount of time to wait for the next request when keep-alives are enabled
493+
Handler: nil, // default server mux
494+
ReadTimeout: time.Duration(*routerTimeouts.ServerRead) * time.Second, // set the maximum duration for reading the entire request, including the body
495+
WriteTimeout: time.Duration(*routerTimeouts.ServerWrite) * time.Second, // set the maximum duration before timing out writes of the response
496+
IdleTimeout: time.Duration(*routerTimeouts.ServerIdle) * time.Second, // set the maximum amount of time to wait for the next request when keep-alives are enabled
457497
}
458498

459499
go func() {
@@ -469,7 +509,6 @@ func main() {
469509
}
470510

471511
func handleSignals(server *http.Server) {
472-
signalChan := make(chan os.Signal, 1)
473512
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
474513

475514
sig := <-signalChan

cmd/router/main_test.go

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
66
You may obtain a copy of the License at
77
8-
http://www.apache.org/licenses/LICENSE-2.0
8+
http://www.apache.org/licenses/LICENSE-2.0
99
1010
Unless required by applicable law or agreed to in writing, software
1111
distributed under the License is distributed on an "AS IS" BASIS,
@@ -17,13 +17,17 @@ limitations under the License.
1717
package main
1818

1919
import (
20+
"bytes"
2021
"encoding/json"
2122
"fmt"
2223
"io"
2324
"net/http"
2425
"net/http/httptest"
2526
"regexp"
27+
"strconv"
28+
"syscall"
2629
"testing"
30+
"time"
2731

2832
"github.com/stretchr/testify/assert"
2933
"github.com/stretchr/testify/require"
@@ -32,12 +36,17 @@ import (
3236
"sigs.k8s.io/controller-runtime/pkg/log/zap"
3337

3438
"github.com/kserve/kserve/pkg/apis/serving/v1alpha1"
39+
"github.com/kserve/kserve/pkg/constants"
3540
)
3641

3742
func init() {
3843
logf.SetLogger(zap.New())
3944
}
4045

46+
func Int64Ptr(i int64) *int64 {
47+
return &i
48+
}
49+
4150
func TestSimpleModelChainer(t *testing.T) {
4251
// Start a local HTTP server
4352
model1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@@ -883,3 +892,129 @@ func TestCallServiceWhenMultipleHeadersToPropagateUsingInvalidPattern(t *testing
883892
fmt.Printf("final response:%v\n", response)
884893
require.Equal(t, expectedResponse, response)
885894
}
895+
896+
func TestServerTimeout(t *testing.T) {
897+
testCases := []struct {
898+
name string
899+
serverTimeout *int64
900+
serviceStepDuration time.Duration
901+
expectError bool
902+
}{
903+
{
904+
name: "default",
905+
serverTimeout: nil,
906+
serviceStepDuration: 1 * time.Millisecond,
907+
expectError: false,
908+
},
909+
{
910+
name: "timeout",
911+
serverTimeout: Int64Ptr(1),
912+
serviceStepDuration: 500 * time.Millisecond,
913+
expectError: true,
914+
},
915+
{
916+
name: "success",
917+
serverTimeout: Int64Ptr(2),
918+
serviceStepDuration: 500 * time.Millisecond,
919+
expectError: false,
920+
},
921+
}
922+
923+
for _, testCase := range testCases {
924+
t.Run(testCase.name, func(t *testing.T) {
925+
drainSleepDuration = 0 * time.Millisecond // instant shutdown
926+
927+
// Setup and start dummy models
928+
model1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
929+
_, err := io.ReadAll(req.Body)
930+
if err != nil {
931+
return
932+
}
933+
time.Sleep(testCase.serviceStepDuration)
934+
response := map[string]interface{}{"predictions": "1"}
935+
responseBytes, _ := json.Marshal(response)
936+
rw.Write(responseBytes)
937+
}))
938+
model1Url, err := apis.ParseURL(model1.URL)
939+
if err != nil {
940+
t.Fatalf("Failed to parse model url")
941+
}
942+
defer model1.Close()
943+
944+
model2 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
945+
_, err := io.ReadAll(req.Body)
946+
if err != nil {
947+
return
948+
}
949+
time.Sleep(testCase.serviceStepDuration)
950+
response := map[string]interface{}{"predictions": "2"}
951+
responseBytes, _ := json.Marshal(response)
952+
rw.Write(responseBytes)
953+
}))
954+
model2Url, err := apis.ParseURL(model2.URL)
955+
if err != nil {
956+
t.Fatalf("Failed to parse model url")
957+
}
958+
defer model2.Close()
959+
960+
// Create InferenceGraph
961+
graphSpec := v1alpha1.InferenceGraphSpec{
962+
Nodes: map[string]v1alpha1.InferenceRouter{
963+
"root": {
964+
RouterType: v1alpha1.Sequence,
965+
Steps: []v1alpha1.InferenceStep{
966+
{
967+
StepName: "model1",
968+
InferenceTarget: v1alpha1.InferenceTarget{
969+
ServiceURL: model1Url.String(),
970+
},
971+
},
972+
{
973+
StepName: "model2",
974+
InferenceTarget: v1alpha1.InferenceTarget{
975+
ServiceURL: model2Url.String(),
976+
},
977+
Data: "$response",
978+
},
979+
},
980+
},
981+
},
982+
}
983+
if testCase.serverTimeout != nil {
984+
timeout := *testCase.serverTimeout
985+
graphSpec.RouterTimeouts = &v1alpha1.InfereceGraphRouterTimeouts{
986+
ServerRead: &timeout,
987+
ServerWrite: &timeout,
988+
ServerIdle: &timeout,
989+
}
990+
}
991+
jsonBytes, _ := json.Marshal(graphSpec)
992+
*jsonGraph = string(jsonBytes)
993+
994+
// Start InferenceGraph router server in a separate goroutine
995+
go func() {
996+
main()
997+
}()
998+
t.Cleanup(func() {
999+
http.DefaultServeMux = http.NewServeMux() // reset http handlers
1000+
signalChan <- syscall.SIGTERM // shutdown the server
1001+
})
1002+
1003+
// Call the InferenceGraph
1004+
client := &http.Client{}
1005+
time.Sleep(1 * time.Second) // prevent race condition
1006+
req, _ := http.NewRequest(http.MethodPost, "http://localhost:"+strconv.Itoa(constants.RouterPort), bytes.NewBuffer(nil))
1007+
resp, err := client.Do(req)
1008+
if resp != nil {
1009+
defer resp.Body.Close()
1010+
}
1011+
1012+
if testCase.expectError {
1013+
assert.Contains(t, err.Error(), "EOF")
1014+
} else {
1015+
require.NoError(t, err)
1016+
assert.Equal(t, http.StatusOK, resp.StatusCode)
1017+
}
1018+
})
1019+
}
1020+
}

config/crd/full/serving.kserve.io_inferencegraphs.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,21 @@ spec:
561561
x-kubernetes-int-or-string: true
562562
type: object
563563
type: object
564+
routerTimeouts:
565+
properties:
566+
serverIdle:
567+
format: int64
568+
type: integer
569+
serverRead:
570+
format: int64
571+
type: integer
572+
serverWrite:
573+
format: int64
574+
type: integer
575+
serviceClient:
576+
format: int64
577+
type: integer
578+
type: object
564579
scaleMetric:
565580
enum:
566581
- cpu

pkg/apis/serving/v1alpha1/inference_graph.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
66
You may obtain a copy of the License at
77
8-
http://www.apache.org/licenses/LICENSE-2.0
8+
http://www.apache.org/licenses/LICENSE-2.0
99
1010
Unless required by applicable law or agreed to in writing, software
1111
distributed under the License is distributed on an "AS IS" BASIS,
@@ -52,6 +52,8 @@ type InferenceGraphSpec struct {
5252
// TimeoutSeconds specifies the number of seconds to wait before timing out a request to the component.
5353
// +optional
5454
TimeoutSeconds *int64 `json:"timeout,omitempty"`
55+
// +optional
56+
RouterTimeouts *InfereceGraphRouterTimeouts `json:"routerTimeouts,omitempty"`
5557
// Minimum number of replicas, defaults to 1 but can be set to 0 to enable scale-to-zero.
5658
// +optional
5759
MinReplicas *int32 `json:"minReplicas,omitempty"`
@@ -115,6 +117,22 @@ const (
115117
GraphRootNodeName string = "root"
116118
)
117119

120+
// +k8s:openapi-gen=true
121+
type InfereceGraphRouterTimeouts struct {
122+
// ServerRead specifies the number of seconds to wait before timing out a request read by the server.
123+
// +optional
124+
ServerRead *int64 `json:"serverRead,omitempty"`
125+
// ServerWrite specifies the maximum duration in seconds before timing out writes of the response.
126+
// +optional
127+
ServerWrite *int64 `json:"serverWrite,omitempty"`
128+
// ServerIdle specifies the maximum amount of time in seconds to wait for the next request when keep-alives are enabled.
129+
// +optional
130+
ServerIdle *int64 `json:"serverIdle,omitempty"`
131+
// ServiceClient specifies a time limit in seconds for requests made to the graph components by HTTP client.
132+
// +optional
133+
ServiceClient *int64 `json:"serviceClient,omitempty"`
134+
}
135+
118136
// +k8s:openapi-gen=true
119137
// InferenceRouter defines the router for each InferenceGraph node with one or multiple steps
120138
//

0 commit comments

Comments
 (0)