@@ -4,15 +4,18 @@ import (
4
4
"bytes"
5
5
"context"
6
6
"encoding/binary"
7
+ "errors"
7
8
"fmt"
8
9
"io"
9
10
"net"
10
11
"time"
11
12
13
+ "github.com/aws/aws-sdk-go-v2/aws"
12
14
"github.com/aws/aws-sdk-go-v2/config"
15
+ "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
16
+ "github.com/aws/aws-sdk-go-v2/service/sts"
13
17
proxyconfig "github.com/grepplabs/kafka-proxy/config"
14
18
"github.com/grepplabs/kafka-proxy/proxy/protocol"
15
- "github.com/pkg/errors"
16
19
"github.com/sirupsen/logrus"
17
20
)
18
21
@@ -44,6 +47,21 @@ func NewAwsMSKIamAuth(
44
47
if err != nil {
45
48
return nil , fmt .Errorf ("loading aws config: %v" , err )
46
49
}
50
+ if awsConfig .RoleArn != "" {
51
+ stsClient := sts .NewFromConfig (cfg )
52
+ assumeRoleProvider := stscreds .NewAssumeRoleProvider (stsClient , awsConfig .RoleArn )
53
+ cfg .Credentials = aws .NewCredentialsCache (assumeRoleProvider )
54
+ }
55
+ if awsConfig .IdentityLookup {
56
+ ctx := context .Background ()
57
+ ctx , cancel := context .WithTimeout (ctx , 15 * time .Second )
58
+ defer cancel ()
59
+ output , err := sts .NewFromConfig (cfg ).GetCallerIdentity (ctx , & sts.GetCallerIdentityInput {})
60
+ if err != nil {
61
+ return nil , fmt .Errorf ("failed to get caller identity: %v" , err )
62
+ }
63
+ logrus .Infof ("AWS_MSK_IAM caller identity %s" , aws .ToString (output .Arn ))
64
+ }
47
65
return & AwsMSKIamAuth {
48
66
clientID : clientId ,
49
67
signer : newMechanism (cfg ),
@@ -56,11 +74,11 @@ func NewAwsMSKIamAuth(
56
74
// sendAndReceiveSASLAuth handles the entire SASL authentication process
57
75
func (a * AwsMSKIamAuth ) sendAndReceiveSASLAuth (conn DeadlineReaderWriter , brokerString string ) error {
58
76
if err := a .saslHandshake (conn ); err != nil {
59
- return errors . Wrap ( err , "handshake failed" )
77
+ return fmt . Errorf ( "handshake failed: %w" , err )
60
78
}
61
79
62
80
if err := a .saslAuthenticate (conn , brokerString ); err != nil {
63
- return errors . Wrap ( err , "authenticate failed" )
81
+ return fmt . Errorf ( "authenticate failed: %w" , err )
64
82
}
65
83
66
84
return nil
@@ -76,21 +94,21 @@ func (a *AwsMSKIamAuth) saslHandshake(conn DeadlineReaderWriter) error {
76
94
Body : rb ,
77
95
}
78
96
if err := a .write (conn , req ); err != nil {
79
- return errors . Wrap ( err , "writing SASL handshake" )
97
+ return fmt . Errorf ( "writing SASL handshake: %w" , err )
80
98
}
81
99
82
100
payload , err := a .read (conn )
83
101
if err != nil {
84
- return errors . Wrap ( err , "reading SASL handshake" )
102
+ return fmt . Errorf ( "reading SASL handshake: %w" , err )
85
103
}
86
104
87
105
res := & protocol.SaslHandshakeResponseV0orV1 {}
88
106
if err := protocol .Decode (payload , res ); err != nil {
89
- return errors . Wrap ( err , "parsing SASL handshake response" )
107
+ return fmt . Errorf ( "parsing SASL handshake response: %w" , err )
90
108
}
91
109
92
- if res .Err != protocol .ErrNoError {
93
- return errors . Wrap ( res . Err , "sasl handshake protocol error" )
110
+ if ! errors . Is ( res .Err , protocol .ErrNoError ) {
111
+ return fmt . Errorf ( "sasl handshake protocol error: %w" , res . Err )
94
112
}
95
113
logrus .Debugf ("Successful IAM SASL handshake. Available mechanisms: %v" , res .EnabledMechanisms )
96
114
return nil
@@ -114,59 +132,59 @@ func (a *AwsMSKIamAuth) saslAuthenticate(conn DeadlineReaderWriter, brokerString
114
132
Body : saslAuthReqV0 ,
115
133
}
116
134
if err := a .write (conn , req ); err != nil {
117
- return errors . Wrap ( err , "writing SASL authentication request" )
135
+ return fmt . Errorf ( "writing SASL authentication request: %w" , err )
118
136
}
119
137
120
138
payload , err := a .read (conn )
121
139
if err != nil {
122
- return errors . Wrap ( err , "reading SASL authentication response" )
140
+ return fmt . Errorf ( "reading SASL authentication response: %w" , err )
123
141
}
124
142
125
143
res := & protocol.SaslAuthenticateResponseV0 {}
126
144
err = protocol .Decode (payload , res )
127
145
if err != nil {
128
- return errors . Wrap ( err , "parsing SASL authentication response" )
146
+ return fmt . Errorf ( "parsing SASL authentication response: %w" , err )
129
147
}
130
- if res .Err != protocol .ErrNoError {
131
- return errors . Wrap ( res . Err , "sasl authentication protocol error" )
148
+ if ! errors . Is ( res .Err , protocol .ErrNoError ) {
149
+ return fmt . Errorf ( "sasl authentication protocol error: %w" , res . Err )
132
150
}
133
151
return nil
134
152
}
135
153
136
154
func (a * AwsMSKIamAuth ) write (conn DeadlineReaderWriter , req * protocol.Request ) error {
137
155
reqBuf , err := protocol .Encode (req )
138
156
if err != nil {
139
- return errors . Wrap ( err , "serializing request" )
157
+ return fmt . Errorf ( "serializing request: %w" , err )
140
158
}
141
159
142
160
sizeBuf := make ([]byte , 4 )
143
161
binary .BigEndian .PutUint32 (sizeBuf , uint32 (len (reqBuf )))
144
162
145
163
if err := conn .SetWriteDeadline (time .Now ().Add (a .writeTimeout )); err != nil {
146
- return errors . Wrap ( err , "setting write deadline" )
164
+ return fmt . Errorf ( "setting write deadline: %w" , err )
147
165
}
148
166
149
167
if _ , err := conn .Write (bytes .Join ([][]byte {sizeBuf , reqBuf }, nil )); err != nil {
150
- return errors . Wrap ( err , "writing bytes" )
168
+ return fmt . Errorf ( "writing bytes: %w" , err )
151
169
}
152
170
return nil
153
171
}
154
172
155
173
func (a * AwsMSKIamAuth ) read (conn DeadlineReaderWriter ) ([]byte , error ) {
156
174
if err := conn .SetReadDeadline (time .Now ().Add (a .readTimeout )); err != nil {
157
- return nil , errors . Wrap ( err , "setting read deadline" )
175
+ return nil , fmt . Errorf ( "setting read deadline: %w" , err )
158
176
}
159
177
160
178
//wait for the handshake response
161
179
header := make ([]byte , 8 ) // response header
162
180
if _ , err := io .ReadFull (conn , header ); err != nil {
163
- return nil , errors . Wrap ( err , "reading header" )
181
+ return nil , fmt . Errorf ( "reading header: %w" , err )
164
182
}
165
183
166
184
length := binary .BigEndian .Uint32 (header [:4 ])
167
185
payload := make ([]byte , length - 4 )
168
186
if _ , err := io .ReadFull (conn , payload ); err != nil {
169
- return nil , errors . Wrap ( err , "reading payload" )
187
+ return nil , fmt . Errorf ( "reading payload: %w" , err )
170
188
}
171
189
172
190
return payload , nil
0 commit comments