diff --git a/c/driver/framework/base_driver_test.cc b/c/driver/framework/base_driver_test.cc index 1d8d61f60f..3117705628 100644 --- a/c/driver/framework/base_driver_test.cc +++ b/c/driver/framework/base_driver_test.cc @@ -20,6 +20,7 @@ #include #include "driver/framework/base_driver.h" +#include "driver/framework/client.h" #include "driver/framework/connection.h" #include "driver/framework/database.h" #include "driver/framework/statement.h" @@ -235,3 +236,79 @@ TEST(TestDriverBase, TestVoidDriverMethods) { ADBC_STATUS_INVALID_ARGUMENT); EXPECT_EQ(driver.StatementCancel(&statement, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); } + +class TestContext : public adbc::client::Context { + void Log(LogLevel level, std::string_view message) override { + GTEST_FAIL() << "Unexpected TestContext log message: " << message; + } +}; + +TEST(TestDriverBase, TestVoidDriverMethodsClient) { + using adbc::client::Connection; + using adbc::client::Database; + using adbc::client::Driver; + using adbc::client::Statement; + + Driver driver(std::make_shared()); + ASSERT_TRUE(driver.Load(VoidDriverInitFunc).ok()); + + auto maybe_database = driver.NewDatabase(); + ASSERT_TRUE(maybe_database.has_value()); + Database database = std::move(maybe_database.value()); + + // TODO: Test database methods + + auto maybe_connection = database.NewConnection(); + ASSERT_TRUE(maybe_connection.has_value()) << maybe_connection.status().message(); + Connection connection = std::move(maybe_connection.value()); + + // TODO: Test connection methods + + // EXPECT_EQ(driver.ConnectionCommit(&connection, nullptr), ADBC_STATUS_INVALID_STATE); + // EXPECT_EQ(driver.ConnectionGetInfo(&connection, nullptr, 0, nullptr, nullptr), + // ADBC_STATUS_INVALID_ARGUMENT); + // EXPECT_EQ(driver.ConnectionGetObjects(&connection, 0, nullptr, nullptr, 0, nullptr, + // nullptr, nullptr, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + // EXPECT_EQ(driver.ConnectionGetTableSchema(&connection, nullptr, nullptr, nullptr, + // nullptr, nullptr), + // ADBC_STATUS_INVALID_ARGUMENT); + // EXPECT_EQ(driver.ConnectionGetTableTypes(&connection, nullptr, nullptr), + // ADBC_STATUS_INVALID_ARGUMENT); + // EXPECT_EQ(driver.ConnectionReadPartition(&connection, nullptr, 0, nullptr, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + // EXPECT_EQ(driver.ConnectionRollback(&connection, nullptr), + // ADBC_STATUS_INVALID_STATE); EXPECT_EQ(driver.ConnectionCancel(&connection, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); EXPECT_EQ(driver.ConnectionGetStatistics(&connection, + // nullptr, nullptr, nullptr, 0, + // nullptr, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + // EXPECT_EQ(driver.ConnectionGetStatisticNames(&connection, nullptr, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + + auto maybe_statement = connection.NewStatement(); + ASSERT_TRUE(maybe_statement.has_value()); + Statement statement = std::move(maybe_statement.value()); + + // TODO: Test statement methods + // EXPECT_EQ(driver.StatementExecuteQuery(&statement, nullptr, nullptr, nullptr), + // ADBC_STATUS_INVALID_STATE); + // EXPECT_EQ(driver.StatementExecuteSchema(&statement, nullptr, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + // EXPECT_EQ(driver.StatementPrepare(&statement, nullptr), ADBC_STATUS_INVALID_STATE); + // EXPECT_EQ(driver.StatementSetSqlQuery(&statement, "", nullptr), ADBC_STATUS_OK); + // EXPECT_EQ(driver.StatementSetSubstraitPlan(&statement, nullptr, 0, nullptr), + // ADBC_STATUS_NOT_IMPLEMENTED); + // EXPECT_EQ(driver.StatementBind(&statement, nullptr, nullptr, nullptr), + // ADBC_STATUS_INVALID_ARGUMENT); + // EXPECT_EQ(driver.StatementBindStream(&statement, nullptr, nullptr), + // ADBC_STATUS_INVALID_ARGUMENT); + // EXPECT_EQ(driver.StatementCancel(&statement, nullptr), ADBC_STATUS_NOT_IMPLEMENTED); + + ASSERT_EQ(statement.SetSqlQuery("").code(), ADBC_STATUS_OK); + + ASSERT_EQ(statement.Release().code(), ADBC_STATUS_OK); + ASSERT_EQ(connection.Release().code(), ADBC_STATUS_OK); + ASSERT_EQ(database.Release().code(), ADBC_STATUS_OK); + ASSERT_EQ(driver.Unload().code(), ADBC_STATUS_OK); +} diff --git a/c/driver/framework/client.h b/c/driver/framework/client.h new file mode 100644 index 0000000000..cc2124768b --- /dev/null +++ b/c/driver/framework/client.h @@ -0,0 +1,629 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "driver/framework/status.h" + +namespace adbc::client { + +namespace internal { +class BaseDriver; +class BaseDatabase; +class BaseConnection; +class BaseStatement; +} // namespace internal + +using adbc::driver::Result; +using adbc::driver::Status; + +class Context { + public: + enum class LogLevel { kInfo, kWarn }; + + virtual ~Context() = default; + + virtual void OnUnreleasableStatement( + std::shared_ptr connection, AdbcStatement* statement, + const AdbcError* error) { + Log(LogLevel::kWarn, "leaking unreleasable statement"); + } + + virtual void OnUnreleasableConnection(std::shared_ptr database, + AdbcConnection* connection, + const AdbcError* error) { + Log(LogLevel::kWarn, "leaking unreleasable connection"); + } + + virtual void OnUnreleaseableDatabase(std::shared_ptr driver, + AdbcDatabase* database, const AdbcError* error) { + Log(LogLevel::kWarn, "leaking unreleasable database"); + } + + virtual void OnUnreleaseableDriver(AdbcDriver* driver, const AdbcError* error) { + Log(LogLevel::kWarn, "leaking unreleasable driver"); + } + + virtual void OnDeleteHandleWithoutClose( + const std::shared_ptr& statement) { + Log(LogLevel::kWarn, + "Leaking Statement handle; AdbcStatement will be auto-released when all child " + "readers are released. Use Statement::Release() to avoid this message."); + } + + virtual void OnDeleteHandleWithoutClose( + const std::shared_ptr& connection) { + Log(LogLevel::kWarn, + "Leaking Connection handle; AdbcConnection will be auto-released when all child " + "readers are released. Use Connection::Release() to avoid this message."); + } + + virtual void OnDeleteHandleWithoutClose( + const std::shared_ptr& database) { + Log(LogLevel::kWarn, + "Leaking Database handle; AdbcDatabase will be auto-released when all child " + "readers are released. Use Database::Release() to avoid this message."); + } + + virtual void Log(LogLevel level, std::string_view message) {} + + static std::shared_ptr Default() { return std::make_unique(); } +}; + +namespace internal { + +class BaseDriver { + public: + explicit BaseDriver(std::shared_ptr context = std::make_unique()) + : context_(context) {} + + BaseDriver(const BaseDriver& rhs) = delete; + + Status Load(AdbcDriverInitFunc init_func) { + return Status::FromAdbc(init_func(ADBC_VERSION_1_1_0, &driver_, &error_), error_); + } + + Status Unload() { + AdbcStatusCode code = driver_.release(&driver_, &error_); + return Status::FromAdbc(code, error_); + } + + AdbcDriver* driver() { return &driver_; } + + Context* context() { return context_.get(); } + + Status CheckValid() { + if (!driver_.release) { + return Status::InvalidState("Driver is released"); + } else { + return Status::Ok(); + } + } + + private: + std::shared_ptr context_; + AdbcDriver driver_{}; + AdbcError error_{ADBC_ERROR_INIT}; +}; + +class BaseObject { + public: + BaseObject() = default; + const std::shared_ptr& GetSharedDriver() { return driver_; } + + protected: + std::shared_ptr driver_; + AdbcError error_{ADBC_ERROR_INIT}; + + void NewBase(std::shared_ptr driver) { driver_ = driver; } + void ReleaseBase() { driver_.reset(); } + AdbcDriver* driver() { return driver_->driver(); } + Context* context() { return driver_->context(); } +}; + +#define WRAP_CALL(func, ...) (driver()->func(&database_, __VA_ARGS__, &error_)) +#define WRAP_CALL0(func) (driver()->func(&database_, &error_)) + +class BaseDatabase : public BaseObject { + public: + BaseDatabase() = default; + BaseDatabase(const BaseDatabase& rhs) = delete; + ~BaseDatabase() { + if (driver_ && database_.private_data) { + AdbcStatusCode code = WRAP_CALL0(DatabaseRelease); + if (code != ADBC_STATUS_OK) { + context()->OnUnreleaseableDatabase(driver_, &database_, &error_); + } + } + } + + AdbcDatabase* database() { return &database_; } + + Status New(std::shared_ptr parent) { + NewBase(std::move(parent)); + AdbcStatusCode code = WRAP_CALL0(DatabaseNew); + return Status::FromAdbc(code, error_); + } + + Status Init() { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL0(DatabaseInit); + return Status::FromAdbc(code, error_); + } + + Status Release() { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL0(DatabaseRelease); + if (code == ADBC_STATUS_OK) { + ReleaseBase(); + std::memset(&database_, 0, sizeof(database_)); + } + + return Status::FromAdbc(code, error_); + } + + Status CheckValid() { + if (!driver_ || !database_.private_data) { + return Status::InvalidState("BaseDatabase is released"); + } + + return driver_->CheckValid(); + } + + private: + AdbcDatabase database_{}; +}; + +#undef WRAP_CALL +#undef WRAP_CALL0 + +#define WRAP_CALL(func, ...) (driver()->func(&connection_, __VA_ARGS__, &error_)) +#define WRAP_CALL0(func) (driver()->func(&connection_, &error_)) + +class BaseConnection : public BaseObject { + public: + BaseConnection() = default; + BaseConnection(const BaseConnection& rhs) = delete; + ~BaseConnection() { + if (driver_ && connection_.private_data) { + AdbcStatusCode code = WRAP_CALL0(ConnectionRelease); + if (code != ADBC_STATUS_OK) { + context()->OnUnreleasableConnection(database_, &connection_, &error_); + } + } + } + + AdbcConnection* connection() { return &connection_; } + + Status New(std::shared_ptr database) { + UNWRAP_STATUS(database->CheckValid()); + NewBase(database->GetSharedDriver()); + AdbcStatusCode code = WRAP_CALL0(ConnectionNew); + if (code == ADBC_STATUS_OK) { + database_ = database; + } + + return Status::FromAdbc(code, error_); + } + + Status Init() { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL(ConnectionInit, database_->database()); + return Status::FromAdbc(code, error_); + } + + Status Release() { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL0(ConnectionRelease); + if (code == ADBC_STATUS_OK) { + ReleaseBase(); + database_.reset(); + std::memset(&connection_, 0, sizeof(connection_)); + } + + return Status::FromAdbc(code, error_); + } + + Status Cancel() { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL0(ConnectionCancel); + return Status::FromAdbc(code, error_); + } + + Status GetInfo(const uint32_t* info_codes, size_t n_info_codes, ArrowArrayStream* out) { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL(ConnectionGetInfo, info_codes, n_info_codes, out); + return Status::FromAdbc(code, error_); + } + + Status CheckValid() { + if (!driver_ || !database_ || !connection_.private_data) { + return Status::InvalidState("BaseConnection is released"); + } + + return database_->CheckValid(); + } + + private: + std::shared_ptr database_; + AdbcConnection connection_{}; +}; + +#undef WRAP_CALL +#undef WRAP_CALL0 + +#define WRAP_CALL(func, ...) (driver_->driver()->func(&statement_, __VA_ARGS__, &error_)) +#define WRAP_CALL0(func) (driver_->driver()->func(&statement_, &error_)) + +class BaseStatement : public BaseObject { + public: + BaseStatement() = default; + BaseStatement(const BaseStatement& rhs) = delete; + ~BaseStatement() { + if (driver_ && statement_.private_data) { + AdbcStatusCode code = WRAP_CALL0(StatementRelease); + if (code != ADBC_STATUS_OK) { + context()->OnUnreleasableStatement(connection_, &statement_, &error_); + } + } + } + + Status New(std::shared_ptr connection) { + NewBase(connection->GetSharedDriver()); + AdbcStatusCode code = + driver()->StatementNew(connection->connection(), &statement_, &error_); + if (code == ADBC_STATUS_OK) { + connection_ = connection; + } + + return Status::FromAdbc(code, error_); + } + + Status Release() { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL0(StatementRelease); + if (code == ADBC_STATUS_OK) { + ReleaseBase(); + connection_.reset(); + std::memset(&statement_, 0, sizeof(statement_)); + } + + return Status::FromAdbc(code, error_); + } + + Status SetSqlQuery(const char* query) { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL(StatementSetSqlQuery, query); + return Status::FromAdbc(code, error_); + } + + Status ExecuteQuery(ArrowArrayStream* stream, int64_t* affected_rows) { + UNWRAP_STATUS(CheckValid()); + AdbcStatusCode code = WRAP_CALL(StatementExecuteQuery, stream, affected_rows); + return Status::FromAdbc(code, error_); + } + + Status CheckValid() { + if (!driver_ || !connection_ || !statement_.private_data) { + return Status::InvalidState("BaseStatement is released"); + } + + return connection_->CheckValid(); + } + + private: + std::shared_ptr connection_; + AdbcStatement statement_{}; +}; + +#undef WRAP_CALL +#undef WRAP_CALL0 + +} // namespace internal + +template +class Stream { + public: + explicit Stream(Parent parent) : parent_(parent) {} + + Stream& operator=(const Stream& rhs) = delete; + Stream(const Stream& rhs) = delete; + + Stream(Stream&& rhs) : Stream(std::move(rhs.parent_)) { + std::memcpy(&stream_, &rhs.stream_, sizeof(ArrowArrayStream)); + std::memset(&rhs.stream_, 0, sizeof(ArrowArrayStream)); + rows_affected_ = rhs.rows_affected_; + } + + Stream& operator=(Stream&& rhs) { + parent_ = std::move(rhs.parent_); + std::memcpy(&stream_, &rhs.stream_, sizeof(ArrowArrayStream)); + std::memset(&rhs.stream_, 0, sizeof(ArrowArrayStream)); + rows_affected_ = rhs.rows_affected_; + return *this; + } + + ArrowArrayStream* stream() { return &stream_; } + + int64_t rows_affected() { return rows_affected_; } + + int64_t* mutable_rows_affected() { return &rows_affected_; } + + ~Stream() { + if (stream_.release) { + stream_.release(&stream_); + } + } + + void Export(ArrowArrayStream* out) { + Stream* instance = new Stream(); + instance->parent_ = std::move(parent_); + std::memcpy(&instance->stream_, &stream_, sizeof(ArrowArrayStream)); + std::memset(&stream_, 0, sizeof(ArrowArrayStream)); + instance->rows_affected_ = rows_affected_; + + out->get_schema = &CGetSchema; + out->get_next = &CGetNext; + out->get_last_error = &CGetLastError; + out->release = &CRelease; + out->private_data = instance; + } + + private: + Parent parent_; + ArrowArrayStream stream_{}; + int64_t rows_affected_{-1}; + // For the specific case of a stream whose parent is no longer valid, + // this lets us save the error message and return a const char* from + // get_last_error(). + Status last_status_; + + static int CGetSchema(ArrowArrayStream* stream, ArrowSchema* schema) { + auto private_data = reinterpret_cast(stream->private_data); + if (!private_data->parent_->CheckValid().ok()) { + return EADDRNOTAVAIL; + } + + return private_data->GetSchema(schema); + } + + static int CGetNext(ArrowArrayStream* stream, ArrowArray* array) { + auto private_data = reinterpret_cast(stream->private_data); + if (!private_data->parent_->CheckValid().ok()) { + return EADDRNOTAVAIL; + } + + return private_data->GetNext(array); + } + + static const char* CGetLastError(ArrowArrayStream* stream) { + auto private_data = reinterpret_cast(stream->private_data); + private_data->last_status_ = private_data->CheckValid(); + if (!private_data->last_status_.ok()) { + return private_data->last_status_.message(); + } + + return private_data->GetLastError(); + } + + static void CRelease(ArrowArrayStream* stream) { + delete reinterpret_cast(stream->private_data); + stream->release = nullptr; + stream->private_data = nullptr; + } +}; + +class Driver; +class Database; +class Connection; +class Statement; +using ConnectionStream = Stream>; +using StatementStream = Stream>; + +class Statement { + public: + Statement& operator=(const Statement&) = delete; + Statement(const Statement& rhs) = delete; + Statement(Statement&& rhs) : base_(std::move(rhs.base_)) {} + Statement& operator=(Statement&& rhs) { + base_ = std::move(rhs.base_); + return *this; + } + + ~Statement() { + if (base_ && base_->GetSharedDriver()) { + base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); + } + } + + Status Release() { + UNWRAP_STATUS(CheckValid()); + UNWRAP_STATUS(base_->Release()); + base_.reset(); + return Status::Ok(); + } + + Status SetSqlQuery(const std::string& query) { + UNWRAP_STATUS(CheckValid()); + return base_->SetSqlQuery(query.c_str()); + } + + Result ExecuteQuery() { + UNWRAP_STATUS(CheckValid()); + StatementStream out(base_); + UNWRAP_STATUS(base_->ExecuteQuery(out.stream(), out.mutable_rows_affected())); + return out; + } + + private: + std::shared_ptr base_; + + friend class Connection; + explicit Statement(std::shared_ptr base) + : base_(std::move(base)) {} + + Status CheckValid() { + if (!base_) { + return Status::InvalidState("Statement handle has been released"); + } else { + return Status::Ok(); + } + } +}; + +class Connection { + public: + Connection& operator=(const Connection&) = delete; + Connection(const Connection& rhs) = delete; + Connection(Connection&& rhs) : base_(std::move(rhs.base_)) {} + Connection& operator=(Connection&& rhs) { + base_ = std::move(rhs.base_); + return *this; + } + + ~Connection() { + if (base_ && base_->GetSharedDriver()) { + base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); + } + } + + Status Release() { + UNWRAP_STATUS(CheckValid()); + UNWRAP_STATUS(base_->Release()); + base_.reset(); + return Status::Ok(); + } + + Result NewStatement() { + UNWRAP_STATUS(CheckValid()); + auto child = std::make_shared(); + UNWRAP_STATUS(child->New(base_)); + return Statement(std::move(child)); + } + + Status Cancel() { + UNWRAP_STATUS(CheckValid()); + return base_->Cancel(); + } + + Result GetInfo(const std::vector& info_codes = {}) { + UNWRAP_STATUS(CheckValid()); + ConnectionStream out(base_); + UNWRAP_STATUS(base_->GetInfo(info_codes.data(), info_codes.size(), out.stream())); + return out; + } + + private: + std::shared_ptr base_; + + friend class Database; + explicit Connection(std::shared_ptr base) + : base_(std::move(base)) {} + + Status CheckValid() { + if (!base_) { + return Status::InvalidState("Connection handle has been released"); + } else { + return Status::Ok(); + } + } +}; + +class Database { + public: + Database& operator=(const Database&) = delete; + Database(const Database& rhs) = delete; + Database(Database&& rhs) : base_(std::move(rhs.base_)) {} + Database& operator=(Database&& rhs) { + base_ = std::move(rhs.base_); + return *this; + } + + ~Database() { + if (base_ && base_->GetSharedDriver()) { + base_->GetSharedDriver()->context()->OnDeleteHandleWithoutClose(base_); + } + } + + Status Release() { + UNWRAP_STATUS(CheckValid()); + UNWRAP_STATUS(base_->Release()); + base_.reset(); + return Status::Ok(); + } + + Result NewConnection() { + UNWRAP_STATUS(CheckValid()); + + auto child = std::make_shared(); + UNWRAP_STATUS(child->New(base_)); + UNWRAP_STATUS(child->Init()); + return Connection(std::move(child)); + } + + private: + std::shared_ptr base_; + + friend class Driver; + explicit Database(std::shared_ptr base) + : base_(std::move(base)) {} + + Status CheckValid() { + if (!base_) { + return Status::InvalidState("Database handle has been released"); + } else { + return Status::Ok(); + } + } +}; + +class Driver { + public: + explicit Driver(std::shared_ptr context = Context::Default()) + : base_(std::make_shared(std::move(context))) {} + + Driver(const Driver& rhs) = delete; + Driver(Driver&& rhs) : base_(std::move(rhs.base_)) {} + Driver& operator=(Driver&& rhs) { + base_ = std::move(rhs.base_); + return *this; + } + + Status Load(AdbcDriverInitFunc init_func) { return base_->Load(init_func); } + + Status Unload() { return base_->Unload(); } + + Result NewDatabase() { + auto child = std::make_shared(); + UNWRAP_STATUS(child->New(base_)); + UNWRAP_STATUS(child->Init()); + return Database(std::move(child)); + } + + private: + std::shared_ptr base_; +}; + +} // namespace adbc::client diff --git a/c/driver/framework/status.h b/c/driver/framework/status.h index cfdca6ebbe..b9e1da85db 100644 --- a/c/driver/framework/status.h +++ b/c/driver/framework/status.h @@ -63,6 +63,22 @@ class Status { /// \brief Check if this is an error or not. bool ok() const { return impl_ == nullptr; } + const char* message() const { + if (!impl_) { + return ""; + } else { + return impl_->message.c_str(); + } + } + + AdbcStatusCode code() const { + if (ok()) { + return ADBC_STATUS_OK; + } else { + return impl_->code; + } + } + /// \brief Add another error detail. void AddDetail(std::string key, std::string value) { assert(impl_ != nullptr); @@ -111,7 +127,6 @@ class Status { } static Status FromAdbc(AdbcStatusCode code, AdbcError& error) { - // not really meant to be used, just something we have for now while porting if (code == ADBC_STATUS_OK) { if (error.release) { error.release(&error);