Skip to content

Commit efd6ecc

Browse files
authored
New provider and models API and CLI (#13865)
### What problem does this PR solve? As title. ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
1 parent 68b4287 commit efd6ecc

File tree

13 files changed

+524
-95
lines changed

13 files changed

+524
-95
lines changed

cmd/server_main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,10 @@ func startServer(config *server.Config) {
193193
searchHandler := handler.NewSearchHandler(searchService, userService)
194194
fileHandler := handler.NewFileHandler(fileService, userService)
195195
memoryHandler := handler.NewMemoryHandler(memoryService)
196+
providerHandler := handler.NewProviderHandler(userService)
196197

197198
// Initialize router
198-
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler)
199+
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, providerHandler)
199200

200201
// Create Gin engine
201202
ginEngine := gin.New()

internal/admin/handler.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ import (
2121
"fmt"
2222
"net/http"
2323
"ragflow/internal/common"
24+
"ragflow/internal/dao"
2425
"ragflow/internal/logger"
2526
"ragflow/internal/server"
2627
"ragflow/internal/service"
2728
"ragflow/internal/utility"
2829
"strconv"
30+
"strings"
2931
"time"
3032

3133
"github.com/gin-gonic/gin"
@@ -794,6 +796,115 @@ func (h *Handler) RestartService(c *gin.Context) {
794796
success(c, result, "")
795797
}
796798

799+
func (h *Handler) ListProviders(c *gin.Context) {
800+
801+
keywords := ""
802+
if queryKeywords := c.Query("available"); queryKeywords != "" {
803+
keywords = queryKeywords
804+
}
805+
806+
// convert keywords to small case
807+
keywords = strings.ToLower(keywords)
808+
if keywords == "true" {
809+
// list pool providers
810+
providers, err := dao.GetModelProviderManager().ListProviders()
811+
if err != nil {
812+
c.JSON(http.StatusOK, gin.H{
813+
"code": common.CodeNotFound,
814+
"message": err.Error(),
815+
})
816+
return
817+
}
818+
819+
c.JSON(http.StatusOK, gin.H{
820+
"code": 0,
821+
"message": "success",
822+
"data": providers,
823+
})
824+
}
825+
}
826+
827+
func (h *Handler) ShowProvider(c *gin.Context) {
828+
providerName := c.Param("provider_name")
829+
if providerName == "" {
830+
c.JSON(http.StatusBadRequest, gin.H{
831+
"code": 400,
832+
"message": "Provider name is required",
833+
})
834+
return
835+
}
836+
837+
provider, err := dao.GetModelProviderManager().GetProviderByName(providerName)
838+
if err != nil {
839+
c.JSON(http.StatusOK, gin.H{
840+
"code": common.CodeNotFound,
841+
"message": err.Error(),
842+
})
843+
return
844+
}
845+
c.JSON(http.StatusOK, gin.H{
846+
"code": 0,
847+
"message": "success",
848+
"data": provider,
849+
})
850+
}
851+
852+
func (h *Handler) ListModels(c *gin.Context) {
853+
providerName := c.Param("provider_name")
854+
if providerName == "" {
855+
c.JSON(http.StatusBadRequest, gin.H{
856+
"code": 400,
857+
"message": "Provider name is required",
858+
})
859+
return
860+
}
861+
models, err := dao.GetModelProviderManager().ListModels(providerName)
862+
if err != nil {
863+
c.JSON(http.StatusOK, gin.H{
864+
"code": common.CodeNotFound,
865+
"message": err.Error(),
866+
})
867+
return
868+
}
869+
c.JSON(http.StatusOK, gin.H{
870+
"code": 0,
871+
"message": "success",
872+
"data": models,
873+
})
874+
}
875+
876+
func (h *Handler) ShowModel(c *gin.Context) {
877+
providerName := c.Param("provider_name")
878+
if providerName == "" {
879+
c.JSON(http.StatusBadRequest, gin.H{
880+
"code": 400,
881+
"message": "Provider name is required",
882+
})
883+
return
884+
}
885+
modelName := c.Param("model_name")
886+
if modelName == "" {
887+
c.JSON(http.StatusBadRequest, gin.H{
888+
"code": 400,
889+
"message": "Model name is required",
890+
})
891+
return
892+
}
893+
model, err := dao.GetModelProviderManager().GetModelByName(providerName, modelName)
894+
if err != nil {
895+
c.JSON(http.StatusOK, gin.H{
896+
"code": common.CodeNotFound,
897+
"message": err.Error(),
898+
})
899+
return
900+
}
901+
c.JSON(http.StatusOK, gin.H{
902+
"code": 0,
903+
"message": "success",
904+
"data": model,
905+
})
906+
}
907+
797908
// GetVariables handle get variables
798909
// Python logic: if request body is empty, list all variables; otherwise get single variable by var_name from body
799910
func (h *Handler) GetVariables(c *gin.Context) {

internal/admin/router.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
package admin
1818

1919
import (
20-
"ragflow/internal/handler"
21-
2220
"github.com/gin-gonic/gin"
2321
)
2422

@@ -48,15 +46,6 @@ func (r *Router) Setup(engine *gin.Engine) {
4846

4947
admin.POST("/reports", r.handler.Reports)
5048

51-
// provider pool route group
52-
provider := admin.Group("providers")
53-
{
54-
provider.GET("/", handler.ListPoolProviders)
55-
provider.GET("/:provider_name", handler.ShowPoolProvider)
56-
provider.GET("/:provider_name/models", handler.ListPoolModels)
57-
provider.GET("/:provider_name/models/:model_name", handler.ShowPoolModel)
58-
}
59-
6049
// Protected routes
6150
protected := admin.Group("")
6251
protected.Use(r.handler.AuthMiddleware())
@@ -136,6 +125,14 @@ func (r *Router) Setup(engine *gin.Engine) {
136125
// Log level
137126
protected.GET("/log_level", r.handler.GetLogLevel)
138127
protected.PUT("/log_level", r.handler.SetLogLevel)
128+
129+
provider := protected.Group("/providers")
130+
{
131+
provider.GET("/", r.handler.ListProviders)
132+
provider.GET("/:provider_name", r.handler.ShowProvider)
133+
provider.GET("/:provider_name/models", r.handler.ListModels)
134+
provider.GET("/:provider_name/models/:model_name", r.handler.ShowModel)
135+
}
139136
}
140137
}
141138

internal/cli/admin_parser.go

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,10 @@ func (p *Parser) parseAdminListCommand() (*Command, error) {
166166
return p.parseAdminListModelProviders()
167167
case TokenDefault:
168168
return p.parseAdminListDefaultModels()
169-
case TokenPool:
170-
return p.parseCommonListPoolModels()
169+
case TokenAvailable:
170+
return p.parseCommonListProviders()
171+
case TokenModels:
172+
return p.parseListModelsOfProvider()
171173
case TokenChats:
172174
p.nextToken()
173175
// Semicolon is optional for SHOW TOKEN
@@ -273,31 +275,38 @@ func (p *Parser) parseAdminListDefaultModels() (*Command, error) {
273275
return NewCommand("list_user_default_models"), nil
274276
}
275277

276-
func (p *Parser) parseCommonListPoolModels() (*Command, error) {
277-
p.nextToken() // consume POOL
278-
if p.curToken.Type == TokenProviders {
279-
return NewCommand("list_pool_providers"), nil
280-
} else if p.curToken.Type == TokenModels {
281-
p.nextToken()
282-
if p.curToken.Type != TokenFrom {
283-
return nil, fmt.Errorf("expected FROM")
284-
}
285-
p.nextToken()
286-
providerName, err := p.parseQuotedString()
287-
if err != nil {
288-
return nil, err
289-
}
290-
cmd := NewCommand("list_pool_models")
291-
cmd.Params["provider_name"] = providerName
278+
func (p *Parser) parseListModelsOfProvider() (*Command, error) {
279+
if p.curToken.Type != TokenModels {
280+
return nil, fmt.Errorf("expected MODELS")
281+
}
282+
283+
p.nextToken()
284+
if p.curToken.Type != TokenFrom {
285+
return nil, fmt.Errorf("expected FROM")
286+
}
287+
p.nextToken()
288+
providerName, err := p.parseQuotedString()
289+
if err != nil {
290+
return nil, err
291+
}
292+
cmd := NewCommand("list_provider_models")
293+
cmd.Params["provider_name"] = providerName
294+
p.nextToken()
295+
// Semicolon is optional for UNSET TOKEN
296+
if p.curToken.Type == TokenSemicolon {
292297
p.nextToken()
293-
// Semicolon is optional for UNSET TOKEN
294-
if p.curToken.Type == TokenSemicolon {
295-
p.nextToken()
296-
}
297-
return cmd, nil
298-
} else {
299-
return nil, fmt.Errorf("expected PROVIDERS or MODELS")
300298
}
299+
return cmd, nil
300+
}
301+
302+
func (p *Parser) parseCommonListProviders() (*Command, error) {
303+
p.nextToken() // consume AVAILABLE
304+
305+
if p.curToken.Type != TokenProviders {
306+
return nil, fmt.Errorf("expected PROVIDERS")
307+
}
308+
309+
return NewCommand("list_available_providers"), nil
301310
}
302311

303312
func (p *Parser) parseCommonShowPoolModel() (*Command, error) {
@@ -409,8 +418,10 @@ func (p *Parser) parseAdminShowCommand() (*Command, error) {
409418
return p.parseShowVariable()
410419
case TokenService:
411420
return p.parseShowService()
412-
case TokenPool:
413-
return p.parseCommonShowPoolModel()
421+
case TokenProvider:
422+
return p.parseShowProvider()
423+
case TokenModel:
424+
return p.parseShowModel()
414425
default:
415426
return nil, fmt.Errorf("unknown SHOW target: %s", p.curToken.Value)
416427
}

internal/cli/cli.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,15 +1006,17 @@ Commands (User Mode):
10061006
LIST MODEL PROVIDERS; - List model providers
10071007
LIST DEFAULT MODELS; - List default models
10081008
LIST TOKENS; - List API tokens
1009+
LIST PROVIDERS; - List available LLM providers
10091010
CREATE TOKEN; - Create new API token
1011+
CREATE PROVIDER 'name'; - Create a provider without API key
1012+
CREATE PROVIDER 'name' 'api_key'; - Create a provider with API key
10101013
DROP TOKEN 'token_value'; - Delete an API token
1014+
DROP PROVIDER 'name'; - Delete a provider
10111015
SET TOKEN 'token_value'; - Set and validate API token
10121016
SHOW TOKEN; - Show current API token
1017+
SHOW PROVIDER 'name'; - Show provider details
10131018
UNSET TOKEN; - Remove current API token
1014-
CREATE INDEX FOR DATASET 'name' VECTOR_SIZE N; - Create index for dataset
1015-
DROP INDEX FOR DATASET 'name'; - Drop index for dataset
1016-
CREATE INDEX DOC_META; - Create doc meta index
1017-
DROP INDEX DOC_META; - Drop doc meta index
1019+
ALTER PROVIDER 'name' NAME 'new_name'; - Rename a provider
10181020
10191021
Context Engine Commands (no quotes):
10201022
ls [path] - List resources

internal/cli/client.go

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ func (c *RAGFlowClient) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) {
152152
return c.ListAdminTokens(cmd)
153153
case "drop_token":
154154
return c.DropAdminToken(cmd)
155-
case "list_pool_providers":
156-
return c.ListPoolProviders(cmd)
157-
case "show_pool_provider":
158-
return c.ShowPoolProvider(cmd)
159-
case "list_pool_models":
160-
return c.ListPoolModels(cmd)
161-
case "show_pool_model":
162-
return c.ShowPoolModel(cmd)
155+
case "list_available_providers":
156+
return c.ListAvailableProviders(cmd)
157+
case "show_provider":
158+
return c.ShowProvider(cmd)
159+
case "list_provider_models":
160+
return c.ListModels(cmd)
161+
case "show_model":
162+
return c.ShowModel(cmd)
163163
// TODO: Implement other commands
164164
default:
165165
return nil, fmt.Errorf("command '%s' would be executed with API", cmd.Type)
@@ -204,13 +204,20 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) {
204204
case "drop_doc_meta_index":
205205
return c.DropDocMetaIndex(cmd)
206206
case "list_pool_providers":
207-
return c.ListPoolProviders(cmd)
208-
case "show_pool_provider":
209-
return c.ShowPoolProvider(cmd)
210-
case "list_pool_models":
211-
return c.ListPoolModels(cmd)
212-
case "show_pool_model":
213-
return c.ShowPoolModel(cmd)
207+
return c.ListAvailableProviders(cmd)
208+
case "show_provider":
209+
return c.ShowProvider(cmd)
210+
case "list_provider_models":
211+
return c.ListModels(cmd)
212+
case "show_model":
213+
return c.ShowModel(cmd)
214+
// Provider commands
215+
case "create_provider":
216+
return c.CreateProvider(cmd)
217+
case "list_providers":
218+
return c.ListProviders(cmd)
219+
case "drop_provider":
220+
return c.DropProvider(cmd)
214221
// ContextEngine commands
215222
case "ce_ls":
216223
return c.CEList(cmd)

internal/cli/common_command.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,13 @@ func (c *RAGFlowClient) Logout() (ResponseIf, error) {
237237
return &result, nil
238238
}
239239

240-
func (c *RAGFlowClient) ListPoolProviders(cmd *Command) (ResponseIf, error) {
240+
func (c *RAGFlowClient) ListAvailableProviders(cmd *Command) (ResponseIf, error) {
241241

242242
var endPoint string
243243
if c.ServerType == "admin" {
244-
endPoint = fmt.Sprintf("/admin/providers")
244+
endPoint = fmt.Sprintf("/admin/providers?available=true")
245245
} else {
246-
endPoint = fmt.Sprintf("/providers")
246+
endPoint = fmt.Sprintf("/providers?available=true")
247247
}
248248

249249
resp, err := c.HTTPClient.Request("GET", endPoint, true, "web", nil, nil)
@@ -267,7 +267,7 @@ func (c *RAGFlowClient) ListPoolProviders(cmd *Command) (ResponseIf, error) {
267267
return &result, nil
268268
}
269269

270-
func (c *RAGFlowClient) ShowPoolProvider(cmd *Command) (ResponseIf, error) {
270+
func (c *RAGFlowClient) ShowProvider(cmd *Command) (ResponseIf, error) {
271271
providerName, ok := cmd.Params["provider_name"].(string)
272272
if !ok {
273273
return nil, fmt.Errorf("provider_name not provided")
@@ -301,7 +301,7 @@ func (c *RAGFlowClient) ShowPoolProvider(cmd *Command) (ResponseIf, error) {
301301
return &result, nil
302302
}
303303

304-
func (c *RAGFlowClient) ListPoolModels(cmd *Command) (ResponseIf, error) {
304+
func (c *RAGFlowClient) ListModels(cmd *Command) (ResponseIf, error) {
305305

306306
providerName, ok := cmd.Params["provider_name"].(string)
307307
if !ok {
@@ -336,7 +336,7 @@ func (c *RAGFlowClient) ListPoolModels(cmd *Command) (ResponseIf, error) {
336336
return &result, nil
337337
}
338338

339-
func (c *RAGFlowClient) ShowPoolModel(cmd *Command) (ResponseIf, error) {
339+
func (c *RAGFlowClient) ShowModel(cmd *Command) (ResponseIf, error) {
340340
providerName, ok := cmd.Params["provider_name"].(string)
341341
if !ok {
342342
return nil, fmt.Errorf("provider_name not provided")

internal/cli/lexer.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,10 @@ func (l *Lexer) lookupIdent(ident string) Token {
301301
return Token{Type: TokenVectorSize, Value: ident}
302302
case "DOC_META":
303303
return Token{Type: TokenDocMeta, Value: ident}
304-
case "POOL":
305-
return Token{Type: TokenPool, Value: ident}
304+
case "AVAILABLE":
305+
return Token{Type: TokenAvailable, Value: ident}
306+
case "NAME":
307+
return Token{Type: TokenName, Value: ident}
306308
default:
307309
return Token{Type: TokenIdentifier, Value: ident}
308310
}

0 commit comments

Comments
 (0)