diff --git a/credentials/providers/cli_profile.go b/credentials/providers/cli_profile.go index 5bf503a..fc5e378 100644 --- a/credentials/providers/cli_profile.go +++ b/credentials/providers/cli_profile.go @@ -309,6 +309,24 @@ func (provider *CLIProfileCredentialsProvider) GetProviderName() string { return "cli_profile" } +// findSourceOAuthProfile 递归查找 OAuth source profile +func (conf *configuration) findSourceOAuthProfile(profileName string) (*profile, error) { + profile, err := conf.getProfile(profileName) + if err != nil { + return nil, fmt.Errorf("unable to get profile with name '%s' from cli credentials file: %v", profileName, err) + } + + if profile.Mode == "OAuth" { + return profile, nil + } + + if profile.SourceProfile != "" { + return conf.findSourceOAuthProfile(profile.SourceProfile) + } + + return nil, fmt.Errorf("unable to get OAuth profile with name '%s' from cli credentials file", profileName) +} + // updateOAuthTokens 更新OAuth令牌并写回配置文件 func (provider *CLIProfileCredentialsProvider) updateOAuthTokens(refreshToken, accessToken, accessKey, secret, securityToken string, accessTokenExpire, stsExpire int64) error { provider.fileMutex.Lock() @@ -321,19 +339,28 @@ func (provider *CLIProfileCredentialsProvider) updateOAuthTokens(refreshToken, a } profileName := provider.profileName - profile, err := conf.getProfile(profileName) + if profileName == "" { + profileName = conf.Current + } + if profileName == "" { + return fmt.Errorf("unable to get profile to update") + } + + // 递归查找真正的 OAuth source profile + sourceProfile, err := conf.findSourceOAuthProfile(profileName) if err != nil { - return fmt.Errorf("failed to get profile %s: %v", profileName, err) + return fmt.Errorf("failed to find OAuth source profile: %v", err) } - // update - profile.OauthRefreshToken = refreshToken - profile.OauthAccessToken = accessToken - profile.OauthAccessTokenExpire = accessTokenExpire - profile.AccessKeyID = accessKey - profile.AccessKeySecret = secret - profile.SecurityToken = securityToken - profile.StsExpire = stsExpire + // update OAuth tokens + sourceProfile.OauthRefreshToken = refreshToken + sourceProfile.OauthAccessToken = accessToken + sourceProfile.OauthAccessTokenExpire = accessTokenExpire + // update STS credentials + sourceProfile.AccessKeyID = accessKey + sourceProfile.AccessKeySecret = secret + sourceProfile.SecurityToken = securityToken + sourceProfile.StsExpire = stsExpire // write back with file lock return provider.writeConfigurationToFileWithLock(cfgPath, conf) diff --git a/credentials/providers/cli_profile_test.go b/credentials/providers/cli_profile_test.go index 3284d75..20e5093 100644 --- a/credentials/providers/cli_profile_test.go +++ b/credentials/providers/cli_profile_test.go @@ -152,7 +152,7 @@ func TestCLIProfileCredentialsProvider_getCredentialsProvider(t *testing.T) { Mode: "CloudSSO", SignInUrl: "url", AccessToken: "token", - AccessTokenExpire: time.Now().Unix() + 1000, + AccessTokenExpire: time.Now().Unix() + 2000, AccessConfig: "config", AccountId: "uid", }, @@ -162,7 +162,7 @@ func TestCLIProfileCredentialsProvider_getCredentialsProvider(t *testing.T) { OauthSiteType: "CN", OauthRefreshToken: "refresh_token", OauthAccessToken: "access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, { Mode: "Unsupported", @@ -249,7 +249,7 @@ func TestCLIProfileCredentialsProvider_OAuthProfile(t *testing.T) { OauthSiteType: "CN", OauthRefreshToken: "refresh_token", OauthAccessToken: "access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, { Mode: "OAuth", @@ -257,7 +257,7 @@ func TestCLIProfileCredentialsProvider_OAuthProfile(t *testing.T) { OauthSiteType: "INTL", OauthRefreshToken: "refresh_token", OauthAccessToken: "access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, { Mode: "OAuth", @@ -265,7 +265,7 @@ func TestCLIProfileCredentialsProvider_OAuthProfile(t *testing.T) { OauthSiteType: "INVALID", OauthRefreshToken: "refresh_token", OauthAccessToken: "access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -313,7 +313,7 @@ func TestCLIProfileCredentialsProvider_updateOAuthTokens(t *testing.T) { OauthSiteType: "CN", OauthRefreshToken: "old_refresh_token", OauthAccessToken: "old_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -375,7 +375,7 @@ func TestCLIProfileCredentialsProvider_writeConfigurationToFile(t *testing.T) { OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -429,7 +429,7 @@ func TestCLIProfileCredentialsProvider_writeConfigurationToFile_Error(t *testing OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -464,7 +464,7 @@ func TestCLIProfileCredentialsProvider_writeConfigurationToFileWithLock(t *testi OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -518,7 +518,7 @@ func TestCLIProfileCredentialsProvider_writeConfigurationToFileWithLock_Error(t OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -553,7 +553,7 @@ func TestCLIProfileCredentialsProvider_getOAuthTokenUpdateCallback(t *testing.T) OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -647,7 +647,7 @@ func TestCLIProfileCredentialsProvider_updateOAuthTokens_ProfileNotFound(t *test OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, // 改为 2000秒(>1200秒) }, }, } @@ -676,7 +676,7 @@ func TestCLIProfileCredentialsProvider_updateOAuthTokens_ProfileNotFound(t *test err = provider.updateOAuthTokens(newRefreshToken, newAccessToken, newAccessKey, newSecret, newSecurityToken, newExpireTime, newStsExpire) assert.NotNil(t, err) - assert.Contains(t, err.Error(), "failed to get profile NonExistentProfile") + assert.Contains(t, err.Error(), "failed to find OAuth source profile") } func TestCLIProfileCredentialsProvider_ConcurrentUpdate(t *testing.T) { @@ -696,7 +696,7 @@ func TestCLIProfileCredentialsProvider_ConcurrentUpdate(t *testing.T) { OauthSiteType: "CN", OauthRefreshToken: "initial_refresh_token", OauthAccessToken: "initial_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -766,7 +766,7 @@ func TestCLIProfileCredentialsProvider_FileLock(t *testing.T) { OauthSiteType: "CN", OauthRefreshToken: "initial_refresh_token", OauthAccessToken: "initial_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -919,7 +919,7 @@ func TestCLIProfileCredentialsProvider_writeConfigurationToFile_RenameError(t *t OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -958,7 +958,7 @@ func TestCLIProfileCredentialsProvider_writeConfigurationToFileWithLock_RenameEr OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -997,7 +997,7 @@ func TestCLIProfileCredentialsProvider_updateOAuthTokens_WriteError(t *testing.T OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -1050,7 +1050,7 @@ func TestCLIProfileCredentialsProvider_GetCredentials_WithOAuthProfile(t *testin OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, // 改为 2000秒(>1200秒),避免触发刷新 }, }, } @@ -1092,7 +1092,7 @@ func TestCLIProfileCredentialsProvider_FileLock_ConcurrentAccess(t *testing.T) { OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -1177,7 +1177,7 @@ func TestCLIProfileCredentialsProvider_ProfileName_Empty(t *testing.T) { OauthSiteType: "CN", OauthRefreshToken: "test_refresh_token", OauthAccessToken: "test_access_token", - OauthAccessTokenExpire: time.Now().Unix() + 1000, + OauthAccessTokenExpire: time.Now().Unix() + 2000, }, }, } @@ -1384,7 +1384,7 @@ func TestCLIProfileCredentialsProvider_UpdateOAuthTokens_ErrorScenarios(t *testi err = provider.updateOAuthTokens("refresh", "access", "ak", "sk", "token", 1234567890, 1234567890) assert.NotNil(t, err) - assert.Contains(t, err.Error(), "failed to get profile nonexistent") + assert.Contains(t, err.Error(), "failed to find OAuth source profile") // 测试4: 配置文件写入失败 - 通过创建只读目录来模拟 (仅在Unix上测试) if runtime.GOOS != "windows" { @@ -2127,3 +2127,166 @@ func TestExternalCredentialsProvider_ConcurrentAccess(t *testing.T) { } } +func TestCLIProfileFindSourceOAuthProfile(t *testing.T) { + // 测试 findSourceOAuthProfile 递归查找功能 + conf := &configuration{ + Current: "chainable_oauth", + Profiles: []*profile{ + { + Name: "oauth_source", + Mode: "OAuth", + OauthRefreshToken: "initial_refresh_token", + OauthAccessToken: "initial_access_token", + }, + { + Name: "chainable_oauth", + Mode: "ChainableRamRoleArn", + SourceProfile: "oauth_source", + RoleArn: "acs:ram::123456789012:role/test-role", + }, + }, + } + + // 测试从 chainable profile 查找到 OAuth source profile + sourceProfile, err := conf.findSourceOAuthProfile("chainable_oauth") + assert.Nil(t, err) + assert.NotNil(t, sourceProfile) + assert.Equal(t, "oauth_source", sourceProfile.Name) + assert.Equal(t, "OAuth", sourceProfile.Mode) + + // 测试直接查找 OAuth profile + sourceProfile, err = conf.findSourceOAuthProfile("oauth_source") + assert.Nil(t, err) + assert.NotNil(t, sourceProfile) + assert.Equal(t, "oauth_source", sourceProfile.Name) + + // 测试查找不存在的 profile + _, err = conf.findSourceOAuthProfile("nonexistent") + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "unable to get profile with name 'nonexistent'") + + // 测试 profile 链中没有 OAuth + conf.Profiles = append(conf.Profiles, &profile{ + Name: "no_oauth_chain", + Mode: "AK", + AccessKeyID: "akid", + AccessKeySecret: "secret", + }) + _, err = conf.findSourceOAuthProfile("no_oauth_chain") + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "unable to get OAuth profile") +} + +func TestCLIProfileNestedChainableRamRoleWithOAuthSource(t *testing.T) { + // 测试多层嵌套 ChainableRamRoleArn 使用 OAuth source profile + conf := &configuration{ + Current: "chainable_level2", + Profiles: []*profile{ + { + Name: "oauth_source", + Mode: "OAuth", + OauthRefreshToken: "initial_refresh_token", + OauthAccessToken: "initial_access_token", + }, + { + Name: "chainable_level1", + Mode: "ChainableRamRoleArn", + SourceProfile: "oauth_source", + RoleArn: "acs:ram::123456789012:role/level1-role", + }, + { + Name: "chainable_level2", + Mode: "ChainableRamRoleArn", + SourceProfile: "chainable_level1", + RoleArn: "acs:ram::123456789012:role/level2-role", + }, + }, + } + + // 测试从最外层 chainable profile 递归查找到 OAuth source profile + sourceProfile, err := conf.findSourceOAuthProfile("chainable_level2") + assert.Nil(t, err) + assert.NotNil(t, sourceProfile) + assert.Equal(t, "oauth_source", sourceProfile.Name) + assert.Equal(t, "OAuth", sourceProfile.Mode) + + // 测试从中间层 chainable profile 查找 + sourceProfile, err = conf.findSourceOAuthProfile("chainable_level1") + assert.Nil(t, err) + assert.NotNil(t, sourceProfile) + assert.Equal(t, "oauth_source", sourceProfile.Name) +} + +func TestCLIProfileUpdateOAuthTokensWithChainable(t *testing.T) { + // 测试 ChainableRamRoleArn 场景下 OAuth token 更新到正确的 source profile + wd, _ := os.Getwd() + tmpDir := path.Join(wd, ".tmp_test") + os.MkdirAll(tmpDir, 0755) + defer os.RemoveAll(tmpDir) + + configPath := path.Join(tmpDir, "config_chainable_oauth.json") + + // 创建测试配置:OAuth profile <- ChainableRamRoleArn profile + testConfig := &configuration{ + Current: "chainable_oauth", + Profiles: []*profile{ + { + Name: "oauth_source", + Mode: "OAuth", + OauthRefreshToken: "initial_refresh_token", + OauthAccessToken: "initial_access_token", + OauthAccessTokenExpire: time.Now().Unix() + 3600, + }, + { + Name: "chainable_oauth", + Mode: "ChainableRamRoleArn", + SourceProfile: "oauth_source", + RoleArn: "acs:ram::123456789012:role/test-role", + }, + }, + } + + // 写入配置文件 + data, _ := json.MarshalIndent(testConfig, "", " ") + ioutil.WriteFile(configPath, data, 0644) + + provider, err := NewCLIProfileCredentialsProviderBuilder(). + WithProfileFile(configPath). + WithProfileName("chainable_oauth"). + Build() + assert.Nil(t, err) + + // 更新 OAuth tokens + newRefreshToken := "new_refresh_token" + newAccessToken := "new_access_token" + newAccessKey := "new_access_key" + newSecret := "new_secret" + newSecurityToken := "new_security_token" + newExpire := time.Now().Unix() + 7200 + newStsExpire := time.Now().Unix() + 10800 + + err = provider.updateOAuthTokens(newRefreshToken, newAccessToken, newAccessKey, newSecret, newSecurityToken, newExpire, newStsExpire) + assert.Nil(t, err) + + // 读取更新后的配置 + updatedConf, err := newConfigurationFromPath(configPath) + assert.Nil(t, err) + + // 验证 OAuth source profile 被正确更新 + oauthProfile, err := updatedConf.getProfile("oauth_source") + assert.Nil(t, err) + assert.Equal(t, newRefreshToken, oauthProfile.OauthRefreshToken) + assert.Equal(t, newAccessToken, oauthProfile.OauthAccessToken) + assert.Equal(t, newAccessKey, oauthProfile.AccessKeyID) + assert.Equal(t, newSecret, oauthProfile.AccessKeySecret) + assert.Equal(t, newSecurityToken, oauthProfile.SecurityToken) + assert.Equal(t, newExpire, oauthProfile.OauthAccessTokenExpire) + assert.Equal(t, newStsExpire, oauthProfile.StsExpire) + + // 验证 chainable profile 没有被错误更新 + chainableProfile, err := updatedConf.getProfile("chainable_oauth") + assert.Nil(t, err) + assert.Equal(t, "", chainableProfile.OauthRefreshToken) + assert.Equal(t, "", chainableProfile.OauthAccessToken) +} + diff --git a/credentials/providers/oauth.go b/credentials/providers/oauth.go index 145c43b..1d05c81 100644 --- a/credentials/providers/oauth.go +++ b/credentials/providers/oauth.go @@ -113,7 +113,9 @@ func (b *OAuthCredentialsProviderBuilder) Build() (provider *OAuthCredentialsPro func (provider *OAuthCredentialsProvider) getCredentials() (session *sessionCredentials, err error) { - if provider.accessToken == "" || provider.accessTokenExpire == 0 || provider.accessTokenExpire-time.Now().Unix() <= 180 { + // OAuth token 必须提前足够时间刷新,确保有效期 >= 15分钟用于后续 exchange 操作 + // 设置为20分钟(1200秒)提前量,留有5分钟余量 + if provider.accessToken == "" || provider.accessTokenExpire == 0 || provider.accessTokenExpire-time.Now().Unix() <= 1200 { err = provider.tryRefreshOauthToken() if err != nil { return nil, err diff --git a/credentials/providers/oauth_test.go b/credentials/providers/oauth_test.go index 14a2df8..06ed916 100644 --- a/credentials/providers/oauth_test.go +++ b/credentials/providers/oauth_test.go @@ -32,7 +32,7 @@ func TestNewOAuthCredentialsProvider(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) assert.Equal(t, "clientId", p.clientId) @@ -50,7 +50,7 @@ func TestOAuthCredentialsProvider_getCredentials(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("token"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). // 改为 2000秒(>1200秒),避免触发刷新 Build() assert.Nil(t, err) @@ -154,7 +154,7 @@ func TestOAuthCredentialsProviderGetCredentials(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("token"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). // 改为 2000秒(>1200秒),避免触发刷新 WithHttpOptions(&HttpOptions{ ConnectTimeout: 10000, }). @@ -213,7 +213,7 @@ func TestOAuthCredentialsProviderGetCredentialsWithHttpOptions(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("token"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). WithHttpOptions(&HttpOptions{ ConnectTimeout: 1000, ReadTimeout: 1000, @@ -233,7 +233,7 @@ func TestOAuthCredentialsProviderGetProviderName(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("token"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) assert.Equal(t, "oauth", p.GetProviderName()) @@ -251,7 +251,7 @@ func TestOAuthCredentialsProviderWithHttpOptions(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("token"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). WithHttpOptions(httpOptions). Build() @@ -271,7 +271,7 @@ func TestOAuthCredentialsProviderCredentialCaching(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("token"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -302,7 +302,7 @@ func TestOAuthCredentialsProviderNeedUpdateCredential(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("token"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -331,7 +331,7 @@ func TestOAuthCredentialsProviderTryRefreshOauthToken(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -472,7 +472,7 @@ func TestOAuthCredentialsProvider_TokenUpdateCallback(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). WithTokenUpdateCallback(callback). Build() assert.Nil(t, err) @@ -522,7 +522,7 @@ func TestOAuthCredentialsProvider_TokenUpdateCallback_Error(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). WithTokenUpdateCallback(callback). Build() assert.Nil(t, err) @@ -561,7 +561,7 @@ func TestOAuthCredentialsProvider_WithoutTokenUpdateCallback(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) assert.Nil(t, p.tokenUpdateCallback) @@ -608,7 +608,7 @@ func TestOAuthCredentialsProvider_TryRefreshOauthToken_WithCallback(t *testing.T WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). WithTokenUpdateCallback(callback). Build() assert.Nil(t, err) @@ -677,7 +677,7 @@ func TestOAuthCredentialsProvider_GetCredentials_WithEmptyAccessToken(t *testing WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken(""). // empty access token - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -788,7 +788,7 @@ func TestOAuthCredentialsProvider_TryRefreshOauthToken_InvalidURL(t *testing.T) WithSignInUrl("invalid-url"). // invalid URL WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -806,7 +806,7 @@ func TestOAuthCredentialsProvider_TryRefreshOauthToken_NetworkError(t *testing.T WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -829,7 +829,7 @@ func TestOAuthCredentialsProvider_TryRefreshOauthToken_Non200Status(t *testing.T WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -856,7 +856,7 @@ func TestOAuthCredentialsProvider_TryRefreshOauthToken_InvalidJSON(t *testing.T) WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -883,7 +883,7 @@ func TestOAuthCredentialsProvider_TryRefreshOauthToken_EmptyTokens(t *testing.T) WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -907,7 +907,7 @@ func TestOAuthCredentialsProvider_NeedUpdateCredential_EdgeCases(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -931,7 +931,7 @@ func TestOAuthCredentialsProvider_HttpOptions_EdgeCases(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). WithHttpOptions(&HttpOptions{ ConnectTimeout: 0, ReadTimeout: 0, @@ -947,7 +947,7 @@ func TestOAuthCredentialsProvider_HttpOptions_EdgeCases(t *testing.T) { WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). WithHttpOptions(&HttpOptions{ ConnectTimeout: -1000, ReadTimeout: -2000, @@ -967,7 +967,7 @@ func TestOAuthCredentialsProvider_GetCredentials_CachedCredentials(t *testing.T) WithSignInUrl("https://oauth.aliyun.com"). WithRefreshToken("refreshToken"). WithAccessToken("accessToken"). - WithAccessTokenExpire(time.Now().Unix() + 1000). + WithAccessTokenExpire(time.Now().Unix() + 2000). Build() assert.Nil(t, err) @@ -996,3 +996,199 @@ func TestOAuthCredentialsProvider_GetCredentials_CachedCredentials(t *testing.T) assert.Equal(t, cc1.AccessKeySecret, cc2.AccessKeySecret) assert.Equal(t, cc1.SecurityToken, cc2.SecurityToken) } + +func TestOAuthTokenRefreshTimingSufficientTime(t *testing.T) { + // 测试当 OAuth token 剩余时间 > 1200秒时,不触发刷新 + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() + + callCount := 0 + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + callCount++ + // 只应该调用一次 /v1/exchange,不应该调用 /v1/token + if req.Path == "/v1/token" { + t.Error("Should not refresh token when remaining time > 1200s") + } + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"accessKeyId":"akid","accessKeySecret":"aksecret","securityToken":"token","expiration":"2030-12-31T23:59:59Z"}`), + } + return + } + + // access_token 还有 25 分钟(1500秒)过期,大于 1200秒阈值 + p, err := NewOAuthCredentialsProviderBuilder(). + WithClientId("clientId"). + WithSignInUrl("https://oauth.aliyun.com"). + WithRefreshToken("refreshToken"). + WithAccessToken("validToken"). + WithAccessTokenExpire(time.Now().Unix() + 1500). + Build() + assert.Nil(t, err) + + _, err = p.getCredentials() + assert.Nil(t, err) + // 应该只调用了一次(/v1/exchange) + assert.Equal(t, 1, callCount) + assert.Equal(t, "validToken", p.accessToken) +} + +func TestOAuthTokenRefreshTimingInsufficientTime(t *testing.T) { + // 测试当 OAuth token 剩余时间 <= 1200秒时,触发刷新 + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() + + callCount := 0 + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + callCount++ + if req.Path == "/v1/token" { + // Token 刷新请求 + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"access_token":"newToken","refresh_token":"newRefreshToken","expires_in":3600}`), + } + } else if req.Path == "/v1/exchange" { + // 凭据交换请求 + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"accessKeyId":"akid","accessKeySecret":"aksecret","securityToken":"token","expiration":"2030-12-31T23:59:59Z"}`), + } + } + return + } + + // access_token 还有 15 分钟(900秒)过期,小于 1200秒阈值 + p, err := NewOAuthCredentialsProviderBuilder(). + WithClientId("clientId"). + WithSignInUrl("https://oauth.aliyun.com"). + WithRefreshToken("oldRefreshToken"). + WithAccessToken("oldToken"). + WithAccessTokenExpire(time.Now().Unix() + 900). + Build() + assert.Nil(t, err) + + _, err = p.getCredentials() + assert.Nil(t, err) + // 应该调用了两次:1. /v1/token(刷新),2. /v1/exchange(交换) + assert.Equal(t, 2, callCount) + assert.Equal(t, "newToken", p.accessToken) + assert.Equal(t, "newRefreshToken", p.refreshToken) +} + +func TestOAuthTokenRefreshTimingExactlyThreshold(t *testing.T) { + // 测试当 OAuth token 剩余时间正好等于 1200秒时,触发刷新 + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() + + refreshCalled := false + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/v1/token" { + refreshCalled = true + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"access_token":"newToken","refresh_token":"newRefreshToken","expires_in":3600}`), + } + } else if req.Path == "/v1/exchange" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"accessKeyId":"akid","accessKeySecret":"aksecret","securityToken":"token","expiration":"2030-12-31T23:59:59Z"}`), + } + } + return + } + + // access_token 还有正好 1200秒(20分钟)过期 + p, err := NewOAuthCredentialsProviderBuilder(). + WithClientId("clientId"). + WithSignInUrl("https://oauth.aliyun.com"). + WithRefreshToken("oldRefreshToken"). + WithAccessToken("oldToken"). + WithAccessTokenExpire(time.Now().Unix() + 1200). + Build() + assert.Nil(t, err) + + _, err = p.getCredentials() + assert.Nil(t, err) + // 应该触发了刷新 + assert.True(t, refreshCalled) + assert.Equal(t, "newToken", p.accessToken) +} + +func TestOAuthTokenRefreshTimingZeroExpire(t *testing.T) { + // 测试边界情况:access_token_expire 为 0 时触发刷新 + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() + + refreshCalled := false + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/v1/token" { + refreshCalled = true + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"access_token":"newToken","refresh_token":"newRefreshToken","expires_in":3600}`), + } + } else if req.Path == "/v1/exchange" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"accessKeyId":"akid","accessKeySecret":"aksecret","securityToken":"token","expiration":"2030-12-31T23:59:59Z"}`), + } + } + return + } + + // access_token_expire 为 0 + p, err := NewOAuthCredentialsProviderBuilder(). + WithClientId("clientId"). + WithSignInUrl("https://oauth.aliyun.com"). + WithRefreshToken("oldRefreshToken"). + WithAccessToken("oldToken"). + WithAccessTokenExpire(0). + Build() + assert.Nil(t, err) + + _, err = p.getCredentials() + assert.Nil(t, err) + // 应该触发了刷新 + assert.True(t, refreshCalled) + assert.Equal(t, "newToken", p.accessToken) +} + +func TestOAuthTokenRefreshTimingEmptyToken(t *testing.T) { + // 测试边界情况:access_token 为空时触发刷新 + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() + + refreshCalled := false + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/v1/token" { + refreshCalled = true + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"access_token":"newToken","refresh_token":"newRefreshToken","expires_in":3600}`), + } + } else if req.Path == "/v1/exchange" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"accessKeyId":"akid","accessKeySecret":"aksecret","securityToken":"token","expiration":"2030-12-31T23:59:59Z"}`), + } + } + return + } + + // access_token 为空 + p, err := NewOAuthCredentialsProviderBuilder(). + WithClientId("clientId"). + WithSignInUrl("https://oauth.aliyun.com"). + WithRefreshToken("refreshToken"). + WithAccessToken(""). + WithAccessTokenExpire(time.Now().Unix() + 3600). + Build() + assert.Nil(t, err) + + _, err = p.getCredentials() + assert.Nil(t, err) + // 应该触发了刷新 + assert.True(t, refreshCalled) + assert.Equal(t, "newToken", p.accessToken) +} +