|
| 1 | +# Orca Shared Memory AES: Per-Bank T-table Replication to Eliminate Bank Conflicts |
| 2 | + |
| 3 | +## Background |
| 4 | + |
| 5 | +Standard AES intermediate rounds (SubBytes + ShiftRows + MixColumns) can be |
| 6 | +merged into four uint32_t lookups from a 256-entry T-table plus XOR. On a GPU, |
| 7 | +shared memory has 32 banks. When threads in the same warp access different |
| 8 | +addresses in the same bank, accesses serialize (bank conflict). |
| 9 | + |
| 10 | +Orca (IEEE S&P 2024) proposes replicating the T-table 32 times |
| 11 | +(`T0[256][32]`), so that each thread reads from column `threadIdx.x & 31`. |
| 12 | +Since 32 threads always land on 32 distinct banks, bank conflicts are |
| 13 | +eliminated. |
| 14 | + |
| 15 | +This document records the implementation of `Aes128ShmSoft`, a device-only PRG |
| 16 | +class that applies this optimization, and its benchmark comparison against the |
| 17 | +existing `Aes128Soft`. |
| 18 | + |
| 19 | +## Design |
| 20 | + |
| 21 | +### Class interface |
| 22 | + |
| 23 | +```cpp |
| 24 | +namespace fss::prg { |
| 25 | +template <int mul> |
| 26 | +class Aes128ShmSoft { |
| 27 | +public: |
| 28 | + struct ShmContext { |
| 29 | + uint32_t t0[256][32]; // T-table replicated 32 times, ~32 KB |
| 30 | + uint8_t sbox[256]; // last-round S-box, 256 B |
| 31 | + }; |
| 32 | + |
| 33 | + __device__ static void LoadShm(ShmContext &ctx); |
| 34 | + __device__ Aes128ShmSoft(const ShmContext &ctx, const uint8_t keys[][16]); |
| 35 | + __device__ cuda::std::array<int4, mul> Gen(int4 seed); |
| 36 | +}; |
| 37 | +} |
| 38 | +``` |
| 39 | +
|
| 40 | +### Comparison with `Aes128Soft` |
| 41 | +
|
| 42 | +| Aspect | Aes128Soft | Aes128ShmSoft | |
| 43 | +|--------|-----------|---------------| |
| 44 | +| AES algorithm | T-table: 4x u32 lookup + rotation + XOR | Same, but from 32-copy table | |
| 45 | +| Table storage | Shared memory, single copy | Shared memory, 32 copies | |
| 46 | +| Bank conflict | Possible | Eliminated | |
| 47 | +| Host support | `__host__ __device__` | `__device__` only | |
| 48 | +| Construction | Needs external te0/sbox pointers | Needs external ShmContext | |
| 49 | +| Shared memory | ~1 KB (te0[256] + sbox[256]) | ~33 KB (t0[256][32] + sbox[256]) | |
| 50 | +| Byte rotation | Shift-based (`RotWord8/16/24`) | `__byte_perm` intrinsic | |
| 51 | +| Byte swap | `reinterpret_cast` to `uint8_t*` | `__byte_perm(val, 0, 0x0123)` | |
| 52 | +
|
| 53 | +### Kernel usage pattern |
| 54 | +
|
| 55 | +```cuda |
| 56 | +__global__ void ExampleKernel(...) { |
| 57 | + // 1. Declare and load shared memory (all threads participate) |
| 58 | + __shared__ fss::prg::Aes128ShmSoft<2>::ShmContext aes_shm; |
| 59 | + fss::prg::Aes128ShmSoft<2>::LoadShm(aes_shm); |
| 60 | + __syncthreads(); // must be before any early return |
| 61 | +
|
| 62 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 63 | + if (tid >= kN) return; |
| 64 | +
|
| 65 | + // 2. Per-thread PRG construction and use |
| 66 | + fss::prg::Aes128ShmSoft<2> prg(aes_shm, kAesSoftKeys); |
| 67 | + fss::Dpf<in_bits, Group, fss::prg::Aes128ShmSoft<2>, uint> dpf{prg}; |
| 68 | + // ... |
| 69 | +} |
| 70 | +``` |
| 71 | + |
| 72 | +`__syncthreads()` must precede the early return guard. If placed after it, |
| 73 | +threads that exit early cause the remaining threads to deadlock. |
| 74 | + |
| 75 | +### Shared memory budget |
| 76 | + |
| 77 | +- `t0[256][32]` = 256 * 32 * 4 B = 32768 B |
| 78 | +- `sbox[256]` = 256 B |
| 79 | +- Total: 33024 B |
| 80 | + |
| 81 | +This fits on all modern GPUs (V100: 96 KB, A6000/A100: 100+ KB, H100: 228 KB) |
| 82 | +but reduces occupancy compared to the 1 KB footprint of `Aes128Soft`. |
| 83 | + |
| 84 | +## Implementation details |
| 85 | + |
| 86 | +### T-table lookup (middle rounds) |
| 87 | + |
| 88 | +Each of the 9 middle rounds computes four state words via: |
| 89 | + |
| 90 | +```cpp |
| 91 | +int wTid = threadIdx.x & 31; |
| 92 | +uint32_t t0 = ctx_->t0[s0 >> 24][wTid] |
| 93 | + ^ RotRight8(ctx_->t0[(s1 >> 16) & 0xff][wTid]) |
| 94 | + ^ RotRight16(ctx_->t0[(s2 >> 8) & 0xff][wTid]) |
| 95 | + ^ RotRight24(ctx_->t0[s3 & 0xff][wTid]) |
| 96 | + ^ rk[r * 4]; |
| 97 | +// ... similarly for t1, t2, t3 |
| 98 | +``` |
| 99 | + |
| 100 | +`RotRight8/16/24` use `__byte_perm(x, x, selector)`: |
| 101 | + |
| 102 | +| Function | Selector | Effect | |
| 103 | +|----------|----------|--------| |
| 104 | +| RotRight8 | `0x0321` | `(x >> 8) \| (x << 24)` | |
| 105 | +| RotRight16 | `0x1032` | `(x >> 16) \| (x << 16)` | |
| 106 | +| RotRight24 | `0x2103` | `(x >> 24) \| (x << 8)` | |
| 107 | + |
| 108 | +### Byte order conversion |
| 109 | + |
| 110 | +`int4` stores four little-endian ints. AES state columns are big-endian u32. |
| 111 | +The conversion in both directions is `__byte_perm(val, 0, 0x0123)` (byte |
| 112 | +reverse). |
| 113 | + |
| 114 | +### Key expansion |
| 115 | + |
| 116 | +Reuses the existing `aes_detail::KeyExpansion` (byte-level), then converts the |
| 117 | +176-byte round key array into 44 big-endian `uint32_t` values at construction |
| 118 | +time. This runs once per thread and is not on the hot path. |
| 119 | + |
| 120 | +### Register usage |
| 121 | + |
| 122 | +Compiled for sm_75 (local) and sm_52 (remote): |
| 123 | + |
| 124 | +| Kernel | Registers | Stack | Spill | |
| 125 | +|--------|-----------|-------|-------| |
| 126 | +| `AesShmPrgTestKernel` (sm_75) | 72 | 624 B | 0 | |
| 127 | +| `DpfEvalKernelAesShm` (sm_52) | 56 | 608 B | 0 | |
| 128 | +| `DpfGenKernelAesShm` (sm_52) | 73 | 624 B | 0 | |
| 129 | + |
| 130 | +No spills. The 624 B stack frame comes from `KeyExpansion`'s 176-byte |
| 131 | +temporary buffer and the `round_keys_[2][44]` member (352 B). |
| 132 | + |
| 133 | +## Correctness |
| 134 | + |
| 135 | +A test (`src/aes128_shm_soft_test.cu`) runs `Aes128ShmSoft<2>::Gen()` on GPU |
| 136 | +for 1024 random seeds and compares each output byte-for-byte against |
| 137 | +`Aes128Soft<2>::Gen()` on host with the same keys. The test passes, |
| 138 | +confirming identical AES encryption results. |
| 139 | + |
| 140 | +## Benchmark results |
| 141 | + |
| 142 | +Machine: 4x NVIDIA RTX A6000 (compute capability 8.6, Ampere), CUDA 12.6. |
| 143 | +Block size 256, 1M DPF instances (in_bits=20, UintGroup). |
| 144 | + |
| 145 | +``` |
| 146 | +BM_DpfEval_Uint_AesSoft/20 23.9 ms 43.8 M items/s |
| 147 | +BM_DpfEval_Uint_AesShmSoft/20 23.7 ms 44.2 M items/s (+0.7%) |
| 148 | +BM_DpfGen_Uint_AesShmSoft/20 46.4 ms 22.6 M items/s |
| 149 | +``` |
| 150 | + |
| 151 | +The 32-copy T-table shows a marginal ~0.7% improvement for DPF Eval over the |
| 152 | +single-copy version. |
| 153 | + |
| 154 | +### Analysis |
| 155 | + |
| 156 | +The lack of significant speedup likely comes from several factors: |
| 157 | + |
| 158 | +1. The single-copy `Aes128Soft` already uses shared memory for its T-table. |
| 159 | + With 256 threads per block, the 32 threads within each warp access T-table |
| 160 | + entries determined by AES state bytes, which are pseudo-random. Random |
| 161 | + accesses across 256 entries in 32 banks have a low collision probability by |
| 162 | + chance (~2-3 conflicts per round on average), so the baseline already has |
| 163 | + limited bank conflict overhead. |
| 164 | + |
| 165 | +2. The 33 KB shared memory footprint of `Aes128ShmSoft` reduces occupancy. On |
| 166 | + A6000 (100 KB shared memory per SM), this allows at most 3 blocks per SM |
| 167 | + vs. potentially 4+ for the 1 KB `Aes128Soft`. Lower occupancy reduces the |
| 168 | + GPU's ability to hide memory latency through warp switching. |
| 169 | + |
| 170 | +3. The DPF computation is not purely AES-bound. Each DPF level calls |
| 171 | + `prg.Gen()` once but also does control-bit logic, XOR corrections, and |
| 172 | + memory loads for correction words. AES throughput improvements are diluted |
| 173 | + by these surrounding operations. |
| 174 | + |
| 175 | +4. Orca targets a different workload (large-batch AES encryption with many |
| 176 | + rounds per thread) where bank conflicts dominate. In DPF evaluation, each |
| 177 | + thread does 20 AES calls (one per tree level) with other work interleaved, |
| 178 | + which changes the performance profile. |
| 179 | + |
| 180 | +## Files |
| 181 | + |
| 182 | +- `include/fss/prg/aes128_shm_soft.cuh` -- implementation |
| 183 | +- `src/aes128_shm_soft_test.cu` -- correctness test |
| 184 | +- `src/bench_gpu.cu` -- benchmark (AesShmSoft kernels and registration) |
| 185 | +- `doc/plans/2026-03-10-aes128-shm-soft.md` -- original implementation plan |
| 186 | + |
| 187 | +## References |
| 188 | + |
| 189 | +- Orca: N. Jawalkar, K. Gupta, A. Bhatia, N. Chandran, D. Gupta, R. Sharma. |
| 190 | + "Orca: FSS-based Secure Training and Inference with GPUs." IEEE S&P 2024. |
| 191 | +- EzPC reference implementation: |
| 192 | + https://github.com/mpc-msri/EzPC/blob/master/GPU-MPC/fss/gpu_aes_shm.cu |
0 commit comments