diff --git a/compio-macros/README.md b/compio-macros/README.md index 5815faa56..cfae72f6b 100644 --- a/compio-macros/README.md +++ b/compio-macros/README.md @@ -39,3 +39,12 @@ async fn main() { println!("Hello from compio!"); } ``` + +You can customize the runtime through params: + +```rust +#[compio::main(event_interval = 4, with_proactor(capacity = 16))] +async fn main() { + println!("Hello from compio!"); +} +``` diff --git a/compio-macros/src/item_fn.rs b/compio-macros/src/item_fn.rs index 954919ebb..5e87e3922 100644 --- a/compio-macros/src/item_fn.rs +++ b/compio-macros/src/item_fn.rs @@ -1,65 +1,211 @@ use proc_macro2::TokenStream; -use quote::quote; +use quote::{ToTokens, TokenStreamExt, quote}; use syn::{ - Attribute, Expr, Lit, Meta, Signature, Visibility, parse::Parse, punctuated::Punctuated, + Attribute, Expr, ExprLit, Ident, Lit, Meta, Path, Signature, Token, Visibility, parse::Parse, + punctuated::Punctuated, }; -type AttributeArgs = Punctuated; +use crate::{retrieve_driver_mod, retrieve_runtime_mod}; +struct MetaPunctuated(Punctuated); + +impl Parse for MetaPunctuated { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Self(Punctuated::parse_terminated(input)?)) + } +} + +pub(crate) struct BuilderMethod { + pub name: Ident, + pub value: Expr, +} + +#[derive(Default)] pub(crate) struct RawAttr { - pub inner_attrs: AttributeArgs, + pub runtime_methods: Vec, + pub proactor_methods: Vec, + pub crate_name: Option, + pub with_proactor_call: Option, } impl Parse for RawAttr { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let inner_attrs = AttributeArgs::parse_terminated(input)?; - Ok(Self { inner_attrs }) + let items = Punctuated::::parse_terminated(input)?; + + let mut runtime_methods = Vec::new(); + let mut proactor_methods = Vec::new(); + let mut crate_name = None; + let mut with_proactor_call = None; + + for meta in items { + match meta { + Meta::List(list) => { + if list.path.is_ident("with_proactor") { + if !list.tokens.is_empty() { + let inner_items = syn::parse2::(list.tokens)?.0; + for inner_meta in inner_items { + if let Meta::NameValue(nv) = inner_meta { + let name = nv.path.require_ident()?.clone(); + proactor_methods.push(BuilderMethod { + name, + value: nv.value, + }); + } else { + return Err(syn::Error::new_spanned( + inner_meta, + "expected `name = value` inside `with_proactor`", + )); + } + } + with_proactor_call = Some(list.path.clone()); + } + } else { + return Err(syn::Error::new_spanned( + list.path, + "unknown key; use `name = value` for parameters or \ + `with_proactor(...)` for proactor config", + )); + } + } + Meta::NameValue(nv) => { + if nv.path.is_ident("crate") { + if let Expr::Lit(ExprLit { + lit: Lit::Str(s), .. + }) = &nv.value + { + crate_name = Some(s.parse::()?); + } else { + crate_name = Some(nv.value.into_token_stream()); + } + } else { + let name = nv.path.require_ident()?.clone(); + runtime_methods.push(BuilderMethod { + name, + value: nv.value, + }); + } + } + Meta::Path(path) => { + return Err(syn::Error::new_spanned( + path, + "expected `name = value` or `with_proactor(...)`", + )); + } + } + } + + Ok(Self { + runtime_methods, + proactor_methods, + crate_name, + with_proactor_call, + }) } } pub(crate) struct RawBodyItemFn { pub attrs: Vec, - pub args: AttributeArgs, + pub args: RawAttr, pub vis: Visibility, pub sig: Signature, pub body: TokenStream, + pub test: bool, } impl RawBodyItemFn { pub fn new(attrs: Vec, vis: Visibility, sig: Signature, body: TokenStream) -> Self { Self { attrs, - args: AttributeArgs::new(), + args: RawAttr::default(), vis, sig, body, + test: false, } } - pub fn set_args(&mut self, args: AttributeArgs) { + pub fn set_args(&mut self, args: RawAttr) { self.args = args; } - pub fn crate_name(&self) -> Option { - for attr in &self.args { - if let Meta::NameValue(name) = &attr { - let ident = name - .path - .get_ident() - .map(|ident| ident.to_string().to_lowercase()) - .unwrap_or_default(); - if ident == "crate" { - if let Expr::Lit(lit) = &name.value - && let Lit::Str(s) = &lit.lit - { - let crate_name = s.parse::().unwrap(); - return Some(quote!(#crate_name::runtime)); - } - } else { - panic!("Unsupported property {ident}"); - } + pub fn set_test(&mut self, test: bool) { + self.test = test; + } + + pub fn emit_fn_to_tokens(&self, tokens: &mut TokenStream) { + if self.test { + tokens.append_all(quote!(#[test])); + } + tokens.append_all( + self.attrs + .iter() + .filter(|a| matches!(a.style, syn::AttrStyle::Outer)), + ); + self.vis.to_tokens(tokens); + self.sig.to_tokens(tokens); + tokens.append_all(self.gen_runtime_block()); + } + + fn gen_runtime_block(&self) -> TokenStream { + let runtime_mod = match &self.args.crate_name { + Some(c) => { + let c = c.clone(); + quote!(#c::runtime) + } + None => retrieve_runtime_mod(), + }; + + let driver_mod = match &self.args.crate_name { + Some(c) => { + let c = c.clone(); + quote!(#c::driver) + } + None => retrieve_driver_mod(), + }; + + let block = &self.body; + + let mut builder = quote! { + #runtime_mod::Runtime::builder() + }; + + for method in &self.args.runtime_methods { + let name = &method.name; + let value = &method.value; + builder = quote! { + #builder.#name(#value) + }; + } + + if !self.args.proactor_methods.is_empty() { + let mut proactor_stmts: Vec = Vec::new(); + proactor_stmts.push(quote! { + let mut __compio_proactor_builder = #driver_mod::Proactor::builder(); + }); + for method in &self.args.proactor_methods { + let name = &method.name; + let value = &method.value; + proactor_stmts.push(quote! { + __compio_proactor_builder.#name(#value); + }); } + // Preserve the original token for the `with_proactor` call to make the language + // server work better. + let with_proactor_call = if let Some(path) = &self.args.with_proactor_call { + quote!(#path) + } else { + quote!(with_proactor) + }; + builder = quote! { + #builder.#with_proactor_call({ + #(#proactor_stmts)* + __compio_proactor_builder + }) + }; } - None + + quote!({ + #builder.build().expect("cannot create runtime").block_on(async move #block) + }) } } diff --git a/compio-macros/src/lib.rs b/compio-macros/src/lib.rs index ae7db9a8a..6e5923181 100644 --- a/compio-macros/src/lib.rs +++ b/compio-macros/src/lib.rs @@ -9,8 +9,6 @@ mod item_fn; mod main_fn; -mod test_fn; - use proc_macro::TokenStream; use proc_macro_crate::{FoundCrate, crate_name}; use proc_macro2::{Ident, Span}; @@ -27,8 +25,9 @@ pub fn main(args: TokenStream, item: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { - parse_macro_input!(item as test_fn::CompioTest) + parse_macro_input!(item as main_fn::CompioMain) .with_args(parse_macro_input!(args as item_fn::RawAttr)) + .with_test(true) .into_token_stream() .into() } @@ -50,3 +49,24 @@ fn retrieve_runtime_mod() -> proc_macro2::TokenStream { }, } } + +fn retrieve_driver_mod() -> proc_macro2::TokenStream { + match crate_name("compio-driver") { + Ok(FoundCrate::Itself) => quote!(crate), + Ok(FoundCrate::Name(name)) => { + let ident = Ident::new(&name, Span::call_site()); + quote!(::#ident) + } + Err(_) => match crate_name("compio") { + Ok(FoundCrate::Itself) => quote!(crate::driver), + Ok(FoundCrate::Name(name)) => { + let ident = Ident::new(&name, Span::call_site()); + quote!(::#ident::driver) + } + Err(_) => { + let ident = Ident::new("compio_driver", Span::call_site()); + quote!(::#ident) + } + }, + } +} diff --git a/compio-macros/src/main_fn.rs b/compio-macros/src/main_fn.rs index 8eb4b46fd..fae3c5688 100644 --- a/compio-macros/src/main_fn.rs +++ b/compio-macros/src/main_fn.rs @@ -1,17 +1,19 @@ use proc_macro2::TokenStream; -use quote::{ToTokens, TokenStreamExt, quote}; -use syn::{AttrStyle, Attribute, Signature, Visibility, parse::Parse}; +use quote::ToTokens; +use syn::{Attribute, Signature, Visibility, parse::Parse}; -use crate::{ - item_fn::{RawAttr, RawBodyItemFn}, - retrieve_runtime_mod, -}; +use crate::item_fn::{RawAttr, RawBodyItemFn}; pub(crate) struct CompioMain(pub RawBodyItemFn); impl CompioMain { pub fn with_args(mut self, args: RawAttr) -> Self { - self.0.set_args(args.inner_attrs); + self.0.set_args(args); + self + } + + pub fn with_test(mut self, test: bool) -> Self { + self.0.set_test(test); self } } @@ -37,18 +39,6 @@ impl Parse for CompioMain { impl ToTokens for CompioMain { fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.append_all( - self.0 - .attrs - .iter() - .filter(|a| matches!(a.style, AttrStyle::Outer)), - ); - self.0.vis.to_tokens(tokens); - self.0.sig.to_tokens(tokens); - let block = &self.0.body; - let runtime_mod = self.0.crate_name().unwrap_or_else(retrieve_runtime_mod); - tokens.append_all(quote!({ - #runtime_mod::Runtime::new().expect("cannot create runtime").block_on(async move #block) - })); + self.0.emit_fn_to_tokens(tokens); } } diff --git a/compio-macros/src/test_fn.rs b/compio-macros/src/test_fn.rs deleted file mode 100644 index 7fb17be3d..000000000 --- a/compio-macros/src/test_fn.rs +++ /dev/null @@ -1,54 +0,0 @@ -use proc_macro2::TokenStream; -use quote::{ToTokens, TokenStreamExt, quote}; -use syn::{AttrStyle, Attribute, Signature, Visibility, parse::Parse}; - -use crate::{ - item_fn::{RawAttr, RawBodyItemFn}, - retrieve_runtime_mod, -}; - -pub(crate) struct CompioTest(pub RawBodyItemFn); - -impl CompioTest { - pub fn with_args(mut self, args: RawAttr) -> Self { - self.0.set_args(args.inner_attrs); - self - } -} - -impl Parse for CompioTest { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let attrs = input.call(Attribute::parse_outer)?; - let vis: Visibility = input.parse()?; - let mut sig: Signature = input.parse()?; - let body: TokenStream = input.parse()?; - - if sig.asyncness.is_none() { - return Err(syn::Error::new_spanned( - sig.ident, - "the `async` keyword is missing from the function declaration", - )); - }; - sig.asyncness.take(); - Ok(Self(RawBodyItemFn::new(attrs, vis, sig, body))) - } -} - -impl ToTokens for CompioTest { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.append_all(quote!(#[test])); - tokens.append_all( - self.0 - .attrs - .iter() - .filter(|a| matches!(a.style, AttrStyle::Outer)), - ); - self.0.vis.to_tokens(tokens); - self.0.sig.to_tokens(tokens); - let block = &self.0.body; - let runtime_mod = self.0.crate_name().unwrap_or_else(retrieve_runtime_mod); - tokens.append_all(quote!({ - #runtime_mod::Runtime::new().expect("cannot create runtime").block_on(async move #block) - })); - } -} diff --git a/compio/tests/accept.rs b/compio/tests/accept.rs index 3c0dff23b..b09402e95 100644 --- a/compio/tests/accept.rs +++ b/compio/tests/accept.rs @@ -1,27 +1,18 @@ use compio::{ - driver::{DriverType, ProactorBuilder}, + driver::DriverType, net::{TcpListener, TcpStream}, - runtime::Runtime, }; use compio_runtime::ResumeUnwind; -#[test] -fn accept() { - let mut proactor_builder = ProactorBuilder::new(); - proactor_builder.driver_type(DriverType::Poll); - let runtime = Runtime::builder() - .with_proactor(proactor_builder) - .build() - .unwrap(); - runtime.block_on(async { - let listener = TcpListener::bind("localhost:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let task = compio_runtime::spawn(async move { - let (socket, _) = listener.accept().await.unwrap(); - socket - }); - let cli = TcpStream::connect(&addr).await.unwrap(); - let srv = task.await.resume_unwind().expect("shouldn't be cancelled"); - assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap()); - }) +#[compio_macros::test(with_proactor(driver_type = DriverType::Poll))] +async fn accept() { + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let task = compio_runtime::spawn(async move { + let (socket, _) = listener.accept().await.unwrap(); + socket + }); + let cli = TcpStream::connect(&addr).await.unwrap(); + let srv = task.await.resume_unwind().expect("shouldn't be cancelled"); + assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap()); } diff --git a/compio/tests/runtime.rs b/compio/tests/runtime.rs index 91bd596ef..fe3cd6b42 100644 --- a/compio/tests/runtime.rs +++ b/compio/tests/runtime.rs @@ -6,7 +6,7 @@ use std::{net::Ipv4Addr, time::Duration}; use compio::fs::named_pipe::{ClientOptions, NamedPipeClient, NamedPipeServer, ServerOptions}; use compio::{ buf::*, - driver::{ErrorExt, ProactorBuilder}, + driver::ErrorExt, fs::File, io::{AsyncRead, AsyncReadAt, AsyncReadExt, AsyncWriteAt, AsyncWriteExt}, net::{TcpListener, TcpStream, UnixStream}, @@ -98,24 +98,16 @@ async fn drop_on_complete() { drop(file); } -#[test] -fn too_many_submissions() { - let mut proactor_builder = ProactorBuilder::new(); - proactor_builder.capacity(1).thread_pool_limit(1); - compio_runtime::Runtime::builder() - .with_proactor(proactor_builder) - .build() - .unwrap() - .block_on(async move { - let tempfile = tempfile(); - let mut file = File::create(tempfile.path()).await.unwrap(); - for _ in 0..600 { - poll_once(async { - file.write_at("hello world", 0).await.0.unwrap(); - }) - .await; - } +#[compio_macros::test(with_proactor(capacity = 1, thread_pool_limit = 1))] +async fn too_many_submissions() { + let tempfile = tempfile(); + let mut file = File::create(tempfile.path()).await.unwrap(); + for _ in 0..600 { + poll_once(async { + file.write_at("hello world", 0).await.0.unwrap(); }) + .await; + } } #[cfg(feature = "allocator_api")]