diff --git a/autogpt_platform/backend/backend/data/credit.py b/autogpt_platform/backend/backend/data/credit.py index 1b0df6afc2fe..c8d29eb62fa2 100644 --- a/autogpt_platform/backend/backend/data/credit.py +++ b/autogpt_platform/backend/backend/data/credit.py @@ -5,7 +5,6 @@ from typing import Any, cast import stripe -from autogpt_libs.utils.cache import thread_cached from prisma import Json from prisma.enums import ( CreditRefundRequestStatus, @@ -32,14 +31,13 @@ TransactionHistory, UserTransaction, ) -from backend.data.notifications import NotificationEventDTO, RefundRequestData +from backend.data.notifications import NotificationEventModel, RefundRequestData from backend.data.user import get_user_by_id, get_user_email_by_id -from backend.notifications import NotificationManagerClient +from backend.notifications.notifications import queue_notification_async from backend.server.model import Pagination from backend.server.v2.admin.model import UserHistoryResponse from backend.util.exceptions import InsufficientBalanceError from backend.util.retry import func_retry -from backend.util.service import get_service_client from backend.util.settings import Settings settings = Settings() @@ -374,20 +372,17 @@ async def _add_transaction( class UserCredit(UserCreditBase): - @thread_cached - def notification_client(self) -> NotificationManagerClient: - return get_service_client(NotificationManagerClient) async def _send_refund_notification( self, notification_request: RefundRequestData, notification_type: NotificationType, ): - await self.notification_client().queue_notification_async( - NotificationEventDTO( + await queue_notification_async( + NotificationEventModel( user_id=notification_request.user_id, type=notification_type, - data=notification_request.model_dump(), + data=notification_request, ) ) diff --git a/autogpt_platform/backend/backend/data/notifications.py b/autogpt_platform/backend/backend/data/notifications.py index c8eb42312099..83c20b153582 100644 --- a/autogpt_platform/backend/backend/data/notifications.py +++ b/autogpt_platform/backend/backend/data/notifications.py @@ -189,26 +189,14 @@ class RefundRequestData(BaseNotificationData): ] -class NotificationEventDTO(BaseModel): - user_id: str +class BaseEventModel(BaseModel): type: NotificationType - data: dict - created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) - retry_count: int = 0 - - -class SummaryParamsEventDTO(BaseModel): user_id: str - type: NotificationType - data: dict created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) -class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]): - user_id: str - type: NotificationType +class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]): data: NotificationDataType_co - created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) @property def strategy(self) -> QueueType: @@ -225,11 +213,8 @@ def template(self) -> str: return NotificationTypeOverride(self.type).template -class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]): - user_id: str - type: NotificationType +class SummaryParamsEventModel(BaseEventModel, Generic[SummaryParamsType_co]): data: SummaryParamsType_co - created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) def get_notif_data_type( diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index b55518ad23ae..5034aeccbe86 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -23,16 +23,16 @@ from backend.data.notifications import ( AgentRunData, LowBalanceData, - NotificationEventDTO, + NotificationEventModel, NotificationType, ) from backend.data.rabbitmq import SyncRabbitMQ from backend.executor.utils import create_execution_queue_config +from backend.notifications.notifications import queue_notification from backend.util.exceptions import InsufficientBalanceError if TYPE_CHECKING: from backend.executor import DatabaseManagerClient - from backend.notifications.notifications import NotificationManagerClient from autogpt_libs.utils.cache import thread_cached from prometheus_client import Gauge, start_http_server @@ -580,7 +580,6 @@ def on_graph_executor_start(cls): cls.db_client = get_db_client() cls.pool_size = settings.config.num_node_workers cls.pid = os.getpid() - cls.notification_service = get_notification_service() cls._init_node_executor_pool() logger.info(f"GraphExec {cls.pid} started with {cls.pool_size} node workers") @@ -905,21 +904,21 @@ def _handle_agent_run_notif( for output in outputs ] - event = NotificationEventDTO( - user_id=graph_exec.user_id, - type=NotificationType.AGENT_RUN, - data=AgentRunData( - outputs=named_outputs, - agent_name=metadata.name if metadata else "Unknown Agent", - credits_used=exec_stats.cost, - execution_time=exec_stats.walltime, - graph_id=graph_exec.graph_id, - node_count=exec_stats.node_count, - ).model_dump(), + queue_notification( + NotificationEventModel( + user_id=graph_exec.user_id, + type=NotificationType.AGENT_RUN, + data=AgentRunData( + outputs=named_outputs, + agent_name=metadata.name if metadata else "Unknown Agent", + credits_used=exec_stats.cost, + execution_time=exec_stats.walltime, + graph_id=graph_exec.graph_id, + node_count=exec_stats.node_count, + ), + ) ) - cls.notification_service.queue_notification(event) - @classmethod def _handle_low_balance_notif( cls, @@ -933,8 +932,8 @@ def _handle_low_balance_notif( base_url = ( settings.config.frontend_base_url or settings.config.platform_base_url ) - cls.notification_service.queue_notification( - NotificationEventDTO( + queue_notification( + NotificationEventModel( user_id=user_id, type=NotificationType.LOW_BALANCE, data=LowBalanceData( @@ -942,7 +941,7 @@ def _handle_low_balance_notif( billing_page_link=f"{base_url}/profile/credits", shortfall=shortfall, agent_name=metadata.name if metadata else "Unknown Agent", - ).model_dump(), + ), ) ) @@ -1139,14 +1138,6 @@ def get_db_client() -> "DatabaseManagerClient": return get_service_client(DatabaseManagerClient, health_check=False) -@thread_cached -def get_notification_service() -> "NotificationManagerClient": - from backend.notifications import NotificationManagerClient - - # Disable health check for the service client to avoid breaking process initializer. - return get_service_client(NotificationManagerClient, health_check=False) - - def send_execution_update(entry: GraphExecution | NodeExecutionResult | None): if entry is None: return diff --git a/autogpt_platform/backend/backend/notifications/notifications.py b/autogpt_platform/backend/backend/notifications/notifications.py index 0b9d20d45904..a697269b2d67 100644 --- a/autogpt_platform/backend/backend/notifications/notifications.py +++ b/autogpt_platform/backend/backend/notifications/notifications.py @@ -1,5 +1,6 @@ import logging import time +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone from typing import Callable @@ -7,20 +8,18 @@ from aio_pika.exceptions import QueueEmpty from autogpt_libs.utils.cache import thread_cached from prisma.enums import NotificationType -from pydantic import BaseModel from backend.data import rabbitmq from backend.data.notifications import ( + BaseEventModel, BaseSummaryData, BaseSummaryParams, DailySummaryData, DailySummaryParams, - NotificationEventDTO, NotificationEventModel, NotificationResult, NotificationTypeOverride, QueueType, - SummaryParamsEventDTO, SummaryParamsEventModel, WeeklySummaryData, WeeklySummaryParams, @@ -28,13 +27,19 @@ get_notif_data_type, get_summary_params_type, ) -from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig +from backend.data.rabbitmq import ( + AsyncRabbitMQ, + Exchange, + ExchangeType, + Queue, + RabbitMQConfig, + SyncRabbitMQ, +) from backend.data.user import generate_unsubscribe_link from backend.notifications.email import EmailSender from backend.util.service import ( AppService, AppServiceClient, - endpoint_to_async, expose, get_service_client, ) @@ -44,70 +49,66 @@ settings = Settings() -class NotificationEvent(BaseModel): - event: NotificationEventDTO - model: NotificationEventModel +NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC) +DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC) +EXCHANGES = [NOTIFICATION_EXCHANGE, DEAD_LETTER_EXCHANGE] + +background_executor = ThreadPoolExecutor(max_workers=2) def create_notification_config() -> RabbitMQConfig: """Create RabbitMQ configuration for notifications""" - notification_exchange = Exchange(name="notifications", type=ExchangeType.TOPIC) - - dead_letter_exchange = Exchange(name="dead_letter", type=ExchangeType.TOPIC) queues = [ # Main notification queues Queue( name="immediate_notifications", - exchange=notification_exchange, + exchange=NOTIFICATION_EXCHANGE, routing_key="notification.immediate.#", arguments={ - "x-dead-letter-exchange": dead_letter_exchange.name, + "x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name, "x-dead-letter-routing-key": "failed.immediate", }, ), Queue( name="admin_notifications", - exchange=notification_exchange, + exchange=NOTIFICATION_EXCHANGE, routing_key="notification.admin.#", arguments={ - "x-dead-letter-exchange": dead_letter_exchange.name, + "x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name, "x-dead-letter-routing-key": "failed.admin", }, ), # Summary notification queues Queue( name="summary_notifications", - exchange=notification_exchange, + exchange=NOTIFICATION_EXCHANGE, routing_key="notification.summary.#", arguments={ - "x-dead-letter-exchange": dead_letter_exchange.name, + "x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name, "x-dead-letter-routing-key": "failed.summary", }, ), # Batch Queue Queue( name="batch_notifications", - exchange=notification_exchange, + exchange=NOTIFICATION_EXCHANGE, routing_key="notification.batch.#", arguments={ - "x-dead-letter-exchange": dead_letter_exchange.name, + "x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name, "x-dead-letter-routing-key": "failed.batch", }, ), # Failed notifications queue Queue( name="failed_notifications", - exchange=dead_letter_exchange, + exchange=DEAD_LETTER_EXCHANGE, routing_key="failed.#", ), ] return RabbitMQConfig( - exchanges=[ - notification_exchange, - dead_letter_exchange, - ], + exchanges=EXCHANGES, queues=queues, ) @@ -119,6 +120,86 @@ def get_db(): return get_service_client(DatabaseManagerClient) +@thread_cached +def get_notification_queue() -> SyncRabbitMQ: + client = SyncRabbitMQ(create_notification_config()) + client.connect() + return client + + +@thread_cached +async def get_async_notification_queue() -> AsyncRabbitMQ: + client = AsyncRabbitMQ(create_notification_config()) + await client.connect() + return client + + +def get_routing_key(event_type: NotificationType) -> str: + strategy = NotificationTypeOverride(event_type).strategy + """Get the appropriate routing key for an event""" + if strategy == QueueType.IMMEDIATE: + return f"notification.immediate.{event_type.value}" + elif strategy == QueueType.BACKOFF: + return f"notification.backoff.{event_type.value}" + elif strategy == QueueType.ADMIN: + return f"notification.admin.{event_type.value}" + elif strategy == QueueType.BATCH: + return f"notification.batch.{event_type.value}" + elif strategy == QueueType.SUMMARY: + return f"notification.summary.{event_type.value}" + return f"notification.{event_type.value}" + + +def queue_notification(event: NotificationEventModel) -> NotificationResult: + """Queue a notification - exposed method for other services to call""" + try: + logger.debug(f"Received Request to queue {event=}") + + exchange = "notifications" + routing_key = get_routing_key(event.type) + + queue = get_notification_queue() + queue.publish_message( + routing_key=routing_key, + message=event.model_dump_json(), + exchange=next(ex for ex in EXCHANGES if ex.name == exchange), + ) + + return NotificationResult( + success=True, + message=f"Notification queued with routing key: {routing_key}", + ) + + except Exception as e: + logger.exception(f"Error queueing notification: {e}") + return NotificationResult(success=False, message=str(e)) + + +async def queue_notification_async(event: NotificationEventModel) -> NotificationResult: + """Queue a notification - exposed method for other services to call""" + try: + logger.debug(f"Received Request to queue {event=}") + + exchange = "notifications" + routing_key = get_routing_key(event.type) + + queue = await get_async_notification_queue() + await queue.publish_message( + routing_key=routing_key, + message=event.model_dump_json(), + exchange=next(ex for ex in EXCHANGES if ex.name == exchange), + ) + + return NotificationResult( + success=True, + message=f"Notification queued with routing key: {routing_key}", + ) + + except Exception as e: + logger.exception(f"Error queueing notification: {e}") + return NotificationResult(success=False, message=str(e)) + + class NotificationManager(AppService): """Service for handling notifications with batching support""" @@ -146,23 +227,11 @@ def rabbit_config(self) -> rabbitmq.RabbitMQConfig: def get_port(cls) -> int: return settings.config.notification_service_port - def get_routing_key(self, event_type: NotificationType) -> str: - strategy = NotificationTypeOverride(event_type).strategy - """Get the appropriate routing key for an event""" - if strategy == QueueType.IMMEDIATE: - return f"notification.immediate.{event_type.value}" - elif strategy == QueueType.BACKOFF: - return f"notification.backoff.{event_type.value}" - elif strategy == QueueType.ADMIN: - return f"notification.admin.{event_type.value}" - elif strategy == QueueType.BATCH: - return f"notification.batch.{event_type.value}" - elif strategy == QueueType.SUMMARY: - return f"notification.summary.{event_type.value}" - return f"notification.{event_type.value}" - @expose def queue_weekly_summary(self): + background_executor.submit(self._queue_weekly_summary) + + def _queue_weekly_summary(self): """Process weekly summary for specified notification types""" try: logger.info("Processing weekly summary queuing operation") @@ -176,13 +245,13 @@ def queue_weekly_summary(self): for user in users: self._queue_scheduled_notification( - SummaryParamsEventDTO( + SummaryParamsEventModel( user_id=user, type=NotificationType.WEEKLY_SUMMARY, data=WeeklySummaryParams( start_date=start_time, end_date=current_time, - ).model_dump(), + ), ), ) processed_count += 1 @@ -194,6 +263,9 @@ def queue_weekly_summary(self): @expose def process_existing_batches(self, notification_types: list[NotificationType]): + background_executor.submit(self._process_existing_batches, notification_types) + + def _process_existing_batches(self, notification_types: list[NotificationType]): """Process existing batches for specified notification types""" try: processed_count = 0 @@ -312,66 +384,20 @@ def process_existing_batches(self, notification_types: list[NotificationType]): "timestamp": datetime.now(tz=timezone.utc).isoformat(), } - @expose - def queue_notification(self, event: NotificationEventDTO) -> NotificationResult: - """Queue a notification - exposed method for other services to call""" - try: - logger.info(f"Received Request to queue {event=}") - # Workaround for not being able to serialize generics over the expose bus - parsed_event = NotificationEventModel[ - get_notif_data_type(event.type) - ].model_validate(event.model_dump()) - routing_key = self.get_routing_key(parsed_event.type) - message = parsed_event.model_dump_json() - - logger.info(f"Received Request to queue {message=}") - - exchange = "notifications" - - # Publish to RabbitMQ - self.run_and_wait( - self.rabbit.publish_message( - routing_key=routing_key, - message=message, - exchange=next( - ex for ex in self.rabbit_config.exchanges if ex.name == exchange - ), - ) - ) - - return NotificationResult( - success=True, - message=f"Notification queued with routing key: {routing_key}", - ) - - except Exception as e: - logger.exception(f"Error queueing notification: {e}") - return NotificationResult(success=False, message=str(e)) - - def _queue_scheduled_notification(self, event: SummaryParamsEventDTO): + def _queue_scheduled_notification(self, event: SummaryParamsEventModel): """Queue a scheduled notification - exposed method for other services to call""" try: - logger.info(f"Received Request to queue scheduled notification {event=}") - - parsed_event = SummaryParamsEventModel[ - get_summary_params_type(event.type) - ].model_validate(event.model_dump()) - - routing_key = self.get_routing_key(event.type) - message = parsed_event.model_dump_json() - - logger.info(f"Received Request to queue {message=}") + logger.debug(f"Received Request to queue scheduled notification {event=}") exchange = "notifications" + routing_key = get_routing_key(event.type) # Publish to RabbitMQ self.run_and_wait( self.rabbit.publish_message( routing_key=routing_key, - message=message, - exchange=next( - ex for ex in self.rabbit_config.exchanges if ex.name == exchange - ), + message=event.model_dump_json(), + exchange=next(ex for ex in EXCHANGES if ex.name == exchange), ) ) @@ -497,13 +523,12 @@ def _should_batch( ) return False - def _parse_message(self, message: str) -> NotificationEvent | None: + def _parse_message(self, message: str) -> NotificationEventModel | None: try: - event = NotificationEventDTO.model_validate_json(message) - model = NotificationEventModel[ + event = BaseEventModel.model_validate_json(message) + return NotificationEventModel[ get_notif_data_type(event.type) ].model_validate_json(message) - return NotificationEvent(event=event, model=model) except Exception as e: logger.error(f"Error parsing message due to non matching schema {e}") return None @@ -511,14 +536,12 @@ def _parse_message(self, message: str) -> NotificationEvent | None: def _process_admin_message(self, message: str) -> bool: """Process a single notification, sending to an admin, returning whether to put into the failed queue""" try: - parsed = self._parse_message(message) - if not parsed: + event = self._parse_message(message) + if not event: return False - event = parsed.event - model = parsed.model - logger.debug(f"Processing notification for admin: {model}") + logger.debug(f"Processing notification for admin: {event}") recipient_email = settings.config.refund_notification_email - self.email_sender.send_templated(event.type, recipient_email, model) + self.email_sender.send_templated(event.type, recipient_email, event) return True except Exception as e: logger.exception(f"Error processing notification for admin queue: {e}") @@ -527,12 +550,10 @@ def _process_admin_message(self, message: str) -> bool: def _process_immediate(self, message: str) -> bool: """Process a single notification immediately, returning whether to put into the failed queue""" try: - parsed = self._parse_message(message) - if not parsed: + event = self._parse_message(message) + if not event: return False - event = parsed.event - model = parsed.model - logger.debug(f"Processing immediate notification: {model}") + logger.debug(f"Processing immediate notification: {event}") recipient_email = get_db().get_user_email_by_id(event.user_id) if not recipient_email: @@ -553,7 +574,7 @@ def _process_immediate(self, message: str) -> bool: self.email_sender.send_templated( notification=event.type, user_email=recipient_email, - data=model, + data=event, user_unsub_link=unsub_link, ) return True @@ -564,12 +585,10 @@ def _process_immediate(self, message: str) -> bool: def _process_batch(self, message: str) -> bool: """Process a single notification with a batching strategy, returning whether to put into the failed queue""" try: - parsed = self._parse_message(message) - if not parsed: + event = self._parse_message(message) + if not event: return False - event = parsed.event - model = parsed.model - logger.info(f"Processing batch notification: {model}") + logger.info(f"Processing batch notification: {event}") recipient_email = get_db().get_user_email_by_id(event.user_id) if not recipient_email: @@ -585,7 +604,7 @@ def _process_batch(self, message: str) -> bool: ) return True - should_send = self._should_batch(event.user_id, event.type, model) + should_send = self._should_batch(event.user_id, event.type, event) if not should_send: logger.info("Batch not old enough to send") @@ -627,7 +646,7 @@ def _process_summary(self, message: str) -> bool: """Process a single notification with a summary strategy, returning whether to put into the failed queue""" try: logger.info(f"Processing summary notification: {message}") - event = SummaryParamsEventDTO.model_validate_json(message) + event = BaseEventModel.model_validate_json(message) model = SummaryParamsEventModel[ get_summary_params_type(event.type) ].model_validate_json(message) @@ -764,7 +783,5 @@ class NotificationManagerClient(AppServiceClient): def get_service_type(cls): return NotificationManager - queue_notification_async = endpoint_to_async(NotificationManager.queue_notification) - queue_notification = NotificationManager.queue_notification process_existing_batches = NotificationManager.process_existing_batches queue_weekly_summary = NotificationManager.queue_weekly_summary