Skip to content

Commit 21ced7d

Browse files
committed
feat: add GPU sharing settings to Node Template
1 parent 608f0eb commit 21ced7d

4 files changed

Lines changed: 213 additions & 19 deletions

File tree

castai/resource_node_template.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ const (
7979
FieldNodeTemplateCPULimitEnabled = "cpu_limit_enabled"
8080
FieldNodeTemplateCPULimitMaxCores = "cpu_limit_max_cores"
8181
FieldNodeTemplateBareMetal = "bare_metal"
82+
FieldNodeTemplateEnableTimeSharing = "enable_time_sharing"
83+
FieldNodeTemplateDefaultSharedClientsPerGpu = "default_shared_clients_per_gpu"
84+
FieldNodeTemplateSharingConfiguration = "sharing_configuration"
85+
FieldNodeTemplateSharedClientsPerGpu = "shared_clients_per_gpu"
86+
FieldNodeTemplateSharedGpuName = "gpu_name"
8287
)
8388

8489
const (
@@ -645,6 +650,45 @@ func resourceNodeTemplate() *schema.Resource {
645650
Description: "Marks whether custom instances with extended memory should be used when deciding which parts of inventory are available. " +
646651
"Custom instances are only supported in GCP.",
647652
},
653+
FieldNodeTemplateGpu: {
654+
Type: schema.TypeList,
655+
MaxItems: 1,
656+
Optional: true,
657+
Elem: &schema.Resource{
658+
Schema: map[string]*schema.Schema{
659+
FieldNodeTemplateEnableTimeSharing: {
660+
Type: schema.TypeBool,
661+
Optional: true,
662+
Description: "Enable/disable GPU time-sharing.",
663+
},
664+
FieldNodeTemplateDefaultSharedClientsPerGpu: {
665+
Type: schema.TypeInt,
666+
Optional: true,
667+
Default: 1,
668+
Description: "Defines default shared client per GPU.",
669+
},
670+
FieldNodeTemplateSharingConfiguration: {
671+
Type: schema.TypeList,
672+
Optional: true,
673+
Description: "Defines GPU sharing configurations for GPU devices.",
674+
Elem: &schema.Resource{
675+
Schema: map[string]*schema.Schema{
676+
FieldNodeTemplateSharedGpuName: {
677+
Type: schema.TypeString,
678+
Required: true,
679+
Description: "GPU name.",
680+
},
681+
FieldNodeTemplateSharedClientsPerGpu: {
682+
Type: schema.TypeInt,
683+
Required: true,
684+
Description: "Defines default shared clients per GPU.",
685+
},
686+
},
687+
},
688+
},
689+
},
690+
},
691+
},
648692
},
649693
}
650694
}
@@ -719,9 +763,52 @@ func resourceNodeTemplateRead(ctx context.Context, d *schema.ResourceData, meta
719763
return diag.FromErr(fmt.Errorf("setting custom instances with extended memory enabled: %w", err))
720764
}
721765

766+
if nodeTemplate.Gpu != nil {
767+
gpu, err := flattenGpuSettings(nodeTemplate.Gpu)
768+
if err != nil {
769+
return diag.FromErr(fmt.Errorf("flattening gpu settings: %w", err))
770+
}
771+
772+
if err := d.Set(FieldNodeTemplateGpu, gpu); err != nil {
773+
return diag.FromErr(fmt.Errorf("setting gpu settings: %w", err))
774+
}
775+
}
722776
return nil
723777
}
724778

