diff --git a/hw/ip/otbn/README.md b/hw/ip/otbn/README.md index 99f12eeba3408..798cb11f7034d 100644 --- a/hw/ip/otbn/README.md +++ b/hw/ip/otbn/README.md @@ -109,8 +109,8 @@ CSRs can be accessed through dedicated instructions, {{#otbn-insn-ref CSRRS}} an Writes to read-only (RO) registers are ignored; they do not signal an error. All read-write (RW) CSRs are set to 0 when OTBN starts an operation (when 1 is written to [`CMD.start`](doc/registers.md#cmd)). - + @@ -565,6 +565,44 @@ All read-write (RW) CSRs are set to 0 when OTBN starts an operation (when 1 is w
+ + 0x7F0 + RW + MAI_CTRL + + The MAI control register. This is used to start MAI operations as well as configuring the accelerators. + + + + + + + + + + + + + + + + + + +
BitDescription
0 + MAI_START: Writing 1 to this bit starts the MAI operation. Writing it when MAI is busy will cause a MAI_ERROR software error. +
2:1 + The MAI_OPERATION field defines which accelerator is used for the next operation. Invalid values and writing to these bits when MAI is busy will cause a MAI_ERROR software error. +

Values:

    +
  • 0: A2B
  • +
  • 1: B2A
  • +
  • 2: secAdd
  • +
+
31:3 + Reserved. Any write is ignored. Always reads as 0. +
+ + 0xFC0 RO @@ -659,6 +697,39 @@ All read-write (RW) CSRs are set to 0 when OTBN starts an operation (when 1 is w + + 0xFE0 + RO + MAI_STATUS + + The MAI status register. + + + + + + + + + + + + + + + + + + +
BitDescription
0 + MAI_BUSY: This bit is set to 1 when an MAI operation is in progress. If reset, the MAI accepts new configuration values and a new execution can be started by writing to the MAI_START bit in the MAI_CTRL CSR. +
1 + MAI_READY: This bit is set to 1 when the MAI_INx_Sx WSRs are ready to accept new values for the next execution. +
31:2 + Reserved. Always reads as 0. +
+ + @@ -858,6 +929,68 @@ All read-write (RW) WSRs are set to 0 when OTBN starts an operation (when 1 is w + + 0xA + RO + MAI_RES_S0 + + This WSR holds share 0 of the masked results produced by the MAI. + The results are organized as eight 32-bit values. + Results are valid when MAI is not busy anymore. + These values are overwritten when the first results of the next execution are available (this depends on the selected accelerator's latency). + + + + 0xB + RO + MAI_RES_S1 + + This WSR holds share 1 of the masked results produced by the MAI. + The results are organized as eight 32-bit values. + Results are valid when MAI is not busy anymore. + These values are overwritten when the first results of the next execution are available (this depends on the selected accelerator's latency). + + + + 0xC + RW + MAI_IN0_S0 + + This WSR transfers share 0 of the first input secrets towards the MAI. + The inputs are considered as eight 32-bit values. + Writing to this WSR while MAI is not ready will cause a MAI_ERROR software error. + + + + 0xD + RW + MAI_IN0_S1 + + This WSR transfers share 1 of the first input secrets towards the MAI. + The inputs are considered as eight 32-bit values. + Writing to this WSR while MAI is not ready will cause a MAI_ERROR software error. + + + + 0xE + RW + MAI_IN1_S0 + + This WSR transfers share 0 of the second input secrets towards the MAI. + The inputs are considered as eight 32-bit values. + Writing to this WSR while MAI is not ready will cause a MAI_ERROR software error. + + + + 0xF + RW + MAI_IN1_S1 + + This WSR transfers share 1 of the second input secrets towards the MAI. + The inputs are considered as eight 32-bit values. + Writing to this WSR while MAI is not ready will cause a MAI_ERROR software error. + + diff --git a/hw/ip/otbn/data/csr.yml b/hw/ip/otbn/data/csr.yml index 4dee642a92e21..dc9bf69e959b5 100644 --- a/hw/ip/otbn/data/csr.yml +++ b/hw/ip/otbn/data/csr.yml @@ -162,6 +162,24 @@ bits: 31-0: BYTE_STROBE is the KMAC byte strobe field. +- name: mai_ctrl + address: 0x7f0 + doc: | + The MAI control register. This is used to start MAI operations as well as configuring the accelerators. + bits: + 0: | + MAI_START: Writing 1 to this bit starts the MAI operation. + Writing it when MAI is busy will cause a MAI_ERROR software error. + 2-1: + doc: | + The MAI_OPERATION field defines which accelerator is used for the next operation. + Invalid values and writing to these bits when MAI is busy will cause a MAI_ERROR software error. + values: + 0: A2B + 1: B2A + 2: secAdd + 31-3: Reserved. Any write is ignored. Always reads as 0. + - name: rnd address: 0xfc0 read-only: true @@ -205,3 +223,16 @@ bits: 7-0: ERROR_CODE contains the error code coming directly from the KMAC HWIP. 31-8: Reserved. Always reads as 0. Any write is ignored. + +- name: mai_status + address: 0xfe0 + read-only: true + doc: | + The MAI status register. + bits: + 0: | + MAI_BUSY: This bit is set to 1 when an MAI operation is in progress. + If reset, the MAI accepts new configuration values and a new execution can be started by writing to the MAI_START bit in the MAI_CTRL CSR. + 1: | + MAI_READY: This bit is set to 1 when the MAI_INx_Sx WSRs are ready to accept new values for the next execution. + 31-2: Reserved. Always reads as 0. diff --git a/hw/ip/otbn/data/wsr.yml b/hw/ip/otbn/data/wsr.yml index 62f8b72a43948..36e61f21bb509 100644 --- a/hw/ip/otbn/data/wsr.yml +++ b/hw/ip/otbn/data/wsr.yml @@ -109,3 +109,49 @@ 255-64: | Write: Words 1-3 of the message share. Read: Returns `0`. Digest shares are read out via the least significant word only. + +- name: mai_res_s0 + address: 10 + read-only: true + doc: | + This WSR holds share 0 of the masked results produced by the MAI. + The results are organized as eight 32-bit values. + Results are valid when MAI is not busy anymore. + These values are overwritten when the first results of the next execution are available (this depends on the selected accelerator's latency). + +- name: mai_res_s1 + address: 11 + read-only: true + doc: | + This WSR holds share 1 of the masked results produced by the MAI. + The results are organized as eight 32-bit values. + Results are valid when MAI is not busy anymore. + These values are overwritten when the first results of the next execution are available (this depends on the selected accelerator's latency). + +- name: mai_in0_s0 + address: 12 + doc: | + This WSR transfers share 0 of the first input secrets towards the MAI. + The inputs are considered as eight 32-bit values. + Writing to this WSR while MAI is not ready will cause a MAI_ERROR software error. + +- name: mai_in0_s1 + address: 13 + doc: | + This WSR transfers share 1 of the first input secrets towards the MAI. + The inputs are considered as eight 32-bit values. + Writing to this WSR while MAI is not ready will cause a MAI_ERROR software error. + +- name: mai_in1_s0 + address: 14 + doc: | + This WSR transfers share 0 of the second input secrets towards the MAI. + The inputs are considered as eight 32-bit values. + Writing to this WSR while MAI is not ready will cause a MAI_ERROR software error. + +- name: mai_in1_s1 + address: 15 + doc: | + This WSR transfers share 1 of the second input secrets towards the MAI. + The inputs are considered as eight 32-bit values. + Writing to this WSR while MAI is not ready will cause a MAI_ERROR software error. diff --git a/hw/ip/otbn/dv/otbnsim/sim/constants.py b/hw/ip/otbn/dv/otbnsim/sim/constants.py index 4c3b2cefba8cb..b781455d43ac1 100644 --- a/hw/ip/otbn/dv/otbnsim/sim/constants.py +++ b/hw/ip/otbn/dv/otbnsim/sim/constants.py @@ -32,6 +32,7 @@ class ErrBits(IntEnum): KEY_INVALID = 1 << 5 RND_REP_CHK_FAIL = 1 << 6 RND_FIPS_CHK_FAIL = 1 << 7 + MAI_ERROR = 1 << 8 IMEM_INTG_VIOLATION = 1 << 16 DMEM_INTG_VIOLATION = 1 << 17 REG_INTG_VIOLATION = 1 << 18 diff --git a/hw/ip/otbn/dv/otbnsim/sim/csr.py b/hw/ip/otbn/dv/otbnsim/sim/csr.py index 0c96acae1e3af..0ee101bfe4cee 100644 --- a/hw/ip/otbn/dv/otbnsim/sim/csr.py +++ b/hw/ip/otbn/dv/otbnsim/sim/csr.py @@ -2,6 +2,7 @@ # Licensed under the Apache License, Version 2.0, see LICENSE for details. # SPDX-License-Identifier: Apache-2.0 +from enum import IntEnum from typing import Any, Callable, Dict, List, Optional from .flags import FlagGroups from .ispr import DumbISPR @@ -45,6 +46,161 @@ def write_unsigned(self, value: int) -> None: self._write_func(value) +class MaiOperation(IntEnum): + A2B = 0 + B2A = 1 + SECADD = 2 + + +class MaiCtrlCSR(DumbISPR): + '''Models the MAI CTRL CSR''' + def __init__(self) -> None: + super().__init__("MAI_CTRL", 32) + self.START_BIT_MASK = 0x1 + self.START_BIT_OFFSET = 0 + self.OPERATION_MASK = 0x3 + self.OPERATION_OFFSET = 1 + + def on_start(self) -> None: + super().on_start() + # On start, the default operation is set. + self._value |= (MaiOperation.A2B & self.OPERATION_MASK) << self.OPERATION_OFFSET + + def read_start_bit(self) -> bool: + '''Get the start bit from the CSR.''' + bit = (self.read_unsigned() >> self.START_BIT_OFFSET) & self.START_BIT_MASK + return bit != 0 + + def would_set_start_bit(self, value: int) -> bool: + '''Return whether writing value would set the start bit.''' + return ((value >> self.START_BIT_OFFSET) & self.START_BIT_MASK) != 0 + + def set_start_bit(self, start: bool) -> None: + '''Set or clear the start bit in the CSR. + + This takes effect immediately. Note that we still report the change to generate a proper + trace. Any "simultaneous" write by an instruction will override this change. + ''' + val = self.read_unsigned() + if start: + val |= self.START_BIT_MASK << self.START_BIT_OFFSET + else: + val &= ~(self.START_BIT_MASK << self.START_BIT_OFFSET) + self._value = val + self._next_value = val + self._pending_write = True + + def read_operation(self) -> MaiOperation: + '''Get the operation field from the CSR.''' + op = (self.read_unsigned() >> self.OPERATION_OFFSET) & self.OPERATION_MASK + # The enum cast will fail if the current operation is not a valid option. If this happens + # something before went wrong as we only allow writing valid options. + return MaiOperation(op) + + def is_valid_operation(self, value: int) -> bool: + '''Returns whether the CSR value contains valid operation bits.''' + op = (value >> self.OPERATION_OFFSET) & self.OPERATION_MASK + # TODO: Clean this up once we have python 3.12+ + return op in MaiOperation._value2member_map_ + + def would_change_op(self, value: int) -> bool: + '''Return whether writing value to the CSR would change the operation field. + The value to be checked must specify a valid operation option. + ''' + new_op = MaiOperation((value >> self.OPERATION_OFFSET) & self.OPERATION_MASK) + curr_op = self.read_operation() + return new_op != curr_op + + +class MaiStatusCSR(DumbISPR): + '''Models the MAI STATUS CSR''' + def __init__(self) -> None: + super().__init__("MAI_STATUS", 32) + self.BUSY_BIT_MASK = 0x1 + self.BUSY_BIT_OFFSET = 0 + self.READY_BIT_MASK = 0x1 + self.READY_BIT_OFFSET = 1 + + def on_start(self) -> None: + super().on_start() + # On start, the MAI is not busy and is ready for new inputs. + self._value |= self.READY_BIT_MASK << self.READY_BIT_OFFSET + + def write_unsigned(self, value: int) -> None: + '''Ignore writes to the MAI STATUS CSR. + Note this is different from set_ methods. This is used by instructions and the set_ methods + are used directly by the MAI.''' + return + + def read_ready_bit(self) -> bool: + '''Get the ready bit from the CSR.''' + bit = (self.read_unsigned() >> self.READY_BIT_OFFSET) & self.READY_BIT_MASK + return bit != 0 + + def set_ready_bit(self, ready: bool) -> None: + '''Set or clear the ready bit in the CSR. + This takes effect immediately. Note that we still report the change to generate a proper + trace.''' + val = self.read_unsigned() + if ready: + val |= self.READY_BIT_MASK << self.READY_BIT_OFFSET + else: + val &= ~(self.READY_BIT_MASK << self.READY_BIT_OFFSET) + self._value = val + self._next_value = val + self._pending_write = True + + def write_ready_bit(self, ready: bool) -> None: + '''Set or clear the ready bit in the CSR. + This takes effect when committing.''' + # Check if any other bit manipulation is pending. If so, we must use the pending value to + # avoid discarding the previous write. + val = self._next_value if self._pending_write else self.read_unsigned() + # _pending_write should never be set if _next_value is None. + assert val is not None + if ready: + val |= (self.READY_BIT_MASK << self.READY_BIT_OFFSET) + else: + val &= ~(self.READY_BIT_MASK << self.READY_BIT_OFFSET) + # We cannot use self.write_unsigned here as any write from an instruction is ignored. + self._next_value = val + self._pending_write = True + + def read_busy_bit(self) -> bool: + '''Get the busy bit from the CSR.''' + bit = (self.read_unsigned() >> self.BUSY_BIT_OFFSET) & self.BUSY_BIT_MASK + return bit != 0 + + def set_busy_bit(self, busy: bool) -> None: + '''Set or clear the busy bit in the CSR. + This takes effect immediately. Note that we still report the change to generate a proper + trace.''' + val = self.read_unsigned() + if busy: + val |= (self.BUSY_BIT_MASK << self.BUSY_BIT_OFFSET) + else: + val &= ~(self.BUSY_BIT_MASK << self.BUSY_BIT_OFFSET) + self._value = val + self._next_value = val + self._pending_write = True + + def write_busy_bit(self, busy: bool) -> None: + '''Set or clear the busy bit in the CSR. + This takes effect when committing.''' + # Check if any other bit manipulation is pending. If so, we must use the pending value to + # avoid discarding the previous write. + val = self._next_value if self._pending_write else self.read_unsigned() + # _pending_write should never be set if _next_value is None. + assert val is not None + if busy: + val |= (self.BUSY_BIT_MASK << self.BUSY_BIT_OFFSET) + else: + val &= ~(self.BUSY_BIT_MASK << self.BUSY_BIT_OFFSET) + # We cannot use self.write_unsigned here as any write from an instruction is ignored. + self._next_value = val + self._pending_write = True + + class CSRFile: '''A model of the CSR file''' def __init__(self, wsrs: WSRFile) -> None: @@ -62,6 +218,8 @@ def __init__(self, wsrs: WSRFile) -> None: self.URND = WrapperCSR(read_func=wsrs.URND.read_u32) self.KMAC_CMD = KmacCommandCSR('KMAC_CMD', write_mask=0x3f) self.KMAC_BYTE_STROBE = DumbISPR('KMAC_BYTE_STROBE', width=32) + self.MAI_CTRL = MaiCtrlCSR() + self.MAI_STATUS = MaiStatusCSR() self._known_indices = { 0x7c0, # FG0 @@ -75,10 +233,12 @@ def __init__(self, wsrs: WSRFile) -> None: 0x7dc, # KMAC_MSG_SEND 0x7dd, # KMAC_CMD 0x7de, # KMAC_BYTE_STROBE + 0x7f0, # MAI_CTRL 0xfc0, # RND 0xfc1, # URND 0xfc2, # KMAC_STATUS 0xfc3, # KMAC_ERROR + 0xfe0, # MAI_STATUS } self._idx_to_csr: Dict[int, Any] = { @@ -92,12 +252,16 @@ def __init__(self, wsrs: WSRFile) -> None: 0x7dc: self.KMAC_MSG_SEND, 0x7dd: self.KMAC_CMD, 0x7de: self.KMAC_BYTE_STROBE, + 0x7f0: self.MAI_CTRL, 0xfc0: self.RND, 0xfc1: self.URND, 0xfc2: self.KMAC_STATUS, 0xfc3: self.KMAC_ERROR, + 0xfe0: self.MAI_STATUS, } + self.on_start() + @staticmethod def _get_field(field_idx: int, field_size: int, val: int) -> int: mask = (1 << field_size) - 1 @@ -111,6 +275,12 @@ def _set_field(field_idx: int, field_size: int, field_val: int, shift = field_size * field_idx return (old_val & ~(mask << shift)) | (field_val << shift) + def on_start(self) -> None: + '''Reset CSRs and flags when starting an operation''' + self.flags = FlagGroups() + self.MaiCtrl.on_start() + self.MaiStatus.on_start() + def check_idx(self, idx: int) -> bool: '''Return True if idx points to a valid CSR; False otherwise.''' return idx in self._known_indices @@ -125,6 +295,10 @@ def read_unsigned(self, wsrs: WSRFile, idx: int) -> int: if csr is not None: return int(csr.read_unsigned()) + if idx == 0xfe0: + # MAI_STATUS register + return self.MaiStatus.read_unsigned() + raise RuntimeError('Unknown CSR index: {:#x}'.format(idx)) def write_unsigned(self, wsrs: WSRFile, idx: int, value: int) -> None: @@ -142,6 +316,10 @@ def write_unsigned(self, wsrs: WSRFile, idx: int, value: int) -> None: csr.write_unsigned(value) return + if idx == 0xfe0: + # MAI_STATUS register (which ignores writes) + return + raise RuntimeError('Unknown CSR index: {:#x}'.format(idx)) def commit(self) -> None: @@ -154,6 +332,8 @@ def commit(self) -> None: self.KMAC_MSG_SEND.commit() self.KMAC_CMD.commit() self.KMAC_BYTE_STROBE.commit() + self.MAI_CTRL.commit() + self.MAI_STATUS.commit() def abort(self) -> None: self.flags.abort() @@ -165,6 +345,9 @@ def abort(self) -> None: self.KMAC_MSG_SEND.abort() self.KMAC_CMD.abort() self.KMAC_BYTE_STROBE.abort() + self.MAI_CTRL.abort() + # The MAI_STATUS is always committed because only the MAI updates it. + self.MAI_STATUS.commit() def changes(self) -> List[Trace]: ret: List[Trace] = [] @@ -177,7 +360,10 @@ def changes(self) -> List[Trace]: ret += self.KMAC_MSG_SEND.changes() ret += self.KMAC_CMD.changes() ret += self.KMAC_BYTE_STROBE.changes() + ret += self.MAI_CTRL.changes() + ret += self.MAI_STATUS.changes() return ret def wipe(self) -> None: self.flags.write_unsigned(0) + # TODO: Wipe or reset MAI CTRL/STATUS? diff --git a/hw/ip/otbn/dv/otbnsim/sim/insn.py b/hw/ip/otbn/dv/otbnsim/sim/insn.py index a9482a8c7452a..a0369625f51b1 100644 --- a/hw/ip/otbn/dv/otbnsim/sim/insn.py +++ b/hw/ip/otbn/dv/otbnsim/sim/insn.py @@ -434,11 +434,18 @@ def execute(self, state: OTBNState) -> Optional[Iterator[None]]: # There's a pending EDN request. Stall for a cycle. yield None - # At this point, the CSR is ready. Read, update and write back to grs1. + # At this point, the CSR is ready. Read it to grd. old_val = state.read_csr(self.csr) - new_val = old_val | bits_to_set state.gprs.get_reg(self.grd).write_unsigned(old_val) + + # If CSR should be updated, compute update, check if update is allowed + # and write it back. if self.grs1 != 0: + new_val = old_val | bits_to_set + if self.csr == 0x7f0: + if not state.mai.is_valid_ctrl_change(new_val): + state.stop_at_end_of_cycle(ErrBits.MAI_ERROR) + return None state.write_csr(self.csr, new_val) return None @@ -479,6 +486,12 @@ def execute(self, state: OTBNState) -> Optional[Iterator[None]]: old_val = state.read_csr(self.csr) state.gprs.get_reg(self.grd).write_unsigned(old_val) + # Check if the write to MAI_CTRL is allowed. + if self.csr == 0x7f0: + if not state.mai.is_valid_ctrl_change(new_val): + state.stop_at_end_of_cycle(ErrBits.MAI_ERROR) + return None + state.write_csr(self.csr, new_val) return None @@ -1272,6 +1285,13 @@ def execute(self, state: OTBNState) -> None: state.stop_at_end_of_cycle(ErrBits.ILLEGAL_INSN) return None + # Check if MAI is ready to accept new inputs. If not stop with MAI + # error. + if self.wsr in [12, 13, 14, 15]: + if not state.mai.ready_for_inputs(): + state.stop_at_end_of_cycle(ErrBits.MAI_ERROR) + return None + val = state.wdrs.get_reg(self.wrs).read_unsigned() state.wsrs.write_at_idx(self.wsr, val) diff --git a/hw/ip/otbn/dv/otbnsim/sim/mai.py b/hw/ip/otbn/dv/otbnsim/sim/mai.py new file mode 100644 index 0000000000000..fbf2f2c9b5c70 --- /dev/null +++ b/hw/ip/otbn/dv/otbnsim/sim/mai.py @@ -0,0 +1,309 @@ +# Copyright lowRISC contributors (OpenTitan project). +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from collections import deque + +from .csr import CSRFile, MaiOperation +from .wsr import WSRFile, DumbWSR + +# The masking accelerator interface (MAI) emulates the behavior of the interface and the actual +# accelerators. + +# Enable or disable assertions which check that the inputs and outputs of the accelerators +# meet certain constraints (e.g., being smaller than the modulus). +CHECK_ACCELERATOR_CONSTRAINTS = False + + +class MaskingAccelerator: + '''Models a masking accelerator which has a simple pipeline. + New operations can be pushed to the accelerator, and results can be popped from it. + Each step of the simulation advances the pipeline by one stage. + ''' + + def __init__(self, latency: int, mod_wsr: DumbWSR) -> None: + # Latency of the masking accelerator in cycles. + self.latency = latency + + # The MOD WSR is used to get the current modulus for operations. + self.mod_wsr = mod_wsr + + # The pipeline contains the two result shares and is modeled with a deque where None + # indicates an empty slot. + self.pipeline: deque[Optional[Tuple[int, int]]] + self.pipeline = deque([None] * self.latency, self.latency) + + def push(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> bool: + '''Try to push an operation to the masking accelerator pipeline. + + Returns True if the accelerator can accept it (free pipeline slot), False otherwise. + ''' + # This accelerator implementation features no backpressure, so we always accept new + # operations. Pop the leftmost pipeline slot and replace it with the new operation result. + # The result is computed immediately but will only be available after the full pipeline + # latency. + self.pipeline.popleft() + self.pipeline.appendleft(self._compute(in0_s0, in0_s1, in1_s0, in1_s1)) + return True + + def pop(self) -> Optional[Tuple[int, int]]: + '''Read the current output of the masking accelerator pipeline.''' + # We do only peak the pipeline as the pipeline advancing is modelled in the step() method. + return self.pipeline[-1] + + def step(self) -> None: + '''Advance the pipeline by one stage if possible.''' + # This accelerator implementation features no backpressure, so we always advance the + # pipeline. We insert an unused pipeline slot which is replaced in case a new item is + # pushed. appendleft() will drop the rightmost item automatically. + self.pipeline.appendleft(None) + + def is_busy(self) -> bool: + '''Return True if the accelerator is busy (has pending operations), False otherwise.''' + # The accelerator is busy if there is at least one non-None item in the pipeline. + return any(slot is not None for slot in self.pipeline) + + def _modulus(self) -> int: + '''Return the current 32-bit modulus from the modulus WSR.''' + return self.mod_wsr.read_unsigned() & ((1 << 32) - 1) + + def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]: + '''Compute the result of the masking operation.''' + raise NotImplementedError + + +class A2BAccelerator(MaskingAccelerator): + def __init__(self, mod_wsr: DumbWSR): + super().__init__(32, mod_wsr) + + def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]: + # The current placeholder implementation removes the arithmetic mask and adds a new boolean + # mask. We use a fixed mask until the exact design is known. + # + # Input: (x - s mod q, s), (x - s) + s mod q = x, 0 <= s < q + # Output: (x XOR r, r), x XOR r XOR r = x, 0 <= x, r < q < 2^k + + # in1_s0 and in1_s1 are not used by the A2B accelerator + + s = in0_s1 + # We take a fixed mask which satisfies the constraints until the exact design is known. + r = self._modulus() // 3 + secret = (in0_s0 + s) % self._modulus() + masked_secret = (secret ^ r) + + # Optionally, we crash if the constraints are not met. + if CHECK_ACCELERATOR_CONSTRAINTS: + assert self._modulus() < 2**32 + assert 0 <= s < self._modulus() + assert 0 <= r < self._modulus() + assert 0 <= secret < self._modulus() + + # Limit results to 32 bits + masked_secret &= ((1 << 32) - 1) + r &= ((1 << 32) - 1) + return (masked_secret, r) + + +class B2AAccelerator(MaskingAccelerator): + def __init__(self, mod_wsr: DumbWSR): + super().__init__(32, mod_wsr) + + def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]: + # The current placeholder implementation removes the boolean mask and adds a new arithmetic + # mask. We use a fixed mask until the exact design is known. + # + # Input: (x XOR r, r), 0 <= x, r < q < 2^k + # Output: (x - s mod q, s), (x - s) + s mod q = x, 0 <= s < q + + # in1_s0 and in1_s1 are not used by the B2A accelerator + + # We take a fixed mask which satisfies the constraints until the exact design is known. + s = self._modulus() // 3 + r = in0_s1 + + secret = in0_s0 ^ r + masked_secret = (secret - s) % self._modulus() + + # Optionally, we crash if the constraints are not met. + if CHECK_ACCELERATOR_CONSTRAINTS: + assert self._modulus() < 2**32 + assert 0 <= in0_s0 < self._modulus() + assert 0 <= r < self._modulus() + assert 0 <= s < self._modulus() + + # Limit results to 32 bits + masked_secret &= ((1 << 32) - 1) + s &= ((1 << 32) - 1) + return (masked_secret, s) + + +class SecAddModkAccelerator(MaskingAccelerator): + def __init__(self, mod_wsr: DumbWSR): + super().__init__(32, mod_wsr) + + def _compute(self, in0_s0: int, in0_s1: int, in1_s0: int, in1_s1: int) -> Tuple[int, int]: + # The current placeholder implementation removes the boolean masks, adds the secrets + # modulo 2**32, and adds a new boolean mask. We use a fixed mask until the exact design is + # known. + # + # Input: (x xor r1, r1), (y xor s1, s1), 0 <= x, y, s, r < q < 2^k + # Output: ((x + y mod q) XOR t, t) + r1 = in0_s1 + s1 = in1_s1 + # We take a fixed mask until the exact design is known. + t = self._modulus() // 3 + + x = in0_s0 ^ r1 + y = in1_s0 ^ s1 + sum = (x + y) % 2**32 + masked_sum = sum ^ t + + if CHECK_ACCELERATOR_CONSTRAINTS: + assert self._modulus() < 2**32 + assert 0 <= x < self._modulus() + assert 0 <= y < self._modulus() + assert 0 <= r1 < self._modulus() + assert 0 <= s1 < self._modulus() + + # Limit results to 32 bits + masked_sum &= ((1 << 32) - 1) + t &= ((1 << 32) - 1) + return (masked_sum, t) + + +class MaskingAcceleratorInterface: + def __init__(self, csrs: CSRFile, wsrs: WSRFile) -> None: + + # The CSRs and WSRs + self.csrs = csrs + self.wsrs = wsrs + self.mai_ctrl = self.csrs.MaiCtrl + self.mai_status = self.csrs.MaiStatus + self.mai_res_s0 = self.wsrs.MaiResS0 + self.mai_res_s1 = self.wsrs.MaiResS1 + self.mai_in0_s0 = self.wsrs.MaiIn0S0 + self.mai_in0_s1 = self.wsrs.MaiIn0S1 + self.mai_in1_s0 = self.wsrs.MaiIn1S0 + self.mai_in1_s1 = self.wsrs.MaiIn1S1 + + # All available accelerators are instantiated here in a dictionary. + # The currently active accelerator is selected based on the operation field in MAI_CTRL. + # Changing the operation while an operation is ongoing is not allowed (see + # is_valid_ctrl_change). Thus, the step() method can simply read the operation field each + # cycle to get the current accelerator like this: + # self._all_accelerators[self.mai_ctrl.read_operation()] + self._all_accelerators = { + MaiOperation.A2B: A2BAccelerator(self.wsrs.MOD), + MaiOperation.B2A: B2AAccelerator(self.wsrs.MOD), + MaiOperation.SECADD: SecAddModkAccelerator(self.wsrs.MOD), + } + + # Dispatch related variables + # The dispatch logic is responsible for pushing inputs into the accelerator. + self._dispatch_idx = 0 + self.is_dispatching = False + + # Writeback related variables + # The writeback logic is responsible for receiving results from the accelerator into the + # output WSRs. + self._writeback_idx = 0 + + def _accelerator(self) -> MaskingAccelerator: + '''Return the currently selected masking accelerator based on the operation field.''' + return self._all_accelerators[self.mai_ctrl.read_operation()] + + def step(self) -> None: + '''Advance the MAI simulation by one cycle. + + This is expected to be called before the current instruction executes / steps. + ''' + ################### + # Writeback logic # + ################### + # Get the newest result and write it into the output WSRs. This is done before + # advancing the pipeline to model the fact that the result is available at + # the start of the cycle. + results = self._accelerator().pop() + if results is not None: + # Write to the output WSRs + self.mai_res_s0.set_32bit_unsigned(results[0], self._writeback_idx) + self.mai_res_s1.set_32bit_unsigned(results[1], self._writeback_idx) + self._writeback_idx += 1 + + # Detect if we finished writing back + if self._writeback_idx >= 8: + self._writeback_idx = 0 + # If we are finishing the writeback, reset the busy bit. The write method update the + # bits when committing to the changes so the current instruction still reads the old + # value. + self.mai_status.write_busy_bit(False) + + ###################### + # Accelerator update # + ###################### + # Advance the accelerator pipeline. + self._accelerator().step() + + ################# + # Start logic # + ################# + # Start a new operation if start bit was set in last cycle + if self.mai_ctrl.read_start_bit(): + # Begin pushing inputs in the dispatch logic + self.is_dispatching = True + # Immediately set the busy bit such that the current instruction reads it as set. + self.mai_status.set_busy_bit(True) + # Immediately reset the ready bit such that the current instruction reads it as reset + # and any configuration change check does not allow changing the operation type. + self.mai_status.set_ready_bit(False) + # Immediately reset the start bit such that it always reads zero. + self.mai_ctrl.set_start_bit(False) + + ################## + # Dispatch logic # + ################## + if self.is_dispatching: + self._accelerator().push(self.mai_in0_s0.read_32bit_unsigned(self._dispatch_idx), + self.mai_in0_s1.read_32bit_unsigned(self._dispatch_idx), + self.mai_in1_s0.read_32bit_unsigned(self._dispatch_idx), + self.mai_in1_s1.read_32bit_unsigned(self._dispatch_idx)) + self._dispatch_idx += 1 + + # Detect if we have finished dispatching + if self._dispatch_idx >= 8: + self._dispatch_idx = 0 + self.is_dispatching = False + # Set the ready bit at the end of this cycle. This indicates that new inputs can be + # accepted. + self.mai_status.write_ready_bit(True) + + def is_busy(self) -> bool: + '''Returns whether the MAI is currently busy processing an operation.''' + return self.mai_status.read_busy_bit() + + def is_ready(self) -> bool: + '''Returns whether the MAI is ready to accept new inputs.''' + return self.mai_status.read_ready_bit() + + def ready_for_inputs(self) -> bool: + return self.is_ready() + + def ready_to_start(self) -> bool: + return not self.is_busy() + + def is_valid_ctrl_change(self, value: int) -> bool: + '''Return whether writing value to the MAI_CTRL CSR is currently allowed.''' + # Starting is only allowed if MAI is ready. + if self.mai_ctrl.would_set_start_bit(value) and not self.ready_to_start(): + return False + + # We only allow setting the operation to valid options. + if not self.mai_ctrl.is_valid_operation(value): + return False + + # Changing the operation is only allowed if MAI is not busy / no operation is ongoing. + if self.mai_ctrl.would_change_op(value) and self.is_busy(): + return False + + return True diff --git a/hw/ip/otbn/dv/otbnsim/sim/state.py b/hw/ip/otbn/dv/otbnsim/sim/state.py index cfb661620657a..be62382ad1fd0 100644 --- a/hw/ip/otbn/dv/otbnsim/sim/state.py +++ b/hw/ip/otbn/dv/otbnsim/sim/state.py @@ -16,6 +16,7 @@ from .gpr import GPRs from .kmac import Kmac from .loop import LoopStack +from .mai import MaskingAcceleratorInterface from .reg import RegFile from .trace import Trace, TracePC from .wsr import WSRFile @@ -202,6 +203,9 @@ def __init__(self) -> None: # random data). self.edn_seen_running = False + # The masking accelerator interface (MAI) handles the accelerators + self.mai = MaskingAcceleratorInterface(self.csrs, self.wsrs) + def get_next_pc(self) -> int: if self._pc_next_override is not None: return self._pc_next_override @@ -311,6 +315,7 @@ def step(self, handle_injected_error: bool) -> None: self.ext_regs.step() self._urnd_client.step() self.kmac.step() + self.mai.step() def commit(self, sim_stalled: bool) -> None: if self._time_to_imem_invalidation is not None: diff --git a/hw/ip/otbn/dv/otbnsim/sim/trace.py b/hw/ip/otbn/dv/otbnsim/sim/trace.py index d83065c0428f1..0f71b22e083f5 100644 --- a/hw/ip/otbn/dv/otbnsim/sim/trace.py +++ b/hw/ip/otbn/dv/otbnsim/sim/trace.py @@ -31,7 +31,7 @@ def hex_value(value: Optional[int], bit_width: int) -> str: '''Render a hex value in the format expected by RTL tracing''' if bit_width == 32: if value is None: - return '0x' + 'x' * 8 + return '0x' + 'X' * 8 else: return '{:#010x}'.format(value) diff --git a/hw/ip/otbn/dv/otbnsim/sim/wsr.py b/hw/ip/otbn/dv/otbnsim/sim/wsr.py index 5afc82c907155..90757376e447a 100644 --- a/hw/ip/otbn/dv/otbnsim/sim/wsr.py +++ b/hw/ip/otbn/dv/otbnsim/sim/wsr.py @@ -240,6 +240,52 @@ def write_unsigned(self, value: int) -> None: return +class MaiOutputWSR(ISPR): + def __init__(self, name: str) -> None: + super().__init__(name, 256) + + def write_unsigned(self, value: int) -> None: + # Writes are ignored + return + + def set_unsigned(self, value: int) -> None: + '''Sets a value that can be read by a future `read_unsigned`. + + This takes effect immediately and is used to model a write from the + MAI. This is used by the simulation environment to provide a value + that is later read by `read_unsigned` and doesn't relate to instruction + execution (e.g. in RTL the MAI will update this register when a new + result is available. Note that we do still report the change until the + next commit. + ''' + assert 0 <= value < (1 << 256) + self._value = value + self._next_value = value + self._pending_write = True + + def set_32bit_unsigned(self, value: int, index: int) -> None: + '''Sets the 32-bit chunk at the given index to the unsigned value. + The index 0 corresponds to bits [31:0], index 1 to bits [63:32],etc.. + ''' + assert 0 <= value < (1 << 32) + assert 0 <= index < 8 + current = self.read_unsigned() + mask = ((1 << 32) - 1) << (index * 32) + new_value = (current & ~mask) | (value << (index * 32)) + assert 0 <= new_value < (1 << 256) + self.set_unsigned(new_value) + + +class MaiInputWSR(ISPR): + def __init__(self, name: str) -> None: + super().__init__(name, 256) + + def read_32bit_unsigned(self, index: int) -> int: + assert 0 <= index < 8 + mask = (1 << 32) - 1 + return (self.read_unsigned() >> (32 * index)) & mask + + class WSRFile: '''A model of the WSR file''' def __init__(self, ext_regs: OTBNExtRegs) -> None: @@ -255,6 +301,12 @@ def __init__(self, ext_regs: OTBNExtRegs) -> None: self.KeyS1L = KeyWSR('KeyS1L', 0, self.KeyS1) self.KeyS1H = KeyWSR('KeyS1H', 256, self.KeyS1) self.KMAC_DATA = KmacDataWSRs(['KMAC_DATA_S0', 'KMAC_DATA_S1']) + self.MaiResS0 = MaiOutputWSR('MaiResS0') + self.MaiResS1 = MaiOutputWSR('MaiResS1') + self.MaiIn0S0 = MaiInputWSR('MaiIn0S0') + self.MaiIn0S1 = MaiInputWSR('MaiIn0S1') + self.MaiIn1S0 = MaiInputWSR('MaiIn1S0') + self.MaiIn1S1 = MaiInputWSR('MaiIn1S1') self._by_idx = { 0: self.MOD, @@ -267,6 +319,12 @@ def __init__(self, ext_regs: OTBNExtRegs) -> None: 7: self.KeyS1H, 8: self.KMAC_DATA.shares[0], 9: self.KMAC_DATA.shares[1], + 10: self.MaiResS0, + 11: self.MaiResS1, + 12: self.MaiIn0S0, + 13: self.MaiIn0S1, + 14: self.MaiIn1S0, + 15: self.MaiIn1S1, } def on_start(self) -> None: @@ -329,6 +387,12 @@ def commit(self) -> None: self.KeyS0.commit() self.KeyS1.commit() self.KMAC_DATA.commit() + self.MaiResS0.commit() + self.MaiResS1.commit() + self.MaiIn0S0.commit() + self.MaiIn0S1.commit() + self.MaiIn1S0.commit() + self.MaiIn1S1.commit() def abort(self) -> None: self.MOD.abort() @@ -340,6 +404,15 @@ def abort(self) -> None: # instruction itself gets aborted. self.KeyS0.commit() self.KeyS1.commit() + # We commit changes to the MAI output registers from outside, even if + # the instruction itself gets aborted (there is never a write to these + # WSRs from an instruction). + self.MaiResS0.commit() + self.MaiResS1.commit() + self.MaiIn0S0.abort() + self.MaiIn0S1.abort() + self.MaiIn1S0.abort() + self.MaiIn1S1.abort() def changes(self) -> List[Trace]: ret: List[Trace] = [] @@ -349,6 +422,12 @@ def changes(self) -> List[Trace]: ret += self.KeyS0.changes() ret += self.KeyS1.changes() ret += self.KMAC_DATA.changes() + ret += self.MaiResS0.changes() + ret += self.MaiResS1.changes() + ret += self.MaiIn0S0.changes() + ret += self.MaiIn0S1.changes() + ret += self.MaiIn1S0.changes() + ret += self.MaiIn1S1.changes() return ret def set_sideload_keys(self, @@ -360,3 +439,9 @@ def set_sideload_keys(self, def wipe(self) -> None: self.MOD.write_invalid() self.ACC.write_invalid() + self.MaiResS0.write_invalid() + self.MaiResS1.write_invalid() + self.MaiIn0S0.write_invalid() + self.MaiIn0S1.write_invalid() + self.MaiIn1S0.write_invalid() + self.MaiIn1S1.write_invalid() diff --git a/sw/otbn/mai/mai_test.s b/sw/otbn/mai/mai_test.s new file mode 100644 index 0000000000000..33db8db5cb4ea --- /dev/null +++ b/sw/otbn/mai/mai_test.s @@ -0,0 +1,78 @@ +/* Copyright lowRISC contributors (OpenTitan project). */ +/* Licensed under the Apache License, Version 2.0, see LICENSE for details. */ +/* SPDX-License-Identifier: Apache-2.0 */ + +/** +* Simple program to showcase how to use the MAI interface. +*/ + +.section .text.start +main: + /* Load modulus */ + li x2, 0 + la x3, mod32 + bn.lid x2, 0(x3) + bn.wsrw MOD, w0 + + /* Configure MAI - select B2A */ + /* B2A requires the value 0x1 in the field operation. This field is in bits[2:1] */ + li x2, 0x2 + csrrs x0, MAI_CTRL, x2 + csrrs x3, MAI_STATUS, x0 /* Optional, just to populate the trace */ + + /* Write data into input WSRs */ + bn.wsrw MAI_IN0_S0, w1 + bn.wsrw MAI_IN0_S1, w2 + + /* Start conversion by writing the start bit */ + li x2, 0x1 + csrrs x3, MAI_STATUS, x0 /* Optional, just to populate the trace */ + csrrs x0, MAI_CTRL, x2 + /* Read the status register to populate the trace to see when the state changes. */ + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + csrrs x3, MAI_STATUS, x0 + /* Poll busy bit - use this in production code */ +_poll_busy: + csrrs x2, MAI_STATUS, x0 + andi x2, x2, 0x1 + bne x2, x0, _poll_busy + + /* Read results from output WSRs */ + bn.wsrr w20, MAI_RES_S0 + bn.wsrr w21, MAI_RES_S1 + + ecall + +.section .data + +/* + mod32 = 8380417 +*/ +mod32: + .word 0x007fe001 + .word 0x00000000 + .word 0x00000000 + .word 0x00000000 + .word 0x00000000 + .word 0x00000000 + .word 0x00000000 + .word 0x00000000