-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathEthPlatform.py
More file actions
executable file
·237 lines (204 loc) · 8.12 KB
/
EthPlatform.py
File metadata and controls
executable file
·237 lines (204 loc) · 8.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
This module contains Ethereum Testing platform for Federated Learning.
"""
from web3 import Web3
from solcx import compile_source
from dataclasses import dataclass
@dataclass
class ContractInfo:
"""
Represents contract information for clients.
"""
contract_id: str
# Application binary interface
abi: list
address: str
def compileContract(w3, filename, *vargs):
"""
Given web3.py instance and .sol filename, create a contract with the default account.
The remaining arguments are passed to contract's constructor.
"""
with open("FL.sol", 'r') as f:
solidity_code = f.read()
compiled_sol = compile_source(solidity_code, output_values=['abi', 'bin'])
# retrieve the contract interface
contract_id, contract_interface = compiled_sol.popitem()
bytecode = contract_interface['bin']
abi = contract_interface['abi']
contract = w3.eth.contract(abi=abi, bytecode=bytecode)
# Submit the transaction that deploys the contract
tx_hash = contract.constructor(*vargs).transact()
# Wait for the transaction to be mined, and get the transaction receipt
tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)
# Get the contract address from the receipt
address = tx_receipt.contractAddress
return ContractInfo(contract_id, abi, address)
def useAccount(func):
"""
A decorator for functions that need to use the account set in self.account
"""
def f(self, *vargs, **kwargs):
EthPlatform.w3.eth.default_account = self.account
return func(self, *vargs, **kwargs)
return f
class EthPlatform:
contractFilename = "FL.sol"
contractInfo = None
w3 = None
@staticmethod
def initAccounts(amount: int):
EthPlatform.w3 = Web3(Web3.EthereumTesterProvider())
amount = min(amount, len(EthPlatform.w3.eth.accounts))
users = []
for i, account in zip(range(amount), EthPlatform.w3.eth.accounts):
users.append(EthPlatform.Account(account))
return users
class Account:
"""
Wraps accounts with helper functions and some additional data.
"""
def __init__(self, account):
self.account = account
@useAccount
def deploy(self, *vargs):
"""
Deploys the contract with this account and obtain a reference to it.
"""
EthPlatform.contractInfo = compileContract(
EthPlatform.w3, EthPlatform.contractFilename, *vargs)
self.obtainContract()
def obtainContract(self):
"""
After the contract has been deployed by one user, the other users will call
this function to obtain a reference to it in self.contract.
"""
self.contract = EthPlatform.w3.eth.contract(
address=EthPlatform.contractInfo.address,
abi=EthPlatform.contractInfo.abi)
@useAccount
def getUpdateEvents(self, receipts):
"""
From a list of receipts get the processed update events.
"""
events = []
seenAddresses = set()
epoch = self.getEpoch()
for tx_receipt in receipts:
logs = self.contract.events.LocalUpdate().processReceipt(tx_receipt)
assert(len(logs) == 1)
args = logs[0]["args"]
address = args["from"]
if address in seenAddresses:
log.warning(f"Ignoring repeated update from address {address}")
continue
seenAddresses.add(address)
updateEpoch = epoch
if epoch != updateEpoch:
log.warning(f"Ignoring update with incorrect epoch {updateEpoch} from {address}")
continue
size = args["size"]
modelBytes = args["model"]
events.append((size, modelBytes))
return events
@useAccount
def getMeanEvents(self, receipts):
"""
From a list of receipts get the processed mean events.
"""
events = []
seenAddresses = set()
for tx_receipt in receipts:
logs = self.contract.events.LocalMeans().processReceipt(tx_receipt)
assert(len(logs) == 1)
args = logs[0]["args"]
address = args["from"]
if address in seenAddresses:
log.warning(f"Ignoring repeated mean report from address {address}")
continue
seenAddresses.add(address)
size = args["size"]
means = args["data"]
events.append((size, means))
return events
@useAccount
def getStdEvents(self, receipts):
"""
From a list of receipts get the processed std events.
"""
events = []
seenAddresses = set()
for tx_receipt in receipts:
logs = self.contract.events.LocalStds().processReceipt(tx_receipt)
assert(len(logs) == 1)
args = logs[0]["args"]
address = args["from"]
if address in seenAddresses:
log.warning(f"Ignoring repeated mean report from address {address}")
continue
seenAddresses.add(address)
size = args["size"]
means = args["data"]
events.append((size, means))
return events
@useAccount
def globalUpdate(self, modelBytes):
"""
Update the global model after weight averaging.
Should be called by owner only.
"""
tx_hash = self.contract.functions.globalUpdate(modelBytes).transact()
tx_receipt = EthPlatform.w3.eth.wait_for_transaction_receipt(tx_hash)
return tx_receipt
@useAccount
def localUpdate(self, *vargs):
"""
Trigger a local update event.
"""
tx_hash = self.contract.functions.localUpdate(*vargs).transact()
tx_receipt = EthPlatform.w3.eth.wait_for_transaction_receipt(tx_hash)
return tx_receipt
@useAccount
def globalMeans(self, meanBytes):
"""
Update the global means after mean averaging.
Should be called by owner only.
"""
tx_hash = self.contract.functions.globalMeans(meanBytes).transact()
tx_receipt = EthPlatform.w3.eth.wait_for_transaction_receipt(tx_hash)
return tx_receipt
@useAccount
def localMeans(self, *vargs):
"""
Trigger a local means event.
"""
tx_hash = self.contract.functions.localMeans(*vargs).transact()
tx_receipt = EthPlatform.w3.eth.wait_for_transaction_receipt(tx_hash)
return tx_receipt
@useAccount
def globalStds(self, stdBytes):
"""
Update the global stds after std averaging.
Should be called by owner only.
"""
tx_hash = self.contract.functions.globalStds(stdBytes).transact()
tx_receipt = EthPlatform.w3.eth.wait_for_transaction_receipt(tx_hash)
return tx_receipt
@useAccount
def localStds(self, *vargs):
"""
Trigger a local stds event.
"""
tx_hash = self.contract.functions.localStds(*vargs).transact()
tx_receipt = EthPlatform.w3.eth.wait_for_transaction_receipt(tx_hash)
return tx_receipt
# The following public accessor functions don't need to use account
def getModel(self):
return self.contract.functions.getModel().call()
def getEpoch(self):
return self.contract.functions.getEpoch().call()
def getDataSize(self):
return self.contract.functions.getDataSize().call()
def getMeans(self):
return self.contract.functions.getMeans().call()
def getStds(self):
return self.contract.functions.getStds().call()