|
1 | | -// Copyright AGNTCY Contributors (https://github.com/agntcy) |
2 | | -// SPDX-License-Identifier: Apache-2.0 |
3 | | - |
4 | | -package main |
5 | | - |
6 | | -import ( |
7 | | - "context" |
8 | | - "encoding/json" |
9 | | - "fmt" |
10 | | - "io" |
11 | | - "net/http" |
12 | | - "net/url" |
13 | | - "path/filepath" |
14 | | - |
15 | | - "github.com/TykTechnologies/tyk/ctx" |
16 | | - "github.com/TykTechnologies/tyk/user" |
17 | | - "github.com/kelindar/search" |
18 | | -) |
19 | | - |
20 | | -const ( |
21 | | - DEFAULT_THRESHOLD = 0.5 |
22 | | -) |
23 | | - |
24 | | -type ACPPluginApiConfig struct { |
25 | | - APIName string `json:"name"` |
26 | | - Target string `json:"url"` |
27 | | - Utterances []string `json:"utterances"` |
28 | | -} |
29 | | - |
30 | | -type ACPPluginData struct { |
31 | | - ACPPluginServices map[string]ACPPluginApiConfig |
32 | | - ModelPath string |
33 | | - ModelEmbedder *search.Vectorizer |
34 | | - ModelIndex *search.Index[string] |
35 | | - StoreVersion int64 |
36 | | - MaxRequestLength int64 `json:"maxRequestLength"` // MaxRequestSize is the maximum size of the request in characters; default is -1 (no limit) |
37 | | -} |
38 | | - |
39 | | -var acpPluginData = ACPPluginData{ |
40 | | - ACPPluginServices: map[string]ACPPluginApiConfig{}, |
41 | | -} |
42 | | - |
43 | | -// SetContext updates the context of a request. |
44 | | -func SetContext(r *http.Request, ctx context.Context) { |
45 | | - r2 := r.WithContext(ctx) |
46 | | - *r = *r2 |
47 | | -} |
48 | | - |
49 | | -func ProcessACPQuery(rw http.ResponseWriter, r *http.Request) { |
50 | | - logger.Debug("[+] Inside ProcessACPQuery -->") |
51 | | - |
52 | | - if len(acpPluginData.ACPPluginServices) == 0 || acpPluginData.StoreVersion != storeVersion { |
53 | | - logger.Infof("[+] ACP plugin config is empty or store version has changed, reloading ...") |
54 | | - err := initACPPluginApiConfig() |
55 | | - if err != nil { |
56 | | - logger.Errorf("[+] Error while getting the ACP plugin config: %s", err) |
57 | | - http.Error(rw, INTERNAL_ERROR_MSG, http.StatusInternalServerError) |
58 | | - return |
59 | | - } |
60 | | - } |
61 | | - |
62 | | - // Only proceed for POST with Content-Type: application/nlq (parameters are allowed) |
63 | | - if r.Method != http.MethodPost || !isNLQContentType(r.Header.Get("Content-Type")) { |
64 | | - logger.Debugf("[+] Query is not POST or Content-Type is not %s, ignoring ...", CONTENT_TYPE_NLQ) |
65 | | - return |
66 | | - } |
67 | | - |
68 | | - if acpPluginData.MaxRequestLength > 0 && r.ContentLength > acpPluginData.MaxRequestLength { |
69 | | - logger.Debugf("[+] Query is too large, ignoring ...") |
70 | | - http.Error(rw, "Query is too large", http.StatusRequestEntityTooLarge) |
71 | | - return |
72 | | - } |
73 | | - |
74 | | - nlqBytes, err := io.ReadAll(r.Body) |
75 | | - if err != nil { |
76 | | - logger.Errorf("[+] Error while reading the body: %s", err) |
77 | | - http.Error(rw, INTERNAL_ERROR_MSG, http.StatusInternalServerError) |
78 | | - return |
79 | | - } |
80 | | - nlq := string(nlqBytes) |
81 | | - |
82 | | - session := &user.SessionState{ |
83 | | - MetaData: map[string]any{ |
84 | | - METADATA_NLQ: string(nlq), |
85 | | - METADATA_RESPONSE_TYPE: RESPONSE_TYPE_NL, |
86 | | - }, |
87 | | - } |
88 | | - ctx.SetSession(r, session, true) |
89 | | - |
90 | | - service, err := findACPServiceFromQuery(nlq) |
91 | | - if err != nil { |
92 | | - logger.Errorf("[+] Failed to find a service for query: %s", nlq) |
93 | | - http.Error(rw, INTERNAL_ERROR_MSG, http.StatusInternalServerError) |
94 | | - return |
95 | | - } |
96 | | - logger.Debugf("[+] Found a service (%v) for query=%v", service, nlq) |
97 | | - |
98 | | - u, err := url.Parse(service) |
99 | | - if err != nil { |
100 | | - logger.Errorf("[+] Error while parsing the service URL (%v): %s", service, err) |
101 | | - http.Error(rw, INTERNAL_ERROR_MSG, http.StatusInternalServerError) |
102 | | - return |
103 | | - } |
104 | | - logger.Debugf("[+] redirect query to: %v ", u) |
105 | | - |
106 | | - rctx := r.Context() |
107 | | - rctx = context.WithValue(rctx, ctx.UrlRewriteTarget, u) |
108 | | - SetContext(r, rctx) |
109 | | -} |
110 | | - |
111 | | -func initACPPluginApiConfig() error { |
112 | | - // Clear existing map |
113 | | - acpPluginData.ACPPluginServices = make(map[string]ACPPluginApiConfig) |
114 | | - |
115 | | - // save the current version of the store BEFORE retreiving the data |
116 | | - acpPluginData.StoreVersion = storeVersion |
117 | | - logger.Debugf("[+] Loading ACP plugin config version (%v) ...", acpPluginData.StoreVersion) |
118 | | - |
119 | | - acpPluginData.MaxRequestLength = int64(getEnvAsInt("MAX_REQUEST_SIZE", DEFAULT_MAX_REQUEST_SIZE)) |
120 | | - |
121 | | - // Get All APIs keys and values from Redis |
122 | | - apiKeysValues := agentBridgeStore.GetKeysAndValuesWithFilter("*") |
123 | | - if apiKeysValues == nil { |
124 | | - logger.Error("[+] Error while getting the keys and values from Redis") |
125 | | - return fmt.Errorf("error while getting the keys and values from Redis") |
126 | | - } |
127 | | - // Refresh config |
128 | | - for key, value := range apiKeysValues { |
129 | | - logger.Debugf("[+] Found key: '%s', with value: '%s'", key, value) |
130 | | - apiConfig := ACPPluginApiConfig{} |
131 | | - err := json.Unmarshal([]byte(value), &apiConfig) |
132 | | - if err != nil { |
133 | | - logger.Fatalf("[+] conversion error for acpPluginConfig: %s", err) |
134 | | - } |
135 | | - acpPluginData.ACPPluginServices[apiConfig.APIName] = apiConfig |
136 | | - } |
137 | | - |
138 | | - return nil |
139 | | -} |
140 | | - |
141 | | -func findACPServiceFromQuery(query string) (string, error) { |
142 | | - logger.Debugf("[+] Process query=%v <--", query) |
143 | | - |
144 | | - if acpPluginData.ModelEmbedder == nil { |
145 | | - var err error |
146 | | - acpPluginData.ModelPath = filepath.Join(DEFAULT_MODEL_EMBEDDINGS_PATH, DEFAULT_MODEL_EMBEDDINGS_MODEL) |
147 | | - acpPluginData.ModelEmbedder, err = search.NewVectorizer(acpPluginData.ModelPath, 1) |
148 | | - if err != nil { |
149 | | - return "", fmt.Errorf("[+] Unable to find embedding model %s: %s", acpPluginData.ModelPath, err) |
150 | | - } |
151 | | - acpPluginData.ModelIndex = search.NewIndex[string]() |
152 | | - for _, service := range acpPluginData.ACPPluginServices { |
153 | | - for _, utterance := range service.Utterances { |
154 | | - embedding, err := acpPluginData.ModelEmbedder.EmbedText(utterance) |
155 | | - if err != nil { |
156 | | - return "", fmt.Errorf("[+] embedding model %s failed for text \"%s\": %s", acpPluginData.ModelPath, utterance, err) |
157 | | - } |
158 | | - acpPluginData.ModelIndex.Add(embedding, service.Target) |
159 | | - } |
160 | | - } |
161 | | - } |
162 | | - if acpPluginData.ModelEmbedder == nil || acpPluginData.ModelIndex == nil { |
163 | | - return "", fmt.Errorf("[+] ModelEmbedder or ModelIndex is nil") |
164 | | - } |
165 | | - |
166 | | - embedding, err := acpPluginData.ModelEmbedder.EmbedText(query) |
167 | | - if err != nil { |
168 | | - return "", fmt.Errorf("[+] embedding model %s failed for query \"%s\": %s", acpPluginData.ModelPath, query, err) |
169 | | - } |
170 | | - results := acpPluginData.ModelIndex.Search(embedding, NBRESULT) |
171 | | - if len(results) == 0 { |
172 | | - return "", fmt.Errorf("[+] No service found for query \"%s\": %s", query, err) |
173 | | - } else if NBRESULT > 1 { |
174 | | - for index, result := range results { |
175 | | - logger.Debugf("Result %d: %v / %v\n", index, result.Value, result.Relevance) |
176 | | - } |
177 | | - } |
178 | | - if results[0].Relevance < DEFAULT_THRESHOLD { |
179 | | - return "", fmt.Errorf("[+] No valid service found for query \"%s\": %s", query, err) |
180 | | - } |
181 | | - |
182 | | - return results[0].Value, nil |
183 | | -} |
184 | | - |
185 | | -func init() { |
186 | | - logger.Infof("[+] Initializing API Bridge Agnt plugin (ACP)...") |
187 | | - |
188 | | - // Init Redis store, if needed |
189 | | - if agentBridgeStore == nil { |
190 | | - agentBridgeStore = getStorageForPlugin(context.TODO()) |
191 | | - } |
192 | | - |
193 | | -} |
| 1 | +// Copyright AGNTCY Contributors (https://github.com/agntcy) |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +package main |
| 5 | + |
| 6 | +import ( |
| 7 | + "context" |
| 8 | + "encoding/json" |
| 9 | + "fmt" |
| 10 | + "io" |
| 11 | + "net/http" |
| 12 | + "net/url" |
| 13 | + "path/filepath" |
| 14 | + |
| 15 | + "github.com/TykTechnologies/tyk/ctx" |
| 16 | + "github.com/TykTechnologies/tyk/user" |
| 17 | + "github.com/kelindar/search" |
| 18 | +) |
| 19 | + |
| 20 | +const ( |
| 21 | + DEFAULT_THRESHOLD = 0.5 |
| 22 | +) |
| 23 | + |
| 24 | +type ACPPluginApiConfig struct { |
| 25 | + APIName string `json:"name"` |
| 26 | + Target string `json:"url"` |
| 27 | + Utterances []string `json:"utterances"` |
| 28 | +} |
| 29 | + |
| 30 | +type ACPPluginData struct { |
| 31 | + ACPPluginServices map[string]ACPPluginApiConfig |
| 32 | + ModelPath string |
| 33 | + ModelEmbedder *search.Vectorizer |
| 34 | + ModelIndex *search.Index[string] |
| 35 | + StoreVersion int64 |
| 36 | + MaxRequestLength int64 `json:"maxRequestLength"` // MaxRequestSize is the maximum size of the request in characters; default is -1 (no limit) |
| 37 | +} |
| 38 | + |
| 39 | +var acpPluginData = ACPPluginData{ |
| 40 | + ACPPluginServices: map[string]ACPPluginApiConfig{}, |
| 41 | +} |
| 42 | + |
| 43 | +// SetContext updates the context of a request. |
| 44 | +func SetContext(r *http.Request, ctx context.Context) { |
| 45 | + r2 := r.WithContext(ctx) |
| 46 | + *r = *r2 |
| 47 | +} |
| 48 | + |
| 49 | +func ProcessACPQuery(rw http.ResponseWriter, r *http.Request) { |
| 50 | + logger.Debug("[+] Inside ProcessACPQuery -->") |
| 51 | + |
| 52 | + if len(acpPluginData.ACPPluginServices) == 0 || acpPluginData.StoreVersion != storeVersion { |
| 53 | + logger.Infof("[+] ACP plugin config is empty or store version has changed, reloading ...") |
| 54 | + err := initACPPluginApiConfig() |
| 55 | + if err != nil { |
| 56 | + logger.Errorf("[+] Error while getting the ACP plugin config: %s", err) |
| 57 | + http.Error(rw, INTERNAL_ERROR_MSG, http.StatusInternalServerError) |
| 58 | + return |
| 59 | + } |
| 60 | + } |
| 61 | + |
| 62 | + // Only proceed for POST with Content-Type: application/nlq (parameters are allowed) |
| 63 | + if r.Method != http.MethodPost || !isNLQContentType(r.Header.Get("Content-Type")) { |
| 64 | + logger.Debugf("[+] Query is not POST or Content-Type is not %s, ignoring ...", CONTENT_TYPE_NLQ) |
| 65 | + return |
| 66 | + } |
| 67 | + |
| 68 | + if acpPluginData.MaxRequestLength > 0 && r.ContentLength > acpPluginData.MaxRequestLength { |
| 69 | + logger.Debugf("[+] Query is too large, ignoring ...") |
| 70 | + http.Error(rw, "Query is too large", http.StatusRequestEntityTooLarge) |
| 71 | + return |
| 72 | + } |
| 73 | + |
| 74 | + nlqBytes, err := io.ReadAll(r.Body) |
| 75 | + if err != nil { |
| 76 | + logger.Errorf("[+] Error while reading the body: %s", err) |
| 77 | + http.Error(rw, INTERNAL_ERROR_MSG, http.StatusInternalServerError) |
| 78 | + return |
| 79 | + } |
| 80 | + nlq := string(nlqBytes) |
| 81 | + |
| 82 | + session := &user.SessionState{ |
| 83 | + MetaData: map[string]any{ |
| 84 | + METADATA_NLQ: string(nlq), |
| 85 | + METADATA_RESPONSE_TYPE: RESPONSE_TYPE_NL, |
| 86 | + }, |
| 87 | + } |
| 88 | + ctx.SetSession(r, session, true) |
| 89 | + |
| 90 | + service, err := findACPServiceFromQuery(nlq) |
| 91 | + if err != nil { |
| 92 | + logger.Errorf("[+] Failed to find a service for query: %s", nlq) |
| 93 | + http.Error(rw, INTERNAL_ERROR_MSG, http.StatusInternalServerError) |
| 94 | + return |
| 95 | + } |
| 96 | + logger.Debugf("[+] Found a service (%v) for query=%v", service, nlq) |
| 97 | + |
| 98 | + u, err := url.Parse(service) |
| 99 | + if err != nil { |
| 100 | + logger.Errorf("[+] Error while parsing the service URL (%v): %s", service, err) |
| 101 | + http.Error(rw, INTERNAL_ERROR_MSG, http.StatusInternalServerError) |
| 102 | + return |
| 103 | + } |
| 104 | + logger.Debugf("[+] redirect query to: %v ", u) |
| 105 | + |
| 106 | + rctx := r.Context() |
| 107 | + rctx = context.WithValue(rctx, ctx.UrlRewriteTarget, u) |
| 108 | + SetContext(r, rctx) |
| 109 | +} |
| 110 | + |
| 111 | +func initACPPluginApiConfig() error { |
| 112 | + // Clear existing map |
| 113 | + acpPluginData.ACPPluginServices = make(map[string]ACPPluginApiConfig) |
| 114 | + |
| 115 | + // save the current version of the store BEFORE retreiving the data |
| 116 | + acpPluginData.StoreVersion = storeVersion |
| 117 | + logger.Debugf("[+] Loading ACP plugin config version (%v) ...", acpPluginData.StoreVersion) |
| 118 | + |
| 119 | + acpPluginData.MaxRequestLength = int64(getEnvAsInt("MAX_REQUEST_SIZE", DEFAULT_MAX_REQUEST_SIZE)) |
| 120 | + |
| 121 | + // Get All APIs keys and values from Redis |
| 122 | + apiKeysValues := agentBridgeStore.GetKeysAndValuesWithFilter("*") |
| 123 | + if apiKeysValues == nil { |
| 124 | + logger.Error("[+] Error while getting the keys and values from Redis") |
| 125 | + return fmt.Errorf("error while getting the keys and values from Redis") |
| 126 | + } |
| 127 | + // Refresh config |
| 128 | + for key, value := range apiKeysValues { |
| 129 | + logger.Debugf("[+] Found key: '%s', with value: '%s'", key, value) |
| 130 | + apiConfig := ACPPluginApiConfig{} |
| 131 | + err := json.Unmarshal([]byte(value), &apiConfig) |
| 132 | + if err != nil { |
| 133 | + logger.Fatalf("[+] conversion error for acpPluginConfig: %s", err) |
| 134 | + } |
| 135 | + acpPluginData.ACPPluginServices[apiConfig.APIName] = apiConfig |
| 136 | + } |
| 137 | + |
| 138 | + return nil |
| 139 | +} |
| 140 | + |
| 141 | +func findACPServiceFromQuery(query string) (string, error) { |
| 142 | + logger.Debugf("[+] Process query=%v <--", query) |
| 143 | + |
| 144 | + if acpPluginData.ModelEmbedder == nil { |
| 145 | + var err error |
| 146 | + acpPluginData.ModelPath = filepath.Join(DEFAULT_MODEL_EMBEDDINGS_PATH, DEFAULT_MODEL_EMBEDDINGS_MODEL) |
| 147 | + acpPluginData.ModelEmbedder, err = search.NewVectorizer(acpPluginData.ModelPath, 1) |
| 148 | + if err != nil { |
| 149 | + return "", fmt.Errorf("[+] Unable to find embedding model %s: %s", acpPluginData.ModelPath, err) |
| 150 | + } |
| 151 | + acpPluginData.ModelIndex = search.NewIndex[string]() |
| 152 | + for _, service := range acpPluginData.ACPPluginServices { |
| 153 | + for _, utterance := range service.Utterances { |
| 154 | + embedding, err := acpPluginData.ModelEmbedder.EmbedText(utterance) |
| 155 | + if err != nil { |
| 156 | + return "", fmt.Errorf("[+] embedding model %s failed for text \"%s\": %s", acpPluginData.ModelPath, utterance, err) |
| 157 | + } |
| 158 | + acpPluginData.ModelIndex.Add(embedding, service.Target) |
| 159 | + } |
| 160 | + } |
| 161 | + } |
| 162 | + if acpPluginData.ModelEmbedder == nil || acpPluginData.ModelIndex == nil { |
| 163 | + return "", fmt.Errorf("[+] ModelEmbedder or ModelIndex is nil") |
| 164 | + } |
| 165 | + |
| 166 | + embedding, err := acpPluginData.ModelEmbedder.EmbedText(query) |
| 167 | + if err != nil { |
| 168 | + return "", fmt.Errorf("[+] embedding model %s failed for query \"%s\": %s", acpPluginData.ModelPath, query, err) |
| 169 | + } |
| 170 | + results := acpPluginData.ModelIndex.Search(embedding, NBRESULT) |
| 171 | + if len(results) == 0 { |
| 172 | + return "", fmt.Errorf("[+] No service found for query \"%s\": %s", query, err) |
| 173 | + } else if NBRESULT > 1 { |
| 174 | + for index, result := range results { |
| 175 | + logger.Debugf("Result %d: %v / %v\n", index, result.Value, result.Relevance) |
| 176 | + } |
| 177 | + } |
| 178 | + if results[0].Relevance < DEFAULT_THRESHOLD { |
| 179 | + return "", fmt.Errorf("[+] No valid service found for query \"%s\": %s", query, err) |
| 180 | + } |
| 181 | + |
| 182 | + return results[0].Value, nil |
| 183 | +} |
| 184 | + |
| 185 | +func init() { |
| 186 | + logger.Infof("[+] Initializing API Bridge Agnt plugin (ACP)...") |
| 187 | + |
| 188 | + // Init Redis store, if needed |
| 189 | + if agentBridgeStore == nil { |
| 190 | + agentBridgeStore = getStorageForPlugin(context.TODO()) |
| 191 | + } |
| 192 | + |
| 193 | +} |
0 commit comments