Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _emit_task_span(ti, state):
"airflow.task_instance.try_number": ti.try_number,
"airflow.task_instance.map_index": ti.map_index if ti.map_index is not None else -1,
"airflow.task_instance.state": state,
"airflow.task_instance.id": ti.id,
"airflow.task_instance.id": str(ti.id),
}
)
status_code = StatusCode.OK if state == TaskInstanceState.SUCCESS else StatusCode.ERROR
Expand Down Expand Up @@ -524,7 +524,10 @@ def _create_ti_state_update_query_and_update_state(
ti_patch_payload.outlet_events,
session,
)
_emit_task_span(ti, state=updated_state)
try:
_emit_task_span(ti, state=updated_state)
except Exception:
log.warning("Failed to emit task span", exc_info=True)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down
15 changes: 12 additions & 3 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@

from airflow._shared.observability.metrics.dual_stats_manager import DualStatsManager
from airflow._shared.observability.metrics.stats import Stats
from airflow._shared.observability.traces import new_dagrun_trace_carrier, override_ids
from airflow._shared.observability.traces import (
TASK_SPAN_DETAIL_LEVEL_KEY,
new_dagrun_trace_carrier,
override_ids,
)
from airflow._shared.timezones import timezone
from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext
from airflow.configuration import conf as airflow_conf
Expand Down Expand Up @@ -376,7 +380,9 @@ def __init__(
self.triggered_by = triggered_by
self.triggering_user_name = triggering_user_name
self.scheduled_by_job_id = None
self.context_carrier: dict[str, str] = new_dagrun_trace_carrier()
self.context_carrier: dict[str, str] = new_dagrun_trace_carrier(
task_span_detail_level=self.conf.get(TASK_SPAN_DETAIL_LEVEL_KEY, None)
)

if not isinstance(partition_key, str | None):
raise ValueError(
Expand Down Expand Up @@ -1268,7 +1274,10 @@ def recalculate(self) -> _UnfinishedStates:
self.data_interval_end,
)
session.flush()
self._emit_dagrun_span(state=self.state)
try:
self._emit_dagrun_span(state=self.state)
except Exception:
self.log.warning("Failed to emit dag run span", exc_info=True)

self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis)
self._emit_duration_stats_for_finished_state()
Expand Down
10 changes: 8 additions & 2 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@
from airflow import settings
from airflow._shared.observability.metrics.dual_stats_manager import DualStatsManager
from airflow._shared.observability.metrics.stats import Stats
from airflow._shared.observability.traces import new_dagrun_trace_carrier, new_task_run_carrier
from airflow._shared.observability.traces import (
TASK_SPAN_DETAIL_LEVEL_KEY,
new_dagrun_trace_carrier,
new_task_run_carrier,
)
from airflow._shared.timezones import timezone
from airflow.assets.manager import asset_manager
from airflow.configuration import conf
Expand Down Expand Up @@ -399,7 +403,9 @@ def clear_task_instances(
# Always update clear_number and queued_at when clearing tasks, regardless of state
dr.clear_number += 1
dr.queued_at = timezone.utcnow()
dr.context_carrier = new_dagrun_trace_carrier()
dr.context_carrier = new_dagrun_trace_carrier(
task_span_detail_level=dr.conf.get(TASK_SPAN_DETAIL_LEVEL_KEY) if dr.conf else None
)

_recalculate_dagrun_queued_at_deadlines(dr, dr.queued_at, session)

Expand Down
83 changes: 73 additions & 10 deletions airflow-core/tests/integration/otel/test_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def wait_for_otel_collector(host: str, port: int, timeout: int = 120) -> None:
)


def unpause_trigger_dag_and_get_run_id(dag_id: str) -> str:
def unpause_trigger_dag_and_get_run_id(dag_id: str, conf: dict | None = None) -> str:
unpause_command = ["airflow", "dags", "unpause", dag_id]

# Unpause the dag using the cli.
Expand All @@ -106,6 +106,11 @@ def unpause_trigger_dag_and_get_run_id(dag_id: str) -> str:
execution_date.isoformat(),
]

if conf:
import json

trigger_command += ["--conf", json.dumps(conf)]

# Trigger the dag using the cli.
subprocess.run(trigger_command, check=True, env=os.environ.copy())

