Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/contract-interface/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub mod types {
Bls12381G1PublicKey, Bls12381G2PublicKey, Ed25519PublicKey, PublicKey, Secp256k1PublicKey,
};
pub use primitives::{AccountId, CkdAppId};
pub use updates::{ProposedUpdates, Update, UpdateHash};
pub use updates::{ProposedUpdates, UpdateHash};

mod attestation;
mod config;
Expand Down
29 changes: 3 additions & 26 deletions crates/contract-interface/src/types/updates.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::types::primitives::AccountId;
use borsh::{BorshDeserialize, BorshSerialize};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;

type Sha256Digest = [u8; 32];

/// A vector of proposed updates
#[derive(
Debug,
Clone,
Expand All @@ -22,30 +21,8 @@ type Sha256Digest = [u8; 32];
all(feature = "abi", not(target_arch = "wasm32")),
derive(schemars::JsonSchema)
)]
pub struct ProposedUpdates(pub Vec<Update>);

/// A proposed update
#[derive(
Debug,
Clone,
Eq,
PartialEq,
Ord,
PartialOrd,
Hash,
Serialize,
Deserialize,
BorshSerialize,
BorshDeserialize,
)]
#[cfg_attr(
all(feature = "abi", not(target_arch = "wasm32")),
derive(schemars::JsonSchema)
)]
pub struct Update {
pub update_id: u64,
pub update_hash: UpdateHash,
pub votes: Vec<AccountId>,
pub struct ProposedUpdates {
pub updates: BTreeMap<u64, UpdateHash>,
}

/// An update hash
Expand Down
20 changes: 8 additions & 12 deletions crates/contract/src/dto_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,15 @@ impl IntoInterfaceType<dtos::UpdateHash> for &Update {

impl IntoInterfaceType<dtos::ProposedUpdates> for &ProposedUpdates {
fn into_dto_type(self) -> dtos::ProposedUpdates {
let updates = self
.all_updates()
.iter()
.map(|(update_id, update, votes)| dtos::Update {
update_id: update_id.0,
update_hash: update.into_dto_type(),
votes: votes
.iter()
.map(|account_id| account_id.into_dto_type())
.collect(),
})
let all = self.all_updates();

let updates = all
.updates
.into_iter()
.map(|(update_id, update)| (update_id.0, update))
.collect();
dtos::ProposedUpdates(updates)

dtos::ProposedUpdates { updates }
}
}

Expand Down
214 changes: 152 additions & 62 deletions crates/contract/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -974,18 +974,23 @@ impl MpcContract {
let threshold = self.threshold()?;

let voter = self.voter_or_panic();
let Some(all_votes) = self.proposed_updates.vote(&id, voter) else {
if self.proposed_updates.vote(&id, voter).is_none() {
return Err(InvalidParameters::UpdateNotFound.into());
};
}

// Filter votes to only count current participants. This ensures correctness
// even if the cleanup promise in MpcContract::vote_reshared() fails.
// Filter votes to only count current participants voting for this specific update.
// This ensures correctness even if the cleanup promise in MpcContract::vote_reshared() fails.
let valid_votes_count = running_state
.parameters
.participants()
.participants()
.iter()
.filter(|(id, _, _)| all_votes.contains(id))
.filter(|(account_id, _, _)| {
self.proposed_updates
.vote_by_participant
.get(account_id)
.is_some_and(|voted_id| *voted_id == id)
})
.count();

// Not enough votes from current participants, wait for more.
Expand Down Expand Up @@ -2939,13 +2944,42 @@ mod tests {
expected_votes
}

fn propose_and_vote_code(expected_update_id: u64, contract: &mut MpcContract) -> dtos::Update {
/// Test helper struct that combines update metadata with its votes for convenient comparison.
/// Used to convert BTreeMap-based [`ProposedUpdates`] into a sortable vector format for assertions.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct TestUpdate {
update_id: u64,
update_hash: dtos::UpdateHash,
votes: Vec<dtos::AccountId>,
}

impl TestUpdate {
fn from_proposed_updates(
update_id: u64,
update_hash: dtos::UpdateHash,
proposed_updates: &dtos::ProposedUpdates,
) -> Self {
let votes: Vec<dtos::AccountId> = proposed_updates
.votes
.iter()
.filter(|(_, &uid)| uid == update_id)
.map(|(account, _)| account.clone())
.collect();
TestUpdate {
update_id,
update_hash,
votes,
}
}
}

fn propose_and_vote_code(expected_update_id: u64, contract: &mut MpcContract) -> TestUpdate {
let code: [u8; 1000] = std::array::from_fn(|_| rand::random());
let hash = Sha256::digest(code);
let update = Update::Contract(code.into());
let expected_update_hash = dtos::UpdateHash::Code(hash.into());
let expected_votes = propose_and_vote(contract, update, expected_update_id);
dtos::Update {
TestUpdate {
update_id: expected_update_id,
update_hash: expected_update_hash,
votes: expected_votes,
Expand All @@ -2961,45 +2995,71 @@ mod tests {
assert_eq!(actual_voters, *expected_voters);

let all_updates = proposed_updates.all_updates();
assert_eq!(all_updates.len(), 1);
let (actual_update_id, update, actual_votes) = &all_updates[0];
assert_eq!(*actual_update_id, expected_update_id);
assert!(matches!(update, Update::Contract(_)));
assert_eq!(**actual_votes, *expected_voters);
assert_eq!(all_updates.updates.len(), 1);

assert!(all_updates.updates.contains_key(&expected_update_id));
let update = all_updates.updates.get(&expected_update_id).unwrap();
assert!(matches!(update, dtos::UpdateHash::Code(_)));

let actual_voters: HashSet<_> = all_updates
.votes
.iter()
.filter(|(_, &update_id)| update_id == expected_update_id)
.map(|(account, _)| account.clone())
.collect();
assert_eq!(actual_voters, *expected_voters);
}

fn test_proposed_updates_case_given_state(protocol_contract_state: ProtocolContractState) {
let mut contract = MpcContract::new_from_protocol_state(protocol_contract_state);

assert_eq!(contract.proposed_updates(), dtos::ProposedUpdates(vec![]));
let empty_result = contract.proposed_updates();
assert_eq!(empty_result.votes, BTreeMap::new());
assert_eq!(empty_result.updates, BTreeMap::new());

let code_update = propose_and_vote_code(0, &mut contract);
// Propose and vote for code update
let code_update_id = 0;
let mut code_update = propose_and_vote_code(code_update_id, &mut contract);

let config_update = {
// Propose and vote for config update
let mut config_update = {
let update_config = dummy_config(1);
let hash = Sha256::digest(serde_json::to_vec(&update_config).unwrap());
let expected_update_hash = dtos::UpdateHash::Config(hash.into());

let update = Update::Config(update_config.clone());

let expected_update_id = 1;
let expected_votes = propose_and_vote(&mut contract, update, expected_update_id);
dtos::Update {
update_id: expected_update_id,
update_hash: expected_update_hash,
votes: expected_votes,
let config_hash = Sha256::digest(serde_json::to_vec(&update_config).unwrap());
let config_update_obj = Update::Config(update_config.clone());
let config_update_id = 1;
let config_votes = propose_and_vote(&mut contract, config_update_obj, config_update_id);
TestUpdate {
update_id: config_update_id,
update_hash: dtos::UpdateHash::Config(config_hash.into()),
votes: config_votes,
}
};

// Sort votes for consistent comparison
code_update.votes.sort();
config_update.votes.sort();

let mut expected = vec![code_update, config_update];
// sorting to have consistent order
expected.sort();

let mut res = contract.proposed_updates();
res.0.iter_mut().for_each(|update| update.votes.sort());
let res = contract.proposed_updates();

// Convert result to vector of TestUpdate for comparison
let mut actual: Vec<TestUpdate> = res
.updates
.iter()
.map(|(update_id, update_hash)| {
TestUpdate::from_proposed_updates(*update_id, update_hash.clone(), &res)
})
.collect();

// Sort votes within each update
actual.iter_mut().for_each(|update| update.votes.sort());
// sorting to have consistent order
res.0.sort();
actual.sort();

assert_eq!(dtos::ProposedUpdates(expected), res);
assert_eq!(expected, actual);
}

#[test]
Expand Down Expand Up @@ -3027,37 +3087,60 @@ mod tests {
let participants = running_state.parameters.participants().clone();
let protocol_contract_state = ProtocolContractState::Running(running_state);
let mut contract = MpcContract::new_from_protocol_state(protocol_contract_state);
let expected = propose_and_vote_code(0, &mut contract);

// Propose and vote for code update
let update_id_u64 = 0;
let test_update = propose_and_vote_code(update_id_u64, &mut contract);
let update_id = UpdateId::from(update_id_u64);

for (account_id, _, _) in participants.participants() {
contract
.proposed_updates
.vote(&UpdateId::from(expected.update_id), account_id.clone());
let mut expected_with_participant_vote = expected.clone();
expected_with_participant_vote
.votes
.push(account_id.into_dto_type());
expected_with_participant_vote.votes.sort();
let mut res = contract.proposed_updates();
res.0.iter_mut().for_each(|update| update.votes.sort());
.vote(&update_id, account_id.clone());

let proposed_updates = contract.proposed_updates();
assert_eq!(proposed_updates.updates.len(), 1);
assert_eq!(
res,
dtos::ProposedUpdates(vec![expected_with_participant_vote])
*proposed_updates.updates.get(&update_id.0).unwrap(),
test_update.update_hash
);

// Check that participant vote was added
let mut expected_voters: Vec<_> = test_update.votes.to_vec();
expected_voters.push(account_id.clone().into_dto_type());
let actual_voters: Vec<_> = proposed_updates
.votes
.iter()
.filter(|(_, &uid)| uid == update_id.0)
.map(|(voter, _)| voter.clone())
.collect();
assert_eq!(actual_voters.len(), expected_voters.len());
for voter in &actual_voters {
assert!(expected_voters.contains(voter));
}

// Remove the vote
testing_env!(VMContextBuilder::new()
.signer_account_id(account_id.as_v1_account_id())
.predecessor_account_id(account_id.as_v1_account_id())
.build());

contract.remove_update_vote();
let mut expected_without_participant_vote = expected.clone();
expected_without_participant_vote.votes.sort();
let mut res = contract.proposed_updates();
res.0.iter_mut().for_each(|update| update.votes.sort());
assert_eq!(
res,
dtos::ProposedUpdates(vec![expected_without_participant_vote])
);

let res = contract.proposed_updates();
assert_eq!(res.updates.len(), 1);

// Check that participant vote was removed
let actual_voters: Vec<_> = res
.votes
.iter()
.filter(|(_, &uid)| uid == update_id.0)
.map(|(voter, _)| voter.clone())
.collect();
assert_eq!(actual_voters.len(), test_update.votes.len());
for voter in &actual_voters {
assert!(test_update.votes.contains(voter));
}
}
}

Expand All @@ -3067,10 +3150,13 @@ mod tests {
let running_state = gen_running_state(2);
let protocol_contract_state = ProtocolContractState::Running(running_state);
let mut contract = MpcContract::new_from_protocol_state(protocol_contract_state);
let expected = propose_and_vote_code(0, &mut contract);

// Propose and vote for code update
let update_id = 0;
let test_update = propose_and_vote_code(update_id, &mut contract);

let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let account_id = expected.votes.choose(&mut rng).unwrap();
let account_id = test_update.votes.choose(&mut rng).unwrap();
let account_id: AccountId = account_id.0.parse().unwrap();
testing_env!(VMContextBuilder::new()
.signer_account_id(account_id.as_v1_account_id())
Expand Down Expand Up @@ -3172,9 +3258,16 @@ mod tests {
let mut contract =
MpcContract::new_from_protocol_state(ProtocolContractState::Running(running_state));

// Propose an update with 2 non-participant votes
let dto_update = propose_and_vote_code(0, &mut contract);
let update_id: UpdateId = dto_update.update_id.into();
// Propose an update with 2 non-participant votes (from propose_and_vote)
let update_id_u64 = 0;
let test_update = propose_and_vote_code(update_id_u64, &mut contract);
let update_id: UpdateId = update_id_u64.into();

let non_participants: HashSet<AccountId> = test_update
.votes
.iter()
.map(|dto_id| dto_id.0.parse().unwrap())
.collect();

// Add votes from 2 current participants
let participants = participants.participants();
Expand All @@ -3183,17 +3276,14 @@ mod tests {
contract.proposed_updates.vote(&update_id, p2.clone());

// Sanity check: verify exact set of voters before cleanup
let voters_before: HashSet<_> = contract.proposed_updates.voters().into_iter().collect();
let non_participants: HashSet<_> = dto_update
.votes
.iter()
.map(|dto_id| dto_id.0.parse().unwrap())
.collect();
let expected_voters_before: HashSet<_> = [p1.clone(), p2.clone()]
.into_iter()
.chain(non_participants)
.chain(non_participants.clone())
.collect();
assert_eq!(voters_before, expected_voters_before);

let actual_voters_before: HashSet<_> =
contract.proposed_updates.voters().into_iter().collect();
assert_eq!(actual_voters_before, expected_voters_before);

// verify the update entry reflects participant + non-participant votes
assert_proposed_update_has_expected_voters(
Expand Down
1 change: 1 addition & 0 deletions crates/contract/src/tee/tee_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ impl TeeState {
.get(tls_public_key)
.map(|(node_id, _)| node_id.clone())
}

/// Returns true if the caller has at least one participant entry
/// whose TLS key matches an attested node belonging to the caller account.
///
Expand Down
Loading
Loading