Skip to content

Commit 5e3e4df

Browse files
committed
fix: remove unused safeOpenFile function and improve file operations security
1. Removed unused safeOpenFile function 2. Updated file operations to use safeReadFile 3. Simplified tree node finding and sorting logic 4. Improved generated file detection
1 parent ea29c14 commit 5e3e4df

File tree

1 file changed

+18
-84
lines changed

1 file changed

+18
-84
lines changed

reader.go

+18-84
Original file line numberDiff line numberDiff line change
@@ -98,47 +98,6 @@ func (r *DefaultReader) safeReadFile(path string) ([]byte, error) {
9898
return os.ReadFile(absPath)
9999
}
100100

101-
// safeOpenFile opens a file with security checks
102-
func (r *DefaultReader) safeOpenFile(path string) (*os.File, error) {
103-
if err := r.validatePath(path); err != nil {
104-
return nil, fmt.Errorf("invalid path: %w", err)
105-
}
106-
107-
// Get absolute path
108-
absPath := path
109-
if !filepath.IsAbs(path) {
110-
absPath = filepath.Join(r.workDir, path)
111-
}
112-
113-
// Clean the path
114-
absPath = filepath.Clean(absPath)
115-
116-
// Verify file exists and get info
117-
info, err := os.Stat(absPath)
118-
if err != nil {
119-
return nil, err
120-
}
121-
122-
// Check if it's a regular file
123-
if !info.Mode().IsRegular() {
124-
return nil, fmt.Errorf("not a regular file: %s", path)
125-
}
126-
127-
// Check file size
128-
if info.Size() > maxFileSize {
129-
return nil, fmt.Errorf("file too large: %s", path)
130-
}
131-
132-
// Check file extension for allowed types
133-
ext := strings.ToLower(filepath.Ext(path))
134-
if !isAllowedExtension(ext) {
135-
return nil, fmt.Errorf("unsupported file type: %s", ext)
136-
}
137-
138-
// Open file with read-only mode
139-
return os.OpenFile(absPath, os.O_RDONLY, 0)
140-
}
141-
142101
// GetFileTree returns the file tree starting from the given root
143102
func (r *DefaultReader) GetFileTree(ctx context.Context, root string, opts TreeOptions) (*FileTreeNode, error) {
144103
if err := r.validatePath(root); err != nil {
@@ -205,7 +164,7 @@ func (r *DefaultReader) GetFileTree(ctx context.Context, root string, opts TreeO
205164
return nil
206165
}
207166
case FileTypeGenerated:
208-
content, err := os.ReadFile(path)
167+
content, err := r.safeReadFile(path)
209168
if err != nil {
210169
return err
211170
}
@@ -237,19 +196,16 @@ func (r *DefaultReader) GetFileTree(ctx context.Context, root string, opts TreeO
237196
// Find parent node
238197
if path != absRoot {
239198
parentPath := filepath.Dir(relPath)
240-
parent := findParentNode(tree, parentPath)
241-
if parent != nil {
242-
parent.Children = append(parent.Children, node)
243-
sortTree(parent)
244-
return nil
199+
parentNode := findNode(tree, parentPath)
200+
if parentNode != nil {
201+
parentNode.Children = append(parentNode.Children, node)
202+
// Sort children by name
203+
sort.Slice(parentNode.Children, func(i, j int) bool {
204+
return parentNode.Children[i].Name < parentNode.Children[j].Name
205+
})
245206
}
246207
}
247208

248-
// If no parent found (should only happen for root), add to tree
249-
if path == absRoot {
250-
*tree = *node
251-
}
252-
253209
return nil
254210
})
255211

@@ -371,57 +327,35 @@ func (r *DefaultReader) ReadSourceFile(ctx context.Context, path string, opts Re
371327

372328
// isGeneratedFile checks if a file is generated based on its content
373329
func isGeneratedFile(content []byte) bool {
374-
// Common markers for generated files
330+
contentStr := string(content)
375331
markers := []string{
376-
"Code generated",
377-
"DO NOT EDIT",
332+
"Code generated", "DO NOT EDIT",
378333
"@generated",
379-
"Generated by",
334+
"// Generated by",
335+
"/* Generated by",
380336
}
381337

382-
contentStr := string(content)
383338
for _, marker := range markers {
384339
if strings.Contains(contentStr, marker) {
385340
return true
386341
}
387342
}
388-
389343
return false
390344
}
391345

392-
// findParentNode finds a parent node in the tree by path
393-
func findParentNode(root *FileTreeNode, parentPath string) *FileTreeNode {
394-
if root.Path == parentPath {
346+
// findNode finds a node in the tree by its path
347+
func findNode(root *FileTreeNode, path string) *FileTreeNode {
348+
if root.Path == path {
395349
return root
396350
}
351+
397352
for _, child := range root.Children {
398353
if child.Type == "directory" {
399-
if node := findParentNode(child, parentPath); node != nil {
354+
if node := findNode(child, path); node != nil {
400355
return node
401356
}
402357
}
403358
}
404-
return nil
405-
}
406359

407-
// sortTree sorts the children of a node by name
408-
func sortTree(node *FileTreeNode) {
409-
if node == nil || len(node.Children) == 0 {
410-
return
411-
}
412-
413-
sort.Slice(node.Children, func(i, j int) bool {
414-
// Directories come first
415-
if node.Children[i].Type != node.Children[j].Type {
416-
return node.Children[i].Type == "directory"
417-
}
418-
return node.Children[i].Name < node.Children[j].Name
419-
})
420-
421-
// Sort children recursively
422-
for _, child := range node.Children {
423-
if child.Type == "directory" {
424-
sortTree(child)
425-
}
426-
}
360+
return nil
427361
}

0 commit comments

Comments
 (0)