diff --git a/pg_query.h b/pg_query.h index def0254e..f57d8f85 100644 --- a/pg_query.h +++ b/pg_query.h @@ -166,6 +166,64 @@ void pg_query_exit(void); #define PG_VERSION "17.7" #define PG_VERSION_NUM 170007 +// Raw parse tree access (bypasses protobuf serialization) +// Note: The returned tree uses PostgreSQL's memory context. The tree is only +// valid until pg_query_free_raw_parse_result is called or pg_query_exit is called. + +// Forward declaration of PostgreSQL List type (defined in nodes/pg_list.h) +struct List; +// Forward declaration of PostgreSQL MemoryContextData type +struct MemoryContextData; + +typedef struct { + struct List *tree; // PostgreSQL parse tree (List of RawStmt nodes) + char* stderr_buffer; + PgQueryError* error; + struct MemoryContextData* context; // Internal: Memory context for the tree (do not modify) +} PgQueryRawParseResult; + +PgQueryRawParseResult pg_query_parse_raw(const char* input); +PgQueryRawParseResult pg_query_parse_raw_opts(const char* input, int parser_options); +void pg_query_free_raw_parse_result(PgQueryRawParseResult result); + +// Raw deparse (bypasses protobuf serialization) +// Takes a raw parse result and converts it back to SQL +PgQueryDeparseResult pg_query_deparse_raw(PgQueryRawParseResult parse_result); +PgQueryDeparseResult pg_query_deparse_raw_opts(PgQueryRawParseResult parse_result, struct PostgresDeparseOpts opts); + +// Node building helpers for Rust (bypasses protobuf) +// These allow Rust to construct parse trees directly +void *pg_query_deparse_enter_context(void); +void pg_query_deparse_exit_context(void *ctx); +void *pg_query_alloc_node(size_t size, int tag); +char *pg_query_pstrdup(const char *str); +void *pg_query_list_make1(void *datum); +void *pg_query_list_append(void *list, void *datum); +PgQueryDeparseResult pg_query_deparse_nodes(void *stmts); + +// Raw scan (bypasses protobuf serialization) +// Returns tokens directly without protobuf encoding + +typedef struct { + int start; + int end; + int token; // Token type (matches Token enum in protobuf) + int keyword_kind; // KeywordKind enum value +} PgQueryRawScanToken; + +typedef struct { + PgQueryRawScanToken *tokens; + size_t n_tokens; + char* stderr_buffer; + PgQueryError* error; +} PgQueryRawScanResult; + +PgQueryRawScanResult pg_query_scan_raw(const char* input); +void pg_query_free_raw_scan_result(PgQueryRawScanResult result); + +// Raw fingerprint (works with raw parse result, bypasses re-parsing) +PgQueryFingerprintResult pg_query_fingerprint_raw(PgQueryRawParseResult parse_result); + // Deprecated APIs below void pg_query_init(void); // Deprecated as of 9.5-1.4.1, this is now run automatically as needed diff --git a/pg_query_raw.h b/pg_query_raw.h new file mode 100644 index 00000000..155525cc --- /dev/null +++ b/pg_query_raw.h @@ -0,0 +1,17 @@ +// Wrapper header for raw parse tree access +// This header includes the PostgreSQL internal types needed for direct parse tree access + +#ifndef PG_QUERY_RAW_H +#define PG_QUERY_RAW_H + +#include "pg_query.h" + +// Include PostgreSQL headers needed for parse tree access +#include "src/postgres/include/postgres.h" +#include "src/postgres/include/nodes/nodes.h" +#include "src/postgres/include/nodes/pg_list.h" +#include "src/postgres/include/nodes/value.h" +#include "src/postgres/include/nodes/primnodes.h" +#include "src/postgres/include/nodes/parsenodes.h" + +#endif diff --git a/src/pg_query_deparse.c b/src/pg_query_deparse.c index 42a96374..e960382e 100644 --- a/src/pg_query_deparse.c +++ b/src/pg_query_deparse.c @@ -85,6 +85,188 @@ pg_query_free_deparse_result(PgQueryDeparseResult result) free(result.query); } +/* + * Helper functions for building nodes from Rust (bypassing protobuf) + * + * These wrap PostgreSQL's internal functions to allow Rust to construct + * parse trees directly. + */ + +/* Memory context management - exposed for Rust */ +void * +pg_query_deparse_enter_context(void) +{ + return (void *) pg_query_enter_memory_context(); +} + +void +pg_query_deparse_exit_context(void *ctx) +{ + pg_query_exit_memory_context((MemoryContext) ctx); +} + +/* Node allocation helper */ +void * +pg_query_alloc_node(size_t size, int tag) +{ + Node *result = (Node *) palloc0(size); + result->type = (NodeTag) tag; + return result; +} + +/* String duplication helper */ +char * +pg_query_pstrdup(const char *str) +{ + if (str == NULL) + return NULL; + return pstrdup(str); +} + +/* List building helpers */ +void * +pg_query_list_make1(void *datum) +{ + return (void *) list_make1(datum); +} + +void * +pg_query_list_append(void *list, void *datum) +{ + return (void *) lappend((List *) list, datum); +} + +/* Deparse a list of RawStmt nodes to SQL */ +PgQueryDeparseResult +pg_query_deparse_nodes(void *stmts_ptr) +{ + List *stmts = (List *) stmts_ptr; + PgQueryDeparseResult result = {0}; + StringInfoData str; + ListCell *lc; + + if (stmts == NULL) + { + result.query = strdup(""); + return result; + } + + /* Note: The caller must have already entered a memory context */ + PG_TRY(); + { + PostgresDeparseOpts opts; + MemSet(&opts, 0, sizeof(PostgresDeparseOpts)); + + initStringInfo(&str); + + foreach(lc, stmts) + { + deparseRawStmtOpts(&str, castNode(RawStmt, lfirst(lc)), opts); + if (lnext(stmts, lc)) + appendStringInfoString(&str, "; "); + } + result.query = strdup(str.data); + } + PG_CATCH(); + { + ErrorData *error_data; + PgQueryError *error; + + error_data = CopyErrorData(); + + error = malloc(sizeof(PgQueryError)); + error->message = strdup(error_data->message); + error->filename = strdup(error_data->filename); + error->funcname = strdup(error_data->funcname); + error->context = NULL; + error->lineno = error_data->lineno; + error->cursorpos = error_data->cursorpos; + + result.error = error; + FlushErrorState(); + } + PG_END_TRY(); + + return result; +} + +PgQueryDeparseResult +pg_query_deparse_raw(PgQueryRawParseResult parse_result) +{ + PostgresDeparseOpts opts; + + MemSet(&opts, 0, sizeof(PostgresDeparseOpts)); + return pg_query_deparse_raw_opts(parse_result, opts); +} + +PgQueryDeparseResult +pg_query_deparse_raw_opts(PgQueryRawParseResult parse_result, PostgresDeparseOpts opts) +{ + PgQueryDeparseResult result = {0}; + StringInfoData str; + ListCell *lc; + + /* If there was a parse error, propagate it */ + if (parse_result.error != NULL) + { + PgQueryError *error = malloc(sizeof(PgQueryError)); + error->message = parse_result.error->message ? strdup(parse_result.error->message) : NULL; + error->filename = parse_result.error->filename ? strdup(parse_result.error->filename) : NULL; + error->funcname = parse_result.error->funcname ? strdup(parse_result.error->funcname) : NULL; + error->context = parse_result.error->context ? strdup(parse_result.error->context) : NULL; + error->lineno = parse_result.error->lineno; + error->cursorpos = parse_result.error->cursorpos; + result.error = error; + return result; + } + + /* If tree is NULL, return empty string */ + if (parse_result.tree == NULL) + { + result.query = strdup(""); + return result; + } + + /* + * Note: We use the parse_result's memory context which is already active. + * The caller must ensure the parse_result is still valid. + */ + PG_TRY(); + { + initStringInfo(&str); + + foreach(lc, parse_result.tree) + { + deparseRawStmtOpts(&str, castNode(RawStmt, lfirst(lc)), opts); + if (lnext(parse_result.tree, lc)) + appendStringInfoString(&str, "; "); + } + result.query = strdup(str.data); + } + PG_CATCH(); + { + ErrorData *error_data; + PgQueryError *error; + + MemoryContextSwitchTo(parse_result.context); + error_data = CopyErrorData(); + + error = malloc(sizeof(PgQueryError)); + error->message = strdup(error_data->message); + error->filename = strdup(error_data->filename); + error->funcname = strdup(error_data->funcname); + error->context = NULL; + error->lineno = error_data->lineno; + error->cursorpos = error_data->cursorpos; + + result.error = error; + FlushErrorState(); + } + PG_END_TRY(); + + return result; +} + PgQueryDeparseCommentsResult pg_query_deparse_comments_for_query(const char *query) { diff --git a/src/pg_query_fingerprint.c b/src/pg_query_fingerprint.c index 8e15b821..93a5d9f4 100644 --- a/src/pg_query_fingerprint.c +++ b/src/pg_query_fingerprint.c @@ -405,3 +405,54 @@ void pg_query_free_fingerprint_result(PgQueryFingerprintResult result) free(result.fingerprint_str); free(result.stderr_buffer); } + +PgQueryFingerprintResult pg_query_fingerprint_raw(PgQueryRawParseResult parse_result) +{ + PgQueryFingerprintResult result = {0}; + + if (parse_result.error != NULL) { + // Copy error from parse result + PgQueryError* error = malloc(sizeof(PgQueryError)); + error->message = parse_result.error->message ? strdup(parse_result.error->message) : NULL; + error->filename = parse_result.error->filename ? strdup(parse_result.error->filename) : NULL; + error->funcname = parse_result.error->funcname ? strdup(parse_result.error->funcname) : NULL; + error->context = parse_result.error->context ? strdup(parse_result.error->context) : NULL; + error->lineno = parse_result.error->lineno; + error->cursorpos = parse_result.error->cursorpos; + result.error = error; + return result; + } + + // Match behavior of pg_query_fingerprint_with_opts: fingerprint even if tree is NULL + // (e.g., for comment-only or empty queries) + { + FingerprintContext ctx; + XXH64_canonical_t chash; + + _fingerprintInitContext(&ctx, NULL, false); + + if (parse_result.tree != NULL) { + _fingerprintNode(&ctx, parse_result.tree, NULL, NULL, 0); + } + + result.fingerprint = XXH3_64bits_digest(ctx.xxh_state); + _fingerprintFreeContext(&ctx); + + XXH64_canonicalFromHash(&chash, result.fingerprint); + result.fingerprint_str = malloc(17 * sizeof(char)); + int n = snprintf(result.fingerprint_str, 17, "%02x%02x%02x%02x%02x%02x%02x%02x", + chash.digest[0], chash.digest[1], chash.digest[2], chash.digest[3], + chash.digest[4], chash.digest[5], chash.digest[6], chash.digest[7]); + if (n < 0 || n >= 17) { + PgQueryError* error = malloc(sizeof(PgQueryError)); + error->message = strdup("Failed to output fingerprint string due to snprintf failure"); + result.error = error; + } + } + + if (parse_result.stderr_buffer != NULL) { + result.stderr_buffer = strdup(parse_result.stderr_buffer); + } + + return result; +} diff --git a/src/pg_query_parse.c b/src/pg_query_parse.c index 3f2c111a..04286580 100644 --- a/src/pg_query_parse.c +++ b/src/pg_query_parse.c @@ -188,3 +188,47 @@ void pg_query_free_protobuf_parse_result(PgQueryProtobufParseResult result) free(result.parse_tree.data); free(result.stderr_buffer); } + +PgQueryRawParseResult pg_query_parse_raw(const char* input) +{ + return pg_query_parse_raw_opts(input, PG_QUERY_PARSE_DEFAULT); +} + +PgQueryRawParseResult pg_query_parse_raw_opts(const char* input, int parser_options) +{ + MemoryContext ctx = NULL; + PgQueryInternalParsetreeAndError parsetree_and_error; + PgQueryRawParseResult result = {0}; + + ctx = pg_query_enter_memory_context(); + + parsetree_and_error = pg_query_raw_parse(input, parser_options); + + // These are all malloc-ed and will survive exiting the memory context, the caller is responsible to free them now + result.stderr_buffer = parsetree_and_error.stderr_buffer; + result.error = parsetree_and_error.error; + + // Keep the parse tree in memory context - caller must not exit memory context until done + result.tree = parsetree_and_error.tree; + result.context = ctx; + + // Note: We intentionally do NOT exit the memory context here because the tree + // is still allocated in it. The caller must call pg_query_free_raw_parse_result + // which will exit the memory context. + + return result; +} + +void pg_query_free_raw_parse_result(PgQueryRawParseResult result) +{ + if (result.error) { + pg_query_free_error(result.error); + } + + free(result.stderr_buffer); + + // Exit the memory context to free the parse tree + if (result.context) { + pg_query_exit_memory_context(result.context); + } +} diff --git a/src/pg_query_scan.c b/src/pg_query_scan.c index 1d0e052e..1e7e97a7 100644 --- a/src/pg_query_scan.c +++ b/src/pg_query_scan.c @@ -51,7 +51,7 @@ PgQueryScanResult pg_query_scan(const char* input) if (pipe(stderr_pipe) != 0) { PgQueryError* error = malloc(sizeof(PgQueryError)); - error->message = strdup("Failed to open pipe, too many open file descriptors") + error->message = strdup("Failed to open pipe, too many open file descriptors"); result.error = error; @@ -172,3 +172,137 @@ void pg_query_free_scan_result(PgQueryScanResult result) free(result.pbuf.data); free(result.stderr_buffer); } + +PgQueryRawScanResult pg_query_scan_raw(const char* input) +{ + MemoryContext ctx = NULL; + PgQueryRawScanResult result = {0}; + core_yyscan_t yyscanner; + core_yy_extra_type yyextra; + core_YYSTYPE yylval; + YYLTYPE yylloc; + size_t token_count = 0; + size_t i; + + ctx = pg_query_enter_memory_context(); + + MemoryContext parse_context = CurrentMemoryContext; + + char stderr_buffer[STDERR_BUFFER_LEN + 1] = {0}; +#ifndef DEBUG + int stderr_global; + int stderr_pipe[2]; +#endif + +#ifndef DEBUG + // Setup pipe for stderr redirection + if (pipe(stderr_pipe) != 0) { + PgQueryError* error = malloc(sizeof(PgQueryError)); + + error->message = strdup("Failed to open pipe, too many open file descriptors"); + + result.error = error; + + return result; + } + + fcntl(stderr_pipe[0], F_SETFL, fcntl(stderr_pipe[0], F_GETFL) | O_NONBLOCK); + + // Redirect stderr to the pipe + stderr_global = dup(STDERR_FILENO); + dup2(stderr_pipe[1], STDERR_FILENO); + close(stderr_pipe[1]); +#endif + + PG_TRY(); + { + // First pass: count tokens + yyscanner = scanner_init(input, &yyextra, &ScanKeywords, ScanKeywordTokens); + for (;; token_count++) + { + if (core_yylex(&yylval, &yylloc, yyscanner) == 0) break; + } + scanner_finish(yyscanner); + + // Allocate output array + result.tokens = malloc(sizeof(PgQueryRawScanToken) * token_count); + result.n_tokens = token_count; + + // Second pass: fill in token data + yyscanner = scanner_init(input, &yyextra, &ScanKeywords, ScanKeywordTokens); + + for (i = 0; ; i++) + { + int tok; + + tok = core_yylex(&yylval, &yylloc, yyscanner); + if (tok == 0) break; + + result.tokens[i].start = yylloc; + if (tok == SCONST || tok == USCONST || tok == BCONST || tok == XCONST || tok == IDENT || tok == UIDENT || tok == C_COMMENT) { + result.tokens[i].end = yyextra.yyllocend; + } else { + result.tokens[i].end = yylloc + ((struct yyguts_t*) yyscanner)->yyleng_r; + } + result.tokens[i].token = tok; + + switch (tok) { + #define PG_KEYWORD(a,b,c,d) case b: result.tokens[i].keyword_kind = c + 1; break; + #include "parser/kwlist.h" + #undef PG_KEYWORD + default: result.tokens[i].keyword_kind = 0; + } + } + + scanner_finish(yyscanner); + +#ifndef DEBUG + // Save stderr for result + read(stderr_pipe[0], stderr_buffer, STDERR_BUFFER_LEN); +#endif + + result.stderr_buffer = strdup(stderr_buffer); + } + PG_CATCH(); + { + ErrorData* error_data; + PgQueryError* error; + + MemoryContextSwitchTo(parse_context); + error_data = CopyErrorData(); + + // Note: This is intentionally malloc so exiting the memory context doesn't free this + error = malloc(sizeof(PgQueryError)); + error->message = strdup(error_data->message); + error->filename = strdup(error_data->filename); + error->funcname = strdup(error_data->funcname); + error->context = NULL; + error->lineno = error_data->lineno; + error->cursorpos = error_data->cursorpos; + + result.error = error; + FlushErrorState(); + } + PG_END_TRY(); + +#ifndef DEBUG + // Restore stderr, close pipe + dup2(stderr_global, STDERR_FILENO); + close(stderr_pipe[0]); + close(stderr_global); +#endif + + pg_query_exit_memory_context(ctx); + + return result; +} + +void pg_query_free_raw_scan_result(PgQueryRawScanResult result) +{ + if (result.error) { + pg_query_free_error(result.error); + } + + free(result.tokens); + free(result.stderr_buffer); +}