Skip to content

Aanuf/sdpa v fp8 #3485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
488cacc
Support scale estimation inside GPTQ
alexsu52 Jun 10, 2024
ee64877
fix for INT4_ASYM
alexsu52 Sep 4, 2024
f22e411
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 23, 2024
51b4d7b
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 26, 2024
f66cd1e
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 30, 2024
7ce5a53
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Oct 2, 2024
f74d156
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 11, 2024
5288c79
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 11, 2024
1becf15
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 14, 2024
047d7d9
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 10, 2024
c0c7e57
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 16, 2024
b74dea1
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 27, 2024
26a9a77
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Jan 7, 2025
25fcc2c
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Feb 25, 2025
26d4887
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Mar 12, 2025
7748233
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 1, 2025
df251b3
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 8, 2025
4c134c4
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 9, 2025
6147097
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 14, 2025
2b94d28
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr May 7, 2025
b77d1d6
Enabled quantization of V in SDPA for FP8 type.
andreyanufr May 8, 2025
e04c939
Fix.
andreyanufr May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ def subtype_check(cls, metatype: type["OperatorMetatype"]) -> bool:

return any(subtype.subtype_check(metatype) for subtype in subtypes)

@classmethod
def get_target_input_ports(cls, is_fp8: bool = False) -> list[int] | None:
"""
Returns the target input ports for FP8.
:returns: A list of target input ports for FP8.
"""
if is_fp8 and hasattr(cls, "target_input_ports_fp8"):
return cls.target_input_ports_fp8
return cls.target_input_ports


class OperatorMetatypeRegistry(Registry):
"""
Expand Down
6 changes: 5 additions & 1 deletion nncf/common/quantization/quantizer_propagation/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __init__(
post_processing_marker_metatypes: Optional[list[type[OperatorMetatype]]] = None,
metatypes_to_ignore: Optional[list[type[OperatorMetatype]]] = None,
scales_unification_map: Optional[dict[type[OperatorMetatype], list[type[OperatorMetatype]]]] = None,
is_fp8: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreyanufr, @AlexanderDokuchaev, please provide suggestion how to avoid passing is_fp8 parameter in the solver.

):
"""
Initializes the solver with parameters affecting the resulting quantizer setup.
Expand Down Expand Up @@ -386,6 +387,7 @@ def __init__(
which should be automatically ignored.
:param scales_unification_map: The framework-specific map with NNCF metatypes, which generating a quantizer
that can be unified if it so requires based on metatype.
:param is_fp8: Whether the quantization is done in FP8 mode.
"""
if default_trait_to_metatype_map is None:
self._default_trait_to_metatype_map = {}
Expand All @@ -409,6 +411,7 @@ def __init__(
self._weight_quantizable_node_names_vs_qconfigs = self._filter_by_weight_ignored_target_scopes(
quantizable_layer_nodes, weight_ignored_scopes, weight_target_scopes
)
self._is_fp8 = is_fp8

if scope_overrides is None:
self._scope_overrides: dict[str, Any] = {}
Expand Down Expand Up @@ -1147,7 +1150,8 @@ def _setup_initial_quantizers_for_operator_node(
if input_port_id in metatype.ignored_input_ports:
continue

if metatype.target_input_ports is not None and input_port_id not in metatype.target_input_ports:
target_input_ports = metatype.get_target_input_ports(self._is_fp8)
if target_input_ports is not None and input_port_id not in target_input_ports:
continue

edge = quant_prop_graph.edges[pred_ip_key, operator_node_key]
Expand Down
1 change: 1 addition & 0 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ class OVScaledDotProductAttentionMetatype(OVOpMetatype):
op_names = ["ScaledDotProductAttention"]
hw_config_names = [HWConfigOpName.SCALED_DOT_PRODUCT_ATTENTION]
target_input_ports = [0, 1]
target_input_ports_fp8 = [0, 1, 2]


@OV_OPERATOR_METATYPES.register()
Expand Down
6 changes: 5 additions & 1 deletion nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,14 @@ def _set_mode_based_params(self) -> None:
if getattr(self, self_name) is None:
setattr(self, self_name, default_value)

def _is_fp8(self):
return self._mode in (QuantizationMode.FP8_E4M3, QuantizationMode.FP8_E5M2)

def _review_mode_based_params(self):
"""
Reviews parameter values because mode option doesn't support them.
"""
if self._mode in (QuantizationMode.FP8_E4M3, QuantizationMode.FP8_E5M2):
if self._is_fp8():
nncf_logger.warning(f"You're using experimental option mode with {self._mode} value.")

if self._preset != QuantizationPreset.PERFORMANCE:
Expand Down Expand Up @@ -696,6 +699,7 @@ def _get_quantizer_setup(
metatypes_to_ignore=metatypes_to_ignore,
scales_unification_map=self._backend_entity.scales_unification_map,
scope_overrides=scope_overrides,
is_fp8=self._is_fp8(),
)

quantization_proposal = solver.run_on_ip_graph(ip_graph, self._backend_entity.elementwise_metatypes)
Expand Down
1 change: 1 addition & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,7 @@ class PTScaledDotProductAttentionMetatype(PTOperatorMetatype):
}
hw_config_names = [HWConfigOpName.SCALED_DOT_PRODUCT_ATTENTION]
target_input_ports = [0, 1]
target_input_ports_fp8 = [0, 1, 2]


@PT_OPERATOR_METATYPES.register()
Expand Down
Loading