|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +"""Custom Windows wheel repair: bundle MKL runtime DLLs into the faiss package. |
| 7 | +
|
| 8 | +Copies MKL (and Intel OpenMP / TBB) DLLs from a staging directory into the |
| 9 | +faiss/ package directory inside the wheel so they are co-located with |
| 10 | +_swigfaiss.pyd and faiss.dll. On Python 3.8+ Windows, DLL dependencies are |
| 11 | +resolved from the directory containing the loading DLL, so co-location is |
| 12 | +sufficient — no .pth file or os.add_dll_directory() call is needed. |
| 13 | +
|
| 14 | +Usage: |
| 15 | + python repair_win_wheel.py <wheel> <dest_dir> [--dll-dir C:/mkl/bin] |
| 16 | +""" |
| 17 | + |
| 18 | +import argparse |
| 19 | +import base64 |
| 20 | +import glob |
| 21 | +import hashlib |
| 22 | +import os |
| 23 | +import shutil |
| 24 | +import tempfile |
| 25 | +import zipfile |
| 26 | + |
| 27 | + |
| 28 | +def repair(wheel_path, dest_dir, dll_dir): |
| 29 | + wheel_name = os.path.basename(wheel_path) |
| 30 | + tmpdir = tempfile.mkdtemp() |
| 31 | + |
| 32 | + try: |
| 33 | + # Unpack the wheel (it's a zip file). |
| 34 | + with zipfile.ZipFile(wheel_path) as zf: |
| 35 | + zf.extractall(tmpdir) |
| 36 | + |
| 37 | + # List DLLs/PYDs already in the wheel (diagnostic). |
| 38 | + print("Files in wheel:") |
| 39 | + for root, _dirs, files in os.walk(tmpdir): |
| 40 | + for f in sorted(files): |
| 41 | + if f.lower().endswith((".dll", ".pyd")): |
| 42 | + rel = os.path.relpath(os.path.join(root, f), tmpdir) |
| 43 | + size = os.path.getsize(os.path.join(root, f)) |
| 44 | + print(f" {rel} ({size:,} bytes)") |
| 45 | + |
| 46 | + # Locate the faiss package directory inside the wheel. |
| 47 | + faiss_dir = os.path.join(tmpdir, "faiss") |
| 48 | + if not os.path.isdir(faiss_dir): |
| 49 | + raise RuntimeError("faiss/ directory not found in wheel") |
| 50 | + |
| 51 | + # Find the RECORD file (in *.dist-info/). |
| 52 | + record_path = None |
| 53 | + for root, _dirs, files in os.walk(tmpdir): |
| 54 | + if root.endswith(".dist-info") and "RECORD" in files: |
| 55 | + record_path = os.path.join(root, "RECORD") |
| 56 | + break |
| 57 | + if not record_path: |
| 58 | + raise RuntimeError("RECORD file not found in wheel") |
| 59 | + |
| 60 | + # Copy runtime DLLs into faiss/ and update RECORD. |
| 61 | + new_records = [] |
| 62 | + for dll in sorted(glob.glob(os.path.join(dll_dir, "*.dll"))): |
| 63 | + dll_name = os.path.basename(dll) |
| 64 | + dst = os.path.join(faiss_dir, dll_name) |
| 65 | + if os.path.exists(dst): |
| 66 | + print(f" skip {dll_name} (already in wheel)") |
| 67 | + continue |
| 68 | + shutil.copy2(dll, dst) |
| 69 | + |
| 70 | + with open(dst, "rb") as f: |
| 71 | + digest = hashlib.sha256(f.read()).digest() |
| 72 | + b64 = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() |
| 73 | + size = os.path.getsize(dst) |
| 74 | + new_records.append(f"faiss/{dll_name},sha256={b64},{size}") |
| 75 | + print(f" bundled {dll_name} ({size:,} bytes)") |
| 76 | + |
| 77 | + with open(record_path, "a") as f: |
| 78 | + for rec in new_records: |
| 79 | + f.write(rec + "\n") |
| 80 | + |
| 81 | + # Repack the wheel. |
| 82 | + out_path = os.path.join(dest_dir, wheel_name) |
| 83 | + with zipfile.ZipFile(out_path, "w", zipfile.ZIP_DEFLATED) as zf: |
| 84 | + for root, _dirs, files in os.walk(tmpdir): |
| 85 | + for fname in sorted(files): |
| 86 | + fullpath = os.path.join(root, fname) |
| 87 | + arcname = os.path.relpath(fullpath, tmpdir) |
| 88 | + zf.write(fullpath, arcname) |
| 89 | + |
| 90 | + print(f" repaired wheel written to {out_path}") |
| 91 | + finally: |
| 92 | + shutil.rmtree(tmpdir, ignore_errors=True) |
| 93 | + |
| 94 | + |
| 95 | +if __name__ == "__main__": |
| 96 | + parser = argparse.ArgumentParser() |
| 97 | + parser.add_argument("wheel", help="Path to the input wheel") |
| 98 | + parser.add_argument("dest_dir", help="Directory to write the repaired wheel") |
| 99 | + parser.add_argument( |
| 100 | + "--dll-dir", |
| 101 | + default="C:/mkl/bin", |
| 102 | + help="Directory containing DLLs to bundle", |
| 103 | + ) |
| 104 | + args = parser.parse_args() |
| 105 | + repair(args.wheel, args.dest_dir, args.dll_dir) |
0 commit comments