Skip to content
2 changes: 2 additions & 0 deletions docs/docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,8 @@ search:
- [RabbitBroker](api/faststream/rabbit/broker/RabbitBroker.md)
- broker
- [RabbitBroker](api/faststream/rabbit/broker/broker/RabbitBroker.md)
- connection
- [ConnectionManager](api/faststream/rabbit/broker/connection/ConnectionManager.md)
- logging
- [RabbitLoggingBroker](api/faststream/rabbit/broker/logging/RabbitLoggingBroker.md)
- registrator
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: faststream.rabbit.broker.connection.ConnectionManager
6 changes: 4 additions & 2 deletions faststream/confluent/broker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,12 @@ def __init__(
] = SERVICE_NAME,
config: Annotated[
Optional[ConfluentConfig],
Doc("""
Doc(
"""
Extra configuration for the confluent-kafka-python
producer/consumer. See `confluent_kafka.Config <https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#kafka-client-configuration>`_.
"""),
"""
),
] = None,
# publisher args
acks: Annotated[
Expand Down
6 changes: 0 additions & 6 deletions faststream/rabbit/annotations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from aio_pika import RobustChannel, RobustConnection
from typing_extensions import Annotated

from faststream.annotations import ContextRepo, Logger, NoCast
Expand All @@ -14,17 +13,12 @@
"RabbitMessage",
"RabbitBroker",
"RabbitProducer",
"Channel",
"Connection",
)

RabbitMessage = Annotated[RM, Context("message")]
RabbitBroker = Annotated[RB, Context("broker")]
RabbitProducer = Annotated[AioPikaFastProducer, Context("broker._producer")]

Channel = Annotated[RobustChannel, Context("broker._channel")]
Connection = Annotated[RobustConnection, Context("broker._connection")]

# NOTE: transaction is not for the public usage yet
# async def _get_transaction(connection: Connection) -> RabbitTransaction:
# async with connection.channel(publisher_confirms=False) as channel:
Expand Down
106 changes: 41 additions & 65 deletions faststream/rabbit/broker/broker.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,19 @@
import logging
from inspect import Parameter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Optional,
Type,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Type, Union, cast
from urllib.parse import urlparse

from aio_pika import connect_robust
from typing_extensions import Annotated, Doc, override

from faststream.__about__ import SERVICE_NAME
from faststream.broker.message import gen_cor_id
from faststream.exceptions import NOT_CONNECTED_YET
from faststream.rabbit.broker.connection import ConnectionManager
from faststream.rabbit.broker.logging import RabbitLoggingBroker
from faststream.rabbit.broker.registrator import RabbitRegistrator
from faststream.rabbit.helpers.declarer import RabbitDeclarer
from faststream.rabbit.publisher.producer import AioPikaFastProducer
from faststream.rabbit.schemas import (
RABBIT_REPLY,
RabbitExchange,
RabbitQueue,
)
from faststream.rabbit.schemas import RabbitExchange, RabbitQueue
from faststream.rabbit.security import parse_security
from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber
from faststream.rabbit.utils import build_url
Expand All @@ -37,8 +24,6 @@

from aio_pika import (
IncomingMessage,
RobustChannel,
RobustConnection,
RobustExchange,
RobustQueue,
)
Expand All @@ -48,10 +33,7 @@
from yarl import URL

