diff --git a/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py b/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py index 7b5d21030db02..720a3c6ad9d69 100644 --- a/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py +++ b/providers/salesforce/src/airflow/providers/salesforce/operators/bulk.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import time from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, cast @@ -29,6 +30,13 @@ from airflow.providers.common.compat.sdk import Context +# Salesforce error statusCode values that indicate a transient server-side +# condition rather than a permanent data problem. Records that fail with one of +# these codes can reasonably be re-submitted after a short delay. +_DEFAULT_TRANSIENT_ERROR_CODES: frozenset[str] = frozenset( + {"UNABLE_TO_LOCK_ROW", "API_TEMPORARILY_UNAVAILABLE"} +) + class SalesforceBulkOperator(BaseOperator): """ @@ -46,6 +54,14 @@ class SalesforceBulkOperator(BaseOperator): :param batch_size: number of records to assign for each batch in the job :param use_serial: Process batches in serial mode :param salesforce_conn_id: The :ref:`Salesforce Connection id `. + :param max_retries: Number of times to re-submit records that failed with a + transient error code such as ``UNABLE_TO_LOCK_ROW`` or + ``API_TEMPORARILY_UNAVAILABLE``. Set to ``0`` (the default) to disable + automatic retries. + :param bulk_retry_delay: Seconds to wait before each retry attempt within the Bulk API retry loop. Defaults to ``5``. + :param transient_error_codes: Collection of Salesforce error ``statusCode`` + values that should trigger a retry. Defaults to + ``{"UNABLE_TO_LOCK_ROW", "API_TEMPORARILY_UNAVAILABLE"}``. """ template_fields: Sequence[str] = ("object_name", "payload", "external_id_field") @@ -62,6 +78,9 @@ def __init__( batch_size: int = 10000, use_serial: bool = False, salesforce_conn_id: str = "salesforce_default", + max_retries: int = 0, + bulk_retry_delay: float = 5.0, + transient_error_codes: Iterable[str] = _DEFAULT_TRANSIENT_ERROR_CODES, **kwargs, ) -> None: super().__init__(**kwargs) @@ -72,9 +91,25 @@ def __init__( self.batch_size = batch_size self.use_serial = use_serial self.salesforce_conn_id = salesforce_conn_id + self.max_retries = max_retries + self.bulk_retry_delay = bulk_retry_delay + if isinstance(transient_error_codes, str): + raise ValueError( + "'transient_error_codes' must be a non-string iterable of strings, " + f"got {transient_error_codes!r}. Wrap it in a list: [{transient_error_codes!r}]" + ) + self.transient_error_codes = frozenset(transient_error_codes) self._validate_inputs() def _validate_inputs(self) -> None: + if self.max_retries < 0: + raise ValueError(f"'max_retries' must be a non-negative integer, got {self.max_retries!r}.") + + if self.bulk_retry_delay < 0: + raise ValueError( + f"'bulk_retry_delay' must be a non-negative number, got {self.bulk_retry_delay!r}." + ) + if not self.object_name: raise ValueError("The required parameter 'object_name' cannot have an empty value.") @@ -84,6 +119,68 @@ def _validate_inputs(self) -> None: f"Available operations are {self.available_operations}." ) + def _run_operation(self, bulk: SFBulkHandler, payload: list) -> list: + """Submit *payload* through the configured Bulk API operation and return the result list.""" + obj = bulk.__getattr__(self.object_name) + if self.operation == "upsert": + return cast( + "list", + obj.upsert( + data=payload, + external_id_field=self.external_id_field, + batch_size=self.batch_size, + use_serial=self.use_serial, + ), + ) + return cast( + "list", + getattr(obj, self.operation)( + data=payload, + batch_size=self.batch_size, + use_serial=self.use_serial, + ), + ) + + def _retry_transient_failures(self, bulk: SFBulkHandler, payload: list, result: list) -> list: + """ + Re-submit records that failed with a transient error, up to *max_retries* times. + + Salesforce Bulk API results are ordered identically to the input payload, so + failed records are located by index and their retry results are written back + into the same positions. + """ + final = list(result) + + for attempt in range(1, self.max_retries + 1): + retry_indices = [ + i + for i, r in enumerate(final) + if not r.get("success") + and {e.get("statusCode") for e in r.get("errors", [])} & self.transient_error_codes + ] + + if not retry_indices: + break + + self.log.warning( + "Salesforce Bulk API %s on %s: retrying %d record(s) with transient errors " + "(attempt %d/%d, waiting %.1f second(s)).", + self.operation, + self.object_name, + len(retry_indices), + attempt, + self.max_retries, + self.bulk_retry_delay, + ) + time.sleep(self.bulk_retry_delay) + + retry_result = list(self._run_operation(bulk, [payload[i] for i in retry_indices])) + + for list_pos, original_idx in enumerate(retry_indices): + final[original_idx] = retry_result[list_pos] + + return final + def execute(self, context: Context): """ Make an HTTP request to Salesforce Bulk API. @@ -95,30 +192,10 @@ def execute(self, context: Context): conn = sf_hook.get_conn() bulk: SFBulkHandler = cast("SFBulkHandler", conn.__getattr__("bulk")) - result: Iterable = [] - if self.operation == "insert": - result = bulk.__getattr__(self.object_name).insert( - data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial - ) - elif self.operation == "update": - result = bulk.__getattr__(self.object_name).update( - data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial - ) - elif self.operation == "upsert": - result = bulk.__getattr__(self.object_name).upsert( - data=self.payload, - external_id_field=self.external_id_field, - batch_size=self.batch_size, - use_serial=self.use_serial, - ) - elif self.operation == "delete": - result = bulk.__getattr__(self.object_name).delete( - data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial - ) - elif self.operation == "hard_delete": - result = bulk.__getattr__(self.object_name).hard_delete( - data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial - ) + result = self._run_operation(bulk, self.payload) + + if self.max_retries > 0: + result = self._retry_transient_failures(bulk, self.payload, result) if self.do_xcom_push and result: return result diff --git a/providers/salesforce/tests/unit/salesforce/operators/test_bulk_retry.py b/providers/salesforce/tests/unit/salesforce/operators/test_bulk_retry.py new file mode 100644 index 0000000000000..d373c208a3990 --- /dev/null +++ b/providers/salesforce/tests/unit/salesforce/operators/test_bulk_retry.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.salesforce.operators.bulk import SalesforceBulkOperator + + +def _make_op(**kwargs): + defaults = dict( + task_id="test_task", + operation="insert", + object_name="Contact", + payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}], + ) + defaults.update(kwargs) + return SalesforceBulkOperator(**defaults) + + +def _transient_failure(status_code="UNABLE_TO_LOCK_ROW"): + return { + "success": False, + "errors": [{"statusCode": status_code, "message": "locked", "fields": []}], + } + + +def _permanent_failure(): + return { + "success": False, + "errors": [ + { + "statusCode": "REQUIRED_FIELD_MISSING", + "message": "missing", + "fields": ["Name"], + } + ], + } + + +def _success(): + return {"success": True, "errors": []} + + +class TestSalesforceBulkOperatorRetry: + def test_no_retry_when_max_retries_zero(self): + op = _make_op(max_retries=0) + assert op.max_retries == 0 + + bulk_mock = mock.MagicMock() + bulk_mock.__getattr__("Contact").insert.return_value = [_success(), _success()] + + with mock.patch("airflow.providers.salesforce.operators.bulk.SalesforceHook") as hook_cls: + hook_cls.return_value.get_conn.return_value.bulk = bulk_mock + result = op.execute(context={}) + + assert result == [_success(), _success()] + assert bulk_mock.__getattr__("Contact").insert.call_count == 1 + + def test_transient_failure_is_retried(self): + op = _make_op(max_retries=2, bulk_retry_delay=0) + + first_result = [_transient_failure(), _success()] + second_result = [_success()] + + run_mock = mock.MagicMock(side_effect=[first_result, second_result]) + + with mock.patch.object(op, "_run_operation", run_mock): + with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"): + final = op._retry_transient_failures( + bulk=mock.MagicMock(), + payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}], + result=first_result, + ) + + assert final[0] == _success() + assert final[1] == _success() + assert run_mock.call_count == 2 + retry_call = run_mock.call_args_list[1] + assert retry_call == mock.call(mock.ANY, [{"FirstName": "Ada"}]) + + def test_permanent_failure_is_not_retried(self): + op = _make_op(max_retries=3, bulk_retry_delay=0) + result = [_permanent_failure(), _success()] + + run_mock = mock.MagicMock() + + with mock.patch.object(op, "_run_operation", run_mock): + final = op._retry_transient_failures( + bulk=mock.MagicMock(), + payload=[{"FirstName": "Ada"}, {"FirstName": "Grace"}], + result=result, + ) + + run_mock.assert_not_called() + assert final[0] == _permanent_failure() + + def test_retries_stop_after_max_retries(self): + op = _make_op(max_retries=2, bulk_retry_delay=0) + + always_transient = [_transient_failure()] + run_mock = mock.MagicMock(return_value=always_transient) + + with mock.patch.object(op, "_run_operation", run_mock): + with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"): + final = op._retry_transient_failures( + bulk=mock.MagicMock(), + payload=[{"FirstName": "Ada"}], + result=always_transient, + ) + + assert run_mock.call_count == 2 + assert final[0]["success"] is False + + def test_retry_delay_is_respected(self): + op = _make_op(max_retries=1, bulk_retry_delay=30.0) + + run_mock = mock.MagicMock(return_value=[_success()]) + + with mock.patch.object(op, "_run_operation", run_mock): + with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep") as sleep_mock: + op._retry_transient_failures( + bulk=mock.MagicMock(), + payload=[{"FirstName": "Ada"}], + result=[_transient_failure()], + ) + + sleep_mock.assert_called_once_with(30.0) + + def test_custom_transient_error_codes(self): + op = _make_op(max_retries=1, bulk_retry_delay=0, transient_error_codes=["MY_CUSTOM_ERROR"]) + assert op.transient_error_codes == frozenset({"MY_CUSTOM_ERROR"}) + + custom_failure = { + "success": False, + "errors": [{"statusCode": "MY_CUSTOM_ERROR", "message": "custom"}], + } + run_mock = mock.MagicMock(return_value=[_success()]) + + with mock.patch.object(op, "_run_operation", run_mock): + with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"): + final = op._retry_transient_failures( + bulk=mock.MagicMock(), + payload=[{"FirstName": "Ada"}], + result=[custom_failure], + ) + + run_mock.assert_called_once() + assert final[0] == _success() + + def test_api_temporarily_unavailable_is_retried(self): + op = _make_op(max_retries=1, bulk_retry_delay=0) + failure = _transient_failure("API_TEMPORARILY_UNAVAILABLE") + run_mock = mock.MagicMock(return_value=[_success()]) + + with mock.patch.object(op, "_run_operation", run_mock): + with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"): + final = op._retry_transient_failures( + bulk=mock.MagicMock(), + payload=[{"FirstName": "Ada"}], + result=[failure], + ) + + run_mock.assert_called_once() + assert final[0] == _success() + + def test_mixed_failures_only_retries_transient(self): + op = _make_op(max_retries=1, bulk_retry_delay=0) + payload = [{"FirstName": "A"}, {"FirstName": "B"}, {"FirstName": "C"}] + initial = [_transient_failure(), _permanent_failure(), _success()] + + run_mock = mock.MagicMock(return_value=[_success()]) + + with mock.patch.object(op, "_run_operation", run_mock): + with mock.patch("airflow.providers.salesforce.operators.bulk.time.sleep"): + final = op._retry_transient_failures( + bulk=mock.MagicMock(), + payload=payload, + result=initial, + ) + + run_mock.assert_called_once_with(mock.ANY, [{"FirstName": "A"}]) + assert final[0] == _success() + assert final[1] == _permanent_failure() + assert final[2] == _success()