Skip to content

Commit 666fc4c

Browse files
committed
feat: Add support for GPU sharingStrategy in Node Template
1 parent 173d194 commit 666fc4c

File tree

3 files changed

+212
-9
lines changed

3 files changed

+212
-9
lines changed

castai/resource_node_template.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ const (
8686
FieldNodeTemplateSharingConfiguration = "sharing_configuration"
8787
FieldNodeTemplateSharedClientsPerGpu = "shared_clients_per_gpu"
8888
FieldNodeTemplateSharedGpuName = "gpu_name"
89+
FieldNodeTemplateSharingStrategy = "sharing_strategy"
8990
FieldNodeTemplateClmEnabled = "clm_enabled"
9091
FieldNodeTemplateEdgeLocationIDs = "edge_location_ids"
9192
FieldNodeTemplatePriceAdjustmentConfiguration = "price_adjustment_configuration"
@@ -679,11 +680,18 @@ func resourceNodeTemplate() *schema.Resource {
679680
DiffSuppressFunc: compareLists,
680681
Elem: &schema.Resource{
681682
Schema: map[string]*schema.Schema{
683+
FieldNodeTemplateSharingStrategy: {
684+
Type: schema.TypeString,
685+
Optional: true,
686+
ValidateDiagFunc: validation.ToDiagFunc(validation.StringInSlice([]string{"time-slicing", "mps"}, false)),
687+
Description: "GPU sharing strategy. Supported values: `time-slicing`, `mps`.",
688+
},
682689
FieldNodeTemplateEnableTimeSharing: {
683690
Type: schema.TypeBool,
684691
Optional: true,
685692
Default: nil,
686-
Description: "Enable/disable GPU time-sharing.",
693+
Deprecated: "Use sharing_strategy instead.",
694+
Description: "Enable/disable GPU time-sharing. Deprecated: use sharing_strategy = \"time-slicing\" instead.",
687695
},
688696
FieldNodeTemplateDefaultSharedClientsPerGpu: {
689697
Type: schema.TypeInt,
@@ -882,6 +890,10 @@ func flattenGpuSettings(g *sdk.NodetemplatesV1GPU) ([]map[string]any, error) {
882890
out[FieldNodeTemplateUserManagedGPUDrivers] = g.UserManagedGpuDrivers
883891
}
884892

893+
if g.SharingStrategy != nil {
894+
out[FieldNodeTemplateSharingStrategy] = gpuSharingStrategyToTerraform(*g.SharingStrategy)
895+
}
896+
885897
if g.EnableTimeSharing != nil {
886898
out[FieldNodeTemplateEnableTimeSharing] = g.EnableTimeSharing
887899
}
@@ -1175,6 +1187,7 @@ func updateNodeTemplate(ctx context.Context, d *schema.ResourceData, meta any, s
11751187
FieldNodeTemplateGpu,
11761188
FieldNodeTemplateDefaultSharedClientsPerGpu,
11771189
FieldNodeTemplateEnableTimeSharing,
1190+
FieldNodeTemplateSharingStrategy,
11781191
FieldNodeTemplateSharingConfiguration,
11791192
FieldNodeTemplateSharedGpuName,
11801193
FieldNodeTemplateSharedClientsPerGpu,
@@ -1702,6 +1715,12 @@ func toTemplateGpu(obj map[string]any) *sdk.NodetemplatesV1GPU {
17021715
defaultSharedClientsPerGpu = int32(v)
17031716
}
17041717

1718+
var sharingStrategy *sdk.NodetemplatesV1GPUSharingStrategy
1719+
if v, ok := obj[FieldNodeTemplateSharingStrategy].(string); ok && v != "" {
1720+
s := gpuSharingStrategyToAPI(v)
1721+
sharingStrategy = &s
1722+
}
1723+
17051724
var enableTimeSharing bool
17061725
if v, ok := obj[FieldNodeTemplateEnableTimeSharing].(bool); ok {
17071726
enableTimeSharing = v
@@ -1728,13 +1747,14 @@ func toTemplateGpu(obj map[string]any) *sdk.NodetemplatesV1GPU {
17281747
// terraform treats nil values as zero values
17291748
// this condition checks whether the whole gpu configuration is deleted
17301749
// and gpu configuration should be set to nil
1731-
if defaultSharedClientsPerGpu == 0 && !enableTimeSharing && len(sharingConfig) == 0 && !userManagedGPUDrivers {
1750+
if defaultSharedClientsPerGpu == 0 && !enableTimeSharing && sharingStrategy == nil && len(sharingConfig) == 0 && !userManagedGPUDrivers {
17321751
return nil
17331752
}
17341753

17351754
result := &sdk.NodetemplatesV1GPU{
17361755
EnableTimeSharing: &enableTimeSharing,
17371756
SharingConfiguration: &sharingConfig,
1757+
SharingStrategy: sharingStrategy,
17381758
}
17391759

17401760
// Only set DefaultSharedClientsPerGpu if it's non-zero to avoid API validation errors
@@ -1925,3 +1945,27 @@ func compareLists(key, oldValue, newValue string, d *schema.ResourceData) bool {
19251945
}
19261946
return false
19271947
}
1948+
1949+
// gpuSharingStrategyToAPI converts a terraform-friendly strategy string to the API enum value.
1950+
func gpuSharingStrategyToAPI(s string) sdk.NodetemplatesV1GPUSharingStrategy {
1951+
switch s {
1952+
case "mps":
1953+
return sdk.GPUSHARINGSTRATEGYMPS
1954+
case "time-slicing":
1955+
return sdk.GPUSHARINGSTRATEGYTIMESLICING
1956+
default:
1957+
return sdk.GPUSHARINGSTRATEGYUNSPECIFIED
1958+
}
1959+
}
1960+
1961+
// gpuSharingStrategyToTerraform converts the API enum value to a terraform-friendly string.
1962+
func gpuSharingStrategyToTerraform(s sdk.NodetemplatesV1GPUSharingStrategy) string {
1963+
switch s {
1964+
case sdk.GPUSHARINGSTRATEGYMPS:
1965+
return "mps"
1966+
case sdk.GPUSHARINGSTRATEGYTIMESLICING:
1967+
return "time-slicing"
1968+
default:
1969+
return ""
1970+
}
1971+
}

castai/resource_node_template_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ edge_location_ids.1 = b2c3d4e5-f6a7-8901-bcde-f12345678901
273273
gpu.# = 1
274274
gpu.0.default_shared_clients_per_gpu = 10
275275
gpu.0.enable_time_sharing = true
276+
gpu.0.sharing_strategy =
276277
gpu.0.user_managed_gpu_drivers = true
277278
gpu.0.sharing_configuration.# = 1
278279
gpu.0.sharing_configuration.0.gpu_name = A100
@@ -1334,6 +1335,45 @@ func Test_toTemplateGpu(t *testing.T) {
13341335
// UserManagedGpuDrivers not set (nil) - only sent when explicitly true (GKE-only)
13351336
},
13361337
},
1338+
{
1339+
name: "sharing_strategy mps",
1340+
input: map[string]any{
1341+
FieldNodeTemplateSharingStrategy: "mps",
1342+
FieldNodeTemplateDefaultSharedClientsPerGpu: 4,
1343+
},
1344+
want: &sdk.NodetemplatesV1GPU{
1345+
DefaultSharedClientsPerGpu: lo.ToPtr(int32(4)),
1346+
EnableTimeSharing: lo.ToPtr(false),
1347+
SharingConfiguration: &map[string]sdk.NodetemplatesV1SharedGPU{},
1348+
SharingStrategy: lo.ToPtr(sdk.GPUSHARINGSTRATEGYMPS),
1349+
},
1350+
},
1351+
{
1352+
name: "sharing_strategy time-slicing",
1353+
input: map[string]any{
1354+
FieldNodeTemplateSharingStrategy: "time-slicing",
1355+
FieldNodeTemplateDefaultSharedClientsPerGpu: 8,
1356+
},
1357+
want: &sdk.NodetemplatesV1GPU{
1358+
DefaultSharedClientsPerGpu: lo.ToPtr(int32(8)),
1359+
EnableTimeSharing: lo.ToPtr(false),
1360+
SharingConfiguration: &map[string]sdk.NodetemplatesV1SharedGPU{},
1361+
SharingStrategy: lo.ToPtr(sdk.GPUSHARINGSTRATEGYTIMESLICING),
1362+
},
1363+
},
1364+
{
1365+
name: "sharing_strategy mps with user_managed_gpu_drivers",
1366+
input: map[string]any{
1367+
FieldNodeTemplateSharingStrategy: "mps",
1368+
FieldNodeTemplateUserManagedGPUDrivers: true,
1369+
},
1370+
want: &sdk.NodetemplatesV1GPU{
1371+
EnableTimeSharing: lo.ToPtr(false),
1372+
SharingConfiguration: &map[string]sdk.NodetemplatesV1SharedGPU{},
1373+
SharingStrategy: lo.ToPtr(sdk.GPUSHARINGSTRATEGYMPS),
1374+
UserManagedGpuDrivers: lo.ToPtr(true),
1375+
},
1376+
},
13371377
}
13381378

13391379
for _, tt := range tests {
@@ -1356,6 +1396,7 @@ func Test_toTemplateGpu(t *testing.T) {
13561396
}
13571397

13581398
require.Equal(t, tt.want.EnableTimeSharing, got.EnableTimeSharing)
1399+
require.Equal(t, tt.want.SharingStrategy, got.SharingStrategy)
13591400
require.Equal(t, tt.want.UserManagedGpuDrivers, got.UserManagedGpuDrivers)
13601401

13611402
if tt.want.SharingConfiguration != nil && got.SharingConfiguration != nil {
@@ -1471,6 +1512,34 @@ func Test_flattenGpuSettings(t *testing.T) {
14711512
},
14721513
wantErr: false,
14731514
},
1515+
{
1516+
name: "sharing_strategy mps",
1517+
input: &sdk.NodetemplatesV1GPU{
1518+
SharingStrategy: lo.ToPtr(sdk.GPUSHARINGSTRATEGYMPS),
1519+
DefaultSharedClientsPerGpu: lo.ToPtr(int32(4)),
1520+
},
1521+
want: []map[string]any{
1522+
{
1523+
FieldNodeTemplateSharingStrategy: "mps",
1524+
FieldNodeTemplateDefaultSharedClientsPerGpu: lo.ToPtr(int32(4)),
1525+
},
1526+
},
1527+
wantErr: false,
1528+
},
1529+
{
1530+
name: "sharing_strategy time-slicing",
1531+
input: &sdk.NodetemplatesV1GPU{
1532+
SharingStrategy: lo.ToPtr(sdk.GPUSHARINGSTRATEGYTIMESLICING),
1533+
DefaultSharedClientsPerGpu: lo.ToPtr(int32(8)),
1534+
},
1535+
want: []map[string]any{
1536+
{
1537+
FieldNodeTemplateSharingStrategy: "time-slicing",
1538+
FieldNodeTemplateDefaultSharedClientsPerGpu: lo.ToPtr(int32(8)),
1539+
},
1540+
},
1541+
wantErr: false,
1542+
},
14741543
}
14751544

14761545
for _, tt := range tests {
@@ -1491,6 +1560,7 @@ func Test_flattenGpuSettings(t *testing.T) {
14911560
// Compare all fields except sharing_configuration
14921561
require.Equal(t, tt.want[0][FieldNodeTemplateUserManagedGPUDrivers], got[0][FieldNodeTemplateUserManagedGPUDrivers])
14931562
require.Equal(t, tt.want[0][FieldNodeTemplateEnableTimeSharing], got[0][FieldNodeTemplateEnableTimeSharing])
1563+
require.Equal(t, tt.want[0][FieldNodeTemplateSharingStrategy], got[0][FieldNodeTemplateSharingStrategy])
14941564
require.Equal(t, tt.want[0][FieldNodeTemplateDefaultSharedClientsPerGpu], got[0][FieldNodeTemplateDefaultSharedClientsPerGpu])
14951565

14961566
// Compare sharing_configuration with ElementsMatch (order doesn't matter)

0 commit comments

Comments
 (0)