Skip to content

Commit 2271f01

Browse files
committed
try to fix tts
1 parent b9560e8 commit 2271f01

File tree

2 files changed

+40
-27
lines changed

2 files changed

+40
-27
lines changed

core/http/endpoints/openai/realtime.go

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ type Model interface {
161161
Transcribe(ctx context.Context, in *proto.TranscriptRequest, opts ...grpc.CallOption) (*proto.TranscriptResult, error)
162162
Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error)
163163
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
164-
TTS(ctx context.Context, in *proto.TTSRequest, opts ...grpc.CallOption) (*proto.Result, error)
164+
TTS(ctx context.Context, in *proto.TTSRequest, opts ...grpc.CallOption) (*proto.Result, string, error)
165165
}
166166

167167
var upgrader = websocket.Upgrader{
@@ -581,7 +581,7 @@ func updateSession(session *Session, update *types.ClientSession, cl *config.Mod
581581
if update.Model != "" || update.Voice != "" || update.InputAudioTranscription != nil {
582582
pipeline := config.Pipeline{
583583
VAD: defaultVADModel,
584-
LLM: update.Model,
584+
LLM: session.Model,
585585
Transcription: session.InputAudioTranscription.Model,
586586
TTS: session.Voice,
587587
}
@@ -923,29 +923,12 @@ func generateResponse(config *config.ModelConfig, evaluator *templates.Evaluator
923923
}
924924
conv.Lock.Unlock()
925925

926-
f, err := os.CreateTemp("", "realtime-tts-*.wav")
927-
if err != nil {
928-
xlog.Error("failed to create temp file for TTS", "error", err)
929-
sendError(c, "tts_error", "Failed to create temp file for TTS", "", item.ID)
930-
return
931-
}
932-
defer os.Remove(f.Name())
933-
934-
modelWrapped, ok := session.ModelInterface.(*wrappedModel)
935-
if !ok {
936-
xlog.Error("model is not wrappedModel")
937-
sendError(c, "model_error", "Model is not wrappedModel", "", item.ID)
938-
return
939-
}
940-
941926
ttsReq := &proto.TTSRequest{
942927
Text: response,
943928
Voice: session.Voice,
944-
Model: modelWrapped.TTSConfig.Model,
945-
Dst: f.Name(),
946929
}
947930

948-
res, err := session.ModelInterface.TTS(context.TODO(), ttsReq)
931+
res, audioFilePath, err := session.ModelInterface.TTS(context.TODO(), ttsReq)
949932
if err != nil {
950933
xlog.Error("TTS failed", "error", err)
951934
sendError(c, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.ID)
@@ -956,8 +939,9 @@ func generateResponse(config *config.ModelConfig, evaluator *templates.Evaluator
956939
sendError(c, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.ID)
957940
return
958941
}
942+
defer os.Remove(audioFilePath)
959943

960-
audioBytes, err := os.ReadFile(f.Name())
944+
audioBytes, err := os.ReadFile(audioFilePath)
961945
if err != nil {
962946
xlog.Error("failed to read TTS file", "error", err)
963947
sendError(c, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.ID)

core/http/endpoints/openai/realtime_model.go

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ package openai
33
import (
44
"context"
55
"fmt"
6+
"os"
7+
"path/filepath"
68

79
"github.com/mudler/LocalAI/core/backend"
810
"github.com/mudler/LocalAI/core/config"
911
grpcClient "github.com/mudler/LocalAI/pkg/grpc"
1012
"github.com/mudler/LocalAI/pkg/grpc/proto"
1113
model "github.com/mudler/LocalAI/pkg/model"
14+
"github.com/mudler/LocalAI/pkg/utils"
1215
"github.com/mudler/xlog"
1316
"google.golang.org/grpc"
1417
)
@@ -31,6 +34,7 @@ type wrappedModel struct {
3134

3235
VADConfig *config.ModelConfig
3336
VADClient grpcClient.Backend
37+
appConfig *config.ApplicationConfig
3438
}
3539

3640
// anyToAnyModel represent a model which supports Any-to-Any operations
@@ -49,6 +53,7 @@ type transcriptOnlyModel struct {
4953
TranscriptionClient grpcClient.Backend
5054
VADConfig *config.ModelConfig
5155
VADClient grpcClient.Backend
56+
appConfig *config.ApplicationConfig
5257
}
5358

5459
func (m *transcriptOnlyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
@@ -67,8 +72,8 @@ func (m *transcriptOnlyModel) PredictStream(ctx context.Context, in *proto.Predi
6772
return fmt.Errorf("predict stream operation not supported in transcript-only mode")
6873
}
6974

70-
func (m *transcriptOnlyModel) TTS(ctx context.Context, in *proto.TTSRequest, opts ...grpc.CallOption) (*proto.Result, error) {
71-
return nil, fmt.Errorf("TTS not supported in transcript-only mode")
75+
func (m *transcriptOnlyModel) TTS(ctx context.Context, in *proto.TTSRequest, opts ...grpc.CallOption) (*proto.Result, string, error) {
76+
return nil, "", fmt.Errorf("TTS not supported in transcript-only mode")
7277
}
7378

7479
func (m *wrappedModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
@@ -101,8 +106,28 @@ func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptio
101106
return m.LLMClient.PredictStream(ctx, in, f)
102107
}
103108

104-
func (m *wrappedModel) TTS(ctx context.Context, in *proto.TTSRequest, opts ...grpc.CallOption) (*proto.Result, error) {
105-
return m.TTSClient.TTS(ctx, in, opts...)
109+
func (m *wrappedModel) TTS(ctx context.Context, in *proto.TTSRequest, opts ...grpc.CallOption) (*proto.Result, string, error) {
110+
if m.appConfig != nil && m.appConfig.SystemState != nil {
111+
mp := filepath.Join(m.appConfig.SystemState.Model.ModelsPath, m.TTSConfig.Model)
112+
if _, err := os.Stat(mp); err == nil {
113+
if err := utils.VerifyPath(mp, m.appConfig.SystemState.Model.ModelsPath); err == nil {
114+
in.Model = mp
115+
}
116+
}
117+
}
118+
119+
if in.Dst == "" && m.appConfig != nil {
120+
audioDir := filepath.Join(m.appConfig.GeneratedContentDir, "audio")
121+
if err := os.MkdirAll(audioDir, 0750); err != nil {
122+
return nil, "", fmt.Errorf("failed creating audio directory: %s", err)
123+
}
124+
125+
fileName := utils.GenerateUniqueFileName(audioDir, "tts", ".wav")
126+
in.Dst = filepath.Join(audioDir, fileName)
127+
}
128+
129+
res, err := m.TTSClient.TTS(ctx, in, opts...)
130+
return res, in.Dst, err
106131
}
107132

108133
func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
@@ -113,8 +138,10 @@ func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOpti
113138
return m.LLMClient.PredictStream(ctx, in, f)
114139
}
115140

116-
func (m *anyToAnyModel) TTS(ctx context.Context, in *proto.TTSRequest, opts ...grpc.CallOption) (*proto.Result, error) {
117-
return m.LLMClient.TTS(ctx, in, opts...)
141+
func (m *anyToAnyModel) TTS(ctx context.Context, in *proto.TTSRequest, opts ...grpc.CallOption) (*proto.Result, string, error) {
142+
// TODO: Handle file generation if needed for anyToAnyModel
143+
res, err := m.LLMClient.TTS(ctx, in, opts...)
144+
return res, in.Dst, err
118145
}
119146

120147
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
@@ -155,6 +182,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfig
155182
VADClient: VADClient,
156183
TranscriptionConfig: cfgSST,
157184
TranscriptionClient: transcriptionClient,
185+
appConfig: appConfig,
158186
}, cfgSST, nil
159187
}
160188

@@ -266,5 +294,6 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
266294

267295
VADConfig: cfgVAD,
268296
VADClient: VADClient,
297+
appConfig: appConfig,
269298
}, nil
270299
}

0 commit comments

Comments
 (0)