|
1 | 1 | package config |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "io" |
| 5 | + "math" |
| 6 | + "net/http" |
4 | 7 | "os" |
5 | 8 | "path/filepath" |
| 9 | + "strings" |
6 | 10 | "testing" |
| 11 | + "time" |
| 12 | + |
| 13 | + "github.com/GMWalletApp/epusdt/util/http_client" |
| 14 | + "github.com/go-resty/resty/v2" |
| 15 | + "github.com/spf13/viper" |
7 | 16 | ) |
8 | 17 |
|
| 18 | +func installSettingsGetter(t *testing.T, values map[string]string) { |
| 19 | + t.Helper() |
| 20 | + |
| 21 | + oldGetter := SettingsGetString |
| 22 | + SettingsGetString = func(key string) string { |
| 23 | + return values[key] |
| 24 | + } |
| 25 | + t.Cleanup(func() { |
| 26 | + SettingsGetString = oldGetter |
| 27 | + }) |
| 28 | +} |
| 29 | + |
| 30 | +type roundTripFunc func(*http.Request) (*http.Response, error) |
| 31 | + |
| 32 | +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { |
| 33 | + return f(req) |
| 34 | +} |
| 35 | + |
| 36 | +func installMockHTTPClient(t *testing.T, handler roundTripFunc) { |
| 37 | + t.Helper() |
| 38 | + |
| 39 | + oldFactory := http_client.ClientFactory |
| 40 | + http_client.ClientFactory = func() *resty.Client { |
| 41 | + client := resty.NewWithClient(&http.Client{Transport: handler}) |
| 42 | + client.SetTimeout(10 * time.Second) |
| 43 | + return client |
| 44 | + } |
| 45 | + t.Cleanup(func() { |
| 46 | + http_client.ClientFactory = oldFactory |
| 47 | + }) |
| 48 | +} |
| 49 | + |
9 | 50 | func TestNormalizeConfiguredPathUsesExplicitFile(t *testing.T) { |
10 | 51 | t.Helper() |
11 | 52 |
|
@@ -135,3 +176,127 @@ func TestResolveConfigFilePathPrefersExplicitOverEnv(t *testing.T) { |
135 | 176 | t.Fatalf("config path = %s, want %s", got, flagPath) |
136 | 177 | } |
137 | 178 | } |
| 179 | + |
| 180 | +func TestGetUsdtRatePrefersPositiveAdminOverride(t *testing.T) { |
| 181 | + viper.Reset() |
| 182 | + t.Cleanup(viper.Reset) |
| 183 | + t.Setenv("API_RATE_URL", "") |
| 184 | + |
| 185 | + apiCalled := false |
| 186 | + installMockHTTPClient(t, func(r *http.Request) (*http.Response, error) { |
| 187 | + apiCalled = true |
| 188 | + return &http.Response{ |
| 189 | + StatusCode: http.StatusInternalServerError, |
| 190 | + Status: http.StatusText(http.StatusInternalServerError), |
| 191 | + Header: make(http.Header), |
| 192 | + Body: io.NopCloser(strings.NewReader("")), |
| 193 | + Request: r, |
| 194 | + }, nil |
| 195 | + }) |
| 196 | + |
| 197 | + installSettingsGetter(t, map[string]string{ |
| 198 | + "rate.forced_usdt_rate": "7.25", |
| 199 | + "rate.api_url": "https://rate.example.test", |
| 200 | + }) |
| 201 | + |
| 202 | + got := GetUsdtRate() |
| 203 | + if got != 7.25 { |
| 204 | + t.Fatalf("GetUsdtRate() = %v, want 7.25", got) |
| 205 | + } |
| 206 | + if apiCalled { |
| 207 | + t.Fatalf("rate API should not be called when rate.forced_usdt_rate > 0") |
| 208 | + } |
| 209 | +} |
| 210 | + |
| 211 | +func TestGetUsdtRateUsesAPIWhenAdminOverrideIsNotPositive(t *testing.T) { |
| 212 | + viper.Reset() |
| 213 | + t.Cleanup(viper.Reset) |
| 214 | + t.Setenv("API_RATE_URL", "") |
| 215 | + |
| 216 | + installMockHTTPClient(t, func(r *http.Request) (*http.Response, error) { |
| 217 | + if r.URL.Path != "/cny.json" { |
| 218 | + t.Fatalf("rate api path = %s, want /cny.json", r.URL.Path) |
| 219 | + } |
| 220 | + return &http.Response{ |
| 221 | + StatusCode: http.StatusOK, |
| 222 | + Status: "200 OK", |
| 223 | + Header: http.Header{"Content-Type": []string{"application/json"}}, |
| 224 | + Body: io.NopCloser(strings.NewReader(`{"cny":{"usdt":0.14635}}`)), |
| 225 | + Request: r, |
| 226 | + }, nil |
| 227 | + }) |
| 228 | + |
| 229 | + installSettingsGetter(t, map[string]string{ |
| 230 | + "rate.forced_usdt_rate": "-1", |
| 231 | + "rate.api_url": "https://rate.example.test", |
| 232 | + }) |
| 233 | + |
| 234 | + got := GetUsdtRate() |
| 235 | + want := 1 / 0.14635 |
| 236 | + if math.Abs(got-want) > 1e-9 { |
| 237 | + t.Fatalf("GetUsdtRate() = %v, want %v", got, want) |
| 238 | + } |
| 239 | + |
| 240 | + rate := GetRateForCoin("usdt", "cny") |
| 241 | + if math.Abs(rate-0.14635) > 1e-9 { |
| 242 | + t.Fatalf("GetRateForCoin(usdt, cny) = %v, want 0.14635", rate) |
| 243 | + } |
| 244 | +} |
| 245 | + |
| 246 | +func TestGetUsdtRateReturnsZeroWhenAPIUnavailableWithoutAdminOverride(t *testing.T) { |
| 247 | + viper.Reset() |
| 248 | + t.Cleanup(viper.Reset) |
| 249 | + t.Setenv("API_RATE_URL", "") |
| 250 | + |
| 251 | + installMockHTTPClient(t, func(r *http.Request) (*http.Response, error) { |
| 252 | + return &http.Response{ |
| 253 | + StatusCode: http.StatusBadGateway, |
| 254 | + Status: "502 Bad Gateway", |
| 255 | + Header: make(http.Header), |
| 256 | + Body: io.NopCloser(strings.NewReader("")), |
| 257 | + Request: r, |
| 258 | + }, nil |
| 259 | + }) |
| 260 | + |
| 261 | + installSettingsGetter(t, map[string]string{ |
| 262 | + "rate.forced_usdt_rate": "0", |
| 263 | + "rate.api_url": "https://rate.example.test", |
| 264 | + }) |
| 265 | + |
| 266 | + if got := GetUsdtRate(); got != 0 { |
| 267 | + t.Fatalf("GetUsdtRate() = %v, want 0", got) |
| 268 | + } |
| 269 | + if got := GetRateForCoin("usdt", "cny"); got != 0 { |
| 270 | + t.Fatalf("GetRateForCoin(usdt, cny) = %v, want 0", got) |
| 271 | + } |
| 272 | +} |
| 273 | + |
| 274 | +func TestGetRateForCoinCallsRateAPIOnceForUsdtCnyFailure(t *testing.T) { |
| 275 | + viper.Reset() |
| 276 | + t.Cleanup(viper.Reset) |
| 277 | + t.Setenv("API_RATE_URL", "") |
| 278 | + |
| 279 | + callCount := 0 |
| 280 | + installMockHTTPClient(t, func(r *http.Request) (*http.Response, error) { |
| 281 | + callCount++ |
| 282 | + return &http.Response{ |
| 283 | + StatusCode: http.StatusBadGateway, |
| 284 | + Status: "502 Bad Gateway", |
| 285 | + Header: make(http.Header), |
| 286 | + Body: io.NopCloser(strings.NewReader("")), |
| 287 | + Request: r, |
| 288 | + }, nil |
| 289 | + }) |
| 290 | + |
| 291 | + installSettingsGetter(t, map[string]string{ |
| 292 | + "rate.forced_usdt_rate": "0", |
| 293 | + "rate.api_url": "https://rate.example.test", |
| 294 | + }) |
| 295 | + |
| 296 | + if got := GetRateForCoin("usdt", "cny"); got != 0 { |
| 297 | + t.Fatalf("GetRateForCoin(usdt, cny) = %v, want 0", got) |
| 298 | + } |
| 299 | + if callCount != 1 { |
| 300 | + t.Fatalf("rate api call count = %d, want 1", callCount) |
| 301 | + } |
| 302 | +} |
0 commit comments