diff --git a/internal/core/plugin_registry.go b/internal/core/plugin_registry.go index c55feab82e..d821814469 100644 --- a/internal/core/plugin_registry.go +++ b/internal/core/plugin_registry.go @@ -24,6 +24,7 @@ import ( "github.com/danielmiessler/fabric/internal/plugins/ai/dryrun" "github.com/danielmiessler/fabric/internal/plugins/ai/exolab" "github.com/danielmiessler/fabric/internal/plugins/ai/gemini" + "github.com/danielmiessler/fabric/internal/plugins/ai/llamacpp" "github.com/danielmiessler/fabric/internal/plugins/ai/lmstudio" "github.com/danielmiessler/fabric/internal/plugins/ai/ollama" "github.com/danielmiessler/fabric/internal/plugins/ai/openai" @@ -83,6 +84,7 @@ func NewPluginRegistry(db *fsdb.Db) (ret *PluginRegistry, err error) { gemini.NewClient(), anthropic.NewClient(), vertexai.NewClient(), + llamacpp.NewClient(), lmstudio.NewClient(), exolab.NewClient(), perplexity.NewClient(), diff --git a/internal/i18n/locales/en.json b/internal/i18n/locales/en.json index 71190e34c3..3a900ca370 100644 --- a/internal/i18n/locales/en.json +++ b/internal/i18n/locales/en.json @@ -323,6 +323,16 @@ "list_all_vendors": "List all vendors", "list_gemini_tts_voices": "List all available Gemini TTS voices", "list_transcription_models": "List all available transcription models", + "llamacpp_api_url_question": "Enter your llama.cpp server URL (default: %v)", + "llamacpp_error_reading_response": "error reading response: %w", + "llamacpp_failed_create_request": "failed to create request: %w", + "llamacpp_failed_decode_response": "failed to decode response: %w", + "llamacpp_failed_marshal_payload": "failed to marshal payload: %w", + "llamacpp_failed_send_request": "failed to send request: %w", + "llamacpp_invalid_response_missing_choices": "invalid response format: missing or empty choices", + "llamacpp_invalid_response_missing_content": "invalid response format: missing or non-string content in message", + "llamacpp_invalid_response_missing_message": "invalid response format: missing message in first choice", + "llamacpp_unexpected_status_code": "unexpected status code: %d", "lmstudio_api_url_question": "Enter your %v URL (as a reminder, it is usually %v)", "lmstudio_error_reading_response": "error reading response: %w", "lmstudio_failed_create_request": "failed to create request: %w", diff --git a/internal/plugins/ai/llamacpp/llamacpp.go b/internal/plugins/ai/llamacpp/llamacpp.go new file mode 100644 index 0000000000..0b8c5cde34 --- /dev/null +++ b/internal/plugins/ai/llamacpp/llamacpp.go @@ -0,0 +1,257 @@ +package llamacpp + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/danielmiessler/fabric/internal/chat" + "github.com/danielmiessler/fabric/internal/domain" + "github.com/danielmiessler/fabric/internal/i18n" + "github.com/danielmiessler/fabric/internal/plugins" +) + +const defaultBaseURL = "http://localhost:8080/v1" + +func NewClient() *Client { + ret := &Client{} + ret.PluginBase = plugins.NewVendorPluginBase("llama.cpp", ret.configure) + ret.ApiUrl = ret.AddSetupQuestionCustom("API URL", true, + fmt.Sprintf(i18n.T("llamacpp_api_url_question"), defaultBaseURL)) + ret.ApiKey = ret.AddSetupQuestion("API key", false) + ret.ApiUrl.Value = defaultBaseURL + return ret +} + +type Client struct { + *plugins.PluginBase + ApiUrl *plugins.SetupQuestion + ApiKey *plugins.SetupQuestion + HttpClient *http.Client +} + +func (c *Client) configure() error { + c.HttpClient = &http.Client{} + return nil +} + +func (c *Client) ListModels(_ context.Context) ([]string, error) { + url := fmt.Sprintf("%s/models", c.ApiUrl.Value) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf(i18n.T("llamacpp_failed_create_request"), err) + } + c.addAuthorizationHeader(req) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, fmt.Errorf(i18n.T("llamacpp_failed_send_request"), err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf(i18n.T("llamacpp_unexpected_status_code"), resp.StatusCode) + } + + var result struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf(i18n.T("llamacpp_failed_decode_response"), err) + } + + models := make([]string, len(result.Data)) + for i, model := range result.Data { + models[i] = model.ID + } + return models, nil +} + +func (c *Client) SendStream(_ context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) { + url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value) + + payload := map[string]any{ + "messages": msgs, + "model": opts.Model, + "stream": true, + "cache_prompt": true, // reuse KV cache across requests for the same prefix + "stream_options": map[string]any{ + "include_usage": true, + }, + } + + var jsonPayload []byte + if jsonPayload, err = json.Marshal(payload); err != nil { + err = fmt.Errorf(i18n.T("llamacpp_failed_marshal_payload"), err) + return + } + + var req *http.Request + if req, err = http.NewRequest("POST", url, bytes.NewBuffer(jsonPayload)); err != nil { + err = fmt.Errorf(i18n.T("llamacpp_failed_create_request"), err) + return + } + + req.Header.Set("Content-Type", "application/json") + c.addAuthorizationHeader(req) + + var resp *http.Response + if resp, err = c.HttpClient.Do(req); err != nil { + err = fmt.Errorf(i18n.T("llamacpp_failed_send_request"), err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf(i18n.T("llamacpp_unexpected_status_code"), resp.StatusCode) + return + } + + defer close(channel) + + reader := bufio.NewReader(resp.Body) + for { + var line []byte + if line, err = reader.ReadBytes('\n'); err != nil { + if err == io.EOF { + err = nil + break + } + err = fmt.Errorf(i18n.T("llamacpp_error_reading_response"), err) + return + } + + if len(line) == 0 { + continue + } + + if after, ok := bytes.CutPrefix(line, []byte("data: ")); ok { + line = after + } + + if string(bytes.TrimSpace(line)) == "[DONE]" { + break + } + + var result map[string]any + if err = json.Unmarshal(line, &result); err != nil { + continue + } + + if usage, ok := result["usage"].(map[string]any); ok { + var metadata domain.UsageMetadata + if val, ok := usage["prompt_tokens"].(float64); ok { + metadata.InputTokens = int(val) + } + if val, ok := usage["completion_tokens"].(float64); ok { + metadata.OutputTokens = int(val) + } + if val, ok := usage["total_tokens"].(float64); ok { + metadata.TotalTokens = int(val) + } + channel <- domain.StreamUpdate{Type: domain.StreamTypeUsage, Usage: &metadata} + } + + var choices []any + var ok bool + if choices, ok = result["choices"].([]any); !ok || len(choices) == 0 { + continue + } + + var delta map[string]any + if delta, ok = choices[0].(map[string]any)["delta"].(map[string]any); !ok { + continue + } + + var content string + if content, _ = delta["content"].(string); content != "" { + channel <- domain.StreamUpdate{Type: domain.StreamTypeContent, Content: content} + } + } + + return +} + +func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (content string, err error) { + url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value) + + payload := map[string]any{ + "messages": msgs, + "model": opts.Model, + "cache_prompt": true, + } + + var jsonPayload []byte + if jsonPayload, err = json.Marshal(payload); err != nil { + err = fmt.Errorf(i18n.T("llamacpp_failed_marshal_payload"), err) + return + } + + var req *http.Request + if req, err = http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload)); err != nil { + err = fmt.Errorf(i18n.T("llamacpp_failed_create_request"), err) + return + } + + req.Header.Set("Content-Type", "application/json") + c.addAuthorizationHeader(req) + + var resp *http.Response + if resp, err = c.HttpClient.Do(req); err != nil { + err = fmt.Errorf(i18n.T("llamacpp_failed_send_request"), err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf(i18n.T("llamacpp_unexpected_status_code"), resp.StatusCode) + return + } + + var result map[string]any + if err = json.NewDecoder(resp.Body).Decode(&result); err != nil { + err = fmt.Errorf(i18n.T("llamacpp_failed_decode_response"), err) + return + } + + var choices []any + var ok bool + if choices, ok = result["choices"].([]any); !ok || len(choices) == 0 { + err = errors.New(i18n.T("llamacpp_invalid_response_missing_choices")) + return + } + + var message map[string]any + if message, ok = choices[0].(map[string]any)["message"].(map[string]any); !ok { + err = errors.New(i18n.T("llamacpp_invalid_response_missing_message")) + return + } + + if content, ok = message["content"].(string); !ok { + err = errors.New(i18n.T("llamacpp_invalid_response_missing_content")) + return + } + + return +} + +func (c *Client) addAuthorizationHeader(req *http.Request) { + if c.ApiKey == nil { + return + } + apiKey := strings.TrimSpace(c.ApiKey.Value) + if apiKey == "" { + return + } + req.Header.Set("Authorization", "Bearer "+apiKey) +}