Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pg_test = []
[dependencies]
pgrx = "0.16.1"
regex = "1.11.1"
itertools = "0.14.0"

[dev-dependencies]
pgrx-tests = "0.16.1"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ that could cause dramatic performance degradation in production.
### Build the extension

For now, you need to build the extension locally:
`cargo build -r`
`cargo build --release`
That will generate the following files:

- control file: pg_no_seqscan.control
Expand Down
49 changes: 19 additions & 30 deletions src/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
use pgrx::pg_sys;
use pgrx::pg_sys::{List, Oid, get_namespace_name, get_rel_name, get_rel_namespace, rt_fetch};
use pgrx::pg_sys::{
GetUserId, GetUserNameFromId, List, MyDatabaseId, Oid, get_database_name, get_namespace_name,
get_rel_name, get_rel_namespace, rt_fetch,
};
use pgrx::{PgRelation, Spi};
use std::ffi::{CStr, CString, c_char};
use std::ffi::{CStr, c_char};

pub fn extract_comma_separated_setting(comma_separated_string: CString) -> Vec<String> {
pub fn comma_separated_list_contains(comma_separated_string: &CStr, value: &str) -> bool {
comma_separated_string
.to_str()
.unwrap_or_default()
.expect("comma_separated_list_contains: Invalid UTF-8 sequence")
.split(',')
.map(|s| s.trim().to_string())
.collect()
}
pub fn comma_separated_list_contains(comma_separated_string: CString, value: String) -> bool {
extract_comma_separated_setting(comma_separated_string).contains(&value)
.any(|s| s.trim() == value)
}

pub fn string_from_ptr(ptr: *const c_char) -> Option<String> {
match unsafe { CStr::from_ptr(ptr).to_str() } {
Ok(str_value) => Some(str_value.to_string()),
Err(_) => None,
fn ptr_to_option_string(ptr: *const c_char) -> Option<String> {
if ptr.is_null() {
None
} else {
unsafe { CStr::from_ptr(ptr).to_str().ok().map(String::from) }
}
}

Expand All @@ -27,33 +26,23 @@ pub fn scanned_table(scanrelid: u32, rtables: *mut List) -> Option<Oid> {
}

pub fn resolve_namespace_name(oid: Oid) -> Option<String> {
let namespace_name = unsafe { get_namespace_name(get_rel_namespace(oid)) };
if namespace_name.is_null() {
None
} else {
string_from_ptr(namespace_name)
}
ptr_to_option_string(unsafe { get_namespace_name(get_rel_namespace(oid)) })
}

pub fn resolve_table_name(table_oid: Oid) -> Option<String> {
let relname_ptr = unsafe { get_rel_name(table_oid) };
if relname_ptr.is_null() {
return None;
}

string_from_ptr(relname_ptr)
ptr_to_option_string(unsafe { get_rel_name(table_oid) })
}

pub fn current_db_name() -> String {
unsafe {
let db_oid = pg_sys::MyDatabaseId;
string_from_ptr(pg_sys::get_database_name(db_oid)).expect("Failed to get database name")
let db_oid = MyDatabaseId;
ptr_to_option_string(get_database_name(db_oid)).expect("Failed to get database name")
}
}

pub fn current_username() -> String {
let current_user = unsafe { pg_sys::GetUserNameFromId(pg_sys::GetUserId(), true) };
string_from_ptr(current_user).expect("Failed to get username")
let current_user = unsafe { GetUserNameFromId(GetUserId(), true) };
ptr_to_option_string(current_user).expect("Failed to get username")
}

pub fn get_parent_table_oid(table_oid: Oid) -> Option<Oid> {
Expand Down
Loading
Loading