Skip to content

Commit 7faf83b

Browse files
committed
feat: add extern implementation for popcount
This commit introduces a native C++ implementation for `BitVec.popcount` to significantly improve its performance, especially on large bitvectors. - The `mpz` class is extended with a `popcount` method. - A new extern function `lean_bitvec_popcount` is implemented in the runtime. It uses compiler intrinsics for hardware popcount instructions (e.g., `__builtin_popcountll`, `__popcnt64`) when available, and gracefully falls back to a generic implementation on other platforms. - `BitVec.zerocount` is refactored to be a cheap calculation based on the now-fast `popcount`, rather than a separate fold.
1 parent 620055c commit 7faf83b

File tree

4 files changed

+79
-12
lines changed

4 files changed

+79
-12
lines changed

src/Init/Data/BitVec/Count.lean

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,9 @@ Examples:
5252
* `(0b1111#4).popcount = 4`
5353
* `(0#8).popcount = 0`
5454
55-
Note: This implementation could be optimized with a native `@[extern]` implementation
56-
using efficient CPU instructions (e.g., GMP's `gmp_popcount` or x86's `POPCNT`).
57-
See https://github.com/leanprover/lean4/issues/7887 for discussion of native implementations.
55+
This function uses a native implementation with CPU popcount instructions when available.
5856
-/
57+
@[extern "lean_bitvec_popcount"]
5958
def popcount (x : BitVec w) : Nat :=
6059
x.countP id
6160

@@ -130,22 +129,27 @@ Count the number of `false` bits (zeros).
130129
This is the complement of `popcount`.
131130
-/
132131
def zerocount (x : BitVec w) : Nat :=
133-
x.countP not
132+
w - x.popcount
134133

135134
@[simp]
136135
theorem zerocount_nil : zerocount nil = 0 := by
137-
simp [zerocount, -ofNat_eq_ofNat]
136+
simp [zerocount]
138137

139138
@[simp]
140139
theorem zerocount_cons (b : Bool) (x : BitVec w) :
141140
zerocount (cons b x) = (!b).toNat + zerocount x := by
142-
cases b <;> simp +arith [zerocount, countP]
141+
cases b <;>
142+
simp +arith [zerocount, Nat.sub_add_comm (popcount_le_width _)]
143143

144-
theorem popcount_add_zerocount (x : BitVec w) :
145-
x.popcount + x.zerocount = w := by
144+
theorem zerocount_eq_countP (x : BitVec w) :
145+
x.zerocount = x.countP not := by
146146
induction x using BitVec.induction with
147147
| nil => simp [-ofNat_eq_ofNat]
148-
| cons _ b => cases b <;> simp_all +arith
148+
| cons _ b => cases b <;> simp_all
149+
150+
theorem popcount_add_zerocount (x : BitVec w) :
151+
x.popcount + x.zerocount = w := by
152+
simp +arith [zerocount, popcount_le_width]
149153

150154
@[simp]
151155
theorem zerocount_not {x : BitVec w} :
@@ -160,14 +164,14 @@ theorem popcount_not {x : BitVec w} :
160164

161165
@[simp]
162166
theorem zerocount_zero : zerocount 0#w = w := by
163-
simp [←popcount_add_zerocount 0#w, -ofNat_eq_ofNat]
167+
simp [zerocount]
164168

165169
@[simp]
166170
theorem zerocount_allOnes : zerocount (allOnes w) = 0 := by
167-
simp [←not_zero]
171+
simp [zerocount]
168172

169173
theorem zerocount_le_width {x : BitVec w} : zerocount x ≤ w := by
170-
simp [←popcount_add_zerocount x]
174+
simp [zerocount]
171175

172176

173177
/--

src/runtime/mpz.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ size_t mpz::log2() const {
222222
return r - 1;
223223
}
224224

225+
size_t mpz::popcount() const {
226+
if (is_nonpos())
227+
return 0;
228+
return mpz_popcount(m_val);
229+
}
230+
225231
mpz & mpz::operator&=(mpz const & o) {
226232
mpz_and(m_val, m_val, o.m_val);
227233
return *this;
@@ -856,6 +862,20 @@ size_t mpz::log2() const {
856862
return (m_size - 1)*sizeof(mpn_digit)*8 + log2_uint(m_digits[m_size - 1]);
857863
}
858864

865+
size_t mpz::popcount() const {
866+
if (is_nonpos())
867+
return 0;
868+
size_t count = 0;
869+
for (size_t i = 0; i < m_size; i++) {
870+
mpn_digit d = m_digits[i];
871+
while (d) {
872+
count += d & 1;
873+
d >>= 1;
874+
}
875+
}
876+
return count;
877+
}
878+
859879
mpz & mpz::operator&=(mpz const & o) {
860880
digit_buffer r;
861881
size_t sz = std::max(m_size, o.m_size);

src/runtime/mpz.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,12 @@ class LEAN_EXPORT mpz {
284284
*/
285285
size_t log2() const;
286286

287+
/**
288+
\brief Return the population count (number of 1 bits).
289+
Return 0 if the number is negative
290+
*/
291+
size_t popcount() const;
292+
287293
friend void power(mpz & a, mpz const & b, unsigned k);
288294
friend void _power(mpz & a, mpz const & b, unsigned k) { power(a, b, k); }
289295
friend mpz pow(mpz a, unsigned k) { power(a, a, k); return a; }

src/runtime/object.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,43 @@ extern "C" LEAN_EXPORT lean_obj_res lean_nat_log2(b_lean_obj_arg a) {
15311531
}
15321532
}
15331533

1534+
extern "C" LEAN_EXPORT lean_obj_res lean_bitvec_popcount(b_lean_obj_arg /* w */, b_lean_obj_arg x) {
1535+
if (lean_is_scalar(x)) {
1536+
size_t n = lean_unbox(x);
1537+
unsigned count = 0;
1538+
bool builtin_used = false;
1539+
1540+
#if defined(__GNUC__) || defined(__clang__)
1541+
#if SIZE_MAX == UINT64_MAX
1542+
count = __builtin_popcountll(n);
1543+
builtin_used = true;
1544+
#elif SIZE_MAX == UINT32_MAX
1545+
count = __builtin_popcount(n);
1546+
builtin_used = true;
1547+
#endif
1548+
#elif defined(_MSC_VER)
1549+
#include <intrin.h>
1550+
#if SIZE_MAX == UINT64_MAX
1551+
count = __popcnt64(n);
1552+
builtin_used = true;
1553+
#elif SIZE_MAX == UINT32_MAX
1554+
count = __popcnt(n);
1555+
builtin_used = true;
1556+
#endif
1557+
#endif
1558+
1559+
if (!builtin_used) {
1560+
while (n) {
1561+
count += n & 1;
1562+
n >>= 1;
1563+
}
1564+
}
1565+
return lean_box(count);
1566+
} else {
1567+
return lean_box(mpz_value(x).popcount());
1568+
}
1569+
}
1570+
15341571
// =======================================
15351572
// Integers
15361573

0 commit comments

Comments
 (0)