Skip to content

Commit 1c79809

Browse files
authored
Refactor user data bootstrapping logic for Flex (#26)
* refactor: refine flex user data bootstrapping logic * feat: bump to new version
1 parent 7dc9d6a commit 1c79809

6 files changed

Lines changed: 251 additions & 18 deletions

File tree

cli/internal/config/nodebootstrap/nodebootstrap.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,20 @@ const (
2121

2222
var r = configcmd.NewRouter("node-bootstrap", "Generate a node bootstrap config for a remote cloud")
2323
var Command *cobra.Command = r.Command()
24-
var flagHasGPU bool
24+
var flagEnableNvidiaGPURuntime bool
2525
var flagVariant string
26+
var flagArch string
27+
var flagKubeVersion string
2628

2729
func init() {
2830
r.Handle("ubuntu", writeUbuntuUserData)
2931
r.Handle("flex", writeFlexUserData)
3032

31-
Command.Flags().BoolVar(&flagHasGPU, "gpu", false, "Indicates whether the node has GPU. This may affect the generated userdata.")
33+
Command.Flags().BoolVar(&flagEnableNvidiaGPURuntime, "nvidia-gpu", false, "Enable Nvidia GPU runtime in containerd configuration.")
34+
Command.Flags().StringVar(&flagArch, "arch", "amd64",
35+
"CPU architecture for the flex node binary (e.g. amd64, arm64).")
36+
Command.Flags().StringVar(&flagKubeVersion, "k8s-version", "1.33.3",
37+
"Kubernetes version for the downloaded binaries (e.g. 1.33.3).")
3238
Command.Flags().StringVar(&flagVariant, "variant", variantCloudInit,
3339
fmt.Sprintf("Output variant: %q produces cloud-init YAML user data, %q produces an equivalent standalone bash script.", variantCloudInit, variantScript))
3440
}
@@ -56,7 +62,12 @@ func marshalUserData(ud *cloudinit.UserData, w io.Writer) error {
5662
}
5763

5864
func writeFlexUserData(ctx context.Context, w io.Writer) error {
59-
ud, err := flex.UserData(flagHasGPU, "1.33.3", configcmd.DefaultKubeadmConfig(ctx))
65+
ud, err := flex.UserData(
66+
flex.WithEnableNvidiaGPURuntime(flagEnableNvidiaGPURuntime),
67+
flex.WithArch(flagArch),
68+
flex.WithKubeVersion(flagKubeVersion),
69+
flex.WithKubeadmConfig(configcmd.DefaultKubeadmConfig(ctx)),
70+
)
6071
if err != nil {
6172
return fmt.Errorf("generating flex userdata: %w", err)
6273
}

plugin/pkg/services/agentpools/azure/ubuntu2404vmss/agentpools.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ func (srv *agentpoolsServer) CreateOrUpdate(ctx context.Context, req *api.Create
7575
// if err != nil {
7676
// return nil, err
7777
// }
78-
userData, err := flex.UserData(false, "1.33.3", kubeadmConfig)
78+
userData, err := flex.UserData(
79+
flex.WithKubeadmConfig(kubeadmConfig),
80+
)
7981
if err != nil {
8082
return nil, err
8183
}

plugin/pkg/services/agentpools/nebius/instance/agentpools.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ func (srv *agentPoolsServer) CreateOrUpdate(
7373
// TODO: get gpu info from spec (might need to infer from SKU)
7474
hasGPU := strings.Contains(apSpec.GetImageFamily(), "cuda")
7575
// TODO: get the k8s version from spec
76-
ud, err := flex.UserData(hasGPU, "1.33.3", kubeadmConfig)
76+
ud, err := flex.UserData(
77+
flex.WithEnableNvidiaGPURuntime(hasGPU),
78+
flex.WithKubeadmConfig(kubeadmConfig),
79+
)
7780
if err != nil {
7881
return nil, fmt.Errorf("failed to generate userdata: %w", err)
7982
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
mkdir -p /tmp/flex
2+
curl -L -o /tmp/flex/aks-flex-node-linux-{{ .Arch }}.tar.gz https://github.com/Azure/AKSFlexNode/releases/download/{{ .Version }}/aks-flex-node-linux-{{ .Arch }}.tar.gz
3+
tar -xzf /tmp/flex/aks-flex-node-linux-{{ .Arch }}.tar.gz -C /tmp/flex
4+
mv /tmp/flex/aks-flex-node-linux-{{ .Arch }} /tmp/flex/aks-flex-node
5+
chmod +x /tmp/flex/aks-flex-node
6+
/tmp/flex/aks-flex-node apply -f /tmp/flex-config.json
7+
rm -rf /tmp/flex

plugin/pkg/services/agentpools/userdata/flex/flex.go

Lines changed: 110 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package flex
22

33
import (
4+
"bytes"
5+
_ "embed"
46
"encoding/json"
7+
"fmt"
58
"maps"
69
"strings"
10+
"text/template"
711

812
"github.com/Azure/AKSFlexNode/components/api"
913
"github.com/Azure/AKSFlexNode/components/cri"
@@ -17,6 +21,82 @@ import (
1721
"github.com/Azure/aks-flex/plugin/pkg/util/cloudinit"
1822
)
1923

24+
//go:embed assets/bootstrap.sh.tmpl
25+
var bootstrapTmpl string
26+
27+
var bootstrapTemplate = template.Must(template.New("bootstrap.sh").Parse(bootstrapTmpl))
28+
29+
const (
30+
flexNodeVersion = "v0.0.13"
31+
defaultArch = "amd64"
32+
defaultKubeVer = "1.33.3"
33+
)
34+
35+
// Options configures how the flex node userdata is generated.
36+
type Options struct {
37+
EnableNvidiaGPURuntime bool
38+
KubeVersion string
39+
Arch string
40+
KubeadmConfig *kubeadmapi.Config
41+
}
42+
43+
// Option is a functional option for [UserData].
44+
type Option func(*Options)
45+
46+
// WithEnableNvidiaGPURuntime configures the containerd Nvidia GPU runtime.
47+
func WithEnableNvidiaGPURuntime(enable bool) Option {
48+
return func(o *Options) { o.EnableNvidiaGPURuntime = enable }
49+
}
50+
51+
// WithKubeVersion sets the Kubernetes version for the downloaded binaries.
52+
func WithKubeVersion(v string) Option {
53+
return func(o *Options) { o.KubeVersion = v }
54+
}
55+
56+
// WithArch sets the CPU architecture for the flex node binary (e.g. "amd64", "arm64").
57+
func WithArch(arch string) Option {
58+
return func(o *Options) { o.Arch = arch }
59+
}
60+
61+
// WithKubeadmConfig sets the kubeadm join configuration.
62+
func WithKubeadmConfig(cfg *kubeadmapi.Config) Option {
63+
return func(o *Options) { o.KubeadmConfig = cfg }
64+
}
65+
66+
func defaultOptions() *Options {
67+
return &Options{
68+
KubeVersion: defaultKubeVer,
69+
Arch: defaultArch,
70+
}
71+
}
72+
73+
// supportedArchs is the set of CPU architectures for which flex node binaries
74+
// are published.
75+
var supportedArchs = map[string]bool{
76+
"amd64": true,
77+
"arm64": true,
78+
}
79+
80+
// validate performs least-effort validation on the options. This is intentionally
81+
// minimal to catch obvious mistakes for ad-hoc values; callers should perform
82+
// more thorough validation beforehand.
83+
func (o *Options) validate() error {
84+
if !supportedArchs[o.Arch] {
85+
return fmt.Errorf("unsupported arch %q, supported: amd64, arm64", o.Arch)
86+
}
87+
o.KubeVersion = strings.TrimPrefix(o.KubeVersion, "v")
88+
if o.KubeVersion == "" {
89+
return fmt.Errorf("kube version must not be empty")
90+
}
91+
return nil
92+
}
93+
94+
// bootstrapParams holds the template parameters for the bootstrap script.
95+
type bootstrapParams struct {
96+
Arch string
97+
Version string
98+
}
99+
20100
func flexMetadata[T proto.Message](name string) *api.Metadata {
21101
var zero T
22102
typeName := string(zero.ProtoReflect().Descriptor().FullName())
@@ -27,12 +107,12 @@ func flexMetadata[T proto.Message](name string) *api.Metadata {
27107
}
28108

29109
func resolveFlexComponentConfigs(
30-
hasGPU bool,
110+
enableNvidiaGPURuntime bool,
31111
kubeVersion string,
32112
kubeadmConfig *kubeadmapi.Config,
33113
) ([]byte, error) {
34114
startCRISpecBuilder := cri.StartContainerdServiceSpec_builder{}
35-
if hasGPU {
115+
if enableNvidiaGPURuntime {
36116
startCRISpecBuilder.GpuConfig = cri.GPUConfig_builder{
37117
NvidiaRuntime: cri.NvidiaRuntime_builder{}.Build(),
38118
}.Build()
@@ -104,8 +184,33 @@ func resolveFlexComponentConfigs(
104184
return b, nil
105185
}
106186

107-
func UserData(hasGPU bool, kubeVersion string, kubeadmConfig *kubeadmapi.Config) (*cloudinit.UserData, error) {
108-
componentConfigsJSON, err := resolveFlexComponentConfigs(hasGPU, kubeVersion, kubeadmConfig)
187+
func renderBootstrapScript(arch string) (string, error) {
188+
var buf bytes.Buffer
189+
if err := bootstrapTemplate.Execute(&buf, bootstrapParams{
190+
Arch: arch,
191+
Version: flexNodeVersion,
192+
}); err != nil {
193+
return "", fmt.Errorf("rendering bootstrap script: %w", err)
194+
}
195+
return buf.String(), nil
196+
}
197+
198+
func UserData(opts ...Option) (*cloudinit.UserData, error) {
199+
o := defaultOptions()
200+
for _, opt := range opts {
201+
opt(o)
202+
}
203+
204+
if err := o.validate(); err != nil {
205+
return nil, err
206+
}
207+
208+
componentConfigsJSON, err := resolveFlexComponentConfigs(o.EnableNvidiaGPURuntime, o.KubeVersion, o.KubeadmConfig)
209+
if err != nil {
210+
return nil, err
211+
}
212+
213+
bootstrapScript, err := renderBootstrapScript(o.Arch)
109214
if err != nil {
110215
return nil, err
111216
}
@@ -124,15 +229,7 @@ func UserData(hasGPU bool, kubeVersion string, kubeadmConfig *kubeadmapi.Config)
124229
},
125230
RunCmd: []any{
126231
[]string{"set", "-e"},
127-
strings.Join([]string{
128-
"mkdir -p /tmp/flex",
129-
// TODO: this should be overridable
130-
"curl -L -o /tmp/flex/aks-flex-node-linux-amd64.tar.gz https://github.com/Azure/AKSFlexNode/releases/download/v0.0.12/aks-flex-node-linux-amd64.tar.gz",
131-
"tar -xzf /tmp/flex/aks-flex-node-linux-amd64.tar.gz -C /tmp/flex",
132-
"mv /tmp/flex/aks-flex-node-linux-amd64 /tmp/flex/aks-flex-node",
133-
"chmod +x /tmp/flex/aks-flex-node",
134-
"/tmp/flex/aks-flex-node apply -f /tmp/flex-config.json",
135-
}, "\n"),
232+
bootstrapScript,
136233
},
137234
}
138235

plugin/pkg/services/agentpools/userdata/flex/flex_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package flex
22

33
import (
44
"encoding/json"
5+
"strings"
56
"testing"
67

78
kubeadm "github.com/Azure/aks-flex/plugin/pkg/services/agentpools/api/features/kubeadm"
@@ -23,3 +24,115 @@ func Test_resolveFlexComponentConfigs_basic(t *testing.T) {
2324
t.Fatalf("failed to unmarshal generated config: %v", err)
2425
}
2526
}
27+
28+
func TestUserData_defaults(t *testing.T) {
29+
kubeadmSpec := kubeadm.Config_builder{}.Build()
30+
31+
ud, err := UserData(WithKubeadmConfig(kubeadmSpec))
32+
if err != nil {
33+
t.Fatalf("unexpected error: %v", err)
34+
}
35+
36+
b, err := ud.Marshal()
37+
if err != nil {
38+
t.Fatalf("failed to marshal userdata: %v", err)
39+
}
40+
41+
content := string(b)
42+
// defaults should produce amd64 binary URL
43+
if !strings.Contains(content, "amd64") {
44+
t.Error("expected default arch amd64 in userdata")
45+
}
46+
}
47+
48+
func TestUserData_arm64(t *testing.T) {
49+
kubeadmSpec := kubeadm.Config_builder{}.Build()
50+
51+
ud, err := UserData(
52+
WithArch("arm64"),
53+
WithKubeadmConfig(kubeadmSpec),
54+
)
55+
if err != nil {
56+
t.Fatalf("unexpected error: %v", err)
57+
}
58+
59+
b, err := ud.Marshal()
60+
if err != nil {
61+
t.Fatalf("failed to marshal userdata: %v", err)
62+
}
63+
64+
content := string(b)
65+
if !strings.Contains(content, "arm64") {
66+
t.Error("expected arm64 in userdata")
67+
}
68+
if strings.Contains(content, "amd64") {
69+
t.Error("unexpected amd64 in userdata when arm64 was specified")
70+
}
71+
}
72+
73+
func TestUserData_invalidArch(t *testing.T) {
74+
kubeadmSpec := kubeadm.Config_builder{}.Build()
75+
76+
_, err := UserData(
77+
WithArch("mips64"),
78+
WithKubeadmConfig(kubeadmSpec),
79+
)
80+
if err == nil {
81+
t.Fatal("expected error for unsupported arch")
82+
}
83+
if !strings.Contains(err.Error(), "unsupported arch") {
84+
t.Errorf("unexpected error message: %v", err)
85+
}
86+
}
87+
88+
func TestUserData_invalidKubeVersion(t *testing.T) {
89+
kubeadmSpec := kubeadm.Config_builder{}.Build()
90+
91+
_, err := UserData(
92+
WithKubeVersion(""),
93+
WithKubeadmConfig(kubeadmSpec),
94+
)
95+
if err == nil {
96+
t.Fatal("expected error for empty kube version")
97+
}
98+
if !strings.Contains(err.Error(), "must not be empty") {
99+
t.Errorf("unexpected error message: %v", err)
100+
}
101+
}
102+
103+
func TestUserData_trimsLeadingV(t *testing.T) {
104+
kubeadmSpec := kubeadm.Config_builder{}.Build()
105+
106+
ud, err := UserData(
107+
WithKubeVersion("v1.33.3"),
108+
WithKubeadmConfig(kubeadmSpec),
109+
)
110+
if err != nil {
111+
t.Fatalf("unexpected error: %v", err)
112+
}
113+
114+
b, err := ud.Marshal()
115+
if err != nil {
116+
t.Fatalf("failed to marshal userdata: %v", err)
117+
}
118+
119+
content := string(b)
120+
if strings.Contains(content, "v1.33.3") {
121+
t.Error("expected leading 'v' to be trimmed from kube version")
122+
}
123+
if !strings.Contains(content, "1.33.3") {
124+
t.Error("expected kube version 1.33.3 in userdata")
125+
}
126+
}
127+
128+
func TestUserData_preReleaseKubeVersion(t *testing.T) {
129+
kubeadmSpec := kubeadm.Config_builder{}.Build()
130+
131+
_, err := UserData(
132+
WithKubeVersion("1.33.0-rc.1"),
133+
WithKubeadmConfig(kubeadmSpec),
134+
)
135+
if err != nil {
136+
t.Fatalf("expected pre-release kube version to be valid, got: %v", err)
137+
}
138+
}

0 commit comments

Comments
 (0)