Skip to content

Commit 6df887c

Browse files
committed
Inf2 hf endpoints docker image
Signed-off-by: Raphael Glon <[email protected]>
1 parent c4ca4f5 commit 6df887c

File tree

3 files changed

+256
-1
lines changed

3 files changed

+256
-1
lines changed

dockerfiles/pytorch/Dockerfile.inf2

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,7 @@ COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starle
9898

9999
# copy entrypoint and change permissions
100100
COPY --chmod=0755 scripts/entrypoint.sh entrypoint.sh
101+
COPY --chmod=0755 scripts/inf2_env.py inf2.env.py
102+
COPY --chmod=0755 scripts/inf2_entrypoint.sh inf2_entrypoint.sh
101103

102-
ENTRYPOINT ["bash", "-c", "./entrypoint.sh"]
104+
ENTRYPOINT ["bash", "-c", "./inf2_entrypoint.sh"]

scripts/inf2_entrypoint.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
set -e -o pipefail -u
3+
4+
export ENV_FILEPATH=$(mktemp)
5+
6+
trap "rm -f ${ENV_FILEPATH}" EXIT
7+
8+
touch $ENV_FILEPATH
9+
10+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
11+
12+
${SCRIPT_DIR}/inf2_env.py $@
13+
14+
source $ENV_FILEPATH
15+
16+
rm -f $ENV_FILEPATH
17+
18+
exec ${SCRIPT_DIR}/entrypoint.sh $@

