diff --git a/aws_lambda_powertools/logging/logger.py b/aws_lambda_powertools/logging/logger.py index b0371b8a47d..89e1fb01da6 100644 --- a/aws_lambda_powertools/logging/logger.py +++ b/aws_lambda_powertools/logging/logger.py @@ -38,6 +38,7 @@ from aws_lambda_powertools.shared.functions import ( extract_event_from_common_models, get_tracer_id, + is_durable_context, resolve_env_var_choice, resolve_truthy_env_var_choice, ) @@ -520,13 +521,18 @@ def handler(event, context): @functools.wraps(lambda_handler) def decorate(event, context, *args, **kwargs): - lambda_context = build_lambda_context_model(context) + unwrapped_context = ( + build_lambda_context_model(context.lambda_context) + if is_durable_context(context) + else build_lambda_context_model(context) + ) + cold_start = _is_cold_start() if clear_state: - self.structure_logs(cold_start=cold_start, **lambda_context.__dict__) + self.structure_logs(cold_start=cold_start, **unwrapped_context.__dict__) else: - self.append_keys(cold_start=cold_start, **lambda_context.__dict__) + self.append_keys(cold_start=cold_start, **unwrapped_context.__dict__) if correlation_id_path: self.set_correlation_id( diff --git a/aws_lambda_powertools/metrics/base.py b/aws_lambda_powertools/metrics/base.py index ee7553148b1..2e5ea59df20 100644 --- a/aws_lambda_powertools/metrics/base.py +++ b/aws_lambda_powertools/metrics/base.py @@ -37,7 +37,7 @@ reset_cold_start_flag, # noqa: F401 # backwards compatibility ) from aws_lambda_powertools.shared import constants -from aws_lambda_powertools.shared.functions import resolve_env_var_choice +from aws_lambda_powertools.shared.functions import is_durable_context, resolve_env_var_choice if TYPE_CHECKING: from collections.abc import Callable, Generator @@ -430,12 +430,13 @@ def handler(event, context): @functools.wraps(lambda_handler) def decorate(event, context, *args, **kwargs): + unwrapped_context = context.lambda_context if is_durable_context(context) else context try: if default_dimensions: self.set_default_dimensions(**default_dimensions) - response = lambda_handler(event, context, *args, **kwargs) + response = lambda_handler(event, unwrapped_context, *args, **kwargs) if capture_cold_start_metric: - self._add_cold_start_metric(context=context) + self._add_cold_start_metric(context=unwrapped_context) finally: self.flush_metrics(raise_on_empty_metrics=raise_on_empty_metrics) diff --git a/aws_lambda_powertools/metrics/provider/base.py b/aws_lambda_powertools/metrics/provider/base.py index 3aab6e7561e..4db047eae45 100644 --- a/aws_lambda_powertools/metrics/provider/base.py +++ b/aws_lambda_powertools/metrics/provider/base.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any from aws_lambda_powertools.metrics.provider import cold_start +from aws_lambda_powertools.shared.functions import is_durable_context if TYPE_CHECKING: from aws_lambda_powertools.shared.types import AnyCallableT @@ -206,7 +207,8 @@ def decorate(event, context, *args, **kwargs): try: response = lambda_handler(event, context, *args, **kwargs) if capture_cold_start_metric: - self._add_cold_start_metric(context=context) + unwrapped_context = context.lambda_context if is_durable_context(context) else context + self._add_cold_start_metric(context=unwrapped_context) finally: self.flush_metrics(raise_on_empty_metrics=raise_on_empty_metrics) diff --git a/aws_lambda_powertools/shared/functions.py b/aws_lambda_powertools/shared/functions.py index ea29ccf3ea5..b02f99d7665 100644 --- a/aws_lambda_powertools/shared/functions.py +++ b/aws_lambda_powertools/shared/functions.py @@ -8,13 +8,15 @@ import warnings from binascii import Error as BinAsciiError from pathlib import Path -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, TypeGuard, overload from aws_lambda_powertools.shared import constants if TYPE_CHECKING: from collections.abc import Generator + from aws_lambda_powertools.utilities.typing import DurableContextProtocol + logger = logging.getLogger(__name__) @@ -307,3 +309,8 @@ def decode_header_bytes(byte_list): # Convert signed bytes to unsigned (0-255 range) unsigned_bytes = [(b & 0xFF) for b in byte_list] return bytes(unsigned_bytes) + + +def is_durable_context(context: Any) -> TypeGuard[DurableContextProtocol]: + """Check if context is a Step Functions durable context wrapping a Lambda context.""" + return hasattr(context, "state") and hasattr(context, "lambda_context") diff --git a/tests/functional/logger/required_dependencies/test_logger.py b/tests/functional/logger/required_dependencies/test_logger.py index 2a960582e3f..c0b0046cfed 100644 --- a/tests/functional/logger/required_dependencies/test_logger.py +++ b/tests/functional/logger/required_dependencies/test_logger.py @@ -48,6 +48,11 @@ def lambda_context(): return namedtuple("LambdaContext", lambda_context.keys())(*lambda_context.values()) +@pytest.fixture +def durable_context(lambda_context): + return namedtuple("DurableContext", ["state", "lambda_context"])(state={}, lambda_context=lambda_context) + + @pytest.fixture def lambda_event(): return {"greeting": "hello"} @@ -1578,3 +1583,20 @@ def test_child_logger_with_caplog(caplog): assert len(caplog.records) == 1 assert pytest_handler_existence is True + + +def test_logger_with_durable_context(lambda_context, durable_context, stdout, service_name): + # GIVEN Logger is initialized and a durable context wrapping the lambda context + logger = Logger(service=service_name, stream=stdout) + + @logger.inject_lambda_context + def handler(event, context): + logger.info("Hello") + + # WHEN handler is called with durable context + handler({}, durable_context) + + # THEN lambda contextual info should be extracted from durable context + log = capture_logging_output(stdout) + assert log["function_name"] == lambda_context.function_name + assert log["function_request_id"] == lambda_context.aws_request_id diff --git a/tests/functional/metrics/conftest.py b/tests/functional/metrics/conftest.py index f0b3766a57d..47c2d1b5f66 100644 --- a/tests/functional/metrics/conftest.py +++ b/tests/functional/metrics/conftest.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import namedtuple from typing import Any import pytest @@ -96,3 +97,13 @@ def a_hundred_metrics() -> list[dict[str, str]]: @pytest.fixture def a_hundred_metric_values() -> list[dict[str, str]]: return [{"name": "metric", "unit": "Count", "value": i} for i in range(100)] + + +@pytest.fixture +def lambda_context(): + return namedtuple("LambdaContext", "function_name")("example_fn") + + +@pytest.fixture +def durable_context(lambda_context): + return namedtuple("DurableContext", ["state", "lambda_context"])(state={}, lambda_context=lambda_context) diff --git a/tests/functional/metrics/required_dependencies/test_metrics_cloudwatch_emf.py b/tests/functional/metrics/required_dependencies/test_metrics_cloudwatch_emf.py index e32d7f3a880..834575e4754 100644 --- a/tests/functional/metrics/required_dependencies/test_metrics_cloudwatch_emf.py +++ b/tests/functional/metrics/required_dependencies/test_metrics_cloudwatch_emf.py @@ -19,6 +19,7 @@ SchemaValidationError, single_metric, ) +from aws_lambda_powertools.metrics.base import SingleMetric from aws_lambda_powertools.metrics.provider.cloudwatch_emf.cloudwatch import ( AmazonCloudWatchEMFProvider, ) @@ -1573,3 +1574,60 @@ def test_metrics_disabled_with_dev_mode_false_and_metrics_disabled_true(monkeypa # THEN no metrics should have been recorded captured = capsys.readouterr() assert not captured.out + + +def test_log_metrics_with_durable_context(capsys, metrics, dimensions, namespace, durable_context): + # GIVEN Metrics is initialized and a durable context wrapping the lambda context + my_metrics = Metrics(namespace=namespace) + for metric in metrics: + my_metrics.add_metric(**metric) + for dimension in dimensions: + my_metrics.add_dimension(**dimension) + + @my_metrics.log_metrics + def lambda_handler(evt, ctx): + pass + + # WHEN handler is called with durable context + lambda_handler({}, durable_context) + output = capture_metrics_output(capsys) + expected = serialize_metrics(metrics=metrics, dimensions=dimensions, namespace=namespace) + + # THEN metrics should be flushed correctly + remove_timestamp(metrics=[output, expected]) + assert expected == output + + +def test_log_metrics_capture_cold_start_metric_with_durable_context(capsys, namespace, service, durable_context): + # GIVEN Metrics is initialized and a durable context wrapping the lambda context + my_metrics = Metrics(service=service, namespace=namespace) + + @my_metrics.log_metrics(capture_cold_start_metric=True) + def lambda_handler(evt, context): + pass + + # WHEN handler is called with durable context + lambda_handler({}, durable_context) + output = capture_metrics_output(capsys) + + # THEN ColdStart metric should use function_name from unwrapped lambda context + assert output["ColdStart"] == [1.0] + assert output["function_name"] == "example_fn" + + +def test_single_metric_log_metrics_with_durable_context(capsys, namespace, durable_context): + # GIVEN SingleMetric is initialized with a durable context + metric = SingleMetric(namespace=namespace) + + @metric.log_metrics(capture_cold_start_metric=True) + def lambda_handler(evt, ctx): + metric.add_metric(name="TestMetric", unit=MetricUnit.Count, value=1) + + # WHEN handler is called with durable context + lambda_handler({}, durable_context) + output = capsys.readouterr().out.strip().split("\n") + + # THEN cold start metric should use function_name from unwrapped context + cold_start_output = json.loads(output[0]) + assert cold_start_output["ColdStart"] == [1.0] + assert cold_start_output["function_name"] == "example_fn" diff --git a/tests/functional/metrics/required_dependencies/test_metrics_provider.py b/tests/functional/metrics/required_dependencies/test_metrics_provider.py index 274d9a7c276..0b46470486b 100644 --- a/tests/functional/metrics/required_dependencies/test_metrics_provider.py +++ b/tests/functional/metrics/required_dependencies/test_metrics_provider.py @@ -78,3 +78,33 @@ def lambda_handler(evt, context, additional_arg, additional_kw_arg="default_valu # the wrapped function is passed additional arguments assert lambda_handler({}, {}, "arg_value", additional_kw_arg="kw_arg_value") == ("arg_value", "kw_arg_value") assert lambda_handler({}, {}, "arg_value") == ("arg_value", "default_value") + + +def test_log_metrics_with_durable_context(capsys, metric, durable_context): + provider = FakeMetricsProvider() + metrics = Metrics(provider=provider) + + @metrics.log_metrics + def lambda_handler(evt, ctx): + metrics.add_metric(**metric) + + lambda_handler({}, durable_context) + output = capture_metrics_output(capsys) + + assert output[0]["name"] == metric["name"] + assert output[0]["value"] == metric["value"] + + +def test_log_metrics_cold_start_with_durable_context(capsys, durable_context): + provider = FakeMetricsProvider() + metrics = Metrics(provider=provider) + + @metrics.log_metrics(capture_cold_start_metric=True) + def lambda_handler(evt, ctx): + return True + + lambda_handler({}, durable_context) + output = capture_metrics_output(capsys) + + assert output[0]["name"] == "ColdStart" + assert output[0]["value"] == 1