@@ -3,12 +3,15 @@ package openai
33import (
44 "context"
55 "fmt"
6+ "os"
7+ "path/filepath"
68
79 "github.com/mudler/LocalAI/core/backend"
810 "github.com/mudler/LocalAI/core/config"
911 grpcClient "github.com/mudler/LocalAI/pkg/grpc"
1012 "github.com/mudler/LocalAI/pkg/grpc/proto"
1113 model "github.com/mudler/LocalAI/pkg/model"
14+ "github.com/mudler/LocalAI/pkg/utils"
1215 "github.com/mudler/xlog"
1316 "google.golang.org/grpc"
1417)
@@ -31,6 +34,7 @@ type wrappedModel struct {
3134
3235 VADConfig * config.ModelConfig
3336 VADClient grpcClient.Backend
37+ appConfig * config.ApplicationConfig
3438}
3539
3640// anyToAnyModel represent a model which supports Any-to-Any operations
@@ -49,6 +53,7 @@ type transcriptOnlyModel struct {
4953 TranscriptionClient grpcClient.Backend
5054 VADConfig * config.ModelConfig
5155 VADClient grpcClient.Backend
56+ appConfig * config.ApplicationConfig
5257}
5358
5459func (m * transcriptOnlyModel ) VAD (ctx context.Context , in * proto.VADRequest , opts ... grpc.CallOption ) (* proto.VADResponse , error ) {
@@ -67,8 +72,8 @@ func (m *transcriptOnlyModel) PredictStream(ctx context.Context, in *proto.Predi
6772 return fmt .Errorf ("predict stream operation not supported in transcript-only mode" )
6873}
6974
70- func (m * transcriptOnlyModel ) TTS (ctx context.Context , in * proto.TTSRequest , opts ... grpc.CallOption ) (* proto.Result , error ) {
71- return nil , fmt .Errorf ("TTS not supported in transcript-only mode" )
75+ func (m * transcriptOnlyModel ) TTS (ctx context.Context , in * proto.TTSRequest , opts ... grpc.CallOption ) (* proto.Result , string , error ) {
76+ return nil , "" , fmt .Errorf ("TTS not supported in transcript-only mode" )
7277}
7378
7479func (m * wrappedModel ) VAD (ctx context.Context , in * proto.VADRequest , opts ... grpc.CallOption ) (* proto.VADResponse , error ) {
@@ -101,8 +106,28 @@ func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptio
101106 return m .LLMClient .PredictStream (ctx , in , f )
102107}
103108
104- func (m * wrappedModel ) TTS (ctx context.Context , in * proto.TTSRequest , opts ... grpc.CallOption ) (* proto.Result , error ) {
105- return m .TTSClient .TTS (ctx , in , opts ... )
109+ func (m * wrappedModel ) TTS (ctx context.Context , in * proto.TTSRequest , opts ... grpc.CallOption ) (* proto.Result , string , error ) {
110+ if m .appConfig != nil && m .appConfig .SystemState != nil {
111+ mp := filepath .Join (m .appConfig .SystemState .Model .ModelsPath , m .TTSConfig .Model )
112+ if _ , err := os .Stat (mp ); err == nil {
113+ if err := utils .VerifyPath (mp , m .appConfig .SystemState .Model .ModelsPath ); err == nil {
114+ in .Model = mp
115+ }
116+ }
117+ }
118+
119+ if in .Dst == "" && m .appConfig != nil {
120+ audioDir := filepath .Join (m .appConfig .GeneratedContentDir , "audio" )
121+ if err := os .MkdirAll (audioDir , 0750 ); err != nil {
122+ return nil , "" , fmt .Errorf ("failed creating audio directory: %s" , err )
123+ }
124+
125+ fileName := utils .GenerateUniqueFileName (audioDir , "tts" , ".wav" )
126+ in .Dst = filepath .Join (audioDir , fileName )
127+ }
128+
129+ res , err := m .TTSClient .TTS (ctx , in , opts ... )
130+ return res , in .Dst , err
106131}
107132
108133func (m * anyToAnyModel ) Predict (ctx context.Context , in * proto.PredictOptions , opts ... grpc.CallOption ) (* proto.Reply , error ) {
@@ -113,8 +138,10 @@ func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOpti
113138 return m .LLMClient .PredictStream (ctx , in , f )
114139}
115140
116- func (m * anyToAnyModel ) TTS (ctx context.Context , in * proto.TTSRequest , opts ... grpc.CallOption ) (* proto.Result , error ) {
117- return m .LLMClient .TTS (ctx , in , opts ... )
141+ func (m * anyToAnyModel ) TTS (ctx context.Context , in * proto.TTSRequest , opts ... grpc.CallOption ) (* proto.Result , string , error ) {
142+ // TODO: Handle file generation if needed for anyToAnyModel
143+ res , err := m .LLMClient .TTS (ctx , in , opts ... )
144+ return res , in .Dst , err
118145}
119146
120147func newTranscriptionOnlyModel (pipeline * config.Pipeline , cl * config.ModelConfigLoader , ml * model.ModelLoader , appConfig * config.ApplicationConfig ) (Model , * config.ModelConfig , error ) {
@@ -155,6 +182,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfig
155182 VADClient : VADClient ,
156183 TranscriptionConfig : cfgSST ,
157184 TranscriptionClient : transcriptionClient ,
185+ appConfig : appConfig ,
158186 }, cfgSST , nil
159187}
160188
@@ -266,5 +294,6 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
266294
267295 VADConfig : cfgVAD ,
268296 VADClient : VADClient ,
297+ appConfig : appConfig ,
269298 }, nil
270299}
0 commit comments