Skip to content

Commit 9571fde

Browse files
committed
error message for forbidden output gates
1 parent ede04f8 commit 9571fde

File tree

2 files changed

+91
-69
lines changed

2 files changed

+91
-69
lines changed

src/abycore/circuit/circuit.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,21 @@ gate_specific Circuit::GetGateSpecificOutput(uint32_t gateid) {
9696
}
9797

9898
uint32_t Circuit::GetOutputGateValue(uint32_t gateid, UGATE_T*& outval) {
99-
assert(m_vGates[gateid].instantiated);
99+
//assert(m_vGates[gateid].instantiated);
100+
if(!m_vGates[gateid].instantiated){
101+
std::cerr << "Output not allowed for this role. Returned value will be wrong!" << std::endl;
102+
return 0;
103+
}
100104
outval = m_vGates[gateid].gs.val;
101105
return m_vGates[gateid].nvals;
102106
}
103107

104108
UGATE_T* Circuit::GetOutputGateValue(uint32_t gateid) {
105-
assert(m_vGates[gateid].instantiated);
109+
//assert(m_vGates[gateid].instantiated);
110+
if(!m_vGates[gateid].instantiated){
111+
std::cerr << "Output not allowed for this role! Returned value will be wrong!" << std::endl;
112+
return nullptr;
113+
}
106114
return m_vGates[gateid].gs.val;
107115
}
108116

src/abycore/circuit/share.cpp

