diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 8a80dd519..67246b344 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -37,7 +37,7 @@ jobs: run: uv run python3 -m pytest --runslow . - name: Check Python Types - run: uvx ty check + run: uv run ty check - name: Build Core run: uv build diff --git a/.github/workflows/format_and_lint.yml b/.github/workflows/format_and_lint.yml index e1d6aabc8..c81156cf8 100644 --- a/.github/workflows/format_and_lint.yml +++ b/.github/workflows/format_and_lint.yml @@ -40,12 +40,12 @@ jobs: - name: Lint with ruff run: | - uvx ruff check + uv run ruff check - name: Format with ruff run: | - uvx ruff format --check . + uv run ruff format --check . - name: Typecheck with ty run: | - uvx ty check + uv run ty check diff --git a/.gitignore b/.gitignore index a4b8f5f71..ce76a0841 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,5 @@ libs/server/build dist/ .mcp.json + +test_output/ diff --git a/app/desktop/WinInnoSetup.iss b/app/desktop/WinInnoSetup.iss index 085633dc3..d9c0b14d6 100644 --- a/app/desktop/WinInnoSetup.iss +++ b/app/desktop/WinInnoSetup.iss @@ -3,7 +3,7 @@ #define MyAppPath "build\dist\Kiln" #define MyAppName "Kiln" -#define MyAppVersion "0.24.0" +#define MyAppVersion "0.25.0" #define MyAppPublisher "Chesterfield Laboratories Inc" #define MyAppURL "https://kiln.tech" #define MyAppExeName "Kiln.exe" diff --git a/app/desktop/pyproject.toml b/app/desktop/pyproject.toml index 3f8a06ead..1de7056d5 100644 --- a/app/desktop/pyproject.toml +++ b/app/desktop/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "kiln-studio-desktop" -version = "0.24.0" +version = "0.25.0" description = "The Kiln Desktop App. Download from https://kiln.tech" requires-python = ">=3.10" dependencies = [ diff --git a/app/desktop/studio_server/test_copilot_api.py b/app/desktop/studio_server/test_copilot_api.py index 19fa0822a..e9982a232 100644 --- a/app/desktop/studio_server/test_copilot_api.py +++ b/app/desktop/studio_server/test_copilot_api.py @@ -13,11 +13,13 @@ from app.desktop.studio_server.copilot_api import connect_copilot_api from fastapi import FastAPI from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors @pytest.fixture def app(): app = FastAPI() + connect_custom_errors(app) connect_copilot_api(app) return app @@ -123,7 +125,7 @@ def test_clarify_spec_no_api_key(self, client, clarify_spec_input): response = client.post("/api/copilot/clarify_spec", json=clarify_spec_input) assert response.status_code == 401 - assert "API key not configured" in response.json()["detail"] + assert "API key not configured" in response.json()["message"] def test_clarify_spec_success(self, client, clarify_spec_input, mock_api_key): mock_output = MagicMock(spec=ClarifySpecOutput) @@ -194,7 +196,7 @@ def test_clarify_spec_no_response(self, client, clarify_spec_input, mock_api_key ): response = client.post("/api/copilot/clarify_spec", json=clarify_spec_input) assert response.status_code == 500 - assert "Failed to analyze spec" in response.json()["detail"] + assert "Failed to analyze spec" in response.json()["message"] def test_clarify_spec_validation_error( self, client, clarify_spec_input, mock_api_key @@ -210,7 +212,7 @@ def test_clarify_spec_validation_error( ): response = client.post("/api/copilot/clarify_spec", json=clarify_spec_input) assert response.status_code == 422 - assert "Validation error from server" in response.json()["detail"] + assert "Validation error from server" in response.json()["message"] class TestRefineSpec: @@ -223,7 +225,7 @@ def test_refine_spec_no_api_key(self, client, refine_spec_input): response = client.post("/api/copilot/refine_spec", json=refine_spec_input) assert response.status_code == 401 - assert "API key not configured" in response.json()["detail"] + assert "API key not configured" in response.json()["message"] def test_refine_spec_success(self, client, refine_spec_input, mock_api_key): mock_output = MagicMock(spec=RefineSpecApiOutput) @@ -259,7 +261,7 @@ def test_refine_spec_no_response(self, client, refine_spec_input, mock_api_key): ): response = client.post("/api/copilot/refine_spec", json=refine_spec_input) assert response.status_code == 500 - assert "Failed to refine spec" in response.json()["detail"] + assert "Failed to refine spec" in response.json()["message"] def test_refine_spec_validation_error( self, client, refine_spec_input, mock_api_key @@ -275,7 +277,7 @@ def test_refine_spec_validation_error( ): response = client.post("/api/copilot/refine_spec", json=refine_spec_input) assert response.status_code == 422 - assert "Validation error from server" in response.json()["detail"] + assert "Validation error from server" in response.json()["message"] class TestGenerateBatch: @@ -290,7 +292,7 @@ def test_generate_batch_no_api_key(self, client, generate_batch_input): "/api/copilot/generate_batch", json=generate_batch_input ) assert response.status_code == 401 - assert "API key not configured" in response.json()["detail"] + assert "API key not configured" in response.json()["message"] def test_generate_batch_success(self, client, generate_batch_input, mock_api_key): mock_output = MagicMock(spec=GenerateBatchOutput) @@ -328,7 +330,7 @@ def test_generate_batch_no_response( "/api/copilot/generate_batch", json=generate_batch_input ) assert response.status_code == 500 - assert "Failed to generate synthetic data" in response.json()["detail"] + assert "Failed to generate synthetic data" in response.json()["message"] def test_generate_batch_validation_error( self, client, generate_batch_input, mock_api_key @@ -346,4 +348,4 @@ def test_generate_batch_validation_error( "/api/copilot/generate_batch", json=generate_batch_input ) assert response.status_code == 422 - assert "Validation error from server" in response.json()["detail"] + assert "Validation error from server" in response.json()["message"] diff --git a/app/desktop/studio_server/test_data_gen_api.py b/app/desktop/studio_server/test_data_gen_api.py index 4d6fec692..bc9e4f995 100644 --- a/app/desktop/studio_server/test_data_gen_api.py +++ b/app/desktop/studio_server/test_data_gen_api.py @@ -3,6 +3,7 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors from kiln_ai.datamodel import ( DataSource, DataSourceType, @@ -31,6 +32,7 @@ @pytest.fixture def app(): app = FastAPI() + connect_custom_errors(app) connect_data_gen_api(app) return app diff --git a/app/desktop/studio_server/test_eval_api.py b/app/desktop/studio_server/test_eval_api.py index 1e3d6befc..c6bc12059 100644 --- a/app/desktop/studio_server/test_eval_api.py +++ b/app/desktop/studio_server/test_eval_api.py @@ -15,6 +15,7 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors from kiln_ai.adapters.ml_model_list import ModelProviderName from kiln_ai.datamodel import ( DataSource, @@ -55,6 +56,7 @@ @pytest.fixture def app(): app = FastAPI() + connect_custom_errors(app) connect_evals_api(app) return app @@ -189,7 +191,7 @@ def test_get_eval_not_found(client, mock_task, mock_task_from_id): response = client.get("/api/projects/project1/tasks/task1/eval/non_existent") assert response.status_code == 404 - assert response.json()["detail"] == "Eval not found. ID: non_existent" + assert response.json()["message"] == "Eval not found. ID: non_existent" @pytest.fixture @@ -513,7 +515,7 @@ async def test_run_eval_config_no_run_configs_error( assert response.status_code == 400 assert ( - response.json()["detail"] + response.json()["message"] == "No run config ids provided. At least one run config id is required." ) @@ -785,7 +787,7 @@ def test_update_run_config_prompt_name_no_prompt( json={"prompt_name": "New Name"}, ) assert response.status_code == 400 - assert "no frozen prompt" in response.json()["detail"].lower() + assert "no frozen prompt" in response.json()["message"].lower() @pytest.fixture @@ -1354,7 +1356,7 @@ def test_delete_eval_not_found(client): # Verify the response assert response.status_code == 404 - assert response.json()["detail"] == "Eval not found. ID: nonexistent_eval" + assert response.json()["message"] == "Eval not found. ID: nonexistent_eval" async def test_create_eval_then_delete_on_spec_failure( @@ -1477,7 +1479,7 @@ def test_update_eval_train_set_filter_id_when_already_set( assert response.status_code == 400 assert ( "Train set filter is already set and cannot be changed" - in response.json()["detail"] + in response.json()["message"] ) @@ -1527,7 +1529,7 @@ def test_update_eval_not_found(client): ) assert response.status_code == 404 - assert "Eval not found" in response.json()["detail"] + assert "Eval not found" in response.json()["message"] def test_update_eval_empty_request(client, mock_task_from_id, mock_eval, mock_task): @@ -1759,7 +1761,7 @@ async def test_get_eval_progress_not_found(client, mock_task_from_id, mock_task) # Verify the response assert response.status_code == 404 - assert response.json()["detail"] == "Eval not found. ID: non_existent" + assert response.json()["message"] == "Eval not found. ID: non_existent" mock_eval_from_id.assert_called_once_with("project1", "task1", "non_existent") @@ -1810,7 +1812,7 @@ async def test_set_current_eval_config_not_found( # Verify the response assert response.status_code == 400 - assert response.json()["detail"] == "Eval config not found." + assert response.json()["message"] == "Eval config not found." @pytest.mark.parametrize( @@ -1901,7 +1903,7 @@ async def test_create_task_run_config_invalid_temperature_values( }, ) assert response.status_code == 422 - error_detail = response.json()["detail"] + error_detail = response.json()["message"] assert "temperature must be between 0 and 2" in str(error_detail) # Test temperature above 2 @@ -1919,7 +1921,7 @@ async def test_create_task_run_config_invalid_temperature_values( }, ) assert response.status_code == 422 - error_detail = response.json()["detail"] + error_detail = response.json()["message"] assert "temperature must be between 0 and 2" in str(error_detail) @@ -1945,7 +1947,7 @@ async def test_create_task_run_config_invalid_top_p_values( }, ) assert response.status_code == 422 - error_detail = response.json()["detail"] + error_detail = response.json()["message"] assert "top_p must be between 0 and 1" in str(error_detail) # Test top_p above 1 @@ -1963,7 +1965,7 @@ async def test_create_task_run_config_invalid_top_p_values( }, ) assert response.status_code == 422 - error_detail = response.json()["detail"] + error_detail = response.json()["message"] assert "top_p must be between 0 and 1" in str(error_detail) @@ -2226,7 +2228,7 @@ def test_get_eval_configs_score_summary_no_filter_id( assert response.status_code == 400 assert ( - response.json()["detail"] + response.json()["message"] == "No eval configs filter id set, cannot get eval configs score summary." ) mock_eval_from_id.assert_called_once_with("project1", "task1", "eval1") diff --git a/app/desktop/studio_server/test_finetune_api.py b/app/desktop/studio_server/test_finetune_api.py index d10385ccc..73f69f9a4 100644 --- a/app/desktop/studio_server/test_finetune_api.py +++ b/app/desktop/studio_server/test_finetune_api.py @@ -4,6 +4,18 @@ import httpx import pytest +from app.desktop.studio_server.finetune_api import ( + CreateDatasetSplitRequest, + CreateFinetuneRequest, + DatasetSplitType, + FinetuneProviderModel, + compute_finetune_tag_info, + connect_fine_tune_api, + data_strategies_from_finetune_id, + fetch_fireworks_finetune_models, + infer_data_strategies_for_model, + thinking_instructions_from_request, +) from fastapi import FastAPI from fastapi.testclient import TestClient from kiln_ai.adapters.fine_tune.base_finetune import FineTuneParameter @@ -39,21 +51,9 @@ KilnAgentRunConfigProperties, ToolsRunConfig, ) +from kiln_server.custom_errors import connect_custom_errors from pydantic import BaseModel -from app.desktop.studio_server.finetune_api import ( - CreateDatasetSplitRequest, - CreateFinetuneRequest, - DatasetSplitType, - FinetuneProviderModel, - compute_finetune_tag_info, - connect_fine_tune_api, - data_strategies_from_finetune_id, - fetch_fireworks_finetune_models, - infer_data_strategies_for_model, - thinking_instructions_from_request, -) - @pytest.fixture def test_task(tmp_path): @@ -173,6 +173,7 @@ def mock_task_from_id_disk_backed(test_task, monkeypatch): @pytest.fixture def client(): app = FastAPI() + connect_custom_errors(app) connect_fine_tune_api(app) return TestClient(app) @@ -391,7 +392,7 @@ def test_get_finetune_hyperparameters_invalid_provider(client, mock_finetune_reg assert response.status_code == 400 assert ( - response.json()["detail"] == "Fine tune provider 'invalid_provider' not found" + response.json()["message"] == "Fine tune provider 'invalid_provider' not found" ) @@ -660,7 +661,7 @@ def test_create_finetune_invalid_provider(client, mock_task_from_id_disk_backed) assert response.status_code == 400 assert ( - response.json()["detail"] == "Fine tune provider 'invalid_provider' not found" + response.json()["message"] == "Fine tune provider 'invalid_provider' not found" ) @@ -688,7 +689,7 @@ def test_create_finetune_invalid_dataset( assert response.status_code == 404 assert ( - response.json()["detail"] + response.json()["message"] == "Dataset split with ID 'invalid_split_id' not found" ) @@ -756,7 +757,7 @@ def test_create_finetune_no_system_message( assert response.status_code == 400 assert ( - response.json()["detail"] + response.json()["message"] == "System message generator or custom system message is required" ) @@ -864,7 +865,7 @@ def test_create_finetune_prompt_builder_error( assert response.status_code == 400 assert ( - response.json()["detail"] + response.json()["message"] == "Error generating system message using generator: test_prompt_builder. Source error: Invalid prompt configuration" ) @@ -952,7 +953,7 @@ def test_download_dataset_jsonl_invalid_format( ) assert response.status_code == 400 - assert response.json()["detail"] == "Dataset format 'invalid_format' not found" + assert response.json()["message"] == "Dataset format 'invalid_format' not found" def test_download_dataset_jsonl_data_strategy_invalid( @@ -966,7 +967,7 @@ def test_download_dataset_jsonl_data_strategy_invalid( assert response.status_code == 400 assert ( - response.json()["detail"] == "Data strategy 'invalid_data_strategy' not found" + response.json()["message"] == "Data strategy 'invalid_data_strategy' not found" ) @@ -981,7 +982,7 @@ def test_download_dataset_jsonl_invalid_dataset( assert response.status_code == 404 assert ( - response.json()["detail"] == "Dataset split with ID 'invalid_split' not found" + response.json()["message"] == "Dataset split with ID 'invalid_split' not found" ) @@ -996,7 +997,8 @@ def test_download_dataset_jsonl_invalid_split( assert response.status_code == 404 assert ( - response.json()["detail"] == "Dataset split with name 'invalid_split' not found" + response.json()["message"] + == "Dataset split with name 'invalid_split' not found" ) @@ -1071,7 +1073,7 @@ def test_get_finetune_not_found(client, mock_task_from_id_disk_backed): response = client.get("/api/projects/project1/tasks/task1/finetunes/nonexistent") assert response.status_code == 404 - assert response.json()["detail"] == "Finetune with ID 'nonexistent' not found" + assert response.json()["message"] == "Finetune with ID 'nonexistent' not found" mock_task_from_id_disk_backed.assert_called_once_with("project1", "task1") diff --git a/app/desktop/studio_server/test_import_api.py b/app/desktop/studio_server/test_import_api.py index bd6aee4c1..f40488039 100644 --- a/app/desktop/studio_server/test_import_api.py +++ b/app/desktop/studio_server/test_import_api.py @@ -4,6 +4,7 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors from app.desktop.studio_server.import_api import _show_file_dialog, connect_import_api @@ -11,6 +12,7 @@ @pytest.fixture def app(): app = FastAPI() + connect_custom_errors(app) return app @@ -36,14 +38,14 @@ def test_select_kiln_file_no_tk_root(client_no_tk): """Test that the endpoint raises HTTPException when tk_root is None""" response = client_no_tk.get("/api/select_kiln_file") assert response.status_code == 400 - assert "Not running in app mode" in response.json()["detail"] + assert "Not running in app mode" in response.json()["message"] def test_select_kiln_file_with_custom_title(client_no_tk): """Test that custom title parameter is handled when tk_root is None""" response = client_no_tk.get("/api/select_kiln_file?title=Custom Title") assert response.status_code == 400 - assert "Not running in app mode" in response.json()["detail"] + assert "Not running in app mode" in response.json()["message"] @patch("app.desktop.studio_server.import_api.filedialog.askopenfilename") diff --git a/app/desktop/studio_server/test_prompt_api.py b/app/desktop/studio_server/test_prompt_api.py index 0b1ccf678..004eff566 100644 --- a/app/desktop/studio_server/test_prompt_api.py +++ b/app/desktop/studio_server/test_prompt_api.py @@ -5,6 +5,7 @@ # Create a FastAPI app and connect the prompt_api from fastapi import FastAPI from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors from kiln_ai.adapters.prompt_builders import BasePromptBuilder from kiln_ai.datamodel import Task @@ -14,6 +15,7 @@ @pytest.fixture def client(): app = FastAPI() + connect_custom_errors(app) connect_prompt_api(app) return TestClient(app) diff --git a/app/desktop/studio_server/test_prompt_optimization_job_api.py b/app/desktop/studio_server/test_prompt_optimization_job_api.py index c75b35dce..79d234c22 100644 --- a/app/desktop/studio_server/test_prompt_optimization_job_api.py +++ b/app/desktop/studio_server/test_prompt_optimization_job_api.py @@ -36,6 +36,7 @@ ) from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors from kiln_ai.cli.commands.package_project import PackageForTrainingConfig from kiln_ai.datamodel import Project, PromptOptimizationJob, Task from kiln_ai.datamodel.datamodel_enums import ModelProviderName, StructuredOutputMode @@ -68,6 +69,7 @@ def _make_sdk_response( @pytest.fixture def app(): app = FastAPI() + connect_custom_errors(app) connect_prompt_optimization_job_api(app) return app @@ -152,7 +154,7 @@ def test_get_prompt_optimization_job_result_no_output(client, mock_api_key): response = client.get(f"/api/prompt_optimization_jobs/{job_id}/result") assert response.status_code == 500 - assert "has no output" in response.json()["detail"] + assert "has no output" in response.json()["message"] def test_get_prompt_optimization_job_result_api_error(client, mock_api_key): @@ -168,7 +170,7 @@ def test_get_prompt_optimization_job_result_api_error(client, mock_api_key): assert response.status_code == 500 assert ( - "Failed to get Prompt Optimization job result" in response.json()["detail"] + "Failed to get Prompt Optimization job result" in response.json()["message"] ) @@ -223,7 +225,7 @@ def test_get_prompt_optimization_job_result_no_api_key(client): response = client.get(f"/api/prompt_optimization_jobs/{job_id}/result") assert response.status_code == 401 - assert "API key not configured" in response.json()["detail"] + assert "API key not configured" in response.json()["message"] def test_get_prompt_optimization_job_result_validation_error(client, mock_api_key): @@ -1450,7 +1452,7 @@ def test_check_run_config_server_none_response(client, mock_api_key, tmp_path): ) assert response.status_code == 500 - assert "unknown error" in response.json()["detail"].lower() + assert "unknown error" in response.json()["message"].lower() def test_check_run_config_exception(client, mock_api_key, tmp_path): @@ -1479,7 +1481,7 @@ def test_check_run_config_exception(client, mock_api_key, tmp_path): ) assert response.status_code == 500 - assert "Failed to check run config" in response.json()["detail"] + assert "Failed to check run config" in response.json()["message"] def test_check_eval_no_current_config(client, mock_api_key, tmp_path): @@ -1748,7 +1750,7 @@ def test_check_eval_server_none_response(client, mock_api_key, tmp_path): ) assert response.status_code == 500 - assert "unknown error" in response.json()["detail"].lower() + assert "unknown error" in response.json()["message"].lower() def test_check_eval_exception(client, mock_api_key, tmp_path): @@ -1777,7 +1779,7 @@ def test_check_eval_exception(client, mock_api_key, tmp_path): ) assert response.status_code == 500 - assert "Failed to check eval" in response.json()["detail"] + assert "Failed to check eval" in response.json()["message"] @pytest.mark.parametrize( @@ -1865,7 +1867,7 @@ def test_start_prompt_optimization_job_no_parent_project( ) assert response.status_code == 404 - assert "Project not found" in response.json()["detail"] + assert "Project not found" in response.json()["message"] def test_start_prompt_optimization_job_with_tools_in_run_config( @@ -1909,8 +1911,8 @@ def test_start_prompt_optimization_job_with_tools_in_run_config( ) assert response.status_code == 400 - assert "does not support" in response.json()["detail"] - assert "tools" in response.json()["detail"] + assert "does not support" in response.json()["message"] + assert "tools" in response.json()["message"] def test_start_prompt_optimization_job_server_not_authenticated( @@ -1957,7 +1959,7 @@ def test_start_prompt_optimization_job_server_not_authenticated( ) assert response.status_code == 500 - assert "not authenticated" in response.json()["detail"] + assert "not authenticated" in response.json()["message"] def test_start_prompt_optimization_job_server_validation_error( @@ -2012,7 +2014,7 @@ def test_start_prompt_optimization_job_server_validation_error( ) assert response.status_code == 422 - assert "Upstream validation error" in response.json()["detail"] + assert "Upstream validation error" in response.json()["message"] def test_start_prompt_optimization_job_server_none_response( @@ -2064,7 +2066,7 @@ def test_start_prompt_optimization_job_server_none_response( ) assert response.status_code == 500 - assert "unknown error" in response.json()["detail"].lower() + assert "unknown error" in response.json()["message"].lower() def test_start_prompt_optimization_job_connection_error(client, mock_api_key, tmp_path): @@ -2117,8 +2119,8 @@ class ReadError(Exception): ) assert response.status_code == 500 - assert "Connection error" in response.json()["detail"] - assert "too large" in response.json()["detail"] + assert "Connection error" in response.json()["message"] + assert "too large" in response.json()["message"] def test_start_prompt_optimization_job_timeout_error(client, mock_api_key, tmp_path): @@ -2168,7 +2170,7 @@ def test_start_prompt_optimization_job_timeout_error(client, mock_api_key, tmp_p ) assert response.status_code == 500 - assert "Connection error" in response.json()["detail"] + assert "Connection error" in response.json()["message"] def test_start_prompt_optimization_job_general_exception( @@ -2215,7 +2217,7 @@ def test_start_prompt_optimization_job_general_exception( ) assert response.status_code == 500 - assert "Failed to start Prompt Optimization job" in response.json()["detail"] + assert "Failed to start Prompt Optimization job" in response.json()["message"] def test_prompt_optimization_job_creates_run_config_on_success( diff --git a/app/desktop/studio_server/test_provider_api.py b/app/desktop/studio_server/test_provider_api.py index 4f67897c2..242135c25 100644 --- a/app/desktop/studio_server/test_provider_api.py +++ b/app/desktop/studio_server/test_provider_api.py @@ -38,6 +38,7 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors from kiln_ai.adapters.ml_embedding_model_list import ( EmbeddingModelName, KilnEmbeddingModel, @@ -65,6 +66,7 @@ @pytest.fixture def app(): app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) return app @@ -162,7 +164,7 @@ def test_connect_api_key_kiln_copilot_empty_key(client): ) assert response.status_code == 400 - assert response.json() == {"detail": "API Key not found"} + assert response.json() == {"message": "API Key not found"} @patch("app.desktop.studio_server.provider_api.httpx.AsyncClient.get") @@ -1602,7 +1604,7 @@ async def test_save_openai_compatible_providers_duplicate_name(client): ) assert response.status_code == 400 - assert response.json() == {"detail": "Provider with this name already exists"} + assert response.json() == {"message": "Provider with this name already exists"} @pytest.mark.asyncio @@ -3415,6 +3417,7 @@ async def test_get_available_reranker_models(app, client): def test_add_user_model_success(): """Test adding a user model successfully""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3447,6 +3450,7 @@ def test_add_user_model_success(): def test_add_user_model_exact_duplicate_rejected(): """Test that an exact duplicate (same provider_id, model_id, name, overrides) is rejected""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3479,12 +3483,13 @@ def test_add_user_model_exact_duplicate_rejected(): ) assert response.status_code == 400 - assert "already exists" in response.json()["detail"] + assert "already exists" in response.json()["message"] def test_add_user_model_same_model_id_different_name_allowed(): """Test that same model_id with different display name is allowed (not a duplicate)""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3524,6 +3529,7 @@ def test_add_user_model_same_model_id_different_name_allowed(): def test_add_user_model_same_model_id_different_overrides_allowed(): """Test that same model_id with different overrides is allowed (not a duplicate)""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3564,6 +3570,7 @@ def test_add_user_model_same_model_id_different_overrides_allowed(): def test_add_user_model_same_model_id_different_provider_allowed(): """Test that same model_id on a different provider is allowed""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3603,6 +3610,7 @@ def test_add_user_model_same_model_id_different_provider_allowed(): def test_add_user_model_invalid_custom_provider(): """Test that adding a model for a non-existent custom provider fails""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3624,12 +3632,13 @@ def test_add_user_model_invalid_custom_provider(): ) assert response.status_code == 400 - assert "not found" in response.json()["detail"] + assert "not found" in response.json()["message"] def test_delete_user_model_by_id(): """Test deleting a user model by its ID (new format)""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3666,6 +3675,7 @@ def test_delete_user_model_by_id(): def test_delete_user_model_by_tuple_from_registry(): """Test deleting a user model by provider_type/provider_id/model_id tuple (legacy format)""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3696,6 +3706,7 @@ def test_delete_user_model_by_tuple_from_registry(): def test_delete_user_model_by_tuple_from_legacy_custom_models(): """Test deleting a legacy model from custom_models by tuple (legacy format)""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3721,6 +3732,7 @@ def test_delete_user_model_by_tuple_from_legacy_custom_models(): def test_delete_user_model_not_found(): """Test deleting a non-existent model returns 404""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3735,12 +3747,13 @@ def test_delete_user_model_not_found(): response = client.delete("/api/settings/user_models?id=non-existent-id") assert response.status_code == 404 - assert "not found" in response.json()["detail"] + assert "not found" in response.json()["message"] def test_delete_user_model_bad_request_no_params(): """Test deleting without required parameters returns 400""" app = FastAPI() + connect_custom_errors(app) connect_provider_api(app) client = TestClient(app) @@ -3755,4 +3768,4 @@ def test_delete_user_model_bad_request_no_params(): response = client.delete("/api/settings/user_models") assert response.status_code == 400 - assert "Must specify" in response.json()["detail"] + assert "Must specify" in response.json()["message"] diff --git a/app/desktop/studio_server/test_repair_api.py b/app/desktop/studio_server/test_repair_api.py index 6b9490839..32f4ec4ce 100644 --- a/app/desktop/studio_server/test_repair_api.py +++ b/app/desktop/studio_server/test_repair_api.py @@ -4,6 +4,7 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors from kiln_ai.datamodel import ( DataSource, DataSourceType, @@ -24,6 +25,7 @@ @pytest.fixture def app(): app = FastAPI() + connect_custom_errors(app) connect_repair_api(app) return app @@ -201,7 +203,7 @@ def test_repair_run_missing_model_info( # Assert assert response.status_code == 422 - assert response.json()["detail"] == "Model name and provider must be specified." + assert response.json()["message"] == "Model name and provider must be specified." def test_repair_run_human_source( diff --git a/app/desktop/studio_server/test_run_config_api.py b/app/desktop/studio_server/test_run_config_api.py index 13a21768c..392e1d0cc 100644 --- a/app/desktop/studio_server/test_run_config_api.py +++ b/app/desktop/studio_server/test_run_config_api.py @@ -16,6 +16,7 @@ ) from fastapi import FastAPI from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors from kiln_ai.datamodel import Project, Task from kiln_ai.datamodel.basemodel import string_to_valid_name from kiln_ai.tools.mcp_server_tool import MCPServerTool @@ -24,6 +25,7 @@ @pytest.fixture def app(): app = FastAPI() + connect_custom_errors(app) connect_run_config_api(app) return app @@ -788,7 +790,7 @@ def test_tasks_compatible_with_tool_invalid_tool(client, project_and_tasks): ) assert response.status_code == 400 - assert response.json()["detail"] == "Tool selected is not an MCP tool." + assert response.json()["message"] == "Tool selected is not an MCP tool." def test_tasks_compatible_with_tool_empty_project(client, tmp_path): diff --git a/app/desktop/studio_server/test_settings_api.py b/app/desktop/studio_server/test_settings_api.py index 9e8c22c4d..ebb83453e 100644 --- a/app/desktop/studio_server/test_settings_api.py +++ b/app/desktop/studio_server/test_settings_api.py @@ -9,6 +9,7 @@ from app.desktop.studio_server.settings_api import connect_settings from fastapi import FastAPI from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors from kiln_ai.utils.config import Config @@ -21,6 +22,7 @@ def temp_home(tmp_path): @pytest.fixture def app(): app = FastAPI() + connect_custom_errors(app) connect_settings(app) return app @@ -186,7 +188,7 @@ def test_check_entitlements_no_api_key(self, client): "/api/check_entitlements?feature_codes=prompt-optimization" ) assert response.status_code == 401 - assert "API key not configured" in response.json()["detail"] + assert "API key not configured" in response.json()["message"] def test_check_entitlements_single_feature_true(self, client, mock_api_key): mock_response = MagicMock( @@ -278,4 +280,4 @@ def test_check_entitlements_api_error_response(self, client, mock_api_key): "/api/check_entitlements?feature_codes=prompt-optimization" ) assert response.status_code == 403 - assert "Forbidden: Invalid API key" in response.json()["detail"] + assert "Forbidden: Invalid API key" in response.json()["message"] diff --git a/app/desktop/studio_server/test_tool_api.py b/app/desktop/studio_server/test_tool_api.py index f9a42ea53..720f21342 100644 --- a/app/desktop/studio_server/test_tool_api.py +++ b/app/desktop/studio_server/test_tool_api.py @@ -3,6 +3,13 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from app.desktop.studio_server.tool_api import ( + ExternalToolApiDescription, + available_mcp_tools, + connect_tool_servers_api, + tool_server_from_id, + validate_tool_server_connectivity, +) from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient from kiln_ai.datamodel.datamodel_enums import StructuredOutputMode @@ -13,17 +20,11 @@ from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties from kiln_ai.datamodel.task import Task, TaskRunConfig from kiln_ai.datamodel.tool_id import KILN_TASK_TOOL_ID_PREFIX +from kiln_ai.tools.mcp_session_manager import KilnMCPError from kiln_ai.utils.config import MCP_SECRETS_KEY +from kiln_server.custom_errors import connect_custom_errors from mcp.types import ListToolsResult, Tool -from app.desktop.studio_server.tool_api import ( - ExternalToolApiDescription, - available_mcp_tools, - connect_tool_servers_api, - tool_server_from_id, - validate_tool_server_connectivity, -) - @pytest.fixture def mock_project_from_id(test_project): @@ -85,7 +86,7 @@ async def mock_mcp_success(tools=None): @asynccontextmanager async def mock_mcp_connection_error(error_message="Connection failed"): """Context manager for MCP connection errors.""" - error = Exception(error_message) + error = KilnMCPError(error_message) patch_obj, mock_client = create_mcp_session_manager_patch(connection_error=error) with patch_obj as mock_session_manager_shared: @@ -98,7 +99,7 @@ async def mock_mcp_connection_error(error_message="Connection failed"): @asynccontextmanager async def mock_mcp_list_tools_error(error_message="list_tools failed"): """Context manager for MCP list_tools errors.""" - error = Exception(error_message) + error = KilnMCPError(error_message) patch_obj, mock_client = create_mcp_session_manager_patch(list_tools_error=error) with patch_obj as mock_session_manager_shared: @@ -111,6 +112,7 @@ async def mock_mcp_list_tools_error(error_message="list_tools failed"): @pytest.fixture def app(): test_app = FastAPI() + connect_custom_errors(test_app) connect_tool_servers_api(test_app) return test_app @@ -416,7 +418,7 @@ async def test_get_tool_server_success(client, test_project): async def test_get_tool_server_mcp_error_handling(client, test_project): - """Test that MCP server errors are handled gracefully and return empty tools""" + """Test that MCP server errors are surfaced with HTTP 503 and the error message""" # First create a tool server tool_data = { "name": "failing_mcp_tool", @@ -443,12 +445,13 @@ async def test_get_tool_server_mcp_error_handling(client, test_project): # Mock retrieval with list_tools error async with mock_mcp_list_tools_error("Connection failed"): - # The API should handle the exception gracefully - # For now, let's test that it raises the exception since that's the current behavior - with pytest.raises(Exception, match="Connection failed"): - client.get( - f"/api/projects/{test_project.id}/tool_servers/{tool_server_id}" - ) + # The API should surface the error as HTTP 503 with the error message in detail + response = client.get( + f"/api/projects/{test_project.id}/tool_servers/{tool_server_id}" + ) + assert response.status_code == 503 + result = response.json() + assert "Connection failed" in result["message"] def test_get_tool_server_not_found(client, test_project): @@ -464,7 +467,61 @@ def test_get_tool_server_not_found(client, test_project): assert response.status_code == 404 result = response.json() - assert result["detail"] == "Tool server not found" + assert result["message"] == "Tool server not found" + + +async def test_get_tool_server_config_returns_file_data_without_connecting( + client, test_project +): + """The /config endpoint must return tool server data without attempting + a connection, so the edit form loads even when the MCP server is unreachable.""" + tool_data = { + "name": "config_test_tool", + "server_url": "https://example.com/api", + "headers": {}, + "description": "Tool for config test", + "is_archived": False, + } + + with patch( + "app.desktop.studio_server.tool_api.project_from_id" + ) as mock_project_from_id: + mock_project_from_id.return_value = test_project + + async with mock_mcp_success(): + create_response = client.post( + f"/api/projects/{test_project.id}/connect_remote_mcp", + json=tool_data, + ) + assert create_response.status_code == 200 + tool_server_id = create_response.json()["id"] + + # Call /config while the MCP connection is broken — must still succeed + async with mock_mcp_list_tools_error("Connection failed"): + response = client.get( + f"/api/projects/{test_project.id}/tool_servers/{tool_server_id}/config" + ) + + assert response.status_code == 200 + result = response.json() + assert result["name"] == "config_test_tool" + assert result["type"] == "remote_mcp" + assert result["available_tools"] == [] + assert result["missing_secrets"] == [] + + +def test_get_tool_server_config_not_found(client, test_project): + with patch( + "app.desktop.studio_server.tool_api.project_from_id" + ) as mock_project_from_id: + mock_project_from_id.return_value = test_project + + response = client.get( + f"/api/projects/{test_project.id}/tool_servers/nonexistent-id/config" + ) + + assert response.status_code == 404 + assert response.json()["message"] == "Tool server not found" def test_get_available_tools_empty(client, test_project): @@ -1620,7 +1677,7 @@ def test_delete_tool_server_not_found(client, test_project): f"/api/projects/{test_project.id}/tool_servers/non-existent-id" ) assert response.status_code == 404 - assert "Tool server not found" in response.json()["detail"] + assert "Tool server not found" in response.json()["message"] def test_delete_tool_server_project_not_found(client): @@ -1630,7 +1687,7 @@ def test_delete_tool_server_project_not_found(client): "/api/projects/non-existent-project/tool_servers/some-tool-id" ) assert response.status_code == 404 - assert "Project not found" in response.json()["detail"] + assert "Project not found" in response.json()["message"] async def test_delete_tool_server_affects_available_servers_list(client, test_project): @@ -3035,7 +3092,7 @@ async def test_edit_local_mcp_404(client, test_project, edit_local_server_data): json=edit_local_server_data, ) assert response.status_code == 404 - assert response.json() == {"detail": "Tool server not found"} + assert response.json() == {"message": "Tool server not found"} @pytest.fixture @@ -3085,7 +3142,7 @@ async def test_edit_local_mcp_wrong_type( ) assert response.status_code == 400 assert response.json() == { - "detail": "Existing tool server is not a local MCP server. You can't edit a non-local MCP server with this endpoint." + "message": "Existing tool server is not a local MCP server. You can't edit a non-local MCP server with this endpoint." } @@ -3165,7 +3222,7 @@ async def test_edit_remote_mcp_404(client, test_project, edit_remote_server_data json=edit_remote_server_data, ) assert response.status_code == 404 - assert response.json() == {"detail": "Tool server not found"} + assert response.json() == {"message": "Tool server not found"} async def test_edit_remote_mcp_wrong_type( @@ -3183,7 +3240,7 @@ async def test_edit_remote_mcp_wrong_type( ) assert response.status_code == 400 assert response.json() == { - "detail": "Existing tool server is not a remote MCP server. You can't edit a non-remote MCP server with this endpoint." + "message": "Existing tool server is not a remote MCP server. You can't edit a non-remote MCP server with this endpoint." } @@ -4108,7 +4165,7 @@ async def test_add_kiln_task_tool_validation_task_not_found(client, test_project ) assert response.status_code == 404 - assert "Task not found" in response.json()["detail"] + assert "Task not found" in response.json()["message"] @pytest.mark.asyncio @@ -4157,7 +4214,7 @@ async def test_add_kiln_task_tool_validation_run_config_not_found(client, test_p assert response.status_code == 400 assert ( - "Run config not found for the specified task" in response.json()["detail"] + "Run config not found for the specified task" in response.json()["message"] ) @@ -4278,7 +4335,7 @@ async def test_edit_kiln_task_tool_validation_task_not_found(client, test_projec ) assert response.status_code == 404 - assert "Task not found" in response.json()["detail"] + assert "Task not found" in response.json()["message"] @pytest.mark.asyncio @@ -4346,7 +4403,7 @@ async def test_edit_kiln_task_tool_validation_run_config_not_found( assert response.status_code == 400 assert ( - "Run config not found for the specified task" in response.json()["detail"] + "Run config not found for the specified task" in response.json()["message"] ) @@ -4463,8 +4520,8 @@ async def test_get_tool_definition_tool_not_found(client, test_project): assert response.status_code == 404 result = response.json() - assert "Tool not found or could not be instantiated" in result["detail"] - assert invalid_tool_id in result["detail"] + assert "Tool not found or could not be instantiated" in result["message"] + assert invalid_tool_id in result["message"] @pytest.mark.asyncio @@ -4484,4 +4541,4 @@ async def test_get_tool_definition_task_not_found(client, test_project): ) assert response.status_code == 404 - assert "Task not found" in response.json()["detail"] + assert "Task not found" in response.json()["message"] diff --git a/app/desktop/studio_server/tool_api.py b/app/desktop/studio_server/tool_api.py index 549a4759d..5874fb9e0 100644 --- a/app/desktop/studio_server/tool_api.py +++ b/app/desktop/studio_server/tool_api.py @@ -1,3 +1,4 @@ +import logging from datetime import datetime from enum import Enum from typing import Any, Dict, List @@ -10,6 +11,7 @@ LocalServerProperties, RemoteServerProperties, ToolServerType, + tool_server_type_to_string, ) from kiln_ai.datamodel.tool_id import ( MCP_LOCAL_TOOL_ID_PREFIX, @@ -20,7 +22,10 @@ build_rag_tool_id, ) from kiln_ai.tools.kiln_task_tool import KilnTaskTool -from kiln_ai.tools.mcp_session_manager import MCPSessionManager +from kiln_ai.tools.mcp_session_manager import ( + KilnMCPError, + MCPSessionManager, +) from kiln_ai.tools.tool_registry import tool_from_id from kiln_ai.utils.config import Config from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error @@ -29,6 +34,8 @@ from mcp.types import Tool as MCPTool from pydantic import BaseModel, Field +logger = logging.getLogger(__name__) + class KilnToolServerDescription(BaseModel): """ @@ -428,15 +435,53 @@ async def get_tool_server( available_tools = [] match tool_server.type: case ToolServerType.remote_mcp | ToolServerType.local_mcp: - async with MCPSessionManager.shared().mcp_client( - tool_server - ) as session: - tools_result = await session.list_tools() - - available_tools = [ - ExternalToolApiDescription.tool_from_mcp_tool(tool) - for tool in tools_result.tools + try: + async with MCPSessionManager.shared().mcp_client( + tool_server + ) as session: + tools_result = await session.list_tools() + + available_tools = [ + ExternalToolApiDescription.tool_from_mcp_tool(tool) + for tool in tools_result.tools + ] + except (KilnMCPError, ValueError) as e: + context_lines = [ + "MCP call failed: list_tools", + f"Tool server: {tool_server.name} ({tool_server.id})", + f"Type: {tool_server_type_to_string(tool_server.type)}", ] + props = tool_server.properties or {} + if tool_server.type == ToolServerType.local_mcp: + command = props.get("command") + args = props.get("args") + command_text = ( + command if isinstance(command, str) else "" + ) + args_text = ( + " ".join(args) + if isinstance(args, list) + and all(isinstance(arg, str) for arg in args) + else "" + ) + context_lines.append(f"Command: {command_text} {args_text}") + elif tool_server.type == ToolServerType.remote_mcp: + server_url = props.get("server_url") + server_url_text = ( + server_url if isinstance(server_url, str) else "" + ) + context_lines.append(f"Server URL: {server_url_text}") + + detail = "\n".join(context_lines) + f"\n\nError: {e}" + stderr_text = e.stderr if isinstance(e, KilnMCPError) else "" + if stderr_text: + # Truncate the error to 4kb + if len(stderr_text) > 4096: + stderr_text = stderr_text[:4096] + "\n... (truncated)" + detail += f"\n\nMCP server stderr:\n{stderr_text}" + + logger.error("MCP list_tools failed:\n%s", e, exc_info=True) + raise HTTPException(status_code=503, detail=detail) from e case ToolServerType.kiln_task: available_tools = [ await ExternalToolApiDescription.tool_from_kiln_task_tool( @@ -460,6 +505,23 @@ async def get_tool_server( missing_secrets=[], ) + @app.get("/api/projects/{project_id}/tool_servers/{tool_server_id}/config") + async def get_tool_server_config( + project_id: str, tool_server_id: str + ) -> ExternalToolServerApiDescription: + tool_server = tool_server_from_id(project_id, tool_server_id) + return ExternalToolServerApiDescription( + id=tool_server.id, + name=tool_server.name, + type=tool_server.type, + description=tool_server.description, + created_at=tool_server.created_at, + created_by=tool_server.created_by, + properties=tool_server.properties, + available_tools=[], + missing_secrets=[], + ) + @app.post("/api/projects/{project_id}/connect_remote_mcp") async def connect_remote_mcp( project_id: str, tool_data: ExternalToolServerCreationRequest diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index d7bb96796..754759f8a 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -2150,6 +2150,23 @@ export interface paths { patch?: never; trace?: never; }; + "/api/projects/{project_id}/tool_servers/{tool_server_id}/config": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** Get Tool Server Config */ + get: operations["get_tool_server_config_api_projects__project_id__tool_servers__tool_server_id__config_get"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/projects/{project_id}/connect_remote_mcp": { parameters: { query?: never; @@ -6232,11 +6249,6 @@ export interface components { } | unknown[] | null; /** Tags */ tags?: string[] | null; - /** - * Task Run Id - * @description When set, continue an existing session. The new message is appended to the run's trace. - */ - task_run_id?: string | null; }; /** * SampleApi @@ -12075,6 +12087,38 @@ export interface operations { }; }; }; + get_tool_server_config_api_projects__project_id__tool_servers__tool_server_id__config_get: { + parameters: { + query?: never; + header?: never; + path: { + project_id: string; + tool_server_id: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["ExternalToolServerApiDescription"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; connect_remote_mcp_api_projects__project_id__connect_remote_mcp_post: { parameters: { query?: never; diff --git a/app/web_ui/src/lib/ui/error_details_block.svelte b/app/web_ui/src/lib/ui/error_details_block.svelte new file mode 100644 index 000000000..ec870ee78 --- /dev/null +++ b/app/web_ui/src/lib/ui/error_details_block.svelte @@ -0,0 +1,51 @@ + + +
+ {title} + + {#if troubleshooting_steps.length > 0} +
+ Troubleshooting Steps +
    + {#each troubleshooting_steps as step} +
  1. + {#if markdown && trusted} + + {:else} + {step} + {/if} +
  2. + {/each} +
+
+ {/if} + + {#if show_logs} +
+ Error Details + +
+
+ +
+ {/if} +
diff --git a/app/web_ui/src/lib/utils/update.ts b/app/web_ui/src/lib/utils/update.ts index 42eaca30a..2e746f0cf 100644 --- a/app/web_ui/src/lib/utils/update.ts +++ b/app/web_ui/src/lib/utils/update.ts @@ -1,7 +1,7 @@ import { createKilnError, KilnError } from "$lib/utils/error_handlers" import { writable } from "svelte/store" -export const app_version = "0.24.0" +export const app_version = "0.25.0" export type UpdateCheckResult = { has_update: boolean diff --git a/app/web_ui/src/routes/(app)/optimize/[project_id]/[task_id]/+page.svelte b/app/web_ui/src/routes/(app)/optimize/[project_id]/[task_id]/+page.svelte index ce26bfdd6..fdcaea657 100644 --- a/app/web_ui/src/routes/(app)/optimize/[project_id]/[task_id]/+page.svelte +++ b/app/web_ui/src/routes/(app)/optimize/[project_id]/[task_id]/+page.svelte @@ -494,7 +494,7 @@ {#if isKilnAgentRunConfig(config.run_config_properties)} - Standard + Agent {:else} MCP {/if} diff --git a/app/web_ui/src/routes/(app)/prompts/[project_id]/[task_id]/+page.svelte b/app/web_ui/src/routes/(app)/prompts/[project_id]/[task_id]/+page.svelte index dc4e1b491..ef04f134b 100644 --- a/app/web_ui/src/routes/(app)/prompts/[project_id]/[task_id]/+page.svelte +++ b/app/web_ui/src/routes/(app)/prompts/[project_id]/[task_id]/+page.svelte @@ -14,6 +14,7 @@ import type { Task } from "$lib/types" import { createKilnError, KilnError } from "$lib/utils/error_handlers" import { getPromptType } from "./prompt_generators/prompt_generators" + import InfoTooltip from "$lib/ui/info_tooltip.svelte" import Banner from "$lib/ui/banner.svelte" import Float from "$lib/ui/float.svelte" @@ -92,11 +93,19 @@ label: string sortable: boolean sortKey?: SortableColumn + tooltip?: string } const tableColumns: TableColumn[] = [ { key: "name", label: "Name", sortable: true, sortKey: "name" }, - { key: "type", label: "Type", sortable: true, sortKey: "type" }, + { + key: "type", + label: "Type", + sortable: true, + sortKey: "type", + tooltip: + "How the prompt was created. 'Frozen' means it's a fixed snapshot of the task prompt from a saved run configuration.", + }, { key: "prompt_preview", label: "Prompt Preview", sortable: false }, { key: "created_at", @@ -213,6 +222,14 @@ class="hover:bg-base-200 cursor-pointer" > {column.label} + {#if column.tooltip} + + + + {/if} {sortColumn === sortKey ? sortDirection === "asc" @@ -222,7 +239,17 @@ {:else} - {column.label} + + {column.label} + {#if column.tooltip} + + + + {/if} + {/if} {/each} diff --git a/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/add_tools/local_mcp/edit_local_tool.svelte b/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/add_tools/local_mcp/edit_local_tool.svelte index aee0dae0f..242a55ab5 100644 --- a/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/add_tools/local_mcp/edit_local_tool.svelte +++ b/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/add_tools/local_mcp/edit_local_tool.svelte @@ -12,8 +12,7 @@ import { uncache_available_tools } from "$lib/stores" import type { ExternalToolServerApiDescription } from "$lib/types" import posthog from "posthog-js" - import { view_logs } from "$lib/utils/logs" - import Output from "$lib/ui/output.svelte" + import ErrorDetailsBlock from "$lib/ui/error_details_block.svelte" $: project_id = $page.params.project_id! @@ -374,53 +373,16 @@ {#if error} -
- Could Not Connect to MCP Server - -
- Troubleshooting Steps -
    -
  1. - Check the Error Details below for information about the issue. -
  2. -
  3. - Check the server's documentation for the correct setup - (dependencies, etc.). -
  4. -
  5. - Ensure your command {command} - {args} runs in your terminal. If you had to install libraries or dependencies, - restart the Kiln app before trying again. -
  6. -
  7. - Check Kiln logs for additional details. -
  8. -
-
-
- Error Details - -
-
+ {/if} diff --git a/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/add_tools/remote_mcp/edit_remote_tool.svelte b/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/add_tools/remote_mcp/edit_remote_tool.svelte index 0373cbf90..182e745a1 100644 --- a/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/add_tools/remote_mcp/edit_remote_tool.svelte +++ b/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/add_tools/remote_mcp/edit_remote_tool.svelte @@ -12,6 +12,7 @@ import type { ExternalToolServerApiDescription } from "$lib/types" import Warning from "$lib/ui/warning.svelte" import posthog from "posthog-js" + import ErrorDetailsBlock from "$lib/ui/error_details_block.svelte" $: project_id = $page.params.project_id! @@ -230,7 +231,6 @@ + + {#if error} + + {/if} diff --git a/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/edit_tool_server/[tool_server_id]/+page.svelte b/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/edit_tool_server/[tool_server_id]/+page.svelte index 67802aaca..0d087b352 100644 --- a/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/edit_tool_server/[tool_server_id]/+page.svelte +++ b/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/edit_tool_server/[tool_server_id]/+page.svelte @@ -34,7 +34,7 @@ // Fetch the specific tool by ID const { data, error: fetch_error } = await client.GET( - "/api/projects/{project_id}/tool_servers/{tool_server_id}", + "/api/projects/{project_id}/tool_servers/{tool_server_id}/config", { params: { path: { diff --git a/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/tool_servers/[tool_server_id]/+page.svelte b/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/tool_servers/[tool_server_id]/+page.svelte index 731d25e63..232ca18e8 100644 --- a/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/tool_servers/[tool_server_id]/+page.svelte +++ b/app/web_ui/src/routes/(app)/settings/manage_tools/[project_id]/tool_servers/[tool_server_id]/+page.svelte @@ -19,6 +19,7 @@ import { selected_tool_for_task } from "$lib/stores/tools_store" import TableButton from "../../../../../generate/[project_id]/[task_id]/table_button.svelte" import Float from "$lib/ui/float.svelte" + import ErrorDetailsBlock from "$lib/ui/error_details_block.svelte" $: project_id = $page.params.project_id! $: tool_server_id = $page.params.tool_server_id! @@ -409,7 +410,7 @@ action_buttons={[ { label: "Edit", - href: `/settings/manage_tools/${project_id}/edit_tool_server/${tool_server?.id}`, + href: `/settings/manage_tools/${project_id}/edit_tool_server/${tool_server_id}`, }, { label: is_archived ? "Unarchive" : "Archive", @@ -489,17 +490,17 @@
{:else if loading_error} -
-
Error Loading Tool
-
- {loading_error.getMessage() || "An unknown error occurred"} -
- -
+ {:else if tool_server}
diff --git a/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/select_template/+page.svelte b/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/select_template/+page.svelte index 70d32303b..235e1df4c 100644 --- a/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/select_template/+page.svelte +++ b/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/select_template/+page.svelte @@ -41,6 +41,7 @@ task_id, ) current_params.set("tool_function_name", tool_function_name) + current_params.set("tool_id", selected_tool) } goto( `/specs/${project_id}/${task_id}/spec_builder?${current_params.toString()}`, diff --git a/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/spec_builder/+page.svelte b/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/spec_builder/+page.svelte index b7faed935..1506bdbab 100644 --- a/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/spec_builder/+page.svelte +++ b/app/web_ui/src/routes/(app)/specs/[project_id]/[task_id]/spec_builder/+page.svelte @@ -78,6 +78,9 @@ let initial_property_values: Record = {} let evaluate_full_trace = false + // Tool use spec: tool_id is not a form field but is required in the saved properties + let selected_tool_id: string | null = null + // Copilot availability let has_kiln_copilot = false let default_run_config_has_tools = false @@ -237,13 +240,14 @@ property_values = values initial_property_values = { ...values } - // Override tool_function_name if provided in URL + // Override tool fields if provided in URL const tool_function_name_param = $page.url.searchParams.get("tool_function_name") if (tool_function_name_param) { property_values["tool_function_name"] = tool_function_name_param initial_property_values["tool_function_name"] = tool_function_name_param } + selected_tool_id = $page.url.searchParams.get("tool_id") } catch (e) { loading_error = createKilnError(e) } finally { @@ -379,6 +383,7 @@ const properties = { spec_type: spec_type, ...filteredValues, + ...(selected_tool_id ? { tool_id: selected_tool_id } : {}), } as SpecProperties // Call the appropriate endpoint based on whether copilot is being used diff --git a/checks.sh b/checks.sh index bb53d64cf..f753b19e2 100755 --- a/checks.sh +++ b/checks.sh @@ -31,14 +31,14 @@ echo $PWD headerStart="\n\033[4;34m=== " headerEnd=" ===\033[0m\n" -echo "${headerStart}Checking Python: uvx ruff check ${headerEnd}" -uvx ruff check +echo "${headerStart}Checking Python: uv run ruff check ${headerEnd}" +uv run ruff check -echo "${headerStart}Checking Python: uvx ruff format --check ${headerEnd}" -uvx ruff format --check . +echo "${headerStart}Checking Python: uv run ruff format --check ${headerEnd}" +uv run ruff format --check . -echo "${headerStart}Checking Python Types: uvx ty check${headerEnd}" -uvx ty check +echo "${headerStart}Checking Python Types: uv run ty check${headerEnd}" +uv run ty check echo "${headerStart}Checking for Misspellings${headerEnd}" if command -v misspell >/dev/null 2>&1; then diff --git a/hooks_mcp.yaml b/hooks_mcp.yaml index ff383c2dc..8374e6997 100644 --- a/hooks_mcp.yaml +++ b/hooks_mcp.yaml @@ -12,23 +12,23 @@ actions: - name: "lint_python" description: "Lint the python source code, checking for errors and warnings" - command: "uvx ruff check" + command: "uv run ruff check" - name: "lint_fix_python" description: "Lint the pythong source code, fixing errors and warnings which it can fix. Not all errors can be fixed automatically." - command: "uvx ruff check --fix" + command: "uv run ruff check --fix" - name: "check_format_python" description: "Check if the python source code is formatted correctly" - command: "uvx ruff format --check ." + command: "uv run ruff format --check ." - name: "format_python" description: "Format the python source code" - command: "uvx ruff format ." + command: "uv run ruff format ." - name: "typecheck_python" description: "Typecheck the source code" - command: "uvx ty check" + command: "uv run ty check" - name: "test_file_python" description: "Run tests in a specific python file or directory" diff --git a/libs/core/README.md b/libs/core/README.md index 76930c83b..a4ea0b7eb 100644 --- a/libs/core/README.md +++ b/libs/core/README.md @@ -48,6 +48,7 @@ The library has a [comprehensive set of docs](https://kiln-ai.github.io/Kiln/kil - [Load a LlamaIndex Vector Store](#load-a-llamaindex-vector-store) - [Example: LanceDB Cloud](#example-lancedb-cloud) - [Deploy RAG without LlamaIndex](#deploy-rag-without-llamaindex) +- [Connecting to Telemetry / MLOps](#connecting-to-telemetry--mlops-opentelemetry-braintrust-etc) - [Full API Reference](#full-api-reference) ## Installation @@ -432,6 +433,18 @@ After export, query your data using [LlamaIndex](https://developers.llamaindex.a While Kiln is designed for deploying to LlamaIndex, you don't need to use it. The `iter_llama_index_nodes` returns a `TextNode` object which includes all the data you need to build a RAG index in any stack: embedding, text, document name, chunk ID, etc. +## Connecting to Telemetry / MLOps (OpenTelemetry, Braintrust, etc) + +You can connect the Kiln SDK to a wide range of MLOps, trace tracking, and telemetry platforms. + +Kiln uses the LiteLLM SDK, which supports all major telemetry providers. See LiteLLM's documentation for how to set up each provider: + +- [OpenTelemetry](https://docs.litellm.ai/docs/observability/opentelemetry_integration) +- [Braintrust](https://docs.litellm.ai/docs/observability/braintrust) +- [MLFlow](https://docs.litellm.ai/docs/observability/mlflow) +- [LangFuse](https://docs.litellm.ai/docs/observability/langfuse_otel_integration) +- Many more (see LiteLLM documentation) + ## Full API Reference The library can do a lot more than the examples we've shown here. diff --git a/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py b/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py index e0138a954..2903b6eee 100644 --- a/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py +++ b/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py @@ -146,6 +146,50 @@ def test_multiturn_formatter_next_turn(): assert formatter.next_turn("assistant response") is None +def test_multiturn_formatter_preserves_tool_call_messages(): + prior_trace = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "4"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "Let me multiply 4 by 7.\n", + "tool_calls": [ + { + "id": "call_abc123", + "function": {"arguments": '{"a": 4, "b": 7}', "name": "multiply"}, + "type": "function", + } + ], + }, + { + "content": "28", + "role": "tool", + "tool_call_id": "call_abc123", + "kiln_task_tool_data": None, + }, + { + "role": "assistant", + "content": "4 multiplied by 7 is 28.", + "reasoning_content": "Done.\n", + }, + ] + formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="now double it") + initial = formatter.initial_messages() + assert initial == prior_trace + assert initial[2]["tool_calls"][0]["id"] == "call_abc123" + assert initial[2]["tool_calls"][0]["function"]["name"] == "multiply" + assert initial[3]["role"] == "tool" + assert initial[3]["tool_call_id"] == "call_abc123" + + first = formatter.next_turn() + assert first is not None + assert len(first.messages) == 1 + assert first.messages[0].role == "user" + assert first.messages[0].content == "now double it" + assert first.final_call + + def test_format_user_message(): # String assert format_user_message("test input") == "test input" diff --git a/libs/core/kiln_ai/adapters/litellm_utils/__init__.py b/libs/core/kiln_ai/adapters/litellm_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py b/libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py new file mode 100644 index 000000000..68b29dd62 --- /dev/null +++ b/libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import Any, AsyncIterator, Optional, Union + +import litellm +from litellm.types.utils import ( + ModelResponse, + ModelResponseStream, + TextCompletionResponse, +) + + +class StreamingCompletion: + """ + Async iterable wrapper around ``litellm.acompletion`` with streaming. + + Yields ``ModelResponseStream`` chunks as they arrive. After iteration + completes, the assembled ``ModelResponse`` is available via the + ``.response`` property. + + Usage:: + + stream = StreamingCompletion(model=..., messages=...) + async for chunk in stream: + # handle chunk however you like (print, log, send over WS, …) + pass + final = stream.response # fully assembled ModelResponse + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs = dict(kwargs) + kwargs.pop("stream", None) + self._args = args + self._kwargs = kwargs + self._response: Optional[Union[ModelResponse, TextCompletionResponse]] = None + self._iterated: bool = False + + @property + def response(self) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + """The final assembled response. Only available after iteration.""" + if not self._iterated: + raise RuntimeError( + "StreamingCompletion has not been iterated yet. " + "Use 'async for chunk in stream:' before accessing .response" + ) + return self._response + + async def __aiter__(self) -> AsyncIterator[ModelResponseStream]: + self._response = None + self._iterated = False + + chunks: list[ModelResponseStream] = [] + stream = await litellm.acompletion(*self._args, stream=True, **self._kwargs) + + async for chunk in stream: + chunks.append(chunk) + yield chunk + + self._response = litellm.stream_chunk_builder(chunks) + self._iterated = True diff --git a/libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py b/libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py new file mode 100644 index 000000000..e35a51982 --- /dev/null +++ b/libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, List +from unittest.mock import MagicMock, patch + +import pytest + +from kiln_ai.adapters.litellm_utils.litellm_streaming import StreamingCompletion + + +def _make_chunk(content: str | None = None, finish_reason: str | None = None) -> Any: + """Build a minimal chunk object matching litellm's streaming shape.""" + delta = SimpleNamespace(content=content, role="assistant") + choice = SimpleNamespace(delta=delta, finish_reason=finish_reason, index=0) + return SimpleNamespace(choices=[choice], id="chatcmpl-test", model="test-model") + + +async def _async_iter(items: List[Any]): + """Turn a plain list into an async iterator.""" + for item in items: + yield item + + +@pytest.fixture +def mock_acompletion(): + with patch("litellm.acompletion") as mock: + yield mock + + +@pytest.fixture +def mock_chunk_builder(): + with patch("litellm.stream_chunk_builder") as mock: + yield mock + + +class TestStreamingCompletion: + async def test_yields_all_chunks(self, mock_acompletion, mock_chunk_builder): + chunks = [_make_chunk("Hello"), _make_chunk(" world"), _make_chunk("!")] + mock_acompletion.return_value = _async_iter(chunks) + mock_chunk_builder.return_value = MagicMock(name="final_response") + + stream = StreamingCompletion(model="test", messages=[]) + received = [chunk async for chunk in stream] + + assert received == chunks + + async def test_response_available_after_iteration( + self, mock_acompletion, mock_chunk_builder + ): + chunks = [_make_chunk("hi")] + mock_acompletion.return_value = _async_iter(chunks) + sentinel = MagicMock(name="final_response") + mock_chunk_builder.return_value = sentinel + + stream = StreamingCompletion(model="test", messages=[]) + async for _ in stream: + pass + + assert stream.response is sentinel + + async def test_response_raises_before_iteration(self): + stream = StreamingCompletion(model="test", messages=[]) + with pytest.raises(RuntimeError, match="not been iterated"): + _ = stream.response + + async def test_stream_kwarg_is_stripped(self, mock_acompletion, mock_chunk_builder): + mock_acompletion.return_value = _async_iter([]) + mock_chunk_builder.return_value = None + + stream = StreamingCompletion(model="test", messages=[], stream=False) + async for _ in stream: + pass + + _, call_kwargs = mock_acompletion.call_args + assert call_kwargs["stream"] is True + + async def test_passes_args_and_kwargs_through( + self, mock_acompletion, mock_chunk_builder + ): + mock_acompletion.return_value = _async_iter([]) + mock_chunk_builder.return_value = None + + stream = StreamingCompletion( + model="gpt-4", messages=[{"role": "user", "content": "hi"}], temperature=0.5 + ) + async for _ in stream: + pass + + _, call_kwargs = mock_acompletion.call_args + assert call_kwargs["model"] == "gpt-4" + assert call_kwargs["messages"] == [{"role": "user", "content": "hi"}] + assert call_kwargs["temperature"] == 0.5 + assert call_kwargs["stream"] is True + + async def test_chunks_passed_to_builder(self, mock_acompletion, mock_chunk_builder): + chunks = [_make_chunk("a"), _make_chunk("b")] + mock_acompletion.return_value = _async_iter(chunks) + mock_chunk_builder.return_value = MagicMock() + + stream = StreamingCompletion(model="test", messages=[]) + async for _ in stream: + pass + + mock_chunk_builder.assert_called_once_with(chunks) + + async def test_re_iteration_resets_state( + self, mock_acompletion, mock_chunk_builder + ): + first_chunks = [_make_chunk("first")] + second_chunks = [_make_chunk("second")] + first_response = MagicMock(name="first_response") + second_response = MagicMock(name="second_response") + + mock_acompletion.side_effect = [ + _async_iter(first_chunks), + _async_iter(second_chunks), + ] + mock_chunk_builder.side_effect = [first_response, second_response] + + stream = StreamingCompletion(model="test", messages=[]) + + async for _ in stream: + pass + assert stream.response is first_response + + async for _ in stream: + pass + assert stream.response is second_response + + async def test_empty_stream(self, mock_acompletion, mock_chunk_builder): + mock_acompletion.return_value = _async_iter([]) + mock_chunk_builder.return_value = None + + stream = StreamingCompletion(model="test", messages=[]) + received = [chunk async for chunk in stream] + + assert received == [] + assert stream.response is None diff --git a/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py b/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py new file mode 100644 index 000000000..db4c4642a --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import copy +import json +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, AsyncIterator + +from litellm.types.utils import ( + ChatCompletionMessageToolCall, + Choices, + ModelResponse, +) + +from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM +from kiln_ai.adapters.chat.chat_formatter import ChatFormatter +from kiln_ai.adapters.litellm_utils.litellm_streaming import StreamingCompletion +from kiln_ai.adapters.ml_model_list import KilnModelProvider +from kiln_ai.adapters.model_adapters.stream_events import ( + AdapterStreamEvent, + ToolCallEvent, + ToolCallEventType, +) +from kiln_ai.adapters.run_output import RunOutput +from kiln_ai.datamodel import Usage + +if TYPE_CHECKING: + from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter + +MAX_CALLS_PER_TURN = 10 +MAX_TOOL_CALLS_PER_TURN = 30 + +logger = logging.getLogger(__name__) + + +@dataclass +class AdapterStreamResult: + run_output: RunOutput + usage: Usage + + +class AdapterStream: + """ + Orchestrates a full task execution as an async iterator, + composing StreamingCompletion instances across chat turns and tool-call rounds. + + Yields ``ModelResponseStream`` chunks from each LLM call and + ``ToolCallEvent`` instances between tool-call rounds. + + After iteration completes the ``result`` property provides the + ``AdapterStreamResult`` with the final ``RunOutput`` and ``Usage``. + """ + + def __init__( + self, + adapter: LiteLlmAdapter, + provider: KilnModelProvider, + chat_formatter: ChatFormatter, + initial_messages: list[ChatCompletionMessageIncludingLiteLLM], + top_logprobs: int | None, + ) -> None: + self._adapter = adapter + self._provider = provider + self._chat_formatter = chat_formatter + self._messages = initial_messages + self._top_logprobs = top_logprobs + self._result: AdapterStreamResult | None = None + self._iterated = False + + @property + def result(self) -> AdapterStreamResult: + if not self._iterated: + raise RuntimeError( + "AdapterStream has not been iterated yet. " + "Use 'async for event in stream:' before accessing .result" + ) + if self._result is None: + raise RuntimeError("AdapterStream completed without producing a result") + return self._result + + async def __aiter__(self) -> AsyncIterator[AdapterStreamEvent]: + self._result = None + self._iterated = False + + usage = Usage() + prior_output: str | None = None + final_choice: Choices | None = None + turns = 0 + + while True: + turns += 1 + if turns > MAX_CALLS_PER_TURN: + raise RuntimeError( + f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." + ) + + turn = self._chat_formatter.next_turn(prior_output) + if turn is None: + break + + for message in turn.messages: + if message.content is None: + raise ValueError("Empty message content isn't allowed") + self._messages.append( + {"role": message.role, "content": message.content} # type: ignore[arg-type] + ) + + skip_response_format = not turn.final_call + turn_top_logprobs = self._top_logprobs if turn.final_call else None + + async for event in self._stream_model_turn( + skip_response_format, turn_top_logprobs + ): + if isinstance(event, _ModelTurnComplete): + usage += event.usage + prior_output = event.assistant_message + final_choice = event.model_choice + else: + yield event + + if not prior_output: + raise RuntimeError("No assistant message/output returned from model") + + logprobs = self._adapter._extract_and_validate_logprobs(final_choice) + + intermediate_outputs = self._chat_formatter.intermediate_outputs() + self._adapter._extract_reasoning_to_intermediate_outputs( + final_choice, intermediate_outputs + ) + + if not isinstance(prior_output, str): + raise RuntimeError(f"assistant message is not a string: {prior_output}") + + trace = self._adapter.all_messages_to_trace(self._messages) + self._result = AdapterStreamResult( + run_output=RunOutput( + output=prior_output, + intermediate_outputs=intermediate_outputs, + output_logprobs=logprobs, + trace=trace, + ), + usage=usage, + ) + self._iterated = True + + async def _stream_model_turn( + self, + skip_response_format: bool, + top_logprobs: int | None, + ) -> AsyncIterator[AdapterStreamEvent | _ModelTurnComplete]: + usage = Usage() + tool_calls_count = 0 + + while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: + completion_kwargs = await self._adapter.build_completion_kwargs( + self._provider, + copy.deepcopy(self._messages), + top_logprobs, + skip_response_format, + ) + + stream = StreamingCompletion(**completion_kwargs) + async for chunk in stream: + yield chunk + + response, response_choice = _validate_response(stream.response) + usage += self._adapter.usage_from_response(response) + + content = response_choice.message.content + tool_calls = response_choice.message.tool_calls + if not content and not tool_calls: + raise ValueError( + "Model returned an assistant message, but no content or tool calls. This is not supported." + ) + + self._messages.append(response_choice.message) + + if tool_calls and len(tool_calls) > 0: + async for event in self._handle_tool_calls(tool_calls): + yield event + + assistant_msg = self._extract_task_response(tool_calls) + if assistant_msg is not None: + yield _ModelTurnComplete( + assistant_message=assistant_msg, + model_choice=response_choice, + usage=usage, + ) + return + + tool_calls_count += 1 + continue + + if content: + yield _ModelTurnComplete( + assistant_message=content, + model_choice=response_choice, + usage=usage, + ) + return + + raise RuntimeError( + "Model returned neither content nor tool calls. It must return at least one of these." + ) + + raise RuntimeError( + f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." + ) + + async def _handle_tool_calls( + self, + tool_calls: list[ChatCompletionMessageToolCall], + ) -> AsyncIterator[AdapterStreamEvent]: + real_tool_calls = [ + tc for tc in tool_calls if tc.function.name != "task_response" + ] + + for tc in real_tool_calls: + try: + parsed_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + parsed_args = None + + yield ToolCallEvent( + event_type=ToolCallEventType.INPUT_AVAILABLE, + tool_call_id=tc.id, + tool_name=tc.function.name or "unknown", + arguments=parsed_args, + error=( + f"Failed to parse arguments: {tc.function.arguments}" + if parsed_args is None + else None + ), + ) + + _, tool_msgs = await self._adapter.process_tool_calls(tool_calls) + + for tool_msg in tool_msgs: + tc_id = tool_msg["tool_call_id"] + tc_name = _find_tool_name(tool_calls, tc_id) + content = tool_msg["content"] + yield ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_AVAILABLE, + tool_call_id=tc_id, + tool_name=tc_name, + result=str(content) if content is not None else None, + ) + + self._messages.extend(tool_msgs) + + @staticmethod + def _extract_task_response( + tool_calls: list[ChatCompletionMessageToolCall], + ) -> str | None: + for tc in tool_calls: + if tc.function.name == "task_response": + return tc.function.arguments + return None + + +@dataclass +class _ModelTurnComplete: + """Internal sentinel yielded when a model turn finishes.""" + + assistant_message: str + model_choice: Choices | None + usage: Usage + + +def _validate_response( + response: Any, +) -> tuple[ModelResponse, Choices]: + if ( + not isinstance(response, ModelResponse) + or not response.choices + or len(response.choices) == 0 + or not isinstance(response.choices[0], Choices) + ): + raise RuntimeError( + f"Expected ModelResponse with Choices, got {type(response)}." + ) + return response, response.choices[0] + + +def _find_tool_name( + tool_calls: list[ChatCompletionMessageToolCall], tool_call_id: str +) -> str: + for tc in tool_calls: + if tc.id == tool_call_id: + return tc.function.name or "unknown" + return "unknown" diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 7cda8c6dc..d9633ac81 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -1,7 +1,12 @@ +from __future__ import annotations + import json +import uuid from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Dict, Tuple +from typing import TYPE_CHECKING, AsyncIterator, Dict, Tuple + +from litellm.types.utils import ModelResponseStream from kiln_ai.adapters.chat.chat_formatter import ( ChatFormatter, @@ -13,6 +18,13 @@ StructuredOutputMode, default_structured_output_mode_for_model_provider, ) +from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStreamResult +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + AiSdkStreamConverter, + AiSdkStreamEvent, + ToolCallEvent, +) from kiln_ai.adapters.parsers.json_parser import parse_json_string from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id @@ -49,6 +61,9 @@ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam +if TYPE_CHECKING: + from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream + @dataclass class AdapterConfig: @@ -127,10 +142,10 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> TaskRun: run_output, _ = await self.invoke_returning_run_output( - input, input_source, existing_run + input, input_source, prior_trace ) return run_output @@ -138,7 +153,7 @@ async def _run_returning_run_output( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> Tuple[TaskRun, RunOutput]: # validate input, allowing arrays if self.input_schema is not None: @@ -149,14 +164,7 @@ async def _run_returning_run_output( require_object=False, ) - if existing_run is not None and ( - not existing_run.trace or len(existing_run.trace) == 0 - ): - raise ValueError( - "Run has no trace. Cannot continue session without conversation history." - ) - - prior_trace = existing_run.trace if existing_run else None + prior_trace = prior_trace if prior_trace else None # Format model input for model call (we save the original input in the task without formatting) formatted_input = input @@ -215,28 +223,9 @@ async def _run_returning_run_output( "Reasoning is required for this model, but no reasoning was returned." ) - # Create the run and output - merge if there is an existing run - if existing_run is not None: - merged_output = RunOutput( - output=parsed_output.output, - intermediate_outputs=parsed_output.intermediate_outputs - or run_output.intermediate_outputs, - output_logprobs=parsed_output.output_logprobs - or run_output.output_logprobs, - trace=run_output.trace, - ) - run = self.generate_run( - input, - input_source, - merged_output, - usage, - run_output.trace, - existing_run=existing_run, - ) - else: - run = self.generate_run( - input, input_source, parsed_output, usage, run_output.trace - ) + run = self.generate_run( + input, input_source, parsed_output, usage, run_output.trace + ) # Save the run if configured to do so, and we have a path to save to if ( @@ -245,7 +234,7 @@ async def _run_returning_run_output( and self.task.path is not None ): run.save_to_file() - elif existing_run is None: + else: # Clear the ID to indicate it's not persisted run.id = None @@ -255,7 +244,7 @@ async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> Tuple[TaskRun, RunOutput]: # Determine if this is the root agent (no existing run context) is_root_agent = get_agent_run_id() is None @@ -266,7 +255,7 @@ async def invoke_returning_run_output( try: return await self._run_returning_run_output( - input, input_source, existing_run + input, input_source, prior_trace ) finally: if is_root_agent: @@ -277,6 +266,134 @@ async def invoke_returning_run_output( finally: clear_agent_run_id() + def invoke_openai_stream( + self, + input: InputType, + input_source: DataSource | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> OpenAIStreamResult: + """Stream raw OpenAI-protocol chunks for the task execution. + + Returns an async-iterable that yields ``ModelResponseStream`` chunks + as they arrive from the model. After the iterator is exhausted the + run has been validated and saved (when configured). The resulting + ``TaskRun`` is available via the ``.task_run`` property. + + Tool-call rounds happen internally and are not surfaced; use + ``invoke_ai_sdk_stream`` if you need tool-call events. + """ + return OpenAIStreamResult(self, input, input_source, prior_trace) + + def invoke_ai_sdk_stream( + self, + input: InputType, + input_source: DataSource | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> AiSdkStreamResult: + """Stream AI SDK protocol events for the task execution. + + Returns an async-iterable that yields ``AiSdkStreamEvent`` instances + covering text, reasoning, tool-call lifecycle, step boundaries, and + control events. After the iterator is exhausted the resulting + ``TaskRun`` is available via the ``.task_run`` property. + """ + return AiSdkStreamResult(self, input, input_source, prior_trace) + + def _prepare_stream( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None, + ) -> AdapterStream: + if self.input_schema is not None: + validate_schema_with_value_error( + input, + self.input_schema, + "This task requires a specific input schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.", + require_object=False, + ) + + prior_trace = prior_trace if prior_trace else None + + formatted_input = input + formatter_id = self.model_provider().formatter + if formatter_id is not None: + formatter = request_formatter_from_id(formatter_id) + formatted_input = formatter.format_input(input) + + return self._create_run_stream(formatted_input, prior_trace) + + def _finalize_stream( + self, + adapter_stream: AdapterStream, + input: InputType, + input_source: DataSource | None, + prior_trace: list[ChatCompletionMessageParam] | None, + ) -> TaskRun: + """Streaming invocations are only concerned with passing through events as they come in. + At the end of the stream, we still need to validate the output, create a run and everything + else that a non-streaming invocation would do. + """ + + result: AdapterStreamResult = adapter_stream.result + run_output = result.run_output + usage = result.usage + + provider = self.model_provider() + parser = model_parser_from_id(provider.parser) + parsed_output = parser.parse_output(original_output=run_output) + + if self.output_schema is not None: + if isinstance(parsed_output.output, str): + parsed_output.output = parse_json_string(parsed_output.output) + if not isinstance(parsed_output.output, dict): + raise RuntimeError( + f"structured response is not a dict: {parsed_output.output}" + ) + validate_schema_with_value_error( + parsed_output.output, + self.output_schema, + "This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.", + ) + else: + if not isinstance(parsed_output.output, str): + raise RuntimeError( + f"response is not a string for non-structured task: {parsed_output.output}" + ) + + trace_has_toolcalls = parsed_output.trace is not None and any( + message.get("role", None) == "tool" for message in parsed_output.trace + ) + if ( + provider.reasoning_capable + and ( + not parsed_output.intermediate_outputs + or "reasoning" not in parsed_output.intermediate_outputs + ) + and not ( + provider.reasoning_optional_for_structured_output + and self.has_structured_output() + ) + and not trace_has_toolcalls + ): + raise RuntimeError( + "Reasoning is required for this model, but no reasoning was returned." + ) + + run = self.generate_run( + input, input_source, parsed_output, usage, run_output.trace + ) + + if ( + self.base_adapter_config.allow_saving + and Config.shared().autosave_runs + and self.task.path is not None + ): + run.save_to_file() + else: + run.id = None + + return run + def has_structured_output(self) -> bool: return self.output_schema is not None @@ -292,6 +409,14 @@ async def _run( ) -> Tuple[RunOutput, Usage | None]: pass + def _create_run_stream( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> AdapterStream: + """Create a stream for the adapter. Implementations must override this method to support streaming.""" + raise NotImplementedError("Streaming is not supported for this adapter type") + def build_prompt(self) -> str: if self.prompt_builder is None: raise ValueError("Prompt builder is not available for MCP run config") @@ -372,7 +497,6 @@ def generate_run( run_output: RunOutput, usage: Usage | None = None, trace: list[ChatCompletionMessageParam] | None = None, - existing_run: TaskRun | None = None, ) -> TaskRun: output_str = ( json.dumps(run_output.output, ensure_ascii=False) @@ -395,26 +519,6 @@ def generate_run( ), ) - if existing_run is not None: - accumulated_usage = existing_run.usage - if usage is not None: - if accumulated_usage is not None: - accumulated_usage = accumulated_usage + usage - else: - accumulated_usage = usage - - merged_intermediate = dict(existing_run.intermediate_outputs or {}) - if run_output.intermediate_outputs: - for k, v in run_output.intermediate_outputs.items(): - merged_intermediate[k] = v - - existing_run.output = new_output - existing_run.trace = trace - existing_run.usage = accumulated_usage - existing_run.intermediate_outputs = merged_intermediate - - return existing_run - # Convert input and output to JSON strings if they aren't strings input_str = ( input if isinstance(input, str) else json.dumps(input, ensure_ascii=False) @@ -503,3 +607,137 @@ async def available_tools(self) -> list[KilnToolInterface]: ) return tools + + +class OpenAIStreamResult: + """Async-iterable wrapper around the OpenAI streaming flow. + + Yields ``ModelResponseStream`` chunks. After iteration the resulting + ``TaskRun`` is available via the ``.task_run`` property. + """ + + def __init__( + self, + adapter: BaseAdapter, + input: InputType, + input_source: DataSource | None, + prior_trace: list[ChatCompletionMessageParam] | None, + ) -> None: + self._adapter = adapter + self._input = input + self._input_source = input_source + self._prior_trace = prior_trace + self._task_run: TaskRun | None = None + + @property + def task_run(self) -> TaskRun: + if self._task_run is None: + raise RuntimeError( + "Stream has not been fully consumed yet. " + "Iterate over the stream before accessing .task_run" + ) + return self._task_run + + async def __aiter__(self) -> AsyncIterator[ModelResponseStream]: + self._task_run = None + is_root_agent = get_agent_run_id() is None + if is_root_agent: + set_agent_run_id(generate_agent_run_id()) + + try: + adapter_stream = self._adapter._prepare_stream( + self._input, self._prior_trace + ) + + async for event in adapter_stream: + if isinstance(event, ModelResponseStream): + yield event + + self._task_run = self._adapter._finalize_stream( + adapter_stream, self._input, self._input_source, self._prior_trace + ) + finally: + if is_root_agent: + try: + run_id = get_agent_run_id() + if run_id: + await MCPSessionManager.shared().cleanup_session(run_id) + finally: + clear_agent_run_id() + + +class AiSdkStreamResult: + """Async-iterable wrapper around the AI SDK streaming flow. + + Yields ``AiSdkStreamEvent`` instances. After iteration the resulting + ``TaskRun`` is available via the ``.task_run`` property. + """ + + def __init__( + self, + adapter: BaseAdapter, + input: InputType, + input_source: DataSource | None, + prior_trace: list[ChatCompletionMessageParam] | None, + ) -> None: + self._adapter = adapter + self._input = input + self._input_source = input_source + self._prior_trace = prior_trace + self._task_run: TaskRun | None = None + + @property + def task_run(self) -> TaskRun: + if self._task_run is None: + raise RuntimeError( + "Stream has not been fully consumed yet. " + "Iterate over the stream before accessing .task_run" + ) + return self._task_run + + async def __aiter__(self) -> AsyncIterator[AiSdkStreamEvent]: + self._task_run = None + is_root_agent = get_agent_run_id() is None + if is_root_agent: + set_agent_run_id(generate_agent_run_id()) + + try: + adapter_stream = self._adapter._prepare_stream( + self._input, self._prior_trace + ) + + message_id = f"msg-{uuid.uuid4().hex}" + converter = AiSdkStreamConverter() + + yield AiSdkStreamEvent(AiSdkEventType.START, {"messageId": message_id}) + yield AiSdkStreamEvent(AiSdkEventType.START_STEP) + + last_event_was_tool_call = False + async for event in adapter_stream: + if isinstance(event, ModelResponseStream): + if last_event_was_tool_call: + converter.reset_for_next_step() + last_event_was_tool_call = False + for ai_event in converter.convert_chunk(event): + yield ai_event + elif isinstance(event, ToolCallEvent): + last_event_was_tool_call = True + for ai_event in converter.convert_tool_event(event): + yield ai_event + + for ai_event in converter.finalize(): + yield ai_event + + yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP) + + self._task_run = self._adapter._finalize_stream( + adapter_stream, self._input, self._input_source, self._prior_trace + ) + finally: + if is_root_agent: + try: + run_id = get_agent_run_id() + if run_id: + await MCPSessionManager.shared().cleanup_session(run_id) + finally: + clear_agent_run_id() diff --git a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py index 131be097c..cebccdf57 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py @@ -25,6 +25,7 @@ ModelProviderName, StructuredOutputMode, ) +from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream from kiln_ai.adapters.model_adapters.base_adapter import ( AdapterConfig, BaseAdapter, @@ -261,6 +262,28 @@ async def _run( return output, usage + def _create_run_stream( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> AdapterStream: + provider = self.model_provider() + if not provider.model_id: + raise ValueError("Model ID is required for OpenAI compatible models") + + chat_formatter = self.build_chat_formatter(input, prior_trace) + initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( + chat_formatter.initial_messages() + ) + + return AdapterStream( + adapter=self, + provider=provider, + chat_formatter=chat_formatter, + initial_messages=initial_messages, + top_logprobs=self.base_adapter_config.top_logprobs, + ) + def _extract_and_validate_logprobs( self, final_choice: Choices | None ) -> ChoiceLogprobs | None: @@ -297,9 +320,10 @@ def _extract_reasoning_to_intermediate_outputs( intermediate_outputs["reasoning"] = stripped_reasoning_content async def acompletion_checking_response( - self, **kwargs + self, **kwargs: Any ) -> Tuple[ModelResponse, Choices]: response = await litellm.acompletion(**kwargs) + if ( not isinstance(response, ModelResponse) or not response.choices diff --git a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py index 45aabc53e..c488e7fc0 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py @@ -1,7 +1,10 @@ import json from typing import Tuple -from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig, BaseAdapter +from kiln_ai.adapters.model_adapters.base_adapter import ( + AdapterConfig, + BaseAdapter, +) from kiln_ai.adapters.parsers.json_parser import parse_json_string from kiln_ai.adapters.run_output import RunOutput from kiln_ai.datamodel import DataSource, Task, TaskRun, Usage @@ -85,16 +88,16 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> TaskRun: - if existing_run is not None: + if prior_trace: raise NotImplementedError( "Session continuation is not supported for MCP adapter. " "MCP tools are single-turn and do not maintain conversation state." ) run_output, _ = await self.invoke_returning_run_output( - input, input_source, existing_run + input, input_source, prior_trace ) return run_output @@ -102,13 +105,13 @@ async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Runs the task and returns both the persisted TaskRun and raw RunOutput. If this call is the root of a run, it creates an agent run context, ensures MCP tool calls have a valid session scope, and cleans up the session/context on completion. """ - if existing_run is not None: + if prior_trace: raise NotImplementedError( "Session continuation is not supported for MCP adapter. " "MCP tools are single-turn and do not maintain conversation state." diff --git a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py new file mode 100644 index 000000000..b2784484a --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from litellm.types.utils import ModelResponseStream + + +class AiSdkEventType(str, Enum): + START = "start" + FINISH = "finish" + ERROR = "error" + ABORT = "abort" + + TEXT_START = "text-start" + TEXT_DELTA = "text-delta" + TEXT_END = "text-end" + + REASONING_START = "reasoning-start" + REASONING_DELTA = "reasoning-delta" + REASONING_END = "reasoning-end" + + TOOL_INPUT_START = "tool-input-start" + TOOL_INPUT_DELTA = "tool-input-delta" + TOOL_INPUT_AVAILABLE = "tool-input-available" + TOOL_INPUT_ERROR = "tool-input-error" + + TOOL_OUTPUT_AVAILABLE = "tool-output-available" + TOOL_OUTPUT_ERROR = "tool-output-error" + + START_STEP = "start-step" + FINISH_STEP = "finish-step" + + METADATA = "metadata" + SOURCE_URL = "source-url" + SOURCE_DOCUMENT = "source-document" + FILE = "file" + + +@dataclass +class AiSdkStreamEvent: + type: AiSdkEventType + payload: dict[str, Any] = field(default_factory=dict) + + def model_dump(self) -> dict[str, Any]: + return { + "type": self.type.value, + **self.payload, + } + + +class ToolCallEventType(str, Enum): + INPUT_AVAILABLE = "input_available" + OUTPUT_AVAILABLE = "output_available" + OUTPUT_ERROR = "output_error" + + +@dataclass +class ToolCallEvent: + event_type: ToolCallEventType + tool_call_id: str + tool_name: str + arguments: dict[str, Any] | None = None + result: str | None = None + error: str | None = None + + +AdapterStreamEvent = ModelResponseStream | ToolCallEvent + + +class AiSdkStreamConverter: + """Stateful converter from OpenAI streaming chunks to AI SDK events.""" + + def __init__(self) -> None: + self._text_started = False + self._text_id = f"text-{uuid.uuid4().hex[:12]}" + self._reasoning_started = False + self._reasoning_id = f"reasoning-{uuid.uuid4().hex[:12]}" + self._reasoning_block_count = 0 + self._tool_calls_state: dict[int, dict[str, Any]] = {} + self._finish_reason: str | None = None + self._usage_data: Any = None + + def convert_chunk(self, chunk: ModelResponseStream) -> list[AiSdkStreamEvent]: + events: list[AiSdkStreamEvent] = [] + + for choice in chunk.choices: + if choice.finish_reason is not None: + self._finish_reason = choice.finish_reason + + delta = choice.delta + if delta is None: + continue + + reasoning_content = getattr(delta, "reasoning_content", None) + if reasoning_content is not None: + if not self._reasoning_started: + self._reasoning_block_count += 1 + self._reasoning_id = f"reasoning-{uuid.uuid4().hex[:12]}" + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_START, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = True + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_DELTA, + {"id": self._reasoning_id, "delta": reasoning_content}, + ) + ) + + if delta.content: + if self._reasoning_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_END, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = False + + if not self._text_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_START, + {"id": self._text_id}, + ) + ) + self._text_started = True + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_DELTA, + {"id": self._text_id, "delta": delta.content}, + ) + ) + + if delta.tool_calls: + if self._reasoning_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_END, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = False + + for tc_delta in delta.tool_calls: + idx = tc_delta.index + tc_state = self._tool_calls_state.setdefault( + idx, + { + "id": None, + "name": None, + "arguments": "", + "started": False, + }, + ) + + if tc_delta.id is not None: + tc_state["id"] = tc_delta.id + + func = getattr(tc_delta, "function", None) + if func is not None: + if func.name is not None: + tc_state["name"] = func.name + if func.arguments: + tc_state["arguments"] += func.arguments + + if tc_state["id"] and tc_state["name"] and not tc_state["started"]: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_INPUT_START, + { + "toolCallId": tc_state["id"], + "toolName": tc_state["name"], + }, + ) + ) + tc_state["started"] = True + + if func and func.arguments and tc_state["id"]: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_INPUT_DELTA, + { + "toolCallId": tc_state["id"], + "inputTextDelta": func.arguments, + }, + ) + ) + + if not chunk.choices: + usage = getattr(chunk, "usage", None) + if usage is not None: + self._usage_data = usage + + return events + + def convert_tool_event(self, event: ToolCallEvent) -> list[AiSdkStreamEvent]: + events: list[AiSdkStreamEvent] = [] + + if event.event_type == ToolCallEventType.INPUT_AVAILABLE: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_INPUT_AVAILABLE, + { + "toolCallId": event.tool_call_id, + "toolName": event.tool_name, + "input": event.arguments or {}, + }, + ) + ) + elif event.event_type == ToolCallEventType.OUTPUT_AVAILABLE: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_OUTPUT_AVAILABLE, + { + "toolCallId": event.tool_call_id, + "output": event.result, + }, + ) + ) + elif event.event_type == ToolCallEventType.OUTPUT_ERROR: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_OUTPUT_ERROR, + { + "toolCallId": event.tool_call_id, + "errorText": event.error or "Unknown error", + }, + ) + ) + + return events + + def finalize(self) -> list[AiSdkStreamEvent]: + events: list[AiSdkStreamEvent] = [] + + if self._reasoning_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_END, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = False + + if self._text_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_END, + {"id": self._text_id}, + ) + ) + self._text_started = False + + finish_payload: dict[str, Any] = {} + if self._finish_reason is not None: + finish_payload["finishReason"] = self._finish_reason.replace("_", "-") + + if self._usage_data is not None: + usage_payload: dict[str, Any] = { + "promptTokens": self._usage_data.prompt_tokens, + "completionTokens": self._usage_data.completion_tokens, + } + total = getattr(self._usage_data, "total_tokens", None) + if total is not None: + usage_payload["totalTokens"] = total + finish_payload["usage"] = usage_payload + + if finish_payload: + events.append( + AiSdkStreamEvent( + AiSdkEventType.FINISH, + {"messageMetadata": finish_payload}, + ) + ) + else: + events.append(AiSdkStreamEvent(AiSdkEventType.FINISH)) + + return events + + def reset_for_next_step(self) -> None: + """Reset per-step state between LLM calls in a multi-step flow.""" + self._tool_calls_state = {} + self._finish_reason = None diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py b/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py new file mode 100644 index 000000000..25715645c --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py @@ -0,0 +1,372 @@ +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from litellm.types.utils import ( + ChatCompletionMessageToolCall, + Choices, + Delta, + Function, + ModelResponse, + ModelResponseStream, + StreamingChoices, +) +from litellm.types.utils import Message as LiteLLMMessage + +from kiln_ai.adapters.chat import ChatFormatter +from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream +from kiln_ai.adapters.model_adapters.stream_events import ( + ToolCallEvent, + ToolCallEventType, +) +from kiln_ai.datamodel import Usage + + +def _make_streaming_chunk( + content: str | None = None, + finish_reason: str | None = None, +) -> ModelResponseStream: + delta = Delta(content=content) + choice = StreamingChoices( + index=0, + delta=delta, + finish_reason=finish_reason, + ) + return ModelResponseStream(id="test-stream", choices=[choice]) + + +def _make_model_response( + content: str = "Hello", + tool_calls: list[ChatCompletionMessageToolCall] | None = None, +) -> ModelResponse: + message = LiteLLMMessage(content=content, role="assistant") + if tool_calls is not None: + message.tool_calls = tool_calls + choice = Choices( + index=0, + message=message, + finish_reason="stop" if tool_calls is None else "tool_calls", + ) + return ModelResponse(id="test-response", choices=[choice]) + + +def _make_tool_call( + call_id: str = "call_1", + name: str = "add", + arguments: dict[str, Any] | None = None, +) -> ChatCompletionMessageToolCall: + args = json.dumps(arguments or {"a": 1, "b": 2}) + return ChatCompletionMessageToolCall( + id=call_id, + type="function", + function=Function(name=name, arguments=args), + ) + + +class FakeChatFormatter(ChatFormatter): + """A simple chat formatter that returns a single turn then None.""" + + def __init__(self, num_turns: int = 1): + self._turn_count = 0 + self._num_turns = num_turns + + def next_turn(self, prior_output: str | None): + if self._turn_count >= self._num_turns: + return None + self._turn_count += 1 + turn = MagicMock() + turn.messages = [MagicMock(role="user", content="test input")] + turn.final_call = self._turn_count == self._num_turns + return turn + + def intermediate_outputs(self): + return {} + + +class FakeStreamingCompletion: + """Mocks StreamingCompletion: yields chunks, then exposes .response""" + + def __init__( + self, + model_response: ModelResponse, + chunks: list[ModelResponseStream] | None = None, + ): + self._chunks = chunks or [ + _make_streaming_chunk(content="Hel"), + _make_streaming_chunk(content="lo"), + _make_streaming_chunk(finish_reason="stop"), + ] + self._response = model_response + + @property + def response(self): + return self._response + + async def __aiter__(self): + for chunk in self._chunks: + yield chunk + + +@pytest.fixture +def mock_adapter(): + adapter = MagicMock() + adapter.build_completion_kwargs = AsyncMock(return_value={"model": "test"}) + adapter.usage_from_response = MagicMock(return_value=Usage()) + adapter.process_tool_calls = AsyncMock(return_value=(None, [])) + adapter._extract_and_validate_logprobs = MagicMock(return_value=None) + adapter._extract_reasoning_to_intermediate_outputs = MagicMock() + adapter.all_messages_to_trace = MagicMock(return_value=[]) + adapter.base_adapter_config = MagicMock() + adapter.base_adapter_config.top_logprobs = None + return adapter + + +@pytest.fixture +def mock_provider(): + provider = MagicMock() + provider.model_id = "test-model" + return provider + + +class TestAdapterStreamSimple: + @pytest.mark.asyncio + async def test_simple_content_response(self, mock_adapter, mock_provider): + response = _make_model_response(content="Hello world") + fake_stream = FakeStreamingCompletion(response) + formatter = FakeChatFormatter() + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + return_value=fake_stream, + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=formatter, + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + chunks = [e for e in events if isinstance(e, ModelResponseStream)] + assert len(chunks) == 3 + + result = stream.result + assert result.run_output.output == "Hello world" + + @pytest.mark.asyncio + async def test_result_not_available_before_iteration( + self, mock_adapter, mock_provider + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + with pytest.raises(RuntimeError, match="not been iterated"): + _ = stream.result + + +class TestAdapterStreamToolCalls: + @pytest.mark.asyncio + async def test_tool_call_yields_events(self, mock_adapter, mock_provider): + tool_call = _make_tool_call( + call_id="call_1", name="add", arguments={"a": 1, "b": 2} + ) + tool_response = _make_model_response(content=None, tool_calls=[tool_call]) + final_response = _make_model_response(content="The answer is 3") + + tool_stream = FakeStreamingCompletion( + tool_response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + final_stream = FakeStreamingCompletion( + final_response, + [ + _make_streaming_chunk(content="The answer is 3"), + _make_streaming_chunk(finish_reason="stop"), + ], + ) + + streams_iter = iter([tool_stream, final_stream]) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=( + None, + [{"role": "tool", "tool_call_id": "call_1", "content": "3"}], + ) + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + side_effect=lambda **kw: next(streams_iter), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + tool_events = [e for e in events if isinstance(e, ToolCallEvent)] + assert len(tool_events) == 2 + + input_event = next( + e for e in tool_events if e.event_type == ToolCallEventType.INPUT_AVAILABLE + ) + assert input_event.tool_call_id == "call_1" + assert input_event.tool_name == "add" + assert input_event.arguments == {"a": 1, "b": 2} + + output_event = next( + e for e in tool_events if e.event_type == ToolCallEventType.OUTPUT_AVAILABLE + ) + assert output_event.tool_call_id == "call_1" + assert output_event.result == "3" + + assert stream.result.run_output.output == "The answer is 3" + + @pytest.mark.asyncio + async def test_task_response_tool_call(self, mock_adapter, mock_provider): + task_response_call = _make_tool_call( + call_id="call_tr", name="task_response", arguments={"result": "42"} + ) + response = _make_model_response(content=None, tool_calls=[task_response_call]) + + fake_stream = FakeStreamingCompletion( + response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=('{"result": "42"}', []) + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + return_value=fake_stream, + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + tool_events = [e for e in events if isinstance(e, ToolCallEvent)] + assert len(tool_events) == 0 + + assert stream.result.run_output.output == '{"result": "42"}' + + @pytest.mark.asyncio + async def test_too_many_tool_calls_raises(self, mock_adapter, mock_provider): + tool_call = _make_tool_call() + response = _make_model_response(content=None, tool_calls=[tool_call]) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=( + None, + [{"role": "tool", "tool_call_id": "call_1", "content": "ok"}], + ) + ) + + def make_stream(**kw): + return FakeStreamingCompletion( + response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + + with ( + patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + side_effect=make_stream, + ), + patch( + "kiln_ai.adapters.model_adapters.adapter_stream.MAX_TOOL_CALLS_PER_TURN", + 2, + ), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + with pytest.raises(RuntimeError, match="Too many tool calls"): + async for _ in stream: + pass + + @pytest.mark.asyncio + async def test_unparseable_tool_call_arguments(self, mock_adapter, mock_provider): + bad_tool_call = ChatCompletionMessageToolCall( + id="call_bad", + type="function", + function=Function(name="add", arguments="not json"), + ) + response = _make_model_response(content=None, tool_calls=[bad_tool_call]) + final_response = _make_model_response(content="fallback") + + tool_stream = FakeStreamingCompletion( + response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + final_stream = FakeStreamingCompletion( + final_response, + [ + _make_streaming_chunk(content="fallback"), + _make_streaming_chunk(finish_reason="stop"), + ], + ) + + streams_iter = iter([tool_stream, final_stream]) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=( + None, + [{"role": "tool", "tool_call_id": "call_bad", "content": "error"}], + ) + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + side_effect=lambda **kw: next(streams_iter), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + input_events = [ + e + for e in events + if isinstance(e, ToolCallEvent) + and e.event_type == ToolCallEventType.INPUT_AVAILABLE + ] + assert len(input_events) == 1 + assert input_events[0].arguments is None + assert "Failed to parse" in (input_events[0].error or "") diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index 8b150c68e..b000df81f 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -1,6 +1,13 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from litellm.types.utils import ( + ChatCompletionDeltaToolCall, + Delta, + Function, + ModelResponseStream, + StreamingChoices, +) from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode from kiln_ai.adapters.model_adapters.base_adapter import ( @@ -8,14 +15,13 @@ BaseAdapter, RunOutput, ) -from kiln_ai.adapters.prompt_builders import BasePromptBuilder -from kiln_ai.datamodel import ( - DataSource, - DataSourceType, - Task, - TaskOutput, - TaskRun, +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + ToolCallEvent, + ToolCallEventType, ) +from kiln_ai.adapters.prompt_builders import BasePromptBuilder +from kiln_ai.datamodel import Task, TaskRun from kiln_ai.datamodel.datamodel_enums import ChatStrategy, ModelProviderName from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig @@ -26,7 +32,7 @@ class MockAdapter(BaseAdapter): """Concrete implementation of BaseAdapter for testing""" - async def _run(self, input, prior_trace=None): + async def _run(self, input, **kwargs): return None, None def adapter_name(self) -> str: @@ -239,7 +245,7 @@ async def test_input_formatting( # Mock the _run method to capture the input captured_input = None - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): nonlocal captured_input captured_input = input return RunOutput(output="test output", intermediate_outputs={}), None @@ -437,7 +443,7 @@ def test_build_chat_formatter_with_prior_trace_returns_multiturn_formatter(adapt @pytest.mark.asyncio -async def test_existing_run_without_trace_raises(base_project): +async def test_invoke_with_prior_trace_none_starts_fresh(base_project): task = Task( name="test_task", instruction="test_instruction", @@ -452,29 +458,44 @@ async def test_existing_run_without_trace_raises(base_project): structured_output_mode=StructuredOutputMode.json_schema, ), ) - run_without_trace = TaskRun( - parent=task, - input="hi", - input_source=None, - output=TaskOutput( - output="hello", - source=DataSource( - type=DataSourceType.synthetic, - properties={ - "model_name": "gpt_4o", - "model_provider": "openai", - "adapter_name": "test", - }, + adapter._run = AsyncMock( + return_value=( + RunOutput(output="ok", intermediate_outputs=None, trace=None), + None, + ) + ) + with ( + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=MagicMock( + parse_output=MagicMock( + return_value=RunOutput( + output="ok", intermediate_outputs=None, trace=None + ) + ) ), ), - trace=None, - ) - with pytest.raises(ValueError, match="no trace"): - await adapter.invoke("input", existing_run=run_without_trace) + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", + ), + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + ): + run = await adapter.invoke("input", prior_trace=None) + assert run.output.output == "ok" + adapter._run.assert_called_once() + assert adapter._run.call_args[1].get("prior_trace") is None @pytest.mark.asyncio -async def test_invoke_returning_run_output_passes_existing_run_to_run( +async def test_invoke_returning_run_output_passes_prior_trace_to_run( adapter, mock_parser, tmp_path ): project = Project(name="proj", path=tmp_path / "proj.kiln") @@ -491,25 +512,12 @@ async def test_invoke_returning_run_output_passes_existing_run_to_run( {"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}, ] - initial_run = adapter.generate_run( - input="hi", - input_source=None, - run_output=RunOutput( - output="hello", - intermediate_outputs=None, - trace=trace, - ), - trace=trace, - ) - initial_run.save_to_file() - run_id = initial_run.id - assert run_id is not None captured_prior_trace = None - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): nonlocal captured_prior_trace - captured_prior_trace = prior_trace + captured_prior_trace = kwargs.get("prior_trace") return RunOutput(output="ok", intermediate_outputs=None, trace=trace), None adapter._run = mock_run @@ -532,7 +540,7 @@ async def mock_run(input, prior_trace=None): "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", ), ): - await adapter.invoke_returning_run_output("follow-up", existing_run=initial_run) + await adapter.invoke_returning_run_output("follow-up", prior_trace=trace) assert captured_prior_trace == trace @@ -798,7 +806,7 @@ async def test_invoke_sets_run_context(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): # Check that run ID is set during _run run_id = get_agent_run_id() assert run_id is not None @@ -838,7 +846,7 @@ async def test_invoke_clears_run_context_after(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): return RunOutput(output="test output", intermediate_outputs={}), None adapter._run = mock_run @@ -876,7 +884,7 @@ async def test_invoke_clears_run_context_on_error(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method to raise an error - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): # Run ID should be set even when error occurs run_id = get_agent_run_id() assert run_id is not None @@ -913,7 +921,7 @@ async def test_sub_agent_inherits_run(self, adapter, clear_context): set_agent_run_id(parent_run_id) # Mock the _run method to check inherited run ID - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): # Sub-agent should see parent's run ID run_id = get_agent_run_id() assert run_id == parent_run_id @@ -962,7 +970,7 @@ async def test_sub_agent_does_not_create_new_run(self, adapter, clear_context): run_id_during_run = None # Mock the _run method to capture run ID - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): nonlocal run_id_during_run run_id_during_run = get_agent_run_id() return RunOutput(output="test output", intermediate_outputs={}), None @@ -1002,7 +1010,7 @@ async def test_cleanup_session_called_on_completion(self, adapter, clear_context from kiln_ai.adapters.run_output import RunOutput # Mock the _run method - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): return RunOutput(output="test output", intermediate_outputs={}), None adapter._run = mock_run @@ -1045,3 +1053,185 @@ async def mock_run(input, prior_trace=None): assert call_args is not None run_id = call_args[0][0] if call_args[0] else call_args[1]["run_id"] assert run_id.startswith("run_") + + +class TestStreamMethods: + """Tests for the streaming methods on BaseAdapter.""" + + @pytest.fixture + def stream_adapter(self, base_task): + return MockAdapter( + task=base_task, + run_config=KilnAgentRunConfigProperties( + model_name="test_model", + model_provider_name="openai", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + ) + + @pytest.mark.asyncio + async def test_invoke_openai_stream_raises_for_unsupported_adapter( + self, stream_adapter + ): + """MockAdapter does not implement _create_run_stream.""" + provider = MagicMock() + provider.formatter = None + stream_adapter.model_provider = MagicMock(return_value=provider) + + with pytest.raises(NotImplementedError, match="Streaming is not supported"): + async for _chunk in stream_adapter.invoke_openai_stream("test input"): + pass + + @pytest.mark.asyncio + async def test_invoke_ai_sdk_stream_raises_for_unsupported_adapter( + self, stream_adapter + ): + """MockAdapter does not implement _create_run_stream.""" + provider = MagicMock() + provider.formatter = None + stream_adapter.model_provider = MagicMock(return_value=provider) + + with pytest.raises(NotImplementedError, match="Streaming is not supported"): + async for _event in stream_adapter.invoke_ai_sdk_stream("test input"): + pass + + @pytest.mark.asyncio + async def test_invoke_ai_sdk_stream_resets_converter_between_tool_rounds( + self, stream_adapter + ): + """tool-input-start must be emitted for a new tool call at index 0 after a tool round.""" + + def _make_tool_chunk(call_id: str, name: str) -> ModelResponseStream: + func = Function(name=name, arguments='{"x":1}') + tc = ChatCompletionDeltaToolCall(index=0, function=func) + tc.id = call_id + delta = Delta(tool_calls=[tc]) + choice = StreamingChoices(index=0, delta=delta, finish_reason=None) + return ModelResponseStream(id="test", choices=[choice]) + + round1_chunk = _make_tool_chunk("call_r1", "tool_a") + round2_chunk = _make_tool_chunk("call_r2", "tool_b") + + fake_events = [ + round1_chunk, + ToolCallEvent( + event_type=ToolCallEventType.INPUT_AVAILABLE, + tool_call_id="call_r1", + tool_name="tool_a", + arguments={"x": 1}, + ), + ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_AVAILABLE, + tool_call_id="call_r1", + tool_name="tool_a", + result="done", + ), + round2_chunk, + ] + + class FakeAdapterStream: + result = MagicMock() + + async def __aiter__(self): + for event in fake_events: + yield event + + with ( + patch.object( + stream_adapter, + "_prepare_stream", + return_value=FakeAdapterStream(), + ), + patch.object(stream_adapter, "_finalize_stream"), + ): + events = [] + async for event in stream_adapter.invoke_ai_sdk_stream("test input"): + events.append(event) + + tool_input_starts = [ + e for e in events if e.type == AiSdkEventType.TOOL_INPUT_START + ] + assert len(tool_input_starts) == 2, ( + "tool-input-start must fire once per tool-call round" + ) + assert tool_input_starts[0].payload["toolCallId"] == "call_r1" + assert tool_input_starts[1].payload["toolCallId"] == "call_r2" + + @pytest.mark.asyncio + async def test_openai_stream_exposes_task_run_after_iteration(self, stream_adapter): + fake_chunk = ModelResponseStream( + id="test", + choices=[ + StreamingChoices( + index=0, + delta=Delta(content="hi"), + finish_reason=None, + ) + ], + ) + + class FakeAdapterStream: + result = MagicMock() + + async def __aiter__(self): + yield fake_chunk + + expected_run = MagicMock(spec=TaskRun) + + with ( + patch.object( + stream_adapter, + "_prepare_stream", + return_value=FakeAdapterStream(), + ), + patch.object(stream_adapter, "_finalize_stream", return_value=expected_run), + ): + stream = stream_adapter.invoke_openai_stream("test input") + + with pytest.raises(RuntimeError, match="not been fully consumed"): + _ = stream.task_run + + async for _chunk in stream: + pass + + assert stream.task_run is expected_run + + @pytest.mark.asyncio + async def test_ai_sdk_stream_exposes_task_run_after_iteration(self, stream_adapter): + fake_chunk = ModelResponseStream( + id="test", + choices=[ + StreamingChoices( + index=0, + delta=Delta(content="hi"), + finish_reason=None, + ) + ], + ) + + class FakeAdapterStream: + result = MagicMock() + + async def __aiter__(self): + yield fake_chunk + + expected_run = MagicMock(spec=TaskRun) + + with ( + patch.object( + stream_adapter, + "_prepare_stream", + return_value=FakeAdapterStream(), + ), + patch.object(stream_adapter, "_finalize_stream", return_value=expected_run), + ): + stream = stream_adapter.invoke_ai_sdk_stream("test input") + + with pytest.raises(RuntimeError, match="not been fully consumed"): + _ = stream.task_run + + async for _event in stream: + pass + + assert stream.task_run is expected_run diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py index 3b44812a6..6ec73e6b6 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py @@ -1212,7 +1212,11 @@ async def test_array_input_converted_to_json(tmp_path, config): mock_config_obj.user_id = "test_user" with ( - patch("litellm.acompletion", new=AsyncMock(return_value=mock_response)), + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), + ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config_obj), ): array_input = [1, 2, 3, 4, 5] @@ -1282,7 +1286,11 @@ async def test_dict_input_converted_to_json(tmp_path, config): mock_config_obj.user_id = "test_user" with ( - patch("litellm.acompletion", new=AsyncMock(return_value=mock_response)), + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), + ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config_obj), ): dict_input = {"x": 10, "y": 20} @@ -1362,3 +1370,121 @@ async def mock_run_model_turn( assert run_output.trace[1]["content"] == "hello" assert run_output.trace[2]["content"] == "follow-up" assert run_output.trace[3]["content"] == "How can I help?" + + +@pytest.mark.asyncio +async def test_run_with_prior_trace_preserves_tool_calls(mock_task): + """Prior trace containing tool calls should be passed through to the model and preserved in the output trace.""" + config = LiteLlmConfig( + base_url="https://api.test.com", + run_config_properties=KilnAgentRunConfigProperties( + model_name="test-model", + model_provider_name="openai_compatible", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + default_headers={"X-Test": "test"}, + additional_body_options={"api_key": "test_key"}, + ) + + prior_trace = [ + {"role": "system", "content": "Use the math tools."}, + {"role": "user", "content": "4"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "Let me multiply 4 by 7.\n", + "tool_calls": [ + { + "id": "call_abc123", + "function": {"arguments": '{"a": 4, "b": 7}', "name": "multiply"}, + "type": "function", + } + ], + }, + { + "content": "28", + "role": "tool", + "tool_call_id": "call_abc123", + "kiln_task_tool_data": None, + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Now add 144.\n", + "tool_calls": [ + { + "id": "call_def456", + "function": {"arguments": '{"a": 28, "b": 144}', "name": "add"}, + "type": "function", + } + ], + }, + { + "content": "172", + "role": "tool", + "tool_call_id": "call_def456", + "kiln_task_tool_data": None, + }, + { + "role": "assistant", + "content": "There were 172 distinct species of giant tortoises.", + "reasoning_content": "Now I have 172.\n", + }, + ] + adapter = LiteLlmAdapter(config=config, kiln_task=mock_task) + + captured_messages = [] + + async def mock_run_model_turn( + provider, prior_messages, top_logprobs, skip_response_format + ): + captured_messages.extend(prior_messages) + extended = list(prior_messages) + extended.append({"role": "assistant", "content": '{"test": "response"}'}) + return ModelTurnResult( + assistant_message='{"test": "response"}', + all_messages=extended, + model_response=None, + model_choice=None, + usage=Usage(), + ) + + adapter._run_model_turn = mock_run_model_turn + + run_output, _ = await adapter._run("what else?", prior_trace=prior_trace) + + assert run_output.trace is not None + # 7 prior trace messages + 1 new user + 1 new assistant = 9 + assert len(run_output.trace) == 9 + + # Verify tool call messages are preserved in the trace + assistant_with_tools = run_output.trace[2] + assert assistant_with_tools["role"] == "assistant" + assert assistant_with_tools["tool_calls"][0]["id"] == "call_abc123" + assert assistant_with_tools["tool_calls"][0]["function"]["name"] == "multiply" + assert assistant_with_tools["reasoning_content"] == "Let me multiply 4 by 7.\n" + + tool_response = run_output.trace[3] + assert tool_response["role"] == "tool" + assert tool_response["tool_call_id"] == "call_abc123" + assert tool_response["content"] == "28" + + second_tool_call = run_output.trace[4] + assert second_tool_call["tool_calls"][0]["id"] == "call_def456" + assert second_tool_call["tool_calls"][0]["function"]["name"] == "add" + + second_tool_response = run_output.trace[5] + assert second_tool_response["role"] == "tool" + assert second_tool_response["tool_call_id"] == "call_def456" + assert second_tool_response["content"] == "172" + + # Verify the tool call messages were passed to _run_model_turn (i.e., sent to the model) + assert any( + m.get("tool_calls") is not None + for m in captured_messages + if isinstance(m, dict) + ) + assert any( + m.get("role") == "tool" for m in captured_messages if isinstance(m, dict) + ) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py new file mode 100644 index 000000000..65c540376 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py @@ -0,0 +1,308 @@ +import json +import logging +import re +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Callable + +import litellm +import pytest +from litellm.types.utils import ChatCompletionDeltaToolCall + +from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode +from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter +from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + AiSdkStreamEvent, +) +from kiln_ai.datamodel import Project, PromptGenerators, Task +from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig +from kiln_ai.datamodel.tool_id import KilnBuiltInToolId + +logger = logging.getLogger(__name__) + +STREAMING_MODELS = [ + ("claude_sonnet_4_5", ModelProviderName.openrouter), + ("claude_sonnet_4_5", ModelProviderName.anthropic), + ("claude_sonnet_4_6", ModelProviderName.openrouter), + ("claude_sonnet_4_6", ModelProviderName.anthropic), + ("claude_opus_4_5", ModelProviderName.openrouter), + ("claude_opus_4_5", ModelProviderName.anthropic), + ("claude_opus_4_6", ModelProviderName.openrouter), + ("claude_opus_4_6", ModelProviderName.anthropic), + ("minimax_m2_5", ModelProviderName.openrouter), + ("claude_4_5_haiku", ModelProviderName.openrouter), + ("claude_4_5_haiku", ModelProviderName.anthropic), +] + +STREAMING_MODELS_NO_HAIKU = [m for m in STREAMING_MODELS if "haiku" not in m[0]] + +PAID_TEST_OUTPUT_DIR = Path(__file__).resolve().parents[5] / "test_output" + + +def _serialize_for_dump(obj: Any) -> Any: + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + if isinstance(obj, list): + if not obj: + return [] + first = obj[0] + if hasattr(first, "type") and hasattr(first, "payload"): + return [{"type": e.type.value, "payload": e.payload} for e in obj] + if hasattr(first, "model_dump"): + return [item.model_dump(mode="json") for item in obj] + return [_serialize_for_dump(x) for x in obj] + return obj + + +def _dump_paid_test_output(request: pytest.FixtureRequest, **payloads: Any) -> Path: + test_name = re.sub(r"[^\w\-]", "_", request.node.name) + param_id = "default" + if hasattr(request.node, "callspec") and request.node.callspec is not None: + id_attr = getattr(request.node.callspec, "id", None) + if id_attr is not None: + param_id = re.sub(r"[^\w\-]", "_", str(id_attr)) + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S") + out_dir = PAID_TEST_OUTPUT_DIR / test_name / param_id / timestamp + out_dir.mkdir(parents=True, exist_ok=True) + for filename, data in payloads.items(): + if data is None: + continue + if not filename.endswith(".json"): + filename = f"{filename}.json" + serialized = _serialize_for_dump(data) + (out_dir / filename).write_text( + json.dumps(serialized, indent=2, default=str), encoding="utf-8" + ) + return out_dir + + +@pytest.fixture +def task(tmp_path): + project_path: Path = tmp_path / "test_project" / "project.kiln" + project_path.parent.mkdir() + + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + task = Task( + name="Streaming Test Task", + instruction="Think about it hard! Solve the math problem provided by the user, in a step by step manner. Use the tools provided to solve the math problem. Then use the result in a short sentence about a cat going to the mall. Remember to use the tools for math even if the operation looks easy.", + parent=project, + ) + task.save_to_file() + return task + + +@pytest.fixture +def adapter_factory(task: Task) -> Callable[[str, ModelProviderName], LiteLlmAdapter]: + def create_adapter( + model_id: str, provider_name: ModelProviderName + ) -> LiteLlmAdapter: + return LiteLlmAdapter( + kiln_task=task, + config=LiteLlmConfig( + run_config_properties=KilnAgentRunConfigProperties( + model_name=model_id, + model_provider_name=provider_name, + prompt_id=PromptGenerators.SIMPLE, + structured_output_mode=StructuredOutputMode.unknown, + tools_config=ToolsRunConfig( + tools=[ + KilnBuiltInToolId.ADD_NUMBERS, + KilnBuiltInToolId.SUBTRACT_NUMBERS, + KilnBuiltInToolId.MULTIPLY_NUMBERS, + KilnBuiltInToolId.DIVIDE_NUMBERS, + ], + ), + ) + ), + ) + + return create_adapter + + +@pytest.mark.paid +@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS) +async def test_invoke_openai_stream_chunks( + request: pytest.FixtureRequest, + model_id: str, + provider_name: ModelProviderName, + adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], +): + """Collect all OpenAI-protocol chunks via invoke_openai_stream and verify we got reasoning, content, and tool call data.""" + adapter = adapter_factory(model_id, provider_name) + + chunks: list[litellm.ModelResponseStream] = [] + async for chunk in adapter.invoke_openai_stream(input="123 + 321 = ?"): + chunks.append(chunk) + + _dump_paid_test_output(request, chunks=chunks) + assert len(chunks) > 0, "No chunks collected" + + reasoning_contents: list[str] = [] + contents: list[str] = [] + tool_calls: list[ChatCompletionDeltaToolCall | Any] = [] + + for chunk in chunks: + if chunk.choices[0].finish_reason is not None: + continue + delta = chunk.choices[0].delta + if delta is None: + continue + if delta.tool_calls is not None: + tool_calls.extend(delta.tool_calls) + elif getattr(delta, "reasoning_content", None) is not None: + text = getattr(delta, "reasoning_content", None) + if text is not None: + reasoning_contents.append(text) + elif delta.content is not None: + contents.append(delta.content) + + assert len(reasoning_contents) > 0, "No reasoning content in chunks" + assert len(contents) > 0, "No content in chunks" + assert len(tool_calls) > 0, "No tool calls in chunks" + assert not all(r.strip() == "" for r in reasoning_contents), ( + "All reasoning content in chunks is empty" + ) + assert not all(c.strip() == "" for c in contents), "All content in chunks is empty" + + tool_call_function_names = [ + tc.function.name for tc in tool_calls if tc.function.name is not None + ] + assert len(tool_call_function_names) == 1, ( + "Expected exactly one tool call function name" + ) + assert tool_call_function_names[0] == "add", "Tool call function name is not 'add'" + + tool_call_args_chunks = "".join( + tc.function.arguments for tc in tool_calls if tc.function.arguments is not None + ) + tool_call_args = json.loads(tool_call_args_chunks) + assert tool_call_args == {"a": 123, "b": 321} or tool_call_args == { + "a": 321, + "b": 123, + }, f"Tool call arguments not as expected: {tool_call_args}" + + +@pytest.mark.paid +@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS) +async def test_invoke_ai_sdk_stream( + request: pytest.FixtureRequest, + model_id: str, + provider_name: ModelProviderName, + adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], +): + """Collect AI SDK events and verify the full protocol lifecycle including tool events.""" + adapter = adapter_factory(model_id, provider_name) + + events: list[AiSdkStreamEvent] = [] + async for event in adapter.invoke_ai_sdk_stream(input="123 + 321 = ?"): + events.append(event) + logger.info(f"AI SDK event: {event.type.value} {event.payload}") + + _dump_paid_test_output(request, events=events) + assert len(events) > 0, "No events collected" + + event_types = [e.type for e in events] + + assert event_types[0] == AiSdkEventType.START, "First event should be START" + assert event_types[1] == AiSdkEventType.START_STEP, ( + "Second event should be START_STEP" + ) + + assert AiSdkEventType.FINISH_STEP in event_types, "Should have FINISH_STEP" + assert AiSdkEventType.FINISH in event_types, "Should have FINISH" + + assert AiSdkEventType.REASONING_START in event_types, "Should have REASONING_START" + assert AiSdkEventType.REASONING_DELTA in event_types, "Should have REASONING_DELTA" + + assert AiSdkEventType.TEXT_START in event_types, "Should have TEXT_START" + assert AiSdkEventType.TEXT_DELTA in event_types, "Should have TEXT_DELTA" + assert AiSdkEventType.TEXT_END in event_types, "Should have TEXT_END" + + assert AiSdkEventType.TOOL_INPUT_START in event_types, ( + "Should have TOOL_INPUT_START" + ) + assert AiSdkEventType.TOOL_INPUT_AVAILABLE in event_types, ( + "Should have TOOL_INPUT_AVAILABLE" + ) + assert AiSdkEventType.TOOL_OUTPUT_AVAILABLE in event_types, ( + "Should have TOOL_OUTPUT_AVAILABLE" + ) + + text_deltas = [ + e.payload.get("delta", "") + for e in events + if e.type == AiSdkEventType.TEXT_DELTA + ] + full_text = "".join(text_deltas) + assert len(full_text) > 0, "Text content is empty" + + tool_input_available = [ + e for e in events if e.type == AiSdkEventType.TOOL_INPUT_AVAILABLE + ] + assert len(tool_input_available) >= 1, ( + "Should have at least one tool-input-available" + ) + tool_input = tool_input_available[0].payload.get("input", {}) + assert "a" in tool_input and "b" in tool_input, ( + f"Tool input should have a and b keys: {tool_input}" + ) + + tool_output_available = [ + e for e in events if e.type == AiSdkEventType.TOOL_OUTPUT_AVAILABLE + ] + assert len(tool_output_available) >= 1, ( + "Should have at least one tool-output-available" + ) + assert tool_output_available[0].payload.get("output") is not None, ( + "Tool output should not be None" + ) + + +@pytest.mark.paid +@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS_NO_HAIKU) +async def test_invoke_openai_stream_non_streaming_still_works( + request: pytest.FixtureRequest, + model_id: str, + provider_name: ModelProviderName, + adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], +): + """Verify the non-streaming invoke() still works after the refactor.""" + adapter = adapter_factory(model_id, provider_name) + task_run = await adapter.invoke(input="123 + 321 = ?") + + _dump_paid_test_output(request, task_run=task_run) + assert task_run.trace is not None, "Task run trace is None" + assert len(task_run.trace) > 0, "Task run trace is empty" + assert "444" in task_run.output.output, ( + f"Expected 444 in output: {task_run.output.output}" + ) + + +@pytest.mark.paid +@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS_NO_HAIKU) +async def test_invoke_openai_stream_with_prior_trace( + request: pytest.FixtureRequest, + model_id: str, + provider_name: ModelProviderName, + adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], +): + """Test that streaming works when continuing an existing run (session continuation).""" + adapter = adapter_factory(model_id, provider_name) + + initial_run = await adapter.invoke(input="123 + 321 = ?") + assert initial_run.trace is not None + assert len(initial_run.trace) > 0 + + continuation_chunks: list[litellm.ModelResponseStream] = [] + async for chunk in adapter.invoke_openai_stream( + input="What was the result? Reply in one short sentence.", + prior_trace=initial_run.trace, + ): + continuation_chunks.append(chunk) + + _dump_paid_test_output(request, continuation_chunks=continuation_chunks) + assert len(continuation_chunks) > 0, "No continuation chunks collected" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py index fb7bc4c21..3674fe5b3 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from litellm.types.utils import ModelResponse @@ -287,10 +287,17 @@ async def test_tools_simplied_mocked(tmp_path): mock_config.open_ai_api_key = "mock_api_key" mock_config.user_id = "test_user" + responses = [mock_response_1, mock_response_2] + + async def mock_acompletion_checking_response(self, **kwargs): + response = responses.pop(0) + return response, response.choices[0] + with ( - patch( - "litellm.acompletion", - side_effect=[mock_response_1, mock_response_2], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=mock_acompletion_checking_response, ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): @@ -386,9 +393,16 @@ async def test_tools_mocked(tmp_path): mock_config.user_id = "test_user" with ( - patch( - "litellm.acompletion", - side_effect=[mock_response_1, mock_response_2, mock_response_3], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock( + side_effect=[ + (mock_response_1, mock_response_1.choices[0]), + (mock_response_2, mock_response_2.choices[0]), + (mock_response_3, mock_response_3.choices[0]), + ] + ), ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py index ad5826d8c..cb0a3e94b 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py @@ -334,7 +334,7 @@ async def test_mcp_adapter_sets_and_clears_run_context( async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( project_with_local_mcp_server, local_mcp_tool_id ): - """Session continuation (existing_run) is not supported for MCP adapter.""" + """Session continuation (prior_trace) is not supported for MCP adapter.""" project, _ = project_with_local_mcp_server task = Task( name="Test Task", @@ -352,7 +352,9 @@ async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( existing_run.trace = [{"role": "user", "content": "hi"}] with pytest.raises(NotImplementedError) as exc_info: - await adapter.invoke_returning_run_output("input", existing_run=existing_run) + await adapter.invoke_returning_run_output( + "input", prior_trace=existing_run.trace + ) assert "Session continuation is not supported" in str(exc_info.value) assert "MCP adapter" in str(exc_info.value) @@ -362,7 +364,7 @@ async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( async def test_mcp_adapter_rejects_multiturn_invoke( project_with_local_mcp_server, local_mcp_tool_id ): - """invoke with existing_run raises NotImplementedError for MCP adapter.""" + """invoke with prior_trace raises NotImplementedError for MCP adapter.""" project, _ = project_with_local_mcp_server task = Task( name="Test Task", @@ -380,7 +382,7 @@ async def test_mcp_adapter_rejects_multiturn_invoke( existing_run.trace = [{"role": "user", "content": "hi"}] with pytest.raises(NotImplementedError) as exc_info: - await adapter.invoke("input", existing_run=existing_run) + await adapter.invoke("input", prior_trace=existing_run.trace) assert "Session continuation is not supported" in str(exc_info.value) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index f96c70a41..db20fe5ea 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -10,11 +10,7 @@ class MockAdapter(BaseAdapter): - async def _run( - self, - input: InputType, - prior_trace=None, - ) -> tuple[RunOutput, Usage | None]: + async def _run(self, input: InputType, **kwargs) -> tuple[RunOutput, Usage | None]: return RunOutput(output="Test output", intermediate_outputs=None), None def adapter_name(self) -> str: @@ -239,7 +235,7 @@ async def test_autosave_true(test_task, adapter): @pytest.mark.asyncio async def test_invoke_continue_session(test_task, adapter): - """Test that invoke with task_run_id continues a session and updates the run.""" + """Test that invoke with prior_trace continues a session and creates a new run.""" with patch("kiln_ai.utils.config.Config.shared") as mock_shared: mock_config = mock_shared.return_value mock_config.autosave_runs = True @@ -249,21 +245,9 @@ async def test_invoke_continue_session(test_task, adapter): {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] - initial_run = adapter.generate_run( - input="Hello", - input_source=None, - run_output=RunOutput( - output="Hi there!", - intermediate_outputs=None, - trace=trace, - ), - trace=trace, - ) - initial_run.save_to_file() - run_id = initial_run.id - assert run_id is not None - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): + prior_trace = kwargs.get("prior_trace") if prior_trace is not None: extended_trace = [ *prior_trace, @@ -309,14 +293,14 @@ async def mock_run(input, prior_trace=None): ) mock_parser_from_id.return_value = mock_parser - updated_run = await adapter.invoke("Tell me more", existing_run=initial_run) + new_run = await adapter.invoke("Tell me more", prior_trace=trace) - assert updated_run.id == run_id - assert updated_run.input == "Hello" - assert updated_run.output.output == "How can I help?" - assert len(updated_run.trace) == 4 - assert updated_run.trace[-2]["content"] == "Tell me more" - assert updated_run.trace[-1]["content"] == "How can I help?" + assert new_run.id is not None + assert new_run.input == "Tell me more" + assert new_run.output.output == "How can I help?" + assert len(new_run.trace) == 4 + assert new_run.trace[-2]["content"] == "Tell me more" + assert new_run.trace[-1]["content"] == "How can I help?" reloaded = Task.load_from_file(test_task.path) runs = reloaded.runs() @@ -325,36 +309,55 @@ async def mock_run(input, prior_trace=None): @pytest.mark.asyncio -async def test_invoke_continue_run_without_trace(test_task, adapter): - """Test that invoke with existing_run that has no trace raises ValueError.""" +async def test_invoke_with_empty_prior_trace_starts_fresh(test_task, adapter): + """Test that invoke with prior_trace=[] starts a fresh conversation (no error).""" with patch("kiln_ai.utils.config.Config.shared") as mock_shared: mock_config = mock_shared.return_value mock_config.autosave_runs = True mock_config.user_id = "test_user" - run_without_trace = adapter.generate_run( - input="Hello", - input_source=None, - run_output=RunOutput( - output="Hi", - intermediate_outputs=None, - trace=None, - ), + adapter._run = AsyncMock( + return_value=( + RunOutput(output="Fresh reply", intermediate_outputs=None, trace=None), + None, + ) ) - run_without_trace.save_to_file() - - with pytest.raises(ValueError, match="no trace"): - await adapter.invoke("Follow up", existing_run=run_without_trace) + with ( + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=MagicMock( + parse_output=MagicMock( + return_value=RunOutput( + output="Fresh reply", + intermediate_outputs=None, + trace=None, + ) + ) + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", + ), + ): + run = await adapter.invoke("Follow up", prior_trace=[]) + assert run.output.output == "Fresh reply" -def test_generate_run_with_existing_run_merges_usage_and_intermediate_outputs( - test_task, adapter -): +def test_generate_run_always_creates_new_task_run(test_task, adapter): trace = [ {"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}, ] - initial_run = adapter.generate_run( + run1 = adapter.generate_run( input="hi", input_source=None, run_output=RunOutput( @@ -370,8 +373,8 @@ def test_generate_run_with_existing_run_merges_usage_and_intermediate_outputs( {"role": "user", "content": "follow-up"}, {"role": "assistant", "content": "ok"}, ] - result = adapter.generate_run( - input="hi", + run2 = adapter.generate_run( + input="follow-up", input_source=None, run_output=RunOutput( output="ok", @@ -380,16 +383,12 @@ def test_generate_run_with_existing_run_merges_usage_and_intermediate_outputs( ), usage=Usage(input_tokens=5, output_tokens=10), trace=extended_trace, - existing_run=initial_run, ) - assert result is initial_run - assert result.usage.input_tokens == 15 - assert result.usage.output_tokens == 30 - assert result.intermediate_outputs == { - "chain_of_thought": "old", - "new_key": "new_val", - } - assert result.output.output == "ok" + assert run2 is not run1 + assert run2.usage is not None and run2.usage.input_tokens == 5 + assert run2.usage.output_tokens == 10 + assert run2.intermediate_outputs == {"new_key": "new_val"} + assert run2.output.output == "ok" def test_properties_for_task_output_custom_values(test_task): diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py new file mode 100644 index 000000000..d4f3b3e43 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py @@ -0,0 +1,267 @@ +from litellm.types.utils import ( + ChatCompletionDeltaToolCall, + Delta, + Function, + ModelResponseStream, + StreamingChoices, +) + +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + AiSdkStreamConverter, + AiSdkStreamEvent, + ToolCallEvent, + ToolCallEventType, +) + + +def _make_tool_call_delta( + index: int = 0, + call_id: str | None = None, + name: str | None = None, + arguments: str | None = None, +) -> ChatCompletionDeltaToolCall: + func = Function(name=name, arguments=arguments or "") + tc = ChatCompletionDeltaToolCall(index=index, function=func) + if call_id is not None: + tc.id = call_id + return tc + + +def _make_chunk( + content: str | None = None, + reasoning_content: str | None = None, + tool_calls: list[ChatCompletionDeltaToolCall] | None = None, + finish_reason: str | None = None, +) -> ModelResponseStream: + delta = Delta(content=content, tool_calls=tool_calls) + if reasoning_content is not None: + delta.reasoning_content = reasoning_content + choice = StreamingChoices( + index=0, + delta=delta, + finish_reason=finish_reason, + ) + return ModelResponseStream(id="test", choices=[choice]) + + +class TestAiSdkStreamEvent: + def test_model_dump(self): + event = AiSdkStreamEvent(AiSdkEventType.START, {"messageId": "msg-123"}) + dump = event.model_dump() + assert dump["type"] == "start" + assert dump["messageId"] == "msg-123" + + +class TestAiSdkStreamConverter: + def test_text_start_and_delta(self): + converter = AiSdkStreamConverter() + events = converter.convert_chunk(_make_chunk(content="Hello")) + types = [e.type for e in events] + assert AiSdkEventType.TEXT_START in types + assert AiSdkEventType.TEXT_DELTA in types + assert events[-1].payload["delta"] == "Hello" + + def test_text_delta_no_duplicate_start(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(content="Hello")) + events = converter.convert_chunk(_make_chunk(content=" world")) + types = [e.type for e in events] + assert AiSdkEventType.TEXT_START not in types + assert AiSdkEventType.TEXT_DELTA in types + + def test_reasoning_start_and_delta(self): + converter = AiSdkStreamConverter() + events = converter.convert_chunk(_make_chunk(reasoning_content="Thinking...")) + types = [e.type for e in events] + assert AiSdkEventType.REASONING_START in types + assert AiSdkEventType.REASONING_DELTA in types + + def test_reasoning_ends_when_content_starts(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(reasoning_content="Thinking...")) + events = converter.convert_chunk(_make_chunk(content="Answer")) + types = [e.type for e in events] + assert AiSdkEventType.REASONING_END in types + assert AiSdkEventType.TEXT_START in types + + def test_reasoning_ends_when_tool_calls_start(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(reasoning_content="Thinking...")) + + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":1}' + ) + events = converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + types = [e.type for e in events] + assert AiSdkEventType.REASONING_END in types + + def test_tool_call_input_start_and_delta(self): + converter = AiSdkStreamConverter() + + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":' + ) + events = converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + types = [e.type for e in events] + assert AiSdkEventType.TOOL_INPUT_START in types + assert AiSdkEventType.TOOL_INPUT_DELTA in types + + start_event = next( + e for e in events if e.type == AiSdkEventType.TOOL_INPUT_START + ) + assert start_event.payload["toolCallId"] == "call_1" + assert start_event.payload["toolName"] == "add" + + def test_finalize_closes_open_blocks(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(content="text")) + events = converter.finalize() + types = [e.type for e in events] + assert AiSdkEventType.TEXT_END in types + assert AiSdkEventType.FINISH in types + + def test_finalize_closes_reasoning(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(reasoning_content="thinking")) + events = converter.finalize() + types = [e.type for e in events] + assert AiSdkEventType.REASONING_END in types + + def test_convert_tool_event_input_available(self): + converter = AiSdkStreamConverter() + event = ToolCallEvent( + event_type=ToolCallEventType.INPUT_AVAILABLE, + tool_call_id="call_1", + tool_name="add", + arguments={"a": 1, "b": 2}, + ) + events = converter.convert_tool_event(event) + assert len(events) == 1 + assert events[0].type == AiSdkEventType.TOOL_INPUT_AVAILABLE + assert events[0].payload["toolCallId"] == "call_1" + assert events[0].payload["input"] == {"a": 1, "b": 2} + + def test_convert_tool_event_output_available(self): + converter = AiSdkStreamConverter() + event = ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_AVAILABLE, + tool_call_id="call_1", + tool_name="add", + result="3", + ) + events = converter.convert_tool_event(event) + assert len(events) == 1 + assert events[0].type == AiSdkEventType.TOOL_OUTPUT_AVAILABLE + assert events[0].payload["output"] == "3" + + def test_convert_tool_event_output_error(self): + converter = AiSdkStreamConverter() + event = ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_ERROR, + tool_call_id="call_1", + tool_name="add", + error="Something went wrong", + ) + events = converter.convert_tool_event(event) + assert len(events) == 1 + assert events[0].type == AiSdkEventType.TOOL_OUTPUT_ERROR + assert events[0].payload["errorText"] == "Something went wrong" + + def test_reasoning_not_interrupted_by_empty_content(self): + # Minimax and similar models send chunks with both reasoning_content and + # delta.content="" simultaneously. Empty content must not close reasoning + # blocks or emit useless text-delta events. + converter = AiSdkStreamConverter() + + chunk1 = _make_chunk(reasoning_content="The", content="") + chunk2 = _make_chunk(reasoning_content=" user", content="") + chunk3 = _make_chunk(reasoning_content=" is", content="") + + events1 = converter.convert_chunk(chunk1) + events2 = converter.convert_chunk(chunk2) + events3 = converter.convert_chunk(chunk3) + + all_types1 = [e.type for e in events1] + all_types2 = [e.type for e in events2] + all_types3 = [e.type for e in events3] + + # First chunk opens the reasoning block + assert AiSdkEventType.REASONING_START in all_types1 + assert AiSdkEventType.REASONING_DELTA in all_types1 + # No text events from empty content + assert AiSdkEventType.TEXT_START not in all_types1 + assert AiSdkEventType.TEXT_DELTA not in all_types1 + + # Subsequent chunks must NOT re-open reasoning (no start) and must NOT + # close reasoning with reasoning-end + assert AiSdkEventType.REASONING_START not in all_types2 + assert AiSdkEventType.REASONING_END not in all_types2 + assert AiSdkEventType.REASONING_DELTA in all_types2 + assert AiSdkEventType.TEXT_DELTA not in all_types2 + + assert AiSdkEventType.REASONING_START not in all_types3 + assert AiSdkEventType.REASONING_END not in all_types3 + assert AiSdkEventType.REASONING_DELTA in all_types3 + assert AiSdkEventType.TEXT_DELTA not in all_types3 + + def test_reset_for_next_step(self): + converter = AiSdkStreamConverter() + converter._finish_reason = "tool_calls" + converter._tool_calls_state = { + 0: {"id": "x", "name": "y", "arguments": "", "started": True} + } + converter.reset_for_next_step() + assert converter._tool_calls_state == {} + assert converter._finish_reason is None + + def test_finish_reason_in_finalize(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(content="done", finish_reason="stop")) + events = converter.finalize() + finish_events = [e for e in events if e.type == AiSdkEventType.FINISH] + assert len(finish_events) == 1 + meta = finish_events[0].payload.get("messageMetadata", {}) + assert meta.get("finishReason") == "stop" + + def test_tool_input_start_reemitted_after_reset(self): + """After reset_for_next_step, tool-input-start must fire again for index 0.""" + converter = AiSdkStreamConverter() + + tc_round1 = _make_tool_call_delta( + index=0, call_id="call_r1", name="search", arguments='{"q":"hi"}' + ) + events_r1 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round1])) + starts_r1 = [e for e in events_r1 if e.type == AiSdkEventType.TOOL_INPUT_START] + assert len(starts_r1) == 1 + assert starts_r1[0].payload["toolCallId"] == "call_r1" + + converter.reset_for_next_step() + + tc_round2 = _make_tool_call_delta( + index=0, call_id="call_r2", name="search", arguments='{"q":"world"}' + ) + events_r2 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round2])) + starts_r2 = [e for e in events_r2 if e.type == AiSdkEventType.TOOL_INPUT_START] + assert len(starts_r2) == 1, ( + "tool-input-start must be re-emitted for index 0 after reset" + ) + assert starts_r2[0].payload["toolCallId"] == "call_r2" + + def test_tool_input_start_not_reemitted_without_reset(self): + """Without reset, a second tool call at index 0 must NOT re-emit tool-input-start.""" + converter = AiSdkStreamConverter() + + tc_round1 = _make_tool_call_delta( + index=0, call_id="call_r1", name="search", arguments='{"q":"hi"}' + ) + converter.convert_chunk(_make_chunk(tool_calls=[tc_round1])) + + tc_round2 = _make_tool_call_delta( + index=0, call_id="call_r2", name="search", arguments='{"q":"world"}' + ) + events_r2 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round2])) + starts_r2 = [e for e in events_r2 if e.type == AiSdkEventType.TOOL_INPUT_START] + assert len(starts_r2) == 0, ( + "Without reset, started=True blocks duplicate tool-input-start" + ) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py index 92d9a1f4d..46fa88aa7 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py @@ -1,7 +1,7 @@ import json from pathlib import Path from typing import Dict -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from litellm.types.utils import ModelResponse @@ -10,6 +10,7 @@ from kiln_ai.adapters.adapter_registry import adapter_for_task from kiln_ai.adapters.ml_model_list import built_in_models from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput, Usage +from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter from kiln_ai.adapters.ollama_tools import ollama_online from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers from kiln_ai.datamodel import PromptId @@ -53,11 +54,7 @@ def __init__(self, kiln_task: datamodel.Task, response: InputType | None): ) self.response = response - async def _run( - self, - input: str, - prior_trace=None, - ) -> tuple[RunOutput, Usage | None]: + async def _run(self, input: str, **kwargs) -> tuple[RunOutput, Usage | None]: return RunOutput(output=self.response, intermediate_outputs=None), None def adapter_name(self) -> str: @@ -351,9 +348,10 @@ async def test_all_built_in_models_structured_input_mocked(tmp_path): mock_config.groq_api_key = "mock_api_key" with ( - patch( - "litellm.acompletion", - side_effect=[mock_response], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): @@ -406,9 +404,15 @@ async def test_structured_input_cot_prompt_builder_mocked(tmp_path): mock_config.groq_api_key = "mock_api_key" with ( - patch( - "litellm.acompletion", - side_effect=[mock_response_1, mock_response_2], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock( + side_effect=[ + (mock_response_1, mock_response_1.choices[0]), + (mock_response_2, mock_response_2.choices[0]), + ] + ), ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): diff --git a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py index 23dd15b0d..187de1ff2 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py +++ b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py @@ -1,9 +1,9 @@ import os from pathlib import Path -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest -from litellm.utils import ModelResponse +from litellm.types.utils import ModelResponse import kiln_ai.datamodel as datamodel from kiln_ai.adapters.adapter_registry import adapter_for_task @@ -113,13 +113,15 @@ async def test_amazon_bedrock(tmp_path): async def test_mock_returning_run(tmp_path): task = build_test_task(tmp_path) - with patch("litellm.acompletion") as mock_acompletion: - # Configure the mock to return a properly structured response - mock_acompletion.return_value = ModelResponse( - model="custom_model", - choices=[{"message": {"content": "mock response"}}], - ) - + mock_response = ModelResponse( + model="custom_model", + choices=[{"message": {"content": "mock response"}}], + ) + with patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), + ): run_config = KilnAgentRunConfigProperties( model_name="custom_model", model_provider_name=ModelProviderName.ollama, diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index 25cadd99d..7a67ff5c9 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -58,11 +58,7 @@ def test_simple_prompt_builder(tmp_path): class MockAdapter(BaseAdapter): - async def _run( - self, - input: InputType, - prior_trace=None, - ) -> tuple[RunOutput, Usage | None]: + async def _run(self, input: InputType, **kwargs) -> tuple[RunOutput, Usage | None]: return RunOutput(output="mock response", intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/core/kiln_ai/datamodel/external_tool_server.py b/libs/core/kiln_ai/datamodel/external_tool_server.py index b01f6acb0..ad6a6ebdb 100644 --- a/libs/core/kiln_ai/datamodel/external_tool_server.py +++ b/libs/core/kiln_ai/datamodel/external_tool_server.py @@ -25,6 +25,18 @@ class ToolServerType(str, Enum): kiln_task = "kiln_task" +def tool_server_type_to_string(server_type: ToolServerType) -> str: + match server_type: + case ToolServerType.remote_mcp: + return "Remote MCP" + case ToolServerType.local_mcp: + return "Local MCP" + case ToolServerType.kiln_task: + return "Kiln Task" + case _: + raise_exhaustive_enum_error(server_type) + + class LocalServerProperties(TypedDict, total=True): command: str args: NotRequired[list[str]] diff --git a/libs/core/kiln_ai/datamodel/test_basemodel.py b/libs/core/kiln_ai/datamodel/test_basemodel.py index 03c6dfeaf..897a3749f 100644 --- a/libs/core/kiln_ai/datamodel/test_basemodel.py +++ b/libs/core/kiln_ai/datamodel/test_basemodel.py @@ -862,7 +862,7 @@ def individual_lookups(): class MockAdapter(BaseAdapter): """Implementation of BaseAdapter for testing""" - async def _run(self, input, prior_trace=None): + async def _run(self, input, **kwargs): return RunOutput(output="test output", intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/core/kiln_ai/tools/mcp_server_tool.py b/libs/core/kiln_ai/tools/mcp_server_tool.py index 10c1d7820..af93eb949 100644 --- a/libs/core/kiln_ai/tools/mcp_server_tool.py +++ b/libs/core/kiln_ai/tools/mcp_server_tool.py @@ -50,9 +50,19 @@ async def run( ) -> ToolCallResult: result = await self._call_tool(**kwargs) + # MCP tool returned an application-level error - return it to the agent + # instead of crashing the run. This allows the agent to respond gracefully. if result.isError: - raise ValueError( - f"Tool {await self.name()} returned an error: {result.content}" + error_text = ( + " ".join( + block.text + for block in result.content + if isinstance(block, TextContent) + ) + or "Unknown error" + ) + return ToolCallResult( + output=json.dumps({"isError": True, "error": error_text}) ) # If the tool returns structured content, return it as a JSON string diff --git a/libs/core/kiln_ai/tools/mcp_session_manager.py b/libs/core/kiln_ai/tools/mcp_session_manager.py index 29ac57c9e..6cde815b1 100644 --- a/libs/core/kiln_ai/tools/mcp_session_manager.py +++ b/libs/core/kiln_ai/tools/mcp_session_manager.py @@ -21,10 +21,17 @@ logger = logging.getLogger(__name__) -LOCAL_MCP_ERROR_INSTRUCTION = "Please verify your command, arguments, and environment variables, and consult the server's documentation for the correct setup." MCP_SESSION_CACHE_KEY_DELIMITER = "::" +class KilnMCPError(RuntimeError): + """Wraps MCP connection failures. Unwraps ExceptionGroup; attaches stderr.""" + + def __init__(self, message: str, stderr: str = ""): + super().__init__(message) + self.stderr = stderr + + class MCPSessionManager: """ This class is a singleton that manages MCP sessions for remote MCP servers. @@ -410,27 +417,18 @@ def _handle_remote_mcp_error(self, e: Exception) -> NoReturn: e: The exception to handle Raises: - ValueError: If the server rejected the request with an HTTP error - RuntimeError: If connection to the server failed + KilnMCPError: Always, with the raw library error message """ - http_error = self._extract_first_exception(e, httpx.HTTPStatusError) - if http_error and isinstance(http_error, httpx.HTTPStatusError): - raise ValueError( - f"The MCP server rejected the request. " - f"Status {http_error.response.status_code}. " - f"Response from server:\n{http_error.response.reason_phrase}" - ) from e - - connection_error_types = (ConnectionError, OSError, httpx.RequestError) - connection_error = self._extract_first_exception(e, connection_error_types) - if connection_error and isinstance(connection_error, connection_error_types): - raise RuntimeError( - f"Unable to connect to MCP server. Please verify the configurations are correct, the server is running, and your network connection is working. Original error: {connection_error}" - ) from e - - raise RuntimeError( - f"Failed to connect to the MCP Server. Check the server's docs for troubleshooting. Original error: {e}" - ) from e + for exc_type in ( + httpx.HTTPStatusError, + ConnectionError, + OSError, + httpx.RequestError, + ): + found = self._extract_first_exception(e, exc_type) + if found: + raise KilnMCPError(str(found)) from e + raise KilnMCPError(str(e)) from e def _handle_local_mcp_error(self, e: Exception, stderr: str) -> NoReturn: """Shared error handling for local MCP connection failures. @@ -440,26 +438,13 @@ def _handle_local_mcp_error(self, e: Exception, stderr: str) -> NoReturn: stderr: The stderr content from the MCP server Raises: - RuntimeError: Always, with a friendly error message + KilnMCPError: Always, with the raw library error message """ - mcp_error = self._extract_first_exception(e, McpError) - if mcp_error and isinstance(mcp_error, McpError): - self._raise_local_mcp_error(mcp_error, stderr) - - self._raise_local_mcp_error(e, stderr) - - def _raise_local_mcp_error(self, e: Exception, stderr: str) -> NoReturn: - """ - Raise a RuntimeError with a friendlier message for local MCP errors. - """ - error_msg = f"'{e}'" - - if stderr: - error_msg += f"\nMCP server error: {stderr}" - - error_msg += f"\n{LOCAL_MCP_ERROR_INSTRUCTION}" - - raise RuntimeError(error_msg) from e + for exc_type in (FileNotFoundError, OSError, McpError): + found = self._extract_first_exception(e, exc_type) + if found: + raise KilnMCPError(str(found), stderr=stderr) from e + raise KilnMCPError(str(e), stderr=stderr) from e def _get_path(self) -> str: """ diff --git a/libs/core/kiln_ai/tools/test_mcp_server_tool.py b/libs/core/kiln_ai/tools/test_mcp_server_tool.py index ab113f600..12fe5aaad 100644 --- a/libs/core/kiln_ai/tools/test_mcp_server_tool.py +++ b/libs/core/kiln_ai/tools/test_mcp_server_tool.py @@ -1,3 +1,4 @@ +import json from unittest.mock import AsyncMock, patch import pytest @@ -173,11 +174,50 @@ async def test_run_non_text_content_error( with pytest.raises(ValueError, match="First block must be a text block"): await tool.run() + @pytest.mark.asyncio + @patch("kiln_ai.tools.mcp_server_tool.get_agent_run_id") + @patch("kiln_ai.tools.mcp_server_tool.MCPSessionManager") + async def test_run_error_result_no_text_content( + self, mock_session_manager, mock_get_run_id + ): + """Test run() returns Unknown error when isError=True with no TextContent.""" + mock_get_run_id.return_value = "test_run_123" + mock_session = AsyncMock() + mock_session_manager.shared.return_value.get_or_create_session = AsyncMock( + return_value=mock_session + ) + + # Return ImageContent instead of TextContent + result_content = [ + ImageContent(type="image", data="base64data", mimeType="image/png") + ] + call_result = CallToolResult( + content=list[ContentBlock](result_content), + isError=True, # type: ignore + ) + mock_session.call_tool.return_value = call_result + + server = ExternalToolServer( + name="test_server", + type=ToolServerType.remote_mcp, + properties={ + "server_url": "https://example.com", + "is_archived": False, + }, + ) + tool = MCPServerTool(server, "test_tool") + + result = await tool.run() + + # Should return structured error with "Unknown error" + expected_output = json.dumps({"isError": True, "error": "Unknown error"}) + assert result.output == expected_output + @pytest.mark.asyncio @patch("kiln_ai.tools.mcp_server_tool.get_agent_run_id") @patch("kiln_ai.tools.mcp_server_tool.MCPSessionManager") async def test_run_error_result(self, mock_session_manager, mock_get_run_id): - """Test run() raises error when tool returns isError=True.""" + """Test run() returns structured error when tool returns isError=True.""" mock_get_run_id.return_value = "test_run_123" mock_session = AsyncMock() mock_session_manager.shared.return_value.get_or_create_session = AsyncMock( @@ -201,8 +241,11 @@ async def test_run_error_result(self, mock_session_manager, mock_get_run_id): ) tool = MCPServerTool(server, "test_tool") - with pytest.raises(ValueError, match="Tool test_tool returned an error"): - await tool.run() + result = await tool.run() + + # Should return structured error, not raise an exception + expected_output = json.dumps({"isError": True, "error": "Error occurred"}) + assert result.output == expected_output @pytest.mark.asyncio @patch("kiln_ai.tools.mcp_server_tool.get_agent_run_id") diff --git a/libs/core/kiln_ai/tools/test_mcp_session_manager.py b/libs/core/kiln_ai/tools/test_mcp_session_manager.py index 07c4bc74f..bbf105841 100644 --- a/libs/core/kiln_ai/tools/test_mcp_session_manager.py +++ b/libs/core/kiln_ai/tools/test_mcp_session_manager.py @@ -13,8 +13,8 @@ ToolServerType, ) from kiln_ai.tools.mcp_session_manager import ( - LOCAL_MCP_ERROR_INSTRUCTION, MCP_SESSION_CACHE_KEY_DELIMITER, + KilnMCPError, MCPSessionManager, build_mcp_session_cache_key, parse_mcp_session_cache_session_id, @@ -345,7 +345,7 @@ async def test_session_with_empty_headers(self, mock_client, basic_remote_server async def test_remote_mcp_http_status_errors( self, mock_client, status_code, reason_phrase, basic_remote_server ): - """Test remote MCP session handles various HTTP status errors with simplified message.""" + """Test remote MCP session handles various HTTP status errors with KilnMCPError.""" # Create HTTP error with specific status code response = MagicMock() response.status_code = status_code @@ -359,14 +359,15 @@ async def test_remote_mcp_http_status_errors( manager = MCPSessionManager.shared() - # All HTTP errors should now use the simplified message format - expected_pattern = f"The MCP server rejected the request. Status {status_code}. Response from server:\n{reason_phrase}" - with pytest.raises( - ValueError, match=expected_pattern.replace("(", r"\(").replace(")", r"\)") - ): + # All HTTP errors should now raise KilnMCPError with the raw library message + with pytest.raises(KilnMCPError) as exc_info: async with manager.mcp_client(basic_remote_server): pass + # Verify the error message contains the raw library error (the reason phrase) + # httpx.HTTPStatusError string representation includes the reason phrase + assert reason_phrase in str(exc_info.value) + @pytest.mark.parametrize( "connection_error_type,error_message", [ @@ -380,7 +381,7 @@ async def test_remote_mcp_http_status_errors( async def test_remote_mcp_connection_errors( self, mock_client, connection_error_type, error_message, basic_remote_server ): - """Test remote MCP session handles various connection errors with simplified message.""" + """Test remote MCP session handles various connection errors with KilnMCPError.""" # Create connection error if connection_error_type == httpx.RequestError: connection_error = connection_error_type(error_message, request=MagicMock()) @@ -394,8 +395,8 @@ async def test_remote_mcp_connection_errors( manager = MCPSessionManager.shared() - # All connection errors should use the simplified message format - with pytest.raises(RuntimeError, match="Unable to connect to MCP server"): + # All connection errors should now raise KilnMCPError + with pytest.raises(KilnMCPError): async with manager.mcp_client(basic_remote_server): pass @@ -424,13 +425,14 @@ def __init__(self, exceptions): manager = MCPSessionManager.shared() - # Should extract the HTTP error from the nested structure - with pytest.raises( - ValueError, match=r"The MCP server rejected the request. Status 401" - ): + # Should extract the HTTP error from the nested structure and raise KilnMCPError + with pytest.raises(KilnMCPError) as exc_info: async with manager.mcp_client(basic_remote_server): pass + # Verify the error message contains the raw library error (the reason phrase) + assert "Unauthorized" in str(exc_info.value) + @patch("kiln_ai.tools.mcp_session_manager.streamablehttp_client") async def test_remote_mcp_connection_error_in_nested_exceptions( self, mock_client, basic_remote_server @@ -451,8 +453,8 @@ def __init__(self, exceptions): manager = MCPSessionManager.shared() - # Should extract the connection error from the nested structure - with pytest.raises(RuntimeError, match="Unable to connect to MCP server"): + # Should extract the connection error from the nested structure and raise KilnMCPError + with pytest.raises(KilnMCPError): async with manager.mcp_client(basic_remote_server): pass @@ -460,18 +462,21 @@ def __init__(self, exceptions): async def test_remote_mcp_unknown_error_fallback( self, mock_client, basic_remote_server ): - """Test remote MCP session handles unknown errors with fallback message.""" + """Test remote MCP session handles unknown errors with KilnMCPError.""" # Mock client to raise an unknown error type unknown_error = RuntimeError("Unexpected error") mock_client.return_value.__aenter__.side_effect = unknown_error manager = MCPSessionManager.shared() - # Should use the fallback error message - with pytest.raises(RuntimeError, match="Failed to connect to the MCP Server"): + # Should raise KilnMCPError with the raw library message + with pytest.raises(KilnMCPError) as exc_info: async with manager.mcp_client(basic_remote_server): pass + # Verify the error message contains the raw error + assert "Unexpected error" in str(exc_info.value) + @patch("kiln_ai.tools.mcp_session_manager.streamablehttp_client") @patch("kiln_ai.utils.config.Config.shared") async def test_session_with_secret_headers( @@ -1103,10 +1108,10 @@ def mock_get_value(key): ], ) @patch("kiln_ai.tools.mcp_session_manager.stdio_client") - async def test_local_mcp_various_errors_use_simplified_message( + async def test_local_mcp_various_errors_use_raw_message( self, mock_client, error_type, error_message, basic_local_server ): - """Test local MCP session handles various errors with simplified message.""" + """Test local MCP session handles various errors with KilnMCPError using raw library messages.""" # Create the appropriate error if error_type == McpError: error_data = ErrorData(code=-1, message=error_message) @@ -1119,11 +1124,14 @@ async def test_local_mcp_various_errors_use_simplified_message( manager = MCPSessionManager.shared() - # All local errors should now use the simplified message format - with pytest.raises(RuntimeError, match=LOCAL_MCP_ERROR_INSTRUCTION): + # All error types should raise KilnMCPError with the raw library message + with pytest.raises(KilnMCPError) as exc_info: async with manager.mcp_client(basic_local_server): pass + # Verify the error message contains the raw library error message + assert error_message in str(exc_info.value) + @patch("kiln_ai.tools.mcp_session_manager.stdio_client") async def test_local_mcp_mcp_error_in_nested_exceptions(self, mock_client): """Test local MCP session extracts McpError from nested exceptions.""" @@ -1155,33 +1163,13 @@ def __init__(self, exceptions): manager = MCPSessionManager.shared() - # Should extract the McpError from the nested structure and use simplified message - with pytest.raises(RuntimeError, match=LOCAL_MCP_ERROR_INSTRUCTION): + # Should extract the McpError from the nested structure and raise KilnMCPError + with pytest.raises(KilnMCPError) as exc_info: async with manager.mcp_client(tool_server): pass - def test_raise_local_mcp_error_method(self): - """Test the _raise_local_mcp_error helper method.""" - manager = MCPSessionManager() - - # Test with different exception types - test_exceptions = [ - ValueError("test value error"), - FileNotFoundError("file not found"), - RuntimeError("runtime error"), - Exception("generic exception"), - ] - - for original_error in test_exceptions: - with pytest.raises(RuntimeError) as exc_info: - manager._raise_local_mcp_error(original_error, "") - - # Check that the error message contains expected text - assert LOCAL_MCP_ERROR_INSTRUCTION in str(exc_info.value) - assert str(original_error) in str(exc_info.value) - - # Check that the original exception is chained - assert exc_info.value.__cause__ is original_error + # Verify the error message contains the raw library error + assert "Server startup failed" in str(exc_info.value) @patch("kiln_ai.tools.mcp_session_manager.stdio_client") @patch("kiln_ai.utils.config.Config.shared") @@ -1562,6 +1550,29 @@ async def test_list_tools_with_real_local_mcp_server(self): assert len(tools.tools) > 0 assert "firecrawl_scrape" in [tool.name for tool in tools.tools] + async def test_local_mcp_nonexistent_command_raises_kiln_mcp_error(self): + """Real integration test (no mocks): confirms that spawning a local MCP server + with a nonexistent command always raises KilnMCPError, not FileNotFoundError + or OSError. Validates that callers can safely catch (KilnMCPError, RuntimeError, ValueError) + to handle all local MCP failure modes.""" + server = ExternalToolServer( + name="bad_command_server", + type=ToolServerType.local_mcp, + description="Server with nonexistent command", + properties={ + "command": "this_command_does_not_exist_kiln_test", + "args": [], + "env_vars": {}, + "secret_env_var_keys": [], + "is_archived": False, + }, + ) + + # Should now raise KilnMCPError (subclass of RuntimeError) + with pytest.raises(KilnMCPError): + async with MCPSessionManager.shared().mcp_client(server) as session: + await session.list_tools() + @pytest.fixture def mock_mcp_streams(): diff --git a/libs/core/kiln_ai/utils/config.py b/libs/core/kiln_ai/utils/config.py index 704c9f87f..2fe8f0f73 100644 --- a/libs/core/kiln_ai/utils/config.py +++ b/libs/core/kiln_ai/utils/config.py @@ -20,6 +20,7 @@ def __init__( default_lambda: Optional[Callable[[], Any]] = None, sensitive: bool = False, sensitive_keys: Optional[List[str]] = None, + in_memory: bool = False, ): self.type = type_ self.default = default @@ -27,6 +28,7 @@ def __init__( self.default_lambda = default_lambda self.sensitive = sensitive self.sensitive_keys = sensitive_keys + self.in_memory = in_memory class Config: @@ -43,6 +45,7 @@ def __init__(self, properties: Dict[str, ConfigProperty] | None = None): bool, env_var="KILN_AUTOSAVE_RUNS", default=True, + in_memory=True, ), "open_ai_api_key": ConfigProperty( str, @@ -198,6 +201,7 @@ def __init__(self, properties: Dict[str, ConfigProperty] | None = None): ), } self._lock = threading.Lock() + self._in_memory_settings: Dict[str, Any] = {} self._settings = self.load_settings() @classmethod @@ -221,10 +225,14 @@ def __getattr__(self, name: str) -> Any: property_config = self._properties[name] - # Check if the value is in settings - if name in self._settings: - value = self._settings[name] - return value if value is None else property_config.type(value) + if property_config.in_memory: + if name in self._in_memory_settings: + value = self._in_memory_settings[name] + return value if value is None else property_config.type(value) + else: + if name in self._settings: + value = self._settings[name] + return value if value is None else property_config.type(value) # Check environment variable if property_config.env_var and property_config.env_var in os.environ: @@ -240,10 +248,14 @@ def __getattr__(self, name: str) -> Any: return None if value is None else property_config.type(value) def __setattr__(self, name, value): - if name in ("_properties", "_settings", "_lock"): + if name in ("_properties", "_settings", "_lock", "_in_memory_settings"): super().__setattr__(name, value) elif name in self._properties: - self.update_settings({name: value}) + if self._properties[name].in_memory: + with self._lock: + self._in_memory_settings[name] = value + else: + self.update_settings({name: value}) else: raise AttributeError(f"Config has no attribute '{name}'") @@ -268,14 +280,22 @@ def load_settings(cls): return settings def settings(self, hide_sensitive=False) -> Dict[str, Any]: + with self._lock: + filtered_disk = { + k: v + for k, v in self._settings.items() + if k not in self._properties or not self._properties[k].in_memory + } + combined = {**filtered_disk, **self._in_memory_settings} + if not hide_sensitive: - return self._settings + return combined settings = { k: "[hidden]" if k in self._properties and self._properties[k].sensitive else copy.deepcopy(v) - for k, v in self._settings.items() + for k, v in combined.items() } # Hide sensitive keys in lists. Could generalize this if we every have more types, but right not it's only needed for root elements of lists for key, value in settings.items(): @@ -293,18 +313,32 @@ def save_setting(self, name: str, value: Any): self.update_settings({name: value}) def update_settings(self, new_settings: Dict[str, Any]): - # Lock to prevent race conditions in multi-threaded scenarios with self._lock: - # Fresh load to avoid clobbering changes from other instances - current_settings = self.load_settings() - current_settings.update(new_settings) - # remove None values - current_settings = { - k: v for k, v in current_settings.items() if v is not None + in_memory_updates = { + k: v + for k, v in new_settings.items() + if k in self._properties and self._properties[k].in_memory } - with open(self.settings_path(), "w") as f: - yaml.dump(current_settings, f) - self._settings = current_settings + disk_updates = { + k: v + for k, v in new_settings.items() + if k not in self._properties or not self._properties[k].in_memory + } + + if in_memory_updates: + self._in_memory_settings.update(in_memory_updates) + + if disk_updates: + # Fresh load to avoid clobbering changes from other instances + current_settings = self.load_settings() + current_settings.update(disk_updates) + # remove None values + current_settings = { + k: v for k, v in current_settings.items() if v is not None + } + with open(self.settings_path(), "w") as f: + yaml.dump(current_settings, f) + self._settings = current_settings def _get_user_id(): diff --git a/libs/core/kiln_ai/utils/test_config.py b/libs/core/kiln_ai/utils/test_config.py index 75891fe88..5da2302e1 100644 --- a/libs/core/kiln_ai/utils/test_config.py +++ b/libs/core/kiln_ai/utils/test_config.py @@ -494,3 +494,206 @@ def test_mcp_secrets_type_validation(): assert config.mcp_secrets == mixed_types or config.mcp_secrets == { "server::key": "123" } + + +def test_in_memory_property_default_value(mock_yaml_file): + """Test that in-memory property returns default value""" + with patch( + "kiln_ai.utils.config.Config.settings_path", + return_value=mock_yaml_file, + ): + config = Config( + properties={ + "in_memory_prop": ConfigProperty(bool, default=True, in_memory=True), + } + ) + assert config.in_memory_prop is True + + +def test_in_memory_property_env_var(mock_yaml_file): + """Test that in-memory property respects env var""" + with ( + patch( + "kiln_ai.utils.config.Config.settings_path", + return_value=mock_yaml_file, + ), + patch.dict(os.environ, {"IN_MEMORY_TEST_VAR": "custom_env_value"}), + ): + config = Config( + properties={ + "in_memory_prop": ConfigProperty( + str, + default="default_value", + env_var="IN_MEMORY_TEST_VAR", + in_memory=True, + ), + } + ) + assert config.in_memory_prop == "custom_env_value" + + +def test_in_memory_property_setter_works(mock_yaml_file): + """Test that in-memory property setter works""" + with patch( + "kiln_ai.utils.config.Config.settings_path", + return_value=mock_yaml_file, + ): + config = Config( + properties={ + "in_memory_prop": ConfigProperty(bool, default=True, in_memory=True), + } + ) + assert config.in_memory_prop is True + config.in_memory_prop = False + assert config.in_memory_prop is False + + +def test_in_memory_property_not_persisted_to_yaml(mock_yaml_file): + """Test that in-memory property is not persisted to YAML""" + with patch( + "kiln_ai.utils.config.Config.settings_path", + return_value=mock_yaml_file, + ): + config = Config( + properties={ + "in_memory_prop": ConfigProperty(bool, default=True, in_memory=True), + "persisted_prop": ConfigProperty(str, default="default"), + } + ) + config.in_memory_prop = False + config.persisted_prop = "persisted_value" + + with open(mock_yaml_file, "r") as f: + saved_settings = yaml.safe_load(f) + + assert "in_memory_prop" not in saved_settings + assert saved_settings["persisted_prop"] == "persisted_value" + + +def test_in_memory_property_included_in_settings_output(mock_yaml_file): + """Test that in-memory property is included in settings() output""" + with patch( + "kiln_ai.utils.config.Config.settings_path", + return_value=mock_yaml_file, + ): + config = Config( + properties={ + "in_memory_prop": ConfigProperty(bool, default=True, in_memory=True), + "persisted_prop": ConfigProperty(str, default="default"), + } + ) + config.in_memory_prop = False + config.persisted_prop = "persisted_value" + + settings = config.settings() + assert settings["in_memory_prop"] is False + assert settings["persisted_prop"] == "persisted_value" + + +def test_in_memory_property_priority(mock_yaml_file): + """Test that setter takes precedence over env var for in-memory properties""" + with ( + patch( + "kiln_ai.utils.config.Config.settings_path", + return_value=mock_yaml_file, + ), + patch.dict(os.environ, {"IN_MEMORY_TEST_VAR": "env_value"}), + ): + config = Config( + properties={ + "in_memory_prop": ConfigProperty( + str, + default="default_value", + env_var="IN_MEMORY_TEST_VAR", + in_memory=True, + ), + } + ) + assert config.in_memory_prop == "env_value" + + config.in_memory_prop = "set_value" + assert config.in_memory_prop == "set_value" + + +def test_autosave_runs_is_in_memory(): + """Test that autosave_runs is an in-memory property""" + config = Config.shared() + assert config._properties["autosave_runs"].in_memory is True + assert config.autosave_runs is True + + +def test_update_settings_filters_in_memory_properties(mock_yaml_file): + """Test that update_settings routes in-memory properties correctly""" + with patch( + "kiln_ai.utils.config.Config.settings_path", + return_value=mock_yaml_file, + ): + config = Config( + properties={ + "in_memory_prop": ConfigProperty(bool, default=True, in_memory=True), + "persisted_prop": ConfigProperty(str, default="default"), + } + ) + + config.update_settings( + { + "in_memory_prop": False, + "persisted_prop": "updated_value", + } + ) + + assert config.in_memory_prop is False + assert config.persisted_prop == "updated_value" + + with open(mock_yaml_file, "r") as f: + saved_settings = yaml.safe_load(f) + + assert "in_memory_prop" not in saved_settings + assert saved_settings["persisted_prop"] == "updated_value" + + +def test_save_setting_filters_in_memory_properties(mock_yaml_file): + """Test that save_setting routes in-memory properties correctly""" + with patch( + "kiln_ai.utils.config.Config.settings_path", + return_value=mock_yaml_file, + ): + config = Config( + properties={ + "in_memory_prop": ConfigProperty(bool, default=True, in_memory=True), + } + ) + + config.save_setting("in_memory_prop", False) + + assert config.in_memory_prop is False + + assert not os.path.exists(mock_yaml_file) + + +def test_in_memory_property_sensitive_hidden(mock_yaml_file): + """Test that sensitive in-memory properties are hidden in settings()""" + with patch( + "kiln_ai.utils.config.Config.settings_path", + return_value=mock_yaml_file, + ): + config = Config( + properties={ + "in_memory_secret": ConfigProperty( + str, default="default", in_memory=True, sensitive=True + ), + "in_memory_public": ConfigProperty( + str, default="default", in_memory=True + ), + } + ) + config.in_memory_secret = "secret_value" + config.in_memory_public = "public_value" + + settings = config.settings(hide_sensitive=True) + assert settings["in_memory_secret"] == "[hidden]" + assert settings["in_memory_public"] == "public_value" + + settings_visible = config.settings(hide_sensitive=False) + assert settings_visible["in_memory_secret"] == "secret_value" + assert settings_visible["in_memory_public"] == "public_value" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 72d673630..d96bd594a 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kiln-ai" -version = "0.24.0" +version = "0.25.0" requires-python = ">=3.10" readme = "README.md" description = 'Kiln AI' @@ -41,8 +41,8 @@ dependencies = [ "pillow>=11.1.0", "llama-index-vector-stores-lancedb>=0.4.2", "mcp[cli]>=1.10.1", - "litellm>=1.80.9", "typer>=0.9.0", + "litellm>=1.81.16", ] [project.scripts] diff --git a/libs/server/kiln_server/run_api.py b/libs/server/kiln_server/run_api.py index ea13104d2..f123d421e 100644 --- a/libs/server/kiln_server/run_api.py +++ b/libs/server/kiln_server/run_api.py @@ -56,10 +56,6 @@ class RunTaskRequest(BaseModel): plaintext_input: str | None = None structured_input: StructuredInputType | None = None tags: list[str] | None = None - task_run_id: str | None = Field( - default=None, - description="When set, continue an existing session. The new message is appended to the run's trace.", - ) # Allows use of the model_name field (usually pydantic will reserve model_*) model_config = ConfigDict(protected_namespaces=()) @@ -285,28 +281,7 @@ async def run_task( detail="No input provided. Ensure your provided the proper format (plaintext or structured).", ) - existing_run: TaskRun | None = None - if request.task_run_id is not None: - if task.path is None: - raise HTTPException( - status_code=400, - detail="Cannot continue session: task has no path. Save the task first.", - ) - existing_run = TaskRun.from_id_and_parent_path( - request.task_run_id, task.path - ) - if existing_run is None: - raise HTTPException( - status_code=404, - detail="Run not found. Cannot continue session.", - ) - if not existing_run.trace or len(existing_run.trace) == 0: - raise HTTPException( - status_code=400, - detail="Run has no trace. Cannot continue session without conversation history.", - ) - - return await adapter.invoke(input, existing_run=existing_run) + return await adapter.invoke(input) @app.patch("/api/projects/{project_id}/tasks/{task_id}/runs/{run_id}") async def update_run( diff --git a/libs/server/kiln_server/test_run_api.py b/libs/server/kiln_server/test_run_api.py index 608a03a94..974f1d433 100644 --- a/libs/server/kiln_server/test_run_api.py +++ b/libs/server/kiln_server/test_run_api.py @@ -143,135 +143,6 @@ async def test_run_task_success(client, task_run_setup): assert res["id"] is not None -@pytest.mark.asyncio -async def test_run_task_with_task_run_id_continues_session(client, task_run_setup): - """Test that run_task with task_run_id passes it to adapter.invoke for session continuation.""" - project = task_run_setup["project"] - task = task_run_setup["task"] - task_run = task_run_setup["task_run"] - - run_task_request = { - "run_config_properties": { - "model_name": "gpt_4o", - "model_provider_name": "ollama", - "prompt_id": "simple_prompt_builder", - "structured_output_mode": "json_schema", - }, - "plaintext_input": "Follow-up message", - "task_run_id": task_run.id, - } - - continued_run = TaskRun( - parent=task, - input=task_run.input, - input_source=task_run.input_source, - output=TaskOutput( - output="Continued response", - source=task_run.output.source, - ), - ) - continued_run.id = task_run.id - - with ( - patch("kiln_server.run_api.task_from_id") as mock_task_from_id, - patch.object(LiteLlmAdapter, "invoke", new_callable=AsyncMock) as mock_invoke, - patch("kiln_ai.utils.config.Config.shared") as MockConfig, - ): - mock_task_from_id.return_value = task - mock_invoke.return_value = continued_run - - mock_config_instance = MockConfig.return_value - mock_config_instance.ollama_base_url = "http://localhost:11434/v1" - - response = client.post( - f"/api/projects/{project.id}/tasks/{task.id}/run", json=run_task_request - ) - - assert response.status_code == 200 - mock_invoke.assert_called_once() - call_kwargs = mock_invoke.call_args[1] - assert call_kwargs["existing_run"].id == task_run.id - assert mock_invoke.call_args[0][0] == "Follow-up message" - res = response.json() - assert res["output"]["output"] == "Continued response" - - -@pytest.mark.asyncio -async def test_run_task_task_run_id_not_found_returns_404(client, task_run_setup): - """Test that run_task with nonexistent task_run_id returns 404.""" - project = task_run_setup["project"] - task = task_run_setup["task"] - - run_task_request = { - "run_config_properties": { - "model_name": "gpt_4o", - "model_provider_name": "ollama", - "prompt_id": "simple_prompt_builder", - "structured_output_mode": "json_schema", - }, - "plaintext_input": "Follow-up", - "task_run_id": "nonexistent-run-id", - } - - with patch("kiln_server.run_api.task_from_id") as mock_task_from_id: - mock_task_from_id.return_value = task - response = client.post( - f"/api/projects/{project.id}/tasks/{task.id}/run", json=run_task_request - ) - - assert response.status_code == 404 - assert "Run not found" in response.json()["message"] - - -@pytest.mark.asyncio -async def test_run_task_task_run_id_no_trace_returns_400(client, task_run_setup): - """Test that run_task with task_run_id for run without trace returns 400.""" - project = task_run_setup["project"] - task = task_run_setup["task"] - - task_run_no_trace = TaskRun( - parent=task, - input="Hello", - input_source=DataSource( - type=DataSourceType.human, properties={"created_by": "Test User"} - ), - output=TaskOutput( - output="Hi", - source=DataSource( - type=DataSourceType.synthetic, - properties={ - "model_name": "gpt_4o", - "model_provider": "ollama", - "adapter_name": "kiln_langchain_adapter", - "prompt_id": "simple_prompt_builder", - }, - ), - ), - trace=None, - ) - task_run_no_trace.save_to_file() - - run_task_request = { - "run_config_properties": { - "model_name": "gpt_4o", - "model_provider_name": "ollama", - "prompt_id": "simple_prompt_builder", - "structured_output_mode": "json_schema", - }, - "plaintext_input": "Follow-up", - "task_run_id": task_run_no_trace.id, - } - - with patch("kiln_server.run_api.task_from_id") as mock_task_from_id: - mock_task_from_id.return_value = task - response = client.post( - f"/api/projects/{project.id}/tasks/{task.id}/run", json=run_task_request - ) - - assert response.status_code == 400 - assert "no trace" in response.json()["message"].lower() - - @pytest.mark.asyncio async def test_run_task_structured_output(client, task_run_setup): task = task_run_setup["task"] @@ -1918,7 +1789,7 @@ def _assert_math_tools_response(res: dict, expected_in_output: str) -> None: async def test_run_task_adapter_sanity_math_tools( client, adapter_sanity_check_math_tools_setup ): - """Multi-turn run with built-in Kiln math tools. Test that tools + continue session work as expected.""" + """Multiple runs with built-in Kiln math tools. Test that tools work across independent runs.""" if not os.environ.get("OPENROUTER_API_KEY"): pytest.skip("OPENROUTER_API_KEY required for this test") @@ -1950,19 +1821,16 @@ async def test_run_task_adapter_sanity_math_tools( assert response1.status_code == 200 res1 = response1.json() _assert_math_tools_response(res1, "4") - task_run_id = res1["id"] response2 = client.post( f"/api/projects/{project.id}/tasks/{task.id}/run", json={ "run_config_properties": run_config, "plaintext_input": "What is 3 times 4? Use the tools to calculate.", - "task_run_id": task_run_id, }, ) assert response2.status_code == 200 res2 = response2.json() - assert res2["id"] == task_run_id _assert_math_tools_response(res2, "12") response3 = client.post( @@ -1970,24 +1838,19 @@ async def test_run_task_adapter_sanity_math_tools( json={ "run_config_properties": run_config, "plaintext_input": "What is 7 times 8 plus 3? Use the tools to calculate.", - "task_run_id": task_run_id, }, ) assert response3.status_code == 200 res3 = response3.json() - assert res3["id"] == task_run_id _assert_math_tools_response(res3, "59") - # now ask it to list out all the previous results in an array response4 = client.post( f"/api/projects/{project.id}/tasks/{task.id}/run", json={ "run_config_properties": run_config, - "plaintext_input": "List all the previous results in an array - e.g. [55, 81, 7].", - "task_run_id": task_run_id, + "plaintext_input": "What is 10 minus 3? Use the tools to calculate.", }, ) assert response4.status_code == 200 res4 = response4.json() - assert res4["id"] == task_run_id - assert res4["output"]["output"] == "[4, 12, 59]" + _assert_math_tools_response(res4, "7") diff --git a/libs/server/pyproject.toml b/libs/server/pyproject.toml index 8514c529f..a653f650a 100644 --- a/libs/server/pyproject.toml +++ b/libs/server/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kiln-server" -version = "0.24.0" +version = "0.25.0" requires-python = ">=3.10" description = 'Kiln AI Server' readme = "README.md" diff --git a/pyproject.toml b/pyproject.toml index 652c995d2..2c9a1b632 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dev = [ "ruff>=0.15.0", "watchfiles>=1.1.0", "scalar-fastapi>=1.4.3", - "ty>=0.0.2", + "ty==0.0.8", ] [tool.uv] diff --git a/uv.lock b/uv.lock index e74173fea..4dd1f62fc 100644 --- a/uv.lock +++ b/uv.lock @@ -1295,31 +1295,34 @@ wheels = [ [[package]] name = "hf-xet" -version = "1.2.0" +version = "1.3.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5e/6e/0f11bacf08a67f7fb5ee09740f2ca54163863b07b70d579356e9222ce5d8/hf_xet-1.2.0.tar.gz", hash = "sha256:a8c27070ca547293b6890c4bf389f713f80e8c478631432962bb7f4bc0bd7d7f", size = 506020, upload-time = "2025-10-24T19:04:32.129Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/a5/85ef910a0aa034a2abcfadc360ab5ac6f6bc4e9112349bd40ca97551cff0/hf_xet-1.2.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:ceeefcd1b7aed4956ae8499e2199607765fbd1c60510752003b6cc0b8413b649", size = 2861870, upload-time = "2025-10-24T19:04:11.422Z" }, - { url = "https://files.pythonhosted.org/packages/ea/40/e2e0a7eb9a51fe8828ba2d47fe22a7e74914ea8a0db68a18c3aa7449c767/hf_xet-1.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b70218dd548e9840224df5638fdc94bd033552963cfa97f9170829381179c813", size = 2717584, upload-time = "2025-10-24T19:04:09.586Z" }, - { url = "https://files.pythonhosted.org/packages/a5/7d/daf7f8bc4594fdd59a8a596f9e3886133fdc68e675292218a5e4c1b7e834/hf_xet-1.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d40b18769bb9a8bc82a9ede575ce1a44c75eb80e7375a01d76259089529b5dc", size = 3315004, upload-time = "2025-10-24T19:04:00.314Z" }, - { url = "https://files.pythonhosted.org/packages/b1/ba/45ea2f605fbf6d81c8b21e4d970b168b18a53515923010c312c06cd83164/hf_xet-1.2.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd3a6027d59cfb60177c12d6424e31f4b5ff13d8e3a1247b3a584bf8977e6df5", size = 3222636, upload-time = "2025-10-24T19:03:58.111Z" }, - { url = "https://files.pythonhosted.org/packages/4a/1d/04513e3cab8f29ab8c109d309ddd21a2705afab9d52f2ba1151e0c14f086/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6de1fc44f58f6dd937956c8d304d8c2dea264c80680bcfa61ca4a15e7b76780f", size = 3408448, upload-time = "2025-10-24T19:04:20.951Z" }, - { url = "https://files.pythonhosted.org/packages/f0/7c/60a2756d7feec7387db3a1176c632357632fbe7849fce576c5559d4520c7/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f182f264ed2acd566c514e45da9f2119110e48a87a327ca271027904c70c5832", size = 3503401, upload-time = "2025-10-24T19:04:22.549Z" }, - { url = "https://files.pythonhosted.org/packages/4e/64/48fffbd67fb418ab07451e4ce641a70de1c40c10a13e25325e24858ebe5a/hf_xet-1.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:293a7a3787e5c95d7be1857358a9130694a9c6021de3f27fa233f37267174382", size = 2900866, upload-time = "2025-10-24T19:04:33.461Z" }, - { url = "https://files.pythonhosted.org/packages/e2/51/f7e2caae42f80af886db414d4e9885fac959330509089f97cccb339c6b87/hf_xet-1.2.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:10bfab528b968c70e062607f663e21e34e2bba349e8038db546646875495179e", size = 2861861, upload-time = "2025-10-24T19:04:19.01Z" }, - { url = "https://files.pythonhosted.org/packages/6e/1d/a641a88b69994f9371bd347f1dd35e5d1e2e2460a2e350c8d5165fc62005/hf_xet-1.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2a212e842647b02eb6a911187dc878e79c4aa0aa397e88dd3b26761676e8c1f8", size = 2717699, upload-time = "2025-10-24T19:04:17.306Z" }, - { url = "https://files.pythonhosted.org/packages/df/e0/e5e9bba7d15f0318955f7ec3f4af13f92e773fbb368c0b8008a5acbcb12f/hf_xet-1.2.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30e06daccb3a7d4c065f34fc26c14c74f4653069bb2b194e7f18f17cbe9939c0", size = 3314885, upload-time = "2025-10-24T19:04:07.642Z" }, - { url = "https://files.pythonhosted.org/packages/21/90/b7fe5ff6f2b7b8cbdf1bd56145f863c90a5807d9758a549bf3d916aa4dec/hf_xet-1.2.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:29c8fc913a529ec0a91867ce3d119ac1aac966e098cf49501800c870328cc090", size = 3221550, upload-time = "2025-10-24T19:04:05.55Z" }, - { url = "https://files.pythonhosted.org/packages/6f/cb/73f276f0a7ce46cc6a6ec7d6c7d61cbfe5f2e107123d9bbd0193c355f106/hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e159cbfcfbb29f920db2c09ed8b660eb894640d284f102ada929b6e3dc410a", size = 3408010, upload-time = "2025-10-24T19:04:28.598Z" }, - { url = "https://files.pythonhosted.org/packages/b8/1e/d642a12caa78171f4be64f7cd9c40e3ca5279d055d0873188a58c0f5fbb9/hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9c91d5ae931510107f148874e9e2de8a16052b6f1b3ca3c1b12f15ccb491390f", size = 3503264, upload-time = "2025-10-24T19:04:30.397Z" }, - { url = "https://files.pythonhosted.org/packages/17/b5/33764714923fa1ff922770f7ed18c2daae034d21ae6e10dbf4347c854154/hf_xet-1.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:210d577732b519ac6ede149d2f2f34049d44e8622bf14eb3d63bbcd2d4b332dc", size = 2901071, upload-time = "2025-10-24T19:04:37.463Z" }, - { url = "https://files.pythonhosted.org/packages/96/2d/22338486473df5923a9ab7107d375dbef9173c338ebef5098ef593d2b560/hf_xet-1.2.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:46740d4ac024a7ca9b22bebf77460ff43332868b661186a8e46c227fdae01848", size = 2866099, upload-time = "2025-10-24T19:04:15.366Z" }, - { url = "https://files.pythonhosted.org/packages/7f/8c/c5becfa53234299bc2210ba314eaaae36c2875e0045809b82e40a9544f0c/hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:27df617a076420d8845bea087f59303da8be17ed7ec0cd7ee3b9b9f579dff0e4", size = 2722178, upload-time = "2025-10-24T19:04:13.695Z" }, - { url = "https://files.pythonhosted.org/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3651fd5bfe0281951b988c0facbe726aa5e347b103a675f49a3fa8144c7968fd", size = 3320214, upload-time = "2025-10-24T19:04:03.596Z" }, - { url = "https://files.pythonhosted.org/packages/46/92/3f7ec4a1b6a65bf45b059b6d4a5d38988f63e193056de2f420137e3c3244/hf_xet-1.2.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d06fa97c8562fb3ee7a378dd9b51e343bc5bc8190254202c9771029152f5e08c", size = 3229054, upload-time = "2025-10-24T19:04:01.949Z" }, - { url = "https://files.pythonhosted.org/packages/0b/dd/7ac658d54b9fb7999a0ccb07ad863b413cbaf5cf172f48ebcd9497ec7263/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4c1428c9ae73ec0939410ec73023c4f842927f39db09b063b9482dac5a3bb737", size = 3413812, upload-time = "2025-10-24T19:04:24.585Z" }, - { url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" }, - { url = "https://files.pythonhosted.org/packages/cb/44/870d44b30e1dcfb6a65932e3e1506c103a8a5aea9103c337e7a53180322c/hf_xet-1.2.0-cp37-abi3-win_amd64.whl", hash = "sha256:e6584a52253f72c9f52f9e549d5895ca7a471608495c4ecaa6cc73dba2b24d69", size = 2905735, upload-time = "2025-10-24T19:04:35.928Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/a6/d0/73454ef7ca885598a3194d07d5c517d91a840753c5b35d272600d7907f64/hf_xet-1.3.1.tar.gz", hash = "sha256:513aa75f8dc39a63cc44dbc8d635ccf6b449e07cdbd8b2e2d006320d2e4be9bb", size = 641393, upload-time = "2026-02-25T00:57:56.701Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/79/9b6a5614230d7a871442d8d8e1c270496821638ba3a9baac16a5b9166200/hf_xet-1.3.1-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:08b231260c68172c866f7aa7257c165d0c87887491aafc5efeee782731725366", size = 3759716, upload-time = "2026-02-25T00:57:41.052Z" }, + { url = "https://files.pythonhosted.org/packages/d4/de/72acb8d7702b3cf9b36a68e8380f3114bf04f9f21cf9e25317457fe31f00/hf_xet-1.3.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:0810b69c64e96dee849036193848007f665dca2311879c9ea8693f4fc37f1795", size = 3518075, upload-time = "2026-02-25T00:57:39.605Z" }, + { url = "https://files.pythonhosted.org/packages/1d/5c/ed728d8530fec28da88ee882b522fccf00dc98e9d7bae4cdb0493070cb17/hf_xet-1.3.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ecd38f98e7f0f41108e30fd4a9a5553ec30cf726df7473dd3e75a1b6d56728c2", size = 4174369, upload-time = "2026-02-25T00:57:32.697Z" }, + { url = "https://files.pythonhosted.org/packages/3c/db/785a0e20aa3086948a26573f1d4ff5c090e63564bf0a52d32eb5b4d82e8d/hf_xet-1.3.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:65411867d46700765018b1990eb1604c3bf0bf576d9e65fc57fdcc10797a2eb9", size = 3953249, upload-time = "2026-02-25T00:57:30.096Z" }, + { url = "https://files.pythonhosted.org/packages/c4/6a/51b669c1e3dbd9374b61356f554e8726b9e1c1d6a7bee5d727d3913b10ad/hf_xet-1.3.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:1684c840c60da12d76c2a031ba40e4b154fdbf9593836fcf5ff090d95a033c61", size = 4152989, upload-time = "2026-02-25T00:57:48.308Z" }, + { url = "https://files.pythonhosted.org/packages/df/31/de07e26e396f46d13a09251df69df9444190e93e06a9d30d639e96c8a0ed/hf_xet-1.3.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:b3012c0f2ce1f0863338491a2bc0fd3f84aded0e147ab25f230da1f5249547fd", size = 4390709, upload-time = "2026-02-25T00:57:49.845Z" }, + { url = "https://files.pythonhosted.org/packages/e3/c1/fcb010b54488c2c112224f55b71f80e44d1706d9b764a0966310b283f86e/hf_xet-1.3.1-cp313-cp313t-win_amd64.whl", hash = "sha256:4eb432e1aa707a65a7e1f8455e40c5b47431d44fe0fb1b0c5d53848c27469398", size = 3634142, upload-time = "2026-02-25T00:57:59.063Z" }, + { url = "https://files.pythonhosted.org/packages/da/a6/9ef49cc601c68209979661b3e0b6659fc5a47bfb40f3ebf29eae9ee09e5c/hf_xet-1.3.1-cp313-cp313t-win_arm64.whl", hash = "sha256:e56104c84b2a88b9c7b23ba11a2d7ed0ccbe96886b3f985a50cedd2f0e99853f", size = 3494918, upload-time = "2026-02-25T00:57:57.654Z" }, + { url = "https://files.pythonhosted.org/packages/e7/f5/66adbb1f54a1b3c6da002fa36d4405901ddbcb7d927d780db17ce18ab99d/hf_xet-1.3.1-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:6517a245e41df3eae5adc5f9e8c86fa52abd548de798cbcd989f0082152860aa", size = 3759781, upload-time = "2026-02-25T00:57:47.017Z" }, + { url = "https://files.pythonhosted.org/packages/1e/75/189d91a90480c142cc710c1baa35ece20e8652d5fe5c9b2364a13573d827/hf_xet-1.3.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:4a322d506c513f98fdc1aa2aaa825daefd535b686e80ca789e6d33fcb146f524", size = 3517533, upload-time = "2026-02-25T00:57:45.812Z" }, + { url = "https://files.pythonhosted.org/packages/c6/52/52dd1ab6c29661e29585f3c10d14572e2535a3a472f27a0a46215b0f4659/hf_xet-1.3.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8f16ec9d26badec46334a798e01b5d86af536924789c95b1a1ec6a05f26523e0", size = 4174082, upload-time = "2026-02-25T00:57:38.171Z" }, + { url = "https://files.pythonhosted.org/packages/14/03/460add181c79e2ea1527d2ad27788ecccaee1d5a82563f9402e25ee627e4/hf_xet-1.3.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:e1f5d72bd5b73e61530fff573bcff34bdb64af2bf4862cdd516e6c1dab4dc75b", size = 3952874, upload-time = "2026-02-25T00:57:36.942Z" }, + { url = "https://files.pythonhosted.org/packages/01/56/bf78f18890dfc8caa907830e95424dce0887d5c45efde13f23c9ebbaa8ef/hf_xet-1.3.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:4bc71afd853508b2ddf123b8fc9de71b0afa4c956ec730b69fb76103781e94cd", size = 4152325, upload-time = "2026-02-25T00:57:54.081Z" }, + { url = "https://files.pythonhosted.org/packages/3c/94/91685c6a4a7f513097a6a73b1e879024304cd0eae78080e3d737622f2fd9/hf_xet-1.3.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:541b4b00ed294ae6cfd9416de9506e58971013714d7316189c9638ed54e362d4", size = 4390499, upload-time = "2026-02-25T00:57:55.258Z" }, + { url = "https://files.pythonhosted.org/packages/79/1b/1e72c8ea1f31ef94640d1f265630d35b97b2ef31fe12696bbcc32dbcdc95/hf_xet-1.3.1-cp314-cp314t-win_amd64.whl", hash = "sha256:f85480b4fe3e8e4cdbc59ef1d235152b732fd57ca439cc983c291892945ae818", size = 3634352, upload-time = "2026-02-25T00:58:04.749Z" }, + { url = "https://files.pythonhosted.org/packages/cf/61/b59e87a7a10b95c4578a6ce555339b2f002035569dfd366662b9f59975a8/hf_xet-1.3.1-cp314-cp314t-win_arm64.whl", hash = "sha256:83a8830160392ef4bea78d443ea2cf1febe65783b3843a8f12c64b368981e7e2", size = 3494371, upload-time = "2026-02-25T00:58:03.422Z" }, + { url = "https://files.pythonhosted.org/packages/75/f8/c2da4352c0335df6ae41750cf5bab09fdbfc30d3b4deeed9d621811aa835/hf_xet-1.3.1-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:581d1809a016f7881069d86a072168a8199a46c839cf394ff53970a47e4f1ca1", size = 3761755, upload-time = "2026-02-25T00:57:43.621Z" }, + { url = "https://files.pythonhosted.org/packages/c0/e5/a2f3eaae09da57deceb16a96ebe9ae1f6f7b9b94145a9cd3c3f994e7782a/hf_xet-1.3.1-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:329c80c86f2dda776bafd2e4813a46a3ee648dce3ac0c84625902c70d7a6ddba", size = 3523677, upload-time = "2026-02-25T00:57:42.3Z" }, + { url = "https://files.pythonhosted.org/packages/61/cd/acbbf9e51f17d8cef2630e61741228e12d4050716619353efc1ac119f902/hf_xet-1.3.1-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2973c3ff594c3a8da890836308cae1444c8af113c6f10fe6824575ddbc37eca7", size = 4178557, upload-time = "2026-02-25T00:57:35.399Z" }, + { url = "https://files.pythonhosted.org/packages/df/4f/014c14c4ae3461d9919008d0bed2f6f35ba1741e28b31e095746e8dac66f/hf_xet-1.3.1-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ed4bfd2e6d10cb86c9b0f3483df1d7dd2d0220f75f27166925253bacbc1c2dbe", size = 3958975, upload-time = "2026-02-25T00:57:34.004Z" }, + { url = "https://files.pythonhosted.org/packages/86/50/043f5c5a26f3831c3fa2509c17fcd468fd02f1f24d363adc7745fbe661cb/hf_xet-1.3.1-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:713913387cc76e300116030705d843a9f15aee86158337eeffb9eb8d26f47fcd", size = 4158298, upload-time = "2026-02-25T00:57:51.14Z" }, + { url = "https://files.pythonhosted.org/packages/08/9c/b667098a636a88358dbeb2caf90e3cb9e4b961f61f6c55bb312793424def/hf_xet-1.3.1-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e5063789c9d21f51e9ed4edbee8539655d3486e9cad37e96b7af967da20e8b16", size = 4395743, upload-time = "2026-02-25T00:57:52.783Z" }, + { url = "https://files.pythonhosted.org/packages/70/37/4db0e4e1534270800cfffd5a7e0b338f2137f8ceb5768000147650d34ea9/hf_xet-1.3.1-cp37-abi3-win_amd64.whl", hash = "sha256:607d5bbc2730274516714e2e442a26e40e3330673ac0d0173004461409147dee", size = 3638145, upload-time = "2026-02-25T00:58:02.167Z" }, + { url = "https://files.pythonhosted.org/packages/4e/46/1ba8d36f8290a4b98f78898bdce2b0e8fe6d9a59df34a1399eb61a8d877f/hf_xet-1.3.1-cp37-abi3-win_arm64.whl", hash = "sha256:851b1be6597a87036fe7258ce7578d5df3c08176283b989c3b165f94125c5097", size = 3500490, upload-time = "2026-02-25T00:58:00.667Z" }, ] [[package]] @@ -1361,7 +1364,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "1.1.5" +version = "1.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -1370,14 +1373,13 @@ dependencies = [ { name = "httpx" }, { name = "packaging" }, { name = "pyyaml" }, - { name = "shellingham" }, { name = "tqdm" }, - { name = "typer-slim" }, + { name = "typer" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fb/02/c3d534d7498ba2792da1d2ce56b5d38bbcbcbbba62071c90ee289b408e8d/huggingface_hub-1.1.5.tar.gz", hash = "sha256:40ba5c9a08792d888fde6088920a0a71ab3cd9d5e6617c81a797c657f1fd9968", size = 607199, upload-time = "2025-11-20T15:49:32.809Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/76/b5efb3033d8499b17f9386beaf60f64c461798e1ee16d10bc9c0077beba5/huggingface_hub-1.5.0.tar.gz", hash = "sha256:f281838db29265880fb543de7a23b0f81d3504675de82044307ea3c6c62f799d", size = 695872, upload-time = "2026-02-26T15:35:32.745Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/35/f4/124858007ddf3c61e9b144107304c9152fa80b5b6c168da07d86fe583cc1/huggingface_hub-1.1.5-py3-none-any.whl", hash = "sha256:e88ecc129011f37b868586bbcfae6c56868cae80cd56a79d61575426a3aa0d7d", size = 516000, upload-time = "2025-11-20T15:49:30.926Z" }, + { url = "https://files.pythonhosted.org/packages/ec/74/2bc951622e2dbba1af9a460d93c51d15e458becd486e62c29cc0ccb08178/huggingface_hub-1.5.0-py3-none-any.whl", hash = "sha256:c9c0b3ab95a777fc91666111f3b3ede71c0cdced3614c553a64e98920585c4ee", size = 596261, upload-time = "2026-02-26T15:35:31.1Z" }, ] [[package]] @@ -1391,14 +1393,14 @@ wheels = [ [[package]] name = "importlib-metadata" -version = "8.7.0" +version = "8.7.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "zipp" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/76/66/650a33bd90f786193e4de4b3ad86ea60b53c89b669a5c7be931fac31cdb0/importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000", size = 56641, upload-time = "2025-04-27T15:29:01.736Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/49/3b30cad09e7771a4982d9975a8cbf64f00d4a1ececb53297f1d9a7be1b10/importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb", size = 57107, upload-time = "2025-12-21T10:00:19.278Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656, upload-time = "2025-04-27T15:29:00.214Z" }, + { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, ] [[package]] @@ -1575,7 +1577,7 @@ wheels = [ [[package]] name = "kiln-ai" -version = "0.24.0" +version = "0.25.0" source = { editable = "libs/core" } dependencies = [ { name = "anyio" }, @@ -1625,7 +1627,7 @@ requires-dist = [ { name = "google-genai", specifier = ">=1.21.1" }, { name = "jsonschema", specifier = ">=4.23.0" }, { name = "lancedb", specifier = ">=0.24.2" }, - { name = "litellm", specifier = ">=1.80.9" }, + { name = "litellm", specifier = ">=1.81.16" }, { name = "llama-index", specifier = ">=0.13.3" }, { name = "llama-index-vector-stores-lancedb", specifier = ">=0.4.2" }, { name = "mcp", extras = ["cli"], specifier = ">=1.10.1" }, @@ -1697,13 +1699,13 @@ dev = [ { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "ruff", specifier = ">=0.15.0" }, { name = "scalar-fastapi", specifier = ">=1.4.3" }, - { name = "ty", specifier = ">=0.0.2" }, + { name = "ty", specifier = "==0.0.8" }, { name = "watchfiles", specifier = ">=1.1.0" }, ] [[package]] name = "kiln-server" -version = "0.24.0" +version = "0.25.0" source = { editable = "libs/server" } dependencies = [ { name = "fastapi" }, @@ -1750,7 +1752,7 @@ dev = [ [[package]] name = "kiln-studio-desktop" -version = "0.24.0" +version = "0.25.0" source = { editable = "app/desktop" } dependencies = [ { name = "kiln-server" }, @@ -1825,7 +1827,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.81.7" +version = "1.81.16" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -1841,9 +1843,9 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/77/69/cfa8a1d68cd10223a9d9741c411e131aece85c60c29c1102d762738b3e5c/litellm-1.81.7.tar.gz", hash = "sha256:442ff38708383ebee21357b3d936e58938172bae892f03bc5be4019ed4ff4a17", size = 14039864, upload-time = "2026-02-03T19:43:10.633Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/36/3cbb22d6ef88c10f3fa4f04664c2a37e93a2e6f9c51899cd9fd025cb0a50/litellm-1.81.16.tar.gz", hash = "sha256:264a3868942e722cd6c19c2d625524fe624a1b6961c37c22d299dc7ea99823b3", size = 16668405, upload-time = "2026-02-26T13:01:48.429Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/60/95/8cecc7e6377171e4ac96f23d65236af8706d99c1b7b71a94c72206672810/litellm-1.81.7-py3-none-any.whl", hash = "sha256:58466c88c3289c6a3830d88768cf8f307581d9e6c87861de874d1128bb2de90d", size = 12254178, upload-time = "2026-02-03T19:43:08.035Z" }, + { url = "https://files.pythonhosted.org/packages/1f/1e/0022cde913bac87a493e4a182b8768f75e7ae90b64d4e11acb009b18311f/litellm-1.81.16-py3-none-any.whl", hash = "sha256:d6bcc13acbd26719e07bfa6b9923740e88409cbf1f9d626d85fc9ae0e0eec88c", size = 14774277, upload-time = "2026-02-26T13:01:45.652Z" }, ] [[package]] @@ -3749,27 +3751,32 @@ dependencies = [ [[package]] name = "tokenizers" -version = "0.22.1" +version = "0.22.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1c/46/fb6854cec3278fbfa4a75b50232c77622bc517ac886156e6afbfa4d8fc6e/tokenizers-0.22.1.tar.gz", hash = "sha256:61de6522785310a309b3407bac22d99c4db5dba349935e99e4d15ea2226af2d9", size = 363123, upload-time = "2025-09-19T09:49:23.424Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/33/f4b2d94ada7ab297328fc671fed209368ddb82f965ec2224eb1892674c3a/tokenizers-0.22.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:59fdb013df17455e5f950b4b834a7b3ee2e0271e6378ccb33aa74d178b513c73", size = 3069318, upload-time = "2025-09-19T09:49:11.848Z" }, - { url = "https://files.pythonhosted.org/packages/1c/58/2aa8c874d02b974990e89ff95826a4852a8b2a273c7d1b4411cdd45a4565/tokenizers-0.22.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:8d4e484f7b0827021ac5f9f71d4794aaef62b979ab7608593da22b1d2e3c4edc", size = 2926478, upload-time = "2025-09-19T09:49:09.759Z" }, - { url = "https://files.pythonhosted.org/packages/1e/3b/55e64befa1e7bfea963cf4b787b2cea1011362c4193f5477047532ce127e/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19d2962dd28bc67c1f205ab180578a78eef89ac60ca7ef7cbe9635a46a56422a", size = 3256994, upload-time = "2025-09-19T09:48:56.701Z" }, - { url = "https://files.pythonhosted.org/packages/71/0b/fbfecf42f67d9b7b80fde4aabb2b3110a97fac6585c9470b5bff103a80cb/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:38201f15cdb1f8a6843e6563e6e79f4abd053394992b9bbdf5213ea3469b4ae7", size = 3153141, upload-time = "2025-09-19T09:48:59.749Z" }, - { url = "https://files.pythonhosted.org/packages/17/a9/b38f4e74e0817af8f8ef925507c63c6ae8171e3c4cb2d5d4624bf58fca69/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1cbe5454c9a15df1b3443c726063d930c16f047a3cc724b9e6e1a91140e5a21", size = 3508049, upload-time = "2025-09-19T09:49:05.868Z" }, - { url = "https://files.pythonhosted.org/packages/d2/48/dd2b3dac46bb9134a88e35d72e1aa4869579eacc1a27238f1577270773ff/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7d094ae6312d69cc2a872b54b91b309f4f6fbce871ef28eb27b52a98e4d0214", size = 3710730, upload-time = "2025-09-19T09:49:01.832Z" }, - { url = "https://files.pythonhosted.org/packages/93/0e/ccabc8d16ae4ba84a55d41345207c1e2ea88784651a5a487547d80851398/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afd7594a56656ace95cdd6df4cca2e4059d294c5cfb1679c57824b605556cb2f", size = 3412560, upload-time = "2025-09-19T09:49:03.867Z" }, - { url = "https://files.pythonhosted.org/packages/d0/c6/dc3a0db5a6766416c32c034286d7c2d406da1f498e4de04ab1b8959edd00/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2ef6063d7a84994129732b47e7915e8710f27f99f3a3260b8a38fc7ccd083f4", size = 3250221, upload-time = "2025-09-19T09:49:07.664Z" }, - { url = "https://files.pythonhosted.org/packages/d7/a6/2c8486eef79671601ff57b093889a345dd3d576713ef047776015dc66de7/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ba0a64f450b9ef412c98f6bcd2a50c6df6e2443b560024a09fa6a03189726879", size = 9345569, upload-time = "2025-09-19T09:49:14.214Z" }, - { url = "https://files.pythonhosted.org/packages/6b/16/32ce667f14c35537f5f605fe9bea3e415ea1b0a646389d2295ec348d5657/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:331d6d149fa9c7d632cde4490fb8bbb12337fa3a0232e77892be656464f4b446", size = 9271599, upload-time = "2025-09-19T09:49:16.639Z" }, - { url = "https://files.pythonhosted.org/packages/51/7c/a5f7898a3f6baa3fc2685c705e04c98c1094c523051c805cdd9306b8f87e/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:607989f2ea68a46cb1dfbaf3e3aabdf3f21d8748312dbeb6263d1b3b66c5010a", size = 9533862, upload-time = "2025-09-19T09:49:19.146Z" }, - { url = "https://files.pythonhosted.org/packages/36/65/7e75caea90bc73c1dd8d40438adf1a7bc26af3b8d0a6705ea190462506e1/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a0f307d490295717726598ef6fa4f24af9d484809223bbc253b201c740a06390", size = 9681250, upload-time = "2025-09-19T09:49:21.501Z" }, - { url = "https://files.pythonhosted.org/packages/30/2c/959dddef581b46e6209da82df3b78471e96260e2bc463f89d23b1bf0e52a/tokenizers-0.22.1-cp39-abi3-win32.whl", hash = "sha256:b5120eed1442765cd90b903bb6cfef781fd8fe64e34ccaecbae4c619b7b12a82", size = 2472003, upload-time = "2025-09-19T09:49:27.089Z" }, - { url = "https://files.pythonhosted.org/packages/b3/46/e33a8c93907b631a99377ef4c5f817ab453d0b34f93529421f42ff559671/tokenizers-0.22.1-cp39-abi3-win_amd64.whl", hash = "sha256:65fd6e3fb11ca1e78a6a93602490f134d1fdeb13bcef99389d5102ea318ed138", size = 2674684, upload-time = "2025-09-19T09:49:24.953Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/73/6f/f80cfef4a312e1fb34baf7d85c72d4411afde10978d4657f8cdd811d3ccc/tokenizers-0.22.2.tar.gz", hash = "sha256:473b83b915e547aa366d1eee11806deaf419e17be16310ac0a14077f1e28f917", size = 372115, upload-time = "2026-01-05T10:45:15.988Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/97/5dbfabf04c7e348e655e907ed27913e03db0923abb5dfdd120d7b25630e1/tokenizers-0.22.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:544dd704ae7238755d790de45ba8da072e9af3eea688f698b137915ae959281c", size = 3100275, upload-time = "2026-01-05T10:41:02.158Z" }, + { url = "https://files.pythonhosted.org/packages/2e/47/174dca0502ef88b28f1c9e06b73ce33500eedfac7a7692108aec220464e7/tokenizers-0.22.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:1e418a55456beedca4621dbab65a318981467a2b188e982a23e117f115ce5001", size = 2981472, upload-time = "2026-01-05T10:41:00.276Z" }, + { url = "https://files.pythonhosted.org/packages/d6/84/7990e799f1309a8b87af6b948f31edaa12a3ed22d11b352eaf4f4b2e5753/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2249487018adec45d6e3554c71d46eb39fa8ea67156c640f7513eb26f318cec7", size = 3290736, upload-time = "2026-01-05T10:40:32.165Z" }, + { url = "https://files.pythonhosted.org/packages/78/59/09d0d9ba94dcd5f4f1368d4858d24546b4bdc0231c2354aa31d6199f0399/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25b85325d0815e86e0bac263506dd114578953b7b53d7de09a6485e4a160a7dd", size = 3168835, upload-time = "2026-01-05T10:40:38.847Z" }, + { url = "https://files.pythonhosted.org/packages/47/50/b3ebb4243e7160bda8d34b731e54dd8ab8b133e50775872e7a434e524c28/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bfb88f22a209ff7b40a576d5324bf8286b519d7358663db21d6246fb17eea2d5", size = 3521673, upload-time = "2026-01-05T10:40:56.614Z" }, + { url = "https://files.pythonhosted.org/packages/e0/fa/89f4cb9e08df770b57adb96f8cbb7e22695a4cb6c2bd5f0c4f0ebcf33b66/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c774b1276f71e1ef716e5486f21e76333464f47bece56bbd554485982a9e03e", size = 3724818, upload-time = "2026-01-05T10:40:44.507Z" }, + { url = "https://files.pythonhosted.org/packages/64/04/ca2363f0bfbe3b3d36e95bf67e56a4c88c8e3362b658e616d1ac185d47f2/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:df6c4265b289083bf710dff49bc51ef252f9d5be33a45ee2bed151114a56207b", size = 3379195, upload-time = "2026-01-05T10:40:51.139Z" }, + { url = "https://files.pythonhosted.org/packages/2e/76/932be4b50ef6ccedf9d3c6639b056a967a86258c6d9200643f01269211ca/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:369cc9fc8cc10cb24143873a0d95438bb8ee257bb80c71989e3ee290e8d72c67", size = 3274982, upload-time = "2026-01-05T10:40:58.331Z" }, + { url = "https://files.pythonhosted.org/packages/1d/28/5f9f5a4cc211b69e89420980e483831bcc29dade307955cc9dc858a40f01/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:29c30b83d8dcd061078b05ae0cb94d3c710555fbb44861139f9f83dcca3dc3e4", size = 9478245, upload-time = "2026-01-05T10:41:04.053Z" }, + { url = "https://files.pythonhosted.org/packages/6c/fb/66e2da4704d6aadebf8cb39f1d6d1957df667ab24cff2326b77cda0dcb85/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:37ae80a28c1d3265bb1f22464c856bd23c02a05bb211e56d0c5301a435be6c1a", size = 9560069, upload-time = "2026-01-05T10:45:10.673Z" }, + { url = "https://files.pythonhosted.org/packages/16/04/fed398b05caa87ce9b1a1bb5166645e38196081b225059a6edaff6440fac/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:791135ee325f2336f498590eb2f11dc5c295232f288e75c99a36c5dbce63088a", size = 9899263, upload-time = "2026-01-05T10:45:12.559Z" }, + { url = "https://files.pythonhosted.org/packages/05/a1/d62dfe7376beaaf1394917e0f8e93ee5f67fea8fcf4107501db35996586b/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:38337540fbbddff8e999d59970f3c6f35a82de10053206a7562f1ea02d046fa5", size = 10033429, upload-time = "2026-01-05T10:45:14.333Z" }, + { url = "https://files.pythonhosted.org/packages/fd/18/a545c4ea42af3df6effd7d13d250ba77a0a86fb20393143bbb9a92e434d4/tokenizers-0.22.2-cp39-abi3-win32.whl", hash = "sha256:a6bf3f88c554a2b653af81f3204491c818ae2ac6fbc09e76ef4773351292bc92", size = 2502363, upload-time = "2026-01-05T10:45:20.593Z" }, + { url = "https://files.pythonhosted.org/packages/65/71/0670843133a43d43070abeb1949abfdef12a86d490bea9cd9e18e37c5ff7/tokenizers-0.22.2-cp39-abi3-win_amd64.whl", hash = "sha256:c9ea31edff2968b44a88f97d784c2f16dc0729b8b143ed004699ebca91f05c48", size = 2747786, upload-time = "2026-01-05T10:45:18.411Z" }, + { url = "https://files.pythonhosted.org/packages/72/f4/0de46cfa12cdcbcd464cc59fde36912af405696f687e53a091fb432f694c/tokenizers-0.22.2-cp39-abi3-win_arm64.whl", hash = "sha256:9ce725d22864a1e965217204946f830c37876eee3b2ba6fc6255e8e903d5fcbc", size = 2612133, upload-time = "2026-01-05T10:45:17.232Z" }, + { url = "https://files.pythonhosted.org/packages/84/04/655b79dbcc9b3ac5f1479f18e931a344af67e5b7d3b251d2dcdcd7558592/tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:753d47ebd4542742ef9261d9da92cd545b2cacbb48349a1225466745bb866ec4", size = 3282301, upload-time = "2026-01-05T10:40:34.858Z" }, + { url = "https://files.pythonhosted.org/packages/46/cd/e4851401f3d8f6f45d8480262ab6a5c8cb9c4302a790a35aa14eeed6d2fd/tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e10bf9113d209be7cd046d40fbabbaf3278ff6d18eb4da4c500443185dc1896c", size = 3161308, upload-time = "2026-01-05T10:40:40.737Z" }, + { url = "https://files.pythonhosted.org/packages/6f/6e/55553992a89982cd12d4a66dddb5e02126c58677ea3931efcbe601d419db/tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64d94e84f6660764e64e7e0b22baa72f6cd942279fdbb21d46abd70d179f0195", size = 3718964, upload-time = "2026-01-05T10:40:46.56Z" }, + { url = "https://files.pythonhosted.org/packages/59/8c/b1c87148aa15e099243ec9f0cf9d0e970cc2234c3257d558c25a2c5304e6/tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f01a9c019878532f98927d2bacb79bbb404b43d3437455522a00a30718cdedb5", size = 3373542, upload-time = "2026-01-05T10:40:52.803Z" }, ] [[package]] @@ -3825,27 +3832,27 @@ wheels = [ [[package]] name = "ty" -version = "0.0.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/47/e5/15b6aceefcd64b53997fe2002b6fa055f0b1afd23ff6fc3f55f3da944530/ty-0.0.2.tar.gz", hash = "sha256:e02dc50b65dc58d6cb8e8b0d563833f81bf03ed8a7d0b15c6396d486489a7e1d", size = 4762024, upload-time = "2025-12-16T20:13:41.07Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6b/86/65d4826677d966cf226662767a4a597ebb4b02c432f413673c8d5d3d1ce8/ty-0.0.2-py3-none-linux_armv6l.whl", hash = "sha256:0954a0e0b6f7e06229dd1da3a9989ee9b881a26047139a88eb7c134c585ad22e", size = 9771409, upload-time = "2025-12-16T20:13:28.964Z" }, - { url = "https://files.pythonhosted.org/packages/d4/bc/6ab06b7c109cec608c24ea182cc8b4714e746a132f70149b759817092665/ty-0.0.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d6044b491d66933547033cecc87cb7eb599ba026a3ef347285add6b21107a648", size = 9580025, upload-time = "2025-12-16T20:13:34.507Z" }, - { url = "https://files.pythonhosted.org/packages/54/de/d826804e304b2430f17bb27ae15bcf02380e7f67f38b5033047e3d2523e6/ty-0.0.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fbca7f08e671a35229f6f400d73da92e2dc0a440fba53a74fe8233079a504358", size = 9098660, upload-time = "2025-12-16T20:13:01.278Z" }, - { url = "https://files.pythonhosted.org/packages/b7/8e/5cd87944ceee02bb0826f19ced54e30c6bb971e985a22768f6be6b1a042f/ty-0.0.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3abd61153dac0b93b284d305e6f96085013a25c3a7ab44e988d24f0a5fcce729", size = 9567693, upload-time = "2025-12-16T20:13:12.559Z" }, - { url = "https://files.pythonhosted.org/packages/c6/b1/062aab2c62c5ae01c05d27b97ba022d9ff66f14a3cb9030c5ad1dca797ec/ty-0.0.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:21a9f28caafb5742e7d594104e2fe2ebd64590da31aed4745ae8bc5be67a7b85", size = 9556471, upload-time = "2025-12-16T20:13:07.771Z" }, - { url = "https://files.pythonhosted.org/packages/0e/07/856f6647a9dd6e36560d182d35d3b5fb21eae98a8bfb516cd879d0e509f3/ty-0.0.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3ec63fd23ab48e0f838fb54a47ec362a972ee80979169a7edfa6f5c5034849d", size = 9971914, upload-time = "2025-12-16T20:13:18.852Z" }, - { url = "https://files.pythonhosted.org/packages/2e/82/c2e3957dbf33a23f793a9239cfd8bd04b6defd999bd0f6e74d6a5afb9f42/ty-0.0.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e5e2e0293a259c9a53f668c9c13153cc2f1403cb0fe2b886ca054be4ac76517c", size = 10840905, upload-time = "2025-12-16T20:13:37.098Z" }, - { url = "https://files.pythonhosted.org/packages/3b/17/49bd74e3d577e6c88b8074581b7382f532a9d40552cc7c48ceaa83f1d950/ty-0.0.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fd2511ac02a83d0dc45d4570c7e21ec0c919be7a7263bad9914800d0cde47817", size = 10570251, upload-time = "2025-12-16T20:13:10.319Z" }, - { url = "https://files.pythonhosted.org/packages/2b/9b/26741834069722033a1a0963fcbb63ea45925c6697357e64e361753c6166/ty-0.0.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c482bfbfb8ad18b2e62427d02a0c934ac510c414188a3cf00e16b8acc35482f0", size = 10369078, upload-time = "2025-12-16T20:13:20.851Z" }, - { url = "https://files.pythonhosted.org/packages/94/fc/1d34ec891900d9337169ff9f8252fcaa633ae5c4d36b67effd849ed4f9ac/ty-0.0.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb514711eed3f56d7a130d4885f4b5d8e490fdcd2adac098e5cf175573a0dda3", size = 10121064, upload-time = "2025-12-16T20:13:23.095Z" }, - { url = "https://files.pythonhosted.org/packages/e5/02/e640325956172355ef8deb9b08d991f229230bf9d07f1dbda8c6665a3a43/ty-0.0.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b2c37fa26c39e9fbed7c73645ba721968ab44f28b2bfe2f79a4e15965a1c426f", size = 9553817, upload-time = "2025-12-16T20:13:27.057Z" }, - { url = "https://files.pythonhosted.org/packages/35/13/c93d579ece84895da9b0aae5d34d84100bbff63ad9f60c906a533a087175/ty-0.0.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:13b264833ac5f3b214693fca38e380e78ee7327e09beaa5ff2e47d75fcab9692", size = 9577512, upload-time = "2025-12-16T20:13:16.956Z" }, - { url = "https://files.pythonhosted.org/packages/85/53/93ab1570adc799cd9120ea187d5b4c00d821e86eca069943b179fe0d3e83/ty-0.0.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:08658d6dbbf8bdef80c0a77eda56a22ab6737002ba129301b7bbd36bcb7acd75", size = 9692726, upload-time = "2025-12-16T20:13:31.169Z" }, - { url = "https://files.pythonhosted.org/packages/9a/07/5fff5335858a14196776207d231c32e23e48a5c912a7d52c80e7a3fa6f8f/ty-0.0.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4a21b5b012061cb13d47edfff6be70052694308dba633b4c819b70f840e6c158", size = 10213996, upload-time = "2025-12-16T20:13:14.606Z" }, - { url = "https://files.pythonhosted.org/packages/a0/d3/896b1439ab765c57a8d732f73c105ec41142c417a582600638385c2bee85/ty-0.0.2-py3-none-win32.whl", hash = "sha256:d773fdad5d2b30f26313204e6b191cdd2f41ab440a6c241fdb444f8c6593c288", size = 9204906, upload-time = "2025-12-16T20:13:25.099Z" }, - { url = "https://files.pythonhosted.org/packages/5d/0a/f30981e7d637f78e3d08e77d63b818752d23db1bc4b66f9e82e2cb3d34f8/ty-0.0.2-py3-none-win_amd64.whl", hash = "sha256:d1c9ac78a8aa60d0ce89acdccf56c3cc0fcb2de07f1ecf313754d83518e8e8c5", size = 10066640, upload-time = "2025-12-16T20:13:04.045Z" }, - { url = "https://files.pythonhosted.org/packages/5a/c4/97958503cf62bfb7908d2a77b03b91a20499a7ff405f5a098c4989589f34/ty-0.0.2-py3-none-win_arm64.whl", hash = "sha256:fbdef644ade0cd4420c4ec14b604b7894cefe77bfd8659686ac2f6aba9d1a306", size = 9572022, upload-time = "2025-12-16T20:13:39.189Z" }, +version = "0.0.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/9d/59e955cc39206a0d58df5374808785c45ec2a8a2a230eb1638fbb4fe5c5d/ty-0.0.8.tar.gz", hash = "sha256:352ac93d6e0050763be57ad1e02087f454a842887e618ec14ac2103feac48676", size = 4828477, upload-time = "2025-12-29T13:50:07.193Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/2b/dd61f7e50a69c72f72c625d026e9ab64a0db62b2dd32e7426b520e2429c6/ty-0.0.8-py3-none-linux_armv6l.whl", hash = "sha256:a289d033c5576fa3b4a582b37d63395edf971cdbf70d2d2e6b8c95638d1a4fcd", size = 9853417, upload-time = "2025-12-29T13:50:08.979Z" }, + { url = "https://files.pythonhosted.org/packages/90/72/3f1d3c64a049a388e199de4493689a51fc6aa5ff9884c03dea52b4966657/ty-0.0.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:788ea97dc8153a94e476c4d57b2551a9458f79c187c4aba48fcb81f05372924a", size = 9657890, upload-time = "2025-12-29T13:50:27.867Z" }, + { url = "https://files.pythonhosted.org/packages/71/d1/08ac676bd536de3c2baba0deb60e67b3196683a2fabebfd35659d794b5e9/ty-0.0.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1b5f1f3d3e230f35a29e520be7c3d90194a5229f755b721e9092879c00842d31", size = 9180129, upload-time = "2025-12-29T13:50:22.842Z" }, + { url = "https://files.pythonhosted.org/packages/af/93/610000e2cfeea1875900f73a375ba917624b0a008d4b8a6c18c894c8dbbc/ty-0.0.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6da9ed377fbbcec0a3b60b2ca5fd30496e15068f47cef2344ba87923e78ba996", size = 9683517, upload-time = "2025-12-29T13:50:18.658Z" }, + { url = "https://files.pythonhosted.org/packages/05/04/bef50ba7d8580b0140be597de5cc0ba9a63abe50d3f65560235f23658762/ty-0.0.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7d0a2bdce5e701d19eb8d46d9da0fe31340f079cecb7c438f5ac6897c73fc5ba", size = 9676279, upload-time = "2025-12-29T13:50:25.207Z" }, + { url = "https://files.pythonhosted.org/packages/aa/b9/2aff1ef1f41b25898bc963173ae67fc8f04ca666ac9439a9c4e78d5cc0ff/ty-0.0.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef9078799d26d3cc65366e02392e2b78f64f72911b599e80a8497d2ec3117ddb", size = 10073015, upload-time = "2025-12-29T13:50:35.422Z" }, + { url = "https://files.pythonhosted.org/packages/df/0e/9feb6794b6ff0a157c3e6a8eb6365cbfa3adb9c0f7976e2abdc48615dd72/ty-0.0.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:54814ac39b4ab67cf111fc0a236818155cf49828976152378347a7678d30ee89", size = 10961649, upload-time = "2025-12-29T13:49:58.717Z" }, + { url = "https://files.pythonhosted.org/packages/f4/3b/faf7328b14f00408f4f65c9d01efe52e11b9bcc4a79e06187b370457b004/ty-0.0.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4baf0a80398e8b6c68fa36ff85045a50ede1906cd4edb41fb4fab46d471f1d4", size = 10676190, upload-time = "2025-12-29T13:50:01.11Z" }, + { url = "https://files.pythonhosted.org/packages/64/a5/cfeca780de7eeab7852c911c06a84615a174d23e9ae08aae42a645771094/ty-0.0.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ac8e23c3faefc579686799ef1649af8d158653169ad5c3a7df56b152781eeb67", size = 10438641, upload-time = "2025-12-29T13:50:29.664Z" }, + { url = "https://files.pythonhosted.org/packages/0e/8d/8667c7e0ac9f13c461ded487c8d7350f440cd39ba866d0160a8e1b1efd6c/ty-0.0.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b558a647a073d0c25540aaa10f8947de826cb8757d034dd61ecf50ab8dbd77bf", size = 10214082, upload-time = "2025-12-29T13:50:31.531Z" }, + { url = "https://files.pythonhosted.org/packages/f8/11/e563229870e2c1d089e7e715c6c3b7605a34436dddf6f58e9205823020c2/ty-0.0.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:8c0104327bf480508bd81f320e22074477df159d9eff85207df39e9c62ad5e96", size = 9664364, upload-time = "2025-12-29T13:50:05.443Z" }, + { url = "https://files.pythonhosted.org/packages/b1/ad/05b79b778bf5237bcd7ee08763b226130aa8da872cbb151c8cfa2e886203/ty-0.0.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:496f1cb87261dd1a036a5609da80ee13de2e6ee4718a661bfa2afb91352fe528", size = 9679440, upload-time = "2025-12-29T13:50:11.289Z" }, + { url = "https://files.pythonhosted.org/packages/12/b5/23ba887769c4a7b8abfd1b6395947dc3dcc87533fbf86379d3a57f87ae8f/ty-0.0.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:2c488031f92a075ae39d13ac6295fdce2141164ec38c5d47aa8dc24ee3afa37e", size = 9808201, upload-time = "2025-12-29T13:50:21.003Z" }, + { url = "https://files.pythonhosted.org/packages/f8/90/5a82ac0a0707db55376922aed80cd5fca6b2e6d6e9bcd8c286e6b43b4084/ty-0.0.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:90d6f08c5982fa3e802b8918a32e326153519077b827f91c66eea4913a86756a", size = 10313262, upload-time = "2025-12-29T13:50:03.306Z" }, + { url = "https://files.pythonhosted.org/packages/14/f7/ff97f37f0a75db9495ddbc47738ec4339837867c4bfa145bdcfbd0d1eb2f/ty-0.0.8-py3-none-win32.whl", hash = "sha256:d7f460ad6fc9325e9cc8ea898949bbd88141b4609d1088d7ede02ce2ef06e776", size = 9254675, upload-time = "2025-12-29T13:50:33.35Z" }, + { url = "https://files.pythonhosted.org/packages/af/51/eba5d83015e04630002209e3590c310a0ff1d26e1815af204a322617a42e/ty-0.0.8-py3-none-win_amd64.whl", hash = "sha256:1641fb8dedc3d2da43279d21c3c7c1f80d84eae5c264a1e8daa544458e433c19", size = 10131382, upload-time = "2025-12-29T13:50:13.719Z" }, + { url = "https://files.pythonhosted.org/packages/38/1c/0d8454ff0f0f258737ecfe84f6e508729191d29663b404832f98fa5626b7/ty-0.0.8-py3-none-win_arm64.whl", hash = "sha256:ec74f022f315bede478ecae1277a01ab618e6500c1d68450d7883f5cd6ed554a", size = 9636374, upload-time = "2025-12-29T13:50:16.344Z" }, ] [[package]] @@ -3863,19 +3870,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/fc/5b29fea8cee020515ca82cc68e3b8e1e34bb19a3535ad854cac9257b414c/typer-0.15.2-py3-none-any.whl", hash = "sha256:46a499c6107d645a9c13f7ee46c5d5096cae6f5fc57dd11eccbbb9ae3e44ddfc", size = 45061, upload-time = "2025-02-27T19:17:32.111Z" }, ] -[[package]] -name = "typer-slim" -version = "0.20.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/8e/45/81b94a52caed434b94da65729c03ad0fb7665fab0f7db9ee54c94e541403/typer_slim-0.20.0.tar.gz", hash = "sha256:9fc6607b3c6c20f5c33ea9590cbeb17848667c51feee27d9e314a579ab07d1a3", size = 106561, upload-time = "2025-10-20T17:03:46.642Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/dd/5cbf31f402f1cc0ab087c94d4669cfa55bd1e818688b910631e131d74e75/typer_slim-0.20.0-py3-none-any.whl", hash = "sha256:f42a9b7571a12b97dddf364745d29f12221865acef7a2680065f9bb29c7dc89d", size = 47087, upload-time = "2025-10-20T17:03:44.546Z" }, -] - [[package]] name = "typing-extensions" version = "4.12.2"