Skip to content

Custom providers

yetone edited this page Nov 14, 2024 · 9 revisions

Cody

See https://github.com/yetone/avante.nvim/pull/810/files

local M = {}

M.role_map = {
  user = "human",
  assistant = "assistant",
  system = "system",
}

M.parse_messages = function(opts)
  local messages = {
    { role = "system", content = opts.system_prompt },
  }
  vim
    .iter(opts.messages)
    :each(function(msg) table.insert(messages, { speaker = M.role_map[msg.role], text = msg.content }) end)
  return messages
end

M.parse_response = function(data_stream, event_state, opts)
  if event_state == "done" then
    opts.on_complete()
    return
  end

  if data_stream == nil or data_stream == "" then return end

  local json = vim.json.decode(data_stream)
  local delta = json.deltaText
  local stopReason = json.stopReason

  if stopReason == "end_turn" then return end

  opts.on_chunk(delta)
end

---@type AvanteProvider
["my-custom-provider"] = {
  endpoint = "https://sourcegraph.com",
  model = "anthropic::2024-10-22::claude-3-5-sonnet-latest",
  api_key_name = "SRC_ACCESS_TOKEN",
  --- This function below will be used to parse in cURL arguments.
  --- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer.
  --- This code_opts include:
  --- - question: Input from the users
  --- - code_lang: the language of given code buffer
  --- - code_content: content of code buffer
  --- - selected_code_content: (optional) If given code content is selected in visual mode as context.
  ---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
  parse_curl_args = function(opts, code_opts)
  local headers = {
    ["Content-Type"] = "application/json",
    ["Authorization"] = "token " .. os.getenv(opts.api_key_name),
  }

  return {
    url = opts.endpoint .. "/.api/completions/stream?api-version=2&client-name=web&client-version=0.0.1",
    timeout = base.timeout,
    insecure = false,
    headers = headers,
    body = vim.tbl_deep_extend("force", {
      model = opts.model,
      temperature = 0,
      topK = -1,
      topP = -1,
      maxTokensToSample = 4000,
      stream = true,
      messages = M.parse_messages(code_opts),
    }, {}),
  }
  end,
  ---@type fun(data_stream: string, event_state: string, opts: ResponseParser): nil
  parse_response = function(data_stream, event_state, opts) M.parse_response(data_stream, event_state, opts) end
}

Groq, Perplexity, Deepseek

vendors = {
  ---@type AvanteProvider
  perplexity = {
    endpoint = "https://api.perplexity.ai/chat/completions",
    model = "llama-3.1-sonar-large-128k-online",
    api_key_name = "cmd:bw get notes perplexity-api-key",
    parse_curl_args = function(opts, code_opts)
      return {
        url = opts.endpoint,
        headers = {
          ["Accept"] = "application/json",
          ["Content-Type"] = "application/json",
          ["Authorization"] = "Bearer " .. os.getenv(opts.api_key_name),
        },
        body = {
          model = opts.model,
          messages = { -- you can make your own message, but this is very advanced
            { role = "system", content = code_opts.system_prompt },
            { role = "user", content = require("avante.providers.openai").get_user_message(code_opts) },
          },
          temperature = 0,
          max_tokens = 8192,
          stream = true, -- this will be set by default.
        },
      }
    end,
    -- The below function is used if the vendors has specific SSE spec that is not claude or openai.
    parse_response = function(data_stream, event_state, opts)
      require("avante.providers").openai.parse_response(data_stream, event_state, opts)
    end,
  },
  ---@type AvanteProvider
  groq = {
    endpoint = "https://api.groq.com/openai/v1/chat/completions",
    model = "llama-3.1-70b-versatile",
    api_key_name = "GROQ_API_KEY",
    parse_curl_args = function(opts, code_opts)
      return {
        url = opts.endpoint,
        headers = {
          ["Accept"] = "application/json",
          ["Content-Type"] = "application/json",
          ["Authorization"] = "Bearer " .. os.getenv(opts.api_key_name),
        },
        body = {
          model = opts.model,
          messages = { -- you can make your own message, but this is very advanced
            { role = "system", content = code_opts.system_prompt },
            { role = "user", content = require("avante.providers.openai").get_user_message(code_opts) },
          },
          temperature = 0,
          max_tokens = 4096,
          stream = true, -- this will be set by default.
        },
      }
    end,
    parse_response = function(data_stream, event_state, opts)
      require("avante.providers").openai.parse_response(data_stream, event_state, opts)
    end,
  },
  ---@type AvanteProvider
  deepseek = {
    endpoint = "https://api.deepseek.com/chat/completions",
    model = "deepseek-coder",
    api_key_name = "DEEPSEEK_API_KEY",
    parse_curl_args = function(opts, code_opts)
      return {
        url = opts.endpoint,
        headers = {
          ["Accept"] = "application/json",
          ["Content-Type"] = "application/json",
          ["Authorization"] = "Bearer " .. os.getenv(opts.api_key_name),
        },
        body = {
          model = opts.model,
          messages = { -- you can make your own message, but this is very advanced
            { role = "system", content = code_opts.system_prompt },
            { role = "user", content = require("avante.providers.openai").get_user_message(code_opts) },
          },
          temperature = 0,
          max_tokens = 4096,
          stream = true, -- this will be set by default.
        },
      }
    end,
    parse_response = function(data_stream, event_state, opts)
      require("avante.providers").openai.parse_response(data_stream, event_state, opts)
    end,
  },
}

custom parser for line call [ADVANCED ONLY]

If certain providers don't follow SSE streaming spec, you might want to implement parse_stream_data for your custom providers.

See parse_and_call implementation for more information.

local llms.

If you want to use local LLM that has a OpenAI-compatible server, set ["local"] = true:

       provider = "ollama",
       vendors = {
         ---@type AvanteProvider
         ollama = {
           ["local"] = true,
           endpoint = "127.0.0.1:11434/v1",
           model = "codegemma",
           parse_curl_args = function(opts, code_opts)
             return {
               url = opts.endpoint .. "/chat/completions",
               headers = {
                 ["Accept"] = "application/json",
                 ["Content-Type"] = "application/json",
               },
               body = {
                 model = opts.model,
                 messages = require("avante.providers").copilot.parse_messages(code_opts), -- you can make your own message, but this is very advanced
                 max_tokens = 2048,
                 stream = true,
               },
             }
           end,
           parse_response = function(data_stream, event_state, opts)
             require("avante.providers").copilot.parse_response(data_stream, event_state, opts)
           end,
         },
       },

You will be responsible for setting up the server yourself before using Neovim.

Clone this wiki locally