diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fe756fde..e44719a24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ This changelog is based on [Keep a Changelog](https://keepachangelog.com/en/1.1. ## [Unreleased] ### Src -- +- Add global `set_default_max_transaction_fee()` to Client. ([#2000](https://github.com/hiero-ledger/hiero-sdk-python/issues/2000)) ### Tests - Refactor `mock_server` setup for network level TLS handling and added thread safety diff --git a/src/hiero_sdk_python/client/client.py b/src/hiero_sdk_python/client/client.py index 928c2e475..f7d47bf1b 100644 --- a/src/hiero_sdk_python/client/client.py +++ b/src/hiero_sdk_python/client/client.py @@ -73,6 +73,8 @@ def __init__(self, network: Network = None) -> None: self.logger: Logger = Logger(LogLevel.from_env(), "hiero_sdk_python") + self.default_max_transaction_fee: Hbar | None = None + @classmethod def from_env(cls, network: Optional[NetworkName] = None) -> "Client": """ @@ -455,6 +457,29 @@ def set_max_backoff(self, max_backoff: Union[int, float]) -> "Client": self._max_backoff = float(max_backoff) return self + def set_default_max_transaction_fee( + self, default_max_transaction_fee: Hbar + ) -> "Client": + """ + Set the maximum fee allowed per transaction. + + Args: + default_max_transaction_fee (Hbar): Maximum fee allowed per transaction. + + Returns: + Client: This client instance for fluent chaining. + """ + if not isinstance(default_max_transaction_fee, Hbar): + raise TypeError( + f"default_max_transaction_fee must be of type Hbar, got {(type(default_max_transaction_fee).__name__)}" + ) + + if default_max_transaction_fee.to_tinybars() < 0: + raise ValueError("default_max_transaction_fee must be >= 0") + + self.default_max_transaction_fee = default_max_transaction_fee + return self + def update_network(self) -> "Client": """ Refresh the network node list from the mirror node. diff --git a/src/hiero_sdk_python/transaction/transaction.py b/src/hiero_sdk_python/transaction/transaction.py index 22a721491..6c894e14e 100644 --- a/src/hiero_sdk_python/transaction/transaction.py +++ b/src/hiero_sdk_python/transaction/transaction.py @@ -1,5 +1,5 @@ import hashlib -from typing import Literal, Optional, overload +from typing import Literal, Optional, overload, Union from typing import TYPE_CHECKING @@ -8,9 +8,17 @@ from hiero_sdk_python.client.client import Client from hiero_sdk_python.exceptions import PrecheckError from hiero_sdk_python.executable import _Executable, _ExecutionState -from hiero_sdk_python.hapi.services import (basic_types_pb2, transaction_pb2, transaction_contents_pb2) -from hiero_sdk_python.hapi.services.schedulable_transaction_body_pb2 import SchedulableTransactionBody -from hiero_sdk_python.hapi.services.transaction_response_pb2 import (TransactionResponse as TransactionResponseProto) +from hiero_sdk_python.hapi.services import ( + basic_types_pb2, + transaction_pb2, + transaction_contents_pb2, +) +from hiero_sdk_python.hapi.services.schedulable_transaction_body_pb2 import ( + SchedulableTransactionBody, +) +from hiero_sdk_python.hapi.services.transaction_response_pb2 import ( + TransactionResponse as TransactionResponseProto, +) from hiero_sdk_python.hbar import Hbar from hiero_sdk_python.response_code import ResponseCode from hiero_sdk_python.transaction.transaction_id import TransactionId @@ -25,6 +33,7 @@ ) from hiero_sdk_python.transaction.custom_fee_limit import CustomFeeLimit +TransactionFee = Union[int, Hbar] class Transaction(_Executable): """ @@ -48,8 +57,8 @@ def __init__(self) -> None: super().__init__() self.transaction_id = None - self.transaction_fee: int | None = None - self.transaction_valid_duration = 120 + self.transaction_fee: TransactionFee | None = None + self.transaction_valid_duration = 120 self.generate_record = False self.memo = "" self.custom_fee_limits: list[CustomFeeLimit] = [] @@ -59,14 +68,14 @@ def __init__(self) -> None: # Each transaction body has the AccountId of the node it's being submitted to. # If these do not match `INVALID_NODE_ACCOUNT` error will occur. self._transaction_body_bytes: dict[AccountId, bytes] = {} - + # Maps transaction body bytes to their associated signatures # This allows us to maintain the signatures for each unique transaction # and ensures that the correct signatures are used when submitting transactions self._signature_map: dict[bytes, basic_types_pb2.SignatureMap] = {} - # changed from int: 2_000_000 to Hbar: 0.02 - self._default_transaction_fee = Hbar(0.02) - self.operator_account_id = None + # changed from int: 2_000_000 to Hbar: 2 + self._default_transaction_fee = Hbar(2) + self.operator_account_id = None self.batch_key: Optional[Key] = None def _make_request(self): @@ -81,11 +90,7 @@ def _make_request(self): """ return self._to_proto() - def _map_response( - self, - response, - node_id, - proto_request): + def _map_response(self, response, node_id, proto_request): """ Implements the Executable._map_response method to create a TransactionResponse. @@ -130,7 +135,9 @@ def _should_retry(self, response): _ExecutionState: The execution state indicating what to do next """ if not isinstance(response, TransactionResponseProto): - raise ValueError(f"Expected TransactionResponseProto but got {type(response)}") + raise ValueError( + f"Expected TransactionResponseProto but got {type(response)}" + ) status = response.nodeTransactionPrecheckCode @@ -139,7 +146,7 @@ def _should_retry(self, response): ResponseCode.PLATFORM_TRANSACTION_NOT_CREATED, ResponseCode.PLATFORM_NOT_ACTIVE, ResponseCode.BUSY, - ResponseCode.INVALID_NODE_ACCOUNT + ResponseCode.INVALID_NODE_ACCOUNT, } if status in retryable_statuses: @@ -165,7 +172,7 @@ def _map_status_error(self, response): """ error_code = response.nodeTransactionPrecheckCode tx_id = self.transaction_id - + return PrecheckError(error_code, tx_id) def sign(self, private_key: "PrivateKey") -> "Transaction": @@ -183,7 +190,7 @@ def sign(self, private_key: "PrivateKey") -> "Transaction": """ # We require the transaction to be frozen before signing self._require_frozen() - + # We sign the bodies for each node in case we need to switch nodes during execution. for body_bytes in self._transaction_body_bytes.values(): signature = private_key.sign(body_bytes) @@ -192,13 +199,11 @@ def sign(self, private_key: "PrivateKey") -> "Transaction": if private_key.is_ed25519(): sig_pair = basic_types_pb2.SignaturePair( - pubKeyPrefix=public_key_bytes, - ed25519=signature + pubKeyPrefix=public_key_bytes, ed25519=signature ) else: sig_pair = basic_types_pb2.SignaturePair( - pubKeyPrefix=public_key_bytes, - ECDSA_secp256k1=signature + pubKeyPrefix=public_key_bytes, ECDSA_secp256k1=signature ) # We initialize the signature map for this body_bytes if it doesn't exist yet @@ -206,7 +211,7 @@ def sign(self, private_key: "PrivateKey") -> "Transaction": # Append the signature pair to the signature map for this transaction body self._signature_map[body_bytes].sigPair.append(sig_pair) - + return self def _to_proto(self): @@ -224,7 +229,9 @@ def _to_proto(self): body_bytes = self._transaction_body_bytes.get(self.node_account_id) if body_bytes is None: - raise ValueError(f"No transaction body found for node {self.node_account_id}") + raise ValueError( + f"No transaction body found for node {self.node_account_id}" + ) # Get signature map, or create empty one if transaction is not signed sig_map = self._signature_map.get(body_bytes) @@ -232,8 +239,7 @@ def _to_proto(self): sig_map = basic_types_pb2.SignatureMap() signed_transaction = transaction_contents_pb2.SignedTransaction( - bodyBytes=body_bytes, - sigMap=sig_map + bodyBytes=body_bytes, sigMap=sig_map ) return transaction_pb2.Transaction( @@ -260,24 +266,32 @@ def freeze(self): """ if self._transaction_body_bytes: return self - + if self.transaction_id is None: - raise ValueError("Transaction ID must be set before freezing. Use freeze_with(client) or set_transaction_id().") - + raise ValueError( + "Transaction ID must be set before freezing. Use freeze_with(client) or set_transaction_id()." + ) + if self.node_account_id is None and len(self.node_account_ids) == 0: - raise ValueError("Node account ID must be set before freezing. Use freeze_with(client) or manually set node_account_ids.") - + raise ValueError( + "Node account ID must be set before freezing. Use freeze_with(client) or manually set node_account_ids." + ) + # Populate node_account_ids for backward compatibility if self.node_account_id: self.set_node_account_id(self.node_account_id) - self._transaction_body_bytes[self.node_account_id] = self.build_transaction_body().SerializeToString() + self._transaction_body_bytes[self.node_account_id] = ( + self.build_transaction_body().SerializeToString() + ) return self # Build the transaction body for the single node for node_account_id in self.node_account_ids: self.node_account_id = node_account_id - self._transaction_body_bytes[node_account_id] = self.build_transaction_body().SerializeToString() - + self._transaction_body_bytes[node_account_id] = ( + self.build_transaction_body().SerializeToString() + ) + return self def freeze_with(self, client): @@ -295,49 +309,62 @@ def freeze_with(self, client): """ if self._transaction_body_bytes: return self - + if self.transaction_id is None: self.transaction_id = client.generate_transaction_id() - + + if self.transaction_fee is None: + if client.default_max_transaction_fee is not None: + self.transaction_fee = client.default_max_transaction_fee + else: + self.transaction_fee = self._default_transaction_fee + # We iterate through every node in the client's network # For each node, set the node_account_id and build the transaction body # This allows the transaction to be submitted to any node in the network if self.batch_key: # For Inner Transaction of batch transaction node_account_id=0.0.0 - self.node_account_id = AccountId(0,0,0) - self._transaction_body_bytes[AccountId(0,0,0)] = self.build_transaction_body().SerializeToString() + self.node_account_id = AccountId(0, 0, 0) + self._transaction_body_bytes[AccountId(0, 0, 0)] = ( + self.build_transaction_body().SerializeToString() + ) return self - + # Single node if self.node_account_id: self.set_node_account_id(self.node_account_id) - self._transaction_body_bytes[self.node_account_id] = self.build_transaction_body().SerializeToString() + self._transaction_body_bytes[self.node_account_id] = ( + self.build_transaction_body().SerializeToString() + ) return self - + # Multiple node if len(self.node_account_ids) > 0: for node_account_id in self.node_account_ids: self.node_account_id = node_account_id - self._transaction_body_bytes[node_account_id] = self.build_transaction_body().SerializeToString() + self._transaction_body_bytes[node_account_id] = ( + self.build_transaction_body().SerializeToString() + ) else: # Use all nodes from client network for node in client.network.nodes: self.node_account_id = node._account_id - self._transaction_body_bytes[node._account_id] = self.build_transaction_body().SerializeToString() + self._transaction_body_bytes[node._account_id] = ( + self.build_transaction_body().SerializeToString() + ) return self - + @overload def execute( self, client: "Client", timeout: int | float | None = None, wait_for_receipt: Literal[True] = True, - validate_status: bool = False - ) -> "TransactionReceipt": - ... + validate_status: bool = False, + ) -> "TransactionReceipt": ... @overload def execute( @@ -345,16 +372,15 @@ def execute( client: "Client", timeout: int | float | None = None, wait_for_receipt: Literal[False] = False, - validate_status: bool = False - ) -> "TransactionResponse": - ... + validate_status: bool = False, + ) -> "TransactionResponse": ... def execute( - self, - client: "Client", - timeout: int | float | None = None, + self, + client: "Client", + timeout: int | float | None = None, wait_for_receipt: bool = True, - validate_status: bool = False + validate_status: bool = False, ) -> TransactionReceipt | TransactionResponse: """ Executes the transaction on the Hedera network using the provided client. @@ -378,8 +404,11 @@ def execute( ReceiptStatusError: If the query fails with a receipt status error """ from hiero_sdk_python.transaction.batch_transaction import BatchTransaction + if self.batch_key and not isinstance(self, (BatchTransaction)): - raise ValueError("Cannot execute batchified transaction outside of BatchTransaction.") + raise ValueError( + "Cannot execute batchified transaction outside of BatchTransaction." + ) if not self._transaction_body_bytes: self.freeze_with(client) @@ -398,8 +427,10 @@ def execute( response.transaction_id = self.transaction_id if wait_for_receipt: - return response.get_receipt(client, timeout=timeout, validate_status=validate_status) - + return response.get_receipt( + client, timeout=timeout, validate_status=validate_status + ) + return response def is_signed_by(self, public_key): @@ -413,12 +444,14 @@ def is_signed_by(self, public_key): bool: True if signed by the given public key, False otherwise. """ public_key_bytes = public_key.to_bytes_raw() - - sig_map = self._signature_map.get(self._transaction_body_bytes.get(self.node_account_id)) - + + sig_map = self._signature_map.get( + self._transaction_body_bytes.get(self.node_account_id) + ) + if sig_map is None: return False - + for sig_pair in sig_map.sigPair: if sig_pair.pubKeyPrefix == public_key_bytes: return True @@ -465,13 +498,15 @@ def build_base_transaction_body(self) -> transaction_pb2.TransactionBody: ValueError: If required IDs are not set. """ if self.transaction_id is None: - if self.operator_account_id is None: - raise ValueError("Operator account ID is not set.") - self.transaction_id = TransactionId.generate(self.operator_account_id) + if self.operator_account_id is None: + raise ValueError("Operator account ID is not set.") + self.transaction_id = TransactionId.generate(self.operator_account_id) transaction_id_proto = self.transaction_id._to_proto() - selected_node = self.node_account_id or (self.node_account_ids[0] if self.node_account_ids else None) + selected_node = self.node_account_id or ( + self.node_account_ids[0] if self.node_account_ids else None + ) if selected_node is None: raise ValueError("Node account ID is not set.") @@ -485,10 +520,14 @@ def build_base_transaction_body(self) -> transaction_pb2.TransactionBody: else: transaction_body.transactionFee = int(fee) - transaction_body.transactionValidDuration.seconds = self.transaction_valid_duration + transaction_body.transactionValidDuration.seconds = ( + self.transaction_valid_duration + ) transaction_body.generateRecord = self.generate_record transaction_body.memo = self.memo - custom_fee_limits = [custom_fee._to_proto() for custom_fee in self.custom_fee_limits] + custom_fee_limits = [ + custom_fee._to_proto() for custom_fee in self.custom_fee_limits + ] transaction_body.max_custom_fees.extend(custom_fee_limits) if self.batch_key: @@ -513,7 +552,9 @@ def build_base_scheduled_body(self) -> SchedulableTransactionBody: schedulable_body.transactionFee = int(fee) schedulable_body.memo = self.memo - custom_fee_limits = [custom_fee._to_proto() for custom_fee in self.custom_fee_limits] + custom_fee_limits = [ + custom_fee._to_proto() for custom_fee in self.custom_fee_limits + ] schedulable_body.max_custom_fees.extend(custom_fee_limits) return schedulable_body @@ -617,6 +658,37 @@ def set_transaction_id(self, transaction_id: TransactionId): self.transaction_id = transaction_id return self + def set_transaction_fee(self, transaction_fee: TransactionFee) -> "Transaction": + """ + Sets the transaction fee for the transaction + + Args: + transaction_fee (TransactionFee): The transaction fee to set. + + Returns: + Transaction: The current transaction instance for method chaining. + + Raises: + Exception: If the fee values aren't valid. + """ + self._require_not_frozen() + + if not isinstance(transaction_fee, TransactionFee): + raise TypeError( + f"transaction_fee must be an int or Hbar, got {type(transaction_fee).__name__}" + ) + + if isinstance(transaction_fee, int): + if transaction_fee < 0: + raise ValueError("transaction_fee must be >= 0") + + if isinstance(transaction_fee, Hbar): + if transaction_fee < Hbar(0): + raise ValueError("transaction_fee must be >= 0") + + self.transaction_fee = transaction_fee + return self + def to_bytes(self) -> bytes: """ Serializes the frozen transaction into its protobuf-encoded byte representation. @@ -653,10 +725,10 @@ def to_bytes(self) -> bytes: Exception: If the transaction has not been frozen yet. """ self._require_frozen() - + # Get the transaction protobuf transaction_proto = self._to_proto() - + # Serialize to bytes return transaction_proto.SerializeToString() @@ -823,7 +895,7 @@ def _get_transaction_class(transaction_type: str): "tokenReject": "hiero_sdk_python.tokens.token_reject_transaction.TokenRejectTransaction", "tokenAirdrop": "hiero_sdk_python.tokens.token_airdrop_transaction.TokenAirdropTransaction", "tokenCancelAirdrop": "hiero_sdk_python.tokens.token_cancel_airdrop_transaction.TokenCancelAirdropTransaction", - "atomic_batch": "hiero_sdk_python.transaction.batch_transaction.BatchTransaction" + "atomic_batch": "hiero_sdk_python.transaction.batch_transaction.BatchTransaction", } class_path = transaction_type_map.get(transaction_type) @@ -836,7 +908,9 @@ def _get_transaction_class(transaction_type: str): module = __import__(module_path, fromlist=[class_name]) return getattr(module, class_name) except (ImportError, AttributeError) as e: - raise ValueError(f"Failed to import transaction class for type '{transaction_type}': {e}") + raise ValueError( + f"Failed to import transaction class for type '{transaction_type}': {e}" + ) @classmethod def _from_protobuf(cls, transaction_body, body_bytes: bytes, sig_map): @@ -857,32 +931,42 @@ def _from_protobuf(cls, transaction_body, body_bytes: bytes, sig_map): transaction = cls() if transaction_body.HasField("transactionID"): - transaction.transaction_id = TransactionId._from_proto(transaction_body.transactionID) + transaction.transaction_id = TransactionId._from_proto( + transaction_body.transactionID + ) if transaction_body.HasField("nodeAccountID"): - transaction.node_account_id = AccountId._from_proto(transaction_body.nodeAccountID) + transaction.node_account_id = AccountId._from_proto( + transaction_body.nodeAccountID + ) transaction.transaction_fee = transaction_body.transactionFee - transaction.transaction_valid_duration = transaction_body.transactionValidDuration.seconds + transaction.transaction_valid_duration = ( + transaction_body.transactionValidDuration.seconds + ) transaction.generate_record = transaction_body.generateRecord transaction.memo = transaction_body.memo if transaction_body.max_custom_fees: from hiero_sdk_python.transaction.custom_fee_limit import CustomFeeLimit + transaction.custom_fee_limits = [ - CustomFeeLimit._from_proto(fee) for fee in transaction_body.max_custom_fees + CustomFeeLimit._from_proto(fee) + for fee in transaction_body.max_custom_fees ] if transaction.node_account_id: # restore for the original frozen node transaction.set_node_account_id(transaction.node_account_id) - transaction._transaction_body_bytes[transaction.node_account_id] = body_bytes + transaction._transaction_body_bytes[transaction.node_account_id] = ( + body_bytes + ) if sig_map and sig_map.sigPair: transaction._signature_map[body_bytes] = sig_map return transaction - + def set_batch_key(self, key: Key): """ Set the batch key required for batch transaction. @@ -891,12 +975,12 @@ def set_batch_key(self, key: Key): batch_key (Key): Key to use as batch key (accepts both PrivateKey and PublicKey). Returns: - Transaction: A reconstructed transaction instance of the appropriate subclass. + Transaction: A reconstructed transaction instance of the appropriate subclass. """ self._require_not_frozen() self.batch_key = key return self - + def batchify(self, client: Client, batch_key: Key): """ Marks the current transaction as an inner (batched) transaction. @@ -904,7 +988,7 @@ def batchify(self, client: Client, batch_key: Key): Args: client (Client): The client instance to use for setting defaults. batch_key (Key): Key to use as batch key (accepts both PrivateKey and PublicKey). - + Returns: Transaction: A reconstructed transaction instance of the appropriate subclass. """ diff --git a/tests/integration/account_update_transaction_e2e_test.py b/tests/integration/account_update_transaction_e2e_test.py index 41605fe08..e65bd39a4 100644 --- a/tests/integration/account_update_transaction_e2e_test.py +++ b/tests/integration/account_update_transaction_e2e_test.py @@ -3,7 +3,6 @@ """ import pytest -import datetime from hiero_sdk_python.account.account_create_transaction import AccountCreateTransaction from hiero_sdk_python.account.account_id import AccountId @@ -16,6 +15,7 @@ from hiero_sdk_python.response_code import ResponseCode from hiero_sdk_python.timestamp import Timestamp from tests.integration.utils import env +from hiero_sdk_python.exceptions import PrecheckError @pytest.mark.integration @@ -269,19 +269,6 @@ def test_integration_account_update_transaction_invalid_auto_renew_period(env): info_after = AccountInfoQuery(account_id).execute(env.client) assert info_after.expiration_time == original_info.expiration_time -def _apply_tiny_max_fee_if_supported(tx, client) -> bool: - # Try tx-level setters - for attr in ("set_max_transaction_fee", "set_max_fee", "set_transaction_fee"): - if hasattr(tx, attr): - getattr(tx, attr)(Hbar.from_tinybars(1)) - return True - # Try client-level default - for attr in ("set_default_max_transaction_fee", "set_max_transaction_fee", - "set_default_max_fee", "setMaxTransactionFee"): - if hasattr(client, attr): - getattr(client, attr)(Hbar.from_tinybars(1)) - return True - return False @pytest.mark.integration def test_account_update_insufficient_fee_with_valid_expiration_bump(env): @@ -307,20 +294,21 @@ def test_account_update_insufficient_fee_with_valid_expiration_bump(env): AccountUpdateTransaction() .set_account_id(account_id) .set_expiration_time(new_expiry) + .set_transaction_fee(Hbar.from_tinybars(1)) ) - if not _apply_tiny_max_fee_if_supported(tx, env.client): - pytest.skip("SDK lacks a max-fee API; cannot deterministically trigger INSUFFICIENT_TX_FEE.") + with pytest.raises(PrecheckError) as exc_info: + tx.execute(env.client) - receipt = tx.execute(env.client) - assert receipt.status == ResponseCode.INSUFFICIENT_TX_FEE, ( - f"Expected INSUFFICIENT_TX_FEE but got {ResponseCode(receipt.status).name}" - ) + assert ( + exc_info.value.status == ResponseCode.INSUFFICIENT_TX_FEE + ), f"Expected INSUFFICIENT_TX_FEE but got {ResponseCode(exc_info.value.status).name}" # Confirm expiration time did not change info_after = AccountInfoQuery(account_id).execute(env.client) assert int(info_after.expiration_time.seconds) == base_expiry_secs + @pytest.mark.integration def test_integration_account_update_transaction_with_only_account_id(env): """Test that AccountUpdateTransaction can execute with only account ID set.""" @@ -351,7 +339,9 @@ def test_integration_account_update_transaction_with_only_account_id(env): @pytest.mark.integration -def test_integration_account_update_transaction_with_max_automatic_token_associations(env): +def test_integration_account_update_transaction_with_max_automatic_token_associations( + env +): """Test updating max_automatic_token_associations and verifying it persists.""" # Create initial account receipt = ( @@ -423,9 +413,15 @@ def test_integration_account_update_transaction_with_staking_fields(env): # Verify staking info reflects the updated values info = AccountInfoQuery(account_id).execute(env.client) - assert info.staking_info.staked_account_id == staked_account_id, "Staked account ID should match" - assert info.staking_info.staked_node_id is None, "Staked node ID should be cleared when staking to an account" - assert info.staking_info.decline_reward is True, "Decline staking reward should be true" + assert ( + info.staking_info.staked_account_id == staked_account_id + ), "Staked account ID should match" + assert ( + info.staking_info.staked_node_id is None + ), "Staked node ID should be cleared when staking to an account" + assert ( + info.staking_info.decline_reward is True + ), "Decline staking reward should be true" @pytest.mark.integration @@ -459,4 +455,4 @@ def test_integration_account_update_transaction_with_staked_node_id(env): info = AccountInfoQuery(account_id).execute(env.client) assert info.staking_info is not None, "Staking info should be set" assert info.staking_info.staked_node_id == 0 - assert info.staking_info.staked_account_id is None \ No newline at end of file + assert info.staking_info.staked_account_id is None diff --git a/tests/unit/client_test.py b/tests/unit/client_test.py index f2c0fc41f..0aab1c272 100644 --- a/tests/unit/client_test.py +++ b/tests/unit/client_test.py @@ -5,7 +5,7 @@ from decimal import Decimal import os import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import patch from hiero_sdk_python.client import client as client_module @@ -456,6 +456,52 @@ def test_set_max_backoff_less_than_min_backoff(): assert returned is client +def test_set_default_max_transaction_fee(): + """Test that set_default_max_transaction_fee updates default value of default_max_transaction_fee.""" + client = Client.for_testnet() + assert client.default_max_transaction_fee is None # default default_max_transaction_fee == None + + setted_default_max_transaction_fee = Hbar(2) + client.set_default_max_transaction_fee(setted_default_max_transaction_fee) + assert client.default_max_transaction_fee == setted_default_max_transaction_fee + + returned = client.set_default_max_transaction_fee(Hbar(1)) + + assert returned is client + + client.close() + + +@pytest.mark.parametrize( + "invalid_default_max_transaction_fee", + [ + "Hi from Anto :D", + True, + 18, + float(18), + ["True", 1, float(1)], + {"default_max_transaction_fee": Hbar(2)}, + ], +) +def test_set_invalid_type_default_max_transaction_fee(invalid_default_max_transaction_fee): + """Test that set invalid Type with set_max_transaction fee.""" + client = Client.for_testnet() + + with pytest.raises( + TypeError, + match=f"default_max_transaction_fee must be of type Hbar, got {(type(invalid_default_max_transaction_fee).__name__)}", + ): + client.set_default_max_transaction_fee(invalid_default_max_transaction_fee) + + +def test_set_invalid_value_default_max_transaction_fee(): + """Test that set invalid value with set_max_transaction fee.""" + client = Client.for_testnet() + + with pytest.raises(ValueError, match="default_max_transaction_fee must be >= 0"): + client.set_default_max_transaction_fee(Hbar(-1)) + + # Test update_network def test_update_network_refreshes_nodes_and_returns_self(): """Test that update_network refreshes network nodes and returns the client.""" @@ -469,6 +515,7 @@ def test_update_network_refreshes_nodes_and_returns_self(): client.close() + def test_warning_when_grpc_deadline_exceeds_request_timeout(): """Warn when grpc_deadline is greater than request_timeout.""" client = Client.for_testnet() diff --git a/tests/unit/transaction_freeze_and_bytes_test.py b/tests/unit/transaction_freeze_and_bytes_test.py index 2ccf188b8..948e7e7c6 100644 --- a/tests/unit/transaction_freeze_and_bytes_test.py +++ b/tests/unit/transaction_freeze_and_bytes_test.py @@ -9,8 +9,13 @@ import pytest from hiero_sdk_python.account.account_id import AccountId from hiero_sdk_python.crypto.private_key import PrivateKey +from hiero_sdk_python.hbar import Hbar from hiero_sdk_python.transaction.transfer_transaction import TransferTransaction +from hiero_sdk_python.transaction.transaction import Transaction from hiero_sdk_python.transaction.transaction_id import TransactionId +from hiero_sdk_python.hapi.services.schedulable_transaction_body_pb2 import ( + SchedulableTransactionBody, +) from hiero_sdk_python.hapi.services.transaction_response_pb2 import ( TransactionResponse as TransactionResponseProto, @@ -20,6 +25,17 @@ pytestmark = pytest.mark.unit +class _DummyTransaction(Transaction): + def build_transaction_body(self): + return self.build_base_transaction_body() + + def build_scheduled_body(self) -> SchedulableTransactionBody: + return SchedulableTransactionBody() + + def _get_method(self, channel): + raise NotImplementedError + + def test_freeze_without_transaction_id_raises_error(): """Test that freeze() raises ValueError when transaction_id is not set.""" transaction = TransferTransaction() @@ -617,6 +633,82 @@ def test_unsigned_transaction_can_be_signed_after_to_bytes(): assert unsigned_bytes != signed_bytes assert isinstance(signed_bytes, bytes) + +@pytest.mark.parametrize("value", [0, 1, 100_000_000, Hbar(1)]) +def test_set_transaction_fee(value): + """set_transaction_fee() stores valid integer and hbar fees.""" + transaction = TransferTransaction() + + returned = transaction.set_transaction_fee(value) + + assert returned is transaction + assert transaction.transaction_fee == value + + +@pytest.mark.parametrize( + "value", + [ + "hello from Anto :D", + 1.5, + [100_000_000], + {"transaction_fee": 100_000_000}, + ], +) +def test_set_transaction_fee_invalid_type(value): + """set_transaction_fee() rejects non-integer fees.""" + transaction = TransferTransaction() + + with pytest.raises( + TypeError, + match=f"transaction_fee must be an int or Hbar, got {type(value).__name__}", + ): + transaction.set_transaction_fee(value) + + +@pytest.mark.parametrize("value", [-1, -100_000_000, Hbar(-1)]) +def test_set_transaction_fee_invalid_value(value): + """set_transaction_fee() rejects negative fees.""" + transaction = TransferTransaction() + + with pytest.raises( + ValueError, + match="transaction_fee must be >= 0", + ): + transaction.set_transaction_fee(value) + + +def test_freeze_with_uses_explicit_transaction_fee(mock_client): + """freeze_with() preserves an explicitly-set transaction fee.""" + transaction = TransferTransaction().set_transaction_fee(123_456) + mock_client.default_max_transaction_fee = Hbar(7) + + transaction.freeze_with(mock_client) + + assert transaction.transaction_fee == 123_456 + + +def test_freeze_with_uses_client_default_max_transaction_fee(mock_client): + """freeze_with() uses the client's default max transaction fee when unset.""" + transaction = _DummyTransaction() + mock_client.default_max_transaction_fee = Hbar(7) + + transaction.freeze_with(mock_client) + + assert transaction.transaction_fee == Hbar(7) + + +def test_freeze_with_uses_transaction_default_fee_when_client_default_is_unset( + mock_client, +): + """freeze_with() falls back to the transaction default fee when no fee is set.""" + transaction = _DummyTransaction() + mock_client.default_max_transaction_fee = None + + transaction.freeze_with(mock_client) + + assert transaction.transaction_fee == transaction._default_transaction_fee + assert transaction.transaction_fee == Hbar(2) + def test_transaction_freeze_with_node_ids(mock_client): """ Test freeze_with() correctly initializes transaction bytes using provided node_account_id(s).