From d8bda58c7b1e5ec3dd7a495d9aeb3af026148385 Mon Sep 17 00:00:00 2001 From: David Li Date: Sun, 18 May 2025 12:44:18 +0900 Subject: [PATCH] GH-46481: [C++][Python] Allow nullable schema in FlightInfo --- cpp/src/arrow/flight/flight_internals_test.cc | 3 +++ cpp/src/arrow/flight/test_util.cc | 10 ++++++++ cpp/src/arrow/flight/test_util.h | 6 +++++ cpp/src/arrow/flight/types.cc | 25 ++++++++++++++++++- cpp/src/arrow/flight/types.h | 11 +++++++- python/pyarrow/_flight.pyx | 4 ++- python/pyarrow/src/arrow/python/flight.cc | 2 +- python/pyarrow/tests/test_flight.py | 24 +++++++++++++++--- 8 files changed, 78 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index ab2f8c7830786..bb14ddd6655e8 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -238,6 +238,7 @@ TEST(FlightTypes, FlightInfo) { MakeFlightInfo(schema1, desc1, {endpoint1}, -1, 42, true, ""), MakeFlightInfo(schema1, desc2, {endpoint1, endpoint2}, 64, -1, false, "\xDE\xAD\xC0\xDE"), + MakeFlightInfo(desc1, {}, -1, -1, false, ""), }; std::vector reprs = { " " @@ -257,6 +258,8 @@ TEST(FlightTypes, FlightInfo) { "locations=[grpc+tcp://localhost:1234] expiration_time=null " "app_metadata='CAFED00D'>] " "total_records=64 total_bytes=-1 ordered=false app_metadata='DEADC0DE'>", + " " + "endpoints=[] total_records=-1 total_bytes=-1 ordered=false app_metadata=''>", }; ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index aa10d9a7da822..e0b73ebb6eeb3 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -77,6 +77,16 @@ FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor& descript return info; } +FlightInfo MakeFlightInfo(const FlightDescriptor& descriptor, + const std::vector& endpoints, + int64_t total_records, int64_t total_bytes, bool ordered, + std::string app_metadata) { + EXPECT_OK_AND_ASSIGN(auto info, + FlightInfo::Make(nullptr, descriptor, endpoints, total_records, + total_bytes, ordered, std::move(app_metadata))); + return info; +} + NumberingStream::NumberingStream(std::unique_ptr stream) : counter_(0), stream_(std::move(stream)) {} diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h index 946caebcc2b5a..02963cd6996e9 100644 --- a/cpp/src/arrow/flight/test_util.h +++ b/cpp/src/arrow/flight/test_util.h @@ -182,6 +182,12 @@ FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor& descript int64_t total_records, int64_t total_bytes, bool ordered, std::string app_metadata); +ARROW_FLIGHT_EXPORT +FlightInfo MakeFlightInfo(const FlightDescriptor& descriptor, + const std::vector& endpoints, + int64_t total_records, int64_t total_bytes, bool ordered, + std::string app_metadata); + ARROW_FLIGHT_EXPORT Status ExampleTlsCertificates(std::vector* out); diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 65beec97d64df..5dbc119719528 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -280,10 +280,31 @@ arrow::Result FlightInfo::Make(const Schema& schema, return FlightInfo(std::move(data)); } +arrow::Result FlightInfo::Make(const std::shared_ptr& schema, + const FlightDescriptor& descriptor, + const std::vector& endpoints, + int64_t total_records, int64_t total_bytes, + bool ordered, std::string app_metadata) { + FlightInfo::Data data; + data.descriptor = descriptor; + data.endpoints = endpoints; + data.total_records = total_records; + data.total_bytes = total_bytes; + data.ordered = ordered; + data.app_metadata = std::move(app_metadata); + if (schema) { + RETURN_NOT_OK(internal::SchemaToString(*schema, &data.schema)); + } + return FlightInfo(std::move(data)); +} + arrow::Result> FlightInfo::GetSchema( ipc::DictionaryMemo* dictionary_memo) const { if (reconstructed_schema_) { return schema_; + } else if (data_.schema.empty()) { + reconstructed_schema_ = true; + return schema_; } // Create a non-owned Buffer to avoid copying io::BufferReader schema_reader(std::make_shared(data_.schema)); @@ -305,7 +326,9 @@ arrow::Status FlightInfo::Deserialize(std::string_view serialized, std::string FlightInfo::ToString() const { std::stringstream ss; ss << "ToString(); } else { ss << "(serialized)"; diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index b7df6191e4d3b..656cc00e67641 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -638,12 +638,21 @@ class ARROW_FLIGHT_EXPORT FlightInfo bool ordered = false, std::string app_metadata = ""); + /// \brief Factory method to construct a FlightInfo. + static arrow::Result Make(const std::shared_ptr& schema, + const FlightDescriptor& descriptor, + const std::vector& endpoints, + int64_t total_records, int64_t total_bytes, + bool ordered = false, + std::string app_metadata = ""); + /// \brief Deserialize the Arrow schema of the dataset. Populate any /// dictionary encoded fields into a DictionaryMemo for /// bookkeeping /// \param[in,out] dictionary_memo for dictionary bookkeeping, will /// be modified - /// \return Arrow result with the reconstructed Schema + /// \return Arrow result with the reconstructed Schema. Note that the schema + /// may be nullptr, as the schema is optional. arrow::Result> GetSchema( ipc::DictionaryMemo* dictionary_memo) const; diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index c9acb842642fe..fe2e1b3d67405 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -890,7 +890,7 @@ cdef class FlightInfo(_Weakrefable): Parameters ---------- - schema : Schema + schema : Schema, optional the schema of the data in this flight. descriptor : FlightDescriptor the descriptor for this flight. @@ -961,6 +961,8 @@ cdef class FlightInfo(_Weakrefable): CDictionaryMemo dummy_memo check_flight_status(self.info.get().GetSchema(&dummy_memo).Value(&schema)) + if schema.get() == NULL: + return None return pyarrow_wrap_schema(schema) @property diff --git a/python/pyarrow/src/arrow/python/flight.cc b/python/pyarrow/src/arrow/python/flight.cc index 2fda48b70b0fd..5ef8a1dd6b050 100644 --- a/python/pyarrow/src/arrow/python/flight.cc +++ b/python/pyarrow/src/arrow/python/flight.cc @@ -373,7 +373,7 @@ Status CreateFlightInfo(const std::shared_ptr& schema, const std::string& app_metadata, std::unique_ptr* out) { ARROW_ASSIGN_OR_RAISE(auto result, arrow::flight::FlightInfo::Make( - *schema, descriptor, endpoints, total_records, + schema, descriptor, endpoints, total_records, total_bytes, ordered, app_metadata)); *out = std::unique_ptr( new arrow::flight::FlightInfo(std::move(result))); diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index f830eacc4fabc..bd5ff12353ca9 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -620,9 +620,10 @@ def __init__(self, factory): def received_headers(self, headers): auth_header = case_insensitive_header_lookup(headers, 'Authorization') - self.factory.set_call_credential([ - b'authorization', - auth_header[0].encode("utf-8")]) + if auth_header: + self.factory.set_call_credential([ + b'authorization', + auth_header[0].encode("utf-8")]) class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory): @@ -916,6 +917,23 @@ def test_repr(): assert repr(flight.SchemaResult(pa.schema([("int", "int64")]))) == \ "" assert repr(flight.Ticket(b"foo")) == ticket_repr + assert info.schema == pa.schema([]) + + info = flight.FlightInfo( + None, flight.FlightDescriptor.for_path(), [], + 1, 42, True, b"test app metadata" + ) + info_repr = ( + " " + "endpoints=[] " + "total_records=1 " + "total_bytes=42 " + "ordered=True " + "app_metadata=b'test app metadata'>") + assert repr(info) == info_repr + assert info.schema is None with pytest.raises(TypeError): flight.Action("foo", None)