+81-67
Original file line numberDiff line numberDiff line change
@@ -118,76 +118,82 @@ uint8_t* boolshare::get_clear_value_ptr() {
118118
uint32_t nvals = m_ccirc->GetNumVals(m_ngateids[0]);
119119
uint32_t bytelen = ceil_divide(m_ngateids.size(), 8);
120120

121-
out = (uint8_t*) calloc(ceil_divide(m_ngateids.size(), 8) * nvals, sizeof(uint8_t));
121+
out = (uint8_t*)calloc(ceil_divide(m_ngateids.size(), 8) * nvals, sizeof(uint8_t));
122122

123123
for (uint32_t i = 0, ibytes; i < m_ngateids.size(); i++) {
124124
assert(nvals == m_ccirc->GetNumVals(m_ngateids[i]));
125125
gatevals = m_ccirc->GetOutputGateValue(m_ngateids[i]);
126126

127-
ibytes = i / 8;
128-
for (uint32_t j = 0; j < nvals; j++) {
129-
out[j * bytelen + ibytes] += (((gatevals[j / 64] >> (j % 64)) & 0x01) << (i & 0x07));
127+
// only write sth if there are values (output might not be for this party)
128+
if (gatevals != nullptr) {
129+
ibytes = i / 8;
130+
for (uint32_t j = 0; j < nvals; j++) {
131+
out[j * bytelen + ibytes] += (((gatevals[j / 64] >> (j % 64)) & 0x01) << (i & 0x07));
132+
}
130133
}
131134
}
132135
return out;
133136
}
134137

135-
136138
//TODO This method will only work up to a bitlength of 32
137-
void boolshare::get_clear_value_vec(uint32_t** vec, uint32_t *bitlen, uint32_t *nvals) {
139+
void boolshare::get_clear_value_vec(uint32_t** vec, uint32_t* bitlen, uint32_t* nvals) {
138140
assert(m_ngateids.size() <= sizeof(uint32_t) * 8);
139141
UGATE_T* outvalptr;
140142
uint32_t gnvals;
141143

142-
*nvals = 1;
144+
*nvals = 0;
143145
*nvals = m_ccirc->GetOutputGateValue(m_ngateids[0], outvalptr);
144-
*vec = (uint32_t*) calloc(*nvals, sizeof(uint32_t));
145-
146-
for (uint32_t j = 0; j < *nvals; j++) {
147-
(*vec)[j] = (outvalptr[j / 64] >> (j % 64)) & 0x01;
148-
}
149146

150-
for (uint32_t i = 1; i < m_ngateids.size(); i++) {
151-
gnvals = m_ccirc->GetOutputGateValue(m_ngateids[i], outvalptr);
152-
assert(*nvals == gnvals);
147+
// only continue if there are values (output might not be for this party)
148+
if (*nvals > 0) {
149+
*vec = (uint32_t*)calloc(*nvals, sizeof(uint32_t));
153150

154151
for (uint32_t j = 0; j < *nvals; j++) {
155-
(*vec)[j] = (*vec)[j] + (((outvalptr[j / 64] >> (j % 64)) & 0x01) << i);
152+
(*vec)[j] = (outvalptr[j / 64] >> (j % 64)) & 0x01;
156153
}
154+
155+
for (uint32_t i = 1; i < m_ngateids.size(); i++) {
156+
gnvals = m_ccirc->GetOutputGateValue(m_ngateids[i], outvalptr);
157+
assert(*nvals == gnvals); //check that all wires have same nvals
158+
159+
for (uint32_t j = 0; j < *nvals; j++) {
160+
(*vec)[j] = (*vec)[j] + (((outvalptr[j / 64] >> (j % 64)) & 0x01) << i);
161+
}
162+
}
163+
*bitlen = m_ngateids.size();
157164
}
158-
*bitlen = m_ngateids.size();
159-
//return nvals;
160165
}
161166

162-
163167
//TODO: copied from 32 bits. Put template in and test later on!
164168
//TODO This method will only work up to a bitlength of 64
165-
void boolshare::get_clear_value_vec(uint64_t** vec, uint32_t *bitlen, uint32_t *nvals) {
169+
void boolshare::get_clear_value_vec(uint64_t** vec, uint32_t* bitlen, uint32_t* nvals) {
166170
assert(m_ngateids.size() <= sizeof(uint64_t) * 8);
167171
UGATE_T* outvalptr;
168172
uint32_t gnvals;
169173

170-
*nvals = 1;
174+
*nvals = 0;
171175
*nvals = m_ccirc->GetOutputGateValue(m_ngateids[0], outvalptr);
172-
*vec = (uint64_t*) calloc(*nvals, sizeof(uint64_t));
173-
174-
for (uint32_t j = 0; j < *nvals; j++) {
175-
(*vec)[j] = (outvalptr[j / 64] >> (j % 64)) & 0x01;
176-
}
177176

178-
for (uint32_t i = 1; i < m_ngateids.size(); i++) {
179-
gnvals = m_ccirc->GetOutputGateValue(m_ngateids[i], outvalptr);
180-
assert(*nvals == gnvals);
177+
// only continue if there are values (output might not be for this party)
178+
if (*nvals > 0) {
179+
*vec = (uint64_t*)calloc(*nvals, sizeof(uint64_t));
181180

182181
for (uint32_t j = 0; j < *nvals; j++) {
183-
(*vec)[j] = (*vec)[j] + (((outvalptr[j / 64] >> (j % 64)) & 0x01) << i);
182+
(*vec)[j] = (outvalptr[j / 64] >> (j % 64)) & 0x01;
183+
}
184+
185+
for (uint32_t i = 1; i < m_ngateids.size(); i++) {
186+
gnvals = m_ccirc->GetOutputGateValue(m_ngateids[i], outvalptr);
187+
assert(*nvals == gnvals); //check that all wires have same nvals
188+
189+
for (uint32_t j = 0; j < *nvals; j++) {
190+
(*vec)[j] = (*vec)[j] + (((outvalptr[j / 64] >> (j % 64)) & 0x01) << i);
191+
}
184192
}
193+
*bitlen = m_ngateids.size();
185194
}
186-
*bitlen = m_ngateids.size();
187-
//return nvals;
188195
}
189196

190-
191197
yao_fields* boolshare::get_internal_yao_keys() {
192198
yao_fields* out;
193199
uint32_t nvals = m_ccirc->GetNumVals(m_ngateids[0]);
@@ -219,35 +225,44 @@ yao_fields* boolshare::get_internal_yao_keys() {
219225
uint8_t* arithshare::get_clear_value_ptr() {
220226
UGATE_T* gate_val;
221227
uint32_t nvals = m_ccirc->GetOutputGateValue(m_ngateids[0], gate_val);
222-
uint8_t* out = (uint8_t*) malloc(nvals * sizeof(uint32_t));
223-
for (uint32_t i = 0; i < nvals; i++) {
224-
((uint32_t*) out)[i] = (uint32_t) gate_val[i];
228+
if (nvals > 0) {
229+
uint8_t* out = (uint8_t*)malloc(nvals * sizeof(uint32_t));
230+
for (uint32_t i = 0; i < nvals; i++) {
231+
((uint32_t*)out)[i] = (uint32_t)gate_val[i];
232+
}
233+
return out;
234+
}
235+
else{
236+
return nullptr;
225237
}
226-
return out;
227238
}
228239

229240
void arithshare::get_clear_value_vec(uint32_t** vec, uint32_t* bitlen, uint32_t* nvals) {
230241
//assert(m_ngateids.size() <= sizeof(uint32_t) * 8);
231242

232243
UGATE_T* gate_val;
233244
*nvals = 0;
234-
for(uint32_t i = 0; i < m_ngateids.size(); i++) {
235-
(*nvals) += m_ccirc->GetOutputGateValue(m_ngateids[i], gate_val);
236-
}
237-
uint32_t sharebytes = ceil_divide(m_ccirc->GetShareBitLen(), 8);
238245

239-
//*nvals = m_ccirc->GetOutputGateValue(m_ngateids[0], gate_val);
240-
*vec = (uint32_t*) calloc(*nvals, sizeof(uint32_t));
246+
// only continue if there are values (output might not be for this party)
247+
if (m_ccirc->GetOutputGateValue(m_ngateids[0], gate_val) > 0) {
248+
for (uint32_t i = 0; i < m_ngateids.size(); i++) {
249+
(*nvals) += m_ccirc->GetOutputGateValue(m_ngateids[i], gate_val);
250+
}
251+
uint32_t sharebytes = ceil_divide(m_ccirc->GetShareBitLen(), 8);
252+
253+
//*nvals = m_ccirc->GetOutputGateValue(m_ngateids[0], gate_val);
254+
*vec = (uint32_t*)calloc(*nvals, sizeof(uint32_t));
241255

242-
for(uint32_t i = 0, tmpctr=0, tmpnvals; i < m_ngateids.size(); i++) {
243-
tmpnvals = m_ccirc->GetOutputGateValue(m_ngateids[i], gate_val);
244-
//cout << m_ngateids[i] << " gateval = " << gate_val[0] << ", nvals = " << *nvals << ", sharebitlen = " << m_ccirc->GetShareBitLen() << endl;
245-
for(uint32_t j = 0; j < tmpnvals; j++, tmpctr++) {
246-
memcpy((*vec)+tmpctr, ((uint8_t*) gate_val)+(j*sharebytes), sharebytes);
256+
for (uint32_t i = 0, tmpctr = 0, tmpnvals; i < m_ngateids.size(); i++) {
257+
tmpnvals = m_ccirc->GetOutputGateValue(m_ngateids[i], gate_val);
258+
//cout << m_ngateids[i] << " gateval = " << gate_val[0] << ", nvals = " << *nvals << ", sharebitlen = " << m_ccirc->GetShareBitLen() << endl;
259+
for (uint32_t j = 0; j < tmpnvals; j++, tmpctr++) {
260+
memcpy((*vec) + tmpctr, ((uint8_t*)gate_val) + (j * sharebytes), sharebytes);
261+
}
247262
}
248-
}
249263

250-
*bitlen = m_ccirc->GetShareBitLen();
264+
*bitlen = m_ccirc->GetShareBitLen();
265+
}
251266
}
252267

253268
//TODO: copied from 32 bits. Put template in and test later on!
@@ -256,23 +271,26 @@ void arithshare::get_clear_value_vec(uint64_t** vec, uint32_t* bitlen, uint32_t*
256271

257272
UGATE_T* gate_val;
258273
*nvals = 0;
259-
for(uint32_t i = 0; i < m_ngateids.size(); i++) {
260-
(*nvals) += m_ccirc->GetOutputGateValue(m_ngateids[i], gate_val);
261-
}
262-
uint32_t sharebytes = ceil_divide(m_ccirc->GetShareBitLen(), 8);
274+
// only continue if there are values (output might not be for this party)
275+
if (m_ccirc->GetOutputGateValue(m_ngateids[0], gate_val) > 0) {
276+
for (uint32_t i = 0; i < m_ngateids.size(); i++) {
277+
(*nvals) += m_ccirc->GetOutputGateValue(m_ngateids[i], gate_val);
278+
}
279+
uint32_t sharebytes = ceil_divide(m_ccirc->GetShareBitLen(), 8);
263280

264-
//*nvals = m_ccirc->GetOutputGateValue(m_ngateids[0], gate_val);
265-
*vec = (uint64_t*) calloc(*nvals, sizeof(uint64_t));
281+
//*nvals = m_ccirc->GetOutputGateValue(m_ngateids[0], gate_val);
282+
*vec = (uint64_t*)calloc(*nvals, sizeof(uint64_t));
266283

267-
for(uint32_t i = 0, tmpctr=0, tmpnvals; i < m_ngateids.size(); i++) {
268-
tmpnvals = m_ccirc->GetOutputGateValue(m_ngateids[i], gate_val);
269-
//cout << m_ngateids[i] << " gateval = " << gate_val[0] << ", nvals = " << *nvals << ", sharebitlen = " << m_ccirc->GetShareBitLen() << endl;
270-
for(uint32_t j = 0; j < tmpnvals; j++, tmpctr++) {
271-
memcpy((*vec)+tmpctr, ((uint8_t*) gate_val)+(j*sharebytes), sharebytes);
284+
for (uint32_t i = 0, tmpctr = 0, tmpnvals; i < m_ngateids.size(); i++) {
285+
tmpnvals = m_ccirc->GetOutputGateValue(m_ngateids[i], gate_val);
286+
//cout << m_ngateids[i] << " gateval = " << gate_val[0] << ", nvals = " << *nvals << ", sharebitlen = " << m_ccirc->GetShareBitLen() << endl;
287+
for (uint32_t j = 0; j < tmpnvals; j++, tmpctr++) {
288+
memcpy((*vec) + tmpctr, ((uint8_t*)gate_val) + (j * sharebytes), sharebytes);
289+
}
272290
}
273-
}
274291

275-
*bitlen = m_ccirc->GetShareBitLen();
292+
*bitlen = m_ccirc->GetShareBitLen();
293+
}
276294
}
277295

278296
share* arithshare::get_share_from_wire_id(uint32_t shareid) {
@@ -288,7 +306,3 @@ share* boolshare::get_share_from_wire_id(uint32_t shareid) {
288306
new_shr->set_wire_id(shareid,get_wire_id(shareid));
289307
return new_shr;
290308
}
291-
292-
293-
294-

0 commit comments

Comments
 (0)