Skip to content

Commit 28cf01e

Browse files
committed
Inf2 hf endpoints docker image
Signed-off-by: Raphael Glon <[email protected]>
1 parent c4ca4f5 commit 28cf01e

File tree

3 files changed

+257
-1
lines changed

3 files changed

+257
-1
lines changed

dockerfiles/pytorch/Dockerfile.inf2

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,7 @@ COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starle
9898

9999
# copy entrypoint and change permissions
100100
COPY --chmod=0755 scripts/entrypoint.sh entrypoint.sh
101+
COPY --chmod=0755 scripts/inf2_env.py inf2.env.py
102+
COPY --chmod=0755 scripts/inf2_entrypoint.sh inf2_entrypoint.sh
101103

102-
ENTRYPOINT ["bash", "-c", "./entrypoint.sh"]
104+
ENTRYPOINT ["bash", "-c", "./inf2_entrypoint.sh"]

scripts/inf2_entrypoint.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
set -e -o pipefail -u
3+
4+
export ENV_FILEPATH=$(mktemp)
5+
6+
trap "rm -f ${ENV_FILEPATH}" EXIT
7+
8+
touch $ENV_FILEPATH
9+
10+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
11+
12+
${SCRIPT_DIR}/inf2_env.py $@
13+
14+
source $ENV_FILEPATH
15+
16+
rm -f $ENV_FILEPATH
17+
18+
exec ${SCRIPT_DIR}/entrypoint.sh $@

