|
1 | | -from typing import Generator |
2 | | -from unittest.mock import MagicMock, patch |
| 1 | +# ruff: noqa: DTZ005 |
| 2 | + |
| 3 | +import uuid |
| 4 | +from datetime import datetime |
| 5 | +from unittest.mock import MagicMock |
3 | 6 |
|
4 | 7 | import pytest |
5 | | -from codex import Codex as _Codex |
| 8 | +from codex.types.project_return_schema import Config, ProjectReturnSchema |
| 9 | +from codex.types.users.myself.user_organizations_schema import UserOrganizationsSchema |
6 | 10 |
|
7 | 11 | from cleanlab_codex.codex import Codex |
| 12 | +from cleanlab_codex.internal.project import MissingProjectIdError |
| 13 | +from cleanlab_codex.types.entry import Entry, EntryCreate |
| 14 | +from cleanlab_codex.types.organization import Organization |
| 15 | +from cleanlab_codex.types.project import ProjectConfig |
| 16 | + |
| 17 | +FAKE_PROJECT_ID = 1 |
| 18 | +FAKE_USER_ID = "Test User" |
| 19 | +FAKE_ORGANIZATION_ID = "Test Organization" |
| 20 | +FAKE_PROJECT_NAME = "Test Project" |
| 21 | +FAKE_PROJECT_DESCRIPTION = "Test Description" |
| 22 | +DEFAULT_PROJECT_CONFIG = ProjectConfig() |
| 23 | + |
| 24 | + |
| 25 | +def test_list_organizations(mock_client: MagicMock): |
| 26 | + mock_client.users.myself.organizations.list.return_value = UserOrganizationsSchema( |
| 27 | + organizations=[ |
| 28 | + Organization( |
| 29 | + organization_id=FAKE_ORGANIZATION_ID, |
| 30 | + created_at=datetime.now(), |
| 31 | + updated_at=datetime.now(), |
| 32 | + user_id=FAKE_USER_ID, |
| 33 | + ) |
| 34 | + ], |
| 35 | + ) |
| 36 | + codex = Codex("") |
| 37 | + organizations = codex.list_organizations() |
| 38 | + assert len(organizations) == 1 |
| 39 | + assert organizations[0].organization_id == FAKE_ORGANIZATION_ID |
| 40 | + assert organizations[0].user_id == FAKE_USER_ID |
| 41 | + |
| 42 | + |
| 43 | +def test_create_project(mock_client: MagicMock): |
| 44 | + mock_client.projects.create.return_value = ProjectReturnSchema( |
| 45 | + id=FAKE_PROJECT_ID, |
| 46 | + config=Config(), |
| 47 | + created_at=datetime.now(), |
| 48 | + created_by_user_id=FAKE_USER_ID, |
| 49 | + name=FAKE_PROJECT_NAME, |
| 50 | + organization_id=FAKE_ORGANIZATION_ID, |
| 51 | + updated_at=datetime.now(), |
| 52 | + description=FAKE_PROJECT_DESCRIPTION, |
| 53 | + ) |
| 54 | + codex = Codex("") |
| 55 | + project_id = codex.create_project(FAKE_PROJECT_NAME, FAKE_ORGANIZATION_ID, FAKE_PROJECT_DESCRIPTION) |
| 56 | + mock_client.projects.create.assert_called_once_with( |
| 57 | + config=DEFAULT_PROJECT_CONFIG, |
| 58 | + organization_id=FAKE_ORGANIZATION_ID, |
| 59 | + name=FAKE_PROJECT_NAME, |
| 60 | + description=FAKE_PROJECT_DESCRIPTION, |
| 61 | + ) |
| 62 | + assert project_id == FAKE_PROJECT_ID |
8 | 63 |
|
9 | | -fake_project_id = 1 |
10 | 64 |
|
| 65 | +def test_add_entries(mock_client: MagicMock): |
| 66 | + answered_entry_create = EntryCreate( |
| 67 | + question="What is the capital of France?", |
| 68 | + answer="Paris", |
| 69 | + ) |
| 70 | + unanswered_entry_create = EntryCreate( |
| 71 | + question="What is the capital of Germany?", |
| 72 | + ) |
| 73 | + codex = Codex("") |
| 74 | + codex.add_entries([answered_entry_create, unanswered_entry_create], project_id=FAKE_PROJECT_ID) |
11 | 75 |
|
12 | | -@pytest.fixture |
13 | | -def mock_client() -> Generator[_Codex, None, None]: |
14 | | - with patch("cleanlab_codex.codex.init_codex_client", return_value=MagicMock()) as mock: |
15 | | - yield mock |
| 76 | + for call, entry in zip( |
| 77 | + mock_client.projects.entries.create.call_args_list, |
| 78 | + [answered_entry_create, unanswered_entry_create], |
| 79 | + ): |
| 80 | + assert call.args[0] == FAKE_PROJECT_ID |
| 81 | + assert call.kwargs["question"] == entry["question"] |
| 82 | + assert call.kwargs["answer"] == entry.get("answer") |
16 | 83 |
|
17 | 84 |
|
18 | | -def test_query_read_only(mock_client: _Codex): |
19 | | - mock_client.projects.entries.query.return_value = None # type: ignore |
| 85 | +def test_create_project_access_key(mock_client: MagicMock): |
20 | 86 | codex = Codex("") |
21 | | - res = codex.query("What is the capital of France?", read_only=True, project_id=fake_project_id) |
22 | | - mock_client.projects.entries.query.assert_called_once_with( # type: ignore |
23 | | - fake_project_id, "What is the capital of France?" |
| 87 | + access_key_name = "Test Access Key" |
| 88 | + access_key_description = "Test Access Key Description" |
| 89 | + codex.create_project_access_key(FAKE_PROJECT_ID, access_key_name, access_key_description) |
| 90 | + mock_client.projects.access_keys.create.assert_called_once_with( |
| 91 | + project_id=FAKE_PROJECT_ID, |
| 92 | + name=access_key_name, |
| 93 | + description=access_key_description, |
24 | 94 | ) |
25 | | - mock_client.projects.entries.add_question.assert_not_called() # type: ignore |
| 95 | + |
| 96 | + |
| 97 | +def test_query_no_project_id(mock_client: MagicMock): |
| 98 | + mock_client.access_key = None |
| 99 | + codex = Codex("") |
| 100 | + |
| 101 | + with pytest.raises(MissingProjectIdError): |
| 102 | + codex.query("What is the capital of France?") |
| 103 | + |
| 104 | + |
| 105 | +def test_query_read_only(mock_client: MagicMock): |
| 106 | + mock_client.access_key = None |
| 107 | + mock_client.projects.entries.query.return_value = None |
| 108 | + |
| 109 | + codex = Codex("") |
| 110 | + res = codex.query("What is the capital of France?", read_only=True, project_id=FAKE_PROJECT_ID) |
| 111 | + mock_client.projects.entries.query.assert_called_once_with( |
| 112 | + FAKE_PROJECT_ID, question="What is the capital of France?" |
| 113 | + ) |
| 114 | + mock_client.projects.entries.add_question.assert_not_called() |
26 | 115 | assert res == (None, None) |
| 116 | + |
| 117 | + |
| 118 | +def test_query_question_found_fallback_answer(mock_client: MagicMock): |
| 119 | + unanswered_entry = Entry( |
| 120 | + id=str(uuid.uuid4()), |
| 121 | + created_at=datetime.now(), |
| 122 | + question="What is the capital of France?", |
| 123 | + answer=None, |
| 124 | + ) |
| 125 | + mock_client.projects.entries.query.return_value = unanswered_entry |
| 126 | + codex = Codex("") |
| 127 | + res = codex.query("What is the capital of France?", project_id=FAKE_PROJECT_ID) |
| 128 | + assert res == (None, unanswered_entry) |
| 129 | + |
| 130 | + |
| 131 | +def test_query_question_not_found_fallback_answer(mock_client: MagicMock): |
| 132 | + mock_client.projects.entries.query.return_value = None |
| 133 | + mock_client.projects.entries.add_question.return_value = None |
| 134 | + |
| 135 | + codex = Codex("") |
| 136 | + res = codex.query("What is the capital of France?", fallback_answer="Paris") |
| 137 | + assert res == ("Paris", None) |
| 138 | + |
| 139 | + |
| 140 | +def test_query_answer_found(mock_client: MagicMock): |
| 141 | + answered_entry = Entry( |
| 142 | + id=str(uuid.uuid4()), |
| 143 | + created_at=datetime.now(), |
| 144 | + question="What is the capital of France?", |
| 145 | + answer="Paris", |
| 146 | + ) |
| 147 | + mock_client.projects.entries.query.return_value = answered_entry |
| 148 | + codex = Codex("") |
| 149 | + res = codex.query("What is the capital of France?", project_id=FAKE_PROJECT_ID) |
| 150 | + assert res == ("Paris", answered_entry) |
0 commit comments