Skip to content

Commit b13e0bc

Browse files
committed
Added support for overriding tensor buffer types
1 parent 99f2ebf commit b13e0bc

File tree

2 files changed

+128
-2
lines changed

2 files changed

+128
-2
lines changed

llama_cpp/llama.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
# Misc
116116
spm_infill: bool = False,
117117
verbose: bool = True,
118+
override_tensor: Optional[str] = None,
118119
# Extra Params
119120
**kwargs, # type: ignore
120121
):
@@ -187,6 +188,7 @@ def __init__(
187188
type_k: KV cache data type for K (default: f16)
188189
type_v: KV cache data type for V (default: f16)
189190
spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
191+
override_tensor: <tensor name pattern>=<buffer type>,...
190192
191193
Raises:
192194
ValueError: If the model path does not exist.
@@ -364,6 +366,46 @@ def __init__(
364366

365367
self.spm_infill = spm_infill
366368

369+
self._c_tensor_buft_overrides = None
370+
if override_tensor is not None:
371+
372+
buft_overrides = []
373+
buft_list = {}
374+
# Enumerate all devices and add their buffer types to the list
375+
for i in range(llama_cpp.ggml_backend_dev_count()):
376+
dev = llama_cpp.ggml_backend_dev_get(i)
377+
buft = llama_cpp.ggml_backend_dev_buffer_type(dev)
378+
if buft:
379+
buft_name = llama_cpp.ggml_backend_buft_name(buft).decode('utf-8')
380+
buft_list[buft_name] = buft
381+
382+
# Process overrides
383+
for override in override_tensor.split(','):
384+
pos = override.find('=')
385+
if pos == -1:
386+
raise ValueError("invalid value")
387+
388+
tensor_name = override[:pos]
389+
buffer_type = override[pos+1:]
390+
391+
if buffer_type not in buft_list:
392+
print("Available buffer types:")
393+
for name in buft_list:
394+
print(f" {name}")
395+
raise ValueError("unknown buffer type")
396+
397+
buft_overrides.append(
398+
llama_cpp.llama_model_tensor_buft_override(
399+
pattern=tensor_name.encode('utf-8'),
400+
buft=buft_list[buffer_type]
401+
)
402+
)
403+
array_type = llama_cpp.llama_model_tensor_buft_override * (len(buft_overrides) + 1)
404+
self._c_tensor_buft_overrides = array_type(
405+
*buft_overrides
406+
)
407+
self.model_params.tensor_buft_overrides = self._c_tensor_buft_overrides
408+
367409
if not os.path.exists(model_path):
368410
raise ValueError(f"Model path does not exist: {model_path}")
369411

llama_cpp/llama_cpp.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,10 +640,94 @@ class llama_model_kv_override(ctypes.Structure):
640640
value: Union[int, float, bool, bytes]
641641

642642

643+
644+
# struct ggml_backend_buffer_type_i {
645+
# const char * (*get_name) (ggml_backend_buffer_type_t buft);
646+
# // allocate a buffer of this type
647+
# ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
648+
# // tensor alignment
649+
# size_t (*get_alignment) (ggml_backend_buffer_type_t buft);
650+
# // (optional) max buffer size that can be allocated (defaults to SIZE_MAX)
651+
# size_t (*get_max_size) (ggml_backend_buffer_type_t buft);
652+
# // (optional) data size needed to allocate the tensor, including padding (defaults to ggml_nbytes)
653+
# size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
654+
# // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false)
655+
# bool (*is_host) (ggml_backend_buffer_type_t buft);
656+
# };
657+
class ggml_backend_buffer_type_i(ctypes.Structure):
658+
_fields_ = [
659+
("get_name", ctypes.c_void_p), # NOTE: Unused
660+
("alloc_buffer", ctypes.c_void_p), # NOTE: Unused
661+
("get_alignment", ctypes.c_void_p), # NOTE: Unused
662+
("get_max_size", ctypes.c_void_p), # NOTE: Unused
663+
("get_alloc_size", ctypes.c_void_p), # NOTE: Unused
664+
("is_host", ctypes.c_void_p) # NOTE: Unused
665+
]
666+
667+
# typedef struct ggml_backend_device * ggml_backend_dev_t;
668+
ggml_backend_dev_t = ctypes.c_void_p # NOTE: Unused
669+
670+
# struct ggml_backend_buffer_type {
671+
# struct ggml_backend_buffer_type_i iface;
672+
# ggml_backend_dev_t device;
673+
# void * context;
674+
# };
675+
class ggml_backend_buffer_type(ctypes.Structure):
676+
_fields_ = [
677+
("iface", ggml_backend_buffer_type_i),
678+
("device", ggml_backend_dev_t),
679+
("context", ctypes.c_void_p)
680+
]
681+
682+
# typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
683+
ggml_backend_buffer_type_t = ctypes.POINTER(ggml_backend_buffer_type)
684+
643685
# struct llama_model_tensor_buft_override {
644686
# const char * pattern;
645687
# ggml_backend_buffer_type_t buft;
646688
# };
689+
class llama_model_tensor_buft_override(ctypes.Structure):
690+
_fields_ = [
691+
("pattern", ctypes.c_char_p),
692+
("buft", ggml_backend_buffer_type_t),
693+
]
694+
695+
696+
# GGML_API size_t ggml_backend_dev_count(void);
697+
@ctypes_function(
698+
"ggml_backend_dev_count",
699+
[],
700+
ctypes.c_size_t,
701+
)
702+
def ggml_backend_dev_count() -> int:
703+
...
704+
705+
# GGML_API ggml_backend_dev_t ggml_backend_dev_get(size_t index);
706+
@ctypes_function(
707+
"ggml_backend_dev_get",
708+
[ctypes.c_size_t],
709+
ggml_backend_dev_t,
710+
)
711+
def ggml_backend_dev_get(index: int, /) -> ggml_backend_dev_t:
712+
...
713+
714+
# GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
715+
@ctypes_function(
716+
"ggml_backend_dev_buffer_type",
717+
[ggml_backend_dev_t],
718+
ggml_backend_buffer_type_t,
719+
)
720+
def ggml_backend_dev_buffer_type(device: ggml_backend_dev_t, /) -> ggml_backend_buffer_type_t:
721+
...
722+
723+
# GGML_API const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft);
724+
@ctypes_function(
725+
"ggml_backend_buft_name",
726+
[ggml_backend_buffer_type_t],
727+
ctypes.c_char_p,
728+
)
729+
def ggml_backend_buft_name(buft: ggml_backend_buffer_type_t, /) -> bytes:
730+
...
647731

648732

649733
# struct llama_model_params {
@@ -703,7 +787,7 @@ class llama_model_params(ctypes.Structure):
703787

704788
if TYPE_CHECKING:
705789
devices: CtypesArray[ctypes.c_void_p] # NOTE: unused
706-
tensor_buft_overrides: CtypesArray[llama_model_tensor_buft_override] # NOTE: unused
790+
tensor_buft_overrides: CtypesArray[llama_model_tensor_buft_override]
707791
n_gpu_layers: int
708792
split_mode: int
709793
main_gpu: int
@@ -718,7 +802,7 @@ class llama_model_params(ctypes.Structure):
718802

719803
_fields_ = [
720804
("devices", ctypes.c_void_p), # NOTE: unnused
721-
("tensor_buft_overrides", ctypes.c_void_p), # NOTE: unused
805+
("tensor_buft_overrides", ctypes.POINTER(llama_model_tensor_buft_override)),
722806
("n_gpu_layers", ctypes.c_int32),
723807
("split_mode", ctypes.c_int),
724808
("main_gpu", ctypes.c_int32),

0 commit comments

Comments
 (0)