diff --git a/Cargo.lock b/Cargo.lock index 429b3366dd..c972ff27a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -226,7 +226,7 @@ dependencies = [ "memchr", "pin-project-lite", "tokio", - "zstd 0.13.2", + "zstd 0.13.3", "zstd-safe 7.2.0", ] @@ -2899,6 +2899,7 @@ name = "libsql" version = "0.9.2" dependencies = [ "anyhow", + "arc-swap", "async-stream", "async-trait", "base64 0.21.7", @@ -2913,6 +2914,7 @@ dependencies = [ "http 0.2.12", "hyper", "hyper-rustls 0.25.0", + "lazy_static", "libsql-hrana", "libsql-sqlite3-parser", "libsql-sys", @@ -2937,6 +2939,7 @@ dependencies = [ "uuid", "worker", "zerocopy", + "zstd 0.13.3", ] [[package]] @@ -6703,9 +6706,9 @@ dependencies = [ [[package]] name = "zstd" -version = "0.13.2" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" dependencies = [ "zstd-safe 7.2.0", ] diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index c02bbe229f..572b2cdd30 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -14,7 +14,7 @@ thiserror = "1.0.40" futures = { version = "0.3.28", optional = true } libsql-sys = { workspace = true, optional = true, default-features = true } libsql-hrana = { workspace = true, optional = true } -tokio = { version = "1.29.1", features = ["sync"], optional = true } +tokio = { version = "1.29.1", features = ["rt-multi-thread", "sync"], optional = true } tokio-util = { version = "0.7", features = ["io-util", "codec"], optional = true } parking_lot = { version = "0.12.1", optional = true } hyper = { version = "0.14", features = ["client", "http1", "http2", "stream", "runtime"], optional = true } @@ -46,6 +46,10 @@ async-stream = { version = "0.3.5", optional = true } crc32fast = { version = "1", optional = true } chrono = { version = "0.4", optional = true } +zstd = "0.13.3" +rand = "0.8.5" +lazy_static = "1.5.0" +arc-swap = "1.7.1" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports", "async", "async_futures", "async_tokio"] } @@ -57,7 +61,7 @@ tempfile = { version = "3.7.0" } rand = "0.8.5" [features] -default = ["core", "replication", "remote", "sync", "tls"] +default = ["core", "replication", "remote", "sync", "tls", "lazy"] core = [ "libsql-sys", "dep:bitflags", @@ -116,6 +120,27 @@ sync = [ "dep:uuid", "tokio/fs" ] +lazy = [ + "core", + "parser", + "serde", + "stream", + "remote", + "replication", + "dep:tower", + "dep:hyper", + "dep:http", + "dep:tokio", + "dep:zerocopy", + "dep:bytes", + "dep:tokio", + "dep:futures", + "dep:serde_json", + "dep:crc32fast", + "dep:chrono", + "dep:uuid", + "tokio/fs" +] hrana = [ "parser", "serde", diff --git a/libsql/src/database.rs b/libsql/src/database.rs index 838eeb267f..85b74fbe92 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -76,7 +76,7 @@ cfg_sync! { } } -enum DbType { +pub(crate) enum DbType { #[cfg(feature = "core")] Memory { db: crate::local::Database }, #[cfg(feature = "core")] @@ -107,6 +107,13 @@ enum DbType { connector: crate::util::ConnectorService, version: Option, }, + #[cfg(feature = "lazy")] + Lazy { + db: crate::local::Database, + url: String, + auth_token: String, + connector: crate::util::ConnectorService, + }, } impl fmt::Debug for DbType { @@ -131,10 +138,10 @@ impl fmt::Debug for DbType { /// A struct that knows how to build [`Connection`]'s, this type does /// not do much work until the [`Database::connect`] fn is called. pub struct Database { - db_type: DbType, + pub(crate) db_type: DbType, /// The maximum replication index returned from a write performed using any connection created using this Database object. #[allow(dead_code)] - max_write_replication_index: std::sync::Arc, + pub(crate) max_write_replication_index: std::sync::Arc, } cfg_core! { @@ -727,7 +734,7 @@ impl Database { all(feature = "tls", feature = "remote"), all(feature = "tls", feature = "sync") ))] -fn connector() -> Result> { +pub(crate) fn connector() -> Result> { let mut http = hyper::client::HttpConnector::new(); http.enforce_http(false); http.set_nodelay(true); diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index 6048b79c17..01ebc81ee9 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -2,6 +2,9 @@ cfg_core! { use crate::EncryptionConfig; } +use anyhow::anyhow; +use hyper::server::conn; + use crate::{Database, Result}; use super::DbType; @@ -126,6 +129,18 @@ impl Builder<()> { } } } + + cfg_lazy! { + pub fn new_lazy( + path: impl AsRef, + url: String, + auth_token: String, + ) -> Builder { + Builder { + inner: LazyReplica { path: todo!(), remote: todo!(), read_your_writes: todo!(), sync_interval: todo!() } + } + } + } } cfg_replication_or_remote_or_sync! { @@ -531,6 +546,84 @@ cfg_replication! { } } +pub struct LazyReplica { + path: std::path::PathBuf, + remote: Remote, + read_your_writes: bool, + sync_interval: Option, +} + +impl Builder { + /// Set weather you want writes to be visible locally before the write query returns. This + /// means that you will be able to read your own writes if this is set to `true`. + /// + /// # Default + /// + /// This defaults to `true`. + pub fn read_your_writes(mut self, read_your_writes: bool) -> Builder { + self.inner.read_your_writes = read_your_writes; + self + } + + /// Set the duration at which the replicator will automatically call `sync` in the + /// background. The sync will continue for the duration that the resulted `Database` + /// type is alive for, once it is dropped the background task will get dropped and stop. + pub fn sync_interval(mut self, duration: std::time::Duration) -> Builder { + self.inner.sync_interval = Some(duration); + self + } + + /// Build the remote embedded replica database. + pub async fn build(self) -> Result { + let LazyReplica { + path, + remote: + Remote { + url, + auth_token, + connector, + version, + }, + read_your_writes, + sync_interval, + } = self.inner; + + let connector = if let Some(connector) = connector { + connector + } else { + let https = super::connector()?; + use tower::ServiceExt; + + let svc = https + .map_err(|e| e.into()) + .map_response(|s| Box::new(s) as Box); + + crate::util::ConnectorService::new(svc) + }; + + let path = path + .to_str() + .ok_or(anyhow!("unable to convert path to string"))? + .to_owned(); + Ok(Database { + db_type: DbType::Lazy { + db: crate::local::Database::open_local_lazy( + connector.clone(), + path, + crate::OpenFlags::default(), + url.clone(), + auth_token.clone(), + ) + .await?, + url, + auth_token, + connector, + }, + max_write_replication_index: Default::default(), + }) + } +} + cfg_sync! { /// Remote replica configuration type in [`Builder`]. pub struct SyncedDatabase { @@ -694,7 +787,7 @@ cfg_remote! { } cfg_replication_or_remote_or_sync! { - fn wrap_connector(connector: C) -> crate::util::ConnectorService +pub(crate) fn wrap_connector(connector: C) -> crate::util::ConnectorService where C: tower::Service + Send + Clone + Sync + 'static, C::Response: crate::util::Socket, diff --git a/libsql/src/errors.rs b/libsql/src/errors.rs index 8d3ed0e581..2a88de7279 100644 --- a/libsql/src/errors.rs +++ b/libsql/src/errors.rs @@ -57,6 +57,8 @@ pub enum Error { InvalidBlobSize(usize), #[error("sync error: {0}")] Sync(crate::BoxError), + #[error("lazy error: {0}")] + Lazy(crate::BoxError), #[error("WAL frame insert conflict")] WalConflict, } @@ -75,6 +77,13 @@ impl From for Error { } } +#[cfg(feature = "lazy")] +impl From for Error { + fn from(e: anyhow::Error) -> Self { + Error::Lazy(e.into()) + } +} + impl From for Error { fn from(_: std::convert::Infallible) -> Self { unreachable!() diff --git a/libsql/src/lazy/lazy.rs b/libsql/src/lazy/lazy.rs new file mode 100644 index 0000000000..30598e5451 --- /dev/null +++ b/libsql/src/lazy/lazy.rs @@ -0,0 +1,408 @@ +use std::path::Path; +use std::sync::Arc; + +use anyhow::anyhow; +use arc_swap::access::Access; +use arc_swap::ArcSwap; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use http::HeaderValue; +use hyper::Body; +use serde::{Deserialize, Serialize}; +use zstd::bulk::Decompressor; + +use crate::local::Connection; +use crate::params::Params; +use crate::sync::atomic_write; +use crate::util::ConnectorService; + +use super::vfs::{register_vfs, RegisteredVfs}; +use super::vfs_default::{get_default_vfs, Sqlite3Vfs}; +use super::vfs_lazy::LazyVfs; + +const PULL_PAGES_CHUNK_SIZE: usize = 10; +const PULL_PROTOCOL_RETRIES: usize = 3; +pub const LAZY_VFS_NAME: &[u8] = b"turso-vfs-lazy\0"; + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct PullPlanReqBody { + pub start_revision: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PullPlanRespBody { + pub steps: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PullPlanRespStep { + pub end_revision: String, + pub pages: Vec, + pub size_after_in_pages: usize, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct PullPagesReqBody { + pub start_revision: Option, + pub end_revision: String, + pub server_pages: Vec, + pub client_pages: Vec, + pub accept_encodings: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct LazyMetadata { + pub revision: Option, +} + +pub trait PageServer { + fn get_revision(&self) -> String; + async fn set_revision(&self, revision: String) -> anyhow::Result<()>; + async fn pull_plan(&self, request: &PullPlanReqBody) -> anyhow::Result; + async fn pull_pages(&self, request: &PullPagesReqBody) -> anyhow::Result>; +} + +pub struct TursoPageServer { + pub endpoint: String, + pub auth_token: Option, + pub client: hyper::Client, + pub revision: ArcSwap, +} + +impl TursoPageServer { + async fn send_request( + &self, + mut uri: String, + body: Bytes, + max_retries: usize, + ) -> anyhow::Result { + let mut retries = 0; + loop { + let mut request = http::Request::post(uri.clone()); + match &self.auth_token { + Some(auth_token) => { + let headers = request.headers_mut().expect("valid http request"); + headers.insert("Authorization", HeaderValue::from_str(&auth_token)?); + } + None => {} + } + let request = request.body(body.clone().into()).expect("valid body"); + let response = self.client.request(request).await?; + + if response.status().is_success() { + return Ok(response.into_body()); + } + + if response.status().is_redirection() { + uri = match response.headers().get(hyper::header::LOCATION) { + Some(loc) => loc.to_str()?.to_string(), + None => return Err(anyhow!("unable to parse location redirect header")), + }; + } + + // If we've retried too many times or the error is not a server error, + // return the error. + if retries > max_retries || !response.status().is_server_error() { + let status = response.status(); + let body = hyper::body::to_bytes(response.into_body()).await?; + let msg = String::from_utf8_lossy(&body[..]); + return Err(anyhow!( + "request failed: url={}, status={}, body={}", + uri, + status, + msg + )); + } + + let delay = std::time::Duration::from_millis(100 * (1 << retries)); + tokio::time::sleep(delay).await; + retries += 1; + } + } +} + +const PAGE_BATCH_ENCODING_RAW: u32 = 0; +const PAGE_BATCH_ENCODING_ZSTD: u32 = 1; + +impl PageServer for TursoPageServer { + fn get_revision(&self) -> String { + let revision = self.revision.load_full().to_string(); + tracing::info!("get_revision: {}", revision); + revision + } + async fn set_revision(&self, revision: String) -> anyhow::Result<()> { + tracing::info!("set_revision: {}", revision); + self.revision.store(Arc::new(revision)); + Ok(()) + } + async fn pull_plan(&self, request: &PullPlanReqBody) -> anyhow::Result { + let body = serde_json::to_vec(&request)?; + let response = self + .send_request( + format!("{}/pull-plan", self.endpoint), + Bytes::from(body), + PULL_PROTOCOL_RETRIES, + ) + .await?; + + let response_bytes = hyper::body::to_bytes(response).await?; + tracing::info!( + "pull_plan(start_revision={:?}): response_bytes={}", + request.start_revision, + response_bytes.len() + ); + let response = serde_json::from_slice(&response_bytes)?; + Ok(response) + } + + async fn pull_pages(&self, request: &PullPagesReqBody) -> anyhow::Result> { + let body = serde_json::to_vec(&request)?; + let response = self + .send_request( + format!("{}/pull-pages", self.endpoint), + Bytes::from(body), + PULL_PROTOCOL_RETRIES, + ) + .await?; + let mut response_bytes = hyper::body::to_bytes(response).await?; + tracing::info!( + "pull_pages(start_revision={:?}, end_revision={}): response_bytes={}", + request.start_revision, + request.end_revision, + response_bytes.len() + ); + let encoded_meta_length = response_bytes.get_u32_le(); + let encoded_pages_length = response_bytes.get_u32_le(); + let pages_count = response_bytes.get_u32_le(); + let page_size = response_bytes.get_u32_le(); + let (mut meta, mut pages) = response_bytes.split_at(encoded_meta_length as usize); + assert!(pages.len() == encoded_pages_length as usize); + let encoding_type = meta.get_u32_le(); + let mut result = Vec::new(); + match encoding_type { + PAGE_BATCH_ENCODING_RAW => { + while !meta.is_empty() { + let page_no = meta.get_u32_le(); + let page; + (page, pages) = pages.split_at(page_size as usize); + result.push((page_no, Bytes::from(page.to_vec()))); + } + } + PAGE_BATCH_ENCODING_ZSTD => { + let dictionary_pages_count = meta.get_u32_le(); + assert!(dictionary_pages_count == 0); + let mut zstd = Decompressor::new()?; + let pages = zstd.decompress(&pages, (pages_count * page_size) as usize)?; + for i in 0..pages_count { + let page_no = meta.get_u32_le(); + let page = Bytes::from( + pages[(i * page_size) as usize..((i + 1) * page_size) as usize].to_vec(), + ); + result.push((page_no, page)); + } + } + _ => return Err(anyhow!("unexpected encoding type: {}", encoding_type)), + } + assert!(result.len() == pages_count as usize); + Ok(result) + } +} + +pub struct LazyContext { + db_path: String, + meta_path: String, + metadata: LazyMetadata, + encoding: String, + page_server: Arc

, + vfs: RegisteredVfs>, +} + +async fn read_metadata(meta_path: &String) -> anyhow::Result> { + let exists = Path::new(&meta_path).try_exists()?; + if !exists { + tracing::debug!("no metadata info file found"); + return Ok(None); + } + + let contents = tokio::fs::read(&meta_path).await?; + let metadata = serde_json::from_slice::(&contents)?; + + tracing::debug!( + "read lazy metadata for meta_path={:?}, metadata={:?}", + meta_path, + metadata + ); + Ok(Some(metadata)) +} + +async fn write_metadata(meta_path: &String, metadata: &LazyMetadata) -> anyhow::Result<()> { + let contents = serde_json::to_vec(metadata)?; + atomic_write(&meta_path, &contents).await?; + Ok(()) +} + +impl LazyContext { + pub async fn new( + db_path: String, + meta_path: String, + encoding: String, + page_server: Arc, + ) -> anyhow::Result { + let metadata = read_metadata(&meta_path).await?; + let metadata = metadata.unwrap_or(LazyMetadata { revision: None }); + + if let Some(revision) = &metadata.revision { + page_server.set_revision(revision.clone()).await?; + } + + let vfs_default = get_default_vfs("turso-vfs-default"); + let vfs_lazy = LazyVfs::new("turso-vfs-lazy", vfs_default, page_server.clone()); + let vfs = register_vfs(vfs_lazy)?; + let ctx = Self { + db_path, + meta_path, + page_server, + encoding, + metadata, + vfs, + }; + Ok(ctx) + } + async fn pull(&mut self, conn: &Connection) -> anyhow::Result<()> { + let start_revision = self.metadata.revision.clone(); + let pull_plan_request = PullPlanReqBody { + start_revision: start_revision.clone(), + }; + let pull_plan = self.page_server.pull_plan(&pull_plan_request).await?; + tracing::debug!( + "pull_plan(start_revision={:?}): steps={}", + pull_plan_request.start_revision, + pull_plan.steps.len() + ); + for step in pull_plan.steps { + tracing::debug!( + "pull_plan(start_revision={:?}): next step, pages={}", + pull_plan_request.start_revision, + step.pages.len() + ); + let frames_count = conn.wal_frame_count(); + + let insert_handle = conn.wal_insert_handle()?; + + let mut frame_buffer = BytesMut::new(); + let mut received_pages = 0; + for chunk in step.pages.chunks(PULL_PAGES_CHUNK_SIZE) { + let pull_pages_request = PullPagesReqBody { + start_revision: start_revision.clone(), + end_revision: step.end_revision.clone(), + server_pages: chunk.to_vec(), + client_pages: vec![], + accept_encodings: vec![self.encoding.clone()], + }; + let pages = self.page_server.pull_pages(&pull_pages_request).await?; + assert!(pages.len() == chunk.len()); + for (page_no, page) in pages { + received_pages += 1; + let size_after = if received_pages == step.pages.len() { + step.size_after_in_pages as u32 + } else { + 0 + }; + + tracing::trace!("pull: insert page={}, size_after={}", page_no, size_after); + frame_buffer.clear(); + frame_buffer.put_u32(page_no); + frame_buffer.put_u32(size_after); + frame_buffer.put_u32(0); + frame_buffer.put_u32(0); + frame_buffer.put_u32(0); + frame_buffer.put_u32(0); + frame_buffer.extend_from_slice(&page); + + // insert_handle.insert(&frame_buffer)?; + } + } + + insert_handle.end()?; + + // assert!(conn.wal_frame_count() == frames_count + step.pages.len() as u32); + + let next_revision = Some(step.end_revision.clone()); + let next_metadata = LazyMetadata { + revision: next_revision, + }; + write_metadata(&self.meta_path, &next_metadata).await?; + self.page_server.set_revision(step.end_revision).await?; + self.metadata = next_metadata; + } + + // checkpoint local DB just to reduce storage overhead + conn.wal_checkpoint(true)?; + let rows = conn.query("SELECT COUNT(*) FROM t", Params::None)?.unwrap(); + let row = rows.next().unwrap().unwrap(); + tracing::info!("rows: {:?}", row.get::(0)); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use arc_swap::ArcSwap; + use std::{sync::Arc, time::Duration}; + use tempfile::tempdir; + use tokio::sync::Mutex; + + use crate::{ + database::connector, + lazy::lazy::{LazyContext, TursoPageServer}, + local::{Connection, Database}, + }; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn lazy_pull_test() { + tracing_subscriber::fmt::init(); + let temp_dir = tempdir().unwrap(); + let db_path = temp_dir.path().join("test.db"); + tracing::info!("db_path: {:?}", db_path); + let db_path = db_path.to_str().unwrap().to_owned(); + + let connector = connector().unwrap(); + let svc = connector + .map_err(|e| e.into()) + .map_response(|s| Box::new(s) as Box); + let svc = crate::util::ConnectorService::new(svc); + use tower::ServiceExt; + let svc = svc + .map_err(|e| e.into()) + .map_response(|s| Box::new(s) as Box); + let connector = crate::util::ConnectorService::new(svc); + let client = hyper::client::Client::builder().build::<_, hyper::Body>(connector); + + let page_server = Arc::new(TursoPageServer { + endpoint: "http://c--ap--b.localhost:8080".into(), + auth_token: None, + client, + revision: ArcSwap::new(Arc::new("".into())), + }); + let lazy_ctx = LazyContext::new( + db_path.clone(), + format!("{}-metadata", db_path), + "zstd".into(), + page_server, + ) + .await + .unwrap(); + let lazy_ctx = Arc::new(Mutex::new(lazy_ctx)); + let conn = Connection::connect(&Database { + db_path: db_path.clone(), + flags: crate::OpenFlags::default(), + replication_ctx: None, + sync_ctx: None, + lazy_ctx: Some(lazy_ctx.clone()), + }) + .unwrap(); + let mut lazy_ctx = lazy_ctx.lock().await; + lazy_ctx.pull(&conn).await.unwrap(); + std::thread::sleep(Duration::from_secs(100000)); + } +} diff --git a/libsql/src/lazy/mod.rs b/libsql/src/lazy/mod.rs new file mode 100644 index 0000000000..d31dd004d2 --- /dev/null +++ b/libsql/src/lazy/mod.rs @@ -0,0 +1,4 @@ +pub mod lazy; +pub mod vfs; +pub mod vfs_default; +pub mod vfs_lazy; diff --git a/libsql/src/lazy/vfs.rs b/libsql/src/lazy/vfs.rs new file mode 100644 index 0000000000..9d7ebf3402 --- /dev/null +++ b/libsql/src/lazy/vfs.rs @@ -0,0 +1,547 @@ +use std::{ + marker::PhantomData, + time::{Duration, SystemTime}, +}; + +use anyhow::anyhow; +use libsql_sys::ffi; +use rand::RngCore; + +/// sqlite3 can return extended code on success (like SQLITE_OK_SYMLINK) +/// in order to be fully compatible we must return full code on success too +#[derive(Debug)] +pub struct VfsResult { + pub value: T, + pub rc: i32, +} + +#[derive(thiserror::Error, Debug)] +pub enum VfsError { + #[error("sqlite: rc={rc}, message={message}")] + Sqlite { rc: i32, message: String }, + #[error("anyhow: msg={0}")] + Anyhow(anyhow::Error), +} + +pub trait VfsFile { + fn path(&self) -> Option<&str>; + + // See https://www.sqlite.org/c3ref/io_methods.html for IO methods reference + fn close(&mut self) -> Result, VfsError>; + fn read(&mut self, buf: &mut [u8], offset: i64) -> Result, VfsError>; + fn write(&mut self, buf: &[u8], offset: i64) -> Result, VfsError>; + fn truncate(&mut self, size: i64) -> Result, VfsError>; + fn sync(&mut self, flags: i32) -> Result, VfsError>; + fn file_size(&mut self) -> Result, VfsError>; + + fn lock(&mut self, upgrade_to: i32) -> Result, VfsError>; + fn unlock(&mut self, downgrade_to: i32) -> Result, VfsError>; + fn check_reserved_lock(&mut self) -> Result, VfsError>; + + fn file_control( + &mut self, + op: i32, + arg: *mut std::ffi::c_void, + ) -> Result, VfsError>; + fn sector_size(&mut self) -> i32; + fn device_characteristics(&mut self) -> i32; + + fn shm_map( + &mut self, + region: i32, + region_size: i32, + extend: bool, + ) -> Result, VfsError>; + fn shm_unmap(&mut self, delete: bool) -> Result, VfsError>; + fn shm_lock(&mut self, offset: i32, count: i32, flags: i32) -> Result, VfsError>; + fn shm_barrier(&mut self); +} + +// VFS represents *shared* instance of virtual file system manager entry point +// As it is shared - it implements Clone trait and can be safely cloned and passed around +// Internally, VFS wraps internal structures with Rc which made cloning cheap and zero-alloc +pub trait Vfs: Clone +where + Self: Sized, +{ + // File type must have C data layout with (*mut ffi::sqlite3_io_methods) pointer as a first field of the struct + type File: VfsFile; + + // zero-terminated static string which will be used to register VFS in sqlite3 DB + fn name(&self) -> &std::ffi::CStr; + fn max_pathname(&self) -> i32; + + // See https://www.sqlite.org/c3ref/vfs.html for VFS methods reference + fn open( + &self, + filename: Option<&std::ffi::CStr>, + flags: i32, + file: &mut Self::File, + ) -> Result, VfsError>; + fn delete(&self, filename: &std::ffi::CStr, sync_dir: bool) -> Result, VfsError>; + fn access(&self, filename: &std::ffi::CStr, flags: i32) -> Result, VfsError>; + fn full_pathname( + &self, + filename: &std::ffi::CStr, + full_buffer: &mut [u8], + ) -> Result, VfsError>; + fn sleep(&self, duration: Duration) -> Result, VfsError>; +} + +pub type VfsName = std::ffi::CString; + +/// Struct that holds all necessary objects after successfull registration +/// Instance of this struct must be valid for the lifetime of the sqlite3 db after successful registration +pub struct RegisteredVfs { + native_vfs_struct: Box, + _phantom: PhantomData, +} + +impl Drop for RegisteredVfs { + fn drop(&mut self) { + unsafe { + ffi::sqlite3_vfs_unregister(&mut *self.native_vfs_struct as *mut ffi::sqlite3_vfs); + let _ = Box::from_raw(self.native_vfs_struct.pAppData as *mut V); + } + } +} + +pub fn register_vfs(vfs: V) -> anyhow::Result> { + let vfs = Box::new(vfs); + let mut native_vfs_struct = sqlite3_vfs(vfs); + let rc = unsafe { ffi::sqlite3_vfs_register(&mut *native_vfs_struct, 0) }; + if rc != ffi::SQLITE_OK { + Err(anyhow!("register failed: {}", rc)) + } else { + Ok(RegisteredVfs { + native_vfs_struct, + _phantom: PhantomData, + }) + } +} + +fn sqlite3_vfs(vfs: Box) -> Box { + Box::new(ffi::sqlite3_vfs { + iVersion: 3, + szOsFile: size_of::() as i32, + mxPathname: vfs.max_pathname(), + pNext: std::ptr::null_mut(), + zName: vfs.name().as_ptr() as *const std::ffi::c_char, + pAppData: Box::into_raw(vfs) as *mut std::ffi::c_void, + xOpen: Some(xOpen::), + xDelete: Some(xDelete::), + xAccess: Some(xAccess::), + xFullPathname: Some(xFullPathname::), + xCurrentTime: Some(xCurrentTime::), + xCurrentTimeInt64: Some(xCurrentTimeInt64::), + + xSleep: Some(xSleep::), + xGetLastError: Some(xGetLastError::), + xRandomness: Some(xRandomness::), + + // It's fine to omit these methods as these are "Interfaces for opening a shared library" + xDlOpen: None, + xDlError: None, + xDlSym: None, + xDlClose: None, + + // It's fine to omit these methods as "The xSetSystemCall(), xGetSystemCall(), and xNestSystemCall() interfaces are not used by the SQLite core" + xSetSystemCall: None, + xGetSystemCall: None, + xNextSystemCall: None, + }) +} + +fn sqlite3_io_methods() -> &'static ffi::sqlite3_io_methods { + &ffi::sqlite3_io_methods { + iVersion: 2, + xClose: Some(xClose::), + xRead: Some(xRead::), + xWrite: Some(xWrite::), + xTruncate: Some(xTruncate::), + xSync: Some(xSync::), + xFileSize: Some(xFileSize::), + xLock: Some(xLock::), + xUnlock: Some(xUnlock::), + xCheckReservedLock: Some(xCheckReservedLock::), + xFileControl: Some(xFileControl::), + xSectorSize: Some(xSectorSize::), + xDeviceCharacteristics: Some(xDeviceCharacteristics::), + xShmMap: Some(xShmMap::), + xShmLock: Some(xShmLock::), + xShmBarrier: Some(xShmBarrier::), + xShmUnmap: Some(xShmUnmap::), + xFetch: None, + xUnfetch: None, + } +} + +fn convert_err_to_rc(e: &VfsError) -> i32 { + match e { + VfsError::Sqlite { rc, .. } => *rc, + VfsError::Anyhow(_) => { + tracing::error!("vfs error: {:?}", e); + ffi::SQLITE_IOERR + } + } +} + +pub fn convert_rc_result(rc: i32, value: T, message: &str) -> Result, VfsError> { + // "The least significant 8 bits of the result code define a broad category and are called the "primary result code" + // see https://www.sqlite.org/rescode.html#primary_result_codes_versus_extended_result_codes + if (rc & 0xff) != ffi::SQLITE_OK { + let message = message.to_string(); + Err(VfsError::Sqlite { rc, message }) + } else { + Ok(VfsResult { rc, value }) + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xOpen( + arg1: *mut ffi::sqlite3_vfs, + zName: ffi::sqlite3_filename, + arg2: *mut ffi::sqlite3_file, + flags: std::os::raw::c_int, + pOutFlags: *mut std::os::raw::c_int, +) -> std::os::raw::c_int { + // SQLite expects that the sqlite3_file.pMethods element will be valid after xOpen returns regardless of the success or failure of the xOpen call + (*arg2).pMethods = std::ptr::null(); + + let vfs = (*arg1).pAppData as *mut V; + let vfs_file = arg2 as *mut V::File; + let name = if !zName.is_null() { + Some(std::ffi::CStr::from_ptr(zName)) + } else { + None + }; + + match (*vfs).open(name, flags, &mut *vfs_file) { + Ok(VfsResult { value, rc }) => { + if !pOutFlags.is_null() { + (*pOutFlags) = value; + } + (*arg2).pMethods = + std::ptr::from_ref(sqlite3_io_methods::()) as *mut ffi::sqlite3_io_methods; + + rc + } + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xDelete( + arg1: *mut ffi::sqlite3_vfs, + zName: *const ::std::os::raw::c_char, + syncDir: std::os::raw::c_int, +) -> std::os::raw::c_int { + let vfs = (*arg1).pAppData as *mut V; + let name = std::ffi::CStr::from_ptr(zName); + + match (*vfs).delete(name, syncDir != 0) { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xAccess( + arg1: *mut ffi::sqlite3_vfs, + zName: *const ::std::os::raw::c_char, + flags: std::os::raw::c_int, + pResOut: *mut std::os::raw::c_int, +) -> std::os::raw::c_int { + let vfs = (*arg1).pAppData as *mut V; + let name = std::ffi::CStr::from_ptr(zName); + + match (*vfs).access(name, flags) { + Ok(VfsResult { value, rc }) => { + *pResOut = if value { 1 } else { 0 }; + rc + } + Err(e) => convert_err_to_rc(&e), + } +} + +// This logic ported from SQLite implementation: https://github.com/sqlite/sqlite/blob/98772d6e75f4033373c806e4e44f675971e55e38/src/os_unix.c#L6908 +// The logic here is to calculate current time and date as a Julian Day number +#[allow(non_snake_case)] +unsafe extern "C" fn xCurrentTime( + _arg1: *mut ffi::sqlite3_vfs, + arg2: *mut f64, +) -> ::std::os::raw::c_int { + let mut now: ffi::sqlite3_int64 = 0; + let rc = xCurrentTimeInt64::(_arg1, &mut now); + if rc == ffi::SQLITE_OK { + *arg2 = (now as f64) / 86400000.0; + } + rc +} + +// This logic ported from SQLite implementation: https://github.com/sqlite/sqlite/blob/98772d6e75f4033373c806e4e44f675971e55e38/src/os_unix.c#L6876 +// The logic here is to calculate the number of milliseconds since the Julian epoch of noon in Greenwich on November 24, 4714 B.C according to the proleptic Gregorian calendar +#[allow(non_snake_case)] +#[allow(clippy::extra_unused_type_parameters)] +unsafe extern "C" fn xCurrentTimeInt64( + _arg1: *mut ffi::sqlite3_vfs, + arg2: *mut ffi::sqlite3_int64, +) -> ::std::os::raw::c_int { + const JULIAN_EPOCH_OFFSET: i64 = 24405875i64 * 8640000i64; + let Ok(unix_epoch_time) = SystemTime::now().duration_since(std::time::UNIX_EPOCH) else { + return ffi::SQLITE_ERROR; + }; + *arg2 = JULIAN_EPOCH_OFFSET + unix_epoch_time.as_millis() as i64; + ffi::SQLITE_OK +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xSleep( + arg1: *mut ffi::sqlite3_vfs, + microseconds: ::std::os::raw::c_int, +) -> ::std::os::raw::c_int { + let vfs = (*arg1).pAppData as *mut V; + match (*vfs).sleep(Duration::from_micros(microseconds as u64)) { + Ok(VfsResult { value: (), rc }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +#[allow(clippy::extra_unused_type_parameters)] +unsafe extern "C" fn xGetLastError( + _arg1: *mut ffi::sqlite3_vfs, + _arg2: ::std::os::raw::c_int, + _arg3: *mut ::std::os::raw::c_char, +) -> ::std::os::raw::c_int { + std::io::Error::last_os_error().raw_os_error().unwrap_or(0) +} + +#[allow(non_snake_case)] +#[allow(clippy::extra_unused_type_parameters)] +unsafe extern "C" fn xRandomness( + _arg1: *mut ffi::sqlite3_vfs, + nByte: ::std::os::raw::c_int, + zOut: *mut ::std::os::raw::c_char, +) -> ::std::os::raw::c_int { + // this is unexpected as RANDOM() seems to not use xRandomness from VFS + rand::thread_rng().fill_bytes(std::slice::from_raw_parts_mut( + zOut as *mut u8, + nByte as usize, + )); + ffi::SQLITE_OK +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xFullPathname( + arg1: *mut ffi::sqlite3_vfs, + zName: *const std::os::raw::c_char, + nOut: std::os::raw::c_int, + zOut: *mut std::os::raw::c_char, +) -> std::os::raw::c_int { + let vfs = (*arg1).pAppData as *mut V; + let name = std::ffi::CStr::from_ptr(zName); + let buffer = unsafe { std::slice::from_raw_parts_mut::(zOut as *mut u8, nOut as usize) }; + match (*vfs).full_pathname(name, buffer) { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xClose(arg1: *mut ffi::sqlite3_file) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).close() { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xRead( + arg1: *mut ffi::sqlite3_file, + arg2: *mut ::std::os::raw::c_void, + iAmt: std::os::raw::c_int, + iOfst: ffi::sqlite3_int64, +) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + let buffer = unsafe { std::slice::from_raw_parts_mut::(arg2 as *mut u8, iAmt as usize) }; + match (*vfs_file).read(buffer, iOfst) { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xWrite( + arg1: *mut ffi::sqlite3_file, + arg2: *const ::std::os::raw::c_void, + iAmt: ::std::os::raw::c_int, + iOfst: ffi::sqlite3_int64, +) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + let buffer = unsafe { std::slice::from_raw_parts_mut::(arg2 as *mut u8, iAmt as usize) }; + match (*vfs_file).write(buffer, iOfst) { + Ok(VfsResult { rc, .. }) => ffi::SQLITE_OK, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xTruncate( + arg1: *mut ffi::sqlite3_file, + size: ffi::sqlite3_int64, +) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).truncate(size) { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xSync( + arg1: *mut ffi::sqlite3_file, + flags: ::std::os::raw::c_int, +) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).sync(flags) { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xFileSize( + arg1: *mut ffi::sqlite3_file, + pSize: *mut ffi::sqlite3_int64, +) -> ::std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).file_size() { + Ok(VfsResult { value, rc }) => { + *pSize = value; + rc + } + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xLock( + arg1: *mut ffi::sqlite3_file, + arg2: ::std::os::raw::c_int, +) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).lock(arg2) { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xUnlock( + arg1: *mut ffi::sqlite3_file, + arg2: ::std::os::raw::c_int, +) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).unlock(arg2) { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xCheckReservedLock( + arg1: *mut ffi::sqlite3_file, + pResOut: *mut ::std::os::raw::c_int, +) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).check_reserved_lock() { + Ok(VfsResult { value, rc }) => { + *pResOut = if value { 1 } else { 0 }; + rc + } + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xFileControl( + arg1: *mut ffi::sqlite3_file, + op: ::std::os::raw::c_int, + pArg: *mut ::std::os::raw::c_void, +) -> ::std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).file_control(op, pArg) { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xSectorSize(arg1: *mut ffi::sqlite3_file) -> ::std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + let sector_size = (*vfs_file).sector_size(); + sector_size +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xDeviceCharacteristics( + arg1: *mut ffi::sqlite3_file, +) -> ::std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + let device_characteristics = (*vfs_file).device_characteristics(); + device_characteristics +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xShmMap( + arg1: *mut ffi::sqlite3_file, + iPg: ::std::os::raw::c_int, + pgsz: ::std::os::raw::c_int, + arg2: ::std::os::raw::c_int, + arg3: *mut *mut ::std::os::raw::c_void, +) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).shm_map(iPg, pgsz, arg2 != 0) { + Ok(VfsResult { rc, value }) => { + if !arg3.is_null() { + unsafe { *arg3 = value }; + } + rc + } + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xShmLock( + arg1: *mut ffi::sqlite3_file, + offset: ::std::os::raw::c_int, + n: ::std::os::raw::c_int, + flags: ::std::os::raw::c_int, +) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).shm_lock(offset, n, flags) { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xShmBarrier(arg1: *mut ffi::sqlite3_file) { + let vfs_file: *mut ::File = arg1 as *mut V::File; + (*vfs_file).shm_barrier(); +} + +#[allow(non_snake_case)] +unsafe extern "C" fn xShmUnmap( + arg1: *mut ffi::sqlite3_file, + deleteFlag: ::std::os::raw::c_int, +) -> std::os::raw::c_int { + let vfs_file = arg1 as *mut V::File; + match (*vfs_file).shm_unmap(deleteFlag != 0) { + Ok(VfsResult { rc, .. }) => rc, + Err(e) => convert_err_to_rc(&e), + } +} diff --git a/libsql/src/lazy/vfs_default.rs b/libsql/src/lazy/vfs_default.rs new file mode 100644 index 0000000000..7eda2ac186 --- /dev/null +++ b/libsql/src/lazy/vfs_default.rs @@ -0,0 +1,270 @@ +use std::{rc::Rc, sync::Arc}; + +use libsql_sys::ffi; +use tokio::sync::Mutex; + +use super::vfs::{convert_rc_result, Vfs, VfsError, VfsFile, VfsResult}; + +static DEFAULT_MAX_PATH_LENGTH: i32 = 1024; + +#[repr(C)] +struct Sqlite3VfsInner { + vfs: *mut ffi::sqlite3_vfs, + name: std::ffi::CString, +} + +pub struct Sqlite3Vfs { + inner: Arc, +} + +impl Clone for Sqlite3Vfs { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +unsafe impl Send for Sqlite3Vfs {} +unsafe impl Sync for Sqlite3Vfs {} + +// Wrapper around VFS returned by SQLite which help this crate to get full control over VFS execution +// So, we will install our custom io_methods which will be proxied to pMethods of inner sqlite_file +// Created by original VFS xOpen method +#[repr(C)] +pub struct Sqlite3VfsFile { + io_methods: *mut ffi::sqlite3_io_methods, + z_filename: *const std::ffi::c_char, + inner: *mut ffi::sqlite3_file, + inner_layout: std::alloc::Layout, +} + +impl VfsFile for Sqlite3VfsFile { + fn path(&self) -> Option<&str> { + if self.z_filename.is_null() { + None + } else { + unsafe { std::ffi::CStr::from_ptr(self.z_filename).to_str() }.ok() + } + } + + fn close(&mut self) -> Result, VfsError> { + let rc = unsafe { (*(*self.inner).pMethods).xClose.unwrap()(self.inner) }; + unsafe { std::alloc::dealloc(self.inner as *mut u8, self.inner_layout) }; + + convert_rc_result(rc, (), "sqlite3_file::close failed") + } + + fn read(&mut self, buf: &mut [u8], offset: i64) -> Result, VfsError> { + let rc = unsafe { + (*(*self.inner).pMethods).xRead.unwrap()( + self.inner, + buf.as_mut_ptr() as *mut std::ffi::c_void, + buf.len() as i32, + offset, + ) + }; + convert_rc_result(rc, (), "sqlite3_file::read failed") + } + + fn write(&mut self, buf: &[u8], offset: i64) -> Result, VfsError> { + let rc = unsafe { + (*(*self.inner).pMethods).xWrite.unwrap()( + self.inner, + buf.as_ptr() as *mut std::ffi::c_void, + buf.len() as i32, + offset, + ) + }; + convert_rc_result(rc, (), "sqlite3_file::write failed") + } + + fn truncate(&mut self, size: i64) -> Result, VfsError> { + let rc = unsafe { (*(*self.inner).pMethods).xTruncate.unwrap()(self.inner, size) }; + convert_rc_result(rc, (), "sqlite3_file::truncate failed") + } + + fn sync(&mut self, flags: i32) -> Result, VfsError> { + let rc = unsafe { (*(*self.inner).pMethods).xSync.unwrap()(self.inner, flags) }; + convert_rc_result(rc, (), "sqlite3_file::sync failed") + } + + fn file_size(&mut self) -> Result, VfsError> { + let mut result: i64 = 0; + let rc = unsafe { (*(*self.inner).pMethods).xFileSize.unwrap()(self.inner, &mut result) }; + convert_rc_result(rc, result, "sqlite3_file::file_size failed") + } + + fn lock(&mut self, upgrade_to: i32) -> Result, VfsError> { + let rc = unsafe { (*(*self.inner).pMethods).xLock.unwrap()(self.inner, upgrade_to) }; + convert_rc_result(rc, (), "sqlite3_file::lock failed") + } + + fn unlock(&mut self, downgrade_to: i32) -> Result, VfsError> { + let rc = unsafe { (*(*self.inner).pMethods).xUnlock.unwrap()(self.inner, downgrade_to) }; + convert_rc_result(rc, (), "sqlite3_file::unlock failed") + } + + fn check_reserved_lock(&mut self) -> Result, VfsError> { + let mut result: i32 = 0; + let rc = unsafe { + (*(*self.inner).pMethods).xCheckReservedLock.unwrap()(self.inner, &mut result) + }; + convert_rc_result(rc, result != 0, "sqlite3_file::check_reserved_lock failed") + } + + #[allow(clippy::not_unsafe_ptr_arg_deref)] + fn file_control( + &mut self, + op: i32, + arg: *mut std::ffi::c_void, + ) -> Result, VfsError> { + let rc = unsafe { (*(*self.inner).pMethods).xFileControl.unwrap()(self.inner, op, arg) }; + convert_rc_result(rc, (), "sqlite3_file::file_control failed") + } + + fn sector_size(&mut self) -> i32 { + unsafe { (*(*self.inner).pMethods).xSectorSize.unwrap()(self.inner) } + } + + fn device_characteristics(&mut self) -> i32 { + unsafe { (*(*self.inner).pMethods).xDeviceCharacteristics.unwrap()(self.inner) } + } + + fn shm_map( + &mut self, + region: i32, + region_size: i32, + extend: bool, + ) -> Result, VfsError> { + let mut mapped: *mut std::ffi::c_void = std::ptr::null_mut(); + let rc = unsafe { + (*(*self.inner).pMethods).xShmMap.unwrap()( + self.inner, + region, + region_size, + if extend { 1 } else { 0 }, + &mut mapped, + ) + }; + convert_rc_result(rc, mapped, "sqlite3_file::shm_map failed") + } + + fn shm_unmap(&mut self, delete: bool) -> Result, VfsError> { + let rc = unsafe { + (*(*self.inner).pMethods).xShmUnmap.unwrap()(self.inner, if delete { 1 } else { 0 }) + }; + convert_rc_result(rc, (), "sqlite3_file::shm_unmap failed") + } + + fn shm_lock(&mut self, offset: i32, count: i32, flags: i32) -> Result, VfsError> { + let rc = unsafe { + (*(*self.inner).pMethods).xShmLock.unwrap()(self.inner, offset, count, flags) + }; + convert_rc_result(rc, (), "sqlite3_file::shm_lock failed") + } + + fn shm_barrier(&mut self) { + unsafe { (*(*self.inner).pMethods).xShmBarrier.unwrap()(self.inner) }; + } +} + +impl Vfs for Sqlite3Vfs { + type File = Sqlite3VfsFile; + + fn name(&self) -> &std::ffi::CStr { + &self.inner.name + } + + fn max_pathname(&self) -> i32 { + DEFAULT_MAX_PATH_LENGTH + } + + fn open( + &self, + filename: Option<&std::ffi::CStr>, + flags: i32, + file: &mut Self::File, + ) -> Result, VfsError> { + let sz_os_file = unsafe { (*self.inner.vfs).szOsFile as usize }; + let filename_ptr = filename.map(|x| x.as_ptr()).unwrap_or(std::ptr::null()); + let file_layout = std::alloc::Layout::from_size_align(sz_os_file, 8).unwrap(); + let file_ptr = unsafe { std::alloc::alloc(file_layout) } as *mut ffi::sqlite3_file; + + let mut result: i32 = 0; + let rc = unsafe { + (*self.inner.vfs).xOpen.unwrap()( + self.inner.vfs, + filename_ptr, + file_ptr, + flags, + &mut result, + ) + }; + if rc == ffi::SQLITE_OK { + file.inner = file_ptr; + file.inner_layout = file_layout; + file.z_filename = filename_ptr; + } else { + unsafe { std::alloc::dealloc(file_ptr as *mut u8, file_layout) }; + } + convert_rc_result(rc, result, "sqlite3_vfs::open failed") + } + + fn delete(&self, filename: &std::ffi::CStr, sync_dir: bool) -> Result, VfsError> { + let rc = unsafe { + (*self.inner.vfs).xDelete.unwrap()( + self.inner.vfs, + filename.as_ptr(), + if sync_dir { 1 } else { 0 }, + ) + }; + convert_rc_result(rc, (), "sqlite3_vfs::delete failed") + } + + fn access(&self, filename: &std::ffi::CStr, flags: i32) -> Result, VfsError> { + let mut result: i32 = 0; + let rc = unsafe { + (*self.inner.vfs).xAccess.unwrap()( + self.inner.vfs, + filename.as_ptr(), + flags, + &mut result, + ) + }; + convert_rc_result(rc, result != 0, "sqlite3_vfs::access failed") + } + + fn full_pathname( + &self, + filename: &std::ffi::CStr, + full_buffer: &mut [u8], + ) -> Result, VfsError> { + let rc = unsafe { + (*self.inner.vfs).xFullPathname.unwrap()( + self.inner.vfs, + filename.as_ptr(), + full_buffer.len() as i32, + full_buffer.as_mut_ptr() as *mut std::ffi::c_char, + ) + }; + convert_rc_result(rc, (), "sqlite3_vfs::full_pathname failed") + } + + fn sleep(&self, duration: std::time::Duration) -> Result, VfsError> { + let rc = unsafe { + (*self.inner.vfs).xSleep.unwrap()(self.inner.vfs, duration.as_micros() as i32) + }; + convert_rc_result(rc, (), "sqlite3_vfs::sleep failed") + } +} + +pub fn get_default_vfs(name: &str) -> Sqlite3Vfs { + let vfs = unsafe { ffi::sqlite3_vfs_find(std::ptr::null()) }; + assert!(!vfs.is_null(), "default vfs is not found"); + let inner = Arc::new(Sqlite3VfsInner { + vfs, + name: std::ffi::CString::new(name).unwrap(), + }); + Sqlite3Vfs { inner } +} diff --git a/libsql/src/lazy/vfs_lazy.rs b/libsql/src/lazy/vfs_lazy.rs new file mode 100644 index 0000000000..22fa27df54 --- /dev/null +++ b/libsql/src/lazy/vfs_lazy.rs @@ -0,0 +1,221 @@ +use std::{mem::ManuallyDrop, sync::Arc}; + +use anyhow::anyhow; +use lazy_static::lazy_static; +use libsql_sys::ffi; +use tokio::runtime::Runtime; + +use super::{ + lazy::{PageServer, PullPagesReqBody}, + vfs::{Vfs, VfsError, VfsFile, VfsResult}, +}; + +struct LazyVfsInner { + vfs: V, + name: std::ffi::CString, + page_server: Arc, +} + +pub struct LazyVfs { + inner: Arc>, +} + +impl Clone for LazyVfs { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +#[repr(C)] +pub struct LazyVfsFile { + file: ManuallyDrop, + inner: ManuallyDrop>>, +} + +impl LazyVfs { + pub fn new(name: &str, vfs: V, page_server: Arc

) -> Self { + let name = std::ffi::CString::new(name).unwrap(); + Self { + inner: Arc::new(LazyVfsInner { + vfs, + name, + page_server, + }), + } + } +} + +const PAGE_SIZE: i64 = 4096; + +lazy_static! { + static ref RT: Runtime = tokio::runtime::Runtime::new().unwrap(); +} + +impl VfsFile for LazyVfsFile { + fn path(&self) -> Option<&str> { + self.file.path() + } + + fn close(&mut self) -> Result, VfsError> { + self.file.close() + } + + fn read(&mut self, buf: &mut [u8], offset: i64) -> Result, VfsError> { + tracing::info!("read: {}", offset); + if self.file.path().is_some() && self.file.path().unwrap().ends_with(".db") { + assert!(buf.len() > 0); + let start_page_no = offset / PAGE_SIZE; + let end_page_no = (offset + buf.len() as i64 - 1) / PAGE_SIZE; + assert!(start_page_no == end_page_no); + + let page_no = (start_page_no as usize) + 1; + let page_server = self.inner.page_server.clone(); + let revision = page_server.get_revision(); + if revision != "" { + let pages = futures::executor::block_on({ + page_server.pull_pages(&PullPagesReqBody { + start_revision: None, + end_revision: revision.clone(), + server_pages: vec![page_no], + client_pages: vec![], + accept_encodings: vec!["zstd".into()], + }) + }) + .map_err(|e| VfsError::Anyhow(e.into()))?; + + tracing::info!( + "got pages(revision={}): {} {}", + revision, + pages[0].0, + pages[0].1.len() + ); + buf.copy_from_slice(&pages[0].1); + return Ok(VfsResult { + value: (), + rc: ffi::SQLITE_OK, + }); + } + } + self.file.read(buf, offset) + } + + fn write(&mut self, buf: &[u8], offset: i64) -> Result, VfsError> { + tracing::info!("write"); + self.file.write(buf, offset) + } + + fn truncate(&mut self, size: i64) -> Result, VfsError> { + tracing::info!("truncate"); + self.file.truncate(size) + } + + fn sync(&mut self, flags: i32) -> Result, VfsError> { + self.file.sync(flags) + } + + fn file_size(&mut self) -> Result, VfsError> { + if self.inner.page_server.get_revision() != "" { + return Ok(VfsResult { + value: 533 * 4096, + rc: ffi::SQLITE_OK, + }); + } + let file_size = self.file.file_size(); + tracing::info!("file_size: {:?}", file_size); + file_size + } + + fn lock(&mut self, upgrade_to: i32) -> Result, VfsError> { + self.file.lock(upgrade_to) + } + + fn unlock(&mut self, downgrade_to: i32) -> Result, VfsError> { + self.file.unlock(downgrade_to) + } + + fn check_reserved_lock(&mut self) -> Result, VfsError> { + self.file.check_reserved_lock() + } + + #[allow(clippy::not_unsafe_ptr_arg_deref)] + fn file_control( + &mut self, + op: i32, + arg: *mut std::ffi::c_void, + ) -> Result, VfsError> { + self.file.file_control(op, arg) + } + + fn sector_size(&mut self) -> i32 { + self.file.sector_size() + } + + fn device_characteristics(&mut self) -> i32 { + self.file.device_characteristics() + } + + fn shm_map( + &mut self, + region: i32, + region_size: i32, + extend: bool, + ) -> Result, VfsError> { + self.file.shm_map(region, region_size, extend) + } + + fn shm_unmap(&mut self, delete: bool) -> Result, VfsError> { + self.file.shm_unmap(delete) + } + + fn shm_lock(&mut self, offset: i32, count: i32, flags: i32) -> Result, VfsError> { + self.file.shm_lock(offset, count, flags) + } + + fn shm_barrier(&mut self) { + self.file.shm_barrier(); + } +} + +impl Vfs for LazyVfs { + type File = LazyVfsFile; + + fn name(&self) -> &std::ffi::CStr { + &self.inner.name + } + + fn max_pathname(&self) -> i32 { + self.inner.vfs.max_pathname() + } + + fn open( + &self, + filename: Option<&std::ffi::CStr>, + flags: i32, + file: &mut Self::File, + ) -> Result, VfsError> { + file.inner = ManuallyDrop::new(self.inner.clone()); + self.inner.vfs.open(filename, flags, &mut file.file) + } + + fn delete(&self, filename: &std::ffi::CStr, sync_dir: bool) -> Result, VfsError> { + self.inner.vfs.delete(filename, sync_dir) + } + + fn access(&self, filename: &std::ffi::CStr, flags: i32) -> Result, VfsError> { + self.inner.vfs.access(filename, flags) + } + + fn full_pathname( + &self, + filename: &std::ffi::CStr, + full_buffer: &mut [u8], + ) -> Result, VfsError> { + self.inner.vfs.full_pathname(filename, full_buffer) + } + + fn sleep(&self, duration: std::time::Duration) -> Result, VfsError> { + self.inner.vfs.sleep(duration) + } +} diff --git a/libsql/src/lib.rs b/libsql/src/lib.rs index 823ab89c84..dfe83243e8 100644 --- a/libsql/src/lib.rs +++ b/libsql/src/lib.rs @@ -134,6 +134,10 @@ cfg_sync! { pub use database::SyncProtocol; } +cfg_lazy! { + mod lazy; +} + cfg_replication! { pub mod replication; } diff --git a/libsql/src/local/connection.rs b/libsql/src/local/connection.rs index b9a2e20150..60974f0fc5 100644 --- a/libsql/src/local/connection.rs +++ b/libsql/src/local/connection.rs @@ -1,5 +1,6 @@ #![allow(dead_code)] +use crate::lazy::lazy::LAZY_VFS_NAME; use crate::local::rows::BatchedRows; use crate::params::Params; use crate::{connection::BatchRows, errors}; @@ -41,6 +42,11 @@ impl Connection { let mut raw = std::ptr::null_mut(); let db_path = db.db_path.clone(); let err = unsafe { + let mut vfs_ptr = std::ptr::null(); + #[cfg(feature = "lazy")] + if let Some(_) = db.lazy_ctx { + vfs_ptr = LAZY_VFS_NAME.as_ptr() as *const i8; + } ffi::sqlite3_open_v2( std::ffi::CString::new(db_path.as_str()) .unwrap() @@ -48,7 +54,7 @@ impl Connection { .as_ptr() as *const _, &mut raw, db.flags.bits() as c_int, - std::ptr::null(), + vfs_ptr, ) }; match err { @@ -75,6 +81,16 @@ impl Connection { ffi::libsql_wal_disable_checkpoint(conn.raw); } } + #[cfg(feature = "lazy")] + if let Some(_) = db.lazy_ctx { + // We need to make sure database is in WAL mode with checkpointing + // disabled so that we can sync our changes back to a remote + // server. + conn.query("PRAGMA journal_mode = WAL", Params::None)?; + unsafe { + ffi::libsql_wal_disable_checkpoint(conn.raw); + } + } Ok(conn) } @@ -459,7 +475,15 @@ impl Connection { } pub(crate) fn wal_checkpoint(&self, truncate: bool) -> Result<()> { - let rc = unsafe { libsql_sys::ffi::sqlite3_wal_checkpoint_v2(self.handle(), std::ptr::null(), truncate as i32, std::ptr::null_mut(), std::ptr::null_mut()) }; + let rc = unsafe { + libsql_sys::ffi::sqlite3_wal_checkpoint_v2( + self.handle(), + std::ptr::null(), + truncate as i32, + std::ptr::null_mut(), + std::ptr::null_mut(), + ) + }; if rc != 0 { let err_msg = unsafe { libsql_sys::ffi::sqlite3_errmsg(self.handle()) }; let err_msg = unsafe { std::ffi::CStr::from_ptr(err_msg) }; @@ -566,13 +590,16 @@ impl Connection { pub(crate) fn wal_insert_handle(&self) -> Result> { self.wal_insert_begin()?; - Ok(WalInsertHandle { conn: self, in_session: RefCell::new(true) }) + Ok(WalInsertHandle { + conn: self, + in_session: RefCell::new(true), + }) } } pub(crate) struct WalInsertHandle<'a> { conn: &'a Connection, - in_session: RefCell + in_session: RefCell, } impl WalInsertHandle<'_> { diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 3b157e715d..61504e6430 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -24,7 +24,13 @@ cfg_sync! { use std::sync::Arc; } -use crate::{database::OpenFlags, local::connection::Connection, Error::ConnectionFailed, Result}; +use crate::{ + database::OpenFlags, + lazy::lazy::{LazyContext, TursoPageServer}, + local::connection::Connection, + Error::ConnectionFailed, + Result, +}; use libsql_sys::ffi; // A libSQL database. @@ -35,6 +41,8 @@ pub struct Database { pub replication_ctx: Option, #[cfg(feature = "sync")] pub sync_ctx: Option>>, + #[cfg(feature = "lazy")] + pub lazy_ctx: Option>>>, } impl Database { @@ -74,6 +82,8 @@ impl Database { replication_ctx: None, #[cfg(feature = "sync")] sync_ctx: None, + #[cfg(feature = "lazy")] + lazy_ctx: None, }) } } @@ -228,6 +238,39 @@ impl Database { Ok(db) } + #[cfg(feature = "lazy")] + #[doc(hidden)] + pub async fn open_local_lazy( + connector: crate::util::ConnectorService, + db_path: impl Into, + flags: OpenFlags, + endpoint: String, + auth_token: String, + ) -> Result { + use arc_swap::ArcSwap; + + use crate::lazy::lazy::LazyContext; + + let db_path = db_path.into(); + let endpoint = if endpoint.starts_with("libsql:") { + endpoint.replace("libsql:", "https:") + } else { + endpoint + }; + let mut db = Database::open(&db_path, flags)?; + let meta_path = format!("{}-metadata", db_path); + let encoding = "raw".into(); + let page_server = Arc::new(TursoPageServer { + endpoint, + auth_token: Some(auth_token), + client: hyper::client::Client::builder().build::<_, hyper::Body>(connector), + revision: ArcSwap::new(Arc::new("".to_string())), + }); + let lazy_ctx = LazyContext::new(db_path, meta_path, encoding, page_server).await?; + db.lazy_ctx = Some(Arc::new(Mutex::new(lazy_ctx))); + Ok(db) + } + #[cfg(feature = "replication")] pub async fn open_local_sync( db_path: impl Into, @@ -336,6 +379,8 @@ impl Database { replication_ctx: None, #[cfg(feature = "sync")] sync_ctx: None, + #[cfg(feature = "lazy")] + lazy_ctx: None, } } diff --git a/libsql/src/macros.rs b/libsql/src/macros.rs index b7d0df31c0..1864ab4310 100644 --- a/libsql/src/macros.rs +++ b/libsql/src/macros.rs @@ -13,8 +13,8 @@ macro_rules! cfg_core { macro_rules! cfg_replication_or_remote_or_sync { ($($item:item)*) => { $( - #[cfg(any(feature = "replication", feature = "sync", feature = "remote"))] - #[cfg_attr(docsrs, doc(cfg(any(feature = "replication", feature = "sync", feature = "remote"))))] + #[cfg(any(feature = "replication", feature = "sync", feature = "remote", feature = "lazy"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "replication", feature = "sync", feature = "remote", feature = "lazy"))))] $item )* } @@ -50,11 +50,21 @@ macro_rules! cfg_sync { } } +macro_rules! cfg_lazy { + ($($item:item)*) => { + $( + #[cfg(feature = "lazy")] + #[cfg_attr(docsrs, doc(cfg(feature = "lazy")))] + $item + )* + } +} + macro_rules! cfg_replication_or_sync { ($($item:item)*) => { $( - #[cfg(any(feature = "replication", feature = "sync"))] - #[cfg_attr(docsrs, doc(cfg(any(feature = "replication", feature = "sync"))))] + #[cfg(any(feature = "replication", feature = "sync", feature = "lazy"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "replication", feature = "sync", feature = "lazy"))))] $item )* } diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index a1d6b742f9..ae94877ff4 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -495,7 +495,7 @@ impl MetadataJson { } } -async fn atomic_write>(path: P, data: &[u8]) -> Result<()> { +pub async fn atomic_write>(path: P, data: &[u8]) -> Result<()> { // Create a temporary file in the same directory as the target file let directory = path.as_ref().parent().ok_or_else(|| { SyncError::io("parent path")(std::io::Error::other(