Expand Down Expand Up @@ -436,7 +441,65 @@ def test_export_metrics_during_process_shutdown(self, capfd):
assert set(metrics_to_check).issubset(metrics_dict.keys())

@pytest.mark.execution_timeout(90)
def test_dag_execution_succeeds(self, capfd):
@pytest.mark.parametrize(
("task_span_detail_level", "expected_hierarchy"),
[
pytest.param(
None,
{
"dag_run.otel_test_dag": None,
"sub_span1": "worker.task1",
"task_run.task1": "dag_run.otel_test_dag",
"worker.task1": "task_run.task1",
},
id="default_spans",
),
pytest.param(
2,
{
"hook.on_starting": "startup",
"get_bundle": "parse",
"initialize": "parse",
"_verify_bundle_access": "parse",
"make BundleDagBag": "parse",
"parse": "startup",
"get_template_context": "startup",
"startup": "worker.task1",
"delete xcom": "run",
"get_template_env": "_prepare",
"render_templates": "_prepare",
"_serialize_rendered_fields": "_prepare",
"set_rendered_fields": "_prepare",
"set_rendered_map_index": "_prepare",
"_validate_task_inlets_and_outlets": "_prepare",
"listener.on_task_instance_running": "_prepare",
"_prepare": "run",
"prepare context": "_execute_task",
"pre-execute": "_execute_task",
"on_execute_callback": "_execute_task",
"execute": "_execute_task",
"post_execute_hook": "_execute_task",
"_execute_task": "run",
"render_map_index": "run",
"push xcom": "run",
"handle success": "run",
"handle_extra_links": "finalize",
"success_callback": "finalize",
"listener.on_task_instance_success": "finalize",
"listener.before_stopping": "finalize",
"finalize": "run",
"run": "worker.task1",
"close_socket": "worker.task1",
"sub_span1": "prepare context",
"dag_run.otel_test_dag": None,
"task_run.task1": "dag_run.otel_test_dag",
"worker.task1": "task_run.task1",
},
id="detail_spans",
),
],
)
def test_dag_execution_succeeds(self, capfd, task_span_detail_level, expected_hierarchy):
"""The same scheduler will start and finish the dag processing."""
scheduler_process = None
apiserver_process = None
Expand All @@ -452,7 +515,13 @@ def test_dag_execution_succeeds(self, capfd):

assert dag is not None

run_id = unpause_trigger_dag_and_get_run_id(dag_id=dag_id)
conf = None
if task_span_detail_level is not None:
from airflow_shared.observability.traces import TASK_SPAN_DETAIL_LEVEL_KEY

conf = {TASK_SPAN_DETAIL_LEVEL_KEY: task_span_detail_level}

run_id = unpause_trigger_dag_and_get_run_id(dag_id=dag_id, conf=conf)

# Skip the span_status check.
wait_for_dag_run(dag_id=dag_id, run_id=run_id, max_wait_time=90)
Expand Down Expand Up @@ -490,7 +559,6 @@ def test_dag_execution_succeeds(self, capfd):
service_name = os.environ.get("OTEL_SERVICE_NAME", "test")
r = requests.get(f"http://{host}:16686/api/traces?service={service_name}")
data = r.json()

trace = data["data"][-1]
spans = trace["spans"]

Expand All @@ -507,12 +575,7 @@ def get_parent_span_id(span):
return nested

nested = get_span_hierarchy()
assert nested == {
"dag_run.otel_test_dag": None,
"sub_span1": "worker.task1",
"task_run.task1": "dag_run.otel_test_dag",
"worker.task1": "task_run.task1",
}
assert nested == expected_hierarchy

def start_scheduler(self, capture_output: bool = False):
stdout = None if capture_output else subprocess.DEVNULL
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/tests/unit/listeners/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_listener_suppresses_exceptions(create_task_instance, session, cap_struc

ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED)
ti.run()
assert "error calling listener" in cap_structlog
assert "error calling on_task_instance_success listener" in cap_structlog


