Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.
from __future__ import annotations

from collections.abc import Iterable
import time
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, cast

from airflow.providers.common.compat.sdk import BaseOperator
Expand All @@ -29,6 +30,13 @@

from airflow.providers.common.compat.sdk import Context

# Salesforce error statusCode values that indicate a transient server-side
# condition rather than a permanent data problem. Records that fail with one of
# these codes can reasonably be re-submitted after a short delay.
_DEFAULT_TRANSIENT_ERROR_CODES: frozenset[str] = frozenset(
{"UNABLE_TO_LOCK_ROW", "API_TEMPORARILY_UNAVAILABLE"}
)


class SalesforceBulkOperator(BaseOperator):
"""
Expand All @@ -46,8 +54,18 @@ class SalesforceBulkOperator(BaseOperator):
:param batch_size: number of records to assign for each batch in the job
:param use_serial: Process batches in serial mode
:param salesforce_conn_id: The :ref:`Salesforce Connection id <howto/connection:salesforce>`.
:param max_retries: Number of times to re-submit records that failed with a
transient error code such as ``UNABLE_TO_LOCK_ROW`` or
``API_TEMPORARILY_UNAVAILABLE``. Set to ``0`` (the default) to disable
automatic retries.
:param retry_delay: Seconds to wait before each retry attempt. Defaults to ``5``.
:param transient_error_codes: Collection of Salesforce error ``statusCode``
values that should trigger a retry. Defaults to
``{"UNABLE_TO_LOCK_ROW", "API_TEMPORARILY_UNAVAILABLE"}``.
"""

template_fields: Sequence[str] = ("object_name", "payload", "external_id_field")

available_operations = ("insert", "update", "upsert", "delete", "hard_delete")

