diff --git a/internal/copilot/copilot.go b/internal/copilot/copilot.go index 59200323..5f37daa0 100644 --- a/internal/copilot/copilot.go +++ b/internal/copilot/copilot.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "path/filepath" "runtime" + "strconv" "strings" "time" @@ -17,9 +19,15 @@ import ( ) const ( - copilotChatAuthURL = "https://api.github.com/copilot_internal/v2/token" - copilotEditorVersion = "vscode/1.95.3" - copilotUserAgent = "curl/7.81.0" // Necessay to bypass the user-agent check + copilotAuthDeviceCodeURL = "https://github.com/login/device/code" + copilotAuthTokenURL = "https://github.com/login/oauth/access_token" // #nosec G101 + copilotChatAuthURL = "https://api.github.com/copilot_internal/v2/token" + copilotEditorVersion = "vscode/1.95.3" + copilotUserAgent = "curl/7.81.0" // Necessay to bypass the user-agent check + + // if you change this, don't forget to update the + // `OAuthToken` json struct tag + copilotClientID = "Iv1.b507a08c87ecfe98" ) // AccessToken response from GitHub Copilot's token endpoint. @@ -40,6 +48,36 @@ type AccessToken struct { } `json:"error_details,omitempty"` } +type DeviceCodeResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +type DeviceTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + Error string `json:"error,omitempty"` +} + +type FailedRequestResponse struct { + DocumentationURL string `json:"documentation_url"` + Message string `json:"message"` +} + +type OAuthTokenWrapper struct { + User string `json:"user"` + OAuthToken string `json:"oauth_token"` + GithubAppID string `json:"githubAppId"` +} + +type OAuthToken struct { + GithubWrapper OAuthTokenWrapper `json:"github.com:Iv1.b507a08c87ecfe98"` +} + // Client copilot client. type Client struct { client *http.Client @@ -83,30 +121,235 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { return httpResp, nil } -func getCopilotRefreshToken() (string, error) { +func Login(client *http.Client, configPath string) (string, error) { + data := strings.NewReader(fmt.Sprintf("client_id=%s&scope=copilot", copilotClientID)) + req, err := http.NewRequest("POST", copilotAuthDeviceCodeURL, data) + if err != nil { + return "", err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to get device code: %w", err) + } + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to decode device code response: %w", err) + } + + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("error closing response body: %w", closeErr) + } + }() + + deviceCodeResp := DeviceCodeResponse{} + + parsedData, err := url.ParseQuery(string(responseBody)) + if err != nil { + return "", fmt.Errorf("failed to parse device code response: %w", err) + } + + deviceCodeResp.UserCode = parsedData.Get("user_code") + deviceCodeResp.ExpiresIn, _ = strconv.Atoi(parsedData.Get("expires_in")) + deviceCodeResp.Interval, _ = strconv.Atoi(parsedData.Get("interval")) + deviceCodeResp.DeviceCode = parsedData.Get("device_code") + deviceCodeResp.VerificationURI = parsedData.Get("verification_uri") + + fmt.Printf("Please go to %s and enter the code %s\n", deviceCodeResp.VerificationURI, deviceCodeResp.UserCode) + oAuthToken, err := fetchRefreshToken(client, deviceCodeResp.DeviceCode, deviceCodeResp.Interval, deviceCodeResp.ExpiresIn) + + if err != nil { + return "", err + } + + err = saveOAuthToken( + OAuthToken{ + GithubWrapper: OAuthTokenWrapper{ + User: "", + OAuthToken: oAuthToken.AccessToken, + GithubAppID: copilotClientID, + }, + }, + configPath, + ) + + if err != nil { + return "", err + } + + return oAuthToken.AccessToken, nil +} + +func fetchRefreshToken(client *http.Client, deviceCode string, interval int, expiresIn int) (DeviceTokenResponse, error) { + var accessTokenResp DeviceTokenResponse + var errResp FailedRequestResponse + + // Adds a delay to give the user time to open + // the browser and type the code + time.Sleep(30 * time.Second) + + endTime := time.Now().Add(time.Duration(expiresIn) * time.Second) + ticker := time.NewTicker(time.Duration(interval) * time.Second) + + defer ticker.Stop() + + for range ticker.C { + if time.Now().After(endTime) { + return DeviceTokenResponse{}, fmt.Errorf("authorization polling timeout") + } + + fmt.Println("Trying to fetch token...") + data := strings.NewReader( + fmt.Sprintf( + "client_id=%s&device_code=%s&grant_type=urn:ietf:params:oauth:grant-type:device_code", + copilotClientID, + deviceCode, + ), + ) + req, err := http.NewRequest("POST", copilotAuthTokenURL, data) + if err != nil { + return DeviceTokenResponse{}, err + } + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return DeviceTokenResponse{}, err + } + + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("error closing response body: %w", closeErr) + } + }() + + isRequestFailed := resp.StatusCode != 200 + + if isRequestFailed { + if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil { + return DeviceTokenResponse{}, err + } + + return DeviceTokenResponse{}, fmt.Errorf( + "failed to check refresh token\n\tMessage: %s\n\tDocumentation: %s", + errResp.Message, + errResp.DocumentationURL, + ) + } + + if err := json.NewDecoder(resp.Body).Decode(&accessTokenResp); err != nil { + return DeviceTokenResponse{}, err + } + + if accessTokenResp.AccessToken != "" { + return accessTokenResp, nil + } + + if accessTokenResp.Error != "" { + // Handle errors like "authorization_pending" or "expired_token" appropriately + if accessTokenResp.Error != "authorization_pending" { + return DeviceTokenResponse{}, fmt.Errorf("token error: %s", accessTokenResp.Error) + } + } + } + + return DeviceTokenResponse{}, fmt.Errorf("authorization polling failed or timed out") +} + +// Registers `mods` as an application that uses copilot +// NOTE: Only if initial config not available. +// TODO: Add support for when the user already has an oAuthToken +func registerApp(versionsPath string) error { + versions := make(map[string]string) + + data, err := os.ReadFile(versionsPath) + if err == nil { + // File exists, unmarshal contents + if err := json.Unmarshal(data, &versions); err != nil { + return fmt.Errorf("error parsing versions file: %w", err) + } + } + + // Add/update our entry + // TODO: How can we import this? Create a `meta.go`? + //versions["mods"] = main.Version + + updatedData, err := json.Marshal(versions) + if err != nil { + return fmt.Errorf("error marshaling versions data: %w", err) + } + + return os.WriteFile(versionsPath, updatedData, 0o640) +} + +func saveOAuthToken(oAuthToken OAuthToken, configPath string) error { + fileContent, err := json.Marshal(oAuthToken) + + if err != nil { + return fmt.Errorf("error mashaling oAuthToken: %e", err) + } + + configDir := filepath.Dir(configPath) + if err = os.MkdirAll(configDir, 0o700); err != nil { + return fmt.Errorf("error creating config directory: %e", err) + } + + err = os.WriteFile(configPath, fileContent, 0o700) + if err != nil { + return fmt.Errorf("error writing oAuthToken to %s: %e", configPath, err) + } + + versionsPath := filepath.Join(filepath.Dir(configPath), "versions.json") + err = registerApp(versionsPath) + if err != nil { + return fmt.Errorf("error registering mods as copilot app %e", err) + } + + return nil +} + +func getOAuthToken(client *http.Client) (string, error) { configPath := filepath.Join(os.Getenv("HOME"), ".config/github-copilot") if runtime.GOOS == "windows" { configPath = filepath.Join(os.Getenv("LOCALAPPDATA"), "github-copilot") } + // Support both legacy and current config file locations + legacyConfigPath := filepath.Join(configPath, "hosts.json") + currentConfigPath := filepath.Join(configPath, "apps.json") + // Check both possible config file locations configFiles := []string{ - filepath.Join(configPath, "hosts.json"), - filepath.Join(configPath, "apps.json"), + legacyConfigPath, + currentConfigPath, } // Try to get token from config files for _, path := range configFiles { - token, err := extractCopilotTokenFromFile(path) + token, err := extractTokenFromFile(path) if err == nil && token != "" { return token, nil } } - return "", fmt.Errorf("no token found in %s", strings.Join(configFiles, ", ")) + // Try to login in into Copilot + token, err := Login(client, currentConfigPath) + if err != nil { + return "", fmt.Errorf("failed to login into Copilot: %w", err) + } + + if token != "" { + return token, nil + } + + return "", fmt.Errorf("empty token") } -func extractCopilotTokenFromFile(path string) (string, error) { +func extractTokenFromFile(path string) (string, error) { bytes, err := os.ReadFile(path) if err != nil { return "", fmt.Errorf("failed to read Copilot configuration file at %s: %w", path, err) @@ -145,9 +388,9 @@ func (c *Client) Auth() (AccessToken, error) { } } - refreshToken, err := getCopilotRefreshToken() + refreshToken, err := getOAuthToken(c.client) if err != nil { - return AccessToken{}, fmt.Errorf("failed to get refresh token: %w", err) + return AccessToken{}, fmt.Errorf("failed to get oAuth token: %w", err) } tokenReq, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, copilotChatAuthURL, nil)