Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a600fa4
Add tests for SalesforceBulkOperator transient-error retry
nagasrisai Apr 1, 2026
cd74153
Add transient-error retry to SalesforceBulkOperator
nagasrisai Apr 1, 2026
516eaa9
Merge branch 'main' into feat/salesforce-bulk-transient-retry
nagasrisai Apr 1, 2026
0d3d665
Merge branch 'main' into feat/salesforce-bulk-transient-retry
nagasrisai Apr 1, 2026
3250d1a
Fix lint: remove unused pytest import and dead variable assignments
nagasrisai Apr 2, 2026
cd7b6b0
Add input validation for max_retries, retry_delay, and transient_erro…
nagasrisai Apr 2, 2026
a117bcf
Merge branch 'main' into feat/salesforce-bulk-transient-retry
nagasrisai Apr 2, 2026
9ba6b41
Fix IndentationError in _validate_inputs: use consistent 8-space indent
nagasrisai Apr 2, 2026
419c916
Rename retry_delay → bulk_retry_delay to avoid collision with BaseOpe…
nagasrisai Apr 2, 2026
774c3e0
Update tests: retry_delay → bulk_retry_delay
nagasrisai Apr 2, 2026
38c9dcc
Fix: correct mock chain for hook conn.bulk; ruff format long dicts
nagasrisai Apr 2, 2026
2ba32a3
Fix: remove list() from _run_operation, add to retry call; fix ruff f…
nagasrisai Apr 2, 2026
6f3d72d
Apply ruff format: split long lines, wrap method signature and retry …
nagasrisai Apr 2, 2026
351fea5
Apply ruff format: split long dicts in test helpers
nagasrisai Apr 2, 2026
a448191
Fix ruff: reformat with line-length=110 (Airflow project standard)
nagasrisai Apr 2, 2026
49c75ce
Fix mypy: cast(list,...) in _run_operation; fix ruff: use line-length…
nagasrisai Apr 2, 2026
4896b4f
Fix ruff: use cast("list",...) string-quoted form (TC rule)
nagasrisai Apr 2, 2026
b97c8ac
Merge branch 'main' into feat/salesforce-bulk-transient-retry
nagasrisai Apr 2, 2026
400168f
Merge branch 'main' into feat/salesforce-bulk-transient-retry
nagasrisai Apr 2, 2026
bf41e32
Merge branch 'main' into feat/salesforce-bulk-transient-retry
nagasrisai Apr 3, 2026
6e067ea
Merge branch 'main' into feat/salesforce-bulk-transient-retry
nagasrisai Apr 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 101 additions & 24 deletions providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import time
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, cast

Expand All @@ -29,6 +30,13 @@

from airflow.providers.common.compat.sdk import Context

# Salesforce error statusCode values that indicate a transient server-side
# condition rather than a permanent data problem. Records that fail with one of
# these codes can reasonably be re-submitted after a short delay.
_DEFAULT_TRANSIENT_ERROR_CODES: frozenset[str] = frozenset(
{"UNABLE_TO_LOCK_ROW", "API_TEMPORARILY_UNAVAILABLE"}
)


class SalesforceBulkOperator(BaseOperator):
"""
Expand All @@ -46,6 +54,14 @@ class SalesforceBulkOperator(BaseOperator):
:param batch_size: number of records to assign for each batch in the job
:param use_serial: Process batches in serial mode
:param salesforce_conn_id: The :ref:`Salesforce Connection id <howto/connection:salesforce>`.
:param max_retries: Number of times to re-submit records that failed with a
transient error code such as ``UNABLE_TO_LOCK_ROW`` or
``API_TEMPORARILY_UNAVAILABLE``. Set to ``0`` (the default) to disable
automatic retries.
:param bulk_retry_delay: Seconds to wait before each retry attempt within the Bulk API retry loop. Defaults to ``5``.
:param transient_error_codes: Collection of Salesforce error ``statusCode``
values that should trigger a retry. Defaults to
``{"UNABLE_TO_LOCK_ROW", "API_TEMPORARILY_UNAVAILABLE"}``.
"""

