|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -import sys |
4 | 3 | import os
|
5 | 4 | import ctypes
|
6 |
| -import functools |
7 | 5 | import pathlib
|
8 | 6 |
|
9 | 7 | from typing import (
|
10 |
| - Any, |
11 | 8 | Callable,
|
12 |
| - List, |
13 | 9 | Union,
|
14 | 10 | NewType,
|
15 | 11 | Optional,
|
16 | 12 | TYPE_CHECKING,
|
17 |
| - TypeVar, |
18 |
| - Generic, |
19 | 13 | )
|
20 |
| -from typing_extensions import TypeAlias |
21 | 14 |
|
| 15 | +from llama_cpp._ctypes_extensions import ( |
| 16 | + load_shared_library, |
| 17 | + byref, |
| 18 | + ctypes_function_for_shared_library, |
| 19 | +) |
22 | 20 |
|
23 |
| -# Load the library |
24 |
| -def _load_shared_library(lib_base_name: str): |
25 |
| - # Construct the paths to the possible shared library names |
26 |
| - _base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" |
27 |
| - # Searching for the library in the current directory under the name "libllama" (default name |
28 |
| - # for llamacpp) and "llama" (default name for this repo) |
29 |
| - _lib_paths: List[pathlib.Path] = [] |
30 |
| - # Determine the file extension based on the platform |
31 |
| - if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): |
32 |
| - _lib_paths += [ |
33 |
| - _base_path / f"lib{lib_base_name}.so", |
34 |
| - ] |
35 |
| - elif sys.platform == "darwin": |
36 |
| - _lib_paths += [ |
37 |
| - _base_path / f"lib{lib_base_name}.so", |
38 |
| - _base_path / f"lib{lib_base_name}.dylib", |
39 |
| - ] |
40 |
| - elif sys.platform == "win32": |
41 |
| - _lib_paths += [ |
42 |
| - _base_path / f"{lib_base_name}.dll", |
43 |
| - _base_path / f"lib{lib_base_name}.dll", |
44 |
| - ] |
45 |
| - else: |
46 |
| - raise RuntimeError("Unsupported platform") |
47 |
| - |
48 |
| - if "LLAMA_CPP_LIB" in os.environ: |
49 |
| - lib_base_name = os.environ["LLAMA_CPP_LIB"] |
50 |
| - _lib = pathlib.Path(lib_base_name) |
51 |
| - _base_path = _lib.parent.resolve() |
52 |
| - _lib_paths = [_lib.resolve()] |
53 |
| - |
54 |
| - cdll_args = dict() # type: ignore |
55 |
| - |
56 |
| - # Add the library directory to the DLL search path on Windows (if needed) |
57 |
| - if sys.platform == "win32": |
58 |
| - os.add_dll_directory(str(_base_path)) |
59 |
| - os.environ["PATH"] = str(_base_path) + os.pathsep + os.environ["PATH"] |
60 |
| - |
61 |
| - if sys.platform == "win32" and sys.version_info >= (3, 8): |
62 |
| - os.add_dll_directory(str(_base_path)) |
63 |
| - if "CUDA_PATH" in os.environ: |
64 |
| - os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin")) |
65 |
| - os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib")) |
66 |
| - if "HIP_PATH" in os.environ: |
67 |
| - os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "bin")) |
68 |
| - os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "lib")) |
69 |
| - cdll_args["winmode"] = ctypes.RTLD_GLOBAL |
70 |
| - |
71 |
| - # Try to load the shared library, handling potential errors |
72 |
| - for _lib_path in _lib_paths: |
73 |
| - if _lib_path.exists(): |
74 |
| - try: |
75 |
| - return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore |
76 |
| - except Exception as e: |
77 |
| - raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") |
78 |
| - |
79 |
| - raise FileNotFoundError( |
80 |
| - f"Shared library with base name '{lib_base_name}' not found" |
| 21 | +if TYPE_CHECKING: |
| 22 | + from llama_cpp._ctypes_extensions import ( |
| 23 | + CtypesCData, |
| 24 | + CtypesArray, |
| 25 | + CtypesPointer, |
| 26 | + CtypesVoidPointer, |
| 27 | + CtypesRef, |
| 28 | + CtypesPointerOrRef, |
| 29 | + CtypesFuncPointer, |
81 | 30 | )
|
82 | 31 |
|
83 | 32 |
|
84 | 33 | # Specify the base name of the shared library to load
|
85 | 34 | _lib_base_name = "llama"
|
86 |
| - |
| 35 | +_override_base_path = os.environ.get("LLAMA_CPP_LIB_PATH") |
| 36 | +_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _override_base_path is None else pathlib.Path(_override_base_path) |
87 | 37 | # Load the library
|
88 |
| -_lib = _load_shared_library(_lib_base_name) |
89 |
| - |
90 |
| - |
91 |
| -# ctypes sane type hint helpers |
92 |
| -# |
93 |
| -# - Generic Pointer and Array types |
94 |
| -# - PointerOrRef type with a type hinted byref function |
95 |
| -# |
96 |
| -# NOTE: Only use these for static type checking not for runtime checks |
97 |
| -# no good will come of that |
98 |
| - |
99 |
| -if TYPE_CHECKING: |
100 |
| - CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore |
101 |
| - |
102 |
| - CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore |
103 |
| - |
104 |
| - CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore |
105 |
| - |
106 |
| - CtypesVoidPointer: TypeAlias = ctypes.c_void_p |
107 |
| - |
108 |
| - class CtypesRef(Generic[CtypesCData]): |
109 |
| - pass |
110 |
| - |
111 |
| - CtypesPointerOrRef: TypeAlias = Union[ |
112 |
| - CtypesPointer[CtypesCData], CtypesRef[CtypesCData] |
113 |
| - ] |
114 |
| - |
115 |
| - CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore |
116 |
| - |
117 |
| -F = TypeVar("F", bound=Callable[..., Any]) |
118 |
| - |
119 |
| - |
120 |
| -def ctypes_function_for_shared_library(lib: ctypes.CDLL): |
121 |
| - def ctypes_function( |
122 |
| - name: str, argtypes: List[Any], restype: Any, enabled: bool = True |
123 |
| - ): |
124 |
| - def decorator(f: F) -> F: |
125 |
| - if enabled: |
126 |
| - func = getattr(lib, name) |
127 |
| - func.argtypes = argtypes |
128 |
| - func.restype = restype |
129 |
| - functools.wraps(f)(func) |
130 |
| - return func |
131 |
| - else: |
132 |
| - return f |
133 |
| - |
134 |
| - return decorator |
135 |
| - |
136 |
| - return ctypes_function |
137 |
| - |
| 38 | +_lib = load_shared_library(_lib_base_name, _base_path) |
138 | 39 |
|
139 | 40 | ctypes_function = ctypes_function_for_shared_library(_lib)
|
140 | 41 |
|
141 | 42 |
|
142 |
| -def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]: |
143 |
| - """Type-annotated version of ctypes.byref""" |
144 |
| - ... |
145 |
| - |
146 |
| - |
147 |
| -byref = ctypes.byref # type: ignore |
148 |
| - |
149 | 43 | # from ggml.h
|
150 | 44 | # // NOTE: always add types at the end of the enum to keep backward compatibility
|
151 | 45 | # enum ggml_type {
|
|
0 commit comments