Skip to content

Add initial bignum implementation #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,11 @@ def _is_const(self, exp: Object) -> bool:
return True
return False

def _make_tag(self, tag: str, size_bytes: str) -> str:
return f"((({size_bytes}) << kBitsPerByte) | {tag})"

def _const_obj(self, type: str, tag: str, contents: str) -> str:
# TODO(max): Emulate make_tag here to encode size
result = self.gensym(f"const_{type}")
self.const_heap.append(f"CONST_HEAP struct {type} {result} = {{.HEAD.tag={tag}, {contents} }};")
return f"ptrto({result})"
Expand All @@ -328,8 +332,20 @@ def _emit_const(self, exp: Object) -> str:
if isinstance(exp, Hole):
return "hole()"
if isinstance(exp, Int):
# TODO(max): Bignum
return f"_mksmallint({exp.value})"
if -0x4000000000000000 <= exp.value <= 0x3FFFFFFFFFFFFFFF:
return f"_mksmallint({exp.value}ULL)"
# Divide number into 64-bit digits
if exp.value < 0:
# TODO(max): Handle negative largeint
raise NotImplementedError(f"negative largeint64({exp.value})")
value = exp.value
digits = []
while value:
digits.append(value & 0xFFFFFFFFFFFFFFFF)
value >>= 64
tag = self._make_tag("TAG_LARGEINT", f"sizeof(struct large_int)+{len(digits)}ULL*kLargeIntDigitSize")
parts = ", ".join(f"{digit}ULL" for digit in digits)
return self._const_obj("large_int", tag, f".digits={{ {parts} }}")
if isinstance(exp, List):
items = [self._emit_const(item) for item in exp.items]
result = "empty_list()"
Expand Down Expand Up @@ -469,6 +485,7 @@ def compile_to_string(program: Object, debug: bool) -> str:
("uword", "kPrimaryTagMask", "(1ULL << kPrimaryTagBits) - 1"),
("uword", "kImmediateTagMask", "(1ULL << kImmediateTagBits) - 1"),
("uword", "kWordSize", "sizeof(word)"),
("uword", "kLargeIntDigitSize", "sizeof(large_int_digit)"),
("uword", "kMaxSmallStringLength", "kWordSize - 1"),
("uword", "kBitsPerByte", 8),
# Up to the five least significant bits are used to tag the object's layout.
Expand Down Expand Up @@ -496,7 +513,6 @@ def compile_to_string(program: Object, debug: bool) -> str:
dirname = os.path.dirname(__file__)
with open(os.path.join(dirname, "runtime.c"), "r") as runtime:
print(runtime.read(), file=f)
print("#define OBJECT_HANDLE(name, exp) GC_HANDLE(struct object*, name, exp)", file=f)
if compiler.record_keys:
print("const char* record_keys[] = {", file=f)
for key in compiler.record_keys:
Expand Down
22 changes: 22 additions & 0 deletions compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ def _run(self, code: str) -> str:
def test_int(self) -> None:
self.assertEqual(self._run("1"), "1\n")

def test_int_small_int_max(self) -> None:
self.assertEqual(self._run("4611686018427387903"), "4611686018427387903\n")

def test_int_small_int_min(self) -> None:
self.assertEqual(self._run("-4611686018427387904"), "-4611686018427387904\n")

def test_int_small_int_too_small(self) -> None:
with self.assertRaisesRegex(NotImplementedError, "negative largeint"):
self._run("-4611686018427387905")

def test_int_add_to_large_int(self) -> None:
self.assertEqual(self._run("4611686018427387903 + 1"), "largeint64(0x4000000000000000)\n")
self.assertEqual(self._run("4611686018427387904"), "largeint64(0x4000000000000000)\n")

def test_int_add_to_large_int_two_digits(self) -> None:
program = "4611686018427387903 + 4611686018427387903 + 4611686018427387903 + 4611686018427387903 + 4611686018427387903 + 4611686018427387903 + 4611686018427387903"
self.assertEqual(hex(eval(program)), "0x1bffffffffffffff9")
self.assertEqual(self._run(program), "largeint64(0x1, 0xbffffffffffffff9)\n")

def test_literal_positive_large_int(self) -> None:
self.assertEqual(self._run("340282366920938463463374607431768211456"), "largeint64(0x1, 0x0, 0x0)\n")

def test_small_string(self) -> None:
self.assertEqual(self._run('"hello"'), '"hello"\n')

Expand Down
218 changes: 208 additions & 10 deletions runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ const int kPointerSize = sizeof(void*);
typedef intptr_t word;
typedef uintptr_t uword;
typedef unsigned char byte;
typedef uint64_t large_int_digit;
const word kMinWord = INTPTR_MIN;
const word kMaxWord = INTPTR_MAX;
const uword kMaxUword = UINTPTR_MAX;

// Garbage collector core by Andy Wingo <[email protected]>.

