Skip to content

[sharktank] Add toy Deepseek IREE Perplexity Tests #1507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
050525b
Add torch PPL tests for Deepseek
Alex-Vasile May 7, 2025
5f5f96c
Remove resharding dataset
Alex-Vasile May 14, 2025
d4e5f4b
add pipeline parallelism support
Alex-Vasile May 14, 2025
b0b260f
Add PP and TP deepseek eager vs unsharded
Alex-Vasile May 14, 2025
e0ef7f7
REVERT ME: Update perplexity scores
Alex-Vasile May 20, 2025
a9fed06
Enable deepseek PP perplexity IREE test
Alex-Vasile May 21, 2025
f12f826
REVERT ME: Update perpelxity values
Alex-Vasile May 21, 2025
26d2d49
Update --run-quick-test in ci_eval_short.yaml
archana-ramalingam May 27, 2025
0af287b
Update --run-quick-test in ci-llama-quick-tests.yaml
archana-ramalingam May 27, 2025
9d52c15
Update --run-nightly-tests in ci-llama-large-tests.yaml
archana-ramalingam May 27, 2025
abb017b
Update --run-nightly-tests in ci-sharktank-nightly.yml
archana-ramalingam May 27, 2025
77a754f
Add toy deepseek ppl test to pre-submit CI
archana-ramalingam May 27, 2025
9b3eb47
Make toy deepseek test a pre-submit+nightly
archana-ramalingam May 27, 2025
1a2b3db
Merge branch 'main' into deepseek_ppl
archana-ramalingam May 27, 2025
39aac93
Fix unshard for replicated tensor
archana-ramalingam May 27, 2025
161c27c
Merge branch 'main' into deepseek_ppl
archana-ramalingam May 27, 2025
57783ac
Add is_pre_submit flag
archana-ramalingam May 28, 2025
567f551
Merge branch 'deepseek_ppl' of https://github.com/Alex-Vasile/shark-a…
archana-ramalingam May 28, 2025
d243414
Fix is_pre_submit_nightly flag
archana-ramalingam May 28, 2025
369284b
Add toy deepseek to ppl CIs
archana-ramalingam May 28, 2025
f7f71b8
Fix pre-commit issue
archana-ramalingam May 28, 2025
6e98f96
Merge branch 'main' into deepseek_ppl
archana-ramalingam May 28, 2025
28b6b4a
Fix ppl CIs flags
archana-ramalingam May 28, 2025
2eacd74
Fix ppl CIs flags
archana-ramalingam May 28, 2025
ea23dcd
fix ppl flags
archana-ramalingam May 28, 2025
b6d9295
fix ppl flags
archana-ramalingam May 28, 2025
132545b
Merge branch 'main' into deepseek_ppl
archana-ramalingam May 28, 2025
5bb2704
Update tp+iree issue link
archana-ramalingam May 28, 2025
0adf073
Enable eager mode pre-submit ppl
archana-ramalingam May 28, 2025
5f0b72b
Merge branch 'deepseek_ppl' of https://github.com/Alex-Vasile/shark-a…
archana-ramalingam May 28, 2025
6eb8628
Merge branch 'main' into deepseek_ppl
archana-ramalingam May 28, 2025
e9bc368
Add device flag for eager
archana-ramalingam May 28, 2025
49b5eac
Merge branch 'deepseek_ppl' of https://github.com/Alex-Vasile/shark-a…
archana-ramalingam May 28, 2025
445576f
Install torch+rocm
archana-ramalingam May 28, 2025
50aca13
Update test name
archana-ramalingam May 28, 2025
40725ad
Update toy ppl changes and ppl numbers
archana-ramalingam May 29, 2025
e5ae697
Merge branch 'main' into deepseek_ppl
archana-ramalingam May 29, 2025
182f770
breakup eager mode ppl tests to separate pr
archana-ramalingam May 29, 2025
9e7c4c6
Merge branch 'deepseek_ppl' of https://github.com/Alex-Vasile/shark-a…
archana-ramalingam May 29, 2025
dd01865
Remove toy test dependency on tokenizer
archana-ramalingam May 29, 2025
1607d98
Disable eager tests due to numeric regression
archana-ramalingam May 29, 2025
968bad9
Move all args to same device for eager ppl
archana-ramalingam May 29, 2025
b26e607
Refactor test fixtures
archana-ramalingam May 29, 2025
f437a41
Run all ppl tests in a single command
archana-ramalingam May 29, 2025
b9f2151
Add toy model flag to test
archana-ramalingam May 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-llama-large-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
source ${VENV_DIR}/bin/activate
pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py \
-v -s \
--run-nightly-llama-tests \
--run-nightly-tests \
--iree-hip-target=gfx942 \
--iree-device=hip://0 \
--html=out/llm/llama/benchmark/index.html
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-llama-quick-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
-v -s \
--iree-hip-target=gfx942 \
--iree-device=hip://0 \
--run-quick-llama-test
--run-quick-test

