Skip to content

Commit 7e40813

Browse files
committed
Tests work
1 parent de2e095 commit 7e40813

3 files changed

Lines changed: 27 additions & 12 deletions

File tree

squeez/encoder/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
"""Encoder-based line classifier for tool output extraction."""
22

3-
from squeez.encoder.model import SqueezEncoderConfig, SqueezEncoderForLineClassification
4-
53
__all__ = ["SqueezEncoderConfig", "SqueezEncoderForLineClassification"]
4+
5+
6+
def __getattr__(name: str):
7+
"""Lazily import encoder model classes so lightweight helpers stay optional."""
8+
if name in __all__:
9+
from squeez.encoder.model import SqueezEncoderConfig, SqueezEncoderForLineClassification
10+
11+
return {
12+
"SqueezEncoderConfig": SqueezEncoderConfig,
13+
"SqueezEncoderForLineClassification": SqueezEncoderForLineClassification,
14+
}[name]
15+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

squeez/encoder/chunking.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@
88

99
from __future__ import annotations
1010

11-
from transformers import PreTrainedTokenizer
11+
from typing import Protocol
12+
13+
14+
class _TokenizerLike(Protocol):
15+
def __call__(self, text: str, **kwargs) -> dict:
16+
...
1217

1318

1419
def encode_text(
15-
tokenizer: PreTrainedTokenizer,
20+
tokenizer: _TokenizerLike,
1621
text: str,
1722
truncation: bool = False,
1823
max_length: int | None = None,
@@ -32,7 +37,7 @@ def encode_text(
3237

3338

3439
def chunk_output_lines(
35-
tokenizer: PreTrainedTokenizer,
40+
tokenizer: _TokenizerLike,
3641
output_lines: list[str],
3742
max_tokens_per_chunk: int,
3843
) -> tuple[list[list[int]], list[int]]:

tests/test_extractor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,19 @@ def test_format_prompt_truncates_long_task():
2020
prompt = _format_prompt(long_task, "output")
2121
assert len(long_task) > 3000
2222
assert "..." in prompt
23-
# Should be truncated to 3000 + "..."
24-
task_section = prompt.split("Task: ")[1].split("\n\n")[0]
23+
task_section = prompt.split("<task>\n", 1)[1].split("\n</task>", 1)[0]
2524
assert len(task_section) == 3003 # 3000 + "..."
2625

2726

2827
def test_format_prompt_empty_task():
2928
prompt = _format_prompt("", "some output")
30-
assert "Task: \n" in prompt
29+
assert "<task>\n\n</task>" in prompt
3130
assert "some output" in prompt
3231

3332

34-
def test_system_prompt_has_json_format():
33+
def test_system_prompt_has_relevant_lines_format():
3534
assert "relevant_lines" in SYSTEM_PROMPT
36-
assert "JSON" in SYSTEM_PROMPT
35+
assert "<relevant_lines>" in SYSTEM_PROMPT
3736

3837

3938
def test_load_config_returns_dict():
@@ -53,8 +52,9 @@ def test_assign_split(self):
5352
from squeez.data.sample_assembler import _assign_split
5453

5554
assert _assign_split("django__django") == "train"
56-
assert _assign_split("pydata__xarray") == "eval"
57-
assert _assign_split("pallets__flask") == "eval"
55+
assert _assign_split("pydata__xarray") == "test"
56+
assert _assign_split("pallets__flask") == "test"
57+
assert _assign_split("psf__requests") == "dev"
5858
assert _assign_split("scikit-learn__scikit-learn") == "train"
5959

6060
def test_format_prompt(self):

0 commit comments

Comments
 (0)