Expand Down Expand Up @@ -360,7 +364,8 @@ static ALWAYS_INLINE ALLOCATOR struct object* allocate(struct gc_heap* heap,
TAG(TAG_CLOSURE) \
TAG(TAG_RECORD) \
TAG(TAG_STRING) \
TAG(TAG_VARIANT)
TAG(TAG_VARIANT) \
TAG(TAG_LARGEINT)

enum {
// All odd becase of the kNotForwardedBit
Expand Down Expand Up @@ -411,9 +416,15 @@ struct variant {
struct object* value;
} HEAP_ALIGNED;

struct large_int {
struct gc_obj HEAD;
large_int_digit digits[];
}; // Not HEAP_ALIGNED; digits is variable size

size_t heap_object_size(struct gc_obj* obj) {
size_t result = obj->tag >> kBitsPerByte;
assert(is_size_aligned(result));
// Size need not be aligned if the object is in the constant heap.
assert(in_const_heap(obj) || is_size_aligned(result));
return result;
}

Expand All @@ -435,6 +446,7 @@ size_t trace_heap_object(struct gc_obj* obj, struct gc_heap* heap,
}
break;
case TAG_STRING:
case TAG_LARGEINT:
break;
case TAG_VARIANT:
visit(&((struct variant*)obj)->value, heap);
Expand All @@ -458,12 +470,83 @@ struct object* mksmallint(word value) {
return _mksmallint(value);
}

static ALWAYS_INLINE bool is_large_int(struct object* obj) {
return is_heap_object(obj) && obj_has_tag(as_heap_object(obj), TAG_LARGEINT);
}

static ALWAYS_INLINE struct large_int* as_large_int(struct object* obj) {
assert(is_large_int(obj));
return (struct large_int*)as_heap_object(obj);
}

uword large_int_num_digits(struct object* obj) {
assert(is_large_int(obj));
size_t size = heap_object_size(as_heap_object(obj)) - sizeof(struct gc_obj);
return size / kLargeIntDigitSize;
}

uword num_digits(struct object* obj) {
if (is_small_int(obj)) {
return 1;
}
assert(is_large_int(obj));
return large_int_num_digits(obj);
}

large_int_digit large_int_digit_at(struct object* obj, uword index) {
assert(is_large_int(obj));
assert(index < large_int_num_digits(obj));
return as_large_int(obj)->digits[index];
}

void large_int_digit_at_put(struct object* obj, uword index,
large_int_digit digit) {
assert(is_large_int(obj));
assert(index < large_int_num_digits(obj));
as_large_int(obj)->digits[index] = digit;
}

word small_int_value(struct object* obj) {
assert(is_small_int(obj));
return ((word)obj) >> kSmallIntTagBits; // sign extend
}

uword digit_at(struct object* obj, uword index) {
if (is_small_int(obj)) {
assert(index == 0);
return small_int_value(obj);
}
assert(is_large_int(obj));
return large_int_digit_at(obj, index);
}

struct object* _mklarge_int_uninit_private(struct gc_heap* heap,
uword num_digits) {
uword digits_size = num_digits * kLargeIntDigitSize;
uword size = align_size(sizeof(struct large_int) + digits_size);
return allocate(heap, TAG_LARGEINT, size);
}

struct object* _mklarge_int(struct gc_heap* heap, uword num_digits,
large_int_digit* digits) {
struct object* result = _mklarge_int_uninit_private(heap, num_digits);
uword digits_size = num_digits * kLargeIntDigitSize;
memcpy(as_large_int(result)->digits, digits, digits_size);
return result;
}

struct object* mknum(struct gc_heap* heap, word value) {
(void)heap;
return mksmallint(value);
if (smallint_is_valid(value)) {
return _mksmallint(value);
}
assert(sizeof(word) == sizeof(large_int_digit));
large_int_digit digits[] = {value};
return _mklarge_int(heap, 1, digits);
}

bool is_num(struct object* obj) { return is_small_int(obj); }
bool is_num(struct object* obj) {
return is_small_int(obj) || is_large_int(obj);
}

bool is_num_equal_word(struct object* obj, word value) {
assert(smallint_is_valid(value));
Expand All @@ -472,7 +555,12 @@ bool is_num_equal_word(struct object* obj, word value) {

word num_value(struct object* obj) {
assert(is_num(obj));
return ((word)obj) >> 1; // sign extend
if (is_small_int(obj)) {
return small_int_value(obj);
}
assert(is_large_int(obj));
assert(large_int_num_digits(obj) == 1);
return large_int_digit_at(obj, 0);
}

bool is_list(struct object* obj) {
Expand Down Expand Up @@ -688,6 +776,7 @@ void pop_handles(void* local_handles) {
#define GC_HANDLE(type, name, val) \
type name = val; \
GC_PROTECT(name)
#define OBJECT_HANDLE(name, exp) GC_HANDLE(struct object*, name, exp)

void trace_roots(struct gc_heap* heap, VisitFn visit) {
for (struct object*** h = handle_stack; h != handles; h++) {
Expand All @@ -698,17 +787,115 @@ void trace_roots(struct gc_heap* heap, VisitFn visit) {
struct gc_heap heap_object;
struct gc_heap* heap = &heap_object;

struct object* num_add(struct object* a, struct object* b) {
// NB: doesn't use pointers after allocating
return mknum(heap, num_value(a) + num_value(b));
#ifndef __has_builtin
// Some versions of TCC don't have __has_builtin.
#define __has_builtin(x) 0
#endif

#if !__has_builtin(__builtin_uaddl_overflow)
// No version of TCC has __builtin_uaddl_overflow.
bool __builtin_uaddl_overflow(uword left, uword right, uword* result) {
*result = left + right;
return *result < left;
}
#endif

static uword add_with_carry(uword x, uword y, uword carry_in,
uword* carry_out) {
assert(carry_in <= 1 && "carry must be 0 or 1");
uword sum;
uword carry0 = __builtin_uaddl_overflow(x, y, &sum);
uword carry1 = __builtin_uaddl_overflow(sum, carry_in, &sum);
*carry_out = carry0 | carry1;
return sum;
}

struct object* normalize_large_int(struct gc_heap* heap, struct object* obj) {
(void)heap;
word num_digits = large_int_num_digits(obj);
word shrink_to_digits = num_digits;
for (word digit = large_int_digit_at(obj, shrink_to_digits - 1), next_digit;
shrink_to_digits > 1; shrink_to_digits--, digit = next_digit) {
next_digit = large_int_digit_at(obj, shrink_to_digits - 2);
// break if we have neither a redundant sign-extension nor a redundnant
// zero-extension.
if ((digit != -1 || next_digit >= 0) && (digit != 0 || next_digit < 0)) {
break;
}
}
if (shrink_to_digits == 1 && smallint_is_valid(large_int_digit_at(obj, 0))) {
return mksmallint(large_int_digit_at(obj, 0));
}
if (shrink_to_digits == num_digits) {
return obj;
}
HANDLES();
GC_PROTECT(obj);
OBJECT_HANDLE(result, _mklarge_int_uninit_private(heap, shrink_to_digits));
for (word i = 0; i < shrink_to_digits; i++) {
large_int_digit_at_put(result, i, large_int_digit_at(obj, i));
}
return result;
}

bool small_int_is_negative(struct object* obj) {
return small_int_value(obj) < 0;
}

bool large_int_is_negative(struct object* obj) {
return (word)large_int_digit_at(obj, large_int_num_digits(obj) - 1) < 0;
}

bool is_negative(struct object* obj) {
if (is_small_int(obj)) {
return small_int_is_negative(obj);
}
return large_int_is_negative(obj);
}

struct object* num_add(struct object* left, struct object* right) {
if (is_small_int(left) && is_small_int(right)) {
// Take a shortcut because we know the result fits in a word.
word result = num_value(left) + num_value(right);
return mknum(heap, result);
}
HANDLES();
uword left_digits = num_digits(left);
uword right_digits = num_digits(right);
GC_PROTECT(left);
GC_PROTECT(right);
OBJECT_HANDLE(longer, left_digits > right_digits ? left : right);
OBJECT_HANDLE(shorter, left_digits > right_digits ? right : left);
uword shorter_digits = num_digits(shorter);
uword longer_digits = num_digits(longer);
uword result_digits = longer_digits + 1;
OBJECT_HANDLE(result, _mklarge_int_uninit_private(heap, result_digits));
uword carry = 0;
for (uword i = 0; i < shorter_digits; i++) {
uword sum = add_with_carry(digit_at(longer, i), digit_at(shorter, i), carry,
&carry);
large_int_digit_at_put(result, i, sum);
}
uword shorter_sign_extension = is_negative(shorter) ? kMaxUword : 0;
for (uword i = shorter_digits; i < longer_digits; i++) {
uword sum = add_with_carry(digit_at(longer, i), shorter_sign_extension,
carry, &carry);
large_int_digit_at_put(result, i, sum);
}
uword longer_sign_extension = is_negative(longer) ? kMaxUword : 0;
uword high_digit = longer_sign_extension + shorter_sign_extension + carry;
large_int_digit_at_put(result, result_digits - 1, high_digit);
return normalize_large_int(heap, result);
}

struct object* num_sub(struct object* a, struct object* b) {
// TODO(max): Implement large_int subtraction
// NB: doesn't use pointers after allocating
return mknum(heap, num_value(a) - num_value(b));
}

struct object* num_mul(struct object* a, struct object* b) {
// TODO(max): Implement large_int multiplication
// NB: doesn't use pointers after allocating
return mknum(heap, num_value(a) * num_value(b));
}
Expand Down Expand Up @@ -793,8 +980,19 @@ extern const char* record_keys[];
extern const char* variant_names[];

struct object* print(struct object* obj) {
if (is_num(obj)) {
if (is_small_int(obj)) {
printf("%ld", num_value(obj));
} else if (is_large_int(obj)) {
printf("largeint%d(", kLargeIntDigitSize * kPointerSize);
uword num_digits = large_int_num_digits(obj);
for (uword i = 0; i < num_digits; i++) {
if (i > 0) {
fprintf(stdout, ", ");
}
fprintf(stdout, "0x%lx", large_int_digit_at(obj, num_digits - i - 1));
}
printf(")");
return obj;
} else if (is_list(obj)) {
putchar('[');
while (!is_empty_list(obj)) {
Expand Down
Loading