Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 60 additions & 46 deletions Sources/Libsql/Libsql.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,28 @@ public protocol Prepareable {
func prepare(_ sql: String) throws -> Statement
}

public extension Prepareable {
func execute(_ sql: String) throws -> Int {
extension Prepareable {
public func execute(_ sql: String) throws -> Int {
return try self.prepare(sql).execute()
}
func execute(_ sql: String, _ params: [String: ValueRepresentable]) throws -> Int {

public func execute(_ sql: String, _ params: [String: ValueRepresentable]) throws -> Int {
return try self.prepare(sql).bind(params).execute()
}
func execute(_ sql: String, _ params: [ValueRepresentable]) throws -> Int {

public func execute(_ sql: String, _ params: [ValueRepresentable]) throws -> Int {
return try self.prepare(sql).bind(params).execute()
}
func query(_ sql: String) throws -> Rows {

public func query(_ sql: String) throws -> Rows {
return try self.prepare(sql).query()
}
func query(_ sql: String, _ params: [String: ValueRepresentable]) throws -> Rows {

public func query(_ sql: String, _ params: [String: ValueRepresentable]) throws -> Rows {
return try self.prepare(sql).bind(params).query()
}
func query(_ sql: String, _ params: [ValueRepresentable]) throws -> Rows {

public func query(_ sql: String, _ params: [ValueRepresentable]) throws -> Rows {
return try self.prepare(sql).bind(params).query()
}
}
Expand All @@ -78,7 +78,7 @@ extension String? {
}

func errIf(_ err: OpaquePointer!) throws {
if (err != nil) {
if err != nil {
defer { libsql_error_deinit(err) }
throw LibsqlError.runtimeError(String(cString: libsql_error_message(err)!))
}
Expand All @@ -87,6 +87,7 @@ func errIf(_ err: OpaquePointer!) throws {
enum LibsqlError: Error {
case runtimeError(String)
case typeMismatch
case indexOutOfRange
}

public class Row {
Expand All @@ -99,11 +100,11 @@ public class Row {

self.inner = inner
}

public func get(_ index: Int32) throws -> Value {
let result = libsql_row_value(self.inner, index)
try errIf(result.err)

switch result.ok.type {
case LIBSQL_TYPE_BLOB:
let slice = result.ok.value.blob
Expand All @@ -125,28 +126,28 @@ public class Row {
}

public func getData(_ index: Int32) throws -> Data {
guard case let .blob(data) = try self.get(index) else {
guard case .blob(let data) = try self.get(index) else {
throw LibsqlError.typeMismatch
}
return data
}

public func getDouble(_ index: Int32) throws -> Double {
guard case let .real(double) = try self.get(index) else {
guard case .real(let double) = try self.get(index) else {
throw LibsqlError.typeMismatch
}
return double
}

public func getString(_ index: Int32) throws -> String {
guard case let .text(string) = try self.get(index) else {
guard case .text(let string) = try self.get(index) else {
throw LibsqlError.typeMismatch
}
return string
}

public func getInt(_ index: Int32) throws -> Int {
guard case let .integer(int) = try self.get(index) else {
guard case .integer(let int) = try self.get(index) else {
throw LibsqlError.typeMismatch
}
return Int(int)
Expand All @@ -167,13 +168,26 @@ public class Rows: Sequence, IteratorProtocol {
public func next() -> Row? {
let row = libsql_rows_next(self.inner)
try! errIf(row.err)

if libsql_row_empty(row) {
return nil
}

return Row(from: row)
}

public func columnCount() -> Int {
return Int(libsql_rows_column_count(self.inner))
}

public func columnName(_ index: Int32) throws -> String {
guard index >= 0 && index < Int32(columnCount()) else {
throw LibsqlError.indexOutOfRange
}
let slice = libsql_rows_column_name(self.inner, index)
defer { libsql_slice_deinit(slice) }
return String(cString: slice.ptr.assumingMemoryBound(to: UInt8.self))
}
}

public class Statement {
Expand All @@ -190,7 +204,7 @@ public class Statement {
public func execute() throws -> Int {
let exec = libsql_statement_execute(self.inner)
try errIf(exec.err)

return Int(exec.rows_changed)
}

Expand All @@ -200,7 +214,7 @@ public class Statement {

return Rows(from: rows)
}

public func bind(_ params: [String: ValueRepresentable]) throws -> Self {
for (name, value) in params {
switch value.toValue() {
Expand Down Expand Up @@ -246,8 +260,8 @@ public class Statement {
try errIf(bind.err)
}
}
return self;

return self
}

public func bind(_ params: [ValueRepresentable]) throws -> Self {
Expand All @@ -260,7 +274,7 @@ public class Statement {
)
try errIf(bind.err)
case .text(let text):

let len = text.utf8.count
try text.withCString { text in
let bind = libsql_statement_bind_value(
Expand Down Expand Up @@ -288,38 +302,38 @@ public class Statement {
try errIf(bind.err)
}
}
return self;

return self
}
}

public class Transaction: Prepareable {
var inner: libsql_transaction_t

public consuming func commit() {
libsql_transaction_commit(self.inner)
}

public consuming func rollback() {
libsql_transaction_rollback(self.inner)
}

fileprivate init(from inner: libsql_transaction_t) {
self.inner = inner
}

public func executeBatch(_ sql: String) throws {
let batch = libsql_transaction_batch(self.inner, sql)
try errIf(batch.err)
}

public func prepare(_ sql: String) throws -> Statement {
let stmt = libsql_transaction_prepare(self.inner, sql);
let stmt = libsql_transaction_prepare(self.inner, sql)
try errIf(stmt.err)

return Statement(from: stmt)
}

}

public class Connection: Prepareable {
Expand All @@ -332,23 +346,23 @@ public class Connection: Prepareable {
fileprivate init(from inner: libsql_connection_t) {
self.inner = inner
}

public func transaction() throws -> Transaction {
let tx = libsql_connection_transaction(self.inner)
try errIf(tx.err);
try errIf(tx.err)

return Transaction(from: tx)
}

public func executeBatch(_ sql: String) throws {
let batch = libsql_connection_batch(self.inner, sql)
try errIf(batch.err)
}

public func prepare(_ sql: String) throws -> Statement {
let stmt = libsql_connection_prepare(self.inner, sql);
let stmt = libsql_connection_prepare(self.inner, sql)
try errIf(stmt.err)

return Statement(from: stmt)
}
}
Expand All @@ -368,18 +382,18 @@ public class Database {
public func connect() throws -> Connection {
let conn = libsql_database_connect(self.inner)
try errIf(conn.err)

return Connection(from: conn)
}

public init(_ path: String) throws {
self.inner = try path.withCString { path in
var desc = libsql_database_desc_t()
desc.path = path

let db = libsql_database_init(desc)
try errIf(db.err)

return db
}
}
Expand All @@ -391,10 +405,10 @@ public class Database {
desc.url = url
desc.auth_token = authToken
desc.webpki = withWebpki

let db = libsql_database_init(desc)
try errIf(db.err)

return db
}
}
Expand Down Expand Up @@ -422,10 +436,10 @@ public class Database {
desc.disable_read_your_writes = !readYourWrites
desc.sync_interval = syncInterval
desc.webpki = withWebpki

let db = libsql_database_init(desc)
try errIf(db.err)

return db
}
}
Expand Down
Loading