Skip to content

Commit b83d926

Browse files
authored
feat: add GPU sharing settings to Node Template (#496)
1 parent 253351b commit b83d926

File tree

4 files changed

+204
-1
lines changed

4 files changed

+204
-1
lines changed

castai/resource_node_template.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ const (
7979
FieldNodeTemplateCPULimitEnabled = "cpu_limit_enabled"
8080
FieldNodeTemplateCPULimitMaxCores = "cpu_limit_max_cores"
8181
FieldNodeTemplateBareMetal = "bare_metal"
82+
FieldNodeTemplateEnableTimeSharing = "enable_time_sharing"
83+
FieldNodeTemplateDefaultSharedClientsPerGpu = "default_shared_clients_per_gpu"
84+
FieldNodeTemplateSharingConfiguration = "sharing_configuration"
85+
FieldNodeTemplateSharedClientsPerGpu = "shared_clients_per_gpu"
86+
FieldNodeTemplateSharedGpuName = "gpu_name"
8287
)
8388

8489
const (
@@ -645,6 +650,47 @@ func resourceNodeTemplate() *schema.Resource {
645650
Description: "Marks whether custom instances with extended memory should be used when deciding which parts of inventory are available. " +
646651
"Custom instances are only supported in GCP.",
647652
},
653+
FieldNodeTemplateGpu: {
654+
Type: schema.TypeList,
655+
MaxItems: 1,
656+
Optional: true,
657+
Description: "GPU configuration.",
658+
Elem: &schema.Resource{
659+
Schema: map[string]*schema.Schema{
660+
FieldNodeTemplateEnableTimeSharing: {
661+
Type: schema.TypeBool,
662+
Optional: true,
663+
Default: false,
664+
Description: "Enable/disable GPU time-sharing.",
665+
},
666+
FieldNodeTemplateDefaultSharedClientsPerGpu: {
667+
Type: schema.TypeInt,
668+
Optional: true,
669+
Default: 1,
670+
Description: "Defines default number of shared clients per GPU.",
671+
},
672+
FieldNodeTemplateSharingConfiguration: {
673+
Type: schema.TypeList,
674+
Optional: true,
675+
Description: "Defines GPU sharing configurations for GPU devices.",
676+
Elem: &schema.Resource{
677+
Schema: map[string]*schema.Schema{
678+
FieldNodeTemplateSharedGpuName: {
679+
Type: schema.TypeString,
680+
Required: true,
681+
Description: "GPU name.",
682+
},
683+
FieldNodeTemplateSharedClientsPerGpu: {
684+
Type: schema.TypeInt,
685+
Required: true,
686+
Description: "Defines number of shared clients for specific GPU device.",
687+
},
688+
},
689+
},
690+
},
691+
},
692+
},
693+
},
648694
},
649695
}
650696
}
@@ -719,9 +765,52 @@ func resourceNodeTemplateRead(ctx context.Context, d *schema.ResourceData, meta
719765
return diag.FromErr(fmt.Errorf("setting custom instances with extended memory enabled: %w", err))
720766
}
721767

768+
if nodeTemplate.Gpu != nil {
769+
gpu, err := flattenGpuSettings(nodeTemplate.Gpu)
770+
if err != nil {
771+
return diag.FromErr(fmt.Errorf("flattening gpu settings: %w", err))
772+
}
773+
774+
if err := d.Set(FieldNodeTemplateGpu, gpu); err != nil {
775+
return diag.FromErr(fmt.Errorf("setting gpu settings: %w", err))
776+
}
777+
}
722778
return nil
723779
}
724780

