|
| 1 | +# Copyright lowRISC contributors (OpenTitan project). |
| 2 | +# Licensed under the Apache License, Version 2.0, see LICENSE for details. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +from typing import Optional, Tuple |
| 6 | +from collections import deque |
| 7 | + |
| 8 | +from .csr import CSRFile, MaiOperation |
| 9 | +from .wsr import WSRFile, DumbWSR |
| 10 | + |
| 11 | +# The masking accelerator interface (MAI) emulates the behavior of the interface and the actual |
| 12 | +# accelerators. |
| 13 | + |
| 14 | +# Enable or disable assertions which check that the inputs and outputs of the accelerators |
| 15 | +# meet certain constraints (e.g., being smaller than the modulus). |
| 16 | +CHECK_ACCELERATOR_CONSTRAINTS = False |
| 17 | + |
| 18 | + |
| 19 | +class MaskingAccelerator: |
| 20 | + '''Models a masking accelerator which has a simple pipeline. |
| 21 | + New operations can be pushed to the accelerator, and results can be popped from it. |
| 22 | + Each step of the simulation advances the pipeline by one stage. |
| 23 | + ''' |
| 24 | + |
| 25 | + def __init__(self, latency: int, mod_wsr: DumbWSR) -> None: |
| 26 | + # Latency of the masking accelerator in cycles. |
| 27 | + self.latency = latency |
| 28 | + |
| 29 | + # The MOD WSR is used to get the current modulus for operations. |
| 30 | + self.mod_wsr = mod_wsr |
| 31 | + |
| 32 | + # The pipeline contains the two result shares and is modeled with a deque where None |
| 33 | + # indicates an empty slot. |
| 34 | + self.pipeline: deque[Optional[Tuple[int, int]]] |
| 35 | + self.pipeline = deque([None] * self.latency, self.latency) |
| 36 | + |
| 37 | + def push(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> bool: |
| 38 | + '''Try to push an operation to the masking accelerator pipeline. |
| 39 | +
|
| 40 | + Returns True if the accelerator can accept it (free pipeline slot), False otherwise. |
| 41 | + ''' |
| 42 | + # This accelerator implementation features no backpressure, so we always accept new |
| 43 | + # operations. Pop the leftmost pipeline slot and replace it with the new operation result. |
| 44 | + # The result is computed immediately but will only be available after the full pipeline |
| 45 | + # latency. |
| 46 | + self.pipeline.popleft() |
| 47 | + self.pipeline.appendleft(self._compute(in0_s0, in0_s1, in1_s0, in1_s1)) |
| 48 | + return True |
| 49 | + |
| 50 | + def pop(self) -> Optional[Tuple[int, int]]: |
| 51 | + '''Read the current output of the masking accelerator pipeline.''' |
| 52 | + # We do only peak the pipeline as the pipeline advancing is modelled in the step() method. |
| 53 | + return self.pipeline[-1] |
| 54 | + |
| 55 | + def step(self) -> None: |
| 56 | + '''Advance the pipeline by one stage if possible.''' |
| 57 | + # This accelerator implementation features no backpressure, so we always advance the |
| 58 | + # pipeline. We insert an unused pipeline slot which is replaced in case a new item is |
| 59 | + # pushed. appendleft() will drop the rightmost item automatically. |
| 60 | + self.pipeline.appendleft(None) |
| 61 | + |
| 62 | + def is_busy(self) -> bool: |
| 63 | + '''Return True if the accelerator is busy (has pending operations), False otherwise.''' |
| 64 | + # The accelerator is busy if there is at least one non-None item in the pipeline. |
| 65 | + return any(slot is not None for slot in self.pipeline) |
| 66 | + |
| 67 | + def _modulus(self) -> int: |
| 68 | + '''Return the current 32-bit modulus from the modulus WSR.''' |
| 69 | + return self.mod_wsr.read_unsigned() & ((1 << 32) - 1) |
| 70 | + |
| 71 | + def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]: |
| 72 | + '''Compute the result of the masking operation.''' |
| 73 | + raise NotImplementedError |
| 74 | + |
| 75 | + |
| 76 | +class A2BAccelerator(MaskingAccelerator): |
| 77 | + def __init__(self, mod_wsr: DumbWSR): |
| 78 | + super().__init__(8, mod_wsr) |
| 79 | + |
| 80 | + def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]: |
| 81 | + # The current placeholder implementation removes the arithmetic mask and adds a new boolean |
| 82 | + # mask. We use a fixed mask until the exact design is known. |
| 83 | + # |
| 84 | + # Input: (x - s mod q, s), (x - s) + s mod q = x, 0 <= s < q |
| 85 | + # Output: (x XOR r, r), x XOR r XOR r = x, 0 <= x, r < q < 2^k |
| 86 | + |
| 87 | + # in1_s0 and in1_s1 are not used by the A2B accelerator |
| 88 | + |
| 89 | + s = in0_s1 |
| 90 | + # We take a zero mask for simplicity and to avoid assertion errors due to modulus being |
| 91 | + # smaller than this fixed constant. |
| 92 | + r = 0 |
| 93 | + secret = (in0_s0 + s) % self._modulus() |
| 94 | + masked_secret = (secret ^ r) |
| 95 | + |
| 96 | + # Optionally, we crash if the constraints are not met. |
| 97 | + if CHECK_ACCELERATOR_CONSTRAINTS: |
| 98 | + assert self._modulus() < 2**32 |
| 99 | + assert 0 <= s < self._modulus() |
| 100 | + assert 0 <= r < self._modulus() |
| 101 | + assert 0 <= secret < self._modulus() |
| 102 | + |
| 103 | + # Limit results to 32 bits |
| 104 | + masked_secret &= ((1 << 32) - 1) |
| 105 | + r &= ((1 << 32) - 1) |
| 106 | + return (masked_secret, r) |
| 107 | + |
| 108 | + |
| 109 | +class B2AAccelerator(MaskingAccelerator): |
| 110 | + def __init__(self, mod_wsr: DumbWSR): |
| 111 | + super().__init__(7, mod_wsr) |
| 112 | + |
| 113 | + def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]: |
| 114 | + # The current placeholder implementation removes the boolean mask and adds a new arithmetic |
| 115 | + # mask. We use a fixed mask until the exact design is known. |
| 116 | + # |
| 117 | + # Input: (x XOR r, r), 0 <= x, r < q < 2^k |
| 118 | + # Output: (x - s mod q, s), (x - s) + s mod q = x, 0 <= s < q |
| 119 | + |
| 120 | + # in1_s0 and in1_s1 are not used by the B2A accelerator |
| 121 | + |
| 122 | + # We take a zero mask for simplicity and to avoid assertion errors due to modulus being |
| 123 | + # smaller than this fixed constant. |
| 124 | + s = 0 |
| 125 | + r = in0_s1 |
| 126 | + |
| 127 | + secret = in0_s0 ^ r |
| 128 | + masked_secret = (secret - s) % self._modulus() |
| 129 | + |
| 130 | + # Optionally, we crash if the constraints are not met. |
| 131 | + if CHECK_ACCELERATOR_CONSTRAINTS: |
| 132 | + assert self._modulus() < 2**32 |
| 133 | + assert 0 <= in0_s0 < self._modulus() |
| 134 | + assert 0 <= r < self._modulus() |
| 135 | + assert 0 <= s < self._modulus() |
| 136 | + |
| 137 | + # Limit results to 32 bits |
| 138 | + masked_secret &= ((1 << 32) - 1) |
| 139 | + s &= ((1 << 32) - 1) |
| 140 | + return (masked_secret, s) |
| 141 | + |
| 142 | + |
| 143 | +class SecAddModqAccelerator(MaskingAccelerator): |
| 144 | + def __init__(self, mod_wsr: DumbWSR): |
| 145 | + super().__init__(9, mod_wsr) |
| 146 | + |
| 147 | + def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]: |
| 148 | + # The current placeholder implementation removes the boolean masks, adds the secrets |
| 149 | + # modulo q, and adds a new boolean mask. We use a fixed mask until the exact design is |
| 150 | + # known. |
| 151 | + # |
| 152 | + # Input: (x xor r1, r1), (y xor s1, s1), 0 <= x, y, s, r < q < 2^k |
| 153 | + # Output: ((x + y mod q) XOR t, t) |
| 154 | + r1 = in0_s1 |
| 155 | + s1 = in1_s1 |
| 156 | + # We take a zero mask for simplicity and to avoid assertion errors due to modulus being |
| 157 | + # smaller than this fixed constant. |
| 158 | + t = 0 |
| 159 | + |
| 160 | + x = in0_s0 ^ r1 |
| 161 | + y = in1_s0 ^ s1 |
| 162 | + sum = (x + y) % self._modulus() |
| 163 | + masked_sum = sum ^ t |
| 164 | + |
| 165 | + if CHECK_ACCELERATOR_CONSTRAINTS: |
| 166 | + assert self._modulus() < 2**32 |
| 167 | + assert 0 <= x < self._modulus() |
| 168 | + assert 0 <= y < self._modulus() |
| 169 | + assert 0 <= r1 < self._modulus() |
| 170 | + assert 0 <= s1 < self._modulus() |
| 171 | + assert 0 <= t < self._modulus() |
| 172 | + |
| 173 | + # Limit results to 32 bits |
| 174 | + masked_sum &= ((1 << 32) - 1) |
| 175 | + t &= ((1 << 32) - 1) |
| 176 | + return (masked_sum, t) |
| 177 | + |
| 178 | + |
| 179 | +class MaskingAcceleratorInterface: |
| 180 | + def __init__(self, csrs: CSRFile, wsrs: WSRFile) -> None: |
| 181 | + |
| 182 | + # The CSRs and WSRs |
| 183 | + self.csrs = csrs |
| 184 | + self.wsrs = wsrs |
| 185 | + self.mai_ctrl = self.csrs.MaiCtrl |
| 186 | + self.mai_status = self.csrs.MaiStatus |
| 187 | + self.mai_res_s0 = self.wsrs.MaiResS0 |
| 188 | + self.mai_res_s1 = self.wsrs.MaiResS1 |
| 189 | + self.mai_in0_s0 = self.wsrs.MaiIn0S0 |
| 190 | + self.mai_in0_s1 = self.wsrs.MaiIn0S1 |
| 191 | + self.mai_in1_s0 = self.wsrs.MaiIn1S0 |
| 192 | + self.mai_in1_s1 = self.wsrs.MaiIn1S1 |
| 193 | + |
| 194 | + # All available accelerators are instantiated here in a dictionary. |
| 195 | + # The currently active accelerator is selected based on the operation field in MAI_CTRL. |
| 196 | + # Changing the operation while an operation is ongoing is not allowed (see |
| 197 | + # is_valid_ctrl_change). Thus, the step() method can simply read the operation field each |
| 198 | + # cycle to get the current accelerator like this: |
| 199 | + # self._all_accelerators[self.mai_ctrl.read_operation()] |
| 200 | + self._all_accelerators = { |
| 201 | + MaiOperation.A2B: A2BAccelerator(self.wsrs.MOD), |
| 202 | + MaiOperation.B2A: B2AAccelerator(self.wsrs.MOD), |
| 203 | + MaiOperation.SECADD: SecAddModqAccelerator(self.wsrs.MOD), |
| 204 | + } |
| 205 | + |
| 206 | + # Dispatch related variables |
| 207 | + # The dispatch logic is responsible for pushing inputs into the accelerator. |
| 208 | + self._dispatch_idx = 0 |
| 209 | + self.is_dispatching = False |
| 210 | + |
| 211 | + # Writeback related variables |
| 212 | + # The writeback logic is responsible for receiving results from the accelerator into the |
| 213 | + # output WSRs. |
| 214 | + self._writeback_idx = 0 |
| 215 | + |
| 216 | + def _accelerator(self) -> MaskingAccelerator: |
| 217 | + '''Return the currently selected masking accelerator based on the operation field.''' |
| 218 | + return self._all_accelerators[self.mai_ctrl.read_operation()] |
| 219 | + |
| 220 | + def step(self) -> None: |
| 221 | + '''Advance the MAI simulation by one cycle. |
| 222 | +
|
| 223 | + This is expected to be called before the current instruction executes / steps. |
| 224 | + ''' |
| 225 | + ################### |
| 226 | + # Writeback logic # |
| 227 | + ################### |
| 228 | + # Get the newest result and write it into the output WSRs. This is done before |
| 229 | + # advancing the pipeline to model the fact that the result is available at |
| 230 | + # the start of the cycle. |
| 231 | + results = self._accelerator().pop() |
| 232 | + if results is not None: |
| 233 | + # Write to the output WSRs |
| 234 | + self.mai_res_s0.set_32bit_unsigned(results[0], self._writeback_idx) |
| 235 | + self.mai_res_s1.set_32bit_unsigned(results[1], self._writeback_idx) |
| 236 | + self._writeback_idx += 1 |
| 237 | + |
| 238 | + # Detect if we finished writing back |
| 239 | + if self._writeback_idx >= 8: |
| 240 | + self._writeback_idx = 0 |
| 241 | + # If we are finishing the writeback, reset the busy bit. The write method update the |
| 242 | + # bits when committing to the changes so the current instruction still reads the old |
| 243 | + # value. |
| 244 | + self.mai_status.write_busy_bit(False) |
| 245 | + |
| 246 | + ###################### |
| 247 | + # Accelerator update # |
| 248 | + ###################### |
| 249 | + # Advance the accelerator pipeline. |
| 250 | + self._accelerator().step() |
| 251 | + |
| 252 | + ################# |
| 253 | + # Start logic # |
| 254 | + ################# |
| 255 | + # Start a new operation if start bit was set in last cycle |
| 256 | + if self.mai_ctrl.read_start_bit(): |
| 257 | + # Begin pushing inputs in the dispatch logic |
| 258 | + self.is_dispatching = True |
| 259 | + # Immediately set the busy bit such that the current instruction reads it as set. |
| 260 | + self.mai_status.set_busy_bit(True) |
| 261 | + # Immediately reset the ready bit such that the current instruction reads it as reset |
| 262 | + # and any configuration change check does not allow changing the operation type. |
| 263 | + self.mai_status.set_ready_bit(False) |
| 264 | + # Immediately reset the start bit such that it always reads zero. |
| 265 | + self.mai_ctrl.set_start_bit(False) |
| 266 | + |
| 267 | + ################## |
| 268 | + # Dispatch logic # |
| 269 | + ################## |
| 270 | + if self.is_dispatching: |
| 271 | + self._accelerator().push(self.mai_in0_s0.read_32bit_unsigned(self._dispatch_idx), |
| 272 | + self.mai_in0_s1.read_32bit_unsigned(self._dispatch_idx), |
| 273 | + self.mai_in1_s0.read_32bit_unsigned(self._dispatch_idx), |
| 274 | + self.mai_in1_s1.read_32bit_unsigned(self._dispatch_idx)) |
| 275 | + self._dispatch_idx += 1 |
| 276 | + |
| 277 | + # Detect if we have finished dispatching |
| 278 | + if self._dispatch_idx >= 8: |
| 279 | + self._dispatch_idx = 0 |
| 280 | + self.is_dispatching = False |
| 281 | + # Set the ready bit at the end of this cycle. This indicates that new inputs can be |
| 282 | + # accepted. |
| 283 | + self.mai_status.write_ready_bit(True) |
| 284 | + |
| 285 | + def is_busy(self) -> bool: |
| 286 | + '''Returns whether the MAI is currently busy processing an operation.''' |
| 287 | + return self.mai_status.read_busy_bit() |
| 288 | + |
| 289 | + def is_ready(self) -> bool: |
| 290 | + '''Returns whether the MAI is ready to accept new inputs.''' |
| 291 | + return self.mai_status.read_ready_bit() |
| 292 | + |
| 293 | + def ready_for_inputs(self) -> bool: |
| 294 | + return self.is_ready() |
| 295 | + |
| 296 | + def ready_to_start(self) -> bool: |
| 297 | + return not self.is_busy() |
| 298 | + |
| 299 | + def is_valid_ctrl_change(self, value: int) -> bool: |
| 300 | + '''Return whether writing value to the MAI_CTRL CSR is currently allowed.''' |
| 301 | + # Starting is only allowed if MAI is ready. |
| 302 | + if self.mai_ctrl.would_set_start_bit(value) and not self.ready_to_start(): |
| 303 | + return False |
| 304 | + |
| 305 | + # We only allow setting the operation to valid options. |
| 306 | + if not self.mai_ctrl.is_valid_operation(value): |
| 307 | + return False |
| 308 | + |
| 309 | + # Changing the operation is only allowed if MAI is not busy / no operation is ongoing. |
| 310 | + if self.mai_ctrl.would_change_op(value) and self.is_busy(): |
| 311 | + return False |
| 312 | + |
| 313 | + return True |
0 commit comments