Skip to content

Aanuf/sdpa v fp8 #3485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
488cacc
Support scale estimation inside GPTQ
alexsu52 Jun 10, 2024
ee64877
fix for INT4_ASYM
alexsu52 Sep 4, 2024
f22e411
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 23, 2024
51b4d7b
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 26, 2024
f66cd1e
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 30, 2024
7ce5a53
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Oct 2, 2024
f74d156
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 11, 2024
5288c79
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 11, 2024
1becf15
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 14, 2024
047d7d9
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 10, 2024
c0c7e57
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 16, 2024
b74dea1
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 27, 2024
26a9a77
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Jan 7, 2025
25fcc2c
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Feb 25, 2025
26d4887
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Mar 12, 2025
7748233
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 1, 2025
df251b3
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 8, 2025
4c134c4
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 9, 2025
6147097
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 14, 2025
2b94d28
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr May 7, 2025
b77d1d6
Enabled quantization of V in SDPA for FP8 type.
andreyanufr May 8, 2025
e04c939
Fix.
andreyanufr May 8, 2025
ebc3715
Merge remote-tracking branch 'upstream/develop' into aanuf/SDPA_V_fp8
andreyanufr May 20, 2025
e87c3a1
Fixed mypy test.
andreyanufr May 20, 2025
47a5194
Added test for fp8 SDPA.
andreyanufr May 21, 2025
4965782
Fixed for mypy.
andreyanufr May 21, 2025
08f35bb
Fixed for mypy.
andreyanufr May 21, 2025
edb8578
Apply suggestions.
andreyanufr May 27, 2025
9f198e2
rfc
alexsu52 May 28, 2025
8d8e0f5
next commit
alexsu52 May 28, 2025
6bbdfd3
target_inputs_port
alexsu52 May 28, 2025
fe7982a
1
alexsu52 May 28, 2025
05650ce
SDPA fp8 V using scope overrides.
andreyanufr May 28, 2025
6946e8e
Fixed bug.
andreyanufr May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions nncf/common/quantization/quantizer_propagation/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,12 +1099,21 @@ def _setup_initial_quantizers_for_operator_node(
qconf_list = deepcopy(self.default_global_qconfig_list)
assert qconf_list is not None

nncf_node_name = next(
iter(quant_prop_graph.op_node_keys_to_underlying_nodes_mapping[operator_node_key])
).node_name
if not HWConfig.is_wildcard_quantization(qconf_list):
nncf_node_ref = next(iter(quant_prop_graph.op_node_keys_to_underlying_nodes_mapping[operator_node_key]))
qconf_list = self._filter_qconfigs_according_to_scope(qconf_list, nncf_node_ref.node_name)
qconf_list = self._filter_qconfigs_according_to_scope(qconf_list, nncf_node_name)
else:
qconf_list = [deepcopy(DEFAULT_QUANTIZER_CONFIG)]

op_override_params = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about introducing the separate function?

op_scope_overrides = self._scope_overrides.get("operations", {})
for overridden_scope, scoped_override_dict in op_scope_overrides.items():
if matches_any(nncf_node_name, overridden_scope):
op_override_params.update(scoped_override_dict)
target_input_ports = op_override_params.get("target_input_ports", metatype.target_input_ports)

is_unified_scale = metatype in self._unified_scales_operation_set
if is_unified_scale:
# Filtering out the per-channel cases in the unified scale scenario.
Expand Down Expand Up @@ -1147,7 +1156,7 @@ def _setup_initial_quantizers_for_operator_node(
if input_port_id in metatype.ignored_input_ports:
continue

if metatype.target_input_ports is not None and input_port_id not in metatype.target_input_ports:
if target_input_ports is not None and input_port_id not in target_input_ports:
continue

edge = quant_prop_graph.edges[pred_ip_key, operator_node_key]
Expand Down
11 changes: 9 additions & 2 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,14 @@ def _set_mode_based_params(self) -> None:
if getattr(self, self_name) is None:
setattr(self, self_name, default_value)

def _is_fp8(self):
return self._mode in (QuantizationMode.FP8_E4M3, QuantizationMode.FP8_E5M2)

def _review_mode_based_params(self):
"""
Reviews parameter values because mode option doesn't support them.
"""
if self._mode in (QuantizationMode.FP8_E4M3, QuantizationMode.FP8_E5M2):
if self._is_fp8():
nncf_logger.warning(f"You're using experimental option mode with {self._mode} value.")

if self._preset != QuantizationPreset.PERFORMANCE:
Expand Down Expand Up @@ -635,10 +638,14 @@ def _get_scope_overrides(self, inference_nncf_graph: NNCFGraph) -> dict:
)
]

