Skip to content

Commit 1b1b164

Browse files
authored
fix: add patch for search-release v0.3.0 and v0.4.0 to fiw uBatch value for model (#42)
* fix: add patch for search-release v0.3.0 and v0.4.0 to fiw uBatch value for model Signed-off-by: ataldir <ataldir@cisco.com> * chore: add a MAX_REQUEST_LEN parameter Signed-off-by: ataldir <ataldir@cisco.com> * chore: rewrite two conditions Signed-off-by: ataldir <ataldir@cisco.com> --------- Signed-off-by: ataldir <ataldir@cisco.com>
1 parent 85300f9 commit 1b1b164

8 files changed

+80
-3
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ search-release-$(SEARCH_VERSION)/README.md: download_models_for_semrouter
7070
cd "search-release-$(SEARCH_VERSION)" && git lfs install --local
7171
cd "search-release-$(SEARCH_VERSION)" && git submodule update --init --recursive
7272
cd "search-release-$(SEARCH_VERSION)" && git lfs pull
73+
cd "search-release-$(SEARCH_VERSION)" && git apply --ignore-space-change --ignore-whitespace --reject --quiet --whitespace=fix ../patches/0001-fix-ubatch-in-search-release-$(SEARCH_VERSION).patch || true
7374

7475
build_search_lib search-release-$(SEARCH_VERSION)/build/lib/$(SEARCH_LIB): search-release-$(SEARCH_VERSION)/README.md
7576
-mkdir -p "search-release-$(SEARCH_VERSION)/build"

go.work.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/
253253
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
254254
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
255255
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
256+
github.com/kelindar/search v0.3.0 h1:NY6JkbC8Wy3vujwIR/p9NW3xaVDZxFrE+JQECiFzwW0=
257+
github.com/kelindar/search v0.3.0/go.mod h1:7goLXnzQ6b0vMJr9SKWmABblTrJgPG3rVwp3yz2Lo5Q=
256258
github.com/kelindar/search v0.4.0 h1:mj3U26qB6BQJr9/6Q1vHS/I40CQ6Mhb5iJfmfI2e/mY=
257259
github.com/kelindar/search v0.4.0/go.mod h1:7goLXnzQ6b0vMJr9SKWmABblTrJgPG3rVwp3yz2Lo5Q=
258260
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
diff --git a/llama-go.cpp/llama-go.cpp b/llama-go.cpp/llama-go.cpp
2+
index 4670be1..53d2cec 100644
3+
--- a/llama-go.cpp/llama-go.cpp
4+
+++ b/llama-go.cpp/llama-go.cpp
5+
@@ -100,6 +100,8 @@ extern "C" {
6+
struct llama_context_params params = llama_context_default_params();
7+
params.n_ctx = ctx_size;
8+
params.embeddings = embeddings;
9+
+ printf("[+] Set context param n_ubatch to 2048\n");
10+
+ params.n_ubatch = 2048;
11+
return llama_new_context_with_model(model, params);
12+
}
13+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
diff --git a/llama-go.cpp/llama-go.cpp b/llama-go.cpp/llama-go.cpp
2+
index 79c5ab3..c4e90d1 100644
3+
--- a/llama-go.cpp/llama-go.cpp
4+
+++ b/llama-go.cpp/llama-go.cpp
5+
@@ -100,7 +100,8 @@ extern "C" {
6+
struct llama_context_params params = llama_context_default_params();
7+
params.n_ctx = ctx_size;
8+
params.embeddings = embeddings; // Corrected field name
9+
-
10+
+ printf("[+] Set context param n_ubatch to 2048\n");
11+
+ params.n_ubatch = 2048;
12+
return llama_init_from_model(model, params); // Updated function
13+
}
14+

plugins/agent_bridge.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ func SelectAndRewrite(rw http.ResponseWriter, r *http.Request) {
8787
return
8888
}
8989

90+
if apiConfig.MaxRequestLength > 0 && r.ContentLength > apiConfig.MaxRequestLength {
91+
logger.Debugf("[+] Query is too large, ignoring ...")
92+
http.Error(rw, "Query is too large", http.StatusRequestEntityTooLarge)
93+
return
94+
}
95+
9096
nlqBytes, err := io.ReadAll(r.Body)
9197
if err != nil {
9298
logger.Errorf("[+] Error while reading the body: %s", err)

plugins/agent_bridge_acp.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ type ACPPluginData struct {
3333
ModelEmbedder *search.Vectorizer
3434
ModelIndex *search.Index[string]
3535
StoreVersion int64
36+
MaxRequestLength int64 `json:"maxRequestLength"` // MaxRequestSize is the maximum size of the request in characters; default is -1 (no limit)
3637
}
3738

3839
var acpPluginData = ACPPluginData{
@@ -64,6 +65,12 @@ func ProcessACPQuery(rw http.ResponseWriter, r *http.Request) {
6465
return
6566
}
6667

68+
if acpPluginData.MaxRequestLength > 0 && r.ContentLength > acpPluginData.MaxRequestLength {
69+
logger.Debugf("[+] Query is too large, ignoring ...")
70+
http.Error(rw, "Query is too large", http.StatusRequestEntityTooLarge)
71+
return
72+
}
73+
6774
nlqBytes, err := io.ReadAll(r.Body)
6875
if err != nil {
6976
logger.Errorf("[+] Error while reading the body: %s", err)
@@ -104,9 +111,13 @@ func ProcessACPQuery(rw http.ResponseWriter, r *http.Request) {
104111
func initACPPluginApiConfig() error {
105112
// Clear existing map
106113
acpPluginData.ACPPluginServices = make(map[string]ACPPluginApiConfig)
114+
107115
// save the current version of the store BEFORE retreiving the data
108116
acpPluginData.StoreVersion = storeVersion
109117
logger.Debugf("[+] Loading ACP plugin config version (%v) ...", acpPluginData.StoreVersion)
118+
119+
acpPluginData.MaxRequestLength = int64(getEnvAsInt("MAX_REQUEST_SIZE", DEFAULT_MAX_REQUEST_SIZE))
120+
110121
// Get All APIs keys and values from Redis
111122
apiKeysValues := agentBridgeStore.GetKeysAndValuesWithFilter("*")
112123
if apiKeysValues == nil {

plugins/agent_bridge_config.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"os"
1010
"path/filepath"
11+
"strconv"
1112
"sync"
1213

1314
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
@@ -24,8 +25,9 @@ const (
2425
DEFAULT_OPENAI_ENDPOINT = "https://api.openai.com/v1"
2526
DEFAULT_OPENAI_MODEL = "gpt-4o-mini"
2627

27-
MAX_UTERANCE_LENGTH = 1500
28-
VECTORIZER_GPU_LAYERS = 1
28+
MAX_UTERANCE_LENGTH = 1500
29+
VECTORIZER_GPU_LAYERS = 1
30+
DEFAULT_MAX_REQUEST_SIZE = -1 // in characters; -1 means no limit
2931
)
3032

3133
type AzureConfig struct {
@@ -63,6 +65,8 @@ type PluginDataConfig struct {
6365

6466
APIID string
6567
ListenPath string
68+
69+
MaxRequestLength int64 `json:"maxRequestLength"` // MaxRequestSize is the maximum size of the request in characters; default is -1 (no limit)
6670
}
6771

6872
func getApiId(r *http.Request) (string, error) {
@@ -102,6 +106,17 @@ func getEnvOrDefault(value string, envKey string, defaultValue string) string {
102106
return value
103107
}
104108

109+
func getEnvAsInt(envKey string, defaultValue int) int {
110+
valueStr := os.Getenv(envKey)
111+
if valueStr != "" {
112+
if value, err := strconv.Atoi(valueStr); err == nil {
113+
return value
114+
}
115+
}
116+
117+
return defaultValue
118+
}
119+
105120
func parseConfigData(apiId string, configData map[string]any) (*PluginDataConfig, error) {
106121
logger.Debugf("[+] Parsing config for api id: %s", apiId)
107122

@@ -130,7 +145,8 @@ func parseConfigData(apiId string, configData map[string]any) (*PluginDataConfig
130145
SelectModelsPath: DEFAULT_MODEL_EMBEDDINGS_PATH,
131146
RelevanceThreshold: threshold,
132147

133-
APIID: apiId,
148+
APIID: apiId,
149+
MaxRequestLength: int64(getEnvAsInt("MAX_REQUEST_SIZE", DEFAULT_MAX_REQUEST_SIZE)),
134150
}
135151

136152
logger.Debugf("[+] Finished parsing config for api id: %s", apiId)

plugins/agent_bridge_endpoint_selection_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,17 @@ func TestEndpointSelection(t *testing.T) {
166166
}
167167

168168
}
169+
170+
func TestFixForUBatchBug(t *testing.T) {
171+
pluginDataConfig, err := loadApiSpecsForTests("tyk-github-id", "../configs/api.github.com.gist.deref.oas.json")
172+
assert.Nil(t, err)
173+
174+
modelPath := filepath.Join(pluginDataConfig.SelectModelsPath, DEFAULT_MODEL_EMBEDDINGS_MODEL)
175+
modelEmbedder, err := search.NewVectorizer(modelPath, 1)
176+
assert.Nil(t, err)
177+
178+
inputInFailure := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
179+
_, err = modelEmbedder.EmbedText(inputInFailure)
180+
assert.Nil(t, err)
181+
182+
}

0 commit comments

Comments
 (0)