Skip to content

refactor(backend): Clear out Notification Service code blockage #9915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions autogpt_platform/backend/backend/data/credit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, cast

import stripe
from autogpt_libs.utils.cache import thread_cached
from prisma import Json
from prisma.enums import (
CreditRefundRequestStatus,
Expand All @@ -32,14 +31,13 @@
TransactionHistory,
UserTransaction,
)
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications import NotificationManagerClient
from backend.notifications.notifications import queue_notification_async
from backend.server.model import Pagination
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.retry import func_retry
from backend.util.service import get_service_client
from backend.util.settings import Settings

settings = Settings()
Expand Down Expand Up @@ -374,20 +372,17 @@ async def _add_transaction(


class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManagerClient:
return get_service_client(NotificationManagerClient)

async def _send_refund_notification(
self,
notification_request: RefundRequestData,
notification_type: NotificationType,
):
await self.notification_client().queue_notification_async(
NotificationEventDTO(
await queue_notification_async(
NotificationEventModel(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
data=notification_request,
)
)

Expand Down
21 changes: 3 additions & 18 deletions autogpt_platform/backend/backend/data/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,26 +189,14 @@ class RefundRequestData(BaseNotificationData):
]


class NotificationEventDTO(BaseModel):
user_id: str
class BaseEventModel(BaseModel):
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
retry_count: int = 0


class SummaryParamsEventDTO(BaseModel):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))


class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
user_id: str
type: NotificationType
class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
data: NotificationDataType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))

@property
def strategy(self) -> QueueType:
Expand All @@ -225,11 +213,8 @@ def template(self) -> str:
return NotificationTypeOverride(self.type).template


class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]):
user_id: str
type: NotificationType
class SummaryParamsEventModel(BaseEventModel, Generic[SummaryParamsType_co]):
data: SummaryParamsType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))


def get_notif_data_type(
Expand Down
45 changes: 18 additions & 27 deletions autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventDTO,
NotificationEventModel,
NotificationType,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.utils import create_execution_queue_config
from backend.notifications.notifications import queue_notification
from backend.util.exceptions import InsufficientBalanceError

if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient
from backend.notifications.notifications import NotificationManagerClient

from autogpt_libs.utils.cache import thread_cached
from prometheus_client import Gauge, start_http_server
Expand Down Expand Up @@ -580,7 +580,6 @@ def on_graph_executor_start(cls):
cls.db_client = get_db_client()
cls.pool_size = settings.config.num_node_workers
cls.pid = os.getpid()
cls.notification_service = get_notification_service()
cls._init_node_executor_pool()
logger.info(f"GraphExec {cls.pid} started with {cls.pool_size} node workers")

Expand Down Expand Up @@ -905,21 +904,21 @@ def _handle_agent_run_notif(
for output in outputs
]

event = NotificationEventDTO(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
).model_dump(),
queue_notification(
NotificationEventModel(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
),
)
)

cls.notification_service.queue_notification(event)

@classmethod
def _handle_low_balance_notif(
cls,
Expand All @@ -933,16 +932,16 @@ def _handle_low_balance_notif(
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
cls.notification_service.queue_notification(
NotificationEventDTO(
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=exec_stats.cost,
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
).model_dump(),
),
)
)

Expand Down Expand Up @@ -1139,14 +1138,6 @@ def get_db_client() -> "DatabaseManagerClient":
return get_service_client(DatabaseManagerClient, health_check=False)


@thread_cached
def get_notification_service() -> "NotificationManagerClient":
from backend.notifications import NotificationManagerClient

# Disable health check for the service client to avoid breaking process initializer.
return get_service_client(NotificationManagerClient, health_check=False)


def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
if entry is None:
return
Expand Down
Loading