Skip to content

Commit 7bf4fa9

Browse files
authored
feat: allow any spl paid token flag (#175)
* feat: allow any spl paid token flag * move spl paid token checks into a function, rebase, use mock rpc client * switch to enum in allowed_spl_paid_tokens * remove unused option from demo config * rename params in `with`, add test for atas init when all payment tokens are allowed * use mock client, make sure that signer exists in test initialize atas * rename token config to make it more generic
1 parent 088e93d commit 7bf4fa9

File tree

10 files changed

+230
-68
lines changed

10 files changed

+230
-68
lines changed

crates/lib/src/admin/token_util.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use {crate::cache::CacheUtil, crate::state::get_config};
2222

2323
#[cfg(test)]
2424
use {
25-
crate::tests::config_mock::mock_state::get_config,
25+
crate::config::SplTokenConfig, crate::tests::config_mock::mock_state::get_config,
2626
crate::tests::redis_cache_mock::MockCacheUtil as CacheUtil,
2727
};
2828

@@ -289,7 +289,9 @@ mod tests {
289289
async fn test_find_missing_atas_no_spl_tokens() {
290290
let _m = ConfigMockBuilder::new()
291291
.with_validation(
292-
ValidationConfigBuilder::new().with_allowed_spl_paid_tokens(vec![]).build(),
292+
ValidationConfigBuilder::new()
293+
.with_allowed_spl_paid_tokens(SplTokenConfig::Allowlist(vec![]))
294+
.build(),
293295
)
294296
.build_and_setup();
295297

@@ -308,9 +310,9 @@ mod tests {
308310
let _m = ConfigMockBuilder::new()
309311
.with_validation(
310312
ValidationConfigBuilder::new()
311-
.with_allowed_spl_paid_tokens(
313+
.with_allowed_spl_paid_tokens(SplTokenConfig::Allowlist(
312314
allowed_spl_tokens.iter().map(|p| p.to_string()).collect(),
313-
)
315+
))
314316
.build(),
315317
)
316318
.build_and_setup();
@@ -397,4 +399,19 @@ mod tests {
397399
}
398400
}
399401
}
402+
403+
#[tokio::test]
404+
async fn test_initialize_atas_when_all_tokens_are_allowed() {
405+
let _m = ConfigMockBuilder::new()
406+
.with_allowed_spl_paid_tokens(SplTokenConfig::All)
407+
.build_and_setup();
408+
409+
let _ = setup_or_get_test_signer();
410+
411+
let rpc_client = RpcMockBuilder::new().build();
412+
413+
let result = initialize_atas(&rpc_client, None, None, None, None).await;
414+
415+
assert!(result.is_ok(), "Expected atas init to succeed");
416+
}
400417
}

crates/lib/src/config.rs

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,61 @@ impl Default for FeePayerBalanceMetricsConfig {
5858
}
5959
}
6060

61+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62+
pub enum SplTokenConfig {
63+
All,
64+
#[serde(untagged)]
65+
Allowlist(Vec<String>),
66+
}
67+
68+
impl Default for SplTokenConfig {
69+
fn default() -> Self {
70+
SplTokenConfig::Allowlist(vec![])
71+
}
72+
}
73+
74+
impl<'a> IntoIterator for &'a SplTokenConfig {
75+
type Item = &'a String;
76+
type IntoIter = std::slice::Iter<'a, String>;
77+
78+
fn into_iter(self) -> Self::IntoIter {
79+
match self {
80+
SplTokenConfig::All => [].iter(),
81+
SplTokenConfig::Allowlist(tokens) => tokens.iter(),
82+
}
83+
}
84+
}
85+
86+
impl SplTokenConfig {
87+
pub fn has_token(&self, token: &str) -> bool {
88+
match self {
89+
SplTokenConfig::All => true,
90+
SplTokenConfig::Allowlist(tokens) => tokens.iter().any(|s| s == token),
91+
}
92+
}
93+
94+
pub fn has_tokens(&self) -> bool {
95+
match self {
96+
SplTokenConfig::All => true,
97+
SplTokenConfig::Allowlist(tokens) => !tokens.is_empty(),
98+
}
99+
}
100+
101+
pub fn as_slice(&self) -> &[String] {
102+
match self {
103+
SplTokenConfig::All => &[],
104+
SplTokenConfig::Allowlist(v) => v.as_slice(),
105+
}
106+
}
107+
}
108+
61109
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
62110
pub struct ValidationConfig {
63111
pub max_allowed_lamports: u64,
64112
pub max_signatures: u64,
65113
pub allowed_programs: Vec<String>,
66114
pub allowed_tokens: Vec<String>,
67-
pub allowed_spl_paid_tokens: Vec<String>,
115+
pub allowed_spl_paid_tokens: SplTokenConfig,
68116
pub disallowed_accounts: Vec<String>,
69117
pub price_source: PriceSource,
70118
#[serde(default)] // Default for backward compatibility
@@ -81,7 +129,7 @@ impl ValidationConfig {
81129
}
82130

83131
pub fn supports_token(&self, token: &str) -> bool {
84-
self.allowed_spl_paid_tokens.iter().any(|s| s == token)
132+
self.allowed_spl_paid_tokens.has_token(token)
85133
}
86134
}
87135

