Skip to content

Commit cf9bea7

Browse files
author
Michael Gschwind
committed
fix tests
1 parent 95da421 commit cf9bea7

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

.github/workflows/compile-bf16.yml .github/workflows/compile-dtype.yml

+1-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
run-tinystories:
1212
strategy:
1313
matrix:
14-
runner: [ubuntu-latest, macos-14, macos-12]
14+
runner: [ubuntu-latest, macos-14]
1515
runs-on: ${{matrix.runner}}
1616
steps:
1717
- name: Checkout repo
@@ -102,9 +102,6 @@ jobs:
102102
echo "******************************************"
103103
echo "******** INT4 group-wise quantized *******"
104104
echo "******************************************"
105-
if [ ${DTYPE} == float16 ]; then
106-
DTYPE=bfloat16
107-
fi
108105

109106
python generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
110107
cat ./output_eager

.github/workflows/eager-dtype.yml

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
name: Compile-dtype main
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
workflow_dispatch:
9+
10+
jobs:
11+
run-tinystories:
12+
strategy:
13+
matrix:
14+
runner: [macos-12]
15+
runs-on: ${{matrix.runner}}
16+
steps:
17+
- name: Checkout repo
18+
uses: actions/checkout@v2
19+
- name: Setup Python
20+
uses: actions/setup-python@v2
21+
with:
22+
python-version: 3.11
23+
- name: Print machine info
24+
run: |
25+
uname -a
26+
if [ $(uname -s) == Darwin ]; then
27+
sysctl machdep.cpu.brand_string
28+
sysctl machdep.cpu.core_count
29+
fi
30+
- name: Install requirements
31+
run: |
32+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
33+
pip install -r requirements.txt
34+
- name: Download checkpoints
35+
run: |
36+
mkdir -p checkpoints/stories15M
37+
pushd checkpoints/stories15M
38+
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
39+
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
40+
popd
41+
- name: Run inference
42+
run: |
43+
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
44+
export MODEL_NAME=stories15M
45+
export MODEL_DIR=/tmp
46+
for DTYPE in bfloat16 float16 float32; do
47+
# if [ $(uname -s) == Darwin ]; then
48+
# export DTYPE=float16
49+
# fi
50+
python generate.py --dtype ${DTYPE} --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
51+
cat ./output_eager
52+
53+
echo "******************************************"
54+
echo "******* Emb: channel-wise quantized ******"
55+
echo "******************************************"
56+
python generate.py --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
57+
cat ./output_eager
58+
59+
echo "******************************************"
60+
echo "******** Emb: group-wise quantized *******"
61+
echo "******************************************"
62+
python generate.py --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
63+
cat ./output_eager
64+
65+
echo "******************************************"
66+
echo "******* INT8 channel-wise quantized ******"
67+
echo "******************************************"
68+
python generate.py --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
69+
cat ./output_eager
70+
71+
echo "******************************************"
72+
echo "******** INT8 group-wise quantized *******"
73+
echo "******************************************"
74+
python generate.py --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
75+
cat ./output_eager
76+
77+
echo "******************************************"
78+
echo "******** INT4 group-wise quantized *******"
79+
echo "******************************************"
80+
81+
python generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
82+
cat ./output_eager
83+
84+
echo "tests complete for ${DTYPE}"
85+
done
86+
87+
echo "tests complete for all dtypes!"

0 commit comments

Comments
 (0)