Skip to content

Commit e86a6f4

Browse files
committed
Refactor and enhance .env loading functionality
Refactored .env loading to support custom configurations via `LoadWithConfig` and implement new utility functions for type-safe environment variable retrieval (`GetBool`, `GetFloat`, `MustGetString`). Introduced depth limit, validation for env variable names, handling of comments, quoted values, and error reporting for invalid formats.
1 parent b5408a5 commit e86a6f4

2 files changed

Lines changed: 582 additions & 60 deletions

File tree

goenv.go

Lines changed: 197 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,90 @@
11
package goenv
22

33
import (
4+
"bufio"
5+
"fmt"
46
"os"
57
"path/filepath"
68
"strconv"
79
"strings"
810
)
911

1012
const (
11-
EnvFile = ".env"
13+
DefaultEnvFile = ".env"
14+
MaxDepth = 10 // Prevent infinite recursion
1215
)
1316

14-
// getPaths converts an array of directory entries into an array of file paths by concatenating each
15-
// entry's name with a base path.
16-
func getPaths(newDirs []os.DirEntry, basePath string) []string {
17-
paths := make([]string, len(newDirs))
17+
type Config struct {
18+
EnvFiles []string
19+
MaxDepth int
20+
StopOnFirst bool
21+
Prefix string
22+
}
1823

19-
for i, dir := range newDirs {
20-
paths[i] = filepath.Join(basePath, dir.Name())
24+
func DefaultConfig() *Config {
25+
return &Config{
26+
EnvFiles: []string{".env"},
27+
MaxDepth: MaxDepth,
28+
StopOnFirst: true,
29+
Prefix: "",
2130
}
31+
}
2232

23-
return paths
33+
// LoadWithConfig loads environment variables with custom configuration
34+
func LoadWithConfig(config *Config) error {
35+
if config == nil {
36+
config = DefaultConfig()
37+
}
38+
39+
return loadFromDirectory(".", config, 0, make(map[string]bool))
2440
}
2541

26-
// loadVarsFromFile parses an .env file at the given path and loads its variables into the environment.
27-
func loadVarsFromFile(path string) error {
28-
fileData, err := os.ReadFile(path)
42+
// Load provides backward compatibility with default behavior
43+
func Load() error {
44+
return LoadWithConfig(nil)
45+
}
46+
47+
// loadFromDirectory recursively searches for .env files with proper error handling and cycle detection
48+
func loadFromDirectory(dir string, config *Config, depth int, visited map[string]bool) error {
49+
if depth > config.MaxDepth {
50+
return nil
51+
}
52+
53+
// Get absolute path to detect cycles
54+
absPath, err := filepath.Abs(dir)
2955
if err != nil {
30-
return err
56+
return fmt.Errorf("failed to get absolute path for %s: %w", dir, err)
3157
}
3258

33-
for _, line := range strings.Split(string(fileData), "\n") {
34-
if line == "" || strings.HasPrefix(line, "#") {
35-
continue
59+
if visited[absPath] {
60+
return nil // Skip already visited directories
61+
}
62+
visited[absPath] = true
63+
64+
// Check for .env files in current directory
65+
for _, envFile := range config.EnvFiles {
66+
envPath := filepath.Join(dir, envFile)
67+
if fileExists(envPath) {
68+
if err := loadVarsFromFile(envPath); err != nil {
69+
return fmt.Errorf("failed to load %s: %w", envPath, err)
70+
}
71+
if config.StopOnFirst {
72+
return nil
73+
}
3674
}
75+
}
76+
77+
// Read directory entries
78+
entries, err := os.ReadDir(dir)
79+
if err != nil {
80+
return fmt.Errorf("failed to read directory %s: %w", dir, err)
81+
}
3782

38-
parts := strings.Split(line, "=")
39-
if len(parts) == 2 {
40-
key, value := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
41-
if err := os.Setenv(key, value); err != nil {
83+
// Recursively search subdirectories
84+
for _, entry := range entries {
85+
if entry.IsDir() && !strings.HasPrefix(entry.Name(), ".") {
86+
subDir := filepath.Join(dir, entry.Name())
87+
if err := loadFromDirectory(subDir, config, depth+1, visited); err != nil {
4288
return err
4389
}
4490
}
@@ -47,56 +93,159 @@ func loadVarsFromFile(path string) error {
4793
return nil
4894
}
4995

50-
// Load recursively scans all directories of a project until it finds a .env file. Once found, it reads
51-
// the file and loads its values as environment variables.
52-
func Load() error {
53-
dirsQueue := []string{"./"}
96+
// fileExists checks if a file exists and is not a directory
97+
func fileExists(path string) bool {
98+
info, err := os.Stat(path)
99+
return err == nil && !info.IsDir()
100+
}
101+
102+
// loadVarsFromFile parses an .env file with improved error handling and format support
103+
func loadVarsFromFile(path string) error {
104+
file, err := os.Open(path)
105+
if err != nil {
106+
return err
107+
}
108+
defer file.Close()
109+
110+
scanner := bufio.NewScanner(file)
111+
lineNum := 0
54112

55-
for len(dirsQueue) > 0 {
56-
path := dirsQueue[0]
113+
for scanner.Scan() {
114+
lineNum++
115+
line := strings.TrimSpace(scanner.Text())
57116

58-
file, err := os.Stat(path)
59-
if err != nil {
117+
// Skip empty lines and comments
118+
if line == "" || strings.HasPrefix(line, "#") {
119+
continue
120+
}
121+
122+
// Parse key=value pairs
123+
if err := parseAndSetEnvVar(line, path, lineNum); err != nil {
60124
return err
61125
}
126+
}
62127

63-
if file.IsDir() {
64-
children, err := os.ReadDir(path)
65-
if err != nil {
66-
return err
67-
}
68-
dirsQueue = append(dirsQueue, getPaths(children, path)...)
69-
} else if file.Name() == EnvFile {
70-
return loadVarsFromFile(path)
128+
return scanner.Err()
129+
}
130+
131+
// parseAndSetEnvVar parses a single environment variable line
132+
func parseAndSetEnvVar(line, filePath string, lineNum int) error {
133+
parts := strings.SplitN(line, "=", 2)
134+
if len(parts) != 2 {
135+
return fmt.Errorf("invalid format in %s at line %d: %s", filePath, lineNum, line)
136+
}
137+
138+
key := strings.TrimSpace(parts[0])
139+
value := strings.TrimSpace(parts[1])
140+
141+
// Handle quoted values
142+
if len(value) >= 2 {
143+
if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) ||
144+
(strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) {
145+
value = value[1 : len(value)-1]
71146
}
147+
}
72148

73-
dirsQueue = dirsQueue[1:]
149+
// Validate key format
150+
if !isValidEnvKey(key) {
151+
return fmt.Errorf("invalid environment variable name in %s at line %d: %s", filePath, lineNum, key)
74152
}
75153

76-
return nil
154+
return os.Setenv(key, value)
155+
}
156+
157+
// isValidEnvKey validates environment variable key format
158+
func isValidEnvKey(key string) bool {
159+
if key == "" {
160+
return false
161+
}
162+
163+
for i, r := range key {
164+
if i == 0 {
165+
if !((r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || r == '_') {
166+
return false
167+
}
168+
} else {
169+
if !((r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_') {
170+
return false
171+
}
172+
}
173+
}
174+
175+
return true
77176
}
78177

79-
// GetString returns the value of an environment variable if it exists; otherwise, it returns the fallback value.
178+
// GetString returns the value of an environment variable with fallback
80179
func GetString(key, fallback string) string {
81-
val, ok := os.LookupEnv(key)
82-
if !ok {
180+
if val, exists := os.LookupEnv(key); exists {
181+
return val
182+
}
183+
return fallback
184+
}
185+
186+
// GetInt returns the integer value of an environment variable with fallback
187+
func GetInt(key string, fallback int) int {
188+
val, exists := os.LookupEnv(key)
189+
if !exists {
83190
return fallback
84191
}
85192

86-
return val
193+
if intVal, err := strconv.Atoi(val); err == nil {
194+
return intVal
195+
}
196+
197+
return fallback
87198
}
88199

89-
// GetInt returns the integer value of an environment variable if it exists; otherwise, it returns the fallback value.
90-
func GetInt(key string, fallback int) int {
91-
val, ok := os.LookupEnv(key)
92-
if !ok {
200+
// GetBool returns the boolean value of an environment variable with fallback
201+
func GetBool(key string, fallback bool) bool {
202+
val, exists := os.LookupEnv(key)
203+
if !exists {
93204
return fallback
94205
}
95206

96-
intVal, err := strconv.Atoi(val)
97-
if err != nil {
207+
if boolVal, err := strconv.ParseBool(val); err == nil {
208+
return boolVal
209+
}
210+
211+
return fallback
212+
}
213+
214+
// GetFloat returns the float64 value of an environment variable with fallback
215+
func GetFloat(key string, fallback float64) float64 {
216+
val, exists := os.LookupEnv(key)
217+
if !exists {
98218
return fallback
99219
}
100220

101-
return intVal
221+
if floatVal, err := strconv.ParseFloat(val, 64); err == nil {
222+
return floatVal
223+
}
224+
225+
return fallback
226+
}
227+
228+
// MustGetString returns the value of an environment variable or panics if not found
229+
func MustGetString(key string) string {
230+
val, exists := os.LookupEnv(key)
231+
if !exists {
232+
panic(fmt.Sprintf("required environment variable %s not found", key))
233+
}
234+
return val
235+
}
236+
237+
// LoadFile loads a specific .env file
238+
func LoadFile(path string) error {
239+
return loadVarsFromFile(path)
240+
}
241+
242+
// Unload removes all environment variables loaded from .env files
243+
// Note: This is a simplified implementation - tracking loaded vars would be better
244+
func Unload(keys []string) error {
245+
for _, key := range keys {
246+
if err := os.Unsetenv(key); err != nil {
247+
return err
248+
}
249+
}
250+
return nil
102251
}

0 commit comments

Comments
 (0)