From 398bc3534486fe7abebf6340c568e596a45c1e07 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 14 Apr 2026 11:37:49 -0700 Subject: [PATCH] Decouple metric fetch errors from trial status in Orchestrator (#5119) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5119 Design doc: D98741656 When `fetch_trials_data_results` returned a `MetricFetchE` for an optimization config metric, the orchestrator marked the trial as ABANDONED. This discarded good data, inflated the failure rate, and was inconsistent with the Client layer which keeps trials COMPLETED with incomplete metrics via `MetricAvailability` (D93924193). This diff removes the trial abandonment behavior. Metric fetch errors are now logged (with traceback via `logger.exception`) but trial status is unchanged. `MetricAvailability` tracks data completeness, and the failure rate check uses it to detect persistent metric issues. Changes: - `_fetch_and_process_trials_data_results`: Removed the branch that marked trials ABANDONED for metric fetch errors and the separate `is_available_while_running` branch. All metric fetch errors are now simply logged and the method continues. The `_report_metric_fetch_e` hook is still called so subclasses (e.g. `AxSweepOrchestrator`) can react to errors (create pastes, build error tables, etc.). - `error_if_failure_rate_exceeded`: Merged `_check_if_failure_rate_exceeded` into this method to avoid duplicate computation. Now counts both runner failures (FAILED/ABANDONED) and metric-incomplete trials (via `compute_metric_availability`) toward the failure rate. - `_get_failure_rate_exceeded_error`: Rewritten with an actionable error message listing runner failures, metric-incomplete trials, missing metrics, and affected trial indices. - Removed dead code: `_mark_err_trial_status`, `_num_trials_bad_due_to_err`, `_num_metric_fetch_e_encountered`, `_check_if_failure_rate_exceeded`, `METRIC_FETCH_ERR_MESSAGE`. - Kept `_report_metric_fetch_e` as a no-op hook so subclasses like `AxSweepOrchestrator` can still react to metric fetch errors. - Updated telemetry (`OrchestratorCompletedRecord`) to use `_count_metric_incomplete_trials` (via `compute_metric_availability`) for both `num_metric_fetch_e_encountered` and `num_trials_bad_due_to_err`. - Updated `AxSweepOrchestrator` test assertions: trials now stay COMPLETED (not ABANDONED) after metric fetch errors. - `Metric.recoverable_exceptions` and `Metric.is_recoverable_fetch_e` are kept for now since `pts/` metrics still reference them; cleanup will follow in a separate diff. Differential Revision: D98924467 --- ax/orchestration/orchestrator.py | 283 +++++++++++--------- ax/orchestration/tests/test_orchestrator.py | 144 ++++------ 2 files changed, 203 insertions(+), 224 deletions(-) diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index d1b6e1f827b..95e1672753b 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -32,6 +32,7 @@ from ax.core.runner import Runner from ax.core.trial import Trial from ax.core.trial_status import TrialStatus +from ax.core.utils import compute_metric_availability, MetricAvailability from ax.exceptions.core import ( AxError, DataRequiredError, @@ -80,13 +81,6 @@ "of an optimization and if at least {min_failed} trials have been " "failed/abandoned, potentially automatically due to issues with the trial." ) -METRIC_FETCH_ERR_MESSAGE = ( - "A majority of the trial failures encountered are due to metric fetching errors. " - "This could mean the metrics are flaky, broken, or misconfigured. Please check " - "that the trial processes/jobs are successfully producing the expected metrics and " - "that the metric is correctly configured." -) - EXPECTED_STAGED_MSG = ( "Expected all trials to be in status {expected} after running or staging, " "found {t_idx_to_status}." @@ -191,13 +185,6 @@ class Orchestrator(WithDBSettingsBase, BestPointMixin): # Saved as a property so that it can be accessed after optimization is complex (ex. # for global stopping saving calculation). _num_remaining_requested_trials: int = 0 - # Total number of MetricFetchEs encountered during the course of optimization. Note - # this is different from and may be greater than the number of trials that have - # been marked either FAILED or ABANDONED due to metric fetching errors. - _num_metric_fetch_e_encountered: int = 0 - # Number of trials that have been marked either FAILED or ABANDONED due to - # MetricFetchE being encountered during _fetch_and_process_trials_data_results - _num_trials_bad_due_to_err: int = 0 # Keeps track of whether the allowed failure rate has been exceeded during # the optimization. If true, allows any pending trials to finish and raises # an error through self._complete_optimization. @@ -1073,83 +1060,63 @@ def summarize_final_result(self) -> OptimizationResult: """ return OptimizationResult() - def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool: - """Checks if the failure rate (set in Orchestrator options) has been exceeded at - any point during the optimization. + def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None: + """Raises an exception if the failure rate (set in Orchestrator options) has + been exceeded at any point during the optimization. - NOTE: Both FAILED and ABANDONED trial statuses count towards the failure rate. + The failure rate is computed as the ratio of "bad" trials to total trials + created by this orchestrator. "Bad" trials include: + - Execution failures: trials with FAILED or ABANDONED status. + - Metric-incomplete trials: COMPLETED trials whose metric data is not + fully available (as determined by ``compute_metric_availability``). Args: force_check: Indicates whether to force a failure-rate check regardless of the number of trials that have been executed. If False (default), the check will be skipped if the optimization has fewer than - five failed trials. If True, the check will be performed unless there - are 0 failures. + ``min_failed_trials_for_failure_rate_check`` bad trials. If True, the + check will be performed unless there are 0 bad trials. + """ + # Count runner-level failures (FAILED + ABANDONED). + num_execution_failures = self._num_bad_in_orchestrator() - Effect on state: - If the failure rate has been exceeded, a warning is logged and the private - attribute `_failure_rate_has_been_exceeded` is set to True, which causes the - `_get_max_pending_trials` to return zero, so that no further trials are - scheduled and an error is raised at the end of the optimization. + # Count completed trials with incomplete metric availability. + num_metric_incomplete, missing_metrics_by_trial = ( + self._get_metric_incomplete_trials() + ) - Returns: - Boolean representing whether the failure rate has been exceeded. - """ - if self._failure_rate_has_been_exceeded: - return True + num_bad = num_execution_failures + num_metric_incomplete - num_bad_in_orchestrator = self._num_bad_in_orchestrator() - # skip check if 0 failures - if num_bad_in_orchestrator == 0: - return False + if not self._failure_rate_has_been_exceeded: + # Skip check if 0 bad trials. + if num_bad == 0: + return - # skip check if fewer than min_failed_trials_for_failure_rate_check failures - # unless force_check is True - if ( - num_bad_in_orchestrator - < self.options.min_failed_trials_for_failure_rate_check - and not force_check - ): - return False + # Skip check if fewer than min threshold unless force_check. + if ( + num_bad < self.options.min_failed_trials_for_failure_rate_check + and not force_check + ): + return - num_ran_in_orchestrator = self._num_ran_in_orchestrator() - failure_rate_exceeded = ( - num_bad_in_orchestrator / num_ran_in_orchestrator - ) > self.options.tolerated_trial_failure_rate + num_ran_in_orchestrator = self._num_ran_in_orchestrator() + failure_rate_exceeded = ( + num_bad / num_ran_in_orchestrator + ) > self.options.tolerated_trial_failure_rate + + if not failure_rate_exceeded: + return - if failure_rate_exceeded: - if self._num_trials_bad_due_to_err > num_bad_in_orchestrator / 2: - self.logger.warning( - "MetricFetchE INFO: Sweep aborted due to an exceeded error rate, " - "which was primarily caused by failure to fetch metrics. Please " - "check if anything could cause your metrics to be flaky or " - "broken." - ) # NOTE: this private attribute causes `_get_max_pending_trials` to # return zero, which causes no further trials to be scheduled. self._failure_rate_has_been_exceeded = True - return True - - return False - - def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None: - """Raises an exception if the failure rate (set in Orchestrator options) has - been exceeded at any point during the optimization. - NOTE: Both FAILED and ABANDONED trial statuses count towards the failure rate. - - Args: - force_check: Indicates whether to force a failure-rate check - regardless of the number of trials that have been executed. If False - (default), the check will be skipped if the optimization has fewer than - five failed trials. If True, the check will be performed unless there - are 0 failures. - """ - if self._check_if_failure_rate_exceeded(force_check=force_check): - raise self._get_failure_rate_exceeded_error( - num_bad_in_orchestrator=self._num_bad_in_orchestrator(), - num_ran_in_orchestrator=self._num_ran_in_orchestrator(), - ) + raise self._get_failure_rate_exceeded_error( + num_execution_failures=num_execution_failures, + num_metric_incomplete=num_metric_incomplete, + num_ran_in_orchestrator=self._num_ran_in_orchestrator(), + missing_metrics_by_trial=missing_metrics_by_trial, + ) def _error_if_status_quo_infeasible(self) -> None: """Raises an exception if the status-quo arm is infeasible and the @@ -2032,9 +1999,13 @@ def _fetch_and_process_trials_data_results( self, trial_indices: Iterable[int], ) -> dict[int, dict[str, MetricFetchResult]]: - """ - Fetches results from experiment and modifies trial statuses depending on - success or failure. + """Fetch trial data results and log any metric fetch errors. + + Metric fetch errors are logged but do NOT change trial status. + ``MetricAvailability`` (computed via ``compute_metric_availability``) + tracks data completeness separately, and the failure rate check in + ``error_if_failure_rate_exceeded`` uses it to detect persistent + metric issues. """ try: @@ -2085,41 +2056,12 @@ def _fetch_and_process_trials_data_results( f"Failed to fetch {metric_name} for trial {trial_index} with " f"status {status}, found {metric_fetch_e}." ) - self._num_metric_fetch_e_encountered += 1 self._report_metric_fetch_e( trial=self.experiment.trials[trial_index], metric_name=metric_name, metric_fetch_e=metric_fetch_e, ) - # If the fetch failure was for a metric in the optimization config (an - # objective or constraint) mark the trial as failed - optimization_config = self.experiment.optimization_config - if ( - optimization_config is not None - and metric_name in optimization_config.metric_names - and not self.experiment.metrics[metric_name].is_recoverable_fetch_e( - metric_fetch_e=metric_fetch_e - ) - ): - status = self._mark_err_trial_status( - trial=self.experiment.trials[trial_index], - metric_name=metric_name, - metric_fetch_e=metric_fetch_e, - ) - self.logger.warning( - f"MetricFetchE INFO: Because {metric_name} is an objective, " - f"marking trial {trial_index} as {status}." - ) - self._num_trials_bad_due_to_err += 1 - continue - - self.logger.info( - "MetricFetchE INFO: Continuing optimization even though " - "MetricFetchE encountered." - ) - continue - return results def _report_metric_fetch_e( @@ -2128,39 +2070,122 @@ def _report_metric_fetch_e( metric_name: str, metric_fetch_e: MetricFetchE, ) -> None: + """Hook for subclasses to react to metric fetch errors. + + Called once per metric fetch error during + ``_fetch_and_process_trials_data_results``. The default + implementation is a no-op; override in subclasses to add custom + reporting (e.g., creating error tables or pastes). + """ pass - def _mark_err_trial_status( + def _get_metric_incomplete_trials( self, - trial: BaseTrial, - metric_name: str | None = None, - metric_fetch_e: MetricFetchE | None = None, - ) -> TrialStatus: - trial.mark_abandoned( - reason=metric_fetch_e.message if metric_fetch_e else None, unsafe=True + ) -> tuple[int, dict[int, set[str]]]: + """Count completed trials with incomplete metric availability and identify + which metrics are missing for each. + + Required metrics include optimization config metrics and any explicitly + defined early stopping strategy metrics. + + Returns: + A tuple of (num_metric_incomplete, missing_metrics_by_trial) where + missing_metrics_by_trial maps trial index to the set of missing + metric names. + """ + opt_config = self.experiment.optimization_config + if opt_config is None: + return 0, {} + + completed_trial_indices = [ + t.index + for t in self.experiment.trials.values() + if t.status == TrialStatus.COMPLETED + and t.index >= self._num_preexisting_trials + ] + if len(completed_trial_indices) == 0: + return 0, {} + + required_metrics = set(opt_config.metric_names) + + # Include explicitly defined early stopping strategy metrics. + # ESS stores metric *signatures*, which may differ from metric names, + # so we resolve them via experiment.signature_to_metric. + ess = self.options.early_stopping_strategy + ess_signatures = ess.metric_signatures if ess is not None else None + if ess_signatures is not None: + for sig in ess_signatures: + metric = self.experiment.signature_to_metric[sig] + required_metrics.add(metric.name) + + metric_availabilities = compute_metric_availability( + experiment=self.experiment, + trial_indices=completed_trial_indices, + metric_names=required_metrics, ) - return TrialStatus.ABANDONED + + # Identify which specific metrics are missing per trial. + data = self.experiment.lookup_data(trial_indices=completed_trial_indices) + metrics_per_trial: dict[int, set[str]] = {} + if len(data.metric_names) > 0: + df = data.full_df + for trial_idx, group in df.groupby("trial_index")["metric_name"]: + metrics_per_trial[int(trial_idx)] = set(group.unique()) + + missing_metrics_by_trial: dict[int, set[str]] = {} + for idx, avail in metric_availabilities.items(): + if avail != MetricAvailability.COMPLETE: + available = metrics_per_trial.get(idx, set()) + missing_metrics_by_trial[idx] = required_metrics - available + + return len(missing_metrics_by_trial), missing_metrics_by_trial def _get_failure_rate_exceeded_error( self, - num_bad_in_orchestrator: int, + num_execution_failures: int, + num_metric_incomplete: int, num_ran_in_orchestrator: int, + missing_metrics_by_trial: dict[int, set[str]], ) -> FailureRateExceededError: - return FailureRateExceededError( - ( - f"{METRIC_FETCH_ERR_MESSAGE}\n" - if self._num_trials_bad_due_to_err > num_bad_in_orchestrator / 2 - else "" + """Build an actionable error message describing why the failure rate was + exceeded, including runner failures, metric-incomplete trials, which + metrics are missing, and which trials are affected. + """ + num_bad = num_execution_failures + num_metric_incomplete + observed_rate = num_bad / num_ran_in_orchestrator + + parts: list[str] = [] + parts.append( + f"Failure rate exceeded: {num_bad} of {num_ran_in_orchestrator} " + f"trials were unsuccessful (observed rate: {observed_rate:.0%}, tolerance: " + f"{self.options.tolerated_trial_failure_rate:.0%}). " + f"Checks are triggered when at least " + f"{self.options.min_failed_trials_for_failure_rate_check} trials " + "are unsuccessful or at the end of the optimization." + ) + + if num_execution_failures > 0: + parts.append( + f"{num_execution_failures} trial(s) failed at the execution " + "level (FAILED or ABANDONED). Check any trial evaluation " + "processes/jobs to see why they are failing." ) - + " Orignal error message: " - + FAILURE_EXCEEDED_MSG.format( - f_rate=self.options.tolerated_trial_failure_rate, - n_failed=num_bad_in_orchestrator, - n_ran=num_ran_in_orchestrator, - min_failed=self.options.min_failed_trials_for_failure_rate_check, - observed_rate=float(num_bad_in_orchestrator) / num_ran_in_orchestrator, + + if num_metric_incomplete > 0: + all_missing: set[str] = set() + for missing in missing_metrics_by_trial.values(): + all_missing.update(missing) + affected_trials = sorted(missing_metrics_by_trial.keys()) + + parts.append( + f"{num_metric_incomplete} trial(s) have incomplete metric data. " + f"Missing metrics: {sorted(all_missing)}. " + f"Affected trials: {affected_trials}. " + "Check that your metric fetching infrastructure is healthy " + "and that the metrics are being logged correctly." ) - ) + + return FailureRateExceededError("\n".join(parts)) def _warn_if_non_terminal_trials(self) -> None: """Warns if there are any non-terminal trials on the experiment.""" diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 1532541ab4c..36779abf5c3 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -38,12 +38,7 @@ get_pending_observation_features_based_on_trial_status, ) from ax.early_stopping.strategies import BaseEarlyStoppingStrategy -from ax.exceptions.core import ( - AxError, - OptimizationComplete, - UnsupportedError, - UserInputError, -) +from ax.exceptions.core import OptimizationComplete, UnsupportedError, UserInputError from ax.exceptions.generation_strategy import AxGenerationException from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.generation_strategy.generation_strategy import ( @@ -1834,6 +1829,12 @@ def test_fetch_and_process_trials_data_results_failed_non_objective( ) def test_fetch_and_process_trials_data_results_failed_objective(self) -> None: + """Metric fetch errors on objective metrics do NOT change trial status. + + The trial remains COMPLETED, and MetricAvailability reflects the missing + data. The failure rate check uses MetricAvailability to detect persistent + metric issues. + """ gs = self.two_sobol_steps_GS orchestrator = Orchestrator( experiment=self.branin_experiment, @@ -1854,97 +1855,44 @@ def test_fetch_and_process_trials_data_results_failed_objective(self) -> None: ), self.assertLogs(logger="ax.orchestration.orchestrator") as lg, ): - # This trial will fail + # The trial completes but has incomplete metrics, triggering + # the failure rate check. with self.assertRaises(FailureRateExceededError): orchestrator.run_n_trials(max_trials=1) - self.assertTrue( - any( - re.search(r"Failed to fetch (branin|m1) for trial 0", warning) - is not None - for warning in lg.output - ) - ) + # Verify the error was logged (not the old "marking trial as ABANDONED"). self.assertTrue( any( re.search( - r"Because (branin|m1) is an objective, marking trial 0 as " - "TrialStatus.ABANDONED", - warning, - ) - is not None - for warning in lg.output - ) - ) - self.assertEqual( - orchestrator.experiment.trials[0].status, TrialStatus.ABANDONED - ) - - def test_fetch_and_process_trials_data_results_failed_objective_but_recoverable( - self, - ) -> None: - gs = self.two_sobol_steps_GS - orchestrator = Orchestrator( - experiment=self.branin_experiment, - generation_strategy=gs, - options=OrchestratorOptions( - enforce_immutable_search_space_and_opt_config=False, - **self.orchestrator_options_kwargs, - ), - db_settings=self.db_settings_if_always_needed, - ) - BraninMetric.recoverable_exceptions = {AxError, TypeError} - # we're throwing a recoverable exception because UserInputError - # is a subclass of AxError - with ( - patch( - f"{BraninMetric.__module__}.BraninMetric.f", - side_effect=UserInputError("yikes!"), - ), - patch( - f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", - return_value=False, - ), - self.assertLogs(logger="ax.orchestration.orchestrator") as lg, - ): - orchestrator.run_n_trials(max_trials=1) - self.assertTrue( - any( - re.search(r"Failed to fetch (branin|m1) for trial 0", warning) - is not None - for warning in lg.output - ), - lg.output, - ) - self.assertTrue( - any( - re.search( - "MetricFetchE INFO: Continuing optimization even though " - "MetricFetchE encountered", + r"Failed to fetch (branin|m1) for trial 0", warning, ) is not None for warning in lg.output ) ) + # Trial stays COMPLETED -- not ABANDONED. self.assertEqual( orchestrator.experiment.trials[0].status, TrialStatus.COMPLETED ) - def test_fetch_and_process_trials_data_results_failed_objective_not_recoverable( - self, - ) -> None: + def test_failure_rate_metric_incomplete(self) -> None: + """Failure rate check uses MetricAvailability to count metric-incomplete + trials and raises FailureRateExceededError with an actionable message + listing missing metrics and affected trials. + """ gs = self.two_sobol_steps_GS + tolerated_failure_rate = 0.5 + min_failed = 1 orchestrator = Orchestrator( experiment=self.branin_experiment, generation_strategy=gs, options=OrchestratorOptions( + tolerated_trial_failure_rate=tolerated_failure_rate, + min_failed_trials_for_failure_rate_check=min_failed, **self.orchestrator_options_kwargs, ), db_settings=self.db_settings_if_always_needed, ) - # we're throwing a unrecoverable exception because Exception is not subclass - # of either error type in recoverable_exceptions - BraninMetric.recoverable_exceptions = {AxError, TypeError} with ( patch( f"{BraninMetric.__module__}.BraninMetric.f", @@ -1954,33 +1902,39 @@ def test_fetch_and_process_trials_data_results_failed_objective_not_recoverable( f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", return_value=False, ), - self.assertLogs(logger="ax.orchestration.orchestrator") as lg, ): - # This trial will fail - with self.assertRaises(FailureRateExceededError): + with self.assertRaises(FailureRateExceededError) as cm: orchestrator.run_n_trials(max_trials=1) - self.assertTrue( - any( - re.search(r"Failed to fetch (branin|m1) for trial 0", warning) - is not None - for warning in lg.output - ) - ) - self.assertTrue( - any( - re.search( - r"Because (branin|m1) is an objective, marking trial 0 as " - "TrialStatus.ABANDONED", - warning, - ) - is not None - for warning in lg.output - ) - ) + + # Trial stays COMPLETED -- metric fetch errors do not change status. self.assertEqual( - orchestrator.experiment.trials[0].status, TrialStatus.ABANDONED + orchestrator.experiment.trials[0].status, TrialStatus.COMPLETED ) + # Build the expected error message from orchestrator config values. + # 1 trial ran, 0 execution failures, 1 metric-incomplete trial. + opt_config = none_throws(orchestrator.experiment.optimization_config) + opt_metric_names = sorted(opt_config.metric_names) + expected_parts = [ + ( + f"Failure rate exceeded: 1 of 1 trials were unsuccessful " + f"(observed rate: 100%, tolerance: " + f"{tolerated_failure_rate:.0%}). " + f"Checks are triggered when at least " + f"{min_failed} trials " + f"are unsuccessful or at the end of the optimization." + ), + ( + f"1 trial(s) have incomplete metric data. " + f"Missing metrics: {opt_metric_names}. " + f"Affected trials: [0]. " + f"Check that your metric fetching infrastructure is healthy " + f"and that the metrics are being logged correctly." + ), + ] + expected_msg = "\n".join(expected_parts) + self.assertEqual(str(cm.exception), expected_msg) + def test_should_consider_optimization_complete(self) -> None: # Tests non-GSS parts of the completion criterion. gs = self.sobol_MBM_GS