Skip to content

Commit 6385d62

Browse files
authored
feat: add custom get env functions (#4)
1 parent 4c30469 commit 6385d62

7 files changed

Lines changed: 1290 additions & 530 deletions

File tree

any.go

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ package goenvconf
22

33
import (
44
"encoding/json"
5-
"errors"
6-
"fmt"
75
"os"
86
)
97

@@ -46,10 +44,6 @@ func (ev *EnvAny) UnmarshalJSON(b []byte) error {
4644
return err
4745
}
4846

49-
if rawValue.Variable != nil && *rawValue.Variable == "" {
50-
return fmt.Errorf("EnvAny: %w", ErrEnvironmentVariableRequired)
51-
}
52-
5347
*ev = EnvAny(rawValue)
5448

5549
return nil
@@ -63,11 +57,7 @@ func (ev EnvAny) IsZero() bool {
6357

6458
// Get gets literal value or from system environment.
6559
func (ev EnvAny) Get() (any, error) {
66-
if ev.Variable != nil {
67-
if *ev.Variable == "" {
68-
return nil, fmt.Errorf("EnvAny: %w", ErrEnvironmentVariableRequired)
69-
}
70-
60+
if ev.Variable != nil && *ev.Variable != "" {
7161
rawValue := os.Getenv(*ev.Variable)
7262
if rawValue != "" {
7363
var result any
@@ -81,16 +71,22 @@ func (ev EnvAny) Get() (any, error) {
8171
return ev.Value, nil
8272
}
8373

84-
// GetOrDefault returns the default value if the environment value is empty.
85-
func (ev EnvAny) GetOrDefault(defaultValue any) (any, error) {
86-
result, err := ev.Get()
87-
if err != nil {
88-
if errors.Is(err, ErrEnvironmentVariableValueRequired) {
89-
return defaultValue, nil
74+
// GetCustom gets literal value or from system environment by a custom function.
75+
func (ev EnvAny) GetCustom(getFunc GetEnvFunc) (any, error) {
76+
if ev.Variable != nil && *ev.Variable != "" {
77+
rawValue, err := getFunc(*ev.Variable)
78+
if err != nil {
79+
return nil, err
9080
}
9181

92-
return false, err
82+
if rawValue != "" {
83+
var result any
84+
85+
err := json.Unmarshal([]byte(rawValue), &result)
86+
87+
return result, err
88+
}
9389
}
9490

95-
return result, nil
91+
return ev.Value, nil
9692
}

any_test.go

Lines changed: 114 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package goenvconf
22

33
import (
4-
"encoding/json"
4+
"errors"
55
"fmt"
66
"testing"
77
)
@@ -34,12 +34,6 @@ func TestEnvAny(t *testing.T) {
3434
Input: NewEnvAny("SOME_FOO_2", "baz"),
3535
Expected: "baz",
3636
},
37-
{
38-
Input: EnvAny{
39-
Variable: toPtr(""),
40-
},
41-
ErrorMsg: ErrEnvironmentVariableRequired.Error(),
42-
},
4337
}
4438

4539
for i, tc := range testCases {
@@ -56,12 +50,118 @@ func TestEnvAny(t *testing.T) {
5650
assertDeepEqual(t, tc.Input.IsZero(), tc.Expected == nil)
5751
})
5852
}
53+
}
54+
55+
// mockGetEnvFuncForAny creates a mock GetEnvFunc for EnvAny tests
56+
func mockGetEnvFuncForAny(values map[string]string, returnError bool) GetEnvFunc {
57+
return func(key string) (string, error) {
58+
if returnError {
59+
return "", errors.New("mock error: failed to get environment variable")
60+
}
61+
if val, ok := values[key]; ok {
62+
return val, nil
63+
}
64+
return "", nil
65+
}
66+
}
67+
68+
func TestEnvAny_GetCustom(t *testing.T) {
69+
testCases := []struct {
70+
Name string
71+
Input EnvAny
72+
GetFunc GetEnvFunc
73+
Expected any
74+
ErrorMsg string
75+
}{
76+
{
77+
Name: "literal_string_value",
78+
Input: NewEnvAnyValue("hello"),
79+
GetFunc: mockGetEnvFuncForAny(map[string]string{}, false),
80+
Expected: "hello",
81+
},
82+
{
83+
Name: "literal_number_value",
84+
Input: NewEnvAnyValue(42.5),
85+
GetFunc: mockGetEnvFuncForAny(map[string]string{}, false),
86+
Expected: 42.5,
87+
},
88+
{
89+
Name: "literal_map_value",
90+
Input: NewEnvAnyValue(map[string]any{"key": "value"}),
91+
GetFunc: mockGetEnvFuncForAny(map[string]string{}, false),
92+
Expected: map[string]any{"key": "value"},
93+
},
94+
{
95+
Name: "variable_from_custom_func_string",
96+
Input: NewEnvAnyVariable("CUSTOM_VAR"),
97+
GetFunc: mockGetEnvFuncForAny(map[string]string{"CUSTOM_VAR": `"test_string"`}, false),
98+
Expected: "test_string",
99+
},
100+
{
101+
Name: "variable_from_custom_func_number",
102+
Input: NewEnvAnyVariable("CUSTOM_VAR"),
103+
GetFunc: mockGetEnvFuncForAny(map[string]string{"CUSTOM_VAR": "123.45"}, false),
104+
Expected: 123.45,
105+
},
106+
{
107+
Name: "variable_from_custom_func_json_object",
108+
Input: NewEnvAnyVariable("CUSTOM_VAR"),
109+
GetFunc: mockGetEnvFuncForAny(map[string]string{"CUSTOM_VAR": `{"foo":"bar","num":42}`}, false),
110+
Expected: map[string]any{"foo": "bar", "num": float64(42)},
111+
},
112+
{
113+
Name: "variable_from_custom_func_json_array",
114+
Input: NewEnvAnyVariable("CUSTOM_VAR"),
115+
GetFunc: mockGetEnvFuncForAny(map[string]string{"CUSTOM_VAR": `[1,2,3]`}, false),
116+
Expected: []any{float64(1), float64(2), float64(3)},
117+
},
118+
{
119+
Name: "variable_with_fallback_value",
120+
Input: NewEnvAny("CUSTOM_VAR", "fallback"),
121+
GetFunc: mockGetEnvFuncForAny(map[string]string{"CUSTOM_VAR": `"custom"`}, false),
122+
Expected: "custom",
123+
},
124+
{
125+
Name: "empty_variable_uses_fallback",
126+
Input: NewEnvAny("EMPTY_VAR", "fallback"),
127+
GetFunc: mockGetEnvFuncForAny(map[string]string{"EMPTY_VAR": ""}, false),
128+
Expected: "fallback",
129+
},
130+
{
131+
Name: "nil_value_and_no_variable",
132+
Input: EnvAny{},
133+
GetFunc: mockGetEnvFuncForAny(map[string]string{}, false),
134+
Expected: nil,
135+
},
136+
{
137+
Name: "custom_func_error",
138+
Input: NewEnvAnyVariable("SOME_VAR"),
139+
GetFunc: mockGetEnvFuncForAny(map[string]string{}, true),
140+
ErrorMsg: "mock error",
141+
},
142+
{
143+
Name: "invalid_json_format",
144+
Input: NewEnvAnyVariable("INVALID_VAR"),
145+
GetFunc: mockGetEnvFuncForAny(map[string]string{"INVALID_VAR": `{invalid json`}, false),
146+
ErrorMsg: "invalid character",
147+
},
148+
{
149+
Name: "missing_variable_returns_nil",
150+
Input: NewEnvAnyVariable("MISSING_VAR"),
151+
GetFunc: mockGetEnvFuncForAny(map[string]string{}, false),
152+
Expected: nil,
153+
},
154+
}
59155

60-
t.Run("json_decode", func(t *testing.T) {
61-
var ev EnvAny
62-
assertNilError(t, json.Unmarshal([]byte(`{"env": "SOME_FOO"}`), &ev))
63-
result, err := ev.GetOrDefault(0)
64-
assertNilError(t, err)
65-
assertDeepEqual(t, float64(2.2), result)
66-
})
156+
for _, tc := range testCases {
157+
t.Run(tc.Name, func(t *testing.T) {
158+
result, err := tc.Input.GetCustom(tc.GetFunc)
159+
if tc.ErrorMsg != "" {
160+
assertErrorContains(t, err, tc.ErrorMsg)
161+
} else {
162+
assertNilError(t, err)
163+
assertDeepEqual(t, tc.Expected, result)
164+
}
165+
})
166+
}
67167
}

0 commit comments

Comments
 (0)