Skip to content

Commit a50d9ab

Browse files
authored
Merge branch 'main' into patch-45
2 parents ef74428 + 7cbf2a3 commit a50d9ab

File tree

10 files changed

+198
-163
lines changed

10 files changed

+198
-163
lines changed

.ci/scripts/run-docs

+62-140
Original file line numberDiff line numberDiff line change
@@ -1,145 +1,67 @@
1-
# /bin/bash -x
1+
#!/bin/bash -x
22

3-
if [ "X$1" == "X" ]; then
3+
# Check if an argument was provided
4+
if [ -z "$1" ]; then
45
echo "Must specify document to run"
56
exit 1
67
fi
78

8-
if [ "$1" == "readme" ]; then
9-
echo "::group::Create script to run README"
10-
python3 torchchat/utils/scripts/updown.py --create-sections --file README.md --replace 'llama3.1:stories15M,-l 3:-l 2' --suppress huggingface-cli,HF_TOKEN > ./run-readme.sh
11-
# for good measure, if something happened to updown processor,
12-
# and it did not error out, fail with an exit 1
13-
echo "exit 1" >> ./run-readme.sh
14-
echo "::endgroup::"
15-
16-
echo "::group::Run README"
17-
echo "*******************************************"
18-
cat ./run-readme.sh
19-
echo "*******************************************"
20-
bash -x ./run-readme.sh
21-
echo "::endgroup::"
22-
23-
exit 0
24-
fi
25-
26-
if [ "$1" == "quantization" ]; then
27-
echo "::group::Create script to run quantization"
28-
python3 torchchat/utils/scripts/updown.py --create-sections --file docs/quantization.md --replace llama3:stories15M --suppress huggingface-cli,HF_TOKEN > ./run-quantization.sh
29-
# for good measure, if something happened to updown processor,
30-
# and it did not error out, fail with an exit 1
31-
echo "exit 1" >> ./run-quantization.sh
32-
echo "::endgroup::"
33-
34-
echo "::group::Run quantization"
35-
echo "*******************************************"
36-
cat ./run-quantization.sh
37-
echo "*******************************************"
38-
bash -x ./run-quantization.sh
39-
echo "::endgroup::"
40-
41-
exit 0
42-
fi
43-
44-
if [ "$1" == "gguf" ]; then
45-
echo "::group::Create script to run gguf"
46-
python3 torchchat/utils/scripts/updown.py --file docs/GGUF.md --replace 'llama3:stories15M,-l 3:-l 2' --suppress huggingface-cli,HF_TOKEN > ./run-gguf.sh
47-
# for good measure, if something happened to updown processor,
48-
# and it did not error out, fail with an exit 1
49-
echo "exit 1" >> ./run-gguf.sh
50-
echo "::endgroup::"
51-
52-
echo "::group::Run gguf"
53-
echo "*******************************************"
54-
cat ./run-gguf.sh
55-
echo "*******************************************"
56-
bash -x ./run-gguf.sh
57-
echo "::endgroup::"
58-
fi
59-
60-
61-
if [ "$1" == "advanced" ]; then
62-
echo "::group::Create script to run advanced"
63-
python3 torchchat/utils/scripts/updown.py --file docs/ADVANCED-USERS.md --replace 'llama3:stories15M,-l 3:-l 2' --suppress huggingface-cli,HF_TOKEN > ./run-advanced.sh
64-
# for good measure, if something happened to updown processor,
65-
# and it did not error out, fail with an exit 1
66-
echo "exit 1" >> ./run-advanced.sh
67-
echo "::endgroup::"
68-
69-
echo "::group::Run advanced"
70-
echo "*******************************************"
71-
cat ./run-advanced.sh
72-
echo "*******************************************"
73-
bash -x ./run-advanced.sh
74-
echo "::endgroup::"
75-
fi
76-
77-
if [ "$1" == "evaluation" ]; then
78-
echo "::group::Create script to run evaluation"
79-
python3 torchchat/utils/scripts/updown.py --file torchchat/utils/docs/evaluation.md --replace 'llama3:stories15M,-l 3:-l 2' --suppress huggingface-cli,HF_TOKEN > ./run-evaluation.sh
80-
# for good measure, if something happened to updown processor,
81-
# and it did not error out, fail with an exit 1
82-
echo "exit 1" >> ./run-evaluation.sh
83-
echo "::endgroup::"
84-
85-
echo "::group::Run evaluation"
86-
echo "*******************************************"
87-
cat ./run-evaluation.sh
88-
echo "*******************************************"
89-
bash -x ./run-evaluation.sh
90-
fi
91-
92-
if [ "$1" == "multimodal" ]; then
93-
94-
# Expecting that this might fail this test as-is, because
95-
# it's the first on-pr test depending on github secrets for access with HF token access
96-
97-
echo "::group::Create script to run multimodal"
98-
python3 torchchat/utils/scripts/updown.py --file docs/multimodal.md > ./run-multimodal.sh
99-
# for good measure, if something happened to updown processor,
100-
# and it did not error out, fail with an exit 1
101-
echo "exit 1" >> ./run-multimodal.sh
102-
echo "::endgroup::"
103-
104-
echo "::group::Run multimodal"
105-
echo "*******************************************"
106-
cat ./run-multimodal.sh
107-
echo "*******************************************"
108-
bash -x ./run-multimodal.sh
109-
echo "::endgroup::"
110-
fi
111-
112-
if [ "$1" == "native" ]; then
113-
114-
echo "::group::Create script to run native-execution"
115-
python3 torchchat/utils/scripts/updown.py --file docs/native-execution.md > ./run-native.sh
116-
# for good measure, if something happened to updown processor,
117-
# and it did not error out, fail with an exit 1
118-
echo "exit 1" >> ./run-native.sh
119-
echo "::endgroup::"
120-
121-
echo "::group::Run native-execution"
122-
echo "*******************************************"
123-
cat ./run-native.sh
124-
echo "*******************************************"
125-
bash -x ./run-native.sh
126-
echo "::endgroup::"
127-
fi
128-
129-
if [ "$1" == "distributed" ]; then
130-
131-
echo "::group::Create script to run distributed"
132-
python3 torchchat/utils/scripts/updown.py --file docs/distributed.md --replace 'llama3.1:stories110M,-l 3:-l 2' --suppress huggingface-cli,HF_TOKEN > ./run-distributed.sh
133-
python3 torchchat/utils/scripts/updown.py --file docs/distributed.md --suppress huggingface-cli,HF_TOKEN > ./run-distributed.sh
134-
# for good measure, if something happened to updown processor,
135-
# and it did not error out, fail with an exit 1
136-
echo "exit 1" >> ./run-distributed.sh
137-
echo "::endgroup::"
138-
139-
echo "::group::Run distributed"
140-
echo "*******************************************"
141-
cat ./run-distributed.sh
142-
echo "*******************************************"
143-
bash -x ./run-distributed.sh
144-
echo "::endgroup::"
145-
fi
9+
# Pre-initialize variables
10+
filepath=""
11+
parameters="--replace 'llama3:stories15M,-l3:-l2' --suppress huggingface-cli,HF_TOKEN"
12+
script_name="./run-${1}.sh" # Dynamically initialize script name
13+
14+
# Use a case statement to handle the $1 argument
15+
case "$1" in
16+
"readme")
17+
filepath="README.md"
18+
;;
19+
"quantization")
20+
filepath="docs/quantization.md"
21+
;;
22+
"gguf")
23+
filepath="docs/GGUF.md"
24+
;;
25+
"advanced")
26+
filepath="docs/ADVANCED-USERS.md"
27+
;;
28+
"evaluation")
29+
filepath="torchchat/utils/docs/evaluation.md"
30+
;;
31+
"multimodal")
32+
filepath="docs/multimodal.md"
33+
parameters="" # Clear parameters
34+
;;
35+
"native")
36+
filepath="docs/native-execution.md"
37+
parameters="" # Clear parameters
38+
;;
39+
"distributed")
40+
filepath="docs/distributed.md"
41+
parameters="--replace 'llama3.1:stories110M,-l3:-l2' --suppress huggingface-cli,HF_TOKEN" # Use stories110M to avoid need for authentication
42+
;;
43+
"local")
44+
filepath="docs/local-model.md"
45+
parameters="" # Clear parameters
46+
;;
47+
48+
*)
49+
echo "Unknown option: $1"
50+
exit 1
51+
;;
52+
esac
53+
54+
# Generate the script
55+
echo "::group::Create script to run $1"
56+
python3 torchchat/utils/scripts/updown.py --file "$filepath" $parameters > "$script_name"
57+
# if something happened to updown processor, and it did not error out, fail with an exit 1
58+
echo "exit 1" >> "$script_name"
59+
echo "::endgroup::"
60+
61+
# Run the script
62+
echo "::group::Run $1"
63+
echo "*******************************************"
64+
cat "$script_name"
65+
echo "*******************************************"
66+
bash -x "$script_name"
67+
echo "::endgroup::"

