Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions src/litai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,24 +425,32 @@ def list_conversations(self) -> List[str]:
raise ValueError("No model loaded")
return self._llm.list_conversations()

def if_(self, input: str, choice1: Optional[str] = None, choice2: Optional[str] = None) -> bool:
"""Returns True if the model selects the first choice, False otherwise.
Defaults to a yes/no binary decision.
"""
choice1 = (choice1 or "yes").strip().lower()
choice2 = (choice2 or "no").strip().lower()
def if_(self, input: str, question: str) -> bool:
"""Ask a yes no question and return a True or False. Perfect for making decisions.

prompt = f"Reply with only one of [{choice1!r}, {choice2!r}].\n\nInput: {input.strip()}\nAnswer:"
Example:
review = 'this TV is awful'
if llm.if_(review, "is this a positive review?"):
print("good sentiment")
else:
print("bad sentiment")
"""
prompt = f"""
Here is an input:
<input>
{input.strip()}
</input>

And a question:
<question>
{question.strip()}
</question>

Answer with only 'yes' or 'no'.
"""

response = self.chat(prompt).strip().lower()

if response == choice1:
return True
elif response == choice2:
return False
else:
# fallback: assume choice1 if unclear
return True
return "yes" in response

def classify(self, input: str, *choices: str) -> str:
"""Returns the label the model chooses from the given options.
Expand All @@ -451,12 +459,26 @@ def classify(self, input: str, *choices: str) -> str:
llm.classify("This product sucks.", "positive", "negative") → "negative"
"""
normalized_choices = [c.strip().lower() for c in choices]
prompt = f"Reply with only one of {normalized_choices!r}.\n\nInput: {input.strip()}\nAnswer:"
choices_str = ", ".join(normalized_choices)
prompt = f"""
You are given this input
<input>
{input}
</input>

And the following choices:
<choices>
{choices_str}
</choices>

Answer with only one of the choices
"""

response = self.chat(prompt).strip().lower()

if response in normalized_choices:
return response

# fallback: return first choice if not matched
return normalized_choices[0]

Expand Down
100 changes: 28 additions & 72 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_initialization_with_config_file(monkeypatch):
monkeypatch.setattr("litai.client.SDKLLM", mock_llm_instance)
LLM(model="openai/gpt-4", lightning_api_key="my-key", lightning_user_id="my-user-id")
assert os.getenv("LIGHTNING_API_KEY") == "my-key"
assert os.getenv("LIGHTNING_USER_ID") == "my-user_id"
assert os.getenv("LIGHTNING_USER_ID") == "my-user-id"


@patch("litai.client.SDKLLM")
Expand Down Expand Up @@ -290,96 +290,52 @@ def mock_auth_constructor():


@patch("litai.client.SDKLLM")
def test_llm_if_method(mock_llm_class):
def test_llm_if_method(mock_sdkllm_class):
"""Test the LLM if_ method."""
from litai.client import LLM as LLMCLIENT

LLMCLIENT._sdkllm_cache.clear()
mock_llm_instance = MagicMock()

# Test case where the condition is true
mock_llm_instance.chat.return_value = "yes"
mock_llm_class.return_value = mock_llm_instance

# Instantiate LLM first
llm = LLM(model="openai/gpt-4")
assert llm.if_("is it true?") is True
mock_llm_instance.chat.assert_called_with(
prompt="is it true?\n\nreply 'yes' if the answer is yes, otherwise reply 'no'.",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)


# Get the actual mock instance used by llm
mock_sdkllm_instance = mock_sdkllm_class.return_value

# Test case where the condition is true
mock_sdkllm_instance.chat.side_effect = ["yes", "no", " Yes "] # Use side_effect for multiple calls
assert llm.if_("this review is great", "is this a positive review?") is True

# Test case where the condition is false
mock_llm_instance.chat.return_value = "no"
assert llm.if_("is it false?") is False
mock_llm_instance.chat.assert_called_with(
prompt="is it false?\n\nreply 'yes' if the answer is yes, otherwise reply 'no'.",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)

assert llm.if_("this review is terrible", "is this a positive review?") is False

# Test case with different capitalization/spacing
mock_llm_instance.chat.return_value = " Yes "
assert llm.if_("is it a positive response?") is True
assert llm.if_("the product is amazing", "is it a positive response?") is True


@patch("litai.client.SDKLLM")
def test_llm_classify_method(mock_llm_class):
def test_llm_classify_method(mock_sdkllm_class):
"""Test the LLM classify method."""
from litai.client import LLM as LLMCLIENT

LLMCLIENT._sdkllm_cache.clear()
mock_llm_instance = MagicMock()

# Test a simple classification
mock_llm_instance.chat.return_value = "positive"
mock_llm_class.return_value = mock_llm_instance

llm = LLM(model="openai/gpt-4")

# Get the actual mock instance used by llm
mock_sdkllm_instance = mock_sdkllm_class.return_value

# Use side_effect to return different values for sequential calls
mock_sdkllm_instance.chat.side_effect = ["positive", "negative", "neutral"]

# Test simple classification
result = llm.classify("this movie was great!", "positive", "negative")
assert result == "positive"
mock_llm_instance.chat.assert_called_with(
prompt="this movie was great!\n\nclassify the input as one of these: positive, negative. reply with only the class.",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)

# Test another classification
mock_llm_instance.chat.return_value = "negative"
result = llm.classify("this movie was awful.", "positive", "negative")
assert result == "negative"
mock_llm_instance.chat.assert_called_with(
prompt="this movie was awful.\n\nclassify the input as one of these: positive, negative. reply with only the class.",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)

# Test with multiple classes
mock_llm_instance.chat.return_value = "neutral"
result = llm.classify("it was okay.", "positive", "negative", "neutral")
assert result == "neutral"
mock_llm_instance.chat.assert_called_with(
prompt="it was okay.\n\nclassify the input as one of these: positive, negative, neutral. reply with only the class.",
system_prompt=None,
max_completion_tokens=500,
images=None,
conversation=None,
metadata=None,
stream=False,
full_response=False,
)
Loading