Skip to content

Commit f239dc2

Browse files
committed
Add doc for Orca
1 parent 40b7b4b commit f239dc2

File tree

2 files changed

+347
-0
lines changed

2 files changed

+347
-0
lines changed

doc/orca-shm-aes-bank-conflict.md

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Orca Shared Memory AES: Per-Bank T-table Replication to Eliminate Bank Conflicts
2+
3+
## Background
4+
5+
Standard AES intermediate rounds (SubBytes + ShiftRows + MixColumns) can be
6+
merged into four uint32_t lookups from a 256-entry T-table plus XOR. On a GPU,
7+
shared memory has 32 banks. When threads in the same warp access different
8+
addresses in the same bank, accesses serialize (bank conflict).
9+
10+
Orca (IEEE S&P 2024) proposes replicating the T-table 32 times
11+
(`T0[256][32]`), so that each thread reads from column `threadIdx.x & 31`.
12+
Since 32 threads always land on 32 distinct banks, bank conflicts are
13+
eliminated.
14+
15+
This document records the implementation of `Aes128ShmSoft`, a device-only PRG
16+
class that applies this optimization, and its benchmark comparison against the
17+
existing `Aes128Soft`.
18+
19+
## Design
20+
21+
### Class interface
22+
23+
```cpp
24+
namespace fss::prg {
25+
template <int mul>
26+
class Aes128ShmSoft {
27+
public:
28+
struct ShmContext {
29+
uint32_t t0[256][32]; // T-table replicated 32 times, ~32 KB
30+
uint8_t sbox[256]; // last-round S-box, 256 B
31+
};
32+
33+
__device__ static void LoadShm(ShmContext &ctx);
34+
__device__ Aes128ShmSoft(const ShmContext &ctx, const uint8_t keys[][16]);
35+
__device__ cuda::std::array<int4, mul> Gen(int4 seed);
36+
};
37+
}
38+
```
39+
40+
### Comparison with `Aes128Soft`
41+
42+
| Aspect | Aes128Soft | Aes128ShmSoft |
43+
|--------|-----------|---------------|
44+
| AES algorithm | T-table: 4x u32 lookup + rotation + XOR | Same, but from 32-copy table |
45+
| Table storage | Shared memory, single copy | Shared memory, 32 copies |
46+
| Bank conflict | Possible | Eliminated |
47+
| Host support | `__host__ __device__` | `__device__` only |
48+
| Construction | Needs external te0/sbox pointers | Needs external ShmContext |
49+
| Shared memory | ~1 KB (te0[256] + sbox[256]) | ~33 KB (t0[256][32] + sbox[256]) |
50+
| Byte rotation | Shift-based (`RotWord8/16/24`) | `__byte_perm` intrinsic |
51+
| Byte swap | `reinterpret_cast` to `uint8_t*` | `__byte_perm(val, 0, 0x0123)` |
52+
53+
### Kernel usage pattern
54+
55+
```cuda
56+
__global__ void ExampleKernel(...) {
57+
// 1. Declare and load shared memory (all threads participate)
58+
__shared__ fss::prg::Aes128ShmSoft<2>::ShmContext aes_shm;
59+
fss::prg::Aes128ShmSoft<2>::LoadShm(aes_shm);
60+
__syncthreads(); // must be before any early return
61+
62+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
63+
if (tid >= kN) return;
64+
65+
// 2. Per-thread PRG construction and use
66+
fss::prg::Aes128ShmSoft<2> prg(aes_shm, kAesSoftKeys);
67+
fss::Dpf<in_bits, Group, fss::prg::Aes128ShmSoft<2>, uint> dpf{prg};
68+
// ...
69+
}
70+
```
71+
72+
`__syncthreads()` must precede the early return guard. If placed after it,
73+
threads that exit early cause the remaining threads to deadlock.
74+
75+
### Shared memory budget
76+
77+
- `t0[256][32]` = 256 * 32 * 4 B = 32768 B
78+
- `sbox[256]` = 256 B
79+
- Total: 33024 B
80+
81+
This fits on all modern GPUs (V100: 96 KB, A6000/A100: 100+ KB, H100: 228 KB)
82+
but reduces occupancy compared to the 1 KB footprint of `Aes128Soft`.
83+
84+
## Implementation details
85+
86+
### T-table lookup (middle rounds)
87+
88+
Each of the 9 middle rounds computes four state words via:
89+
90+
```cpp
91+
int wTid = threadIdx.x & 31;
92+
uint32_t t0 = ctx_->t0[s0 >> 24][wTid]
93+
^ RotRight8(ctx_->t0[(s1 >> 16) & 0xff][wTid])
94+
^ RotRight16(ctx_->t0[(s2 >> 8) & 0xff][wTid])
95+
^ RotRight24(ctx_->t0[s3 & 0xff][wTid])
96+
^ rk[r * 4];
97+
// ... similarly for t1, t2, t3
98+
```
99+
100+
`RotRight8/16/24` use `__byte_perm(x, x, selector)`:
101+
102+
| Function | Selector | Effect |
103+
|----------|----------|--------|
104+
| RotRight8 | `0x0321` | `(x >> 8) \| (x << 24)` |
105+
| RotRight16 | `0x1032` | `(x >> 16) \| (x << 16)` |
106+
| RotRight24 | `0x2103` | `(x >> 24) \| (x << 8)` |
107+
108+
### Byte order conversion
109+
110+
`int4` stores four little-endian ints. AES state columns are big-endian u32.
111+
The conversion in both directions is `__byte_perm(val, 0, 0x0123)` (byte
112+
reverse).
113+
114+
### Key expansion
115+
116+
Reuses the existing `aes_detail::KeyExpansion` (byte-level), then converts the
117+
176-byte round key array into 44 big-endian `uint32_t` values at construction
118+
time. This runs once per thread and is not on the hot path.
119+
120+
### Register usage
121+
122+
Compiled for sm_75 (local) and sm_52 (remote):
123+
124+
| Kernel | Registers | Stack | Spill |
125+
|--------|-----------|-------|-------|
126+
| `AesShmPrgTestKernel` (sm_75) | 72 | 624 B | 0 |
127+
| `DpfEvalKernelAesShm` (sm_52) | 56 | 608 B | 0 |
128+
| `DpfGenKernelAesShm` (sm_52) | 73 | 624 B | 0 |
129+
130+
No spills. The 624 B stack frame comes from `KeyExpansion`'s 176-byte
131+
temporary buffer and the `round_keys_[2][44]` member (352 B).
132+
133+
## Correctness
134+
135+
A test (`src/aes128_shm_soft_test.cu`) runs `Aes128ShmSoft<2>::Gen()` on GPU
136+
for 1024 random seeds and compares each output byte-for-byte against
137+
`Aes128Soft<2>::Gen()` on host with the same keys. The test passes,
138+
confirming identical AES encryption results.
139+
140+
## Benchmark results
141+
142+
Machine: 4x NVIDIA RTX A6000 (compute capability 8.6, Ampere), CUDA 12.6.
143+
Block size 256, 1M DPF instances (in_bits=20, UintGroup).
144+
145+
```
146+
BM_DpfEval_Uint_AesSoft/20 23.9 ms 43.8 M items/s
147+
BM_DpfEval_Uint_AesShmSoft/20 23.7 ms 44.2 M items/s (+0.7%)
148+
BM_DpfGen_Uint_AesShmSoft/20 46.4 ms 22.6 M items/s
149+
```
150+
151+
The 32-copy T-table shows a marginal ~0.7% improvement for DPF Eval over the
152+
single-copy version.
153+
154+
### Analysis
155+
156+
The lack of significant speedup likely comes from several factors:
157+
158+
1. The single-copy `Aes128Soft` already uses shared memory for its T-table.
159+
With 256 threads per block, the 32 threads within each warp access T-table
160+
entries determined by AES state bytes, which are pseudo-random. Random
161+
accesses across 256 entries in 32 banks have a low collision probability by
162+
chance (~2-3 conflicts per round on average), so the baseline already has
163+
limited bank conflict overhead.
164+
165+
2. The 33 KB shared memory footprint of `Aes128ShmSoft` reduces occupancy. On
166+
A6000 (100 KB shared memory per SM), this allows at most 3 blocks per SM
167+
vs. potentially 4+ for the 1 KB `Aes128Soft`. Lower occupancy reduces the
168+
GPU's ability to hide memory latency through warp switching.
169+
170+
3. The DPF computation is not purely AES-bound. Each DPF level calls
171+
`prg.Gen()` once but also does control-bit logic, XOR corrections, and
172+
memory loads for correction words. AES throughput improvements are diluted
173+
by these surrounding operations.
174+
175+
4. Orca targets a different workload (large-batch AES encryption with many
176+
rounds per thread) where bank conflicts dominate. In DPF evaluation, each
177+
thread does 20 AES calls (one per tree level) with other work interleaved,
178+
which changes the performance profile.
179+
180+
## Files
181+
182+
- `include/fss/prg/aes128_shm_soft.cuh` -- implementation
183+
- `src/aes128_shm_soft_test.cu` -- correctness test
184+
- `src/bench_gpu.cu` -- benchmark (AesShmSoft kernels and registration)
185+
- `doc/plans/2026-03-10-aes128-shm-soft.md` -- original implementation plan
186+
187+
## References
188+
189+
- Orca: N. Jawalkar, K. Gupta, A. Bhatia, N. Chandran, D. Gupta, R. Sharma.
190+
"Orca: FSS-based Secure Training and Inference with GPUs." IEEE S&P 2024.
191+
- EzPC reference implementation:
192+
https://github.com/mpc-msri/EzPC/blob/master/GPU-MPC/fss/gpu_aes_shm.cu
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
diff --git a/CMakeLists.txt b/CMakeLists.txt
2+
index b7d0dac..d392187 100644
3+
--- a/CMakeLists.txt
4+
+++ b/CMakeLists.txt
5+
@@ -78,6 +78,10 @@ if(BUILD_TESTING)
6+
add_executable(grotto_dcf_test src/grotto_dcf_test.cu)
7+
target_link_libraries(grotto_dcf_test GTest::gtest_main fss)
8+
gtest_discover_tests(grotto_dcf_test)
9+
+
10+
+ add_executable(aes_shm_soft_test src/aes128_shm_soft_test.cu)
11+
+ target_link_libraries(aes_shm_soft_test GTest::gtest_main fss)
12+
+ gtest_discover_tests(aes_shm_soft_test)
13+
endif()
14+
15+
option(BUILD_BENCH "Build benchmarks" OFF)
16+
diff --git a/src/bench_gpu.cu b/src/bench_gpu.cu
17+
index afe9690..8c81686 100644
18+
--- a/src/bench_gpu.cu
19+
+++ b/src/bench_gpu.cu
20+
@@ -10,6 +10,7 @@
21+
#include <fss/group/uint.cuh>
22+
#include <fss/prg/chacha.cuh>
23+
#include <fss/prg/aes128_mmo_soft.cuh>
24+
+#include <fss/prg/aes128_shm_soft.cuh>
25+
#include <fss/hash/blake3.cuh>
26+
27+
constexpr int kN = 1 << 20;
28+
@@ -116,6 +117,43 @@ __global__ void DpfEvalKernelAes(int4 *ys, bool party, const int4 *seeds,
29+
ys[tid] = dpf.Eval(party, seeds[tid], cws + tid * (in_bits + 1), xs[tid]);
30+
}
31+
32+
+// --- DPF Kernels (Aes128ShmSoft<2>) ---
33+
+
34+
+template <int in_bits, typename Group>
35+
+__global__ void DpfGenKernelAesShm(
36+
+ typename fss::Dpf<in_bits, Group, fss::prg::Aes128ShmSoft<2>, uint>::Cw *cws, const int4 *seeds,
37+
+ const uint *alphas, const int4 *betas) {
38+
+ __shared__ fss::prg::Aes128ShmSoft<2>::ShmContext aes_shm;
39+
+ fss::prg::Aes128ShmSoft<2>::LoadShm(aes_shm);
40+
+ __syncthreads();
41+
+
42+
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
43+
+ if (tid >= kN) return;
44+
+
45+
+ fss::prg::Aes128ShmSoft<2> prg(aes_shm, kAesSoftKeys);
46+
+ fss::Dpf<in_bits, Group, fss::prg::Aes128ShmSoft<2>, uint> dpf{prg};
47+
+
48+
+ int4 s[2] = {seeds[tid * 2], seeds[tid * 2 + 1]};
49+
+ dpf.Gen(cws + tid * (in_bits + 1), s, alphas[tid], betas[tid]);
50+
+}
51+
+
52+
+template <int in_bits, typename Group>
53+
+__global__ void DpfEvalKernelAesShm(int4 *ys, bool party, const int4 *seeds,
54+
+ const typename fss::Dpf<in_bits, Group, fss::prg::Aes128ShmSoft<2>, uint>::Cw *cws,
55+
+ const uint *xs) {
56+
+ __shared__ fss::prg::Aes128ShmSoft<2>::ShmContext aes_shm;
57+
+ fss::prg::Aes128ShmSoft<2>::LoadShm(aes_shm);
58+
+ __syncthreads();
59+
+
60+
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
61+
+ if (tid >= kN) return;
62+
+
63+
+ fss::prg::Aes128ShmSoft<2> prg(aes_shm, kAesSoftKeys);
64+
+ fss::Dpf<in_bits, Group, fss::prg::Aes128ShmSoft<2>, uint> dpf{prg};
65+
+
66+
+ ys[tid] = dpf.Eval(party, seeds[tid], cws + tid * (in_bits + 1), xs[tid]);
67+
+}
68+
+
69+
// --- DCF Kernels (ChaCha<4>) ---
70+
71+
template <int in_bits, typename Group>
72+
@@ -375,6 +413,74 @@ static void BM_DpfEvalAes(benchmark::State &state) {
73+
cudaFree(d_cws);
74+
}
75+
76+
+// --- DPF Eval GPU benchmark (Aes128ShmSoft<2>) ---
77+
+
78+
+template <int in_bits, typename Group>
79+
+static void BM_DpfEvalAesShm(benchmark::State &state) {
80+
+ using DpfType = fss::Dpf<in_bits, Group, fss::prg::Aes128ShmSoft<2>, uint>;
81+
+ GpuData data;
82+
+
83+
+ typename DpfType::Cw *d_cws;
84+
+ CUDA_CHECK(cudaMalloc(&d_cws, sizeof(typename DpfType::Cw) * (in_bits + 1) * kN));
85+
+
86+
+ DpfGenKernelAesShm<in_bits, Group>
87+
+ <<<kNumBlocks, kThreadsPerBlock>>>(d_cws, data.d_seeds, data.d_alphas, data.d_betas);
88+
+ CUDA_CHECK(cudaDeviceSynchronize());
89+
+
90+
+ for (auto _ : state) {
91+
+ cudaEvent_t start, stop;
92+
+ cudaEventCreate(&start);
93+
+ cudaEventCreate(&stop);
94+
+ cudaEventRecord(start);
95+
+
96+
+ DpfEvalKernelAesShm<in_bits, Group>
97+
+ <<<kNumBlocks, kThreadsPerBlock>>>(data.d_ys, false, data.d_seeds0, d_cws, data.d_xs);
98+
+
99+
+ cudaEventRecord(stop);
100+
+ cudaEventSynchronize(stop);
101+
+ float ms = 0;
102+
+ cudaEventElapsedTime(&ms, start, stop);
103+
+ state.SetIterationTime(ms / 1000.0);
104+
+ cudaEventDestroy(start);
105+
+ cudaEventDestroy(stop);
106+
+ }
107+
+ state.SetItemsProcessed(state.iterations() * kN);
108+
+
109+
+ cudaFree(d_cws);
110+
+}
111+
+
112+
+// --- DPF Gen GPU benchmark (Aes128ShmSoft<2>) ---
113+
+
114+
+template <int in_bits, typename Group>
115+
+static void BM_DpfGenAesShm(benchmark::State &state) {
116+
+ using DpfType = fss::Dpf<in_bits, Group, fss::prg::Aes128ShmSoft<2>, uint>;
117+
+ GpuData data;
118+
+
119+
+ typename DpfType::Cw *d_cws;
120+
+ CUDA_CHECK(cudaMalloc(&d_cws, sizeof(typename DpfType::Cw) * (in_bits + 1) * kN));
121+
+
122+
+ for (auto _ : state) {
123+
+ cudaEvent_t start, stop;
124+
+ cudaEventCreate(&start);
125+
+ cudaEventCreate(&stop);
126+
+ cudaEventRecord(start);
127+
+
128+
+ DpfGenKernelAesShm<in_bits, Group>
129+
+ <<<kNumBlocks, kThreadsPerBlock>>>(d_cws, data.d_seeds, data.d_alphas, data.d_betas);
130+
+
131+
+ cudaEventRecord(stop);
132+
+ cudaEventSynchronize(stop);
133+
+ float ms = 0;
134+
+ cudaEventElapsedTime(&ms, start, stop);
135+
+ state.SetIterationTime(ms / 1000.0);
136+
+ cudaEventDestroy(start);
137+
+ cudaEventDestroy(stop);
138+
+ }
139+
+ state.SetItemsProcessed(state.iterations() * kN);
140+
+
141+
+ cudaFree(d_cws);
142+
+}
143+
+
144+
// --- DCF Eval GPU benchmark (ChaCha<4>) ---
145+
146+
template <int in_bits, typename Group>
147+
@@ -608,6 +714,8 @@ BENCHMARK(BM_DpfEval<20, BytesGroup>)->Name("BM_DpfEval_Bytes/20")->UseManualTim
148+
149+
// DPF other PRG
150+
BENCHMARK(BM_DpfEvalAes<20, UintGroup>)->Name("BM_DpfEval_Uint_AesSoft/20")->UseManualTime();
151+
+BENCHMARK(BM_DpfEvalAesShm<20, UintGroup>)->Name("BM_DpfEval_Uint_AesShmSoft/20")->UseManualTime();
152+
+BENCHMARK(BM_DpfGenAesShm<20, UintGroup>)->Name("BM_DpfGen_Uint_AesShmSoft/20")->UseManualTime();
153+
154+
// DCF (ChaCha<4>)
155+
BENCHMARK(BM_DcfEval<20, UintGroup>)->Name("BM_DcfEval_Uint/20")->UseManualTime();

0 commit comments

Comments
 (0)