Skip to content

Commit 09c2760

Browse files
authored
Update linear.h (#1963)
Update linear.h (#1963) Summary: Pull Request resolved: #1963 This diff updates the interface for universal kernels to mirrored the update we made in KleidiAI a couple weeks ago. The changes do look large, but most are moving files around. Changes: * Within cpu/aarch64/linear, there is a new folder channelwise_8bit_activation_groupwise_lowbit_weight that contains: * include.h (new version of linear.h, with shared activation packing and weight packing functions), the 3 kernel files (moved, but no changes made), pack_activations (moved but no changes), and pack_weights (moved but no changes). * The new shared embedding ops (this diff stack) need a new include/namespace update to reflect these changes. Note that linear.h is preserved to not break existing code at op level, but eventually it will be deleted and the op code will use the new interface. Reviewed By: digantdesai Differential Revision: D71357293
1 parent b948c4e commit 09c2760

10 files changed

+528
-593
lines changed

torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h

+15-14
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#include <arm_neon.h>
1212
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
13-
#include <torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h>
13+
#include <torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h>
1414
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1515
#include <cassert>
1616
#include <vector>
@@ -353,19 +353,20 @@ inline void shared_embedding(
353353
n_idx = n_idx * nr;
354354
int j = index - n_idx;
355355

356-
torchao::kernels::cpu::aarch64::linear::packing::
357-
unpack_weights_at_n_idx<weight_nbit, nr, kr, sr>(
358-
weight_qvals.data(),
359-
weight_scales.data(),
360-
has_weight_zeros ? weight_zeros.data() : nullptr,
361-
has_bias ? bias.data() : nullptr,
362-
n_idx,
363-
n,
364-
k,
365-
group_size,
366-
has_weight_zeros,
367-
has_bias,
368-
packed_weights);
356+
torchao::kernels::cpu::aarch64::linear::
357+
channelwise_8bit_activation_groupwise_lowbit_weight::weight_packing::
358+
unpack_weights_at_n_idx<weight_nbit, nr, kr, sr>(
359+
weight_qvals.data(),
360+
weight_scales.data(),
361+
has_weight_zeros ? weight_zeros.data() : nullptr,
362+
has_bias ? bias.data() : nullptr,
363+
n_idx,
364+
n,
365+
k,
366+
group_size,
367+
has_weight_zeros,
368+
has_bias,
369+
packed_weights);
369370

370371
// Dequantize and store to output (size k)
371372
int8x16_t qvals;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#if defined(__aarch64__) || defined(__ARM_NEON)
10+
11+
#include <arm_neon.h>
12+
#include <stddef.h>
13+
#include <torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h>
14+
#include <torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h>
15+
16+
#include <torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h>
17+
#include <torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h>
18+
#include <torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h>
19+
20+
namespace torchao::kernels::cpu::aarch64::linear::
21+
channelwise_8bit_activation_groupwise_lowbit_weight {
22+
23+
inline size_t packed_activations_size(
24+
int m,
25+
int k,
26+
int group_size,
27+
bool has_weight_zeros,
28+
int mr,
29+
int kr,
30+
int sr) {
31+
(void)mr; // unused
32+
(void)kr; // unused
33+
(void)sr; // unused
34+
return activation_packing::packed_activations_size(
35+
m, k, group_size, has_weight_zeros);
36+
}
37+
38+
inline size_t packed_activations_offset(
39+
int m_idx,
40+
int k,
41+
int group_size,
42+
bool has_weight_zeros,
43+
int mr,
44+
int kr,
45+
int sr) {
46+
assert(m_idx % mr == 0);
47+
auto packed_activations_size_mr_rows =
48+
packed_activations_size(mr, k, group_size, has_weight_zeros, mr, kr, sr);
49+
return (m_idx / mr) * packed_activations_size_mr_rows;
50+
}
51+
52+
template <int mr, int kr, int sr>
53+
void pack_activations(
54+
void* packed_activations,
55+
int m,
56+
int k,
57+
int group_size,
58+
const float* activations,
59+
bool has_weight_zeros) {
60+
activation_packing::pack_activations<mr, kr, sr>(
61+
packed_activations, m, k, group_size, activations, has_weight_zeros);
62+
}
63+
64+
inline size_t packed_weights_size(
65+
int n,
66+
int k,
67+
int group_size,
68+
int weight_nbit,
69+
bool has_weight_zeros,
70+
bool has_bias,
71+
int nr,
72+
int kr,
73+
int sr) {
74+
(void)kr; // unused
75+
(void)sr; // unused
76+
return weight_packing::packed_weights_size(
77+
n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr);
78+
}
79+
80+
inline size_t packed_weights_offset(
81+
int n_idx,
82+
int k,
83+
int group_size,
84+
int weight_nbit,
85+
bool has_weight_zeros,
86+
bool has_bias,
87+
int nr,
88+
int kr,
89+
int sr) {
90+
assert(n_idx % nr == 0);
91+
auto packed_weights_size_nr_cols = packed_weights_size(
92+
nr, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr);
93+
return (n_idx / nr) * packed_weights_size_nr_cols;
94+
}
95+
96+
template <int weight_nbit, int nr, int kr, int sr>
97+
void pack_weights(
98+
void* packed_weights,
99+
int n,
100+
int k,
101+
int group_size,
102+
const int8_t* weight_qvals,
103+
const float* weight_scales,
104+
const int8_t* weight_zeros,
105+
const float* bias) {
106+
weight_packing::pack_weights<weight_nbit, nr, kr, sr>(
107+
packed_weights,
108+
n,
109+
k,
110+
group_size,
111+
weight_qvals,
112+
weight_scales,
113+
weight_zeros,
114+
bias);
115+
}
116+
117+
template <int weight_nbit>
118+
void kernel_1x1x32_f32_neondot(
119+
// Outputs
120+
float32_t* output,
121+
// Inputs
122+
int output_m_stride,
123+
int m,
124+
int n,
125+
int k,
126+
int group_size,
127+
const void* packed_weights,
128+
const void* packed_activations,
129+
// Ignored if has_clamp = false
130+
float clamp_min,
131+
float clamp_max,
132+
bool has_weight_zeros,
133+
bool has_bias,
134+
bool has_clamp) {
135+
kernel::kernel_1x1x32_f32_neondot<weight_nbit>(
136+
output,
137+
output_m_stride,
138+
m,
139+
n,
140+
k,
141+
group_size,
142+
packed_weights,
143+
packed_activations,
144+
clamp_min,
145+
clamp_max,
146+
has_weight_zeros,
147+
has_bias,
148+
has_clamp);
149+
}
150+
151+
template <int weight_nbit>
152+
void kernel_1x4x16_f32_neondot(
153+
// Outputs
154+
float32_t* output,
155+
// Inputs
156+
int output_m_stride,
157+
int m,
158+
int n,
159+
int k,
160+
int group_size,
161+
const void* packed_weights,
162+
const void* packed_activations,
163+
// Ignored if has_clamp = false
164+
float clamp_min,
165+
float clamp_max,
166+
bool has_weight_zeros,
167+
bool has_bias,
168+
bool has_clamp) {
169+
kernel::kernel_1x4x16_f32_neondot<weight_nbit>(
170+
output,
171+
output_m_stride,
172+
m,
173+
n,
174+
k,
175+
group_size,
176+
packed_weights,
177+
packed_activations,
178+
clamp_min,
179+
clamp_max,
180+
has_weight_zeros,
181+
has_bias,
182+
has_clamp);
183+
}
184+
185+
template <int weight_nbit>
186+
void kernel_1x8x16_f32_neondot(
187+
// Outputs
188+
float32_t* output,
189+
// Inputs
190+
int output_m_stride,
191+
int m,
192+
int n,
193+
int k,
194+
int group_size,
195+
const void* packed_weights,
196+
const void* packed_activations,
197+
// Ignored if has_clamp = false
198+
float clamp_min,
199+
float clamp_max,
200+
bool has_weight_zeros,
201+
bool has_bias,
202+
bool has_clamp) {
203+
kernel::kernel_1x8x16_f32_neondot<weight_nbit>(
204+
output,
205+
output_m_stride,
206+
m,
207+
n,
208+
k,
209+
group_size,
210+
packed_weights,
211+
packed_activations,
212+
clamp_min,
213+
clamp_max,
214+
has_weight_zeros,
215+
has_bias,
216+
has_clamp);
217+
}
218+
219+
} // namespace
220+
// torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight
221+
222+
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)