Skip to content

Commit c7868c9

Browse files
author
maxtext authors
committed
Merge pull request #1792 from AI-Hypercomputer:jacobplatin/fix-misc-inference-benchmark-issues
PiperOrigin-RevId: 765181154
2 parents 910888a + 2da3f75 commit c7868c9

File tree

7 files changed

+45
-19
lines changed

7 files changed

+45
-19
lines changed

MaxText/inference_mlperf/README.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,24 @@ source .env/bin/activate
1313
```
1414

1515
### Install loadgen
16+
Note: this is taken from the MLCommons inference [README](https://github.com/mlcommons/inference/blob/master/loadgen/README_BUILD.md#quick-start) (as of May 2025).
1617
```
17-
sudo apt-get install python3-dev
18-
sudo apt-get install build-essential -y
19-
git clone [email protected]:mlcommons/inference.git
20-
cd inference/
21-
cd loadgen/ && python3 -m pip install .
18+
pip install absl-py numpy
19+
git clone --recurse-submodules https://github.com/mlcommons/inference.git mlperf_inference
20+
cd mlperf_inference/loadgen
21+
CFLAGS="-std=c++14 -O3" python -m pip install .
2222
```
2323

24+
If you run into an issue like the following:
25+
26+
```
27+
ImportError: venv/lib/libstdc++.so.6: version `GLIBCXX_3.4.30'
28+
not found (required by venv/lib/python3.10/site-packages/lperf_loadgen.cpython-310-x86_64-linux-gnu.so)
29+
```
30+
31+
Please try running `conda install -c conda-forge gcc_linux-64 gxx_linux-64 libstdcxx-ng` if you using Conda or `sudo apt install build-essential` if you are using Venv and then reinstalling `loadgen`
32+
33+
2434
### Download datasets
2535

