diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 5f5073c916b68..45f92ad32cb96 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -468,7 +468,7 @@ def _emit_task_span(ti, state): "airflow.task_instance.try_number": ti.try_number, "airflow.task_instance.map_index": ti.map_index if ti.map_index is not None else -1, "airflow.task_instance.state": state, - "airflow.task_instance.id": ti.id, + "airflow.task_instance.id": str(ti.id), } ) status_code = StatusCode.OK if state == TaskInstanceState.SUCCESS else StatusCode.ERROR @@ -524,7 +524,10 @@ def _create_ti_state_update_query_and_update_state( ti_patch_payload.outlet_events, session, ) - _emit_task_span(ti, state=updated_state) + try: + _emit_task_span(ti, state=updated_state) + except Exception: + log.warning("Failed to emit task span", exc_info=True) elif isinstance(ti_patch_payload, TIDeferredStatePayload): # Calculate timeout if it was passed timeout = None diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 459bb12eeaafd..26e1b5683715b 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -62,7 +62,11 @@ from airflow._shared.observability.metrics.dual_stats_manager import DualStatsManager from airflow._shared.observability.metrics.stats import Stats -from airflow._shared.observability.traces import new_dagrun_trace_carrier, override_ids +from airflow._shared.observability.traces import ( + TASK_SPAN_DETAIL_LEVEL_KEY, + new_dagrun_trace_carrier, + override_ids, +) from airflow._shared.timezones import timezone from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext from airflow.configuration import conf as airflow_conf @@ -376,7 +380,9 @@ def __init__( self.triggered_by = triggered_by self.triggering_user_name = triggering_user_name self.scheduled_by_job_id = None - self.context_carrier: dict[str, str] = new_dagrun_trace_carrier() + self.context_carrier: dict[str, str] = new_dagrun_trace_carrier( + task_span_detail_level=self.conf.get(TASK_SPAN_DETAIL_LEVEL_KEY, None) + ) if not isinstance(partition_key, str | None): raise ValueError( @@ -1268,7 +1274,10 @@ def recalculate(self) -> _UnfinishedStates: self.data_interval_end, ) session.flush() - self._emit_dagrun_span(state=self.state) + try: + self._emit_dagrun_span(state=self.state) + except Exception: + self.log.warning("Failed to emit dag run span", exc_info=True) self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis) self._emit_duration_stats_for_finished_state() diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 09044b5aaf2cc..550094c5b4694 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -68,7 +68,11 @@ from airflow import settings from airflow._shared.observability.metrics.dual_stats_manager import DualStatsManager from airflow._shared.observability.metrics.stats import Stats -from airflow._shared.observability.traces import new_dagrun_trace_carrier, new_task_run_carrier +from airflow._shared.observability.traces import ( + TASK_SPAN_DETAIL_LEVEL_KEY, + new_dagrun_trace_carrier, + new_task_run_carrier, +) from airflow._shared.timezones import timezone from airflow.assets.manager import asset_manager from airflow.configuration import conf @@ -399,7 +403,9 @@ def clear_task_instances( # Always update clear_number and queued_at when clearing tasks, regardless of state dr.clear_number += 1 dr.queued_at = timezone.utcnow() - dr.context_carrier = new_dagrun_trace_carrier() + dr.context_carrier = new_dagrun_trace_carrier( + task_span_detail_level=dr.conf.get(TASK_SPAN_DETAIL_LEVEL_KEY) if dr.conf else None + ) _recalculate_dagrun_queued_at_deadlines(dr, dr.queued_at, session) diff --git a/airflow-core/tests/integration/otel/test_otel.py b/airflow-core/tests/integration/otel/test_otel.py index 6852b47af0402..b589d86d16c96 100644 --- a/airflow-core/tests/integration/otel/test_otel.py +++ b/airflow-core/tests/integration/otel/test_otel.py @@ -86,7 +86,7 @@ def wait_for_otel_collector(host: str, port: int, timeout: int = 120) -> None: ) -def unpause_trigger_dag_and_get_run_id(dag_id: str) -> str: +def unpause_trigger_dag_and_get_run_id(dag_id: str, conf: dict | None = None) -> str: unpause_command = ["airflow", "dags", "unpause", dag_id] # Unpause the dag using the cli. @@ -106,6 +106,11 @@ def unpause_trigger_dag_and_get_run_id(dag_id: str) -> str: execution_date.isoformat(), ] + if conf: + import json + + trigger_command += ["--conf", json.dumps(conf)] + # Trigger the dag using the cli. subprocess.run(trigger_command, check=True, env=os.environ.copy()) @@ -436,7 +441,65 @@ def test_export_metrics_during_process_shutdown(self, capfd): assert set(metrics_to_check).issubset(metrics_dict.keys()) @pytest.mark.execution_timeout(90) - def test_dag_execution_succeeds(self, capfd): + @pytest.mark.parametrize( + ("task_span_detail_level", "expected_hierarchy"), + [ + pytest.param( + None, + { + "dag_run.otel_test_dag": None, + "sub_span1": "worker.task1", + "task_run.task1": "dag_run.otel_test_dag", + "worker.task1": "task_run.task1", + }, + id="default_spans", + ), + pytest.param( + 2, + { + "hook.on_starting": "startup", + "get_bundle": "parse", + "initialize": "parse", + "_verify_bundle_access": "parse", + "make BundleDagBag": "parse", + "parse": "startup", + "get_template_context": "startup", + "startup": "worker.task1", + "delete xcom": "run", + "get_template_env": "_prepare", + "render_templates": "_prepare", + "_serialize_rendered_fields": "_prepare", + "set_rendered_fields": "_prepare", + "set_rendered_map_index": "_prepare", + "_validate_task_inlets_and_outlets": "_prepare", + "listener.on_task_instance_running": "_prepare", + "_prepare": "run", + "prepare context": "_execute_task", + "pre-execute": "_execute_task", + "on_execute_callback": "_execute_task", + "execute": "_execute_task", + "post_execute_hook": "_execute_task", + "_execute_task": "run", + "render_map_index": "run", + "push xcom": "run", + "handle success": "run", + "handle_extra_links": "finalize", + "success_callback": "finalize", + "listener.on_task_instance_success": "finalize", + "listener.before_stopping": "finalize", + "finalize": "run", + "run": "worker.task1", + "close_socket": "worker.task1", + "sub_span1": "prepare context", + "dag_run.otel_test_dag": None, + "task_run.task1": "dag_run.otel_test_dag", + "worker.task1": "task_run.task1", + }, + id="detail_spans", + ), + ], + ) + def test_dag_execution_succeeds(self, capfd, task_span_detail_level, expected_hierarchy): """The same scheduler will start and finish the dag processing.""" scheduler_process = None apiserver_process = None @@ -452,7 +515,13 @@ def test_dag_execution_succeeds(self, capfd): assert dag is not None - run_id = unpause_trigger_dag_and_get_run_id(dag_id=dag_id) + conf = None + if task_span_detail_level is not None: + from airflow_shared.observability.traces import TASK_SPAN_DETAIL_LEVEL_KEY + + conf = {TASK_SPAN_DETAIL_LEVEL_KEY: task_span_detail_level} + + run_id = unpause_trigger_dag_and_get_run_id(dag_id=dag_id, conf=conf) # Skip the span_status check. wait_for_dag_run(dag_id=dag_id, run_id=run_id, max_wait_time=90) @@ -490,7 +559,6 @@ def test_dag_execution_succeeds(self, capfd): service_name = os.environ.get("OTEL_SERVICE_NAME", "test") r = requests.get(f"http://{host}:16686/api/traces?service={service_name}") data = r.json() - trace = data["data"][-1] spans = trace["spans"] @@ -507,12 +575,7 @@ def get_parent_span_id(span): return nested nested = get_span_hierarchy() - assert nested == { - "dag_run.otel_test_dag": None, - "sub_span1": "worker.task1", - "task_run.task1": "dag_run.otel_test_dag", - "worker.task1": "task_run.task1", - } + assert nested == expected_hierarchy def start_scheduler(self, capture_output: bool = False): stdout = None if capture_output else subprocess.DEVNULL diff --git a/airflow-core/tests/unit/listeners/test_listeners.py b/airflow-core/tests/unit/listeners/test_listeners.py index aad2ea7b6e863..8f17e6cb94458 100644 --- a/airflow-core/tests/unit/listeners/test_listeners.py +++ b/airflow-core/tests/unit/listeners/test_listeners.py @@ -118,7 +118,7 @@ def test_listener_suppresses_exceptions(create_task_instance, session, cap_struc ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) ti.run() - assert "error calling listener" in cap_structlog + assert "error calling on_task_instance_success listener" in cap_structlog @provide_session diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index dd34c2d10e7ea..242b4b3cf1113 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -3562,3 +3562,23 @@ def test_emit_dagrun_span_with_none_or_empty_carrier(self, dag_maker, session, c assert spans[0].name == f"dag_run.{dr.dag_id}" else: assert len(spans) == 0 + + @pytest.mark.db_test + def test_context_carrier_includes_detail_level_from_conf(self, dag_maker): + """DagRun created with TASK_SPAN_DETAIL_LEVEL_KEY in conf should encode the level in trace state.""" + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + from airflow._shared.observability.traces import ( + TASK_SPAN_DETAIL_LEVEL_KEY, + get_task_span_detail_level, + ) + + with dag_maker("test_tracing_detail_level"): + EmptyOperator(task_id="t1") + dr = dag_maker.create_dagrun(conf={TASK_SPAN_DETAIL_LEVEL_KEY: 2}) + + ctx = TraceContextTextMapPropagator().extract(dr.context_carrier) + from opentelemetry import trace + + span = trace.get_current_span(ctx) + assert get_task_span_detail_level(span) == 2 diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index cd1dd3c337a12..403478b78d837 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -3560,3 +3560,25 @@ def test_clear_task_instances_resets_context_carrier(dag_maker, session): assert ti.context_carrier["traceparent"] != original_ti_traceparent assert dag_run.context_carrier["traceparent"] != original_dr_traceparent + + +@pytest.mark.db_test +def test_clear_task_instances_preserves_detail_level(dag_maker, session): + """clear_task_instances should produce a new context_carrier that keeps the detail level from dag run conf.""" + from airflow._shared.observability.traces import ( + TASK_SPAN_DETAIL_LEVEL_KEY, + get_task_span_detail_level, + ) + + with dag_maker("test_clear_preserves_level"): + EmptyOperator(task_id="t1") + dag_run = dag_maker.create_dagrun(conf={TASK_SPAN_DETAIL_LEVEL_KEY: 2}) + ti = dag_run.get_task_instance("t1", session=session) + ti.state = TaskInstanceState.SUCCESS + session.flush() + + clear_task_instances([ti], session) + + new_ctx = TraceContextTextMapPropagator().extract(dag_run.context_carrier) + span = otel_trace.get_current_span(new_ctx) + assert get_task_span_detail_level(span) == 2 diff --git a/shared/observability/src/airflow_shared/observability/traces/__init__.py b/shared/observability/src/airflow_shared/observability/traces/__init__.py index dc3532262d15e..04746f2023204 100644 --- a/shared/observability/src/airflow_shared/observability/traces/__init__.py +++ b/shared/observability/src/airflow_shared/observability/traces/__init__.py @@ -28,7 +28,7 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter from opentelemetry.sdk.trace.id_generator import RandomIdGenerator -from opentelemetry.trace import NonRecordingSpan, SpanContext, TraceFlags +from opentelemetry.trace import NonRecordingSpan, Span, SpanContext, TraceFlags, TraceState from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator if TYPE_CHECKING: @@ -56,14 +56,20 @@ def generate_trace_id(self): return super().generate_trace_id() -def new_dagrun_trace_carrier() -> dict[str, str]: +TASK_SPAN_DETAIL_LEVEL_KEY = "airflow/task_span_detail_level" +DEFAULT_TASK_SPAN_DETAIL_LEVEL = 1 + + +def new_dagrun_trace_carrier(task_span_detail_level=None) -> dict[str, str]: """Generate a fresh W3C traceparent carrier without creating a recordable span.""" gen = RandomIdGenerator() + trace_state_entries = build_trace_state_entries(task_span_detail_level) span_ctx = SpanContext( trace_id=gen.generate_trace_id(), span_id=gen.generate_span_id(), is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED), + trace_state=TraceState(entries=trace_state_entries), ) ctx = trace.set_span_in_context(NonRecordingSpan(span_ctx)) carrier: dict[str, str] = {} @@ -82,6 +88,28 @@ def new_task_run_carrier(dag_run_context_carrier): return carrier +def build_trace_state_entries(task_span_detail_level) -> list[tuple[str, str]]: + trace_state_entries = [] + if task_span_detail_level is not None: + try: + level = int(task_span_detail_level) + except Exception: + level = None + if level: + trace_state_entries.append((TASK_SPAN_DETAIL_LEVEL_KEY, str(level))) + return trace_state_entries + + +def get_task_span_detail_level(span: Span): + span_ctx = span.get_span_context() + trace_state = span_ctx.trace_state + try: + return int(trace_state.get(TASK_SPAN_DETAIL_LEVEL_KEY, default=DEFAULT_TASK_SPAN_DETAIL_LEVEL)) + except Exception: + log.warning("%s config in dag run conf must be integer.", TASK_SPAN_DETAIL_LEVEL_KEY) + return DEFAULT_TASK_SPAN_DETAIL_LEVEL + + @contextmanager def override_ids(trace_id, span_id, ctx=None): ctx = context.set_value(OVERRIDE_TRACE_ID_KEY, trace_id, context=ctx) diff --git a/shared/observability/tests/observability/test_traces.py b/shared/observability/tests/observability/test_traces.py new file mode 100644 index 0000000000000..b21cc3c876173 --- /dev/null +++ b/shared/observability/tests/observability/test_traces.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from opentelemetry.trace import NonRecordingSpan, SpanContext, TraceFlags, TraceState +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + +from airflow_shared.observability.traces import ( + DEFAULT_TASK_SPAN_DETAIL_LEVEL, + TASK_SPAN_DETAIL_LEVEL_KEY, + build_trace_state_entries, + get_task_span_detail_level, + new_dagrun_trace_carrier, +) + + +class TestBuildTraceStateEntries: + def test_with_integer_level(self): + entries = build_trace_state_entries(2) + assert entries == [(TASK_SPAN_DETAIL_LEVEL_KEY, "2")] + + def test_with_string_level(self): + entries = build_trace_state_entries("3") + assert entries == [(TASK_SPAN_DETAIL_LEVEL_KEY, "3")] + + def test_with_none(self): + assert build_trace_state_entries(None) == [] + + def test_with_zero(self): + # 0 is falsy — treated as no detail level + assert build_trace_state_entries(0) == [] + + def test_with_invalid_string(self): + # Non-integer string should not raise; returns empty + assert build_trace_state_entries("not-a-number") == [] + + +class TestNewDagrunTraceCarrier: + def test_with_detail_level_embeds_level_in_trace_state(self): + carrier = new_dagrun_trace_carrier(task_span_detail_level=2) + ctx = TraceContextTextMapPropagator().extract(carrier) + from opentelemetry import trace + + span_ctx = trace.get_current_span(ctx).get_span_context() + assert span_ctx.trace_state.get(TASK_SPAN_DETAIL_LEVEL_KEY) == "2" + + def test_without_detail_level_has_empty_trace_state(self): + carrier = new_dagrun_trace_carrier() + ctx = TraceContextTextMapPropagator().extract(carrier) + from opentelemetry import trace + + span_ctx = trace.get_current_span(ctx).get_span_context() + assert span_ctx.trace_state.get(TASK_SPAN_DETAIL_LEVEL_KEY) is None + + +class TestGetTaskSpanDetailLevel: + def _make_span_with_trace_state(self, entries: list[tuple[str, str]]) -> NonRecordingSpan: + from opentelemetry.sdk.trace.id_generator import RandomIdGenerator + + gen = RandomIdGenerator() + span_ctx = SpanContext( + trace_id=gen.generate_trace_id(), + span_id=gen.generate_span_id(), + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + trace_state=TraceState(entries=entries), + ) + return NonRecordingSpan(span_ctx) + + def test_returns_default_when_no_trace_state(self): + span = self._make_span_with_trace_state([]) + assert get_task_span_detail_level(span) == DEFAULT_TASK_SPAN_DETAIL_LEVEL + + def test_reads_level_from_trace_state(self): + span = self._make_span_with_trace_state([(TASK_SPAN_DETAIL_LEVEL_KEY, "2")]) + assert get_task_span_detail_level(span) == 2 + + def test_fallback_on_invalid_value(self): + span = self._make_span_with_trace_state([(TASK_SPAN_DETAIL_LEVEL_KEY, "bad")]) + assert get_task_span_detail_level(span) == DEFAULT_TASK_SPAN_DETAIL_LEVEL + + def test_roundtrip_via_carrier(self): + """Level set in new_dagrun_trace_carrier is readable by get_task_span_detail_level.""" + carrier = new_dagrun_trace_carrier(task_span_detail_level=3) + ctx = TraceContextTextMapPropagator().extract(carrier) + from opentelemetry import trace + + span = trace.get_current_span(ctx) + assert get_task_span_detail_level(span) == 3 diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 649c076126c65..3048947b0596c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -21,6 +21,7 @@ import contextvars import functools +import inspect import os import sys import time @@ -36,12 +37,14 @@ import lazy_object_proxy import structlog from opentelemetry import trace +from opentelemetry.trace import Status, StatusCode from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.sdk._shared.observability.metrics.stats import Stats +from airflow.sdk._shared.observability.traces import get_task_span_detail_level from airflow.sdk._shared.template_rendering import truncate_rendered_value from airflow.sdk.api.client import get_hostname, getuser from airflow.sdk.api.datamodels._generated import ( @@ -142,6 +145,38 @@ tracer = trace.get_tracer(__name__) +class detail_span: + """Context manager and decorator that creates a child span when detail level > 1.""" + + def __init__(self, *args, **kwargs): + self._args = args + self._kwargs = kwargs + self._ctx = None + + def _make_ctx(self): + parent_span = trace.get_current_span() + config_level = get_task_span_detail_level(span=parent_span) + if config_level > 1: + return tracer.start_as_current_span(*self._args, **self._kwargs) + return trace.INVALID_SPAN + + def __enter__(self): + self._ctx = self._make_ctx() + return self._ctx.__enter__() + + def __exit__(self, *exc_info): + return self._ctx.__exit__(*exc_info) + + def __call__(self, f): + @functools.wraps(f) + def wrapper(*inner_args, **inner_kwargs): + with self._make_ctx(): + return f(*inner_args, **inner_kwargs) + + wrapper.__signature__ = inspect.signature(f) + return wrapper + + @contextmanager def _make_task_span(msg: StartupDetails): parent_context = ( @@ -210,6 +245,7 @@ def __rich_repr__(self): __rich_repr__.angular = True # type: ignore[attr-defined] + @detail_span("get_template_context") def get_template_context(self) -> Context: # TODO: Move this to `airflow.sdk.execution_time.context` # once we port the entire context logic from airflow/utils/context.py ? @@ -305,6 +341,7 @@ def get_template_context(self) -> Context: return self._cached_template_context + @detail_span("render_templates") def render_templates( self, context: Context | None = None, jinja_env: jinja2.Environment | None = None ) -> BaseOperator: @@ -765,27 +802,31 @@ def _maybe_reschedule_startup_failure( ) +@detail_span("parse") def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # TODO: Task-SDK: # Using BundleDagBag here is about 98% wrong, but it'll do for now from airflow.dag_processing.dagbag import BundleDagBag bundle_info = what.bundle_info - bundle_instance = DagBundlesManager().get_bundle( - name=bundle_info.name, - version=bundle_info.version, - ) - bundle_instance.initialize() + with detail_span("get_bundle"): + bundle_instance = DagBundlesManager().get_bundle( + name=bundle_info.name, + version=bundle_info.version, + ) + with detail_span("initialize"): + bundle_instance.initialize() _verify_bundle_access(bundle_instance, log) dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) - bag = BundleDagBag( - dag_folder=dag_absolute_path, - safe_mode=False, - load_op_links=False, - bundle_path=bundle_instance.path, - bundle_name=bundle_info.name, - ) + with detail_span("make BundleDagBag"): + bag = BundleDagBag( + dag_folder=dag_absolute_path, + safe_mode=False, + load_op_links=False, + bundle_path=bundle_instance.path, + bundle_name=bundle_info.name, + ) if TYPE_CHECKING: assert what.ti.dag_id @@ -849,6 +890,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # 3. Shutdown and report status +@detail_span("_verify_bundle_access") def _verify_bundle_access(bundle_instance: BaseDagBundle, log: Logger) -> None: """ Verify bundle is accessible by the current user. @@ -911,6 +953,7 @@ def get_startup_details() -> StartupDetails: return msg +@detail_span("startup") def startup(msg: StartupDetails) -> tuple[RuntimeTaskInstance, Context, Logger]: # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 os_type = sys.platform @@ -921,10 +964,11 @@ def startup(msg: StartupDetails) -> tuple[RuntimeTaskInstance, Context, Logger]: setproctitle(f"airflow worker -- {msg.ti.id}") - try: - get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) - except Exception: - log.exception("error calling listener") + with detail_span("hook.on_starting"): + try: + get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) + except Exception: + log.exception("error calling on_starting listener") with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): ti = parse(msg, log) @@ -958,7 +1002,9 @@ def startup(msg: StartupDetails) -> tuple[RuntimeTaskInstance, Context, Logger]: # ideally, we should never reach here, but if we do, we should return None, None, None return None, None, None - return ti, ti.get_template_context(), log + template_context = ti.get_template_context() + + return ti, template_context, log def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: @@ -1040,6 +1086,7 @@ def _fallback_serialization(obj): return template_field +@detail_span("_serialize_rendered_fields") def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: from airflow.sdk._shared.secrets_masker import redact @@ -1108,6 +1155,7 @@ def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[d yield attrs.asdict(alias_event) +@detail_span("_prepare") def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: ti.hostname = get_hostname() ti.task = ti.task.prepare_for_execution() @@ -1115,40 +1163,45 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv # update the value of the task that is sent from there context["task"] = ti.task - jinja_env = ti.task.dag.get_template_env() + with detail_span("get_template_env"): + jinja_env = ti.task.dag.get_template_env() ti.render_templates(context=context, jinja_env=jinja_env) if rendered_fields := _serialize_rendered_fields(ti.task): # so that we do not call the API unnecessarily - SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) + with detail_span("set_rendered_fields"): + SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) - # Try to render map_index_template early with available context (will be re-rendered after execution) - # This provides a partial label during task execution for templates using pre-execution context - # If rendering fails here, we suppress the error since it will be re-rendered after execution - try: - if rendered_map_index := _render_map_index(context, ti=ti, log=log): - ti.rendered_map_index = rendered_map_index - log.debug("Sending early rendered map index", length=len(rendered_map_index)) - SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) - except Exception: - log.debug( - "Early rendering of map_index_template failed, will retry after task execution", exc_info=True - ) + with detail_span("set_rendered_map_index"): + # Try to render map_index_template early with available context (will be re-rendered after execution) + # This provides a partial label during task execution for templates using pre-execution context + # If rendering fails here, we suppress the error since it will be re-rendered after execution + try: + if rendered_map_index := _render_map_index(context, ti=ti, log=log): + ti.rendered_map_index = rendered_map_index + log.debug("Sending early rendered map index", length=len(rendered_map_index)) + SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) + except Exception: + log.debug( + "Early rendering of map_index_template failed, will retry after task execution", exc_info=True + ) _validate_task_inlets_and_outlets(ti=ti, log=log) - try: - # TODO: Call pre execute etc. - get_listener_manager().hook.on_task_instance_running( - previous_state=TaskInstanceState.QUEUED, task_instance=ti - ) - except Exception: - log.exception("error calling listener") + with detail_span("listener.on_task_instance_running"): + try: + # TODO: Call pre execute etc. + get_listener_manager().hook.on_task_instance_running( + previous_state=TaskInstanceState.QUEUED, task_instance=ti + ) + except Exception: + log.exception("error calling on_task_instance_running listener") # No error, carry on and execute the task return None +@detail_span("_validate_task_inlets_and_outlets") def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None: if not ti.task.inlets and not ti.task.outlets: return @@ -1247,16 +1300,19 @@ def _on_term(signum, frame): try: # First, clear the xcom data sent from server - if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear): - for x in keys_to_delete: - log.debug("Clearing XCom with key", key=x) - XCom.delete( - key=x, - dag_id=ti.dag_id, - task_id=ti.task_id, - run_id=ti.run_id, - map_index=ti.map_index, - ) + with detail_span("delete xcom"): + if ti._ti_context_from_server and ( + keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear + ): + for x in keys_to_delete: + log.debug("Clearing XCom with key", key=x) + XCom.delete( + key=x, + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=ti.run_id, + map_index=ti.map_index, + ) with set_current_context(context): # This is the earliest that we can render templates -- as if it excepts for any reason we need to @@ -1282,15 +1338,19 @@ def _on_term(signum, frame): ) raise else: # If the task succeeded, render normally to let rendering error bubble up. - previous_rendered_map_index = ti.rendered_map_index - ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) - # Send update only if value changed (e.g., user set context variables during execution) - if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: - SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) - - _push_xcom_if_needed(result, ti, log) + with detail_span("render_map_index"): + previous_rendered_map_index = ti.rendered_map_index + ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) + # Send update only if value changed (e.g., user set context variables during execution) + if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: + SUPERVISOR_COMMS.send( + msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index) + ) - msg, state = _handle_current_task_success(context, ti) + with detail_span("push xcom"): + _push_xcom_if_needed(result, ti, log) + with detail_span("handle success"): + msg, state = _handle_current_task_success(context, ti) except DownstreamTasksSkipped as skip: log.info("Skipping downstream tasks.") tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] @@ -1627,67 +1687,74 @@ def _send_error_email_notification( log.exception("Failed to send email notification") +@detail_span("_execute_task") def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): """Execute Task (optionally with a Timeout) and push Xcom results.""" - task = ti.task - execute = task.execute + with detail_span("prepare context"): + task = ti.task + execute = task.execute - if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method): - from airflow.sdk.serde import deserialize + if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method): + from airflow.sdk.serde import deserialize - next_kwargs_data = ti._ti_context_from_server.next_kwargs or {} - try: - if TYPE_CHECKING: - assert isinstance(next_kwargs_data, dict) - kwargs = deserialize(next_kwargs_data) - except (ImportError, KeyError, AttributeError, TypeError): - from airflow.serialization.serialized_objects import BaseSerialization + next_kwargs_data = ti._ti_context_from_server.next_kwargs or {} + try: + if TYPE_CHECKING: + assert isinstance(next_kwargs_data, dict) + kwargs = deserialize(next_kwargs_data) + except (ImportError, KeyError, AttributeError, TypeError): + from airflow.serialization.serialized_objects import BaseSerialization - kwargs = BaseSerialization.deserialize(next_kwargs_data) + kwargs = BaseSerialization.deserialize(next_kwargs_data) - if TYPE_CHECKING: - assert isinstance(kwargs, dict) - execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) + if TYPE_CHECKING: + assert isinstance(kwargs, dict) + execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) - ctx = contextvars.copy_context() - # Populate the context var so ExecutorSafeguard doesn't complain - ctx.run(ExecutorSafeguard.tracker.set, task) + ctx = contextvars.copy_context() + # Populate the context var so ExecutorSafeguard doesn't complain + ctx.run(ExecutorSafeguard.tracker.set, task) - # Export context in os.environ to make it available for operators to use. - airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) - os.environ.update(airflow_context_vars) + # Export context in os.environ to make it available for operators to use. + airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) + os.environ.update(airflow_context_vars) - outlet_events = context_get_outlet_events(context) + outlet_events = context_get_outlet_events(context) - if (pre_execute_hook := task._pre_execute_hook) is not None: - create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) - if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute: - create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) + with detail_span("pre-execute"): + if (pre_execute_hook := task._pre_execute_hook) is not None: + create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) + if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute: + create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) - _run_task_state_change_callbacks(task, "on_execute_callback", context, log) + with detail_span("on_execute_callback"): + _run_task_state_change_callbacks(task, "on_execute_callback", context, log) - if task.execution_timeout: - from airflow.sdk.execution_time.timeout import timeout + with detail_span("execute") as span: + if task.execution_timeout: + from airflow.sdk.execution_time.timeout import timeout - # TODO: handle timeout in case of deferral - timeout_seconds = task.execution_timeout.total_seconds() - try: - # It's possible we're already timed out, so fast-fail if true - if timeout_seconds <= 0: - raise AirflowTaskTimeout() - # Run task in timeout wrapper - with timeout(timeout_seconds): - result = ctx.run(execute, context=context) - except AirflowTaskTimeout: - task.on_kill() - raise - else: - result = ctx.run(execute, context=context) + # TODO: handle timeout in case of deferral + timeout_seconds = task.execution_timeout.total_seconds() + try: + # It's possible we're already timed out, so fast-fail if true + if timeout_seconds <= 0: + raise AirflowTaskTimeout() + # Run task in timeout wrapper + with timeout(timeout_seconds): + result = ctx.run(execute, context=context) + except AirflowTaskTimeout: + span.add_event("task.execute.timeout") + task.on_kill() + raise + else: + result = ctx.run(execute, context=context) - if (post_execute_hook := task._post_execute_hook) is not None: - create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) - if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute: - create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context) + with detail_span("post_execute_hook"): + if (post_execute_hook := task._post_execute_hook) is not None: + create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) + if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute: + create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context) return result @@ -1748,6 +1815,7 @@ def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length) +@detail_span("finalize") def finalize( ti: RuntimeTaskInstance, state: TaskInstanceState, @@ -1765,69 +1833,87 @@ def finalize( task = ti.task # Pushing xcom for each operator extra links defined on the operator only. - for oe in task.operator_extra_links: - try: - link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] - log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) - _xcom_push_to_db(ti, key=xcom_key, value=link) - except Exception: - log.exception( - "Failed to push an xcom for task operator extra link", - link_name=oe.name, - xcom_key=oe.xcom_key, - ti=ti, - ) - - if getattr(ti.task, "overwrite_rtif_after_execution", False): - log.debug("Overwriting Rendered template fields.") - if ti.task.template_fields: + with detail_span("handle_extra_links"): + for oe in task.operator_extra_links: try: - SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) + link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] + log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) + _xcom_push_to_db(ti, key=xcom_key, value=link) except Exception: - log.exception("Failed to set rendered fields during finalization", ti=ti, task=ti.task) + log.exception( + "Failed to set rendered fields during finalization", + ti=ti, + task=ti.task, + ) + + if getattr(ti.task, "overwrite_rtif_after_execution", False): + with detail_span("overwrite_rtif"): + log.debug("Overwriting Rendered template fields.") + if ti.task.template_fields: + try: + SUPERVISOR_COMMS.send( + SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)) + ) + except Exception: + log.exception( + "Failed to set rendered fields during finalization", + task_id=ti.task_id, + dag_id=ti.dag_id, + ) log.debug("Running finalizers", ti=ti) if state == TaskInstanceState.SUCCESS: - _run_task_state_change_callbacks(task, "on_success_callback", context, log) - try: - get_listener_manager().hook.on_task_instance_success( - previous_state=TaskInstanceState.RUNNING, task_instance=ti - ) - except Exception: - log.exception("error calling listener") + with detail_span("success_callback"): + _run_task_state_change_callbacks(task, "on_success_callback", context, log) + with detail_span("listener.on_task_instance_success"): + try: + get_listener_manager().hook.on_task_instance_success( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) + except Exception: + log.exception("error calling on_task_instance_success listener") elif state == TaskInstanceState.SKIPPED: - _run_task_state_change_callbacks(task, "on_skipped_callback", context, log) - try: - get_listener_manager().hook.on_task_instance_skipped( - previous_state=TaskInstanceState.RUNNING, task_instance=ti - ) - except Exception: - log.exception("error calling listener") + with detail_span("skipped_callback"): + _run_task_state_change_callbacks(task, "on_skipped_callback", context, log) + with detail_span("listener.skipped_callback"): + try: + get_listener_manager().hook.on_task_instance_skipped( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) + except Exception: + log.exception("error calling on_task_instance_skipped listener") elif state == TaskInstanceState.UP_FOR_RETRY: - _run_task_state_change_callbacks(task, "on_retry_callback", context, log) - try: - get_listener_manager().hook.on_task_instance_failed( - previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error - ) - except Exception: - log.exception("error calling listener") + with detail_span("retry_callback"): + _run_task_state_change_callbacks(task, "on_retry_callback", context, log) + with detail_span("listener.retry_callback"): + try: + get_listener_manager().hook.on_task_instance_failed( + previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error + ) + except Exception: + log.exception("error calling on_task_instance_failed listener") if error and task.email_on_retry and task.email: - _send_error_email_notification(task, ti, context, error, log) + with detail_span("email_notif"): + _send_error_email_notification(task, ti, context, error, log) elif state == TaskInstanceState.FAILED: - _run_task_state_change_callbacks(task, "on_failure_callback", context, log) - try: - get_listener_manager().hook.on_task_instance_failed( - previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error - ) - except Exception: - log.exception("error calling listener") + with detail_span("failure_callback"): + _run_task_state_change_callbacks(task, "on_failure_callback", context, log) + with detail_span("listener.failure_callback"): + try: + get_listener_manager().hook.on_task_instance_failed( + previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error + ) + except Exception: + log.exception("error calling on_task_instance_failed listener") if error and task.email_on_failure and task.email: - _send_error_email_notification(task, ti, context, error, log) + with detail_span("send_error_email"): + _send_error_email_notification(task, ti, context, error, log) - try: - get_listener_manager().hook.before_stopping(component=TaskRunnerMarker()) - except Exception: - log.exception("error calling listener") + with detail_span("listener.before_stopping"): + try: + get_listener_manager().hook.before_stopping(component=TaskRunnerMarker()) + except Exception: + log.exception("error calling before_stopping listener") @contextmanager @@ -1858,8 +1944,8 @@ def main(): try: try: startup_details = get_startup_details() - span = _make_task_span(msg=startup_details) - stack.enter_context(span) + span_ctx_mgr = _make_task_span(msg=startup_details) + span = stack.enter_context(span_ctx_mgr) ti, context, log = startup(msg=startup_details) except AirflowRescheduleException as reschedule: log.warning("Rescheduling task during startup, marking task as UP_FOR_RESCHEDULE") @@ -1869,26 +1955,43 @@ def main(): end_date=datetime.now(tz=timezone.utc), ) ) + span.record_exception(reschedule) + span.set_status( + Status(StatusCode.ERROR, description=f"Exception: {type(reschedule).__name__}") + ) sys.exit(0) - with BundleVersionLock( - bundle_name=ti.bundle_instance.name, - bundle_version=ti.bundle_instance.version, - ): - state, _, error = run(ti, context, log) - context["exception"] = error - finalize(ti, state, context, log, error) - except KeyboardInterrupt: + + with detail_span("run") as span: + with BundleVersionLock( + bundle_name=ti.bundle_instance.name, + bundle_version=ti.bundle_instance.version, + ): + state, _, error = run(ti, context, log) + if error: + span.record_exception(error) + span.set_status( + Status(StatusCode.ERROR, description=f"Exception: {type(error).__name__}") + ) + context["exception"] = error + span.set_attribute("state", state.value if state else "unknown") + finalize(ti, state, context, log, error) + except KeyboardInterrupt as e: log.exception("Ctrl-c hit") + span.record_exception(e) + span.set_status(Status(StatusCode.ERROR, description=f"Exception: {type(e).__name__}")) sys.exit(2) - except Exception: + except Exception as e: log.exception("Top level error") + span.record_exception(e) + span.set_status(Status(StatusCode.ERROR, description=f"Exception: {type(e).__name__}")) sys.exit(1) finally: # Ensure the request socket is closed on the child side in all circumstances # before the process fully terminates. - if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: - with suppress(Exception): - SUPERVISOR_COMMS.socket.close() + with detail_span("close_socket"): + if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: + with suppress(Exception): + SUPERVISOR_COMMS.socket.close() def reinit_supervisor_comms() -> None: diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 5aeb009bd33da..925ba06a9c25f 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -32,7 +32,7 @@ import pandas as pd import pytest -from opentelemetry import trace as otel_trace +from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter @@ -40,7 +40,11 @@ from task_sdk import FAKE_BUNDLE from uuid6 import uuid7 -from airflow._shared.observability.traces import OverrideableRandomIdGenerator, new_task_run_carrier +from airflow._shared.observability.traces import ( + OverrideableRandomIdGenerator, + new_dagrun_trace_carrier, + new_task_run_carrier, +) from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span from airflow.listeners import hookimpl from airflow.providers.standard.operators.python import PythonOperator @@ -139,6 +143,7 @@ _make_task_span, _push_xcom_if_needed, _xcom_push, + detail_span, finalize, get_startup_details, parse, @@ -441,7 +446,7 @@ def test_task_span_is_child_of_dag_run_span(make_ti_context): ti_carrier = new_task_run_carrier(dag_run_carrier) # Extract the parent task span context (the stable span ID stored in ti_carrier). - parent_task_span_ctx = otel_trace.get_current_span( + parent_task_span_ctx = trace.get_current_span( context=TraceContextTextMapPropagator().extract(ti_carrier) ).get_span_context() @@ -4722,6 +4727,111 @@ def test_operator_failures_metrics_emitted(self, create_runtime_ti, mock_supervi mock_stats.incr.assert_any_call("ti_failures", tags=stats_tags) +class TestDetailSpan: + """Tests for the detail_span decorator / context manager.""" + + def _make_provider_with_detail_level(self, level: int): + """Return (provider, tracer, carrier) where the carrier encodes the given detail level.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=level) + return provider, t, exporter, carrier + + def test_level_1_no_child_span_as_context_manager(self): + """At detail level 1, entering detail_span should not create a real recorded span.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=1) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + with detail_span("child") as span: + assert span is trace.INVALID_SPAN + + # Only the "parent" span should be recorded; no "child". + names = [s.name for s in exporter.get_finished_spans()] + assert "child" not in names + + def test_level_2_creates_child_span_as_context_manager(self): + """At detail level 2, detail_span should create a real recorded child span.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=2) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + with detail_span("child"): + pass + + names = [s.name for s in exporter.get_finished_spans()] + assert "child" in names + + def test_decorator_at_level_1_does_not_create_span(self): + """@detail_span at level 1 should not produce a recorded span.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=1) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + @detail_span("decorated") + def my_func(): + return 42 + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + result = my_func() + + assert result == 42 + names = [s.name for s in exporter.get_finished_spans()] + assert "decorated" not in names + + def test_decorator_at_level_2_creates_span_and_preserves_return_value(self): + """@detail_span at level 2 creates a span and the wrapped function's return value is preserved.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=2) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + @detail_span("decorated") + def my_func(x): + return x * 2 + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + result = my_func(7) + + assert result == 14 + names = [s.name for s in exporter.get_finished_spans()] + assert "decorated" in names + + def test_exception_in_context_manager_propagates(self): + """Exceptions inside `with detail_span(...)` propagate normally.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=2) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + with pytest.raises(ValueError, match="boom"): + with detail_span("child"): + raise ValueError("boom") + + def test_dag_add_result(create_runtime_ti, mock_supervisor_comms): with DAG(dag_id="test_dag_add_result") as dag: task = PythonOperator(task_id="t", python_callable=lambda: 123)