Skip to content

Commit 9a0546b

Browse files
authored
Merge pull request #3950 from wangzhaode/feature/sync_325
MNN:Sync: Sync Internal 3.2.5
2 parents 0138017 + 6064746 commit 9a0546b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3079
-1260
lines changed

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
tools/test
8282
tools/benchmark
8383
tools/compress
84+
tools/mnncompress
8485
tools/visual
8586
tools/python
8687

docs/tools/compress.md

Lines changed: 363 additions & 934 deletions
Large diffs are not rendered by default.

docs/tools/mnncompress.md

Lines changed: 837 additions & 0 deletions
Large diffs are not rendered by default.

express/MathOp.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,7 @@ VARP _ScatterNd(VARP indices, VARP updates, VARP shape, int reducetion) {
11401140
op->main.type = OpParameter_BinaryOp;
11411141
op->type = OpType_ScatterNd;
11421142
auto param = new BinaryOpT;
1143-
param->opType = reducetion;
1143+
param->opType = (BinaryOpOperation)reducetion;
11441144
op->main.value = param;
11451145
return (Variable::create(Expr::create(std::move(op), {indices, updates, shape})));
11461146
}
@@ -1150,7 +1150,7 @@ VARP _ScatterNd(VARP indices, VARP updates, VARP shape, VARP input, int reduceti
11501150
op->main.type = OpParameter_BinaryOp;
11511151
op->type = OpType_ScatterNd;
11521152
auto param = new BinaryOpT;
1153-
param->opType = reducetion;
1153+
param->opType = (BinaryOpOperation)reducetion;
11541154
op->main.value = param;
11551155
return (Variable::create(Expr::create(std::move(op), {indices, updates, shape, input})));
11561156
}
@@ -1167,7 +1167,7 @@ VARP _ScatterElements(VARP data, VARP indices, VARP updates, int reducetion) {
11671167
op->main.type = OpParameter_BinaryOp;
11681168
op->type = OpType_ScatterElements;
11691169
auto param = new BinaryOpT;
1170-
param->opType = reducetion;
1170+
param->opType = (BinaryOpOperation)reducetion;
11711171
op->main.value = param;
11721172
return (Variable::create(Expr::create(std::move(op), {data, indices, updates})));
11731173
}
@@ -1177,7 +1177,7 @@ VARP _ScatterElements(VARP data, VARP indices, VARP updates, VARP axis, int redu
11771177
op->main.type = OpParameter_BinaryOp;
11781178
op->type = OpType_ScatterElements;
11791179
auto param = new BinaryOpT;
1180-
param->opType = reducetion;
1180+
param->opType = (BinaryOpOperation)reducetion;
11811181
op->main.value = param;
11821182
return (Variable::create(Expr::create(std::move(op), {data, indices, updates, axis})));
11831183
}

include/MNN/MNNDefine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
7878
#define STR(x) STR_IMP(x)
7979
#define MNN_VERSION_MAJOR 3
8080
#define MNN_VERSION_MINOR 2
81-
#define MNN_VERSION_PATCH 4
81+
#define MNN_VERSION_PATCH 5
8282
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
8383
#endif /* MNNDefine_h */

