From a2769ffefc80c15608fa8b64b763a6fa7a5769ec Mon Sep 17 00:00:00 2001 From: a4z Date: Sun, 4 Jan 2026 14:14:50 +0100 Subject: [PATCH] Add way to access column names of query result rows Fixess #5 --- Sources/Libsql/Libsql.swift | 106 ++++++++++++++++------------ Tests/LibsqlTests/LibsqlTests.swift | 82 ++++++++++++++++----- 2 files changed, 123 insertions(+), 65 deletions(-) diff --git a/Sources/Libsql/Libsql.swift b/Sources/Libsql/Libsql.swift index a29983d..23a3bcb 100644 --- a/Sources/Libsql/Libsql.swift +++ b/Sources/Libsql/Libsql.swift @@ -41,28 +41,28 @@ public protocol Prepareable { func prepare(_ sql: String) throws -> Statement } -public extension Prepareable { - func execute(_ sql: String) throws -> Int { +extension Prepareable { + public func execute(_ sql: String) throws -> Int { return try self.prepare(sql).execute() } - - func execute(_ sql: String, _ params: [String: ValueRepresentable]) throws -> Int { + + public func execute(_ sql: String, _ params: [String: ValueRepresentable]) throws -> Int { return try self.prepare(sql).bind(params).execute() } - - func execute(_ sql: String, _ params: [ValueRepresentable]) throws -> Int { + + public func execute(_ sql: String, _ params: [ValueRepresentable]) throws -> Int { return try self.prepare(sql).bind(params).execute() } - - func query(_ sql: String) throws -> Rows { + + public func query(_ sql: String) throws -> Rows { return try self.prepare(sql).query() } - - func query(_ sql: String, _ params: [String: ValueRepresentable]) throws -> Rows { + + public func query(_ sql: String, _ params: [String: ValueRepresentable]) throws -> Rows { return try self.prepare(sql).bind(params).query() } - - func query(_ sql: String, _ params: [ValueRepresentable]) throws -> Rows { + + public func query(_ sql: String, _ params: [ValueRepresentable]) throws -> Rows { return try self.prepare(sql).bind(params).query() } } @@ -78,7 +78,7 @@ extension String? { } func errIf(_ err: OpaquePointer!) throws { - if (err != nil) { + if err != nil { defer { libsql_error_deinit(err) } throw LibsqlError.runtimeError(String(cString: libsql_error_message(err)!)) } @@ -87,6 +87,7 @@ func errIf(_ err: OpaquePointer!) throws { enum LibsqlError: Error { case runtimeError(String) case typeMismatch + case indexOutOfRange } public class Row { @@ -99,11 +100,11 @@ public class Row { self.inner = inner } - + public func get(_ index: Int32) throws -> Value { let result = libsql_row_value(self.inner, index) try errIf(result.err) - + switch result.ok.type { case LIBSQL_TYPE_BLOB: let slice = result.ok.value.blob @@ -125,28 +126,28 @@ public class Row { } public func getData(_ index: Int32) throws -> Data { - guard case let .blob(data) = try self.get(index) else { + guard case .blob(let data) = try self.get(index) else { throw LibsqlError.typeMismatch } return data } public func getDouble(_ index: Int32) throws -> Double { - guard case let .real(double) = try self.get(index) else { + guard case .real(let double) = try self.get(index) else { throw LibsqlError.typeMismatch } return double } public func getString(_ index: Int32) throws -> String { - guard case let .text(string) = try self.get(index) else { + guard case .text(let string) = try self.get(index) else { throw LibsqlError.typeMismatch } return string } public func getInt(_ index: Int32) throws -> Int { - guard case let .integer(int) = try self.get(index) else { + guard case .integer(let int) = try self.get(index) else { throw LibsqlError.typeMismatch } return Int(int) @@ -167,13 +168,26 @@ public class Rows: Sequence, IteratorProtocol { public func next() -> Row? { let row = libsql_rows_next(self.inner) try! errIf(row.err) - + if libsql_row_empty(row) { return nil } - + return Row(from: row) } + + public func columnCount() -> Int { + return Int(libsql_rows_column_count(self.inner)) + } + + public func columnName(_ index: Int32) throws -> String { + guard index >= 0 && index < Int32(columnCount()) else { + throw LibsqlError.indexOutOfRange + } + let slice = libsql_rows_column_name(self.inner, index) + defer { libsql_slice_deinit(slice) } + return String(cString: slice.ptr.assumingMemoryBound(to: UInt8.self)) + } } public class Statement { @@ -190,7 +204,7 @@ public class Statement { public func execute() throws -> Int { let exec = libsql_statement_execute(self.inner) try errIf(exec.err) - + return Int(exec.rows_changed) } @@ -200,7 +214,7 @@ public class Statement { return Rows(from: rows) } - + public func bind(_ params: [String: ValueRepresentable]) throws -> Self { for (name, value) in params { switch value.toValue() { @@ -246,8 +260,8 @@ public class Statement { try errIf(bind.err) } } - - return self; + + return self } public func bind(_ params: [ValueRepresentable]) throws -> Self { @@ -260,7 +274,7 @@ public class Statement { ) try errIf(bind.err) case .text(let text): - + let len = text.utf8.count try text.withCString { text in let bind = libsql_statement_bind_value( @@ -288,18 +302,18 @@ public class Statement { try errIf(bind.err) } } - - return self; + + return self } } public class Transaction: Prepareable { var inner: libsql_transaction_t - + public consuming func commit() { libsql_transaction_commit(self.inner) } - + public consuming func rollback() { libsql_transaction_rollback(self.inner) } @@ -307,19 +321,19 @@ public class Transaction: Prepareable { fileprivate init(from inner: libsql_transaction_t) { self.inner = inner } - + public func executeBatch(_ sql: String) throws { let batch = libsql_transaction_batch(self.inner, sql) try errIf(batch.err) } public func prepare(_ sql: String) throws -> Statement { - let stmt = libsql_transaction_prepare(self.inner, sql); + let stmt = libsql_transaction_prepare(self.inner, sql) try errIf(stmt.err) - + return Statement(from: stmt) } - + } public class Connection: Prepareable { @@ -332,23 +346,23 @@ public class Connection: Prepareable { fileprivate init(from inner: libsql_connection_t) { self.inner = inner } - + public func transaction() throws -> Transaction { let tx = libsql_connection_transaction(self.inner) - try errIf(tx.err); + try errIf(tx.err) return Transaction(from: tx) } - + public func executeBatch(_ sql: String) throws { let batch = libsql_connection_batch(self.inner, sql) try errIf(batch.err) } public func prepare(_ sql: String) throws -> Statement { - let stmt = libsql_connection_prepare(self.inner, sql); + let stmt = libsql_connection_prepare(self.inner, sql) try errIf(stmt.err) - + return Statement(from: stmt) } } @@ -368,7 +382,7 @@ public class Database { public func connect() throws -> Connection { let conn = libsql_database_connect(self.inner) try errIf(conn.err) - + return Connection(from: conn) } @@ -376,10 +390,10 @@ public class Database { self.inner = try path.withCString { path in var desc = libsql_database_desc_t() desc.path = path - + let db = libsql_database_init(desc) try errIf(db.err) - + return db } } @@ -391,10 +405,10 @@ public class Database { desc.url = url desc.auth_token = authToken desc.webpki = withWebpki - + let db = libsql_database_init(desc) try errIf(db.err) - + return db } } @@ -422,10 +436,10 @@ public class Database { desc.disable_read_your_writes = !readYourWrites desc.sync_interval = syncInterval desc.webpki = withWebpki - + let db = libsql_database_init(desc) try errIf(db.err) - + return db } } diff --git a/Tests/LibsqlTests/LibsqlTests.swift b/Tests/LibsqlTests/LibsqlTests.swift index 0851b07..0b35dff 100644 --- a/Tests/LibsqlTests/LibsqlTests.swift +++ b/Tests/LibsqlTests/LibsqlTests.swift @@ -8,7 +8,7 @@ final class LibsqlTests: XCTestCase { let db = try Database(":memory:") let _ = try db.connect() } - + func testOpenDbFile() throws { let db = try Database("test.db") let _ = try db.connect() @@ -19,7 +19,7 @@ final class LibsqlTests: XCTestCase { let conn = try db.connect() _ = try conn.execute("create table test (i integer, s text)") _ = try conn.execute("insert into test values (?, ?)", [1, "lorem ipsum"]) - let row = try conn.query("select * from test").next()!; + let row = try conn.query("select * from test").next()! XCTAssertEqual(try row.getInt(0), 1) XCTAssertEqual(try row.getString(1), "lorem ipsum") @@ -28,11 +28,12 @@ final class LibsqlTests: XCTestCase { func testExecuteBatch() throws { let db = try Database(":memory:") let conn = try db.connect() - _ = try conn.executeBatch(""" - create table test (i integer, s text); - insert into test values (1, 'lorem ipsum'); - """) - let row = try conn.query("select * from test").next()!; + _ = try conn.executeBatch( + """ + create table test (i integer, s text); + insert into test values (1, 'lorem ipsum'); + """) + let row = try conn.query("select * from test").next()! XCTAssertEqual(try row.getInt(0), 1) XCTAssertEqual(try row.getString(1), "lorem ipsum") @@ -53,35 +54,35 @@ final class LibsqlTests: XCTestCase { let stmt = try conn.prepare("select ?").bind([1]) XCTAssertEqual(try stmt.query().next()!.getInt(0), 1) } - + func testTransaction() throws { let db = try Database(":memory:") let conn = try db.connect() - + do { let tx = try conn.transaction() defer { tx.commit() } - + _ = try tx.execute("create table test (i integer)") - _ = try tx.execute("insert into test values (:v)", [ ":v": 1 ]) + _ = try tx.execute("insert into test values (:v)", [":v": 1]) } - + XCTAssertEqual(try conn.query("select * from test").next()!.getInt(0), 1) } - + func testTransactionRollback() throws { let db = try Database(":memory:") let conn = try db.connect() - + _ = try conn.execute("create table test (i integer)") - + do { let tx = try conn.transaction() defer { tx.rollback() } - - _ = try tx.execute("insert into test values (:v)", [ ":v": 1 ]) + + _ = try tx.execute("insert into test values (:v)", [":v": 1]) } - + XCTAssert(try conn.query("select * from test").next() == nil) } @@ -96,7 +97,7 @@ final class LibsqlTests: XCTestCase { for i in range { _ = try conn.execute( "insert into test values (?, ?, ?, ?)", - [ i, "\(i)", exp(Double(i)), Data([UInt8(i)]) ] + [i, "\(i)", exp(Double(i)), Data([UInt8(i)])] ) } @@ -107,4 +108,47 @@ final class LibsqlTests: XCTestCase { XCTAssertEqual(try row.getData(3), Data([UInt8(i)])) } } + + func testColumnCountAndNames() throws { + let db = try Database(":memory:") + let conn = try db.connect() + + _ = try conn.execute("create table test (id integer, name text, value real)") + _ = try conn.execute("insert into test values (1, 'test', 3.14)") + + let rows = try conn.query("select * from test") + + XCTAssertEqual(rows.columnCount(), 3) + XCTAssertEqual(try rows.columnName(0), "id") + XCTAssertEqual(try rows.columnName(1), "name") + XCTAssertEqual(try rows.columnName(2), "value") + } + + func testColumnNameOutOfRange() throws { + let db = try Database(":memory:") + let conn = try db.connect() + + _ = try conn.execute("create table test (id integer)") + _ = try conn.execute("insert into test values (1)") + + let rows = try conn.query("select * from test") + + XCTAssertThrowsError(try rows.columnName(-1)) { error in + XCTAssert(error is LibsqlError) + if case LibsqlError.indexOutOfRange = error { + } else { + XCTFail("Expected indexOutOfRange error") + } + } + + XCTAssertNoThrow(try rows.columnName(0)) + + XCTAssertThrowsError(try rows.columnName(1)) { error in + XCTAssert(error is LibsqlError) + if case LibsqlError.indexOutOfRange = error { + } else { + XCTFail("Expected indexOutOfRange error") + } + } + } }