Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/llama_cpp/LICENSE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ APPENDIX: How to apply the Apache License to your work.

To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.

Copyright [yyyy] [name of copyright owner]
Copyright 2024 deepset GmbH

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import json
from collections.abc import Iterator
from datetime import datetime, timezone
Expand Down Expand Up @@ -201,7 +205,7 @@ def __init__(
streaming_callback: StreamingCallbackT | None = None,
chat_handler_name: str | None = None,
model_clip_path: str | None = None,
):
) -> None:
"""
:param model: The path of a quantized model for text generation, for example, "zephyr-7b-beta.Q4_0.gguf".
If the model path is also specified in the `model_kwargs`, this parameter will be ignored.
Expand Down Expand Up @@ -263,7 +267,7 @@ def __init__(
self.model_clip_path = model_clip_path
self._handler = handler

def warm_up(self):
def warm_up(self) -> None:
if self._model is not None:
return

Expand Down Expand Up @@ -347,8 +351,7 @@ def run(
- `replies`: The responses from the model
"""
if self._model is None:
error_msg = "The model has not been loaded. Please call warm_up() before running."
raise RuntimeError(error_msg)
self.warm_up()

if not messages:
return {"replies": []}
Expand Down Expand Up @@ -381,7 +384,7 @@ def run(
)

if streaming_callback:
response_stream = self._model.create_chat_completion(
response_stream = self._model.create_chat_completion( # type: ignore[union-attr]
messages=formatted_messages, tools=llamacpp_tools, **updated_generation_kwargs, stream=True
)
return self._handle_streaming_response(
Expand All @@ -391,16 +394,19 @@ def run(
) # we know that response_stream is Iterator[CreateChatCompletionStreamResponse]
# because create_chat_completion was called with stream=True, but mypy doesn't know that

response = self._model.create_chat_completion(
response = self._model.create_chat_completion( # type: ignore[union-attr]
messages=formatted_messages, tools=llamacpp_tools, **updated_generation_kwargs
)
replies = []
if not isinstance(response, dict):
msg = f"Expected a dictionary response, got a different object: {response}"
raise ValueError(msg)

for choice in response["choices"]:
chat_message = self._convert_chat_completion_choice_to_chat_message(choice, response)
for choice in response["choices"]: # type: ignore[index]
chat_message = self._convert_chat_completion_choice_to_chat_message(
choice,
response, # type: ignore[arg-type]
)
replies.append(chat_message)
return {"replies": replies}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any

from haystack import component, logging
Expand Down Expand Up @@ -32,7 +36,7 @@ def __init__(
n_batch: int | None = 512,
model_kwargs: dict[str, Any] | None = None,
generation_kwargs: dict[str, Any] | None = None,
):
) -> None:
"""
:param model: The path of a quantized model for text generation, for example, "zephyr-7b-beta.Q4_0.gguf".
If the model path is also specified in the `model_kwargs`, this parameter will be ignored.
Expand Down Expand Up @@ -64,7 +68,7 @@ def __init__(
self.generation_kwargs = generation_kwargs
self.model: Llama | None = None

def warm_up(self):
def warm_up(self) -> None:
if self.model is None:
self.model = Llama(**self.model_kwargs)

Expand All @@ -84,16 +88,15 @@ def run(
- `meta`: metadata about the request.
"""
if self.model is None:
error_msg = "The model has not been loaded. Please call warm_up() before running."
raise RuntimeError(error_msg)
self.warm_up()

if not prompt:
return {"replies": []}

# merge generation kwargs from init method with those from run method
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

output = self.model.create_completion(prompt=prompt, **updated_generation_kwargs)
output = self.model.create_completion(prompt=prompt, **updated_generation_kwargs) # type: ignore[union-attr]
if not isinstance(output, dict):
msg = f"Expected a dictionary response, got a different object: {output}"
raise ValueError(msg)
Expand Down
17 changes: 4 additions & 13 deletions integrations/llama_cpp/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import json
import os
import urllib.request
Expand Down Expand Up @@ -678,7 +682,6 @@ def generator(self, model_path, capsys):

model_path = str(model_path / filename)
generator = LlamaCppChatGenerator(model=model_path, n_ctx=8192, n_batch=512)
generator.warm_up()
return generator

@pytest.fixture
Expand Down Expand Up @@ -898,14 +901,6 @@ def test_ignores_n_batch_if_specified_in_model_kwargs(self):
)
assert generator.model_kwargs["n_batch"] == 1024

def test_raises_error_without_warm_up(self):
"""
Test that the generator raises an error if warm_up() is not called before running.
"""
generator = LlamaCppChatGenerator(model="test_model.gguf", n_ctx=512, n_batch=512)
with pytest.raises(RuntimeError):
generator.run("What is the capital of China?")

def test_run_with_empty_message(self, generator_mock):
"""
Test that an empty message returns an empty list of replies.
Expand Down Expand Up @@ -1179,7 +1174,6 @@ def generator(self, model_path, capsys):
"hf_tokenizer_path": hf_tokenizer_path,
},
)
generator.warm_up()
return generator

@pytest.mark.integration
Expand Down Expand Up @@ -1260,7 +1254,6 @@ def generator(self, model_path, capsys):
"chat_format": "chatml-function-calling",
},
)
generator.warm_up()
return generator

@pytest.mark.integration
Expand Down Expand Up @@ -1326,8 +1319,6 @@ def test_live_run_image_support(self, vision_language_model):
generation_kwargs={"max_tokens": 50, "temperature": 0.1},
)

generator.warm_up()

result = generator.run(messages)

assert "replies" in result
Expand Down
13 changes: 4 additions & 9 deletions integrations/llama_cpp/tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import os
import urllib.request
from pathlib import Path
Expand Down Expand Up @@ -42,7 +46,6 @@ def generator(self, model_path, capsys):

model_path = str(model_path / filename)
generator = LlamaCppGenerator(model=model_path, n_ctx=128, n_batch=128)
generator.warm_up()
return generator

@pytest.fixture
Expand Down Expand Up @@ -106,14 +109,6 @@ def test_ignores_n_batch_if_specified_in_model_kwargs(self):
generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_batch": 1024})
assert generator.model_kwargs["n_batch"] == 1024

def test_raises_error_without_warm_up(self):
"""
Test that the generator raises an error if warm_up() is not called before running.
"""
generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512)
with pytest.raises(RuntimeError):
generator.run("What is the capital of China?")

def test_run_with_empty_prompt(self, generator_mock):
"""
Test that an empty prompt returns an empty list of replies.
Expand Down