@@ -44,30 +44,42 @@ def _(
44
44
bias : Optional [torch .Tensor ] = None ,
45
45
dtype = torch .float16 ,
46
46
) -> torch .Tensor :
47
- out_i32 = torch .ops .bitsandbytes .int8_linear_matmul (A , B )
48
- out = torch .ops .bitsandbytes .int8_mm_dequant (out_i32 , row_stats , col_stats , dtype = dtype , bias = bias )
47
+ out_i32 = torch .ops .bitsandbytes .int8_linear_matmul . default (A , B )
48
+ out = torch .ops .bitsandbytes .int8_mm_dequant . default (out_i32 , row_stats , col_stats , dtype = dtype , bias = bias )
49
49
return out
50
50
51
51
52
- # Define op
53
- # TODO: mutable output arg as alias of return can be challenging;
54
- # consider a separate op without aliased return:
55
- # int8_linear_matmul_out(
56
- # Tensor A, Tensor B, Tensor out, ScalarType dtype=int32
57
- # ) -> ()
58
- # return () instead of `None` for compatibility, see here: https://github.com/pytorch/pytorch/issues/125044
59
52
torch .library .define (
60
53
"bitsandbytes::int8_linear_matmul" ,
61
- "(Tensor A, Tensor B, Tensor? out=None, ScalarType dtype=int32 ) -> Tensor" ,
54
+ "(Tensor A, Tensor B) -> Tensor" ,
62
55
)
63
56
64
57
65
58
@register_fake ("bitsandbytes::int8_linear_matmul" )
66
- def _ (A : torch .Tensor , B : torch .Tensor , out : Optional [torch .Tensor ] = None , dtype = torch .int32 ):
59
+ def _ (A : torch .Tensor , B : torch .Tensor ):
60
+ torch ._check (A .dtype == torch .int8 , lambda : "A must be int8" )
61
+ torch ._check (B .dtype == torch .int8 , lambda : "B must be int8" )
67
62
shapeC = (* A .shape [:- 1 ], B .shape [0 ])
68
- if out is None :
69
- return torch .empty (shapeC , device = A .device , dtype = dtype )
70
- return out
63
+ return torch .empty (shapeC , device = A .device , dtype = torch .int32 )
64
+
65
+
66
+ # More info on `out` overloads:
67
+ # https://github.com/pytorch/pytorch/issues/125044
68
+ torch .library .define (
69
+ "bitsandbytes::int8_linear_matmul.out" ,
70
+ "(Tensor A, Tensor B, Tensor! out) -> ()" ,
71
+ )
72
+
73
+
74
+ @register_fake ("bitsandbytes::int8_linear_matmul.out" )
75
+ def _ (A : torch .Tensor , B : torch .Tensor , out : torch .Tensor ):
76
+ shapeC = (* A .shape [:- 1 ], B .shape [0 ])
77
+
78
+ torch ._check (A .dtype == torch .int8 , lambda : "A must be int8" )
79
+ torch ._check (B .dtype == torch .int8 , lambda : "B must be int8" )
80
+ torch ._check (out .shape == shapeC , lambda : f"Expected out.shape == { shapeC } , got { out .shape } " )
81
+ torch ._check (out .device == A .device , lambda : f"Expected out.device == { A .device } , got { out .device } " )
82
+ torch ._check (out .dtype == torch .int32 , lambda : f"Expected out.dtype == int32, got { out .dtype } " )
71
83
72
84
73
85
torch .library .define (
@@ -107,7 +119,7 @@ def _(A: torch.Tensor, stats: torch.Tensor):
107
119
108
120
torch .library .define (
109
121
"bitsandbytes::int8_mm_dequant" ,
110
- "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? out=None, Tensor? bias=None) -> Tensor" ,
122
+ "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor" ,
111
123
)
112
124
113
125
@@ -117,7 +129,6 @@ def _(
117
129
row_stats : torch .Tensor ,
118
130
col_stats : torch .Tensor ,
119
131
dtype = torch .float16 ,
120
- out : Optional [torch .Tensor ] = None ,
121
132
bias : Optional [torch .Tensor ] = None ,
122
133
) -> torch .Tensor :
123
134
torch ._check (A .dtype == torch .int32 , lambda : "A must be int32" )
@@ -126,17 +137,13 @@ def _(
126
137
127
138
torch .library .define (
128
139
"bitsandbytes::int8_double_quant" ,
129
- "(Tensor A, Tensor? col_stats, Tensor? row_stats, Tensor? out_col, Tensor? out_row, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)" ,
140
+ "(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)" ,
130
141
)
131
142
132
143
133
144
@register_fake ("bitsandbytes::int8_double_quant" )
134
145
def _ (
135
146
A : torch .Tensor ,
136
- col_stats : Optional [torch .Tensor ] = None ,
137
- row_stats : Optional [torch .Tensor ] = None ,
138
- out_col : Optional [torch .Tensor ] = None ,
139
- out_row : Optional [torch .Tensor ] = None ,
140
147
threshold = 0.0 ,
141
148
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
142
149
out_row = torch .empty_like (A , dtype = torch .int8 )
@@ -156,12 +163,39 @@ def _(
156
163
157
164
@register_fake ("bitsandbytes::dequantize_4bit" )
158
165
def _ (
159
- A : torch .Tensor , absmax : torch .Tensor , blocksize : int , quant_type : str , shape : Sequence [int ], dtype : torch .dtype
166
+ A : torch .Tensor ,
167
+ absmax : torch .Tensor ,
168
+ blocksize : int ,
169
+ quant_type : str ,
170
+ shape : Sequence [int ],
171
+ dtype : torch .dtype ,
160
172
) -> torch .Tensor :
161
173
torch ._check_is_size (blocksize )
162
174
return torch .empty (shape , dtype = dtype , device = A .device )
163
175
164
176
177
+ torch .library .define (
178
+ "bitsandbytes::dequantize_4bit.out" ,
179
+ "(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()" ,
180
+ )
181
+
182
+
183
+ @register_fake ("bitsandbytes::dequantize_4bit.out" )
184
+ def _ (
185
+ A : torch .Tensor ,
186
+ absmax : torch .Tensor ,
187
+ blocksize : int ,
188
+ quant_type : str ,
189
+ shape : Sequence [int ],
190
+ dtype : torch .dtype ,
191
+ out : torch .Tensor ,
192
+ ) -> None :
193
+ torch ._check_is_size (blocksize )
194
+ torch ._check (out .shape == shape , lambda : f"Expected out.shape == { shape } , got { out .shape } " )
195
+ torch ._check (out .device == A .device , lambda : f"Expected out.device == { A .device } , got { out .device } " )
196
+ torch ._check (out .dtype == dtype , lambda : f"Expected out.dtype == { dtype } , got { out .dtype } " )
197
+
198
+
165
199
torch .library .define (
166
200
"bitsandbytes::quantize_4bit" ,
167
201
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)" ,
@@ -194,6 +228,23 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
194
228
return torch .empty_like (A , dtype = dtype )
195
229
196
230
231
+ torch .library .define (
232
+ "bitsandbytes::dequantize_blockwise.out" ,
233
+ "(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()" ,
234
+ )
235
+
236
+
237
+ @register_fake ("bitsandbytes::dequantize_blockwise.out" )
238
+ def _ (
239
+ A : torch .Tensor , absmax : torch .Tensor , code : torch .Tensor , blocksize : int , dtype : torch .dtype , out : torch .Tensor
240
+ ):
241
+ torch ._check_is_size (blocksize )
242
+ torch ._check (A .dtype == torch .uint8 , lambda : f"A must be uint8, got { A .dtype } " )
243
+ torch ._check (out .shape == A .shape , lambda : f"Expected out.shape == { A .shape } , got { out .shape } " )
244
+ torch ._check (out .device == A .device , lambda : f"Expected out.device == { A .device } , got { out .device } " )
245
+ torch ._check (out .dtype == dtype , lambda : f"Expected out.dtype == { dtype } , got { out .dtype } " )
246
+
247
+
197
248
torch .library .define ("bitsandbytes::quantize_blockwise" , "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)" )
198
249
199
250
@@ -229,3 +280,37 @@ def _(
229
280
)
230
281
shape = (* A .shape [:- 1 ], shapeB [0 ])
231
282
return torch .empty (shape , device = A .device , dtype = A .dtype )
283
+
284
+
285
+ torch .library .define (
286
+ "bitsandbytes::gemv_4bit.out" ,
287
+ "(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()" ,
288
+ )
289
+
290
+
291
+ @register_fake ("bitsandbytes::gemv_4bit.out" )
292
+ def _ (
293
+ A : torch .Tensor ,
294
+ B : torch .Tensor ,
295
+ shapeB : Sequence [int ],
296
+ absmax : torch .Tensor ,
297
+ code : torch .Tensor ,
298
+ blocksize : int ,
299
+ out : torch .Tensor ,
300
+ ) -> None :
301
+ torch ._check_is_size (blocksize )
302
+ torch ._check (A .numel () == A .size (- 1 ), lambda : f"A must be a vector with leading dimensions of 1, got { A .shape } " )
303
+ torch ._check (
304
+ A .dtype in [torch .float16 , torch .bfloat16 , torch .float32 ],
305
+ lambda : f"A must be float16, bfloat16, or float32, got { A .dtype } " ,
306
+ )
307
+ torch ._check (
308
+ B .dtype in [torch .uint8 , torch .bfloat16 , torch .float16 , torch .float32 ],
309
+ lambda : f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got { B .dtype } " ,
310
+ )
311
+ torch ._check (
312
+ out .shape == (* A .shape [:- 1 ], shapeB [0 ]),
313
+ lambda : f"Expected out.shape == { (* A .shape [:- 1 ], shapeB [0 ])} , got { out .shape } " ,
314
+ )
315
+ torch ._check (out .device == A .device , lambda : f"Expected out.device == { A .device } , got { out .device } " )
316
+ torch ._check (out .dtype == A .dtype , lambda : f"Expected out.dtype == { A .dtype } , got { out .dtype } " )
0 commit comments