1919import com .alibaba .cloud .ai .graph .agent .interceptor .ModelInterceptor ;
2020import com .alibaba .cloud .ai .graph .agent .interceptor .ModelRequest ;
2121import com .alibaba .cloud .ai .graph .agent .interceptor .ModelResponse ;
22+
23+ import org .springframework .ai .chat .messages .Message ;
2224import org .slf4j .Logger ;
2325import org .slf4j .LoggerFactory ;
24- import org .springframework .ai .chat .messages .Message ;
25- import org .springframework .ai .chat .model .ChatResponse ;
26- import reactor .core .publisher .Flux ;
27- import reactor .core .publisher .Mono ;
28- import reactor .util .retry .Retry ;
29-
30- import java .time .Duration ;
31- import java .util .concurrent .atomic .AtomicBoolean ;
32- import java .util .concurrent .atomic .AtomicLong ;
26+
27+
3328import java .util .function .Predicate ;
3429
3530/**
@@ -54,15 +49,13 @@ public class ModelRetryInterceptor extends ModelInterceptor {
5449 private final long maxDelay ;
5550 private final double backoffMultiplier ;
5651 private final Predicate <Exception > retryableExceptionPredicate ;
57- private final Predicate <Throwable > retryableThrowablePredicate ;
5852
5953 private ModelRetryInterceptor (Builder builder ) {
6054 this .maxAttempts = builder .maxAttempts ;
6155 this .initialDelay = builder .initialDelay ;
6256 this .maxDelay = builder .maxDelay ;
6357 this .backoffMultiplier = builder .backoffMultiplier ;
6458 this .retryableExceptionPredicate = builder .retryableExceptionPredicate ;
65- this .retryableThrowablePredicate = this ::isRetryableException ;
6659 }
6760
6861 public static Builder builder () {
@@ -71,39 +64,17 @@ public static Builder builder() {
7164
7265 @ Override
7366 public ModelResponse interceptModel (ModelRequest request , ModelCallHandler handler ) {
74-
75- // 1. Initial attempt to call the model
76- ModelResponse response = handler .call (request );
77-
78- // 2. Handle streaming calls (Flux)
79- if (response .getMessage () instanceof Flux <?>) {
80- return handleStreamRetry (request , handler , response );
81- }
82-
83- // 3. Handle blocking calls
84- return handleBlockingRetry (request , handler , response );
85-
86- }
87-
88-
89- /**
90- * Retry logic for blocking (non-streaming) model calls.
91- */
92- private ModelResponse handleBlockingRetry (ModelRequest request , ModelCallHandler handler , ModelResponse initialResponse ) {
9367 Exception lastException = null ;
9468 long currentDelay = initialDelay ;
9569
96- ModelResponse currentResponse = initialResponse ;
97-
9870 for (int attempt = 1 ; attempt <= maxAttempts ; attempt ++) {
9971 try {
100- // If it is not the first iteration (i.e., this is a retry attempt), re-invoke the handler.
10172 if (attempt > 1 ) {
10273 log .info ("Retry model call, on the {}th attempt (out of {} attempts)." , attempt , maxAttempts );
103- currentResponse = handler .call (request );
10474 }
10575
106- Message message = (Message ) currentResponse .getMessage ();
76+ ModelResponse modelResponse = handler .call (request );
77+ Message message = (Message ) modelResponse .getMessage ();
10778
10879 // Check if the response contains any exception information (exceptions captured from AgentLlmNode).
10980 if (message != null && message .getText () != null && message .getText ().startsWith ("Exception:" )) {
@@ -132,14 +103,14 @@ private ModelResponse handleBlockingRetry(ModelRequest request, ModelCallHandler
132103 }
133104
134105 // For non-retryable exceptions, return immediately.
135- return currentResponse ;
106+ return modelResponse ;
136107 }
137108
138109 // Successful response
139110 if (attempt > 1 ) {
140111 log .info ("The model call succeeded after the {}th attempt." , attempt );
141112 }
142- return currentResponse ;
113+ return modelResponse ;
143114
144115 } catch (Exception e ) {
145116 lastException = e ;
@@ -176,69 +147,6 @@ private ModelResponse handleBlockingRetry(ModelRequest request, ModelCallHandler
176147 throw new RuntimeException ("Model call failed, maximum number of retries reached. " + maxAttempts , lastException );
177148 }
178149
179- /**
180- * Streaming call retry logic using Reactor's retry mechanism.
181- * Retries are only triggered for retryable exceptions and if no data has been emitted yet.
182- */
183- private ModelResponse handleStreamRetry (ModelRequest request , ModelCallHandler handler , ModelResponse initialResponse ) {
184- // Flag to track whether the stream has emitted any data
185- AtomicBoolean hasOutput = new AtomicBoolean (false );
186- AtomicBoolean isFirstAttempt = new AtomicBoolean (true );
187- AtomicLong currentDelay = new AtomicLong (initialDelay );
188-
189- Flux <ChatResponse > retryableFlux = Flux .defer (() -> {
190- // Re-invoke the handler on each subscription (including retries)
191- if (isFirstAttempt .compareAndSet (true , false )) {
192- return (Flux <ChatResponse >) initialResponse .getMessage ();
193- }
194- // Subsequent subscriptions (i.e., retries) will trigger a fresh handler.call
195- ModelResponse newResponse = handler .call (request );
196- return (Flux <ChatResponse >) newResponse .getMessage ();
197- })
198- .doOnNext (data -> {
199- // Mark as output emitted once the first data chunk is received
200- if (!hasOutput .get ()) {
201- hasOutput .set (true );
202- }
203- })
204- .retryWhen (Retry .from (signals ->
205- signals .flatMap (signal -> {
206- // 1. Get current retry count (starts from 0, so +1 represents the current attempt number)
207- long attempt = signal .totalRetries () + 1 ;
208- Throwable throwable = signal .failure ();
209-
210- // 2. Check if the maximum number of attempts has been reached
211- if (attempt >= maxAttempts ) {
212- log .error ("The maximum number of retries has been reached ({}), and the model call has failed." , maxAttempts );
213- return Mono .error (throwable );
214- }
215-
216- // 3. Business logic: Check if any data has already been emitted
217- if (hasOutput .get ()) {
218- log .error ("Stream failed after partial output. Cannot retry to avoid data duplication." );
219- return Mono .error (throwable );
220- }
221-
222- // 4. Business logic: Determine if the exception is retryable (aligned with blocking retryableExceptionPredicate)
223- if (!retryableThrowablePredicate .test (throwable )) {
224- log .warn ("Exception is non-retryable and will be thrown immediately: {}" , throwable .getMessage ());
225- return Mono .error (throwable );
226- }
227-
228- // 5. Calculate aligned backoff delay (consistent with blocking retry logic)
229- currentDelay .set (Math .min ((long ) (currentDelay .get () * backoffMultiplier ), maxDelay ));
230-
231- log .info ("Retrying model call, attempt {}/{}" , attempt , maxAttempts );
232- log .info ("Wait for {} ms before next retry" , currentDelay .get ());
233-
234- // 6. Generate a delay signal to trigger the retry
235- return Mono .delay (Duration .ofMillis (currentDelay .get ()));
236- })
237- ));
238-
239- return new ModelResponse (retryableFlux );
240- }
241-
242150 /**
243151 * Determine if the exception message indicates a retryable error.
244152 */
@@ -253,10 +161,6 @@ private boolean isRetryableExceptionMessage(String exceptionText) {
253161 lowerText .contains ("socket" );
254162 }
255163
256- private boolean isRetryableException (Throwable e ) {
257- return isRetryableExceptionMessage (e .getMessage ());
258- }
259-
260164 @ Override
261165 public String getName () {
262166 return "ModelRetry" ;
@@ -343,17 +247,17 @@ private static boolean isRetryableException(Exception e) {
343247
344248 // Network-related exceptions
345249 if (lowerMessage .contains ("i/o error" ) ||
346- lowerMessage .contains ("remote host terminated" ) ||
347- lowerMessage .contains ("connection" ) ||
348- lowerMessage .contains ("timeout" ) ||
349- lowerMessage .contains ("handshake" ) ||
350- lowerMessage .contains ("socket" )) {
250+ lowerMessage .contains ("remote host terminated" ) ||
251+ lowerMessage .contains ("connection" ) ||
252+ lowerMessage .contains ("timeout" ) ||
253+ lowerMessage .contains ("handshake" ) ||
254+ lowerMessage .contains ("socket" )) {
351255 return true ;
352256 }
353257
354258 // Spring WebClient related exceptions
355259 if (e .getClass ().getName ().contains ("ResourceAccessException" ) ||
356- e .getClass ().getName ().contains ("WebClientRequestException" )) {
260+ e .getClass ().getName ().contains ("WebClientRequestException" )) {
357261 return true ;
358262 }
359263
@@ -362,10 +266,10 @@ private static boolean isRetryableException(Exception e) {
362266 while (cause != null ) {
363267 String causeClassName = cause .getClass ().getName ();
364268 if (causeClassName .contains ("IOException" ) ||
365- causeClassName .contains ("SocketException" ) ||
366- causeClassName .contains ("ConnectException" ) ||
367- causeClassName .contains ("TimeoutException" ) ||
368- causeClassName .contains ("SSLException" )) {
269+ causeClassName .contains ("SocketException" ) ||
270+ causeClassName .contains ("ConnectException" ) ||
271+ causeClassName .contains ("TimeoutException" ) ||
272+ causeClassName .contains ("SSLException" )) {
369273 return true ;
370274 }
371275 cause = cause .getCause ();
0 commit comments