Skip to content

Commit cf52800

Browse files
JkLondonJkLondon
andauthored
threadID check (#181)
* Add platform-specific thread ID retrieval and enforce strict thread checks * fix of win --------- Co-authored-by: JkLondon <ilya@mikheev.fun>
1 parent 3d3028c commit cf52800

File tree

6 files changed

+69
-1
lines changed

6 files changed

+69
-1
lines changed

go.mod

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
module github.com/erigontech/mdbx-go
22

3-
go 1.20
3+
go 1.23.0
4+
5+
toolchain go1.24.1
46

57
require github.com/ianlancetaylor/cgosymbolizer v0.0.0-20241129212102-9c50ad6b591e
8+
9+
require golang.org/x/sys v0.31.0 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
github.com/ianlancetaylor/cgosymbolizer v0.0.0-20241129212102-9c50ad6b591e h1:8AnObPi8WmIgjwcidUxaREhXMSpyUJeeSrIkZTXdabw=
22
github.com/ianlancetaylor/cgosymbolizer v0.0.0-20241129212102-9c50ad6b591e/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg=
3+
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
4+
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=

mdbx/threads/threads_darwin.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//go:build darwin
2+
3+
package threads
4+
5+
/*
6+
#include <pthread.h>
7+
#include <stdint.h>
8+
9+
// getThreadID uses the pthread API to get the thread ID.
10+
uint64_t getThreadID() {
11+
uint64_t tid;
12+
pthread_threadid_np(NULL, &tid);
13+
return tid;
14+
}
15+
*/
16+
import "C"
17+
18+
//TODO: maybe there's go func for that)
19+
20+
// CurrentThreadID returns the macOS thread ID.
21+
func CurrentThreadID() uint64 {
22+
return uint64(C.getThreadID())
23+
}

mdbx/threads/threads_linux.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
//go:build linux
2+
3+
package threads
4+
5+
import "syscall"
6+
7+
// CurrentThreadID returns the Linux thread ID.
8+
// Note: gettid() is not directly available in Go so we use a raw syscall.
9+
func CurrentThreadID() uint64 {
10+
tid, _, _ := syscall.RawSyscall(syscall.SYS_GETTID, 0, 0, 0)
11+
return uint64(tid)
12+
}

mdbx/threads/threads_windows.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
//go:build windows
2+
3+
package threads
4+
5+
import (
6+
"golang.org/x/sys/windows"
7+
)
8+
9+
// CurrentThreadID returns the Windows thread ID.
10+
func CurrentThreadID() uint64 {
11+
return uint64(windows.GetCurrentThreadId())
12+
}

mdbx/txn.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package mdbx
88
import "C"
99

1010
import (
11+
"fmt"
12+
"github.com/erigontech/mdbx-go/mdbx/threads"
1113
"log"
1214
"time"
1315
"unsafe"
@@ -77,6 +79,8 @@ type Txn struct {
7779
// be paid. The id of a Txn cannot change over its life, even if it is
7880
// reset/renewed
7981
id uint64
82+
83+
tid uint64
8084
}
8185

8286
// beginTxn does not lock the OS thread which is a prerequisite for creating a
@@ -86,6 +90,7 @@ func beginTxn(env *Env, parent *Txn, flags uint) (*Txn, error) {
8690
readonly: flags&Readonly != 0,
8791
env: env,
8892
}
93+
txn.tid = threads.CurrentThreadID()
8994

9095
var ptxn *C.MDBX_txn
9196
if parent != nil {
@@ -243,6 +248,7 @@ type CommitLatencyGC struct {
243248

244249
func (txn *Txn) commit() (CommitLatency, error) {
245250
var _stat C.MDBX_commit_latency
251+
txn.strictThreadCheck()
246252
ret := C.mdbx_txn_commit_ex(txn._txn, &_stat)
247253
txn.clearTxn()
248254
s := CommitLatency{
@@ -297,6 +303,14 @@ func (txn *Txn) Abort() {
297303
txn.abort()
298304
}
299305

306+
func (txn *Txn) strictThreadCheck() {
307+
currentThread := threads.CurrentThreadID()
308+
if currentThread != txn.tid {
309+
msg := fmt.Sprintf("thread mismatch. not allowed current %d, open in %d", currentThread, txn.tid)
310+
panic(msg)
311+
}
312+
}
313+
300314
func (txn *Txn) abort() {
301315
if txn._txn == nil {
302316
return
@@ -305,6 +319,7 @@ func (txn *Txn) abort() {
305319
// Get a read-lock on the environment so we can abort txn if needed.
306320
// txn.env **should** terminate all readers otherwise when it closes.
307321
txn.env.closeLock.RLock()
322+
txn.strictThreadCheck()
308323
if txn.env._env != nil {
309324
C.mdbx_txn_abort(txn._txn)
310325
}

0 commit comments

Comments
 (0)