Skip to content
This repository was archived by the owner on Jul 30, 2025. It is now read-only.

Commit d571715

Browse files
authored
Merge pull request #202 from tazz4843/convert-to-string-helpers
Convert `full_get_*_*` methods to use internal helper instead of duplicating code
2 parents 37cba93 + 9a96b0e commit d571715

File tree

1 file changed

+77
-60
lines changed

1 file changed

+77
-60
lines changed

src/whisper_state.rs

Lines changed: 77 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -346,25 +346,43 @@ impl WhisperState {
346346
Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.ptr, segment) })
347347
}
348348

349+
fn full_get_segment_raw(&self, segment: c_int) -> Result<&CStr, WhisperError> {
350+
let ret =
351+
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
352+
if ret.is_null() {
353+
return Err(WhisperError::NullPointer);
354+
}
355+
unsafe { Ok(CStr::from_ptr(ret)) }
356+
}
357+
358+
/// Get the raw bytes of the specified segment.
359+
///
360+
/// # Arguments
361+
/// * segment: Segment index.
362+
///
363+
/// # Returns
364+
/// `Ok(Vec<u8>)` on success, with the returned bytes or
365+
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
366+
///
367+
/// # C++ equivalent
368+
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
369+
pub fn full_get_segment_bytes(&self, segment: c_int) -> Result<Vec<u8>, WhisperError> {
370+
Ok(self.full_get_segment_raw(segment)?.to_bytes().to_vec())
371+
}
372+
349373
/// Get the text of the specified segment.
350374
///
351375
/// # Arguments
352376
/// * segment: Segment index.
353377
///
354378
/// # Returns
355-
/// Ok(String) on success, Err(WhisperError) on failure.
379+
/// `Ok(String)` on success, with the UTF-8 validated string, or
380+
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
356381
///
357382
/// # C++ equivalent
358383
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
359384
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> {
360-
let ret =
361-
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
362-
if ret.is_null() {
363-
return Err(WhisperError::NullPointer);
364-
}
365-
let c_str = unsafe { CStr::from_ptr(ret) };
366-
let r_str = c_str.to_str()?;
367-
Ok(r_str.to_string())
385+
Ok(self.full_get_segment_raw(segment)?.to_str()?.to_string())
368386
}
369387

370388
/// Get the text of the specified segment.
@@ -376,53 +394,69 @@ impl WhisperState {
376394
/// * segment: Segment index.
377395
///
378396
/// # Returns
379-
/// Ok(String) on success, Err(WhisperError) on failure.
397+
/// `Ok(String)` on success, or
398+
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
380399
///
381400
/// # C++ equivalent
382401
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
383402
pub fn full_get_segment_text_lossy(&self, segment: c_int) -> Result<String, WhisperError> {
384-
let ret =
385-
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
386-
if ret.is_null() {
387-
return Err(WhisperError::NullPointer);
388-
}
389-
let c_str = unsafe { CStr::from_ptr(ret) };
390-
Ok(c_str.to_string_lossy().to_string())
403+
Ok(self
404+
.full_get_segment_raw(segment)?
405+
.to_string_lossy()
406+
.to_string())
391407
}
392408

393-
/// Get the bytes of the specified segment.
409+
/// Get number of tokens in the specified segment.
394410
///
395411
/// # Arguments
396412
/// * segment: Segment index.
397413
///
398414
/// # Returns
399-
/// `Ok(Vec<u8>)` on success, `Err(WhisperError)` on failure.
415+
/// c_int
400416
///
401417
/// # C++ equivalent
402-
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
403-
pub fn full_get_segment_bytes(&self, segment: c_int) -> Result<Vec<u8>, WhisperError> {
404-
let ret =
405-
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
418+
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
419+
#[inline]
420+
pub fn full_n_tokens(&self, segment: c_int) -> Result<c_int, WhisperError> {
421+
Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) })
422+
}
423+
424+
fn full_get_token_raw(&self, segment: c_int, token: c_int) -> Result<&CStr, WhisperError> {
425+
let ret = unsafe {
426+
whisper_rs_sys::whisper_full_get_token_text_from_state(
427+
self.ctx.ctx,
428+
self.ptr,
429+
segment,
430+
token,
431+
)
432+
};
406433
if ret.is_null() {
407434
return Err(WhisperError::NullPointer);
408435
}
409-
let c_str = unsafe { CStr::from_ptr(ret) };
410-
Ok(c_str.to_bytes().to_vec())
436+
unsafe { Ok(CStr::from_ptr(ret)) }
411437
}
412438

