Skip to content

Commit 152955a

Browse files
authored
Fix credential extraction; never try to await non-awaitable object (#345)
* Fix credential extraction * elide
1 parent f5988bb commit 152955a

File tree

5 files changed

+56
-19
lines changed

5 files changed

+56
-19
lines changed

pyo3-object_store/src/aws/credentials.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use pyo3::intern;
99
use pyo3::prelude::*;
1010

1111
use crate::aws::store::PyAmazonS3Config;
12-
use crate::credentials::{TemporaryToken, TokenCache};
12+
use crate::credentials::{is_awaitable, TemporaryToken, TokenCache};
1313

1414
/// A wrapper around an [AwsCredential] that includes an optional expiry timestamp.
1515
struct PyAwsCredential {
@@ -138,10 +138,10 @@ impl PyCredentialProviderResult {
138138

139139
impl<'py> FromPyObject<'py> for PyCredentialProviderResult {
140140
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
141-
if let Ok(credentials) = ob.extract() {
142-
Ok(Self::Sync(credentials))
143-
} else {
141+
if is_awaitable(ob)? {
144142
Ok(Self::Async(ob.clone().unbind()))
143+
} else {
144+
Ok(Self::Sync(ob.extract()?))
145145
}
146146
}
147147
}
@@ -164,8 +164,8 @@ impl PyAWSCredentialProvider {
164164
let credential = self
165165
.call()
166166
.await
167-
.map_err(|err| object_store::Error::Generic {
168-
store: "External AWS credential provider",
167+
.map_err(|err| object_store::Error::Unauthenticated {
168+
path: "External AWS credential provider".to_string(),
169169
source: Box::new(err),
170170
})?;
171171

pyo3-object_store/src/azure/credentials.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use pyo3::prelude::*;
1111
use pyo3::pybacked::PyBackedStr;
1212

1313
use crate::azure::error::Error;
14-
use crate::credentials::{TemporaryToken, TokenCache};
14+
use crate::credentials::{is_awaitable, TemporaryToken, TokenCache};
1515
use crate::PyObjectStoreError;
1616

1717
struct PyAzureAccessKey {
@@ -200,10 +200,10 @@ impl PyCredentialProviderResult {
200200

201201
impl<'py> FromPyObject<'py> for PyCredentialProviderResult {
202202
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
203-
if let Ok(credentials) = ob.extract() {
204-
Ok(Self::Sync(credentials))
205-
} else {
203+
if is_awaitable(ob)? {
206204
Ok(Self::Async(ob.clone().unbind()))
205+
} else {
206+
Ok(Self::Sync(ob.extract()?))
207207
}
208208
}
209209
}
@@ -223,8 +223,8 @@ impl PyAzureCredentialProvider {
223223
let credential = self
224224
.call()
225225
.await
226-
.map_err(|err| object_store::Error::Generic {
227-
store: "External Azure credential provider",
226+
.map_err(|err| object_store::Error::Unauthenticated {
227+
path: "External Azure credential provider".to_string(),
228228
source: Box::new(err),
229229
})?;
230230

pyo3-object_store/src/credentials.rs

+12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use chrono::Utc;
22
use chrono::{DateTime, TimeDelta};
3+
use pyo3::intern;
4+
use pyo3::prelude::*;
5+
use pyo3::types::PyTuple;
36
use std::future::Future;
47
use tokio::sync::Mutex;
58

@@ -87,3 +90,12 @@ impl<T: Clone + Send> TokenCache<T> {
8790
Ok(token)
8891
}
8992
}
93+
94+
/// Check whether a Python object is awaitable
95+
pub(crate) fn is_awaitable(ob: &Bound<PyAny>) -> PyResult<bool> {
96+
let py = ob.py();
97+
let inspect_mod = py.import(intern!(py, "inspect"))?;
98+
inspect_mod
99+
.call_method1(intern!(py, "isawaitable"), PyTuple::new(py, [ob])?)?
100+
.extract::<bool>()
101+
}

pyo3-object_store/src/gcp/credentials.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use pyo3::exceptions::PyTypeError;
88
use pyo3::intern;
99
use pyo3::prelude::*;
1010

11-
use crate::credentials::{TemporaryToken, TokenCache};
11+
use crate::credentials::{is_awaitable, TemporaryToken, TokenCache};
1212

1313
/// Ref https://github.com/apache/arrow-rs/pull/6638
1414
const DEFAULT_GCP_MIN_TTL: TimeDelta = TimeDelta::minutes(4);
@@ -114,10 +114,10 @@ impl PyCredentialProviderResult {
114114

115115
impl<'py> FromPyObject<'py> for PyCredentialProviderResult {
116116
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
117-
if let Ok(credentials) = ob.extract() {
118-
Ok(Self::Sync(credentials))
119-
} else {
117+
if is_awaitable(ob)? {
120118
Ok(Self::Async(ob.clone().unbind()))
119+
} else {
120+
Ok(Self::Sync(ob.extract()?))
121121
}
122122
}
123123
}
@@ -140,8 +140,8 @@ impl PyGcpCredentialProvider {
140140
let credential = self
141141
.call()
142142
.await
143-
.map_err(|err| object_store::Error::Generic {
144-
store: "External GCP credential provider",
143+
.map_err(|err| object_store::Error::Unauthenticated {
144+
path: "External GCP credential provider".to_string(),
145145
source: Box::new(err),
146146
})?;
147147

tests/store/test_s3.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# ruff: noqa: PGH003
22

33
import pickle
4+
from datetime import UTC, datetime
45

56
import pytest
67

78
import obstore as obs
8-
from obstore.exceptions import BaseError
9+
from obstore.exceptions import BaseError, UnauthenticatedError
910
from obstore.store import S3Store, from_url
1011

1112

@@ -97,3 +98,27 @@ def test_config_round_trip():
9798
assert store.prefix == new_store.prefix
9899
assert store.client_options == new_store.client_options
99100
assert store.retry_config == new_store.retry_config
101+
102+
103+
def test_invalid_credential_provider():
104+
"""Test that passing an invalid synchronous credential provider raises an error.
105+
106+
instead of trying to await the value.
107+
"""
108+
109+
def credential_provider():
110+
return {"access_key_id": "str", "expires_at": datetime.now(UTC)}
111+
112+
store = S3Store("bucket", credential_provider=credential_provider) # type: ignore
113+
with pytest.raises(UnauthenticatedError):
114+
obs.list(store).collect()
115+
116+
117+
@pytest.mark.asyncio
118+
async def test_invalid_credential_provider_async():
119+
async def credential_provider():
120+
return {"access_key_id": "str", "expires_at": datetime.now(UTC)}
121+
122+
store = S3Store("bucket", credential_provider=credential_provider) # type: ignore
123+
with pytest.raises(UnauthenticatedError):
124+
await obs.list(store).collect_async()

0 commit comments

Comments
 (0)