forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcausal_mask_preprocess.cpp
More file actions
183 lines (158 loc) · 7.02 KB
/
causal_mask_preprocess.cpp
File metadata and controls
183 lines (158 loc) · 7.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// Copyright (C) 2018-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "causal_mask_preprocess.h"
#include <cstddef>
#include <cstdint>
#include <limits>
#include <memory>
#include <oneapi/dnnl/dnnl_common.hpp>
#include <string>
#include <vector>
#include "cpu_parallel.hpp"
#include "cpu_types.h"
#include "graph_context.h"
#include "memory_desc/cpu_memory_desc.h"
#include "node.h"
#include "onednn/iml_type_mapper.h"
#include "openvino/core/except.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/type.hpp"
#include "openvino/core/type/bfloat16.hpp"
#include "openvino/core/type/element_type.hpp"
#include "shape_inference/shape_inference_internal_dyn.hpp"
#include "transformations/cpu_opset/common/op/causal_mask_preprocess.hpp"
#include "utils/debug_capabilities.h"
#include "utils/plain_tensor.hpp"
namespace ov::intel_cpu::node {
/*
CausalMaskPreprocess:
inputs:
0: attention_mask : i64[N, kv_len]
0 means mask-out, 1 means attends to
1: batch_size (size_Gather) : i32[1]
2: cache_positions i32[q_len];
3: kvLen i32[1];
outputs
0: causal mask for SDPA : f32[batch_size, 1, q_len, kvLen]
The functionality is equivalent to following python code:
##### preprocess
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
causal_mask = causal_mask.to(dtype=dtype, device=device)
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
##### when being used will be further sliced
causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
*/
template <typename T>
struct CausalMaskPreprocess::ExecutorCausalMaskPreprocess : public CausalMaskPreprocess::Executor {
void execute([[maybe_unused]] const dnnl::stream& strm,
intel_cpu::Node* pnode,
const CpuParallelPtr& cpu_parallel,
[[maybe_unused]] const intel_cpu::CausalMaskPreprocessNode::Config& config) override {
ov::intel_cpu::PlainTensor t_attention_mask(pnode->getSrcMemoryAtPort(0));
ov::intel_cpu::PlainTensor t_batch_size(pnode->getSrcMemoryAtPort(1));
ov::intel_cpu::PlainTensor t_cache_positions(pnode->getSrcMemoryAtPort(2));
ov::intel_cpu::PlainTensor t_kvLen(pnode->getSrcMemoryAtPort(3));
auto mask_length = t_attention_mask.size(-1);
auto batch_size = static_cast<size_t>(*t_batch_size.ptr<int32_t>(0));
auto kvLen = static_cast<size_t>(*t_kvLen.ptr<int32_t>(0));
auto qLen = t_cache_positions.size(0);
VectorDims newDims{batch_size, 1, qLen, kvLen};
pnode->redefineOutputMemory({newDims});
ov::intel_cpu::PlainTensor t_dst(pnode->getDstMemoryAtPort(0));
DEBUG_LOG("CausalMaskPreprocess::execute",
config.type,
" batch_size=",
batch_size,
" qLen=",
qLen,
" kvLen=",
kvLen);
DEBUG_LOG("CausalMaskPreprocess::execute attention_mask=", t_attention_mask);
DEBUG_LOG("CausalMaskPreprocess::execute cache_positions=", t_cache_positions);
// raw_causal_mask is already ensured to be triu by transformation
auto* prow = t_cache_positions.ptr<int32_t>(0);
T min_dtype = std::numeric_limits<T>::lowest();
cpu_parallel->parallel_for2d(batch_size, qLen, [&](size_t n, size_t i) {
auto* pamask = t_attention_mask.ptr<int32_t>(n, 0);
auto* pdst = t_dst.ptr<T>(n, 0, i);
auto row = static_cast<size_t>(prow[i]);
size_t j = 0;
for (; j < mask_length; j++) {
bool cmask_eq0 = (j <= row);
bool amask_eq0 = (pamask[j] == 0);
bool padding_mask = (cmask_eq0 && amask_eq0);
pdst[j] =
(static_cast<int>(padding_mask) | static_cast<int>(!cmask_eq0)) ? min_dtype : static_cast<T>(0);
}
for (; j < kvLen; j++) {
bool cmask_eq0 = (j <= row);
pdst[j] = cmask_eq0 ? static_cast<T>(0) : min_dtype;
}
});
DEBUG_LOG("CausalMaskPreprocess::execute dst=", t_dst);
}
};
CausalMaskPreprocess::CausalMaskPreprocess(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context)
: Node(op, context, InternalDynShapeInferFactory()) {
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage);
}
const auto node = ov::as_type_ptr<const intel_cpu::CausalMaskPreprocessNode>(op);
m_config = node->get_config();
}
bool CausalMaskPreprocess::isSupportedOperation(const std::shared_ptr<const ov::Node>& op,
std::string& errorMessage) noexcept {
try {
const auto node = ov::as_type_ptr<const intel_cpu::CausalMaskPreprocessNode>(op);
if (!node) {
errorMessage = "Only CausalMaskPreprocessNode operation is supported";
return false;
}
} catch (...) {
return false;
}
return true;
}
void CausalMaskPreprocess::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty()) {
return;
}
std::vector<ov::element::Type> iprecs = getOriginalInputPrecisions();
std::vector<ov::element::Type> oprecs = getOriginalOutputPrecisions();
// precision preferences
if (m_config.type == "CausalMaskPreprocess") {
if (oprecs[0] == ov::element::bf16) {
m_executor = std::make_shared<ExecutorCausalMaskPreprocess<ov::bfloat16>>();
} else {
// fallback to default precision
m_executor = std::make_shared<ExecutorCausalMaskPreprocess<float>>();
oprecs[0] = ov::element::f32;
}
// all input precisions must be int32
for (auto& prec : iprecs) {
prec = ov::element::i32;
}
} else {
CPU_NODE_THROW("type not supported : " + m_config.type);
}
std::vector<PortConfigurator> inPortConfigs;
for (size_t i = 0; i < getOriginalInputsNumber(); i++) {
inPortConfigs.emplace_back(LayoutType::ncsp, iprecs[i], getInputShapeAtPort(i), false, -1);
}
std::vector<PortConfigurator> outPortConfigs;
for (size_t i = 0; i < getOriginalOutputsNumber(); i++) {
outPortConfigs.emplace_back(LayoutType::ncsp, oprecs[i], getOutputShapeAtPort(i), false, -1);
}
addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any);
}
void CausalMaskPreprocess::execute(const dnnl::stream& strm) {
m_executor->execute(strm, this, context->getCpuParallel(), m_config);
}
} // namespace ov::intel_cpu::node