@@ -77,9 +77,9 @@ Status Gelu<T>::Compute(OpKernelContext* context) const {
7777 T* output_data = output->MutableData <T>();
7878
7979 concurrency::ThreadPool* tp = context->GetOperatorThreadPool ();
80- int64_t elem_count = input->Shape ().Size ();
81- constexpr int64_t length_per_task = 4096 ; // this number comes from FastGelu.
82- int64_t task_count = (elem_count + length_per_task - 1 ) / length_per_task;
80+ size_t elem_count = input->Shape ().Size ();
81+ constexpr size_t length_per_task = 4096 ; // this number comes from FastGelu.
82+ size_t task_count = (elem_count + length_per_task - 1 ) / length_per_task;
8383
8484 if (approximation_algorithm_ == " tanh" ) {
8585 // FastGelu allows optional bias. Here we split input data into chunks. Each chunk
@@ -95,16 +95,16 @@ Status Gelu<T>::Compute(OpKernelContext* context) const {
9595 const auto start = task_idx * length_per_task;
9696 const T* p_input = input_data + start;
9797 T* p_output = output_data + start;
98- int64_t count = std::min (length_per_task, elem_count - start);
98+ size_t count = std::min (length_per_task, elem_count - start);
9999
100- for (int64_t i = 0 ; i < count; i++) {
100+ for (size_t i = 0 ; i < count; i++) {
101101 T value = p_input[i];
102102 p_output[i] = value * (static_cast <T>(C) * value * value + static_cast <T>(B));
103103 }
104104
105105 MlasComputeTanh (p_output, p_output, narrow<size_t >(count));
106106
107- for (int64_t i = 0 ; i < count; i++) {
107+ for (size_t i = 0 ; i < count; i++) {
108108 p_output[i] = 0 .5f * p_input[i] * (p_output[i] + 1 .0f );
109109 }
110110 },
@@ -117,16 +117,16 @@ Status Gelu<T>::Compute(OpKernelContext* context) const {
117117 const auto start = task_idx * length_per_task;
118118 const T* p_input = input_data + start;
119119 T* p_output = output_data + start;
120- int64_t count = std::min (length_per_task, elem_count - start);
120+ size_t count = std::min (length_per_task, elem_count - start);
121121
122- for (int64_t i = 0 ; i < count; i++) {
122+ for (size_t i = 0 ; i < count; i++) {
123123 T value = p_input[i];
124124 p_output[i] = value * static_cast <T>(M_SQRT1_2);
125125 }
126126
127127 MlasComputeErf (p_output, p_output, narrow<size_t >(count));
128128
129- for (int64_t i = 0 ; i < count; i++) {
129+ for (size_t i = 0 ; i < count; i++) {
130130 p_output[i] = 0 .5f * p_input[i] * (p_output[i] + 1 .0f );
131131 }
132132 },
@@ -143,9 +143,9 @@ Status Gelu<MLFloat16>::Compute(OpKernelContext* context) const {
143143 Tensor* output = context->Output (0 , input->Shape ());
144144 MLFloat16* output_data = output->MutableData <MLFloat16>();
145145 concurrency::ThreadPool* tp = context->GetOperatorThreadPool ();
146- int64_t elem_count = input->Shape ().Size ();
147- constexpr int64_t length_per_task = 4096 ;
148- int64_t task_count = (elem_count + length_per_task - 1 ) / length_per_task;
146+ size_t elem_count = input->Shape ().Size ();
147+ constexpr size_t length_per_task = 4096 ;
148+ size_t task_count = (elem_count + length_per_task - 1 ) / length_per_task;
149149
150150 if (approximation_algorithm_ != " tanh" && approximation_algorithm_ != " none" ) {
151151 return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Unsupported approximation_algorithm: " , approximation_algorithm_);
@@ -178,7 +178,7 @@ Status Gelu<MLFloat16>::Compute(OpKernelContext* context) const {
178178 const auto start = task_idx * length_per_task;
179179 const MLFloat16* p_input = input_data + start;
180180 MLFloat16* p_output = output_data + start;
181- int64_t count = std::min (length_per_task, elem_count - start);
181+ size_t count = std::min (length_per_task, elem_count - start);
182182 MLFloat16* p_temp = temp_fp16_aligned.get () + start;
183183 MlasComputeFP16Gelu (p_input, p_output, p_temp, count, approximation_algorithm_);
184184 },
0 commit comments