.github/workflows/run-readme-pr-mps.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ jobs:
1515
conda create -y -n test-readme-mps-macos python=3.10.11 llvm-openmp
1616
conda activate test-readme-mps-macos
1717
set -x
18-
# NS: Remove previous installation of torch first
19-
# as this script does not isntall anything into conda env but rather as system dep
18+
# NS: Remove previous installation of torch first
19+
# as this script does not install anything into conda env but rather as system dep
2020
pip3 uninstall -y torch || true
2121
set -eou pipefail
2222
@@ -37,6 +37,7 @@ jobs:
3737
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
3838
with:
3939
runner: macos-m1-14
40+
timeout: 60
4041
script: |
4142
set -x
4243
conda create -y -n test-quantization-mps-macos python=3.10.11

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ torchchat/utils/scripts/build_native.sh et
413413

414414
Execute using the runner
415415
```bash
416-
cmake-out/et_run llama3.1.pte -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
416+
cmake-out/et_run llama3.1.pte -z `python3 torchchat.py where llama3.1`/tokenizer.model -i "Once upon a time"
417417
```
418418

419419
</details>

docs/quantization.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so
182182
If you built the AOTI runner with link_torchao_ops as discussed in the setup section, you can also use the C++ runner:
183183

184184
```
185-
OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time,"
185+
OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -i "Once upon a time," # -l 3
186186
```
187187

