Skip to content

Commit 7787e6e

Browse files
authored
Merge pull request #3036 from ProvableHQ/fix/content-type
[Fix] Do not automatically append API version
2 parents 6c9efe3 + f503cb5 commit 7787e6e

File tree

1 file changed

+36
-49
lines changed

1 file changed

+36
-49
lines changed

ledger/query/src/query/rest.rs

Lines changed: 36 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,12 @@ use std::str::FromStr;
3333
#[derive(Clone)]
3434
pub struct RestQuery<N: Network> {
3535
base_url: http::Uri,
36-
/// `true` if the api version is already contained in the base URL.
37-
has_api_version: bool,
3836
_marker: std::marker::PhantomData<N>,
3937
}
4038

4139
impl<N: Network> From<http::Uri> for RestQuery<N> {
4240
fn from(base_url: http::Uri) -> Self {
43-
// Avoid trailing slash when checking the version
44-
let path = base_url.path().strip_suffix('/').unwrap_or(base_url.path());
45-
46-
let has_api_version = path.ends_with(Self::API_V1) || path.ends_with(Self::API_V2);
47-
48-
Self { base_url, has_api_version, _marker: Default::default() }
41+
Self { base_url, _marker: Default::default() }
4942
}
5043
}
5144

@@ -161,9 +154,6 @@ impl<N: Network> QueryTrait<N> for RestQuery<N> {
161154
}
162155

