Skip to content

Commit 32d6764

Browse files
committed
Get rid of readonly variant of NativeBuffer.cs, it is kinda pointless and cumbersome to maintain both structures
1 parent e3ff261 commit 32d6764

10 files changed

+147
-212
lines changed

Native/src/lib.rs

Lines changed: 93 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4,97 +4,108 @@ use std::slice;
44
use tokenizers::tokenizer::Tokenizer;
55
use tokenizers::Encoding;
66

7+
// #[inline(always)] is used aggressively - Realistically we only have a few callsites.
8+
9+
#[repr(C)]
10+
union RawPointer<T>
11+
{
12+
mutable: *mut T,
13+
readonly: *const T,
14+
}
15+
716
#[repr(C)]
8-
pub struct Buffer<T>
17+
pub struct NativeBuffer<T>
918
{
10-
pub ptr: *mut T,
19+
pub ptr: RawPointer<T>,
1120
pub length: usize,
1221
}
1322

14-
impl<T> Buffer<T>
23+
impl<T> NativeBuffer<T>
1524
{
16-
pub fn new(ptr: *mut T, length: usize) -> Self
25+
#[inline(always)]
26+
pub fn wrap_mutable_ptr(ptr: *mut T, length: usize) -> Self
1727
{
18-
Buffer
28+
NativeBuffer
1929
{
20-
ptr,
30+
ptr: RawPointer { mutable: ptr },
2131
length,
2232
}
2333
}
2434

25-
pub fn from_slice(slice: &mut [T]) -> Self
35+
#[inline(always)]
36+
pub fn wrap_ptr(ptr: *const T, length: usize) -> Self
2637
{
27-
Buffer
38+
NativeBuffer
2839
{
29-
ptr: slice.as_mut_ptr(),
30-
length: slice.len(),
40+
ptr: RawPointer { readonly: ptr },
41+
length,
3142
}
3243
}
3344

34-
pub unsafe fn to_slice(&self) -> &mut [T]
45+
#[inline(always)]
46+
pub fn from_slice(slice: &[T]) -> Self
3547
{
36-
return slice::from_raw_parts_mut(self.ptr, self.length)
48+
NativeBuffer
49+
{
50+
ptr: RawPointer { readonly: slice.as_ptr() },
51+
length: slice.len(),
52+
}
3753
}
3854

39-
pub fn empty() -> Self
55+
#[inline(always)]
56+
pub fn from_mutable_slice(slice: &mut [T]) -> Self
4057
{
41-
Buffer
58+
NativeBuffer
4259
{
43-
ptr: null_mut(),
44-
length: 0,
60+
ptr: RawPointer { mutable: slice.as_mut_ptr() },
61+
length: slice.len(),
4562
}
4663
}
47-
}
48-
49-
#[repr(C)]
50-
pub struct ReadOnlyBuffer<T>
51-
{
52-
ptr: *const T,
53-
pub length: usize,
54-
}
5564