template_fields: Sequence[str] = ("object_name", "payload", "external_id_field")
Expand All @@ -62,6 +78,9 @@ def __init__(
batch_size: int = 10000,
use_serial: bool = False,
salesforce_conn_id: str = "salesforce_default",
max_retries: int = 0,
bulk_retry_delay: float = 5.0,
transient_error_codes: Iterable[str] = _DEFAULT_TRANSIENT_ERROR_CODES,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -72,9 +91,25 @@ def __init__(
self.batch_size = batch_size
self.use_serial = use_serial
self.salesforce_conn_id = salesforce_conn_id
self.max_retries = max_retries
self.bulk_retry_delay = bulk_retry_delay
if isinstance(transient_error_codes, str):
raise ValueError(
"'transient_error_codes' must be a non-string iterable of strings, "
f"got {transient_error_codes!r}. Wrap it in a list: [{transient_error_codes!r}]"
)
self.transient_error_codes = frozenset(transient_error_codes)
self._validate_inputs()

def _validate_inputs(self) -> None:
if self.max_retries < 0:
raise ValueError(f"'max_retries' must be a non-negative integer, got {self.max_retries!r}.")

if self.bulk_retry_delay < 0:
raise ValueError(
f"'bulk_retry_delay' must be a non-negative number, got {self.bulk_retry_delay!r}."
)

if not self.object_name:
raise ValueError("The required parameter 'object_name' cannot have an empty value.")

Expand All @@ -84,6 +119,68 @@ def _validate_inputs(self) -> None:
f"Available operations are {self.available_operations}."
)

def _run_operation(self, bulk: SFBulkHandler, payload: list) -> list:
"""Submit *payload* through the configured Bulk API operation and return the result list."""
obj = bulk.__getattr__(self.object_name)
if self.operation == "upsert":
return cast(
"list",
obj.upsert(
data=payload,
external_id_field=self.external_id_field,
batch_size=self.batch_size,
use_serial=self.use_serial,
),
)
return cast(
"list",
getattr(obj, self.operation)(
data=payload,
batch_size=self.batch_size,
use_serial=self.use_serial,
),
)

def _retry_transient_failures(self, bulk: SFBulkHandler, payload: list, result: list) -> list:
"""
Re-submit records that failed with a transient error, up to *max_retries* times.

Salesforce Bulk API results are ordered identically to the input payload, so
failed records are located by index and their retry results are written back
into the same positions.
"""
final = list(result)

for attempt in range(1, self.max_retries + 1):
retry_indices = [
i
for i, r in enumerate(final)
if not r.get("success")
and {e.get("statusCode") for e in r.get("errors", [])} & self.transient_error_codes
]

if not retry_indices:
break

self.log.warning(
"Salesforce Bulk API %s on %s: retrying %d record(s) with transient errors "
"(attempt %d/%d, waiting %.1f second(s)).",
self.operation,
self.object_name,
len(retry_indices),
attempt,
self.max_retries,
self.bulk_retry_delay,
)
time.sleep(self.bulk_retry_delay)

retry_result = list(self._run_operation(bulk, [payload[i] for i in retry_indices]))

for list_pos, original_idx in enumerate(retry_indices):
final[original_idx] = retry_result[list_pos]

return final

def execute(self, context: Context):
"""
Make an HTTP request to Salesforce Bulk API.
Expand All @@ -95,30 +192,10 @@ def execute(self, context: Context):
conn = sf_hook.get_conn()
bulk: SFBulkHandler = cast("SFBulkHandler", conn.__getattr__("bulk"))

result: Iterable = []
if self.operation == "insert":
result = bulk.__getattr__(self.object_name).insert(
data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial
)
elif self.operation == "update":
result = bulk.__getattr__(self.object_name).update(
data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial
)
elif self.operation == "upsert":
result = bulk.__getattr__(self.object_name).upsert(
data=self.payload,
external_id_field=self.external_id_field,
batch_size=self.batch_size,
use_serial=self.use_serial,
)
elif self.operation == "delete":
result = bulk.__getattr__(self.object_name).delete(
data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial
)
elif self.operation == "hard_delete":
result = bulk.__getattr__(self.object_name).hard_delete(
data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial
)
result = self._run_operation(bulk, self.payload)

if self.max_retries > 0:
result = self._retry_transient_failures(bulk, self.payload, result)

if self.do_xcom_push and result:
return result
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest import mock

from airflow.providers.salesforce.operators.bulk import SalesforceBulkOperator


def _make_op(**kwargs):
defaults = dict(
task_id="test_task",
operation="insert",
object_name="Contact",
payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
)
defaults.update(kwargs)
return SalesforceBulkOperator(**defaults)


def _transient_failure(status_code="UNABLE_TO_LOCK_ROW"):
return {
"success": False,
"errors": [{"statusCode": status_code, "message": "locked", "fields": []}],
}


def _permanent_failure():
return {
"success": False,
"errors": [
{
"statusCode": "REQUIRED_FIELD_MISSING",
"message": "missing",
"fields": ["Name"],
}
],
}


def _success():
return {"success": True, "errors": []}


class TestSalesforceBulkOperatorRetry:
def test_no_retry_when_max_retries_zero(self):
op = _make_op(max_retries=0)
assert op.max_retries == 0

bulk_mock = mock.MagicMock()
bulk_mock.__getattr__("Contact").insert.return_value = [_success(), _success()]

with mock.patch("airflow.providers.salesforce.operators.bulk.SalesforceHook") as hook_cls:
hook_cls.return_value.get_conn.return_value.bulk = bulk_mock
result = op.execute(context={})

assert result == [_success(), _success()]
assert bulk_mock.__getattr__("Contact").insert.call_count == 1

def test_transient_failure_is_retried(self):
op = _make_op(max_retries=2, bulk_retry_delay=0)

first_result = [_transient_failure(), _success()]
second_result = [_success()]

run_mock = mock.MagicMock(side_effect=[first_result, second_result])

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
result=first_result,
)

assert final[0] == _success()
assert final[1] == _success()
assert run_mock.call_count == 2
retry_call = run_mock.call_args_list[1]
assert retry_call == mock.call(mock.ANY, [{"FirstName": "Ada"}])

def test_permanent_failure_is_not_retried(self):
op = _make_op(max_retries=3, bulk_retry_delay=0)
result = [_permanent_failure(), _success()]

run_mock = mock.MagicMock()

with mock.patch.object(op, "_run_operation", run_mock):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
result=result,
)

run_mock.assert_not_called()
assert final[0] == _permanent_failure()

def test_retries_stop_after_max_retries(self):
op = _make_op(max_retries=2, bulk_retry_delay=0)

always_transient = [_transient_failure()]
run_mock = mock.MagicMock(return_value=always_transient)

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}],
result=always_transient,
)

assert run_mock.call_count == 2
assert final[0]["success"] is False

def test_retry_delay_is_respected(self):
op = _make_op(max_retries=1, bulk_retry_delay=30.0)

run_mock = mock.MagicMock(return_value=[_success()])

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep") as sleep_mock:
op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}],
result=[_transient_failure()],
)

sleep_mock.assert_called_once_with(30.0)

def test_custom_transient_error_codes(self):
op = _make_op(max_retries=1, bulk_retry_delay=0, transient_error_codes=["MY_CUSTOM_ERROR"])
assert op.transient_error_codes == frozenset({"MY_CUSTOM_ERROR"})

custom_failure = {
"success": False,
"errors": [{"statusCode": "MY_CUSTOM_ERROR", "message": "custom"}],
}
run_mock = mock.MagicMock(return_value=[_success()])

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}],
result=[custom_failure],
)

run_mock.assert_called_once()
assert final[0] == _success()

def test_api_temporarily_unavailable_is_retried(self):
op = _make_op(max_retries=1, bulk_retry_delay=0)
failure = _transient_failure("API_TEMPORARILY_UNAVAILABLE")
run_mock = mock.MagicMock(return_value=[_success()])

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}],
result=[failure],
)

run_mock.assert_called_once()
assert final[0] == _success()

def test_mixed_failures_only_retries_transient(self):
op = _make_op(max_retries=1, bulk_retry_delay=0)
payload = [{"FirstName": "A"}, {"FirstName": "B"}, {"FirstName": "C"}]
initial = [_transient_failure(), _permanent_failure(), _success()]

run_mock = mock.MagicMock(return_value=[_success()])

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=payload,
result=initial,
)

run_mock.assert_called_once_with(mock.ANY, [{"FirstName": "A"}])
assert final[0] == _success()
assert final[1] == _permanent_failure()
assert final[2] == _success()
Loading