-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathprovider.go
More file actions
475 lines (408 loc) · 15.9 KB
/
provider.go
File metadata and controls
475 lines (408 loc) · 15.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
// Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package recipe
import (
"embed"
"fmt"
"io/fs"
"log/slog"
"os"
"path/filepath"
"strings"
"sync"
aicrerrors "github.com/NVIDIA/aicr/pkg/errors"
"gopkg.in/yaml.v3"
)
// DataProvider abstracts access to recipe data files.
// This allows layering external directories over embedded data.
type DataProvider interface {
// ReadFile reads a file by path (relative to data directory).
ReadFile(path string) ([]byte, error)
// WalkDir walks the directory tree rooted at root.
WalkDir(root string, fn fs.WalkDirFunc) error
// Source returns a description of where data came from (for debugging).
Source(path string) string
}
// EmbeddedDataProvider wraps an embed.FS to implement DataProvider.
type EmbeddedDataProvider struct {
fs embed.FS
prefix string // e.g., "data" to strip from paths
}
// NewEmbeddedDataProvider creates a provider from an embedded filesystem.
func NewEmbeddedDataProvider(efs embed.FS, prefix string) *EmbeddedDataProvider {
return &EmbeddedDataProvider{
fs: efs,
prefix: prefix,
}
}
// ReadFile reads a file from the embedded filesystem.
func (p *EmbeddedDataProvider) ReadFile(path string) ([]byte, error) {
fullPath := filepath.Join(p.prefix, path)
slog.Debug("reading file from embedded provider", "path", path, "fullPath", fullPath)
return p.fs.ReadFile(fullPath)
}
// WalkDir walks the embedded filesystem.
func (p *EmbeddedDataProvider) WalkDir(root string, fn fs.WalkDirFunc) error {
fullRoot := filepath.Join(p.prefix, root)
if fullRoot == "" {
fullRoot = "." // embed.FS expects "." for root
}
slog.Debug("walking embedded filesystem", "root", root, "fullRoot", fullRoot)
return fs.WalkDir(p.fs, fullRoot, func(path string, d fs.DirEntry, err error) error {
// Strip the prefix before passing to callback
var relPath string
if p.prefix == "" {
relPath = path
} else {
relPath = strings.TrimPrefix(path, p.prefix+"/")
if relPath == p.prefix {
relPath = ""
}
}
return fn(relPath, d, err)
})
}
// Source returns "embedded" for all paths.
func (p *EmbeddedDataProvider) Source(path string) string {
return sourceEmbedded
}
// LayeredDataProvider overlays an external directory on top of embedded data.
// For registryFileName: merges external components with embedded (external takes precedence).
// For all other files: external completely replaces embedded if present.
type LayeredDataProvider struct {
embedded *EmbeddedDataProvider
externalDir string
// Cached merged registry (computed once on first access)
mergedRegistryOnce sync.Once
mergedRegistry []byte
mergedRegistryErr error
// Track which files came from external (for debugging)
externalFiles map[string]bool
}
// LayeredProviderConfig configures the layered data provider.
type LayeredProviderConfig struct {
// ExternalDir is the path to the external data directory.
ExternalDir string
// MaxFileSize is the maximum allowed file size in bytes (default: 10MB).
MaxFileSize int64
// AllowSymlinks allows symlinks in the external directory (default: false).
AllowSymlinks bool
}
const (
// DefaultMaxFileSize is the default maximum file size (10MB).
DefaultMaxFileSize = 10 * 1024 * 1024
// sourceEmbedded is the source name for embedded files.
sourceEmbedded = "embedded"
// sourceExternal is the source name for external files.
sourceExternal = "external"
// registryFileName is the name of the component registry file.
registryFileName = "registry.yaml"
)
// NewLayeredDataProvider creates a provider that layers external data over embedded.
// Returns an error if:
// - External directory doesn't exist
// - External directory doesn't contain registryFileName
// - Path traversal is detected
// - File size exceeds limits
func NewLayeredDataProvider(embedded *EmbeddedDataProvider, config LayeredProviderConfig) (*LayeredDataProvider, error) {
slog.Debug("creating layered data provider",
"external_dir", config.ExternalDir,
"max_file_size", config.MaxFileSize,
"allow_symlinks", config.AllowSymlinks)
if config.MaxFileSize == 0 {
config.MaxFileSize = DefaultMaxFileSize
}
// Validate external directory exists
slog.Debug("validating external directory")
info, err := os.Stat(config.ExternalDir)
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeNotFound,
fmt.Sprintf("external data directory not found: %s", config.ExternalDir), err)
}
if !info.IsDir() {
return nil, aicrerrors.New(aicrerrors.ErrCodeInvalidRequest,
fmt.Sprintf("external data path is not a directory: %s", config.ExternalDir))
}
// Validate registryFileName exists in external directory
registryPath := filepath.Join(config.ExternalDir, registryFileName)
slog.Debug("checking for required registry file", "path", registryPath)
if _, statErr := os.Stat(registryPath); statErr != nil {
return nil, aicrerrors.New(aicrerrors.ErrCodeInvalidRequest,
fmt.Sprintf("%s is required in external data directory: %s", registryFileName, config.ExternalDir))
}
slog.Debug("registry file found", "path", registryPath)
// Validate external directory for security issues
slog.Debug("scanning external directory for security issues")
externalFiles := make(map[string]bool)
err = filepath.WalkDir(config.ExternalDir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to walk external directory", err)
}
if d.IsDir() {
return nil
}
// Get relative path
relPath, relErr := filepath.Rel(config.ExternalDir, path)
if relErr != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to get relative path", relErr)
}
// Check for path traversal
if strings.Contains(relPath, "..") {
return aicrerrors.New(aicrerrors.ErrCodeInvalidRequest,
fmt.Sprintf("path traversal detected: %s", relPath))
}
// Check for symlinks
if !config.AllowSymlinks {
info, lstatErr := os.Lstat(path)
if lstatErr != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to stat file", lstatErr)
}
if info.Mode()&os.ModeSymlink != 0 {
return aicrerrors.New(aicrerrors.ErrCodeInvalidRequest,
fmt.Sprintf("symlinks not allowed: %s", relPath))
}
}
// Check file size
info, statErr := d.Info()
if statErr != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to get file info", statErr)
}
if info.Size() > config.MaxFileSize {
return aicrerrors.New(aicrerrors.ErrCodeInvalidRequest,
fmt.Sprintf("file too large (%d bytes, max %d): %s", info.Size(), config.MaxFileSize, relPath))
}
externalFiles[relPath] = true
slog.Debug("discovered external file",
"path", relPath,
"size", info.Size())
return nil
})
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "external directory validation failed", err)
}
slog.Info("layered data provider initialized",
"external_dir", config.ExternalDir,
"external_files", len(externalFiles))
// Log all external files at debug level for troubleshooting
for path := range externalFiles {
slog.Debug("external file registered", "path", path)
}
return &LayeredDataProvider{
embedded: embedded,
externalDir: config.ExternalDir,
externalFiles: externalFiles,
}, nil
}
// ReadFile reads a file, checking external directory first.
// For registryFileName, returns merged content.
// For other files, external completely replaces embedded.
func (p *LayeredDataProvider) ReadFile(path string) ([]byte, error) {
slog.Debug("reading file from layered provider", "path", path)
// Special handling for registry file - merge instead of replace
if path == registryFileName {
slog.Debug("reading merged registry file")
return p.getMergedRegistry()
}
// Check external directory first
if p.externalFiles[path] {
externalPath := filepath.Join(p.externalDir, path)
data, err := os.ReadFile(externalPath)
if err != nil {
return nil, aicrerrors.Wrap(aicrerrors.ErrCodeInternal, fmt.Sprintf("failed to read external file %s", path), err)
}
slog.Debug("read from external data directory", "path", path)
return data, nil
}
// Fall back to embedded
slog.Debug("falling back to embedded data", "path", path)
return p.embedded.ReadFile(path)
}
// WalkDir walks both embedded and external directories.
// External files take precedence over embedded files.
func (p *LayeredDataProvider) WalkDir(root string, fn fs.WalkDirFunc) error {
slog.Debug("walking layered data directory", "root", root)
// Track files we've visited (to avoid duplicates)
visited := make(map[string]bool)
// Walk external directory first
externalRoot := filepath.Join(p.externalDir, root)
if _, err := os.Stat(externalRoot); err == nil {
slog.Debug("walking external directory", "path", externalRoot)
err := filepath.WalkDir(externalRoot, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to walk external directory", err)
}
relPath, relErr := filepath.Rel(p.externalDir, path)
if relErr != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to compute relative path", relErr)
}
// Strip root prefix if present
if root != "" {
relPath = strings.TrimPrefix(relPath, root+"/")
if relPath == root {
relPath = ""
}
}
visited[relPath] = true
slog.Debug("visiting external file", "path", relPath, "isDir", d.IsDir())
return fn(relPath, d, nil)
})
if err != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to walk external directory tree", err)
}
}
slog.Debug("walking embedded directory", "root", root)
// Walk embedded, skipping already-visited paths
return p.embedded.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to walk embedded directory", err)
}
if visited[path] {
slog.Debug("skipping embedded file (external takes precedence)", "path", path)
return nil // Skip, external takes precedence
}
slog.Debug("visiting embedded file", "path", path, "isDir", d.IsDir())
return fn(path, d, err)
})
}
// Source returns "external" or "embedded" depending on where the file comes from.
func (p *LayeredDataProvider) Source(path string) string {
var source string
switch {
case path == registryFileName:
source = "merged (" + sourceEmbedded + " + " + sourceExternal + ")"
case p.externalFiles[path]:
source = sourceExternal
default:
source = sourceEmbedded
}
slog.Debug("resolved file source", "path", path, "source", source)
return source
}
// getMergedRegistry returns the merged registryFileName content.
// External registry components are merged with embedded, with external taking precedence.
func (p *LayeredDataProvider) getMergedRegistry() ([]byte, error) {
p.mergedRegistryOnce.Do(func() {
slog.Debug("merging registry files")
// Load embedded registry
embeddedData, err := p.embedded.ReadFile(registryFileName)
if err != nil {
p.mergedRegistryErr = aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to read embedded registry", err)
return
}
var embeddedReg ComponentRegistry
if unmarshalErr := yaml.Unmarshal(embeddedData, &embeddedReg); unmarshalErr != nil {
p.mergedRegistryErr = aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to parse embedded registry", unmarshalErr)
return
}
// Load external registry
externalPath := filepath.Join(p.externalDir, registryFileName)
externalData, err := os.ReadFile(externalPath)
if err != nil {
p.mergedRegistryErr = aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to read external registry", err)
return
}
var externalReg ComponentRegistry
if unmarshalErr := yaml.Unmarshal(externalData, &externalReg); unmarshalErr != nil {
p.mergedRegistryErr = aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to parse external registry", unmarshalErr)
return
}
// Validate schema version compatibility
if externalReg.APIVersion != "" && externalReg.APIVersion != embeddedReg.APIVersion {
slog.Warn("external registry has different API version",
"embedded", embeddedReg.APIVersion,
"external", externalReg.APIVersion)
}
// Merge: external components override embedded by name
merged := mergeRegistries(&embeddedReg, &externalReg)
// Serialize merged registry
p.mergedRegistry, p.mergedRegistryErr = yaml.Marshal(merged)
if p.mergedRegistryErr != nil {
p.mergedRegistryErr = aicrerrors.Wrap(aicrerrors.ErrCodeInternal, "failed to serialize merged registry", p.mergedRegistryErr)
return
}
slog.Info("merged component registries",
"embedded_components", len(embeddedReg.Components),
"external_components", len(externalReg.Components),
"merged_components", len(merged.Components))
})
return p.mergedRegistry, p.mergedRegistryErr
}
// mergeRegistries merges external registry into embedded.
// Components with the same name are replaced by external version.
// New components from external are added.
func mergeRegistries(embedded, external *ComponentRegistry) *ComponentRegistry {
slog.Debug("starting registry merge",
"embedded_count", len(embedded.Components),
"external_count", len(external.Components))
result := &ComponentRegistry{
APIVersion: embedded.APIVersion,
Kind: embedded.Kind,
Components: make([]ComponentConfig, 0, len(embedded.Components)+len(external.Components)),
}
// Index external components by name
externalByName := make(map[string]*ComponentConfig)
for i := range external.Components {
comp := &external.Components[i]
externalByName[comp.Name] = comp
slog.Debug("indexed external component", "name", comp.Name)
}
// Add embedded components, replacing with external if present
addedNames := make(map[string]bool)
for _, comp := range embedded.Components {
if ext, found := externalByName[comp.Name]; found {
result.Components = append(result.Components, *ext)
slog.Debug("component overridden from external", "name", comp.Name)
} else {
result.Components = append(result.Components, comp)
slog.Debug("component retained from embedded", "name", comp.Name)
}
addedNames[comp.Name] = true
}
// Add new components from external that aren't in embedded
for _, comp := range external.Components {
if !addedNames[comp.Name] {
result.Components = append(result.Components, comp)
slog.Debug("component added from external", "name", comp.Name)
}
}
return result
}
// Global data provider (defaults to embedded, can be set for layered)
var (
globalDataProvider DataProvider
dataProviderGeneration int // Incremented when provider changes
)
// SetDataProvider sets the global data provider.
// This should be called before any recipe operations if using external data.
// Note: This invalidates cached data, so callers should ensure this is called
// early in the application lifecycle.
func SetDataProvider(provider DataProvider) {
globalDataProvider = provider
dataProviderGeneration++
slog.Info("data provider set", "generation", dataProviderGeneration)
}
// GetDataProvider returns the global data provider.
// Returns the embedded provider if none was set.
func GetDataProvider() DataProvider {
if globalDataProvider == nil {
slog.Debug("initializing default embedded data provider")
globalDataProvider = NewEmbeddedDataProvider(GetEmbeddedFS(), "")
}
return globalDataProvider
}
// GetDataProviderGeneration returns the current data provider generation.
// This is used by caches to detect when they need to reload.
func GetDataProviderGeneration() int {
return dataProviderGeneration
}