2636
```

MaxText/inference_mlperf/llama_offline_run.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ if [ -z "$MAXENGINE_ARGS" ];
9292
then
9393
CHECKPOINT="gs://msingh-bkt/checkpoints/quant_${MODEL_NAME}-chat/mlperf_070924/int8_"
9494
BASE_CFG="model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${CHECKPOINT}"
95-
QUANT_CFG="quantization=int8 quantize_kvcache=True checkpoint_is_quantized=True"
95+
QUANT_CFG="quantization=int8 quantize_kvcache=True checkpoint_is_quantized=True skip_jax_distributed_system=true"
9696
MAXENGINE_ARGS="${BASE_CFG} ${QUANT_CFG}"
9797
fi
9898

@@ -117,7 +117,7 @@ else
117117
export DATASET_TYPE=full
118118
export DATASET_PATH=${DATA_DISK_DIR}/processed-data.pkl
119119
export TOTAL_SAMPLE_COUNT=24576
120-
export USER_CONFIG=user.conf
120+
export USER_CONFIG=user.conf # NOTE: you may need to change this path(e.g. `MaxText/inference_mlperf/user.conf`)
121121
fi
122122

123123
# LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

MaxText/inference_mlperf/offline_inference.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
DecodeState = Any
4242
Params = Any
43+
PRNGKeyType = Any
4344

4445
log = logging.getLogger(__name__)
4546

@@ -130,19 +131,20 @@ def process(
130131
input_true_length: int,
131132
max_length: int,
132133
prefill_done: Callable[[List[Tuple[engine_api.ResultTokens, int]], List[int], DecodeState], None],
134+
rng: PRNGKeyType,
133135
) -> None:
134136
"""Prefill helper process runner"""
135137
padded_length = len(input_tokens_padded)
136138
if self._type == "default":
137139
first_token, decode_state = self._processor.process(
138-
model_params, decode_state, decode_slot, input_tokens_padded, input_true_length
140+
model_params, decode_state, decode_slot, input_tokens_padded, input_true_length, rng
139141
)
140142
prefill_done([(first_token, decode_slot)], [input_id], decode_state)
141143
elif self._type == "batch":
142144
if padded_length == max_length:
143145
# fallback to default mode
144146
first_token, decode_state = self._processor.process(
145-
model_params, decode_state, decode_slot, input_tokens_padded, input_true_length
147+
model_params, decode_state, decode_slot, input_tokens_padded, input_true_length, rng
146148
)
147149
prefill_done([(first_token, decode_slot)], [input_id], decode_state)
148150
else:
@@ -249,6 +251,9 @@ def batch_inference_with_callback(
249251
counter = EventCounter(input=0, prefill=0, decode=0, detokenize=0)
250252
dummy_length = 1
251253

254+
rng = jax.random.PRNGKey(1234)
255+
rng, _ = jax.random.split(rng)
256+
252257
def prefill_done(prefill_result, ids, decode_state):
253258
nonlocal self
254259
nonlocal counter
@@ -345,7 +350,15 @@ def detokenize():
345350

346351
# Do prefill when there are free slots
347352
self.prefill.process(
348-
self.params, self.decode_state, slot, row.id, row.tokens, row.true_length, self.max_prefill_length, prefill_done
353+
self.params,
354+
self.decode_state,
355+
slot,
356+
row.id,
357+
row.tokens,
358+
row.true_length,
359+
self.max_prefill_length,
360+
prefill_done,
361+
rng,
349362
)
350363
self.prefill.finalize(self.params, self.decode_state, prefill_done)
351364

MaxText/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#!/usr/bin/env bash
22

3+
# NOTE: please check the README located at MaxText/inference_mlperf/README.md for instructions on how
4+
# to set up the environment before running this script.
35
# Run command:
46
# bash benchmarks_llama2-70b-trillium_2x4.sh [-b benchmark_type]
5-
# benchmark_type can be: performance, audit, accuracy, or all (default)
7+
# benchmark_type can be: performance (default), audit, accuracy, or all
68

79
run_name="trillium_llama2-70b"
810
dry_run=false
@@ -84,21 +86,21 @@ if [[ -z ${CHECKPOINT} ]] ; then
8486
fi
8587

8688
if [[ -z ${TOKENIZER_PATH} ]] ; then
87-
export TOKENIZER_PATH="/home/${USER}/maxtext/assets/tokenizer.llama2"
89+
export TOKENIZER_PATH="/home/${USER}/maxtext/assets/tokenizer.llama2" # NOTE: you may need to change this path for your VM
8890
fi
8991

9092
if [ -z "$PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES" ];
9193
then
9294
PREFILL_LEN="1024"
93-
BATCH_SIZE_PER_DEVICE="64"
95+
BATCH_SIZE_PER_DEVICE="64"
9496
export PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES="${PREFILL_LEN},${BATCH_SIZE_PER_DEVICE}"
9597
fi
9698

9799

98100
BASE_CFG="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${CHECKPOINT}"
99101
QUANT_CFG="quantization=${QUANTIZATION} quant_cfg_path=${QUANT_PATH} checkpoint_is_quantized=True"
100102
KV_QUANT_CFG="quantize_kvcache=True kv_quant_dtype=${KV_QUANT_DTYPE}"
101-
export MAXENGINE_ARGS="${BASE_CFG} ${QUANT_CFG} ${KV_QUANT_CFG} optimize_mesh_for_tpu_v6e=True"
103+
export MAXENGINE_ARGS="${BASE_CFG} ${QUANT_CFG} ${KV_QUANT_CFG} optimize_mesh_for_tpu_v6e=True skip_jax_distributed_system=True"
102104
echo
103105
echo $MAXENGINE_ARGS
104106
echo
@@ -117,7 +119,7 @@ run_benchmark() {
117119
;;
118120
"accuracy")
119121
export HF_CKPT="meta-llama/Llama-2-70b-chat-hf"
120-
$cmd bash llama_offline_run.sh ${RUN_OPTIONS} -r benchmarks_accuracy_${RUN_DESC} -a
122+
$cmd bash llama_offline_run.sh ${RUN_OPTIONS} -r benchmarks_accuracy_${RUN_DESC} -a
121123
;;
122124
esac
123125
}

MaxText/inference_mlperf/trillium/microbenchmarks_llama2-70b-trillium_2x4.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Run command:
22
# bash microbenchmarks_llama2-70b-trillium_2x4.sh
3-
# Look at profiles:
4-
# tensorboard --logdir /tmp/mb/profiles/trillium_llama2_70b/tensorboard/prefill_insert_1024
3+
# Look at profiles:
4+
# tensorboard --logdir /tmp/mb/profiles/trillium_llama2_70b/tensorboard/prefill_insert_1024
55

66

77
run_name="trillium_llama2-70b"

MaxText/prefill_packing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ def process(
114114
decode_slot: int,
115115
input_tokens_padded: jax.Array,
116116
input_true_length: int,
117+
rng: PRNGKeyType,
117118
) -> Tuple[engine_api.ResultTokens, DecodeState]:
118119
"""Process a new input."""
119120

120121
process_fn = self._process_compiled(model_params, len(input_tokens_padded))
121-
return process_fn(model_params, input_tokens_padded, decode_slot, input_true_length, decode_state)
122+
return process_fn(model_params, input_tokens_padded, decode_slot, input_true_length, decode_state, rng)
122123

123124
def _process_compiled(self, params: Params, padded_length: int):
124125
"""Ahead-of-time compilation wrapper of _process()."""

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pyink
2828
pre-commit
2929
pytype
3030
pillow>=11.1.0
31-
sentencepiece==0.1.97
31+
sentencepiece==0.2.0
3232
tensorflow-text>=2.13.0
3333
tensorflow>=2.13.0
3434
tensorflow-datasets

0 commit comments

Comments
 (0)