Skip to content
2 changes: 1 addition & 1 deletion .github/workflows/coverage-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
uses: zgosalvez/github-actions-report-lcov@v5
with:
coverage-files: lcov.filtered.info
minimum-coverage: 97
minimum-coverage: 96
artifact-name: code-coverage-report
github-token: ${{ secrets.GITHUB_TOKEN }}
update-comment: true
84 changes: 82 additions & 2 deletions src/core/CrossStore.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ pragma solidity ^0.8.20;

import {IAuthExtensionVerifier} from "./IAuthExtensionVerifier.sol";
import {IContractModule} from "./IContractModule.sol";
import {MsgInitiateTxResponse} from "../proto/cross/core/initiator/Initiator.sol";
import {MsgInitiateTxResponse, Tx} from "../proto/cross/core/initiator/Initiator.sol";
import {Account} from "../proto/cross/core/auth/Auth.sol";
import {CoordinatorState, ContractTransactionState} from "../proto/cross/core/atomic/simple/AtomicSimple.sol";
import {ChannelInfo} from "../proto/cross/core/xcc/XCC.sol";
import {IIBCHandler} from "@hyperledger-labs/yui-ibc-solidity/contracts/core/25-handler/IIBCHandler.sol";

abstract contract CrossStore {
Expand Down Expand Up @@ -38,7 +39,23 @@ abstract contract CrossStore {
}

struct CoordStorage {
mapping(bytes32 => CoordinatorState.Data) states;
mapping(bytes32 => CoordStateCompact) compactStates;
}

/**
* @dev Compact version of CoordinatorState.Data optimized for 1:1 COMMIT_PROTOCOL_SIMPLE.
* Note: Extending to multi-participant 2PC will require a storage redesign.
*/
struct CoordStateCompact {
Tx.CommitProtocol commitProtocol;
CoordinatorState.CoordinatorPhase phase;
CoordinatorState.CoordinatorDecision decision;

string participantPort;
string participantChannel;

uint8 confirmedMask; // bit0=coord, bit1=participant
uint8 ackMask; // bit0=coord, bit1=participant
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to add a comment clarifying that this struct is only intended for a 1:1 relationship between the coordinator and the participant. It would also be good to note that introducing 2PC would require a redesign.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix: 8c538ed

}

function _getAuthStorage() internal pure returns (AuthStorage storage $) {
Expand All @@ -55,4 +72,67 @@ abstract contract CrossStore {
// solhint-disable-next-line no-inline-assembly
assembly { $.slot := COORD_STORAGE_LOCATION }
}

function _loadCoordinatorState(bytes32 txID) internal view returns (CoordinatorState.Data memory data) {
CoordStorage storage $ = _getCoordStorage();
CoordStateCompact storage compact = $.compactStates[txID];

data.commit_protocol = compact.commitProtocol;
data.phase = compact.phase;
data.decision = compact.decision;

data.channels = new ChannelInfo.Data[](2);
data.channels[0] = ChannelInfo.Data("", "");
data.channels[1] = ChannelInfo.Data(compact.participantPort, compact.participantChannel);

data.confirmed_txs = _maskToUint32Array(compact.confirmedMask);
data.acks = _maskToUint32Array(compact.ackMask);

return data;
}

function _saveCoordinatorState(bytes32 txID, CoordinatorState.Data memory data) internal {
CoordStorage storage $ = _getCoordStorage();
CoordStateCompact storage compact = $.compactStates[txID];

compact.commitProtocol = data.commit_protocol;
compact.phase = data.phase;
compact.decision = data.decision;

if (data.channels.length > 1) {
compact.participantPort = data.channels[1].port;
compact.participantChannel = data.channels[1].channel;
}

compact.confirmedMask = _uint32ArrayToMask(data.confirmed_txs);
compact.ackMask = _uint32ArrayToMask(data.acks);
}

function _maskToUint32Array(uint8 mask) internal pure returns (uint32[] memory) {
uint256 count = 0;
if ((mask & 0x01) != 0) ++count;
if ((mask & 0x02) != 0) ++count;

uint32[] memory arr = new uint32[](count);
uint256 idx = 0;
if ((mask & 0x01) != 0) {
arr[idx] = 0;
++idx;
}
if ((mask & 0x02) != 0) {
arr[idx] = 1;
++idx;
}
return arr;
}

function _uint32ArrayToMask(uint32[] memory arr) internal pure returns (uint8 mask) {
for (uint256 i = 0; i < arr.length;) {
if (arr[i] == 0) mask |= 0x01;
else if (arr[i] == 1) mask |= 0x02;
unchecked {
++i;
}
}
}
}
36 changes: 21 additions & 15 deletions src/core/TxAtomicSimple.sol
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ abstract contract TxAtomicSimple is
}

