Skip to content

Commit d1615d2

Browse files
authored
Merge pull request #629 from yuerqiqi/feat/concat-slice
[Ascend] Implement Concat and Slice operators
2 parents 9fd8762 + a562bbd commit d1615d2

8 files changed

Lines changed: 430 additions & 1 deletion

File tree

mllm/backends/ascend/AscendBackend.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
#include "mllm/backends/ascend/ops/AscendViewOp.hpp"
1515
#include "mllm/backends/ascend/ops/AscendMatMulOp.hpp"
1616
#include "mllm/backends/ascend/ops/AscendSoftmaxOp.hpp"
17+
#include "mllm/backends/ascend/ops/AscendConcatOp.hpp"
18+
#include "mllm/backends/ascend/ops/AscendSliceOp.hpp"
1719

1820
namespace mllm::ascend {
1921

2022
AscendBackend::AscendBackend() : Backend(kAscend, createAscendAllocator()) {
2123
regOpFactory<AscendAddOpFactory,AscendSubOpFactory,AscendMulOpFactory,AscendX2XOpFactory,AscendSiLUOpFactory,
22-
AscendLinearOpFactory,AscendRMSNormOpFactory,AscendViewOpFactory,AscendMatMulOpFactory,AscendSoftmaxOpFactory>();
24+
AscendLinearOpFactory,AscendRMSNormOpFactory,AscendViewOpFactory,AscendMatMulOpFactory,AscendSoftmaxOpFactory,
25+
AscendConcatOpFactory, AscendSliceOpFactory>();
2326
auto& devices = AscendDeviceMetaInfo::instance().devices;
2427
for (const auto& device : devices) {
2528
const auto bytes_to_mb = [](size_t bytes) { return bytes / (1024.0 * 1024.0); };
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright (c) MLLM Team.
2+
// Licensed under the MIT License.
3+
4+
#include "mllm/backends/ascend/ops/AscendConcatOp.hpp"
5+
6+
#include <iostream>
7+
#include <acl/acl.h>
8+
#include <atb/atb_infer.h>
9+
#include <atb/types.h>
10+
#include <atb/utils.h>
11+
#include <atb/infer_op_params.h>
12+
13+
#include "mllm/utils/Common.hpp"
14+
#include "mllm/core/DataTypes.hpp"
15+
#include "mllm/core/Tensor.hpp"
16+
#include "mllm/backends/ascend/memory/AscendMemoryManager.hpp"
17+
#include "mllm/backends/ascend/AscendCommon.hpp"
18+
19+
namespace mllm::ascend {
20+
21+
AscendConcatOp::AscendConcatOp(const aops::ConcatOpOptions& options) : aops::ConcatOp(options) {}
22+
23+
void AscendConcatOp::setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
24+
BaseOp::setup(inputs, outputs);
25+
}
26+
27+
void AscendConcatOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
28+
MLLM_RT_ASSERT(inputs.size() >= 1);
29+
MLLM_RT_ASSERT_EQ(outputs.size(), 1);
30+
31+
if (inputs.size() == 1) {
32+
const size_t data_size = inputs[0].bytes();
33+
const void* src_data = inputs[0].ptr<void>();
34+
void* dst_data = outputs[0].ptr<void>();
35+
36+
if (src_data != dst_data) {
37+
auto ret = aclrtMemcpy(dst_data, data_size, src_data, data_size, ACL_MEMCPY_DEVICE_TO_DEVICE);
38+
if (ret != ACL_SUCCESS) {
39+
MLLM_ACL_CHECK(ret);
40+
}
41+
syncGlobalAtbStream();
42+
}
43+
return;
44+
}
45+
46+
int32_t concat_dim = options().dim;
47+
if (concat_dim < 0) {
48+
concat_dim += static_cast<int32_t>(inputs[0].rank());
49+
}
50+
51+
auto run_concat = [&](const Tensor& left, const Tensor& right, Tensor& out) {
52+
atb::infer::ConcatParam param;
53+
param.concatDim = concat_dim;
54+
55+
atb::Operation* op = nullptr;
56+
auto st = atb::CreateOperation(param, &op);
57+
if (st != atb::NO_ERROR || op == nullptr) {
58+
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB CreateOperation(Concat) failed, status={}", static_cast<int>(st));
59+
}
60+
61+
atb::Context* atb_ctx = getGlobalAtbContext();
62+
63+
atb::SVector<atb::Tensor> inTensors;
64+
atb::Tensor atb_left;
65+
atb::Tensor atb_right;
66+
fillAtbTensor(left, atb_left);
67+
fillAtbTensor(right, atb_right);
68+
inTensors.push_back(atb_left);
69+
inTensors.push_back(atb_right);
70+
71+
atb::Tensor atb_out;
72+
fillAtbTensor(out, atb_out);
73+
atb::SVector<atb::Tensor> outTensors;
74+
outTensors.push_back(atb_out);
75+
76+
atb::VariantPack vp;
77+
vp.inTensors = inTensors;
78+
vp.outTensors = outTensors;
79+
80+
uint64_t workspaceSize = 0;
81+
st = op->Setup(vp, workspaceSize, atb_ctx);
82+
if (st != atb::NO_ERROR) {
83+
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB ConcatOp Setup failed, status={}", static_cast<int>(st));
84+
}
85+
86+
void* workspace = nullptr;
87+
int workspace_block_id = -1;
88+
if (workspaceSize > 0) {
89+
auto& mem_mgr = getAscendMemoryManager();
90+
mem_mgr.allocateBlock(static_cast<uint32_t>(workspaceSize), workspace_block_id);
91+
mem_mgr.getBlockPtr(workspace_block_id, workspace);
92+
}
93+
94+
{
95+
ASCEND_TIME_SCOPE("AscendConcatOp::forward");
96+
st = op->Execute(vp, reinterpret_cast<uint8_t*>(workspace), workspaceSize, atb_ctx);
97+
}
98+
99+
if (st != atb::NO_ERROR) {
100+
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB ConcatOp Execute failed, status={}", static_cast<int>(st));
101+
}
102+
103+
syncGlobalAtbStream();
104+
105+
if (workspace_block_id != -1) {
106+
auto& mem_mgr = getAscendMemoryManager();
107+
mem_mgr.freeBlock(workspace_block_id);
108+
}
109+
110+
atb::DestroyOperation(op);
111+
};
112+
113+
std::vector<int32_t> current_shape = inputs[0].shape();
114+
Tensor current = inputs[0];
115+
116+
for (size_t i = 1; i < inputs.size(); ++i) {
117+
current_shape[concat_dim] += inputs[i].shape()[concat_dim];
118+
119+
if (i == inputs.size() - 1) {
120+
run_concat(current, inputs[i], outputs[0]);
121+
} else {
122+
Tensor temp = Tensor::empty(current_shape, outputs[0].dtype(), outputs[0].device()).alloc();
123+
run_concat(current, inputs[i], temp);
124+
current = temp;
125+
}
126+
}
127+
}
128+
129+
} // namespace mllm::ascend
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) MLLM Team.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "mllm/core/BaseOp.hpp"
7+
#include "mllm/core/aops/ConcatOp.hpp"
8+
#include "mllm/core/OpTypes.hpp"
9+
10+
namespace mllm::ascend {
11+
12+
class AscendConcatOp final : public aops::ConcatOp {
13+
public:
14+
explicit AscendConcatOp(const aops::ConcatOpOptions& options);
15+
16+
void setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
17+
void forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
18+
};
19+
20+
class AscendConcatOpFactory final : public TypedOpFactory<OpTypes::kConcat, aops::ConcatOpOptions> {
21+
public:
22+
std::shared_ptr<BaseOp> createOpImpl(const aops::ConcatOpOptions& options) override {
23+
return std::make_shared<AscendConcatOp>(options);
24+
}
25+
};
26+
27+
} // namespace mllm::ascend
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// Copyright (c) MLLM Team.
2+
// Licensed under the MIT License.
3+
4+
#include "mllm/backends/ascend/ops/AscendSliceOp.hpp"
5+
6+
#include <acl/acl.h>
7+
#include <atb/atb_infer.h>
8+
#include <atb/types.h>
9+
#include <atb/utils.h>
10+
#include <atb/infer_op_params.h>
11+
12+
#include "mllm/utils/Common.hpp"
13+
#include "mllm/core/DataTypes.hpp"
14+
#include "mllm/core/Tensor.hpp"
15+
#include "mllm/backends/ascend/memory/AscendMemoryManager.hpp"
16+
#include "mllm/backends/ascend/AscendCommon.hpp"
17+
18+
namespace mllm::ascend {
19+
20+
AscendSliceOp::AscendSliceOp(const aops::SliceOpOptions& options) : aops::SliceOp(options) {}
21+
22+
void AscendSliceOp::setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
23+
BaseOp::setup(inputs, outputs);
24+
}
25+
26+
void AscendSliceOp::reshape(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
27+
auto& input = inputs[0];
28+
auto shape = input.shape();
29+
auto slice_index = options().indices_;
30+
31+
MLLM_RT_ASSERT_EQ(slice_index.size(), shape.size());
32+
33+
std::vector<int> out_shape;
34+
for (size_t i = 0; i < shape.size(); ++i) {
35+
const auto& pair = slice_index[i];
36+
int32_t start = pair.start_;
37+
int32_t end = pair.end_;
38+
39+
if (start == kAll) { start = 0; }
40+
if (end == kAll) { end = shape[i]; }
41+
42+
if (start < 0) { start = start + shape[i]; }
43+
if (end < 0) { end = end + shape[i]; }
44+
45+
start = std::max(0, std::min(start, static_cast<int>(shape[i])));
46+
end = std::max(0, std::min(end, static_cast<int>(shape[i])));
47+
48+
int len = std::max(0, end - start);
49+
out_shape.push_back(len);
50+
}
51+
52+
outputs.emplace_back(Tensor::empty(out_shape, input.dtype(), input.device()));
53+
}
54+
55+
void AscendSliceOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
56+
atb::infer::SliceParam param;
57+
auto& input = inputs[0];
58+
auto shape = input.shape();
59+
auto slice_index = options().indices_;
60+
61+
for(size_t i=0; i<shape.size(); ++i) {
62+
int32_t start = slice_index[i].start_;
63+
int32_t end = slice_index[i].end_;
64+
int32_t dim_size = shape[i];
65+
66+
if (start == kAll) start = 0;
67+
if (end == kAll) end = dim_size;
68+
69+
if (start < 0) start += dim_size;
70+
if (end < 0) end += dim_size;
71+
72+
start = std::max(0, std::min(start, dim_size));
73+
end = std::max(0, std::min(end, dim_size));
74+
75+
param.offsets.push_back(start);
76+
param.size.push_back(std::max(0, end - start));
77+
}
78+
79+
atb::Operation* op = nullptr;
80+
auto st = atb::CreateOperation(param, &op);
81+
if (st != atb::NO_ERROR || op == nullptr) {
82+
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB CreateOperation(Slice) failed, status={}", static_cast<int>(st));
83+
}
84+
85+
atb::Context* atb_ctx = getGlobalAtbContext();
86+
87+
atb::SVector<atb::Tensor> inTensors;
88+
std::vector<atb::Tensor> atb_inputs(inputs.size());
89+
for (size_t i = 0; i < inputs.size(); ++i) {
90+
fillAtbTensor(inputs[i], atb_inputs[i]);
91+
inTensors.push_back(atb_inputs[i]);
92+
}
93+
94+
atb::Tensor atb_output;
95+
fillAtbTensor(outputs[0], atb_output);
96+
atb::SVector<atb::Tensor> outTensors;
97+
outTensors.push_back(atb_output);
98+
99+
atb::VariantPack vp;
100+
vp.inTensors = inTensors;
101+
vp.outTensors = outTensors;
102+
103+
uint64_t workspaceSize = 0;
104+
st = op->Setup(vp, workspaceSize, atb_ctx);
105+
if (st != atb::NO_ERROR) {
106+
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SliceOp Setup failed, status={}", static_cast<int>(st));
107+
}
108+
109+
void* workspace = nullptr;
110+
int workspace_block_id = -1;
111+
if (workspaceSize > 0) {
112+
auto& mem_mgr = getAscendMemoryManager();
113+
mem_mgr.allocateBlock(static_cast<uint32_t>(workspaceSize), workspace_block_id);
114+
mem_mgr.getBlockPtr(workspace_block_id, workspace);
115+
}
116+
117+
{
118+
ASCEND_TIME_SCOPE("AscendSliceOp::forward");
119+
st = op->Execute(vp, reinterpret_cast<uint8_t*>(workspace), workspaceSize, atb_ctx);
120+
}
121+
122+
if (st != atb::NO_ERROR) {
123+
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SliceOp Execute failed, status={}", static_cast<int>(st));
124+
}
125+
126+
syncGlobalAtbStream();
127+
128+
if (workspace_block_id != -1) {
129+
auto& mem_mgr = getAscendMemoryManager();
130+
mem_mgr.freeBlock(workspace_block_id);
131+
}
132+
133+
atb::DestroyOperation(op);
134+
}
135+
136+
} // namespace mllm::ascend
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) MLLM Team.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "mllm/core/BaseOp.hpp"
7+
#include "mllm/core/aops/SliceOp.hpp"
8+
#include "mllm/core/OpTypes.hpp"
9+
10+
namespace mllm::ascend {
11+
12+
class AscendSliceOp final : public aops::SliceOp {
13+
public:
14+
explicit AscendSliceOp(const aops::SliceOpOptions& options);
15+
16+
void setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
17+
void reshape(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
18+
void forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
19+
};
20+
21+
class AscendSliceOpFactory final : public TypedOpFactory<OpTypes::kSlice, aops::SliceOpOptions> {
22+
public:
23+
std::shared_ptr<BaseOp> createOpImpl(const aops::SliceOpOptions& options) override {
24+
return std::make_shared<AscendSliceOp>(options);
25+
}
26+
};
27+
28+
} // namespace mllm::ascend
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) MLLM Team.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "mllm/mllm.hpp"
7+
#include "mllm/core/Tensor.hpp"
8+
#include "mllm/nn/Functional.hpp"
9+
#include "KernelTestHelper.hpp" // Has KernelTest base class
10+
11+
class AscendConcatKernelTest : public KernelTest {
12+
public:
13+
bool ConcatFloat16Test(const std::vector<mllm::Tensor::shape_t>& input_shapes, int dim) {
14+
using namespace mllm;
15+
16+
std::vector<Tensor> inputs_cpu;
17+
for (const auto& shape : input_shapes) {
18+
inputs_cpu.push_back(Tensor::random(shape, -1.0, 1.0, kFloat16, kCPU));
19+
}
20+
21+
// CPU Reference
22+
auto out_cpu = nn::functional::concat(inputs_cpu, dim);
23+
24+
// Ascend
25+
std::vector<Tensor> inputs_ascend;
26+
for (auto& t : inputs_cpu) {
27+
inputs_ascend.push_back(t.to(kAscend));
28+
}
29+
30+
auto out_ascend = nn::functional::concat(inputs_ascend, dim);
31+
auto out_back = out_ascend.to(kCPU);
32+
33+
auto result = test::allClose(out_back, out_cpu, 1e-2, 1e-2);
34+
if (!result.is_close) {
35+
std::cout << "[ConcatTest] FAILED! dim=" << dim << std::endl;
36+
return false;
37+
}
38+
std::cout << "[ConcatTest] PASSED dim=" << dim << std::endl;
39+
return true;
40+
}
41+
};

0 commit comments

Comments
 (0)