Skip to content

Commit 4f25833

Browse files
author
JkLondon
committed
Add platform-specific thread ID retrieval and enforce strict thread checks
1 parent 3d3028c commit 4f25833

File tree

4 files changed

+60
-0
lines changed

4 files changed

+60
-0
lines changed

mdbx/threads/threads_darwin.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//go:build darwin
2+
3+
package threads
4+
5+
/*
6+
#include <pthread.h>
7+
#include <stdint.h>
8+
9+
// getThreadID uses the pthread API to get the thread ID.
10+
uint64_t getThreadID() {
11+
uint64_t tid;
12+
pthread_threadid_np(NULL, &tid);
13+
return tid;
14+
}
15+
*/
16+
import "C"
17+
18+
//TODO: maybe there's go func for that)
19+
20+
// CurrentThreadID returns the macOS thread ID.
21+
func CurrentThreadID() uint64 {
22+
return uint64(C.getThreadID())
23+
}

mdbx/threads/threads_linux.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
//go:build linux
2+
3+
package threads
4+
5+
import "syscall"
6+
7+
// CurrentThreadID returns the Linux thread ID.
8+
// Note: gettid() is not directly available in Go so we use a raw syscall.
9+
func CurrentThreadID() uint64 {
10+
tid, _, _ := syscall.RawSyscall(syscall.SYS_GETTID, 0, 0, 0)
11+
return uint64(tid)
12+
}

mdbx/threads/threads_windows.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//go:build windows
2+
3+
package threads
4+
5+
import "syscall"
6+
7+
// CurrentThreadID returns the Windows thread ID.
8+
func CurrentThreadID() uint64 {
9+
return uint64(syscall.GetCurrentThreadId())
10+
}

mdbx/txn.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package mdbx
88
import "C"
99

1010
import (
11+
"fmt"
12+
"github.com/erigontech/mdbx-go/mdbx/threads"
1113
"log"
1214
"time"
1315
"unsafe"
@@ -77,6 +79,8 @@ type Txn struct {
7779
// be paid. The id of a Txn cannot change over its life, even if it is
7880
// reset/renewed
7981
id uint64
82+
83+
tid uint64
8084
}
8185

8286
// beginTxn does not lock the OS thread which is a prerequisite for creating a
@@ -86,6 +90,7 @@ func beginTxn(env *Env, parent *Txn, flags uint) (*Txn, error) {
8690
readonly: flags&Readonly != 0,
8791
env: env,
8892
}
93+
txn.tid = threads.CurrentThreadID()
8994

9095
var ptxn *C.MDBX_txn
9196
if parent != nil {
@@ -243,6 +248,7 @@ type CommitLatencyGC struct {
243248

244249
func (txn *Txn) commit() (CommitLatency, error) {
245250
var _stat C.MDBX_commit_latency
251+
txn.strictThreadCheck()
246252
ret := C.mdbx_txn_commit_ex(txn._txn, &_stat)
247253
txn.clearTxn()
248254
s := CommitLatency{
@@ -297,6 +303,14 @@ func (txn *Txn) Abort() {
297303
txn.abort()
298304
}
299305

306+
func (txn *Txn) strictThreadCheck() {
307+
currentThread := threads.CurrentThreadID()
308+
if currentThread != txn.tid {
309+
msg := fmt.Sprintf("thread mismatch. not allowed current %d, open in %d", currentThread, txn.tid)
310+
panic(msg)
311+
}
312+
}
313+
300314
func (txn *Txn) abort() {
301315
if txn._txn == nil {
302316
return
@@ -305,6 +319,7 @@ func (txn *Txn) abort() {
305319
// Get a read-lock on the environment so we can abort txn if needed.
306320
// txn.env **should** terminate all readers otherwise when it closes.
307321
txn.env.closeLock.RLock()
322+
txn.strictThreadCheck()
308323
if txn.env._env != nil {
309324
C.mdbx_txn_abort(txn._txn)
310325
}

0 commit comments

Comments
 (0)