413-
/// Get number of tokens in the specified segment.
439+
/// Get the raw token bytes of the specified token in the specified segment.
440+
///
441+
/// Useful if you're using a language for which whisper is known to split tokens
442+
/// away from UTF-8 character boundaries.
414443
///
415444
/// # Arguments
416445
/// * segment: Segment index.
446+
/// * token: Token index.
417447
///
418448
/// # Returns
419-
/// c_int
449+
/// `Ok(Vec<u8>)` on success, with the returned bytes or
450+
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
420451
///
421452
/// # C++ equivalent
422-
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
423-
#[inline]
424-
pub fn full_n_tokens(&self, segment: c_int) -> Result<c_int, WhisperError> {
425-
Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) })
453+
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
454+
pub fn full_get_token_bytes(
455+
&self,
456+
segment: c_int,
457+
token: c_int,
458+
) -> Result<Vec<u8>, WhisperError> {
459+
Ok(self.full_get_token_raw(segment, token)?.to_bytes().to_vec())
426460
}
427461

428462
/// Get the token text of the specified token in the specified segment.
@@ -432,7 +466,8 @@ impl WhisperState {
432466
/// * token: Token index.
433467
///
434468
/// # Returns
435-
/// Ok(String) on success, Err(WhisperError) on failure.
469+
/// `Ok(String)` on success, with the UTF-8 validated string, or
470+
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
436471
///
437472
/// # C++ equivalent
438473
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
@@ -441,20 +476,10 @@ impl WhisperState {
441476
segment: c_int,
442477
token: c_int,
443478
) -> Result<String, WhisperError> {
444-
let ret = unsafe {
445-
whisper_rs_sys::whisper_full_get_token_text_from_state(
446-
self.ctx.ctx,
447-
self.ptr,
448-
segment,
449-
token,
450-
)
451-
};
452-
if ret.is_null() {
453-
return Err(WhisperError::NullPointer);
454-
}
455-
let c_str = unsafe { CStr::from_ptr(ret) };
456-
let r_str = c_str.to_str()?;
457-
Ok(r_str.to_string())
479+
Ok(self
480+
.full_get_token_raw(segment, token)?
481+
.to_str()?
482+
.to_string())
458483
}
459484

460485
/// Get the token text of the specified token in the specified segment.
@@ -467,7 +492,8 @@ impl WhisperState {
467492
/// * token: Token index.
468493
///
469494
/// # Returns
470-
/// Ok(String) on success, Err(WhisperError) on failure.
495+
/// `Ok(String)` on success, or
496+
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
471497
///
472498
/// # C++ equivalent
473499
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
@@ -476,19 +502,10 @@ impl WhisperState {
476502
segment: c_int,
477503
token: c_int,
478504
) -> Result<String, WhisperError> {
479-
let ret = unsafe {
480-
whisper_rs_sys::whisper_full_get_token_text_from_state(
481-
self.ctx.ctx,
482-
self.ptr,
483-
segment,
484-
token,
485-
)
486-
};
487-
if ret.is_null() {
488-
return Err(WhisperError::NullPointer);
489-
}
490-
let c_str = unsafe { CStr::from_ptr(ret) };
491-
Ok(c_str.to_string_lossy().to_string())
505+
Ok(self
506+
.full_get_token_raw(segment, token)?
507+
.to_string_lossy()
508+
.to_string())
492509
}
493510

494511
/// Get the token ID of the specified token in the specified segment.

0 commit comments

Comments
 (0)