@@ -346,25 +346,43 @@ impl WhisperState {
346346 Ok ( unsafe { whisper_rs_sys:: whisper_full_get_segment_t1_from_state ( self . ptr , segment) } )
347347 }
348348
349+ fn full_get_segment_raw ( & self , segment : c_int ) -> Result < & CStr , WhisperError > {
350+ let ret =
351+ unsafe { whisper_rs_sys:: whisper_full_get_segment_text_from_state ( self . ptr , segment) } ;
352+ if ret. is_null ( ) {
353+ return Err ( WhisperError :: NullPointer ) ;
354+ }
355+ unsafe { Ok ( CStr :: from_ptr ( ret) ) }
356+ }
357+
358+ /// Get the raw bytes of the specified segment.
359+ ///
360+ /// # Arguments
361+ /// * segment: Segment index.
362+ ///
363+ /// # Returns
364+ /// `Ok(Vec<u8>)` on success, with the returned bytes or
365+ /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
366+ ///
367+ /// # C++ equivalent
368+ /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
369+ pub fn full_get_segment_bytes ( & self , segment : c_int ) -> Result < Vec < u8 > , WhisperError > {
370+ Ok ( self . full_get_segment_raw ( segment) ?. to_bytes ( ) . to_vec ( ) )
371+ }
372+
349373 /// Get the text of the specified segment.
350374 ///
351375 /// # Arguments
352376 /// * segment: Segment index.
353377 ///
354378 /// # Returns
355- /// Ok(String) on success, Err(WhisperError) on failure.
379+ /// `Ok(String)` on success, with the UTF-8 validated string, or
380+ /// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
356381 ///
357382 /// # C++ equivalent
358383 /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
359384 pub fn full_get_segment_text ( & self , segment : c_int ) -> Result < String , WhisperError > {
360- let ret =
361- unsafe { whisper_rs_sys:: whisper_full_get_segment_text_from_state ( self . ptr , segment) } ;
362- if ret. is_null ( ) {
363- return Err ( WhisperError :: NullPointer ) ;
364- }
365- let c_str = unsafe { CStr :: from_ptr ( ret) } ;
366- let r_str = c_str. to_str ( ) ?;
367- Ok ( r_str. to_string ( ) )
385+ Ok ( self . full_get_segment_raw ( segment) ?. to_str ( ) ?. to_string ( ) )
368386 }
369387
370388 /// Get the text of the specified segment.
@@ -376,53 +394,69 @@ impl WhisperState {
376394 /// * segment: Segment index.
377395 ///
378396 /// # Returns
379- /// Ok(String) on success, Err(WhisperError) on failure.
397+ /// `Ok(String)` on success, or
398+ /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
380399 ///
381400 /// # C++ equivalent
382401 /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
383402 pub fn full_get_segment_text_lossy ( & self , segment : c_int ) -> Result < String , WhisperError > {
384- let ret =
385- unsafe { whisper_rs_sys:: whisper_full_get_segment_text_from_state ( self . ptr , segment) } ;
386- if ret. is_null ( ) {
387- return Err ( WhisperError :: NullPointer ) ;
388- }
389- let c_str = unsafe { CStr :: from_ptr ( ret) } ;
390- Ok ( c_str. to_string_lossy ( ) . to_string ( ) )
403+ Ok ( self
404+ . full_get_segment_raw ( segment) ?
405+ . to_string_lossy ( )
406+ . to_string ( ) )
391407 }
392408
393- /// Get the bytes of the specified segment.
409+ /// Get number of tokens in the specified segment.
394410 ///
395411 /// # Arguments
396412 /// * segment: Segment index.
397413 ///
398414 /// # Returns
399- /// `Ok(Vec<u8>)` on success, `Err(WhisperError)` on failure.
415+ /// c_int
400416 ///
401417 /// # C++ equivalent
402- /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
403- pub fn full_get_segment_bytes ( & self , segment : c_int ) -> Result < Vec < u8 > , WhisperError > {
404- let ret =
405- unsafe { whisper_rs_sys:: whisper_full_get_segment_text_from_state ( self . ptr , segment) } ;
418+ /// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
419+ #[ inline]
420+ pub fn full_n_tokens ( & self , segment : c_int ) -> Result < c_int , WhisperError > {
421+ Ok ( unsafe { whisper_rs_sys:: whisper_full_n_tokens_from_state ( self . ptr , segment) } )
422+ }
423+
424+ fn full_get_token_raw ( & self , segment : c_int , token : c_int ) -> Result < & CStr , WhisperError > {
425+ let ret = unsafe {
426+ whisper_rs_sys:: whisper_full_get_token_text_from_state (
427+ self . ctx . ctx ,
428+ self . ptr ,
429+ segment,
430+ token,
431+ )
432+ } ;
406433 if ret. is_null ( ) {
407434 return Err ( WhisperError :: NullPointer ) ;
408435 }
409- let c_str = unsafe { CStr :: from_ptr ( ret) } ;
410- Ok ( c_str. to_bytes ( ) . to_vec ( ) )
436+ unsafe { Ok ( CStr :: from_ptr ( ret) ) }
411437 }
412438
413- /// Get number of tokens in the specified segment.
439+ /// Get the raw token bytes of the specified token in the specified segment.
440+ ///
441+ /// Useful if you're using a language for which whisper is known to split tokens
442+ /// away from UTF-8 character boundaries.
414443 ///
415444 /// # Arguments
416445 /// * segment: Segment index.
446+ /// * token: Token index.
417447 ///
418448 /// # Returns
419- /// c_int
449+ /// `Ok(Vec<u8>)` on success, with the returned bytes or
450+ /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
420451 ///
421452 /// # C++ equivalent
422- /// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
423- #[ inline]
424- pub fn full_n_tokens ( & self , segment : c_int ) -> Result < c_int , WhisperError > {
425- Ok ( unsafe { whisper_rs_sys:: whisper_full_n_tokens_from_state ( self . ptr , segment) } )
453+ /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
454+ pub fn full_get_token_bytes (
455+ & self ,
456+ segment : c_int ,
457+ token : c_int ,
458+ ) -> Result < Vec < u8 > , WhisperError > {
459+ Ok ( self . full_get_token_raw ( segment, token) ?. to_bytes ( ) . to_vec ( ) )
426460 }
427461
428462 /// Get the token text of the specified token in the specified segment.
@@ -432,7 +466,8 @@ impl WhisperState {
432466 /// * token: Token index.
433467 ///
434468 /// # Returns
435- /// Ok(String) on success, Err(WhisperError) on failure.
469+ /// `Ok(String)` on success, with the UTF-8 validated string, or
470+ /// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
436471 ///
437472 /// # C++ equivalent
438473 /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
@@ -441,20 +476,10 @@ impl WhisperState {
441476 segment : c_int ,
442477 token : c_int ,
443478 ) -> Result < String , WhisperError > {
444- let ret = unsafe {
445- whisper_rs_sys:: whisper_full_get_token_text_from_state (
446- self . ctx . ctx ,
447- self . ptr ,
448- segment,
449- token,
450- )
451- } ;
452- if ret. is_null ( ) {
453- return Err ( WhisperError :: NullPointer ) ;
454- }
455- let c_str = unsafe { CStr :: from_ptr ( ret) } ;
456- let r_str = c_str. to_str ( ) ?;
457- Ok ( r_str. to_string ( ) )
479+ Ok ( self
480+ . full_get_token_raw ( segment, token) ?
481+ . to_str ( ) ?
482+ . to_string ( ) )
458483 }
459484
460485 /// Get the token text of the specified token in the specified segment.
@@ -467,7 +492,8 @@ impl WhisperState {
467492 /// * token: Token index.
468493 ///
469494 /// # Returns
470- /// Ok(String) on success, Err(WhisperError) on failure.
495+ /// `Ok(String)` on success, or
496+ /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
471497 ///
472498 /// # C++ equivalent
473499 /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
@@ -476,19 +502,10 @@ impl WhisperState {
476502 segment : c_int ,
477503 token : c_int ,
478504 ) -> Result < String , WhisperError > {
479- let ret = unsafe {
480- whisper_rs_sys:: whisper_full_get_token_text_from_state (
481- self . ctx . ctx ,
482- self . ptr ,
483- segment,
484- token,
485- )
486- } ;
487- if ret. is_null ( ) {
488- return Err ( WhisperError :: NullPointer ) ;
489- }
490- let c_str = unsafe { CStr :: from_ptr ( ret) } ;
491- Ok ( c_str. to_string_lossy ( ) . to_string ( ) )
505+ Ok ( self
506+ . full_get_token_raw ( segment, token) ?
507+ . to_string_lossy ( )
508+ . to_string ( ) )
492509 }
493510
494511 /// Get the token ID of the specified token in the specified segment.
0 commit comments