diff --git a/kinesis/kinesis.go b/kinesis/kinesis.go index 8e7a569b..741bfe5e 100644 --- a/kinesis/kinesis.go +++ b/kinesis/kinesis.go @@ -212,6 +212,7 @@ func newPutRecordsClient(roleARN string, awsRegion string, kinesisEndpoint strin eksConfig.Credentials = creds eksConfig.Region = aws.String(awsRegion) eksConfig.HTTPClient = httpClient + eksConfig.EndpointResolver = endpoints.ResolverFunc(customResolverFn) svcConfig = eksConfig svcSess, err = session.NewSession(svcConfig) @@ -226,6 +227,7 @@ func newPutRecordsClient(roleARN string, awsRegion string, kinesisEndpoint strin stsConfig.Credentials = creds stsConfig.Region = aws.String(awsRegion) stsConfig.HTTPClient = httpClient + stsConfig.EndpointResolver = endpoints.ResolverFunc(customResolverFn) svcConfig = stsConfig svcSess, err = session.NewSession(svcConfig) diff --git a/kinesis/kinesis_test.go b/kinesis/kinesis_test.go index 6199d68e..18175f6d 100644 --- a/kinesis/kinesis_test.go +++ b/kinesis/kinesis_test.go @@ -271,15 +271,15 @@ func RandStringRunes(n int) string { } func TestCompressionTruncation(t *testing.T) { - deftlvl := logrus.GetLevel(); - logrus.SetLevel(0); + deftlvl := logrus.GetLevel() + logrus.SetLevel(0) rand.Seed(0) testData := []byte(RandStringRunes(4000)) testSuffix := "[truncate]" outputPlugin := OutputPlugin{ PluginID: 10, - stream: "MyStream", + stream: "MyStream", } var compressedOutput, err = compressThenTruncate(gzipCompress, testData, 200, []byte(testSuffix), outputPlugin) assert.Nil(t, err) @@ -290,15 +290,15 @@ func TestCompressionTruncation(t *testing.T) { } func TestCompressionTruncationFailureA(t *testing.T) { - deftlvl := logrus.GetLevel(); - logrus.SetLevel(0); + deftlvl := logrus.GetLevel() + logrus.SetLevel(0) rand.Seed(0) testData := []byte(RandStringRunes(4000)) testSuffix := "[truncate]" outputPlugin := OutputPlugin{ PluginID: 10, - stream: "MyStream", + stream: "MyStream", } var _, err = compressThenTruncate(gzipCompress, testData, 20, []byte(testSuffix), outputPlugin) assert.Contains(t, err.Error(), "no room for suffix") @@ -307,15 +307,15 @@ func TestCompressionTruncationFailureA(t *testing.T) { } func TestCompressionTruncationFailureB(t *testing.T) { - deftlvl := logrus.GetLevel(); - logrus.SetLevel(0); + deftlvl := logrus.GetLevel() + logrus.SetLevel(0) rand.Seed(0) testData := []byte{} testSuffix := "[truncate]" outputPlugin := OutputPlugin{ PluginID: 10, - stream: "MyStream", + stream: "MyStream", } var _, err = compressThenTruncate(gzipCompress, testData, 5, []byte(testSuffix), outputPlugin) assert.Contains(t, err.Error(), "compressed empty to large") @@ -403,3 +403,56 @@ func TestGetPartitionKey(t *testing.T) { assert.Equal(t, false, hasValue, "Should not find value") assert.Len(t, value, 0, "This should be an empty string") } + +// TestNewPutRecordsClient_CustomEndpointWithRoles tests that custom endpoint resolvers +// are preserved when using role-based authentication +func TestNewPutRecordsClient_CustomEndpointWithRoles(t *testing.T) { + // Save and restore environment variable + originalEKSRole := os.Getenv("EKS_POD_EXECUTION_ROLE") + defer func() { + if originalEKSRole == "" { + os.Unsetenv("EKS_POD_EXECUTION_ROLE") + } else { + os.Setenv("EKS_POD_EXECUTION_ROLE", originalEKSRole) + } + }() + + customEndpoint := "https://kinesis.custom-domain.local" + + testCases := []struct { + name string + roleARN string + eksRole string + }{ + {"no_roles", "", ""}, + {"with_role_arn", "arn:aws:iam::123456789012:role/test-role", ""}, + {"with_eks_role", "", "arn:aws:iam::123456789012:role/eks-role"}, + {"with_both_roles", "arn:aws:iam::123456789012:role/test-role", "arn:aws:iam::123456789012:role/eks-role"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Set EKS role environment variable if specified + if tc.eksRole != "" { + os.Setenv("EKS_POD_EXECUTION_ROLE", tc.eksRole) + } else { + os.Unsetenv("EKS_POD_EXECUTION_ROLE") + } + + client, err := newPutRecordsClient(tc.roleARN, "us-west-2", customEndpoint, "", 1, time.Second*30) + + if err != nil { + // Expected in test environment without credentials + t.Logf("Expected credential error: %v", err) + return + } + + // Verify the custom endpoint is preserved + if client != nil && client.Client != nil { + actualEndpoint := client.Client.ClientInfo.Endpoint + assert.Equal(t, customEndpoint, actualEndpoint, + "Custom endpoint should be preserved when using role-based authentication") + } + }) + } +}