diff --git a/.release-notes/add-row-streaming.md b/.release-notes/add-row-streaming.md new file mode 100644 index 0000000..fae15a5 --- /dev/null +++ b/.release-notes/add-row-streaming.md @@ -0,0 +1,23 @@ +## Add row streaming support + +Row streaming delivers query results in fixed-size batches instead of buffering all rows before delivery. This enables pull-based paged result consumption with bounded memory, ideal for large result sets. + +A new `StreamingResultReceiver` interface provides three callbacks: `pg_stream_batch` delivers each batch of rows, `pg_stream_complete` signals all rows have been delivered, and `pg_stream_failed` reports errors. Three new `Session` methods control the flow: + +```pony +// Start streaming with a window size of 100 rows per batch +session.stream( + PreparedQuery("SELECT * FROM big_table", + recover val Array[(String | None)] end), + 100, my_receiver) + +// In the receiver: +be pg_stream_batch(session: Session, rows: Rows) => + // Process this batch + session.fetch_more() // Pull the next batch + +be pg_stream_complete(session: Session) => + // All rows delivered +``` + +Call `session.close_stream()` to end streaming early. Only `PreparedQuery` and `NamedPreparedQuery` are supported — streaming uses the extended query protocol's `Execute(max_rows)` + `PortalSuspended` mechanism. diff --git a/CLAUDE.md b/CLAUDE.md index 86bbb13..0df424a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -61,17 +61,18 @@ This design makes illegal state transitions call `_IllegalState()` (panic) by de ### Query Execution Flow -1. Client calls `session.execute(query, ResultReceiver)` where query is `SimpleQuery`, `PreparedQuery`, or `NamedPreparedQuery`; or `session.prepare(name, sql, PrepareReceiver)` to create a named statement; or `session.close_statement(name)` to destroy one; or `session.copy_in(sql, CopyInReceiver)` to start a COPY FROM STDIN operation; or `session.copy_out(sql, CopyOutReceiver)` to start a COPY TO STDOUT operation -2. `_SessionLoggedIn` queues operations as `_QueueItem` — a union of `_QueuedQuery` (execute), `_QueuedPrepare` (prepare), `_QueuedCloseStatement` (close_statement), `_QueuedCopyIn` (copy_in), and `_QueuedCopyOut` (copy_out) +1. Client calls `session.execute(query, ResultReceiver)` where query is `SimpleQuery`, `PreparedQuery`, or `NamedPreparedQuery`; or `session.prepare(name, sql, PrepareReceiver)` to create a named statement; or `session.close_statement(name)` to destroy one; or `session.copy_in(sql, CopyInReceiver)` to start a COPY FROM STDIN operation; or `session.copy_out(sql, CopyOutReceiver)` to start a COPY TO STDOUT operation; or `session.stream(query, window_size, StreamingResultReceiver)` to start a streaming query +2. `_SessionLoggedIn` queues operations as `_QueueItem` — a union of `_QueuedQuery` (execute), `_QueuedPrepare` (prepare), `_QueuedCloseStatement` (close_statement), `_QueuedCopyIn` (copy_in), `_QueuedCopyOut` (copy_out), and `_QueuedStreamingQuery` (stream) 3. The `_QueryState` sub-state machine manages operation lifecycle: - `_QueryNotReady`: initial state after auth, before the first ReadyForQuery arrives - - `_QueryReady`: server is idle, `try_run_query` dispatches based on queue item type — `SimpleQuery` transitions to `_SimpleQueryInFlight`, `PreparedQuery` and `NamedPreparedQuery` transition to `_ExtendedQueryInFlight`, `_QueuedPrepare` transitions to `_PrepareInFlight`, `_QueuedCloseStatement` transitions to `_CloseStatementInFlight`, `_QueuedCopyIn` transitions to `_CopyInInFlight`, `_QueuedCopyOut` transitions to `_CopyOutInFlight` + - `_QueryReady`: server is idle, `try_run_query` dispatches based on queue item type — `SimpleQuery` transitions to `_SimpleQueryInFlight`, `PreparedQuery` and `NamedPreparedQuery` transition to `_ExtendedQueryInFlight`, `_QueuedPrepare` transitions to `_PrepareInFlight`, `_QueuedCloseStatement` transitions to `_CloseStatementInFlight`, `_QueuedCopyIn` transitions to `_CopyInInFlight`, `_QueuedCopyOut` transitions to `_CopyOutInFlight`, `_QueuedStreamingQuery` transitions to `_StreamingQueryInFlight` - `_SimpleQueryInFlight`: owns per-query accumulation data (`_data_rows`, `_row_description`), delivers results on `CommandComplete` - `_ExtendedQueryInFlight`: same data accumulation and result delivery as `_SimpleQueryInFlight` (duplicated because Pony traits can't have iso fields). Entered after sending Parse+Bind+Describe(portal)+Execute+Sync (unnamed) or Bind+Describe(portal)+Execute+Sync (named) - `_PrepareInFlight`: handles Parse+Describe(statement)+Sync cycle. Notifies `PrepareReceiver` on success/failure via `ReadyForQuery` - `_CloseStatementInFlight`: handles Close(statement)+Sync cycle. Fire-and-forget (no callback); errors silently absorbed - `_CopyInInFlight`: handles COPY FROM STDIN data transfer. Sends the COPY query via simple query protocol, receives `CopyInResponse`, then uses pull-based flow: calls `pg_copy_ready` on the `CopyInReceiver` to request data. Client calls `send_copy_data` (sends CopyData + pulls again), `finish_copy` (sends CopyDone), or `abort_copy` (sends CopyFail). Server responds with CommandComplete+ReadyForQuery on success, or ErrorResponse+ReadyForQuery on failure - `_CopyOutInFlight`: handles COPY TO STDOUT data reception. Sends the COPY query via simple query protocol, receives `CopyOutResponse` (silently consumed), then receives server-pushed `CopyData` messages (each delivered via `pg_copy_data` to the `CopyOutReceiver`), `CopyDone` (silently consumed), and finally `CommandComplete` (stores row count) + `ReadyForQuery` (delivers `pg_copy_complete`). On error, `ErrorResponse` delivers `pg_copy_failed` and the session remains usable + - `_StreamingQueryInFlight`: handles streaming row delivery. Entered after sending Parse+Bind+Describe(portal)+Execute(max_rows)+Flush (unnamed) or Bind+Describe(portal)+Execute(max_rows)+Flush (named). Uses Flush instead of Sync to keep the portal alive between batches. `PortalSuspended` triggers batch delivery via `pg_stream_batch`. Client calls `fetch_more()` (sends Execute+Flush) or `close_stream()` (sends Sync). `CommandComplete` delivers final batch and sends Sync. `ReadyForQuery` delivers `pg_stream_complete` and dequeues. On error, sends Sync (required because no Sync is pending during streaming) and delivers `pg_stream_failed` 4. Response data arrives: `_RowDescriptionMessage` sets column metadata, `_DataRowMessage` accumulates rows 5. `_CommandCompleteMessage` triggers result delivery to receiver 6. `_ReadyForQueryMessage` dequeues completed operation, transitions to `_QueryReady` @@ -83,13 +84,13 @@ Only one operation is in-flight at a time. The queue serializes execution. `quer ### Protocol Layer **Frontend (client → server):** -- `_FrontendMessage` primitive: `startup()`, `password()`, `query()`, `parse()`, `bind()`, `describe_portal()`, `describe_statement()`, `execute_msg()`, `close_statement()`, `sync()`, `ssl_request()`, `cancel_request()`, `terminate()`, `sasl_initial_response()`, `sasl_response()`, `copy_data()`, `copy_done()`, `copy_fail()` — builds raw byte arrays with big-endian wire format +- `_FrontendMessage` primitive: `startup()`, `password()`, `query()`, `parse()`, `bind()`, `describe_portal()`, `describe_statement()`, `execute_msg()`, `close_statement()`, `sync()`, `flush()`, `ssl_request()`, `cancel_request()`, `terminate()`, `sasl_initial_response()`, `sasl_response()`, `copy_data()`, `copy_done()`, `copy_fail()` — builds raw byte arrays with big-endian wire format **Backend (server → client):** - `_ResponseParser` primitive: incremental parser consuming from a `Reader` buffer. Returns one parsed message per call, `None` if incomplete, errors on junk. - `_ResponseMessageParser` primitive: routes parsed messages to the current session state's callbacks. Processes messages synchronously within a query cycle (looping until `ReadyForQuery` or buffer exhaustion), then yields via `s._process_again()` between cycles. This prevents behaviors like `close()` from interleaving between result delivery and query dequeuing. If a callback triggers shutdown during the loop, `on_shutdown` clears the read buffer, causing the next parse to return `None` and exit the loop. -**Supported message types:** AuthenticationOk, AuthenticationMD5Password, AuthenticationSASL, AuthenticationSASLContinue, AuthenticationSASLFinal, BackendKeyData, CommandComplete, CopyInResponse, CopyOutResponse, CopyData, CopyDone, DataRow, EmptyQueryResponse, ErrorResponse, NoticeResponse, NotificationResponse, ParameterStatus, ReadyForQuery, RowDescription, ParseComplete, BindComplete, NoData, CloseComplete, ParameterDescription, PortalSuspended. BackendKeyData is parsed and stored in `_SessionLoggedIn` (`backend_pid`, `backend_secret_key`) for future query cancellation. NotificationResponse is parsed into `_NotificationResponseMessage` and routed to `_SessionLoggedIn.on_notification()`, which delivers `pg_notification` to `SessionStatusNotify`. NoticeResponse is parsed into `NoticeResponseMessage` (using shared `_ResponseFieldBuilder` / `_parse_response_fields` with ErrorResponse) and routed via `on_notice()` to `SessionStatusNotify.pg_notice()`. Notices are delivered in all connected states (including during authentication) since PostgreSQL can send them at any time. ParameterStatus is parsed into `_ParameterStatusMessage` and routed via `on_parameter_status()` to `SessionStatusNotify.pg_parameter_status()`, which delivers a `ParameterStatus` val. Like notices, parameter status messages are delivered in all connected states. Extended query acknowledgment messages (ParseComplete, BindComplete, NoData, etc.) are parsed but silently consumed — they fall through the `_ResponseMessageParser` match without routing since the state machine tracks query lifecycle through data-carrying messages only. +**Supported message types:** AuthenticationOk, AuthenticationMD5Password, AuthenticationSASL, AuthenticationSASLContinue, AuthenticationSASLFinal, BackendKeyData, CommandComplete, CopyInResponse, CopyOutResponse, CopyData, CopyDone, DataRow, EmptyQueryResponse, ErrorResponse, NoticeResponse, NotificationResponse, ParameterStatus, ReadyForQuery, RowDescription, ParseComplete, BindComplete, NoData, CloseComplete, ParameterDescription, PortalSuspended. BackendKeyData is parsed and stored in `_SessionLoggedIn` (`backend_pid`, `backend_secret_key`) for future query cancellation. NotificationResponse is parsed into `_NotificationResponseMessage` and routed to `_SessionLoggedIn.on_notification()`, which delivers `pg_notification` to `SessionStatusNotify`. NoticeResponse is parsed into `NoticeResponseMessage` (using shared `_ResponseFieldBuilder` / `_parse_response_fields` with ErrorResponse) and routed via `on_notice()` to `SessionStatusNotify.pg_notice()`. Notices are delivered in all connected states (including during authentication) since PostgreSQL can send them at any time. ParameterStatus is parsed into `_ParameterStatusMessage` and routed via `on_parameter_status()` to `SessionStatusNotify.pg_parameter_status()`, which delivers a `ParameterStatus` val. Like notices, parameter status messages are delivered in all connected states. PortalSuspended is parsed into `_PortalSuspendedMessage` and routed to `s.state.on_portal_suspended(s)` for streaming batch delivery. Extended query acknowledgment messages (ParseComplete, BindComplete, NoData, etc.) are parsed but silently consumed — they fall through the `_ResponseMessageParser` match without routing since the state machine tracks query lifecycle through data-carrying messages only. ### Public API Types @@ -109,6 +110,7 @@ Only one operation is in-flight at a time. The queue serializes execution. `quer - `PrepareReceiver` interface (tag) — `pg_statement_prepared(Session, name)`, `pg_prepare_failed(Session, name, (ErrorResponseMessage | ClientQueryError))` - `CopyInReceiver` interface (tag) — `pg_copy_ready(Session)`, `pg_copy_complete(Session, count)`, `pg_copy_failed(Session, (ErrorResponseMessage | ClientQueryError))`. Pull-based: session calls `pg_copy_ready` after `copy_in` and after each `send_copy_data`, letting the client control data flow - `CopyOutReceiver` interface (tag) — `pg_copy_data(Session, Array[U8] val)`, `pg_copy_complete(Session, count)`, `pg_copy_failed(Session, (ErrorResponseMessage | ClientQueryError))`. Push-based: server drives the flow, delivering data chunks via `pg_copy_data` and signaling completion via `pg_copy_complete` +- `StreamingResultReceiver` interface (tag) — `pg_stream_batch(Session, Rows)`, `pg_stream_complete(Session)`, `pg_stream_failed(Session, (PreparedQuery | NamedPreparedQuery), (ErrorResponseMessage | ClientQueryError))`. Pull-based: session delivers batches via `pg_stream_batch`; client calls `fetch_more()` for the next batch or `close_stream()` to end early - `ClientQueryError` trait — `SessionNeverOpened`, `SessionClosed`, `SessionNotAuthenticated`, `DataError` - `DatabaseConnectInfo` — val class grouping database authentication parameters (user, password, database). Passed to `Session.create()` alongside `ServerConnectInfo`. - `ServerConnectInfo` — val class grouping connection parameters (auth, host, service, ssl_mode). Passed to `Session.create()` as the first parameter. Also used by `_CancelSender`. @@ -143,7 +145,7 @@ Tests live in the main `postgres/` package (private test classes), organized acr **Conventions**: `_test.pony` contains shared helpers (`_ConnectionTestConfiguration` for env vars, `_ConnectTestNotify`/`_AuthenticateTestNotify` reused by other files). `_test_response_parser.pony` contains `_Incoming*TestMessage` builder classes that construct raw protocol bytes for mock servers across all test files. `_test_mock_message_reader.pony` contains `_MockMessageReader` for extracting complete PostgreSQL frontend messages from TCP data in mock servers. -**Ports**: Mock server tests use ports in the 7669–7701 range and 9667–9668. **Port 7680 is reserved by Windows** (Update Delivery Optimization) and will fail to bind on WSL2 — do not use it. +**Ports**: Mock server tests use ports in the 7669–7706 range and 9667–9668. **Port 7680 is reserved by Windows** (Update Delivery Optimization) and will fail to bind on WSL2 — do not use it. ## Supported PostgreSQL Features @@ -151,7 +153,7 @@ Tests live in the main `postgres/` package (private test classes), organized acr **Authentication:** MD5 password and SCRAM-SHA-256. No SCRAM-SHA-256-PLUS (channel binding), Kerberos, GSS, or certificate auth. Design: [discussion #83](https://github.com/ponylang/postgres/discussions/83). -**Protocol:** Simple query and extended query (parameterized via unnamed and named prepared statements). Parameters are text-format only; type OIDs are inferred by the server. LISTEN/NOTIFY, NoticeResponse, ParameterStatus, COPY FROM STDIN (pull-based), COPY TO STDOUT (push-based). No function calls. Full feature roadmap: [discussion #72](https://github.com/ponylang/postgres/discussions/72). +**Protocol:** Simple query and extended query (parameterized via unnamed and named prepared statements). Parameters are text-format only; type OIDs are inferred by the server. LISTEN/NOTIFY, NoticeResponse, ParameterStatus, COPY FROM STDIN (pull-based), COPY TO STDOUT (push-based), row streaming (windowed batch delivery via Execute(max_rows)+PortalSuspended). No function calls. Full feature roadmap: [discussion #72](https://github.com/ponylang/postgres/discussions/72). **CI containers:** Stock `postgres:14.5` for plain (port 5432, SCRAM-SHA-256 default) and `ghcr.io/ponylang/postgres-ci-pg-ssl:latest` for SSL (port 5433, SSL + md5user); built via `build-ci-image.yml` workflow dispatch or locally via `.ci-dockerfiles/pg-ssl/build-and-push.bash`. MD5 integration tests connect to the SSL container (without using SSL) because only that container has the md5user. diff --git a/examples/README.md b/examples/README.md index 168f56e..8a04636 100644 --- a/examples/README.md +++ b/examples/README.md @@ -42,6 +42,10 @@ Bulk data loading using `COPY ... FROM STDIN`. Creates a table, loads three rows Bulk data export using `COPY ... TO STDOUT`. Creates a table, inserts three rows, exports them via `Session.copy_out()`, and prints the received data. Demonstrates the push-based `CopyOutReceiver` interface: the server drives the flow, calling `pg_copy_data` for each chunk, then `pg_copy_complete` when finished. +## streaming + +Row streaming using `Session.stream()` with windowed batch delivery. Creates a table with 7 rows, streams them with `window_size=3` (producing batches of 3, 3, and 1), then drops the table. Demonstrates the pull-based `StreamingResultReceiver` interface: `pg_stream_batch` delivers each batch, `fetch_more()` requests the next, and `pg_stream_complete` signals completion. + ## notice Server notice handling using `pg_notice`. Executes `DROP TABLE IF EXISTS` on a nonexistent table, which triggers a PostgreSQL `NoticeResponse`, and prints the notice fields (severity, code, message). Shows how `SessionStatusNotify.pg_notice` delivers non-fatal informational messages from the server. diff --git a/examples/streaming/streaming-example.pony b/examples/streaming/streaming-example.pony new file mode 100644 index 0000000..ff26a8e --- /dev/null +++ b/examples/streaming/streaming-example.pony @@ -0,0 +1,147 @@ +use "cli" +use "collections" +use lori = "lori" +// in your code this `use` statement would be: +// use "postgres" +use "../../postgres" + +actor Main + new create(env: Env) => + let server_info = ServerInfo(env.vars) + let auth = lori.TCPConnectAuth(env.root) + + let client = Client(auth, server_info, env.out) + +// This example demonstrates row streaming for pull-based paged result +// consumption. It creates a table with 7 rows, streams them with a window +// size of 3 (producing batches of 3, 3, and 1), then drops the table. +// +// The streaming protocol uses a pull-based flow: the session delivers a +// batch via pg_stream_batch, then the client calls fetch_more() to request +// the next batch. When no more rows remain, pg_stream_complete fires. +actor Client is + (SessionStatusNotify & ResultReceiver & StreamingResultReceiver) + let _session: Session + let _out: OutStream + var _phase: USize = 0 + + new create(auth: lori.TCPConnectAuth, info: ServerInfo, out: OutStream) => + _out = out + _session = Session( + ServerConnectInfo(auth, info.host, info.port), + DatabaseConnectInfo(info.username, info.password, info.database), + this) + + be close() => + _session.close() + + be pg_session_authenticated(session: Session) => + _out.print("Authenticated.") + _phase = 0 + session.execute( + SimpleQuery("DROP TABLE IF EXISTS streaming_example"), this) + + be pg_session_authentication_failed( + s: Session, + reason: AuthenticationFailureReason) + => + _out.print("Failed to authenticate.") + + be pg_query_result(session: Session, result: Result) => + _phase = _phase + 1 + match _phase + | 1 => + _out.print("Creating table...") + _session.execute( + SimpleQuery( + """ + CREATE TABLE streaming_example ( + id INT NOT NULL, + name VARCHAR(50) NOT NULL + ) + """), + this) + | 2 => + _out.print("Inserting rows...") + _session.execute( + SimpleQuery( + """ + INSERT INTO streaming_example VALUES + (1, 'alpha'), (2, 'bravo'), (3, 'charlie'), + (4, 'delta'), (5, 'echo'), (6, 'foxtrot'), (7, 'golf') + """), + this) + | 3 => + _out.print("Starting stream with window_size=3...") + _session.stream( + PreparedQuery( + "SELECT id, name FROM streaming_example ORDER BY id", + recover val Array[(String | None)] end), + 3, this) + | 5 => + _out.print("Done.") + close() + end + + be pg_query_failed(session: Session, query: Query, + failure: (ErrorResponseMessage | ClientQueryError)) + => + match failure + | let e: ErrorResponseMessage => + _out.print("Query failed: [" + e.severity + "] " + e.code + ": " + + e.message) + | let e: ClientQueryError => + _out.print("Query failed: client error") + end + close() + + be pg_stream_batch(session: Session, rows: Rows) => + _out.print(" Batch (" + rows.size().string() + " rows):") + for row in rows.values() do + _out.write(" ") + for field in row.fields.values() do + _out.write(" " + field.name + "=") + match field.value + | let v: String => _out.write(v) + | let v: I32 => _out.write(v.string()) + | None => _out.write("NULL") + end + end + _out.print("") + end + session.fetch_more() + + be pg_stream_complete(session: Session) => + _out.print("Stream complete.") + _out.print("Dropping table...") + _phase = _phase + 1 + _session.execute( + SimpleQuery("DROP TABLE streaming_example"), this) + + be pg_stream_failed(session: Session, + query: (PreparedQuery | NamedPreparedQuery), + failure: (ErrorResponseMessage | ClientQueryError)) + => + match failure + | let e: ErrorResponseMessage => + _out.print("Stream failed: [" + e.severity + "] " + e.code + ": " + + e.message) + | let e: ClientQueryError => + _out.print("Stream failed: client error") + end + close() + +class val ServerInfo + let host: String + let port: String + let username: String + let password: String + let database: String + + new val create(vars: (Array[String] val | None)) => + let e = EnvVars(vars) + host = try e("POSTGRES_HOST")? else "127.0.0.1" end + port = try e("POSTGRES_PORT")? else "5432" end + username = try e("POSTGRES_USERNAME")? else "postgres" end + password = try e("POSTGRES_PASSWORD")? else "postgres" end + database = try e("POSTGRES_DATABASE")? else "postgres" end diff --git a/postgres/_frontend_message.pony b/postgres/_frontend_message.pony index 778a4e1..8694237 100644 --- a/postgres/_frontend_message.pony +++ b/postgres/_frontend_message.pony @@ -537,6 +537,29 @@ primitive _FrontendMessage [] end + fun flush(): Array[U8] val => + """ + Build a Flush message. Forces the server to deliver any pending output + without ending the current query cycle or producing ReadyForQuery. + + Format: Byte1('H') Int32(4) + """ + try + recover val + let msg: Array[U8] = Array[U8].init(0, 5) + msg.update_u8(0, 'H')? + ifdef bigendian then + msg.update_u32(1, U32(4))? + else + msg.update_u32(1, U32(4).bswap())? + end + msg + end + else + _Unreachable() + [] + end + fun terminate(): Array[U8] val => """ Build a Terminate message. Sent before closing the TCP connection to diff --git a/postgres/_response_message_parser.pony b/postgres/_response_message_parser.pony index 5d270ef..02267f8 100644 --- a/postgres/_response_message_parser.pony +++ b/postgres/_response_message_parser.pony @@ -69,6 +69,8 @@ primitive _ResponseMessageParser s.state.on_parameter_status(s, msg) | let msg: _EmptyQueryResponseMessage => s.state.on_empty_query_response(s) + | _PortalSuspendedMessage => + s.state.on_portal_suspended(s) | None => // No complete message was found. Stop parsing for now. return diff --git a/postgres/_test.pony b/postgres/_test.pony index d051660..a62bd1c 100644 --- a/postgres/_test.pony +++ b/postgres/_test.pony @@ -66,6 +66,7 @@ actor \nodoc\ Main is TestList test(_TestFrontendMessageCloseStatement) test(_TestFrontendMessageSync) test(_TestFrontendMessageSSLRequest) + test(_TestFrontendMessageFlush) test(_TestFrontendMessageTerminate) test(_TestTerminateSentOnClose) test(_TestSSLNegotiationRefused) @@ -129,6 +130,7 @@ actor \nodoc\ Main is TestList test(_TestResponseParserMultipleMessagesChainCloseStatementSequence) test(_TestResponseParserMultipleMessagesChainSASLFullSequence) test(_TestResponseParserMultipleMessagesChainRemainingTypes) + test(_TestResponseParserMultipleMessagesChainStreamingQuerySequence) test(_TestFrontendMessageSASLInitialResponse) test(_TestFrontendMessageSASLResponse) test(_TestScramSha256MessageBuilders) @@ -176,6 +178,13 @@ actor \nodoc\ Main is TestList test(_TestCopyOutShutdownDrainsCopyQueue) test(_TestCopyOutAfterSessionClosed) test(_TestCopyOutExport) + test(_TestStreamingSuccess) + test(_TestStreamingEmpty) + test(_TestStreamingEarlyStop) + test(_TestStreamingServerError) + test(_TestStreamingShutdownDrainsQueue) + test(_TestStreamingQueryResults) + test(_TestStreamingAfterSessionClosed) class \nodoc\ iso _TestAuthenticate is UnitTest """ diff --git a/postgres/_test_frontend_message.pony b/postgres/_test_frontend_message.pony index dc09b37..d487261 100644 --- a/postgres/_test_frontend_message.pony +++ b/postgres/_test_frontend_message.pony @@ -323,6 +323,21 @@ class \nodoc\ iso _TestFrontendMessageCopyFail is UnitTest h.assert_array_eq[U8](expected, _FrontendMessage.copy_fail("err")) +class \nodoc\ iso _TestFrontendMessageFlush is UnitTest + fun name(): String => + "FrontendMessage/Flush" + + fun apply(h: TestHelper) => + // Flush: Byte1('H') Int32(4) = 5 bytes + let expected: Array[U8] = ifdef bigendian then + [ 'H'; 4; 0; 0; 0 ] + else + [ 'H'; 0; 0; 0; 4 ] + end + + h.assert_array_eq[U8](expected, + _FrontendMessage.flush()) + class \nodoc\ iso _TestFrontendMessageTerminate is UnitTest fun name(): String => "FrontendMessage/Terminate" diff --git a/postgres/_test_response_parser.pony b/postgres/_test_response_parser.pony index fad9317..b72e835 100644 --- a/postgres/_test_response_parser.pony +++ b/postgres/_test_response_parser.pony @@ -2118,6 +2118,144 @@ class \nodoc\ iso _TestResponseParserMultipleMessagesChainRemainingTypes h.fail("Buffer not fully consumed.") end +class \nodoc\ iso _TestResponseParserMultipleMessagesChainStreamingQuerySequence + is UnitTest + """ + Verify correct buffer advancement across a streaming query response: + RowDescription + DataRow + PortalSuspended + DataRow + DataRow + + PortalSuspended + DataRow + CommandComplete + ReadyForQuery. + + Streaming uses Execute(max_rows) which produces PortalSuspended after each + batch until the final batch ends with CommandComplete. + """ + fun name(): String => + "ResponseParser/MultipleMessages/Chain/StreamingQuerySequence" + + fun apply(h: TestHelper) ? => + let columns: Array[(String, String)] val = recover val + [("id", "int4"); ("name", "text")] + end + let r: Reader = Reader + r.append(_IncomingRowDescriptionTestMessage(columns)?.bytes()) + // First batch: 1 row + PortalSuspended + r.append(_IncomingDataRowTestMessage( + recover val [as (String | None): "1"; "Alice"] end).bytes()) + r.append(_IncomingPortalSuspendedTestMessage.bytes()) + // Second batch: 2 rows + PortalSuspended + r.append(_IncomingDataRowTestMessage( + recover val [as (String | None): "2"; "Bob"] end).bytes()) + r.append(_IncomingDataRowTestMessage( + recover val [as (String | None): "3"; "Carol"] end).bytes()) + r.append(_IncomingPortalSuspendedTestMessage.bytes()) + // Final batch: 1 row + CommandComplete (no more rows) + r.append(_IncomingDataRowTestMessage( + recover val [as (String | None): "4"; "Dave"] end).bytes()) + r.append(_IncomingCommandCompleteTestMessage("SELECT 4").bytes()) + r.append(_IncomingReadyForQueryTestMessage('I').bytes()) + + match _ResponseParser(r)? + | let m: _RowDescriptionMessage => + h.assert_eq[USize](2, m.columns.size()) + h.assert_eq[String]("id", m.columns(0)?._1) + h.assert_eq[U32](23, m.columns(0)?._2) + h.assert_eq[String]("name", m.columns(1)?._1) + h.assert_eq[U32](25, m.columns(1)?._2) + else + h.fail("Wrong message for RowDescription.") + return + end + + // First batch + match _ResponseParser(r)? + | let m: _DataRowMessage => + h.assert_eq[USize](2, m.columns.size()) + match m.columns(0)? + | "1" => None + else + h.fail("Batch 1 row col 0 not parsed correctly.") + return + end + else + h.fail("Wrong message for first DataRow.") + return + end + + if _ResponseParser(r)? isnt _PortalSuspendedMessage then + h.fail("Wrong message for first PortalSuspended.") + return + end + + // Second batch + match _ResponseParser(r)? + | let m: _DataRowMessage => + h.assert_eq[USize](2, m.columns.size()) + match m.columns(0)? + | "2" => None + else + h.fail("Batch 2 row 1 col 0 not parsed correctly.") + return + end + else + h.fail("Wrong message for second DataRow.") + return + end + + match _ResponseParser(r)? + | let m: _DataRowMessage => + h.assert_eq[USize](2, m.columns.size()) + match m.columns(0)? + | "3" => None + else + h.fail("Batch 2 row 2 col 0 not parsed correctly.") + return + end + else + h.fail("Wrong message for third DataRow.") + return + end + + if _ResponseParser(r)? isnt _PortalSuspendedMessage then + h.fail("Wrong message for second PortalSuspended.") + return + end + + // Final batch + match _ResponseParser(r)? + | let m: _DataRowMessage => + h.assert_eq[USize](2, m.columns.size()) + match m.columns(0)? + | "4" => None + else + h.fail("Batch 3 row col 0 not parsed correctly.") + return + end + else + h.fail("Wrong message for fourth DataRow.") + return + end + + match _ResponseParser(r)? + | let m: _CommandCompleteMessage => + h.assert_eq[String]("SELECT", m.id) + h.assert_eq[USize](4, m.value) + else + h.fail("Wrong message for CommandComplete.") + return + end + + match _ResponseParser(r)? + | let m: _ReadyForQueryMessage => + h.assert_is[TransactionStatus](TransactionIdle, + m.transaction_status()) + else + h.fail("Wrong message for ReadyForQuery.") + return + end + + if _ResponseParser(r)? isnt None then + h.fail("Buffer not fully consumed.") + end + primitive WriterToByteArray fun apply(writer: Writer): Array[U8] val => recover val diff --git a/postgres/_test_streaming.pony b/postgres/_test_streaming.pony new file mode 100644 index 0000000..6433055 --- /dev/null +++ b/postgres/_test_streaming.pony @@ -0,0 +1,1074 @@ +use "collections" +use lori = "lori" +use "pony_test" + +class \nodoc\ iso _TestStreamingSuccess is UnitTest + """ + Verifies the complete streaming success path: authenticate, send a + streaming query, receive two batches via PortalSuspended, then a final + batch via CommandComplete. Verify pg_stream_batch x3 + pg_stream_complete. + """ + fun name(): String => + "Streaming/Success" + + fun apply(h: TestHelper) => + let host = "127.0.0.1" + let port = "7702" + + let listener = _StreamingSuccessTestListener( + lori.TCPListenAuth(h.env.root), + host, + port, + h) + + h.dispose_when_done(listener) + h.long_test(5_000_000_000) + +actor \nodoc\ _StreamingSuccessTestClient is + (SessionStatusNotify & StreamingResultReceiver) + let _h: TestHelper + var _batches: USize = 0 + var _total_rows: USize = 0 + var _session: (Session | None) = None + + new create(h: TestHelper) => + _h = h + + be pg_session_connection_failed(s: Session) => + _h.fail("Unable to establish connection.") + _h.complete(false) + + be pg_session_authenticated(session: Session) => + _session = session + session.stream( + PreparedQuery("SELECT id FROM t", recover val Array[(String | None)] end), + 2, this) + + be pg_session_authentication_failed( + session: Session, + reason: AuthenticationFailureReason) + => + _h.fail("Unable to authenticate.") + _h.complete(false) + + be pg_stream_batch(session: Session, rows: Rows) => + _batches = _batches + 1 + _total_rows = _total_rows + rows.size() + if _batches <= 2 then + session.fetch_more() + end + // Third batch arrives from CommandComplete — no fetch_more needed. + + be pg_stream_complete(session: Session) => + if (_batches == 3) and (_total_rows == 5) then + _close_and_complete(true) + else + _h.fail("Expected 3 batches with 5 total rows but got " + + _batches.string() + " batches with " + _total_rows.string() + + " rows") + _close_and_complete(false) + end + + be pg_stream_failed(session: Session, + query: (PreparedQuery | NamedPreparedQuery), + failure: (ErrorResponseMessage | ClientQueryError)) + => + _h.fail("Unexpected stream failure.") + _close_and_complete(false) + + fun ref _close_and_complete(success: Bool) => + match _session + | let s: Session => s.close() + end + _h.complete(success) + +actor \nodoc\ _StreamingSuccessTestListener is lori.TCPListenerActor + var _tcp_listener: lori.TCPListener = lori.TCPListener.none() + let _server_auth: lori.TCPServerAuth + let _h: TestHelper + let _host: String + let _port: String + + new create(listen_auth: lori.TCPListenAuth, + host: String, + port: String, + h: TestHelper) + => + _host = host + _port = port + _h = h + _server_auth = lori.TCPServerAuth(listen_auth) + _tcp_listener = lori.TCPListener(listen_auth, host, port, this) + + fun ref _listener(): lori.TCPListener => + _tcp_listener + + fun ref _on_accept(fd: U32): _StreamingSuccessTestServer => + _StreamingSuccessTestServer(_server_auth, fd) + + fun ref _on_listening() => + Session( + ServerConnectInfo(lori.TCPConnectAuth(_h.env.root), _host, _port), + DatabaseConnectInfo("postgres", "postgres", "postgres"), + _StreamingSuccessTestClient(_h)) + + fun ref _on_listen_failure() => + _h.fail("Unable to listen") + _h.complete(false) + +actor \nodoc\ _StreamingSuccessTestServer + is (lori.TCPConnectionActor & lori.ServerLifecycleEventReceiver) + """ + Mock server that authenticates, responds to an extended query pipeline + with RowDescription, then simulates streaming: first Execute returns + 2 DataRows + PortalSuspended, second Execute returns 2 DataRows + + PortalSuspended, third Execute returns 1 DataRow + CommandComplete + + ReadyForQuery (after Sync). + """ + var _tcp_connection: lori.TCPConnection = lori.TCPConnection.none() + var _state: U8 = 0 + var _execute_count: U8 = 0 + let _reader: _MockMessageReader = _MockMessageReader + + new create(auth: lori.TCPServerAuth, fd: U32) => + _tcp_connection = lori.TCPConnection.server(auth, fd, this, this) + + fun ref _connection(): lori.TCPConnection => + _tcp_connection + + fun ref _on_received(data: Array[U8] iso) => + _reader.append(consume data) + _process() + + fun ref _process() => + if _state == 0 then + // Startup message + match _reader.read_startup_message() + | let _: Array[U8] val => + let auth_ok = _IncomingAuthenticationOkTestMessage.bytes() + let ready = _IncomingReadyForQueryTestMessage('I').bytes() + _tcp_connection.send(auth_ok) + _tcp_connection.send(ready) + _state = 1 + _process() + end + elseif _state == 1 then + // Read all extended query pipeline messages: Parse+Bind+Describe+Execute+Flush + // We read them one at a time until we get Execute ('E') + match _reader.read_message() + | let msg: Array[U8] val => + try + if msg(0)? == 'E' then + // First Execute received. Send RowDescription + 2 DataRows + + // PortalSuspended + _execute_count = _execute_count + 1 + try + let columns: Array[(String, String)] val = recover val + [("id", "int4")] + end + let row_desc = + _IncomingRowDescriptionTestMessage(columns)?.bytes() + _tcp_connection.send(row_desc) + end + _send_data_rows(2, (_execute_count.usize() - 1) * 2) + let ps = _IncomingPortalSuspendedTestMessage.bytes() + _tcp_connection.send(ps) + _state = 2 + else + // Parse, Bind, Describe, Flush — consume and continue + _process() + end + end + end + elseif _state == 2 then + // Subsequent Execute+Flush messages + match _reader.read_message() + | let msg: Array[U8] val => + try + if msg(0)? == 'E' then + _execute_count = _execute_count + 1 + if _execute_count == 3 then + // Final batch: 1 DataRow + CommandComplete + _send_data_rows(1, (_execute_count.usize() - 1) * 2) + let cmd_complete = + _IncomingCommandCompleteTestMessage("SELECT 5").bytes() + _tcp_connection.send(cmd_complete) + // Don't send ReadyForQuery yet — wait for Sync + _state = 3 + else + // 2 DataRows + PortalSuspended + _send_data_rows(2, (_execute_count.usize() - 1) * 2) + let ps = _IncomingPortalSuspendedTestMessage.bytes() + _tcp_connection.send(ps) + end + else + // Flush — consume and continue + _process() + end + end + end + elseif _state == 3 then + // Wait for Sync — skip non-Sync messages (e.g. leftover Flush) + match _reader.read_message() + | let msg: Array[U8] val => + try + if msg(0)? == 'S' then + let ready = _IncomingReadyForQueryTestMessage('I').bytes() + _tcp_connection.send(ready) + else + _process() + end + end + end + end + + fun ref _send_data_rows(count: USize, start: USize) => + for i in Range(0, count) do + let data_row_cols: Array[(String | None)] val = recover val + let v: String = (start + i + 1).string() + [as (String | None): v] + end + let data_row = _IncomingDataRowTestMessage(data_row_cols).bytes() + _tcp_connection.send(data_row) + end + +class \nodoc\ iso _TestStreamingEmpty is UnitTest + """ + Verifies that streaming a query returning zero rows delivers + pg_stream_complete without any pg_stream_batch calls. + """ + fun name(): String => + "Streaming/Empty" + + fun apply(h: TestHelper) => + let host = "127.0.0.1" + let port = "7703" + + let listener = _StreamingEmptyTestListener( + lori.TCPListenAuth(h.env.root), + host, + port, + h) + + h.dispose_when_done(listener) + h.long_test(5_000_000_000) + +actor \nodoc\ _StreamingEmptyTestClient is + (SessionStatusNotify & StreamingResultReceiver) + let _h: TestHelper + var _batches: USize = 0 + var _session: (Session | None) = None + + new create(h: TestHelper) => + _h = h + + be pg_session_connection_failed(s: Session) => + _h.fail("Unable to establish connection.") + _h.complete(false) + + be pg_session_authenticated(session: Session) => + _session = session + session.stream( + PreparedQuery("SELECT id FROM empty", recover val Array[(String | None)] end), + 2, this) + + be pg_session_authentication_failed( + session: Session, + reason: AuthenticationFailureReason) + => + _h.fail("Unable to authenticate.") + _h.complete(false) + + be pg_stream_batch(session: Session, rows: Rows) => + _batches = _batches + 1 + + be pg_stream_complete(session: Session) => + if _batches == 0 then + _close_and_complete(true) + else + _h.fail("Expected 0 batches but got " + _batches.string()) + _close_and_complete(false) + end + + be pg_stream_failed(session: Session, + query: (PreparedQuery | NamedPreparedQuery), + failure: (ErrorResponseMessage | ClientQueryError)) + => + _h.fail("Unexpected stream failure.") + _close_and_complete(false) + + fun ref _close_and_complete(success: Bool) => + match _session + | let s: Session => s.close() + end + _h.complete(success) + +actor \nodoc\ _StreamingEmptyTestListener is lori.TCPListenerActor + var _tcp_listener: lori.TCPListener = lori.TCPListener.none() + let _server_auth: lori.TCPServerAuth + let _h: TestHelper + let _host: String + let _port: String + + new create(listen_auth: lori.TCPListenAuth, + host: String, + port: String, + h: TestHelper) + => + _host = host + _port = port + _h = h + _server_auth = lori.TCPServerAuth(listen_auth) + _tcp_listener = lori.TCPListener(listen_auth, host, port, this) + + fun ref _listener(): lori.TCPListener => + _tcp_listener + + fun ref _on_accept(fd: U32): _StreamingEmptyTestServer => + _StreamingEmptyTestServer(_server_auth, fd) + + fun ref _on_listening() => + Session( + ServerConnectInfo(lori.TCPConnectAuth(_h.env.root), _host, _port), + DatabaseConnectInfo("postgres", "postgres", "postgres"), + _StreamingEmptyTestClient(_h)) + + fun ref _on_listen_failure() => + _h.fail("Unable to listen") + _h.complete(false) + +actor \nodoc\ _StreamingEmptyTestServer + is (lori.TCPConnectionActor & lori.ServerLifecycleEventReceiver) + """ + Mock server that responds to a streaming query with RowDescription + + CommandComplete("SELECT 0") — zero rows, no PortalSuspended. + """ + var _tcp_connection: lori.TCPConnection = lori.TCPConnection.none() + var _state: U8 = 0 + let _reader: _MockMessageReader = _MockMessageReader + + new create(auth: lori.TCPServerAuth, fd: U32) => + _tcp_connection = lori.TCPConnection.server(auth, fd, this, this) + + fun ref _connection(): lori.TCPConnection => + _tcp_connection + + fun ref _on_received(data: Array[U8] iso) => + _reader.append(consume data) + _process() + + fun ref _process() => + if _state == 0 then + match _reader.read_startup_message() + | let _: Array[U8] val => + let auth_ok = _IncomingAuthenticationOkTestMessage.bytes() + let ready = _IncomingReadyForQueryTestMessage('I').bytes() + _tcp_connection.send(auth_ok) + _tcp_connection.send(ready) + _state = 1 + _process() + end + elseif _state == 1 then + match _reader.read_message() + | let msg: Array[U8] val => + try + if msg(0)? == 'E' then + // Execute received — send RowDescription + CommandComplete + try + let columns: Array[(String, String)] val = recover val + [("id", "int4")] + end + let row_desc = + _IncomingRowDescriptionTestMessage(columns)?.bytes() + let cmd_complete = + _IncomingCommandCompleteTestMessage("SELECT 0").bytes() + _tcp_connection.send(row_desc) + _tcp_connection.send(cmd_complete) + end + _state = 2 + else + _process() + end + end + end + elseif _state == 2 then + // Wait for Sync — skip non-Sync messages (e.g. leftover Flush) + match _reader.read_message() + | let msg: Array[U8] val => + try + if msg(0)? == 'S' then + let ready = _IncomingReadyForQueryTestMessage('I').bytes() + _tcp_connection.send(ready) + else + _process() + end + end + end + end + +class \nodoc\ iso _TestStreamingEarlyStop is UnitTest + """ + Verifies that close_stream() ends streaming early. Client receives one + batch, calls close_stream() instead of fetch_more(), and verifies + pg_stream_complete fires. + """ + fun name(): String => + "Streaming/EarlyStop" + + fun apply(h: TestHelper) => + let host = "127.0.0.1" + let port = "7704" + + let listener = _StreamingEarlyStopTestListener( + lori.TCPListenAuth(h.env.root), + host, + port, + h) + + h.dispose_when_done(listener) + h.long_test(5_000_000_000) + +actor \nodoc\ _StreamingEarlyStopTestClient is + (SessionStatusNotify & StreamingResultReceiver) + let _h: TestHelper + var _batches: USize = 0 + var _session: (Session | None) = None + + new create(h: TestHelper) => + _h = h + + be pg_session_connection_failed(s: Session) => + _h.fail("Unable to establish connection.") + _h.complete(false) + + be pg_session_authenticated(session: Session) => + _session = session + session.stream( + PreparedQuery("SELECT id FROM t", recover val Array[(String | None)] end), + 2, this) + + be pg_session_authentication_failed( + session: Session, + reason: AuthenticationFailureReason) + => + _h.fail("Unable to authenticate.") + _h.complete(false) + + be pg_stream_batch(session: Session, rows: Rows) => + _batches = _batches + 1 + // Close instead of fetching more + session.close_stream() + + be pg_stream_complete(session: Session) => + if _batches == 1 then + _close_and_complete(true) + else + _h.fail("Expected 1 batch but got " + _batches.string()) + _close_and_complete(false) + end + + be pg_stream_failed(session: Session, + query: (PreparedQuery | NamedPreparedQuery), + failure: (ErrorResponseMessage | ClientQueryError)) + => + _h.fail("Unexpected stream failure.") + _close_and_complete(false) + + fun ref _close_and_complete(success: Bool) => + match _session + | let s: Session => s.close() + end + _h.complete(success) + +actor \nodoc\ _StreamingEarlyStopTestListener is lori.TCPListenerActor + var _tcp_listener: lori.TCPListener = lori.TCPListener.none() + let _server_auth: lori.TCPServerAuth + let _h: TestHelper + let _host: String + let _port: String + + new create(listen_auth: lori.TCPListenAuth, + host: String, + port: String, + h: TestHelper) + => + _host = host + _port = port + _h = h + _server_auth = lori.TCPServerAuth(listen_auth) + _tcp_listener = lori.TCPListener(listen_auth, host, port, this) + + fun ref _listener(): lori.TCPListener => + _tcp_listener + + fun ref _on_accept(fd: U32): _StreamingEarlyStopTestServer => + _StreamingEarlyStopTestServer(_server_auth, fd) + + fun ref _on_listening() => + Session( + ServerConnectInfo(lori.TCPConnectAuth(_h.env.root), _host, _port), + DatabaseConnectInfo("postgres", "postgres", "postgres"), + _StreamingEarlyStopTestClient(_h)) + + fun ref _on_listen_failure() => + _h.fail("Unable to listen") + _h.complete(false) + +actor \nodoc\ _StreamingEarlyStopTestServer + is (lori.TCPConnectionActor & lori.ServerLifecycleEventReceiver) + """ + Mock server that authenticates, responds to first Execute with 2 DataRows + + PortalSuspended, then responds to Sync (from close_stream) with + ReadyForQuery. + """ + var _tcp_connection: lori.TCPConnection = lori.TCPConnection.none() + var _state: U8 = 0 + let _reader: _MockMessageReader = _MockMessageReader + + new create(auth: lori.TCPServerAuth, fd: U32) => + _tcp_connection = lori.TCPConnection.server(auth, fd, this, this) + + fun ref _connection(): lori.TCPConnection => + _tcp_connection + + fun ref _on_received(data: Array[U8] iso) => + _reader.append(consume data) + _process() + + fun ref _process() => + if _state == 0 then + match _reader.read_startup_message() + | let _: Array[U8] val => + let auth_ok = _IncomingAuthenticationOkTestMessage.bytes() + let ready = _IncomingReadyForQueryTestMessage('I').bytes() + _tcp_connection.send(auth_ok) + _tcp_connection.send(ready) + _state = 1 + _process() + end + elseif _state == 1 then + match _reader.read_message() + | let msg: Array[U8] val => + try + if msg(0)? == 'E' then + try + let columns: Array[(String, String)] val = recover val + [("id", "int4")] + end + let row_desc = + _IncomingRowDescriptionTestMessage(columns)?.bytes() + _tcp_connection.send(row_desc) + end + let data_row_cols: Array[(String | None)] val = recover val + [as (String | None): "1"] + end + let data_row = _IncomingDataRowTestMessage(data_row_cols).bytes() + _tcp_connection.send(data_row) + let data_row_cols2: Array[(String | None)] val = recover val + [as (String | None): "2"] + end + let data_row2 = _IncomingDataRowTestMessage(data_row_cols2).bytes() + _tcp_connection.send(data_row2) + let ps = _IncomingPortalSuspendedTestMessage.bytes() + _tcp_connection.send(ps) + _state = 2 + else + _process() + end + end + end + elseif _state == 2 then + // Wait for Sync (from close_stream) — skip non-Sync messages + // (e.g. leftover Flush) + match _reader.read_message() + | let msg: Array[U8] val => + try + if msg(0)? == 'S' then + let ready = _IncomingReadyForQueryTestMessage('I').bytes() + _tcp_connection.send(ready) + else + _process() + end + end + end + end + +class \nodoc\ iso _TestStreamingServerError is UnitTest + """ + Verifies that an ErrorResponse during streaming delivers pg_stream_failed + and the session remains usable for subsequent queries. + """ + fun name(): String => + "Streaming/ServerError" + + fun apply(h: TestHelper) => + let host = "127.0.0.1" + let port = "7705" + + let listener = _StreamingServerErrorTestListener( + lori.TCPListenAuth(h.env.root), + host, + port, + h) + + h.dispose_when_done(listener) + h.long_test(5_000_000_000) + +actor \nodoc\ _StreamingServerErrorTestClient is + (SessionStatusNotify & StreamingResultReceiver & ResultReceiver) + let _h: TestHelper + var _session: (Session | None) = None + var _stream_failed: Bool = false + + new create(h: TestHelper) => + _h = h + + be pg_session_connection_failed(s: Session) => + _h.fail("Unable to establish connection.") + _h.complete(false) + + be pg_session_authenticated(session: Session) => + _session = session + session.stream( + PreparedQuery("SELECT id FROM bad", recover val Array[(String | None)] end), + 2, this) + + be pg_session_authentication_failed( + session: Session, + reason: AuthenticationFailureReason) + => + _h.fail("Unable to authenticate.") + _h.complete(false) + + be pg_stream_batch(session: Session, rows: Rows) => + _h.fail("Unexpected stream batch.") + _close_and_complete(false) + + be pg_stream_complete(session: Session) => + _h.fail("Unexpected stream complete.") + _close_and_complete(false) + + be pg_stream_failed(session: Session, + query: (PreparedQuery | NamedPreparedQuery), + failure: (ErrorResponseMessage | ClientQueryError)) + => + _stream_failed = true + // Verify session is still usable with a follow-up query + session.execute(SimpleQuery("SELECT 1"), this) + + be pg_query_result(session: Session, result: Result) => + if _stream_failed then + _close_and_complete(true) + else + _h.fail("Unexpected query result before stream failure.") + _close_and_complete(false) + end + + be pg_query_failed(session: Session, query: Query, + failure: (ErrorResponseMessage | ClientQueryError)) + => + _h.fail("Follow-up query failed.") + _close_and_complete(false) + + fun ref _close_and_complete(success: Bool) => + match _session + | let s: Session => s.close() + end + _h.complete(success) + +actor \nodoc\ _StreamingServerErrorTestListener is lori.TCPListenerActor + var _tcp_listener: lori.TCPListener = lori.TCPListener.none() + let _server_auth: lori.TCPServerAuth + let _h: TestHelper + let _host: String + let _port: String + + new create(listen_auth: lori.TCPListenAuth, + host: String, + port: String, + h: TestHelper) + => + _host = host + _port = port + _h = h + _server_auth = lori.TCPServerAuth(listen_auth) + _tcp_listener = lori.TCPListener(listen_auth, host, port, this) + + fun ref _listener(): lori.TCPListener => + _tcp_listener + + fun ref _on_accept(fd: U32): _StreamingServerErrorTestServer => + _StreamingServerErrorTestServer(_server_auth, fd) + + fun ref _on_listening() => + Session( + ServerConnectInfo(lori.TCPConnectAuth(_h.env.root), _host, _port), + DatabaseConnectInfo("postgres", "postgres", "postgres"), + _StreamingServerErrorTestClient(_h)) + + fun ref _on_listen_failure() => + _h.fail("Unable to listen") + _h.complete(false) + +actor \nodoc\ _StreamingServerErrorTestServer + is (lori.TCPConnectionActor & lori.ServerLifecycleEventReceiver) + """ + Mock server that authenticates, responds to the streaming query pipeline + with ErrorResponse (before any data), then responds to the follow-up + simple query normally. + """ + var _tcp_connection: lori.TCPConnection = lori.TCPConnection.none() + var _state: U8 = 0 + var _error_sent: Bool = false + let _reader: _MockMessageReader = _MockMessageReader + + new create(auth: lori.TCPServerAuth, fd: U32) => + _tcp_connection = lori.TCPConnection.server(auth, fd, this, this) + + fun ref _connection(): lori.TCPConnection => + _tcp_connection + + fun ref _on_received(data: Array[U8] iso) => + _reader.append(consume data) + _process() + + fun ref _process() => + if _state == 0 then + match _reader.read_startup_message() + | let _: Array[U8] val => + let auth_ok = _IncomingAuthenticationOkTestMessage.bytes() + let ready = _IncomingReadyForQueryTestMessage('I').bytes() + _tcp_connection.send(auth_ok) + _tcp_connection.send(ready) + _state = 1 + _process() + end + elseif _state == 1 then + // Read pipeline messages until we get Execute + match _reader.read_message() + | let msg: Array[U8] val => + try + if msg(0)? == 'E' then + // Send ErrorResponse immediately + let err = _IncomingErrorResponseTestMessage( + "ERROR", "42P01", "relation does not exist").bytes() + _tcp_connection.send(err) + _error_sent = true + _state = 2 + else + _process() + end + end + end + elseif _state == 2 then + // Wait for Sync from on_error_response, respond with ReadyForQuery + match _reader.read_message() + | let msg: Array[U8] val => + try + if msg(0)? == 'S' then + let ready = _IncomingReadyForQueryTestMessage('I').bytes() + _tcp_connection.send(ready) + _state = 3 + end + end + _process() + end + elseif _state == 3 then + // Follow-up simple query + match _reader.read_message() + | let msg: Array[U8] val => + try + if msg(0)? == 'Q' then + try + let columns: Array[(String, String)] val = recover val + [("?column?", "text")] + end + let row_desc = + _IncomingRowDescriptionTestMessage(columns)?.bytes() + let data_row_cols: Array[(String | None)] val = recover val + [as (String | None): "1"] + end + let data_row = + _IncomingDataRowTestMessage(data_row_cols).bytes() + let cmd_complete = + _IncomingCommandCompleteTestMessage("SELECT 1").bytes() + let ready = _IncomingReadyForQueryTestMessage('I').bytes() + _tcp_connection.send(row_desc) + _tcp_connection.send(data_row) + _tcp_connection.send(cmd_complete) + _tcp_connection.send(ready) + end + end + end + end + end + +class \nodoc\ iso _TestStreamingShutdownDrainsQueue is UnitTest + """ + Verifies that when a session shuts down, pending stream() calls receive + pg_stream_failed with SessionClosed. Uses a misbehaving server that + authenticates but never sends ReadyForQuery. + """ + fun name(): String => + "Streaming/ShutdownDrainsQueue" + + fun apply(h: TestHelper) => + let host = "127.0.0.1" + let port = "7706" + + let listener = _StreamingShutdownTestListener( + lori.TCPListenAuth(h.env.root), + host, + port, + h) + + h.dispose_when_done(listener) + h.long_test(5_000_000_000) + +actor \nodoc\ _StreamingShutdownTestClient is + (SessionStatusNotify & StreamingResultReceiver) + let _h: TestHelper + var _pending: USize = 0 + + new create(h: TestHelper) => + _h = h + + be pg_session_connection_failed(s: Session) => + _h.fail("Unable to establish connection.") + _h.complete(false) + + be pg_session_authenticated(session: Session) => + _pending = 2 + session.stream( + PreparedQuery("SELECT 1", recover val Array[(String | None)] end), + 2, this) + session.stream( + PreparedQuery("SELECT 2", recover val Array[(String | None)] end), + 2, this) + session.close() + + be pg_session_authentication_failed( + session: Session, + reason: AuthenticationFailureReason) + => + _h.fail("Unable to authenticate.") + _h.complete(false) + + be pg_stream_batch(session: Session, rows: Rows) => + _h.fail("Unexpected stream batch.") + _h.complete(false) + + be pg_stream_complete(session: Session) => + _h.fail("Unexpected stream complete.") + _h.complete(false) + + be pg_stream_failed(session: Session, + query: (PreparedQuery | NamedPreparedQuery), + failure: (ErrorResponseMessage | ClientQueryError)) + => + match failure + | SessionClosed => + _pending = _pending - 1 + if _pending == 0 then + _h.complete(true) + end + else + _h.fail("Got an incorrect stream failure reason.") + _h.complete(false) + end + +actor \nodoc\ _StreamingShutdownTestListener is lori.TCPListenerActor + var _tcp_listener: lori.TCPListener = lori.TCPListener.none() + let _server_auth: lori.TCPServerAuth + let _h: TestHelper + let _host: String + let _port: String + + new create(listen_auth: lori.TCPListenAuth, + host: String, + port: String, + h: TestHelper) + => + _host = host + _port = port + _h = h + _server_auth = lori.TCPServerAuth(listen_auth) + _tcp_listener = lori.TCPListener(listen_auth, host, port, this) + + fun ref _listener(): lori.TCPListener => + _tcp_listener + + fun ref _on_accept(fd: U32): _DoesntAnswerTestServer => + _DoesntAnswerTestServer(_server_auth, fd) + + fun ref _on_listening() => + Session( + ServerConnectInfo(lori.TCPConnectAuth(_h.env.root), _host, _port), + DatabaseConnectInfo("postgres", "postgres", "postgres"), + _StreamingShutdownTestClient(_h)) + + fun ref _on_listen_failure() => + _h.fail("Unable to listen") + _h.complete(false) + +class \nodoc\ iso _TestStreamingQueryResults is UnitTest + """ + Integration test: create table, insert 5 rows, stream with window_size=2. + Verify 3 batches (2+2+1 rows), all rows received, pg_stream_complete. + """ + fun name(): String => + "integration/Streaming/QueryResults" + + fun apply(h: TestHelper) => + let info = _ConnectionTestConfiguration(h.env.vars) + + let session = Session( + ServerConnectInfo(lori.TCPConnectAuth(h.env.root), info.host, info.port), + DatabaseConnectInfo(info.username, info.password, info.database), + _StreamingQueryResultsNotify(h)) + + h.dispose_when_done(session) + h.long_test(10_000_000_000) + +actor \nodoc\ _StreamingQueryResultsNotify is + (SessionStatusNotify & ResultReceiver & StreamingResultReceiver) + let _h: TestHelper + var _phase: USize = 0 + var _batches: USize = 0 + var _total_rows: USize = 0 + var _session: (Session | None) = None + + new create(h: TestHelper) => + _h = h + + be pg_session_authenticated(session: Session) => + _session = session + // Phase 0: drop table if exists + session.execute( + SimpleQuery("DROP TABLE IF EXISTS streaming_test"), this) + + be pg_session_authentication_failed( + session: Session, + reason: AuthenticationFailureReason) + => + _h.fail("Unable to authenticate.") + _h.complete(false) + + be pg_query_result(session: Session, result: Result) => + _phase = _phase + 1 + match _phase + | 1 => + // Table dropped. Create it. + session.execute( + SimpleQuery("CREATE TABLE streaming_test (id INT NOT NULL)"), this) + | 2 => + // Table created. Insert 5 rows. + session.execute( + SimpleQuery( + "INSERT INTO streaming_test VALUES (1),(2),(3),(4),(5)"), this) + | 3 => + // Rows inserted. Start streaming. + session.stream( + PreparedQuery("SELECT id FROM streaming_test ORDER BY id", + recover val Array[(String | None)] end), + 2, this) + | 5 => + // Table dropped after streaming. Done. + _close_and_complete(true) + end + + be pg_query_failed(session: Session, query: Query, + failure: (ErrorResponseMessage | ClientQueryError)) + => + _h.fail("Query failed.") + _close_and_complete(false) + + be pg_stream_batch(session: Session, rows: Rows) => + _batches = _batches + 1 + _total_rows = _total_rows + rows.size() + session.fetch_more() + + be pg_stream_complete(session: Session) => + if (_batches == 3) and (_total_rows == 5) then + // Drop the table (phase becomes 4, then pg_query_result gets phase 5) + _phase = _phase + 1 + session.execute( + SimpleQuery("DROP TABLE streaming_test"), this) + else + _h.fail("Expected 3 batches with 5 total rows but got " + + _batches.string() + " batches with " + _total_rows.string() + + " rows") + _close_and_complete(false) + end + + be pg_stream_failed(session: Session, + query: (PreparedQuery | NamedPreparedQuery), + failure: (ErrorResponseMessage | ClientQueryError)) + => + _h.fail("Unexpected stream failure.") + _close_and_complete(false) + + fun ref _close_and_complete(success: Bool) => + match _session + | let s: Session => s.close() + end + _h.complete(success) + +class \nodoc\ iso _TestStreamingAfterSessionClosed is UnitTest + """ + Verifies that calling stream() after the session has been closed delivers + pg_stream_failed with SessionClosed. + """ + fun name(): String => + "integration/Streaming/AfterSessionClosed" + + fun apply(h: TestHelper) => + let info = _ConnectionTestConfiguration(h.env.vars) + + let session = Session( + ServerConnectInfo(lori.TCPConnectAuth(h.env.root), info.host, info.port), + DatabaseConnectInfo(info.username, info.password, info.database), + _StreamingAfterSessionClosedNotify(h)) + + h.dispose_when_done(session) + h.long_test(5_000_000_000) + +actor \nodoc\ _StreamingAfterSessionClosedNotify is + (SessionStatusNotify & StreamingResultReceiver) + let _h: TestHelper + + new create(h: TestHelper) => + _h = h + + be pg_session_authenticated(session: Session) => + session.close() + + be pg_session_authentication_failed( + session: Session, + reason: AuthenticationFailureReason) + => + _h.fail("Unexpected authentication failure") + + be pg_session_shutdown(session: Session) => + session.stream( + PreparedQuery("SELECT 1", recover val Array[(String | None)] end), + 2, this) + + be pg_stream_batch(session: Session, rows: Rows) => + _h.fail("Unexpected stream batch.") + _h.complete(false) + + be pg_stream_complete(session: Session) => + _h.fail("Unexpected stream complete.") + _h.complete(false) + + be pg_stream_failed(session: Session, + query: (PreparedQuery | NamedPreparedQuery), + failure: (ErrorResponseMessage | ClientQueryError)) + => + if failure is SessionClosed then + _h.complete(true) + else + _h.fail("Expected SessionClosed but got a different failure.") + _h.complete(false) + end diff --git a/postgres/postgres.pony b/postgres/postgres.pony index 7faa7b4..b6a8c44 100644 --- a/postgres/postgres.pony +++ b/postgres/postgres.pony @@ -201,6 +201,35 @@ be pg_copy_complete(session: Session, count: USize) => _env.out.print("Exported " + count.string() + " rows") ``` +## Row Streaming + +`session.stream()` delivers rows in windowed batches using the +extended query protocol's portal suspension mechanism. Unlike +`execute()` which buffers all rows before delivery, streaming enables +pull-based paged result consumption with bounded memory: + +```pony +be pg_session_authenticated(session: Session) => + session.stream( + PreparedQuery("SELECT * FROM big_table", + recover val Array[(String | None)] end), + 100, this) // window_size = 100 rows per batch + +be pg_stream_batch(session: Session, rows: Rows) => + // Process this batch of up to 100 rows + for row in rows.values() do + // ... + end + session.fetch_more() // Pull the next batch + +be pg_stream_complete(session: Session) => + _env.out.print("All rows processed") +``` + +Call `session.close_stream()` to end streaming early. Only +`PreparedQuery` and `NamedPreparedQuery` are supported — streaming +requires the extended query protocol. + ## Query Cancellation `session.cancel()` requests cancellation of the currently executing @@ -227,6 +256,7 @@ supported. * NoticeResponse delivery (non-fatal server messages) * COPY FROM STDIN (bulk data loading) * COPY TO STDOUT (bulk data export) +* Row streaming (windowed batch delivery) * Query cancellation * ParameterStatus tracking (server runtime parameters) """ diff --git a/postgres/session.pony b/postgres/session.pony index 4719265..0fdea7c 100644 --- a/postgres/session.pony +++ b/postgres/session.pony @@ -113,6 +113,39 @@ actor Session is (lori.TCPConnectionActor & lori.ClientLifecycleEventReceiver) """ state.copy_out(this, sql, receiver) + be stream(query: (PreparedQuery | NamedPreparedQuery), + window_size: U32, receiver: StreamingResultReceiver) + => + """ + Start a streaming query that delivers rows in windowed batches via + `StreamingResultReceiver`. Each batch contains up to `window_size` rows. + Call `fetch_more()` from `pg_stream_batch` to pull the next batch, or + `close_stream()` to end early. + + Only `PreparedQuery` and `NamedPreparedQuery` are supported — streaming + uses the extended query protocol's `Execute(max_rows)` + `PortalSuspended` + mechanism which requires a prepared statement. + """ + state.stream(this, query, window_size, receiver) + + be fetch_more() => + """ + Request the next batch of rows during a streaming query. The next + `pg_stream_batch` callback delivers the rows. Safe to call at any + time — no-op if no streaming query is active, if the stream has + already completed naturally, or if the stream has already failed. + """ + state.fetch_more(this) + + be close_stream() => + """ + End a streaming query early. The `pg_stream_complete` callback fires + when the server acknowledges the close. Safe to call at any time — + no-op if no streaming query is active, if the stream has already + completed naturally, or if the stream has already failed. + """ + state.close_stream(this) + be close() => """ Close the connection. Sends a Terminate message to the server before @@ -185,6 +218,12 @@ class ref _SessionUnopened is _ConnectableState => receiver.pg_copy_failed(s, SessionNeverOpened) + fun ref stream(s: Session ref, + query: (PreparedQuery | NamedPreparedQuery), + window_size: U32, receiver: StreamingResultReceiver) + => + receiver.pg_stream_failed(s, query, SessionNeverOpened) + fun ref close_statement(s: Session ref, name: String) => None @@ -219,6 +258,12 @@ class ref _SessionClosed is (_NotConnectableState & _UnconnectedState) => receiver.pg_copy_failed(s, SessionClosed) + fun ref stream(s: Session ref, + query: (PreparedQuery | NamedPreparedQuery), + window_size: U32, receiver: StreamingResultReceiver) + => + receiver.pg_stream_failed(s, query, SessionClosed) + fun ref close_statement(s: Session ref, name: String) => None @@ -310,6 +355,12 @@ class ref _SessionSSLNegotiating => receiver.pg_copy_failed(s, SessionNotAuthenticated) + fun ref stream(s: Session ref, + query: (PreparedQuery | NamedPreparedQuery), + window_size: U32, receiver: StreamingResultReceiver) + => + receiver.pg_stream_failed(s, query, SessionNotAuthenticated) + fun ref close_statement(s: Session ref, name: String) => None @@ -322,6 +373,12 @@ class ref _SessionSSLNegotiating fun ref abort_copy(s: Session ref, reason: String) => None + fun ref fetch_more(s: Session ref) => + None + + fun ref close_stream(s: Session ref) => + None + fun ref on_notice(s: Session ref, msg: NoticeResponseMessage) => _IllegalState() @@ -330,6 +387,9 @@ class ref _SessionSSLNegotiating => _IllegalState() + fun ref on_portal_suspended(s: Session ref) => + _IllegalState() + fun ref cancel(s: Session ref) => None @@ -382,6 +442,12 @@ class ref _SessionConnected is _AuthenticableState => receiver.pg_copy_failed(s, SessionNotAuthenticated) + fun ref stream(s: Session ref, + query: (PreparedQuery | NamedPreparedQuery), + window_size: U32, receiver: StreamingResultReceiver) + => + receiver.pg_stream_failed(s, query, SessionNotAuthenticated) + fun ref close_statement(s: Session ref, name: String) => None @@ -541,6 +607,12 @@ class ref _SessionSCRAMAuthenticating is (_ConnectedState & _NotAuthenticated) => receiver.pg_copy_failed(s, SessionNotAuthenticated) + fun ref stream(s: Session ref, + query: (PreparedQuery | NamedPreparedQuery), + window_size: U32, receiver: StreamingResultReceiver) + => + receiver.pg_stream_failed(s, query, SessionNotAuthenticated) + fun ref close_statement(s: Session ref, name: String) => None @@ -593,12 +665,28 @@ class val _QueuedCopyOut sql = sql' receiver = receiver' +class val _QueuedStreamingQuery + """ + A queued streaming query operation waiting to be dispatched. + """ + let query: (PreparedQuery | NamedPreparedQuery) + let window_size: U32 + let receiver: StreamingResultReceiver + + new val create(query': (PreparedQuery | NamedPreparedQuery), + window_size': U32, receiver': StreamingResultReceiver) + => + query = query' + window_size = window_size' + receiver = receiver' + type _QueueItem is ( _QueuedQuery | _QueuedPrepare | _QueuedCloseStatement | _QueuedCopyIn - | _QueuedCopyOut ) + | _QueuedCopyOut + | _QueuedStreamingQuery ) class _SessionLoggedIn is _AuthenticatedState """ @@ -648,6 +736,9 @@ class _SessionLoggedIn is _AuthenticatedState fun ref on_row_description(s: Session ref, msg: _RowDescriptionMessage) => query_state.on_row_description(s, this, msg) + fun ref on_portal_suspended(s: Session ref) => + query_state.on_portal_suspended(s, this) + fun ref cancel(s: Session ref) => match query_state | let _: _QueryReady => None @@ -682,6 +773,23 @@ class _SessionLoggedIn is _AuthenticatedState query_queue.push(_QueuedCopyOut(sql, receiver)) query_state.try_run_query(s, this) + fun ref stream(s: Session ref, + query: (PreparedQuery | NamedPreparedQuery), + window_size: U32, receiver: StreamingResultReceiver) + => + query_queue.push(_QueuedStreamingQuery(query, window_size, receiver)) + query_state.try_run_query(s, this) + + fun ref fetch_more(s: Session ref) => + match query_state + | let sq: _StreamingQueryInFlight => sq.fetch_more(s, this) + end + + fun ref close_stream(s: Session ref) => + match query_state + | let sq: _StreamingQueryInFlight => sq.close_stream(s) + end + fun ref send_copy_data(s: Session ref, data: Array[U8] val) => match query_state | let c: _CopyInInFlight => c.send_copy_data(s, this, data) @@ -730,6 +838,8 @@ class _SessionLoggedIn is _AuthenticatedState ci.receiver.pg_copy_failed(s, SessionClosed) | let co: _QueuedCopyOut => co.receiver.pg_copy_failed(s, SessionClosed) + | let sq: _QueuedStreamingQuery => + sq.receiver.pg_stream_failed(s, sq.query, SessionClosed) end end query_queue.clear() @@ -768,6 +878,10 @@ interface _QueryState fun ref on_copy_data(s: Session ref, li: _SessionLoggedIn ref, msg: _CopyDataMessage) fun ref on_copy_done(s: Session ref, li: _SessionLoggedIn ref) + fun ref on_portal_suspended(s: Session ref, li: _SessionLoggedIn ref) + """ + Called when a portal is suspended during streaming (more rows available). + """ fun ref try_run_query(s: Session ref, li: _SessionLoggedIn ref) fun ref drain_in_flight(s: Session ref, li: _SessionLoggedIn ref) @@ -795,6 +909,8 @@ trait _QueryNoQueryInFlight is _QueryState msg: _CopyDataMessage) => li.shutdown(s) fun ref on_copy_done(s: Session ref, li: _SessionLoggedIn ref) => li.shutdown(s) + fun ref on_portal_suspended(s: Session ref, li: _SessionLoggedIn ref) => + li.shutdown(s) fun ref try_run_query(s: Session ref, li: _SessionLoggedIn ref) => None fun ref drain_in_flight(s: Session ref, li: _SessionLoggedIn ref) => None @@ -904,6 +1020,52 @@ class _QueryReady is _QueryNoQueryInFlight | let co: _QueuedCopyOut => li.query_state = _CopyOutInFlight s._connection().send(_FrontendMessage.query(co.sql)) + | let sq: _QueuedStreamingQuery => + li.query_state = _StreamingQueryInFlight.create() + match sq.query + | let pq: PreparedQuery => + let parse = _FrontendMessage.parse("", pq.string, + recover val Array[U32] end) + let bind = _FrontendMessage.bind("", "", pq.params) + let describe = _FrontendMessage.describe_portal("") + let execute = _FrontendMessage.execute_msg("", sq.window_size) + let flush_msg = _FrontendMessage.flush() + let combined = recover val + let total = parse.size() + bind.size() + describe.size() + + execute.size() + flush_msg.size() + let buf = Array[U8](total) + buf.copy_from(parse, 0, 0, parse.size()) + buf.copy_from(bind, 0, parse.size(), bind.size()) + buf.copy_from(describe, 0, + parse.size() + bind.size(), describe.size()) + buf.copy_from(execute, 0, + parse.size() + bind.size() + describe.size(), execute.size()) + buf.copy_from(flush_msg, 0, + parse.size() + bind.size() + describe.size() + execute.size(), + flush_msg.size()) + buf + end + s._connection().send(consume combined) + | let nq: NamedPreparedQuery => + let bind = _FrontendMessage.bind("", nq.name, nq.params) + let describe = _FrontendMessage.describe_portal("") + let execute = _FrontendMessage.execute_msg("", sq.window_size) + let flush_msg = _FrontendMessage.flush() + let combined = recover val + let total = bind.size() + describe.size() + + execute.size() + flush_msg.size() + let buf = Array[U8](total) + buf.copy_from(bind, 0, 0, bind.size()) + buf.copy_from(describe, 0, bind.size(), describe.size()) + buf.copy_from(execute, 0, + bind.size() + describe.size(), execute.size()) + buf.copy_from(flush_msg, 0, + bind.size() + describe.size() + execute.size(), + flush_msg.size()) + buf + end + s._connection().send(consume combined) + end end end else @@ -1037,6 +1199,10 @@ class _SimpleQueryInFlight is _QueryState fun ref on_copy_done(s: Session ref, li: _SessionLoggedIn ref) => li.shutdown(s) + fun ref on_portal_suspended(s: Session ref, li: _SessionLoggedIn ref) => + li.shutdown(s) + + fun ref drain_in_flight(s: Session ref, li: _SessionLoggedIn ref) => if not _error then try @@ -1188,6 +1354,10 @@ class _ExtendedQueryInFlight is _QueryState fun ref on_copy_done(s: Session ref, li: _SessionLoggedIn ref) => li.shutdown(s) + fun ref on_portal_suspended(s: Session ref, li: _SessionLoggedIn ref) => + li.shutdown(s) + + fun ref drain_in_flight(s: Session ref, li: _SessionLoggedIn ref) => if not _error then try @@ -1292,6 +1462,10 @@ class _PrepareInFlight is _QueryState fun ref on_copy_done(s: Session ref, li: _SessionLoggedIn ref) => li.shutdown(s) + fun ref on_portal_suspended(s: Session ref, li: _SessionLoggedIn ref) => + li.shutdown(s) + + fun ref drain_in_flight(s: Session ref, li: _SessionLoggedIn ref) => if not _error then try @@ -1371,6 +1545,10 @@ class _CloseStatementInFlight is _QueryState fun ref on_copy_done(s: Session ref, li: _SessionLoggedIn ref) => li.shutdown(s) + fun ref on_portal_suspended(s: Session ref, li: _SessionLoggedIn ref) => + li.shutdown(s) + + fun ref drain_in_flight(s: Session ref, li: _SessionLoggedIn ref) => try li.query_queue.shift()? @@ -1473,6 +1651,10 @@ class _CopyInInFlight is _QueryState fun ref on_copy_done(s: Session ref, li: _SessionLoggedIn ref) => li.shutdown(s) + fun ref on_portal_suspended(s: Session ref, li: _SessionLoggedIn ref) => + li.shutdown(s) + + fun ref try_run_query(s: Session ref, li: _SessionLoggedIn ref) => None fun ref drain_in_flight(s: Session ref, li: _SessionLoggedIn ref) => @@ -1573,6 +1755,10 @@ class _CopyOutInFlight is _QueryState => li.shutdown(s) + fun ref on_portal_suspended(s: Session ref, li: _SessionLoggedIn ref) => + li.shutdown(s) + + fun ref try_run_query(s: Session ref, li: _SessionLoggedIn ref) => None fun ref drain_in_flight(s: Session ref, li: _SessionLoggedIn ref) => @@ -1590,6 +1776,179 @@ class _CopyOutInFlight is _QueryState _Unreachable() end +class _StreamingQueryInFlight is _QueryState + """ + Streaming query in progress. Delivers rows in windowed batches via + `StreamingResultReceiver`. Uses Execute(max_rows > 0) + Flush to keep the + portal alive between batches, with Sync sent only on completion or error + to trigger ReadyForQuery. `_completing` guards against `fetch_more` and + `close_stream` sending messages after `on_command_complete` has already + sent Sync — the receiver may call `fetch_more()` in response to the + final `pg_stream_batch` before `ReadyForQuery` arrives. + """ + var _data_rows: Array[Array[(String|None)] val] iso + var _row_description: (Array[(String, U32)] val | None) + var _error: Bool = false + var _completing: Bool = false + + new create() => + _data_rows = recover iso Array[Array[(String|None)] val] end + _row_description = None + + fun ref try_run_query(s: Session ref, li: _SessionLoggedIn ref) => None + + fun ref on_data_row(s: Session ref, li: _SessionLoggedIn ref, + msg: _DataRowMessage) + => + _data_rows.push(msg.columns) + + fun ref on_row_description(s: Session ref, li: _SessionLoggedIn ref, + msg: _RowDescriptionMessage) + => + _row_description = msg.columns + + fun ref on_portal_suspended(s: Session ref, li: _SessionLoggedIn ref) => + try + let sq = li.query_queue(0)? as _QueuedStreamingQuery + let rows = _data_rows = recover iso + Array[Array[(String|None)] val].create() + end + match _row_description + | let desc: Array[(String, U32)] val => + let rows_object = _RowsBuilder(consume rows, desc)? + sq.receiver.pg_stream_batch(s, rows_object) + else + _Unreachable() + end + else + _Unreachable() + end + + fun ref fetch_more(s: Session ref, li: _SessionLoggedIn ref) => + // After CommandComplete or ErrorResponse, the portal is destroyed by + // Sync. The receiver may still call fetch_more() in response to the + // final pg_stream_batch before ReadyForQuery arrives — silently ignore. + if _completing or _error then return end + try + let sq = li.query_queue(0)? as _QueuedStreamingQuery + let execute = _FrontendMessage.execute_msg("", sq.window_size) + let flush_msg = _FrontendMessage.flush() + let combined = recover val + let total = execute.size() + flush_msg.size() + let buf = Array[U8](total) + buf.copy_from(execute, 0, 0, execute.size()) + buf.copy_from(flush_msg, 0, execute.size(), flush_msg.size()) + buf + end + s._connection().send(consume combined) + else + _Unreachable() + end + + fun ref close_stream(s: Session ref) => + if (not _error) and (not _completing) then + s._connection().send(_FrontendMessage.sync()) + end + + fun ref on_command_complete(s: Session ref, li: _SessionLoggedIn ref, + msg: _CommandCompleteMessage) + => + // Final batch — deliver any remaining accumulated rows. + try + let sq = li.query_queue(0)? as _QueuedStreamingQuery + let rows = _data_rows = recover iso + Array[Array[(String|None)] val].create() + end + if rows.size() > 0 then + match _row_description + | let desc: Array[(String, U32)] val => + let rows_object = _RowsBuilder(consume rows, desc)? + sq.receiver.pg_stream_batch(s, rows_object) + else + _Unreachable() + end + end + else + _Unreachable() + end + // Send Sync to trigger ReadyForQuery and destroy the portal. + // _completing prevents close_stream() from sending a duplicate Sync if it + // arrives between this point and ReadyForQuery. + _completing = true + s._connection().send(_FrontendMessage.sync()) + + fun ref on_ready_for_query(s: Session ref, li: _SessionLoggedIn ref) => + if not _error then + try + let sq = li.query_queue(0)? as _QueuedStreamingQuery + sq.receiver.pg_stream_complete(s) + else + _Unreachable() + end + end + try + li.query_queue.shift()? + else + _Unreachable() + end + li.query_state = _QueryReady + li.query_state.try_run_query(s, li) + + fun ref on_error_response(s: Session ref, li: _SessionLoggedIn ref, + msg: ErrorResponseMessage) + => + _error = true + try + let sq = li.query_queue(0)? as _QueuedStreamingQuery + _data_rows = recover iso Array[Array[(String|None)] val] end + _row_description = None + sq.receiver.pg_stream_failed(s, sq.query, msg) + else + _Unreachable() + end + // Sync is required because streaming uses Flush (not Sync) to keep the + // portal alive. Without a pending Sync, the server waits indefinitely + // after ErrorResponse, deadlocking the session. + s._connection().send(_FrontendMessage.sync()) + + fun ref on_empty_query_response(s: Session ref, + li: _SessionLoggedIn ref) + => + li.shutdown(s) + + fun ref on_copy_in_response(s: Session ref, li: _SessionLoggedIn ref, + msg: _CopyInResponseMessage) + => + li.shutdown(s) + + fun ref on_copy_out_response(s: Session ref, li: _SessionLoggedIn ref, + msg: _CopyOutResponseMessage) + => + li.shutdown(s) + + fun ref on_copy_data(s: Session ref, li: _SessionLoggedIn ref, + msg: _CopyDataMessage) + => + li.shutdown(s) + + fun ref on_copy_done(s: Session ref, li: _SessionLoggedIn ref) => + li.shutdown(s) + + fun ref drain_in_flight(s: Session ref, li: _SessionLoggedIn ref) => + if not _error then + try + let sq = li.query_queue(0)? as _QueuedStreamingQuery + sq.receiver.pg_stream_failed(s, sq.query, SessionClosed) + else + _Unreachable() + end + end + try + li.query_queue.shift()? + else + _Unreachable() + end + interface _SessionState fun on_connected(s: Session ref) """ @@ -1697,6 +2056,25 @@ interface _SessionState Called when the server sends a CopyDone message, indicating the end of the COPY TO STDOUT data stream. """ + fun ref on_portal_suspended(s: Session ref) + """ + Called when the server sends a PortalSuspended message during a streaming + query, indicating more rows are available for the current portal. + """ + fun ref stream(s: Session ref, + query: (PreparedQuery | NamedPreparedQuery), + window_size: U32, receiver: StreamingResultReceiver) + """ + Called when a client requests a streaming query execution. + """ + fun ref fetch_more(s: Session ref) + """ + Called when a client requests the next batch of streaming rows. + """ + fun ref close_stream(s: Session ref) + """ + Called when a client requests early termination of a streaming query. + """ fun ref on_ready_for_query(s: Session ref, msg: _ReadyForQueryMessage) """ Called when the server sends a "ready for query" message @@ -1867,6 +2245,12 @@ trait _ConnectedState is _NotConnectableState fun ref abort_copy(s: Session ref, reason: String) => None + fun ref fetch_more(s: Session ref) => + None + + fun ref close_stream(s: Session ref) => + None + fun ref close(s: Session ref) => shutdown(s) @@ -1928,6 +2312,12 @@ trait _UnconnectedState is (_NotAuthenticableState & _NotAuthenticated) fun ref abort_copy(s: Session ref, reason: String) => None + fun ref fetch_more(s: Session ref) => + None + + fun ref close_stream(s: Session ref) => + None + fun ref close(s: Session ref) => None @@ -2091,3 +2481,6 @@ trait _NotAuthenticated fun ref on_copy_done(s: Session ref) => _IllegalState() + + fun ref on_portal_suspended(s: Session ref) => + _IllegalState() diff --git a/postgres/streaming_result_receiver.pony b/postgres/streaming_result_receiver.pony new file mode 100644 index 0000000..41d0407 --- /dev/null +++ b/postgres/streaming_result_receiver.pony @@ -0,0 +1,33 @@ +interface tag StreamingResultReceiver + """ + Receives results from a `Session.stream()` call. Unlike `ResultReceiver` + which buffers all rows before delivery, streaming delivers rows in + fixed-size batches as they arrive from the server. + + The flow is pull-based: after each `pg_stream_batch`, call + `session.fetch_more()` to request the next batch. When the server has + no more rows, `pg_stream_complete` fires. Call `session.close_stream()` + to end streaming early. + """ + + be pg_stream_batch(session: Session, rows: Rows) + """ + Called when a batch of rows is available. The batch size is at most the + `window_size` passed to `session.stream()`. After processing the batch, + call `session.fetch_more()` to request the next batch. + """ + + be pg_stream_complete(session: Session) + """ + Called when all rows have been delivered and the streaming query is + finished. No further batches will arrive. + """ + + be pg_stream_failed(session: Session, + query: (PreparedQuery | NamedPreparedQuery), + failure: (ErrorResponseMessage | ClientQueryError)) + """ + Called when the streaming query fails. The failure is either a server + error (ErrorResponseMessage) or a client-side error (ClientQueryError) + such as the session being closed or not yet authenticated. + """