diff --git a/pallets/subtensor/src/lib.rs b/pallets/subtensor/src/lib.rs index 197cd5f8f..9efaed4c8 100644 --- a/pallets/subtensor/src/lib.rs +++ b/pallets/subtensor/src/lib.rs @@ -2165,6 +2165,20 @@ where Self::get_priority_staking(who, hotkey, *amount_unstaked), ) } + Some(Call::unstake_all { hotkey }) => { + // Fully validate the user input + Self::result_to_validity( + Pallet::::validate_unstake_all(who, hotkey, false), + Self::get_priority_vanilla(), + ) + } + Some(Call::unstake_all_alpha { hotkey }) => { + // Fully validate the user input + Self::result_to_validity( + Pallet::::validate_unstake_all(who, hotkey, true), + Self::get_priority_vanilla(), + ) + } Some(Call::remove_stake_limit { hotkey, netuid, diff --git a/pallets/subtensor/src/staking/stake_utils.rs b/pallets/subtensor/src/staking/stake_utils.rs index b7010185f..84af459f1 100644 --- a/pallets/subtensor/src/staking/stake_utils.rs +++ b/pallets/subtensor/src/staking/stake_utils.rs @@ -995,6 +995,47 @@ impl Pallet { Ok(()) } + /// Validate if unstake_all can be executed + /// + pub fn validate_unstake_all( + coldkey: &T::AccountId, + hotkey: &T::AccountId, + only_alpha: bool, + ) -> Result<(), Error> { + // Get all netuids (filter out root) + let subnets: Vec = Self::get_all_subnet_netuids(); + + // Ensure that the hotkey account exists this is only possible through registration. + ensure!( + Self::hotkey_account_exists(hotkey), + Error::::HotKeyAccountNotExists + ); + + let mut unstaking_any = false; + for netuid in subnets.iter() { + if !SubtokenEnabled::::get(netuid) { + continue; + } + + if only_alpha && (*netuid == Self::get_root_netuid()) { + continue; + } + + // Get user's stake in this subnet + let alpha = Self::get_stake_for_hotkey_and_coldkey_on_subnet(hotkey, coldkey, *netuid); + + if Self::validate_remove_stake(coldkey, hotkey, *netuid, alpha, alpha, false).is_ok() + { + unstaking_any = true; + } + } + + // If no unstaking happens, return error + ensure!(unstaking_any, Error::::AmountTooLow); + + Ok(()) + } + /// Validate stake transition user input /// That works for move_stake, transfer_stake, and swap_stake /// diff --git a/pallets/subtensor/src/tests/batch_tx.rs b/pallets/subtensor/src/tests/batch_tx.rs index 512fa9b36..94480b973 100644 --- a/pallets/subtensor/src/tests/batch_tx.rs +++ b/pallets/subtensor/src/tests/batch_tx.rs @@ -1,5 +1,8 @@ use super::mock::*; -use frame_support::{assert_ok, traits::Currency}; +use frame_support::{ + assert_ok, + traits::{Contains, Currency}, +}; use frame_system::Config; use sp_core::U256; @@ -33,3 +36,29 @@ fn test_batch_txs() { assert_eq!(Balances::total_balance(&charlie), 2_000_000_000); }); } + +#[test] +fn test_cant_nest_batch_txs() { + let _alice = U256::from(0); + let bob = U256::from(1); + let charlie = U256::from(2); + + new_test_ext(1).execute_with(|| { + let call = RuntimeCall::Utility(pallet_utility::Call::batch { + calls: vec![ + RuntimeCall::Balances(BalanceCall::transfer_allow_death { + dest: bob, + value: 1_000_000_000, + }), + RuntimeCall::Utility(pallet_utility::Call::force_batch { + calls: vec![RuntimeCall::Balances(BalanceCall::transfer_allow_death { + dest: charlie, + value: 1_000_000_000, + })], + }), + ], + }); + + assert!(!::BaseCallFilter::contains(&call)); + }); +} diff --git a/pallets/subtensor/src/tests/mock.rs b/pallets/subtensor/src/tests/mock.rs index 221d802cc..d5d302d5c 100644 --- a/pallets/subtensor/src/tests/mock.rs +++ b/pallets/subtensor/src/tests/mock.rs @@ -2,11 +2,12 @@ use crate::utils::rate_limiting::TransactionType; use frame_support::derive_impl; use frame_support::dispatch::DispatchResultWithPostInfo; +use frame_support::traits::{Contains, Everything, InsideBoth}; use frame_support::weights::Weight; use frame_support::weights::constants::RocksDbWeight; use frame_support::{ assert_ok, parameter_types, - traits::{Everything, Hooks, PrivilegeCmp}, + traits::{Hooks, PrivilegeCmp}, }; use frame_system as system; use frame_system::{EnsureNever, EnsureRoot, RawOrigin, limits}; @@ -88,9 +89,31 @@ impl pallet_balances::Config for Test { type MaxFreezes = (); } +pub struct NoNestingCallFilter; + +impl Contains for NoNestingCallFilter { + fn contains(call: &RuntimeCall) -> bool { + match call { + RuntimeCall::Utility(inner) => { + let calls = match inner { + pallet_utility::Call::force_batch { calls } => calls, + pallet_utility::Call::batch { calls } => calls, + pallet_utility::Call::batch_all { calls } => calls, + _ => &Vec::new(), + }; + + !calls.iter().any(|call| { + matches!(call, RuntimeCall::Utility(inner) if matches!(inner, pallet_utility::Call::force_batch { .. } | pallet_utility::Call::batch_all { .. } | pallet_utility::Call::batch { .. })) + }) + } + _ => true, + } + } +} + #[derive_impl(frame_system::config_preludes::TestDefaultConfig)] impl system::Config for Test { - type BaseCallFilter = Everything; + type BaseCallFilter = InsideBoth; type BlockWeights = BlockWeights; type BlockLength = (); type DbWeight = RocksDbWeight; diff --git a/pallets/subtensor/src/tests/staking.rs b/pallets/subtensor/src/tests/staking.rs index 9a672676e..9220fa729 100644 --- a/pallets/subtensor/src/tests/staking.rs +++ b/pallets/subtensor/src/tests/staking.rs @@ -3841,6 +3841,57 @@ fn test_unstake_low_liquidity_validate() { }); } +#[test] +fn test_unstake_all_validate() { + // Testing the signed extension validate function + // correctly filters the `unstake_all` transaction. + + new_test_ext(0).execute_with(|| { + let subnet_owner_coldkey = U256::from(1001); + let subnet_owner_hotkey = U256::from(1002); + let hotkey = U256::from(2); + let coldkey = U256::from(3); + let amount_staked = DefaultMinStake::::get() * 10 + DefaultStakingFee::::get(); + + let netuid = add_dynamic_network(&subnet_owner_hotkey, &subnet_owner_coldkey); + SubtensorModule::create_account_if_non_existent(&coldkey, &hotkey); + SubtensorModule::add_balance_to_coldkey_account(&coldkey, amount_staked); + + // Simulate stake for hotkey + SubnetTAO::::insert(netuid, u64::MAX / 1000); + SubnetAlphaIn::::insert(netuid, u64::MAX / 1000); + SubtensorModule::stake_into_subnet(&hotkey, &coldkey, netuid, amount_staked, 0); + + // Set the liquidity at lowest possible value so that all staking requests fail + SubnetTAO::::insert( + netuid, + DefaultMinimumPoolLiquidity::::get().to_num::(), + ); + SubnetAlphaIn::::insert( + netuid, + DefaultMinimumPoolLiquidity::::get().to_num::(), + ); + + // unstake_all call + let call = RuntimeCall::SubtensorModule(SubtensorCall::unstake_all { hotkey }); + + let info: DispatchInfo = + DispatchInfoOf::<::RuntimeCall>::default(); + + let extension = SubtensorSignedExtension::::new(); + // Submit to the signed extension validate function + let result_no_stake = extension.validate(&coldkey, &call.clone(), &info, 10); + + // Should fail due to insufficient stake + assert_err!( + result_no_stake, + TransactionValidityError::Invalid(InvalidTransaction::Custom( + CustomTransactionError::StakeAmountTooLow.into() + )) + ); + }); +} + #[test] fn test_max_amount_add_root() { new_test_ext(0).execute_with(|| { diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 64937fd3b..2a6ce1ab8 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -12,7 +12,7 @@ pub mod check_nonce; mod migrations; use codec::{Compact, Decode, Encode}; -use frame_support::traits::Imbalance; +use frame_support::traits::{Imbalance, InsideBoth}; use frame_support::{ PalletId, dispatch::DispatchResultWithPostInfo, @@ -209,7 +209,7 @@ pub const VERSION: RuntimeVersion = RuntimeVersion { // `spec_version`, and `authoring_version` are the same between Wasm and native. // This value is set to 100 to notify Polkadot-JS App (https://polkadot.js.org/apps) to use // the compatible custom types. - spec_version: 271, + spec_version: 272, impl_version: 1, apis: RUNTIME_API_VERSIONS, transaction_version: 1, @@ -244,11 +244,33 @@ parameter_types! { pub const SS58Prefix: u8 = 42; } +pub struct NoNestingCallFilter; + +impl Contains for NoNestingCallFilter { + fn contains(call: &RuntimeCall) -> bool { + match call { + RuntimeCall::Utility(inner) => { + let calls = match inner { + pallet_utility::Call::force_batch { calls } => calls, + pallet_utility::Call::batch { calls } => calls, + pallet_utility::Call::batch_all { calls } => calls, + _ => &Vec::new(), + }; + + !calls.iter().any(|call| { + matches!(call, RuntimeCall::Utility(inner) if matches!(inner, pallet_utility::Call::force_batch { .. } | pallet_utility::Call::batch_all { .. } | pallet_utility::Call::batch { .. })) + }) + } + _ => true, + } + } +} + // Configure FRAME pallets to include in runtime. impl frame_system::Config for Runtime { // The basic call filter to use in dispatchable. - type BaseCallFilter = SafeMode; + type BaseCallFilter = InsideBoth; // Block & extrinsics weights: base values and limits. type BlockWeights = BlockWeights; // The maximum length of a block (in bytes).