Skip to content

Commit 3cb66f5

Browse files
committed
[otbn,sim] Add masking accelerator interface simulator implementation
This adds the simulator implementation of the masking accelerator interface (MAI). The MAI allows the OTBN to offload A2B, B2A or secAdd computations to hardened accelerators. Signed-off-by: Pascal Etterli <[email protected]>
1 parent 9e03e28 commit 3cb66f5

File tree

4 files changed

+341
-2
lines changed

4 files changed

+341
-2
lines changed

hw/ip/otbn/dv/otbnsim/sim/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ErrBits(IntEnum):
3232
KEY_INVALID = 1 << 5
3333
RND_REP_CHK_FAIL = 1 << 6
3434
RND_FIPS_CHK_FAIL = 1 << 7
35+
MAI_ERROR = 1 << 8
3536
IMEM_INTG_VIOLATION = 1 << 16
3637
DMEM_INTG_VIOLATION = 1 << 17
3738
REG_INTG_VIOLATION = 1 << 18

hw/ip/otbn/dv/otbnsim/sim/insn.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,11 +434,18 @@ def execute(self, state: OTBNState) -> Optional[Iterator[None]]:
434434
# There's a pending EDN request. Stall for a cycle.
435435
yield None
436436

437-
# At this point, the CSR is ready. Read, update and write back to grs1.
437+
# At this point, the CSR is ready. Read it to grd.
438438
old_val = state.read_csr(self.csr)
439-
new_val = old_val | bits_to_set
440439
state.gprs.get_reg(self.grd).write_unsigned(old_val)
440+
441+
# If CSR should be updated, compute update, check if update is allowed
442+
# and write it back.
441443
if self.grs1 != 0:
444+
new_val = old_val | bits_to_set
445+
if self.csr == 0x7f0:
446+
if not state.mai.is_valid_ctrl_change(new_val):
447+
state.stop_at_end_of_cycle(ErrBits.MAI_ERROR)
448+
return None
442449
state.write_csr(self.csr, new_val)
443450

444451
return None
@@ -479,6 +486,12 @@ def execute(self, state: OTBNState) -> Optional[Iterator[None]]:
479486
old_val = state.read_csr(self.csr)
480487
state.gprs.get_reg(self.grd).write_unsigned(old_val)
481488

489+
# Check if the write to MAI_CTRL is allowed.
490+
if self.csr == 0x7f0:
491+
if not state.mai.is_valid_ctrl_change(new_val):
492+
state.stop_at_end_of_cycle(ErrBits.MAI_ERROR)
493+
return None
494+
482495
state.write_csr(self.csr, new_val)
483496
return None
484497

@@ -1271,6 +1284,13 @@ def execute(self, state: OTBNState) -> None:
12711284
state.stop_at_end_of_cycle(ErrBits.ILLEGAL_INSN)
12721285
return None
12731286

1287+
# Check if MAI is ready to accept new inputs. If not stop with MAI
1288+
# error.
1289+
if self.wsr in [12, 13, 14, 15]:
1290+
if not state.mai.ready_for_inputs():
1291+
state.stop_at_end_of_cycle(ErrBits.MAI_ERROR)
1292+
return None
1293+
12741294
val = state.wdrs.get_reg(self.wrs).read_unsigned()
12751295
state.wsrs.write_at_idx(self.wsr, val)
12761296

