Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 252 additions & 29 deletions crates/lib/src/fee/fee.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
use std::str::FromStr;

use crate::{
constant::{ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION, LAMPORTS_PER_SIGNATURE},
error::KoraError,
token::token::TokenType,
fee::price::PriceModel,
oracle::PriceSource,
token::{
spl_token_2022::Token2022Mint,
token::{TokenType, TokenUtil},
TokenState,
},
transaction::{
ParsedSPLInstructionData, ParsedSPLInstructionType, ParsedSystemInstructionData,
ParsedSystemInstructionType, VersionedTransactionResolved,
Expand All @@ -23,6 +31,17 @@ use solana_sdk::{pubkey::Pubkey, rent::Rent};
use spl_associated_token_account::get_associated_token_address;
use spl_token::state::Account as SplTokenAccountState;

#[derive(Debug, Clone)]
pub struct TotalFeeCalculation {
pub total_fee_lamports: u64,
pub base_fee: u64,
pub account_creation_fee: u64,
pub kora_signature_fee: u64,
pub fee_payer_outflow: u64,
pub payment_instruction_fee: u64,
pub transfer_fee_amount: u64,
}

pub struct FeeConfigUtil {}

impl FeeConfigUtil {
Expand Down Expand Up @@ -115,12 +134,43 @@ impl FeeConfigUtil {
Ok(total_lamports)
}

/// Helper function to check if a token transfer instruction is a payment to Kora
/// Returns Some(token_account_data) if it's a payment, None otherwise
async fn get_payment_instruction_info(
rpc_client: &RpcClient,
destination_address: &Pubkey,
payment_destination: &Pubkey,
skip_missing_accounts: bool,
) -> Result<Option<Box<dyn TokenState + Send + Sync>>, KoraError> {
// Get destination account - handle missing accounts based on skip_missing_accounts
let destination_account =
match CacheUtil::get_account(rpc_client, destination_address, false).await {
Ok(account) => account,
Err(_) if skip_missing_accounts => {
return Ok(None);
}
Err(e) => {
return Err(e);
}
};

let token_program = TokenType::get_token_program_from_owner(&destination_account.owner)?;
let token_account = token_program.unpack_token_account(&destination_account.data)?;

// Check if this is a payment to Kora
if token_account.owner() == *payment_destination {
Ok(Some(token_account))
} else {
Ok(None)
}
}

async fn has_payment_instruction(
resolved_transaction: &mut VersionedTransactionResolved,
rpc_client: &RpcClient,
fee_payer: &Pubkey,
) -> Result<u64, KoraError> {
let payment_destination = &get_config()?.kora.get_payment_address(fee_payer)?;
let payment_destination = get_config()?.kora.get_payment_address(fee_payer)?;

for instruction in resolved_transaction
.get_or_parse_spl_instructions()?
Expand All @@ -130,16 +180,15 @@ impl FeeConfigUtil {
if let ParsedSPLInstructionData::SplTokenTransfer { destination_address, .. } =
instruction
{
let destination_account =
CacheUtil::get_account(rpc_client, destination_address, false).await?;

let token_program =
TokenType::get_token_program_from_owner(&destination_account.owner)?;

let token_account =
token_program.unpack_token_account(&destination_account.data)?;

if token_account.owner() == *payment_destination {
if Self::get_payment_instruction_info(
rpc_client,
destination_address,
&payment_destination,
false, // Don't skip missing accounts for has_payment_instruction
)
.await?
.is_some()
{
return Ok(0);
}
}
Expand All @@ -149,12 +198,79 @@ impl FeeConfigUtil {
Ok(ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION)
}

pub async fn estimate_transaction_fee(
/// Calculate transfer fees for token transfers in the transaction
async fn calculate_transfer_fees(
rpc_client: &RpcClient,
transaction: &mut VersionedTransactionResolved,
fee_payer: &Pubkey,
is_payment_required: bool,
) -> Result<u64, KoraError> {
let config = get_config()?;
let payment_destination = config.kora.get_payment_address(fee_payer)?;

let parsed_spl_instructions = transaction.get_or_parse_spl_instructions()?;

for instruction in parsed_spl_instructions
.get(&ParsedSPLInstructionType::SplTokenTransfer)
.unwrap_or(&vec![])
{
if let ParsedSPLInstructionData::SplTokenTransfer {
mint,
amount,
is_2022,
destination_address,
..
} = instruction
{
// Check if this is a payment to Kora
// Skip if destination account doesn't exist (not a payment to existing Kora account)
if Self::get_payment_instruction_info(
rpc_client,
destination_address,
&payment_destination,
true, // Skip missing accounts for transfer fee calculation
)
.await?
.is_none()
{
continue;
}

if let Some(mint_pubkey) = mint {
// Get mint account to calculate transfer fees
let mint_account =
CacheUtil::get_account(rpc_client, mint_pubkey, true).await?;

let token_program =
TokenType::get_token_program_from_owner(&mint_account.owner)?;
let mint_state = token_program.unpack_mint(mint_pubkey, &mint_account.data)?;

if *is_2022 {
// For Token2022, check for transfer fees
if let Some(token2022_mint) =
mint_state.as_any().downcast_ref::<Token2022Mint>()
{
let current_epoch = rpc_client.get_epoch_info().await?.epoch;

if let Some(fee_amount) =
token2022_mint.calculate_transfer_fee(*amount, current_epoch)
{
return Ok(fee_amount);
}
}
}
}
}
}

Ok(0)
}

async fn estimate_transaction_fee(
rpc_client: &RpcClient,
transaction: &mut VersionedTransactionResolved,
fee_payer: &Pubkey,
is_payment_required: bool,
) -> Result<TotalFeeCalculation, KoraError> {
// Get base transaction fee using resolved transaction to handle lookup tables
let base_fee =
TransactionFeeUtil::get_estimate_fee_resolved(rpc_client, transaction).await?;
Expand Down Expand Up @@ -185,11 +301,106 @@ impl FeeConfigUtil {
0
};

Ok(base_fee
let transfer_fee_config_amount =
FeeConfigUtil::calculate_transfer_fees(rpc_client, transaction, fee_payer).await?;

let total_fee_lamports = base_fee
+ account_creation_fee
+ kora_signature_fee
+ fee_payer_outflow
+ fee_for_payment_instruction)
+ fee_for_payment_instruction
+ transfer_fee_config_amount;

Ok(TotalFeeCalculation {
total_fee_lamports,
base_fee,
account_creation_fee,
kora_signature_fee,
fee_payer_outflow,
payment_instruction_fee: fee_for_payment_instruction,
transfer_fee_amount: transfer_fee_config_amount,
})
}

/// Main entry point for fee calculation with Kora's price model applied
pub async fn estimate_kora_fee(
rpc_client: &RpcClient,
transaction: &mut VersionedTransactionResolved,
fee_payer: &Pubkey,
is_payment_required: bool,
price_source: Option<PriceSource>,
) -> Result<TotalFeeCalculation, KoraError> {
let config = get_config()?;

// Check if the price is free, so that we can return early (and skip expensive RPC calls / estimation)
if matches!(&config.validation.price.model, PriceModel::Free) {
return Ok(TotalFeeCalculation {
total_fee_lamports: 0,
base_fee: 0,
account_creation_fee: 0,
kora_signature_fee: 0,
fee_payer_outflow: 0,
payment_instruction_fee: 0,
transfer_fee_amount: 0,
});
}

// Get the raw transaction fees
let mut fee_calculation =
Self::estimate_transaction_fee(rpc_client, transaction, fee_payer, is_payment_required)
.await?;

// Apply Kora's price model
if let Some(price_source) = price_source {
let adjusted_fee = config
.validation
.price
.get_required_lamports(
Some(rpc_client),
Some(price_source),
fee_calculation.total_fee_lamports,
)
.await?;

// Update the total with the price model applied
fee_calculation.total_fee_lamports = adjusted_fee;
}

Ok(fee_calculation)
}

/// Calculate the fee in a specific token if provided
pub async fn calculate_fee_in_token(
rpc_client: &RpcClient,
fee_in_lamports: u64,
fee_token: Option<&str>,
) -> Result<Option<f64>, KoraError> {
if let Some(fee_token) = fee_token {
let token_mint = Pubkey::from_str(fee_token).map_err(|_| {
KoraError::InvalidTransaction("Invalid fee token mint address".to_string())
})?;

let config = get_config()?;
let validation_config = &config.validation;

if !validation_config.supports_token(fee_token) {
return Err(KoraError::InvalidRequest(format!(
"Token {fee_token} is not supported"
)));
}

let fee_value_in_token = TokenUtil::calculate_lamports_value_in_token(
fee_in_lamports,
&token_mint,
&validation_config.price_source,
rpc_client,
)
.await?;

Ok(Some(fee_value_in_token))
} else {
Ok(None)
}
}

/// Calculate the total outflow (SOL spending) that could occur for a fee payer account in a transaction.
Expand Down Expand Up @@ -991,7 +1202,7 @@ mod tests {
.unwrap();

// Should include base fee (5000) + fee payer outflow (100_000)
assert_eq!(result, 105_000, "Should return base fee + outflow");
assert_eq!(result.total_fee_lamports, 105_000, "Should return base fee + outflow");
}

#[tokio::test]
Expand Down Expand Up @@ -1021,7 +1232,11 @@ mod tests {
.unwrap();

// Should include base fee + kora signature fee since kora signer not in transaction signers
assert_eq!(result, 5000 + LAMPORTS_PER_SIGNATURE, "Should add Kora signature fee");
assert_eq!(
result.total_fee_lamports,
5000 + LAMPORTS_PER_SIGNATURE,
"Should add Kora signature fee"
);
}

#[tokio::test]
Expand Down Expand Up @@ -1055,7 +1270,10 @@ mod tests {

// Should include base fee + fee payer outflow + payment instruction fee
let expected = 5000 + 100_000 + ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION;
assert_eq!(result, expected, "Should include payment instruction fee when required");
assert_eq!(
result.total_fee_lamports, expected,
"Should include payment instruction fee when required"
);
}

#[tokio::test]
Expand Down Expand Up @@ -1104,6 +1322,7 @@ mod tests {
#[tokio::test]
async fn test_can_estimate_transaction_fees_on_transfers_with_uninitialized_atas() {
let _m = ConfigMockBuilder::new().build_and_setup();
let _signer = setup_or_get_test_signer();
let cache_ctx = CacheUtil::get_account_context();
cache_ctx.checkpoint();

Expand All @@ -1112,8 +1331,9 @@ mod tests {
let recipient = Keypair::new(); // This will be a newly generated wallet
let mint = Pubkey::new_unique();

// Mock RPC client that returns base fee
let mocked_rpc_client = RpcMockBuilder::new().with_fee_estimate(5000).build();
// Mock RPC client that returns base fee and handles epoch info
let mocked_rpc_client =
RpcMockBuilder::new().with_fee_estimate(5000).with_epoch_info_mock().build();

// Create ATA creation instruction for recipient (this is what triggers the fee calculation)
let recipient_ata = get_associated_token_address(&recipient.pubkey(), &mint);
Expand All @@ -1137,18 +1357,20 @@ mod tests {
let mut resolved_transaction =
TransactionUtil::new_unsigned_versioned_transaction_resolved(message);

// Setup cache responses for ATA creation instruction:
// 1. Recipient ATA (doesn't exist - AccountNotFound) - this is the case we're testing
// 2. Mint account exists (Ok) - needed to determine token program
// Setup cache responses - correct order based on estimate_transaction_fee execution:
// 1. ATA creation: Recipient ATA (doesn't exist - AccountNotFound) - this is expected
// 2. ATA creation: Mint account exists (Ok) - needed to determine token program
// 3. calculate_transfer_fees: Recipient ATA (doesn't exist - AccountNotFound) → skip
let responses = Arc::new(Mutex::new(VecDeque::from([
Err(KoraError::AccountNotFound(recipient_ata.to_string())), // recipient ATA doesn't exist
Err(KoraError::AccountNotFound(recipient_ata.to_string())), // ATA creation check
Ok(create_mock_spl_mint_account(6)), // mint exists
Err(KoraError::AccountNotFound(recipient_ata.to_string())), // calculate_transfer_fees -> skip
])));

let responses_clone = responses.clone();
cache_ctx
.expect()
.times(2)
.times(3)
.returning(move |_, _, _| responses_clone.lock().unwrap().pop_front().unwrap());

// This should succeed without throwing InternalServerError
Expand All @@ -1162,7 +1384,8 @@ mod tests {

assert!(
result.is_ok(),
"Fee estimation should succeed for transaction with uninitialized ATAs"
"Fee estimation should succeed for transaction with uninitialized ATAs: {:?}",
result.err()
);

let fee = result.unwrap();
Expand All @@ -1172,8 +1395,8 @@ mod tests {
let expected_min_fee = 5000 + expected_ata_rent;

assert_eq!(
fee, expected_min_fee,
"Fee should include base transaction fee plus ATA creation cost. Got: {fee}, Expected at least: {expected_min_fee}"
fee.total_fee_lamports, expected_min_fee,
"Fee should include base transaction fee plus ATA creation cost. Got: {}, Expected at least: {expected_min_fee}", fee.total_fee_lamports
);
}
}
Loading