def __init__(
Expand All @@ -60,6 +78,9 @@ def __init__(
batch_size: int = 10000,
use_serial: bool = False,
salesforce_conn_id: str = "salesforce_default",
max_retries: int = 0,
retry_delay: float = 5.0,
transient_error_codes: Iterable[str] = _DEFAULT_TRANSIENT_ERROR_CODES,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -70,6 +91,9 @@ def __init__(
self.batch_size = batch_size
self.use_serial = use_serial
self.salesforce_conn_id = salesforce_conn_id
self.max_retries = max_retries
self.retry_delay = retry_delay
self.transient_error_codes = frozenset(transient_error_codes)
self._validate_inputs()

def _validate_inputs(self) -> None:
Expand All @@ -82,6 +106,65 @@ def _validate_inputs(self) -> None:
f"Available operations are {self.available_operations}."
)

def _run_operation(self, bulk: SFBulkHandler, payload: list) -> list:
"""Submit *payload* through the configured Bulk API operation and return the result list."""
obj = bulk.__getattr__(self.object_name)
if self.operation == "upsert":
return list(
obj.upsert(
data=payload,
external_id_field=self.external_id_field,
batch_size=self.batch_size,
use_serial=self.use_serial,
)
)
return list(
getattr(obj, self.operation)(
data=payload,
batch_size=self.batch_size,
use_serial=self.use_serial,
)
)

def _retry_transient_failures(self, bulk: SFBulkHandler, payload: list, result: list) -> list:
"""Re-submit records that failed with a transient error, up to *max_retries* times.

Salesforce Bulk API results are ordered identically to the input payload, so
failed records are located by index and their retry results are written back
into the same positions.
"""
final = list(result)

for attempt in range(1, self.max_retries + 1):
retry_indices = [
i
for i, r in enumerate(final)
if not r.get("success")
and {e.get("statusCode") for e in r.get("errors", [])} & self.transient_error_codes
]

if not retry_indices:
break

self.log.warning(
"Salesforce Bulk API %s on %s: retrying %d record(s) with transient errors "
"(attempt %d/%d, waiting %.1f second(s)).",
self.operation,
self.object_name,
len(retry_indices),
attempt,
self.max_retries,
self.retry_delay,
)
time.sleep(self.retry_delay)

retry_result = self._run_operation(bulk, [payload[i] for i in retry_indices])

for list_pos, original_idx in enumerate(retry_indices):
final[original_idx] = retry_result[list_pos]

return final

def execute(self, context: Context):
"""
Make an HTTP request to Salesforce Bulk API.
Expand All @@ -93,30 +176,10 @@ def execute(self, context: Context):
conn = sf_hook.get_conn()
bulk: SFBulkHandler = cast("SFBulkHandler", conn.__getattr__("bulk"))

result: Iterable = []
if self.operation == "insert":
result = bulk.__getattr__(self.object_name).insert(
data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial
)
elif self.operation == "update":
result = bulk.__getattr__(self.object_name).update(
data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial
)
elif self.operation == "upsert":
result = bulk.__getattr__(self.object_name).upsert(
data=self.payload,
external_id_field=self.external_id_field,
batch_size=self.batch_size,
use_serial=self.use_serial,
)
elif self.operation == "delete":
result = bulk.__getattr__(self.object_name).delete(
data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial
)
elif self.operation == "hard_delete":
result = bulk.__getattr__(self.object_name).hard_delete(
data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial
)
result = self._run_operation(bulk, self.payload)

if self.max_retries > 0:
result = self._retry_transient_failures(bulk, self.payload, result)

if self.do_xcom_push and result:
return result
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest import mock

import pytest

from airflow.providers.salesforce.operators.bulk import SalesforceBulkOperator


def _make_op(**kwargs):
defaults = dict(
task_id="test_task",
operation="insert",
object_name="Contact",
payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
)
defaults.update(kwargs)
return SalesforceBulkOperator(**defaults)


def _transient_failure(status_code="UNABLE_TO_LOCK_ROW"):
return {"success": False, "errors": [{"statusCode": status_code, "message": "locked", "fields": []}]}


def _permanent_failure():
return {"success": False, "errors": [{"statusCode": "REQUIRED_FIELD_MISSING", "message": "missing", "fields": ["Name"]}]}


def _success():
return {"success": True, "errors": []}


class TestSalesforceBulkOperatorRetry:
def test_no_retry_when_max_retries_zero(self):
op = _make_op(max_retries=0)
assert op.max_retries == 0

bulk_mock = mock.MagicMock()
bulk_mock.__getattr__("Contact").insert.return_value = [_success(), _success()]

with mock.patch("airflow.providers.salesforce.operators.bulk.SalesforceHook") as hook_cls:
hook_cls.return_value.get_conn.return_value.__getattr__.return_value = bulk_mock
result = op.execute(context={})

assert result == [_success(), _success()]
assert bulk_mock.__getattr__("Contact").insert.call_count == 1

def test_transient_failure_is_retried(self):
op = _make_op(max_retries=2, retry_delay=0)

first_result = [_transient_failure(), _success()]
second_result = [_success()]

run_mock = mock.MagicMock(side_effect=[first_result, second_result])

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
result=first_result,
)

assert final[0] == _success()
assert final[1] == _success()
assert run_mock.call_count == 2
_, retry_call = run_mock.call_args_list
assert retry_call == mock.call(mock.ANY, [{"FirstName": "Ada"}])

def test_permanent_failure_is_not_retried(self):
op = _make_op(max_retries=3, retry_delay=0)
result = [_permanent_failure(), _success()]

run_mock = mock.MagicMock()

with mock.patch.object(op, "_run_operation", run_mock):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}],
result=result,
)

run_mock.assert_not_called()
assert final[0] == _permanent_failure()

def test_retries_stop_after_max_retries(self):
op = _make_op(max_retries=2, retry_delay=0)

always_transient = [_transient_failure()]
run_mock = mock.MagicMock(return_value=always_transient)

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}],
result=always_transient,
)

assert run_mock.call_count == 2
assert final[0]["success"] is False

def test_retry_delay_is_respected(self):
op = _make_op(max_retries=1, retry_delay=30.0)

run_mock = mock.MagicMock(return_value=[_success()])

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep") as sleep_mock:
op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}],
result=[_transient_failure()],
)

sleep_mock.assert_called_once_with(30.0)

def test_custom_transient_error_codes(self):
op = _make_op(max_retries=1, retry_delay=0, transient_error_codes=["MY_CUSTOM_ERROR"])
assert op.transient_error_codes == frozenset({"MY_CUSTOM_ERROR"})

custom_failure = {"success": False, "errors": [{"statusCode": "MY_CUSTOM_ERROR", "message": "custom"}]}
run_mock = mock.MagicMock(return_value=[_success()])

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}],
result=[custom_failure],
)

run_mock.assert_called_once()
assert final[0] == _success()

def test_api_temporarily_unavailable_is_retried(self):
op = _make_op(max_retries=1, retry_delay=0)
failure = _transient_failure("API_TEMPORARILY_UNAVAILABLE")
run_mock = mock.MagicMock(return_value=[_success()])

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=[{"FirstName": "Ada"}],
result=[failure],
)

run_mock.assert_called_once()
assert final[0] == _success()

def test_mixed_failures_only_retries_transient(self):
op = _make_op(max_retries=1, retry_delay=0)
payload = [{"FirstName": "A"}, {"FirstName": "B"}, {"FirstName": "C"}]
initial = [_transient_failure(), _permanent_failure(), _success()]

run_mock = mock.MagicMock(return_value=[_success()])

with mock.patch.object(op, "_run_operation", run_mock):
with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"):
final = op._retry_transient_failures(
bulk=mock.MagicMock(),
payload=payload,
result=initial,
)

_, retry_call = run_mock.call_args_list[0], run_mock.call_args_list[-1]
run_mock.assert_called_once_with(mock.ANY, [{"FirstName": "A"}])
assert final[0] == _success()
assert final[1] == _permanent_failure()
assert final[2] == _success()