Skip to content

Commit add4f87

Browse files
authored
feat: add Node Template option to configure user-managed GPU drivers (#639)
* feat: add Node Template option to configure user-managed GPU drivers * fix: only send user_managed_gpu_drivers when it is set to true * fix(tests): handle autoscaler policy version conflicts with wait.Backoff retry When node template and autoscaler policies are updated concurrently in the same terraform apply, the API rejects updates with: "node template has changed since the policies have been retrieved, refetch the policies and perform the update again" * fix: remove additional try of updatePolicies() fix: change constant name
1 parent 799638e commit add4f87

File tree

6 files changed

+438
-48
lines changed

6 files changed

+438
-48
lines changed

castai/resource_autoscaler.go

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io"
99
"log"
1010
"net/http"
11+
"strings"
1112
"time"
1213

1314
jsonpatch "github.com/evanphx/json-patch"
@@ -17,6 +18,7 @@ import (
1718
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
1819
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
1920
"github.com/mitchellh/mapstructure"
21+
"k8s.io/apimachinery/pkg/util/wait"
2022

2123
"github.com/castai/terraform-provider-castai/castai/sdk"
2224
"github.com/castai/terraform-provider-castai/castai/types"
@@ -805,23 +807,59 @@ func updateAutoscalerPolicies(ctx context.Context, data *schema.ResourceData, me
805807
return nil
806808
}
807809

808-
policies, err := getChangedPolicies(ctx, data, meta, clusterId)
809-
if err != nil {
810-
return err
810+
// Define the update operation that will be executed with retry logic
811+
updatePolicies := func() error {
812+
policies, err := getChangedPolicies(ctx, data, meta, clusterId)
813+
if err != nil {
814+
return err
815+
}
816+
817+
if policies == nil {
818+
log.Printf("[DEBUG] changed policies json not calculated. Skipping autoscaler policies changes")
819+
return nil
820+
}
821+
822+
changedPoliciesJSON := string(policies)
823+
if changedPoliciesJSON == "" {
824+
log.Printf("[DEBUG] changed policies json not found. Skipping autoscaler policies changes")
825+
return nil
826+
}
827+
828+
return upsertPolicies(ctx, meta, clusterId, changedPoliciesJSON)
811829
}
812830

813-
if policies == nil {
814-
log.Printf("[DEBUG] changed policies json not calculated. Skipping autoscaler policies changes")
815-
return nil
831+
// Exponential backoff configuration
832+
backoff := wait.Backoff{
833+
Duration: 100 * time.Millisecond,
834+
Factor: 2.0,
835+
Jitter: 0.1,
836+
Steps: 5,
837+
Cap: 2 * time.Second,
816838
}
817839

818-
changedPoliciesJSON := string(policies)
819-
if changedPoliciesJSON == "" {
820-
log.Printf("[DEBUG] changed policies json not found. Skipping autoscaler policies changes")
821-
return nil
840+
retryErr := wait.ExponentialBackoffWithContext(ctx, backoff, func(ctx context.Context) (done bool, err error) {
841+
err = updatePolicies()
842+
if err == nil {
843+
return true, nil // Success - stop retrying
844+
}
845+
846+
// Check if error is retryable
847+
if !isNodeTemplateVersionConflict(err) {
848+
return false, err // Non-retryable error - stop with error
849+
}
850+
851+
log.Printf("[DEBUG] Retry failed with version conflict: %v", err)
852+
return false, nil // Retryable error - continue retrying
853+
})
854+
855+
if retryErr != nil {
856+
if wait.Interrupted(retryErr) {
857+
return fmt.Errorf("timeout waiting for autoscaler policy update after version conflicts: %w", retryErr)
858+
}
859+
return retryErr
822860
}
823861

824-
return upsertPolicies(ctx, meta, clusterId, changedPoliciesJSON)
862+
return nil
825863
}
826864

827865
func upsertPolicies(ctx context.Context, meta interface{}, clusterId string, changedPoliciesJSON string) error {
@@ -835,6 +873,15 @@ func upsertPolicies(ctx context.Context, meta interface{}, clusterId string, cha
835873
return nil
836874
}
837875

876+
// isNodeTemplateVersionConflict checks if the error is due to version mismatch
877+
func isNodeTemplateVersionConflict(err error) bool {
878+
if err == nil {
879+
return false
880+
}
881+
errMsg := err.Error()
882+
return strings.Contains(errMsg, "template has changed") || strings.Contains(errMsg, "refetch the policies")
883+
}
884+
838885
func readAutoscalerPolicies(ctx context.Context, data *schema.ResourceData, meta interface{}) error {
839886
log.Printf("[INFO] AUTOSCALER policies get call start")
840887
defer log.Printf("[INFO] AUTOSCALER policies get call end")

castai/resource_node_template.go

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ const (
9090
FieldNodeTemplateEdgeLocationIDs = "edge_location_ids"
9191
FieldNodeTemplatePriceAdjustmentConfiguration = "price_adjustment_configuration"
9292
FieldNodeTemplateInstanceTypeAdjustments = "instance_type_adjustments"
93+
FieldNodeTemplateUserManagedGPUDrivers = "user_managed_gpu_drivers"
9394
)
9495

9596
const (
@@ -710,6 +711,12 @@ func resourceNodeTemplate() *schema.Resource {
710711
},
711712
},
712713
},
714+
FieldNodeTemplateUserManagedGPUDrivers: {
715+
Type: schema.TypeBool,
716+
Optional: true,
717+
Default: nil,
718+
Description: "Enable/disable user-managed GPU drivers (for GKE clusters only).",
719+
},
713720
},
714721
},
715722
},
@@ -871,6 +878,10 @@ func flattenGpuSettings(g *sdk.NodetemplatesV1GPU) ([]map[string]any, error) {
871878

872879
out := make(map[string]any)
873880

881+
if g.UserManagedGpuDrivers != nil {
882+
out[FieldNodeTemplateUserManagedGPUDrivers] = g.UserManagedGpuDrivers
883+
}
884+
874885
if g.EnableTimeSharing != nil {
875886
out[FieldNodeTemplateEnableTimeSharing] = g.EnableTimeSharing
876887
}
@@ -1170,6 +1181,7 @@ func updateNodeTemplate(ctx context.Context, d *schema.ResourceData, meta any, s
11701181
FieldNodeTemplateClmEnabled,
11711182
FieldNodeTemplateEdgeLocationIDs,
11721183
FieldNodeTemplatePriceAdjustmentConfiguration,
1184+
FieldNodeTemplateUserManagedGPUDrivers,
11731185
) {
11741186
log.Printf("[INFO] Nothing to update in node template")
11751187
return nil
@@ -1680,6 +1692,11 @@ func toTemplateGpu(obj map[string]any) *sdk.NodetemplatesV1GPU {
16801692
return nil
16811693
}
16821694

1695+
var userManagedGPUDrivers bool
1696+
if v, ok := obj[FieldNodeTemplateUserManagedGPUDrivers].(bool); ok {
1697+
userManagedGPUDrivers = v
1698+
}
1699+
16831700
var defaultSharedClientsPerGpu int32
16841701
if v, ok := obj[FieldNodeTemplateDefaultSharedClientsPerGpu].(int); ok {
16851702
defaultSharedClientsPerGpu = int32(v)
@@ -1711,16 +1728,29 @@ func toTemplateGpu(obj map[string]any) *sdk.NodetemplatesV1GPU {
17111728
// terraform treats nil values as zero values
17121729
// this condition checks whether the whole gpu configuration is deleted
17131730
// and gpu configuration should be set to nil
1714-
if defaultSharedClientsPerGpu == 0 &&
1715-
!enableTimeSharing &&
1716-
len(sharingConfig) == 0 {
1731+
if defaultSharedClientsPerGpu == 0 && !enableTimeSharing && len(sharingConfig) == 0 && !userManagedGPUDrivers {
17171732
return nil
17181733
}
1719-
return &sdk.NodetemplatesV1GPU{
1720-
DefaultSharedClientsPerGpu: &defaultSharedClientsPerGpu,
1721-
EnableTimeSharing: &enableTimeSharing,
1722-
SharingConfiguration: &sharingConfig,
1734+
1735+
result := &sdk.NodetemplatesV1GPU{
1736+
EnableTimeSharing: &enableTimeSharing,
1737+
SharingConfiguration: &sharingConfig,
1738+
}
1739+
1740+
// Only set DefaultSharedClientsPerGpu if it's non-zero to avoid API validation errors
1741+
// as terraform treats nil values as zero values
1742+
// API requires it to be in range (0, 48] if present
1743+
if defaultSharedClientsPerGpu > 0 {
1744+
result.DefaultSharedClientsPerGpu = &defaultSharedClientsPerGpu
17231745
}
1746+
1747+
// Only set UserManagedGpuDrivers if explicitly true
1748+
//as this is optional and as terraform treats nil values as zero values
1749+
if userManagedGPUDrivers {
1750+
result.UserManagedGpuDrivers = &userManagedGPUDrivers
1751+
}
1752+
1753+
return result
17241754
}
17251755

17261756
func toTemplateConstraintsInstanceFamilies(o map[string]any) *sdk.NodetemplatesV1TemplateConstraintsInstanceFamilyConstraints {

0 commit comments

Comments
 (0)