Skip to content

Commit fedeb72

Browse files
committed
Add search parameters
1 parent 8e3a769 commit fedeb72

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

olive/passes/onnx/model_builder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from olive.passes import Pass
2626
from olive.passes.olive_pass import PassConfigParam
2727
from olive.passes.pass_config import BasePassConfig
28+
from olive.search.search_parameter import Boolean, Categorical
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -83,11 +84,16 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
8384
"int4_block_size": PassConfigParam(
8485
type_=ModelBuilder.BlockSize,
8586
required=False,
87+
search_defaults=Categorical([
88+
ModelBuilder.BlockSize.B32,
89+
ModelBuilder.BlockSize.B64,
90+
ModelBuilder.BlockSize.B128]),
8691
description="Specify the block_size for int4 quantization. Acceptable values: 16/32/64/128/256.",
8792
),
8893
"int4_is_symmetric": PassConfigParam(
8994
type_=bool,
9095
required=False,
96+
search_defaults=Boolean(),
9197
description="Specify whether symmetric or asymmetric INT4 quantization needs to be used.",
9298
),
9399
"int4_op_types_to_quantize": PassConfigParam(
@@ -106,6 +112,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
106112
"int4_algo_config": PassConfigParam(
107113
type_=str,
108114
required=False,
115+
search_defaults=Categorical([
116+
"default",
117+
"rtn",
118+
"k_quant_mixed",
119+
"k_quant_last",
120+
]),
109121
description="Specify the INT4 quantization algorithm to use in GenAI Model Builder",
110122
),
111123
"use_qdq": PassConfigParam(

olive/passes/pytorch/selective_mixed_precision.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from olive.passes import Pass
2121
from olive.passes.pass_config import BasePassConfig, PassConfigParam
2222
from olive.passes.pytorch.train_utils import get_calibration_dataset, kl_div_loss, load_hf_base_model
23+
from olive.search.search_parameter import Boolean, Categorical
2324

2425
if TYPE_CHECKING:
2526
from olive.hardware.accelerator import AcceleratorSpec
@@ -65,7 +66,11 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
6566
return {
6667
"algorithm": PassConfigParam(
6768
type_=SelectiveMixedPrecision.Algorithm,
68-
required=True,
69+
required=False,
70+
search_defaults=Categorical([
71+
SelectiveMixedPrecision.Algorithm.K_QUANT_DOWN,
72+
SelectiveMixedPrecision.Algorithm.K_QUANT_MIXED,
73+
SelectiveMixedPrecision.Algorithm.K_QUANT_LAST]),
6974
description="The algorithm to use for mixed precision.",
7075
),
7176
"bits": PassConfigParam(

0 commit comments

Comments
 (0)