Skip to content

Commit 4bf4245

Browse files
authored
feat: Add support for GPU sharingStrategy in Node Template (#647)
1 parent 173d194 commit 4bf4245

File tree

4 files changed

+306
-26
lines changed

4 files changed

+306
-26
lines changed

castai/resource_node_template.go

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ const (
8686
FieldNodeTemplateSharingConfiguration = "sharing_configuration"
8787
FieldNodeTemplateSharedClientsPerGpu = "shared_clients_per_gpu"
8888
FieldNodeTemplateSharedGpuName = "gpu_name"
89+
FieldNodeTemplateSharingStrategy = "sharing_strategy"
8990
FieldNodeTemplateClmEnabled = "clm_enabled"
9091
FieldNodeTemplateEdgeLocationIDs = "edge_location_ids"
9192
FieldNodeTemplatePriceAdjustmentConfiguration = "price_adjustment_configuration"
@@ -679,11 +680,44 @@ func resourceNodeTemplate() *schema.Resource {
679680
DiffSuppressFunc: compareLists,
680681
Elem: &schema.Resource{
681682
Schema: map[string]*schema.Schema{
683+
FieldNodeTemplateSharingStrategy: {
684+
Type: schema.TypeString,
685+
Optional: true,
686+
ValidateDiagFunc: validation.ToDiagFunc(validation.StringInSlice([]string{"time-slicing", "mps"}, false)),
687+
Description: "GPU sharing strategy. Supported values: `time-slicing`, `mps`.",
688+
DiffSuppressFunc: func(k, oldVal, newVal string, d *schema.ResourceData) bool {
689+
// Suppress diff when sharing_strategy is not configured but enable_time_sharing=true
690+
// implies time-slicing, so the API returning "time-slicing" is not a real change.
691+
//
692+
// The DiffSuppressFunc receives k — the full key path of the field that's diffing, which looks like:
693+
// gpu.0.sharing_strategy
694+
// We need to check the sibling field enable_time_sharing at the same path:
695+
// gpu.0.enable_time_sharing
696+
// So the code:
697+
// prefix := k[:len(k)-len(FieldNodeTemplateSharingStrategy)]
698+
// // k = "gpu.0.sharing_strategy"
699+
// // len("sharing_strategy") chars stripped from end
700+
// // prefix = "gpu.0."
701+
//
702+
//Then:
703+
// d.GetOk(prefix + FieldNodeTemplateEnableTimeSharing)
704+
// // = d.GetOk("gpu.0." + "enable_time_sharing")
705+
// // = d.GetOk("gpu.0.enable_time_sharing")
706+
if newVal == "" && oldVal == "time-slicing" {
707+
prefix := k[:len(k)-len(FieldNodeTemplateSharingStrategy)]
708+
if v, ok := d.GetOk(prefix + FieldNodeTemplateEnableTimeSharing); ok && v.(bool) {
709+
return true
710+
}
711+
}
712+
return false
713+
},
714+
},
682715
FieldNodeTemplateEnableTimeSharing: {
683716
Type: schema.TypeBool,
684717
Optional: true,
685718
Default: nil,
686-
Description: "Enable/disable GPU time-sharing.",
719+
Deprecated: "Use sharing_strategy instead.",
720+
Description: "Enable/disable GPU time-sharing. Deprecated: use sharing_strategy = \"time-slicing\" instead.",
687721
},
688722
FieldNodeTemplateDefaultSharedClientsPerGpu: {
689723
Type: schema.TypeInt,
@@ -882,6 +916,10 @@ func flattenGpuSettings(g *sdk.NodetemplatesV1GPU) ([]map[string]any, error) {
882916
out[FieldNodeTemplateUserManagedGPUDrivers] = g.UserManagedGpuDrivers
883917
}
884918

919+
if g.SharingStrategy != nil {
920+
out[FieldNodeTemplateSharingStrategy] = gpuSharingStrategyToTerraform(*g.SharingStrategy)
921+
}
922+
885923
if g.EnableTimeSharing != nil {
886924
out[FieldNodeTemplateEnableTimeSharing] = g.EnableTimeSharing
887925
}
@@ -1175,6 +1213,7 @@ func updateNodeTemplate(ctx context.Context, d *schema.ResourceData, meta any, s
11751213
FieldNodeTemplateGpu,
11761214
FieldNodeTemplateDefaultSharedClientsPerGpu,
11771215
FieldNodeTemplateEnableTimeSharing,
1216+
FieldNodeTemplateSharingStrategy,
11781217
FieldNodeTemplateSharingConfiguration,
11791218
FieldNodeTemplateSharedGpuName,
11801219
FieldNodeTemplateSharedClientsPerGpu,
@@ -1702,6 +1741,12 @@ func toTemplateGpu(obj map[string]any) *sdk.NodetemplatesV1GPU {
17021741
defaultSharedClientsPerGpu = int32(v)
17031742
}
17041743

1744+
var sharingStrategy *sdk.NodetemplatesV1GPUSharingStrategy
1745+
if v, ok := obj[FieldNodeTemplateSharingStrategy].(string); ok && v != "" {
1746+
s := gpuSharingStrategyToAPI(v)
1747+
sharingStrategy = &s
1748+
}
1749+
17051750
var enableTimeSharing bool
17061751
if v, ok := obj[FieldNodeTemplateEnableTimeSharing].(bool); ok {
17071752
enableTimeSharing = v
@@ -1728,13 +1773,14 @@ func toTemplateGpu(obj map[string]any) *sdk.NodetemplatesV1GPU {
17281773
// terraform treats nil values as zero values
17291774
// this condition checks whether the whole gpu configuration is deleted
17301775
// and gpu configuration should be set to nil
1731-
if defaultSharedClientsPerGpu == 0 && !enableTimeSharing && len(sharingConfig) == 0 && !userManagedGPUDrivers {
1776+
if defaultSharedClientsPerGpu == 0 && !enableTimeSharing && sharingStrategy == nil && len(sharingConfig) == 0 && !userManagedGPUDrivers {
17321777
return nil
17331778
}
17341779

17351780
result := &sdk.NodetemplatesV1GPU{
17361781
EnableTimeSharing: &enableTimeSharing,
17371782
SharingConfiguration: &sharingConfig,
1783+
SharingStrategy: sharingStrategy,
17381784
}
17391785

17401786
// Only set DefaultSharedClientsPerGpu if it's non-zero to avoid API validation errors
@@ -1925,3 +1971,27 @@ func compareLists(key, oldValue, newValue string, d *schema.ResourceData) bool {
19251971
}
19261972
return false
19271973
}
1974+
1975+
// gpuSharingStrategyToAPI converts a terraform-friendly strategy string to the API enum value.
1976+
func gpuSharingStrategyToAPI(s string) sdk.NodetemplatesV1GPUSharingStrategy {
1977+
switch s {
1978+
case "mps":
1979+
return sdk.GPUSHARINGSTRATEGYMPS
1980+
case "time-slicing":
1981+
return sdk.GPUSHARINGSTRATEGYTIMESLICING
1982+
default:
1983+
return sdk.GPUSHARINGSTRATEGYUNSPECIFIED
1984+
}
1985+
}
1986+
1987+
// gpuSharingStrategyToTerraform converts the API enum value to a terraform-friendly string.
1988+
func gpuSharingStrategyToTerraform(s sdk.NodetemplatesV1GPUSharingStrategy) string {
1989+
switch s {
1990+
case sdk.GPUSHARINGSTRATEGYMPS:
1991+
return "mps"
1992+
case sdk.GPUSHARINGSTRATEGYTIMESLICING:
1993+
return "time-slicing"
1994+
default:
1995+
return ""
1996+
}
1997+
}

castai/resource_node_template_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ edge_location_ids.1 = b2c3d4e5-f6a7-8901-bcde-f12345678901
273273
gpu.# = 1
274274
gpu.0.default_shared_clients_per_gpu = 10
275275
gpu.0.enable_time_sharing = true
276+
gpu.0.sharing_strategy =
276277
gpu.0.user_managed_gpu_drivers = true
277278
gpu.0.sharing_configuration.# = 1
278279
gpu.0.sharing_configuration.0.gpu_name = A100
@@ -1334,6 +1335,45 @@ func Test_toTemplateGpu(t *testing.T) {
13341335
// UserManagedGpuDrivers not set (nil) - only sent when explicitly true (GKE-only)
13351336
},
13361337
},
1338+
{
1339+
name: "sharing_strategy mps",
1340+
input: map[string]any{
1341+
FieldNodeTemplateSharingStrategy: "mps",
1342+
FieldNodeTemplateDefaultSharedClientsPerGpu: 4,
1343+
},
1344+
want: &sdk.NodetemplatesV1GPU{
1345+
DefaultSharedClientsPerGpu: lo.ToPtr(int32(4)),
1346+
EnableTimeSharing: lo.ToPtr(false),
1347+
SharingConfiguration: &map[string]sdk.NodetemplatesV1SharedGPU{},
1348+
SharingStrategy: lo.ToPtr(sdk.GPUSHARINGSTRATEGYMPS),
1349+
},
1350+
},
1351+
{
1352+
name: "sharing_strategy time-slicing",
1353+
input: map[string]any{
1354+
FieldNodeTemplateSharingStrategy: "time-slicing",
1355+
FieldNodeTemplateDefaultSharedClientsPerGpu: 8,
1356+
},
1357+
want: &sdk.NodetemplatesV1GPU{
1358+
DefaultSharedClientsPerGpu: lo.ToPtr(int32(8)),
1359+
EnableTimeSharing: lo.ToPtr(false),
1360+
SharingConfiguration: &map[string]sdk.NodetemplatesV1SharedGPU{},
1361+
SharingStrategy: lo.ToPtr(sdk.GPUSHARINGSTRATEGYTIMESLICING),
1362+
},
1363+
},
1364+
{
1365+
name: "sharing_strategy mps with user_managed_gpu_drivers",
1366+
input: map[string]any{
1367+
FieldNodeTemplateSharingStrategy: "mps",
1368+
FieldNodeTemplateUserManagedGPUDrivers: true,
1369+
},
1370+
want: &sdk.NodetemplatesV1GPU{
1371+
EnableTimeSharing: lo.ToPtr(false),
1372+
SharingConfiguration: &map[string]sdk.NodetemplatesV1SharedGPU{},
1373+
SharingStrategy: lo.ToPtr(sdk.GPUSHARINGSTRATEGYMPS),
1374+
UserManagedGpuDrivers: lo.ToPtr(true),
1375+
},
1376+
},
13371377
}
13381378

13391379
for _, tt := range tests {
@@ -1356,6 +1396,7 @@ func Test_toTemplateGpu(t *testing.T) {
13561396
}
13571397

13581398
require.Equal(t, tt.want.EnableTimeSharing, got.EnableTimeSharing)
1399+
require.Equal(t, tt.want.SharingStrategy, got.SharingStrategy)
13591400
require.Equal(t, tt.want.UserManagedGpuDrivers, got.UserManagedGpuDrivers)
13601401

13611402
if tt.want.SharingConfiguration != nil && got.SharingConfiguration != nil {
@@ -1471,6 +1512,34 @@ func Test_flattenGpuSettings(t *testing.T) {
14711512
},
14721513
wantErr: false,
14731514
},
1515+
{
1516+
name: "sharing_strategy mps",
1517+
input: &sdk.NodetemplatesV1GPU{
1518+
SharingStrategy: lo.ToPtr(sdk.GPUSHARINGSTRATEGYMPS),
1519+
DefaultSharedClientsPerGpu: lo.ToPtr(int32(4)),
1520+
},
1521+
want: []map[string]any{
1522+
{
1523+
FieldNodeTemplateSharingStrategy: "mps",
1524+
FieldNodeTemplateDefaultSharedClientsPerGpu: lo.ToPtr(int32(4)),
1525+
},
1526+
},
1527+
wantErr: false,
1528+
},
1529+
{
1530+
name: "sharing_strategy time-slicing",
1531+
input: &sdk.NodetemplatesV1GPU{
1532+
SharingStrategy: lo.ToPtr(sdk.GPUSHARINGSTRATEGYTIMESLICING),
1533+
DefaultSharedClientsPerGpu: lo.ToPtr(int32(8)),
1534+
},
1535+
want: []map[string]any{
1536+
{
1537+
FieldNodeTemplateSharingStrategy: "time-slicing",
1538+
FieldNodeTemplateDefaultSharedClientsPerGpu: lo.ToPtr(int32(8)),
1539+
},
1540+
},
1541+
wantErr: false,
1542+
},
14741543
}
14751544

14761545
for _, tt := range tests {
@@ -1491,6 +1560,7 @@ func Test_flattenGpuSettings(t *testing.T) {
14911560
// Compare all fields except sharing_configuration
14921561
require.Equal(t, tt.want[0][FieldNodeTemplateUserManagedGPUDrivers], got[0][FieldNodeTemplateUserManagedGPUDrivers])
14931562
require.Equal(t, tt.want[0][FieldNodeTemplateEnableTimeSharing], got[0][FieldNodeTemplateEnableTimeSharing])
1563+
require.Equal(t, tt.want[0][FieldNodeTemplateSharingStrategy], got[0][FieldNodeTemplateSharingStrategy])
14941564
require.Equal(t, tt.want[0][FieldNodeTemplateDefaultSharedClientsPerGpu], got[0][FieldNodeTemplateDefaultSharedClientsPerGpu])
14951565

14961566
// Compare sharing_configuration with ElementsMatch (order doesn't matter)

0 commit comments

Comments
 (0)