- name: Upload llama executable files
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
Expand Down
11 changes: 7 additions & 4 deletions .github/workflows/ci-sharktank-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,14 @@ jobs:
- name: Run perplexity test with IREE
run: |
source ${VENV_DIR}/bin/activate
mkdir perplexity_ci_artifacts
python -m sharktank.models.deepseek.toy_deepseek -o "perplexity_ci_artifacts/toy_deepseek.irpa"
pytest \
-n 8 \
-v \
-s \
sharktank/tests/evaluate/perplexity_iree_test.py \
--run-nightly-llama-tests \
--run-nightly-tests \
--bs=128 \
--iree-device=hip://0 \
--iree-hip-target=gfx942 \
Expand All @@ -144,6 +146,8 @@ jobs:
--llama3-8b-f16-tp2-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/tp2/llama3.1_8b_instruct_fp16_tp2.irpa \
--llama3-8b-f8-model-path=/shark-dev/8b/fp8/native_fp8_e4m3fnuz_llama3_8b.irpa \
--llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json \
--deepseek-v3-model-path=perplexity_ci_artifacts/toy_deepseek.irpa \
--deepseek-v3-tokenizer-path=/shark-dev/data/deepseekv3/weights/fp16/tokenizer_config.json \
--html=out/llm/llama/perplexity/iree_perplexity/index.html \
--log-cli-level=INFO
ls -lha ${{ github.workspace }}/perplexity_ci_artifacts
Expand Down Expand Up @@ -200,12 +204,11 @@ jobs:
- name: Run Torch perplexity for fp16
run: |
source ${VENV_DIR}/bin/activate
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py -k test_llama3_8B_f16 --run-nightly-llama-tests --bs=32 --device='cuda:0' --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index_f16.html --log-cli-level=INFO

pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py -k test_llama3_8B_f16 --run-nightly-tests --bs=32 --device='cuda:0' --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index_f16.html --log-cli-level=INFO
- name: Run Torch perplexity for fp8
run: |
source ${VENV_DIR}/bin/activate
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py -k test_llama3_8B_f8 --run-nightly-llama-tests --bs=32 --device='cuda:0' --llama3-8b-f8-model-path=/shark-dev/8b/fp8/native_fp8_e4m3fnuz_llama3_8b.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index_f8.html --log-cli-level=INFO
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py -k test_llama3_8B_f8 --run-nightly-tests --bs=32 --device='cuda:0' --llama3-8b-f8-model-path=/shark-dev/8b/fp8/native_fp8_e4m3fnuz_llama3_8b.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index_f8.html --log-cli-level=INFO

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
Expand Down
18 changes: 13 additions & 5 deletions .github/workflows/ci_eval_short.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ concurrency:

jobs:
test_perplexity_iree:
name: "IREE Perplexity"
name: "Perplexity tests"
strategy:
matrix:
version: [3.11]
Expand Down Expand Up @@ -53,24 +53,32 @@ jobs:

# Note: We install in three steps in order to satisfy requirements
# from non default locations first.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -r pytorch-rocm-requirements.txt
pip install -r requirements-iree-pinned.txt
pip install --no-compile \
-r sharktank/requirements-tests.txt \
-e sharktank/

pip freeze

- name: Run perplexity test with vmfb
- name: Run Perplexity tests
run: |
source ${VENV_DIR}/bin/activate
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py \
mkdir perplexity_ci_artifacts
python -m sharktank.models.deepseek.toy_deepseek -o "perplexity_ci_artifacts/toy_deepseek.irpa"
pytest \
-n 8 \
-v \
-s \
sharktank/tests/evaluate/ \
--run-quick-test \
--bs=4 \
--device='cuda:0' \
--iree-device=hip://0 \
--iree-hip-target=gfx942 \
--iree-hal-target-device=hip \
--llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa \
--llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json \
--run-quick-llama-test \
--deepseek-v3-model-path=perplexity_ci_artifacts/toy_deepseek.irpa \
--log-cli-level=INFO
ls -lha ${{ github.workspace }}/perplexity_ci_artifacts
39 changes: 33 additions & 6 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ def pytest_addoption(parser):
help="List a torch device, (e.g., 'cuda:0')",
)
parser.addoption(
"--run-quick-llama-test",
"--run-quick-test",
action="store_true",
dest="run-quick-llama-test",
dest="run-quick-test",
default=False,
help="Run large llama tests if passed",
help="Enable all quick tests",
)
parser.addoption(
"--run-nightly-llama-tests",
"--run-nightly-tests",
action="store_true",
dest="run-nightly-llama-tests",
dest="run-nightly-tests",
default=False,
help="Enable all llama benchmarking tests",
help="Enable all nightly tests",
)

