Skip to content

Commit a7baef9

Browse files
committed
Allow assume role for AWS_MSK_IAM
1 parent f5e4d2d commit a7baef9

File tree

4 files changed

+45
-21
lines changed

4 files changed

+45
-21
lines changed

Diff for: README.md

+2
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,10 @@ You can launch a kafka-proxy container with auth-ldap plugin for trying it out w
191191
--proxy-listener-write-buffer-size int Sets the size of the operating system's transmit buffer associated with the connection. If zero, system default is used
192192
--proxy-request-buffer-size int Request buffer size pro tcp connection (default 4096)
193193
--proxy-response-buffer-size int Response buffer size pro tcp connection (default 4096)
194+
--sasl-aws-identity-lookup Verify AWS authentication identity
194195
--sasl-aws-profile string AWS profile
195196
--sasl-aws-region string Region for AWS IAM Auth
197+
--sasl-aws-role-arn string AWS Role ARN to assume
196198
--sasl-enable Connect using SASL
197199
--sasl-jaas-config-file string Location of JAAS config file with SASL username and password
198200
--sasl-method string SASL method to use (PLAIN, SCRAM-SHA-256, SCRAM-SHA-512, GSSAPI, AWS_MSK_IAM (default "PLAIN")

Diff for: cmd/kafka-proxy/server.go

+2
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ func initFlags() {
187187
// SASL AWS_MSK_IAM
188188
Server.Flags().StringVar(&c.Kafka.SASL.AWSConfig.Region, "sasl-aws-region", "", "Region for AWS IAM Auth")
189189
Server.Flags().StringVar(&c.Kafka.SASL.AWSConfig.Profile, "sasl-aws-profile", "", "AWS profile")
190+
Server.Flags().StringVar(&c.Kafka.SASL.AWSConfig.RoleArn, "sasl-aws-role-arn", "", "AWS Role ARN to assume")
191+
Server.Flags().BoolVar(&c.Kafka.SASL.AWSConfig.IdentityLookup, "sasl-aws-identity-lookup", false, "Verify AWS authentication identity")
190192

191193
// SASL by Proxy plugin
192194
Server.Flags().BoolVar(&c.Kafka.SASL.Plugin.Enable, "sasl-plugin-enable", false, "Use plugin for SASL authentication")

Diff for: config/config.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ type GSSAPIConfig struct {
4848
}
4949

5050
type AWSConfig struct {
51-
Region string
52-
Profile string
51+
Region string
52+
Profile string
53+
RoleArn string
54+
IdentityLookup bool
5355
}
5456

5557
type Config struct {

Diff for: proxy/sasl_aws_msk_iam.go

+37-19
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@ import (
44
"bytes"
55
"context"
66
"encoding/binary"
7+
"errors"
78
"fmt"
89
"io"
910
"net"
1011
"time"
1112

13+
"github.com/aws/aws-sdk-go-v2/aws"
1214
"github.com/aws/aws-sdk-go-v2/config"
15+
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
16+
"github.com/aws/aws-sdk-go-v2/service/sts"
1317
proxyconfig "github.com/grepplabs/kafka-proxy/config"
1418
"github.com/grepplabs/kafka-proxy/proxy/protocol"
15-
"github.com/pkg/errors"
1619
"github.com/sirupsen/logrus"
1720
)
1821

@@ -44,6 +47,21 @@ func NewAwsMSKIamAuth(
4447
if err != nil {
4548
return nil, fmt.Errorf("loading aws config: %v", err)
4649
}
50+
if awsConfig.RoleArn != "" {
51+
stsClient := sts.NewFromConfig(cfg)
52+
assumeRoleProvider := stscreds.NewAssumeRoleProvider(stsClient, awsConfig.RoleArn)
53+
cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider)
54+
}
55+
if awsConfig.IdentityLookup {
56+
ctx := context.Background()
57+
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
58+
defer cancel()
59+
output, err := sts.NewFromConfig(cfg).GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
60+
if err != nil {
61+
return nil, fmt.Errorf("failed to get caller identity: %v", err)
62+
}
63+
logrus.Infof("AWS_MSK_IAM caller identity %s", aws.ToString(output.Arn))
64+
}
4765
return &AwsMSKIamAuth{
4866
clientID: clientId,
4967
signer: newMechanism(cfg),
@@ -56,11 +74,11 @@ func NewAwsMSKIamAuth(
5674
// sendAndReceiveSASLAuth handles the entire SASL authentication process
5775
func (a *AwsMSKIamAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter, brokerString string) error {
5876
if err := a.saslHandshake(conn); err != nil {
59-
return errors.Wrap(err, "handshake failed")
77+
return fmt.Errorf("handshake failed: %w", err)
6078
}
6179

6280
if err := a.saslAuthenticate(conn, brokerString); err != nil {
63-
return errors.Wrap(err, "authenticate failed")
81+
return fmt.Errorf("authenticate failed: %w", err)
6482
}
6583

6684
return nil
@@ -76,21 +94,21 @@ func (a *AwsMSKIamAuth) saslHandshake(conn DeadlineReaderWriter) error {
7694
Body: rb,
7795
}
7896
if err := a.write(conn, req); err != nil {
79-
return errors.Wrap(err, "writing SASL handshake")
97+
return fmt.Errorf("writing SASL handshake: %w", err)
8098
}
8199

82100
payload, err := a.read(conn)
83101
if err != nil {
84-
return errors.Wrap(err, "reading SASL handshake")
102+
return fmt.Errorf("reading SASL handshake: %w", err)
85103
}
86104

87105
res := &protocol.SaslHandshakeResponseV0orV1{}
88106
if err := protocol.Decode(payload, res); err != nil {
89-
return errors.Wrap(err, "parsing SASL handshake response")
107+
return fmt.Errorf("parsing SASL handshake response: %w", err)
90108
}
91109

92-
if res.Err != protocol.ErrNoError {
93-
return errors.Wrap(res.Err, "sasl handshake protocol error")
110+
if !errors.Is(res.Err, protocol.ErrNoError) {
111+
return fmt.Errorf("sasl handshake protocol error: %w", res.Err)
94112
}
95113
logrus.Debugf("Successful IAM SASL handshake. Available mechanisms: %v", res.EnabledMechanisms)
96114
return nil
@@ -114,59 +132,59 @@ func (a *AwsMSKIamAuth) saslAuthenticate(conn DeadlineReaderWriter, brokerString
114132
Body: saslAuthReqV0,
115133
}
116134
if err := a.write(conn, req); err != nil {
117-
return errors.Wrap(err, "writing SASL authentication request")
135+
return fmt.Errorf("writing SASL authentication request: %w", err)
118136
}
119137

120138
payload, err := a.read(conn)
121139
if err != nil {
122-
return errors.Wrap(err, "reading SASL authentication response")
140+
return fmt.Errorf("reading SASL authentication response: %w", err)
123141
}
124142

125143
res := &protocol.SaslAuthenticateResponseV0{}
126144
err = protocol.Decode(payload, res)
127145
if err != nil {
128-
return errors.Wrap(err, "parsing SASL authentication response")
146+
return fmt.Errorf("parsing SASL authentication response: %w", err)
129147
}
130-
if res.Err != protocol.ErrNoError {
131-
return errors.Wrap(res.Err, "sasl authentication protocol error")
148+
if !errors.Is(res.Err, protocol.ErrNoError) {
149+
return fmt.Errorf("sasl authentication protocol error: %w", res.Err)
132150
}
133151
return nil
134152
}
135153

136154
func (a *AwsMSKIamAuth) write(conn DeadlineReaderWriter, req *protocol.Request) error {
137155
reqBuf, err := protocol.Encode(req)
138156
if err != nil {
139-
return errors.Wrap(err, "serializing request")
157+
return fmt.Errorf("serializing request: %w", err)
140158
}
141159

142160
sizeBuf := make([]byte, 4)
143161
binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqBuf)))
144162

145163
if err := conn.SetWriteDeadline(time.Now().Add(a.writeTimeout)); err != nil {
146-
return errors.Wrap(err, "setting write deadline")
164+
return fmt.Errorf("setting write deadline: %w", err)
147165
}
148166

149167
if _, err := conn.Write(bytes.Join([][]byte{sizeBuf, reqBuf}, nil)); err != nil {
150-
return errors.Wrap(err, "writing bytes")
168+
return fmt.Errorf("writing bytes: %w", err)
151169
}
152170
return nil
153171
}
154172

155173
func (a *AwsMSKIamAuth) read(conn DeadlineReaderWriter) ([]byte, error) {
156174
if err := conn.SetReadDeadline(time.Now().Add(a.readTimeout)); err != nil {
157-
return nil, errors.Wrap(err, "setting read deadline")
175+
return nil, fmt.Errorf("setting read deadline: %w", err)
158176
}
159177

160178
//wait for the handshake response
161179
header := make([]byte, 8) // response header
162180
if _, err := io.ReadFull(conn, header); err != nil {
163-
return nil, errors.Wrap(err, "reading header")
181+
return nil, fmt.Errorf("reading header: %w", err)
164182
}
165183

166184
length := binary.BigEndian.Uint32(header[:4])
167185
payload := make([]byte, length-4)
168186
if _, err := io.ReadFull(conn, payload); err != nil {
169-
return nil, errors.Wrap(err, "reading payload")
187+
return nil, fmt.Errorf("reading payload: %w", err)
170188
}
171189

172190
return payload, nil

0 commit comments

Comments
 (0)