Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions Sources/Starknet/Crypto/KeyDerivation.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import BigInt
import CryptoKit
import Foundation

public enum StarkKeygenError: Error {
case rngFailure
case grindExceeded
}

private let STARK_EC_ORDER = StarknetCurve.curveOrder

public enum StarkKeygen {
/// Deterministically derives a valid Starknet private key.
///
/// - Parameters:
/// - seed: bytes used as a seed to compute the key.
/// - maxIters: maximum number of iterations to avoid infinite loop or dos.
/// - Returns: valid Starknet private key.
public static func grindKey(seed: Data, maxIters: Int = 100_000) throws -> BigUInt {
let mask = BigUInt(1) << 256
let limit = mask - (mask % STARK_EC_ORDER)

var i: UInt64 = 0
while i <= UInt64(maxIters) {
let iBytes = varIntBE(i)
var input = Data(capacity: seed.count + iBytes.count)
input.append(seed)
input.append(iBytes)

let digest = SHA256.hash(data: input)
let x = BigUInt(Data(digest))

if x < limit {
return x % STARK_EC_ORDER // in [0, n)
}

i &+= 1
}
throw StarkKeygenError.grindExceeded
}

/// Generates a random Starknet private key.
///
/// - Returns: a valid Starknet private key.
public static func randomPrivateKeyHex() throws -> String {
var seed = Data(count: 32)
let status = seed.withUnsafeMutableBytes { ptr in
SecRandomCopyBytes(kSecRandomDefault, ptr.count, ptr.baseAddress!)
}
guard status == errSecSuccess else { throw StarkKeygenError.rngFailure }

let sk = try grindKey(seed: seed)
return "0x" + String(sk, radix: 16).leftPadding(width: 64)
}
}

/// Big-endian variable-length encoding for UInt64 (no leading zeros).
private func varIntBE(_ x: UInt64) -> Data {
if x == 0 { return Data([0]) }
var bytes: [UInt8] = []
var v = x
while v > 0 {
bytes.append(UInt8(truncatingIfNeeded: v & 0xFF))
v >>= 8
}
return Data(bytes.reversed())
}

