Skip to content

Commit 00c068a

Browse files
authored
Merge pull request #114 from amikos-tech/codex/evaluate-concurrent-tokenization
Fix tokenizer lifecycle close races
2 parents 02a3420 + 53b9229 commit 00c068a

2 files changed

Lines changed: 304 additions & 12 deletions

File tree

tokenizer_lifecycle_test.go

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
package tokenizers
2+
3+
import (
4+
stderrors "errors"
5+
"sync"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func newLifecycleTestTokenizer(t *testing.T) *Tokenizer {
13+
t.Helper()
14+
15+
libpath := checkLibraryExists(t)
16+
tok, err := FromFile("./tokenizer.json", WithLibraryPath(libpath))
17+
require.NoError(t, err, "Failed to load tokenizer from file")
18+
return tok
19+
}
20+
21+
func TestCloseIsIdempotent(t *testing.T) {
22+
tok := newLifecycleTestTokenizer(t)
23+
24+
require.NoError(t, tok.Close())
25+
require.NoError(t, tok.Close())
26+
}
27+
28+
func TestConcurrentCloseIsIdempotent(t *testing.T) {
29+
tok := newLifecycleTestTokenizer(t)
30+
31+
const goroutines = 8
32+
errs := make(chan error, goroutines)
33+
34+
var wg sync.WaitGroup
35+
wg.Add(goroutines)
36+
for i := 0; i < goroutines; i++ {
37+
go func() {
38+
defer wg.Done()
39+
errs <- tok.Close()
40+
}()
41+
}
42+
43+
wg.Wait()
44+
close(errs)
45+
for err := range errs {
46+
require.NoError(t, err)
47+
}
48+
}
49+
50+
func TestTokenizerMethodsReturnErrTokenizerClosed(t *testing.T) {
51+
tok := newLifecycleTestTokenizer(t)
52+
53+
encoding, err := tok.Encode("Hello, world!")
54+
require.NoError(t, err)
55+
56+
require.NoError(t, tok.Close())
57+
58+
_, err = tok.Encode("Hello again")
59+
require.ErrorIs(t, err, ErrTokenizerClosed)
60+
61+
_, err = tok.EncodePairs([]string{"query"}, []string{"document"})
62+
require.ErrorIs(t, err, ErrTokenizerClosed)
63+
64+
_, err = tok.Decode(encoding.IDs, false)
65+
require.ErrorIs(t, err, ErrTokenizerClosed)
66+
67+
_, err = tok.VocabSize()
68+
require.ErrorIs(t, err, ErrTokenizerClosed)
69+
70+
require.Equal(t, "unknown", tok.GetLibraryVersion())
71+
}
72+
73+
func TestCloseWaitsForActiveOperations(t *testing.T) {
74+
tok := newLifecycleTestTokenizer(t)
75+
76+
tok.lifecycleMu.RLock()
77+
closeDone := make(chan error, 1)
78+
go func() {
79+
closeDone <- tok.Close()
80+
}()
81+
82+
select {
83+
case err := <-closeDone:
84+
t.Fatalf("Close returned before active operations finished: %v", err)
85+
case <-time.After(200 * time.Millisecond):
86+
}
87+
88+
tok.lifecycleMu.RUnlock()
89+
require.NoError(t, <-closeDone)
90+
91+
_, err := tok.Encode("Hello after close")
92+
require.ErrorIs(t, err, ErrTokenizerClosed)
93+
}
94+
95+
func TestConcurrentEncodeAndClose(t *testing.T) {
96+
tok := newLifecycleTestTokenizer(t)
97+
98+
const goroutines = 8
99+
const iterationsPerGoroutine = 200
100+
101+
start := make(chan struct{})
102+
errs := make(chan error, goroutines)
103+
104+
var wg sync.WaitGroup
105+
wg.Add(goroutines)
106+
for i := 0; i < goroutines; i++ {
107+
go func() {
108+
defer wg.Done()
109+
<-start
110+
for j := 0; j < iterationsPerGoroutine; j++ {
111+
result, err := tok.Encode("Concurrent lifecycle test text")
112+
if err != nil && !stderrors.Is(err, ErrTokenizerClosed) {
113+
errs <- err
114+
return
115+
}
116+
if err == nil && len(result.IDs) == 0 {
117+
errs <- stderrors.New("encode returned empty ids")
118+
return
119+
}
120+
}
121+
}()
122+
}
123+
124+
close(start)
125+
time.Sleep(10 * time.Millisecond)
126+
require.NoError(t, tok.Close())
127+
128+
wg.Wait()
129+
close(errs)
130+
for err := range errs {
131+
require.NoError(t, err)
132+
}
133+
}
134+
135+
func TestConcurrentMixedOperationsAndClose(t *testing.T) {
136+
tok := newLifecycleTestTokenizer(t)
137+
138+
baseline, err := tok.Encode("Mixed lifecycle test text")
139+
require.NoError(t, err)
140+
require.NotEmpty(t, baseline.IDs)
141+
pairBaseline, err := tok.EncodePairs([]string{"query"}, []string{"document"})
142+
require.NoError(t, err)
143+
require.Len(t, pairBaseline, 1)
144+
require.NotEmpty(t, pairBaseline[0].IDs)
145+
146+
const goroutines = 12
147+
const iterationsPerGoroutine = 150
148+
149+
start := make(chan struct{})
150+
errs := make(chan error, goroutines)
151+
152+
var wg sync.WaitGroup
153+
wg.Add(goroutines)
154+
for i := 0; i < goroutines; i++ {
155+
workerID := i
156+
go func() {
157+
defer wg.Done()
158+
<-start
159+
for j := 0; j < iterationsPerGoroutine; j++ {
160+
switch workerID % 4 {
161+
case 0:
162+
result, opErr := tok.Encode("Mixed lifecycle test text")
163+
if opErr == nil && len(result.IDs) == 0 {
164+
errs <- stderrors.New("mixed encode returned empty ids")
165+
return
166+
}
167+
if opErr != nil && !stderrors.Is(opErr, ErrTokenizerClosed) {
168+
errs <- opErr
169+
return
170+
}
171+
case 1:
172+
decoded, opErr := tok.Decode(baseline.IDs, false)
173+
if opErr == nil && decoded == "" {
174+
errs <- stderrors.New("mixed decode returned empty string")
175+
return
176+
}
177+
if opErr != nil && !stderrors.Is(opErr, ErrTokenizerClosed) {
178+
errs <- opErr
179+
return
180+
}
181+
case 2:
182+
size, opErr := tok.VocabSize()
183+
if opErr == nil && size == 0 {
184+
errs <- stderrors.New("mixed vocab size returned zero")
185+
return
186+
}
187+
if opErr != nil && !stderrors.Is(opErr, ErrTokenizerClosed) {
188+
errs <- opErr
189+
return
190+
}
191+
default:
192+
results, opErr := tok.EncodePairs([]string{"query"}, []string{"document"})
193+
if opErr == nil && (len(results) != 1 || len(results[0].IDs) == 0) {
194+
errs <- stderrors.New("mixed encode pairs returned empty ids")
195+
return
196+
}
197+
if opErr != nil && !stderrors.Is(opErr, ErrTokenizerClosed) {
198+
errs <- opErr
199+
return
200+
}
201+
}
202+
}
203+
}()
204+
}
205+
206+
close(start)
207+
time.Sleep(10 * time.Millisecond)
208+
require.NoError(t, tok.Close())
209+
210+
wg.Wait()
211+
close(errs)
212+
for err := range errs {
213+
require.NoError(t, err)
214+
}
215+
}

tokenizers.go

Lines changed: 89 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package tokenizers
33
import (
44
"math"
55
"os"
6+
"sync"
67
"unsafe"
78

89
"github.com/Masterminds/semver/v3"
@@ -28,6 +29,9 @@ const (
2829
ErrInvalidOptions = -14
2930
)
3031

32+
// ErrTokenizerClosed is returned when an operation is attempted on a closed tokenizer.
33+
var ErrTokenizerClosed = errors.New("tokenizer is closed")
34+
3135
// AbiCompatibilityConstraint defines the required version range for ABI compatibility.
3236
// The library version from Cargo.toml is used as the ABI version.
3337
// Update this constraint when making breaking changes to the FFI interface.
@@ -218,6 +222,8 @@ func WithPadding(enabled bool, strategy PaddingStrategy) TokenizerOption {
218222
}
219223

220224
type Tokenizer struct {
225+
lifecycleMu sync.RWMutex
226+
closed bool
221227
LibraryPath string // Path to the shared library
222228
libh uintptr
223229
tokenizerh unsafe.Pointer // Pointer to the tokenizer instance
@@ -384,19 +390,63 @@ To resolve this issue:
384390
return nil
385391
}
386392

387-
func (t *Tokenizer) Close() error {
388-
if t.tokenizerh != nil {
389-
t.freeTokenizer(t.tokenizerh)
390-
t.tokenizerh = nil
393+
func (t *Tokenizer) beginOperation() (func(), error) {
394+
t.lifecycleMu.RLock()
395+
if t.closed {
396+
t.lifecycleMu.RUnlock()
397+
return nil, ErrTokenizerClosed
391398
}
392-
err := closeLibrary(t.libh)
393-
if err != nil {
394-
return errors.Errorf("failed to close shared library: %s", err.Error())
399+
return t.lifecycleMu.RUnlock, nil
400+
}
401+
402+
func (t *Tokenizer) Close() (err error) {
403+
t.lifecycleMu.Lock()
404+
if t.closed {
405+
t.lifecycleMu.Unlock()
406+
return nil
395407
}
396-
return nil
408+
t.closed = true
409+
410+
tokenizerh := t.tokenizerh
411+
freeTokenizer := t.freeTokenizer
412+
libh := t.libh
413+
414+
t.tokenizerh = nil
415+
t.libh = 0
416+
t.fromFile = nil
417+
t.fromBytes = nil
418+
t.encode = nil
419+
t.encodeBatchPairs = nil
420+
t.freeTokenizer = nil
421+
t.freeBuffer = nil
422+
t.freeString = nil
423+
t.decode = nil
424+
t.vocabSize = nil
425+
t.getVersion = nil
426+
427+
t.lifecycleMu.Unlock()
428+
429+
if libh != 0 {
430+
defer func() {
431+
if closeErr := closeLibrary(libh); err == nil && closeErr != nil {
432+
err = closeErr
433+
}
434+
}()
435+
}
436+
437+
if tokenizerh != nil && freeTokenizer != nil {
438+
freeTokenizer(tokenizerh)
439+
}
440+
return
397441
}
398442

399443
func (t *Tokenizer) Encode(message string, opts ...EncodeOption) (*EncodeResult, error) {
444+
unlock, err := t.beginOperation()
445+
if err != nil {
446+
return nil, err
447+
}
448+
defer unlock()
449+
400450
if t.encode == nil || t.tokenizerh == nil {
401451
return nil, errors.New("encode function is not initialized or tokenizer is not loaded")
402452
}
@@ -454,6 +504,12 @@ func (t *Tokenizer) Encode(message string, opts ...EncodeOption) (*EncodeResult,
454504
// EncodePairs encodes multiple sequence pairs in parallel.
455505
// This is useful for reranking tasks where you need to encode query-document pairs.
456506
func (t *Tokenizer) EncodePairs(sequences []string, pairs []string, opts ...EncodeOption) ([]*EncodeResult, error) {
507+
unlock, err := t.beginOperation()
508+
if err != nil {
509+
return nil, err
510+
}
511+
defer unlock()
512+
457513
if t.encodeBatchPairs == nil || t.tokenizerh == nil {
458514
return nil, errors.New("encode_batch_pairs function is not initialized or tokenizer is not loaded")
459515
}
@@ -504,6 +560,11 @@ func (t *Tokenizer) EncodePairs(sequences []string, pairs []string, opts ...Enco
504560
lastError := getErrorForCode(rc)
505561
return nil, errors.Wrap(lastError, "failed to encode pairs")
506562
}
563+
defer func() {
564+
for i := range buffers {
565+
t.freeBuffer(&buffers[i])
566+
}
567+
}()
507568

508569
// Convert buffers to results
509570
results := make([]*EncodeResult, len(buffers))
@@ -547,9 +608,6 @@ func (t *Tokenizer) EncodePairs(sequences []string, pairs []string, opts ...Enco
547608
}
548609

549610
results[i] = result
550-
551-
// Free the buffer
552-
t.freeBuffer(buff)
553611
}
554612

555613
return results, nil
@@ -566,6 +624,12 @@ func (t *Tokenizer) EncodePair(sequence string, pair string, opts ...EncodeOptio
566624
}
567625

568626
func (t *Tokenizer) Decode(ids []uint32, skipSpecialTokens bool) (string, error) {
627+
unlock, err := t.beginOperation()
628+
if err != nil {
629+
return "", err
630+
}
631+
defer unlock()
632+
569633
if t.decode == nil || t.tokenizerh == nil {
570634
return "", errors.New("decode function is not initialized or tokenizer is not loaded")
571635
}
@@ -618,6 +682,12 @@ func goStringFromPtr(ptr unsafe.Pointer) string {
618682
return string(unsafe.Slice(p, n)) // #nosec G103 -- Converts validated null-terminated FFI buffer into a Go string.
619683
}
620684
func (t *Tokenizer) VocabSize() (uint32, error) {
685+
unlock, err := t.beginOperation()
686+
if err != nil {
687+
return 0, err
688+
}
689+
defer unlock()
690+
621691
if t.vocabSize == nil || t.tokenizerh == nil {
622692
return 0, errors.New("vocabSize function is not initialized or tokenizer is not loaded")
623693
}
@@ -668,8 +738,15 @@ func getErrorForCode(errCode int32) error {
668738
}
669739
}
670740

671-
// GetLibraryVersion returns the version of the tokenizer library
741+
// GetLibraryVersion returns the version of the tokenizer library.
742+
// It returns "unknown" when the version callback is unavailable or the tokenizer is closed.
672743
func (t *Tokenizer) GetLibraryVersion() string {
744+
unlock, err := t.beginOperation()
745+
if err != nil {
746+
return "unknown"
747+
}
748+
defer unlock()
749+
673750
if t.getVersion == nil {
674751
return "unknown"
675752
}

0 commit comments

Comments
 (0)