target_input_ports = [0, 1, 2] if self._is_fp8() else [0, 1]

scope_overrides_activations = {}
scope_overrides_operations = {}
for node_name in scaled_dot_product_attention_node_names:
scope_overrides_activations[node_name] = {"mode": "symmetric"}
return {"activations": scope_overrides_activations}
scope_overrides_operations[node_name] = {"target_input_ports": target_input_ports}
return {"activations": scope_overrides_activations, "operations": scope_overrides_operations}

def _get_quantizer_setup(
self,
Expand Down
5 changes: 4 additions & 1 deletion tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def _create_ov_model(self, weights_dtype: Optional[ov.Type] = None, activation_d


class ScaledDotProductAttentionModel(OVReferenceModel):
def _create_ov_model(self):
def _create_ov_model(self, with_weights=False):
input_ = opset.parameter([1, 1, 1, 64], name="Input_1")
attn_mask = opset.parameter([1, 1, 1, 1], name="Input_2")
x = opset.reshape(input_, [64], False)
Expand All @@ -898,6 +898,9 @@ def _create_ov_model(self):
for _ in range(3):
x_ = opset.reshape(x, [64], False)
x_ = opset.reshape(x_, [1, 1, 1, 64], False)
if with_weights:
w_ = opset.constant(self._rng.random((64, 64)), dtype=np.float32)
x_ = opset.matmul(x_, w_, transpose_a=False, transpose_b=False)
inputs.append(x_)

attn = opset.scaled_dot_product_attention(*inputs, attn_mask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tests.openvino.native.models import FPModel
from tests.openvino.native.models import LinearModel
from tests.openvino.native.models import MatMul2DModel
from tests.openvino.native.models import ScaledDotProductAttentionModel
from tests.openvino.native.models import UnifiedScalesModel
from tests.openvino.native.models import WeightsModel
from tests.openvino.native.models import get_torch_model_info
Expand Down Expand Up @@ -215,3 +216,21 @@ def test_fq_precision_orig_fp32model(const_dtype, input_dtype, inplace_statistic
fq_input_node = inp_node.get_source_output().get_node()
if fq_input_node.get_type_name() == "Constant":
assert op.get_element_type() == input_dtype


@pytest.mark.parametrize(
"mode, num_quantizers, quantizer_name",
(
(QuantizationMode.FP8_E4M3, 7, "FakeConvert"), # 3 for weights + 1 activation + 3 for SDPA
(QuantizationMode.FP8_E5M2, 7, "FakeConvert"), # 3 for weights + 1 activation + 3 for SDPA
(None, 6, "FakeQuantize"), # 3 for weights + 1 activation + 2 for SDPA
),
)
@pytest.mark.parametrize("model_creator_func", [ScaledDotProductAttentionModel])
def test_sdpa_layer(mode, num_quantizers, quantizer_name, model_creator_func):
model = model_creator_func(with_weights=True)
quantized_model = quantize_model(model.ov_model, {"mode": mode})

stat_nodes = get_fq_nodes_stats_algo(quantized_model)

assert len(stat_nodes) == num_quantizers, f"Expected {num_quantizers} {quantizer_name}, but got {len(stat_nodes)}"
Loading