Skip to content

Commit 7405a7d

Browse files
authored
revert #4330, fix: support streaming calls in ModelRetryInterceptor (#4398)
1 parent 99055ef commit 7405a7d

File tree

2 files changed

+64
-316
lines changed

2 files changed

+64
-316
lines changed

spring-ai-alibaba-agent-framework/src/main/java/com/alibaba/cloud/ai/graph/agent/interceptor/modelretry/ModelRetryInterceptor.java

Lines changed: 18 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,12 @@
1919
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelInterceptor;
2020
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelRequest;
2121
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelResponse;
22+
23+
import org.springframework.ai.chat.messages.Message;
2224
import org.slf4j.Logger;
2325
import org.slf4j.LoggerFactory;
24-
import org.springframework.ai.chat.messages.Message;
25-
import org.springframework.ai.chat.model.ChatResponse;
26-
import reactor.core.publisher.Flux;
27-
import reactor.core.publisher.Mono;
28-
import reactor.util.retry.Retry;
29-
30-
import java.time.Duration;
31-
import java.util.concurrent.atomic.AtomicBoolean;
32-
import java.util.concurrent.atomic.AtomicLong;
26+
27+
3328
import java.util.function.Predicate;
3429

3530
/**
@@ -54,15 +49,13 @@ public class ModelRetryInterceptor extends ModelInterceptor {
5449
private final long maxDelay;
5550
private final double backoffMultiplier;
5651
private final Predicate<Exception> retryableExceptionPredicate;
57-
private final Predicate<Throwable> retryableThrowablePredicate;
5852

5953
private ModelRetryInterceptor(Builder builder) {
6054
this.maxAttempts = builder.maxAttempts;
6155
this.initialDelay = builder.initialDelay;
6256
this.maxDelay = builder.maxDelay;
6357
this.backoffMultiplier = builder.backoffMultiplier;
6458
this.retryableExceptionPredicate = builder.retryableExceptionPredicate;
65-
this.retryableThrowablePredicate = this::isRetryableException;
6659
}
6760

6861
public static Builder builder() {
@@ -71,39 +64,17 @@ public static Builder builder() {
7164

7265
@Override
7366
public ModelResponse interceptModel(ModelRequest request, ModelCallHandler handler) {
74-
75-
// 1. Initial attempt to call the model
76-
ModelResponse response = handler.call(request);
77-
78-
// 2. Handle streaming calls (Flux)
79-
if (response.getMessage() instanceof Flux<?>) {
80-
return handleStreamRetry(request, handler, response);
81-
}
82-
83-
// 3. Handle blocking calls
84-
return handleBlockingRetry(request, handler, response);
85-
86-
}
87-
88-
89-
/**
90-
* Retry logic for blocking (non-streaming) model calls.
91-
*/
92-
private ModelResponse handleBlockingRetry(ModelRequest request, ModelCallHandler handler, ModelResponse initialResponse) {
9367
Exception lastException = null;
9468
long currentDelay = initialDelay;
9569

96-
ModelResponse currentResponse = initialResponse;
97-
9870
for (int attempt = 1; attempt <= maxAttempts; attempt++) {
9971
try {
100-
// If it is not the first iteration (i.e., this is a retry attempt), re-invoke the handler.
10172
if (attempt > 1) {
10273
log.info("Retry model call, on the {}th attempt (out of {} attempts).", attempt, maxAttempts);
103-
currentResponse = handler.call(request);
10474
}
10575

106-
Message message = (Message) currentResponse.getMessage();
76+
ModelResponse modelResponse = handler.call(request);
77+
Message message = (Message) modelResponse.getMessage();
10778

10879
// Check if the response contains any exception information (exceptions captured from AgentLlmNode).
10980
if (message != null && message.getText() != null && message.getText().startsWith("Exception:")) {
@@ -132,14 +103,14 @@ private ModelResponse handleBlockingRetry(ModelRequest request, ModelCallHandler
132103
}
133104

134105
// For non-retryable exceptions, return immediately.
135-
return currentResponse;
106+
return modelResponse;
136107
}
137108

138109
// Successful response
139110
if (attempt > 1) {
140111
log.info("The model call succeeded after the {}th attempt.", attempt);
141112
}
142-
return currentResponse;
113+
return modelResponse;
143114

144115
} catch (Exception e) {
145116
lastException = e;
@@ -176,69 +147,6 @@ private ModelResponse handleBlockingRetry(ModelRequest request, ModelCallHandler
176147
throw new RuntimeException("Model call failed, maximum number of retries reached. " + maxAttempts, lastException);
177148
}
178149

179-
/**
180-
* Streaming call retry logic using Reactor's retry mechanism.
181-
* Retries are only triggered for retryable exceptions and if no data has been emitted yet.
182-
*/
183-
private ModelResponse handleStreamRetry(ModelRequest request, ModelCallHandler handler, ModelResponse initialResponse) {
184-
// Flag to track whether the stream has emitted any data
185-
AtomicBoolean hasOutput = new AtomicBoolean(false);
186-
AtomicBoolean isFirstAttempt = new AtomicBoolean(true);
187-
AtomicLong currentDelay = new AtomicLong(initialDelay);
188-
189-
Flux<ChatResponse> retryableFlux = Flux.defer(() -> {
190-
// Re-invoke the handler on each subscription (including retries)
191-
if (isFirstAttempt.compareAndSet(true, false)) {
192-
return (Flux<ChatResponse>) initialResponse.getMessage();
193-
}
194-
// Subsequent subscriptions (i.e., retries) will trigger a fresh handler.call
195-
ModelResponse newResponse = handler.call(request);
196-
return (Flux<ChatResponse>) newResponse.getMessage();
197-
})
198-
.doOnNext(data -> {
199-
// Mark as output emitted once the first data chunk is received
200-
if (!hasOutput.get()) {
201-
hasOutput.set(true);
202-
}
203-
})
204-
.retryWhen(Retry.from(signals ->
205-
signals.flatMap(signal -> {
206-
// 1. Get current retry count (starts from 0, so +1 represents the current attempt number)
207-
long attempt = signal.totalRetries() + 1;
208-
Throwable throwable = signal.failure();
209-
210-
// 2. Check if the maximum number of attempts has been reached
211-
if (attempt >= maxAttempts) {
212-
log.error("The maximum number of retries has been reached ({}), and the model call has failed.", maxAttempts);
213-
return Mono.error(throwable);
214-
}
215-
216-
// 3. Business logic: Check if any data has already been emitted
217-
if (hasOutput.get()) {
218-
log.error("Stream failed after partial output. Cannot retry to avoid data duplication.");
219-
return Mono.error(throwable);
220-
}
221-
222-
// 4. Business logic: Determine if the exception is retryable (aligned with blocking retryableExceptionPredicate)
223-
if (!retryableThrowablePredicate.test(throwable)) {
224-
log.warn("Exception is non-retryable and will be thrown immediately: {}", throwable.getMessage());
225-
return Mono.error(throwable);
226-
}
227-
228-
// 5. Calculate aligned backoff delay (consistent with blocking retry logic)
229-
currentDelay.set(Math.min((long) (currentDelay.get() * backoffMultiplier), maxDelay));
230-
231-
log.info("Retrying model call, attempt {}/{}", attempt, maxAttempts);
232-
log.info("Wait for {} ms before next retry", currentDelay.get());
233-
234-
// 6. Generate a delay signal to trigger the retry
235-
return Mono.delay(Duration.ofMillis(currentDelay.get()));
236-
})
237-
));
238-
239-
return new ModelResponse(retryableFlux);
240-
}
241-
242150
/**
243151
* Determine if the exception message indicates a retryable error.
244152
*/
@@ -253,10 +161,6 @@ private boolean isRetryableExceptionMessage(String exceptionText) {
253161
lowerText.contains("socket");
254162
}
255163

256-
private boolean isRetryableException(Throwable e) {
257-
return isRetryableExceptionMessage(e.getMessage());
258-
}
259-
260164
@Override
261165
public String getName() {
262166
return "ModelRetry";
@@ -343,17 +247,17 @@ private static boolean isRetryableException(Exception e) {
343247

344248
// Network-related exceptions
345249
if (lowerMessage.contains("i/o error") ||
346-
lowerMessage.contains("remote host terminated") ||
347-
lowerMessage.contains("connection") ||
348-
lowerMessage.contains("timeout") ||
349-
lowerMessage.contains("handshake") ||
350-
lowerMessage.contains("socket")) {
250+
lowerMessage.contains("remote host terminated") ||
251+
lowerMessage.contains("connection") ||
252+
lowerMessage.contains("timeout") ||
253+
lowerMessage.contains("handshake") ||
254+
lowerMessage.contains("socket")) {
351255
return true;
352256
}
353257

354258
// Spring WebClient related exceptions
355259
if (e.getClass().getName().contains("ResourceAccessException") ||
356-
e.getClass().getName().contains("WebClientRequestException")) {
260+
e.getClass().getName().contains("WebClientRequestException")) {
357261
return true;
358262
}
359263

@@ -362,10 +266,10 @@ private static boolean isRetryableException(Exception e) {
362266
while (cause != null) {
363267
String causeClassName = cause.getClass().getName();
364268
if (causeClassName.contains("IOException") ||
365-
causeClassName.contains("SocketException") ||
366-
causeClassName.contains("ConnectException") ||
367-
causeClassName.contains("TimeoutException") ||
368-
causeClassName.contains("SSLException")) {
269+
causeClassName.contains("SocketException") ||
270+
causeClassName.contains("ConnectException") ||
271+
causeClassName.contains("TimeoutException") ||
272+
causeClassName.contains("SSLException")) {
369273
return true;
370274
}
371275
cause = cause.getCause();

0 commit comments

Comments
 (0)