@@ -10,17 +10,21 @@ class _TranscriptionThread(QThread):
1010 transcription_done = Signal (str )
1111 error_occurred = Signal (str )
1212
13- def __init__ (self , model , expected_id : int , audio_file : str | Path ) -> None :
13+ def __init__ (self , model , expected_id : int , audio_file : str | Path , task_mode : str = "transcribe" ) -> None :
1414 super ().__init__ ()
1515 self .model = model
1616 self .expected_id = expected_id
1717 self .audio_file = str (audio_file )
18+ self .task_mode = task_mode
1819
1920 def run (self ) -> None :
2021 try :
2122 if id (self .model ) != self .expected_id or self .isInterruptionRequested ():
2223 return
23- segments , _ = self .model .transcribe (self .audio_file )
24+ segments , _ = self .model .transcribe (
25+ self .audio_file ,
26+ task = self .task_mode
27+ )
2428 if self .isInterruptionRequested ():
2529 return
2630 self .transcription_done .emit ("\n " .join (s .text for s in segments ))
@@ -37,9 +41,10 @@ class TranscriptionService(QObject):
3741 transcription_completed = Signal (str )
3842 transcription_error = Signal (str )
3943
40- def __init__ (self , curate_text_enabled : bool = False ):
44+ def __init__ (self , curate_text_enabled : bool = False , task_mode : str = "transcribe" ):
4145 super ().__init__ ()
4246 self .curate_enabled = curate_text_enabled
47+ self .task_mode = task_mode
4348 self ._transcription_thread : Optional [_TranscriptionThread ] = None
4449
4550 def transcribe_file (self , model , expected_id : int , audio_file : str | Path ) -> None :
@@ -48,7 +53,7 @@ def transcribe_file(self, model, expected_id: int, audio_file: str | Path) -> No
4853 return
4954
5055 self ._transcription_thread = _TranscriptionThread (
51- model , expected_id , str (audio_file )
56+ model , expected_id , str (audio_file ), self . task_mode
5257 )
5358 self ._transcription_thread .transcription_done .connect (self ._on_transcription_done )
5459 self ._transcription_thread .error_occurred .connect (self .transcription_error )
@@ -66,6 +71,9 @@ def _on_transcription_done(self, text: str) -> None:
6671 text = "\n " .join (line .lstrip () for line in text .splitlines ())
6772 self .transcription_completed .emit (text )
6873
74+ def set_task_mode (self , mode : str ) -> None :
75+ self .task_mode = mode
76+
6977 def set_curation_enabled (self , enabled : bool ) -> None :
7078 self .curate_enabled = enabled
7179
0 commit comments