schema/current/TensorflowOp_generated.h

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -592,11 +592,11 @@ inline const char *EnumNamePadValueMode(PadValueMode e) {
592592

593593
struct BinaryOpT : public flatbuffers::NativeTable {
594594
typedef BinaryOp TableType;
595-
int32_t opType;
595+
BinaryOpOperation opType;
596596
DataType T;
597597
int32_t activationType;
598598
BinaryOpT()
599-
: opType(0),
599+
: opType(BinaryOpOperation_ADD),
600600
T(DataType_DT_FLOAT),
601601
activationType(0) {
602602
}
@@ -607,8 +607,8 @@ struct BinaryOp FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
607607
static const flatbuffers::TypeTable *MiniReflectTypeTable() {
608608
return BinaryOpTypeTable();
609609
}
610-
int32_t opType() const {
611-
return GetField<int32_t>(4, 0);
610+
BinaryOpOperation opType() const {
611+
return static_cast<BinaryOpOperation>(GetField<int32_t>(4, 0));
612612
}
613613
DataType T() const {
614614
return static_cast<DataType>(GetField<int32_t>(6, 1));
@@ -631,8 +631,8 @@ struct BinaryOp FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
631631
struct BinaryOpBuilder {
632632
flatbuffers::FlatBufferBuilder &fbb_;
633633
flatbuffers::uoffset_t start_;
634-
void add_opType(int32_t opType) {
635-
fbb_.AddElement<int32_t>(4, opType, 0);
634+
void add_opType(BinaryOpOperation opType) {
635+
fbb_.AddElement<int32_t>(4, static_cast<int32_t>(opType), 0);
636636
}
637637
void add_T(DataType T) {
638638
fbb_.AddElement<int32_t>(6, static_cast<int32_t>(T), 1);
@@ -654,7 +654,7 @@ struct BinaryOpBuilder {
654654

655655
inline flatbuffers::Offset<BinaryOp> CreateBinaryOp(
656656
flatbuffers::FlatBufferBuilder &_fbb,
657-
int32_t opType = 0,
657+
BinaryOpOperation opType = BinaryOpOperation_ADD,
658658
DataType T = DataType_DT_FLOAT,
659659
int32_t activationType = 0) {
660660
BinaryOpBuilder builder_(_fbb);
@@ -4930,34 +4930,34 @@ inline flatbuffers::Offset<LSTMBlockCell> CreateLSTMBlockCell(flatbuffers::FlatB
49304930

49314931
inline const flatbuffers::TypeTable *BinaryOpOperationTypeTable() {
49324932
static const flatbuffers::TypeCode type_codes[] = {
4933-
{ flatbuffers::ET_CHAR, 0, 0 },
4934-
{ flatbuffers::ET_CHAR, 0, 0 },
4935-
{ flatbuffers::ET_CHAR, 0, 0 },
4936-
{ flatbuffers::ET_CHAR, 0, 0 },
4937-
{ flatbuffers::ET_CHAR, 0, 0 },
4938-
{ flatbuffers::ET_CHAR, 0, 0 },
4939-
{ flatbuffers::ET_CHAR, 0, 0 },
4940-
{ flatbuffers::ET_CHAR, 0, 0 },
4941-
{ flatbuffers::ET_CHAR, 0, 0 },
4942-
{ flatbuffers::ET_CHAR, 0, 0 },
4943-
{ flatbuffers::ET_CHAR, 0, 0 },
4944-
{ flatbuffers::ET_CHAR, 0, 0 },
4945-
{ flatbuffers::ET_CHAR, 0, 0 },
4946-
{ flatbuffers::ET_CHAR, 0, 0 },
4947-
{ flatbuffers::ET_CHAR, 0, 0 },
4948-
{ flatbuffers::ET_CHAR, 0, 0 },
4949-
{ flatbuffers::ET_CHAR, 0, 0 },
4950-
{ flatbuffers::ET_CHAR, 0, 0 },
4951-
{ flatbuffers::ET_CHAR, 0, 0 },
4952-
{ flatbuffers::ET_CHAR, 0, 0 },
4953-
{ flatbuffers::ET_CHAR, 0, 0 },
4954-
{ flatbuffers::ET_CHAR, 0, 0 },
4955-
{ flatbuffers::ET_CHAR, 0, 0 },
4956-
{ flatbuffers::ET_CHAR, 0, 0 },
4957-
{ flatbuffers::ET_CHAR, 0, 0 },
4958-
{ flatbuffers::ET_CHAR, 0, 0 },
4959-
{ flatbuffers::ET_CHAR, 0, 0 },
4960-
{ flatbuffers::ET_CHAR, 0, 0 }
4933+
{ flatbuffers::ET_INT, 0, 0 },
4934+
{ flatbuffers::ET_INT, 0, 0 },
4935+
{ flatbuffers::ET_INT, 0, 0 },
4936+
{ flatbuffers::ET_INT, 0, 0 },
4937+
{ flatbuffers::ET_INT, 0, 0 },
4938+
{ flatbuffers::ET_INT, 0, 0 },
4939+
{ flatbuffers::ET_INT, 0, 0 },
4940+
{ flatbuffers::ET_INT, 0, 0 },
4941+
{ flatbuffers::ET_INT, 0, 0 },
4942+
{ flatbuffers::ET_INT, 0, 0 },
4943+
{ flatbuffers::ET_INT, 0, 0 },
4944+
{ flatbuffers::ET_INT, 0, 0 },
4945+
{ flatbuffers::ET_INT, 0, 0 },
4946+
{ flatbuffers::ET_INT, 0, 0 },
4947+
{ flatbuffers::ET_INT, 0, 0 },
4948+
{ flatbuffers::ET_INT, 0, 0 },
4949+
{ flatbuffers::ET_INT, 0, 0 },
4950+
{ flatbuffers::ET_INT, 0, 0 },
4951+
{ flatbuffers::ET_INT, 0, 0 },
4952+
{ flatbuffers::ET_INT, 0, 0 },
4953+
{ flatbuffers::ET_INT, 0, 0 },
4954+
{ flatbuffers::ET_INT, 0, 0 },
4955+
{ flatbuffers::ET_INT, 0, 0 },
4956+
{ flatbuffers::ET_INT, 0, 0 },
4957+
{ flatbuffers::ET_INT, 0, 0 },
4958+
{ flatbuffers::ET_INT, 0, 0 },
4959+
{ flatbuffers::ET_INT, 0, 0 },
4960+
{ flatbuffers::ET_INT, 0, 0 }
49614961
};
49624962
static const flatbuffers::TypeFunction type_refs[] = {
49634963
BinaryOpOperationTypeTable
@@ -5175,11 +5175,12 @@ inline const flatbuffers::TypeTable *PadValueModeTypeTable() {
51755175

51765176
inline const flatbuffers::TypeTable *BinaryOpTypeTable() {
51775177
static const flatbuffers::TypeCode type_codes[] = {
5178-
{ flatbuffers::ET_INT, 0, -1 },
51795178
{ flatbuffers::ET_INT, 0, 0 },
5179+
{ flatbuffers::ET_INT, 0, 1 },
51805180
{ flatbuffers::ET_INT, 0, -1 }
51815181
};
51825182
static const flatbuffers::TypeFunction type_refs[] = {
5183+
BinaryOpOperationTypeTable,
51835184
DataTypeTypeTable
51845185
};
51855186
static const char * const names[] = {

schema/default/TensorflowOp.fbs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
include "Tensor.fbs";
22
namespace MNN;
33

4-
enum BinaryOpOperation : byte {
4+
enum BinaryOpOperation : int {
55
ADD = 0,
66
SUB = 1,
77
MUL = 2,
@@ -33,7 +33,7 @@ enum BinaryOpOperation : byte {
3333
}
3434

3535
table BinaryOp {
36-
opType:int;
36+
opType:BinaryOpOperation;
3737
T:DataType=DT_FLOAT;
3838
// 0 -> No Activation
3939
// 1 -> Relu

source/backend/cpu/KVCacheManager.cpp

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
264264
} else if (mConfig.mQuantKey) {
265265
old_key.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(oldMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
266266
} else {
267-
old_key.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(oldMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
267+
old_key.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(oldMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
268268
}
269269
if (mConfig.mQuantValue) {
270270
old_value.reset(Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(oldMaxLength, lP), hP, lP}));
@@ -387,7 +387,7 @@ void KVCacheManager::onAlloc(int kv_seq_len) {
387387
} else {
388388
mPastValue.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP}));
389389
}
390-
mBackend->onAcquireBuffer(mPastKey.get(), Backend::STATIC);
390+
mBackend->onAcquireBuffer(mPastKey.get(), Backend::STATIC);
391391
mBackend->onAcquireBuffer(mPastValue.get(), Backend::STATIC);
392392
if (mHeadDim % lP) {
393393
memset(mPastKey->host<int8_t>(), 0, mPastKey->length(0) * mPastKey->stride(0) * mBytes);
@@ -486,6 +486,21 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
486486
mPastLength = start;
487487
return;
488488
}
489+
#if 1
490+
auto dstIndex = start;
491+
for (int n = 0; n < meta->n_reserve; ++n) {
492+
auto begin = meta->reserve[2 * n];
493+
auto size = meta->reserve[2 * n + 1];
494+
auto srcIndex = start + begin;
495+
if (mBytes == 2) {
496+
moveKV<FLOAT16_T>(srcIndex, dstIndex, size);
497+
} else {
498+
moveKV<float>(srcIndex, dstIndex, size);
499+
}
500+
dstIndex += size;
501+
}
502+
mPastLength = dstIndex;
503+
#else
489504
// Don't support not align reserve
490505
auto align = hP;
491506
auto dstStart = start;
@@ -503,7 +518,7 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
503518
}
504519
auto end = begin + start + size;
505520
auto endAlign = UP_DIV(end, align) * align;
506-
521+
507522
auto sizeUnit = (endAlign - startAlign) / align;
508523
auto dstStartAlign = UP_DIV(dstStart, align) * align;
509524
@@ -539,6 +554,7 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
539554
lastValidSrcEnd = begin + start + size;
540555
}
541556
mPastLength = dstStart;
557+
#endif
542558
}
543559

544560
void KVCacheManager::onClear() {
@@ -551,7 +567,7 @@ void KVCacheManager::onClear() {
551567
} else {
552568
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
553569
}
554-
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * (mConfig.mQuantValue ? 1 : mBytes);
570+
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * (mConfig.mQuantValue ? 1 : mBytes);
555571
unmapKVCache(keySize, valueSize);
556572
removeKVCacheFile();
557573
mKVCacheInDisk = false;
@@ -663,6 +679,34 @@ void KVCacheManager::pack_value(const Tensor* value, int seq_len, int kv_h) { //
663679
}
664680
}
665681

682+
size_t KVCacheManager::keyIndex(int seq, int dim) const {
683+
return (seq / hP) * ROUND_UP(mHeadDim, lP) * hP +
684+
(dim / lP) * hP * lP +
685+
(seq % hP) * lP +
686+
(dim % lP);
687+
}
688+
689+
size_t KVCacheManager::valueIndex(int seq, int dim) const {
690+
return (dim / hP) * ROUND_UP(mMaxLength, lP) * hP +
691+
(seq / lP) * hP * lP +
692+
(dim % hP) * lP +
693+
(seq % lP);
694+
}
695+
696+
template <typename T>
697+
void KVCacheManager::moveKV(int src, int dst, int size) {
698+
for (int h = 0; h < mKvNumHead; ++h) {
699+
auto kPtr = reinterpret_cast<T*>(addrOfKey(h));
700+
auto vPtr = reinterpret_cast<T*>(addrOfValue(h));
701+
for (int i = 0; i < size; i++) {
702+
for (int j = 0; j < mHeadDim; j++) {
703+
kPtr[keyIndex(dst + i, j)] = kPtr[keyIndex(src + i, j)];
704+
vPtr[valueIndex(dst + i, j)] = vPtr[valueIndex(src + i, j)];
705+
}
706+
}
707+
}
708+
}
709+
666710
void KVCacheManager::onPushBack(const Tensor * key, const Tensor * value, int add) {
667711
auto core = static_cast<CPUBackend*>(mBackend)->functions();
668712
int seq_len = add;

source/backend/cpu/KVCacheManager.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class KVCacheManager : public NonCopyable{
3939
private:
4040
Backend * mBackend;
4141
KVCacheConfig mConfig;
42-
std::shared_ptr<Tensor> mPastKey; // {numhead, [maxlen/hP, headdim, hP]} or {numhead, [maxlen/hP8, headdim/lP8, hP8, lP8]}
42+
std::shared_ptr<Tensor> mPastKey; // {numhead, [maxlen/hP, headdim, hP]} or {numhead, [maxlen/hP8, headdim/lP8, hP8, lP8]}
4343
std::shared_ptr<Tensor> mPastValue; // numhead, [headdim/hP, maxlen, hP]
4444
std::shared_ptr<Tensor> mKeyScale; // {numhead, [maxlen/hP, hP]} or {numhead, [maxlen/hP8, hP8]}
4545
std::shared_ptr<Tensor> mKeyZeroPoint; // {numhead, [maxlen/hP, hP]} or {numhead, [maxlen/hP8, hP8]}
@@ -65,10 +65,13 @@ class KVCacheManager : public NonCopyable{
6565
void expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int oldValueSize, int keySize, int valueSize);
6666
template <typename T> void pack_key(const Tensor* key, int seq_len, int kv_h);
6767
template <typename T> void pack_value(const Tensor* value, int seq_len, int kv_h);
68+
template <typename T> void moveKV(int src, int dst, int size);
69+
size_t keyIndex(int seq, int dim) const;
70+
size_t valueIndex(int seq, int dim) const;
6871
public:
6972
KVCacheManager(Backend * backend, KVCacheConfig & kvConfig) {
7073
mBackend = backend;
71-
mConfig = kvConfig;
74+
mConfig = kvConfig;
7275
}
7376
~KVCacheManager() {
7477
onClear();

source/backend/metal/MetalAttention.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,15 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override {
176176
auto valueBuf = MetalBackend::getBuffer(mCache->mPastValue.get());
177177
auto value_ptr = (uint8_t*)[valueBuf.first contents] + valueBuf.second;
178178

179+
auto src_start = start;
179180
// TODO: need to ensure reserve info is sorted
180181
for (int n = 0; n < mMeta->n_reserve; ++n) {
181182
auto begin = mMeta->reserve[2 * n];
182183
auto length = mMeta->reserve[2 * n + 1];
183184
// past_key : [mCache->mPastLength, mKvNumHead, mHeadDim]
184185
// past_value : [mKvNumHead, mHeadDim, mCache->mMaxLength]
185186

186-
auto copy_src_index = start + begin;
187+
auto copy_src_index = src_start + begin;
187188
auto copy_dst_index = start;
188189
for(int i = 0; i < length; i++) {
189190
::memcpy(key_ptr + (copy_dst_index + i) * mKvNumHead * mHeadDim * byte, key_ptr + (copy_src_index + i) * mKvNumHead * mHeadDim * byte, mKvNumHead * mHeadDim * byte);

0 commit comments

Comments
 (0)