scripts/inf2_env.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
#!/usr/bin/env python
2+
3+
"""
4+
This script is here to specify all missing environment variables that would be required to run some encoder models on
5+
inferentia2.
6+
"""
7+
8+
import argparse
9+
import logging
10+
import os
11+
import sys
12+
from typing import Any, Dict, List, Optional
13+
14+
from huggingface_hub import constants
15+
from transformers import AutoConfig
16+
17+
from optimum.neuron.utils import get_hub_cached_entries
18+
from optimum.neuron.utils.version_utils import get_neuronxcc_version
19+
20+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', force=True)
21+
logger = logging.getLogger(__name__)
22+
23+
env_config_peering = [
24+
("HF_BATCH_SIZE", "static_batch_size"),
25+
("HF_OPTIMUM_SEQUENCE_LENGTH", "static_sequence_length"),
26+
("HF_AUTO_CAST_TYPE", "auto_cast_type"),
27+
]
28+
29+
# By the end of this script all env vars should be specified properly
30+
env_vars = list(map(lambda x: x[0], env_config_peering))
31+
32+
# Currently not used for encoder models
33+
# available_cores = get_available_cores()
34+
35+
neuronxcc_version = get_neuronxcc_version()
36+
37+
38+
def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
39+
parser = argparse.ArgumentParser()
40+
if not argv:
41+
argv = sys.argv
42+
# All these are params passed to tgi and intercepted here
43+
parser.add_argument(
44+
"--batch-size",
45+
type=int,
46+
default=os.getenv("HF_BATCH_SIZE", os.getenv("BATCH_SIZE", 0)),
47+
)
48+
parser.add_argument(
49+
"--sequence-length", type=int,
50+
default=os.getenv("HF_OPTIMUM_SEQUENCE_LENGTH",
51+
os.getenv("SEQUENCE_LENGTH", 0))
52+
)
53+
54+
parser.add_argument("--model-id", type=str, default=os.getenv("MODEL_ID", os.getenv("HF_MODEL_DIR")))
55+
parser.add_argument("--revision", type=str, default=os.getenv("REVISION"))
56+
57+
args = parser.parse_known_args(argv)[0]
58+
59+
if not args.model_id:
60+
raise Exception(
61+
"No model id provided ! Either specify it using --model-id cmdline or MODEL_ID env var"
62+
)
63+
64+
# Override env with cmdline params
65+
os.environ["MODEL_ID"] = args.model_id
66+
67+
# Set all tgi router and tgi server values to consistent values as early as possible
68+
# from the order of the parser defaults, the tgi router value can override the tgi server ones
69+
if args.batch_size > 0:
70+
os.environ["HF_BATCH_SIZE"] = str(args.batch_size)
71+
72+
if args.sequence_length > 0:
73+
os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = str(args.sequence_length)
74+
75+
if args.revision:
76+
os.environ["REVISION"] = str(args.revision)
77+
78+
return args
79+
80+
81+
def neuron_config_to_env(neuron_config):
82+
with open(os.environ["ENV_FILEPATH"], "w") as f:
83+
for env_var, config_key in env_config_peering:
84+
f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
85+
86+
87+
def sort_neuron_configs(dictionary):
88+
return -dictionary["static_batch_size"]
89+
90+
91+
def lookup_compatible_cached_model(
92+
model_id: str, revision: Optional[str]
93+
) -> Optional[Dict[str, Any]]:
94+
# Reuse the same mechanic as the one in use to configure the tgi server part
95+
# The only difference here is that we stay as flexible as possible on the compatibility part
96+
entries = get_hub_cached_entries(model_id, "inference")
97+
98+
logger.debug(
99+
"Found %d cached entries for model %s, revision %s",
100+
len(entries),
101+
model_id,
102+
revision,
103+
)
104+
105+
all_compatible = []
106+
for entry in entries:
107+
if check_env_and_neuron_config_compatibility(
108+
entry, check_compiler_version=True
109+
):
110+
all_compatible.append(entry)
111+
112+
if not all_compatible:
113+
logger.debug(
114+
"No compatible cached entry found for model %s, env %s, neuronxcc version %s",
115+
model_id,
116+
get_env_dict(),
117+
neuronxcc_version,
118+
)
119+
return None
120+
121+
logger.info("%d compatible neuron cached models found", len(all_compatible))
122+
123+
all_compatible = sorted(all_compatible, key=sort_neuron_configs)
124+
125+
entry = all_compatible[0]
126+
127+
logger.info("Selected entry %s", entry)
128+
129+
return entry
130+
131+
132+
def check_env_and_neuron_config_compatibility(
133+
neuron_config: Dict[str, Any], check_compiler_version: bool
134+
) -> bool:
135+
logger.debug(
136+
"Checking the provided neuron config %s is compatible with the local setup and provided environment",
137+
neuron_config,
138+
)
139+
140+
# Local setup compat checks
141+
# if neuron_config["num_cores"] > available_cores:
142+
# logger.debug(
143+
# "Not enough neuron cores available to run the provided neuron config"
144+
# )
145+
# return False
146+
147+
if (
148+
check_compiler_version
149+
and neuron_config["compiler_version"] != neuronxcc_version
150+
):
151+
logger.debug(
152+
"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
153+
neuronxcc_version,
154+
neuron_config["compiler_version"],
155+
)
156+
return False
157+
158+
for env_var, config_key in env_config_peering:
159+
try:
160+
neuron_config_value = str(neuron_config[config_key])
161+
except KeyError:
162+
logger.debug("No key %s found in neuron config %s", config_key, neuron_config)
163+
return False
164+
env_value = os.getenv(env_var, str(neuron_config_value))
165+
if env_value != neuron_config_value:
166+
logger.debug(
167+
"The provided env var '%s' and the neuron config '%s' param differ (%s != %s)",
168+
env_var,
169+
config_key,
170+
env_value,
171+
neuron_config_value,
172+
)
173+
return False
174+
175+
return True
176+
177+
178+
def get_env_dict() -> Dict[str, str]:
179+
d = {}
180+
for k in env_vars:
181+
d[k] = os.getenv(k)
182+
return d
183+
184+
185+
def main():
186+
"""
187+
This script determines proper default TGI env variables for the neuron precompiled models to
188+
work properly
189+
:return:
190+
"""
191+
args = parse_cmdline_and_set_env()
192+
193+
for env_var in env_vars:
194+
if not os.getenv(env_var):
195+
break
196+
else:
197+
logger.info(
198+
"All env vars %s already set, skipping, user know what they are doing",
199+
env_vars,
200+
)
201+
sys.exit(0)
202+
203+
cache_dir = constants.HF_HUB_CACHE
204+
205+
logger.info("Cache dir %s, model %s", cache_dir, args.model_id)
206+
207+
config = AutoConfig.from_pretrained(args.model_id, revision=args.revision)
208+
neuron_config = getattr(config, "neuron", None)
209+
if neuron_config is not None:
210+
compatible = check_env_and_neuron_config_compatibility(
211+
neuron_config, check_compiler_version=False
212+
)
213+
if not compatible:
214+
env_dict = get_env_dict()
215+
msg = (
216+
"Invalid neuron config and env. Config {}, env {}, neuronxcc version {}"
217+
).format(neuron_config, env_dict, neuronxcc_version)
218+
logger.error(msg)
219+
raise Exception(msg)
220+
else:
221+
neuron_config = lookup_compatible_cached_model(args.model_id, args.revision)
222+
223+
if not neuron_config:
224+
neuron_config = {'static_batch_size': 1, 'static_sequence_length': 128}
225+
msg = (
226+
"No compatible neuron config found. Provided env {}, neuronxcc version {}. Falling back to default"
227+
).format(get_env_dict(), neuronxcc_version, neuron_config)
228+
logger.info(msg)
229+
230+
logger.info("Final neuron config %s", neuron_config)
231+
232+
neuron_config_to_env(neuron_config)
233+
234+
235+
if __name__ == "__main__":
236+
main()

0 commit comments

Comments
 (0)