scripts/inf2_env.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
#!/usr/bin/env python
2+
3+
"""
4+
This script is here to specify all missing environment variables that would be required to run some encoder models on
5+
inferentia2.
6+
"""
7+
8+
import argparse
9+
import logging
10+
import os
11+
import sys
12+
from typing import Any, Dict, List, Optional
13+
14+
from huggingface_hub import constants
15+
from transformers import AutoConfig
16+
17+
from optimum.neuron.utils import get_hub_cached_entries
18+
from optimum.neuron.utils.version_utils import get_neuronxcc_version
19+
20+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', force=True)
21+
logger = logging.getLogger(__name__)
22+
23+
env_config_peering = [
24+
("HF_BATCH_SIZE", "static_batch_size"),
25+
("HF_OPTIMUM_SEQUENCE_LENGTH", "static_sequence_length"),
26+
]
27+
28+
# By the end of this script all env vars should be specified properly
29+
env_vars = list(map(lambda x: x[0], env_config_peering))
30+
31+
# Currently not used for encoder models
32+
# available_cores = get_available_cores()
33+
34+
neuronxcc_version = get_neuronxcc_version()
35+
36+
37+
def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
38+
parser = argparse.ArgumentParser()
39+
if not argv:
40+
argv = sys.argv
41+
# All these are params passed to tgi and intercepted here
42+
parser.add_argument(
43+
"--batch-size",
44+
type=int,
45+
default=os.getenv("HF_BATCH_SIZE", os.getenv("BATCH_SIZE", 0)),
46+
)
47+
parser.add_argument(
48+
"--sequence-length", type=int,
49+
default=os.getenv("HF_OPTIMUM_SEQUENCE_LENGTH",
50+
os.getenv("SEQUENCE_LENGTH", 0))
51+
)
52+
53+
parser.add_argument("--model-id", type=str, default=os.getenv("MODEL_ID", os.getenv("HF_MODEL_DIR")))
54+
parser.add_argument("--revision", type=str, default=os.getenv("REVISION"))
55+
56+
args = parser.parse_known_args(argv)[0]
57+
58+
if not args.model_id:
59+
raise Exception(
60+
"No model id provided ! Either specify it using --model-id cmdline or MODEL_ID env var"
61+
)
62+
63+
# Override env with cmdline params
64+
os.environ["MODEL_ID"] = args.model_id
65+
66+
# Set all tgi router and tgi server values to consistent values as early as possible
67+
# from the order of the parser defaults, the tgi router value can override the tgi server ones
68+
if args.batch_size > 0:
69+
os.environ["HF_BATCH_SIZE"] = str(args.batch_size)
70+
71+
if args.sequence_length > 0:
72+
os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = str(args.sequence_length)
73+
74+
if args.revision:
75+
os.environ["REVISION"] = str(args.revision)
76+
77+
return args
78+
79+
80+
def neuron_config_to_env(neuron_config):
81+
with open(os.environ["ENV_FILEPATH"], "w") as f:
82+
for env_var, config_key in env_config_peering:
83+
f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
84+
85+
86+
def sort_neuron_configs(dictionary):
87+
return -dictionary["static_batch_size"]
88+
89+
90+
def lookup_compatible_cached_model(
91+
model_id: str, revision: Optional[str]
92+
) -> Optional[Dict[str, Any]]:
93+
# Reuse the same mechanic as the one in use to configure the tgi server part
94+
# The only difference here is that we stay as flexible as possible on the compatibility part
95+
entries = get_hub_cached_entries(model_id, "inference")
96+
97+
logger.debug(
98+
"Found %d cached entries for model %s, revision %s",
99+
len(entries),
100+
model_id,
101+
revision,
102+
)
103+
104+
all_compatible = []
105+
for entry in entries:
106+
if check_env_and_neuron_config_compatibility(
107+
entry, check_compiler_version=True
108+
):
109+
all_compatible.append(entry)
110+
111+
if not all_compatible:
112+
logger.debug(
113+
"No compatible cached entry found for model %s, env %s, neuronxcc version %s",
114+
model_id,
115+
get_env_dict(),
116+
neuronxcc_version,
117+
)
118+
return None
119+
120+
logger.info("%d compatible neuron cached models found", len(all_compatible))
121+
122+
all_compatible = sorted(all_compatible, key=sort_neuron_configs)
123+
124+
entry = all_compatible[0]
125+
126+
logger.info("Selected entry %s", entry)
127+
128+
return entry
129+
130+
131+
def check_env_and_neuron_config_compatibility(
132+
neuron_config: Dict[str, Any], check_compiler_version: bool
133+
) -> bool:
134+
logger.debug(
135+
"Checking the provided neuron config %s is compatible with the local setup and provided environment",
136+
neuron_config,
137+
)
138+
139+
# Local setup compat checks
140+
# if neuron_config["num_cores"] > available_cores:
141+
# logger.debug(
142+
# "Not enough neuron cores available to run the provided neuron config"
143+
# )
144+
# return False
145+
146+
if (
147+
check_compiler_version
148+
and neuron_config["compiler_version"] != neuronxcc_version
149+
):
150+
logger.debug(
151+
"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
152+
neuronxcc_version,
153+
neuron_config["compiler_version"],
154+
)
155+
return False
156+
157+
for env_var, config_key in env_config_peering:
158+
try:
159+
neuron_config_value = str(neuron_config[config_key])
160+
except KeyError:
161+
logger.debug("No key %s found in neuron config %s", config_key, neuron_config)
162+
return False
163+
env_value = os.getenv(env_var, str(neuron_config_value))
164+
if env_value != neuron_config_value:
165+
logger.debug(
166+
"The provided env var '%s' and the neuron config '%s' param differ (%s != %s)",
167+
env_var,
168+
config_key,
169+
env_value,
170+
neuron_config_value,
171+
)
172+
return False
173+
174+
return True
175+
176+
177+
def get_env_dict() -> Dict[str, str]:
178+
d = {}
179+
for k in env_vars:
180+
d[k] = os.getenv(k)
181+
return d
182+
183+
184+
def main():
185+
"""
186+
This script determines proper default TGI env variables for the neuron precompiled models to
187+
work properly
188+
:return:
189+
"""
190+
args = parse_cmdline_and_set_env()
191+
192+
for env_var in env_vars:
193+
if not os.getenv(env_var):
194+
break
195+
else:
196+
logger.info(
197+
"All env vars %s already set, skipping, user know what they are doing",
198+
env_vars,
199+
)
200+
sys.exit(0)
201+
202+
cache_dir = constants.HF_HUB_CACHE
203+
204+
logger.info("Cache dir %s, model %s", cache_dir, args.model_id)
205+
206+
config = AutoConfig.from_pretrained(args.model_id, revision=args.revision)
207+
neuron_config = getattr(config, "neuron", None)
208+
if neuron_config is not None:
209+
compatible = check_env_and_neuron_config_compatibility(
210+
neuron_config, check_compiler_version=False
211+
)
212+
if not compatible:
213+
env_dict = get_env_dict()
214+
msg = (
215+
"Invalid neuron config and env. Config {}, env {}, neuronxcc version {}"
216+
).format(neuron_config, env_dict, neuronxcc_version)
217+
logger.error(msg)
218+
raise Exception(msg)
219+
else:
220+
neuron_config = lookup_compatible_cached_model(args.model_id, args.revision)
221+
222+
if not neuron_config:
223+
neuron_config = {'static_batch_size': 1, 'static_sequence_length': 128}
224+
msg = (
225+
"No compatible neuron config found. Provided env {}, neuronxcc version {}. Falling back to default"
226+
).format(get_env_dict(), neuronxcc_version, neuron_config)
227+
logger.info(msg)
228+
229+
logger.info("Final neuron config %s", neuron_config)
230+
231+
neuron_config_to_env(neuron_config)
232+
233+
234+
if __name__ == "__main__":
235+
main()

0 commit comments

Comments
 (0)