diff --git a/Makefile b/Makefile index 52eff32..835a31f 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ ifndef KEYSTONE_SDK_DIR endif CFLAGS = -Wall -Werror -fPIC -fno-builtin -std=c11 -g $(OPTIONS_FLAGS) -SRCS = aes.c sha256.c boot.c interrupt.c printf.c syscall.c string.c linux_wrap.c io_wrap.c rt_util.c mm.c env.c freemem.c paging.c sbi.c merkle.c page_swap.c vm.c +SRCS = aes.c sha256.c boot.c interrupt.c printf.c syscall.c string.c linux_wrap.c io_wrap.c rt_util.c mm.c env.c freemem.c paging.c sbi.c hmac.c page_swap.c vm.c ASM_SRCS = entry.S RUNTIME = eyrie-rt LINK = $(CROSS_COMPILE)ld diff --git a/bench/Makefile b/bench/Makefile new file mode 100644 index 0000000..8c2035f --- /dev/null +++ b/bench/Makefile @@ -0,0 +1,29 @@ +CFLAGS := -I . -g -O3 -Wall -Wextra +LDFLAGS := -lm +BENCH_FLAGS := -DUSE_FREEMEM -DUSE_PAGING -DUSE_PAGE_HASH -DUSE_PAGE_CRYPTO -D__riscv_xlen=64 + +.PHONY: benches clean + +benches: hmac crypto page_swap + +MODULES := hmac.o sha256.o aes.o page_swap.o freemem.o bench/bencher.o bench/crypto.o bench/hmac.o bench/page_swap.o +OBJ_PREFIX := objs/ + +$(addprefix $(OBJ_PREFIX),$(MODULES)) :: $(OBJ_PREFIX)%.o : ../%.c + mkdir -p $(shell dirname $@) + $(CC) $(CFLAGS) $(BENCH_FLAGS) -c -o $@ $^ + +HMAC_MODULES := hmac.o sha256.o bench/bencher.o bench/hmac.o +hmac: $(addprefix $(OBJ_PREFIX),$(HMAC_MODULES)) + $(CC) $(CFLAGS) $(BENCH_FLAGS) $^ $(LDFLAGS) -o $@ + +CRYPTO_MODULES := sha256.o aes.o bench/bencher.o bench/crypto.o +crypto: $(addprefix $(OBJ_PREFIX),$(CRYPTO_MODULES)) + $(CC) $(CFLAGS) $(BENCH_FLAGS) $^ $(LDFLAGS) -o $@ + +PAGE_SWAP_MODULES := sha256.o aes.o hmac.o freemem.o page_swap.o bench/bencher.o bench/page_swap.o +page_swap: $(addprefix $(OBJ_PREFIX),$(PAGE_SWAP_MODULES)) + $(CC) $(CFLAGS) $(BENCH_FLAGS) $^ $(LDFLAGS) -o $@ + +clean: + rm -rf hmac crypto page_swap objs/ diff --git a/bench/bencher.c b/bench/bencher.c new file mode 100644 index 0000000..96832fa --- /dev/null +++ b/bench/bencher.c @@ -0,0 +1,281 @@ +#define _GNU_SOURCE + +#include "bencher.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#define ERROR(...) fprintf(stderr, "[ERROR] " __VA_ARGS__) + +static char* +precise_time_str(double seconds) { + double precision; + const char* prec_name; + if (seconds < 1e-6) { + prec_name = "ns"; + precision = 1e9; + } else if (seconds < 1e-3) { + prec_name = "us"; + precision = 1e6; + } else if (seconds < 1) { + prec_name = "ms"; + precision = 1e3; + } else { + prec_name = "s"; + precision = 1; + } + + char* out; + double adjusted = seconds * precision; + asprintf(&out, "%f%s", trunc(adjusted * 1000) / 1000, prec_name); + return out; +} + +static int +measure_iters( + struct bench* bench, void* ctx, clock_t clocks, size_t chunk_size, + size_t* niters) { + int err; + size_t i = 0; + + err = bench->init(ctx); + if (err) { + ERROR("Failed to initialize benchmark `%s`!\n", bench->name); + return err; + } + + clock_t end_clock = clock() + clocks; + while (clock() < end_clock) { + for (size_t ci = 0; ci < chunk_size; ci++, i++) { + err = bench->iter(ctx); + if (err) { + ERROR( + "Failed to run iteration `%zu` of benchmark `%s`!\n", i, + bench->name); + return err; + } + } + } + + assert(niters); + *niters = i; + + err = bench->deinit(ctx); + if (err) { + ERROR("Failed to deinitialize benchmark `%s`!\n", bench->name); + return err; + } + + return 0; +} + +static int +run_bench( + struct bench* bench, void* ctx, size_t nchunks, size_t chunk_size, + double* chunk_mean, double* chunk_std) { + int err; + + err = bench->init(ctx); + if (err) { + ERROR("Failed to initialize benchmark `%s`!\n", bench->name); + return err; + } + + clock_t chunk_start, chunk_end; + size_t chunk; + double mean = 0, m2 = 0; + + for (chunk = 0; chunk < nchunks; chunk++) { + chunk_start = clock(); + + for (size_t ci = 0; ci < chunk_size; ci++) { + err = bench->iter(ctx); + if (err) { + ERROR( + "Failed to run iteration `%zu` of benchmark `%s`!\n", + chunk * chunk_size + ci, bench->name); + return err; + } + } + + chunk_end = clock(); + + double chunk_time = (double)(chunk_end - chunk_start) / CLOCKS_PER_SEC; + double delta = chunk_time - mean; + mean += delta / (chunk + 1); + double delta2 = chunk_time - mean; + m2 += delta * delta2; + } + + double variance = m2 / chunk; + assert(chunk_mean); + assert(chunk_std); + *chunk_mean = mean; + *chunk_std = sqrt(variance); + + err = bench->deinit(ctx); + if (err) { + ERROR("Failed to deinitialize benchmark `%s`!\n", bench->name); + return err; + } + + return 0; +} + +static int +approximate_chunk_size( + struct bench_opts* opts, struct bench* bench, void* ctx, + size_t* chunk_size) { + double single_time, _single_std; + + if (opts->verbose) + printf("`%s`: Measuring single iteration... ", bench->name); + + int err = run_bench(bench, ctx, 1, 1, &single_time, &_single_std); + if (err) return err; + + if (opts->verbose) { + char* iter_time = precise_time_str(single_time); + assert(iter_time); + printf("<%s\n", iter_time); + free(iter_time); + } + + double approx_measured_iters = opts->measure_secs / single_time; + *chunk_size = pow(2.0, log2(approx_measured_iters) / 2); + return 0; +} + +static int +approximate_bench_chunks( + struct bench_opts* opts, struct bench* bench, void* ctx, size_t chunk_size, + size_t* nchunks) { + clock_t measure_clocks = opts->measure_secs * CLOCKS_PER_SEC; + size_t measured_iters; + + if (opts->verbose) + printf( + "`%s`: Counting num iterations in %ds... ", bench->name, + opts->measure_secs); + + int err = + measure_iters(bench, ctx, measure_clocks, chunk_size, &measured_iters); + if (err) return err; + + if (opts->verbose) printf("%zu iterations.\n", measured_iters); + + size_t bench_iters = (opts->bench_secs / opts->measure_secs) * measured_iters; + *nchunks = bench_iters / chunk_size; + return 0; +} + +int +run_benches( + struct bench_opts* opts, struct bench* benches, size_t nbenches, + void* ctx) { + regex_t filter; + if (opts->filter) { + int err = regcomp(&filter, opts->filter, REG_NOSUB); + if (err) { + ERROR("Bad regular expression: `%s`", opts->filter); + exit(err); + } + } + + int measure_secs = opts->measure_secs; + int bench_secs = opts->bench_secs; + assert(bench_secs % measure_secs == 0); + +#define CHECK_ERR(err, ...) \ + if (err) { \ + putchar('\n'); \ + ERROR(__VA_ARGS__); \ + exit(err); \ + } + + for (size_t i = 0; i < nbenches; i++) { + struct bench* bench = &benches[i]; + if (opts->filter && regexec(&filter, bench->name, 0, NULL, 0) != 0) + continue; + + char* iter_time; + int err; + size_t chunk_size; + size_t nchunks; + + err = approximate_chunk_size(opts, bench, ctx, &chunk_size); + CHECK_ERR(err, "Failed to approximate chunk size for `%s`!\n", bench->name); + + err = approximate_bench_chunks(opts, bench, ctx, chunk_size, &nchunks); + CHECK_ERR( + err, "Failed to approximate # benchmark iters for `%s`!\n", + bench->name); + + double chunk_mean, chunk_std; + + printf( + "`%s`: Benchmarking %zux%zu iterations (~%ds)...\n", bench->name, + nchunks, chunk_size, bench_secs); + err = run_bench(bench, ctx, nchunks, chunk_size, &chunk_mean, &chunk_std); + CHECK_ERR(err, "Failed to measure `%s`!\n", bench->name); + + iter_time = precise_time_str(chunk_mean); + printf(" Chunk runtime: %s", iter_time); + free(iter_time); + + iter_time = precise_time_str(chunk_std); + printf(" (+/- %s)\n", iter_time); + free(iter_time); + + iter_time = precise_time_str(chunk_mean / chunk_size); + printf(" Time/Iter: %s\n", iter_time); + free(iter_time); + } + + return 0; +} + +static error_t +bench_arg_parser(int key, char* val, struct argp_state* state) { + struct bench_opts* opts = (struct bench_opts*)state->input; + switch (key) { + case 'f': + opts->filter = strdup(val); + break; + case 'v': + opts->verbose = true; + break; + default: + return ARGP_ERR_UNKNOWN; + } + return 0; +} + +struct bench_opts +bench_argp(int argc, char** argv) { + struct bench_opts opts = { + .measure_secs = 2, + .bench_secs = 10, + }; + + struct argp_option argp_opts[] = { + {.name = "filter", .key = 'f', .arg = "REGEXP"}, + {.name = "verbose", .key = 'v'}, + {}}; + struct argp argp = { + .options = argp_opts, + .parser = bench_arg_parser, + }; + error_t err = argp_parse(&argp, argc, argv, 0, 0, &opts); + if (err) { + ERROR("Can't parse arguments!\n"); + } + + return opts; +} diff --git a/bench/bencher.h b/bench/bencher.h new file mode 100644 index 0000000..b318569 --- /dev/null +++ b/bench/bencher.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +struct bench { + const char* name; + int (*init)(void* ctx); + int (*iter)(void* ctx); + int (*deinit)(void* ctx); +}; + +struct bench_opts { + bool verbose; + char* filter; + int measure_secs; + int bench_secs; +}; + +struct bench_opts +bench_argp(int argc, char** argv); +int +run_benches( + struct bench_opts* opts, struct bench* benches, size_t nbenches, void* ctx); diff --git a/bench/crypto.c b/bench/crypto.c new file mode 100644 index 0000000..4118142 --- /dev/null +++ b/bench/crypto.c @@ -0,0 +1,99 @@ +#define _GNU_SOURCE + +#include <../aes.h> +#include <../sha256.h> +#include +#include +#include +#include +#include +#include +#include + +#define RAND_BOOK_LEN 1024 + +typedef struct bench_ctx { + int* rand_book; + int rand_idx; +} bench_ctx_t; + +void +bench_ctx__init(bench_ctx_t* ctx) { + ctx->rand_book = (int*)malloc(sizeof(int) * RAND_BOOK_LEN); + for (int i = 0; i < RAND_BOOK_LEN; i++) { + ctx->rand_book[i] = rand(); + } +} + +int +nop(void* _ctx) { + (void)_ctx; + return 0; +} + +int +bench_sha_1byte(void* _ctx) { + SHA256_CTX sha; + volatile uint8_t byte = 0; + uint8_t hash[32]; + + sha256_init(&sha); + sha256_update(&sha, &byte, 1); + sha256_final(&sha, hash); + + __asm__ __volatile__("" ::"r"(hash)); + return 0; +} + +int +bench_sha_32byte(void* _ctx) { + SHA256_CTX sha; + uint8_t hash[32]; + bench_ctx_t* ctx = (bench_ctx_t*)_ctx; + + sha256_init(&sha); + sha256_update(&sha, ctx->rand_book, 32); + sha256_final(&sha, hash); + + __asm__ __volatile__("" ::"r"(hash)); + return 0; +} + +int +bench_sha_page(void* _ctx) { + SHA256_CTX sha; + uint8_t hash[32]; + bench_ctx_t* ctx = (bench_ctx_t*)_ctx; + assert(sizeof(int) * RAND_BOOK_LEN >= 4096); + + sha256_init(&sha); + sha256_update(&sha, ctx->rand_book, 4096); + sha256_final(&sha, hash); + + __asm__ __volatile__("" ::"r"(hash)); + return 0; +} + +int +main(int argc, char** argv) { + srand(time(NULL)); + bench_ctx_t ctx = {}; + bench_ctx__init(&ctx); + + struct bench benches[] = { + {.name = "sha 1 byte", + .init = nop, + .deinit = nop, + .iter = bench_sha_1byte}, + {.name = "sha 32 byte", + .init = nop, + .deinit = nop, + .iter = bench_sha_32byte}, + {.name = "sha 4096 bytes", + .init = nop, + .deinit = nop, + .iter = bench_sha_page}, + }; + struct bench_opts opts = bench_argp(argc, argv); + run_benches(&opts, benches, sizeof(benches) / sizeof(struct bench), &ctx); +} diff --git a/bench/hmac.c b/bench/hmac.c new file mode 100644 index 0000000..fdbf95e --- /dev/null +++ b/bench/hmac.c @@ -0,0 +1,65 @@ +#define _GNU_SOURCE + +#include <../hmac.h> +#include +#include +#include +#include +#include +#include +#include + +#define RAND_BOOK_LEN (4 * 1024 * 1024) + +typedef struct bench_ctx { + int* rand_book; + uint8_t* key; + uint8_t* page; +} bench_ctx_t; + +void +bench_ctx__init(bench_ctx_t* ctx) { + ctx->rand_book = (int*)malloc(sizeof(int) * RAND_BOOK_LEN); + for (int i = 0; i < RAND_BOOK_LEN; i++) { + ctx->rand_book[i] = rand(); + } +} + +int +hmac_page__init(void* _ctx) { + bench_ctx_t* ctx = (bench_ctx_t*)_ctx; + ctx->key = (uint8_t*)ctx->rand_book + rand() % (RAND_BOOK_LEN - 32); + ctx->page = (uint8_t*)ctx->rand_book + rand() % (RAND_BOOK_LEN - 4096); + return 0; +} + +int +hmac_page__destroy(void* _ctx) { + (void)_ctx; + return 0; +} + +int +hmac_page(void* _ctx) { + bench_ctx_t* ctx = (bench_ctx_t*)_ctx; + uint8_t hash[32]; + int err = hmac(ctx->page, 4096, ctx->key, 32, hash); + __asm__ __volatile__("" ::"r"(hash)); + return err; +} + +int +main(int argc, char** argv) { + srand(time(NULL)); + bench_ctx_t ctx = {}; + bench_ctx__init(&ctx); + + struct bench benches[] = { + {.name = "hmac page", + .init = hmac_page__init, + .deinit = hmac_page__destroy, + .iter = hmac_page}, + }; + struct bench_opts opts = bench_argp(argc, argv); + run_benches(&opts, benches, sizeof(benches) / sizeof(struct bench), &ctx); +} diff --git a/bench/page_swap.c b/bench/page_swap.c new file mode 100644 index 0000000..81aaf70 --- /dev/null +++ b/bench/page_swap.c @@ -0,0 +1,183 @@ +#define _GNU_SOURCE + +#include "../page_swap.h" + +#include +#include +#include +#include +#include + +#include "../freemem.h" +#include "../paging.h" +#include "../vm_defs.h" +#include "bencher.h" + +void +sbi_exit_enclave(uintptr_t code) { + exit(code); +} + +size_t +rt_util_getrandom(void* vaddr, size_t buflen) { + uint8_t* charbuf = (uint8_t*)vaddr; + for (size_t i = 0; i < buflen; i++) charbuf[i] = rand(); + return buflen; +} + +uintptr_t +sbi_random() { + uintptr_t out; + rt_util_getrandom(&out, sizeof out); + return out; +} + +bool +paging_epm_inbounds(uintptr_t addr) { + (void)addr; + return true; +} + +static void* backing_region; +#define BACKING_REGION_SIZE (2 * 1024 * 1024) + +bool +paging_backpage_inbounds(uintptr_t addr) { + return (addr >= (uintptr_t)backing_region) && + (addr < (uintptr_t)backing_region + BACKING_REGION_SIZE); +} + +uintptr_t +paging_backing_region() { + if (!backing_region) { + backing_region = mmap( + NULL, BACKING_REGION_SIZE, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + assert(backing_region != MAP_FAILED); + } + return (uintptr_t)backing_region; +} +uintptr_t +paging_backing_region_size() { + return BACKING_REGION_SIZE; +} + +uintptr_t +__va(uintptr_t pa) { + return pa; +} + +uintptr_t +paging_evict_and_free_one(uintptr_t swap_va) { + (void)swap_va; + assert(false); +} + +bool +spa_page_inbounds(uintptr_t page_addr) { + (void)page_addr; + return true; +} + +#define VM_REGION_SIZE (2 * 1024 * 1024) +#define VM_REGION_PAGES (VM_REGION_SIZE / RISCV_PAGE_SIZE) +#define SWAPPABLE_PAGES 128 +#define RAND_BOOK_SIZE 4096 + +typedef struct { + void* vm_region; + uint64_t swapped_out[SWAPPABLE_PAGES / 64]; + uintptr_t swappable_pages_front[SWAPPABLE_PAGES]; + uintptr_t swappable_pages_back[SWAPPABLE_PAGES]; + + int rand_book[RAND_BOOK_SIZE]; + size_t rand_book_idx; +} bench_ctx_t; + +static int +bench_ctx__init(bench_ctx_t* ctx) { + ctx->vm_region = mmap( + NULL, VM_REGION_SIZE, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, + -1, 0); + if (ctx->vm_region == MAP_FAILED) return -1; + + for (size_t i = 0; i < RAND_BOOK_SIZE; i++) ctx->rand_book[i] = rand(); + return 0; +} + +static int +bench_ctx__destroy(bench_ctx_t* ctx) { + return munmap(ctx->vm_region, VM_REGION_SIZE); +} + +static int +bench_ctx__rand(bench_ctx_t* ctx) { + int out = ctx->rand_book[ctx->rand_book_idx]; + ctx->rand_book_idx = (ctx->rand_book_idx + 1) % RAND_BOOK_SIZE; + return out; +} + +static int +page_swap__init(void* _ctx) { + bench_ctx_t* ctx = (bench_ctx_t*)_ctx; + + spa_init((uintptr_t)ctx->vm_region, VM_REGION_SIZE); + pswap_init(); + memset(ctx->swapped_out, 0, sizeof ctx->swapped_out); + + for (size_t i = 0; i < SWAPPABLE_PAGES; i++) { + ctx->swappable_pages_front[i] = spa_get(); + ctx->swappable_pages_back[i] = paging_alloc_backing_page(); + rt_util_getrandom((void*)ctx->swappable_pages_front[i], RISCV_PAGE_SIZE); + } + + ctx->rand_book_idx = rand() % RAND_BOOK_SIZE; + return 0; +} + +static int +page_swap__destroy(void* _ctx) { + (void)_ctx; + return 0; +} + +static int +page_swap(void* _ctx) { + bench_ctx_t* ctx = (bench_ctx_t*)_ctx; + + // Choose a page (excluding the first page, which is used by the SPA) + size_t which_page = bench_ctx__rand(ctx) % SWAPPABLE_PAGES; + uintptr_t front_page = ctx->swappable_pages_front[which_page]; + uintptr_t back_page = ctx->swappable_pages_back[which_page]; + uintptr_t swap_page; + + if (ctx->swapped_out[which_page / 64] & (1ull << (which_page % 64))) { + // Have already swapped this page out + swap_page = back_page; + } else { + swap_page = 0; + ctx->swapped_out[which_page / 64] |= (1ull << (which_page % 64)); + } + + return page_swap_epm(back_page, front_page, swap_page); +} + +int +main(int argc, char** argv) { + srand(time(NULL)); + bench_ctx_t ctx = {}; + int err = bench_ctx__init(&ctx); + assert(!err); + + struct bench benches[] = { + {.name = "page swap", + .init = page_swap__init, + .deinit = page_swap__destroy, + .iter = page_swap}, + }; + struct bench_opts opts = bench_argp(argc, argv); + run_benches(&opts, benches, sizeof(benches) / sizeof(struct bench), &ctx); + + err = bench_ctx__destroy(&ctx); + assert(!err); +} \ No newline at end of file diff --git a/boot.c b/boot.c index ebf69da..a806c96 100644 --- a/boot.c +++ b/boot.c @@ -15,6 +15,17 @@ extern uintptr_t shared_buffer; extern uintptr_t shared_buffer_size; +uintptr_t load_pa_start; +uintptr_t +__va(uintptr_t pa) { + return (pa - load_pa_start) + EYRIE_LOAD_START; +} + +uintptr_t +__pa(uintptr_t va) { + return (va - EYRIE_LOAD_START) + load_pa_start; +} + /* initial memory layout */ uintptr_t utm_base; size_t utm_size; diff --git a/build.sh b/build.sh index c939bf4..3080787 100755 --- a/build.sh +++ b/build.sh @@ -13,8 +13,8 @@ PLUGINS[linux_syscall]="-DLINUX_SYSCALL_WRAPPING " PLUGINS[env_setup]="-DENV_SETUP " PLUGINS[strace_debug]="-DINTERNAL_STRACE " PLUGINS[paging]="-DUSE_PAGING -DUSE_FREEMEM " -PLUGINS[page_crypto]="-DPAGE_CRYPTO " -PLUGINS[page_hash]="-DPAGE_HASH " +PLUGINS[page_crypto]="-DUSE_PAGE_CRYPTO " +PLUGINS[page_hash]="-DUSE_PAGE_HASH " PLUGINS[debug]="-DDEBUG " #PLUGINS[dynamic_resizing]="-DDYN_ALLOCATION " diff --git a/freemem.c b/freemem.c index ef93db9..997129b 100644 --- a/freemem.c +++ b/freemem.c @@ -1,9 +1,10 @@ #ifdef USE_FREEMEM -#include "string.h" -#include "common.h" -#include "vm.h" #include "freemem.h" + +#include "common.h" #include "paging.h" +#include "string.h" +#include "vm_defs.h" /* This file implements a simple page allocator (SPA) * which stores the pages based on a linked list. @@ -16,8 +17,17 @@ * spa_free_pages will only hold the head and the tail pages so that * SPA can allocate/free a page in constant time. */ +uintptr_t freemem_va_start; +size_t freemem_size; + static struct pg_list spa_free_pages; +__attribute__((weak)) bool +spa_page_inbounds(uintptr_t page_addr) { + return page_addr >= EYRIE_LOAD_START && + page_addr < (freemem_va_start + freemem_size); +} + /* get a free page from the simple page allocator */ uintptr_t __spa_get(bool zero) @@ -48,7 +58,7 @@ __spa_get(bool zero) spa_free_pages.head = next; spa_free_pages.count--; - assert(free_page > EYRIE_LOAD_START && free_page < (freemem_va_start + freemem_size)); + assert(spa_page_inbounds(free_page)); if (zero) memset((void*)free_page, 0, RISCV_PAGE_SIZE); @@ -66,7 +76,7 @@ spa_put(uintptr_t page_addr) uintptr_t prev; assert(IS_ALIGNED(page_addr, RISCV_PAGE_BITS)); - assert(page_addr >= EYRIE_LOAD_START && page_addr < (freemem_va_start + freemem_size)); + assert(spa_page_inbounds(page_addr)); if (!LIST_EMPTY(spa_free_pages)) { prev = spa_free_pages.tail; diff --git a/freemem.h b/freemem.h index 9c88afa..31bc125 100644 --- a/freemem.h +++ b/freemem.h @@ -3,6 +3,9 @@ #ifndef __FREEMEM_H__ #define __FREEMEM_H__ +#include +#include + #define NEXT_PAGE(page) *((uintptr_t*)page) #define LIST_EMPTY(list) ((list).count == 0 || (list).head == 0) #define LIST_INIT(list) { (list).count = 0; (list).head = 0; (list).tail = 0; } @@ -19,5 +22,8 @@ uintptr_t spa_get(void); uintptr_t spa_get_zero(void); void spa_put(uintptr_t page); unsigned int spa_available(); + +extern uintptr_t freemem_va_start; +extern size_t freemem_size; #endif #endif diff --git a/hmac.c b/hmac.c new file mode 100644 index 0000000..5c31f18 --- /dev/null +++ b/hmac.c @@ -0,0 +1,193 @@ +// Adapted from Yubico's HMAC implementation. Licensed under BSD 2-clause +// license. Copyright (c) 2006-2013 Yubico AB + +/**************************** hmac.c ****************************/ +/******************** See RFC 4634 for details ******************/ +/* + * Description: + * This file implements the HMAC algorithm (Keyed-Hashing for + * Message Authentication, RFC2104), expressed in terms of the + * various SHA algorithms. + */ + +#ifdef USE_PAGE_HASH + +#include "hmac.h" + +/* + * hmac + * + * Description: + * This function will compute an HMAC message digest. + * + * Parameters: + * key: [in] + * The secret shared key. + * key_len: [in] + * The length of the secret shared key. + * message_array: [in] + * An array of characters representing the message. + * length: [in] + * The length of the message in message_array + * digest: [out] + * Where the digest is returned. + * NOTE: The length of the digest is determined by + * the value of whichSha. + * + * Returns: + * sha Error Code. + * + */ +int +hmac( + const unsigned char* text, int text_len, const unsigned char* key, + int key_len, uint8_t digest[_USHAHashSize]) { + HMACContext ctx; + return hmacReset(&ctx, key, key_len) || hmacInput(&ctx, text, text_len) || + hmacResult(&ctx, digest); +} + +/* + * hmacReset + * + * Description: + * This function will initialize the hmacContext in preparation + * for computing a new HMAC message digest. + * + * Parameters: + * context: [in/out] + * The context to reset. + * key: [in] + * The secret shared key. + * key_len: [in] + * The length of the secret shared key. + * + * Returns: + * sha Error Code. + * + */ +int +hmacReset(HMACContext* ctx, const unsigned char* key, int key_len) { + int i, blocksize, hashsize; + + /* inner padding - key XORd with ipad */ + unsigned char k_ipad[USHA_Message_Block_Size]; + + /* temporary buffer when keylen > blocksize */ + unsigned char tempkey[_USHAHashSize]; + + if (!ctx) return shaNull; + + blocksize = ctx->blockSize = USHABlockSize(); + hashsize = ctx->hashSize = USHAHashSize(); + + /* + * If key is longer than the hash blocksize, + * reset it to key = HASH(key). + */ + if (key_len > blocksize) { + USHAContext tctx; + int err = USHAReset(&tctx) || USHAInput(&tctx, key, key_len) || + USHAResult(&tctx, tempkey); + if (err != shaSuccess) return err; + + key = tempkey; + key_len = hashsize; + } + + /* + * The HMAC transform looks like: + * + * SHA(K XOR opad, SHA(K XOR ipad, text)) + * + * where K is an n byte key. + * ipad is the byte 0x36 repeated blocksize times + * opad is the byte 0x5c repeated blocksize times + * and text is the data being protected. + */ + + /* store key into the pads, XOR'd with ipad and opad values */ + for (i = 0; i < key_len; i++) { + k_ipad[i] = key[i] ^ 0x36; + ctx->k_opad[i] = key[i] ^ 0x5c; + } + /* remaining pad bytes are '\0' XOR'd with ipad and opad values */ + for (; i < blocksize; i++) { + k_ipad[i] = 0x36; + ctx->k_opad[i] = 0x5c; + } + + /* perform inner hash */ + /* init context for 1st pass */ + return USHAReset(&ctx->shaContext) || + /* and start with inner pad */ + USHAInput(&ctx->shaContext, k_ipad, blocksize); +} + +/* + * hmacInput + * + * Description: + * This function accepts an array of octets as the next portion + * of the message. + * + * Parameters: + * context: [in/out] + * The HMAC context to update + * message_array: [in] + * An array of characters representing the next portion of + * the message. + * length: [in] + * The length of the message in message_array + * + * Returns: + * sha Error Code. + * + */ +int +hmacInput(HMACContext* ctx, const unsigned char* text, int text_len) { + if (!ctx) return shaNull; + /* then text of datagram */ + return USHAInput(&ctx->shaContext, text, text_len); +} + +/* + * HMACResult + * + * Description: + * This function will return the N-byte message digest into the + * Message_Digest array provided by the caller. + * NOTE: The first octet of hash is stored in the 0th element, + * the last octet of hash in the Nth element. + * + * Parameters: + * context: [in/out] + * The context to use to calculate the HMAC hash. + * digest: [out] + * Where the digest is returned. + * NOTE 2: The length of the hash is determined by the value of + * whichSha that was passed to hmacReset(). + * + * Returns: + * sha Error Code. + * + */ +int +hmacResult(HMACContext* ctx, uint8_t* digest) { + if (!ctx) return shaNull; + + /* finish up 1st pass */ + /* (Use digest here as a temporary buffer.) */ + return USHAResult(&ctx->shaContext, digest) || + /* perform outer SHA */ + /* init context for 2nd pass */ + USHAReset(&ctx->shaContext) || + /* start with outer pad */ + USHAInput(&ctx->shaContext, ctx->k_opad, ctx->blockSize) || + /* then results of 1st hash */ + USHAInput(&ctx->shaContext, digest, ctx->hashSize) || + /* finish up 2nd pass */ + USHAResult(&ctx->shaContext, digest); +} + +#endif // USE_PAGE_HASH diff --git a/hmac.h b/hmac.h new file mode 100644 index 0000000..97e9721 --- /dev/null +++ b/hmac.h @@ -0,0 +1,77 @@ +// Adapted from Yubico's HMAC implementation. Licensed under BSD 2-clause +// license. Copyright (c) 2006-2013 Yubico AB + +#ifndef _SHA_H_ +#define _SHA_H_ + +#include + +#ifndef _SHA_enum_ +#define _SHA_enum_ +/* + * All SHA functions return one of these values. + */ +enum { + shaSuccess = 0, + shaNull, /* Null pointer parameter */ + shaInputTooLong, /* input data too long */ + shaStateError, /* called Input after FinalBits or Result */ + shaBadParam /* passed a bad parameter */ +}; +#endif /* _SHA_enum_ */ + +#include "sha256.h" +#define _USHAHashSize 32 +#define USHA_Message_Block_Size 64 +#define USHAReset(...) (sha256_init(__VA_ARGS__), 0) +#define USHAInput(...) (sha256_update(__VA_ARGS__), 0) +#define USHAResult(...) (sha256_final(__VA_ARGS__), 0) +#define USHABlockSize() USHA_Message_Block_Size +#define USHAHashSize() _USHAHashSize +#define USHAHashSizeBits() (USHAHashSize() * 8) + +/* + * This structure will hold context information for the HMAC + * keyed hashing operation. + */ +typedef struct HMACContext { + int hashSize; /* hash size of SHA being used */ + int blockSize; /* block size of SHA being used */ + SHA256_CTX shaContext; /* SHA context */ + unsigned char k_opad[USHA_Message_Block_Size]; + /* outer padding - key XORd with opad */ +} HMACContext; + +typedef SHA256_CTX USHAContext; + +/* + * Function Prototypes + */ + +/* + * HMAC Keyed-Hashing for Message Authentication, RFC2104, + * for all SHAs. + * This interface allows a fixed-length text input to be used. + */ +extern int +hmac( + const unsigned char* text, /* pointer to data stream */ + int text_len, /* length of data stream */ + const unsigned char* key, /* pointer to authentication key */ + int key_len, /* length of authentication key */ + uint8_t digest[_USHAHashSize]); /* caller digest to fill in */ + +/* + * HMAC Keyed-Hashing for Message Authentication, RFC2104, + * for all SHAs. + * This interface allows any length of text input to be used. + */ +extern int +hmacReset(HMACContext* ctx, const unsigned char* key, int key_len); +extern int +hmacInput(HMACContext* ctx, const unsigned char* text, int text_len); + +extern int +hmacResult(HMACContext* ctx, uint8_t digest[_USHAHashSize]); + +#endif /* _SHA_H_ */ diff --git a/merkle.c b/merkle.c deleted file mode 100644 index 1796044..0000000 --- a/merkle.c +++ /dev/null @@ -1,342 +0,0 @@ -#if defined(USE_PAGE_HASH) - -#include "merkle.h" - -#include -#include -#include - -#include "paging.h" -#include "sha256.h" -#include "vm_defs.h" - -#ifndef MERK_SILENT -#define MERK_LOG printf -#else -#define MERK_LOG(...) -#endif - -_Static_assert(sizeof(merkle_node_t) == 64, "merkle_node_t is not 64 bytes!"); - -#define MERK_NODES_PER_PAGE (RISCV_PAGE_SIZE / sizeof(merkle_node_t)) - -typedef struct merkle_page_freelist { - uint64_t free[MERK_NODES_PER_PAGE / 64]; - uint16_t free_count; - bool in_freelist; - struct merkle_page_freelist* next; -} merkle_page_freelist_t; - -_Static_assert( - sizeof(merkle_page_freelist_t) <= sizeof(merkle_node_t), - "merkle_page_freelist_t does not fit in one merkle_node_t!"); - -static merkle_page_freelist_t* -merk_alloc_page(void) { - void* page = (void*)paging_alloc_backing_page(); - merkle_page_freelist_t* free_list = (merkle_page_freelist_t*)page; - memset(free_list, 0, sizeof(*free_list)); - - for (size_t i = 0; i < MERK_NODES_PER_PAGE; i += 64) { - size_t this_page_nodes = MERK_NODES_PER_PAGE - i; - free_list->free[i / 64] = - (this_page_nodes < 64) * (1ull << this_page_nodes) - 1; - } - free_list->free[0] &= ~(uint64_t)1; - free_list->free_count = MERK_NODES_PER_PAGE - 1; - - return free_list; -} - -static merkle_page_freelist_t* merk_free_list = NULL; - -static merkle_node_t* -merk_reserve_node_in_page(merkle_page_freelist_t* free_list) { - if (!free_list->free_count) return NULL; - - for (size_t i = 0; i < MERK_NODES_PER_PAGE / 64; i++) { - if (free_list->free[i]) { - size_t free_idx = __builtin_ctzll(free_list->free[i]); - free_list->free[i] &= ~(1ull << free_idx); - free_list->free_count--; - - merkle_node_t* page = (merkle_node_t*)free_list; - assert(free_idx != 0); - - return page + free_idx; - } - } - return NULL; -} - -static merkle_node_t* -merk_alloc_node(void) { - while (merk_free_list && merk_free_list->free_count == 0) { - // Clear out the unfree lists - merk_free_list->in_freelist = false; - merk_free_list = merk_free_list->next; - } - - if (!merk_free_list) { - merk_free_list = merk_alloc_page(); - merk_free_list->in_freelist = true; - } - - merkle_node_t* out = merk_reserve_node_in_page(merk_free_list); - return out; -} - -static void -merk_free_node(merkle_node_t* node) { - uintptr_t page = (uintptr_t)node & ~(RISCV_PAGE_SIZE - 1); - merkle_page_freelist_t* free_list = (merkle_page_freelist_t*)page; - size_t idx = node - (merkle_node_t*)page; - - assert(idx < MERK_NODES_PER_PAGE); - assert((free_list->free[idx / 64] & (1ull << (idx % 64))) == 0); - - free_list->free[idx / 64] |= (1ull << (idx % 64)); - free_list->free_count++; - - if (!free_list->in_freelist) { - free_list->next = merk_free_list; - merk_free_list = free_list; - free_list->in_freelist = true; - } -} - -static bool -merk_verify_single_node( - const merkle_node_t* node, const merkle_node_t* left, - const merkle_node_t* right) { - SHA256_CTX hasher; - uint8_t calculated_hash[32]; - - sha256_init(&hasher); - - if (left) { - sha256_update(&hasher, (uint8_t*)&left->ptr, sizeof right->ptr); - sha256_update(&hasher, left->hash, 32); - } - if (right) { - sha256_update(&hasher, (uint8_t*)&right->ptr, sizeof right->ptr); - sha256_update(&hasher, right->hash, 32); - } - - if (!left && !right) { - return true; - } - - sha256_final(&hasher, calculated_hash); - return memcmp(calculated_hash, node->hash, 32) == 0; -} - -static void -merk_hash_single_node( - merkle_node_t* node, const merkle_node_t* left, - const merkle_node_t* right) { - SHA256_CTX hasher; - sha256_init(&hasher); - if (left) { - sha256_update(&hasher, (uint8_t*)&left->ptr, sizeof left->ptr); - sha256_update(&hasher, left->hash, 32); - } - if (right) { - sha256_update(&hasher, (uint8_t*)&right->ptr, sizeof right->ptr); - sha256_update(&hasher, right->hash, 32); - } - sha256_final(&hasher, node->hash); -} - -bool -merk_verify( - volatile merkle_node_t* root, uintptr_t key, const uint8_t hash[32]) { - merkle_node_t node = *root; - if (!root->right) return false; - - merkle_node_t left; - merkle_node_t right = *root->right; - - // Verify root node - if (!merk_verify_single_node(&node, NULL, &right)) { - MERK_LOG("Error verifying root!\n"); - return false; - } - - node = right; - - for (int i = 0;; i++) { - // node is a leaf, so return its hash check - if (!node.left && !node.right) { - return memcmp(hash, node.hash, 32) == 0; - } - - // Load in the next layer. This is to prevent race conditions - if (node.left) left = *(volatile merkle_node_t*)node.left; - if (node.right) right = *(volatile merkle_node_t*)node.right; - - bool node_ok = merk_verify_single_node( - &node, node.left ? &left : NULL, node.right ? &right : NULL); - if (!node_ok) { - MERK_LOG("Error at node with ptr %zx in layer %d\n", node.ptr, i); - return false; - } - - // BST traversal - if (key < node.ptr) { - node = left; - } else { - node = right; - } - } -} - -// Insert a node at the leaf position. May insert a new intermediate node or -// overwrite an existing one. Returns the node modified. -static merkle_node_t* -merk_splice_node(merkle_node_t* leaf, merkle_node_t* node) { - if (node->ptr == leaf->ptr) { - // We've specified a key that already exists, so overwrite the old node. - merk_free_node(leaf); - return node; - } - - merkle_node_t* new_parent = merk_alloc_node(); - - if (node->ptr < leaf->ptr) { - *new_parent = (merkle_node_t){ - .ptr = leaf->ptr, - .left = node, - .right = leaf, - }; - merk_hash_single_node(new_parent, node, leaf); - } else { - *new_parent = (merkle_node_t){ - .ptr = node->ptr, - .left = leaf, - .right = node, - }; - merk_hash_single_node(new_parent, leaf, node); - } - - return new_parent; -} - -#define MERK_MAX_DEPTH 32 -static merkle_node_t* intermediate_nodes[MERK_MAX_DEPTH] = {}; - -int -merk_insert(merkle_node_t* root, uintptr_t key, const uint8_t hash[32]) { - merkle_node_t new_node_data = { - .ptr = key, - }; - memcpy(new_node_data.hash, hash, 32); - - merkle_node_t* new_node = merk_alloc_node(); - *(volatile merkle_node_t*)new_node = new_node_data; - - // The root never contains data, only a single pointer to the start - // of data on its right side. - // This is to better ensure a total split between the root and other - // nodes, as the root is merely a "guardian" which must reside in secure - // memory while others don't need to. - if (!root->right) { - merk_hash_single_node(root, NULL, &new_node_data); - root->right = new_node; - return 0; - } - - intermediate_nodes[0] = root; - int i; - - for (i = 0; i < MERK_MAX_DEPTH - 1; i++) { - // Walk down the BST to find an appropriate location to store our new node. - - merkle_node_t* parent = intermediate_nodes[i]; - - // Traverse the BST - - bool traverse_left = key < parent->ptr; - bool child_idx = traverse_left ^ 1; - intermediate_nodes[i + 1] = parent->children[child_idx]; - if (!intermediate_nodes[i + 1]) break; - } - - if (i == MERK_MAX_DEPTH) { - printf( - "Inserted to merkle tree with problematic key insertion order! " - "This has led to an unbalanced tree exceeding the depth capacity of " - "%d. " - "Aborting!", - MERK_MAX_DEPTH); - assert(false); - } - - merkle_node_t curr_node = *intermediate_nodes[i]; - - for (; i > 0; i--) { - // Here we walk back up the tree to percolate up our new hashes. - // We keep the previous iteration's merkle node, as well as the current - // iteration's merkle node, in secure memory to avoid any race conditions - // after writing to DRAM in the last step. - // - // We otherwise aren't concerned that an attacker will modify any data - // in the tree during this stage, as doing so will compromise the integrity - // of the tree such that the user will be alerted upon attempting any - // accesses to the compromised location. - // - // We also mark accesses to parent_ptr as volatile to ensure they get - // written and read with the correct access pattern. - - merkle_node_t* parent_ptr = intermediate_nodes[i - 1]; - merkle_node_t* node_ptr = intermediate_nodes[i]; - merkle_node_t parent = *(volatile merkle_node_t*)parent_ptr; - merkle_node_t sibling; - bool has_sibling = parent.left && parent.right; - int node_idx = parent.right == node_ptr; - - if (has_sibling) - sibling = *(volatile merkle_node_t*)parent.children[!node_idx]; - - // Check to see that the sibling we pull is valid. - // We don't care about node_ptr races here, because if it's been tampered - // with, verifying against parent will fail (eventually). And we already - // stored the node data we'll be using later on in curr_node, anyway. - if (has_sibling) { - const merkle_node_t* copied_children[2]; - copied_children[node_idx] = node_ptr; - copied_children[!node_idx] = &sibling; - if (!merk_verify_single_node( - &parent, copied_children[0], copied_children[1])) - return -1; - } - - // At the leaf, insert the new node. merk_splice_node will handle updating - // the hashes, so if we insert a new intermediate node we can treat that as - // the new "leaf" / bottom-layer node. - if (!curr_node.left && !curr_node.right) { - node_ptr = intermediate_nodes[i] = merk_splice_node(node_ptr, new_node); - curr_node = *node_ptr; - - parent.children[node_idx] = node_ptr; - } - - // Hash our data from the saved curr_node and sibling. - assert(node_ptr == parent.children[node_idx]); - const merkle_node_t* copied_children[2]; - copied_children[node_idx] = &curr_node; - copied_children[!node_idx] = has_sibling ? &sibling : NULL; - merk_hash_single_node(&parent, copied_children[0], copied_children[1]); - - // Writeback the previous changed node now - *(volatile merkle_node_t*)node_ptr = curr_node; - curr_node = parent; - } - - // Writeback the final changed node - *(volatile merkle_node_t*)root = curr_node; - - return 0; -} - -#endif diff --git a/merkle.h b/merkle.h deleted file mode 100644 index d35ea2f..0000000 --- a/merkle.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#if defined(USE_FREEMEM) && defined(USE_PAGING) - -#include -#include -#include - -typedef union merkle_node { - struct { - uintptr_t ptr; - uint8_t hash[32]; - union { - struct { - union merkle_node *left, *right; - }; - union merkle_node* children[2]; - }; - }; - struct { - uint64_t raw_words[8]; - }; -} merkle_node_t; - -int -merk_insert(merkle_node_t* root, uintptr_t key, const uint8_t hash[32]); -bool -merk_verify( - volatile merkle_node_t* root, uintptr_t key, const uint8_t hash_out[32]); - -#endif diff --git a/page_swap.c b/page_swap.c index 2cb9aa6..ae023e1 100644 --- a/page_swap.c +++ b/page_swap.c @@ -2,27 +2,58 @@ #if defined(USE_FREEMEM) && defined(USE_PAGING) +#include #include #include #include "aes.h" -#include "merkle.h" +#include "hmac.h" #include "paging.h" #include "sbi.h" #include "sha256.h" #include "vm_defs.h" +typedef struct { + uint8_t val[32]; +} hash_t; + +// Specify the amount of pages of swap counters. This works out to 24*(4096/8) +// or 12288 swappable pages; 50MB of memory. These counters are stored in +// on-chip memory. #define NUM_CTR_INDIRECTS 24 -static uintptr_t ctr_indirect_ptrs[NUM_CTR_INDIRECTS]; +#define NUM_CTR_INTERNAL (RISCV_PAGE_SIZE / 8) +typedef struct { + uint64_t buf[NUM_CTR_INTERNAL]; +} ctr_page_t; +static ctr_page_t* ctr_indirect_ptrs[NUM_CTR_INDIRECTS]; + +// Specify the amount of pages of HMACs. This should work out to the same amount +// of swappable pages as the ctrs. +#define NUM_HMAC_INDIRECTS 96 +#define NUM_HMAC_INTERNAL (RISCV_PAGE_SIZE / sizeof(hash_t)) +typedef struct { + hash_t buf[NUM_HMAC_INTERNAL]; +} hmac_page_t; +static hmac_page_t* hmac_indirect_ptrs[NUM_HMAC_INDIRECTS]; + +static_assert( + NUM_CTR_INDIRECTS * NUM_CTR_INTERNAL == + NUM_HMAC_INDIRECTS * NUM_HMAC_INTERNAL, + "Number of page swap counters does not match number of HMACs!"); + +// Set-once global keys +static uint8_t aes_key[32]; +static uint8_t hmac_key[32]; static uintptr_t paging_next_backing_page_offset; -static uintptr_t paging_inc_backing_page_offset_by; - +/// Provide a mechanism for allocating off-chip pages backing swapped ones. +/// There is no way to deallocate after calling this function, except to +/// reinitialize the page swap subsystem. This gives us a simple way to linearly +/// allocate page counters and HMACs as well. uintptr_t paging_alloc_backing_page() { - uintptr_t offs_update = - (paging_next_backing_page_offset + paging_inc_backing_page_offset_by) % - paging_backing_region_size(); + uintptr_t offs_update = (paging_next_backing_page_offset + RISCV_PAGE_SIZE) % + paging_backing_region_size(); /* no backing page available */ if (offs_update == 0) { @@ -45,94 +76,71 @@ paging_remaining_pages() { RISCV_PAGE_SIZE; } -static uintptr_t -gcd(uintptr_t a, uintptr_t b) { - while (b) { - uintptr_t tmp = b; - b = a % b; - a = tmp; - } - return a; -} - -static uintptr_t -find_coprime_of(uintptr_t n) { - uintptr_t res; - do { - res = n / 2 + sbi_random() % (n / 2); - } while (gcd(res, n) != 1); - return res; -} - void pswap_init(void) { - uintptr_t backing_pages = paging_backing_region_size() / RISCV_PAGE_SIZE; - uintptr_t inc = find_coprime_of(backing_pages); + paging_next_backing_page_offset = 0; - paging_inc_backing_page_offset_by = inc * RISCV_PAGE_SIZE; - warn("num_pages = %zx, pagesize_inc = %zx", backing_pages, inc); + rt_util_getrandom(aes_key, 32); + rt_util_getrandom(hmac_key, 32); - paging_next_backing_page_offset = 0; + memset(ctr_indirect_ptrs, 0, sizeof ctr_indirect_ptrs); + memset(hmac_indirect_ptrs, 0, sizeof hmac_indirect_ptrs); } +/// Find the particular pageout counter for a given page. These counters are +/// incremented on every page swap and are concatenated to the HMAC input to +/// prevent replay attacks. To avoid needing to verify integrity of counters, +/// they are stored in on-chip memory. static uint64_t* pswap_pageout_ctr(uintptr_t page) { assert(paging_backpage_inbounds(page)); size_t idx = (page - paging_backing_region()) >> RISCV_PAGE_BITS; - size_t indirect_idx = idx / (RISCV_PAGE_SIZE / 8); - size_t interior_idx = idx % (RISCV_PAGE_SIZE / 8); + size_t indirect_idx = idx / NUM_CTR_INTERNAL; + size_t interior_idx = idx % NUM_CTR_INTERNAL; assert(indirect_idx < NUM_CTR_INDIRECTS); if (!ctr_indirect_ptrs[indirect_idx]) { - ctr_indirect_ptrs[indirect_idx] = paging_alloc_backing_page(); + ctr_indirect_ptrs[indirect_idx] = (ctr_page_t*)spa_get(); // Fill ptr pages with random values so our counters start unpredictable - rt_util_getrandom((void*)ctr_indirect_ptrs[indirect_idx], RISCV_PAGE_SIZE); + rt_util_getrandom(ctr_indirect_ptrs[indirect_idx], RISCV_PAGE_SIZE); } - return (uint64_t*)(ctr_indirect_ptrs[indirect_idx]) + interior_idx; + return ctr_indirect_ptrs[indirect_idx]->buf + interior_idx; } -#ifdef USE_PAGE_CRYPTO -static volatile atomic_bool pswap_boot_key_reserved = false; -static volatile atomic_bool pswap_boot_key_set = false; -static uint8_t pswap_boot_key[32]; - -static void -pswap_establish_boot_key(void) { - uint8_t boot_key_tmp[32]; - - if (atomic_load(&pswap_boot_key_set)) { - // Key already set - return; - } - - rt_util_getrandom(boot_key_tmp, 32); +#ifdef USE_PAGE_HASH +/// Find the particular HMAC for a given page. These HMACs are stored in +/// off-chip memory, but that's acceptable because an attacker can neither forge +/// it, nor replay it. Replay security comes from a counter concatenated to the +/// HMAC input. +static hash_t* +pswap_pageout_hmac(uintptr_t page) { + assert(paging_backpage_inbounds(page)); + size_t idx = (page - paging_backing_region()) >> RISCV_PAGE_BITS; + size_t indirect_idx = idx / NUM_HMAC_INTERNAL; + size_t interior_idx = idx % NUM_HMAC_INTERNAL; + assert(indirect_idx < NUM_HMAC_INDIRECTS); - if (atomic_flag_test_and_set(&pswap_boot_key_reserved)) { - // Lost the race; key already being set. Spin until finished. - while (!atomic_load(&pswap_boot_key_set)) - ; - return; + if (!hmac_indirect_ptrs[indirect_idx]) { + hmac_indirect_ptrs[indirect_idx] = + (hmac_page_t*)paging_alloc_backing_page(); + // Fill ptr pages with random values so our counters start unpredictable + memset(hmac_indirect_ptrs[indirect_idx], 0, RISCV_PAGE_SIZE); } - memcpy(pswap_boot_key, boot_key_tmp, 32); - atomic_store(&pswap_boot_key_set, true); + return hmac_indirect_ptrs[indirect_idx]->buf + interior_idx; } -#endif // USE_PAGE_CRYPTO - -#ifdef USE_PAGE_HASH -static merkle_node_t paging_merk_root = {}; -#endif +#endif // USE_PAGE_HASH +/// Copy page to destination, encrypting if USE_PAGE_CRYPTO is defined. static void pswap_encrypt(const void* addr, void* dst, uint64_t pageout_ctr) { size_t len = RISCV_PAGE_SIZE; #ifdef USE_PAGE_CRYPTO - pswap_establish_boot_key(); uint8_t iv[32] = {0}; WORD key_sched[80]; - aes_key_setup(pswap_boot_key, key_sched, 256); + aes_key_setup(aes_key, key_sched, 256); memcpy(iv + 8, &pageout_ctr, 8); @@ -142,15 +150,15 @@ pswap_encrypt(const void* addr, void* dst, uint64_t pageout_ctr) { #endif } +/// Copy page to destination, decrypting if USE_PAGE_CRYPTO is defined. static void pswap_decrypt(const void* addr, void* dst, uint64_t pageout_ctr) { size_t len = RISCV_PAGE_SIZE; #ifdef USE_PAGE_CRYPTO - pswap_establish_boot_key(); uint8_t iv[32] = {0}; WORD key_sched[80]; - aes_key_setup(pswap_boot_key, key_sched, 256); + aes_key_setup(aes_key, key_sched, 256); memcpy(iv + 8, &pageout_ctr, 8); @@ -160,23 +168,24 @@ pswap_decrypt(const void* addr, void* dst, uint64_t pageout_ctr) { #endif } -static void -pswap_hash(uint8_t* hash, void* page_addr, uint64_t pageout_ctr) { #ifdef USE_PAGE_HASH - SHA256_CTX hasher; +static void +pswap_hmac(uint8_t* hash, void* page_addr, uint64_t pageout_ctr) { + HMACContext ctx; - sha256_init(&hasher); - sha256_update(&hasher, page_addr, RISCV_PAGE_SIZE); - sha256_update(&hasher, (uint8_t*)&pageout_ctr, sizeof(pageout_ctr)); - sha256_final(&hasher, hash); -#endif + hmacReset(&ctx, hmac_key, sizeof hmac_key); + hmacInput(&ctx, page_addr, RISCV_PAGE_SIZE); + hmacInput(&ctx, (uint8_t*)&pageout_ctr, sizeof(pageout_ctr)); + hmacResult(&ctx, hash); } +#endif -/* evict a page from EPM and store it to the backing storage +/* Evict a page from EPM and store it to the backing storage. * back_page (PA1) <-- epm_page (PA2) <-- swap_page (PA1) - * if swap_page is 0, no need to write epm_page + * If swap_page is 0, no need to write epm_page. Otherwise, swap_page must + * point to the same address as back_page. */ -void +int page_swap_epm(uintptr_t back_page, uintptr_t epm_page, uintptr_t swap_page) { assert(paging_epm_inbounds(epm_page)); assert(paging_backpage_inbounds(back_page)); @@ -191,28 +200,33 @@ page_swap_epm(uintptr_t back_page, uintptr_t epm_page, uintptr_t swap_page) { uint64_t old_pageout_ctr = *pageout_ctr; uint64_t new_pageout_ctr = old_pageout_ctr + 1; - uint8_t new_hash[32]; - pswap_hash(new_hash, (void*)epm_page, new_pageout_ctr); +#ifdef USE_PAGE_HASH + hash_t* hmac = pswap_pageout_hmac(back_page); + uint8_t new_hmac[32]; + pswap_hmac(new_hmac, (void*)epm_page, new_pageout_ctr); +#endif + pswap_encrypt((void*)epm_page, (void*)back_page, new_pageout_ctr); if (swap_page) { - uint8_t old_hash[32]; pswap_decrypt((void*)buffer, (void*)epm_page, old_pageout_ctr); - pswap_hash(old_hash, (void*)epm_page, old_pageout_ctr); #ifdef USE_PAGE_HASH - bool ok = merk_verify(&paging_merk_root, back_page, old_hash); - assert(ok); + uint8_t old_hmac[32]; + pswap_hmac(old_hmac, (void*)epm_page, old_pageout_ctr); + if (memcmp(hmac, old_hmac, sizeof(hash_t))) { + return -1; + } #endif } #ifdef USE_PAGE_HASH - merk_insert(&paging_merk_root, back_page, new_hash); + *hmac = *(hash_t*)new_hmac; #endif *pageout_ctr = new_pageout_ctr; - return; + return 0; } #endif diff --git a/page_swap.h b/page_swap.h index b5476e6..8bcff2f 100644 --- a/page_swap.h +++ b/page_swap.h @@ -5,5 +5,5 @@ void pswap_init(void); -void +int page_swap_epm(uintptr_t back_page, uintptr_t epm_page, uintptr_t swap_page); diff --git a/paging.c b/paging.c index 0645bef..155b951 100644 --- a/paging.c +++ b/paging.c @@ -206,7 +206,8 @@ uintptr_t paging_evict_and_free_one(uintptr_t swap_va) assert(target_pte && (*target_pte & PTE_U)); src_pa = pte_ppn(*target_pte) << RISCV_PAGE_BITS; - page_swap_epm(dest_va, __va(src_pa), swap_va); + int err = page_swap_epm(dest_va, __va(src_pa), swap_va); + assert(!err); /* invalidate target PTE */ *target_pte = pte_create_invalid(ppn(__paging_pa(dest_va)), diff --git a/slice.h b/slice.h new file mode 100644 index 0000000..e4723df --- /dev/null +++ b/slice.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +typedef struct { + void* buf; + size_t len; +} slice_t; + +#define SLICE(ptr, _start, end) \ + ({ \ + size_t start = _start; \ + (slice_t){ \ + .buf = (ptr) + start, \ + .len = ((end)-start) * sizeof *(ptr), \ + }; \ + }) + +static inline void +slice_move(void* dst, slice_t src) { + memmove(dst, src.buf, src.len); +} +static inline void +slice_copy(void* dst, slice_t src) { + memcpy(dst, src.buf, src.len); +} +static inline void +slice_fill(slice_t dst, char c) { + memset(dst.buf, c, dst.len); +} \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7fb62c0..227481a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -9,12 +9,12 @@ include(AddCMockaTest) enable_testing() add_cmocka_test(test_string SOURCES string.c COMPILE_OPTIONS -I${CMAKE_BINARY_DIR}/cmocka/include LINK_LIBRARIES cmocka) -add_cmocka_test(test_merkle - SOURCES merkle.c ../sha256.c +add_cmocka_test(test_hmac + SOURCES hmac.c ../sha256.c COMPILE_OPTIONS -DUSE_PAGE_HASH -DUSE_PAGING -DUSE_FREEMEM -D__riscv_xlen=64 -I${CMAKE_BINARY_DIR}/cmocka/include -g LINK_LIBRARIES cmocka) add_cmocka_test(test_pageswap - SOURCES page_swap.c ../merkle.c ../sha256.c ../aes.c + SOURCES page_swap.c ../hmac.c ../sha256.c ../aes.c COMPILE_OPTIONS -DUSE_PAGE_HASH -DUSE_PAGE_CRYPTO -DUSE_PAGING -DUSE_FREEMEM -D__riscv_xlen=64 -I${CMAKE_BINARY_DIR}/cmocka/include -g LINK_LIBRARIES cmocka) diff --git a/test/hmac.c b/test/hmac.c new file mode 100644 index 0000000..3991bc7 --- /dev/null +++ b/test/hmac.c @@ -0,0 +1,33 @@ +#define _GNU_SOURCE + +#include "../hmac.c" + +#include +#include +#include +#include +#include + +#include "mock.h" + +void +test_create_hmac() { + uint8_t key[] = "null"; + uint8_t input[] = "null"; + uint8_t digest[USHAHashSize()]; + uint8_t digest_expected[] = {0xb4, 0xb1, 0x09, 0x3e, 0xb9, 0x42, 0x81, 0x6c, + 0x0d, 0x39, 0x99, 0xa9, 0x4b, 0xdf, 0x50, 0x08, + 0x2a, 0x4b, 0xe3, 0x02, 0x39, 0x62, 0x6e, 0xc6, + 0xf7, 0x65, 0xc7, 0x94, 0x39, 0x97, 0xfe, 0x89}; + int res = hmac(input, strlen(input), key, strlen(key), digest); + assert_int_equal(res, 0); + assert_memory_equal(digest, digest_expected, sizeof(digest)); +} + +int +main() { + const struct CMUnitTest tests[] = { + cmocka_unit_test(test_create_hmac), + }; + return cmocka_run_group_tests(tests, NULL, NULL); +} \ No newline at end of file diff --git a/test/merkle.c b/test/merkle.c deleted file mode 100644 index 31ae805..0000000 --- a/test/merkle.c +++ /dev/null @@ -1,368 +0,0 @@ -#define _GNU_SOURCE - -#include "../merkle.h" - -#include -#include -#include -#include - -#define MERK_SILENT -#include "../merkle.c" -#include "mock.h" - -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#define MIN(a, b) ((a) < (b) ? (a) : (b)) - -void -sbi_exit_enclave(uintptr_t code) { - exit(code); -} - -uintptr_t -paging_alloc_backing_page() { - void* out = mmap( - NULL, 4096, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - assert_int_not_equal(out, MAP_FAILED); - return (uintptr_t)out; -} - -#define RAND_REGION_ENTRIES 1000 -#define RAND_ENTRY_SIZE 64 - -const uint8_t* -random_region() { - static uint8_t* random_region_buf = NULL; - if (!random_region_buf) { - random_region_buf = (uint8_t*)malloc(RAND_REGION_ENTRIES * RAND_ENTRY_SIZE); - for (size_t i = 0; i < RAND_REGION_ENTRIES * RAND_ENTRY_SIZE; i++) { - random_region_buf[i] = (uint8_t)rand(); - } - } - return random_region_buf; -} - -size_t* -shuffled_idxs(size_t max) { - size_t* shuffled_idxs = (size_t*)malloc(sizeof(size_t) * max); - for (size_t i = 0; i < max; i++) shuffled_idxs[i] = i; - - for (size_t i = max - 1; i > 0; i--) { - size_t j = rand() % i + 1; - size_t tmp = shuffled_idxs[i]; - shuffled_idxs[i] = shuffled_idxs[j]; - shuffled_idxs[j] = tmp; - } - return shuffled_idxs; -} - -static merkle_node_t -random_region_insert(merkle_node_t* root) { - const uint8_t* region = random_region(); - size_t* idxs = shuffled_idxs(RAND_REGION_ENTRIES); - - SHA256_CTX sha; - - for (int i = 0; i < RAND_REGION_ENTRIES; i++) { - const uint8_t* subregion = region + idxs[i] * RAND_ENTRY_SIZE; - uint8_t hash[32]; - - sha256_init(&sha); - sha256_update(&sha, subregion, RAND_ENTRY_SIZE); - sha256_final(&sha, hash); - - int res = merk_insert(root, (uintptr_t)subregion, hash); - assert_int_equal(res, 0); - } - - free(idxs); -} - -static merkle_node_t -random_region_tree() { - merkle_node_t root = {}; - random_region_insert(&root); - return root; -} - -static size_t -count_verify_fails(merkle_node_t* tree) { - size_t total_verify_fails = 0; - SHA256_CTX sha; - - size_t* idxs = shuffled_idxs(RAND_REGION_ENTRIES); - - for (size_t ri = 0; ri < RAND_REGION_ENTRIES; ri++) { - const uint8_t* region = random_region() + idxs[ri] * RAND_ENTRY_SIZE; - uint8_t region_hash[32]; - sha256_init(&sha); - sha256_update(&sha, region, RAND_ENTRY_SIZE); - sha256_final(&sha, region_hash); - total_verify_fails += !merk_verify(tree, (uintptr_t)region, region_hash); - } - - free(idxs); - return total_verify_fails; -} - -struct merk_stats_s { - size_t max_depth, min_depth; - size_t elems, leaves; - double avg_depth; -}; - -struct merk_stats_s -merk_stats(const merkle_node_t* root) { - const merkle_node_t *left = root->left, *right = root->right; - - struct merk_stats_s out = { - .max_depth = 1, - .min_depth = 1, - .elems = 1, - .leaves = 0, - .avg_depth = 1, - }; - - if (!left && !right) { - out.leaves = 1; - return out; - } - - struct merk_stats_s lstats, rstats; - if (left) { - lstats = merk_stats(left); - out.max_depth = lstats.max_depth; - out.min_depth = lstats.min_depth; - out.elems += lstats.elems; - out.leaves += lstats.leaves; - out.avg_depth = lstats.avg_depth + 1; - } - if (right) { - rstats = merk_stats(right); - out.max_depth = rstats.max_depth; - out.min_depth = rstats.min_depth; - out.elems += rstats.elems; - out.leaves += rstats.leaves; - out.avg_depth = rstats.avg_depth + 1; - } - if (left && right) { - double both_elems = lstats.elems + rstats.elems; - double lweight = lstats.elems / both_elems; - double rweight = rstats.elems / both_elems; - out.max_depth = MAX(lstats.max_depth, rstats.max_depth) + 1; - out.min_depth = MIN(lstats.max_depth, rstats.max_depth) + 1; - out.avg_depth = lweight * lstats.avg_depth + rweight * rstats.avg_depth + 1; - } - return out; -} - -static void -test_verify_nonexistant() { - merkle_node_t root = {}; - uint8_t zeros[32] = {}; - assert_false(merk_verify(&root, 1, zeros)); -} - -static void -test_insert_and_verify_1() { - merkle_node_t root = {}; - const uint8_t* rand_hash = random_region(); - - int res = merk_insert(&root, 1, rand_hash); - assert_int_equal(res, 0); - assert_true(merk_verify(&root, 1, rand_hash)); -} - -static void -test_insert_and_verify_2() { - merkle_node_t root = {}; - const uint8_t* rand_hash_1 = random_region(); - const uint8_t* rand_hash_2 = random_region() + 32; - - int res = merk_insert(&root, 1, rand_hash_1); - assert_int_equal(res, 0); - res = merk_insert(&root, 2, rand_hash_2); - assert_int_equal(res, 0); - assert_true(merk_verify(&root, 1, rand_hash_1)); - assert_true(merk_verify(&root, 2, rand_hash_2)); -} - -static void -test_insert_and_verify_many() { - merkle_node_t root = random_region_tree(); - assert_int_equal(count_verify_fails(&root), 0); - struct merk_stats_s stats_0 = merk_stats(&root); - - random_region_insert(&root); - assert_int_equal(count_verify_fails(&root), 0); - struct merk_stats_s stats_1 = merk_stats(&root); - - assert_memory_equal(&stats_0, &stats_1, sizeof(struct merk_stats_s)); -} - -static void -test_random_insert_stats() { - merkle_node_t root = random_region_tree(); - struct merk_stats_s stats = merk_stats(&root); - assert_int_equal(stats.leaves, RAND_REGION_ENTRIES); - assert_in_range(stats.elems, stats.leaves, stats.leaves * 2); - assert_true(stats.min_depth > log2(RAND_REGION_ENTRIES)); - assert_true(stats.max_depth < 4 * log2(RAND_REGION_ENTRIES)); - assert_true(stats.avg_depth < 2 * log2(RAND_REGION_ENTRIES)); -} - -static void -test_poison_data() { - merkle_node_t root = random_region_tree(); - size_t poison_idx = rand() % RAND_REGION_ENTRIES; - const uint8_t* poison_ptr = random_region() + poison_idx * RAND_ENTRY_SIZE; - - SHA256_CTX sha; - uint8_t hash[32]; - sha256_init(&sha); - sha256_update(&sha, poison_ptr, RAND_ENTRY_SIZE); - sha256_final(&sha, hash); - - // Flip a random bit in the hash to simulate a tampered entry - hash[rand() & 31] ^= 1 << (rand() & 7); - - bool res = merk_verify(&root, (uintptr_t)poison_ptr, hash); - assert_false(res); -} - -static void -flip_random_bit(uint8_t* buf, size_t size) { - buf[rand() % size] ^= 1 << (rand() & 7); -} - -static void -test_poison_leaf() { - merkle_node_t root = random_region_tree(); - merkle_node_t* node = &root; - - // Randomly walk the tree until we get to a leaf - while (node->left || node->right) { - merkle_node_t* next[2]; - int num_next = 0; - next[num_next] = node->left; - num_next += !!node->left; - next[num_next] = node->right; - num_next += !!node->right; - - int taken = rand() % num_next; - node = next[taken]; - } - - uintptr_t key = node->ptr; - uint8_t* hash = node->hash; - // Simulate a tampered entry - flip_random_bit(hash, 32); - - bool res = merk_verify(&root, key, hash); - assert_false(res); -} - -static void -test_poison_root() { - merkle_node_t root = random_region_tree(); - flip_random_bit(root.hash, 32); - - size_t total_verify_fails = count_verify_fails(&root); - assert_int_equal(total_verify_fails, RAND_REGION_ENTRIES); -} - -static void -test_insert_corrupt_insert() { - merkle_node_t root = random_region_tree(); - - // Find a node that has a leaf on the right, and a sibling or nephew on the - // left - // TODO: not all trees may have this structure - merkle_node_t* node = &root; - assert_non_null(node->right); - assert_non_null(node->right->right); - - while (node->right->right) { - node = node->right; - } - - merkle_node_t* leaf = node->right; - // Find the position of the sibling/nephew leaf - merkle_node_t* sibling = node->left; - - assert_non_null(leaf); - assert_non_null(sibling); - - while (sibling->left) { - sibling = sibling->left; - } - - assert_null(leaf->left); - assert_null(leaf->right); - assert_null(sibling->left); - assert_null(sibling->right); - - // Check to make sure both start off okay - bool ok = merk_verify(&root, leaf->ptr, leaf->hash); - ok &= merk_verify(&root, sibling->ptr, sibling->hash); - assert_true(ok); - - // When we corrupt the leaf hash, we expect the leaf check to fail - flip_random_bit(leaf->hash, 32); - - merkle_node_t leaf_copy = *leaf, sibling_copy = *sibling; - ok = merk_verify(&root, leaf_copy.ptr, leaf_copy.hash); - assert_false(ok); - - // Test that merk_insert doesn't incorrectly "validate" a hash that isn't the - // one we inserted - int res = merk_insert(&root, sibling_copy.ptr, sibling_copy.hash); - assert_int_not_equal(res, 0); - ok = merk_verify(&root, leaf_copy.ptr, leaf_copy.hash); - assert_false(ok); -} - -static void -test_corrupt_key() { - merkle_node_t root = {}; - SHA256_CTX sha; - - int res = merk_insert(&root, 1, random_region()); - assert_int_equal(res, 0); - res = merk_insert(&root, 2, random_region() + 32); - assert_int_equal(res, 0); - - assert_true(merk_verify(&root, 1, random_region())); - assert_true(merk_verify(&root, 2, random_region() + 32)); - - // Swap the keys for entries 1 and 2 - assert_non_null(root.right); - merkle_node_t *first = root.right->left, *second = root.right->right; - assert_non_null(first); - assert_non_null(second); - - assert_int_equal(first->ptr, 1); - first->ptr = 2; - assert_int_equal(second->ptr, 2); - second->ptr = 1; - - assert_false(merk_verify(&root, 1, random_region())); - assert_false(merk_verify(&root, 2, random_region() + 32)); -} - -int -main() { - const struct CMUnitTest tests[] = { - cmocka_unit_test(test_verify_nonexistant), - cmocka_unit_test(test_insert_and_verify_1), - cmocka_unit_test(test_insert_and_verify_2), - cmocka_unit_test(test_insert_and_verify_many), - cmocka_unit_test(test_random_insert_stats), - cmocka_unit_test(test_poison_data), - cmocka_unit_test(test_poison_leaf), - cmocka_unit_test(test_poison_root), - cmocka_unit_test(test_insert_corrupt_insert), - cmocka_unit_test(test_corrupt_key), - }; - return cmocka_run_group_tests(tests, NULL, NULL); -} \ No newline at end of file diff --git a/test/page_swap.c b/test/page_swap.c index 413a2c5..dd7db18 100644 --- a/test/page_swap.c +++ b/test/page_swap.c @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "mock.h" @@ -73,21 +74,23 @@ pfree(uintptr_t page) { assert_int_equal(res, 0); } -typedef struct { - uint8_t dat[32]; -} hash_s; -static hash_s +uintptr_t +spa_get(void) { + return palloc(); +} + +static hash_t hash_page(uintptr_t page) { - hash_s out; + hash_t out; SHA256_CTX sha; sha256_init(&sha); sha256_update(&sha, (uint8_t*)page, RISCV_PAGE_SIZE); - sha256_final(&sha, out.dat); + sha256_final(&sha, out.val); return out; } static bool -hash_eq(hash_s* h1, hash_s* h2) { - return !memcmp(h1, h2, sizeof(hash_s)); +hash_eq(hash_t* h1, hash_t* h2) { + return !memcmp(h1, h2, sizeof(hash_t)); } static double @@ -142,22 +145,24 @@ test_swap_out_in() { uintptr_t front_page = palloc(); rt_util_getrandom((void*)front_page, RISCV_PAGE_SIZE); - hash_s back_hash = hash_page(back_page); - hash_s front_hash = hash_page(front_page); + hash_t back_hash = hash_page(back_page); + hash_t front_hash = hash_page(front_page); - page_swap_epm(back_page, front_page, 0); + int err = page_swap_epm(back_page, front_page, 0); + assert(!err); - hash_s back_swp_hash = hash_page(back_page); - hash_s front_swp_hash = hash_page(front_page); + hash_t back_swp_hash = hash_page(back_page); + hash_t front_swp_hash = hash_page(front_page); assert_false(hash_eq(&back_hash, &back_swp_hash)); assert_true(hash_eq(&front_hash, &front_swp_hash)); // Randomize front_page and then swap back in our old front_page rt_util_getrandom((void*)front_page, RISCV_PAGE_SIZE); - page_swap_epm(back_page, front_page, back_page); + err = page_swap_epm(back_page, front_page, back_page); + assert(!err); - hash_s back_swp_hash2 = hash_page(back_page); - hash_s front_swp_hash2 = hash_page(front_page); + hash_t back_swp_hash2 = hash_page(back_page); + hash_t front_swp_hash2 = hash_page(front_page); assert_false(hash_eq(&back_hash, &back_swp_hash2)); assert_false(hash_eq(&back_swp_hash, &back_swp_hash2)); assert_true(hash_eq(&front_hash, &front_swp_hash2)); @@ -165,11 +170,128 @@ test_swap_out_in() { pfree(front_page); } +void +test_corrupt_back_page() { + pswap_init(); + + uintptr_t back_page = paging_alloc_backing_page(); + uintptr_t front_page = palloc(); + rt_util_getrandom((void*)front_page, RISCV_PAGE_SIZE); + + int err = page_swap_epm(back_page, front_page, 0); + assert(!err); + + // Flip a random bit in the page + *(uint8_t*)(back_page + rand() % RISCV_PAGE_SIZE) ^= (1 << (rand() % 8)); + + // Now we should see an error swapping the page back in + err = page_swap_epm(back_page, front_page, back_page); + assert(err); + + pfree(front_page); +} + +void +test_corrupt_back_page_hmac() { + pswap_init(); + + uintptr_t back_page = paging_alloc_backing_page(); + uintptr_t front_page = palloc(); + rt_util_getrandom((void*)front_page, RISCV_PAGE_SIZE); + + int err = page_swap_epm(back_page, front_page, 0); + assert(!err); + + // Flip a random bit in the hmac + hash_t* hmac = pswap_pageout_hmac(back_page); + *(uint8_t*)(back_page + rand() % sizeof(hash_t)) ^= (1 << (rand() % 8)); + + // Now we should see an error swapping the page back in + err = page_swap_epm(back_page, front_page, back_page); + assert(err); + + pfree(front_page); +} + +void +test_replay_back_page() { + pswap_init(); + + uintptr_t back_page = paging_alloc_backing_page(); + uintptr_t front_page = palloc(); + rt_util_getrandom((void*)front_page, RISCV_PAGE_SIZE); + + uint8_t* back_page_backup = (uint8_t*)palloc(); + hash_t back_hmac_backup; + + // Swap the page out + int err = page_swap_epm(back_page, front_page, 0); + assert(!err); + memcpy(back_page_backup, (void*)back_page, RISCV_PAGE_SIZE); + back_hmac_backup = *pswap_pageout_hmac(back_page); + + // Swap another page + err = page_swap_epm(back_page, front_page, back_page); + assert(!err); + + // Restore the backups + memcpy((void*)back_page, back_page_backup, RISCV_PAGE_SIZE); + *pswap_pageout_hmac(back_page) = back_hmac_backup; + + // Swapping in should fail + err = page_swap_epm(back_page, front_page, back_page); + assert(err); + + pfree(front_page); + pfree((uintptr_t)back_page_backup); +} + +void +test_invasive_replay() { + pswap_init(); + + uintptr_t back_page = paging_alloc_backing_page(); + uintptr_t front_page = palloc(); + rt_util_getrandom((void*)front_page, RISCV_PAGE_SIZE); + + uint8_t* back_page_backup = (uint8_t*)palloc(); + hash_t back_hmac_backup; + uint64_t back_ctr_backup; + + // Swap the page out + int err = page_swap_epm(back_page, front_page, 0); + assert(!err); + memcpy(back_page_backup, (void*)back_page, RISCV_PAGE_SIZE); + back_hmac_backup = *pswap_pageout_hmac(back_page); + back_ctr_backup = *pswap_pageout_ctr(back_page); + + // Swap another page + err = page_swap_epm(back_page, front_page, back_page); + assert(!err); + + // Restore the backups + memcpy((void*)back_page, back_page_backup, RISCV_PAGE_SIZE); + *pswap_pageout_hmac(back_page) = back_hmac_backup; + *pswap_pageout_ctr(back_page) = back_ctr_backup; + + // Swapping in should succeed because we somehow restored the page counter + // This should not be possible in a real attack + err = page_swap_epm(back_page, front_page, back_page); + assert(!err); + + pfree(front_page); + pfree((uintptr_t)back_page_backup); +} + int main() { const struct CMUnitTest tests[] = { cmocka_unit_test(test_swapout_randomness), cmocka_unit_test(test_swap_out_in), + cmocka_unit_test(test_corrupt_back_page), + cmocka_unit_test(test_corrupt_back_page_hmac), + cmocka_unit_test(test_replay_back_page), + cmocka_unit_test(test_invasive_replay), }; return cmocka_run_group_tests(tests, NULL, NULL); } \ No newline at end of file diff --git a/vm.c b/vm.c index 4a02e0d..829cf14 100644 --- a/vm.c +++ b/vm.c @@ -2,7 +2,6 @@ uintptr_t runtime_va_start; uintptr_t kernel_offset; -uintptr_t load_pa_start; #ifdef USE_FREEMEM /* root page table */ @@ -17,9 +16,6 @@ pte load_l3_page_table[BIT(RISCV_PT_INDEX_BITS)] __attribute__((aligned(RISCV_PA /* Program break */ uintptr_t program_break; -/* freemem */ -uintptr_t freemem_va_start; -size_t freemem_size; #endif // USE_FREEMEM /* shared buffer */ diff --git a/vm.h b/vm.h index af2a638..6470377 100644 --- a/vm.h +++ b/vm.h @@ -24,16 +24,6 @@ static inline uintptr_t kernel_va_to_pa(void* ptr) return (uintptr_t) ptr - kernel_offset; } -static inline uintptr_t __va(uintptr_t pa) -{ - return (pa - load_pa_start) + EYRIE_LOAD_START; -} - -static inline uintptr_t __pa(uintptr_t va) -{ - return (va - EYRIE_LOAD_START) + load_pa_start; -} - static inline pte pte_create(uintptr_t ppn, int type) { return (pte)((ppn << PTE_PPN_SHIFT) | PTE_V | (type & PTE_FLAG_MASK) ); diff --git a/vm_defs.h b/vm_defs.h index e100e65..a8b7dd7 100644 --- a/vm_defs.h +++ b/vm_defs.h @@ -1,5 +1,7 @@ #pragma once +#include + #define BIT(n) (1ul << (n)) #define MASK(n) (BIT(n) - 1ul) #define IS_ALIGNED(n, b) (!((n)&MASK(b))) @@ -64,3 +66,7 @@ #define PTE_PPN_SHIFT 10 typedef uintptr_t pte; +uintptr_t +__va(uintptr_t pa); +uintptr_t +__pa(uintptr_t va);