diff --git a/src/instana/__init__.py b/src/instana/__init__.py index 00de8627..c3849aad 100644 --- a/src/instana/__init__.py +++ b/src/instana/__init__.py @@ -189,6 +189,7 @@ def boot_agent() -> None: starlette, # noqa: F401 urllib3, # noqa: F401 spyne, # noqa: F401 + aio_pika, # noqa: F401 ) from instana.instrumentation.aiohttp import ( client as aiohttp_client, # noqa: F401 diff --git a/src/instana/instrumentation/aio_pika.py b/src/instana/instrumentation/aio_pika.py new file mode 100644 index 00000000..52dfb054 --- /dev/null +++ b/src/instana/instrumentation/aio_pika.py @@ -0,0 +1,117 @@ +# (c) Copyright IBM Corp. 2025 + +try: + import aio_pika + import wrapt + from typing import ( + TYPE_CHECKING, + Dict, + Any, + Callable, + Tuple, + Type, + Optional, + ) + + from instana.log import logger + from instana.propagators.format import Format + from instana.util.traceutils import get_tracer_tuple, tracing_is_off + from instana.singletons import tracer + + if TYPE_CHECKING: + from instana.span.span import InstanaSpan + from aio_pika.exchange import Exchange + from aiormq.abc import ConfirmationFrameType + from aio_pika.abc import ConsumerTag, AbstractMessage + from aio_pika.queue import Queue, QueueIterator + + def _extract_span_attributes( + span: "InstanaSpan", connection, sort: str, routing_key: str, exchange: str + ) -> None: + span.set_attribute("address", str(connection.url)) + + span.set_attribute("sort", sort) + span.set_attribute("key", routing_key) + span.set_attribute("exchange", exchange) + + @wrapt.patch_function_wrapper("aio_pika", "Exchange.publish") + async def publish_with_instana( + wrapped: Callable[..., Optional["ConfirmationFrameType"]], + instance: "Exchange", + args: Tuple[object], + kwargs: Dict[str, Any], + ) -> Optional["ConfirmationFrameType"]: + if tracing_is_off(): + return await wrapped(*args, **kwargs) + + tracer, parent_span, _ = get_tracer_tuple() + parent_context = parent_span.get_span_context() if parent_span else None + + with tracer.start_as_current_span( + "rabbitmq", span_context=parent_context + ) as span: + connection = instance.channel._connection + _extract_span_attributes( + span, connection, "publish", kwargs["routing_key"], instance.name + ) + + message = args[0] + tracer.inject( + span.context, + Format.HTTP_HEADERS, + message.properties.headers, + disable_w3c_trace_context=True, + ) + try: + response = await wrapped(*args, **kwargs) + except Exception as exc: + span.record_exception(exc) + else: + return response + + @wrapt.patch_function_wrapper("aio_pika", "Queue.consume") + async def consume_with_instana( + wrapped: Callable[..., "ConsumerTag"], + instance: Type["Queue"], + args: Tuple[object], + kwargs: Dict[str, Any], + ) -> "ConsumerTag": + connection = instance.channel._connection + callback = kwargs["callback"] if kwargs.get("callback") else args[0] + + @wrapt.decorator + async def callback_wrapper( + wrapped: Callable[[Type["AbstractMessage"]], Any], + instance: Type["QueueIterator"], + args: Tuple[Type["AbstractMessage"], ...], + kwargs: Dict[str, Any], + ) -> Callable[[Type["AbstractMessage"]], Any]: + message = args[0] + parent_context = tracer.extract( + Format.HTTP_HEADERS, message.headers, disable_w3c_trace_context=True + ) + with tracer.start_as_current_span( + "rabbitmq", span_context=parent_context + ) as span: + _extract_span_attributes( + span, connection, "consume", message.routing_key, message.exchange + ) + try: + response = await wrapped(*args, **kwargs) + except Exception as exc: + span.record_exception(exc) + else: + return response + + wrapped_callback = callback_wrapper(callback) + if kwargs.get("callback"): + kwargs["callback"] = wrapped_callback + else: + args = (wrapped_callback,) + args[1:] + + return await wrapped(*args, **kwargs) + + logger.debug("Instrumenting aio-pika") + +except ImportError: + pass diff --git a/tests/clients/test_aio_pika.py b/tests/clients/test_aio_pika.py new file mode 100644 index 00000000..ee49b545 --- /dev/null +++ b/tests/clients/test_aio_pika.py @@ -0,0 +1,171 @@ +# (c) Copyright IBM Corp. 2025 + +import pytest +from typing import Generator, TYPE_CHECKING +import asyncio +from aio_pika import Message, connect, connect_robust + +from instana.singletons import agent, tracer + +if TYPE_CHECKING: + from instana.span.readable_span import ReadableSpan + + +class TestAioPika: + @pytest.fixture(autouse=True) + def _resource(self) -> Generator[None, None, None]: + """SetUp and TearDown""" + # setup + self.recorder = tracer.span_processor + self.recorder.clear_spans() + + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + self.queue_name = "test.queue" + yield + # teardown + self.loop.run_until_complete(self.delete_queue()) + if self.loop.is_running(): + self.loop.close() + # Ensure that allow_exit_as_root has the default value + agent.options.allow_exit_as_root = False + + async def publish_message(self) -> None: + # Perform connection + connection = await connect() + + async with connection: + # Creating a channel + channel = await connection.channel() + + # Declaring queue + queue_name = self.queue_name + queue = await channel.declare_queue(queue_name) + + # Declaring exchange + exchange = await channel.declare_exchange("test.exchange") + await queue.bind(exchange, routing_key=queue_name) + + # Sending the message + await exchange.publish( + Message(f"Hello {queue_name}".encode()), + routing_key=queue_name, + ) + + async def delete_queue(self) -> None: + connection = await connect() + + async with connection: + channel = await connection.channel() + await channel.queue_delete(self.queue_name) + + async def consume_message(self, connect_method) -> None: + connection = await connect_method() + + async with connection: + # Creating channel + channel = await connection.channel() + + # Declaring queue + queue = await channel.declare_queue(self.queue_name) + + async with queue.iterator() as queue_iter: + async for message in queue_iter: + async with message.process(): + if queue.name in message.body.decode(): + break + + def test_basic_publish(self) -> None: + with tracer.start_as_current_span("test"): + self.loop.run_until_complete(self.publish_message()) + + spans = self.recorder.queued_spans() + assert len(spans) == 2 + + rabbitmq_span = spans[0] + test_span = spans[1] + + # Same traceId + assert test_span.t == rabbitmq_span.t + + # Parent relationships + assert rabbitmq_span.p == test_span.s + + # Error logging + assert not test_span.ec + assert not rabbitmq_span.ec + + # Span attributes + assert rabbitmq_span.data["rabbitmq"]["exchange"] == "test.exchange" + assert rabbitmq_span.data["rabbitmq"]["sort"] == "publish" + assert rabbitmq_span.data["rabbitmq"]["address"] + assert rabbitmq_span.data["rabbitmq"]["key"] == "test.queue" + assert rabbitmq_span.stack + assert isinstance(rabbitmq_span.stack, list) + assert len(rabbitmq_span.stack) > 0 + + def test_basic_publish_as_root_exit_span(self) -> None: + agent.options.allow_exit_as_root = True + self.loop.run_until_complete(self.publish_message()) + + spans = self.recorder.queued_spans() + assert len(spans) == 1 + + rabbitmq_span = spans[0] + + # Parent relationships + assert not rabbitmq_span.p + + # Error logging + assert not rabbitmq_span.ec + + # Span attributes + assert rabbitmq_span.data["rabbitmq"]["exchange"] == "test.exchange" + assert rabbitmq_span.data["rabbitmq"]["sort"] == "publish" + assert rabbitmq_span.data["rabbitmq"]["address"] + assert rabbitmq_span.data["rabbitmq"]["key"] == "test.queue" + assert rabbitmq_span.stack + assert isinstance(rabbitmq_span.stack, list) + assert len(rabbitmq_span.stack) > 0 + + @pytest.mark.parametrize( + "connect_method", + [connect, connect_robust], + ) + def test_basic_consume(self, connect_method) -> None: + with tracer.start_as_current_span("test"): + self.loop.run_until_complete(self.publish_message()) + self.loop.run_until_complete(self.consume_message(connect_method)) + + spans = self.recorder.queued_spans() + assert len(spans) == 3 + + rabbitmq_publisher_span = spans[0] + rabbitmq_consumer_span = spans[1] + test_span = spans[2] + + # Same traceId + assert test_span.t == rabbitmq_publisher_span.t + assert rabbitmq_publisher_span.t == rabbitmq_consumer_span.t + + # Parent relationships + assert rabbitmq_publisher_span.p == test_span.s + assert rabbitmq_consumer_span.p == rabbitmq_publisher_span.s + + # Error logging + assert not rabbitmq_publisher_span.ec + assert not rabbitmq_consumer_span.ec + assert not test_span.ec + + # Span attributes + def assert_span_info(rabbitmq_span: "ReadableSpan", sort: str) -> None: + assert rabbitmq_span.data["rabbitmq"]["exchange"] == "test.exchange" + assert rabbitmq_span.data["rabbitmq"]["sort"] == sort + assert rabbitmq_span.data["rabbitmq"]["address"] + assert rabbitmq_span.data["rabbitmq"]["key"] == "test.queue" + assert rabbitmq_span.stack + assert isinstance(rabbitmq_span.stack, list) + assert len(rabbitmq_span.stack) > 0 + + assert_span_info(rabbitmq_publisher_span, "publish") + assert_span_info(rabbitmq_consumer_span, "consume") diff --git a/tests/requirements-pre314.txt b/tests/requirements-pre314.txt index 57ecaa2d..57a1cb23 100644 --- a/tests/requirements-pre314.txt +++ b/tests/requirements-pre314.txt @@ -2,6 +2,7 @@ aioamqp>=0.15.0 aiofiles>=0.5.0 aiohttp>=3.8.3 +aio-pika>=9.5.2 boto3>=1.17.74 bottle>=0.12.25 celery>=5.2.7 diff --git a/tests/requirements.txt b/tests/requirements.txt index b4b257d8..ea441fdc 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -2,6 +2,7 @@ aioamqp>=0.15.0 aiofiles>=0.5.0 aiohttp>=3.8.3 +aio-pika>=9.5.2 boto3>=1.17.74 bottle>=0.12.25 celery>=5.2.7