private extension String {
func leftPadding(width: Int) -> String {
if count >= width { return self }
return String(repeating: "0", count: width - count) + self
}
}
4 changes: 2 additions & 2 deletions Sources/Starknet/Data/Transaction/TransactionTrace.swift
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ enum StarknetTransactionTraceWrapper: Decodable {
case deployAccount(StarknetDeployAccountTransactionTrace)
case l1Handler(StarknetL1HandlerTransactionTrace)

public var transactionTrace: any StarknetTransactionTrace {
var transactionTrace: any StarknetTransactionTrace {
switch self {
case let .invoke(txTrace):
txTrace
Expand All @@ -162,7 +162,7 @@ enum StarknetTransactionTraceWrapper: Decodable {
}
}

public init(from decoder: Decoder) throws {
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: Keys.self)

let type = try container.decode(StarknetTransactionType.self, forKey: Keys.type)
Expand Down
2 changes: 1 addition & 1 deletion Tests/StarknetTests/Accounts/AccountTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ final class AccountTests: XCTestCase {
}

func testGetNonce() async throws {
let _ = await (try? provider.send(request: account.getNonce()))
_ = await (try? provider.send(request: account.getNonce()))
}

func testExecuteV3() async throws {
Expand Down
34 changes: 34 additions & 0 deletions Tests/StarknetTests/Crypto/KeyDerivationTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import BigInt
@testable import Starknet
import XCTest

final class KeyDerivationTests: XCTestCase {
private let curveOrder = StarknetCurve.curveOrder

func testGrindKeyIsDeterministicAndInRange() throws {
let seed = Data([0xDE, 0xAD, 0xBE, 0xEF])
let k1 = try StarkKeygen.grindKey(seed: seed)
let k2 = try StarkKeygen.grindKey(seed: seed)
XCTAssertEqual(k1, k2)
XCTAssertTrue(k1 < curveOrder)
XCTAssertNotEqual(k1, 0)
}

func testRandomPrivateKeyLooksValid() throws {
let hex = try StarkKeygen.randomPrivateKeyHex()
XCTAssertTrue(hex.hasPrefix("0x"))
XCTAssertEqual(hex.count, 2 + 64)
let sk = BigUInt(hex.dropFirst(2), radix: 16)!
XCTAssertTrue(sk < curveOrder)
XCTAssertNotEqual(sk, 0)
let hex2 = try StarkKeygen.randomPrivateKeyHex()
XCTAssertNotEqual(hex, hex2)
}
}

private extension String {
func leftPadding(width: Int) -> String {
if count >= width { return self }
return String(repeating: "0", count: width - count) + self
}
}
14 changes: 7 additions & 7 deletions Tests/StarknetTests/Providers/ProviderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ final class ProviderTests: XCTestCase {
func testGetTransactionByHash() async throws {
let previousResult = try await provider.send(request: RequestBuilder.getTransactionBy(blockId: .tag(.latest), index: 0))

let _ = try await provider.send(request: RequestBuilder.getTransactionBy(hash: previousResult.transaction.hash!))
_ = try await provider.send(request: RequestBuilder.getTransactionBy(hash: previousResult.transaction.hash!))

do {
let _ = try await provider.send(request: RequestBuilder.getTransactionBy(hash: "0x123"))
_ = try await provider.send(request: RequestBuilder.getTransactionBy(hash: "0x123"))
XCTFail("Fetching transaction with nonexistent hash should fail")
} catch {}
}
Expand Down Expand Up @@ -236,12 +236,12 @@ final class ProviderTests: XCTestCase {
let params2 = StarknetInvokeParamsV3(nonce: Felt(nonce.value + 1)!, resourceBounds: StarknetResourceBoundsMapping.zero)
let tx2 = try account.signV3(calls: [call, call2], params: params2, forFeeEstimation: true)

let _ = try await provider.send(request: RequestBuilder.estimateFee(for: [tx1, tx2], simulationFlags: []))
_ = try await provider.send(request: RequestBuilder.estimateFee(for: [tx1, tx2], simulationFlags: []))

let tx1WithoutSignature = StarknetInvokeTransactionV3(senderAddress: tx1.senderAddress, calldata: tx1.calldata, signature: [], resourceBounds: tx1.resourceBounds, nonce: nonce, forFeeEstimation: true)
let tx2WithoutSignature = StarknetInvokeTransactionV3(senderAddress: tx2.senderAddress, calldata: tx2.calldata, signature: [], resourceBounds: tx2.resourceBounds, nonce: Felt(nonce.value + 1)!, forFeeEstimation: true)

let _ = try await provider.send(request: RequestBuilder.estimateFee(for: [tx1WithoutSignature, tx2WithoutSignature], simulationFlags: [.skipValidate]))
_ = try await provider.send(request: RequestBuilder.estimateFee(for: [tx1WithoutSignature, tx2WithoutSignature], simulationFlags: [.skipValidate]))
}

func testEstimateDeployAccountV3Fee() async throws {
Expand All @@ -259,11 +259,11 @@ final class ProviderTests: XCTestCase {

let tx = try newAccount.signDeployAccountV3(classHash: accountContractClassHash, calldata: [newPublicKey], salt: .zero, params: params, forFeeEstimation: true)

let _ = try await provider.send(request: RequestBuilder.estimateFee(for: tx))
_ = try await provider.send(request: RequestBuilder.estimateFee(for: tx))

let txWithoutSignature = StarknetDeployAccountTransactionV3(signature: [], resourceBounds: tx.resourceBounds, nonce: tx.nonce, contractAddressSalt: tx.contractAddressSalt, constructorCalldata: tx.constructorCalldata, classHash: tx.classHash, forFeeEstimation: true)

let _ = try await provider.send(request: RequestBuilder.estimateFee(for: txWithoutSignature, simulationFlags: [.skipValidate]))
_ = try await provider.send(request: RequestBuilder.estimateFee(for: txWithoutSignature, simulationFlags: [.skipValidate]))
}

func testEstimateMessageFee() async throws {
Expand Down Expand Up @@ -370,7 +370,7 @@ final class ProviderTests: XCTestCase {
XCTAssertEqual(try transactionsResponse[0].get().transaction.hash, invokeTx.transaction.hash!)

do {
let _ = try transactionsResponse[1].get().transaction.hash
_ = try transactionsResponse[1].get().transaction.hash
XCTFail("Fetching transaction with nonexistent hash should fail")
} catch let error as StarknetProviderError {
switch error {
Expand Down
30 changes: 15 additions & 15 deletions Tests/StarknetTests/Utils/DevnetClient/DevnetClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
accountDirectory = URL(string: tmpPath)!
}

public func start() async throws {
func start() async throws {
guard !self.devnetPath.isEmpty, !self.scarbPath.isEmpty, !self.snCastPath.isEmpty else {
throw DevnetClientError.environmentVariablesNotSet
}
Expand Down Expand Up @@ -287,7 +287,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
try fileManager.copyItem(at: accountsResourcePath, to: newAccountsPath)
}

public func close() {
func close() {
guard devnetProcess != nil else {
return
}
Expand All @@ -302,7 +302,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
self.devnetProcess = nil
}

public func prefundAccount(address: Felt, amount: BigUInt, unit: StarknetPriceUnit) async throws {
func prefundAccount(address: Felt, amount: BigUInt, unit: StarknetPriceUnit) async throws {
try guardDevnetIsRunning()

let url = URL(string: mintUrl)!
Expand Down Expand Up @@ -347,7 +347,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
}
}

public func createDeployAccount(
func createDeployAccount(
name: String,
classHash: Felt = DevnetClientConstants.accountContractClassHash,
salt: Felt? = nil
Expand All @@ -365,7 +365,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
)
}

public func createAccount(
func createAccount(
name: String,
classHash: Felt = DevnetClientConstants.accountContractClassHash,
salt: Felt? = nil,
Expand Down Expand Up @@ -400,7 +400,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
)
}

public func deployAccount(
func deployAccount(
name: String,
classHash _: Felt = DevnetClientConstants.accountContractClassHash,
prefund: Bool = true
Expand Down Expand Up @@ -437,7 +437,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
return result
}

public func declareDeployContract(
func declareDeployContract(
contractName: String,
constructorCalldata: [Felt] = [],
salt: Felt? = nil,
Expand All @@ -459,7 +459,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
)
}

public func declareContract(contractName: String) async throws -> DeclareContractResult {
func declareContract(contractName: String) async throws -> DeclareContractResult {
try guardDevnetIsRunning()

if let result = declaredContractsAtName[contractName] {
Expand All @@ -486,7 +486,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
return result
}

public func deployContract(
func deployContract(
classHash: Felt,
constructorCalldata: [Felt] = [],
salt: Felt? = nil,
Expand Down Expand Up @@ -531,7 +531,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
return result
}

public func invokeContract(
func invokeContract(
contractAddress: Felt,
function: String,
calldata: [Felt] = []
Expand Down Expand Up @@ -559,7 +559,7 @@ func makeDevnetClient() -> DevnetClientProtocol {
return InvokeContractResult(transactionHash: response.transactionHash)
}

public func isRunning() -> Bool {
func isRunning() -> Bool {
if let devnetProcess, devnetProcess.isRunning {
return true
}
Expand Down Expand Up @@ -626,7 +626,7 @@ func makeDevnetClient() -> DevnetClientProtocol {

typealias AccountDetailsResponse = [String: [String: AccountDetails]]

public func readAccountDetails(accountName: String) throws -> AccountDetails {
func readAccountDetails(accountName: String) throws -> AccountDetails {
let filename = "\(accountDirectory)/starknet_open_zeppelin_accounts.json"

let contents = try String(contentsOfFile: filename)
Expand All @@ -651,19 +651,19 @@ func makeDevnetClient() -> DevnetClientProtocol {
try await Task.sleep(nanoseconds: seconds * UInt64(Double(NSEC_PER_SEC)))
}

public func assertTransactionSucceeded(transactionHash: Felt) async throws {
func assertTransactionSucceeded(transactionHash: Felt) async throws {
guard try await isTransactionSuccessful(transactionHash: transactionHash) == true else {
throw DevnetClientError.transactionFailed
}
}

public func assertTransactionFailed(transactionHash: Felt) async throws {
func assertTransactionFailed(transactionHash: Felt) async throws {
guard try await isTransactionSuccessful(transactionHash: transactionHash) == false else {
throw DevnetClientError.transactionSucceeded
}
}

public func isTransactionSuccessful(transactionHash: Felt) async throws -> Bool {
func isTransactionSuccessful(transactionHash: Felt) async throws -> Bool {
let params = GetTransactionByHashParams(hash: transactionHash)
let rpcPayload = JsonRpcPayload(method: .getTransactionReceipt, params: .getTransactionByHash(params))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct AccountDetails: Codable {
self.salt = salt
}

public init(from decoder: Decoder) throws {
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)

self.privateKey = try container.decode(Felt.self, forKey: .privateKey)
Expand Down Expand Up @@ -85,7 +85,7 @@ struct PrefundPayload: Codable {
case unit
}

public func encode(to encoder: Encoder) throws {
func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(address, forKey: .address)
try container.encode(amount.description, forKey: .amount)
Expand All @@ -107,7 +107,7 @@ struct DevnetReceipt: Decodable {
case finalityStatus = "finality_status"
}

public var isSuccessful: Bool {
var isSuccessful: Bool {
switch status {
case nil:
executionStatus == .succeeded && (finalityStatus == .acceptedL1 || finalityStatus == .acceptedL2)
Expand All @@ -116,7 +116,7 @@ struct DevnetReceipt: Decodable {
}
}

public init(from decoder: Decoder) throws {
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)

self.status = try container.decodeIfPresent(StarknetTransactionStatus.self, forKey: .status)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ enum SnCastResponseWrapper: Decodable {
case deploy(DeploySnCastResponse)
case invoke(InvokeSnCastResponse)

public var response: any SnCastResponse {
var response: any SnCastResponse {
switch self {
case let .accountCreate(res):
res
Expand All @@ -29,7 +29,7 @@ enum SnCastResponseWrapper: Decodable {
}
}

public init(from decoder: Decoder) throws {
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: Keys.self)
let command = try container.decode(SnCastCommand.self, forKey: .command)
let error = try container.decodeIfPresent(String.self, forKey: .error)
Expand Down