diff --git a/src/crewai/crew.py b/src/crewai/crew.py index a0158f646f..6ee63d91f1 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -16,6 +16,16 @@ Tuple, Union, cast, + TypedDict, +) + +from crewai.types.crew_types import ( + CrewConfig, + MemoryConfig, + AgentProtocol, + TaskProtocol, + validate_agent, + validate_task, ) from pydantic import ( @@ -257,17 +267,35 @@ def _deny_user_set_id(cls, v: Optional[UUID4]) -> None: @field_validator("config", mode="before") @classmethod def check_config_type( - cls, v: Union[Json, Dict[str, Any]] - ) -> Union[Json, Dict[str, Any]]: + cls, + v: Union[Json, Dict[str, Any], str] + ) -> Dict[str, Any]: """Validates that the config is a valid type. + Args: - v: The config to be validated. + v: The configuration to be validated. Can be a JSON string, + a JSON object, or a dictionary. + Returns: - The config if it is valid. + A validated configuration dictionary. + + Raises: + json.JSONDecodeError: If the input is an invalid JSON string. + TypeError: If the input is not a supported configuration type. """ - - # TODO: Improve typing - return json.loads(v) if isinstance(v, Json) else v # type: ignore + if isinstance(v, str): + try: + return json.loads(v) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON configuration: {e}") from e + + if isinstance(v, dict): + return v + + if isinstance(v, Json): + return json.loads(v) + + raise TypeError(f"Unsupported configuration type: {type(v)}") @model_validator(mode="after") def set_private_attrs(self) -> "Crew": diff --git a/src/crewai/types/crew_types.py b/src/crewai/types/crew_types.py new file mode 100644 index 0000000000..a4713f222b --- /dev/null +++ b/src/crewai/types/crew_types.py @@ -0,0 +1,61 @@ +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Protocol, + TypedDict, + Union, +) + +from crewai.agent import Agent +from crewai.task import Task +from crewai.tools.base_tool import BaseTool + + +class MemoryConfig(TypedDict, total=False): + """Typed dictionary for memory configuration.""" + user_memory: Dict[str, Any] + provider: str + long_term: bool + short_term: bool + + +class CrewConfig(TypedDict, total=False): + """Typed dictionary for crew configuration.""" + tasks: List[Task] + agents: List[Agent] + memory: bool + memory_config: MemoryConfig + max_rpm: int + verbose: bool + + +class AgentProtocol(Protocol): + """Protocol for defining agent-like objects.""" + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + role: str + goal: str + backstory: str + tools: List[BaseTool] + + +class TaskProtocol(Protocol): + """Protocol for defining task-like objects.""" + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + description: str + expected_output: str + agent: AgentProtocol + + +def validate_agent(agent: Any) -> bool: + """Validate if an object conforms to the AgentProtocol.""" + required_attrs = ['role', 'goal', 'backstory', 'tools'] + return all(hasattr(agent, attr) for attr in required_attrs) + + +def validate_task(task: Any) -> bool: + """Validate if an object conforms to the TaskProtocol.""" + required_attrs = ['description', 'expected_output', 'agent'] + return all(hasattr(task, attr) for attr in required_attrs)