function _runSimpleProtocol(bytes32 txID, MsgInitiateTx.Data calldata msg_) internal {
CoordStorage storage coordStorage = _getCoordStorage();
TxStorage storage txStorage = _getTxStorage();

if (msg_.contract_transactions.length != 2) {
Expand All @@ -76,7 +75,7 @@ abstract contract TxAtomicSimple is
revert MessageTimeoutTimestamp(block.timestamp, msg_.timeout_timestamp);
}

if (coordStorage.states[txID].commit_protocol != Tx.CommitProtocol.COMMIT_PROTOCOL_UNKNOWN) {
if (_loadCoordinatorState(txID).commit_protocol != Tx.CommitProtocol.COMMIT_PROTOCOL_UNKNOWN) {
revert TxIDAlreadyExists(txID);
}

Expand Down Expand Up @@ -220,7 +219,7 @@ abstract contract TxAtomicSimple is
acks: acks
});

coordStorage.states[txID] = newState;
_saveCoordinatorState(txID, newState);

// --- 6. Save ContractTransactionState ---

Expand Down Expand Up @@ -316,10 +315,9 @@ abstract contract TxAtomicSimple is

// --- 3. Retrieve & Validate CoordinatorState ---

CoordStorage storage coordStorage = _getCoordStorage();
TxStorage storage txStorage = _getTxStorage();

CoordinatorState.Data storage cs = coordStorage.states[txID];
CoordinatorState.Data memory cs = _loadCoordinatorState(txID);
if (cs.commit_protocol == Tx.CommitProtocol.COMMIT_PROTOCOL_UNKNOWN) {
revert CoordinatorStateNotFound(txID);
}
Expand Down Expand Up @@ -354,9 +352,7 @@ abstract contract TxAtomicSimple is
}

// Mark Participant prepare as confirmed
if (!_containsUint32(cs.confirmed_txs, TX_INDEX_PARTICIPANT)) {
cs.confirmed_txs.push(TX_INDEX_PARTICIPANT);
}
cs.confirmed_txs = _addToUint32Array(cs.confirmed_txs, TX_INDEX_PARTICIPANT);

// --- 5. Determine Commit/Abort based on ACK ---

Expand All @@ -375,12 +371,10 @@ abstract contract TxAtomicSimple is
cs.phase = CoordinatorState.CoordinatorPhase.COORDINATOR_PHASE_COMMIT;

// Set ACK flags
if (!_containsUint32(cs.acks, TX_INDEX_COORDINATOR)) {
cs.acks.push(TX_INDEX_COORDINATOR);
}
if (!_containsUint32(cs.acks, TX_INDEX_PARTICIPANT)) {
cs.acks.push(TX_INDEX_PARTICIPANT);
}
cs.acks = _addToUint32Array(cs.acks, TX_INDEX_COORDINATOR);
cs.acks = _addToUint32Array(cs.acks, TX_INDEX_PARTICIPANT);

_saveCoordinatorState(txID, cs);

bool allPrepares =
_containsUint32(cs.confirmed_txs, TX_INDEX_COORDINATOR)
Expand Down Expand Up @@ -454,13 +448,25 @@ abstract contract TxAtomicSimple is

// --- Helpers ---

function _containsUint32(uint32[] storage arr, uint32 value) internal view returns (bool) {
function _containsUint32(uint32[] memory arr, uint32 value) internal view returns (bool) {
for (uint256 i = 0; i < arr.length; ++i) {
if (arr[i] == value) return true;
}
return false;
}

function _addToUint32Array(uint32[] memory arr, uint32 val) internal view returns (uint32[] memory) {
if (_containsUint32(arr, val)) {
return arr;
}
uint32[] memory newArr = new uint32[](arr.length + 1);
for (uint256 i = 0; i < arr.length; ++i) {
newArr[i] = arr[i];
}
newArr[arr.length] = val;
return newArr;
}

function packPacketAcknowledgementCall(PacketAcknowledgementCall.Data memory ack)
internal
pure
Expand Down
6 changes: 3 additions & 3 deletions src/core/TxManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ contract TxManager is
}

