diff --git a/ax/core/metric.py b/ax/core/metric.py index 24240dfe4b6..ecb975c1981 100644 --- a/ax/core/metric.py +++ b/ax/core/metric.py @@ -88,11 +88,6 @@ class Metric(SortableBase, SerializationMixin): properties: Properties specific to a particular metric. """ - # The set of exception types stored in a ``MetchFetchE.exception`` that are - # recoverable ``orchestrator._fetch_and_process_trials_data_results()``. - # Exception may be a subclass of any of these types. If you want your metric - # to never fail the trial, set this to ``{Exception}`` in your metric subclass. - recoverable_exceptions: set[type[Exception]] = set() has_map_data: bool = False def __init__( @@ -164,17 +159,6 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta: """ return timedelta(0) - @classmethod - def is_recoverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool: - """Checks whether the given MetricFetchE is recoverable for this metric class - in ``orchestrator._fetch_and_process_trials_data_results``. - """ - if metric_fetch_e.exception is None: - return False - return any( - isinstance(metric_fetch_e.exception, e) for e in cls.recoverable_exceptions - ) - # NOTE: This is rarely overridden –– oonly if you want to fetch data in groups # consisting of multiple different metric classes, for data to be fetched together. # This makes sense only if `fetch_trial data_multi` or `fetch_experiment_data_multi` diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index d1b6e1f827b..95e1672753b 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -32,6 +32,7 @@ from ax.core.runner import Runner from ax.core.trial import Trial from ax.core.trial_status import TrialStatus +from ax.core.utils import compute_metric_availability, MetricAvailability from ax.exceptions.core import ( AxError, DataRequiredError, @@ -80,13 +81,6 @@ "of an optimization and if at least {min_failed} trials have been " "failed/abandoned, potentially automatically due to issues with the trial." ) -METRIC_FETCH_ERR_MESSAGE = ( - "A majority of the trial failures encountered are due to metric fetching errors. " - "This could mean the metrics are flaky, broken, or misconfigured. Please check " - "that the trial processes/jobs are successfully producing the expected metrics and " - "that the metric is correctly configured." -) - EXPECTED_STAGED_MSG = ( "Expected all trials to be in status {expected} after running or staging, " "found {t_idx_to_status}." @@ -191,13 +185,6 @@ class Orchestrator(WithDBSettingsBase, BestPointMixin): # Saved as a property so that it can be accessed after optimization is complex (ex. # for global stopping saving calculation). _num_remaining_requested_trials: int = 0 - # Total number of MetricFetchEs encountered during the course of optimization. Note - # this is different from and may be greater than the number of trials that have - # been marked either FAILED or ABANDONED due to metric fetching errors. - _num_metric_fetch_e_encountered: int = 0 - # Number of trials that have been marked either FAILED or ABANDONED due to - # MetricFetchE being encountered during _fetch_and_process_trials_data_results - _num_trials_bad_due_to_err: int = 0 # Keeps track of whether the allowed failure rate has been exceeded during # the optimization. If true, allows any pending trials to finish and raises # an error through self._complete_optimization. @@ -1073,83 +1060,63 @@ def summarize_final_result(self) -> OptimizationResult: """ return OptimizationResult() - def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool: - """Checks if the failure rate (set in Orchestrator options) has been exceeded at - any point during the optimization. + def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None: + """Raises an exception if the failure rate (set in Orchestrator options) has + been exceeded at any point during the optimization. - NOTE: Both FAILED and ABANDONED trial statuses count towards the failure rate. + The failure rate is computed as the ratio of "bad" trials to total trials + created by this orchestrator. "Bad" trials include: + - Execution failures: trials with FAILED or ABANDONED status. + - Metric-incomplete trials: COMPLETED trials whose metric data is not + fully available (as determined by ``compute_metric_availability``). Args: force_check: Indicates whether to force a failure-rate check regardless of the number of trials that have been executed. If False (default), the check will be skipped if the optimization has fewer than - five failed trials. If True, the check will be performed unless there - are 0 failures. + ``min_failed_trials_for_failure_rate_check`` bad trials. If True, the + check will be performed unless there are 0 bad trials. + """ + # Count runner-level failures (FAILED + ABANDONED). + num_execution_failures = self._num_bad_in_orchestrator() - Effect on state: - If the failure rate has been exceeded, a warning is logged and the private - attribute `_failure_rate_has_been_exceeded` is set to True, which causes the - `_get_max_pending_trials` to return zero, so that no further trials are - scheduled and an error is raised at the end of the optimization. + # Count completed trials with incomplete metric availability. + num_metric_incomplete, missing_metrics_by_trial = ( + self._get_metric_incomplete_trials() + ) - Returns: - Boolean representing whether the failure rate has been exceeded. - """ - if self._failure_rate_has_been_exceeded: - return True + num_bad = num_execution_failures + num_metric_incomplete - num_bad_in_orchestrator = self._num_bad_in_orchestrator() - # skip check if 0 failures - if num_bad_in_orchestrator == 0: - return False + if not self._failure_rate_has_been_exceeded: + # Skip check if 0 bad trials. + if num_bad == 0: + return - # skip check if fewer than min_failed_trials_for_failure_rate_check failures - # unless force_check is True - if ( - num_bad_in_orchestrator - < self.options.min_failed_trials_for_failure_rate_check - and not force_check - ): - return False + # Skip check if fewer than min threshold unless force_check. + if ( + num_bad < self.options.min_failed_trials_for_failure_rate_check + and not force_check + ): + return - num_ran_in_orchestrator = self._num_ran_in_orchestrator() - failure_rate_exceeded = ( - num_bad_in_orchestrator / num_ran_in_orchestrator - ) > self.options.tolerated_trial_failure_rate + num_ran_in_orchestrator = self._num_ran_in_orchestrator() + failure_rate_exceeded = ( + num_bad / num_ran_in_orchestrator + ) > self.options.tolerated_trial_failure_rate + + if not failure_rate_exceeded: + return - if failure_rate_exceeded: - if self._num_trials_bad_due_to_err > num_bad_in_orchestrator / 2: - self.logger.warning( - "MetricFetchE INFO: Sweep aborted due to an exceeded error rate, " - "which was primarily caused by failure to fetch metrics. Please " - "check if anything could cause your metrics to be flaky or " - "broken." - ) # NOTE: this private attribute causes `_get_max_pending_trials` to # return zero, which causes no further trials to be scheduled. self._failure_rate_has_been_exceeded = True - return True - - return False - - def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None: - """Raises an exception if the failure rate (set in Orchestrator options) has - been exceeded at any point during the optimization. - NOTE: Both FAILED and ABANDONED trial statuses count towards the failure rate. - - Args: - force_check: Indicates whether to force a failure-rate check - regardless of the number of trials that have been executed. If False - (default), the check will be skipped if the optimization has fewer than - five failed trials. If True, the check will be performed unless there - are 0 failures. - """ - if self._check_if_failure_rate_exceeded(force_check=force_check): - raise self._get_failure_rate_exceeded_error( - num_bad_in_orchestrator=self._num_bad_in_orchestrator(), - num_ran_in_orchestrator=self._num_ran_in_orchestrator(), - ) + raise self._get_failure_rate_exceeded_error( + num_execution_failures=num_execution_failures, + num_metric_incomplete=num_metric_incomplete, + num_ran_in_orchestrator=self._num_ran_in_orchestrator(), + missing_metrics_by_trial=missing_metrics_by_trial, + ) def _error_if_status_quo_infeasible(self) -> None: """Raises an exception if the status-quo arm is infeasible and the @@ -2032,9 +1999,13 @@ def _fetch_and_process_trials_data_results( self, trial_indices: Iterable[int], ) -> dict[int, dict[str, MetricFetchResult]]: - """ - Fetches results from experiment and modifies trial statuses depending on - success or failure. + """Fetch trial data results and log any metric fetch errors. + + Metric fetch errors are logged but do NOT change trial status. + ``MetricAvailability`` (computed via ``compute_metric_availability``) + tracks data completeness separately, and the failure rate check in + ``error_if_failure_rate_exceeded`` uses it to detect persistent + metric issues. """ try: @@ -2085,41 +2056,12 @@ def _fetch_and_process_trials_data_results( f"Failed to fetch {metric_name} for trial {trial_index} with " f"status {status}, found {metric_fetch_e}." ) - self._num_metric_fetch_e_encountered += 1 self._report_metric_fetch_e( trial=self.experiment.trials[trial_index], metric_name=metric_name, metric_fetch_e=metric_fetch_e, ) - # If the fetch failure was for a metric in the optimization config (an - # objective or constraint) mark the trial as failed - optimization_config = self.experiment.optimization_config - if ( - optimization_config is not None - and metric_name in optimization_config.metric_names - and not self.experiment.metrics[metric_name].is_recoverable_fetch_e( - metric_fetch_e=metric_fetch_e - ) - ): - status = self._mark_err_trial_status( - trial=self.experiment.trials[trial_index], - metric_name=metric_name, - metric_fetch_e=metric_fetch_e, - ) - self.logger.warning( - f"MetricFetchE INFO: Because {metric_name} is an objective, " - f"marking trial {trial_index} as {status}." - ) - self._num_trials_bad_due_to_err += 1 - continue - - self.logger.info( - "MetricFetchE INFO: Continuing optimization even though " - "MetricFetchE encountered." - ) - continue - return results def _report_metric_fetch_e( @@ -2128,39 +2070,122 @@ def _report_metric_fetch_e( metric_name: str, metric_fetch_e: MetricFetchE, ) -> None: + """Hook for subclasses to react to metric fetch errors. + + Called once per metric fetch error during + ``_fetch_and_process_trials_data_results``. The default + implementation is a no-op; override in subclasses to add custom + reporting (e.g., creating error tables or pastes). + """ pass - def _mark_err_trial_status( + def _get_metric_incomplete_trials( self, - trial: BaseTrial, - metric_name: str | None = None, - metric_fetch_e: MetricFetchE | None = None, - ) -> TrialStatus: - trial.mark_abandoned( - reason=metric_fetch_e.message if metric_fetch_e else None, unsafe=True + ) -> tuple[int, dict[int, set[str]]]: + """Count completed trials with incomplete metric availability and identify + which metrics are missing for each. + + Required metrics include optimization config metrics and any explicitly + defined early stopping strategy metrics. + + Returns: + A tuple of (num_metric_incomplete, missing_metrics_by_trial) where + missing_metrics_by_trial maps trial index to the set of missing + metric names. + """ + opt_config = self.experiment.optimization_config + if opt_config is None: + return 0, {} + + completed_trial_indices = [ + t.index + for t in self.experiment.trials.values() + if t.status == TrialStatus.COMPLETED + and t.index >= self._num_preexisting_trials + ] + if len(completed_trial_indices) == 0: + return 0, {} + + required_metrics = set(opt_config.metric_names) + + # Include explicitly defined early stopping strategy metrics. + # ESS stores metric *signatures*, which may differ from metric names, + # so we resolve them via experiment.signature_to_metric. + ess = self.options.early_stopping_strategy + ess_signatures = ess.metric_signatures if ess is not None else None + if ess_signatures is not None: + for sig in ess_signatures: + metric = self.experiment.signature_to_metric[sig] + required_metrics.add(metric.name) + + metric_availabilities = compute_metric_availability( + experiment=self.experiment, + trial_indices=completed_trial_indices, + metric_names=required_metrics, ) - return TrialStatus.ABANDONED + + # Identify which specific metrics are missing per trial. + data = self.experiment.lookup_data(trial_indices=completed_trial_indices) + metrics_per_trial: dict[int, set[str]] = {} + if len(data.metric_names) > 0: + df = data.full_df + for trial_idx, group in df.groupby("trial_index")["metric_name"]: + metrics_per_trial[int(trial_idx)] = set(group.unique()) + + missing_metrics_by_trial: dict[int, set[str]] = {} + for idx, avail in metric_availabilities.items(): + if avail != MetricAvailability.COMPLETE: + available = metrics_per_trial.get(idx, set()) + missing_metrics_by_trial[idx] = required_metrics - available + + return len(missing_metrics_by_trial), missing_metrics_by_trial def _get_failure_rate_exceeded_error( self, - num_bad_in_orchestrator: int, + num_execution_failures: int, + num_metric_incomplete: int, num_ran_in_orchestrator: int, + missing_metrics_by_trial: dict[int, set[str]], ) -> FailureRateExceededError: - return FailureRateExceededError( - ( - f"{METRIC_FETCH_ERR_MESSAGE}\n" - if self._num_trials_bad_due_to_err > num_bad_in_orchestrator / 2 - else "" + """Build an actionable error message describing why the failure rate was + exceeded, including runner failures, metric-incomplete trials, which + metrics are missing, and which trials are affected. + """ + num_bad = num_execution_failures + num_metric_incomplete + observed_rate = num_bad / num_ran_in_orchestrator + + parts: list[str] = [] + parts.append( + f"Failure rate exceeded: {num_bad} of {num_ran_in_orchestrator} " + f"trials were unsuccessful (observed rate: {observed_rate:.0%}, tolerance: " + f"{self.options.tolerated_trial_failure_rate:.0%}). " + f"Checks are triggered when at least " + f"{self.options.min_failed_trials_for_failure_rate_check} trials " + "are unsuccessful or at the end of the optimization." + ) + + if num_execution_failures > 0: + parts.append( + f"{num_execution_failures} trial(s) failed at the execution " + "level (FAILED or ABANDONED). Check any trial evaluation " + "processes/jobs to see why they are failing." ) - + " Orignal error message: " - + FAILURE_EXCEEDED_MSG.format( - f_rate=self.options.tolerated_trial_failure_rate, - n_failed=num_bad_in_orchestrator, - n_ran=num_ran_in_orchestrator, - min_failed=self.options.min_failed_trials_for_failure_rate_check, - observed_rate=float(num_bad_in_orchestrator) / num_ran_in_orchestrator, + + if num_metric_incomplete > 0: + all_missing: set[str] = set() + for missing in missing_metrics_by_trial.values(): + all_missing.update(missing) + affected_trials = sorted(missing_metrics_by_trial.keys()) + + parts.append( + f"{num_metric_incomplete} trial(s) have incomplete metric data. " + f"Missing metrics: {sorted(all_missing)}. " + f"Affected trials: {affected_trials}. " + "Check that your metric fetching infrastructure is healthy " + "and that the metrics are being logged correctly." ) - ) + + return FailureRateExceededError("\n".join(parts)) def _warn_if_non_terminal_trials(self) -> None: """Warns if there are any non-terminal trials on the experiment.""" diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 1532541ab4c..36779abf5c3 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -38,12 +38,7 @@ get_pending_observation_features_based_on_trial_status, ) from ax.early_stopping.strategies import BaseEarlyStoppingStrategy -from ax.exceptions.core import ( - AxError, - OptimizationComplete, - UnsupportedError, - UserInputError, -) +from ax.exceptions.core import OptimizationComplete, UnsupportedError, UserInputError from ax.exceptions.generation_strategy import AxGenerationException from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.generation_strategy.generation_strategy import ( @@ -1834,6 +1829,12 @@ def test_fetch_and_process_trials_data_results_failed_non_objective( ) def test_fetch_and_process_trials_data_results_failed_objective(self) -> None: + """Metric fetch errors on objective metrics do NOT change trial status. + + The trial remains COMPLETED, and MetricAvailability reflects the missing + data. The failure rate check uses MetricAvailability to detect persistent + metric issues. + """ gs = self.two_sobol_steps_GS orchestrator = Orchestrator( experiment=self.branin_experiment, @@ -1854,97 +1855,44 @@ def test_fetch_and_process_trials_data_results_failed_objective(self) -> None: ), self.assertLogs(logger="ax.orchestration.orchestrator") as lg, ): - # This trial will fail + # The trial completes but has incomplete metrics, triggering + # the failure rate check. with self.assertRaises(FailureRateExceededError): orchestrator.run_n_trials(max_trials=1) - self.assertTrue( - any( - re.search(r"Failed to fetch (branin|m1) for trial 0", warning) - is not None - for warning in lg.output - ) - ) + # Verify the error was logged (not the old "marking trial as ABANDONED"). self.assertTrue( any( re.search( - r"Because (branin|m1) is an objective, marking trial 0 as " - "TrialStatus.ABANDONED", - warning, - ) - is not None - for warning in lg.output - ) - ) - self.assertEqual( - orchestrator.experiment.trials[0].status, TrialStatus.ABANDONED - ) - - def test_fetch_and_process_trials_data_results_failed_objective_but_recoverable( - self, - ) -> None: - gs = self.two_sobol_steps_GS - orchestrator = Orchestrator( - experiment=self.branin_experiment, - generation_strategy=gs, - options=OrchestratorOptions( - enforce_immutable_search_space_and_opt_config=False, - **self.orchestrator_options_kwargs, - ), - db_settings=self.db_settings_if_always_needed, - ) - BraninMetric.recoverable_exceptions = {AxError, TypeError} - # we're throwing a recoverable exception because UserInputError - # is a subclass of AxError - with ( - patch( - f"{BraninMetric.__module__}.BraninMetric.f", - side_effect=UserInputError("yikes!"), - ), - patch( - f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", - return_value=False, - ), - self.assertLogs(logger="ax.orchestration.orchestrator") as lg, - ): - orchestrator.run_n_trials(max_trials=1) - self.assertTrue( - any( - re.search(r"Failed to fetch (branin|m1) for trial 0", warning) - is not None - for warning in lg.output - ), - lg.output, - ) - self.assertTrue( - any( - re.search( - "MetricFetchE INFO: Continuing optimization even though " - "MetricFetchE encountered", + r"Failed to fetch (branin|m1) for trial 0", warning, ) is not None for warning in lg.output ) ) + # Trial stays COMPLETED -- not ABANDONED. self.assertEqual( orchestrator.experiment.trials[0].status, TrialStatus.COMPLETED ) - def test_fetch_and_process_trials_data_results_failed_objective_not_recoverable( - self, - ) -> None: + def test_failure_rate_metric_incomplete(self) -> None: + """Failure rate check uses MetricAvailability to count metric-incomplete + trials and raises FailureRateExceededError with an actionable message + listing missing metrics and affected trials. + """ gs = self.two_sobol_steps_GS + tolerated_failure_rate = 0.5 + min_failed = 1 orchestrator = Orchestrator( experiment=self.branin_experiment, generation_strategy=gs, options=OrchestratorOptions( + tolerated_trial_failure_rate=tolerated_failure_rate, + min_failed_trials_for_failure_rate_check=min_failed, **self.orchestrator_options_kwargs, ), db_settings=self.db_settings_if_always_needed, ) - # we're throwing a unrecoverable exception because Exception is not subclass - # of either error type in recoverable_exceptions - BraninMetric.recoverable_exceptions = {AxError, TypeError} with ( patch( f"{BraninMetric.__module__}.BraninMetric.f", @@ -1954,33 +1902,39 @@ def test_fetch_and_process_trials_data_results_failed_objective_not_recoverable( f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", return_value=False, ), - self.assertLogs(logger="ax.orchestration.orchestrator") as lg, ): - # This trial will fail - with self.assertRaises(FailureRateExceededError): + with self.assertRaises(FailureRateExceededError) as cm: orchestrator.run_n_trials(max_trials=1) - self.assertTrue( - any( - re.search(r"Failed to fetch (branin|m1) for trial 0", warning) - is not None - for warning in lg.output - ) - ) - self.assertTrue( - any( - re.search( - r"Because (branin|m1) is an objective, marking trial 0 as " - "TrialStatus.ABANDONED", - warning, - ) - is not None - for warning in lg.output - ) - ) + + # Trial stays COMPLETED -- metric fetch errors do not change status. self.assertEqual( - orchestrator.experiment.trials[0].status, TrialStatus.ABANDONED + orchestrator.experiment.trials[0].status, TrialStatus.COMPLETED ) + # Build the expected error message from orchestrator config values. + # 1 trial ran, 0 execution failures, 1 metric-incomplete trial. + opt_config = none_throws(orchestrator.experiment.optimization_config) + opt_metric_names = sorted(opt_config.metric_names) + expected_parts = [ + ( + f"Failure rate exceeded: 1 of 1 trials were unsuccessful " + f"(observed rate: 100%, tolerance: " + f"{tolerated_failure_rate:.0%}). " + f"Checks are triggered when at least " + f"{min_failed} trials " + f"are unsuccessful or at the end of the optimization." + ), + ( + f"1 trial(s) have incomplete metric data. " + f"Missing metrics: {opt_metric_names}. " + f"Affected trials: [0]. " + f"Check that your metric fetching infrastructure is healthy " + f"and that the metrics are being logged correctly." + ), + ] + expected_msg = "\n".join(expected_parts) + self.assertEqual(str(cm.exception), expected_msg) + def test_should_consider_optimization_complete(self) -> None: # Tests non-GSS parts of the completion criterion. gs = self.sobol_MBM_GS