1717#include " GatherLayer.h"
1818
1919#include " OperationUtils.h"
20- #include " GGMLHelper.h"
2120#include " ../KernelGenerator.h"
2221#include " ../Validator.h"
2322
2625namespace onert ::backend::cpu
2726{
2827
29- void Validator::visit (const ir::operation::Gather &) { _supported = true ; }
28+ void Validator::visit (const ir::operation::Gather &node)
29+ {
30+ using ir::operation::Gather;
31+
32+ const auto input_index{node.getInputs ().at (Gather::Input::INPUT)};
33+ const auto input_node = &_graph.operands ().at (input_index);
34+
35+ _supported = false ;
36+
37+ if (input_node->typeInfo ().type () == ir::DataType::QUANT_GGML_Q4_0)
38+ return ;
39+
40+ _supported = true ;
41+ }
3042
3143void KernelGenerator::visit (const ir::operation::Gather &node)
3244{
@@ -43,7 +55,7 @@ void KernelGenerator::visit(const ir::operation::Gather &node)
4355
4456 auto fn = std::make_unique<ops::GatherLayer>();
4557
46- fn->configure (input_tensor, indices_tensor, output_tensor, axis, _external_context. get () );
58+ fn->configure (input_tensor, indices_tensor, output_tensor, axis);
4759
4860 _return_fn = std::move (fn);
4961}
@@ -54,16 +66,12 @@ namespace onert::backend::cpu::ops
5466{
5567
5668void GatherLayer::configure (const IPortableTensor *input, const IPortableTensor *indices,
57- IPortableTensor *output, int32_t axis, ExternalContext *ctx )
69+ IPortableTensor *output, int32_t axis)
5870{
5971 _input = input;
6072 _indices = indices;
6173 _axis = axis;
6274 _output = output;
63- _ctx = ctx;
64-
65- if (_input->data_type () == OperandType::QUANT_GGML_Q4_0)
66- ctx->initGgmlContext ();
6775}
6876
6977template <typename InputType> void GatherLayer::runByInputType ()
@@ -97,53 +105,6 @@ template <typename InputType> void GatherLayer::runByInputType()
97105 }
98106}
99107
100- void GatherLayer::runByGGMLQuantInputType ()
101- {
102- // Supporting condition
103- // Input: rank 2
104- // Indice: rank < 4 or rank 4 with dim(0) = 1, INT32
105- // Axis: 0
106- if (getShape (_input).DimensionsCount () != 2 )
107- throw std::runtime_error (" Gather: block quantized input tensor must be rank 2" );
108-
109- if (getShape (_indices).DimensionsCount () >= 4 &&
110- (getShape (_indices).DimensionsCount () != 4 || getShape (_indices).Dims (0 ) != 1 ))
111- throw std::runtime_error (" Gather: invalid indices tensor shape" );
112-
113- if (_indices->data_type () != ir::DataType::INT32)
114- throw std::runtime_error (" Gather: indices tensor must be int32 type" );
115-
116- if (_axis != 0 )
117- throw std::runtime_error (" Gather: axis must be 0" );
118-
119- // convert tensor
120- auto input = getGGMLTensor (_input);
121- auto indices = getGGMLTensor (_indices);
122- auto output = getGGMLTensor (_output);
123- {
124- output.op = GGML_OP_GET_ROWS;
125- output.src [0 ] = &input;
126- output.src [1 ] = &indices;
127- }
128- auto *nodes = &output;
129-
130- // create graph
131- struct ggml_cgraph graph;
132- {
133- memset (&graph, 0 , sizeof (graph));
134- graph.n_nodes = 1 ;
135- graph.nodes = &nodes;
136- }
137-
138- // get cplan
139- auto cplan = ggml_graph_plan (&graph, _ctx->maxNumThreads ());
140- std::vector<uint8_t > buf (cplan.work_size );
141- cplan.work_data = buf.data ();
142-
143- // compute
144- ggml_graph_compute (&graph, &cplan);
145- }
146-
147108void GatherLayer::run ()
148109{
149110 switch (_input->data_type ())
@@ -157,9 +118,6 @@ void GatherLayer::run()
157118 case OperandType::INT32:
158119 runByInputType<int32_t >();
159120 break ;
160- case OperandType::QUANT_GGML_Q4_0:
161- runByGGMLQuantInputType ();
162- break ;
163121 case OperandType::BOOL8:
164122 runByInputType<bool >();
165123 break ;
0 commit comments