Skip to content

Commit e68df85

Browse files
committed
fix(azure-flex): handle mismatched agentpool types in GC paths
1 parent 0a0bfb2 commit e68df85

4 files changed

Lines changed: 186 additions & 10 deletions

File tree

karpenter/pkg/cloudproviders/azure/api.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ func IsNotFound(err error) bool {
2626
return false
2727
}
2828

29+
// IsTypeMismatch returns true if err indicates the plugin returned an object of
30+
// a different concrete protobuf type than the caller expected.
31+
func IsTypeMismatch(err error) bool {
32+
if err == nil {
33+
return false
34+
}
35+
s, ok := status.FromError(err)
36+
return ok && s.Code() == codes.InvalidArgument && strings.Contains(s.Message(), "type mismatch")
37+
}
38+
2939
// IsQuotaError returns true if err signals an Azure quota / capacity exhaustion.
3040
// We classify both HTTP 429 and the well-known Azure ARM error codes.
3141
func IsQuotaError(err error) bool {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package azure
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"google.golang.org/grpc/codes"
8+
"google.golang.org/grpc/status"
9+
)
10+
11+
func TestIsTypeMismatch(t *testing.T) {
12+
t.Parallel()
13+
14+
tests := []struct {
15+
name string
16+
err error
17+
want bool
18+
}{
19+
{
20+
name: "exact type mismatch status",
21+
err: status.Error(codes.InvalidArgument, "type mismatch"),
22+
want: true,
23+
},
24+
{
25+
name: "wrapped type mismatch status",
26+
err: fmt.Errorf("wrap: %w", status.Error(codes.InvalidArgument, "type mismatch")),
27+
want: true,
28+
},
29+
{
30+
name: "other invalid argument",
31+
err: status.Error(codes.InvalidArgument, "bad request"),
32+
want: false,
33+
},
34+
{
35+
name: "not found",
36+
err: status.Error(codes.NotFound, "not found"),
37+
want: false,
38+
},
39+
{
40+
name: "nil",
41+
err: nil,
42+
want: false,
43+
},
44+
}
45+
46+
for _, tc := range tests {
47+
tc := tc
48+
t.Run(tc.name, func(t *testing.T) {
49+
t.Parallel()
50+
if got := IsTypeMismatch(tc.err); got != tc.want {
51+
t.Fatalf("IsTypeMismatch(%v) = %v, want %v", tc.err, got, tc.want)
52+
}
53+
})
54+
}
55+
}