188188
#### ExecuTorch
@@ -193,7 +193,7 @@ python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"e
193193
Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command.
194194

195195
```
196-
./cmake-out/et_run llama3_1.pte -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time,"
196+
./cmake-out/et_run llama3_1.pte -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l3 -i "Once upon a time,"
197197
```
198198

199199
## Experimental TorchAO MPS lowbit kernels

runner/run.cpp

+28-16
Original file line numberDiff line numberDiff line change
@@ -803,41 +803,53 @@ int main(int argc, char *argv[]) {
803803
} else {
804804
error_usage();
805805
}
806-
for (int i = 2; i < argc; i += 2) {
806+
for (int i = 2; i < argc; i += 1) {
807807
// do some basic validation
808-
if (i + 1 >= argc) {
809-
error_usage();
810-
} // must have arg after flag
808+
char *parm = argv[i+1];
809+
// uniarg means the arg comes right after the letter in accordance with posix
810+
int uniarg = strlen(argv[i]) > 2;
811+
811812
if (argv[i][0] != '-') {
812813
error_usage();
813814
} // must start with dash
814-
if (strlen(argv[i]) != 2) {
815+
816+
if (strlen(argv[i]) < 2) {
815817
error_usage();
816-
} // must be -x (one dash, one letter)
818+
} // must have at least dash '-' and option letter
819+
820+
if (uniarg) {
821+
parm=&argv[i][2];
822+
} else if (i + 1 >= argc) {
823+
error_usage();
824+
} // must have arg after option if flag is not contiguous to option
825+
817826
// read in the args
818827
if (argv[i][1] == 't') {
819-
temperature = atof(argv[i + 1]);
828+
temperature = atof(parm);
820829
} else if (argv[i][1] == 'p') {
821-
topp = atof(argv[i + 1]);
830+
topp = atof(parm);
822831
} else if (argv[i][1] == 's') {
823-
rng_seed = atoi(argv[i + 1]);
832+
rng_seed = atoi(parm);
824833
} else if (argv[i][1] == 'n') {
825-
steps = atoi(argv[i + 1]);
834+
steps = atoi(parm);
826835
} else if (argv[i][1] == 'v') {
827-
vocab_size = atoi(argv[i + 1]);
836+
vocab_size = atoi(parm);
828837
} else if (argv[i][1] == 'i') {
829-
prompt = argv[i + 1];
838+
prompt = parm;
830839
} else if (argv[i][1] == 'z') {
831-
tokenizer_path = argv[i + 1];
840+
tokenizer_path = parm;
832841
} else if (argv[i][1] == 'm') {
833-
mode = argv[i + 1];
842+
mode = parm;
834843
} else if (argv[i][1] == 'y') {
835-
system_prompt = argv[i + 1];
844+
system_prompt = parm;
836845
} else if (argv[i][1] == 'l') {
837-
llama_ver = atoi(argv[i + 1]);
846+
llama_ver = atoi(parm);
838847
} else {
839848
error_usage();
840849
}
850+
851+
// account for parameter
852+
i += (uniarg)?0:1;
841853
}
842854

843855
if (model_path == NULL) {

torchchat/cli/builder.py

+33
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class BuilderArgs:
5656
gguf_kwargs: Optional[Dict[str, Any]] = None
5757
dso_path: Optional[Union[Path, str]] = None
5858
aoti_package_path: Optional[Union[Path, str]] = None
59+
snapshot_path: Optional[Union[Path, str]] = None
5960
pte_path: Optional[Union[Path, str]] = None
6061
device: Optional[str] = None
6162
precision: torch.dtype = torch.float32
@@ -87,6 +88,7 @@ def __post_init__(self):
8788
or (self.dso_path and Path(self.dso_path).is_file())
8889
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
8990
or (self.pte_path and Path(self.pte_path).is_file())
91+
or (self.snapshot_path and Path(self.snapshot_path).is_file())
9092
):
9193
raise RuntimeError(
9294
"need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path"
@@ -142,6 +144,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
142144
dso_path = getattr(args, "dso_path", None)
143145
pte_path = getattr(args, "pte_path", None)
144146
aoti_package_path = getattr(args, "aoti_package_path", None)
147+
snapshot_path = getattr(args, "snapshot_path", None)
145148

146149
is_chat_model = False
147150
if args.is_chat_model:
@@ -169,6 +172,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
169172
output_pte_path = getattr(args, "output_pte_path", None)
170173
output_aoti_package_path = getattr(args, "output_aoti_package_path", None)
171174
output_dso_path = getattr(args, "output_dso_path", None)
175+
output_snapshot_path = getattr(args, "output_snapshot_path", None)
172176
if output_pte_path and args.dtype.startswith("fast"):
173177
if args.dtype == "fast":
174178
# As per Kimish, float32 should be faster on ET XNNPACK
@@ -206,6 +210,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
206210
dso_path=dso_path,
207211
aoti_package_path=aoti_package_path,
208212
pte_path=pte_path,
213+
snapshot_path=snapshot_path,
209214
device=args.device,
210215
precision=dtype,
211216
setup_caches=(
@@ -631,6 +636,34 @@ def do_nothing(max_batch_size, max_seq_length):
631636
model = PTEModel(config, builder_args.pte_path)
632637
except Exception:
633638
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
639+
elif builder_args.snapshot_path:
640+
# Resolve ModelArgs for constructing the PTEModel
641+
# If a manual params_path is provided, use that
642+
if builder_args.params_path:
643+
config: ModelArgs = ModelArgs.from_params(builder_args.params_path)
644+
else:
645+
# TODO: Instead of loading the whole model, refactor to call a
646+
# helper that generate just model.config
647+
with measure_time("Time to load model: {time:.02f} seconds"):
648+
model = _load_model(builder_args)
649+
device_sync(device=builder_args.device)
650+
config = model.config
651+
model = None
652+
try:
653+
model = torch.load(builder_args.snapshot_path, weights_only=False)
654+
except Exception:
655+
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
656+
# _active_backend() does not allow DSO & AOTI to be true.
657+
# Choose either.
658+
from torchchat.utils.build_utils import set_backend
659+
set_backend (dso=True, pte=False, aoti_package=False)
660+
if (model.config != config):
661+
raise RuntimeError("loaded model architecture mismatch")
662+
##
663+
## import all libraries with custom kernels ans custom operators
664+
## that quantize may be pulling in
665+
##
666+
634667
elif builder_args.distributed:
635668
pp_degree = builder_args.pp
636669
tp_degree = builder_args.tp

0 commit comments

Comments
 (0)