Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,27 @@ jobs:
done
echo "sqld is ready!"

- name: Install sqlean
run: |
SQLEAN_VERSION=0.27.2
case "${{ runner.os }}" in
Linux) OS_SLUG=linux-x86 ;;
macOS) OS_SLUG=macos-x86 ;;
esac
curl -sL -o sqlean-${OS_SLUG}.zip \
https://github.com/nalgeon/sqlean/releases/download/${SQLEAN_VERSION}/sqlean-${OS_SLUG}.zip
unzip -q sqlean-${OS_SLUG}.zip -d .
echo "$PWD" >> $GITHUB_PATH

- name: Set extension env
run: |
case "${{ runner.os }}" in
Linux) ext=sqlean.so ;;
macOS) ext=sqlean.dylib ;;
esac
echo "LIBSQL_TEST_EXTENSION=$PWD/$ext" >> $GITHUB_ENV
echo "LIBSQL_TEST_EXTENSION_ENTRY=sqlite3_sqlean_init" >> $GITHUB_ENV

- name: Build
run: go build -v ./...

Expand Down
14 changes: 14 additions & 0 deletions libsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,20 @@ type conn struct {
nativePtr C.libsql_connection_t
}

func (c *conn) LoadExtension(lib string, entry string) error {
libCString := C.CString(lib)
defer C.free(unsafe.Pointer(libCString))
entryCString := C.CString(entry)
defer C.free(unsafe.Pointer(entryCString))

var errMsg *C.char
statusCode := C.libsql_load_extension(c.nativePtr, libCString, entryCString, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprintf("failed to load extension %s with entry point %s", lib, entry), statusCode, errMsg)
}
return nil
}

func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
Expand Down
68 changes: 68 additions & 0 deletions libsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1359,3 +1359,71 @@ func TestErrorRowsNext(t *testing.T) {
}
})
}

// To run this, set LIBSQL_TEST_EXTENSION to the full path of a valid SQLite extension
// and (optionally) LIBSQL_TEST_EXTENSION_ENTRY to its init symbol (defaults to "sqlite3_extension_init").
func TestLoadExtension_Existing(t *testing.T) {
extPath := os.Getenv("LIBSQL_TEST_EXTENSION")
if extPath == "" {
t.Skip("LIBSQL_TEST_EXTENSION not set; skipping existing‐extension load test")
}
entryPoint := os.Getenv("LIBSQL_TEST_EXTENSION_ENTRY")
if entryPoint == "" {
entryPoint = "sqlite3_extension_init"
}

db, err := sql.Open("libsql", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()

ctx := context.Background()
sqlConn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer sqlConn.Close()

err = sqlConn.Raw(func(driverConn any) error {
cImpl, ok := driverConn.(*conn)
if !ok {
return fmt.Errorf("unexpected driverConn type %T", driverConn)
}
return cImpl.LoadExtension(extPath, entryPoint)
})

if err != nil {
t.Fatalf("failed to load existing extension %q (entry %q): %v", extPath, entryPoint, err)
}
}

func TestLoadExtension_Nonexistent(t *testing.T) {
db, err := sql.Open("libsql", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()

ctx := context.Background()
sqlConn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer sqlConn.Close()

err = sqlConn.Raw(func(driverConn any) error {
cImpl, ok := driverConn.(*conn)
if !ok {
return fmt.Errorf("unexpected driverConn type %T", driverConn)
}
return cImpl.LoadExtension("nonexistent_extension.so", "entry_point")
})

if err == nil {
t.Fatal("expected error loading nonexistent extension, got nil")
}
if !strings.Contains(err.Error(), "failed to load extension") {
t.Fatalf("unexpected error loading extension: %v", err)
}
}
Loading