-
Notifications
You must be signed in to change notification settings - Fork 39
Expand file tree
/
Copy pathrun_inference.sh
More file actions
executable file
·47 lines (42 loc) · 1.16 KB
/
Copy pathrun_inference.sh
File metadata and controls
executable file
·47 lines (42 loc) · 1.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
ENGINE="${ENGINE:-hf}"
MODEL_PATH="${MODEL_PATH:-$REPO_DIR/model}"
IMAGE_PATH="${IMAGE_PATH:-$REPO_DIR/assets/image.png}"
PROMPT_TYPE="text"
PROMPT=""
DEVICE="cuda"
DTYPE="bfloat16"
MAX_LENGTH="4096"
GEN_LENGTH="1024"
BLOCK_SIZE="32"
TEMPERATURE="1.0"
REMASK_STRATEGY="low_confidence_dynamic"
DYNAMIC_THRESHOLD="0.95"
SGLANG_SERVER_URL="${SGLANG_SERVER_URL:-http://127.0.0.1:31002/v1/chat/completions}"
SGLANG_REQUEST_TIMEOUT="${SGLANG_REQUEST_TIMEOUT:-180}"
ARGS=(
--engine "$ENGINE"
--model-path "$MODEL_PATH"
--image-path "$IMAGE_PATH"
--prompt-type "$PROMPT_TYPE"
--prompt "$PROMPT"
--device "$DEVICE"
--dtype "$DTYPE"
--max-length "$MAX_LENGTH"
--gen-length "$GEN_LENGTH"
--block-size "$BLOCK_SIZE"
--temperature "$TEMPERATURE"
--remask-strategy "$REMASK_STRATEGY"
--dynamic-threshold "$DYNAMIC_THRESHOLD"
)
if [[ "$ENGINE" == "sglang" ]]; then
ARGS+=(
--server-url "$SGLANG_SERVER_URL"
--request-timeout "$SGLANG_REQUEST_TIMEOUT"
)
fi
python "$REPO_DIR/scripts/run_inference.py" \
"${ARGS[@]}"