karpenter/pkg/cloudproviders/azure/cloudprovider.go

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@ import (
1818
"github.com/Azure/karpenter-provider-azure/pkg/utils"
1919
"github.com/awslabs/operatorpkg/status"
2020
"google.golang.org/grpc"
21+
"google.golang.org/grpc/codes"
22+
grpcstatus "google.golang.org/grpc/status"
2123
"k8s.io/apimachinery/pkg/runtime/schema"
2224
"sigs.k8s.io/controller-runtime/pkg/client"
2325
"sigs.k8s.io/controller-runtime/pkg/log"
2426
v1 "sigs.k8s.io/karpenter/pkg/apis/v1"
2527
corecloudprovider "sigs.k8s.io/karpenter/pkg/cloudprovider"
2628

29+
pluginapi "github.com/Azure/aks-flex/plugin/api"
2730
stretchhelper "github.com/Azure/aks-flex/plugin/pkg/helper"
2831
stretchservices "github.com/Azure/aks-flex/plugin/pkg/services"
2932
agentpoolsapi "github.com/Azure/aks-flex/plugin/pkg/services/agentpools/api"
@@ -50,6 +53,8 @@ type CloudProvider struct {
5053
instanceTypeProvider *instancetype.Provider
5154
}
5255

56+
var flexAgentPoolTypeURL = "type.googleapis.com/" + string((&flexvm.AgentPool{}).ProtoReflect().Descriptor().FullName())
57+
5358
func newCloudProvider(
5459
stretchPluginConn *grpc.ClientConn,
5560
kubeClient client.Client,
@@ -177,11 +182,8 @@ func (c *CloudProvider) Delete(ctx context.Context, nodeClaim *v1.NodeClaim) err
177182

178183
// Per CloudProvider.Delete contract: signal NodeClaimNotFoundError if the
179184
// remote resource is already gone (so karpenter knows it's safe to drop).
180-
if _, err := stretchhelper.Get[*flexvm.AgentPool](
181-
c.stretchAgentPoolsClient.Get,
182-
ctx, nodeClaim.Name,
183-
); err != nil {
184-
if IsNotFound(err) {
185+
if _, err := c.getFlexAgentPool(ctx, nodeClaim.Name); err != nil {
186+
if IsNotFound(err) || IsTypeMismatch(err) {
185187
return corecloudprovider.NewNodeClaimNotFoundError(err)
186188
}
187189
// Non-NotFound get failure: log and proceed with delete in best effort.
@@ -192,6 +194,9 @@ func (c *CloudProvider) Delete(ctx context.Context, nodeClaim *v1.NodeClaim) err
192194
c.stretchAgentPoolsClient.Delete,
193195
ctx, nodeClaim.Name,
194196
); err != nil {
197+
if IsNotFound(err) || IsTypeMismatch(err) {
198+
return corecloudprovider.NewNodeClaimNotFoundError(err)
199+
}
195200
return fmt.Errorf("deleting azure-flex agent pool: %w", err)
196201
}
197202
logger.Info("deleted azure-flex agent pool", "nodeClaim", nodeClaim.Name)
@@ -203,12 +208,9 @@ func (c *CloudProvider) Get(ctx context.Context, providerID string) (*v1.NodeCla
203208
if err != nil {
204209
return nil, err
205210
}
206-
ap, err := stretchhelper.Get[*flexvm.AgentPool](
207-
c.stretchAgentPoolsClient.Get,
208-
ctx, name,
209-
)
211+
ap, err := c.getFlexAgentPool(ctx, name)
210212
if err != nil {
211-
if IsNotFound(err) {
213+
if IsNotFound(err) || IsTypeMismatch(err) {
212214
return nil, corecloudprovider.NewNodeClaimNotFoundError(err)
213215
}
214216
return nil, err
@@ -219,6 +221,26 @@ func (c *CloudProvider) Get(ctx context.Context, providerID string) (*v1.NodeCla
219221
return agentPoolToNodeClaim(ap, nil), nil
220222
}
221223

224+
func (c *CloudProvider) getFlexAgentPool(ctx context.Context, id string) (*flexvm.AgentPool, error) {
225+
req := &pluginapi.GetRequest{}
226+
req.SetId(id)
227+
resp, err := c.stretchAgentPoolsClient.Get(ctx, req)
228+
if err != nil {
229+
return nil, err
230+
}
231+
return flexAgentPoolFromGetResponse(resp)
232+
}
233+
234+
func flexAgentPoolFromGetResponse(resp *pluginapi.GetResponse) (*flexvm.AgentPool, error) {
235+
if resp == nil || resp.GetItem() == nil {
236+
return nil, grpcstatus.Error(codes.NotFound, "")
237+
}
238+
if resp.GetItem().GetTypeUrl() != flexAgentPoolTypeURL {
239+
return nil, grpcstatus.Error(codes.NotFound, "")
240+
}
241+
return stretchhelper.AnyTo[*flexvm.AgentPool](resp.GetItem())
242+
}
243+
222244
func (c *CloudProvider) List(ctx context.Context) ([]*v1.NodeClaim, error) {
223245
aps, err := stretchhelper.List[*flexvm.AgentPool](
224246
c.stretchAgentPoolsClient.List,
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package azure
2+
3+
import (
4+
"testing"
5+
6+
pluginapi "github.com/Azure/aks-flex/plugin/api"
7+
"github.com/Azure/aks-flex/plugin/pkg/services/agentpools/azure/flexvm"
8+
"github.com/Azure/aks-flex/plugin/pkg/services/agentpools/azure/ubuntu2404vmss"
9+
"google.golang.org/protobuf/proto"
10+
"google.golang.org/protobuf/types/known/anypb"
11+
)
12+
13+
func TestFlexAgentPoolFromGetResponse(t *testing.T) {
14+
t.Parallel()
15+
16+
mkMeta := func(id string) *pluginapi.Metadata {
17+
return pluginapi.Metadata_builder{Id: proto.String(id)}.Build()
18+
}
19+
mkFlexResp := func(id string) *pluginapi.GetResponse {
20+
item, err := anypb.New(flexvm.AgentPool_builder{
21+
Metadata: mkMeta(id),
22+
}.Build())
23+
if err != nil {
24+
t.Fatalf("building flex anypb: %v", err)
25+
}
26+
return pluginapi.GetResponse_builder{Item: item}.Build()
27+
}
28+
mkVMSSResp := func(id string) *pluginapi.GetResponse {
29+
item, err := anypb.New(ubuntu2404vmss.AgentPool_builder{
30+
Metadata: mkMeta(id),
31+
}.Build())
32+
if err != nil {
33+
t.Fatalf("building vmss anypb: %v", err)
34+
}
35+
return pluginapi.GetResponse_builder{Item: item}.Build()
36+
}
37+
38+
tests := []struct {
39+
name string
40+
resp *pluginapi.GetResponse
41+
wantID string
42+
wantErr bool
43+
}{
44+
{
45+
name: "nil response is not found",
46+
resp: nil,
47+
wantErr: true,
48+
},
49+
{
50+
name: "nil item is not found",
51+
resp: pluginapi.GetResponse_builder{}.Build(),
52+
wantErr: true,
53+
},
54+
{
55+
name: "wrong item type is not found",
56+
resp: mkVMSSResp("node-1"),
57+
wantErr: true,
58+
},
59+
{
60+
name: "flex agentpool item returns parsed object",
61+
resp: mkFlexResp("node-2"),
62+
wantID: "node-2",
63+
},
64+
}
65+
66+
for _, tc := range tests {
67+
tc := tc
68+
t.Run(tc.name, func(t *testing.T) {
69+
t.Parallel()
70+
71+
got, err := flexAgentPoolFromGetResponse(tc.resp)
72+
if tc.wantErr {
73+
if err == nil {
74+
t.Fatalf("expected error, got nil")
75+
}
76+
if !IsNotFound(err) {
77+
t.Fatalf("expected NotFound-style error, got: %v", err)
78+
}
79+
return
80+
}
81+
if err != nil {
82+
t.Fatalf("unexpected error: %v", err)
83+
}
84+
if got.GetMetadata().GetId() != tc.wantID {
85+
t.Fatalf("got id %q, want %q", got.GetMetadata().GetId(), tc.wantID)
86+
}
87+
})
88+
}
89+
}

0 commit comments

Comments
 (0)