Skip to content

Much needed TLC #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ on:
branches:
- master

workflow_dispatch:

env:
CARGO_TERM_COLOR: always

Expand Down
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ harness = false
required-features = ["async-std-runtime", "testing"]

[dependencies]
async-channel = { version = "1.6", optional = true }
clap = { version = "3.2", optional = true }
async-channel = { version = "2.1", optional = true }
clap = { version = "4.4", optional = true }
futures = "0.3"
inventory = { version = "0.3", optional = true }
once_cell = "1.14"
once_cell = "1.19"
pin-project-lite = "0.2"
pyo3 = "0.20"
pyo3-asyncio-macros = { path = "pyo3-asyncio-macros", version = "=0.20.0", optional = true }
Expand All @@ -128,6 +128,6 @@ features = ["unstable"]
optional = true

[dependencies.tokio]
version = "1.13"
version = "1.36"
features = ["rt", "rt-multi-thread", "time"]
optional = true
2 changes: 1 addition & 1 deletion pyo3-asyncio-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1"
syn = { version = "1", features = ["full"] }
syn = { version = "2" }
2 changes: 1 addition & 1 deletion pyo3-asyncio-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ pub fn async_std_main(_attr: TokenStream, item: TokenStream) -> TokenStream {
#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
#[proc_macro_attribute]
pub fn tokio_main(args: TokenStream, item: TokenStream) -> TokenStream {
tokio::main(args, item, true)
tokio::main(args.into(), item.into(), true).into()
}

/// Registers an `async-std` test with the `pyo3-asyncio` test harness.
Expand Down
58 changes: 36 additions & 22 deletions pyo3-asyncio-macros/src/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::spanned::Spanned;
use syn::{parse::Parser, spanned::Spanned};

#[derive(Clone, Copy, PartialEq)]
enum RuntimeFlavor {
Expand Down Expand Up @@ -47,7 +46,7 @@ impl Configuration {
}
}

fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> {
fn set_flavor(&mut self, runtime: &syn::Lit, span: Span) -> Result<(), syn::Error> {
if self.flavor.is_some() {
return Err(syn::Error::new(span, "`flavor` set multiple times."));
}
Expand All @@ -61,7 +60,7 @@ impl Configuration {

fn set_worker_threads(
&mut self,
worker_threads: syn::Lit,
worker_threads: &syn::Lit,
span: Span,
) -> Result<(), syn::Error> {
if self.worker_threads.is_some() {
Expand Down Expand Up @@ -107,7 +106,7 @@ impl Configuration {
}
}

fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
fn parse_int(int: &syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
match int {
syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
Ok(value) => Ok(value),
Expand All @@ -123,7 +122,7 @@ fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error
}
}

fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
fn parse_string(int: &syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
match int {
syn::Lit::Str(s) => Ok(s.value()),
syn::Lit::Verbatim(s) => Ok(s.to_string()),
Expand All @@ -136,7 +135,7 @@ fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::E

fn parse_knobs(
input: syn::ItemFn,
args: syn::AttributeArgs,
args: syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>,
is_test: bool,
rt_multi_thread: bool,
) -> Result<TokenStream, syn::Error> {
Expand All @@ -156,18 +155,22 @@ fn parse_knobs(

for arg in args {
match arg {
syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => {
syn::Meta::NameValue(namevalue) => {
let ident = namevalue.path.get_ident();
if ident.is_none() {
let msg = "Must have specified ident";
return Err(syn::Error::new_spanned(namevalue, msg));
}
let lit = match &namevalue.value {
syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit,
expr => return Err(syn::Error::new_spanned(expr, "Must be a literal")),
};
match ident.unwrap().to_string().to_lowercase().as_str() {
"worker_threads" => {
config.set_worker_threads(namevalue.lit.clone(), namevalue.span())?;
config.set_worker_threads(lit, Spanned::span(lit))?;
}
"flavor" => {
config.set_flavor(namevalue.lit.clone(), namevalue.span())?;
config.set_flavor(lit, Spanned::span(lit))?;
}
"core_threads" => {
let msg = "Attribute `core_threads` is renamed to `worker_threads`";
Expand All @@ -179,7 +182,7 @@ fn parse_knobs(
}
}
}
syn::NestedMeta::Meta(syn::Meta::Path(path)) => {
syn::Meta::Path(path) => {
let ident = path.get_ident();
if ident.is_none() {
let msg = "Must have specified ident";
Expand Down Expand Up @@ -273,20 +276,31 @@ fn parse_knobs(
}
};

Ok(result.into())
Ok(result)
}

#[cfg(not(test))] // Work around for rust-lang/rust#62127
pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
let input = syn::parse_macro_input!(item as syn::ItemFn);
let args = syn::parse_macro_input!(args as syn::AttributeArgs);
let input = match syn::parse2::<syn::ItemFn>(item) {
Ok(input) => {
if input.sig.ident == "main" && !input.sig.inputs.is_empty() {
let msg = "the main function cannot accept arguments";
return syn::Error::new_spanned(&input.sig.ident, msg)
.to_compile_error()
.into();
}

if input.sig.ident == "main" && !input.sig.inputs.is_empty() {
let msg = "the main function cannot accept arguments";
return syn::Error::new_spanned(&input.sig.ident, msg)
.to_compile_error()
.into();
}
input
}
Err(e) => return e.to_compile_error().into(),
};

let args =
syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated.parse2(args);

parse_knobs(input, args, false, rt_multi_thread).unwrap_or_else(|e| e.to_compile_error().into())
match args {
Ok(args) => parse_knobs(input, args, false, rt_multi_thread)
.unwrap_or_else(|e| e.to_compile_error().into()),
Err(e) => return e.to_compile_error().into(),
}
}
19 changes: 0 additions & 19 deletions pytests/test_async_std_asyncio.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
mod common;

use std::{
rc::Rc,
sync::{Arc, Mutex},
time::Duration,
};
Expand Down Expand Up @@ -112,24 +111,6 @@ async fn test_panic() -> PyResult<()> {
}
}

#[pyo3_asyncio::async_std::test]
async fn test_local_future_into_py() -> PyResult<()> {
Python::with_gil(|py| {
let non_send_secs = Rc::new(1);

#[allow(deprecated)]
let py_future = pyo3_asyncio::async_std::local_future_into_py(py, async move {
async_std::task::sleep(Duration::from_secs(*non_send_secs)).await;
Ok(())
})?;

pyo3_asyncio::async_std::into_future(py_future)
})?
.await?;

Ok(())
}

#[pyo3_asyncio::async_std::test]
async fn test_cancel() -> PyResult<()> {
let completed = Arc::new(Mutex::new(false));
Expand Down
82 changes: 0 additions & 82 deletions pytests/tokio_asyncio/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
rc::Rc,
sync::{Arc, Mutex},
time::Duration,
};
Expand All @@ -9,7 +8,6 @@ use pyo3::{
types::{IntoPyDict, PyType},
wrap_pyfunction, wrap_pymodule,
};
use pyo3_asyncio::TaskLocals;

#[cfg(feature = "unstable-streams")]
use futures::{StreamExt, TryStreamExt};
Expand Down Expand Up @@ -86,33 +84,6 @@ async fn test_other_awaitables() -> PyResult<()> {
.await
}

#[pyo3_asyncio::tokio::test]
fn test_local_future_into_py(event_loop: PyObject) -> PyResult<()> {
tokio::task::LocalSet::new().block_on(pyo3_asyncio::tokio::get_runtime(), async {
Python::with_gil(|py| {
let non_send_secs = Rc::new(1);

#[allow(deprecated)]
let py_future = pyo3_asyncio::tokio::local_future_into_py_with_locals(
py,
TaskLocals::new(event_loop.as_ref(py)),
async move {
tokio::time::sleep(Duration::from_secs(*non_send_secs)).await;
Ok(())
},
)?;

pyo3_asyncio::into_future_with_locals(
&TaskLocals::new(event_loop.as_ref(py)),
py_future,
)
})?
.await?;

Ok(())
})
}

#[pyo3_asyncio::tokio::test]
async fn test_panic() -> PyResult<()> {
let fut = Python::with_gil(|py| -> PyResult<_> {
Expand Down Expand Up @@ -175,59 +146,6 @@ async fn test_cancel() -> PyResult<()> {
Ok(())
}

#[pyo3_asyncio::tokio::test]
#[allow(deprecated)]
fn test_local_cancel(event_loop: PyObject) -> PyResult<()> {
let locals = Python::with_gil(|py| -> PyResult<TaskLocals> {
Ok(TaskLocals::new(event_loop.as_ref(py)).copy_context(py)?)
})?;

tokio::task::LocalSet::new().block_on(
pyo3_asyncio::tokio::get_runtime(),
pyo3_asyncio::tokio::scope_local(locals, async {
let completed = Arc::new(Mutex::new(false));
let py_future = Python::with_gil(|py| -> PyResult<PyObject> {
let completed = Arc::clone(&completed);

#[allow(deprecated)]
Ok(pyo3_asyncio::tokio::local_future_into_py(py, async move {
tokio::time::sleep(Duration::from_secs(1)).await;
*completed.lock().unwrap() = true;
Ok(())
})?
.into())
})?;

if let Err(e) = Python::with_gil(|py| -> PyResult<_> {
py_future.as_ref(py).call_method0("cancel")?;
pyo3_asyncio::tokio::into_future(py_future.as_ref(py))
})?
.await
{
Python::with_gil(|py| -> PyResult<()> {
assert!(e.value(py).is_instance(
py.import("asyncio")?
.getattr("CancelledError")?
.downcast::<PyType>()
.unwrap()
)?);
Ok(())
})?;
} else {
panic!("expected CancelledError");
}

tokio::time::sleep(Duration::from_secs(1)).await;

if *completed.lock().unwrap() {
panic!("future still completed")
}

Ok(())
}),
)
}

/// This module is implemented in Rust.
#[pymodule]
fn test_mod(_py: Python, m: &PyModule) -> PyResult<()> {
Expand Down
Loading