Skip to content

Commit 5ed7d41

Browse files
authored
[GPU] Enable INT4 KV-cache compress (#33109)
### Description of the issue - Enable 4bit KV-cache compression for peak memory reduction - Pass WWB accuracy check, no performance regression by this feature #### Reproduction step and snapshot - Run WWB test with ov-config to enable 4bit KV-cache `--ov-config enable_u4.json` - sample json file : `{ "KV_CACHE_PRECISION": "u4", "KEY_CACHE_QUANT_MODE": "BY_CHANNEL" }` - execution WWB `python wwb.py --genai --ov-config enable_uu4.json --target-model minicpm4-0.5b/pytorch/ov/OV_FP16-INT8_ASYM --device GPU --gt-data gt-data/minicpm4-0.5b__NAT/reference.csv"` #### Checklist - [x] Is it a proper fix? - [x] Did you include test case for this fix, if necessary? - [x] Did you review existing test that can be extended to cover this scenario? ### AI Assistance: - AI assistance used: yes Summarized validation report. Analyzed failed unt-test cases and generated proper solution. ### Tickets: - CVS-169489, CVS-180645 --------- Signed-off-by: Min, Byungil <byungil.min@intel.com>
1 parent 5884aeb commit 5ed7d41

File tree

11 files changed

+1024
-170
lines changed

11 files changed

+1024
-170
lines changed

src/common/transformations/include/transformations/common_optimizations/convert_pagedattn_inputs.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class TRANSFORMATIONS_API ConvertPagedAttnInputs;
2020
class ConvertPagedAttnInputs : public ov::pass::MatcherPass {
2121
public:
2222
using UpdateShapeFunc = std::function<void(const ov::element::Type, const bool, const size_t, int64_t&, int64_t&)>;
23+
using UpdatePrecisionFunc = std::function<void(ov::element::Type&)>;
2324

2425
struct KVCacheConfig {
2526
ov::element::Type keyCachePrecision;
@@ -36,7 +37,9 @@ class ConvertPagedAttnInputs : public ov::pass::MatcherPass {
3637
};
3738

3839
OPENVINO_MATCHER_PASS_RTTI("ConvertPagedAttnInputs");
39-
ConvertPagedAttnInputs(const KVCacheConfig& config, UpdateShapeFunc update_shape_func);
40+
ConvertPagedAttnInputs(const KVCacheConfig& config,
41+
UpdateShapeFunc update_shape_func,
42+
UpdatePrecisionFunc update_precision_func = nullptr);
4043

4144
void setKVCacheConfig(const KVCacheConfig& config);
4245

@@ -45,6 +48,7 @@ class ConvertPagedAttnInputs : public ov::pass::MatcherPass {
4548
private:
4649
KVCacheConfig m_config;
4750
UpdateShapeFunc m_update_shape_func;
51+
UpdatePrecisionFunc m_update_precision_func;
4852
};
4953

5054
} // namespace pass

src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ namespace v0 = ov::op::v0;
2222

2323
namespace ov::pass {
2424

25-
ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config, UpdateShapeFunc func)
25+
ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config,
26+
UpdateShapeFunc func,
27+
UpdatePrecisionFunc update_precision_func)
2628
: m_config(config),
27-
m_update_shape_func(std::move(func)) {
29+
m_update_shape_func(std::move(func)),
30+
m_update_precision_func(std::move(update_precision_func)) {
2831
MATCHER_SCOPE(ConvertPagedAttnInputs);
2932

3033
auto Q = pattern::any_input(pattern::has_static_rank());
@@ -87,6 +90,7 @@ ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config, Upda
8790
return cache_precision == ov::element::f16 && infer_precision == ov::element::bf16 ? infer_precision
8891
: cache_precision;
8992
};
93+
9094
auto init_cache_shape = [&](const size_t head_nums,
9195
const size_t head_size,
9296
const size_t block_size,
@@ -105,6 +109,7 @@ ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config, Upda
105109
}
106110
}
107111
size_t group_num = _head_size / _group_size;
112+
// Update head_size and block_size by precision and quantizing channel mode
108113
m_update_shape_func(precision, bychannel, group_num, _head_size, _block_size);
109114

110115
auto block_shape = ov::PartialShape::dynamic(4);
@@ -147,6 +152,13 @@ ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config, Upda
147152
status = false;
148153
}
149154

155+
if (m_update_precision_func) {
156+
m_update_precision_func(key_cache_precision);
157+
m_update_precision_func(value_cache_precision);
158+
key_cache->set_element_type(key_cache_precision);
159+
value_cache->set_element_type(value_cache_precision);
160+
}
161+
150162
key_cache->validate_and_infer_types();
151163
value_cache->validate_and_infer_types();
152164
return status;

src/common/transformations/tests/common_optimizations/convert_pagedattn_inputs.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ std::vector<std::vector<ov::element::Type>> get_cache_prec() {
351351
{ov::element::f16, ov::element::f16},
352352
{ov::element::u8, ov::element::u8},
353353
{ov::element::u8, ov::element::u4},
354+
{ov::element::u4, ov::element::u4},
354355
};
355356
}
356357

0 commit comments

Comments
 (0)