Skip to content

Commit 2d5b2cc

Browse files
Implement out kwarg overloads for custom ops
1 parent 23eba7a commit 2d5b2cc

File tree

5 files changed

+342
-126
lines changed

5 files changed

+342
-126
lines changed

bitsandbytes/_ops.py

+107-22
Original file line numberDiff line numberDiff line change
@@ -44,30 +44,42 @@ def _(
4444
bias: Optional[torch.Tensor] = None,
4545
dtype=torch.float16,
4646
) -> torch.Tensor:
47-
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul(A, B)
48-
out = torch.ops.bitsandbytes.int8_mm_dequant(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
47+
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
48+
out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
4949
return out
5050

5151

52-
# Define op
53-
# TODO: mutable output arg as alias of return can be challenging;
54-
# consider a separate op without aliased return:
55-
# int8_linear_matmul_out(
56-
# Tensor A, Tensor B, Tensor out, ScalarType dtype=int32
57-
# ) -> ()
58-
# return () instead of `None` for compatibility, see here: https://github.com/pytorch/pytorch/issues/125044
5952
torch.library.define(
6053
"bitsandbytes::int8_linear_matmul",
61-
"(Tensor A, Tensor B, Tensor? out=None, ScalarType dtype=int32) -> Tensor",
54+
"(Tensor A, Tensor B) -> Tensor",
6255
)
6356

6457

6558
@register_fake("bitsandbytes::int8_linear_matmul")
66-
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
59+
def _(A: torch.Tensor, B: torch.Tensor):
60+
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
61+
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
6762
shapeC = (*A.shape[:-1], B.shape[0])
68-
if out is None:
69-
return torch.empty(shapeC, device=A.device, dtype=dtype)
70-
return out
63+
return torch.empty(shapeC, device=A.device, dtype=torch.int32)
64+
65+
66+
# More info on `out` overloads:
67+
# https://github.com/pytorch/pytorch/issues/125044
68+
torch.library.define(
69+
"bitsandbytes::int8_linear_matmul.out",
70+
"(Tensor A, Tensor B, Tensor! out) -> ()",
71+
)
72+
73+
74+
@register_fake("bitsandbytes::int8_linear_matmul.out")
75+
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
76+
shapeC = (*A.shape[:-1], B.shape[0])
77+
78+
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
79+
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
80+
torch._check(out.shape == shapeC, lambda: f"Expected out.shape == {shapeC}, got {out.shape}")
81+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
82+
torch._check(out.dtype == torch.int32, lambda: f"Expected out.dtype == int32, got {out.dtype}")
7183

7284

7385
torch.library.define(
@@ -107,7 +119,7 @@ def _(A: torch.Tensor, stats: torch.Tensor):
107119

108120
torch.library.define(
109121
"bitsandbytes::int8_mm_dequant",
110-
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? out=None, Tensor? bias=None) -> Tensor",
122+
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor",
111123
)
112124

113125

@@ -117,7 +129,6 @@ def _(
117129
row_stats: torch.Tensor,
118130
col_stats: torch.Tensor,
119131
dtype=torch.float16,
120-
out: Optional[torch.Tensor] = None,
121132
bias: Optional[torch.Tensor] = None,
122133
) -> torch.Tensor:
123134
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
@@ -126,17 +137,13 @@ def _(
126137

127138
torch.library.define(
128139
"bitsandbytes::int8_double_quant",
129-
"(Tensor A, Tensor? col_stats, Tensor? row_stats, Tensor? out_col, Tensor? out_row, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
140+
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
130141
)
131142

132143

133144
@register_fake("bitsandbytes::int8_double_quant")
134145
def _(
135146
A: torch.Tensor,
136-
col_stats: Optional[torch.Tensor] = None,
137-
row_stats: Optional[torch.Tensor] = None,
138-
out_col: Optional[torch.Tensor] = None,
139-
out_row: Optional[torch.Tensor] = None,
140147
threshold=0.0,
141148
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
142149
out_row = torch.empty_like(A, dtype=torch.int8)
@@ -156,12 +163,39 @@ def _(
156163

157164
@register_fake("bitsandbytes::dequantize_4bit")
158165
def _(
159-
A: torch.Tensor, absmax: torch.Tensor, blocksize: int, quant_type: str, shape: Sequence[int], dtype: torch.dtype
166+
A: torch.Tensor,
167+
absmax: torch.Tensor,
168+
blocksize: int,
169+
quant_type: str,
170+
shape: Sequence[int],
171+
dtype: torch.dtype,
160172
) -> torch.Tensor:
161173
torch._check_is_size(blocksize)
162174
return torch.empty(shape, dtype=dtype, device=A.device)
163175

164176

177+
torch.library.define(
178+
"bitsandbytes::dequantize_4bit.out",
179+
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()",
180+
)
181+
182+
183+
@register_fake("bitsandbytes::dequantize_4bit.out")
184+
def _(
185+
A: torch.Tensor,
186+
absmax: torch.Tensor,
187+
blocksize: int,
188+
quant_type: str,
189+
shape: Sequence[int],
190+
dtype: torch.dtype,
191+
out: torch.Tensor,
192+
) -> None:
193+
torch._check_is_size(blocksize)
194+
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
195+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
196+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
197+
198+
165199
torch.library.define(
166200
"bitsandbytes::quantize_4bit",
167201
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)",
@@ -194,6 +228,23 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
194228
return torch.empty_like(A, dtype=dtype)
195229

196230

231+
torch.library.define(
232+
"bitsandbytes::dequantize_blockwise.out",
233+
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()",
234+
)
235+
236+
237+
@register_fake("bitsandbytes::dequantize_blockwise.out")
238+
def _(
239+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
240+
):
241+
torch._check_is_size(blocksize)
242+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
243+
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
244+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
245+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
246+
247+
197248
torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")
198249

199250

@@ -229,3 +280,37 @@ def _(
229280
)
230281
shape = (*A.shape[:-1], shapeB[0])
231282
return torch.empty(shape, device=A.device, dtype=A.dtype)
283+
284+
285+
torch.library.define(
286+
"bitsandbytes::gemv_4bit.out",
287+
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()",
288+
)
289+
290+
291+
@register_fake("bitsandbytes::gemv_4bit.out")
292+
def _(
293+
A: torch.Tensor,
294+
B: torch.Tensor,
295+
shapeB: Sequence[int],
296+
absmax: torch.Tensor,
297+
code: torch.Tensor,
298+
blocksize: int,
299+
out: torch.Tensor,
300+
) -> None:
301+
torch._check_is_size(blocksize)
302+
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
303+
torch._check(
304+
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
305+
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
306+
)
307+
torch._check(
308+
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
309+
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
310+
)
311+
torch._check(
312+
out.shape == (*A.shape[:-1], shapeB[0]),
313+
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
314+
)
315+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
316+
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")

bitsandbytes/backends/cpu/ops.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,17 @@
1010

1111

1212
@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
13-
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
13+
def _(A: torch.Tensor, B: torch.Tensor):
14+
return _int8_linear_matmul_impl(A, B)
15+
16+
17+
@register_kernel("bitsandbytes::int8_linear_matmul.out", "cpu")
18+
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
19+
torch._check(out.dtype == torch.int32)
20+
_int8_linear_matmul_impl(A, B, out)
21+
22+
23+
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None):
1424
# Naive implementation: perform matmul in fp32
1525
result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
1626
if out is not None:

0 commit comments

Comments
 (0)