56-
impl<T> ReadOnlyBuffer<T>
57-
{
58-
pub fn new(ptr: *const T, length: usize) -> Self
65+
#[inline(always)]
66+
pub unsafe fn as_slice(&self) -> &[T]
5967
{
60-
ReadOnlyBuffer
61-
{
62-
ptr,
63-
length,
64-
}
68+
return slice::from_raw_parts(self.ptr.readonly, self.length)
6569
}
6670

67-
pub fn from_slice(slice: &[T]) -> Self
71+
#[inline(always)]
72+
pub unsafe fn as_mutable_slice(&self) -> &mut [T]
6873
{
69-
ReadOnlyBuffer
70-
{
71-
ptr: slice.as_ptr(),
72-
length: slice.len(),
73-
}
74+
return slice::from_raw_parts_mut(self.ptr.mutable, self.length)
7475
}
7576

76-
pub unsafe fn as_slice(&self) -> &[T]
77+
#[inline(always)]
78+
pub fn from_vec(vec: &Vec<T>) -> Self
7779
{
78-
return slice::from_raw_parts(self.ptr, self.length)
80+
let ptr = vec.as_ptr();
81+
let length = vec.len();
82+
83+
return NativeBuffer
84+
{
85+
ptr: RawPointer { readonly: ptr },
86+
length,
87+
}
7988
}
8089

81-
pub fn from_vec(vec: &mut Vec<T>) -> Self
90+
#[inline(always)]
91+
pub fn from_mutable_vec(vec: &mut Vec<T>) -> Self
8292
{
8393
let ptr = vec.as_mut_ptr();
8494
let length = vec.len();
8595

86-
ReadOnlyBuffer
96+
return NativeBuffer
8797
{
88-
ptr,
98+
ptr: RawPointer { mutable: ptr },
8999
length,
90100
}
91101
}
92102

103+
#[inline(always)]
93104
pub fn empty() -> Self
94105
{
95-
ReadOnlyBuffer
106+
NativeBuffer
96107
{
97-
ptr: null(),
108+
ptr: RawPointer { mutable: null_mut() },
98109
length: 0,
99110
}
100111
}
@@ -109,6 +120,7 @@ pub struct DropHandle<T=()>
109120

110121
impl <T> DropHandle<T>
111122
{
123+
#[inline(always)]
112124
pub unsafe fn from_value_and_allocate_box(value: T) -> *mut DropHandle<T>
113125
{
114126
let val_box = Box::new(value);
@@ -131,6 +143,7 @@ impl <T> DropHandle<T>
131143
return Box::into_raw(handle);
132144
}
133145

146+
#[inline(always)]
134147
pub unsafe fn from_handle(handle: *mut DropHandle<T>) -> Box<DropHandle<T>>
135148
{
136149
return Box::from_raw(handle);
@@ -140,10 +153,10 @@ impl <T> DropHandle<T>
140153
#[repr(C)]
141154
pub struct TokenizeOutput
142155
{
143-
pub ids: ReadOnlyBuffer<u32>,
144-
pub attention_mask: ReadOnlyBuffer<u32>,
145-
pub special_tokens_mask: ReadOnlyBuffer<u32>,
146-
pub overflowing_tokens: ReadOnlyBuffer<TokenizeOutputOverflowedToken>,
156+
pub ids: NativeBuffer<u32>,
157+
pub attention_mask: NativeBuffer<u32>,
158+
pub special_tokens_mask: NativeBuffer<u32>,
159+
pub overflowing_tokens: NativeBuffer<TokenizeOutputOverflowedToken>,
147160
pub original_output_free_handle: *const DropHandle<Encoding>,
148161
pub overflowing_tokens_free_handle: *const DropHandle<Vec<TokenizeOutputOverflowedToken>>,
149162
}
@@ -155,13 +168,13 @@ impl TokenizeOutput
155168
{
156169
// println!("Offsets {:?}", encoded_tokens.get_offsets());
157170

158-
let ids = ReadOnlyBuffer::from_slice(encoded_tokens.get_ids());
159-
let attention_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_attention_mask());
160-
let special_tokens_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_special_tokens_mask());
171+
let ids = NativeBuffer::from_slice(encoded_tokens.get_ids());
172+
let attention_mask = NativeBuffer::from_slice(encoded_tokens.get_attention_mask());
173+
let special_tokens_mask = NativeBuffer::from_slice(encoded_tokens.get_special_tokens_mask());
161174

162175
let overflowing_tokens_slice = encoded_tokens.get_overflowing();
163176

164-
let overflowing_tokens: ReadOnlyBuffer<TokenizeOutputOverflowedToken>;
177+
let overflowing_tokens: NativeBuffer<TokenizeOutputOverflowedToken>;
165178

166179
let overflowing_tokens_free_handle: *const DropHandle<Vec<TokenizeOutputOverflowedToken>>;
167180

@@ -175,7 +188,7 @@ impl TokenizeOutput
175188

176189
// println!("Overflowing tokens: {:?}", overflowing_tokens.as_slice().len());
177190

178-
overflowing_tokens = ReadOnlyBuffer::from_vec(&mut overflowing_tokens_vec);
191+
overflowing_tokens = NativeBuffer::from_mutable_vec(&mut overflowing_tokens_vec);
179192

180193
overflowing_tokens_free_handle = DropHandle::from_value_and_allocate_box(
181194
overflowing_tokens_vec
@@ -184,7 +197,7 @@ impl TokenizeOutput
184197

185198
else
186199
{
187-
overflowing_tokens = ReadOnlyBuffer::empty();
200+
overflowing_tokens = NativeBuffer::empty();
188201
overflowing_tokens_free_handle = null();
189202
}
190203

@@ -206,19 +219,19 @@ impl TokenizeOutput
206219
#[repr(C)]
207220
pub struct TokenizeOutputOverflowedToken
208221
{
209-
pub ids: ReadOnlyBuffer<u32>,
210-
pub attention_mask: ReadOnlyBuffer<u32>,
211-
pub special_tokens_mask: ReadOnlyBuffer<u32>,
222+
pub ids: NativeBuffer<u32>,
223+
pub attention_mask: NativeBuffer<u32>,
224+
pub special_tokens_mask: NativeBuffer<u32>,
212225
}
213226

214227
impl TokenizeOutputOverflowedToken
215228
{
216229
#[inline(always)]
217230
pub unsafe fn from_overflowing_encoded_tokens(encoded_tokens: &Encoding) -> Self
218231
{
219-
let ids = ReadOnlyBuffer::from_slice(encoded_tokens.get_ids());
220-
let attention_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_attention_mask());
221-
let special_tokens_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_special_tokens_mask());
232+
let ids = NativeBuffer::from_slice(encoded_tokens.get_ids());
233+
let attention_mask = NativeBuffer::from_slice(encoded_tokens.get_attention_mask());
234+
let special_tokens_mask = NativeBuffer::from_slice(encoded_tokens.get_special_tokens_mask());
222235

223236
return TokenizeOutputOverflowedToken
224237
{
@@ -231,7 +244,7 @@ impl TokenizeOutputOverflowedToken
231244

232245
#[no_mangle]
233246
pub unsafe extern "C" fn allocate_tokenizer(
234-
json_bytes: ReadOnlyBuffer<u8>)
247+
json_bytes: NativeBuffer<u8>)
235248
-> *mut Tokenizer
236249
{
237250
let json_bytes = json_bytes.as_slice();
@@ -252,7 +265,7 @@ pub unsafe extern "C" fn free_tokenizer(tokenizer_handle: *mut Tokenizer)
252265
#[no_mangle]
253266
pub unsafe extern "C" fn tokenizer_encode(
254267
tokenizer_ptr: *mut Tokenizer,
255-
text_buffer: ReadOnlyBuffer<u8>,
268+
text_buffer: NativeBuffer<u8>,
256269
add_special_tokens: bool)
257270
-> TokenizeOutput
258271
{
@@ -267,7 +280,7 @@ pub unsafe extern "C" fn tokenizer_encode(
267280
#[no_mangle]
268281
pub unsafe extern "C" fn tokenizer_encode_non_truncating(
269282
tokenizer_ptr: *mut Tokenizer,
270-
text_buffer: ReadOnlyBuffer<u8>,
283+
text_buffer: NativeBuffer<u8>,
271284
add_special_tokens: bool)
272285
-> TokenizeOutput
273286
{
@@ -282,7 +295,7 @@ pub unsafe extern "C" fn tokenizer_encode_non_truncating(
282295
#[inline(always)]
283296
pub unsafe extern "C" fn tokenizer_encode_core(
284297
tokenizer_ptr: *mut Tokenizer,
285-
text_buffer: ReadOnlyBuffer<u8>,
298+
text_buffer: NativeBuffer<u8>,
286299
truncate: bool,
287300
add_special_tokens: bool)
288301
-> TokenizeOutput
@@ -305,8 +318,8 @@ pub unsafe extern "C" fn tokenizer_encode_core(
305318
#[no_mangle]
306319
pub unsafe extern "C" fn tokenizer_encode_batch(
307320
tokenizer_ptr: *mut Tokenizer,
308-
text_buffers: ReadOnlyBuffer<ReadOnlyBuffer<u8>>,
309-
output_buffer: Buffer<TokenizeOutput>,
321+
text_buffers: NativeBuffer<NativeBuffer<u8>>,
322+
output_buffer: NativeBuffer<TokenizeOutput>,
310323
add_special_tokens: bool)
311324
{
312325
tokenizer_encode_batch_core(
@@ -321,8 +334,8 @@ pub unsafe extern "C" fn tokenizer_encode_batch(
321334
#[no_mangle]
322335
pub unsafe extern "C" fn tokenizer_encode_batch_non_truncating(
323336
tokenizer_ptr: *mut Tokenizer,
324-
text_buffers: ReadOnlyBuffer<ReadOnlyBuffer<u8>>,
325-
output_buffer: Buffer<TokenizeOutput>,
337+
text_buffers: NativeBuffer<NativeBuffer<u8>>,
338+
output_buffer: NativeBuffer<TokenizeOutput>,
326339
add_special_tokens: bool)
327340
{
328341
tokenizer_encode_batch_core(
@@ -337,8 +350,8 @@ pub unsafe extern "C" fn tokenizer_encode_batch_non_truncating(
337350
#[inline(always)]
338351
pub unsafe extern "C" fn tokenizer_encode_batch_core(
339352
tokenizer_ptr: *mut Tokenizer,
340-
text_buffers: ReadOnlyBuffer<ReadOnlyBuffer<u8>>,
341-
output_buffer: Buffer<TokenizeOutput>,
353+
text_buffers: NativeBuffer<NativeBuffer<u8>>,
354+
output_buffer: NativeBuffer<TokenizeOutput>,
342355
truncate: bool,
343356
add_special_tokens: bool)
344357
{
@@ -355,12 +368,10 @@ pub unsafe extern "C" fn tokenizer_encode_batch_core(
355368
let encoded_tokens = match encoded_result
356369
{
357370
Ok(encoded) => encoded,
358-
Err(err) => panic!("{}", err),
371+
Err(error) => panic!("{}", error),
359372
};
360373

361-
let mut current_ptr = output_buffer.ptr;
362-
363-
// println!("{:?}", current_ptr);
374+
let mut current_ptr = output_buffer.ptr.mutable;
364375

365376
for encoded_token in encoded_tokens
366377
{
@@ -373,7 +384,7 @@ pub unsafe extern "C" fn tokenizer_encode_batch_core(
373384
#[repr(C)]
374385
pub struct DecodeOutput
375386
{
376-
pub text_buffer: ReadOnlyBuffer<u8>,
387+
pub text_buffer: NativeBuffer<u8>,
377388
pub free_handle: *mut DropHandle<String>
378389
}
379390

@@ -384,7 +395,7 @@ impl DecodeOutput
384395
{
385396
let text_bytes = text.as_mut_vec();
386397

387-
let text_buffer = ReadOnlyBuffer::from_vec(text_bytes);
398+
let text_buffer = NativeBuffer::from_mutable_vec(text_bytes);
388399

389400
let free_handle = DropHandle::from_value_and_allocate_box(text);
390401

@@ -399,7 +410,7 @@ impl DecodeOutput
399410
#[no_mangle]
400411
pub unsafe extern "C" fn tokenizer_decode(
401412
tokenizer_ptr: *mut Tokenizer,
402-
id_buffer: ReadOnlyBuffer<u32>)
413+
id_buffer: NativeBuffer<u32>)
403414
-> DecodeOutput
404415
{
405416
return tokenizer_decode_core(tokenizer_ptr, id_buffer, false);
@@ -408,7 +419,7 @@ pub unsafe extern "C" fn tokenizer_decode(
408419
#[no_mangle]
409420
pub unsafe extern "C" fn tokenizer_decode_skip_special_tokens(
410421
tokenizer_ptr: *mut Tokenizer,
411-
id_buffer: ReadOnlyBuffer<u32>)
422+
id_buffer: NativeBuffer<u32>)
412423
-> DecodeOutput
413424
{
414425
return tokenizer_decode_core(tokenizer_ptr, id_buffer, true);
@@ -417,7 +428,7 @@ pub unsafe extern "C" fn tokenizer_decode_skip_special_tokens(
417428
#[inline(always)]
418429
pub unsafe extern "C" fn tokenizer_decode_core(
419430
tokenizer_ptr: *mut Tokenizer,
420-
id_buffer: ReadOnlyBuffer<u32>,
431+
id_buffer: NativeBuffer<u32>,
421432
skip_special_tokens: bool)
422433
-> DecodeOutput
423434
{
@@ -434,15 +445,15 @@ pub unsafe extern "C" fn free_with_handle(handle: *mut DropHandle<()>)
434445
{
435446
let free_data = DropHandle::from_handle(handle);
436447

437-
// println!("Freeing memory at {:p}", free_data.ptr_to_box);
448+
// println!("Freeing memory at {:p}", free_data.ptr_to_box);
438449

439450
let drop_callback = free_data.drop_callback;
440451

441452
drop_callback(free_data.ptr_to_box);
442453
}
443454

444455
#[no_mangle]
445-
pub unsafe extern "C" fn free_with_multiple_handles(handle: ReadOnlyBuffer<*mut DropHandle<()>>)
456+
pub unsafe extern "C" fn free_with_multiple_handles(handle: NativeBuffer<*mut DropHandle<()>>)
446457
{
447458
for free_data in handle.as_slice()
448459
{

0 commit comments

Comments
 (0)