Skip to content

Commit 4df5e15

Browse files
committed
Add search parameters
1 parent 8eec438 commit 4df5e15

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

olive/passes/onnx/model_builder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from olive.passes import Pass
2626
from olive.passes.olive_pass import PassConfigParam
2727
from olive.passes.pass_config import BasePassConfig
28+
from olive.search.search_parameter import Boolean, Categorical
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -83,11 +84,15 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
8384
"int4_block_size": PassConfigParam(
8485
type_=ModelBuilder.BlockSize,
8586
required=False,
87+
search_defaults=Categorical(
88+
[ModelBuilder.BlockSize.B32, ModelBuilder.BlockSize.B64, ModelBuilder.BlockSize.B128]
89+
),
8690
description="Specify the block_size for int4 quantization. Acceptable values: 16/32/64/128/256.",
8791
),
8892
"int4_is_symmetric": PassConfigParam(
8993
type_=bool,
9094
required=False,
95+
search_defaults=Boolean(),
9196
description="Specify whether symmetric or asymmetric INT4 quantization needs to be used.",
9297
),
9398
"int4_op_types_to_quantize": PassConfigParam(
@@ -106,6 +111,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
106111
"int4_algo_config": PassConfigParam(
107112
type_=str,
108113
required=False,
114+
search_defaults=Categorical(
115+
[
116+
"default",
117+
"rtn",
118+
"k_quant_mixed",
119+
"k_quant_last",
120+
]
121+
),
109122
description="Specify the INT4 quantization algorithm to use in GenAI Model Builder",
110123
),
111124
"use_qdq": PassConfigParam(

olive/passes/pytorch/selective_mixed_precision.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from olive.passes import Pass
2121
from olive.passes.pass_config import BasePassConfig, PassConfigParam
2222
from olive.passes.pytorch.train_utils import get_calibration_dataset, kl_div_loss, load_hf_base_model
23+
from olive.search.search_parameter import Categorical
2324

2425
if TYPE_CHECKING:
2526
from olive.hardware.accelerator import AcceleratorSpec
@@ -65,7 +66,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
6566
return {
6667
"algorithm": PassConfigParam(
6768
type_=SelectiveMixedPrecision.Algorithm,
68-
required=True,
69+
required=False,
70+
search_defaults=Categorical(
71+
[
72+
SelectiveMixedPrecision.Algorithm.K_QUANT_DOWN,
73+
SelectiveMixedPrecision.Algorithm.K_QUANT_MIXED,
74+
SelectiveMixedPrecision.Algorithm.K_QUANT_LAST,
75+
]
76+
),
6977
description="The algorithm to use for mixed precision.",
7078
),
7179
"bits": PassConfigParam(

0 commit comments

Comments
 (0)