diff --git a/Cargo.toml b/Cargo.toml index 4e8a0d4..4eaa6dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,4 +20,7 @@ members = [ "hdfs", "hdfs-examples", "hdfs-testing", -] \ No newline at end of file +] + +[patch.crates-io] +object_store = { git = "https://github.com/apache/arrow-rs.git", rev = "bff6155d38e19bfe62a776731b78b435560f2c8e" } diff --git a/hdfs/Cargo.toml b/hdfs/Cargo.toml index 1df6333..9ad7e7d 100644 --- a/hdfs/Cargo.toml +++ b/hdfs/Cargo.toml @@ -46,5 +46,5 @@ chrono = { version = "0.4" } fs-hdfs = { version = "^0.1.11", optional = true } fs-hdfs3 = { version = "^0.1.11", optional = true } futures = "0.3" -object_store = "0.6.1" +object_store = { version = "0.6", features = ["cloud"] } tokio = { version = "1.18", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } diff --git a/hdfs/src/object_store/hdfs.rs b/hdfs/src/object_store/hdfs.rs index 5a45641..6b075b5 100644 --- a/hdfs/src/object_store/hdfs.rs +++ b/hdfs/src/object_store/hdfs.rs @@ -21,7 +21,7 @@ use std::collections::{BTreeSet, VecDeque}; use std::fmt::{Display, Formatter}; use std::ops::Range; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use async_trait::async_trait; use bytes::Bytes; @@ -29,6 +29,7 @@ use chrono::{DateTime, NaiveDateTime, Utc}; use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use hdfs::hdfs::{get_hdfs_by_full_path, FileStatus, HdfsErr, HdfsFile, HdfsFs}; use hdfs::walkdir::HdfsWalkDir; +use object_store::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl}; use object_store::{ path::{self, Path}, Error, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, @@ -111,6 +112,126 @@ impl Display for HadoopFileSystem { } } +struct HdfsMultiPartUpload { + location: Path, + hdfs: Arc, + content: Arc>>>>, + first_unwritten_idx: Arc>, + file_created: Arc>, +} + +impl HdfsMultiPartUpload { + fn create_file_if_necessary(&self) -> Result<()> { + let mut file_created = self.file_created.lock().unwrap(); + if !*file_created { + let location = HadoopFileSystem::path_to_filesystem(&self.location.clone()); + match self.hdfs.create_with_overwrite(&location, true) { + Ok(_) => { + *file_created = true; + Ok(()) + } + Err(e) => Err(to_error(e)), + } + } else { + Ok(()) + } + } +} + +#[async_trait] +impl CloudMultiPartUploadImpl for HdfsMultiPartUpload { + async fn put_multipart_part( + &self, + buf: Vec, + part_idx: usize, + ) -> Result { + { + let mut content = self.content.lock().unwrap(); + while content.len() <= part_idx { + content.push(None); + } + content[part_idx] = Some(buf); + } + + let location = HadoopFileSystem::path_to_filesystem(&self.location.clone()); + let first_unwritten_idx = { + let guard = self.first_unwritten_idx.lock().unwrap(); + *guard + }; + + self.create_file_if_necessary()?; + + // Attempt to write all contiguous sequences of parts + if first_unwritten_idx <= part_idx { + let hdfs = self.hdfs.clone(); + let content = self.content.clone(); + let first_unwritten_idx = self.first_unwritten_idx.clone(); + + maybe_spawn_blocking(move || { + let file = hdfs.append(&location).map_err(to_error)?; + let mut content = content.lock().unwrap(); + + let mut first_unwritten_idx = first_unwritten_idx.lock().unwrap(); + + // Write all contiguous parts and free up the memory + while let Some(buf) = content.get_mut(*first_unwritten_idx).and_then(Option::take) { + file.write(buf.as_slice()).map_err(to_error)?; + *first_unwritten_idx += 1; + } + + file.close().map_err(to_error)?; + Ok(()) + }) + .await + .map_err(to_io_error)?; + } + + Ok(object_store::multipart::UploadPart { + content_id: part_idx.to_string(), + }) + } + + async fn complete( + &self, + completed_parts: Vec, + ) -> Result<(), std::io::Error> { + let content = self.content.lock().unwrap(); + if content.len() != completed_parts.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "Expected {} parts, but only {} parts were received", + content.len(), + completed_parts.len() + ), + )); + } + + // check first_unwritten_idx + let first_unwritten_idx = self.first_unwritten_idx.lock().unwrap(); + if *first_unwritten_idx != content.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "Expected to write {} parts, but only {} parts were written", + content.len(), + *first_unwritten_idx + ), + )); + } + + // Last check: make sure all parts were written, since we change it to None after writing + if content.iter().any(Option::is_some) { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Not all parts were written", + )); + } + + Ok(()) + } +} + #[async_trait] impl ObjectStore for HadoopFileSystem { // Current implementation is very simple due to missing configs, @@ -138,13 +259,25 @@ impl ObjectStore for HadoopFileSystem { async fn put_multipart( &self, - _location: &Path, + location: &Path, ) -> Result<(MultipartId, Box)> { - todo!() + let upload = HdfsMultiPartUpload { + location: location.clone(), + hdfs: self.hdfs.clone(), + content: Arc::new(Mutex::new(Vec::new())), + first_unwritten_idx: Arc::new(Mutex::new(0)), + file_created: Arc::new(Mutex::new(false)), + }; + + Ok(( + MultipartId::default(), + Box::new(CloudMultiPartUpload::new(upload, 8)), + )) } - async fn abort_multipart(&self, _location: &Path, _multipart_id: &MultipartId) -> Result<()> { - todo!() + async fn abort_multipart(&self, location: &Path, _multipart_id: &MultipartId) -> Result<()> { + // remove the file if it exists + self.delete(location).await } async fn get(&self, location: &Path) -> Result { @@ -620,6 +753,30 @@ fn to_error(err: HdfsErr) -> Error { } } +fn to_io_error(err: Error) -> std::io::Error { + match err { + Error::Generic { store, source } => { + std::io::Error::new(std::io::ErrorKind::Other, format!("{}: {}", store, source)) + } + Error::NotFound { path, source } => std::io::Error::new( + std::io::ErrorKind::NotFound, + format!("{}: {}", path, source), + ), + Error::AlreadyExists { path, source } => std::io::Error::new( + std::io::ErrorKind::AlreadyExists, + format!("{}: {}", path, source), + ), + Error::InvalidPath { source } => { + std::io::Error::new(std::io::ErrorKind::InvalidInput, source) + } + + _ => std::io::Error::new( + std::io::ErrorKind::Other, + format!("HadoopFileSystem: {}", err), + ), + } +} + #[cfg(test)] mod tests { use super::*;