from faststream.asyncapi import schema as asyncapi
from faststream.broker.types import (
BrokerMiddleware,
CustomCallable,
)
from faststream.broker.types import BrokerMiddleware, CustomCallable
from faststream.rabbit.types import AioPikaSendableMessage
from faststream.security import BaseSecurity
from faststream.types import AnyDict, Decorator, LoggerProto
Expand All @@ -67,7 +49,6 @@ class RabbitBroker(
_producer: Optional["AioPikaFastProducer"]

declarer: Optional[RabbitDeclarer]
_channel: Optional["RobustChannel"]

def __init__(
self,
Expand Down Expand Up @@ -213,6 +194,14 @@ def __init__(
Iterable["Decorator"],
Doc("Any custom decorator to apply to wrapped functions."),
] = (),
max_connection_pool_size: Annotated[
int,
Doc("Max connection pool size"),
] = 1,
max_channel_pool_size: Annotated[
int,
Doc("Max channel pool size"),
] = 1,
) -> None:
security_args = parse_security(security)

Expand All @@ -234,6 +223,8 @@ def __init__(
# respect ascynapi_url argument scheme
builded_asyncapi_url = urlparse(asyncapi_url)
self.virtual_host = builded_asyncapi_url.path
self.max_connection_pool_size = max_connection_pool_size
self.max_channel_pool_size = max_channel_pool_size
if protocol is None:
protocol = builded_asyncapi_url.scheme

Expand Down Expand Up @@ -273,13 +264,13 @@ def __init__(

self.app_id = app_id

self._channel = None
self.declarer = None

@property
def _subscriber_setup_extra(self) -> "AnyDict":
return {
**super()._subscriber_setup_extra,
"max_consumers": self._max_consumers,
"app_id": self.app_id,
"virtual_host": self.virtual_host,
"declarer": self.declarer,
Expand Down Expand Up @@ -350,7 +341,7 @@ async def connect( # type: ignore[override]
"when mandatory message will be returned"
),
] = Parameter.empty,
) -> "RobustConnection":
) -> "ConnectionManager":
"""Connect broker object to RabbitMQ.

To startup subscribers too you should use `broker.start()` after/instead this method.
Expand Down Expand Up @@ -405,65 +396,50 @@ async def _connect( # type: ignore[override]
channel_number: Optional[int],
publisher_confirms: bool,
on_return_raises: bool,
) -> "RobustConnection":
connection = cast(
"RobustConnection",
await connect_robust(
url,
timeout=timeout,
ssl_context=ssl_context,
),
)

if self._channel is None: # pragma: no branch
max_consumers = self._max_consumers
channel = self._channel = cast(
"RobustChannel",
await connection.channel(
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
),
) -> "ConnectionManager":
if self._max_consumers:
c = AsyncAPISubscriber.build_log_context(
None,
RabbitQueue(""),
RabbitExchange(""),
)
self._log(f"Set max consumers to {self._max_consumers}", extra=c)

declarer = self.declarer = RabbitDeclarer(channel)
await declarer.declare_queue(RABBIT_REPLY)
connection_manager = ConnectionManager(
url=url,
timeout=timeout,
ssl_context=ssl_context,
connection_pool_size=self.max_connection_pool_size,
channel_pool_size=self.max_channel_pool_size,
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
)

if self.declarer is None:
self.declarer = RabbitDeclarer(connection_manager)

if self._producer is None:
self._producer = AioPikaFastProducer(
declarer=declarer,
declarer=self.declarer,
decoder=self._decoder,
parser=self._parser,
)

if max_consumers:
c = AsyncAPISubscriber.build_log_context(
None,
RabbitQueue(""),
RabbitExchange(""),
)
self._log(f"Set max consumers to {max_consumers}", extra=c)
await channel.set_qos(prefetch_count=int(max_consumers))

return connection
return connection_manager

async def _close(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional["TracebackType"] = None,
) -> None:
if self._channel is not None:
if not self._channel.is_closed:
await self._channel.close()

self._channel = None
if self._connection is not None:
await self._connection.close()

self.declarer = None
self._producer = None

if self._connection is not None:
await self._connection.close()

await super()._close(exc_type, exc_val, exc_tb)

async def start(self) -> None:
Expand Down
87 changes: 87 additions & 0 deletions faststream/rabbit/broker/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncIterator, Optional, cast

from aio_pika import connect_robust
from aio_pika.pool import Pool

if TYPE_CHECKING:
from ssl import SSLContext

from aio_pika import (
RobustChannel,
RobustConnection,
)
from aio_pika.abc import TimeoutType


class ConnectionManager:
def __init__(
self,
*,
url: str,
timeout: "TimeoutType",
ssl_context: Optional["SSLContext"],
connection_pool_size: Optional[int],
channel_pool_size: Optional[int],
channel_number: Optional[int],
publisher_confirms: bool,
on_return_raises: bool,
) -> None:
self._connection_pool: "Pool[RobustConnection]" = Pool(
lambda: connect_robust(
url=url,
timeout=timeout,
ssl_context=ssl_context,
),
max_size=connection_pool_size,
)

self._channel_pool: "Pool[RobustChannel]" = Pool(
lambda: self._get_channel(
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
),
max_size=channel_pool_size,
)

async def get_connection(self) -> "RobustConnection":
return await self._connection_pool._get()

@asynccontextmanager
async def acquire_connection(self) -> AsyncIterator["RobustConnection"]:
async with self._connection_pool.acquire() as connection:
yield connection

async def get_channel(self) -> "RobustChannel":
return await self._channel_pool._get()

@asynccontextmanager
async def acquire_channel(self) -> AsyncIterator["RobustChannel"]:
async with self._channel_pool.acquire() as channel:
yield channel

async def _get_channel(
self,
channel_number: Optional[int] = None,
publisher_confirms: bool = True,
on_return_raises: bool = False,
) -> "RobustChannel":
async with self.acquire_connection() as connection:
channel = cast(
"RobustChannel",
await connection.channel(
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
),
)

return channel

async def close(self) -> None:
if not self._channel_pool.is_closed:
await self._channel_pool.close()

if not self._connection_pool.is_closed:
await self._connection_pool.close()
5 changes: 3 additions & 2 deletions faststream/rabbit/broker/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
from inspect import Parameter
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union

from aio_pika import IncomingMessage, RobustConnection
from aio_pika import IncomingMessage

from faststream.broker.core.usecase import BrokerUsecase
from faststream.log.logging import get_broker_logger
from faststream.rabbit.broker.connection import ConnectionManager

if TYPE_CHECKING:
from faststream.types import LoggerProto


class RabbitLoggingBroker(BrokerUsecase[IncomingMessage, RobustConnection]):
class RabbitLoggingBroker(BrokerUsecase[IncomingMessage, ConnectionManager]):
"""A class that extends the LoggingMixin class and adds additional functionality for logging RabbitMQ related information."""

_max_queue_len: int
Expand Down
Loading