Skip to content

Commit 63f608d

Browse files
feat - improve classify to use list (#14)
* comments * added tests * fix prompt * fix tests * fix tests * lint * make input an array
1 parent e52d8d3 commit 63f608d

2 files changed

Lines changed: 11 additions & 10 deletions

File tree

src/litai/client.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -452,27 +452,28 @@ def if_(self, input: str, question: str) -> bool:
452452
response = self.chat(prompt).strip().lower()
453453
return "yes" in response
454454

455-
def classify(self, input: str, *choices: str) -> str:
456-
"""Returns the label the model chooses from the given options.
455+
def classify(self, input: str, choices: List[str]) -> str:
456+
"""Returns the label the model chooses from the given list of options.
457457
458458
Example:
459-
llm.classify("This product sucks.", "positive", "negative") → "negative"
459+
llm.classify("This product sucks.", ["positive", "negative", "neutral"]) → "negative"
460460
"""
461461
normalized_choices = [c.strip().lower() for c in choices]
462462
choices_str = ", ".join(normalized_choices)
463+
463464
prompt = f"""
464-
You are given this input
465+
You are given this input:
465466
<input>
466-
{input}
467+
{input.strip()}
467468
</input>
468469
469470
And the following choices:
470471
<choices>
471472
{choices_str}
472473
</choices>
473474
474-
Answer with only one of the choices
475-
"""
475+
Answer with only one of the choices.
476+
""".strip()
476477

477478
response = self.chat(prompt).strip().lower()
478479

tests/test_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,13 @@ def test_llm_classify_method(mock_sdkllm_class):
329329
mock_sdkllm_instance.chat.side_effect = ["positive", "negative", "neutral"]
330330

331331
# Test simple classification
332-
result = llm.classify("this movie was great!", "positive", "negative")
332+
result = llm.classify("this movie was great!", ["positive", "negative"])
333333
assert result == "positive"
334334

335335
# Test another classification
336-
result = llm.classify("this movie was awful.", "positive", "negative")
336+
result = llm.classify("this movie was awful.", ["positive", "negative"])
337337
assert result == "negative"
338338

339339
# Test with multiple classes
340-
result = llm.classify("it was okay.", "positive", "negative", "neutral")
340+
result = llm.classify("it was okay.", ["positive", "negative", "neutral"])
341341
assert result == "neutral"

0 commit comments

Comments
 (0)