diff --git a/src/main/java/io/github/doocs/im/model/response/BaseGenericResult.java b/src/main/java/io/github/doocs/im/model/response/BaseGenericResult.java new file mode 100644 index 0000000..2ad6009 --- /dev/null +++ b/src/main/java/io/github/doocs/im/model/response/BaseGenericResult.java @@ -0,0 +1,44 @@ +package io.github.doocs.im.model.response; + +import java.io.Serializable; + +public class BaseGenericResult extends GenericResult implements Serializable { + + private static final long serialVersionUID = -8713954419178432365L; + + + @Override + public String getActionStatus() { + return super.getActionStatus(); + } + + @Override + public void setActionStatus(String actionStatus) { + super.setActionStatus(actionStatus); + } + + @Override + public String getErrorInfo() { + return super.getErrorInfo(); + } + + @Override + public void setErrorInfo(String errorInfo) { + super.setErrorInfo(errorInfo); + } + + @Override + public Integer getErrorCode() { + return super.getErrorCode(); + } + + @Override + public void setErrorCode(Integer errorCode) { + super.setErrorCode(errorCode); + } + + @Override + public String toString() { + return super.toString(); + } +} diff --git a/src/main/java/io/github/doocs/im/util/HttpUtil.java b/src/main/java/io/github/doocs/im/util/HttpUtil.java index 6e4d54e..0fbac2c 100644 --- a/src/main/java/io/github/doocs/im/util/HttpUtil.java +++ b/src/main/java/io/github/doocs/im/util/HttpUtil.java @@ -1,6 +1,7 @@ package io.github.doocs.im.util; import io.github.doocs.im.ClientConfiguration; +import io.github.doocs.im.model.response.BaseGenericResult; import io.github.doocs.im.model.response.GenericResult; import okhttp3.*; @@ -32,7 +33,7 @@ public class HttpUtil { .writeTimeout(DEFAULT_CONFIG.getWriteTimeout(), TimeUnit.MILLISECONDS) .callTimeout(DEFAULT_CONFIG.getCallTimeout(), TimeUnit.MILLISECONDS) .retryOnConnectionFailure(false) - .addInterceptor(new RetryInterceptor(DEFAULT_CONFIG.getMaxRetries(), DEFAULT_CONFIG.getRetryIntervalMs(), DEFAULT_CONFIG.getBusinessRetryCodes(), DEFAULT_CONFIG.isEnableBusinessRetry())) + .addInterceptor(new RetryInterceptor(DEFAULT_CONFIG.getMaxRetries(), DEFAULT_CONFIG.getRetryIntervalMs(), DEFAULT_CONFIG.getBusinessRetryCodes(), DEFAULT_CONFIG.isEnableBusinessRetry(), BaseGenericResult.class)) .build(); private HttpUtil() { @@ -59,7 +60,7 @@ private static OkHttpClient getClient(ClientConfiguration config) { .writeTimeout(cfg.getWriteTimeout(), TimeUnit.MILLISECONDS) .callTimeout(cfg.getCallTimeout(), TimeUnit.MILLISECONDS) .retryOnConnectionFailure(false) - .addInterceptor(new RetryInterceptor(cfg.getMaxRetries(), cfg.getRetryIntervalMs(), DEFAULT_CONFIG.getBusinessRetryCodes(), DEFAULT_CONFIG.isEnableBusinessRetry())) + .addInterceptor(new RetryInterceptor(cfg.getMaxRetries(), cfg.getRetryIntervalMs(), cfg.getBusinessRetryCodes(), cfg.isEnableBusinessRetry(), BaseGenericResult.class)) .build()); } @@ -103,16 +104,20 @@ class RetryInterceptor implements Interceptor { Stream.of(408, 429, 500, 502, 503, 504).collect(Collectors.toSet()) ); private static final int MAX_DELAY_MS = 10000; + private static final int MAX_BODY_SIZE = 1 * 1024 * 1024; private final int maxRetries; private final long retryIntervalMs; private final Set businessRetryCodes; private final boolean enableBusinessRetry; + private final Class resultType; + private final Random random = new Random(); - public RetryInterceptor(int maxRetries, long retryIntervalMs, Set businessRetryCodes, boolean enableBusinessRetry) { - this.maxRetries = maxRetries; + public RetryInterceptor(int maxRetries, long retryIntervalMs, Set businessRetryCodes, boolean enableBusinessRetry, Class resultType) { + this.maxRetries = maxRetries + 1; this.retryIntervalMs = retryIntervalMs; this.businessRetryCodes = businessRetryCodes; this.enableBusinessRetry = enableBusinessRetry; + this.resultType = Objects.requireNonNull(resultType); } @Override @@ -120,72 +125,68 @@ public Response intercept(Chain chain) throws IOException { Request request = chain.request(); Response response = null; IOException exception = null; - for (int attempt = 0; attempt <= maxRetries; ++attempt) { - if (response != null) { + for (int attempt = 1; attempt <= maxRetries; attempt++) { + if (response != null) response.close(); - } try { response = chain.proceed(request); - if (response.isSuccessful() && !shouldRetry(response)) { + if (response.isSuccessful()) { + if (enableBusinessRetry && shouldRetryForBusiness(response)) { + waitForRetry(attempt); + continue; + } return response; - } - if (!shouldRetry(response)) { + } else { + if (shouldRetryForHttp(response)) { + waitForRetry(attempt); + continue; + } return response; } } catch (IOException e) { - if (attempt >= maxRetries) { - throw e; - } exception = e; - } - if (attempt < maxRetries) { + if (attempt == maxRetries) throw e; waitForRetry(attempt); } } - if (response != null) { - return response; - } - if (exception != null) { - throw exception; - } else { - throw new IOException("Failed to get a valid response after all retries and no exception was caught."); - } + if (exception != null) throw exception; + if (response != null) return response; + throw new IOException("Failed after all retries with no response"); } - private boolean shouldRetry(Response response) { - final int code = response.code(); - if (code >= 500 && code < 600) { - return true; - } - if (RETRYABLE_STATUS_CODES.contains(code)) { - return true; - } - if (enableBusinessRetry) { - return shouldRetryBasedOnBusinessCode(response); - } - return false; + private boolean shouldRetryForHttp(Response response) { + int code = response.code(); + return code >= 500 || RETRYABLE_STATUS_CODES.contains(code); } - private void waitForRetry(int attempt) { + private void waitForRetry(int attempt) throws IOException { try { - final long delayMs = Math.min(MAX_DELAY_MS, retryIntervalMs * (1L << attempt)); - TimeUnit.MILLISECONDS.sleep(delayMs); + long delay = calculateBackoff(attempt); + TimeUnit.MILLISECONDS.sleep(delay); } catch (InterruptedException e) { Thread.currentThread().interrupt(); + throw new IOException("Retry interrupted", e); } } - private boolean shouldRetryBasedOnBusinessCode(Response response) { + private long calculateBackoff(int attempt) { + double jitter = 0.8 + random.nextDouble() * 0.4; + long calculated = (long) (retryIntervalMs * Math.pow(2, attempt) * jitter); + return Math.min(calculated, MAX_DELAY_MS); + } + + private boolean shouldRetryForBusiness(Response response) { try { - if (businessRetryCodes == null || businessRetryCodes.isEmpty()) { + if (businessRetryCodes.isEmpty()) return false; + ResponseBody peekBody = response.peekBody(MAX_BODY_SIZE); + String responseBody = peekBody.source().readByteString().utf8(); + GenericResult result = JsonUtil.str2Obj(responseBody, resultType); + if (result == null || result.getErrorCode() == null) { return false; } - String responseBody = Objects.requireNonNull(response.body()).string(); - GenericResult genericResult = JsonUtil.str2Obj(responseBody, GenericResult.class); - int businessCode = genericResult.getErrorCode(); - return businessRetryCodes.contains(businessCode); - } catch (IOException | IllegalStateException e) { + return businessRetryCodes.contains(result.getErrorCode()); + } catch (Exception e) { return false; } } diff --git a/src/test/java/io/github/doocs/im/util/RetryInterceptorTest.java b/src/test/java/io/github/doocs/im/util/RetryInterceptorTest.java new file mode 100644 index 0000000..980c3f5 --- /dev/null +++ b/src/test/java/io/github/doocs/im/util/RetryInterceptorTest.java @@ -0,0 +1,313 @@ +package io.github.doocs.im.util; + +import io.github.doocs.im.model.response.BaseGenericResult; +import okhttp3.*; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.SocketTimeoutException; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.*; + +public class RetryInterceptorTest { + + private RetryInterceptor interceptor; + private TestChain testChain; + private Request request; + + @BeforeEach + public void setup() { + // Use virtual URL + request = new Request.Builder() + .url("http://example.com") + .build(); + + interceptor = new RetryInterceptor( + 2, 1000, + Collections.unmodifiableSet(new java.util.HashSet() {{ + add(10002); + add(20004); + add(20005); + }}), + true, + BaseGenericResult.class + ); + + // Constructor using default timeout parameter + testChain = new TestChain(request); + } + + // Custom Test Chain Implementation + private static class TestChain implements Interceptor.Chain { + private final Request request; + private final AtomicInteger callCount = new AtomicInteger(0); + private final java.util.function.IntFunction responseSupplier; + private final java.util.function.IntFunction exceptionSupplier; + private final int connectTimeoutMs; + private final int readTimeoutMs; + + public TestChain(Request request) { + this(request, i -> null, i -> null, 10000, 10000); + } + + public TestChain(Request request, + java.util.function.IntFunction responseSupplier, + java.util.function.IntFunction exceptionSupplier, + int connectTimeoutMs, + int readTimeoutMs) { + this.request = request; + this.responseSupplier = responseSupplier; + this.exceptionSupplier = exceptionSupplier; + this.connectTimeoutMs = connectTimeoutMs; + this.readTimeoutMs = readTimeoutMs; + } + + @Override + public Request request() { + return request; + } + + @Override + public Response proceed(Request request) throws IOException { + int count = callCount.incrementAndGet(); + IOException exception = exceptionSupplier.apply(count); + if (exception != null) { + throw exception; + } + Response response = responseSupplier.apply(count); + if (response == null) { + throw new AssertionError("No response configured for call " + count); + } + return response; + } + + @Override + public Connection connection() { + // No actual connection is required during testing, return null + return null; + } + + @Override + public Call call() { + // No actual Call object is required during testing, return null + return null; + } + + @Override + public int connectTimeoutMillis() { + return connectTimeoutMs; + } + + @Override + public Interceptor.Chain withConnectTimeout(int timeout, java.util.concurrent.TimeUnit unit) { + return new TestChain( + request, + responseSupplier, + exceptionSupplier, + (int) unit.toMillis(timeout), + readTimeoutMs + ); + } + + @Override + public int readTimeoutMillis() { + return readTimeoutMs; + } + + @Override + public Interceptor.Chain withReadTimeout(int timeout, java.util.concurrent.TimeUnit unit) { + return new TestChain( + request, + responseSupplier, + exceptionSupplier, + connectTimeoutMs, + (int) unit.toMillis(timeout) + ); + } + + @Override + public int writeTimeoutMillis() { + // Default write timeout + return 10000; + } + + @Override + public Interceptor.Chain withWriteTimeout(int timeout, java.util.concurrent.TimeUnit unit) { + // No need to implement during testing + return this; + } + + public int getCallCount() { + return callCount.get(); + } + } + + //---------------- Tool method: Create simulated response ----------------// + private Response createResponse(int code, String body) { + return new Response.Builder() + .request(request) + .protocol(Protocol.HTTP_1_1) + .code(code) + .message("") + .body(ResponseBody.create( + body, + MediaType.get("application/json") + )) + .build(); + } + + //---------------- Normal response test ----------------// + @Test + public void testNormalResponse_Http200() throws IOException { + // Simulate a successful response with a constructor that includes all parameters + testChain = new TestChain( + request, + i -> createResponse(200, "{ \"ErrorCode\": 0 }"), + i -> null, + 10000, + 10000 + ); + + Response response = interceptor.intercept(testChain); + assertEquals(200, response.code()); + // Verify a single request + assertEquals(1, testChain.getCallCount()); + } + + //---------------- HTTP error retry test ----------------// + @Test + public void testHttpRetry_SuccessAfterRetries() throws IOException { + // Use counters to control retry logic, including a constructor with all parameters + testChain = new TestChain( + request, + i -> i <= 2 ? createResponse(500, "") : createResponse(200, "{ \"ErrorCode\": 0 }"), + i -> null, + 10000, + 10000 + ); + + Response response = interceptor.intercept(testChain); + assertEquals(200, response.code()); + assertEquals(3, testChain.getCallCount()); // Verify the number of retries + } + + //---------------- Business error retry test ----------------// + @Test + public void testBusinessRetry_SuccessAfterRetries() throws IOException { + // Constructor containing all parameters + testChain = new TestChain( + request, + i -> createResponse(200, i < 3 ? "{ \"ErrorCode\": 10002 }" : "{ \"ErrorCode\": 0 }"), + i -> null, + 10000, + 10000 + ); + + Response response = interceptor.intercept(testChain); + assertEquals(200, response.code()); + assertEquals(3, testChain.getCallCount()); + } + + //---------------- Abnormal scenario testing ----------------// + @Test + public void testHttpRetry_MaxRetriesExceeded() { + // Constructor containing all parameters + testChain = new TestChain( + request, + i -> createResponse(500, ""), + i -> null, + 10000, + 10000 + ); + + assertThrows(IOException.class, () -> interceptor.intercept(testChain)); + assertEquals(3, testChain.getCallCount()); // Verify the number of retries + } + + //---------------- Timeout retry test ----------------// + @Test + public void testConnectTimeoutRetry() throws IOException { + // Simulate the first two connection timeouts and the third one is successful + testChain = new TestChain( + request, + i -> i > 2 ? createResponse(200, "{}") : null, + i -> i <= 2 ? new SocketTimeoutException("Connect timed out") : null, + 10000, + 10000 + ); + + Response response = interceptor.intercept(testChain); + assertEquals(200, response.code()); + assertEquals(3, testChain.getCallCount()); + } + + //---------------- ReadTime retry test ----------------// + @Test + public void testReadTimeoutRetry() throws IOException { + // Simulate the first two reads timed out, the third one succeeded + testChain = new TestChain( + request, + i -> i > 2 ? createResponse(200, "{}") : null, + i -> i <= 2 ? new SocketTimeoutException("Read timed out") : null, + 10000, + 10000 + ); + + Response response = interceptor.intercept(testChain); + assertEquals(200, response.code()); + assertEquals(3, testChain.getCallCount()); + } + + @Test + public void testConnectTimeoutExceedMaxRetries() { + // Simulate connection timeout every time, exceeding the maximum retry count + testChain = new TestChain( + request, + i -> null, + i -> new SocketTimeoutException("Connect timed out"), + 10000, + 10000 + ); + + assertThrows(SocketTimeoutException.class, () -> interceptor.intercept(testChain)); + assertEquals(3, testChain.getCallCount()); + } + + //---------------- Similar adjustments to other test cases (example)----------------// + @Test + public void testIOExceptionRetry() throws IOException { + // Constructor containing all parameters + testChain = new TestChain( + request, + i -> i > 2 ? createResponse(200, "{}") : null, + i -> i <= 2 ? new IOException("Simulate network errors") : null, + 10000, + 10000 + ); + + Response response = interceptor.intercept(testChain); + assertEquals(200, response.code()); + assertEquals(3, testChain.getCallCount()); + } + + @Test + public void testMaxRetriesZero() throws IOException { + interceptor = new RetryInterceptor(0, 1000, Collections.emptySet(), false, BaseGenericResult.class); + + // Constructor containing all parameters + testChain = new TestChain( + request, + i -> createResponse(500, ""), + i -> null, + 10000, + 10000 + ); + + Response response = interceptor.intercept(testChain); + assertEquals(500, response.code()); + assertEquals(1, testChain.getCallCount()); + } +}