hw/ip/otbn/dv/otbnsim/sim/mai.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# Copyright lowRISC contributors (OpenTitan project).
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import Optional, Tuple
6+
from collections import deque
7+
8+
from .csr import CSRFile, MaiOperation
9+
from .wsr import WSRFile, DumbWSR
10+
11+
# The masking accelerator interface (MAI) emulates the behavior of the interface and the actual
12+
# accelerators.
13+
14+
# Enable or disable assertions which check that the inputs and outputs of the accelerators
15+
# meet certain constraints (e.g., being smaller than the modulus).
16+
CHECK_ACCELERATOR_CONSTRAINTS = False
17+
18+
19+
class MaskingAccelerator:
20+
'''Models a masking accelerator which has a simple pipeline.
21+
New operations can be pushed to the accelerator, and results can be popped from it.
22+
Each step of the simulation advances the pipeline by one stage.
23+
'''
24+
25+
def __init__(self, latency: int, mod_wsr: DumbWSR) -> None:
26+
# Latency of the masking accelerator in cycles.
27+
self.latency = latency
28+
29+
# The MOD WSR is used to get the current modulus for operations.
30+
self.mod_wsr = mod_wsr
31+
32+
# The pipeline contains the two result shares and is modeled with a deque where None
33+
# indicates an empty slot.
34+
self.pipeline: deque[Optional[Tuple[int, int]]]
35+
self.pipeline = deque([None] * self.latency, self.latency)
36+
37+
def push(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> bool:
38+
'''Try to push an operation to the masking accelerator pipeline.
39+
40+
Returns True if the accelerator can accept it (free pipeline slot), False otherwise.
41+
'''
42+
# This accelerator implementation features no backpressure, so we always accept new
43+
# operations. Pop the leftmost pipeline slot and replace it with the new operation result.
44+
# The result is computed immediately but will only be available after the full pipeline
45+
# latency.
46+
self.pipeline.popleft()
47+
self.pipeline.appendleft(self._compute(in0_s0, in0_s1, in1_s0, in1_s1))
48+
return True
49+
50+
def pop(self) -> Optional[Tuple[int, int]]:
51+
'''Read the current output of the masking accelerator pipeline.'''
52+
# We do only peak the pipeline as the pipeline advancing is modelled in the step() method.
53+
return self.pipeline[-1]
54+
55+
def step(self) -> None:
56+
'''Advance the pipeline by one stage if possible.'''
57+
# This accelerator implementation features no backpressure, so we always advance the
58+
# pipeline. We insert an unused pipeline slot which is replaced in case a new item is
59+
# pushed. appendleft() will drop the rightmost item automatically.
60+
self.pipeline.appendleft(None)
61+
62+
def is_busy(self) -> bool:
63+
'''Return True if the accelerator is busy (has pending operations), False otherwise.'''
64+
# The accelerator is busy if there is at least one non-None item in the pipeline.
65+
return any(slot is not None for slot in self.pipeline)
66+
67+
def _modulus(self) -> int:
68+
'''Return the current 32-bit modulus from the modulus WSR.'''
69+
return self.mod_wsr.read_unsigned() & ((1 << 32) - 1)
70+
71+
def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]:
72+
'''Compute the result of the masking operation.'''
73+
raise NotImplementedError
74+
75+
76+
class A2BAccelerator(MaskingAccelerator):
77+
def __init__(self, mod_wsr: DumbWSR):
78+
super().__init__(8, mod_wsr)
79+
80+
def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]:
81+
# The current placeholder implementation removes the arithmetic mask and adds a new boolean
82+
# mask. We use a fixed mask until the exact design is known.
83+
#
84+
# Input: (x - s mod q, s), (x - s) + s mod q = x, 0 <= s < q
85+
# Output: (x XOR r, r), x XOR r XOR r = x, 0 <= x, r < q < 2^k
86+
87+
# in1_s0 and in1_s1 are not used by the A2B accelerator
88+
89+
s = in0_s1
90+
# We take a zero mask for simplicity and to avoid assertion errors due to modulus being
91+
# smaller than this fixed constant.
92+
r = 0
93+
secret = (in0_s0 + s) % self._modulus()
94+
masked_secret = (secret ^ r)
95+
96+
# Optionally, we crash if the constraints are not met.
97+
if CHECK_ACCELERATOR_CONSTRAINTS:
98+
assert self._modulus() < 2**32
99+
assert 0 <= s < self._modulus()
100+
assert 0 <= r < self._modulus()
101+
assert 0 <= secret < self._modulus()
102+
103+
# Limit results to 32 bits
104+
masked_secret &= ((1 << 32) - 1)
105+
r &= ((1 << 32) - 1)
106+
return (masked_secret, r)
107+
108+
109+
class B2AAccelerator(MaskingAccelerator):
110+
def __init__(self, mod_wsr: DumbWSR):
111+
super().__init__(7, mod_wsr)
112+
113+
def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]:
114+
# The current placeholder implementation removes the boolean mask and adds a new arithmetic
115+
# mask. We use a fixed mask until the exact design is known.
116+
#
117+
# Input: (x XOR r, r), 0 <= x, r < q < 2^k
118+
# Output: (x - s mod q, s), (x - s) + s mod q = x, 0 <= s < q
119+
120+
# in1_s0 and in1_s1 are not used by the B2A accelerator
121+
122+
# We take a zero mask for simplicity and to avoid assertion errors due to modulus being
123+
# smaller than this fixed constant.
124+
s = 0
125+
r = in0_s1
126+
127+
secret = in0_s0 ^ r
128+
masked_secret = (secret - s) % self._modulus()
129+
130+
# Optionally, we crash if the constraints are not met.
131+
if CHECK_ACCELERATOR_CONSTRAINTS:
132+
assert self._modulus() < 2**32
133+
assert 0 <= in0_s0 < self._modulus()
134+
assert 0 <= r < self._modulus()
135+
assert 0 <= s < self._modulus()
136+
137+
# Limit results to 32 bits
138+
masked_secret &= ((1 << 32) - 1)
139+
s &= ((1 << 32) - 1)
140+
return (masked_secret, s)
141+
142+
143+
class SecAddModqAccelerator(MaskingAccelerator):
144+
def __init__(self, mod_wsr: DumbWSR):
145+
super().__init__(9, mod_wsr)
146+
147+
def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]:
148+
# The current placeholder implementation removes the boolean masks, adds the secrets
149+
# modulo q, and adds a new boolean mask. We use a fixed mask until the exact design is
150+
# known.
151+
#
152+
# Input: (x xor r1, r1), (y xor s1, s1), 0 <= x, y, s, r < q < 2^k
153+
# Output: ((x + y mod q) XOR t, t)
154+
r1 = in0_s1
155+
s1 = in1_s1
156+
# We take a zero mask for simplicity and to avoid assertion errors due to modulus being
157+
# smaller than this fixed constant.
158+
t = 0
159+
160+
x = in0_s0 ^ r1
161+
y = in1_s0 ^ s1
162+
sum = (x + y) % self._modulus()
163+
masked_sum = sum ^ t
164+
165+
if CHECK_ACCELERATOR_CONSTRAINTS:
166+
assert self._modulus() < 2**32
167+
assert 0 <= x < self._modulus()
168+
assert 0 <= y < self._modulus()
169+
assert 0 <= r1 < self._modulus()
170+
assert 0 <= s1 < self._modulus()
171+
assert 0 <= t < self._modulus()
172+
173+
# Limit results to 32 bits
174+
masked_sum &= ((1 << 32) - 1)
175+
t &= ((1 << 32) - 1)
176+
return (masked_sum, t)
177+
178+
179+
class MaskingAcceleratorInterface:
180+
def __init__(self, csrs: CSRFile, wsrs: WSRFile) -> None:
181+
182+
# The CSRs and WSRs
183+
self.csrs = csrs
184+
self.wsrs = wsrs
185+
self.mai_ctrl = self.csrs.MaiCtrl
186+
self.mai_status = self.csrs.MaiStatus
187+
self.mai_res_s0 = self.wsrs.MaiResS0
188+
self.mai_res_s1 = self.wsrs.MaiResS1
189+
self.mai_in0_s0 = self.wsrs.MaiIn0S0
190+
self.mai_in0_s1 = self.wsrs.MaiIn0S1
191+
self.mai_in1_s0 = self.wsrs.MaiIn1S0
192+
self.mai_in1_s1 = self.wsrs.MaiIn1S1
193+
194+
# All available accelerators are instantiated here in a dictionary.
195+
# The currently active accelerator is selected based on the operation field in MAI_CTRL.
196+
# Changing the operation while an operation is ongoing is not allowed (see
197+
# is_valid_ctrl_change). Thus, the step() method can simply read the operation field each
198+
# cycle to get the current accelerator like this:
199+
# self._all_accelerators[self.mai_ctrl.read_operation()]
200+
self._all_accelerators = {
201+
MaiOperation.A2B: A2BAccelerator(self.wsrs.MOD),
202+
MaiOperation.B2A: B2AAccelerator(self.wsrs.MOD),
203+
MaiOperation.SECADD: SecAddModqAccelerator(self.wsrs.MOD),
204+
}
205+
206+
# Dispatch related variables
207+
# The dispatch logic is responsible for pushing inputs into the accelerator.
208+
self._dispatch_idx = 0
209+
self.is_dispatching = False
210+
211+
# Writeback related variables
212+
# The writeback logic is responsible for receiving results from the accelerator into the
213+
# output WSRs.
214+
self._writeback_idx = 0
215+
216+
def _accelerator(self) -> MaskingAccelerator:
217+
'''Return the currently selected masking accelerator based on the operation field.'''
218+
return self._all_accelerators[self.mai_ctrl.read_operation()]
219+
220+
def step(self) -> None:
221+
'''Advance the MAI simulation by one cycle.
222+
223+
This is expected to be called before the current instruction executes / steps.
224+
'''
225+
###################
226+
# Writeback logic #
227+
###################
228+
# Get the newest result and write it into the output WSRs. This is done before
229+
# advancing the pipeline to model the fact that the result is available at
230+
# the start of the cycle.
231+
results = self._accelerator().pop()
232+
if results is not None:
233+
# Write to the output WSRs
234+
self.mai_res_s0.set_32bit_unsigned(results[0], self._writeback_idx)
235+
self.mai_res_s1.set_32bit_unsigned(results[1], self._writeback_idx)
236+
self._writeback_idx += 1
237+
238+
# Detect if we finished writing back
239+
if self._writeback_idx >= 8:
240+
self._writeback_idx = 0
241+
# If we are finishing the writeback, reset the busy bit. The write method update the
242+
# bits when committing to the changes so the current instruction still reads the old
243+
# value.
244+
self.mai_status.write_busy_bit(False)
245+
246+
######################
247+
# Accelerator update #
248+
######################
249+
# Advance the accelerator pipeline.
250+
self._accelerator().step()
251+
252+
#################
253+
# Start logic #
254+
#################
255+
# Start a new operation if start bit was set in last cycle
256+
if self.mai_ctrl.read_start_bit():
257+
# Begin pushing inputs in the dispatch logic
258+
self.is_dispatching = True
259+
# Immediately set the busy bit such that the current instruction reads it as set.
260+
self.mai_status.set_busy_bit(True)
261+
# Immediately reset the ready bit such that the current instruction reads it as reset
262+
# and any configuration change check does not allow changing the operation type.
263+
self.mai_status.set_ready_bit(False)
264+
# Immediately reset the start bit such that it always reads zero.
265+
self.mai_ctrl.set_start_bit(False)
266+
267+
##################
268+
# Dispatch logic #
269+
##################
270+
if self.is_dispatching:
271+
self._accelerator().push(self.mai_in0_s0.read_32bit_unsigned(self._dispatch_idx),
272+
self.mai_in0_s1.read_32bit_unsigned(self._dispatch_idx),
273+
self.mai_in1_s0.read_32bit_unsigned(self._dispatch_idx),
274+
self.mai_in1_s1.read_32bit_unsigned(self._dispatch_idx))
275+
self._dispatch_idx += 1
276+
277+
# Detect if we have finished dispatching
278+
if self._dispatch_idx >= 8:
279+
self._dispatch_idx = 0
280+
self.is_dispatching = False
281+
# Set the ready bit at the end of this cycle. This indicates that new inputs can be
282+
# accepted.
283+
self.mai_status.write_ready_bit(True)
284+
285+
def is_busy(self) -> bool:
286+
'''Returns whether the MAI is currently busy processing an operation.'''
287+
return self.mai_status.read_busy_bit()
288+
289+
def is_ready(self) -> bool:
290+
'''Returns whether the MAI is ready to accept new inputs.'''
291+
return self.mai_status.read_ready_bit()
292+
293+
def ready_for_inputs(self) -> bool:
294+
return self.is_ready()
295+
296+
def ready_to_start(self) -> bool:
297+
return not self.is_busy()
298+
299+
def is_valid_ctrl_change(self, value: int) -> bool:
300+
'''Return whether writing value to the MAI_CTRL CSR is currently allowed.'''
301+
# Starting is only allowed if MAI is ready.
302+
if self.mai_ctrl.would_set_start_bit(value) and not self.ready_to_start():
303+
return False
304+
305+
# We only allow setting the operation to valid options.
306+
if not self.mai_ctrl.is_valid_operation(value):
307+
return False
308+
309+
# Changing the operation is only allowed if MAI is not busy / no operation is ongoing.
310+
if self.mai_ctrl.would_change_op(value) and self.is_busy():
311+
return False
312+
313+
return True

hw/ip/otbn/dv/otbnsim/sim/state.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .flags import FlagReg
1616
from .gpr import GPRs
1717
from .loop import LoopStack
18+
from .mai import MaskingAcceleratorInterface
1819
from .reg import RegFile
1920
from .trace import Trace, TracePC
2021
from .wsr import WSRFile
@@ -200,6 +201,9 @@ def __init__(self) -> None:
200201
# random data).
201202
self.edn_seen_running = False
202203

204+
# The masking accelerator interface (MAI) handles the accelerators
205+
self.mai = MaskingAcceleratorInterface(self.csrs, self.wsrs)
206+
203207
def get_next_pc(self) -> int:
204208
if self._pc_next_override is not None:
205209
return self._pc_next_override
@@ -308,6 +312,7 @@ def step(self, handle_injected_error: bool) -> None:
308312
self.take_injected_err_bits()
309313
self.ext_regs.step()
310314
self._urnd_client.step()
315+
self.mai.step()
311316

312317
def commit(self, sim_stalled: bool) -> None:
313318
if self._time_to_imem_invalidation is not None:

0 commit comments

Comments
 (0)