779+
func flattenGpuSettings(g *sdk.NodetemplatesV1GPU) ([]map[string]any, error) {
780+
if g == nil {
781+
return nil, nil
782+
}
783+
784+
out := make(map[string]any)
785+
786+
if g.EnableTimeSharing != nil {
787+
out[FieldNodeTemplateEnableTimeSharing] = g.EnableTimeSharing
788+
}
789+
790+
if g.DefaultSharedClientsPerGpu != nil {
791+
out[FieldNodeTemplateDefaultSharedClientsPerGpu] = g.DefaultSharedClientsPerGpu
792+
}
793+
794+
if g.SharingConfiguration != nil {
795+
sharingConfigurations := make([]map[string]any, 0)
796+
for gpuName, sc := range *g.SharingConfiguration {
797+
if sc.SharedClientsPerGpu != nil {
798+
sharingConfig := make(map[string]any)
799+
sharingConfig[FieldNodeTemplateSharedClientsPerGpu] = *sc.SharedClientsPerGpu
800+
sharingConfig[FieldNodeTemplateSharedGpuName] = gpuName
801+
802+
sharingConfigurations = append(sharingConfigurations, sharingConfig)
803+
}
804+
}
805+
if len(sharingConfigurations) > 0 {
806+
out[FieldNodeTemplateSharingConfiguration] = sharingConfigurations
807+
}
808+
}
809+
return []map[string]any{out}, nil
810+
}
811+
725812
func flattenConstraints(c *sdk.NodetemplatesV1TemplateConstraints) ([]map[string]any, error) {
726813
if c == nil {
727814
return nil, nil
@@ -981,6 +1068,12 @@ func updateNodeTemplate(ctx context.Context, d *schema.ResourceData, meta any, s
9811068
FieldNodeTemplateCustomInstancesWithExtendedMemoryEnabled,
9821069
FieldNodeTemplateConstraints,
9831070
FieldNodeTemplateIsEnabled,
1071+
FieldNodeTemplateGpu,
1072+
FieldNodeTemplateDefaultSharedClientsPerGpu,
1073+
FieldNodeTemplateEnableTimeSharing,
1074+
FieldNodeTemplateSharingConfiguration,
1075+
FieldNodeTemplateSharedGpuName,
1076+
FieldNodeTemplateSharedClientsPerGpu,
9841077
) {
9851078
log.Printf("[INFO] Nothing to update in node template")
9861079
return nil
@@ -1050,6 +1143,10 @@ func updateNodeTemplate(ctx context.Context, d *schema.ResourceData, meta any, s
10501143
req.CustomInstancesWithExtendedMemoryEnabled = lo.ToPtr(v.(bool))
10511144
}
10521145

1146+
if v, ok := d.Get(FieldNodeTemplateGpu).([]any); ok && len(v) > 0 {
1147+
req.Gpu = toTemplateGpu(v[0].(map[string]any))
1148+
}
1149+
10531150
resp, err := client.NodeTemplatesAPIUpdateNodeTemplateWithResponse(ctx, clusterID, name, req)
10541151
if checkErr := sdk.CheckOKResponse(resp, err); checkErr != nil {
10551152
return diag.FromErr(checkErr)
@@ -1123,6 +1220,10 @@ func resourceNodeTemplateCreate(ctx context.Context, d *schema.ResourceData, met
11231220
req.CustomInstancesEnabled = lo.ToPtr(v.(bool))
11241221
}
11251222

1223+
if v, ok := d.Get(FieldNodeTemplateGpu).([]any); ok && len(v) > 0 {
1224+
req.Gpu = toTemplateGpu(v[0].(map[string]any))
1225+
}
1226+
11261227
resp, err := client.NodeTemplatesAPICreateNodeTemplateWithResponse(ctx, clusterID, req)
11271228
if checkErr := sdk.CheckOKResponse(resp, err); checkErr != nil {
11281229
return diag.FromErr(checkErr)
@@ -1455,6 +1556,40 @@ func toTemplateConstraints(obj map[string]any) *sdk.NodetemplatesV1TemplateConst
14551556
return out
14561557
}
14571558

1559+
func toTemplateGpu(obj map[string]any) *sdk.NodetemplatesV1GPU {
1560+
if obj == nil {
1561+
return nil
1562+
}
1563+
1564+
out := &sdk.NodetemplatesV1GPU{}
1565+
if v, ok := obj[FieldNodeTemplateDefaultSharedClientsPerGpu].(int); ok {
1566+
out.DefaultSharedClientsPerGpu = toPtr(int32(v))
1567+
}
1568+
1569+
if v, ok := obj[FieldNodeTemplateEnableTimeSharing].(bool); ok {
1570+
out.EnableTimeSharing = toPtr(v)
1571+
}
1572+
1573+
if sharingConfiguration, ok := obj[FieldNodeTemplateSharingConfiguration].([]interface{}); ok {
1574+
outSharingConfiguration := make(map[string]sdk.NodetemplatesV1SharedGPU)
1575+
for _, configuration := range sharingConfiguration {
1576+
1577+
sharedGPUConfig := configuration.(map[string]interface{})
1578+
gpuName, gpuNameOk := sharedGPUConfig[FieldNodeTemplateSharedGpuName].(string)
1579+
sharedClientsPerGpu, sharedClientsPerGpuOk := sharedGPUConfig[FieldNodeTemplateSharedClientsPerGpu].(int)
1580+
if gpuNameOk && sharedClientsPerGpuOk {
1581+
outSharingConfiguration[gpuName] = sdk.NodetemplatesV1SharedGPU{
1582+
SharedClientsPerGpu: toPtr(int32(sharedClientsPerGpu)),
1583+
}
1584+
}
1585+
}
1586+
if len(outSharingConfiguration) > 0 {
1587+
out.SharingConfiguration = &outSharingConfiguration
1588+
}
1589+
}
1590+
return out
1591+
}
1592+
14581593
func toTemplateConstraintsInstanceFamilies(o map[string]any) *sdk.NodetemplatesV1TemplateConstraintsInstanceFamilyConstraints {
14591594
if o == nil {
14601595
return nil

castai/resource_node_template_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ func TestNodeTemplateResourceReadContext(t *testing.T) {
4646
"configurationId": "7dc4f922-29c9-4377-889c-0c8c5fb8d497",
4747
"configurationName": "default",
4848
"isEnabled": true,
49+
"gpu": {
50+
"enableTimeSharing": true,
51+
"defaultSharedClientsPerGpu": 10,
52+
"sharingConfiguration": {
53+
"A100": {
54+
"sharedClientsPerGpu": 5
55+
}
56+
}
57+
},
4958
"name": "gpu",
5059
"constraints": {
5160
"spot": false,
@@ -153,6 +162,7 @@ func TestNodeTemplateResourceReadContext(t *testing.T) {
153162
state.ID = "gpu"
154163

155164
data := resource.Data(state)
165+
//spew.Dump(data)
156166
result := resource.ReadContext(ctx, data, provider)
157167
r.Nil(result)
158168
r.False(result.HasError())
@@ -251,6 +261,12 @@ name = gpu
251261
rebalancing_config_min_nodes = 0
252262
should_taint = true
253263
Tainted = false
264+
gpu.# = 1
265+
gpu.0.default_shared_clients_per_gpu = 10
266+
gpu.0.enable_time_sharing = true
267+
gpu.0.sharing_configuration.# = 1
268+
gpu.0.sharing_configuration.0.gpu_name = A100
269+
gpu.0.sharing_configuration.0.shared_clients_per_gpu = 5
254270
`, "\n"),
255271
strings.Split(data.State().String(), "\n"),
256272
)

castai/sdk/api.gen.go

Lines changed: 42 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/resources/node_template.md

Lines changed: 20 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)