|
1 | 1 | package callback |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "errors" |
4 | 5 | "net/http" |
5 | 6 |
|
| 7 | + "github.com/nexus-rpc/sdk-go/nexus" |
| 8 | + "go.temporal.io/api/serviceerror" |
6 | 9 | "go.temporal.io/server/common" |
7 | 10 | "go.temporal.io/server/common/cluster" |
8 | 11 | "go.temporal.io/server/common/log" |
9 | 12 | "go.temporal.io/server/common/log/tag" |
10 | | - "go.temporal.io/server/common/nexus" |
| 13 | + "go.temporal.io/server/common/namespace" |
| 14 | + commonnexus "go.temporal.io/server/common/nexus" |
11 | 15 | ) |
12 | 16 |
|
13 | 17 | // Header key used to identify callbacks that originate from and target the same cluster. |
14 | 18 | // Note: this is the nexusoperations.NexusCallbackSourceHeader stripped of Nexus-Callback- |
15 | 19 | const callbackSourceHeader = "source" |
16 | 20 |
|
| 21 | +// routeSystemCallbackRequest routes a system callback request to the appropriate frontend client |
| 22 | +// based on the callback token's namespace and active cluster. |
| 23 | +func routeSystemCallbackRequest( |
| 24 | + r *http.Request, |
| 25 | + clusterMetadata cluster.Metadata, |
| 26 | + namespaceRegistry namespace.Registry, |
| 27 | + httpClientCache *cluster.FrontendHTTPClientCache, |
| 28 | + callbackTokenGenerator *commonnexus.CallbackTokenGenerator, |
| 29 | + localClient *common.FrontendHTTPClient, |
| 30 | + logger log.Logger, |
| 31 | +) (*http.Response, error) { |
| 32 | + var frontendClient *common.FrontendHTTPClient |
| 33 | + if r.Header != nil { |
| 34 | + token, err := commonnexus.DecodeCallbackToken(r.Header.Get(commonnexus.CallbackTokenHeader)) |
| 35 | + if err != nil { |
| 36 | + logger.Error("failed to decode callback token", tag.Error(err)) |
| 37 | + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") |
| 38 | + } |
| 39 | + |
| 40 | + completion, err := callbackTokenGenerator.DecodeCompletion(token) |
| 41 | + if err != nil { |
| 42 | + logger.Error("failed to decode completion from token", tag.Error(err)) |
| 43 | + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") |
| 44 | + } |
| 45 | + ns, err := namespaceRegistry.GetNamespaceByID(namespace.ID(completion.NamespaceId)) |
| 46 | + if err != nil { |
| 47 | + logger.Error("failed to get namespace for nexus completion request", tag.WorkflowNamespaceID(completion.NamespaceId), tag.Error(err)) |
| 48 | + var nfe *serviceerror.NamespaceNotFound |
| 49 | + if errors.As(err, &nfe) { |
| 50 | + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace %q not found", completion.NamespaceId) |
| 51 | + } |
| 52 | + return nil, commonnexus.ConvertGRPCError(err, false) |
| 53 | + } |
| 54 | + clusterName := ns.ActiveClusterName(completion.GetWorkflowId()) |
| 55 | + if clusterMetadata.GetCurrentClusterName() == clusterName { |
| 56 | + frontendClient = localClient |
| 57 | + } else { |
| 58 | + fec, err := httpClientCache.Get(clusterName) |
| 59 | + if err != nil { |
| 60 | + logger.Warn( |
| 61 | + "HTTPCallerProvider unable to get FrontendHTTPClient for callback target cluster. Using local HTTP Client.", |
| 62 | + tag.SourceCluster(clusterMetadata.GetCurrentClusterName()), |
| 63 | + tag.TargetCluster(clusterName), |
| 64 | + tag.Error(err), |
| 65 | + ) |
| 66 | + frontendClient = localClient |
| 67 | + } else { |
| 68 | + frontendClient = fec |
| 69 | + } |
| 70 | + } |
| 71 | + } else { |
| 72 | + frontendClient = localClient |
| 73 | + } |
| 74 | + r.URL.Path = commonnexus.PathCompletionCallbackNoIdentifier |
| 75 | + r.URL.Scheme = frontendClient.Scheme |
| 76 | + r.URL.Host = frontendClient.Address |
| 77 | + r.Host = frontendClient.Address |
| 78 | + return frontendClient.Do(r) |
| 79 | +} |
| 80 | + |
17 | 81 | func routeRequest( |
18 | 82 | r *http.Request, |
19 | 83 | clusterMetadata cluster.Metadata, |
| 84 | + namespaceRegistry namespace.Registry, |
20 | 85 | httpClientCache *cluster.FrontendHTTPClientCache, |
| 86 | + callbackTokenGenerator *commonnexus.CallbackTokenGenerator, |
21 | 87 | defaultClient *http.Client, |
22 | 88 | localClient *common.FrontendHTTPClient, |
23 | 89 | logger log.Logger, |
24 | 90 | ) (*http.Response, error) { |
| 91 | + if r.URL.String() == commonnexus.SystemCallbackURL { |
| 92 | + return routeSystemCallbackRequest(r, clusterMetadata, namespaceRegistry, httpClientCache, callbackTokenGenerator, localClient, logger) |
| 93 | + } |
25 | 94 | // This source header is populated in nexusoperations/executors (via the ClientProvider) for worker targets |
26 | | - // if this header is not populated then we assume it's and external target. |
| 95 | + // if this header is not populated then we assume it's an external target. |
27 | 96 | if r.Header == nil || r.Header.Get(callbackSourceHeader) == "" { |
28 | 97 | return defaultClient.Do(r) |
29 | 98 | } |
@@ -61,9 +130,6 @@ func routeRequest( |
61 | 130 | frontendClient = localClient |
62 | 131 | } |
63 | 132 |
|
64 | | - if r.URL.String() == nexus.SystemCallbackURL { |
65 | | - r.URL.Path = nexus.PathCompletionCallbackNoIdentifier |
66 | | - } |
67 | 133 | r.URL.Scheme = frontendClient.Scheme |
68 | 134 | r.URL.Host = frontendClient.Address |
69 | 135 | r.Host = frontendClient.Address |
|
0 commit comments