781+
func flattenGpuSettings(g *sdk.NodetemplatesV1GPU) ([]map[string]any, error) {
782+
if g == nil {
783+
return nil, nil
784+
}
785+
786+
out := make(map[string]any)
787+
788+
if g.EnableTimeSharing != nil {
789+
out[FieldNodeTemplateEnableTimeSharing] = g.EnableTimeSharing
790+
}
791+
792+
if g.DefaultSharedClientsPerGpu != nil {
793+
out[FieldNodeTemplateDefaultSharedClientsPerGpu] = g.DefaultSharedClientsPerGpu
794+
}
795+
796+
if g.SharingConfiguration != nil {
797+
sharingConfigurations := make([]map[string]any, 0)
798+
for gpuName, sc := range *g.SharingConfiguration {
799+
if sc.SharedClientsPerGpu != nil {
800+
sharingConfig := make(map[string]any)
801+
sharingConfig[FieldNodeTemplateSharedClientsPerGpu] = *sc.SharedClientsPerGpu
802+
sharingConfig[FieldNodeTemplateSharedGpuName] = gpuName
803+
804+
sharingConfigurations = append(sharingConfigurations, sharingConfig)
805+
}
806+
}
807+
if len(sharingConfigurations) > 0 {
808+
out[FieldNodeTemplateSharingConfiguration] = sharingConfigurations
809+
}
810+
}
811+
return []map[string]any{out}, nil
812+
}
813+
725814
func flattenConstraints(c *sdk.NodetemplatesV1TemplateConstraints) ([]map[string]any, error) {
726815
if c == nil {
727816
return nil, nil
@@ -981,6 +1070,12 @@ func updateNodeTemplate(ctx context.Context, d *schema.ResourceData, meta any, s
9811070
FieldNodeTemplateCustomInstancesWithExtendedMemoryEnabled,
9821071
FieldNodeTemplateConstraints,
9831072
FieldNodeTemplateIsEnabled,
1073+
FieldNodeTemplateGpu,
1074+
FieldNodeTemplateDefaultSharedClientsPerGpu,
1075+
FieldNodeTemplateEnableTimeSharing,
1076+
FieldNodeTemplateSharingConfiguration,
1077+
FieldNodeTemplateSharedGpuName,
1078+
FieldNodeTemplateSharedClientsPerGpu,
9841079
) {
9851080
log.Printf("[INFO] Nothing to update in node template")
9861081
return nil
@@ -1050,6 +1145,10 @@ func updateNodeTemplate(ctx context.Context, d *schema.ResourceData, meta any, s
10501145
req.CustomInstancesWithExtendedMemoryEnabled = lo.ToPtr(v.(bool))
10511146
}
10521147

1148+
if v, ok := d.Get(FieldNodeTemplateGpu).([]any); ok && len(v) > 0 {
1149+
req.Gpu = toTemplateGpu(v[0].(map[string]any))
1150+
}
1151+
10531152
resp, err := client.NodeTemplatesAPIUpdateNodeTemplateWithResponse(ctx, clusterID, name, req)
10541153
if checkErr := sdk.CheckOKResponse(resp, err); checkErr != nil {
10551154
return diag.FromErr(checkErr)
@@ -1123,6 +1222,10 @@ func resourceNodeTemplateCreate(ctx context.Context, d *schema.ResourceData, met
11231222
req.CustomInstancesEnabled = lo.ToPtr(v.(bool))
11241223
}
11251224

1225+
if v, ok := d.Get(FieldNodeTemplateGpu).([]any); ok && len(v) > 0 {
1226+
req.Gpu = toTemplateGpu(v[0].(map[string]any))
1227+
}
1228+
11261229
resp, err := client.NodeTemplatesAPICreateNodeTemplateWithResponse(ctx, clusterID, req)
11271230
if checkErr := sdk.CheckOKResponse(resp, err); checkErr != nil {
11281231
return diag.FromErr(checkErr)
@@ -1455,6 +1558,40 @@ func toTemplateConstraints(obj map[string]any) *sdk.NodetemplatesV1TemplateConst
14551558
return out
14561559
}
14571560

1561+
func toTemplateGpu(obj map[string]any) *sdk.NodetemplatesV1GPU {
1562+
if obj == nil {
1563+
return nil
1564+
}
1565+
1566+
out := &sdk.NodetemplatesV1GPU{}
1567+
if v, ok := obj[FieldNodeTemplateDefaultSharedClientsPerGpu].(int); ok {
1568+
out.DefaultSharedClientsPerGpu = toPtr(int32(v))
1569+
}
1570+
1571+
if v, ok := obj[FieldNodeTemplateEnableTimeSharing].(bool); ok {
1572+
out.EnableTimeSharing = toPtr(v)
1573+
}
1574+
1575+
if sharingConfiguration, ok := obj[FieldNodeTemplateSharingConfiguration].([]interface{}); ok {
1576+
outSharingConfiguration := make(map[string]sdk.NodetemplatesV1SharedGPU)
1577+
for _, configuration := range sharingConfiguration {
1578+
1579+
sharedGPUConfig := configuration.(map[string]interface{})
1580+
gpuName, gpuNameOk := sharedGPUConfig[FieldNodeTemplateSharedGpuName].(string)
1581+
sharedClientsPerGpu, sharedClientsPerGpuOk := sharedGPUConfig[FieldNodeTemplateSharedClientsPerGpu].(int)
1582+
if gpuNameOk && sharedClientsPerGpuOk {
1583+
outSharingConfiguration[gpuName] = sdk.NodetemplatesV1SharedGPU{
1584+
SharedClientsPerGpu: toPtr(int32(sharedClientsPerGpu)),
1585+
}
1586+
}
1587+
}
1588+
if len(outSharingConfiguration) > 0 {
1589+
out.SharingConfiguration = &outSharingConfiguration
1590+
}
1591+
}
1592+
return out
1593+
}
1594+
14581595
func toTemplateConstraintsInstanceFamilies(o map[string]any) *sdk.NodetemplatesV1TemplateConstraintsInstanceFamilyConstraints {
14591596
if o == nil {
14601597
return nil

castai/resource_node_template_test.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ func TestNodeTemplateResourceReadContext(t *testing.T) {
4646
"configurationId": "7dc4f922-29c9-4377-889c-0c8c5fb8d497",
4747
"configurationName": "default",
4848
"isEnabled": true,
49+
"gpu": {
50+
"enableTimeSharing": true,
51+
"defaultSharedClientsPerGpu": 10,
52+
"sharingConfiguration": {
53+
"A100": {
54+
"sharedClientsPerGpu": 5
55+
}
56+
}
57+
},
4958
"name": "gpu",
5059
"constraints": {
5160
"spot": false,
@@ -153,6 +162,7 @@ func TestNodeTemplateResourceReadContext(t *testing.T) {
153162
state.ID = "gpu"
154163

155164
data := resource.Data(state)
165+
//spew.Dump(data)
156166
result := resource.ReadContext(ctx, data, provider)
157167
r.Nil(result)
158168
r.False(result.HasError())
@@ -251,6 +261,12 @@ name = gpu
251261
rebalancing_config_min_nodes = 0
252262
should_taint = true
253263
Tainted = false
264+
gpu.# = 1
265+
gpu.0.default_shared_clients_per_gpu = 10
266+
gpu.0.enable_time_sharing = true
267+
gpu.0.sharing_configuration.# = 1
268+
gpu.0.sharing_configuration.0.gpu_name = A100
269+
gpu.0.sharing_configuration.0.shared_clients_per_gpu = 5
254270
`, "\n"),
255271
strings.Split(data.State().String(), "\n"),
256272
)
@@ -666,6 +682,8 @@ func TestAccResourceNodeTemplate_basic(t *testing.T) {
666682
resource.TestCheckResourceAttr(resourceName, "constraints.0.resource_limits.#", "1"),
667683
resource.TestCheckResourceAttr(resourceName, "constraints.0.resource_limits.0.cpu_limit_enabled", "false"),
668684
resource.TestCheckResourceAttr(resourceName, "constraints.0.resource_limits.0.cpu_limit_max_cores", "0"),
685+
resource.TestCheckResourceAttr(resourceName, "gpu.0.default_shared_clients_per_gpu", "1"),
686+
resource.TestCheckResourceAttr(resourceName, "gpu.0.enable_time_sharing", "false"),
669687
),
670688
},
671689
{
@@ -742,7 +760,8 @@ func TestAccResourceNodeTemplate_basic(t *testing.T) {
742760
resource.TestCheckResourceAttr(resourceName, "constraints.0.resource_limits.#", "1"),
743761
resource.TestCheckResourceAttr(resourceName, "constraints.0.resource_limits.0.cpu_limit_enabled", "true"),
744762
resource.TestCheckResourceAttr(resourceName, "constraints.0.resource_limits.0.cpu_limit_max_cores", "50"),
745-
resource.TestCheckResourceAttr(resourceName, "constraints.0.bare_metal", "false"),
763+
resource.TestCheckResourceAttr(resourceName, "gpu.0.default_shared_clients_per_gpu", "1"),
764+
resource.TestCheckResourceAttr(resourceName, "gpu.0.enable_time_sharing", "false"),
746765
),
747766
},
748767
},
@@ -789,6 +808,16 @@ func testAccNodeTemplateConfig(rName, clusterName string) string {
789808
key = "%[1]s-taint-key-4"
790809
}
791810
811+
gpu {
812+
default_shared_clients_per_gpu = 1
813+
enable_time_sharing = false
814+
815+
sharing_configuration {
816+
gpu_name = "L4"
817+
shared_clients_per_gpu = 8
818+
}
819+
}
820+
792821
constraints {
793822
fallback_restore_rate_seconds = 1800
794823
spot = true
@@ -854,6 +883,11 @@ func testNodeTemplateUpdated(rName, clusterName string) string {
854883
effect = "NoSchedule"
855884
}
856885
886+
gpu {
887+
default_shared_clients_per_gpu = 1
888+
enable_time_sharing = false
889+
}
890+
857891
constraints {
858892
use_spot_fallbacks = true
859893
spot = true

castai/sdk/api.gen.go

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/resources/node_template.md

Lines changed: 20 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)