|
| 1 | +// |
| 2 | +// Copyright Amazon.com Inc. or its affiliates. |
| 3 | +// All Rights Reserved. |
| 4 | +// |
| 5 | +// SPDX-License-Identifier: Apache-2.0 |
| 6 | +// |
| 7 | + |
| 8 | +public struct RetryerMiddleware<Output: HttpResponseBinding, |
| 9 | + OutputError: HttpResponseBinding>: Middleware { |
| 10 | + |
| 11 | + public var id: String = "Retryer" |
| 12 | + |
| 13 | + let retryer: SDKRetryer |
| 14 | + |
| 15 | + public init(retryer: SDKRetryer) { |
| 16 | + self.retryer = retryer |
| 17 | + } |
| 18 | + |
| 19 | + public func handle<H>( |
| 20 | + context: Context, |
| 21 | + input: SdkHttpRequestBuilder, |
| 22 | + next: H |
| 23 | + ) async throws -> OperationOutput<Output> where |
| 24 | + H: Handler, |
| 25 | + Self.MInput == H.Input, |
| 26 | + Self.MOutput == H.Output, |
| 27 | + Self.Context == H.Context { |
| 28 | + |
| 29 | + // Select a partition ID to be used for throttling retry requests. Requests with the |
| 30 | + // same partition ID will be "pooled" together for throttling purposes. |
| 31 | + let partitionID: String |
| 32 | + if let customPartitionID = context.getPartitionID(), !customPartitionID.isEmpty { |
| 33 | + // use custom partition ID provided by context |
| 34 | + partitionID = customPartitionID |
| 35 | + } else if !input.host.isEmpty { |
| 36 | + // fall back to the hostname for partition ID, which is a "commonsense" default |
| 37 | + partitionID = input.host |
| 38 | + } else { |
| 39 | + throw SdkError<OutputError>.client(ClientError.unknownError("Partition ID could not be determined")) |
| 40 | + } |
| 41 | + |
| 42 | + do { |
| 43 | + let token = try await retryer.acquireToken(partitionId: partitionID) |
| 44 | + return try await tryRequest( |
| 45 | + token: token, |
| 46 | + partitionID: partitionID, |
| 47 | + context: context, |
| 48 | + input: input, |
| 49 | + next: next |
| 50 | + ) |
| 51 | + } catch { |
| 52 | + throw SdkError<OutputError>.client(ClientError.retryError(error)) |
| 53 | + } |
| 54 | + } |
| 55 | + |
| 56 | + func tryRequest<H>( |
| 57 | + token: RetryToken, |
| 58 | + errorType: RetryError? = nil, |
| 59 | + partitionID: String, |
| 60 | + context: Context, |
| 61 | + input: SdkHttpRequestBuilder, |
| 62 | + next: H |
| 63 | + ) async throws -> OperationOutput<Output> where |
| 64 | + H: Handler, |
| 65 | + Self.MInput == H.Input, |
| 66 | + Self.MOutput == H.Output, |
| 67 | + Self.Context == H.Context { |
| 68 | + |
| 69 | + do { |
| 70 | + let serviceResponse = try await next.handle(context: context, input: input) |
| 71 | + retryer.recordSuccess(token: token) |
| 72 | + return serviceResponse |
| 73 | + } catch let error as SdkError<OutputError> where retryer.isErrorRetryable(error: error) { |
| 74 | + let errorType = retryer.getErrorType(error: error) |
| 75 | + let newToken = try await retryer.scheduleRetry(token: token, error: errorType) |
| 76 | + // TODO: rewind the stream once streaming is properly implemented |
| 77 | + return try await tryRequest( |
| 78 | + token: newToken, |
| 79 | + partitionID: partitionID, |
| 80 | + context: context, |
| 81 | + input: input, |
| 82 | + next: next |
| 83 | + ) |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + public typealias MInput = SdkHttpRequestBuilder |
| 88 | + public typealias MOutput = OperationOutput<Output> |
| 89 | + public typealias Context = HttpContext |
| 90 | +} |
0 commit comments