@provide_session
Expand Down
20 changes: 20 additions & 0 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -3562,3 +3562,23 @@ def test_emit_dagrun_span_with_none_or_empty_carrier(self, dag_maker, session, c
assert spans[0].name == f"dag_run.{dr.dag_id}"
else:
assert len(spans) == 0

@pytest.mark.db_test
def test_context_carrier_includes_detail_level_from_conf(self, dag_maker):
"""DagRun created with TASK_SPAN_DETAIL_LEVEL_KEY in conf should encode the level in trace state."""
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from airflow._shared.observability.traces import (
TASK_SPAN_DETAIL_LEVEL_KEY,
get_task_span_detail_level,
)

with dag_maker("test_tracing_detail_level"):
EmptyOperator(task_id="t1")
dr = dag_maker.create_dagrun(conf={TASK_SPAN_DETAIL_LEVEL_KEY: 2})

ctx = TraceContextTextMapPropagator().extract(dr.context_carrier)
from opentelemetry import trace

span = trace.get_current_span(ctx)
assert get_task_span_detail_level(span) == 2
22 changes: 22 additions & 0 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3560,3 +3560,25 @@ def test_clear_task_instances_resets_context_carrier(dag_maker, session):

assert ti.context_carrier["traceparent"] != original_ti_traceparent
assert dag_run.context_carrier["traceparent"] != original_dr_traceparent


@pytest.mark.db_test
def test_clear_task_instances_preserves_detail_level(dag_maker, session):
"""clear_task_instances should produce a new context_carrier that keeps the detail level from dag run conf."""
from airflow._shared.observability.traces import (
TASK_SPAN_DETAIL_LEVEL_KEY,
get_task_span_detail_level,
)

with dag_maker("test_clear_preserves_level"):
EmptyOperator(task_id="t1")
dag_run = dag_maker.create_dagrun(conf={TASK_SPAN_DETAIL_LEVEL_KEY: 2})
ti = dag_run.get_task_instance("t1", session=session)
ti.state = TaskInstanceState.SUCCESS
session.flush()

clear_task_instances([ti], session)

new_ctx = TraceContextTextMapPropagator().extract(dag_run.context_carrier)
span = otel_trace.get_current_span(new_ctx)
assert get_task_span_detail_level(span) == 2
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import NonRecordingSpan, SpanContext, TraceFlags
from opentelemetry.trace import NonRecordingSpan, Span, SpanContext, TraceFlags, TraceState
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

if TYPE_CHECKING:
Expand Down Expand Up @@ -56,14 +56,20 @@ def generate_trace_id(self):
return super().generate_trace_id()


def new_dagrun_trace_carrier() -> dict[str, str]:
TASK_SPAN_DETAIL_LEVEL_KEY = "airflow/task_span_detail_level"
DEFAULT_TASK_SPAN_DETAIL_LEVEL = 1


def new_dagrun_trace_carrier(task_span_detail_level=None) -> dict[str, str]:
"""Generate a fresh W3C traceparent carrier without creating a recordable span."""
gen = RandomIdGenerator()
trace_state_entries = build_trace_state_entries(task_span_detail_level)
span_ctx = SpanContext(
trace_id=gen.generate_trace_id(),
span_id=gen.generate_span_id(),
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(entries=trace_state_entries),
)
ctx = trace.set_span_in_context(NonRecordingSpan(span_ctx))
carrier: dict[str, str] = {}
Expand All @@ -82,6 +88,28 @@ def new_task_run_carrier(dag_run_context_carrier):
return carrier


def build_trace_state_entries(task_span_detail_level) -> list[tuple[str, str]]:
trace_state_entries = []
if task_span_detail_level is not None:
try:
level = int(task_span_detail_level)
except Exception:
level = None
if level:
trace_state_entries.append((TASK_SPAN_DETAIL_LEVEL_KEY, str(level)))
return trace_state_entries


def get_task_span_detail_level(span: Span):
span_ctx = span.get_span_context()
trace_state = span_ctx.trace_state
try:
return int(trace_state.get(TASK_SPAN_DETAIL_LEVEL_KEY, default=DEFAULT_TASK_SPAN_DETAIL_LEVEL))
except Exception:
log.warning("%s config in dag run conf must be integer.", TASK_SPAN_DETAIL_LEVEL_KEY)
return DEFAULT_TASK_SPAN_DETAIL_LEVEL


@contextmanager
def override_ids(trace_id, span_id, ctx=None):
ctx = context.set_value(OVERRIDE_TRACE_ID_KEY, trace_id, context=ctx)
Expand Down
Loading
Loading