-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathhandler.go
More file actions
342 lines (300 loc) · 11.9 KB
/
handler.go
File metadata and controls
342 lines (300 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
// Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bundler
import (
"archive/zip"
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strconv"
corev1 "k8s.io/api/core/v1"
"github.com/NVIDIA/aicr/pkg/bundler/config"
"github.com/NVIDIA/aicr/pkg/bundler/result"
"github.com/NVIDIA/aicr/pkg/defaults"
aicrerrors "github.com/NVIDIA/aicr/pkg/errors"
"github.com/NVIDIA/aicr/pkg/recipe"
"github.com/NVIDIA/aicr/pkg/server"
"github.com/NVIDIA/aicr/pkg/snapshotter"
)
// DefaultBundleTimeout is the timeout for bundle generation.
// Exported for backwards compatibility; prefer using defaults.BundleHandlerTimeout.
const DefaultBundleTimeout = defaults.BundleHandlerTimeout
// HandleBundles processes bundle generation requests.
// It accepts a POST request with a JSON body containing the recipe (RecipeResult).
// Supports query parameters:
// - set: Value overrides in format "bundler:path.to.field=value" (can be repeated)
// - system-node-selector: Node selectors for system components in format "key=value" (can be repeated)
// - system-node-toleration: Tolerations for system components in format "key=value:effect" (can be repeated)
// - accelerated-node-selector: Node selectors for GPU nodes in format "key=value" (can be repeated)
// - accelerated-node-toleration: Tolerations for GPU nodes in format "key=value:effect" (can be repeated)
// - workload-gate: Taint for skyhook-operator runtime required in format "key=value:effect" or "key:effect"
// - workload-selector: Label selector for skyhook-customizations in format "key=value" (can be repeated)
// - nodes: Estimated number of GPU nodes (sets estimatedNodeCount in skyhook-operator; 0 = unset)
//
// The response is a zip archive containing the Helm per-component bundle:
// - README.md: Root deployment guide
// - deploy.sh: Automation script
// - recipe.yaml: Copy of the input recipe
// - <component>/values.yaml: Helm values per component
// - <component>/README.md: Component install/upgrade/uninstall
// - checksums.txt: SHA256 checksums of generated files
//
// Example:
//
// POST /v1/bundle?set=gpuoperator:gds.enabled=true
// Content-Type: application/json
// Body: { "apiVersion": "aicr.nvidia.com/v1alpha1", "kind": "Recipe", ... }
func (b *DefaultBundler) HandleBundles(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.Header().Set("Allow", http.MethodPost)
server.WriteError(w, r, http.StatusMethodNotAllowed, aicrerrors.ErrCodeMethodNotAllowed,
"Method not allowed", false, map[string]any{
"method": r.Method,
})
return
}
// Add request-scoped timeout
ctx, cancel := context.WithTimeout(r.Context(), DefaultBundleTimeout)
defer cancel()
// Parse all query parameters
params, err := parseQueryParams(r)
if err != nil {
server.WriteErrorFromErr(w, r, err, "Invalid query parameters", nil)
return
}
// Parse request body directly as RecipeResult
var recipeResult recipe.RecipeResult
err = json.NewDecoder(r.Body).Decode(&recipeResult)
if err != nil {
server.WriteError(w, r, http.StatusBadRequest, aicrerrors.ErrCodeInvalidRequest,
"Invalid request body", false, map[string]any{
"error": err.Error(),
})
return
}
// Validate recipe has component references
if len(recipeResult.ComponentRefs) == 0 {
server.WriteError(w, r, http.StatusBadRequest, aicrerrors.ErrCodeInvalidRequest,
"Recipe must contain at least one component reference", false, nil)
return
}
// Validate recipe criteria against allowlists (if configured)
if b.AllowLists != nil && recipeResult.Criteria != nil {
if validateErr := b.AllowLists.ValidateCriteria(recipeResult.Criteria); validateErr != nil {
server.WriteErrorFromErr(w, r, validateErr, "Recipe criteria value not allowed", nil)
return
}
}
slog.Debug("bundle request received",
"components", len(recipeResult.ComponentRefs),
"value_overrides", len(params.valueOverrides),
"system_node_selectors", len(params.systemNodeSelector),
"accelerated_node_selectors", len(params.acceleratedNodeSelector),
)
// Create temporary directory for bundle output
tempDir, err := os.MkdirTemp("", "aicr-bundle-*")
if err != nil {
server.WriteError(w, r, http.StatusInternalServerError, aicrerrors.ErrCodeInternal,
"Failed to create temporary directory", true, nil)
return
}
defer os.RemoveAll(tempDir) // Clean up on exit
// Create a new bundler with configuration
bundler, err := New(
WithConfig(config.NewConfig(
config.WithValueOverrides(params.valueOverrides),
config.WithSystemNodeSelector(params.systemNodeSelector),
config.WithSystemNodeTolerations(params.systemNodeTolerations),
config.WithAcceleratedNodeSelector(params.acceleratedNodeSelector),
config.WithAcceleratedNodeTolerations(params.acceleratedNodeTolerations),
config.WithWorkloadGateTaint(params.workloadGateTaint),
config.WithWorkloadSelector(params.workloadSelector),
config.WithEstimatedNodeCount(params.estimatedNodeCount),
config.WithDeployer(params.deployer),
config.WithRepoURL(params.repoURL),
)),
)
if err != nil {
server.WriteError(w, r, http.StatusInternalServerError, aicrerrors.ErrCodeInternal,
"Failed to create bundler", true, map[string]any{
"error": err.Error(),
})
return
}
// Generate bundle
output, err := bundler.Make(ctx, &recipeResult, tempDir)
if err != nil {
server.WriteErrorFromErr(w, r, err, "Failed to generate bundle", nil)
return
}
// Check for bundle errors
if output.HasErrors() {
errorDetails := make([]map[string]any, 0, len(output.Errors))
for _, be := range output.Errors {
errorDetails = append(errorDetails, map[string]any{
"bundler": be.BundlerType,
"error": be.Error,
})
}
server.WriteError(w, r, http.StatusInternalServerError, aicrerrors.ErrCodeInternal,
"Bundle generation failed", true, map[string]any{
"errors": errorDetails,
})
return
}
// Stream zip response
if err := streamZipResponse(w, tempDir, output); err != nil {
// Can't write error response if we've already started writing
slog.Error("failed to stream zip response", "error", err)
return
}
}
// streamZipResponse creates a zip archive from the output directory and streams it to the response.
func streamZipResponse(w http.ResponseWriter, dir string, output *result.Output) error {
// Set response headers before writing body
w.Header().Set("Content-Type", "application/zip")
w.Header().Set("Content-Disposition", "attachment; filename=\"bundles.zip\"")
w.Header().Set("X-Bundle-Files", strconv.Itoa(output.TotalFiles))
w.Header().Set("X-Bundle-Size", strconv.FormatInt(output.TotalSize, 10))
w.Header().Set("X-Bundle-Duration", output.TotalDuration.String())
// Create zip writer directly to response
zw := zip.NewWriter(w)
defer zw.Close()
// Walk the directory and add all files to zip
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "walk error", err)
}
// Skip the root directory itself
if path == dir {
return nil
}
// Get relative path for zip entry
relPath, err := filepath.Rel(dir, path)
if err != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to get relative path", err)
}
// Create zip file header
header, err := zip.FileInfoHeader(info)
if err != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to create file header", err)
}
header.Name = relPath
// Preserve directory structure
if info.IsDir() {
header.Name += "/"
_, headerErr := zw.CreateHeader(header)
return headerErr
}
// Use deflate compression
header.Method = zip.Deflate
writer, err := zw.CreateHeader(header)
if err != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to create zip entry", err)
}
// Open and copy file content
file, err := os.Open(path)
if err != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to open file", err)
}
_, copyErr := io.Copy(writer, file)
file.Close()
if copyErr != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to copy file content", copyErr)
}
return nil
})
}
// bundleParams holds parsed query parameters for bundle generation
type bundleParams struct {
valueOverrides map[string]map[string]string
systemNodeSelector map[string]string
systemNodeTolerations []corev1.Toleration
acceleratedNodeSelector map[string]string
acceleratedNodeTolerations []corev1.Toleration
workloadGateTaint *corev1.Taint
workloadSelector map[string]string
estimatedNodeCount int
deployer config.DeployerType
repoURL string
}
// parseQueryParams extracts and validates all query parameters from the request
func parseQueryParams(r *http.Request) (*bundleParams, error) {
query := r.URL.Query()
params := &bundleParams{}
var err error
// Parse value overrides
params.valueOverrides, err = config.ParseValueOverrides(query["set"])
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeInvalidRequest, "Invalid set parameter", err)
}
// Parse system node selectors
params.systemNodeSelector, err = snapshotter.ParseNodeSelectors(query["system-node-selector"])
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeInvalidRequest, "Invalid system-node-selector", err)
}
// Parse accelerated node selectors
params.acceleratedNodeSelector, err = snapshotter.ParseNodeSelectors(query["accelerated-node-selector"])
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeInvalidRequest, "Invalid accelerated-node-selector", err)
}
// Parse system node tolerations
params.systemNodeTolerations, err = snapshotter.ParseTolerations(query["system-node-toleration"])
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeInvalidRequest, "Invalid system-node-toleration", err)
}
// Parse accelerated node tolerations
params.acceleratedNodeTolerations, err = snapshotter.ParseTolerations(query["accelerated-node-toleration"])
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeInvalidRequest, "Invalid accelerated-node-toleration", err)
}
// Parse deployer type (helm, argocd)
deployerStr := query.Get("deployer")
if deployerStr == "" {
params.deployer = config.DeployerHelm // default
} else {
params.deployer, err = config.ParseDeployerType(deployerStr)
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeInvalidRequest, "Invalid deployer parameter", err)
}
}
// Parse repo URL (for ArgoCD deployer)
params.repoURL = query.Get("repo")
// Parse workload-gate taint
workloadGateStr := query.Get("workload-gate")
if workloadGateStr != "" {
params.workloadGateTaint, err = snapshotter.ParseTaint(workloadGateStr)
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeInvalidRequest, "Invalid workload-gate parameter", err)
}
}
// Parse workload-selector
params.workloadSelector, err = snapshotter.ParseNodeSelectors(query["workload-selector"])
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeInvalidRequest, "Invalid workload-selector parameter", err)
}
// Parse nodes (estimated node count; 0 = unset)
if nodesStr := query.Get("nodes"); nodesStr != "" {
n, parseErr := strconv.Atoi(nodesStr)
if parseErr != nil || n < 0 {
return nil, aicrerrors.New(aicrerrors.ErrCodeInvalidRequest, "nodes must be a non-negative integer")
}
params.estimatedNodeCount = n
}
return params, nil
}