function _getCoordinatorState(bytes32 txID) internal view override returns (CoordinatorState.Data memory) {
CrossStore.CoordStorage storage s = _getCoordStorage();
if (s.states[txID].commit_protocol == Tx.CommitProtocol.COMMIT_PROTOCOL_UNKNOWN) {
CoordinatorState.Data memory state = _loadCoordinatorState(txID);
if (state.commit_protocol == Tx.CommitProtocol.COMMIT_PROTOCOL_UNKNOWN) {
revert CoordinatorStateNotFound(txID);
}
return s.states[txID];
return state;
}

function _storeCoordSigners(CrossStore.TxStorage storage t, bytes32 txID, Account.Data[] calldata signers) private {
Expand Down
162 changes: 160 additions & 2 deletions test/CrossStore.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,29 @@ contract CrossStoreHarness is CrossStore {
//--- Coord Storage Accessors ---
function writeCoord(bytes32 txID, CoordinatorState.CoordinatorPhase phase) public {
CoordStorage storage s = _getCoordStorage();
s.states[txID].phase = phase;
s.compactStates[txID].phase = phase;
}

function readCoord(bytes32 txID) public view returns (CoordinatorState.CoordinatorPhase) {
CoordStorage storage s = _getCoordStorage();
return s.states[txID].phase;
return s.compactStates[txID].phase;
}

//--- Coord Storage Logic Exposers ---
function exposed_loadCoordinatorState(bytes32 txID) public view returns (CoordinatorState.Data memory) {
return _loadCoordinatorState(txID);
}

function exposed_saveCoordinatorState(bytes32 txID, CoordinatorState.Data memory data) public {
_saveCoordinatorState(txID, data);
}

function exposed_maskToUint32Array(uint8 mask) public pure returns (uint32[] memory) {
return _maskToUint32Array(mask);
}

function exposed_uint32ArrayToMask(uint32[] memory arr) public pure returns (uint8 mask) {
return _uint32ArrayToMask(arr);
}
}

Expand Down Expand Up @@ -98,4 +115,145 @@ contract CrossStoreTest is Test {
"Coord storage collision with Tx"
);
}

function test_maskToUint32Array_Empty() public view {
uint32[] memory res = harness.exposed_maskToUint32Array(0x00);
assertEq(res.length, 0, "Should return empty array for mask 0x00");
}

function test_maskToUint32Array_CoordinatorOnly() public view {
uint32[] memory res = harness.exposed_maskToUint32Array(0x01);
assertEq(res.length, 1, "Array length mismatch");
assertEq(res[0], 0, "Index 0 should be TX_INDEX_COORDINATOR (0)");
}

function test_maskToUint32Array_ParticipantOnly() public view {
uint32[] memory res = harness.exposed_maskToUint32Array(0x02);
assertEq(res.length, 1, "Array length mismatch");
assertEq(res[0], 1, "Index 0 should be TX_INDEX_PARTICIPANT (1)");
}

function test_maskToUint32Array_Both() public view {
uint32[] memory res = harness.exposed_maskToUint32Array(0x03);
assertEq(res.length, 2, "Array length mismatch");
assertEq(res[0], 0, "First element mismatch");
assertEq(res[1], 1, "Second element mismatch");
}

function test_uint32ArrayToMask_Empty() public view {
uint32[] memory arr = new uint32[](0);
assertEq(harness.exposed_uint32ArrayToMask(arr), 0x00, "Empty array should yield mask 0x00");
}

function test_uint32ArrayToMask_CoordinatorOnly() public view {
uint32[] memory arr = new uint32[](1);
arr[0] = 0;
assertEq(harness.exposed_uint32ArrayToMask(arr), 0x01, "Mask mismatch for Coordinator");
}

function test_uint32ArrayToMask_ParticipantOnly() public view {
uint32[] memory arr = new uint32[](1);
arr[0] = 1;
assertEq(harness.exposed_uint32ArrayToMask(arr), 0x02, "Mask mismatch for Participant");
}

function test_uint32ArrayToMask_Both() public view {
uint32[] memory arr = new uint32[](2);
arr[0] = 0;
arr[1] = 1;
assertEq(harness.exposed_uint32ArrayToMask(arr), 0x03, "Mask mismatch for Both");
}

function test_uint32ArrayToMask_BothReversedOrder() public view {
uint32[] memory arr = new uint32[](2);
arr[0] = 1;
arr[1] = 0;
assertEq(harness.exposed_uint32ArrayToMask(arr), 0x03, "Mask should be 0x03 regardless of order");
}

function test_SaveAndLoad_CoordinatorState() public {
bytes32 txId = keccak256("tx.save.load");

CoordinatorState.Data memory original;
original.commit_protocol = Tx.CommitProtocol.COMMIT_PROTOCOL_SIMPLE;
original.phase = CoordinatorState.CoordinatorPhase.COORDINATOR_PHASE_PREPARE;
original.decision = CoordinatorState.CoordinatorDecision.COORDINATOR_DECISION_UNKNOWN;

original.channels = new ChannelInfo.Data[](2);
original.channels[0] = ChannelInfo.Data("", "");
original.channels[1] = ChannelInfo.Data("port-1", "channel-1");

original.confirmed_txs = new uint32[](1);
original.confirmed_txs[0] = 0;

harness.exposed_saveCoordinatorState(txId, original);

CoordinatorState.Data memory loaded = harness.exposed_loadCoordinatorState(txId);
assertEq(uint256(loaded.commit_protocol), uint256(original.commit_protocol));
assertEq(loaded.channels[1].port, "port-1");
assertEq(loaded.confirmed_txs.length, 1);
assertEq(loaded.confirmed_txs[0], 0);
}

function test_SaveLoad_DataLossChannel0() public {
bytes32 txId = keccak256("tx.dataloss.ch0");
CoordinatorState.Data memory data = _createEmptyData();

data.channels = new ChannelInfo.Data[](2);
data.channels[0] = ChannelInfo.Data("should-be-lost", "lost-channel");
data.channels[1] = ChannelInfo.Data("port-1", "channel-1");

harness.exposed_saveCoordinatorState(txId, data);

CoordinatorState.Data memory loaded = harness.exposed_loadCoordinatorState(txId);

assertEq(loaded.channels[0].port, "", "Channel0 port should be lost");
assertEq(loaded.channels[0].channel, "", "Channel0 channel should be lost");
assertEq(loaded.channels[1].port, "port-1");
}

function test_SaveLoad_DataLossExtraChannels() public {
bytes32 txId = keccak256("tx.dataloss.extra_ch");
CoordinatorState.Data memory data = _createEmptyData();

data.channels = new ChannelInfo.Data[](3);
data.channels[0] = ChannelInfo.Data("", "");
data.channels[1] = ChannelInfo.Data("port-1", "channel-1");
data.channels[2] = ChannelInfo.Data("port-2", "channel-2");

harness.exposed_saveCoordinatorState(txId, data);

CoordinatorState.Data memory loaded = harness.exposed_loadCoordinatorState(txId);

assertEq(loaded.channels.length, 2, "Extra channels should be truncated to 2");
}

function test_SaveLoad_DataLossUnsupportedIndices() public {
bytes32 txId = keccak256("tx.dataloss.indices");
CoordinatorState.Data memory data = _createEmptyData();

data.confirmed_txs = new uint32[](3);
data.confirmed_txs[0] = 0;
data.confirmed_txs[1] = 1;
data.confirmed_txs[2] = 2;

harness.exposed_saveCoordinatorState(txId, data);

CoordinatorState.Data memory loaded = harness.exposed_loadCoordinatorState(txId);

assertEq(loaded.confirmed_txs.length, 2, "Index 2 should be lost");
assertEq(loaded.confirmed_txs[0], 0);
assertEq(loaded.confirmed_txs[1], 1);
}

function _createEmptyData() internal pure returns (CoordinatorState.Data memory) {
return CoordinatorState.Data({
commit_protocol: Tx.CommitProtocol.COMMIT_PROTOCOL_SIMPLE,
phase: CoordinatorState.CoordinatorPhase.COORDINATOR_PHASE_PREPARE,
decision: CoordinatorState.CoordinatorDecision.COORDINATOR_DECISION_UNKNOWN,
channels: new ChannelInfo.Data[](0),
confirmed_txs: new uint32[](0),
acks: new uint32[](0)
});
}
}
Loading
Loading