-
Notifications
You must be signed in to change notification settings - Fork 584
Add new export LLM config #11028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/jackzhxng/10/base
Are you sure you want to change the base?
Add new export LLM config #11028
Changes from 12 commits
aa54351
48c8a19
ca9474c
983ff6d
8124a32
c69b158
5f4c78d
e630be8
8a4f08b
a2c19be
caaee0e
ed0dd5a
f769d50
eca31cd
49b05a5
1c2ab41
a02693f
3fb93fe
0bf2ea4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Any targets that should be shared between fbcode and xplat must be defined in | ||
# targets.bzl. This file can contain fbcode-only targets. | ||
|
||
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") | ||
load(":targets.bzl", "define_common_targets") | ||
|
||
oncall("executorch") | ||
|
||
define_common_targets() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
""" | ||
Configurations for exporting Llama. | ||
|
||
Uses dataclases, which integrate with OmegaConf and Hydra. | ||
""" | ||
|
||
import re | ||
from dataclasses import dataclass, field | ||
from enum import Enum | ||
from typing import List, Optional | ||
|
||
|
||
################################################################################ | ||
################################## BaseConfig ################################## | ||
################################################################################ | ||
|
||
|
||
class ModelType(str, Enum): | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
STORIES110M = "stories110m" | ||
LLAMA2 = "llama2" | ||
LLAMA3 = "llama3" | ||
LLAMA3_1 = "llama3_1" | ||
LLAMA3_2 = "llama3_2" | ||
LLAMA3_2_VISION = "llama3_2_vision" | ||
STATIC_LLAMA = "static_llama" | ||
QWEN2_5 = "qwen2_5" | ||
QWEN3_0_6B = "qwen3-0_6b" | ||
QWEN3_1_7B = "qwen3-1_7b" | ||
QWEN3_4B = "qwen3-4b" | ||
PHI_4_MINI = "phi_4_mini" | ||
SMOLLM2 = "smollm2" | ||
|
||
|
||
class PreqMode(str, Enum): | ||
PREQ_8DA4W = "8da4w" | ||
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w" | ||
|
||
|
||
@dataclass | ||
class BaseConfig: | ||
""" | ||
These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini. | ||
For each of these different models, you can expect each of these fields to change. | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
model_class: ModelType = ModelType.LLAMA3 | ||
params: Optional[str] = None | ||
checkpoint: Optional[str] = None | ||
checkpoint_dir: Optional[str] = None # For sharded checkpoint. | ||
tokenizer_path: Optional[str] = None | ||
metadata: Optional[str] = None | ||
use_lora: bool = False | ||
fairseq2: bool = False # For legacy internal use cases. | ||
|
||
# Legacy pre-quantization options that happen during model weight loading. | ||
preq_mode: Optional[PreqMode] = None | ||
preq_group_size: int = 32 | ||
preq_embedding_quantize: str = "8,0" | ||
|
||
|
||
################################################################################ | ||
################################# ModelConfig ################################## | ||
################################################################################ | ||
|
||
|
||
class DtypeOverride(str, Enum): | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
FP32 = "fp32" | ||
FP16 = "fp16" | ||
BF16 = "bf16" | ||
|
||
|
||
@dataclass | ||
class ModelConfig: | ||
""" | ||
These are not necessarily specific to the model, but are needed to finish off | ||
the rest of the model configuration in eager. You can think of these like | ||
optimizations / actual configurations. The same ModelConfig can be applied | ||
to different models. | ||
""" | ||
|
||
dtype_override: DtypeOverride = DtypeOverride.FP32 | ||
enable_dynamic_shape: bool = True | ||
lucylq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
use_shared_embedding: bool = False | ||
use_sdpa_with_kv_cache: bool = False | ||
expand_rope_table: bool = False | ||
use_attention_sink: Optional[str] = None | ||
output_prune_map: Optional[str] = None | ||
input_prune_map: Optional[str] = None | ||
|
||
# Below are config options relating to kv cache. | ||
use_kv_cache: bool = False | ||
quantize_kv_cache: bool = False | ||
local_global_attention: Optional[List[int]] = None | ||
|
||
|
||
################################################################################ | ||
################################ ExportConfig ################################## | ||
################################################################################ | ||
|
||
|
||
@dataclass | ||
class ExportConfig: | ||
max_seq_length: int = 128 | ||
max_context_length: int = 128 | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
output_dir: Optional[str] = None | ||
output_name: Optional[str] = None | ||
so_library: Optional[str] = None | ||
export_only: bool = False | ||
|
||
|
||
################################################################################ | ||
################################# DebugConfig ################################## | ||
################################################################################ | ||
|
||
|
||
@dataclass | ||
class DebugConfig: | ||
profile_memory: bool = False | ||
profile_path: Optional[str] = None | ||
generate_etrecord: bool = False | ||
generate_full_logits: bool = False | ||
verbose: bool = False | ||
|
||
|
||
################################################################################ | ||
############################# QuantizationConfig ############################### | ||
################################################################################ | ||
|
||
|
||
class Pt2eQuantize(str, Enum): | ||
XNNPACK_DYNAMIC = "xnnpack_dynamic" | ||
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4" | ||
QNN_8A8W = "qnn_8a8w" | ||
QNN_16A16W = "qnn_16a16w" | ||
QNN_16A4W = "qnn_16a4w" | ||
COREML_C4W = "coreml_c4w" | ||
COREML_8A_C8W = "coreml_8a_c8w" | ||
COREML_8A_C4W = "coreml_8a_c4w" | ||
COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w" | ||
COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w" | ||
VULKAN_8W = "vulkan_8w" | ||
|
||
|
||
class SpinQuant(str, Enum): | ||
CUDA = "cuda" | ||
NATIVE = "native" | ||
|
||
|
||
@dataclass | ||
class QuantizationConfig: | ||
qmode: Optional[str] = None | ||
embedding_quantize: Optional[str] = None | ||
pt2e_quantize: Optional[Pt2eQuantize] = None | ||
group_size: Optional[int] = None | ||
use_spin_quant: Optional[SpinQuant] = None | ||
use_qat: bool = False | ||
calibration_tasks: Optional[List[str]] = None | ||
calibration_limit: Optional[int] = None | ||
calibration_seq_length: Optional[int] = None | ||
calibration_data: str = "Once upon a time" | ||
|
||
def __post_init__(self): | ||
if self.qmode: | ||
self._validate_qmode() | ||
|
||
def _validate_qmode(self) -> None: | ||
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] | ||
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"] | ||
|
||
if self.qmode in choices: | ||
return | ||
|
||
for pattern in patterns: | ||
matches = re.findall(pattern, self.qmode) | ||
if len(matches) == 1: | ||
return | ||
|
||
raise ValueError( | ||
f"Got qmode {self.qmode}, but expected one of {choices}, or one of the regex patterns {patterns}." | ||
) | ||
|
||
|
||
################################################################################ | ||
############################### BackendConfig ################################## | ||
################################################################################ | ||
|
||
|
||
@dataclass | ||
class XNNPackConfig: | ||
enabled: bool = False | ||
extended_ops: bool = False | ||
|
||
|
||
class CoreMLQuantize(str, Enum): | ||
B4W = "b4w" | ||
C4W = "c4w" | ||
|
||
|
||
class CoreMLComputeUnit(str, Enum): | ||
CPU_ONLY = "cpu_only" | ||
CPU_AND_GPU = "cpu_and_gpu" | ||
CPU_AND_NE = "cpu_and_ne" | ||
ALL = "all" | ||
|
||
|
||
@dataclass | ||
class CoreMLConfig: | ||
enabled: bool = False | ||
enable_state: bool = False | ||
preserve_sdpa: bool = False | ||
quantize: Optional[CoreMLQuantize] = None | ||
ios: int = 15 | ||
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY | ||
|
||
def __post_init__(self): | ||
if self.ios not in (15, 16, 17, 18): | ||
raise ValueError(f"Invalid coreml ios version: {self.ios}") | ||
|
||
|
||
@dataclass | ||
class VulkanConfig: | ||
enabled: bool = False | ||
|
||
|
||
@dataclass | ||
class QNNConfig: | ||
enabled: bool = False | ||
use_sha: bool = False | ||
soc_model: str = "SM8650" | ||
use_qnn_sha: bool = False | ||
optimized_rotation_path: Optional[str] = None | ||
num_sharding: int = 0 | ||
|
||
|
||
@dataclass | ||
class MPSConfig: | ||
enabled: bool = False | ||
|
||
|
||
@dataclass | ||
class BackendConfig: | ||
xnnpack: XNNPackConfig = field(default_factory=XNNPackConfig) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wdyt about making these optional configs, set to None? I think that could make it easier to tell when a backend is enabled or not, instead of looking through each backend config to check if enabled=False. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mind if I follow up with this change? I would like to keep the changes in this diff as small as possible (ie as close to the original args as possible) to make it easier to pass all the internal CI There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, sounds good. |
||
coreml: CoreMLConfig = field(default_factory=CoreMLConfig) | ||
vulkan: VulkanConfig = field(default_factory=VulkanConfig) | ||
qnn: QNNConfig = field(default_factory=QNNConfig) | ||
mps: MPSConfig = field(default_factory=MPSConfig) | ||
|
||
|
||
################################################################################ | ||
################################## LlmConfig ################################### | ||
################################################################################ | ||
|
||
|
||
@dataclass | ||
class LlmConfig: | ||
base: BaseConfig = field(default_factory=BaseConfig) | ||
model: ModelConfig = field(default_factory=ModelConfig) | ||
export: ExportConfig = field(default_factory=ExportConfig) | ||
debug: DebugConfig = field(default_factory=DebugConfig) | ||
quantization: QuantizationConfig = field(default_factory=QuantizationConfig) | ||
backend: BackendConfig = field(default_factory=BackendConfig) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") | ||
|
||
def define_common_targets(): | ||
runtime.python_library( | ||
name = "llm_config", | ||
srcs = [ | ||
"llm_config.py", | ||
], | ||
_is_external_target = True, | ||
base_module = "executorch.examples.models.llama.config", | ||
visibility = [ | ||
"//executorch/...", | ||
"@EXECUTORCH_CLIENTS", | ||
], | ||
) | ||
|
||
runtime.python_library( | ||
name = "llm_config_utils", | ||
srcs = [ | ||
"llm_config_utils.py", | ||
], | ||
_is_external_target = True, | ||
base_module = "executorch.examples.models.llama.config", | ||
visibility = [ | ||
"//executorch/...", | ||
"@EXECUTORCH_CLIENTS", | ||
], | ||
deps = [ | ||
":llm_config", | ||
], | ||
) |
Uh oh!
There was an error while loading. Please reload this page.