|
| 1 | +import numpy as np |
| 2 | + |
1 | 3 | from deepset_mcp.api.exceptions import UnexpectedAPIError |
2 | 4 | from deepset_mcp.api.protocols import AsyncClientProtocol |
| 5 | +from deepset_mcp.tools.component_helper import ( |
| 6 | + extract_component_info, |
| 7 | + extract_component_texts, |
| 8 | + format_io_info, |
| 9 | +) |
| 10 | +from deepset_mcp.tools.model_protocol import ModelProtocol |
3 | 11 |
|
4 | 12 |
|
5 | 13 | async def get_component_definition(client: AsyncClientProtocol, component_type: str) -> str: |
@@ -32,97 +40,76 @@ async def get_component_definition(client: AsyncClientProtocol, component_type: |
32 | 40 | if not component_def: |
33 | 41 | return f"Component not found: {component_type}" |
34 | 42 |
|
35 | | - # Extract relevant information |
36 | | - component_type_info = component_def["properties"]["type"] |
37 | | - init_params = component_def["properties"].get("init_parameters", {}).get("properties", {}) |
38 | | - |
39 | | - # Format the basic component information |
40 | | - parts = [ |
41 | | - f"Component: {component_type}", |
42 | | - f"Name: {component_def.get('title', 'Unknown')}", |
43 | | - f"Family: {component_type_info.get('family', 'Unknown')}", |
44 | | - f"Family Description: {component_type_info.get('family_description', 'No description available.')}", |
45 | | - f"\nDescription:\n{component_def.get('description', 'No description available.')}\n", |
46 | | - "\nInitialization Parameters:", |
47 | | - ] |
48 | | - |
49 | | - if not init_params: |
50 | | - parts.append(" No initialization parameters") |
51 | | - else: |
52 | | - for param_name, param_info in init_params.items(): |
53 | | - param_type = param_info.get("_annotation", param_info.get("type", "Unknown")) |
54 | | - param_desc = param_info.get("description", "No description available.") |
55 | | - default = f" (default: {param_info['default']})" if "default" in param_info else "" |
56 | | - parts.append(f" {param_name}: {param_type}{default}\n {param_desc}") |
| 43 | + # Get component information |
| 44 | + parts = [extract_component_info(components, component_def)] |
57 | 45 |
|
58 | 46 | # Fetch and add input/output information |
59 | 47 | try: |
60 | 48 | # Extract component name from the full path |
61 | 49 | component_name = component_type.split(".")[-1] |
62 | 50 | io_info = await haystack_service.get_component_input_output(component_name) |
63 | | - |
64 | | - # Add Input Schema |
65 | | - parts.append("\nInput Schema:") |
66 | | - if "input" in io_info: |
67 | | - input_props = io_info["input"].get("properties", {}) |
68 | | - if not input_props: |
69 | | - parts.append(" No input parameters") |
70 | | - else: |
71 | | - required = io_info["input"].get("required", []) |
72 | | - for param_name, param_info in input_props.items(): |
73 | | - req_marker = " (required)" if param_name in required else "" |
74 | | - param_type = param_info.get("_annotation", param_info.get("type", "Unknown")) |
75 | | - param_desc = param_info.get("description", "No description available.") |
76 | | - default = f" (default: {param_info['default']})" if "default" in param_info else "" |
77 | | - parts.append(f" {param_name}: {param_type}{req_marker}{default}\n {param_desc}") |
78 | | - else: |
79 | | - parts.append(" Input schema not available") |
80 | | - |
81 | | - # Add Output Schema |
82 | | - parts.append("\nOutput Schema:") |
83 | | - if "output" in io_info and isinstance(io_info["output"], dict): |
84 | | - output_info = io_info["output"] |
85 | | - if "properties" in output_info: |
86 | | - output_props = output_info.get("properties", {}) |
87 | | - if not output_props: |
88 | | - parts.append(" No output parameters") |
89 | | - else: |
90 | | - required = output_info.get("required", []) |
91 | | - for param_name, param_info in output_props.items(): |
92 | | - req_marker = " (required)" if param_name in required else "" |
93 | | - param_type = param_info.get("_annotation", param_info.get("type", "Unknown")) |
94 | | - param_desc = param_info.get("description", "No description available.") |
95 | | - default = f" (default: {param_info['default']})" if "default" in param_info else "" |
96 | | - parts.append(f" {param_name}: {param_type}{req_marker}{default}\n {param_desc}") |
97 | | - |
98 | | - # Include any definitions if they exist |
99 | | - if "definitions" in output_info: |
100 | | - parts.append("\n Definitions:") |
101 | | - for def_name, def_info in output_info["definitions"].items(): |
102 | | - parts.append(f"\n {def_name}:") |
103 | | - if "properties" in def_info: |
104 | | - def_required = def_info.get("required", []) |
105 | | - for prop_name, prop_info in def_info["properties"].items(): |
106 | | - req_marker = " (required)" if prop_name in def_required else "" |
107 | | - prop_type = prop_info.get("_annotation", prop_info.get("type", "Unknown")) |
108 | | - prop_desc = prop_info.get("description", "No description available.") |
109 | | - default = f" (default: {prop_info['default']})" if "default" in prop_info else "" |
110 | | - parts.append( |
111 | | - f" {prop_name}: {prop_type}{req_marker}{default}\n {prop_desc}" |
112 | | - ) |
113 | | - else: |
114 | | - # Simple output schema |
115 | | - desc = output_info.get("description", "No description available.") |
116 | | - output_type = output_info.get("type", "Unknown") |
117 | | - parts.append(f" Type: {output_type}\n {desc}") |
118 | | - else: |
119 | | - parts.append(" Output schema not available") |
| 51 | + parts.append(format_io_info(io_info)) |
120 | 52 | except Exception as e: |
121 | 53 | parts.append(f"\nFailed to fetch input/output schema: {str(e)}") |
122 | 54 |
|
123 | 55 | return "\n".join(parts) |
124 | 56 |
|
125 | 57 |
|
| 58 | +async def search_component_definition( |
| 59 | + client: AsyncClientProtocol, query: str, model: ModelProtocol, top_k: int = 5 |
| 60 | +) -> str: |
| 61 | + """Searches for components based on name or description using semantic similarity. |
| 62 | +
|
| 63 | + Args: |
| 64 | + client: The API client to use |
| 65 | + query: The search query |
| 66 | + model: The model to use for computing embeddings |
| 67 | + top_k: Maximum number of results to return (default: 5) |
| 68 | +
|
| 69 | + Returns: |
| 70 | + A formatted string containing the matched component definitions |
| 71 | + """ |
| 72 | + haystack_service = client.haystack_service() |
| 73 | + |
| 74 | + try: |
| 75 | + response = await haystack_service.get_component_schemas() |
| 76 | + except UnexpectedAPIError as e: |
| 77 | + return f"Failed to retrieve component schemas: {e}" |
| 78 | + |
| 79 | + components = response["component_schema"]["definitions"]["Components"] |
| 80 | + |
| 81 | + # Extract text for embedding from all components |
| 82 | + component_texts: list[tuple[str, str]] = [extract_component_texts(comp) for comp in components.values()] |
| 83 | + component_types: list[str] = [c[0] for c in component_texts] |
| 84 | + |
| 85 | + if not component_texts: |
| 86 | + return "No components found" |
| 87 | + |
| 88 | + # Compute embeddings |
| 89 | + query_embedding = model.encode(query) |
| 90 | + component_embeddings = model.encode([text for _, text in component_texts]) |
| 91 | + |
| 92 | + query_embedding_reshaped = query_embedding.reshape(1, -1) |
| 93 | + |
| 94 | + # Calculate dot product between target and all paths |
| 95 | + # This gives us a similarity score for each path |
| 96 | + similarities = np.dot(component_embeddings, query_embedding_reshaped.T).flatten() |
| 97 | + |
| 98 | + # Create (path, similarity) pairs |
| 99 | + component_similarities = list(zip(component_types, similarities, strict=False)) |
| 100 | + |
| 101 | + # Sort by similarity score in descending order |
| 102 | + component_similarities.sort(key=lambda x: x[1], reverse=True) |
| 103 | + |
| 104 | + top_components = component_similarities[:top_k] |
| 105 | + results = [] |
| 106 | + for component_type, sim in top_components: |
| 107 | + definition = await get_component_definition(client, component_type) |
| 108 | + results.append(f"Similarity Score: {sim:.3f}\n{definition}\n{'-' * 80}\n") |
| 109 | + |
| 110 | + return "\n".join(results) |
| 111 | + |
| 112 | + |
126 | 113 | async def list_component_families(client: AsyncClientProtocol) -> str: |
127 | 114 | """Lists all Haystack component families that are available on deepset.""" |
128 | 115 | haystack_service = client.haystack_service() |
|
0 commit comments