Skip to content

Commit f4a94aa

Browse files
committed
feat(mirrors)!: added ratelimiter (hopefully less bans)
chore!: updated User-Agent to keep track of old users chore: limited "concurrent_downloads" up to 6
1 parent cb7b949 commit f4a94aa

File tree

11 files changed

+95
-49
lines changed

11 files changed

+95
-49
lines changed

src/collector/mod.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,6 @@ pub struct CollectionUploader {
3333
pub username: String,
3434
}
3535

36-
#[derive(Clone, Debug, Deserialize)]
37-
struct CollectionDate {}
38-
39-
#[derive(Clone, Debug, Deserialize)]
40-
struct CollectionModes {}
41-
4236
#[derive(Clone, Debug, Deserialize)]
4337
#[serde(rename_all = "lowercase")]
4438
enum Gamemode {

src/config.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ pub fn init() -> Config {
9090
let contents = fs::read_to_string("config.toml").expect("config.toml doesn't exist!");
9191
let mut config = toml::from_str::<Config>(&contents).unwrap();
9292

93+
if config.user.concurrent_downloads > 6 {
94+
panic!("It's highly recommended, that you won't use more than 6 \"threads\" to download maps, otherwise you will get banned from mirrors.");
95+
}
96+
9397
let osu_path = osu::find_game().unwrap();
9498
config.osu.songs_path = format!("{}\\Songs", osu_path);
9599
config.osu.collection_path = format!("{}\\collection.db", osu_path);

src/main.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::{
77
};
88

99
use clap::Parser;
10+
use mirrors::Ratelimiter;
1011
use osu_db::CollectionList;
1112
use sanitise_file_name::sanitise;
1213
use tokio::sync::{RwLock, Semaphore};
@@ -69,6 +70,7 @@ async fn main() {
6970
);
7071

7172
let downloaded = Arc::new(AtomicI32::new(1));
73+
let rate_limiter = Arc::new(Ratelimiter::default());
7274
let semaphore = Arc::new(Semaphore::new(CONFIG.user.concurrent_downloads));
7375

7476
for beatmapset in remote_collection_info.beatmapsets {
@@ -105,11 +107,13 @@ async fn main() {
105107
let downloaded = Arc::clone(&downloaded);
106108
let remote_collection_beatmaps = Arc::clone(&remote_collection_beatmaps);
107109
let semaphore = Arc::clone(&semaphore);
110+
let rate_limiter = Arc::clone(&rate_limiter);
108111

109112
tokio::task::spawn(async move {
110113
let _permit = semaphore.acquire().await.unwrap();
114+
let _rate_limiter: &Ratelimiter = &rate_limiter;
111115

112-
match mirror.get_file(beatmapset.id).await {
116+
match mirror.get_file(beatmapset.id, _rate_limiter).await {
113117
Ok(bytes) => {
114118
let beatmapset_entity = &remote_collection_beatmaps
115119
.beatmapsets
@@ -143,6 +147,8 @@ async fn main() {
143147
.find(|b| b.checksum == beatmap.checksum)
144148
.unwrap();
145149

150+
info!("requests left: {}", _rate_limiter.info.read().await.remaining);
151+
146152
info!(
147153
"({}/{}) {} - {} [{}]",
148154
downloaded.load(Ordering::SeqCst),

src/mirrors/beatconnect.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use serde::Deserialize;
22

3-
use super::Mirror;
3+
use super::{Mirror, Ratelimiter};
44

55
#[derive(Deserialize)]
66
pub struct Beatconnect;
@@ -15,15 +15,19 @@ impl Mirror for Beatconnect {
1515
"https://beatconnect.io/b"
1616
}
1717

18-
async fn get_file(&self, id: i32) -> Result<Vec<u8>, String> {
18+
async fn get_file(&self, id: i32, rate_limit: &Ratelimiter) -> Result<Vec<u8>, String> {
19+
rate_limit.wait_if_needed().await;
20+
1921
let client = reqwest::Client::new();
2022
let response = client
2123
.get(format!("{}/{}", self.get_base_url(), id))
22-
.header("User-Agent", "shockpast/osu-collector-cli: 1.0.0")
24+
.header("User-Agent", "shockpast/ecstasy: 1.1.2")
2325
.send()
2426
.await
2527
.unwrap();
2628

29+
rate_limit.update_rate_limit(response.headers()).await;
30+
2731
let content_type = response
2832
.headers()
2933
.get(reqwest::header::CONTENT_TYPE)

src/mirrors/catboy.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use serde::Deserialize;
22

3-
use super::Mirror;
3+
use super::{Mirror, Ratelimiter};
44

55
#[derive(Debug, Clone, Deserialize)]
66
struct ErrorResponse {
@@ -20,15 +20,19 @@ impl Mirror for Catboy {
2020
"https://catboy.best/d"
2121
}
2222

23-
async fn get_file(&self, id: i32) -> Result<Vec<u8>, String> {
23+
async fn get_file(&self, id: i32, rate_limit: &Ratelimiter) -> Result<Vec<u8>, String> {
24+
rate_limit.wait_if_needed().await;
25+
2426
let client = reqwest::Client::new();
2527
let response = client
2628
.get(format!("{}/{}", self.get_base_url(), id))
27-
.header("User-Agent", "shockpast/osu-collector-cli: 1.0.0")
29+
.header("User-Agent", "shockpast/ecstasy: 1.1.2")
2830
.send()
2931
.await
3032
.unwrap();
3133

34+
rate_limit.update_rate_limit(response.headers()).await;
35+
3236
let content_type = response
3337
.headers()
3438
.get(reqwest::header::CONTENT_TYPE)

src/mirrors/mod.rs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,52 @@
1+
use std::{sync::Arc, time::{Duration, Instant}};
2+
3+
use reqwest::header::HeaderMap;
4+
use tokio::sync::RwLock;
5+
use tracing::warn;
6+
17
pub mod beatconnect;
28
pub mod catboy;
39
pub mod nerinyan;
410
pub mod osudirect;
511
pub mod sayobot;
612

13+
#[derive(Default)]
14+
pub struct RatelimitInfo {
15+
pub remaining: u32,
16+
pub reset_at: Option<Instant>,
17+
}
18+
19+
#[derive(Default)]
20+
pub struct Ratelimiter {
21+
pub info: Arc<RwLock<RatelimitInfo>>,
22+
}
23+
24+
impl Ratelimiter {
25+
async fn wait_if_needed(&self) {
26+
if let Some(reset_at) = self.info.read().await.reset_at {
27+
if reset_at > Instant::now() {
28+
let wait_duration = reset_at.duration_since(Instant::now());
29+
warn!("You've hit an rate-limit, chill out, and wait for 60 seconds.");
30+
31+
tokio::time::sleep(wait_duration).await;
32+
}
33+
}
34+
}
35+
36+
async fn update_rate_limit(&self, headers: &HeaderMap) {
37+
if let Some(remaining) = headers.get("x-ratelimit-remaining").and_then(|v| v.to_str().ok()) {
38+
self.info.write().await.remaining = remaining.parse().unwrap_or(0);
39+
}
40+
41+
if self.info.read().await.remaining <= 1 {
42+
self.info.write().await.reset_at = Some(Instant::now() + Duration::from_secs(70));
43+
}
44+
}
45+
}
46+
747
#[async_trait::async_trait]
848
pub trait Mirror {
949
fn get_name(&self) -> &'static str;
1050
fn get_base_url(&self) -> &'static str;
11-
async fn get_file(&self, id: i32) -> Result<Vec<u8>, String>;
12-
}
51+
async fn get_file(&self, id: i32, rate_limiter: &Ratelimiter) -> Result<Vec<u8>, String>;
52+
}

src/mirrors/nerinyan.rs

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
use reqwest::StatusCode;
21
use serde::Deserialize;
32

4-
use super::Mirror;
3+
use super::{Mirror, Ratelimiter};
54

65
#[derive(Deserialize)]
76
pub struct Nerinyan;
@@ -16,43 +15,28 @@ impl Mirror for Nerinyan {
1615
"https://api.nerinyan.moe/d"
1716
}
1817

19-
async fn get_file(&self, id: i32) -> Result<Vec<u8>, String> {
18+
async fn get_file(&self, id: i32, rate_limit: &Ratelimiter) -> Result<Vec<u8>, String> {
19+
rate_limit.wait_if_needed().await;
20+
2021
let client = reqwest::Client::new();
2122
let response = client
2223
.get(format!("{}/{}", self.get_base_url(), id))
23-
.header("User-Agent", "shockpast/osu-collector-cli: 1.0.0")
24+
.header("User-Agent", "shockpast/ecstasy: 1.1.2")
2425
.send()
2526
.await
2627
.unwrap();
2728

29+
rate_limit.update_rate_limit(response.headers()).await;
30+
2831
let content_type = response
2932
.headers()
3033
.get(reqwest::header::CONTENT_TYPE)
3134
.and_then(|v| v.to_str().ok())
3235
.map(|s| s.to_string())
3336
.unwrap_or_default();
34-
let status_code = response.status();
3537

3638
let bytes = response.bytes().await.map_err(|e| e.to_string())?;
3739

38-
if status_code.is_client_error() {
39-
match status_code {
40-
StatusCode::FORBIDDEN => {
41-
return Err(format!("{} possibly banned us.", self.get_name()));
42-
}
43-
_ => todo!(),
44-
};
45-
}
46-
47-
if status_code.is_server_error() {
48-
match status_code {
49-
StatusCode::BAD_GATEWAY => {
50-
panic!("{} is down, consider using other mirror.", self.get_name());
51-
}
52-
_ => todo!(),
53-
};
54-
}
55-
5640
if content_type.contains("application/json") {
5741
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&bytes) {
5842
return Err(json.to_string());

src/mirrors/osudirect.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use serde::Deserialize;
22

3-
use super::Mirror;
3+
use super::{Mirror, Ratelimiter};
44

55
#[derive(Debug, Clone, Deserialize)]
66
struct ErrorResponse {
@@ -20,15 +20,19 @@ impl Mirror for OsuDirect {
2020
"https://osu.direct/api/d"
2121
}
2222

23-
async fn get_file(&self, id: i32) -> Result<Vec<u8>, String> {
23+
async fn get_file(&self, id: i32, rate_limit: &Ratelimiter) -> Result<Vec<u8>, String> {
24+
rate_limit.wait_if_needed().await;
25+
2426
let client = reqwest::Client::new();
2527
let response = client
2628
.get(format!("{}/{}", self.get_base_url(), id))
27-
.header("User-Agent", "shockpast/osu-collector-cli: 1.0.0")
29+
.header("User-Agent", "shockpast/ecstasy: 1.1.2")
2830
.send()
2931
.await
3032
.unwrap();
3133

34+
rate_limit.update_rate_limit(response.headers()).await;
35+
3236
let content_type = response
3337
.headers()
3438
.get(reqwest::header::CONTENT_TYPE)

src/mirrors/sayobot.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use serde::Deserialize;
22

3-
use super::Mirror;
3+
use super::{Mirror, Ratelimiter};
44

55
#[derive(Debug, Clone, Deserialize)]
66
struct ErrorResponse {
@@ -20,15 +20,19 @@ impl Mirror for Sayobot {
2020
"https://txy1.sayobot.cn/beatmaps/download/full"
2121
}
2222

23-
async fn get_file(&self, id: i32) -> Result<Vec<u8>, String> {
23+
async fn get_file(&self, id: i32, rate_limit: &Ratelimiter) -> Result<Vec<u8>, String> {
24+
rate_limit.wait_if_needed().await;
25+
2426
let client = reqwest::Client::new();
2527
let response = client
26-
.get(format!("{}/{}?server=auto", self.get_base_url(), id))
27-
.header("User-Agent", "shockpast/osu-collector-cli: 1.0.0")
28+
.get(format!("{}/{}", self.get_base_url(), id))
29+
.header("User-Agent", "shockpast/ecstasy: 1.1.2")
2830
.send()
2931
.await
3032
.unwrap();
3133

34+
rate_limit.update_rate_limit(response.headers()).await;
35+
3236
let content_type = response
3337
.headers()
3438
.get(reqwest::header::CONTENT_TYPE)

src/utilities/osu.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pub fn find_game() -> Result<String, std::io::Error> {
4343
}
4444

4545
#[cfg(target_os = "linux")]
46-
pub async fn find_game() -> Option<&'static str> {
47-
todo!("Linux is currently not supported, since we don't have any experience with it.")
46+
pub async fn find_game() -> Result<String, std::io::Error> {
47+
let path = std::env::var("OSU_FOLDER").expect("'OSU_FOLDER' export is not defined (e.g.: '$HOME\\osu')");
48+
Ok(path)
4849
}

0 commit comments

Comments
 (0)