parser.addoption(
Expand Down Expand Up @@ -173,6 +173,24 @@ def pytest_addoption(parser):
default=None,
help="Llama3.1 405b f8 model path",
)
parser.addoption(
"--deepseek-v3-tokenizer-path",
type=Path,
action="store",
help="Deepkseek v3 tokenizer path",
)
parser.addoption(
"--deepseek-v3-model-path",
type=Path,
action="store",
help="Deepseek v3 unsharded model path",
)
parser.addoption(
"--deepseek-v3-tp8-model-path",
type=Path,
action="store",
help="Deepseek v3 tp8 sharded model path",
)

# To obtain a T5 GGUF file you can use llama.cpp's convert_hf_to_gguf.py.
# https://github.com/ggerganov/llama.cpp/blob/9abe9eeae98b11fa93b82632b264126a010225ff/convert_hf_to_gguf.py
Expand Down Expand Up @@ -350,6 +368,15 @@ def get_model_artifacts(request: FixtureRequest):
model_path["llama3_405b_f8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-f8-model-path", "llama3_405b_f8_model"
)
model_path["deepseek_v3_tokenizer_path"] = set_fixture_from_cli_option(
request, "--deepseek-v3-tokenizer-path", "deepseek_v3_tokenizer"
)
model_path["deepseek_v3_model_path"] = set_fixture_from_cli_option(
request, "--deepseek-v3-model-path", "deepseek_v3_model"
)
model_path["deepseek_v3_tp8_model_path"] = set_fixture_from_cli_option(
request, "--deepseek-v3-tp8-model-path", "deepseek_v3_tp8_model"
)
model_path["google__t5_v1_1_small_f32_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-small-f32-model-path",
Expand Down
95 changes: 68 additions & 27 deletions sharktank/sharktank/evaluate/perplexity_iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
use_attention_mask,
use_hf,
weight_path_str: str,
use_toy_model: bool = False,
):
self.torch_device = torch_device
self.iree_devices = iree_devices
Expand All @@ -79,11 +80,24 @@ def __init__(
self.use_attention_mask = use_attention_mask
self.use_hf = use_hf
self.weight_path_str = weight_path_str
self.use_toy_model = use_toy_model
self.vm_context: iree.runtime.VmContext = None
self.cache_state: None | list[ireert.DeviceArray] = None
self.page_cache_size = 128
# Add context to improve perplexity by starting at 10th token
self.start = 10

def print_token_comparison(self, i: int):
if i <= self.max_prompt_length:
if self.use_toy_model and i <= self.max_prompt_length:
batch_predicted_token_id = [[i[-1]] for i in self.batch.results]
logger.debug(f"Predicted:")
logger.debug(f"{batch_predicted_token_id}")

expected_token_id = self.token_ids[:, i + 1 : i + 2].tolist()
logger.debug(f"Expected:")
logger.debug(f"{expected_token_id}")

elif i <= self.max_prompt_length:
batch_predicted_token_id = [[i[-1]] for i in self.batch.results]
batch_predicted_token = self.generator.tokenizer.decode(
batch_predicted_token_id
Expand Down Expand Up @@ -133,7 +147,9 @@ def compile_model(
)
self.output_vmfb = export_artifacts.get_artifacts()

def load_model(self, dataset: Dataset, tokenizer: InferenceTokenizer):
def load_model(
self, dataset: Dataset, tokenizer: Optional[InferenceTokenizer] = None
):
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)

pp = self.pipeline_parallelism_size
Expand Down Expand Up @@ -195,8 +211,10 @@ def assemble_batch(self, token_batch: torch.tensor, devices) -> torch.tensor:
def prefill_vmfb(
self, token_batch: torch.tensor, i: int, devices: list[iree.runtime.HalDevice]
) -> torch.tensor:
logger.debug(f"Prefill input:")
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
if not self.use_toy_model:
logger.debug(
f"Prefill input:\n{self.generator.tokenizer.decode(token_batch)}"
)

token_batch = self.assemble_batch(token_batch, devices)

Expand Down Expand Up @@ -243,8 +261,9 @@ def prefill_vmfb(
def decode_vmfb(
self, token_batch: torch.tensor, i: int, devices: list[iree.runtime.HalDevice]
) -> torch.tensor:
logger.debug("Decode input:")
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
logger.debug(f"Decode input:")
if not self.use_toy_model:
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
logger.debug(f"{token_batch.tolist()}")

start_positions = [self.batch.seq_lens.clone()]
Expand Down Expand Up @@ -295,8 +314,6 @@ def decode_vmfb(

@timeit
def get_logits(self, skip_decode: bool) -> torch.Tensor:
# Add context to improve perplexity by starting at 10th token
self.start = 10
shard_count = self.tensor_parallelism_size

vm_instance = ireert.VmInstance()
Expand Down Expand Up @@ -366,27 +383,35 @@ def run_iree_module(devices: list[iree.runtime.HalDevice]):
return with_iree_device_context(run_iree_module, devices)

def get_perplexity(
self, test_prompts: list[str], skip_decode: bool
self, test_prompts: list[str], token_ids: list[list[int]], skip_decode: bool
) -> dict[str, Any]:

token_ids, seq_lens = self.generator.tokenizer.encode(
test_prompts,
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
)
if self.use_toy_model:
self.token_ids = token_ids
self.seq_lens = [len(t) for t in self.token_ids]
self.start = 5

logger.debug(f" Prompts for Evaluation:")
for idx, prompt in enumerate(test_prompts):
logger.debug(
f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n"
logger.debug(f" Token ids for Evaluation: \n{self.token_ids}\n")

else:
self.token_ids, self.seq_lens = self.generator.tokenizer.encode(
test_prompts,
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
)

self.page_cache_size = (
len(token_ids[0]) // self.generator.model.config.block_seq_stride
) * len(test_prompts) + 1
logger.debug(f" Prompts for Evaluation:")
for idx, prompt in enumerate(test_prompts):
logger.debug(
f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {self.token_ids[idx]}\n"
)

self.max_prompt_length = max(seq_lens)
self.page_cache_size = (
len(self.token_ids[0]) // self.generator.model.config.block_seq_stride
) * len(test_prompts) + 1

self.token_ids = torch.as_tensor(token_ids, device=self.torch_device)
self.max_prompt_length = max(self.seq_lens)

self.token_ids = torch.as_tensor(self.token_ids, device=self.torch_device)

out_logits = self.get_logits(skip_decode)

Expand All @@ -406,7 +431,15 @@ def run_perplexity_iree(
) -> dict[str, Any]:
start = time.time()

test_prompts = args.prompt_list or get_prompts(num_prompts=args.num_prompts)
token_ids = None
test_prompts = None

if args.use_toy_model:
token_ids = get_token_ids()
bs = len(token_ids)
else:
test_prompts = args.prompt_list or get_prompts(num_prompts=args.num_prompts)
bs = len(test_prompts)

perplexity = PerplexityIree(
torch_device=torch_device,
Expand All @@ -422,18 +455,24 @@ def run_perplexity_iree(
attention_dtype=args.attention_dtype,
kv_cache_dtype=args.kv_cache_dtype,
use_hf=args.use_hf,
bs=len(test_prompts),
bs=bs,
weight_path_str=str(args.irpa_file),
use_toy_model=args.use_toy_model,
)

perplexity.compile_model(
output_mlir=args.output_mlir,
output_config=args.output_config,
output_vmfb=args.output_vmfb,
)
perplexity.load_model(dataset=dataset, tokenizer=tokenizer)
perplexity.load_model(
dataset=dataset,
tokenizer=tokenizer,
)
perplexity_batch = perplexity.get_perplexity(
test_prompts, skip_decode=args.skip_decode
test_prompts=test_prompts,
token_ids=token_ids,
skip_decode=args.skip_decode,
)

end = time.time()
Expand Down Expand Up @@ -463,7 +502,9 @@ def main(argv):

args = cli.parse(parser, args=argv)
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
tokenizer = None
if not args.use_toy_model:
tokenizer = cli.get_tokenizer(args)

logger.setLevel(args.loglevel)
torch_device = torch.device(args.device) if args.device else None
Expand Down
Loading
Loading