Skip to content

Enhance Type Safety and Type Hinting in CrewAI Core Components #2830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
Tuple,
Union,
cast,
TypedDict,
)

from crewai.types.crew_types import (
CrewConfig,
MemoryConfig,
AgentProtocol,
TaskProtocol,
validate_agent,
validate_task,
)

from pydantic import (
Expand Down Expand Up @@ -257,17 +267,35 @@ def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
@field_validator("config", mode="before")
@classmethod
def check_config_type(
cls, v: Union[Json, Dict[str, Any]]
) -> Union[Json, Dict[str, Any]]:
cls,
v: Union[Json, Dict[str, Any], str]
) -> Dict[str, Any]:
"""Validates that the config is a valid type.

Args:
v: The config to be validated.
v: The configuration to be validated. Can be a JSON string,
a JSON object, or a dictionary.

Returns:
The config if it is valid.
A validated configuration dictionary.

Raises:
json.JSONDecodeError: If the input is an invalid JSON string.
TypeError: If the input is not a supported configuration type.
"""

# TODO: Improve typing
return json.loads(v) if isinstance(v, Json) else v # type: ignore
if isinstance(v, str):
try:
return json.loads(v)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON configuration: {e}") from e

if isinstance(v, dict):
return v

if isinstance(v, Json):
return json.loads(v)

raise TypeError(f"Unsupported configuration type: {type(v)}")

@model_validator(mode="after")
def set_private_attrs(self) -> "Crew":
Expand Down
61 changes: 61 additions & 0 deletions src/crewai/types/crew_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't quite get why you added it or where it’s being used. Did you maybe forget to include some code in this PR?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes... Let me fix this!

Any,
Callable,
Dict,
List,
Optional,
Protocol,
TypedDict,
Union,
)

from crewai.agent import Agent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool


class MemoryConfig(TypedDict, total=False):
"""Typed dictionary for memory configuration."""
user_memory: Dict[str, Any]
provider: str
long_term: bool
short_term: bool


class CrewConfig(TypedDict, total=False):
"""Typed dictionary for crew configuration."""
tasks: List[Task]
agents: List[Agent]
memory: bool
memory_config: MemoryConfig
max_rpm: int
verbose: bool


class AgentProtocol(Protocol):
"""Protocol for defining agent-like objects."""
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
role: str
goal: str
backstory: str
tools: List[BaseTool]


class TaskProtocol(Protocol):
"""Protocol for defining task-like objects."""
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
description: str
expected_output: str
agent: AgentProtocol


def validate_agent(agent: Any) -> bool:
"""Validate if an object conforms to the AgentProtocol."""
required_attrs = ['role', 'goal', 'backstory', 'tools']
return all(hasattr(agent, attr) for attr in required_attrs)


def validate_task(task: Any) -> bool:
"""Validate if an object conforms to the TaskProtocol."""
required_attrs = ['description', 'expected_output', 'agent']
return all(hasattr(task, attr) for attr in required_attrs)
Loading