Skip to content

Commit a096376

Browse files
authored
Merge pull request #21 from MbolotSuse/new-product
Adding support for second product id
2 parents c0dc459 + 608cf2d commit a096376

File tree

4 files changed

+270
-12
lines changed

4 files changed

+270
-12
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ replace (
3131
)
3232

3333
require (
34-
github.com/aws/aws-sdk-go-v2 v1.16.2
3534
github.com/aws/aws-sdk-go-v2/config v1.15.3
3635
github.com/aws/aws-sdk-go-v2/service/licensemanager v1.15.3
3736
github.com/aws/aws-sdk-go-v2/service/sts v1.16.3
@@ -49,6 +48,7 @@ require (
4948
)
5049

5150
require (
51+
github.com/aws/aws-sdk-go-v2 v1.16.2 // indirect
5252
github.com/aws/aws-sdk-go-v2/credentials v1.11.2 // indirect
5353
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3 // indirect
5454
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9 // indirect

pkg/clients/aws/client.go

+46-11
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"fmt"
88
"strconv"
99

10-
"github.com/aws/aws-sdk-go-v2/aws"
1110
"github.com/aws/aws-sdk-go-v2/config"
1211
lm "github.com/aws/aws-sdk-go-v2/service/licensemanager"
1312
"github.com/aws/aws-sdk-go-v2/service/licensemanager/types"
@@ -30,12 +29,22 @@ type Client interface {
3029
// GetNumberOfAvailableEntitlements gets the number of RKE_NODE_SUPP entitlements available on license
3130
GetNumberOfAvailableEntitlements(ctx context.Context, license types.GrantedLicense) (int, error)
3231
}
32+
type licenseManagerClient interface {
33+
ListReceivedLicenses(ctx context.Context, params *lm.ListReceivedLicensesInput, optFns ...func(*lm.Options)) (*lm.ListReceivedLicensesOutput, error)
34+
CheckoutLicense(ctx context.Context, params *lm.CheckoutLicenseInput, optFns ...func(*lm.Options)) (*lm.CheckoutLicenseOutput, error)
35+
CheckInLicense(ctx context.Context, params *lm.CheckInLicenseInput, optFns ...func(*lm.Options)) (*lm.CheckInLicenseOutput, error)
36+
ExtendLicenseConsumption(ctx context.Context, params *lm.ExtendLicenseConsumptionInput, optFns ...func(*lm.Options)) (*lm.ExtendLicenseConsumptionOutput, error)
37+
GetLicenseUsage(ctx context.Context, params *lm.GetLicenseUsageInput, optFns ...func(*lm.Options)) (*lm.GetLicenseUsageOutput, error)
38+
}
39+
40+
type stsClient interface {
41+
GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error)
42+
}
3343

3444
type client struct {
3545
acctNum string
36-
cfg aws.Config
37-
sts *sts.Client
38-
lm *lm.Client
46+
sts stsClient
47+
lm licenseManagerClient
3948
}
4049

4150
func NewClient(ctx context.Context) (Client, error) {
@@ -47,7 +56,6 @@ func NewClient(ctx context.Context) (Client, error) {
4756
logrus.Debugf("aws config region: %+v", cfg.Region)
4857

4958
c := &client{
50-
cfg: cfg,
5159
sts: sts.NewFromConfig(cfg),
5260
lm: lm.NewFromConfig(cfg),
5361
}
@@ -84,18 +92,32 @@ func (c *client) getAccountNumber(ctx context.Context) (string, error) {
8492
}
8593

8694
var (
87-
productSKUField = "ProductSKU"
88-
rancherProductSKU = "0b87d4fa-d1fe-41d8-830b-67d4ec381549"
89-
maxResults int32 = 1
95+
productSKUField = "ProductSKU"
96+
rancherProductSKUNonEmea = "0b87d4fa-d1fe-41d8-830b-67d4ec381549"
97+
rancherProductSKUEmea = "a303097d-1dc2-4548-8ea6-f46bb9842e21"
98+
maxResults int32 = 1
9099
)
91100

92101
func (c *client) GetRancherLicense(ctx context.Context) (*types.GrantedLicense, error) {
102+
license, err := c.getLicenseForProductID(ctx, rancherProductSKUNonEmea)
103+
if err != nil {
104+
// if we could not get the original license, attempt to retrieve the license for Emea countries
105+
license, newErr := c.getLicenseForProductID(ctx, rancherProductSKUEmea)
106+
if newErr != nil {
107+
return nil, fmt.Errorf("unable to get license for non-emea: %s, unable to get license for emea: %s", err.Error(), newErr.Error())
108+
}
109+
return license, nil
110+
}
111+
return license, nil
112+
}
113+
114+
func (c *client) getLicenseForProductID(ctx context.Context, productID string) (*types.GrantedLicense, error) {
93115
// per aws engineering, there should only ever be at most one license for a given product sku.
94116
input := &lm.ListReceivedLicensesInput{
95117
Filters: []types.Filter{
96118
{
97119
Name: &productSKUField,
98-
Values: []string{rancherProductSKU},
120+
Values: []string{productID},
99121
},
100122
},
101123
MaxResults: &maxResults,
@@ -106,7 +128,17 @@ func (c *client) GetRancherLicense(ctx context.Context) (*types.GrantedLicense,
106128
return nil, err
107129
}
108130

109-
return &res.Licenses[0], nil
131+
if len(res.Licenses) == 0 {
132+
return nil, fmt.Errorf("unable to find license for product id %s", productID)
133+
}
134+
135+
license := &res.Licenses[0]
136+
if license.ProductSKU == nil {
137+
// we expect this value to be set, but given that the value is a pointer we can't be sure
138+
license.ProductSKU = &productID
139+
}
140+
141+
return license, nil
110142
}
111143

112144
var (
@@ -119,6 +151,9 @@ const (
119151

120152
func (c *client) CheckoutRancherLicense(ctx context.Context, l types.GrantedLicense, entitlementAmt int) (*lm.CheckoutLicenseOutput, error) {
121153
if l.Issuer == nil || l.Issuer.KeyFingerprint == nil {
154+
if l.LicenseArn == nil {
155+
return nil, fmt.Errorf("license is missing arn and KeyFingerprint/Issuer")
156+
}
122157
return nil, fmt.Errorf("license %s must have a KeyFingerprint for checkout", *l.LicenseArn)
123158
}
124159

@@ -127,7 +162,7 @@ func (c *client) CheckoutRancherLicense(ctx context.Context, l types.GrantedLice
127162
res, err := c.lm.CheckoutLicense(ctx, &lm.CheckoutLicenseInput{
128163
CheckoutType: types.CheckoutTypeProvisional,
129164
ClientToken: &token,
130-
ProductSKU: &rancherProductSKU,
165+
ProductSKU: l.ProductSKU,
131166
KeyFingerprint: l.Issuer.KeyFingerprint,
132167
Entitlements: []types.EntitlementData{
133168
{

pkg/clients/aws/client_test.go

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package aws
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
const fakeAccountNum = "123456789101"
11+
12+
func TestGetRancherLicense(t *testing.T) {
13+
tests := []struct {
14+
name string // name of the test, to be displayed on failure
15+
hasNonEmeaLicense bool // if the account has a license for the non-EMEA product sku
16+
hasEmeaLicense bool // if the account has a licensed for the EMEA product sku
17+
includeProductSku bool // if the return from aws should include or exclude a product sku
18+
desiredLicense string // which license our client should pick - emea, non-emea, or nothing
19+
errDesired bool // if we wanted an error for this test case
20+
}{
21+
{
22+
name: "test non-emea license",
23+
hasNonEmeaLicense: true,
24+
hasEmeaLicense: false,
25+
includeProductSku: true,
26+
desiredLicense: rancherProductSKUNonEmea,
27+
errDesired: false,
28+
},
29+
{
30+
name: "test emea license",
31+
hasNonEmeaLicense: false,
32+
hasEmeaLicense: true,
33+
includeProductSku: true,
34+
desiredLicense: rancherProductSKUEmea,
35+
errDesired: false,
36+
},
37+
{
38+
name: "test non-emea + emea license - should not occur in reality",
39+
hasNonEmeaLicense: true,
40+
hasEmeaLicense: true,
41+
includeProductSku: true,
42+
desiredLicense: rancherProductSKUNonEmea,
43+
errDesired: false,
44+
},
45+
{
46+
name: "test no valid license",
47+
hasNonEmeaLicense: false,
48+
hasEmeaLicense: false,
49+
includeProductSku: false,
50+
desiredLicense: "",
51+
errDesired: true,
52+
},
53+
{
54+
name: "test no product sku non-emea",
55+
hasNonEmeaLicense: true,
56+
hasEmeaLicense: false,
57+
includeProductSku: false,
58+
desiredLicense: rancherProductSKUNonEmea,
59+
errDesired: false,
60+
},
61+
{
62+
name: "test no product sku emea",
63+
hasNonEmeaLicense: false,
64+
hasEmeaLicense: true,
65+
includeProductSku: false,
66+
desiredLicense: rancherProductSKUEmea,
67+
errDesired: false,
68+
},
69+
}
70+
for _, test := range tests {
71+
test := test
72+
t.Run(test.name, func(t *testing.T) {
73+
mockLMClient := mockLicenseManagerClient{}
74+
client := &client{
75+
acctNum: fakeAccountNum,
76+
lm: &mockLMClient,
77+
sts: &mockSTSClient{accountNumber: fakeAccountNum},
78+
}
79+
if test.hasNonEmeaLicense {
80+
mockLMClient.AddLicenseForSku(rancherProductSKUNonEmea, fakeAccountNum, test.includeProductSku)
81+
}
82+
if test.hasEmeaLicense {
83+
mockLMClient.AddLicenseForSku(rancherProductSKUEmea, fakeAccountNum, test.includeProductSku)
84+
}
85+
86+
license, err := client.GetRancherLicense(context.Background())
87+
if test.errDesired {
88+
assert.Error(t, err, "expected an error but err was nil")
89+
} else {
90+
assert.NoError(t, err, "no error was expected, but got an error")
91+
assert.NotNil(t, license, "expected a valid license but was nil")
92+
assert.Equal(t, *license.ProductSKU, test.desiredLicense, "received unexpected product sku")
93+
}
94+
})
95+
}
96+
}

pkg/clients/aws/mocks_test.go

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package aws
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"time"
7+
8+
lm "github.com/aws/aws-sdk-go-v2/service/licensemanager"
9+
"github.com/aws/aws-sdk-go-v2/service/licensemanager/types"
10+
"github.com/aws/aws-sdk-go-v2/service/sts"
11+
)
12+
13+
const timeFormat = time.RFC3339
14+
15+
type mockLicenseManagerClient struct {
16+
licenses map[string]types.GrantedLicense
17+
checkedOutLicenses map[string]licenseInfo
18+
licenseCounter int
19+
}
20+
21+
type mockSTSClient struct {
22+
accountNumber string
23+
}
24+
25+
type licenseInfo struct {
26+
checkOutInput lm.CheckoutLicenseInput
27+
expiryTime time.Time
28+
}
29+
30+
func (m *mockLicenseManagerClient) AddLicenseForSku(productSku string, accountNumber string, includeSkuInReturn bool) {
31+
if m.licenses == nil {
32+
m.licenses = map[string]types.GrantedLicense{}
33+
}
34+
licenseArn := fmt.Sprintf("arn:aws:license-manager::%s:license:l-%06d", accountNumber, m.licenseCounter)
35+
m.licenseCounter++
36+
license := types.GrantedLicense{
37+
LicenseArn: &licenseArn,
38+
}
39+
if includeSkuInReturn {
40+
license.ProductSKU = &productSku
41+
}
42+
m.licenses[productSku] = license
43+
}
44+
45+
func (m *mockLicenseManagerClient) Clear() {
46+
m.licenses = map[string]types.GrantedLicense{}
47+
m.checkedOutLicenses = map[string]licenseInfo{}
48+
m.licenseCounter = 0
49+
}
50+
51+
func (m *mockLicenseManagerClient) ListReceivedLicenses(ctx context.Context, params *lm.ListReceivedLicensesInput, optFns ...func(*lm.Options)) (*lm.ListReceivedLicensesOutput, error) {
52+
var productIDs []string
53+
for _, filter := range params.Filters {
54+
if *filter.Name == productSKUField {
55+
productIDs = filter.Values
56+
}
57+
}
58+
var licenses []types.GrantedLicense
59+
for _, productID := range productIDs {
60+
if license, ok := m.licenses[productID]; ok {
61+
licenses = append(licenses, license)
62+
}
63+
}
64+
return &lm.ListReceivedLicensesOutput{
65+
Licenses: licenses,
66+
}, nil
67+
}
68+
func (m *mockLicenseManagerClient) CheckoutLicense(ctx context.Context, params *lm.CheckoutLicenseInput, optFns ...func(*lm.Options)) (*lm.CheckoutLicenseOutput, error) {
69+
consumptionToken := params.ClientToken
70+
if consumptionToken == nil {
71+
return nil, fmt.Errorf("unable to checkout license, no consumption token provided")
72+
}
73+
expiryTime := time.Now().Add(time.Hour * 24)
74+
expiryTS := expiryTime.Format(timeFormat)
75+
m.checkedOutLicenses[*consumptionToken] = licenseInfo{
76+
checkOutInput: *params,
77+
expiryTime: expiryTime,
78+
}
79+
return &lm.CheckoutLicenseOutput{
80+
LicenseConsumptionToken: consumptionToken,
81+
Expiration: &expiryTS,
82+
}, nil
83+
}
84+
func (m *mockLicenseManagerClient) CheckInLicense(ctx context.Context, params *lm.CheckInLicenseInput, optFns ...func(*lm.Options)) (*lm.CheckInLicenseOutput, error) {
85+
if params.LicenseConsumptionToken == nil {
86+
return nil, fmt.Errorf("can't check in license without consumption token")
87+
}
88+
if _, ok := m.checkedOutLicenses[*params.LicenseConsumptionToken]; ok {
89+
return nil, fmt.Errorf("no license checked in for consumption token")
90+
}
91+
delete(m.checkedOutLicenses, *params.LicenseConsumptionToken)
92+
return nil, nil
93+
}
94+
func (m *mockLicenseManagerClient) ExtendLicenseConsumption(ctx context.Context, params *lm.ExtendLicenseConsumptionInput, optFns ...func(*lm.Options)) (*lm.ExtendLicenseConsumptionOutput, error) {
95+
token := params.LicenseConsumptionToken
96+
if token == nil {
97+
return nil, fmt.Errorf("no token provided, cannot extend checkout")
98+
}
99+
if _, ok := m.checkedOutLicenses[*token]; ok {
100+
return nil, fmt.Errorf("no license checked in for consumption token")
101+
}
102+
return nil, nil
103+
}
104+
func (m *mockLicenseManagerClient) GetLicenseUsage(ctx context.Context, params *lm.GetLicenseUsageInput, optFns ...func(*lm.Options)) (*lm.GetLicenseUsageOutput, error) {
105+
licenseArn := params.LicenseArn
106+
if licenseArn == nil {
107+
return nil, fmt.Errorf("license arn is missing but is required")
108+
}
109+
var entitlementUsage []types.EntitlementUsage
110+
for _, value := range m.checkedOutLicenses {
111+
// find only the check-outs for this license
112+
if license, ok := m.licenses[*value.checkOutInput.ProductSKU]; ok {
113+
if *license.LicenseArn == *licenseArn {
114+
// add each usage separately, no need to combine into one
115+
for _, data := range value.checkOutInput.Entitlements {
116+
entitlementUsage = append(entitlementUsage, types.EntitlementUsage{Name: data.Name, ConsumedValue: data.Value, Unit: data.Unit})
117+
}
118+
}
119+
}
120+
}
121+
return &lm.GetLicenseUsageOutput{
122+
LicenseUsage: &types.LicenseUsage{EntitlementUsages: entitlementUsage}}, nil
123+
}
124+
125+
func (m *mockSTSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) {
126+
return &sts.GetCallerIdentityOutput{Account: &m.accountNumber}, nil
127+
}

0 commit comments

Comments
 (0)