diff --git a/backend/backend.proto b/backend/backend.proto index 3dca83878396..454cc0230ef0 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -15,6 +15,7 @@ service Backend { rpc PredictStream(PredictOptions) returns (stream Reply) {} rpc Embedding(PredictOptions) returns (EmbeddingResult) {} rpc GenerateImage(GenerateImageRequest) returns (Result) {} + rpc UpscaleImage(UpscaleImageRequest) returns (Result) {} rpc GenerateVideo(GenerateVideoRequest) returns (Result) {} rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} rpc AudioTranscriptionStream(TranscriptRequest) returns (stream TranscriptStreamResponse) {} @@ -516,6 +517,12 @@ message GenerateImageRequest { repeated string ref_images = 12; } +message UpscaleImageRequest { + string src = 1; // input image path + string dst = 2; // output image path + int32 scale = 3; // upscale factor (e.g. 2 or 4) +} + message GenerateVideoRequest { string prompt = 1; string negative_prompt = 2; // Negative prompt for video generation diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index 00b292292949..4c20978d1a29 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -869,6 +869,34 @@ def GenerateImage(self, request, context): return backend_pb2.Result(message="Media generated", success=True) + def UpscaleImage(self, request, context): + try: + if not request.src: + return backend_pb2.Result(success=False, message="No source image provided") + if not request.dst: + return backend_pb2.Result(success=False, message="No destination path provided") + + scale = request.scale if request.scale > 0 else 2 + image = Image.open(request.src).convert("RGB") + + # If the loaded pipeline supports upscaling (e.g. StableDiffusionUpscalePipeline), + # use it; otherwise fall back to high-quality Lanczos resize. + if self.pipe is not None and self.PipelineType in ("StableDiffusionUpscalePipeline", "StableDiffusionLatentUpscalePipeline"): + print(f"UpscaleImage: using diffusers upscale pipeline ({self.PipelineType})", file=sys.stderr) + upscaled = self.pipe(prompt="", image=image).images[0] + else: + # Fallback: high-quality Lanczos resize + print(f"UpscaleImage: no upscale pipeline loaded, using Lanczos resize (scale={scale})", file=sys.stderr) + new_w = image.width * scale + new_h = image.height * scale + upscaled = image.resize((new_w, new_h), Image.LANCZOS) + + upscaled.save(request.dst) + return backend_pb2.Result(message="Image upscaled", success=True) + except Exception as e: + print(f"UpscaleImage error: {e}", file=sys.stderr) + return backend_pb2.Result(success=False, message=str(e)) + def GenerateVideo(self, request, context): try: prompt = request.prompt diff --git a/core/backend/upscale.go b/core/backend/upscale.go new file mode 100644 index 000000000000..c821f56cd92c --- /dev/null +++ b/core/backend/upscale.go @@ -0,0 +1,37 @@ +package backend + +import ( + "context" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/grpc/proto" + model "github.com/mudler/LocalAI/pkg/model" +) + +// ImageUpscale loads the model specified in modelConfig and calls UpscaleImage +// on the backend, writing the result to dst. +func ImageUpscale(ctx context.Context, src, dst string, scale int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) { + opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx)) + inferenceModel, err := loader.Load(opts...) + if err != nil { + recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) + return nil, err + } + + fn := func() error { + _, err := inferenceModel.UpscaleImage( + ctx, + &proto.UpscaleImageRequest{ + Src: src, + Dst: dst, + Scale: int32(scale), + }, + ) + return err + } + + return fn, nil +} + +// ImageUpscaleFunc is a test-friendly indirection. +var ImageUpscaleFunc = ImageUpscale diff --git a/core/http/auth/features.go b/core/http/auth/features.go index 77199580a7a5..8b0f0f4ccd5a 100644 --- a/core/http/auth/features.go +++ b/core/http/auth/features.go @@ -39,6 +39,8 @@ var RouteFeatureRegistry = []RouteFeature{ {"POST", "/images/generations", FeatureImages}, {"POST", "/v1/images/inpainting", FeatureImages}, {"POST", "/images/inpainting", FeatureImages}, + {"POST", "/v1/images/upscale", FeatureImages}, + {"POST", "/images/upscale", FeatureImages}, // Audio transcription {"POST", "/v1/audio/transcriptions", FeatureAudioTranscription}, diff --git a/core/http/endpoints/openai/upscale.go b/core/http/endpoints/openai/upscale.go new file mode 100644 index 000000000000..940cb515238a --- /dev/null +++ b/core/http/endpoints/openai/upscale.go @@ -0,0 +1,135 @@ +package openai + +import ( + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/mudler/xlog" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" + model "github.com/mudler/LocalAI/pkg/model" +) + +// UpscaleEndpoint handles POST /v1/images/upscale +// +// @Summary Image upscaling +// @Description Upscale an image using a specified model (e.g. realesrgan). Accepts multipart/form-data. +// @Tags images +// @Accept multipart/form-data +// @Produce application/json +// @Param model formData string true "Upscaler model identifier (e.g. realesrgan)" +// @Param image formData file true "Input image file" +// @Param scale formData int false "Upscale factor: 2 or 4 (default 2)" +// @Success 200 {object} schema.OpenAIResponse +// @Failure 400 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /v1/images/upscale [post] +func UpscaleEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + modelName := c.FormValue("model") + scaleStr := c.FormValue("scale") + + if modelName == "" { + xlog.Error("Upscale Endpoint - missing model") + return echo.NewHTTPError(http.StatusBadRequest, "missing model") + } + + scale := 2 + if scaleStr != "" { + if v, err := strconv.Atoi(scaleStr); err == nil && (v == 2 || v == 4) { + scale = v + } + } + + // Read uploaded image + imageFile, err := c.FormFile("image") + if err != nil { + xlog.Error("Upscale Endpoint - missing image file", "error", err) + return echo.NewHTTPError(http.StatusBadRequest, "missing image file") + } + + imgSrc, err := imageFile.Open() + if err != nil { + return err + } + defer imgSrc.Close() + imgBytes, err := io.ReadAll(imgSrc) + if err != nil { + return err + } + + // Get model config from middleware context + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + xlog.Error("Upscale Endpoint - model config not found in context") + return echo.ErrBadRequest + } + + tmpDir := appConfig.GeneratedContentDir + if err := os.MkdirAll(tmpDir, 0750); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "failed to prepare storage") + } + + // Write input image to a temp file + srcTmp, err := os.CreateTemp(tmpDir, "upscale_src_") + if err != nil { + return err + } + if _, err := srcTmp.Write(imgBytes); err != nil { + _ = srcTmp.Close() + _ = os.Remove(srcTmp.Name()) + return err + } + if err := srcTmp.Close(); err != nil { + xlog.Warn("Upscale Endpoint - failed to close src temp file", "error", err) + } + srcPath := srcTmp.Name() + defer os.Remove(srcPath) + + // Prepare output file path + id := uuid.New().String() + dstPath := filepath.Join(tmpDir, fmt.Sprintf("upscale_%s.png", id)) + defer func() { + // Only remove on error; success path keeps the file for serving + }() + + fn, err := backend.ImageUpscaleFunc(c.Request().Context(), srcPath, dstPath, scale, ml, *cfg, appConfig) + if err != nil { + return err + } + if err := fn(); err != nil { + _ = os.Remove(dstPath) + return err + } + + baseURL := middleware.BaseURL(c) + imgURL, err := url.JoinPath(baseURL, "generated-images", filepath.Base(dstPath)) + if err != nil { + _ = os.Remove(dstPath) + return err + } + + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Data: []schema.Item{{URL: imgURL}}, + Usage: &schema.OpenAIUsage{ + InputTokensDetails: &schema.InputTokensDetails{}, + }, + } + + return c.JSON(http.StatusOK, resp) + } +} diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index 8a13935aefbe..51212d56dd27 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -232,6 +232,11 @@ func RegisterOpenAIRoutes(app *echo.Echo, app.POST("/v1/images/inpainting", inpaintingHandler, imageMiddleware...) app.POST("/images/inpainting", inpaintingHandler, imageMiddleware...) + // upscale endpoint - reuse same middleware config as images + upscaleHandler := openai.UpscaleEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) + app.POST("/v1/images/upscale", upscaleHandler, imageMiddleware...) + app.POST("/images/upscale", upscaleHandler, imageMiddleware...) + // List models app.GET("/v1/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.AuthDB())) app.GET("/models", openai.ListModelsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.AuthDB())) diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index ead95d1952c5..33ca87a430af 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -49,6 +49,7 @@ type Backend interface { PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) + UpscaleImage(ctx context.Context, in *pb.UpscaleImageRequest, opts ...grpc.CallOption) (*pb.Result, error) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 24417e4c2914..e10a26c62bc9 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -58,6 +58,10 @@ func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error { return fmt.Errorf("unimplemented") } +func (llm *Base) UpscaleImage(*pb.UpscaleImageRequest) error { + return fmt.Errorf("unimplemented") +} + func (llm *Base) AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) { return pb.TranscriptResult{}, fmt.Errorf("unimplemented") } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index b6a148186958..e5a2d56195ff 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -234,6 +234,24 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, return client.GenerateImage(ctx, in, opts...) } +func (c *Client) UpscaleImage(ctx context.Context, in *pb.UpscaleImageRequest, opts ...grpc.CallOption) (*pb.Result, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := c.dial() + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.UpscaleImage(ctx, in, opts...) +} + func (c *Client) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index b9f08ddb42d8..7cc17ec07d8c 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -49,6 +49,10 @@ func (e *embedBackend) GenerateImage(ctx context.Context, in *pb.GenerateImageRe return e.s.GenerateImage(ctx, in) } +func (e *embedBackend) UpscaleImage(ctx context.Context, in *pb.UpscaleImageRequest, opts ...grpc.CallOption) (*pb.Result, error) { + return e.s.UpscaleImage(ctx, in) +} + func (e *embedBackend) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) { return e.s.GenerateVideo(ctx, in) } diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 31b9ab26deb6..72960701940b 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -17,6 +17,7 @@ type AIModel interface { Free() error Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error + UpscaleImage(*pb.UpscaleImageRequest) error GenerateVideo(*pb.GenerateVideoRequest) error Detect(*pb.DetectOptions) (pb.DetectResponse, error) FaceVerify(*pb.FaceVerifyRequest) (pb.FaceVerifyResponse, error) diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 5be668497b77..85c176c2b7a6 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -87,6 +87,18 @@ func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) return &pb.Result{Message: "Image generated", Success: true}, nil } +func (s *server) UpscaleImage(ctx context.Context, in *pb.UpscaleImageRequest) (*pb.Result, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + err := s.llm.UpscaleImage(in) + if err != nil { + return &pb.Result{Message: fmt.Sprintf("Error upscaling image: %s", err.Error()), Success: false}, err + } + return &pb.Result{Message: "Image upscaled", Success: true}, nil +} + func (s *server) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest) (*pb.Result, error) { if s.llm.Locking() { s.llm.Lock()