@@ -378,7 +426,7 @@ mod tests {
378426
let config = ConfigBuilder::new()
379427
.with_programs(vec!["program1", "program2"])
380428
.with_tokens(vec!["token1", "token2"])
381-
.with_spl_paid_tokens(vec!["token3"])
429+
.with_spl_paid_tokens(SplTokenConfig::Allowlist(vec!["token3".to_string()]))
382430
.with_disallowed_accounts(vec!["account1"])
383431
.build_config()
384432
.unwrap();
@@ -387,7 +435,10 @@ mod tests {
387435
assert_eq!(config.validation.max_signatures, 10);
388436
assert_eq!(config.validation.allowed_programs, vec!["program1", "program2"]);
389437
assert_eq!(config.validation.allowed_tokens, vec!["token1", "token2"]);
390-
assert_eq!(config.validation.allowed_spl_paid_tokens, vec!["token3"]);
438+
assert_eq!(
439+
config.validation.allowed_spl_paid_tokens,
440+
SplTokenConfig::Allowlist(vec!["token3".to_string()])
441+
);
391442
assert_eq!(config.validation.disallowed_accounts, vec!["account1"]);
392443
assert_eq!(config.validation.price_source, PriceSource::Jupiter);
393444
assert_eq!(config.kora.rate_limit, 100);
@@ -400,7 +451,7 @@ mod tests {
400451
let config = ConfigBuilder::new()
401452
.with_programs(vec!["program1", "program2"])
402453
.with_tokens(vec!["token1", "token2"])
403-
.with_spl_paid_tokens(vec!["token3"])
454+
.with_spl_paid_tokens(SplTokenConfig::Allowlist(vec!["token3".to_string()]))
404455
.with_disallowed_accounts(vec!["account1"])
405456
.with_enabled_methods(&[
406457
("liveness", true),
@@ -441,6 +492,14 @@ mod tests {
441492
assert!(result.is_err());
442493
}
443494

495+
#[test]
496+
fn test_parse_spl_payment_config() {
497+
let config =
498+
ConfigBuilder::new().with_spl_paid_tokens(SplTokenConfig::All).build_config().unwrap();
499+
500+
assert_eq!(config.validation.allowed_spl_paid_tokens, SplTokenConfig::All);
501+
}
502+
444503
#[test]
445504
fn test_parse_margin_price_config() {
446505
let config = ConfigBuilder::new().with_margin_price(0.1).build_config().unwrap();

crates/lib/src/oracle/jupiter.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,4 +246,43 @@ mod tests {
246246
assert_eq!(price.price, 1.0);
247247
assert_eq!(price.source, PriceSource::Jupiter);
248248
}
249+
250+
#[tokio::test]
251+
async fn test_jupiter_price_fetch_when_no_price_data() {
252+
// No API key
253+
{
254+
let mut api_key_guard = GLOBAL_JUPITER_API_KEY.write();
255+
*api_key_guard = None;
256+
}
257+
258+
let mock_response = r#"{
259+
"So11111111111111111111111111111111111111112": {
260+
"usdPrice": 100.0,
261+
"blockId": 12345,
262+
"decimals": 9,
263+
"priceChange24h": 2.5
264+
}
265+
}"#;
266+
let mut server = Server::new_async().await;
267+
let _m = server
268+
.mock("GET", "/price/v3")
269+
.match_query(Matcher::Any)
270+
.with_status(200)
271+
.with_header("content-type", "application/json")
272+
.with_body(mock_response)
273+
.create();
274+
275+
let client = Client::new();
276+
// Test without API key - should use lite API
277+
let mut oracle = JupiterPriceOracle::new();
278+
oracle.lite_api_url = format!("{}/price/v3", server.url());
279+
280+
let result = oracle.get_price(&client, "JUPyiwrYJFskUPiHa7hkeR8VUtAeFoSYbKedZNsDvCN").await;
281+
282+
assert!(result.is_err());
283+
assert_eq!(
284+
result.err(),
285+
Some(KoraError::RpcError("No price data from Jupiter".to_string()))
286+
)
287+
}
249288
}

crates/lib/src/rpc_server/method/get_config.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ mod tests {
8080
response.validation_config.allowed_tokens[0],
8181
"4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU"
8282
); // USDC devnet
83-
assert_eq!(response.validation_config.allowed_spl_paid_tokens.len(), 1);
83+
assert_eq!(response.validation_config.allowed_spl_paid_tokens.as_slice().len(), 1);
8484
assert_eq!(
85-
response.validation_config.allowed_spl_paid_tokens[0],
85+
response.validation_config.allowed_spl_paid_tokens.as_slice()[0],
8686
"4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU"
8787
); // USDC devnet
8888
assert_eq!(response.validation_config.disallowed_accounts.len(), 0);

crates/lib/src/tests/config_mock.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use crate::{
22
config::{
33
AuthConfig, CacheConfig, Config, EnabledMethods, FeePayerBalanceMetricsConfig,
4-
FeePayerPolicy, KoraConfig, MetricsConfig, Token2022Config, ValidationConfig,
4+
FeePayerPolicy, KoraConfig, MetricsConfig, SplTokenConfig, Token2022Config,
5+
ValidationConfig,
56
},
67
fee::price::PriceConfig,
78
oracle::PriceSource,
@@ -85,9 +86,9 @@ impl ConfigMockBuilder {
8586
allowed_tokens: vec![
8687
"4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU".parse().unwrap(), // USDC devnet
8788
],
88-
allowed_spl_paid_tokens: vec![
89+
allowed_spl_paid_tokens: SplTokenConfig::Allowlist(vec![
8990
"4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU".parse().unwrap(), // USDC devnet
90-
],
91+
]),
9192
disallowed_accounts: vec![],
9293
price_source: PriceSource::Mock,
9394
fee_payer_policy: FeePayerPolicy::default(),
@@ -160,6 +161,11 @@ impl ConfigMockBuilder {
160161
self
161162
}
162163

164+
pub fn with_allowed_spl_paid_tokens(mut self, spl_payment_config: SplTokenConfig) -> Self {
165+
self.config.validation.allowed_spl_paid_tokens = spl_payment_config;
166+
self
167+
}
168+
163169
pub fn with_payment_address(mut self, payment_address: Option<String>) -> Self {
164170
self.config.kora.payment_address = payment_address;
165171
self
@@ -232,7 +238,7 @@ impl ValidationConfigBuilder {
232238
max_signatures: 10,
233239
allowed_programs: vec![],
234240
allowed_tokens: vec![],
235-
allowed_spl_paid_tokens: vec![],
241+
allowed_spl_paid_tokens: SplTokenConfig::Allowlist(vec![]),
236242
disallowed_accounts: vec![],
237243
price_source: PriceSource::Mock,
238244
fee_payer_policy: FeePayerPolicy::default(),
@@ -261,8 +267,8 @@ impl ValidationConfigBuilder {
261267
self
262268
}
263269

264-
pub fn with_allowed_spl_paid_tokens(mut self, tokens: Vec<String>) -> Self {
265-
self.config.allowed_spl_paid_tokens = tokens;
270+
pub fn with_allowed_spl_paid_tokens(mut self, spl_payment_config: SplTokenConfig) -> Self {
271+
self.config.allowed_spl_paid_tokens = spl_payment_config;
266272
self
267273
}
268274

crates/lib/src/tests/toml_mock.rs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use std::fs;
22
use tempfile::NamedTempFile;
33

4-
use crate::{config::Config, error::KoraError};
4+
use crate::{
5+
config::{Config, SplTokenConfig},
6+
error::KoraError,
7+
};
58

69
/// TOML-specific configuration builder for testing TOML parsing and serialization
710
///
@@ -24,7 +27,7 @@ struct ValidationSection {
2427
max_signatures: u64,
2528
allowed_programs: Vec<String>,
2629
allowed_tokens: Vec<String>,
27-
allowed_spl_paid_tokens: Vec<String>,
30+
allowed_spl_paid_tokens: SplTokenConfig,
2831
disallowed_accounts: Vec<String>,
2932
price_source: String,
3033
price_config: Option<String>,
@@ -44,7 +47,7 @@ impl Default for ValidationSection {
4447
max_signatures: 10,
4548
allowed_programs: vec!["program1".to_string()],
4649
allowed_tokens: vec!["token1".to_string()],
47-
allowed_spl_paid_tokens: vec!["token2".to_string()],
50+
allowed_spl_paid_tokens: SplTokenConfig::Allowlist(vec!["token2".to_string()]),
4851
disallowed_accounts: vec![],
4952
price_source: "Jupiter".to_string(),
5053
price_config: None,
@@ -74,8 +77,8 @@ impl ConfigBuilder {
7477
self
7578
}
7679

77-
pub fn with_spl_paid_tokens(mut self, tokens: Vec<&str>) -> Self {
78-
self.validation.allowed_spl_paid_tokens = tokens.iter().map(|s| s.to_string()).collect();
80+
pub fn with_spl_paid_tokens(mut self, spl_payment_config: SplTokenConfig) -> Self {
81+
self.validation.allowed_spl_paid_tokens = spl_payment_config;
7982
self
8083
}
8184

@@ -170,13 +173,13 @@ impl ConfigBuilder {
170173
.collect::<Vec<_>>()
171174
.join(", ");
172175

173-
let spl_tokens_list = self
174-
.validation
175-
.allowed_spl_paid_tokens
176-
.iter()
177-
.map(|t| format!("\"{t}\""))
178-
.collect::<Vec<_>>()
179-
.join(", ");
176+
let spl_tokens_config = match self.validation.allowed_spl_paid_tokens {
177+
SplTokenConfig::Allowlist(ref tokens) => format!(
178+
"[{}]",
179+
tokens.iter().map(|t| format!("\"{t}\"")).collect::<Vec<_>>().join(", ")
180+
),
181+
SplTokenConfig::All => format!("\"{}\"", "All"),
182+
};
180183

181184
let disallowed_list = if self.validation.disallowed_accounts.is_empty() {
182185
"[]".to_string()
@@ -198,14 +201,14 @@ impl ConfigBuilder {
198201
max_signatures = {}\n\
199202
allowed_programs = [{}]\n\
200203
allowed_tokens = [{}]\n\
201-
allowed_spl_paid_tokens = [{}]\n\
204+
allowed_spl_paid_tokens = {}\n\
202205
disallowed_accounts = {}\n\
203206
price_source = \"{}\"\n\n",
204207
self.validation.max_allowed_lamports,
205208
self.validation.max_signatures,
206209
programs_list,
207210
tokens_list,
208-
spl_tokens_list,
211+
spl_tokens_config,
209212
disallowed_list,
210213
self.validation.price_source
211214
);

crates/lib/src/token/token.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ impl TokenUtil {
326326
(token_state.mint(), actual_amount)
327327
};
328328

329-
if !config.validation.allowed_spl_paid_tokens.contains(&mint_address.to_string()) {
329+
if !config.validation.supports_token(&mint_address.to_string()) {
330330
return Ok(false);
331331
}
332332

crates/lib/src/transaction/versioned_transaction.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ impl LookupTableUtil {
377377
#[cfg(test)]
378378
mod tests {
379379
use crate::{
380+
config::SplTokenConfig,
380381
tests::{
381382
common::RpcMockBuilder, config_mock::mock_state::setup_config_mock,
382383
toml_mock::ConfigBuilder,
@@ -403,7 +404,7 @@ mod tests {
403404
ConfigBuilder::new()
404405
.with_programs(vec![])
405406
.with_tokens(vec![])
406-
.with_spl_paid_tokens(vec![])
407+
.with_spl_paid_tokens(SplTokenConfig::Allowlist(vec![]))
407408
.with_free_price()
408409
.with_cache_config(None, false, 60, 30) // Disable cache for tests
409410
.build_config()

0 commit comments

Comments
 (0)