163156
impl<N: Network> RestQuery<N> {
164-
const API_V1: &str = "v1";
165-
const API_V2: &str = "v2";
166-
167157
/// Returns the transaction for the given transaction ID.
168158
pub fn get_transaction(&self, transaction_id: &N::TransactionID) -> Result<Transaction<N>> {
169159
self.get_request(&format!("transaction/{transaction_id}"))
@@ -195,15 +185,12 @@ impl<N: Network> RestQuery<N> {
195185
// This function is only called internally but check for additional sanity.
196186
ensure!(!route.starts_with('/'), "path cannot start with a slash");
197187

198-
// Add the API version if it is not already contained in the base URL.
199-
let api_version = if self.has_api_version { "" } else { "v2/" };
200-
201188
// Work around a bug in the `http` crate where empty paths will be set to '/' but other paths are not appended with a slash.
202189
// See [this issue](https://github.com/hyperium/http/issues/507).
203190
let path = if self.base_url.path().ends_with('/') {
204-
format!("{base_url}{api_version}{network}/{route}", base_url = self.base_url, network = N::SHORT_NAME)
191+
format!("{base_url}{network}/{route}", base_url = self.base_url, network = N::SHORT_NAME)
205192
} else {
206-
format!("{base_url}/{api_version}{network}/{route}", base_url = self.base_url, network = N::SHORT_NAME)
193+
format!("{base_url}/{network}/{route}", base_url = self.base_url, network = N::SHORT_NAME)
207194
};
208195

209196
Ok(path)
@@ -226,16 +213,17 @@ impl<N: Network> RestQuery<N> {
226213
if response.status().is_success() {
227214
response.body_mut().read_json().with_context(|| format!("Failed to parse JSON response from {endpoint}"))
228215
} else {
229-
let content_type = response
216+
// v2 will return the error in JSON format.
217+
let is_json = response
230218
.headers()
231-
.get("Content-Type")
232-
.ok_or_else(|| anyhow!("Endpoint return error without ContentType"))?
233-
.to_str()
234-
.with_context(|| "Endpoint returned invalid ContentType")?;
219+
.get(http::header::CONTENT_TYPE)
220+
.and_then(|ct| ct.to_str().ok())
221+
.map(|ct| ct.contains("json"))
222+
.unwrap_or(false);
235223

236224
// Convert returned error into an `anyhow::Error`.
237225
// Depending on the API version, the error is either encoded as a string or as a JSON.
238-
if content_type.contains("json") {
226+
if is_json {
239227
let error: RestError = response
240228
.body_mut()
241229
.read_json()
@@ -263,12 +251,26 @@ impl<N: Network> RestQuery<N> {
263251
if response.status().is_success() {
264252
response.json().await.with_context(|| format!("Failed to parse JSON response from {endpoint}"))
265253
} else {
266-
// Convert returned error into an `anyhow::Error`.
267-
let error: RestError = response
268-
.json()
269-
.await
270-
.with_context(|| format!("Failed to parse JSON error response from {endpoint}"))?;
271-
Err(error.parse().context(format!("Failed to fetch from {endpoint}")))
254+
// v2 will return the error in JSON format.
255+
let is_json = response
256+
.headers()
257+
.get(http::header::CONTENT_TYPE)
258+
.and_then(|ct| ct.to_str().ok())
259+
.map(|ct| ct.contains("json"))
260+
.unwrap_or(false);
261+
262+
if is_json {
263+
// Convert returned error into an `anyhow::Error`.
264+
let error: RestError = response
265+
.json()
266+
.await
267+
.with_context(|| format!("Failed to parse JSON error response from {endpoint}"))?;
268+
Err(error.parse().context(format!("Failed to fetch from {endpoint}")))
269+
} else {
270+
let error =
271+
response.text().await.with_context(|| format!("Failed to read error message {endpoint}"))?;
272+
Err(anyhow!(error).context(format!("Failed to fetch from {endpoint}")))
273+
}
272274
}
273275
}
274276
}
@@ -300,13 +302,13 @@ mod tests {
300302
let Query::REST(rest) = query else { panic!() };
301303
assert_eq!(rest.base_url.path_and_query().unwrap().to_string(), "/");
302304
assert_eq!(rest.base_url.to_string(), withslash);
303-
assert_eq!(rest.build_endpoint(route)?, format!("{noslash}/v2/testnet/{route}"));
305+
assert_eq!(rest.build_endpoint(route)?, format!("{noslash}/testnet/{route}"));
304306

305307
let query = withslash.parse::<CurrentQuery>().unwrap();
306308
let Query::REST(rest) = query else { panic!() };
307309
assert_eq!(rest.base_url.path_and_query().unwrap().to_string(), "/");
308310
assert_eq!(rest.base_url.to_string(), withslash);
309-
assert_eq!(rest.build_endpoint(route)?, format!("{noslash}/v2/testnet/{route}"));
311+
assert_eq!(rest.build_endpoint(route)?, format!("{noslash}/testnet/{route}"));
310312

311313
Ok(())
312314
}
@@ -323,34 +325,19 @@ mod tests {
323325

324326
#[test]
325327
fn test_rest_url_parse_with_suffix() -> Result<()> {
326-
let base = "http://localhost:3030/a/prefix";
328+
let base = "http://localhost:3030/a/prefix/v2";
327329
let route = "a/route";
328-
let query = base.parse::<CurrentQuery>().unwrap();
329-
let Query::REST(rest_query) = &query else { panic!() };
330-
assert!(!rest_query.has_api_version);
331330

332331
// Test without trailing slash.
332+
let query = base.parse::<CurrentQuery>().unwrap();
333333
let Query::REST(rest) = query else { panic!() };
334-
assert_eq!(rest.build_endpoint(route)?, format!("{base}/v2/testnet/{route}"));
334+
assert_eq!(rest.build_endpoint(route)?, format!("{base}/testnet/{route}"));
335335

336336
// Set again with trailing slash.
337337
let query = format!("{base}/").parse::<CurrentQuery>().unwrap();
338338
let Query::REST(rest) = query else { panic!() };
339-
assert_eq!(rest.build_endpoint(route)?, format!("{base}/v2/testnet/{route}"));
339+
assert_eq!(rest.build_endpoint(route)?, format!("{base}/testnet/{route}"));
340340

341341
Ok(())
342342
}
343-
344-
#[test]
345-
fn test_rest_url_parse_with_api_version() {
346-
let base = "http://localhost:3030/a/prefix/v1";
347-
let query = base.parse::<CurrentQuery>().unwrap();
348-
let Query::REST(rest_query) = &query else { panic!() };
349-
assert!(rest_query.has_api_version);
350-
351-
let base = "http://localhost:3030/a/prefix/v2/";
352-
let query = base.parse::<CurrentQuery>().unwrap();
353-
let Query::REST(rest_query) = &query else { panic!() };
354-
assert!(rest_query.has_api_version);
355-
}
356343
}

0 commit comments

Comments
 (0)