Skip to content

Commit eb01ad6

Browse files
authored
fix: clear llama memory cache on subsequent embeddings to prevent context overflow (#73)
* fix(llama): prevent context overflow and filter invalid utf8/null bytes * fix(embed): clear llama KV cache memory between embeddings * fix(lint): handle or ignore error returns in embedder --------- Co-authored-by: uchebnick <uchebnick@users.noreply.github.com>
1 parent aeb3554 commit eb01ad6

1 file changed

Lines changed: 19 additions & 10 deletions

File tree

internal/embed/llama/embedder.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ func New(cfg Config) (*Embedder, error) {
116116

117117
ctxParams := llama.ContextDefaultParams()
118118
ctxParams.NCtx = uint32(cfg.ContextSize)
119+
ctxParams.NBatch = uint32(cfg.ContextSize)
120+
ctxParams.NUbatch = uint32(cfg.ContextSize)
119121
ctxParams.PoolingType = cfg.Pooling
120122
ctxParams.Embeddings = 1
121123

@@ -185,28 +187,35 @@ func (e *Embedder) Dim() int {
185187
}
186188

187189
func (e *Embedder) Embed(text string) ([]float32, error) {
188-
if e == nil {
189-
return nil, fmt.Errorf("nil embedder")
190-
}
190+
e.mu.Lock()
191+
defer e.mu.Unlock()
191192

192193
text = normalizeText(text)
193-
if text == "" {
194-
return nil, fmt.Errorf("empty text")
194+
195+
tokens := llama.Tokenize(e.vocab, text, true, false)
196+
if len(tokens) == 0 {
197+
return nil, nil // Return empty if text results in zero tokens
195198
}
196199

197-
e.mu.Lock()
198-
defer e.mu.Unlock()
200+
// Truncate tokens if they exceed ContextSize
201+
if len(tokens) > e.contextSize {
202+
tokens = tokens[:e.contextSize]
203+
}
199204

200-
tokens := llama.Tokenize(e.vocab, text, true, true)
201-
if len(tokens) == 0 {
202-
return nil, fmt.Errorf("tokenize returned zero tokens")
205+
// Clear memory before processing new tokens
206+
mem, err := llama.GetMemory(e.ctx)
207+
if err == nil {
208+
_ = llama.MemoryClear(mem, true)
203209
}
204210

205211
if len(tokens) > e.contextSize {
206212
tokens = tokens[:e.contextSize]
207213
}
208214

209215
batch := llama.BatchGetOne(tokens)
216+
defer func() {
217+
_ = llama.BatchFree(batch)
218+
}()
210219

211220
ret, err := llama.Decode(e.ctx, batch)
212221
if err != nil {

0 commit comments

Comments
 (0)