From f0b721c9490ac4c07da703bb6d9eede5f1de0a4a Mon Sep 17 00:00:00 2001 From: James Benze Date: Fri, 23 Aug 2024 19:00:00 -0400 Subject: [PATCH] Initial Commit --- CODE_OF_CONDUCT.md | 74 ++++++++++++ LICENSE.txt | 21 ++++ README.md | 74 ++++++++++++ SECURITY.md | 31 +++++ SUPPORT.md | 11 ++ agent/service.go | 214 +++++++++++++++++++++++++++++++++++ config/info.go | 57 ++++++++++ copilot/endpoints.go | 76 +++++++++++++ copilot/messages.go | 44 +++++++ data/app_configuration.md | 4 + data/payload_verification.md | 3 + data/request_format.md | 18 +++ data/response_format.md | 13 +++ embedding/datasets.go | 88 ++++++++++++++ go.mod | 19 ++++ go.sum | 33 ++++++ main.go | 111 ++++++++++++++++++ oauth/handler.go | 93 +++++++++++++++ 18 files changed, 984 insertions(+) create mode 100644 CODE_OF_CONDUCT.md create mode 100644 LICENSE.txt create mode 100644 README.md create mode 100644 SECURITY.md create mode 100644 SUPPORT.md create mode 100644 agent/service.go create mode 100644 config/info.go create mode 100644 copilot/endpoints.go create mode 100644 copilot/messages.go create mode 100644 data/app_configuration.md create mode 100644 data/payload_verification.md create mode 100644 data/request_format.md create mode 100644 data/response_format.md create mode 100644 embedding/datasets.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 oauth/handler.go diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..a1f82f0 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,74 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of experience, +nationality, personal appearance, race, religion, or sexual identity and +orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..89bc5e9 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright GitHub, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..7e6ff6a --- /dev/null +++ b/README.md @@ -0,0 +1,74 @@ +# Function Calling Extensions Sample + +> [!NOTE] +> To use Copilot Extensions, you must be enrolled in the limited public beta. +> +> All enrolled users with a GitHub Copilot Individual subscription can use Copilot Extensions. +> +> For enrolled organizations or enterprises with a Copilot Business or Copilot Enterprise subscription, organization owners and enterprise administrators can grant access to Copilot Extensions. + +## Description +This project is a Go application that demonstrates how to use function calling in a GitHub Copilot Extension. + +## Prerequisites + +- Go 1.16 or higher +- Set the following environment variables (example below): + +``` +export PORT=8080 +export CLIENT_ID=Iv1.0ae52273ad3193eb // the application id +export CLIENT_SECRET="your_client_secret" // generate a new client secret for your application +export FQDN=https://6de513480979.ngrok.app // use ngrok to expose a url +``` + +## Installation: +1. Clone the repository: + +``` +git clone git@github.com:copilot-extensions/rag-extension.git +cd rag-extension +``` + +2. Install dependencies: + +``` +go mod tidy +``` + +## Usage + +1. Start up ngrok with the port provided: + +``` +ngrok http http://localhost:8080 +``` + +2. Set the environment variables (use the ngrok generated url for the `FDQN`) +3. Run the application: + +``` +go run . +``` + +## Accessing the Agent in Chat: + +1. In the `Copilot` tab of your Application settings (`https://github.com/settings/apps//agent`) +- Set the URL that was set for your FQDN above with the endpoint `/agent` (e.g. `https://6de513480979.ngrok.app/agent`) +- Set the Pre-Authorization URL with the endpoint `/auth/authorization` (e.g. `https://6de513480979.ngrok.app/auth/authorization`) +2. In the `General` tab of your application settings (`https://github.com/settings/apps/`) +- Set the `Callback URL` with the `/auth/callback` endpoint (e.g. `https://6de513480979.ngrok.app/auth/callback`) +- Set the `Homepage URL` with the base ngrok endpoint (e.g. `https://6de513480979.ngrok.app/auth/callback`) +3. Ensure your permissions are enabled in `Permissions & events` > +- `Account Permissions` > `Copilot Chat` > `Access: Read Only` +4. Ensure you install your application at (`https://github.com/apps/`) +5. Now if you go to `https://github.com/copilot` you can `@` your agent using the name of your application. + +## What Can It Do + +Test out the agent with the following commands! + +| Description | Prompt | +| --- |--- | +| User asking `@agent` how to configure a Copilot extension | `@agent How do I configure a copilot extension?` | +| User asking `@agent` what a Copilot extension looks like | `@agent What is the response format for a copilot extension?` | diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..4279c87 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,31 @@ +Thanks for helping make GitHub safe for everyone. + +# Security + +GitHub takes the security of our software products and services seriously, including all of the open source code repositories managed through our GitHub organizations, such as [GitHub](https://github.com/GitHub). + +Even though [open source repositories are outside of the scope of our bug bounty program](https://bounty.github.com/index.html#scope) and therefore not eligible for bounty rewards, we will ensure that your finding gets passed along to the appropriate maintainers for remediation. + +## Reporting Security Issues + +If you believe you have found a security vulnerability in any GitHub-owned repository, please report it to us through coordinated disclosure. + +**Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.** + +Instead, please send an email to opensource-security[@]github.com. + +Please include as much of the information listed below as you can to help us better understand and resolve the issue: + + * The type of issue (e.g., buffer overflow, SQL injection, or cross-site scripting) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +## Policy + +See [GitHub's Safe Harbor Policy](https://docs.github.com/en/site-policy/security-policies/github-bug-bounty-program-legal-safe-harbor#1-safe-harbor-terms) \ No newline at end of file diff --git a/SUPPORT.md b/SUPPORT.md new file mode 100644 index 0000000..80273bd --- /dev/null +++ b/SUPPORT.md @@ -0,0 +1,11 @@ +# Support + +## How to file issues and get help + +This project uses GitHub issues to track bugs and feature requests. Please search the existing issues before filing new issues to avoid duplicates. For new issues, file your bug or feature request as a new issue. + +- **THIS PROJECT NAME** is under active development and maintained by GitHub staff **AND THE COMMUNITY**. We will do our best to respond to support, feature requests, and community questions in a timely manner. + +## GitHub Support Policy + +Support for this project is limited to the resources listed above. \ No newline at end of file diff --git a/agent/service.go b/agent/service.go new file mode 100644 index 0000000..140970a --- /dev/null +++ b/agent/service.go @@ -0,0 +1,214 @@ +package agent + +import ( + "bufio" + "context" + "crypto/ecdsa" + "crypto/sha256" + "encoding/asn1" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "math/big" + "net/http" + "os" + "path/filepath" + "sync" + + "github.com/copilot-extensions/rag-extension/copilot" + "github.com/copilot-extensions/rag-extension/embedding" +) + +// Service provides and endpoint for this agent to perform chat completions +type Service struct { + pubKey *ecdsa.PublicKey + + // Singleton + datasets []*embedding.Dataset + datasetsInit *sync.Once +} + +func NewService(pubKey *ecdsa.PublicKey) *Service { + return &Service{ + pubKey: pubKey, + datasetsInit: &sync.Once{}, + } +} + +func (s *Service) ChatCompletion(w http.ResponseWriter, r *http.Request) { + sig := r.Header.Get("Github-Public-Key-Signature") + + body, err := io.ReadAll(r.Body) + if err != nil { + fmt.Println(fmt.Errorf("failed to read request body: %w", err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + + // Make sure the payload matches the signature. In this way, you can be sure + // that an incoming request comes from github + isValid, err := validPayload(body, sig, s.pubKey) + if err != nil { + fmt.Printf("failed to validate payload signature: %v\n", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if !isValid { + http.Error(w, "invalid payload signature", http.StatusUnauthorized) + return + } + + apiToken := r.Header.Get("X-GitHub-Token") + integrationID := r.Header.Get("Copilot-Integration-Id") + + var req *copilot.ChatRequest + if err := json.Unmarshal(body, &req); err != nil { + fmt.Printf("failed to unmarshal request: %v\n", err) + w.WriteHeader(http.StatusBadRequest) + return + } + if err := s.generateCompletion(r.Context(), integrationID, apiToken, req, w); err != nil { + fmt.Printf("failed to execute agent: %v\n", err) + w.WriteHeader(http.StatusInternalServerError) + } +} + +func (s *Service) generateCompletion(ctx context.Context, integrationID, apiToken string, req *copilot.ChatRequest, w io.Writer) error { + // Initialize the datasets. In a real application, these would be generated + // ahead of time and stored in a database + var err error + s.datasetsInit.Do(func() { + var files []fs.DirEntry + files, err = os.ReadDir("data") + if err != nil { + err = fmt.Errorf("error reading files from \"data\" directory: %w", err) + return + } + + filenames := make([]string, len(files)) + for i, file := range files { + filenames[i] = filepath.Join("data", file.Name()) + } + + s.datasets, err = embedding.GenerateDatasets(integrationID, apiToken, filenames) + if err != nil { + err = fmt.Errorf("error generating datasets: %w", err) + return + } + }) + if err != nil { + return err + } + + var messages []copilot.ChatMessage + + // Create embeddings from user messages + for i := len(req.Messages) - 1; i >= 0; i++ { + msg := req.Messages[i] + if msg.Role != "user" { + continue + } + + // Filter empty messages + if msg.Content == "" { + continue + } + + emb, err := embedding.Create(ctx, integrationID, apiToken, msg.Content) + if err != nil { + return fmt.Errorf("error creating embedding for user message: %w", err) + } + + // Load most appropriate dataset + dataset, err := embedding.FindBestDataset(s.datasets, emb) + if err != nil { + return fmt.Errorf("error computing best dataset") + } + + if dataset == nil { + break + } + + fmt.Printf("loading dataset: %s\n", dataset.Filename) + + file, err := os.Open(dataset.Filename) + if err != nil { + return fmt.Errorf("failed to open documents: %w", err) + } + + fileContents, err := io.ReadAll(file) + if err != nil { + return fmt.Errorf("failed to read documents: %w", err) + } + + messages = append(messages, copilot.ChatMessage{ + Role: "system", + Content: "You are a helpful assistant that replies to user messages. Use the following context when responding to a message.\n" + + "Context: " + string(fileContents), + }) + + break + } + + messages = append(messages, req.Messages...) + + chatReq := &copilot.ChatCompletionsRequest{ + Model: copilot.ModelGPT35, + Messages: messages, + Stream: true, + } + + stream, err := copilot.ChatCompletions(ctx, "copilot-chat", apiToken, chatReq) + if err != nil { + return fmt.Errorf("failed to get chat completions stream: %w", err) + } + defer stream.Close() + + reader := bufio.NewScanner(stream) + for reader.Scan() { + buf := reader.Bytes() + _, err := w.Write(buf) + if err != nil { + return fmt.Errorf("failed to write to stream: %w", err) + } + + if _, err := w.Write([]byte("\n")); err != nil { + return fmt.Errorf("failed to write delimiter to stream: %w", err) + } + } + + if err := reader.Err(); err != nil { + if errors.Is(err, io.EOF) { + return nil + } + + return fmt.Errorf("failed to read from stream: %w", err) + } + + return nil +} + +// asn1Signature is a struct for ASN.1 serializing/parsing signatures. +type asn1Signature struct { + R *big.Int + S *big.Int +} + +func validPayload(data []byte, sig string, publicKey *ecdsa.PublicKey) (bool, error) { + asnSig, err := base64.StdEncoding.DecodeString(sig) + parsedSig := asn1Signature{} + if err != nil { + return false, err + } + rest, err := asn1.Unmarshal(asnSig, &parsedSig) + if err != nil || len(rest) != 0 { + return false, err + } + + // Verify the SHA256 encoded payload against the signature with GitHub's Key + digest := sha256.Sum256(data) + return ecdsa.Verify(publicKey, digest[:], parsedSig.R, parsedSig.S), nil +} diff --git a/config/info.go b/config/info.go new file mode 100644 index 0000000..f64e858 --- /dev/null +++ b/config/info.go @@ -0,0 +1,57 @@ +package config + +import ( + "fmt" + "os" +) + +type Info struct { + // Port is the local port on which the application will run + Port string + + // FQDN (for Fully-Qualified Domain Name) is the internet facing host address + // where application will live (e.g. https://example.com) + FQDN string + + // ClientID comes from your configured GitHub app + ClientID string + + // ClientSecret comes from your configured GitHub app + ClientSecret string +} + +const ( + portEnv = "PORT" + clientIdEnv = "CLIENT_ID" + clientSecretEnv = "CLIENT_SECRET" + fqdnEnv = "FQDN" +) + +func New() (*Info, error) { + port := os.Getenv(portEnv) + if port == "" { + return nil, fmt.Errorf("%s environment variable required", portEnv) + } + + fqdn := os.Getenv(fqdnEnv) + if fqdn == "" { + return nil, fmt.Errorf("%s environment variable required", fqdnEnv) + } + + clientID := os.Getenv(clientIdEnv) + if clientID == "" { + return nil, fmt.Errorf("%s environment variable required", clientIdEnv) + } + + clientSecret := os.Getenv(clientSecretEnv) + if clientSecret == "" { + return nil, fmt.Errorf("%s environment variable required", clientSecretEnv) + } + + return &Info{ + Port: port, + FQDN: fqdn, + ClientID: clientID, + ClientSecret: clientSecret, + }, nil +} diff --git a/copilot/endpoints.go b/copilot/endpoints.go new file mode 100644 index 0000000..8f9f214 --- /dev/null +++ b/copilot/endpoints.go @@ -0,0 +1,76 @@ +package copilot + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +func ChatCompletions(ctx context.Context, integrationID, apiKey string, req *ChatCompletionsRequest) (io.ReadCloser, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.githubcopilot.com/chat/completions", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+apiKey) + if integrationID != "" { + httpReq.Header.Set("Copilot-Integration-Id", integrationID) + } + + resp, err := (&http.Client{}).Do(httpReq) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + fmt.Println(string(b)) + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return resp.Body, nil +} + +func Embeddings(ctx context.Context, integrationID string, token string, req *EmbeddingsRequest) (*EmbeddingsResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.githubcopilot.com/embeddings", bytes.NewReader(body)) + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+token) + if integrationID != "" { + httpReq.Header.Set("Copilot-Integration-Id", integrationID) + } + + resp, err := (&http.Client{}).Do(httpReq) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + fmt.Println(string(b)) + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var embeddingsResponse *EmbeddingsResponse + err = json.NewDecoder(resp.Body).Decode(&embeddingsResponse) + if err != nil { + return nil, fmt.Errorf("failed to decode request body") + } + + return embeddingsResponse, nil +} diff --git a/copilot/messages.go b/copilot/messages.go new file mode 100644 index 0000000..fb2336a --- /dev/null +++ b/copilot/messages.go @@ -0,0 +1,44 @@ +package copilot + +type ChatRequest struct { + Messages []ChatMessage `json:"messages"` +} + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Model string + +const ( + ModelGPT35 Model = "gpt-3.5-turbo" + ModelGPT4 Model = "gpt-4" + ModelEmbeddings Model = "text-embedding-ada-002" +) + +type ChatCompletionsRequest struct { + Messages []ChatMessage `json:"messages"` + Model Model `json:"model"` + Stream bool `json:"stream"` +} + +type EmbeddingsRequest struct { + Model Model `json:"model"` + Input []string `json:"input"` +} + +type EmbeddingsResponse struct { + Data []*EmbeddingsResponseData `json:"data"` + Usage *EmbeddingsResponseUsage `json:"usage"` +} + +type EmbeddingsResponseData struct { + Embedding []float32 `json:"embedding"` + Index int `json:"index"` +} + +type EmbeddingsResponseUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} diff --git a/data/app_configuration.md b/data/app_configuration.md new file mode 100644 index 0000000..ecfb42b --- /dev/null +++ b/data/app_configuration.md @@ -0,0 +1,4 @@ +Configuring a Copilot App +========================= + +To set up your copilot agent you must first create a github app. This is found under "settings > developer settings > github apps > New GitHub App". After the app is created, you want to click on the "Copilot" tab, and fill out the URL and app description. It is recommended that you also set the pre-authorization URL to an endpoint where you can start the OAuth process. In the general tab, set the callback url to the location that the user should be sent after authorization is complete. diff --git a/data/payload_verification.md b/data/payload_verification.md new file mode 100644 index 0000000..b436b26 --- /dev/null +++ b/data/payload_verification.md @@ -0,0 +1,3 @@ +We support request payload signatures via asymmetric key signing. This allows you to verify that a request came from GitHub and was intended for your agent plugin. All agent requests will contain 2 headers: Github-Public-Key-Identifier and Github-Public-Key-Signature. + +To verify the signature, you can compare the signature provided in the Github-Public-Key-Signature header with a signed copy of the request body, using a public key found at https://api.github.com/meta/public_keys/copilot_api. The public signature verification docs is a great reference for how to do this and provides a few language examples. Below is a simple example in typescript. diff --git a/data/request_format.md b/data/request_format.md new file mode 100644 index 0000000..f9450f0 --- /dev/null +++ b/data/request_format.md @@ -0,0 +1,18 @@ +Requests to your agent will be `application/json` formatted. An example curl request to your agent may look like this: + +``` +curl --request POST \ + --url $AGENT_URL \ + --header 'Accept: application/json' \ + --header 'Content-Type: application/json' \ + --header "X-GitHub-Token: $RUNTIME_GENERATED_TOKEN" \ + --data '{ + "messages": [ + { + "role": "user", + "content": "What is a closure in javascript?", + "copilot_references": [] + } + ] + }' +``` diff --git a/data/response_format.md b/data/response_format.md new file mode 100644 index 0000000..bea087e --- /dev/null +++ b/data/response_format.md @@ -0,0 +1,13 @@ +Responses back to the copilot platform should be in the form of server-side-events streamed from your agent. Some example events that could be streamed back are: + +``` +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]} + +.... + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + +data: [DONE] +``` diff --git a/embedding/datasets.go b/embedding/datasets.go new file mode 100644 index 0000000..ac35f81 --- /dev/null +++ b/embedding/datasets.go @@ -0,0 +1,88 @@ +package embedding + +import ( + "context" + "fmt" + "io" + "math" + "os" + + "github.com/copilot-extensions/rag-extension/copilot" +) + +func Create(ctx context.Context, integrationID, apiToken string, content string) ([]float32, error) { + resp, err := copilot.Embeddings(ctx, integrationID, apiToken, &copilot.EmbeddingsRequest{ + Model: copilot.ModelEmbeddings, + Input: []string{content}, + }) + + if err != nil { + return nil, fmt.Errorf("error fetching embeddings: %w", err) + } + + for _, data := range resp.Data { + return data.Embedding, nil + } + + return nil, fmt.Errorf("no embeddings found in response") +} + +type Dataset struct { + Embedding []float32 + Filename string +} + +func GenerateDatasets(integrationID, apiToken string, filenames []string) ([]*Dataset, error) { + datasets := make([]*Dataset, len(filenames)) + for i, filename := range filenames { + file, err := os.Open(filename) + if err != nil { + return nil, fmt.Errorf("error reading in file %s: %w", filename, err) + } + + fileContent, err := io.ReadAll(file) + + embedding, err := Create(context.Background(), integrationID, apiToken, string(fileContent)) + if err != nil { + return nil, fmt.Errorf("error creating embedding for file %s: %w", filename, err) + } + + datasets[i] = &Dataset{ + Embedding: embedding, + Filename: filename, + } + } + + return datasets, nil +} + +func FindBestDataset(datasets []*Dataset, target []float32) (*Dataset, error) { + var bestDataset *Dataset + var bestScore float32 + + var targetMagnitude float32 + for i := 0; i < len(target); i++ { + targetMagnitude += target[i] * target[i] + } + + for _, dataset := range datasets { + // Score similarity using Cosine Similarity + if len(target) != len(dataset.Embedding) { + return nil, fmt.Errorf("embeddings are different length, cannot compare") + } + + var docMagnitude, dotProduct float32 + for i := 0; i < len(target); i++ { + docMagnitude += dataset.Embedding[i] * dataset.Embedding[i] + dotProduct += target[i] * dataset.Embedding[i] + } + + dotProduct /= float32(math.Sqrt(float64(targetMagnitude)) * math.Sqrt(float64(docMagnitude))) + if dotProduct > bestScore { + bestDataset = dataset + bestScore = dotProduct + } + } + + return bestDataset, nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..22872e7 --- /dev/null +++ b/go.mod @@ -0,0 +1,19 @@ +module github.com/copilot-extensions/rag-extension + +go 1.21.6 + +require ( + github.com/google/go-github/v57 v57.0.0 + github.com/google/uuid v1.6.0 + github.com/invopop/jsonschema v0.12.0 + github.com/wk8/go-ordered-map/v2 v2.1.8 + golang.org/x/oauth2 v0.22.0 +) + +require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/google/go-querystring v1.1.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7ab7f98 --- /dev/null +++ b/go.sum @@ -0,0 +1,33 @@ +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-github/v57 v57.0.0 h1:L+Y3UPTY8ALM8x+TV0lg+IEBI+upibemtBD8Q9u7zHs= +github.com/google/go-github/v57 v57.0.0/go.mod h1:s0omdnye0hvK/ecLvpsGfJMiRt85PimQh4oygmLIxHw= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI= +github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +golang.org/x/oauth2 v0.22.0 h1:BzDx2FehcG7jJwgWLELCdmLuxk2i+x9UDpSiss2u0ZA= +golang.org/x/oauth2 v0.22.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..84218f1 --- /dev/null +++ b/main.go @@ -0,0 +1,111 @@ +package main + +import ( + "crypto/ecdsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "net/http" + "net/url" + "os" + "strings" + + "github.com/copilot-extensions/rag-extension/agent" + "github.com/copilot-extensions/rag-extension/config" + "github.com/copilot-extensions/rag-extension/oauth" +) + +func main() { + if err := run(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func run() error { + pubKey, err := fetchPublicKey() + if err != nil { + return fmt.Errorf("failed to fetch public key: %w", err) + } + + config, err := config.New() + if err != nil { + return fmt.Errorf("error fetching config: %w", err) + } + + me, err := url.Parse(config.FQDN) + if err != nil { + return fmt.Errorf("unable to parse HOST environment variable: %w", err) + } + + me.Path = "auth/callback" + + oauthService := oauth.NewService(config.ClientID, config.ClientSecret, me.String()) + http.HandleFunc("/auth/authorization", oauthService.PreAuth) + http.HandleFunc("/auth/callback", oauthService.PostAuth) + + agentService := agent.NewService(pubKey) + + http.HandleFunc("/agent", agentService.ChatCompletion) + + fmt.Println("Listening on port", config.Port) + return http.ListenAndServe(":"+config.Port, nil) +} + +// fetchPublicKey fetches the keys used to sign messages from copilot. Checking +// the signature with one of these keys verifies that the request to the +// completions API comes from GitHub and not elsewhere on the internet. +func fetchPublicKey() (*ecdsa.PublicKey, error) { + resp, err := http.Get("https://api.github.com/meta/public_keys/copilot_api") + if err != nil { + return nil, fmt.Errorf("failed to fetch public key: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to fetch public key: %s", resp.Status) + } + + var respBody struct { + PublicKeys []struct { + Key string `json:"key"` + IsCurrent bool `json:"is_current"` + } `json:"public_keys"` + } + if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { + return nil, fmt.Errorf("failed to decode public key: %w", err) + } + + var rawKey string + for _, pk := range respBody.PublicKeys { + if pk.IsCurrent { + rawKey = pk.Key + break + } + } + if rawKey == "" { + return nil, fmt.Errorf("could not find current public key") + } + + pubPemStr := strings.ReplaceAll(rawKey, "\\n", "\n") + // Decode the Public Key + block, _ := pem.Decode([]byte(pubPemStr)) + if block == nil { + return nil, fmt.Errorf("error parsing PEM block with GitHub public key") + } + + // Create our ECDSA Public Key + key, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, err + } + + // Because of documentation, we know it's a *ecdsa.PublicKey + ecdsaKey, ok := key.(*ecdsa.PublicKey) + if !ok { + return nil, fmt.Errorf("GitHub key is not ECDSA") + } + + return ecdsaKey, nil +} diff --git a/oauth/handler.go b/oauth/handler.go new file mode 100644 index 0000000..8f1ff58 --- /dev/null +++ b/oauth/handler.go @@ -0,0 +1,93 @@ +package oauth + +import ( + "fmt" + "github.com/google/uuid" + "golang.org/x/oauth2" + "net/http" +) + +// Service provides endpoints to allow this agent to be authorized. +type Service struct { + conf *oauth2.Config +} + +func NewService(clientID, clientSecret, callback string) *Service { + return &Service{ + conf: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: callback, + Endpoint: oauth2.Endpoint{ + AuthURL: "https://github.com/login/oauth/authorize", + TokenURL: "https://github.com/login/oauth/access_token", + }, + }, + } +} + +const ( + STATE_COOKIE = "oauth_state" +) + +// PreAuth is the landing page that the user arrives at when they first attempt +// to use the agent while unauthorized. You can do anything you want here, +// including making sure the user has an account on your side. At some point, +// you'll probably want to make a call to the authorize endpoint to authorize +// the app. +func (s *Service) PreAuth(w http.ResponseWriter, r *http.Request) { + // In our example, we're not doing anything except going through the + // authorization flow. This is standard Oauth2. + + verifier := oauth2.GenerateVerifier() + state := uuid.New() + + url := s.conf.AuthCodeURL(state.String(), oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier)) + stateCookie := &http.Cookie{ + Name: STATE_COOKIE, + Value: state.String(), + MaxAge: 10 * 60, // 10 minutes in seconds + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + + http.SetCookie(w, stateCookie) + w.Header().Set("location", url) + w.WriteHeader(http.StatusFound) +} + +// PostAuth is the landing page where the user lads after authorizing. As +// above, you can do anything you want here. A common thing you might do is +// get the user information and then perform some sort of account linking in +// your database. +func (s *Service) PostAuth(w http.ResponseWriter, r *http.Request) { + state := r.URL.Query().Get("state") + code := r.URL.Query().Get("code") + + stateCookie, err := r.Cookie(STATE_COOKIE) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("state cookie not found")) + return + } + + // Important: Compare the state! This prevents CSRF attacks + if state != stateCookie.Value { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("invalid state")) + return + } + + _, err = s.conf.Exchange(r.Context(), code) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("error exchange code for token: %v", err))) + return + } + + // Response contains an access token, now the world is your oyster. Get user information and perform account linking, or do whatever you want from here. + + w.WriteHeader(http.StatusOK) + w.Write([]byte("All done! Please return to the app")) +}