@@ -4,97 +4,108 @@ use std::slice;
4
4
use tokenizers:: tokenizer:: Tokenizer ;
5
5
use tokenizers:: Encoding ;
6
6
7
+ // #[inline(always)] is used aggressively - Realistically we only have a few callsites.
8
+
9
+ #[ repr( C ) ]
10
+ union RawPointer < T >
11
+ {
12
+ mutable : * mut T ,
13
+ readonly : * const T ,
14
+ }
15
+
7
16
#[ repr( C ) ]
8
- pub struct Buffer < T >
17
+ pub struct NativeBuffer < T >
9
18
{
10
- pub ptr : * mut T ,
19
+ pub ptr : RawPointer < T > ,
11
20
pub length : usize ,
12
21
}
13
22
14
- impl < T > Buffer < T >
23
+ impl < T > NativeBuffer < T >
15
24
{
16
- pub fn new ( ptr : * mut T , length : usize ) -> Self
25
+ #[ inline( always) ]
26
+ pub fn wrap_mutable_ptr ( ptr : * mut T , length : usize ) -> Self
17
27
{
18
- Buffer
28
+ NativeBuffer
19
29
{
20
- ptr,
30
+ ptr : RawPointer { mutable : ptr } ,
21
31
length,
22
32
}
23
33
}
24
34
25
- pub fn from_slice ( slice : & mut [ T ] ) -> Self
35
+ #[ inline( always) ]
36
+ pub fn wrap_ptr ( ptr : * const T , length : usize ) -> Self
26
37
{
27
- Buffer
38
+ NativeBuffer
28
39
{
29
- ptr : slice . as_mut_ptr ( ) ,
30
- length : slice . len ( ) ,
40
+ ptr : RawPointer { readonly : ptr } ,
41
+ length,
31
42
}
32
43
}
33
44
34
- pub unsafe fn to_slice ( & self ) -> & mut [ T ]
45
+ #[ inline( always) ]
46
+ pub fn from_slice ( slice : & [ T ] ) -> Self
35
47
{
36
- return slice:: from_raw_parts_mut ( self . ptr , self . length )
48
+ NativeBuffer
49
+ {
50
+ ptr : RawPointer { readonly : slice. as_ptr ( ) } ,
51
+ length : slice. len ( ) ,
52
+ }
37
53
}
38
54
39
- pub fn empty ( ) -> Self
55
+ #[ inline( always) ]
56
+ pub fn from_mutable_slice ( slice : & mut [ T ] ) -> Self
40
57
{
41
- Buffer
58
+ NativeBuffer
42
59
{
43
- ptr : null_mut ( ) ,
44
- length : 0 ,
60
+ ptr : RawPointer { mutable : slice . as_mut_ptr ( ) } ,
61
+ length : slice . len ( ) ,
45
62
}
46
63
}
47
- }
48
-
49
- #[ repr( C ) ]
50
- pub struct ReadOnlyBuffer < T >
51
- {
52
- ptr : * const T ,
53
- pub length : usize ,
54
- }
55
64
56
- impl < T > ReadOnlyBuffer < T >
57
- {
58
- pub fn new ( ptr : * const T , length : usize ) -> Self
65
+ #[ inline( always) ]
66
+ pub unsafe fn as_slice ( & self ) -> & [ T ]
59
67
{
60
- ReadOnlyBuffer
61
- {
62
- ptr,
63
- length,
64
- }
68
+ return slice:: from_raw_parts ( self . ptr . readonly , self . length )
65
69
}
66
70
67
- pub fn from_slice ( slice : & [ T ] ) -> Self
71
+ #[ inline( always) ]
72
+ pub unsafe fn as_mutable_slice ( & self ) -> & mut [ T ]
68
73
{
69
- ReadOnlyBuffer
70
- {
71
- ptr : slice. as_ptr ( ) ,
72
- length : slice. len ( ) ,
73
- }
74
+ return slice:: from_raw_parts_mut ( self . ptr . mutable , self . length )
74
75
}
75
76
76
- pub unsafe fn as_slice ( & self ) -> & [ T ]
77
+ #[ inline( always) ]
78
+ pub fn from_vec ( vec : & Vec < T > ) -> Self
77
79
{
78
- return slice:: from_raw_parts ( self . ptr , self . length )
80
+ let ptr = vec. as_ptr ( ) ;
81
+ let length = vec. len ( ) ;
82
+
83
+ return NativeBuffer
84
+ {
85
+ ptr : RawPointer { readonly : ptr } ,
86
+ length,
87
+ }
79
88
}
80
89
81
- pub fn from_vec ( vec : & mut Vec < T > ) -> Self
90
+ #[ inline( always) ]
91
+ pub fn from_mutable_vec ( vec : & mut Vec < T > ) -> Self
82
92
{
83
93
let ptr = vec. as_mut_ptr ( ) ;
84
94
let length = vec. len ( ) ;
85
95
86
- ReadOnlyBuffer
96
+ return NativeBuffer
87
97
{
88
- ptr,
98
+ ptr : RawPointer { mutable : ptr } ,
89
99
length,
90
100
}
91
101
}
92
102
103
+ #[ inline( always) ]
93
104
pub fn empty ( ) -> Self
94
105
{
95
- ReadOnlyBuffer
106
+ NativeBuffer
96
107
{
97
- ptr : null ( ) ,
108
+ ptr : RawPointer { mutable : null_mut ( ) } ,
98
109
length : 0 ,
99
110
}
100
111
}
@@ -109,6 +120,7 @@ pub struct DropHandle<T=()>
109
120
110
121
impl < T > DropHandle < T >
111
122
{
123
+ #[ inline( always) ]
112
124
pub unsafe fn from_value_and_allocate_box ( value : T ) -> * mut DropHandle < T >
113
125
{
114
126
let val_box = Box :: new ( value) ;
@@ -131,6 +143,7 @@ impl <T> DropHandle<T>
131
143
return Box :: into_raw ( handle) ;
132
144
}
133
145
146
+ #[ inline( always) ]
134
147
pub unsafe fn from_handle ( handle : * mut DropHandle < T > ) -> Box < DropHandle < T > >
135
148
{
136
149
return Box :: from_raw ( handle) ;
@@ -140,10 +153,10 @@ impl <T> DropHandle<T>
140
153
#[ repr( C ) ]
141
154
pub struct TokenizeOutput
142
155
{
143
- pub ids : ReadOnlyBuffer < u32 > ,
144
- pub attention_mask : ReadOnlyBuffer < u32 > ,
145
- pub special_tokens_mask : ReadOnlyBuffer < u32 > ,
146
- pub overflowing_tokens : ReadOnlyBuffer < TokenizeOutputOverflowedToken > ,
156
+ pub ids : NativeBuffer < u32 > ,
157
+ pub attention_mask : NativeBuffer < u32 > ,
158
+ pub special_tokens_mask : NativeBuffer < u32 > ,
159
+ pub overflowing_tokens : NativeBuffer < TokenizeOutputOverflowedToken > ,
147
160
pub original_output_free_handle : * const DropHandle < Encoding > ,
148
161
pub overflowing_tokens_free_handle : * const DropHandle < Vec < TokenizeOutputOverflowedToken > > ,
149
162
}
@@ -155,13 +168,13 @@ impl TokenizeOutput
155
168
{
156
169
// println!("Offsets {:?}", encoded_tokens.get_offsets());
157
170
158
- let ids = ReadOnlyBuffer :: from_slice ( encoded_tokens. get_ids ( ) ) ;
159
- let attention_mask = ReadOnlyBuffer :: from_slice ( encoded_tokens. get_attention_mask ( ) ) ;
160
- let special_tokens_mask = ReadOnlyBuffer :: from_slice ( encoded_tokens. get_special_tokens_mask ( ) ) ;
171
+ let ids = NativeBuffer :: from_slice ( encoded_tokens. get_ids ( ) ) ;
172
+ let attention_mask = NativeBuffer :: from_slice ( encoded_tokens. get_attention_mask ( ) ) ;
173
+ let special_tokens_mask = NativeBuffer :: from_slice ( encoded_tokens. get_special_tokens_mask ( ) ) ;
161
174
162
175
let overflowing_tokens_slice = encoded_tokens. get_overflowing ( ) ;
163
176
164
- let overflowing_tokens: ReadOnlyBuffer < TokenizeOutputOverflowedToken > ;
177
+ let overflowing_tokens: NativeBuffer < TokenizeOutputOverflowedToken > ;
165
178
166
179
let overflowing_tokens_free_handle: * const DropHandle < Vec < TokenizeOutputOverflowedToken > > ;
167
180
@@ -175,7 +188,7 @@ impl TokenizeOutput
175
188
176
189
// println!("Overflowing tokens: {:?}", overflowing_tokens.as_slice().len());
177
190
178
- overflowing_tokens = ReadOnlyBuffer :: from_vec ( & mut overflowing_tokens_vec) ;
191
+ overflowing_tokens = NativeBuffer :: from_mutable_vec ( & mut overflowing_tokens_vec) ;
179
192
180
193
overflowing_tokens_free_handle = DropHandle :: from_value_and_allocate_box (
181
194
overflowing_tokens_vec
@@ -184,7 +197,7 @@ impl TokenizeOutput
184
197
185
198
else
186
199
{
187
- overflowing_tokens = ReadOnlyBuffer :: empty ( ) ;
200
+ overflowing_tokens = NativeBuffer :: empty ( ) ;
188
201
overflowing_tokens_free_handle = null ( ) ;
189
202
}
190
203
@@ -206,19 +219,19 @@ impl TokenizeOutput
206
219
#[ repr( C ) ]
207
220
pub struct TokenizeOutputOverflowedToken
208
221
{
209
- pub ids : ReadOnlyBuffer < u32 > ,
210
- pub attention_mask : ReadOnlyBuffer < u32 > ,
211
- pub special_tokens_mask : ReadOnlyBuffer < u32 > ,
222
+ pub ids : NativeBuffer < u32 > ,
223
+ pub attention_mask : NativeBuffer < u32 > ,
224
+ pub special_tokens_mask : NativeBuffer < u32 > ,
212
225
}
213
226
214
227
impl TokenizeOutputOverflowedToken
215
228
{
216
229
#[ inline( always) ]
217
230
pub unsafe fn from_overflowing_encoded_tokens ( encoded_tokens : & Encoding ) -> Self
218
231
{
219
- let ids = ReadOnlyBuffer :: from_slice ( encoded_tokens. get_ids ( ) ) ;
220
- let attention_mask = ReadOnlyBuffer :: from_slice ( encoded_tokens. get_attention_mask ( ) ) ;
221
- let special_tokens_mask = ReadOnlyBuffer :: from_slice ( encoded_tokens. get_special_tokens_mask ( ) ) ;
232
+ let ids = NativeBuffer :: from_slice ( encoded_tokens. get_ids ( ) ) ;
233
+ let attention_mask = NativeBuffer :: from_slice ( encoded_tokens. get_attention_mask ( ) ) ;
234
+ let special_tokens_mask = NativeBuffer :: from_slice ( encoded_tokens. get_special_tokens_mask ( ) ) ;
222
235
223
236
return TokenizeOutputOverflowedToken
224
237
{
@@ -231,7 +244,7 @@ impl TokenizeOutputOverflowedToken
231
244
232
245
#[ no_mangle]
233
246
pub unsafe extern "C" fn allocate_tokenizer (
234
- json_bytes : ReadOnlyBuffer < u8 > )
247
+ json_bytes : NativeBuffer < u8 > )
235
248
-> * mut Tokenizer
236
249
{
237
250
let json_bytes = json_bytes. as_slice ( ) ;
@@ -252,7 +265,7 @@ pub unsafe extern "C" fn free_tokenizer(tokenizer_handle: *mut Tokenizer)
252
265
#[ no_mangle]
253
266
pub unsafe extern "C" fn tokenizer_encode (
254
267
tokenizer_ptr : * mut Tokenizer ,
255
- text_buffer : ReadOnlyBuffer < u8 > ,
268
+ text_buffer : NativeBuffer < u8 > ,
256
269
add_special_tokens : bool )
257
270
-> TokenizeOutput
258
271
{
@@ -267,7 +280,7 @@ pub unsafe extern "C" fn tokenizer_encode(
267
280
#[ no_mangle]
268
281
pub unsafe extern "C" fn tokenizer_encode_non_truncating (
269
282
tokenizer_ptr : * mut Tokenizer ,
270
- text_buffer : ReadOnlyBuffer < u8 > ,
283
+ text_buffer : NativeBuffer < u8 > ,
271
284
add_special_tokens : bool )
272
285
-> TokenizeOutput
273
286
{
@@ -282,7 +295,7 @@ pub unsafe extern "C" fn tokenizer_encode_non_truncating(
282
295
#[ inline( always) ]
283
296
pub unsafe extern "C" fn tokenizer_encode_core (
284
297
tokenizer_ptr : * mut Tokenizer ,
285
- text_buffer : ReadOnlyBuffer < u8 > ,
298
+ text_buffer : NativeBuffer < u8 > ,
286
299
truncate : bool ,
287
300
add_special_tokens : bool )
288
301
-> TokenizeOutput
@@ -305,8 +318,8 @@ pub unsafe extern "C" fn tokenizer_encode_core(
305
318
#[ no_mangle]
306
319
pub unsafe extern "C" fn tokenizer_encode_batch (
307
320
tokenizer_ptr : * mut Tokenizer ,
308
- text_buffers : ReadOnlyBuffer < ReadOnlyBuffer < u8 > > ,
309
- output_buffer : Buffer < TokenizeOutput > ,
321
+ text_buffers : NativeBuffer < NativeBuffer < u8 > > ,
322
+ output_buffer : NativeBuffer < TokenizeOutput > ,
310
323
add_special_tokens : bool )
311
324
{
312
325
tokenizer_encode_batch_core (
@@ -321,8 +334,8 @@ pub unsafe extern "C" fn tokenizer_encode_batch(
321
334
#[ no_mangle]
322
335
pub unsafe extern "C" fn tokenizer_encode_batch_non_truncating (
323
336
tokenizer_ptr : * mut Tokenizer ,
324
- text_buffers : ReadOnlyBuffer < ReadOnlyBuffer < u8 > > ,
325
- output_buffer : Buffer < TokenizeOutput > ,
337
+ text_buffers : NativeBuffer < NativeBuffer < u8 > > ,
338
+ output_buffer : NativeBuffer < TokenizeOutput > ,
326
339
add_special_tokens : bool )
327
340
{
328
341
tokenizer_encode_batch_core (
@@ -337,8 +350,8 @@ pub unsafe extern "C" fn tokenizer_encode_batch_non_truncating(
337
350
#[ inline( always) ]
338
351
pub unsafe extern "C" fn tokenizer_encode_batch_core (
339
352
tokenizer_ptr : * mut Tokenizer ,
340
- text_buffers : ReadOnlyBuffer < ReadOnlyBuffer < u8 > > ,
341
- output_buffer : Buffer < TokenizeOutput > ,
353
+ text_buffers : NativeBuffer < NativeBuffer < u8 > > ,
354
+ output_buffer : NativeBuffer < TokenizeOutput > ,
342
355
truncate : bool ,
343
356
add_special_tokens : bool )
344
357
{
@@ -355,12 +368,10 @@ pub unsafe extern "C" fn tokenizer_encode_batch_core(
355
368
let encoded_tokens = match encoded_result
356
369
{
357
370
Ok ( encoded) => encoded,
358
- Err ( err ) => panic ! ( "{}" , err ) ,
371
+ Err ( error ) => panic ! ( "{}" , error ) ,
359
372
} ;
360
373
361
- let mut current_ptr = output_buffer. ptr ;
362
-
363
- // println!("{:?}", current_ptr);
374
+ let mut current_ptr = output_buffer. ptr . mutable ;
364
375
365
376
for encoded_token in encoded_tokens
366
377
{
@@ -373,7 +384,7 @@ pub unsafe extern "C" fn tokenizer_encode_batch_core(
373
384
#[ repr( C ) ]
374
385
pub struct DecodeOutput
375
386
{
376
- pub text_buffer : ReadOnlyBuffer < u8 > ,
387
+ pub text_buffer : NativeBuffer < u8 > ,
377
388
pub free_handle : * mut DropHandle < String >
378
389
}
379
390
@@ -384,7 +395,7 @@ impl DecodeOutput
384
395
{
385
396
let text_bytes = text. as_mut_vec ( ) ;
386
397
387
- let text_buffer = ReadOnlyBuffer :: from_vec ( text_bytes) ;
398
+ let text_buffer = NativeBuffer :: from_mutable_vec ( text_bytes) ;
388
399
389
400
let free_handle = DropHandle :: from_value_and_allocate_box ( text) ;
390
401
@@ -399,7 +410,7 @@ impl DecodeOutput
399
410
#[ no_mangle]
400
411
pub unsafe extern "C" fn tokenizer_decode (
401
412
tokenizer_ptr : * mut Tokenizer ,
402
- id_buffer : ReadOnlyBuffer < u32 > )
413
+ id_buffer : NativeBuffer < u32 > )
403
414
-> DecodeOutput
404
415
{
405
416
return tokenizer_decode_core ( tokenizer_ptr, id_buffer, false ) ;
@@ -408,7 +419,7 @@ pub unsafe extern "C" fn tokenizer_decode(
408
419
#[ no_mangle]
409
420
pub unsafe extern "C" fn tokenizer_decode_skip_special_tokens (
410
421
tokenizer_ptr : * mut Tokenizer ,
411
- id_buffer : ReadOnlyBuffer < u32 > )
422
+ id_buffer : NativeBuffer < u32 > )
412
423
-> DecodeOutput
413
424
{
414
425
return tokenizer_decode_core ( tokenizer_ptr, id_buffer, true ) ;
@@ -417,7 +428,7 @@ pub unsafe extern "C" fn tokenizer_decode_skip_special_tokens(
417
428
#[ inline( always) ]
418
429
pub unsafe extern "C" fn tokenizer_decode_core (
419
430
tokenizer_ptr : * mut Tokenizer ,
420
- id_buffer : ReadOnlyBuffer < u32 > ,
431
+ id_buffer : NativeBuffer < u32 > ,
421
432
skip_special_tokens : bool )
422
433
-> DecodeOutput
423
434
{
@@ -434,15 +445,15 @@ pub unsafe extern "C" fn free_with_handle(handle: *mut DropHandle<()>)
434
445
{
435
446
let free_data = DropHandle :: from_handle ( handle) ;
436
447
437
- // println!("Freeing memory at {:p}", free_data.ptr_to_box);
448
+ // println!("Freeing memory at {:p}", free_data.ptr_to_box);
438
449
439
450
let drop_callback = free_data. drop_callback ;
440
451
441
452
drop_callback ( free_data. ptr_to_box ) ;
442
453
}
443
454
444
455
#[ no_mangle]
445
- pub unsafe extern "C" fn free_with_multiple_handles ( handle : ReadOnlyBuffer < * mut DropHandle < ( ) > > )
456
+ pub unsafe extern "C" fn free_with_multiple_handles ( handle : NativeBuffer < * mut DropHandle < ( ) > > )
446
457
{
447
458
for free_data in handle. as_slice ( )
448
459
{
0 commit comments