forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__main__.py
136 lines (114 loc) · 4.27 KB
/
__main__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
""" KServe wrapper to handler inference in the kserve_predictor """
import json
import logging
import os
import kserve
from kserve.model_server import ModelServer
from TorchserveModel import TorchserveModel
from TSModelRepository import TSModelRepository
logging.basicConfig(level=kserve.constants.KSERVE_LOGLEVEL)
DEFAULT_MODEL_NAME = "model"
DEFAULT_INFERENCE_ADDRESS = DEFAULT_MANAGEMENT_ADDRESS = "http://127.0.0.1:8085"
DEFAULT_GRPC_INFERENCE_PORT = "7070"
DEFAULT_MODEL_STORE = "/mnt/models/model-store"
CONFIG_PATH = "/mnt/models/config/config.properties"
def parse_config():
"""This function parses the model snapshot from the config.properties file
Returns:
model_name: The name of the model specified in the config.properties
inference_address: The inference address in which the inference endpoint is hit
management_address: The management address in which the model gets registered
model_store: the path in which the .mar file resides
"""
separator = "="
keys = {}
with open(CONFIG_PATH) as f:
for line in f:
if separator in line:
# Find the name and value by splitting the string
name, value = line.split(separator, 1)
# Assign key value pair to dict
# strip() removes white space from the ends of strings
keys[name.strip()] = value.strip()
keys["model_snapshot"] = json.loads(keys["model_snapshot"])
inference_address, management_address, grpc_inference_port, model_store = (
keys["inference_address"],
keys["management_address"],
keys["grpc_inference_port"],
keys["model_store"],
)
models = keys["model_snapshot"]["models"]
model_names = []
# Get all the model_names
for model, value in models.items():
model_names.append(model)
if not inference_address:
inference_address = DEFAULT_INFERENCE_ADDRESS
if not model_names:
model_names = [DEFAULT_MODEL_NAME]
if not inference_address:
inference_address = DEFAULT_INFERENCE_ADDRESS
if not management_address:
management_address = DEFAULT_MANAGEMENT_ADDRESS
inf_splits = inference_address.split(":")
if not grpc_inference_port:
grpc_inference_address = inf_splits[1] + ":" + DEFAULT_GRPC_INFERENCE_PORT
else:
grpc_inference_address = inf_splits[1] + ":" + grpc_inference_port
grpc_inference_address = grpc_inference_address.replace("/", "")
if not model_store:
model_store = DEFAULT_MODEL_STORE
logging.info(
"Wrapper : Model names %s, inference address %s, management address %s, grpc_inference_address, %s, model store %s",
model_names,
inference_address,
management_address,
grpc_inference_address,
model_store,
)
return (
model_names,
inference_address,
management_address,
grpc_inference_address,
model_store,
)
if __name__ == "__main__":
(
model_names,
inference_address,
management_address,
grpc_inference_address,
model_dir,
) = parse_config()
protocol = os.environ.get("PROTOCOL_VERSION")
models = []
for model_name in model_names:
model = TorchserveModel(
model_name,
inference_address,
management_address,
grpc_inference_address,
protocol,
model_dir,
)
# By default model.load() is called on first request. Enabling load all
# model in TS config.properties, all models are loaded at start and the
# below method sets status to true for the models.
# However, even if all preparations related to loading the model (e.g.,
# download pretrained models using online storage) are not completed in
# torchserve handler, if model.ready=true is set, there may be problems.
# Therefore, the ready status is determined using the api provided by
# torchserve.
model.load()
models.append(model)
registeredModels = TSModelRepository(
inference_address,
management_address,
model_dir,
)
ModelServer(
registered_models=registeredModels,
http_port=8080,
grpc_port=8081,
).start(models)